// Modifies parameters in place // returns model output public NDArray Update(GradFunc grad, NDArray[] parameters) { var outputAndGrads = grad(parameters); NDArray output = outputAndGrads.output; NDArray[] gradients = outputAndGrads.grads; if (gradAcc == null) { gradAcc = gradients.Select(x => { var result = new NDArray(x.Allocator, x.ElementType, x.Shape); Ops.Fill(result, 0); return(result); }).ToArray(); } // gradAcc = gradAcc * momentum - learningRate * gradients for (int i = 0; i < gradients.Length; ++i) { Ops.Mul(gradAcc[i], gradAcc[i], config.Momentum); using (var temp = Ops.Mul(null, gradients[i], -config.LearningRate)) { Ops.Add(gradAcc[i], gradAcc[i], temp); } } for (int i = 0; i < parameters.Length; ++i) { Ops.Add(parameters[i], parameters[i], gradAcc[i]); } return(output); }
// Runs a single epoch of training. private static void TrainEpoch(Sequential model, ICriterion criterion, SgdOptimizer optim, DataSet trainingSet, int numInputs, bool useTargetClasses) { using (new SimpleTimer("Training epoch completed in {0}ms")) { for (int batchStart = 0; batchStart <= trainingSet.inputs.Shape[0] - BatchSize; batchStart += BatchSize) { Console.Write("."); var grad = new GradFunc(parameters => { using (var mbInputs = trainingSet.inputs.Narrow(0, batchStart, BatchSize)) using (var mbTargets = trainingSet.targets.Narrow(0, batchStart, BatchSize)) using (var mbTargetClasses = trainingSet.targetValues.Narrow(0, batchStart, BatchSize)) { foreach (var gradNDArray in model.GetGradParameters()) { Ops.Fill(gradNDArray, 0); } var modelOutput = model.Forward(mbInputs, ModelMode.Train); var criterionOutput = criterion.UpdateOutput(modelOutput, useTargetClasses ? mbTargetClasses : mbTargets); var criterionGradIn = criterion.UpdateGradInput(modelOutput, useTargetClasses ? mbTargetClasses : mbTargets); model.Backward(mbInputs, criterionGradIn, ModelMode.Train); return(new OutputAndGrads() { output = modelOutput, grads = model.GetGradParameters().ToArray() }); } }); optim.Update(grad, model.GetParameters().ToArray()); } } Console.WriteLine(); }