Example #1
0
        // Runs a single epoch of training.
        private static void TrainEpoch(Sequential model, ICriterion criterion, SgdOptimizer optim, DataSet trainingSet, int numInputs, bool useTargetClasses)
        {
            using (new SimpleTimer("Training epoch completed in {0}ms"))
            {
                for (int batchStart = 0; batchStart <= trainingSet.inputs.Shape[0] - BatchSize; batchStart += BatchSize)
                {
                    Console.Write(".");

                    var grad = new GradFunc(parameters =>
                    {
                        using (var mbInputs = trainingSet.inputs.Narrow(0, batchStart, BatchSize))
                            using (var mbTargets = trainingSet.targets.Narrow(0, batchStart, BatchSize))
                                using (var mbTargetClasses = trainingSet.targetValues.Narrow(0, batchStart, BatchSize))
                                {
                                    foreach (var gradNDArray in model.GetGradParameters())
                                    {
                                        Ops.Fill(gradNDArray, 0);
                                    }

                                    var modelOutput     = model.Forward(mbInputs, ModelMode.Train);
                                    var criterionOutput = criterion.UpdateOutput(modelOutput, useTargetClasses ? mbTargetClasses : mbTargets);


                                    var criterionGradIn = criterion.UpdateGradInput(modelOutput, useTargetClasses ? mbTargetClasses : mbTargets);
                                    model.Backward(mbInputs, criterionGradIn, ModelMode.Train);

                                    return(new OutputAndGrads()
                                    {
                                        output = modelOutput, grads = model.GetGradParameters().ToArray()
                                    });
                                }
                    });

                    optim.Update(grad, model.GetParameters().ToArray());
                }
            }
            Console.WriteLine();
        }