示例#1
0
        public Trainer createTrainer(Function network, Variable target, double lrvalue)
        {
            //learning rate
            // var lrate = 0.05;
            var momentum = 0.9;
            var lr       = new TrainingParameterScheduleDouble(lrvalue);
            var mm       = CNTKLib.MomentumAsTimeConstantSchedule(momentum);
            var l        = new AdditionalLearningOptions()
            {
                l1RegularizationWeight = 0.001, l2RegularizationWeight = 0.1
            };
            //network parameters
            var zParams = new ParameterVector(network.Parameters().ToList());

            //create loss and eval
            Function loss = CNTKLib.SquaredError(network, target);
            Function eval = StatMetrics.RMSError(network, target);

            //learners
            //
            var llr  = new List <Learner>();
            var msgd = Learner.SGDLearner(network.Parameters(), lr, l);

            llr.Add(msgd);


            //trainer
            var trainer = Trainer.CreateTrainer(network, loss, eval, llr);

            //
            return(trainer);
        }
示例#2
0
        /// <summary>
        /// Defines CNTK Function with certain arguments.
        /// </summary>
        /// <param name="function">CNTK function</param>
        /// <param name="prediction">First parameters of the function.</param>
        /// <param name="target">Second parameters of the function</param>
        /// <returns></returns>
        private Function createFunction(EFunction function, Function prediction, Variable target)
        {
            switch (function)
            {
            case EFunction.BinaryCrossEntropy:
                return(CNTKLib.BinaryCrossEntropy(prediction, target, function.ToString()));

            case EFunction.CrossEntropyWithSoftmax:
                return(CNTKLib.CrossEntropyWithSoftmax(prediction, target, function.ToString()));

            case EFunction.ClassificationError:
                return(CNTKLib.ClassificationError(prediction, target, function.ToString()));

            case EFunction.SquaredError:
                return(CNTKLib.SquaredError(prediction, target, function.ToString()));

            case EFunction.RMSError:
                return(StatMetrics.RMSError(prediction, target, function.ToString()));

            case EFunction.MSError:
                return(StatMetrics.MSError(prediction, target, function.ToString()));

            case EFunction.ClassificationAccuracy:
                return(StatMetrics.ClassificationAccuracy(prediction, target, function.ToString()));

            default:
                throw new Exception($"The '{function}' function is not supported!");
            }
        }
示例#3
0
        private static void createWeightedSE()
        {
            NDShape shape     = new int[] { 15 };
            var     actual    = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
            var     predicted = new float[] { -5, -2, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37 };
            var     weights   = new NDArrayView(shape, new float[] { 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 }, device);
            var     aValue    = Value.CreateBatch(shape, actual, device, false);
            var     pValue    = Value.CreateBatch(shape, predicted, device, false);
            var     pWeights  = new Constant(weights); //Value.CreateBatch(shape, predicted, device, false);

            //define variables
            var vactual    = Variable.InputVariable(shape, DataType.Float);
            var vpredicted = Variable.InputVariable(shape, DataType.Float);

            //calculate weighted squared error
            var loss = StatMetrics.WeightedSE(vpredicted, vactual, pWeights);

            //evaluate function
            var inMap = new Dictionary <Variable, Value>()
            {
                { vpredicted, pValue }, { vactual, aValue }
            };

            var outMap = new Dictionary <Variable, Value>()
            {
                { loss, null }
            };

            loss.Evaluate(inMap, outMap, device);

            //
            var result = outMap[loss].GetDenseData <float>(loss);

            Assert.Equal((double)result[0][0], (double)7680.0, 2);
        }
示例#4
0
        /// <summary>
        /// Helper method in order to create training before training. It also try to restore trained from
        /// checkpoint file in order to continue with training
        /// </summary>
        /// <param name="network"></param>
        /// <param name="lrParams"></param>
        /// <param name="trParams"></param>
        /// <param name="modelCheckPoint"></param>
        /// <returns></returns>
        public Trainer CreateTrainer(Function network, LearningParameters lrParams, TrainingParameters trParams, string modelCheckPoint, string historyPath)
        {
            try
            {
                //create trainer
                var trainer = createTrainer(network, lrParams, trParams);

                //set initial value for the evaluation value
                m_PrevTrainingEval   = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()) ? double.MaxValue : double.MinValue;
                m_PrevValidationEval = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()) ? double.MaxValue : double.MinValue;

                //in case modelCheckpoint is saved and user select re-training existing trainer
                //first check if the checkpoint is available
                if (trParams.ContinueTraining && !string.IsNullOrEmpty(modelCheckPoint) && File.Exists(modelCheckPoint))
                {
                    //if the network model changed checkpoint state will throw exception
                    //in that case throw exception that re-training is not possible
                    try
                    {
                        trainer.RestoreFromCheckpoint(modelCheckPoint);
                        //load history of training in case continuation of training is requested
                        m_trainingHistory = loadTrainingHistory(historyPath);
                    }
                    catch (Exception)
                    {
                        throw new Exception("The Trainer cannot be restored from the previous state probably because the network has changed." +
                                            "\n Uncheck 'Continue Training' and train the model from scratch.");
                        throw;
                    }
                }

                else//delete checkpoint if exist in case no retraining is required,
                    //so the next checkpoint saving is free of previous checkpoints
                {
                    //delete heckpoint
                    if (File.Exists(modelCheckPoint))
                    {
                        File.Delete(modelCheckPoint);
                    }

                    //delete history
                    if (File.Exists(historyPath))
                    {
                        File.Delete(historyPath);
                    }
                }
                return(trainer);
            }
            catch (Exception)
            {
                throw;
            }
        }
示例#5
0
        /// <summary>
        /// Calback from the training in order to inform user about trining progress
        /// </summary>
        /// <param name="trParams"></param>
        /// <param name="trainer"></param>
        /// <param name="network"></param>
        /// <param name="mbs"></param>
        /// <param name="epoch"></param>
        /// <param name="progress"></param>
        /// <param name="device"></param>
        /// <returns></returns>
        protected virtual ProgressData progressTraining(TrainingParameters trParams, Trainer trainer,
                                                        Function network, MinibatchSourceEx mbs, int epoch, TrainingProgress progress, DeviceDescriptor device)
        {
            //calculate average training loss and evaluation
            var mbAvgLoss = trainer.PreviousMinibatchLossAverage();
            var mbAvgEval = trainer.PreviousMinibatchEvaluationAverage();
            var vars      = InputVariables.Union(OutputVariables).ToList();
            //get training dataset
            double trainEval = mbAvgEval;

            //sometimes when the data set is huge validation model against
            // full training dataset could take time, so we can skip it by setting parameter 'FullTrainingSetEval'
            if (trParams.FullTrainingSetEval)
            {
                if (m_TrainData == null || m_TrainData.Values.Any(x => x.data.IsValid == false))
                {
                    using (var streamDatat = MinibatchSourceEx.GetFullBatch(mbs.Type, mbs.TrainingDataFile, mbs.StreamConfigurations, device))
                    {
                        //get full training dataset
                        m_TrainData = MinibatchSourceEx.ToMinibatchData(streamDatat, vars, mbs.Type);
                    }
                    //perform evaluation of the current model on whole training dataset
                    trainEval = trainer.TestMinibatch(m_TrainData, device);
                }
            }

            string bestModelPath = m_bestModelPath;
            double validEval     = 0;

            //in case validation data set is empty don't perform test-minibatch
            if (!string.IsNullOrEmpty(mbs.ValidationDataFile))
            {
                if (m_ValidationData == null || m_ValidationData.Values.Any(x => x.data.IsValid == false))
                {
                    //get validation dataset
                    using (var streamData = MinibatchSourceEx.GetFullBatch(mbs.Type, mbs.ValidationDataFile, mbs.StreamConfigurations, device))
                    {
                        //store validation data for future testing
                        m_ValidationData = MinibatchSourceEx.ToMinibatchData(streamData, vars, mbs.Type);
                    }
                }
                //perform evaluation of the current model with validation dataset
                validEval = trainer.TestMinibatch(m_ValidationData, device);
            }

            //here we should decide if the current model worth to be saved into temp location
            // depending of the Evaluation function which sometimes can be better if it is greater that previous (e.g. ClassificationAccuracy)
            if (isBetterThanPrevious(trainEval, validEval, StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction())) && trParams.SaveModelWhileTraining)
            {
                //save model
                var strFilePath = $"{trParams.ModelTempLocation}\\model_at_{epoch}of{trParams.Epochs}_epochs_TimeSpan_{DateTime.Now.Ticks}";
                if (!Directory.Exists(trParams.ModelTempLocation))
                {
                    Directory.CreateDirectory(trParams.ModelTempLocation);
                }

                //save temp model
                network.Save(strFilePath);

                //set training and validation evaluation to previous state
                m_PrevTrainingEval   = trainEval;
                m_PrevValidationEval = validEval;
                bestModelPath        = strFilePath;

                var tpl = Tuple.Create <double, double, string>(trainEval, validEval, strFilePath);
                m_ModelEvaluations.Add(tpl);
            }


            m_bestModelPath = bestModelPath;

            //create progressData object
            var prData = new ProgressData();

            prData.EpochTotal           = trParams.Epochs;
            prData.EpochCurrent         = epoch;
            prData.EvaluationFunName    = trainer.EvaluationFunction().Name;
            prData.TrainEval            = trainEval;
            prData.ValidationEval       = validEval;
            prData.MinibatchAverageEval = mbAvgEval;
            prData.MinibatchAverageLoss = mbAvgLoss;

            //the progress is only reported if satisfied the following condition
            if (progress != null && (epoch % trParams.ProgressFrequency == 0 || epoch == 1 || epoch == trParams.Epochs))
            {
                //add info to the history
                m_trainingHistory.Add(new Tuple <int, float, float, float, float>(epoch, (float)mbAvgLoss, (float)mbAvgEval,
                                                                                  (float)trainEval, (float)validEval));

                //send progress
                progress(prData);
                //
                //Console.WriteLine($"Epoch={epoch} of {trParams.Epochs} processed.");
            }

            //return progress data
            return(prData);
        }
示例#6
0
        /// <summary>
        /// Main method for training
        /// </summary>
        /// <param name="trainer"></param>
        /// <param name="network"></param>
        /// <param name="trParams"></param>
        /// <param name="miniBatchSource"></param>
        /// <param name="device"></param>
        /// <param name="token"></param>
        /// <param name="progress"></param>
        /// <param name="modelCheckPoint"></param>
        /// <returns></returns>
        public override TrainResult Train(Trainer trainer, Function network, TrainingParameters trParams,
                                          MinibatchSourceEx miniBatchSource, DeviceDescriptor device, CancellationToken token, TrainingProgress progress, string modelCheckPoint, string historyPath)
        {
            try
            {
                //create trainer result.
                // the variable indicate how training process is ended
                // completed, stopped, crashed,
                var trainResult = new TrainResult();
                var historyFile = "";
                //create training process evaluation collection
                //for each iteration it is stored evaluationValue for training, and validation set with the model
                m_ModelEvaluations = new List <Tuple <double, double, string> >();

                //check what is the optimization (Minimization (error) or maximization (accuracy))
                bool isMinimize = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction());

                //setup first iteration
                if (m_trainingHistory == null)
                {
                    m_trainingHistory = new List <Tuple <int, float, float, float, float> >();
                }
                //in case of continuation of training iteration must start with the last of path previous training process
                int epoch = (m_trainingHistory.Count > 0)? m_trainingHistory.Last().Item1 + 1:1;

                //define progressData
                ProgressData prData = null;

                //define helper variable collection
                var vars = InputVariables.Union(OutputVariables).ToList();

                //training process
                while (true)
                {
                    //get mini batch data
                    var args = miniBatchSource.GetNextMinibatch(trParams.BatchSize, device);

                    var arguments = MinibatchSourceEx.ToMinibatchData(args, vars, miniBatchSource.Type);
                    //
                    trainer.TrainMinibatch(arguments, device);

                    //make progress
                    if (args.Any(a => a.Value.sweepEnd))
                    {
                        //check the progress of the training process
                        prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device);
                        //check if training process ends
                        if (epoch >= trParams.Epochs)
                        {
                            //save training checkpoint state
                            if (!string.IsNullOrEmpty(modelCheckPoint))
                            {
                                trainer.SaveCheckpoint(modelCheckPoint);
                            }

                            //save training history
                            if (!string.IsNullOrEmpty(historyPath))
                            {
                                string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};";
                                saveTrainingHistory(m_trainingHistory, header, historyPath);
                            }

                            //save best or last trained model and send report last time before trainer completes
                            var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize);
                            //
                            if (progress != null)
                            {
                                progress(prData);
                            }
                            //
                            trainResult.Iteration           = epoch;
                            trainResult.ProcessState        = ProcessState.Compleated;
                            trainResult.BestModelFile       = bestModelPath;
                            trainResult.TrainingHistoryFile = historyFile;
                            break;
                        }
                        else
                        {
                            epoch++;
                        }
                    }
                    //stop in case user request it
                    if (token.IsCancellationRequested)
                    {
                        if (!string.IsNullOrEmpty(modelCheckPoint))
                        {
                            trainer.SaveCheckpoint(modelCheckPoint);
                        }

                        //save training history
                        if (!string.IsNullOrEmpty(historyPath))
                        {
                            string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};";
                            saveTrainingHistory(m_trainingHistory, header, historyPath);
                        }

                        //sometime stopping training process can be before first epoch passed so make a incomplete progress
                        if (prData == null)//check the progress of the training process
                        {
                            prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device);
                        }

                        //save best or last trained model and send report last time before trainer terminates
                        var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize);
                        //
                        if (progress != null)
                        {
                            progress(prData);
                        }

                        //setup training result
                        trainResult.Iteration           = prData.EpochCurrent;
                        trainResult.ProcessState        = ProcessState.Stopped;
                        trainResult.BestModelFile       = bestModelPath;
                        trainResult.TrainingHistoryFile = historyFile;
                        break;
                    }
                }

                return(trainResult);
            }
            catch (Exception ex)
            {
                var ee = ex;
                throw;
            }
            finally
            {
            }
        }
示例#7
0
        protected virtual ProgressData progressTraining(TrainingParameters trParams, Trainer trainer,
                                                        Function network, MinibatchSourceEx mbs, int epoch, TrainingProgress progress, DeviceDescriptor device)
        {
            //calculate average training loss and evaluation
            var mbAvgLoss = trainer.PreviousMinibatchLossAverage();
            var mbAvgEval = trainer.PreviousMinibatchEvaluationAverage();

            //get training dataset
            double trainEval = mbAvgEval;

            //sometimes when the data set is huge validation model against
            // full training dataset could take time, so we can skip it by setting parameter 'FullTrainingSetEval'
            if (trParams.FullTrainingSetEval)
            {
                var evParams = new EvaluationParameters()
                {
                    MinibatchSize = trParams.BatchSize,
                    MBSource      = new MinibatchSourceEx(mbs.Type, this.StreamConfigurations.ToArray(), this.InputVariables, this.OutputVariables, mbs.TrainingDataFile, null, MinibatchSource.FullDataSweep, false, 0),
                    Ouptut        = OutputVariables,
                    Input         = InputVariables,
                };

                var result = MLEvaluator.EvaluateFunction(trainer.Model(), evParams, device);
                trainEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result.actual, result.predicted, device);

                ////if output has more than one dimension and when the output is not categorical but numeric with more than one value
                ////for now only custom mini-batch source is supported this kind of variable
                //if(OutputVariables.First().Shape.Dimensions.Last() > 1 && evParams.MBSource.Type== MinibatchType.Custom)
                //{
                //    var result1 = MLEvaluator.EvaluateFunctionEx(trainer.Model(), evParams, device);
                //    trainEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result1.actual, result1.predicted, device);
                //}
                //else
                //{
                //    var result = MLEvaluator.EvaluateFunction(trainer.Model(), evParams, device);
                //    trainEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result.actual, result.predicted, device);
                //}
            }

            string bestModelPath = m_bestModelPath;
            double validEval     = 0;

            //in case validation data set is empty don't perform test-minibatch
            if (!string.IsNullOrEmpty(mbs.ValidationDataFile))
            {
                var evParams = new EvaluationParameters()
                {
                    MinibatchSize = trParams.BatchSize,
                    //StrmsConfig = StreamConfigurations.ToArray(),
                    MBSource = new MinibatchSourceEx(mbs.Type, this.StreamConfigurations.ToArray(), this.InputVariables, this.OutputVariables, mbs.ValidationDataFile, null, MinibatchSource.FullDataSweep, false, 0),
                    Ouptut   = OutputVariables,
                    Input    = InputVariables,
                };
                //
                var result = MLEvaluator.EvaluateFunction(trainer.Model(), evParams, device);
                validEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result.actual, result.predicted, device);

                ////if output has more than one dimension and when the output is not categorical but numeric with more than one value
                ////for now only custom mini-batch source is supported this kind of variable
                //if (OutputVariables.First().Shape.Dimensions.Last() > 1 && evParams.MBSource.Type == MinibatchType.Custom)
                //{
                //    var result1 = MLEvaluator.EvaluateFunctionEx(trainer.Model(), evParams, device);
                //    validEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result1.actual, result1.predicted, device);
                //}
                //else
                //{
                //    var result = MLEvaluator.EvaluateFunction(trainer.Model(), evParams, device);
                //    validEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result.actual, result.predicted, device);
                //}
            }

            //here we should decide if the current model worth to be saved into temp location
            // depending of the Evaluation function which sometimes can be better if it is greater that previous (e.g. ClassificationAccuracy)
            if (isBetterThanPrevious(trainEval, validEval, StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction())) && trParams.SaveModelWhileTraining)
            {
                //save model
                var strFilePath = $"{trParams.ModelTempLocation}\\model_at_{epoch}of{trParams.Epochs}_epochs_TimeSpan_{DateTime.Now.Ticks}";
                if (!Directory.Exists(trParams.ModelTempLocation))
                {
                    Directory.CreateDirectory(trParams.ModelTempLocation);
                }

                //save temp model
                network.Save(strFilePath);

                //set training and validation evaluation to previous state
                m_PrevTrainingEval   = trainEval;
                m_PrevValidationEval = validEval;
                bestModelPath        = strFilePath;

                var tpl = Tuple.Create <double, double, string>(trainEval, validEval, strFilePath);
                m_ModelEvaluations.Add(tpl);
            }


            m_bestModelPath = bestModelPath;

            //create progressData object
            var prData = new ProgressData();

            prData.EpochTotal           = trParams.Epochs;
            prData.EpochCurrent         = epoch;
            prData.EvaluationFunName    = trainer.EvaluationFunction().Name;
            prData.TrainEval            = trainEval;
            prData.ValidationEval       = validEval;
            prData.MinibatchAverageEval = mbAvgEval;
            prData.MinibatchAverageLoss = mbAvgLoss;
            //prData.BestModel = bestModelPath;

            //the progress is only reported if satisfied the following condition
            if (progress != null && (epoch % trParams.ProgressFrequency == 0 || epoch == 1 || epoch == trParams.Epochs))
            {
                //add info to the history
                m_trainingHistory.Add(new Tuple <int, float, float, float, float>(epoch, (float)mbAvgLoss, (float)mbAvgEval,
                                                                                  (float)trainEval, (float)validEval));

                //send progress
                progress(prData);
                //
                //Console.WriteLine($"Epoch={epoch} of {trParams.Epochs} processed.");
            }

            //return progress data
            return(prData);
        }