Beispiel #1
0
        //Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614
        public void Train(PredictorTrainingContext ctx)
        {
            InitialSetup();

            tf.compat.v1.disable_eager_execution();
            var p = ctx.Predictor;

            var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings;

            Tensor inputPlaceholder  = tf.placeholder(tf.float32, new[] { -1, ctx.InputCodifications.Count }, "inputPlaceholder");
            Tensor outputPlaceholder = tf.placeholder(tf.float32, new[] { -1, ctx.OutputCodifications.Count }, "outputPlaceholder");

            Tensor currentTensor = inputPlaceholder;

            nn.HiddenLayers.ForEach((layer, i) =>
            {
                currentTensor = NetworkBuilder.DenseLayer(currentTensor, layer.Size, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i);
            });
            Tensor output           = NetworkBuilder.DenseLayer(currentTensor, ctx.OutputCodifications.Count, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output");
            Tensor calculatedOutput = tf.identity(output, "calculatedOutput");

            Tensor loss     = NetworkBuilder.GetEvalFunction(nn.LossFunction, outputPlaceholder, calculatedOutput);
            Tensor accuracy = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, outputPlaceholder, calculatedOutput);

            // prepare for training
            Optimizer optimizer = NetworkBuilder.GetOptimizer(nn);

            Operation trainOperation = optimizer.minimize(loss);

            Random rand = p.Settings.Seed == null ?
                          new Random() :
                          new Random(p.Settings.Seed.Value);

            var(training, validation) = ctx.SplitTrainValidation(rand);

            var minibachtSize  = nn.MinibatchSize;
            var numMinibatches = nn.NumMinibatches;


            Stopwatch             sw        = Stopwatch.StartNew();
            List <FinalCandidate> candidate = new List <FinalCandidate>();

            var config = new ConfigProto
            {
                IntraOpParallelismThreads = 1,
                InterOpParallelismThreads = 1,
                LogDevicePlacement        = true
            };

            ctx.ReportProgress($"Deleting Files");
            var dir = PredictorDirectory(ctx.Predictor);

            if (Directory.Exists(dir))
            {
                Directory.Delete(dir, true);
            }

            Directory.CreateDirectory(dir);

            ctx.ReportProgress($"Starting training...");

            var saver = tf.train.Saver();

            using (var sess = tf.Session(config))
            {
                sess.run(tf.global_variables_initializer());

                for (int i = 0; i < numMinibatches; i++)
                {
                    using (HeavyProfiler.Log("MiniBatch", () => i.ToString()))
                    {
                        var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList();

                        var inputValue  = CreateNDArray(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn);
                        var outputValue = CreateNDArray(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn);

                        using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString()))
                        {
                            sess.run(trainOperation,
                                     (inputPlaceholder, inputValue),
                                     (outputPlaceholder, outputValue));
                        }

                        if (ctx.StopTraining)
                        {
                            p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember();
                        }

                        var isLast = numMinibatches - nn.BestResultFromLast <= i;
                        if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining)
                        {
                            float loss_val;
                            float accuracy_val;

                            using (HeavyProfiler.Log("EvalTraining", () => i.ToString()))
                            {
                                (loss_val, accuracy_val) = sess.run((loss, accuracy),
                                                                    (inputPlaceholder, inputValue),
                                                                    (outputPlaceholder, outputValue));
                            }

                            var ep = new EpochProgress
                            {
                                Ellapsed           = sw.ElapsedMilliseconds,
                                Epoch              = i,
                                TrainingExamples   = i * minibachtSize,
                                LossTraining       = loss_val,
                                AccuracyTraining   = accuracy_val,
                                LossValidation     = null,
                                AccuracyValidation = null,
                            };

                            ctx.ReportProgress($"Training Minibatches Loss:{loss_val} / Accuracy:{accuracy_val}", (i + 1) / (decimal)numMinibatches);

                            ctx.Progresses.Enqueue(ep);

                            if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining)
                            {
                                using (HeavyProfiler.LogNoStackTrace("EvalValidation"))
                                {
                                    var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList();

                                    var inputValValue  = CreateNDArray(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn);
                                    var outputValValue = CreateNDArray(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn);

                                    (loss_val, accuracy_val) = sess.run((loss, accuracy),
                                                                        (inputPlaceholder, inputValValue),
                                                                        (outputPlaceholder, outputValValue));


                                    ep.LossValidation     = loss_val;
                                    ep.AccuracyValidation = accuracy_val;
                                }
                            }

                            var progress = ep.SaveEntity(ctx.Predictor);

                            if (isLast || ctx.StopTraining)
                            {
                                Directory.CreateDirectory(TrainingModelDirectory(ctx.Predictor, i));
                                var save = saver.save(sess, Path.Combine(TrainingModelDirectory(ctx.Predictor, i), ModelFileName));

                                using (HeavyProfiler.LogNoStackTrace("FinalCandidate"))
                                {
                                    candidate.Add(new FinalCandidate
                                    {
                                        ModelIndex     = i,
                                        ResultTraining = new PredictorMetricsEmbedded {
                                            Accuracy = progress.AccuracyTraining, Loss = progress.LossTraining
                                        },
                                        ResultValidation = new PredictorMetricsEmbedded {
                                            Accuracy = progress.AccuracyValidation, Loss = progress.LossValidation
                                        },
                                    });
                                }
                            }
                        }

                        if (ctx.StopTraining)
                        {
                            break;
                        }
                    }
                }
            }

            var best = candidate.WithMin(a => a.ResultValidation.Loss !.Value);

            p.ResultTraining   = best.ResultTraining;
            p.ResultValidation = best.ResultValidation;

            var files = Directory.GetFiles(TrainingModelDirectory(ctx.Predictor, best.ModelIndex));

            p.Files.AddRange(files.Select(p => new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, p)));

            using (OperationLogic.AllowSave <PredictorEntity>())
                p.Save();
        }
        //Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614
        public void Train(PredictorTrainingContext ctx)
        {
            InitialSetup();
            var p = ctx.Predictor;

            var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings;

            DeviceDescriptor device         = GetDevice(nn);
            Variable         inputVariable  = Variable.InputVariable(new[] { ctx.InputCodifications.Count }, DataType.Float, "input");
            Variable         outputVariable = Variable.InputVariable(new[] { ctx.OutputCodifications.Count }, DataType.Float, "output");

            Variable currentVar = inputVariable;

            nn.HiddenLayers.ForEach((layer, i) =>
            {
                currentVar = NetworkBuilder.DenseLayer(currentVar, layer.Size, device, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i);
            });
            Function calculatedOutputs = NetworkBuilder.DenseLayer(currentVar, ctx.OutputCodifications.Count, device, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output");

            Function loss      = NetworkBuilder.GetEvalFunction(nn.LossFunction, calculatedOutputs, outputVariable);
            Function evalError = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, calculatedOutputs, outputVariable);

            // prepare for training
            Learner learner = NetworkBuilder.GetInitializer(calculatedOutputs.Parameters(), nn);

            Trainer trainer = Trainer.CreateTrainer(calculatedOutputs, loss, evalError, new List <Learner>()
            {
                learner
            });

            Random rand = p.Settings.Seed == null ?
                          new Random() :
                          new Random(p.Settings.Seed.Value);

            var(training, validation) = ctx.SplitTrainValidation(rand);

            var minibachtSize  = nn.MinibatchSize;
            var numMinibatches = nn.NumMinibatches;

            Stopwatch             sw        = Stopwatch.StartNew();
            List <FinalCandidate> candidate = new List <FinalCandidate>();

            for (int i = 0; i < numMinibatches; i++)
            {
                using (HeavyProfiler.Log("MiniBatch", () => i.ToString()))
                {
                    ctx.ReportProgress("Training Minibatches", (i + 1) / (decimal)numMinibatches);

                    {
                        var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList();
                        using (Value inputValue = CreateValue(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device))
                            using (Value outputValue = CreateValue(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device))
                            {
                                using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString()))
                                    trainer.TrainMinibatch(new Dictionary <Variable, Value>()
                                    {
                                        { inputVariable, inputValue },
                                        { outputVariable, outputValue },
                                    }, false, device);
                            }
                    }

                    var ep = new EpochProgress
                    {
                        Ellapsed             = sw.ElapsedMilliseconds,
                        Epoch                = i,
                        TrainingExamples     = (int)trainer.TotalNumberOfSamplesSeen(),
                        LossTraining         = trainer.PreviousMinibatchLossAverage(),
                        EvaluationTraining   = trainer.PreviousMinibatchEvaluationAverage(),
                        LossValidation       = null,
                        EvaluationValidation = null,
                    };

                    ctx.Progresses.Enqueue(ep);

                    if (ctx.StopTraining)
                    {
                        p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember();
                    }

                    var isLast = numMinibatches - nn.BestResultFromLast <= i;
                    if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining)
                    {
                        if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining)
                        {
                            using (HeavyProfiler.LogNoStackTrace("Validation"))
                            {
                                var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList();

                                using (Value inputValValue = CreateValue(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device))
                                    using (Value outputValValue = CreateValue(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device))
                                    {
                                        var inputs = new Dictionary <Variable, Value>()
                                        {
                                            { inputVariable, inputValValue },
                                            { outputVariable, outputValValue },
                                        };

                                        ep.LossValidation       = loss.EvaluateAvg(inputs, device);
                                        ep.EvaluationValidation = evalError.EvaluateAvg(inputs, device);
                                    }
                            }
                        }

                        var progress = ep.SaveEntity(ctx.Predictor);

                        if (isLast || ctx.StopTraining)
                        {
                            using (HeavyProfiler.LogNoStackTrace("FinalCandidate"))
                            {
                                candidate.Add(new FinalCandidate
                                {
                                    Model = calculatedOutputs.Save(),

                                    ResultTraining = new PredictorMetricsEmbedded {
                                        Evaluation = progress.EvaluationTraining, Loss = progress.LossTraining
                                    },
                                    ResultValidation = new PredictorMetricsEmbedded {
                                        Evaluation = progress.EvaluationValidation, Loss = progress.LossValidation
                                    },
                                });
                            }
                        }
                    }

                    if (ctx.StopTraining)
                    {
                        break;
                    }
                }
            }

            var best = candidate.WithMin(a => a.ResultValidation.Loss !.Value);

            p.ResultTraining   = best.ResultTraining;
            p.ResultValidation = best.ResultValidation;

            var fp = new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, "Model.cntk", best.Model);

            p.Files.Add(fp);

            using (OperationLogic.AllowSave <PredictorEntity>())
                p.Save();
        }