Ejemplo n.º 1
0
        public static Function GetEvalFunction(NeuralNetworkEvalFunction lossFunction, Function calculatedOutputs, Variable outputVariable)
        {
            switch (lossFunction)
            {
            case NeuralNetworkEvalFunction.CrossEntropyWithSoftmax: return(CNTKLib.CrossEntropyWithSoftmax(calculatedOutputs, outputVariable));

            case NeuralNetworkEvalFunction.ClassificationError: return(CNTKLib.ClassificationError(calculatedOutputs, outputVariable));

            case NeuralNetworkEvalFunction.SquaredError: return(CNTKLib.SquaredError(calculatedOutputs, outputVariable));

            case NeuralNetworkEvalFunction.MeanAbsoluteError: return(NetworkBuilder.MeanAbsoluteError(calculatedOutputs, outputVariable));

            case NeuralNetworkEvalFunction.MeanAbsolutePercentageError: return(NetworkBuilder.MeanAbsolutePercentageError(calculatedOutputs, outputVariable));

            default:
                throw new InvalidOperationException("Unexpected " + lossFunction);
            }
        }
        //Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614
        public void Train(PredictorTrainingContext ctx)
        {
            InitialSetup();
            var p = ctx.Predictor;

            var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings;

            DeviceDescriptor device         = GetDevice(nn);
            Variable         inputVariable  = Variable.InputVariable(new[] { ctx.InputCodifications.Count }, DataType.Float, "input");
            Variable         outputVariable = Variable.InputVariable(new[] { ctx.OutputCodifications.Count }, DataType.Float, "output");

            Variable currentVar = inputVariable;

            nn.HiddenLayers.ForEach((layer, i) =>
            {
                currentVar = NetworkBuilder.DenseLayer(currentVar, layer.Size, device, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i);
            });
            Function calculatedOutputs = NetworkBuilder.DenseLayer(currentVar, ctx.OutputCodifications.Count, device, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output");

            Function loss      = NetworkBuilder.GetEvalFunction(nn.LossFunction, calculatedOutputs, outputVariable);
            Function evalError = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, calculatedOutputs, outputVariable);

            // prepare for training
            Learner learner = NetworkBuilder.GetInitializer(calculatedOutputs.Parameters(), nn);

            Trainer trainer = Trainer.CreateTrainer(calculatedOutputs, loss, evalError, new List <Learner>()
            {
                learner
            });

            Random rand = p.Settings.Seed == null ?
                          new Random() :
                          new Random(p.Settings.Seed.Value);

            var(training, validation) = ctx.SplitTrainValidation(rand);

            var minibachtSize  = nn.MinibatchSize;
            var numMinibatches = nn.NumMinibatches;

            Stopwatch             sw        = Stopwatch.StartNew();
            List <FinalCandidate> candidate = new List <FinalCandidate>();

            for (int i = 0; i < numMinibatches; i++)
            {
                using (HeavyProfiler.Log("MiniBatch", () => i.ToString()))
                {
                    ctx.ReportProgress("Training Minibatches", (i + 1) / (decimal)numMinibatches);

                    {
                        var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList();
                        using (Value inputValue = CreateValue(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device))
                            using (Value outputValue = CreateValue(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device))
                            {
                                using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString()))
                                    trainer.TrainMinibatch(new Dictionary <Variable, Value>()
                                    {
                                        { inputVariable, inputValue },
                                        { outputVariable, outputValue },
                                    }, false, device);
                            }
                    }

                    var ep = new EpochProgress
                    {
                        Ellapsed             = sw.ElapsedMilliseconds,
                        Epoch                = i,
                        TrainingExamples     = (int)trainer.TotalNumberOfSamplesSeen(),
                        LossTraining         = trainer.PreviousMinibatchLossAverage(),
                        EvaluationTraining   = trainer.PreviousMinibatchEvaluationAverage(),
                        LossValidation       = null,
                        EvaluationValidation = null,
                    };

                    ctx.Progresses.Enqueue(ep);

                    if (ctx.StopTraining)
                    {
                        p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember();
                    }

                    var isLast = numMinibatches - nn.BestResultFromLast <= i;
                    if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining)
                    {
                        if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining)
                        {
                            using (HeavyProfiler.LogNoStackTrace("Validation"))
                            {
                                var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList();

                                using (Value inputValValue = CreateValue(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device))
                                    using (Value outputValValue = CreateValue(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device))
                                    {
                                        var inputs = new Dictionary <Variable, Value>()
                                        {
                                            { inputVariable, inputValValue },
                                            { outputVariable, outputValValue },
                                        };

                                        ep.LossValidation       = loss.EvaluateAvg(inputs, device);
                                        ep.EvaluationValidation = evalError.EvaluateAvg(inputs, device);
                                    }
                            }
                        }

                        var progress = ep.SaveEntity(ctx.Predictor);

                        if (isLast || ctx.StopTraining)
                        {
                            using (HeavyProfiler.LogNoStackTrace("FinalCandidate"))
                            {
                                candidate.Add(new FinalCandidate
                                {
                                    Model = calculatedOutputs.Save(),

                                    ResultTraining = new PredictorMetricsEmbedded {
                                        Evaluation = progress.EvaluationTraining, Loss = progress.LossTraining
                                    },
                                    ResultValidation = new PredictorMetricsEmbedded {
                                        Evaluation = progress.EvaluationValidation, Loss = progress.LossValidation
                                    },
                                });
                            }
                        }
                    }

                    if (ctx.StopTraining)
                    {
                        break;
                    }
                }
            }

            var best = candidate.WithMin(a => a.ResultValidation.Loss !.Value);

            p.ResultTraining   = best.ResultTraining;
            p.ResultValidation = best.ResultValidation;

            var fp = new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, "Model.cntk", best.Model);

            p.Files.Add(fp);

            using (OperationLogic.AllowSave <PredictorEntity>())
                p.Save();
        }