Esempio n. 1
0
        /// <summary>
        /// Main method for training
        /// </summary>
        /// <param name="trainer"></param>
        /// <param name="network"></param>
        /// <param name="trParams"></param>
        /// <param name="miniBatchSource"></param>
        /// <param name="device"></param>
        /// <param name="token"></param>
        /// <param name="progress"></param>
        /// <param name="modelCheckPoint"></param>
        /// <returns></returns>
        public override TrainResult Train(Trainer trainer, Function network, TrainingParameters trParams,
                                          MinibatchSourceEx miniBatchSource, DeviceDescriptor device, CancellationToken token, TrainingProgress progress, string modelCheckPoint, string historyPath)
        {
            try
            {
                //create trainer result.
                // the variable indicate how training process is ended
                // completed, stopped, crashed,
                var trainResult = new TrainResult();
                var historyFile = "";
                //create training process evaluation collection
                //for each iteration it is stored evaluationValue for training, and validation set with the model
                m_ModelEvaluations = new List <Tuple <double, double, string> >();

                //check what is the optimization (Minimization (error) or maximization (accuracy))
                bool isMinimize = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction());

                //setup first iteration
                if (m_trainingHistory == null)
                {
                    m_trainingHistory = new List <Tuple <int, float, float, float, float> >();
                }
                //in case of continuation of training iteration must start with the last of path previous training process
                int epoch = (m_trainingHistory.Count > 0)? m_trainingHistory.Last().Item1 + 1:1;

                //define progressData
                ProgressData prData = null;

                //define helper variable collection
                var vars = InputVariables.Union(OutputVariables).ToList();

                //training process
                while (true)
                {
                    //get mini batch data
                    var args = miniBatchSource.GetNextMinibatch(trParams.BatchSize, device);

                    var arguments = MinibatchSourceEx.ToMinibatchData(args, vars, miniBatchSource.Type);
                    //
                    trainer.TrainMinibatch(arguments, device);

                    //make progress
                    if (args.Any(a => a.Value.sweepEnd))
                    {
                        //check the progress of the training process
                        prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device);
                        //check if training process ends
                        if (epoch >= trParams.Epochs)
                        {
                            //save training checkpoint state
                            if (!string.IsNullOrEmpty(modelCheckPoint))
                            {
                                trainer.SaveCheckpoint(modelCheckPoint);
                            }

                            //save training history
                            if (!string.IsNullOrEmpty(historyPath))
                            {
                                string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};";
                                saveTrainingHistory(m_trainingHistory, header, historyPath);
                            }

                            //save best or last trained model and send report last time before trainer completes
                            var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize);
                            //
                            if (progress != null)
                            {
                                progress(prData);
                            }
                            //
                            trainResult.Iteration           = epoch;
                            trainResult.ProcessState        = ProcessState.Compleated;
                            trainResult.BestModelFile       = bestModelPath;
                            trainResult.TrainingHistoryFile = historyFile;
                            break;
                        }
                        else
                        {
                            epoch++;
                        }
                    }
                    //stop in case user request it
                    if (token.IsCancellationRequested)
                    {
                        if (!string.IsNullOrEmpty(modelCheckPoint))
                        {
                            trainer.SaveCheckpoint(modelCheckPoint);
                        }

                        //save training history
                        if (!string.IsNullOrEmpty(historyPath))
                        {
                            string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};";
                            saveTrainingHistory(m_trainingHistory, header, historyPath);
                        }

                        //sometime stopping training process can be before first epoch passed so make a incomplete progress
                        if (prData == null)//check the progress of the training process
                        {
                            prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device);
                        }

                        //save best or last trained model and send report last time before trainer terminates
                        var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize);
                        //
                        if (progress != null)
                        {
                            progress(prData);
                        }

                        //setup training result
                        trainResult.Iteration           = prData.EpochCurrent;
                        trainResult.ProcessState        = ProcessState.Stopped;
                        trainResult.BestModelFile       = bestModelPath;
                        trainResult.TrainingHistoryFile = historyFile;
                        break;
                    }
                }

                return(trainResult);
            }
            catch (Exception ex)
            {
                var ee = ex;
                throw;
            }
            finally
            {
            }
        }
Esempio n. 2
0
        /// <summary>
        /// Test cntk model stored at 'modelPath' against array of image paths
        /// </summary>
        /// <param name="modelPath"></param>
        /// <param name="vector"></param>
        /// <param name="device"></param>
        /// <returns></returns>
        public static List <int> TestModel(string modelPath, string[] imagePaths, DeviceDescriptor device)
        {
            try
            {
                //
                FileInfo fi = new FileInfo(modelPath);
                if (!fi.Exists)
                {
                    throw new Exception($"The '{fi.FullName}' does not exist. Make sure the model is places at this location.");
                }

                //load the model from disk
                var model = Function.Load(fi.FullName, device);
                //get input feature
                var features     = model.Arguments.ToList();
                var labels       = model.Outputs.ToList();
                var stremsConfig = MLFactory.CreateStreamConfiguration(features, labels);
                var mapFile      = "testMapFile";
                File.WriteAllLines(mapFile, imagePaths.Select(x => $"{x}\t0"));

                var testMB = new MinibatchSourceEx(MinibatchType.Image, stremsConfig.ToArray(), features, labels, mapFile, null, 30, false, 0);

                //
                var vars   = features.Union(labels).ToList();
                var retVal = new List <int>();
                var mbSize = imagePaths.Count();
                if (mbSize > 30)
                {
                    mbSize = 30;
                }
                while (true)
                {
                    bool isSweepEnd = false;
                    var  inputMap   = testMB.GetNextMinibatch((uint)mbSize, ref isSweepEnd, vars, device);
                    //prepare data for trainer
                    //var inputMap = new Dictionary<Variable, Value>();
                    //inputMap.Add(features.First(), nextMB.Where(x => x.Key.m_name.Equals(features.First().Name)).Select(x => x.Value.data).FirstOrDefault());


                    var outputMap = new Dictionary <Variable, Value>();
                    outputMap.Add(labels.First(), null);
                    //evaluate model
                    model.Evaluate(inputMap, outputMap, device);
                    var result = outputMap[labels.First()].GetDenseData <float>(labels.First());

                    //extract result
                    foreach (var r in result)
                    {
                        var l = MLValue.GetResult(r);
                        retVal.Add((int)l);
                    }

                    if (/*nextMB.Any(x => x.Value.sweepEnd)*/ isSweepEnd)
                    {
                        break;
                    }
                }


                return(retVal);
            }
            catch (Exception)
            {
                throw;
            }
        }