public override void Backward(Executor executor) { var input = executor.GetTensor(Input); var states = executor.GetTensor(States); var weight = executor.GetTensor(Weight); Util.EnsureTrue(input.Shape.Rank == 2); Util.EnsureTrue(states.Shape.Rank == 3, "states shape: (steps, n, n)"); Util.EnsureTrue(states.Shape[1] == states.Shape[2], "states shape: (steps, n, n)"); var steps = (int)states.Shape[0]; var n = states.Shape[1]; var intermediate = executor.GetTensor(Intermediate); var output = executor.GetTensor(Output); var dOutput = executor.GetGradient(Output); var dIntermediate = executor.GetGradient(Intermediate, intermediate.Shape); var dStates = executor.GetGradient(States, states.Shape); var dWeight = executor.GetGradient(Weight, weight.Shape); var dInput = executor.GetGradient(Input, input.Shape); var counterInput = executor.GetGradientAggregationCounter(Input); var counterWeight = executor.GetGradientAggregationCounter(Weight); var counterStates = executor.GetGradientAggregationCounter(States); var counterIntermediate = executor.GetGradientAggregationCounter(Intermediate); var subExecutor = (Executor)executor.Objects[SubExecutor]; for (var i = steps - 1; i >= 0; --i) { // need set both input and output tensor and their gradient var input_i = i == 0 ? input : intermediate.Slice(i - 1).Reshape(n, n); var state_i = states.Slice(i).Reshape(n, n); var output_i = i == steps - 1 ? output : intermediate.Slice(i).Reshape(n, n); subExecutor.SetTensor(SubInput, input_i); subExecutor.SetTensor(SubWeight, weight); subExecutor.SetTensor(SubState, state_i); subExecutor.SetTensor(SubOutput, output_i); var dInput_i = i == 0 ? dInput : dIntermediate.Slice(i - 1).Reshape(n, n); var dState_i = dStates.Slice(i).Reshape(n, n); var dOutput_i = i == steps - 1 ? dOutput : dIntermediate.Slice(i).Reshape(n, n); // since we have one shared variable, the weight, so we need update the // gradient aggregation counter ourselves // set counter = 0 means, you just point the memory for that gradient to another // tensor, but it contains no value for aggregation // but since weight is shared, so we need update its counter correctly, it // will be assigned by steps - 1 times. subExecutor.ClearGradientAggregationCounters(); subExecutor.SetGradient(SubInput, dInput_i, counter: i == 0 ? counterInput : counterIntermediate); subExecutor.SetGradient(SubWeight, dWeight, counter: counterWeight + steps - 1 - i); subExecutor.SetGradient(SubState, dState_i, counter: counterStates); subExecutor.SetGradient(SubOutput, dOutput_i); // do backward without clearing the counter, because we set the counter ourselves. subExecutor.Backward(clearGradientAggretionCounter: false); } executor.IncreaseGradientAggregationCounter(Input); executor.IncreaseGradientAggregationCounter(Weight); executor.IncreaseGradientAggregationCounter(States); executor.IncreaseGradientAggregationCounter(Intermediate); }