Ejemplo n.º 1
0
        internal Loop Patch(Patch substitutions)
        {
            Loop result;

            if (substitutions.TryGetValue(this, out result))
            {
                return(result);
            }

            // patch the length
            var patchLength = (Scalar <int>)Length.Patch(substitutions);

            // patch the sequences and outputInfo from the loop
            var patchSequences   = Sequences.Patch(substitutions);
            var outputsInfo      = Fors.Select(f => f.OutputInfo).ToArray();
            var patchOutputsInfo = outputsInfo.Patch(substitutions);

            // the expression uses the loop variables
            // we do a first patch to detect if we need to create a new loop
            var blankSubstitutions = new Patch(substitutions);
            var expressions        = Fors.Select(f => f.Expression).ToArray();
            var patchExpressions   = expressions.Patch(blankSubstitutions);

            if (patchLength == Length &&
                patchOutputsInfo == outputsInfo &&
                patchSequences == Sequences &&
                patchExpressions == expressions)
            {
                // the expressions have been correctly patched
                foreach (var(expr, value) in (expressions, patchExpressions).Zip())
                {
                    substitutions.Add_(expr, value);
                }
                result = this;
            }
            else
            {
                // we need to create a new loop in particular we need to create new hidden variables for the loop
                var loopName = LoopName();
                // create the patched Loop
                result = new Loop(loopName, patchLength);

                // created variables are automatically added to `substitutions`
                var createdVariables = Variables.Select(v => _patchVar(v, substitutions)).ToList();

                for (int i = 0; i < Sequences.Count; ++i)
                {
                    var seq      = patchSequences[i];
                    var patchVar = createdVariables[i];
                    patchVar.Match(
                        (Tensor <float> .Var varF) => (ITensorVar)result.AddSeq((Tensor <float>)seq, varF, SequenceAxes[i]),
                        (Tensor <int> .Var varI) => result.AddSeq((Tensor <int>)seq, varI, SequenceAxes[i]),
                        null
                        );
                }

                foreach (var @for in Fors)
                {
                    IFor patchFor;

                    // we need to patch the expression to use the new variables
                    var patchExpr = (ITensor)@for.Expression.Patch(substitutions);
                    if (@for.IsRecursive)
                    {
                        // the output has been patched already, it should be in substitutions.
                        var patchOutput = (ITensor)@for.OutputInfo.Patch(substitutions);
                        // we just put the patched version of the variable in substitutions
                        var patchVar = (ITensorVar)@for.RecursiveVariable.Patch(substitutions);

                        patchFor = result.AddRecursive_(patchExpr, patchOutput, patchVar);
                    }
                    else
                    {
                        patchFor = result.AddOutput_(patchExpr);
                    }

                    substitutions.Add_(@for, patchFor);
                }
            }
            substitutions.Add(this, result);
            return(result);
        }
Ejemplo n.º 2
0
            public override void Backward(Tensor <Type> deltas, Backpropagation bp)
            {
                deltas.AssertOfShape(Shape);

                var deltaFromRecursive = OutputInfo != null;

                // var in the forward -> for in the backward
                var forsDic = new Dictionary <ISymbol, IFor>();   // ITensorSymbol

                var backLoop = new Loop("d" + Loop.Name);

                backLoop.Length = Loop.Length;
                var substitution = new Patch(preserveShape: true);

                // add the sequences used by the forward
                int fwdSeqCount = Loop.Sequences.Count;

                for (int i = 0; i < fwdSeqCount; i++)
                {
                    var seq      = Loop.Sequences[i];
                    var variable = Loop.Variable(seq);
                    var alias    = Loop.Sequences[i].Match(
                        (Tensor <float> s) =>
                        backLoop.AddSeq(s[Step_m1], variable.Name + "_", Loop.SequenceAxes[i]),
                        (Tensor <int> s) =>
                        backLoop.AddSeq(s[Step_m1], variable.Name + "_", Loop.SequenceAxes[i]),
                        (Func <ITensor>)null
                        );
                    substitution.Add_(variable, alias);
                }

                // add the sequences computed by the forward
                foreach (var @for in Loop.Fors)
                {
                    if (@for.IsRecursive)
                    {
                        var variable = @for.RecursiveVariable;
                        var alias    = @for.Match(
                            (Tensor <float> .For f) =>
                            backLoop.AddSeq(new Insert <float>(f, 0, f.OutputInfo, 0)[From_m2_Step_m1], variable.Name + "_", axis: 0),
                            (Tensor <int> .For f) =>
                            backLoop.AddSeq(new Insert <int>(f, 0, f.OutputInfo, 0)[From_m2_Step_m1], variable.Name + "_", axis: 0),
                            (Func <ITensor>)null
                            );
                        substitution.Add_(variable, alias);
                    }
                    else
                    {
                        var alias = @for.Match(
                            (Tensor <float> .For f) =>
                            backLoop.AddSeq(f[Step_m1], @for.Name + "_"),
                            (Tensor <int> .For f) =>
                            backLoop.AddSeq(f[Step_m1], @for.Name + "_"),
                            (Func <ITensor>)null
                            );
                        substitution.Add_(@for.Expression, alias);
                    }
                }

                // add the retropropagated delta
                var deltaOut = backLoop.AddSeq(deltas[Step_m1], $"delta_{RecursiveVariable?.ToString() ?? "f" + Index}_", axis: 0);

                // d_ avoid duplicated variables with the same name.
                var d_ = new Dictionary <IVar, IVar>();

                // add deltas of sequences (inputs of and computed by the forward), initialized to zero
                var recVariables = Loop.RecursiveFors.Select(f => Loop.Variable(f));

                foreach (var varFwd in Loop.Variables)
                {
                    var zeros = varFwd.Match(
                        (Tensor <float> .Var x) => Op.ZerosLike(x),
                        (Tensor <int> .Var x) => Op.ZerosLike(x),
                        (Func <ITensor>)null
                        );
                    var @for = backLoop.AddRecursive_(zeros, zeros, $"d{varFwd.Name}_");
                    @for.Comment = $"dL/d{varFwd}";

                    d_[varFwd]      = @for.RecursiveVariable;
                    forsDic[varFwd] = @for;
                }

                // `others` collect gradients pushed to expressions of the loop that aren't sequences or variables.
                var others = new Dictionary <IExpr, IFor>();

                AddDeltaFromBackpropagate(backLoop, others, forsDic, Backpropagation.Backward(Expression, deltaFromRecursive ? deltaOut + (Var)d_[RecursiveVariable] : deltaOut));

                foreach (var @for in Loop.RecursiveFors)
                {
                    var variable = @for.RecursiveVariable;

                    if (!deltaFromRecursive || @for != this)
                    {
                        var gradExpr = @for.Match(
                            (Tensor <float> .For f) => Backpropagation.Backward(f.Expression, (Tensor <float>)d_[f.RecursiveVariable]),
                            (Tensor <int> .For f) => Backpropagation.Backward(f.Expression, (Tensor <int>)d_[f.RecursiveVariable]),
                            null
                            );

                        AddDeltaFromBackpropagate(backLoop, others, forsDic, gradExpr);
                    }
                    // else: we already added the delta prior to the loop

                    // reuse results computed during the forward inside the backward
                    var alias_tp1 = backLoop.AddRecursive_(variable, @for[-1], variable.Name + "_tp1").RecursiveVariable;
                    substitution.Add_(@for.Expression, alias_tp1);
                }

                // Substitute variable in fors
                foreach (var @for in backLoop.Fors)
                {
                    var comment = @for.Expression.Comment;
                    @for.Expression         = (ITensor)@for.Expression.Patch(substitution);
                    @for.Expression.Comment = comment;
                }

                // deltas of sequences
                for (int i = 0; i < Loop.Sequences.Count; ++i)
                {
                    if (Loop.Sequences[i] is Tensor <float> )
                    {
                        bp.PushGradientTo((Tensor <float>)Loop.Sequences[i], ((Tensor <float>)backLoop.Fors[i])[Step_m1]);
                    }
                    else
                    {
                        throw new NotImplementedException();
                    }
                }

                // deltas of seed
                foreach (var @for in Loop.RecursiveFors)
                {
                    if (@for is Tensor <float> )
                    {
                        bp.PushGradientTo((Tensor <float>)@for.OutputInfo, ((Tensor <float>)forsDic[@for.RecursiveVariable])[-1]);
                    }
                    else
                    {
                        throw new NotImplementedException();
                    }
                }

                // other deltas
                foreach (var W_dW in others)
                {
                    var W = W_dW.Key; var dW = W_dW.Value;
                    if (W is Tensor <float> )
                    {
                        bp.PushGradientTo((Tensor <float>)W, Op.Sum((Tensor <float>)dW, axis: 0));
                    }
                    else
                    {
                        throw new NotImplementedException();
                    }
                }
            }