//Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614 public void Train(PredictorTrainingContext ctx) { InitialSetup(); tf.compat.v1.disable_eager_execution(); var p = ctx.Predictor; var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings; Tensor inputPlaceholder = tf.placeholder(tf.float32, new[] { -1, ctx.InputCodifications.Count }, "inputPlaceholder"); Tensor outputPlaceholder = tf.placeholder(tf.float32, new[] { -1, ctx.OutputCodifications.Count }, "outputPlaceholder"); Tensor currentTensor = inputPlaceholder; nn.HiddenLayers.ForEach((layer, i) => { currentTensor = NetworkBuilder.DenseLayer(currentTensor, layer.Size, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i); }); Tensor output = NetworkBuilder.DenseLayer(currentTensor, ctx.OutputCodifications.Count, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output"); Tensor calculatedOutput = tf.identity(output, "calculatedOutput"); Tensor loss = NetworkBuilder.GetEvalFunction(nn.LossFunction, outputPlaceholder, calculatedOutput); Tensor accuracy = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, outputPlaceholder, calculatedOutput); // prepare for training Optimizer optimizer = NetworkBuilder.GetOptimizer(nn); Operation trainOperation = optimizer.minimize(loss); Random rand = p.Settings.Seed == null ? new Random() : new Random(p.Settings.Seed.Value); var(training, validation) = ctx.SplitTrainValidation(rand); var minibachtSize = nn.MinibatchSize; var numMinibatches = nn.NumMinibatches; Stopwatch sw = Stopwatch.StartNew(); List <FinalCandidate> candidate = new List <FinalCandidate>(); var config = new ConfigProto { IntraOpParallelismThreads = 1, InterOpParallelismThreads = 1, LogDevicePlacement = true }; ctx.ReportProgress($"Deleting Files"); var dir = PredictorDirectory(ctx.Predictor); if (Directory.Exists(dir)) { Directory.Delete(dir, true); } Directory.CreateDirectory(dir); ctx.ReportProgress($"Starting training..."); var saver = tf.train.Saver(); using (var sess = tf.Session(config)) { sess.run(tf.global_variables_initializer()); for (int i = 0; i < numMinibatches; i++) { using (HeavyProfiler.Log("MiniBatch", () => i.ToString())) { var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList(); var inputValue = CreateNDArray(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn); var outputValue = CreateNDArray(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn); using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString())) { sess.run(trainOperation, (inputPlaceholder, inputValue), (outputPlaceholder, outputValue)); } if (ctx.StopTraining) { p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember(); } var isLast = numMinibatches - nn.BestResultFromLast <= i; if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining) { float loss_val; float accuracy_val; using (HeavyProfiler.Log("EvalTraining", () => i.ToString())) { (loss_val, accuracy_val) = sess.run((loss, accuracy), (inputPlaceholder, inputValue), (outputPlaceholder, outputValue)); } var ep = new EpochProgress { Ellapsed = sw.ElapsedMilliseconds, Epoch = i, TrainingExamples = i * minibachtSize, LossTraining = loss_val, AccuracyTraining = accuracy_val, LossValidation = null, AccuracyValidation = null, }; ctx.ReportProgress($"Training Minibatches Loss:{loss_val} / Accuracy:{accuracy_val}", (i + 1) / (decimal)numMinibatches); ctx.Progresses.Enqueue(ep); if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining) { using (HeavyProfiler.LogNoStackTrace("EvalValidation")) { var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList(); var inputValValue = CreateNDArray(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn); var outputValValue = CreateNDArray(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn); (loss_val, accuracy_val) = sess.run((loss, accuracy), (inputPlaceholder, inputValValue), (outputPlaceholder, outputValValue)); ep.LossValidation = loss_val; ep.AccuracyValidation = accuracy_val; } } var progress = ep.SaveEntity(ctx.Predictor); if (isLast || ctx.StopTraining) { Directory.CreateDirectory(TrainingModelDirectory(ctx.Predictor, i)); var save = saver.save(sess, Path.Combine(TrainingModelDirectory(ctx.Predictor, i), ModelFileName)); using (HeavyProfiler.LogNoStackTrace("FinalCandidate")) { candidate.Add(new FinalCandidate { ModelIndex = i, ResultTraining = new PredictorMetricsEmbedded { Accuracy = progress.AccuracyTraining, Loss = progress.LossTraining }, ResultValidation = new PredictorMetricsEmbedded { Accuracy = progress.AccuracyValidation, Loss = progress.LossValidation }, }); } } } if (ctx.StopTraining) { break; } } } } var best = candidate.MinBy(a => a.ResultValidation.Loss !.Value) !; p.ResultTraining = best.ResultTraining; p.ResultValidation = best.ResultValidation; var files = Directory.GetFiles(TrainingModelDirectory(ctx.Predictor, best.ModelIndex)); p.Files.AddRange(files.Select(p => new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, p))); using (OperationLogic.AllowSave <PredictorEntity>()) p.Save(); }