コード例 #1
0
        /// <summary>
        /// Main method to perform training process
        /// </summary>
        /// <param name="strFilePath">ML configuration file</param>
        /// <param name="device">Device of computation (GPU/CPU)</param>
        /// <param name="token">Cancellation token for interrupting training process on user request.</param>
        /// <param name="trainingProgress">training progress object</param>
        /// <param name="customModel">custom neural network model if available</param>
        public static TrainResult Train(string mlconfigPath, TrainingProgress trainingProgress, CancellationToken token, CreateCustomModel customModel = null)
        {
            //LOad ML configuration file
            var dicMParameters = MLFactory.LoadMLConfiguration(mlconfigPath);

            var fi         = new FileInfo(mlconfigPath);
            var folderPath = MLFactory.GetMLConfigFolder(fi.FullName);

            //add path of model folder
            dicMParameters.Add("root", folderPath);

            var retVal = MLFactory.PrepareNNData(dicMParameters, customModel, Device);

            //create trainer
            var tr = new MLTrainer(retVal.f.StreamConfigurations, retVal.f.InputVariables, retVal.f.OutputVariables);

            //setup model checkpoint
            string modelCheckPoint = null;

            if (dicMParameters.ContainsKey("configid"))
            {
                modelCheckPoint = MLFactory.GetModelCheckPointPath(mlconfigPath, dicMParameters["configid"].Trim(' '));
            }

            //setup model checkpoint
            string historyPath = null;

            if (dicMParameters.ContainsKey("configid"))
            {
                historyPath = MLFactory.GetTrainingHistoryPath(mlconfigPath, dicMParameters["configid"].Trim(' '));
            }

            //create trainer
            var trainer = tr.CreateTrainer(retVal.nnModel, retVal.lrData, retVal.trData, modelCheckPoint, historyPath);

            //perform training
            var result = tr.Train(trainer, retVal.nnModel, retVal.trData, retVal.mbs, Device, token, trainingProgress, modelCheckPoint, historyPath);

            //delete previous best model before change variable values
            retVal.trData.LastBestModel = MLFactory.ReplaceBestModel(retVal.trData, mlconfigPath, result.BestModelFile);

            //save best model to mlconifg file
            var trStrData = retVal.trData.ToString();
            var d         = new Dictionary <string, string>();

            d.Add("training", trStrData);
            //save to file
            MLFactory.SaveMLConfiguration(mlconfigPath, d);
            return(result);
        }
コード例 #2
0
ファイル: Project.cs プロジェクト: bhrnjica/anndotnet
        /// <summary>
        /// Main methods for model training
        /// </summary>
        /// <param name="mlconfigPath"></param>
        /// <param name="token"></param>
        /// <param name="trainingProgress"></param>
        /// <param name="pdevice"></param>
        /// <returns></returns>
        public static TrainResult TrainModel(string mlconfigPath, CancellationToken token, TrainingProgress trainingProgress, ProcessDevice pdevice)
        {
            try
            {
                //device definition
                DeviceDescriptor device = MLFactory.GetDevice(pdevice);

                //LOad ML configuration file
                var dicMParameters = MLFactory.LoadMLConfiguration(mlconfigPath);
                //add path of model folder
                dicMParameters.Add("root", Project.GetMLConfigFolder(mlconfigPath));

                //prepare NN data
                var retVal = MLFactory.PrepareNNData(dicMParameters, CustomNNModels.CustomModelCallEntryPoint, device);

                //create trainer
                var tr = new MLTrainer(retVal.f.StreamConfigurations, retVal.f.InputVariables, retVal.f.OutputVariables);

                //setup model checkpoint
                string modelCheckPoint = null;
                if (dicMParameters.ContainsKey("configid"))
                {
                    modelCheckPoint = MLFactory.GetModelCheckPointPath(mlconfigPath, dicMParameters["configid"].Trim(' '));
                }

                //setup model checkpoint
                string historyPath = null;
                if (dicMParameters.ContainsKey("configid"))
                {
                    historyPath = MLFactory.GetTrainingHistoryPath(mlconfigPath, dicMParameters["configid"].Trim(' '));
                }

                //create trainer
                var trainer = tr.CreateTrainer(retVal.nnModel, retVal.lrData, retVal.trData, modelCheckPoint, historyPath);

                //perform training
                var result = tr.Train(trainer, retVal.nnModel, retVal.trData, retVal.mbs, device, token, trainingProgress, modelCheckPoint, historyPath);

                return(result);
            }
            catch (Exception)
            {
                throw;
            }
        }