public Trainer createTrainer(Function network, Variable target, double lrvalue) { //learning rate // var lrate = 0.05; var momentum = 0.9; var lr = new TrainingParameterScheduleDouble(lrvalue); var mm = CNTKLib.MomentumAsTimeConstantSchedule(momentum); var l = new AdditionalLearningOptions() { l1RegularizationWeight = 0.001, l2RegularizationWeight = 0.1 }; //network parameters var zParams = new ParameterVector(network.Parameters().ToList()); //create loss and eval Function loss = CNTKLib.SquaredError(network, target); Function eval = StatMetrics.RMSError(network, target); //learners // var llr = new List <Learner>(); var msgd = Learner.SGDLearner(network.Parameters(), lr, l); llr.Add(msgd); //trainer var trainer = Trainer.CreateTrainer(network, loss, eval, llr); // return(trainer); }
/// <summary> /// Defines CNTK Function with certain arguments. /// </summary> /// <param name="function">CNTK function</param> /// <param name="prediction">First parameters of the function.</param> /// <param name="target">Second parameters of the function</param> /// <returns></returns> private Function createFunction(EFunction function, Function prediction, Variable target) { switch (function) { case EFunction.BinaryCrossEntropy: return(CNTKLib.BinaryCrossEntropy(prediction, target, function.ToString())); case EFunction.CrossEntropyWithSoftmax: return(CNTKLib.CrossEntropyWithSoftmax(prediction, target, function.ToString())); case EFunction.ClassificationError: return(CNTKLib.ClassificationError(prediction, target, function.ToString())); case EFunction.SquaredError: return(CNTKLib.SquaredError(prediction, target, function.ToString())); case EFunction.RMSError: return(StatMetrics.RMSError(prediction, target, function.ToString())); case EFunction.MSError: return(StatMetrics.MSError(prediction, target, function.ToString())); case EFunction.ClassificationAccuracy: return(StatMetrics.ClassificationAccuracy(prediction, target, function.ToString())); default: throw new Exception($"The '{function}' function is not supported!"); } }
private static void createWeightedSE() { NDShape shape = new int[] { 15 }; var actual = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }; var predicted = new float[] { -5, -2, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37 }; var weights = new NDArrayView(shape, new float[] { 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1 }, device); var aValue = Value.CreateBatch(shape, actual, device, false); var pValue = Value.CreateBatch(shape, predicted, device, false); var pWeights = new Constant(weights); //Value.CreateBatch(shape, predicted, device, false); //define variables var vactual = Variable.InputVariable(shape, DataType.Float); var vpredicted = Variable.InputVariable(shape, DataType.Float); //calculate weighted squared error var loss = StatMetrics.WeightedSE(vpredicted, vactual, pWeights); //evaluate function var inMap = new Dictionary <Variable, Value>() { { vpredicted, pValue }, { vactual, aValue } }; var outMap = new Dictionary <Variable, Value>() { { loss, null } }; loss.Evaluate(inMap, outMap, device); // var result = outMap[loss].GetDenseData <float>(loss); Assert.Equal((double)result[0][0], (double)7680.0, 2); }
/// <summary> /// Helper method in order to create training before training. It also try to restore trained from /// checkpoint file in order to continue with training /// </summary> /// <param name="network"></param> /// <param name="lrParams"></param> /// <param name="trParams"></param> /// <param name="modelCheckPoint"></param> /// <returns></returns> public Trainer CreateTrainer(Function network, LearningParameters lrParams, TrainingParameters trParams, string modelCheckPoint, string historyPath) { try { //create trainer var trainer = createTrainer(network, lrParams, trParams); //set initial value for the evaluation value m_PrevTrainingEval = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()) ? double.MaxValue : double.MinValue; m_PrevValidationEval = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()) ? double.MaxValue : double.MinValue; //in case modelCheckpoint is saved and user select re-training existing trainer //first check if the checkpoint is available if (trParams.ContinueTraining && !string.IsNullOrEmpty(modelCheckPoint) && File.Exists(modelCheckPoint)) { //if the network model changed checkpoint state will throw exception //in that case throw exception that re-training is not possible try { trainer.RestoreFromCheckpoint(modelCheckPoint); //load history of training in case continuation of training is requested m_trainingHistory = loadTrainingHistory(historyPath); } catch (Exception) { throw new Exception("The Trainer cannot be restored from the previous state probably because the network has changed." + "\n Uncheck 'Continue Training' and train the model from scratch."); throw; } } else//delete checkpoint if exist in case no retraining is required, //so the next checkpoint saving is free of previous checkpoints { //delete heckpoint if (File.Exists(modelCheckPoint)) { File.Delete(modelCheckPoint); } //delete history if (File.Exists(historyPath)) { File.Delete(historyPath); } } return(trainer); } catch (Exception) { throw; } }
/// <summary> /// Calback from the training in order to inform user about trining progress /// </summary> /// <param name="trParams"></param> /// <param name="trainer"></param> /// <param name="network"></param> /// <param name="mbs"></param> /// <param name="epoch"></param> /// <param name="progress"></param> /// <param name="device"></param> /// <returns></returns> protected virtual ProgressData progressTraining(TrainingParameters trParams, Trainer trainer, Function network, MinibatchSourceEx mbs, int epoch, TrainingProgress progress, DeviceDescriptor device) { //calculate average training loss and evaluation var mbAvgLoss = trainer.PreviousMinibatchLossAverage(); var mbAvgEval = trainer.PreviousMinibatchEvaluationAverage(); var vars = InputVariables.Union(OutputVariables).ToList(); //get training dataset double trainEval = mbAvgEval; //sometimes when the data set is huge validation model against // full training dataset could take time, so we can skip it by setting parameter 'FullTrainingSetEval' if (trParams.FullTrainingSetEval) { if (m_TrainData == null || m_TrainData.Values.Any(x => x.data.IsValid == false)) { using (var streamDatat = MinibatchSourceEx.GetFullBatch(mbs.Type, mbs.TrainingDataFile, mbs.StreamConfigurations, device)) { //get full training dataset m_TrainData = MinibatchSourceEx.ToMinibatchData(streamDatat, vars, mbs.Type); } //perform evaluation of the current model on whole training dataset trainEval = trainer.TestMinibatch(m_TrainData, device); } } string bestModelPath = m_bestModelPath; double validEval = 0; //in case validation data set is empty don't perform test-minibatch if (!string.IsNullOrEmpty(mbs.ValidationDataFile)) { if (m_ValidationData == null || m_ValidationData.Values.Any(x => x.data.IsValid == false)) { //get validation dataset using (var streamData = MinibatchSourceEx.GetFullBatch(mbs.Type, mbs.ValidationDataFile, mbs.StreamConfigurations, device)) { //store validation data for future testing m_ValidationData = MinibatchSourceEx.ToMinibatchData(streamData, vars, mbs.Type); } } //perform evaluation of the current model with validation dataset validEval = trainer.TestMinibatch(m_ValidationData, device); } //here we should decide if the current model worth to be saved into temp location // depending of the Evaluation function which sometimes can be better if it is greater that previous (e.g. ClassificationAccuracy) if (isBetterThanPrevious(trainEval, validEval, StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction())) && trParams.SaveModelWhileTraining) { //save model var strFilePath = $"{trParams.ModelTempLocation}\\model_at_{epoch}of{trParams.Epochs}_epochs_TimeSpan_{DateTime.Now.Ticks}"; if (!Directory.Exists(trParams.ModelTempLocation)) { Directory.CreateDirectory(trParams.ModelTempLocation); } //save temp model network.Save(strFilePath); //set training and validation evaluation to previous state m_PrevTrainingEval = trainEval; m_PrevValidationEval = validEval; bestModelPath = strFilePath; var tpl = Tuple.Create <double, double, string>(trainEval, validEval, strFilePath); m_ModelEvaluations.Add(tpl); } m_bestModelPath = bestModelPath; //create progressData object var prData = new ProgressData(); prData.EpochTotal = trParams.Epochs; prData.EpochCurrent = epoch; prData.EvaluationFunName = trainer.EvaluationFunction().Name; prData.TrainEval = trainEval; prData.ValidationEval = validEval; prData.MinibatchAverageEval = mbAvgEval; prData.MinibatchAverageLoss = mbAvgLoss; //the progress is only reported if satisfied the following condition if (progress != null && (epoch % trParams.ProgressFrequency == 0 || epoch == 1 || epoch == trParams.Epochs)) { //add info to the history m_trainingHistory.Add(new Tuple <int, float, float, float, float>(epoch, (float)mbAvgLoss, (float)mbAvgEval, (float)trainEval, (float)validEval)); //send progress progress(prData); // //Console.WriteLine($"Epoch={epoch} of {trParams.Epochs} processed."); } //return progress data return(prData); }
/// <summary> /// Main method for training /// </summary> /// <param name="trainer"></param> /// <param name="network"></param> /// <param name="trParams"></param> /// <param name="miniBatchSource"></param> /// <param name="device"></param> /// <param name="token"></param> /// <param name="progress"></param> /// <param name="modelCheckPoint"></param> /// <returns></returns> public override TrainResult Train(Trainer trainer, Function network, TrainingParameters trParams, MinibatchSourceEx miniBatchSource, DeviceDescriptor device, CancellationToken token, TrainingProgress progress, string modelCheckPoint, string historyPath) { try { //create trainer result. // the variable indicate how training process is ended // completed, stopped, crashed, var trainResult = new TrainResult(); var historyFile = ""; //create training process evaluation collection //for each iteration it is stored evaluationValue for training, and validation set with the model m_ModelEvaluations = new List <Tuple <double, double, string> >(); //check what is the optimization (Minimization (error) or maximization (accuracy)) bool isMinimize = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()); //setup first iteration if (m_trainingHistory == null) { m_trainingHistory = new List <Tuple <int, float, float, float, float> >(); } //in case of continuation of training iteration must start with the last of path previous training process int epoch = (m_trainingHistory.Count > 0)? m_trainingHistory.Last().Item1 + 1:1; //define progressData ProgressData prData = null; //define helper variable collection var vars = InputVariables.Union(OutputVariables).ToList(); //training process while (true) { //get mini batch data var args = miniBatchSource.GetNextMinibatch(trParams.BatchSize, device); var arguments = MinibatchSourceEx.ToMinibatchData(args, vars, miniBatchSource.Type); // trainer.TrainMinibatch(arguments, device); //make progress if (args.Any(a => a.Value.sweepEnd)) { //check the progress of the training process prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device); //check if training process ends if (epoch >= trParams.Epochs) { //save training checkpoint state if (!string.IsNullOrEmpty(modelCheckPoint)) { trainer.SaveCheckpoint(modelCheckPoint); } //save training history if (!string.IsNullOrEmpty(historyPath)) { string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};"; saveTrainingHistory(m_trainingHistory, header, historyPath); } //save best or last trained model and send report last time before trainer completes var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize); // if (progress != null) { progress(prData); } // trainResult.Iteration = epoch; trainResult.ProcessState = ProcessState.Compleated; trainResult.BestModelFile = bestModelPath; trainResult.TrainingHistoryFile = historyFile; break; } else { epoch++; } } //stop in case user request it if (token.IsCancellationRequested) { if (!string.IsNullOrEmpty(modelCheckPoint)) { trainer.SaveCheckpoint(modelCheckPoint); } //save training history if (!string.IsNullOrEmpty(historyPath)) { string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};"; saveTrainingHistory(m_trainingHistory, header, historyPath); } //sometime stopping training process can be before first epoch passed so make a incomplete progress if (prData == null)//check the progress of the training process { prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device); } //save best or last trained model and send report last time before trainer terminates var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize); // if (progress != null) { progress(prData); } //setup training result trainResult.Iteration = prData.EpochCurrent; trainResult.ProcessState = ProcessState.Stopped; trainResult.BestModelFile = bestModelPath; trainResult.TrainingHistoryFile = historyFile; break; } } return(trainResult); } catch (Exception ex) { var ee = ex; throw; } finally { } }
protected virtual ProgressData progressTraining(TrainingParameters trParams, Trainer trainer, Function network, MinibatchSourceEx mbs, int epoch, TrainingProgress progress, DeviceDescriptor device) { //calculate average training loss and evaluation var mbAvgLoss = trainer.PreviousMinibatchLossAverage(); var mbAvgEval = trainer.PreviousMinibatchEvaluationAverage(); //get training dataset double trainEval = mbAvgEval; //sometimes when the data set is huge validation model against // full training dataset could take time, so we can skip it by setting parameter 'FullTrainingSetEval' if (trParams.FullTrainingSetEval) { var evParams = new EvaluationParameters() { MinibatchSize = trParams.BatchSize, MBSource = new MinibatchSourceEx(mbs.Type, this.StreamConfigurations.ToArray(), this.InputVariables, this.OutputVariables, mbs.TrainingDataFile, null, MinibatchSource.FullDataSweep, false, 0), Ouptut = OutputVariables, Input = InputVariables, }; var result = MLEvaluator.EvaluateFunction(trainer.Model(), evParams, device); trainEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result.actual, result.predicted, device); ////if output has more than one dimension and when the output is not categorical but numeric with more than one value ////for now only custom mini-batch source is supported this kind of variable //if(OutputVariables.First().Shape.Dimensions.Last() > 1 && evParams.MBSource.Type== MinibatchType.Custom) //{ // var result1 = MLEvaluator.EvaluateFunctionEx(trainer.Model(), evParams, device); // trainEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result1.actual, result1.predicted, device); //} //else //{ // var result = MLEvaluator.EvaluateFunction(trainer.Model(), evParams, device); // trainEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result.actual, result.predicted, device); //} } string bestModelPath = m_bestModelPath; double validEval = 0; //in case validation data set is empty don't perform test-minibatch if (!string.IsNullOrEmpty(mbs.ValidationDataFile)) { var evParams = new EvaluationParameters() { MinibatchSize = trParams.BatchSize, //StrmsConfig = StreamConfigurations.ToArray(), MBSource = new MinibatchSourceEx(mbs.Type, this.StreamConfigurations.ToArray(), this.InputVariables, this.OutputVariables, mbs.ValidationDataFile, null, MinibatchSource.FullDataSweep, false, 0), Ouptut = OutputVariables, Input = InputVariables, }; // var result = MLEvaluator.EvaluateFunction(trainer.Model(), evParams, device); validEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result.actual, result.predicted, device); ////if output has more than one dimension and when the output is not categorical but numeric with more than one value ////for now only custom mini-batch source is supported this kind of variable //if (OutputVariables.First().Shape.Dimensions.Last() > 1 && evParams.MBSource.Type == MinibatchType.Custom) //{ // var result1 = MLEvaluator.EvaluateFunctionEx(trainer.Model(), evParams, device); // validEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result1.actual, result1.predicted, device); //} //else //{ // var result = MLEvaluator.EvaluateFunction(trainer.Model(), evParams, device); // validEval = MLEvaluator.CalculateMetrics(trainer.EvaluationFunction().Name, result.actual, result.predicted, device); //} } //here we should decide if the current model worth to be saved into temp location // depending of the Evaluation function which sometimes can be better if it is greater that previous (e.g. ClassificationAccuracy) if (isBetterThanPrevious(trainEval, validEval, StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction())) && trParams.SaveModelWhileTraining) { //save model var strFilePath = $"{trParams.ModelTempLocation}\\model_at_{epoch}of{trParams.Epochs}_epochs_TimeSpan_{DateTime.Now.Ticks}"; if (!Directory.Exists(trParams.ModelTempLocation)) { Directory.CreateDirectory(trParams.ModelTempLocation); } //save temp model network.Save(strFilePath); //set training and validation evaluation to previous state m_PrevTrainingEval = trainEval; m_PrevValidationEval = validEval; bestModelPath = strFilePath; var tpl = Tuple.Create <double, double, string>(trainEval, validEval, strFilePath); m_ModelEvaluations.Add(tpl); } m_bestModelPath = bestModelPath; //create progressData object var prData = new ProgressData(); prData.EpochTotal = trParams.Epochs; prData.EpochCurrent = epoch; prData.EvaluationFunName = trainer.EvaluationFunction().Name; prData.TrainEval = trainEval; prData.ValidationEval = validEval; prData.MinibatchAverageEval = mbAvgEval; prData.MinibatchAverageLoss = mbAvgLoss; //prData.BestModel = bestModelPath; //the progress is only reported if satisfied the following condition if (progress != null && (epoch % trParams.ProgressFrequency == 0 || epoch == 1 || epoch == trParams.Epochs)) { //add info to the history m_trainingHistory.Add(new Tuple <int, float, float, float, float>(epoch, (float)mbAvgLoss, (float)mbAvgEval, (float)trainEval, (float)validEval)); //send progress progress(prData); // //Console.WriteLine($"Epoch={epoch} of {trParams.Epochs} processed."); } //return progress data return(prData); }