예제 #1
0
        /// <summary>
        /// Helper method in order to create training before training. It also try to restore trained from
        /// checkpoint file in order to continue with training
        /// </summary>
        /// <param name="network"></param>
        /// <param name="lrParams"></param>
        /// <param name="trParams"></param>
        /// <param name="modelCheckPoint"></param>
        /// <returns></returns>
        public Trainer CreateTrainer(Function network, LearningParameters lrParams, TrainingParameters trParams, string modelCheckPoint, string historyPath)
        {
            try
            {
                //create trainer
                var trainer = createTrainer(network, lrParams, trParams);

                //set initial value for the evaluation value
                m_PrevTrainingEval   = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()) ? double.MaxValue : double.MinValue;
                m_PrevValidationEval = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()) ? double.MaxValue : double.MinValue;

                //in case modelCheckpoint is saved and user select re-training existing trainer
                //first check if the checkpoint is available
                if (trParams.ContinueTraining && !string.IsNullOrEmpty(modelCheckPoint) && File.Exists(modelCheckPoint))
                {
                    //if the network model changed checkpoint state will throw exception
                    //in that case throw exception that re-training is not possible
                    try
                    {
                        trainer.RestoreFromCheckpoint(modelCheckPoint);
                        //load history of training in case continuation of training is requested
                        m_trainingHistory = MLFactory.LoadTrainingHistory(historyPath);
                    }
                    catch (Exception)
                    {
                        throw new Exception("The Trainer cannot be restored from the previous state probably because the network has changed." +
                                            "\n Uncheck 'Continue Training' and train the model from scratch.");
                        throw;
                    }
                }

                else//delete checkpoint if exist in case no retraining is required,
                    //so the next checkpoint saving is free of previous checkpoints
                {
                    //delete heckpoint
                    if (File.Exists(modelCheckPoint))
                    {
                        File.Delete(modelCheckPoint);
                    }

                    //delete history
                    if (File.Exists(historyPath))
                    {
                        File.Delete(historyPath);
                    }
                }
                return(trainer);
            }
            catch (Exception)
            {
                throw;
            }
        }