Exemplo n.º 1
0
            public override void Backward(Tensor <Type> deltas, Backpropagation bp)
            {
                deltas.AssertOfShape(Shape);

                var deltaFromRecursive = OutputInfo != null;

                // var in the forward -> for in the backward
                var forsDic = new Dictionary <ISymbol, IFor>();   // ITensorSymbol

                var backLoop = new Loop("d" + Loop.Name);

                backLoop.Length = Loop.Length;
                var substitution = new Patch(preserveShape: true);

                // add the sequences used by the forward
                int fwdSeqCount = Loop.Sequences.Count;

                for (int i = 0; i < fwdSeqCount; i++)
                {
                    var seq      = Loop.Sequences[i];
                    var variable = Loop.Variable(seq);
                    var alias    = Loop.Sequences[i].Match(
                        (Tensor <float> s) =>
                        backLoop.AddSeq(s[Step_m1], variable.Name + "_", Loop.SequenceAxes[i]),
                        (Tensor <int> s) =>
                        backLoop.AddSeq(s[Step_m1], variable.Name + "_", Loop.SequenceAxes[i]),
                        (Func <ITensor>)null
                        );
                    substitution.Add_(variable, alias);
                }

                // add the sequences computed by the forward
                foreach (var @for in Loop.Fors)
                {
                    if (@for.IsRecursive)
                    {
                        var variable = @for.RecursiveVariable;
                        var alias    = @for.Match(
                            (Tensor <float> .For f) =>
                            backLoop.AddSeq(new Insert <float>(f, 0, f.OutputInfo, 0)[From_m2_Step_m1], variable.Name + "_", axis: 0),
                            (Tensor <int> .For f) =>
                            backLoop.AddSeq(new Insert <int>(f, 0, f.OutputInfo, 0)[From_m2_Step_m1], variable.Name + "_", axis: 0),
                            (Func <ITensor>)null
                            );
                        substitution.Add_(variable, alias);
                    }
                    else
                    {
                        var alias = @for.Match(
                            (Tensor <float> .For f) =>
                            backLoop.AddSeq(f[Step_m1], @for.Name + "_"),
                            (Tensor <int> .For f) =>
                            backLoop.AddSeq(f[Step_m1], @for.Name + "_"),
                            (Func <ITensor>)null
                            );
                        substitution.Add_(@for.Expression, alias);
                    }
                }

                // add the retropropagated delta
                var deltaOut = backLoop.AddSeq(deltas[Step_m1], $"delta_{RecursiveVariable?.ToString() ?? "f" + Index}_", axis: 0);

                // d_ avoid duplicated variables with the same name.
                var d_ = new Dictionary <IVar, IVar>();

                // add deltas of sequences (inputs of and computed by the forward), initialized to zero
                var recVariables = Loop.RecursiveFors.Select(f => Loop.Variable(f));

                foreach (var varFwd in Loop.Variables)
                {
                    var zeros = varFwd.Match(
                        (Tensor <float> .Var x) => Op.ZerosLike(x),
                        (Tensor <int> .Var x) => Op.ZerosLike(x),
                        (Func <ITensor>)null
                        );
                    var @for = backLoop.AddRecursive_(zeros, zeros, $"d{varFwd.Name}_");
                    @for.Comment = $"dL/d{varFwd}";

                    d_[varFwd]      = @for.RecursiveVariable;
                    forsDic[varFwd] = @for;
                }

                // `others` collect gradients pushed to expressions of the loop that aren't sequences or variables.
                var others = new Dictionary <IExpr, IFor>();

                AddDeltaFromBackpropagate(backLoop, others, forsDic, Backpropagation.Backward(Expression, deltaFromRecursive ? deltaOut + (Var)d_[RecursiveVariable] : deltaOut));

                foreach (var @for in Loop.RecursiveFors)
                {
                    var variable = @for.RecursiveVariable;

                    if (!deltaFromRecursive || @for != this)
                    {
                        var gradExpr = @for.Match(
                            (Tensor <float> .For f) => Backpropagation.Backward(f.Expression, (Tensor <float>)d_[f.RecursiveVariable]),
                            (Tensor <int> .For f) => Backpropagation.Backward(f.Expression, (Tensor <int>)d_[f.RecursiveVariable]),
                            null
                            );

                        AddDeltaFromBackpropagate(backLoop, others, forsDic, gradExpr);
                    }
                    // else: we already added the delta prior to the loop

                    // reuse results computed during the forward inside the backward
                    var alias_tp1 = backLoop.AddRecursive_(variable, @for[-1], variable.Name + "_tp1").RecursiveVariable;
                    substitution.Add_(@for.Expression, alias_tp1);
                }

                // Substitute variable in fors
                foreach (var @for in backLoop.Fors)
                {
                    var comment = @for.Expression.Comment;
                    @for.Expression         = (ITensor)@for.Expression.Patch(substitution);
                    @for.Expression.Comment = comment;
                }

                // deltas of sequences
                for (int i = 0; i < Loop.Sequences.Count; ++i)
                {
                    if (Loop.Sequences[i] is Tensor <float> )
                    {
                        bp.PushGradientTo((Tensor <float>)Loop.Sequences[i], ((Tensor <float>)backLoop.Fors[i])[Step_m1]);
                    }
                    else
                    {
                        throw new NotImplementedException();
                    }
                }

                // deltas of seed
                foreach (var @for in Loop.RecursiveFors)
                {
                    if (@for is Tensor <float> )
                    {
                        bp.PushGradientTo((Tensor <float>)@for.OutputInfo, ((Tensor <float>)forsDic[@for.RecursiveVariable])[-1]);
                    }
                    else
                    {
                        throw new NotImplementedException();
                    }
                }

                // other deltas
                foreach (var W_dW in others)
                {
                    var W = W_dW.Key; var dW = W_dW.Value;
                    if (W is Tensor <float> )
                    {
                        bp.PushGradientTo((Tensor <float>)W, Op.Sum((Tensor <float>)dW, axis: 0));
                    }
                    else
                    {
                        throw new NotImplementedException();
                    }
                }
            }
Exemplo n.º 2
0
 public abstract IExpr Patch(Patch substitutions);
Exemplo n.º 3
0
 /// <remarks>This implementation only works while there is no inputs</remarks>
 public override IExpr Patch(Patch substitutions) => PatternMatching.GetOrElse(substitutions.TryGetValue, this, this);