コード例 #1
0
            public override void Backward(Executor executor)
            {
                var input  = executor.GetTensor(Input);
                var states = executor.GetTensor(States);
                var weight = executor.GetTensor(Weight);

                Util.EnsureTrue(input.Shape.Rank == 2);
                Util.EnsureTrue(states.Shape.Rank == 3, "states shape: (steps, n, n)");
                Util.EnsureTrue(states.Shape[1] == states.Shape[2], "states shape: (steps, n, n)");
                var steps        = (int)states.Shape[0];
                var n            = states.Shape[1];
                var intermediate = executor.GetTensor(Intermediate);
                var output       = executor.GetTensor(Output);

                var dOutput       = executor.GetGradient(Output);
                var dIntermediate = executor.GetGradient(Intermediate, intermediate.Shape);
                var dStates       = executor.GetGradient(States, states.Shape);
                var dWeight       = executor.GetGradient(Weight, weight.Shape);
                var dInput        = executor.GetGradient(Input, input.Shape);

                var counterInput        = executor.GetGradientAggregationCounter(Input);
                var counterWeight       = executor.GetGradientAggregationCounter(Weight);
                var counterStates       = executor.GetGradientAggregationCounter(States);
                var counterIntermediate = executor.GetGradientAggregationCounter(Intermediate);

                var subExecutor = (Executor)executor.Objects[SubExecutor];

                for (var i = steps - 1; i >= 0; --i)
                {
                    // need set both input and output tensor and their gradient

                    var input_i  = i == 0 ? input : intermediate.Slice(i - 1).Reshape(n, n);
                    var state_i  = states.Slice(i).Reshape(n, n);
                    var output_i = i == steps - 1 ? output : intermediate.Slice(i).Reshape(n, n);

                    subExecutor.SetTensor(SubInput, input_i);
                    subExecutor.SetTensor(SubWeight, weight);
                    subExecutor.SetTensor(SubState, state_i);
                    subExecutor.SetTensor(SubOutput, output_i);

                    var dInput_i  = i == 0 ? dInput : dIntermediate.Slice(i - 1).Reshape(n, n);
                    var dState_i  = dStates.Slice(i).Reshape(n, n);
                    var dOutput_i = i == steps - 1 ? dOutput : dIntermediate.Slice(i).Reshape(n, n);

                    // since we have one shared variable, the weight, so we need update the
                    // gradient aggregation counter ourselves
                    // set counter = 0 means, you just point the memory for that gradient to another
                    // tensor, but it contains no value for aggregation
                    // but since weight is shared, so we need update its counter correctly, it
                    // will be assigned by steps - 1 times.
                    subExecutor.ClearGradientAggregationCounters();
                    subExecutor.SetGradient(SubInput, dInput_i, counter: i == 0 ? counterInput : counterIntermediate);
                    subExecutor.SetGradient(SubWeight, dWeight, counter: counterWeight + steps - 1 - i);
                    subExecutor.SetGradient(SubState, dState_i, counter: counterStates);
                    subExecutor.SetGradient(SubOutput, dOutput_i);

                    // do backward without clearing the counter, because we set the counter ourselves.
                    subExecutor.Backward(clearGradientAggretionCounter: false);
                }

                executor.IncreaseGradientAggregationCounter(Input);
                executor.IncreaseGradientAggregationCounter(Weight);
                executor.IncreaseGradientAggregationCounter(States);
                executor.IncreaseGradientAggregationCounter(Intermediate);
            }