/// <summary> /// Main method for training /// </summary> /// <param name="trainer"></param> /// <param name="network"></param> /// <param name="trParams"></param> /// <param name="miniBatchSource"></param> /// <param name="device"></param> /// <param name="token"></param> /// <param name="progress"></param> /// <param name="modelCheckPoint"></param> /// <returns></returns> public override TrainResult Train(Trainer trainer, Function network, TrainingParameters trParams, MinibatchSourceEx miniBatchSource, DeviceDescriptor device, CancellationToken token, TrainingProgress progress, string modelCheckPoint, string historyPath) { try { //create trainer result. // the variable indicate how training process is ended // completed, stopped, crashed, var trainResult = new TrainResult(); var historyFile = ""; //create training process evaluation collection //for each iteration it is stored evaluationValue for training, and validation set with the model m_ModelEvaluations = new List <Tuple <double, double, string> >(); //check what is the optimization (Minimization (error) or maximization (accuracy)) bool isMinimize = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()); //setup first iteration if (m_trainingHistory == null) { m_trainingHistory = new List <Tuple <int, float, float, float, float> >(); } //in case of continuation of training iteration must start with the last of path previous training process int epoch = (m_trainingHistory.Count > 0)? m_trainingHistory.Last().Item1 + 1:1; //define progressData ProgressData prData = null; //define helper variable collection var vars = InputVariables.Union(OutputVariables).ToList(); //training process while (true) { //get next mini batch data //var args = miniBatchSource.GetNextMinibatch(trParams.BatchSize, device); //var isSweepEnd = args.Any(a => a.Value.sweepEnd); ////prepare the data for trainer //var arguments = MinibatchSourceEx.ToMinibatchValueData(args, vars); // var isSweepEnd = false; var arguments = miniBatchSource.GetNextMinibatch(trParams.BatchSize, ref isSweepEnd, vars, device); trainer.TrainMinibatch(arguments, isSweepEnd, device); //make progress if (isSweepEnd) { //check the progress of the training process prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device); //check if training process ends if (epoch >= trParams.Epochs) { //save training checkpoint state if (!string.IsNullOrEmpty(modelCheckPoint)) { trainer.SaveCheckpoint(modelCheckPoint); } //save training history if (!string.IsNullOrEmpty(historyPath)) { string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};"; MLFactory.SaveTrainingHistory(m_trainingHistory, header, historyPath); } //save best or last trained model and send report last time before trainer completes var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize); // if (progress != null) { progress(prData); } // trainResult.Iteration = epoch; trainResult.ProcessState = ProcessState.Compleated; trainResult.BestModelFile = bestModelPath; trainResult.TrainingHistoryFile = historyFile; break; } else { epoch++; } } //stop in case user request it if (token.IsCancellationRequested) { if (!string.IsNullOrEmpty(modelCheckPoint)) { trainer.SaveCheckpoint(modelCheckPoint); } //save training history if (!string.IsNullOrEmpty(historyPath)) { string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};"; MLFactory.SaveTrainingHistory(m_trainingHistory, header, historyPath); } //sometime stopping training process can be before first epoch passed so make a incomplete progress if (prData == null)//check the progress of the training process { prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device); } //save best or last trained model and send report last time before trainer terminates var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize); // if (progress != null) { progress(prData); } //setup training result trainResult.Iteration = prData.EpochCurrent; trainResult.ProcessState = ProcessState.Stopped; trainResult.BestModelFile = bestModelPath; trainResult.TrainingHistoryFile = historyFile; break; } } return(trainResult); } catch (Exception ex) { var ee = ex; throw; } finally { } }