public static Function GetEvalFunction(NeuralNetworkEvalFunction lossFunction, Function calculatedOutputs, Variable outputVariable) { switch (lossFunction) { case NeuralNetworkEvalFunction.CrossEntropyWithSoftmax: return(CNTKLib.CrossEntropyWithSoftmax(calculatedOutputs, outputVariable)); case NeuralNetworkEvalFunction.ClassificationError: return(CNTKLib.ClassificationError(calculatedOutputs, outputVariable)); case NeuralNetworkEvalFunction.SquaredError: return(CNTKLib.SquaredError(calculatedOutputs, outputVariable)); case NeuralNetworkEvalFunction.MeanAbsoluteError: return(NetworkBuilder.MeanAbsoluteError(calculatedOutputs, outputVariable)); case NeuralNetworkEvalFunction.MeanAbsolutePercentageError: return(NetworkBuilder.MeanAbsolutePercentageError(calculatedOutputs, outputVariable)); default: throw new InvalidOperationException("Unexpected " + lossFunction); } }
//Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614 public void Train(PredictorTrainingContext ctx) { InitialSetup(); var p = ctx.Predictor; var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings; DeviceDescriptor device = GetDevice(nn); Variable inputVariable = Variable.InputVariable(new[] { ctx.InputCodifications.Count }, DataType.Float, "input"); Variable outputVariable = Variable.InputVariable(new[] { ctx.OutputCodifications.Count }, DataType.Float, "output"); Variable currentVar = inputVariable; nn.HiddenLayers.ForEach((layer, i) => { currentVar = NetworkBuilder.DenseLayer(currentVar, layer.Size, device, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i); }); Function calculatedOutputs = NetworkBuilder.DenseLayer(currentVar, ctx.OutputCodifications.Count, device, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output"); Function loss = NetworkBuilder.GetEvalFunction(nn.LossFunction, calculatedOutputs, outputVariable); Function evalError = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, calculatedOutputs, outputVariable); // prepare for training Learner learner = NetworkBuilder.GetInitializer(calculatedOutputs.Parameters(), nn); Trainer trainer = Trainer.CreateTrainer(calculatedOutputs, loss, evalError, new List <Learner>() { learner }); Random rand = p.Settings.Seed == null ? new Random() : new Random(p.Settings.Seed.Value); var(training, validation) = ctx.SplitTrainValidation(rand); var minibachtSize = nn.MinibatchSize; var numMinibatches = nn.NumMinibatches; Stopwatch sw = Stopwatch.StartNew(); List <FinalCandidate> candidate = new List <FinalCandidate>(); for (int i = 0; i < numMinibatches; i++) { using (HeavyProfiler.Log("MiniBatch", () => i.ToString())) { ctx.ReportProgress("Training Minibatches", (i + 1) / (decimal)numMinibatches); { var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList(); using (Value inputValue = CreateValue(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device)) using (Value outputValue = CreateValue(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device)) { using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString())) trainer.TrainMinibatch(new Dictionary <Variable, Value>() { { inputVariable, inputValue }, { outputVariable, outputValue }, }, false, device); } } var ep = new EpochProgress { Ellapsed = sw.ElapsedMilliseconds, Epoch = i, TrainingExamples = (int)trainer.TotalNumberOfSamplesSeen(), LossTraining = trainer.PreviousMinibatchLossAverage(), EvaluationTraining = trainer.PreviousMinibatchEvaluationAverage(), LossValidation = null, EvaluationValidation = null, }; ctx.Progresses.Enqueue(ep); if (ctx.StopTraining) { p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember(); } var isLast = numMinibatches - nn.BestResultFromLast <= i; if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining) { if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining) { using (HeavyProfiler.LogNoStackTrace("Validation")) { var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList(); using (Value inputValValue = CreateValue(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device)) using (Value outputValValue = CreateValue(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device)) { var inputs = new Dictionary <Variable, Value>() { { inputVariable, inputValValue }, { outputVariable, outputValValue }, }; ep.LossValidation = loss.EvaluateAvg(inputs, device); ep.EvaluationValidation = evalError.EvaluateAvg(inputs, device); } } } var progress = ep.SaveEntity(ctx.Predictor); if (isLast || ctx.StopTraining) { using (HeavyProfiler.LogNoStackTrace("FinalCandidate")) { candidate.Add(new FinalCandidate { Model = calculatedOutputs.Save(), ResultTraining = new PredictorMetricsEmbedded { Evaluation = progress.EvaluationTraining, Loss = progress.LossTraining }, ResultValidation = new PredictorMetricsEmbedded { Evaluation = progress.EvaluationValidation, Loss = progress.LossValidation }, }); } } } if (ctx.StopTraining) { break; } } } var best = candidate.WithMin(a => a.ResultValidation.Loss !.Value); p.ResultTraining = best.ResultTraining; p.ResultValidation = best.ResultValidation; var fp = new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, "Model.cntk", best.Model); p.Files.Add(fp); using (OperationLogic.AllowSave <PredictorEntity>()) p.Save(); }