コード例 #1
0
        public static Tensor GetEvalFunction(NeuralNetworkEvalFunction lossFunction, Tensor labels, Tensor calculatedOutputs)
        {
            switch (lossFunction)
            {
            case NeuralNetworkEvalFunction.sigmoid_cross_entropy_with_logits: return(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels, calculatedOutputs)));

            case NeuralNetworkEvalFunction.softmax_cross_entropy_with_logits: return(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels, calculatedOutputs)));

            case NeuralNetworkEvalFunction.softmax_cross_entropy_with_logits_v2: return(tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels, calculatedOutputs)));

            case NeuralNetworkEvalFunction.ClassificationError: return(NetworkBuilder.ClassificationError(calculatedOutputs, labels));

            case NeuralNetworkEvalFunction.MeanSquaredError: return(MeanSquaredError(labels, calculatedOutputs));

            case NeuralNetworkEvalFunction.MeanAbsoluteError: return(NetworkBuilder.MeanAbsoluteError(calculatedOutputs, labels));

            case NeuralNetworkEvalFunction.MeanAbsolutePercentageError: return(NetworkBuilder.MeanAbsolutePercentageError(calculatedOutputs, labels));

            default:
                throw new InvalidOperationException("Unexpected " + lossFunction);
            }
        }
コード例 #2
0
        //Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614
        public void Train(PredictorTrainingContext ctx)
        {
            InitialSetup();

            tf.compat.v1.disable_eager_execution();
            var p = ctx.Predictor;

            var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings;

            Tensor inputPlaceholder  = tf.placeholder(tf.float32, new[] { -1, ctx.InputCodifications.Count }, "inputPlaceholder");
            Tensor outputPlaceholder = tf.placeholder(tf.float32, new[] { -1, ctx.OutputCodifications.Count }, "outputPlaceholder");

            Tensor currentTensor = inputPlaceholder;

            nn.HiddenLayers.ForEach((layer, i) =>
            {
                currentTensor = NetworkBuilder.DenseLayer(currentTensor, layer.Size, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i);
            });
            Tensor output           = NetworkBuilder.DenseLayer(currentTensor, ctx.OutputCodifications.Count, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output");
            Tensor calculatedOutput = tf.identity(output, "calculatedOutput");

            Tensor loss     = NetworkBuilder.GetEvalFunction(nn.LossFunction, outputPlaceholder, calculatedOutput);
            Tensor accuracy = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, outputPlaceholder, calculatedOutput);

            // prepare for training
            Optimizer optimizer = NetworkBuilder.GetOptimizer(nn);

            Operation trainOperation = optimizer.minimize(loss);

            Random rand = p.Settings.Seed == null ?
                          new Random() :
                          new Random(p.Settings.Seed.Value);

            var(training, validation) = ctx.SplitTrainValidation(rand);

            var minibachtSize  = nn.MinibatchSize;
            var numMinibatches = nn.NumMinibatches;


            Stopwatch             sw        = Stopwatch.StartNew();
            List <FinalCandidate> candidate = new List <FinalCandidate>();

            var config = new ConfigProto
            {
                IntraOpParallelismThreads = 1,
                InterOpParallelismThreads = 1,
                LogDevicePlacement        = true
            };

            ctx.ReportProgress($"Deleting Files");
            var dir = PredictorDirectory(ctx.Predictor);

            if (Directory.Exists(dir))
            {
                Directory.Delete(dir, true);
            }

            Directory.CreateDirectory(dir);

            ctx.ReportProgress($"Starting training...");

            var saver = tf.train.Saver();

            using (var sess = tf.Session(config))
            {
                sess.run(tf.global_variables_initializer());

                for (int i = 0; i < numMinibatches; i++)
                {
                    using (HeavyProfiler.Log("MiniBatch", () => i.ToString()))
                    {
                        var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList();

                        var inputValue  = CreateNDArray(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn);
                        var outputValue = CreateNDArray(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn);

                        using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString()))
                        {
                            sess.run(trainOperation,
                                     (inputPlaceholder, inputValue),
                                     (outputPlaceholder, outputValue));
                        }

                        if (ctx.StopTraining)
                        {
                            p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember();
                        }

                        var isLast = numMinibatches - nn.BestResultFromLast <= i;
                        if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining)
                        {
                            float loss_val;
                            float accuracy_val;

                            using (HeavyProfiler.Log("EvalTraining", () => i.ToString()))
                            {
                                (loss_val, accuracy_val) = sess.run((loss, accuracy),
                                                                    (inputPlaceholder, inputValue),
                                                                    (outputPlaceholder, outputValue));
                            }

                            var ep = new EpochProgress
                            {
                                Ellapsed           = sw.ElapsedMilliseconds,
                                Epoch              = i,
                                TrainingExamples   = i * minibachtSize,
                                LossTraining       = loss_val,
                                AccuracyTraining   = accuracy_val,
                                LossValidation     = null,
                                AccuracyValidation = null,
                            };

                            ctx.ReportProgress($"Training Minibatches Loss:{loss_val} / Accuracy:{accuracy_val}", (i + 1) / (decimal)numMinibatches);

                            ctx.Progresses.Enqueue(ep);

                            if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining)
                            {
                                using (HeavyProfiler.LogNoStackTrace("EvalValidation"))
                                {
                                    var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList();

                                    var inputValValue  = CreateNDArray(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn);
                                    var outputValValue = CreateNDArray(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn);

                                    (loss_val, accuracy_val) = sess.run((loss, accuracy),
                                                                        (inputPlaceholder, inputValValue),
                                                                        (outputPlaceholder, outputValValue));


                                    ep.LossValidation     = loss_val;
                                    ep.AccuracyValidation = accuracy_val;
                                }
                            }

                            var progress = ep.SaveEntity(ctx.Predictor);

                            if (isLast || ctx.StopTraining)
                            {
                                Directory.CreateDirectory(TrainingModelDirectory(ctx.Predictor, i));
                                var save = saver.save(sess, Path.Combine(TrainingModelDirectory(ctx.Predictor, i), ModelFileName));

                                using (HeavyProfiler.LogNoStackTrace("FinalCandidate"))
                                {
                                    candidate.Add(new FinalCandidate
                                    {
                                        ModelIndex     = i,
                                        ResultTraining = new PredictorMetricsEmbedded {
                                            Accuracy = progress.AccuracyTraining, Loss = progress.LossTraining
                                        },
                                        ResultValidation = new PredictorMetricsEmbedded {
                                            Accuracy = progress.AccuracyValidation, Loss = progress.LossValidation
                                        },
                                    });
                                }
                            }
                        }

                        if (ctx.StopTraining)
                        {
                            break;
                        }
                    }
                }
            }

            var best = candidate.WithMin(a => a.ResultValidation.Loss !.Value);

            p.ResultTraining   = best.ResultTraining;
            p.ResultValidation = best.ResultValidation;

            var files = Directory.GetFiles(TrainingModelDirectory(ctx.Predictor, best.ModelIndex));

            p.Files.AddRange(files.Select(p => new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, p)));

            using (OperationLogic.AllowSave <PredictorEntity>())
                p.Save();
        }