internal async void RunTraining(CancellationToken token) { try { IsTrainRunning = true; //check if the model parameters are valid for training isModelParametersValid(); //first save model Save(); //raise event since some controls depend on this property //change icon of treeView item to indicate this model is running IconUri = "Images/runningmodel.png"; RaisePropertyChangedEvent("IsRunning"); //create training progress var //reset previous and return new object if (TrainingParameters.ContinueTraining == false) { TrainingProgress = new TrainingProgress() { Iteration = "0", MBLossValue = new List <PointPair>(), MBEvaluationValue = new List <PointPair>(), TrainEvalValue = new List <PointPair>(), ValidationEvalValue = new List <PointPair>(), TrainingLoss = "0", }; if (UpdateTrainingtGraphs != null) { UpdateTrainingtGraphs(0, 0, 0, 0, 0); } } if (TrainingProgress.MBLossValue.Count == 0 || TrainingParameters.Epochs > TrainingProgress.MBLossValue.Last().X) { //LOad ML configuration file var mlconfigPath = Project.GetMLConfigPath(Settings, Name); progressStartTraining(trainingProgress); // var res = await Task.Run <TrainResult>(() => Project.TrainModel(mlconfigPath, token, trainingProgress, m_device)); //save best trained model TrainingParameters.LastBestModel = Project.ReplaceBestModel(TrainingParameters, mlconfigPath, res.BestModelFile); //once the training process completes inform the GUI about it var appCnt = anndotnet.wnd.App.Current.MainWindow.DataContext as AppController; //send note to GUI the training process is completed IconUri = "Images/model.png"; RaisePropertyChangedEvent("IsRunning"); appCnt.TrainingCompleated(res); //save the mlconfig file after training process is over Save(); IsTrainRunning = false; } else { //once the training process completes inform the GUI about it var appCnt = anndotnet.wnd.App.Current.MainWindow.DataContext as AppController; //send note to GUI the training process is completed IconUri = "Images/model.png"; RaisePropertyChangedEvent("IsRunning"); appCnt.TrainingCompleated(new TrainResult() { ProcessState = ProcessState.Compleated, Iteration = TrainingParameters.Epochs }); IsTrainRunning = false; } } catch (Exception ex) { // throw; //once the training process completes inform the GUI about it var appCnt = anndotnet.wnd.App.Current.MainWindow.DataContext as AppController; //send note to GUI the training process is completed TrainResult res = new TrainResult(); res.BestModelFile = null; res.Iteration = 0; res.ProcessState = ProcessState.Crashed; //save best trained model TrainingParameters.LastBestModel = res.BestModelFile; appCnt.TrainingCompleated(res); appCnt.ReportException(ex); IsTrainRunning = false; } }
private TrainingProgress inittrainingProgress() { try { var trProg = new TrainingProgress() { Iteration = "0", MBLossValue = new List <PointPair>(), MBEvaluationValue = new List <PointPair>(), TrainEvalValue = new List <PointPair>(), ValidationEvalValue = new List <PointPair>(), TrainingLoss = "0", }; var configPath = Project.GetMLConfigPath(Settings, Name); var configId = Project.GetMLConfigId(configPath); if (configId == null) { return(trProg); } else { var strhistoryPath = Project.GetTrainingHistoryPath(configPath, configId); var fi = new FileInfo(strhistoryPath); if (fi.Exists) { var cnt = File.ReadLines(strhistoryPath); if (cnt == null || cnt.Count() == 0) { return(trProg); } var header = cnt.ElementAt(0).Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries); var lossFunc = header[0]; var evalFun = header[1]; foreach (var line in cnt.Skip(1)) { var row = line.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries); int it = int.Parse(row[0]); float lossMB = (float)double.Parse(row[1]); float evalMB = (float)double.Parse(row[2]); float trainEval = (float)double.Parse(row[3]); float valiEval = (float)double.Parse(row[4]); trProg.MBLossValue.Add(new PointPair(it, lossMB)); trProg.MBEvaluationValue.Add(new PointPair(it, evalMB)); trProg.TrainEvalValue.Add(new PointPair(it, trainEval)); trProg.ValidationEvalValue.Add(new PointPair(it, valiEval)); } // trProg.TrainingLoss = lossFunc; trProg.Iteration = $"{trProg.MBLossValue.Last().X} of {trProg.MBLossValue.Last().X}"; } return(trProg); } } catch (Exception) { throw; } }