コード例 #1
0
        internal async void RunTraining(CancellationToken token)
        {
            try
            {
                IsTrainRunning = true;

                //check if the model parameters are valid for training
                isModelParametersValid();

                //first save model
                Save();

                //raise event since some controls depend on this property
                //change icon of treeView item to indicate this model is running
                IconUri = "Images/runningmodel.png";
                RaisePropertyChangedEvent("IsRunning");

                //create training progress var
                //reset previous and return new object
                if (TrainingParameters.ContinueTraining == false)
                {
                    TrainingProgress = new TrainingProgress()
                    {
                        Iteration           = "0",
                        MBLossValue         = new List <PointPair>(),
                        MBEvaluationValue   = new List <PointPair>(),
                        TrainEvalValue      = new List <PointPair>(),
                        ValidationEvalValue = new List <PointPair>(),
                        TrainingLoss        = "0",
                    };
                    if (UpdateTrainingtGraphs != null)
                    {
                        UpdateTrainingtGraphs(0, 0, 0, 0, 0);
                    }
                }

                if (TrainingProgress.MBLossValue.Count == 0 || TrainingParameters.Epochs > TrainingProgress.MBLossValue.Last().X)
                {
                    //LOad ML configuration file
                    var mlconfigPath = Project.GetMLConfigPath(Settings, Name);

                    progressStartTraining(trainingProgress);
                    //
                    var res = await Task.Run <TrainResult>(() => Project.TrainModel(mlconfigPath, token, trainingProgress, m_device));

                    //save best trained model
                    TrainingParameters.LastBestModel = Project.ReplaceBestModel(TrainingParameters, mlconfigPath, res.BestModelFile);
                    //once the training process completes inform the GUI about it
                    var appCnt = anndotnet.wnd.App.Current.MainWindow.DataContext as AppController;
                    //send note to GUI the training process is completed
                    IconUri = "Images/model.png";
                    RaisePropertyChangedEvent("IsRunning");
                    appCnt.TrainingCompleated(res);

                    //save the mlconfig file after training process is over
                    Save();
                    IsTrainRunning = false;
                }
                else
                {
                    //once the training process completes inform the GUI about it
                    var appCnt = anndotnet.wnd.App.Current.MainWindow.DataContext as AppController;
                    //send note to GUI the training process is completed
                    IconUri = "Images/model.png";
                    RaisePropertyChangedEvent("IsRunning");
                    appCnt.TrainingCompleated(new TrainResult()
                    {
                        ProcessState = ProcessState.Compleated, Iteration = TrainingParameters.Epochs
                    });
                    IsTrainRunning = false;
                }
            }
            catch (Exception ex)
            {
                // throw;
                //once the training process completes inform the GUI about it
                var appCnt = anndotnet.wnd.App.Current.MainWindow.DataContext as AppController;
                //send note to GUI the training process is completed
                TrainResult res = new TrainResult();
                res.BestModelFile = null;
                res.Iteration     = 0;
                res.ProcessState  = ProcessState.Crashed;
                //save best trained model
                TrainingParameters.LastBestModel = res.BestModelFile;
                appCnt.TrainingCompleated(res);
                appCnt.ReportException(ex);
                IsTrainRunning = false;
            }
        }
コード例 #2
0
        private TrainingProgress inittrainingProgress()
        {
            try
            {
                var trProg = new TrainingProgress()
                {
                    Iteration           = "0",
                    MBLossValue         = new List <PointPair>(),
                    MBEvaluationValue   = new List <PointPair>(),
                    TrainEvalValue      = new List <PointPair>(),
                    ValidationEvalValue = new List <PointPair>(),
                    TrainingLoss        = "0",
                };

                var configPath = Project.GetMLConfigPath(Settings, Name);
                var configId   = Project.GetMLConfigId(configPath);
                if (configId == null)
                {
                    return(trProg);
                }
                else
                {
                    var strhistoryPath = Project.GetTrainingHistoryPath(configPath, configId);
                    var fi             = new FileInfo(strhistoryPath);
                    if (fi.Exists)
                    {
                        var cnt = File.ReadLines(strhistoryPath);
                        if (cnt == null || cnt.Count() == 0)
                        {
                            return(trProg);
                        }

                        var header = cnt.ElementAt(0).Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries);


                        var lossFunc = header[0];
                        var evalFun  = header[1];

                        foreach (var line in cnt.Skip(1))
                        {
                            var   row       = line.Split(new char[] { ';' }, StringSplitOptions.RemoveEmptyEntries);
                            int   it        = int.Parse(row[0]);
                            float lossMB    = (float)double.Parse(row[1]);
                            float evalMB    = (float)double.Parse(row[2]);
                            float trainEval = (float)double.Parse(row[3]);
                            float valiEval  = (float)double.Parse(row[4]);
                            trProg.MBLossValue.Add(new PointPair(it, lossMB));
                            trProg.MBEvaluationValue.Add(new PointPair(it, evalMB));
                            trProg.TrainEvalValue.Add(new PointPair(it, trainEval));
                            trProg.ValidationEvalValue.Add(new PointPair(it, valiEval));
                        }
                        //
                        trProg.TrainingLoss = lossFunc;
                        trProg.Iteration    = $"{trProg.MBLossValue.Last().X} of {trProg.MBLossValue.Last().X}";
                    }

                    return(trProg);
                }
            }
            catch (Exception)
            {
                throw;
            }
        }