예제 #1
0
        /// <summary>
        /// Main methods for model training
        /// </summary>
        /// <param name="mlconfigPath"></param>
        /// <param name="token"></param>
        /// <param name="trainingProgress"></param>
        /// <param name="pdevice"></param>
        /// <returns></returns>
        public static TrainResult TrainModel(string mlconfigPath, CancellationToken token, TrainingProgress trainingProgress, ProcessDevice pdevice)
        {
            try
            {
                //device definition
                DeviceDescriptor device = MLFactory.GetDevice(pdevice);

                //LOad ML configuration file
                var dicMParameters = MLFactory.LoadMLConfiguration(mlconfigPath);
                //add path of model folder
                dicMParameters.Add("root", Project.GetMLConfigFolder(mlconfigPath));

                //prepare NN data
                var retVal = MLFactory.PrepareNNData(dicMParameters, CustomNNModels.CustomModelCallEntryPoint, device);

                //create trainer
                MLTrainerEx tr = new MLTrainerEx(retVal.f.StreamConfigurations, retVal.f.InputVariables, retVal.f.OutputVariables);

                //setup model checkpoint
                string modelCheckPoint = null;
                if (dicMParameters.ContainsKey("configid"))
                {
                    modelCheckPoint = MLFactory.GetModelCheckPointPath(mlconfigPath, dicMParameters["configid"].Trim(' '));
                }

                //setup model checkpoint
                string historyPath = null;
                if (dicMParameters.ContainsKey("configid"))
                {
                    historyPath = MLFactory.GetTrainingHistoryPath(mlconfigPath, dicMParameters["configid"].Trim(' '));
                }

                //create trainer
                var trainer = tr.CreateTrainer(retVal.nnModel, retVal.lrData, retVal.trData, modelCheckPoint, historyPath);

                //perform training
                var result = tr.Train(trainer, retVal.nnModel, retVal.trData, retVal.mbs, device, token, trainingProgress, modelCheckPoint, historyPath);

                return(result);
            }
            catch (Exception)
            {
                throw;
            }
        }
예제 #2
0
        /// <summary>
        /// Main method to perform training process
        /// </summary>
        /// <param name="strFilePath">ML configuration file</param>
        /// <param name="device">Device of computation (GPU/CPU)</param>
        /// <param name="token">Cancellation token for interrupting training process on user request.</param>
        /// <param name="trainingProgress">training progress object</param>
        /// <param name="customModel">custom neural network model if available</param>
        public static TrainResult Run(string mlconfigPath, DeviceDescriptor device, CancellationToken token, TrainingProgress trainingProgress, CreateCustomModel customModel = null)
        {

            //LOad ML configuration file
            var dicMParameters = MLFactory.LoadMLConfiguration(mlconfigPath);

            var fi = new FileInfo(mlconfigPath);
            var folderPath = MLFactory.GetMLConfigFolder(fi.FullName);
            //add path of model folder
            dicMParameters.Add("root", folderPath);

            var retVal = MLFactory.PrepareNNData(dicMParameters, customModel, device);

            //create trainer 
            MLTrainerEx tr = new MLTrainerEx(retVal.f.StreamConfigurations, retVal.f.InputVariables, retVal.f.OutputVariables);

            //setup model checkpoint
            string modelCheckPoint = null;
            if (dicMParameters.ContainsKey("configid"))
            {
                modelCheckPoint = MLFactory.GetModelCheckPointPath(mlconfigPath, dicMParameters["configid"].Trim(' '));
            }
            //setup history of training path
            //TODO
            //create trainer 
            var trainer = tr.CreateTrainer(retVal.nnModel, retVal.lrData, retVal.trData, modelCheckPoint,null);

            //perform training
            var result = tr.Train(trainer, retVal.nnModel, retVal.trData, retVal.mbs, device, token, trainingProgress, modelCheckPoint, null);

            //delete previous best model before change variable values
            retVal.trData.LastBestModel =  MLFactory.ReplaceBestModel(retVal.trData, mlconfigPath, result.BestModelFile);

            //save best model to mlconifg file
            var trStrData = retVal.trData.ToString();
            var d = new Dictionary<string, string>();
            d.Add( "training", trStrData );
            //save to file
            MLFactory.SaveMLConfiguration(mlconfigPath, d);
            return result;
        }