Esempio n. 1
0
 public Dictionary <Variable, Value> GetNextMinibatch(uint minibatchSizeInSamples, ref bool sweepEnd, List <Variable> vars, DeviceDescriptor device)
 {
     if (Type == MinibatchType.Default || Type == MinibatchType.Image)
     {
         var args = defaultmb.GetNextMinibatch(minibatchSizeInSamples, device);
         sweepEnd = args.Any(x => x.Value.sweepEnd);
         //
         var arguments = MinibatchSourceEx.ToMinibatchValueData(args, vars);
         return(arguments);
     }
     else if (Type == MinibatchType.Custom)
     {
         var retVal = nextBatch(custommb, StreamConfigurations, (int)minibatchSizeInSamples);
         var mb     = new Dictionary <Variable, Value>();
         sweepEnd = custommb.EndOfStream;
         //create minibatch
         foreach (var d in retVal)
         {
             var v = Value.CreateBatchOfSequences <float>(new NDShape(1, d.Key.m_dim), d.Value, device);
             //
             var var = vars.Where(x => x.Name == d.Key.m_streamName).FirstOrDefault();
             if (var == null)
             {
                 throw new Exception("Variable cannot be  null!");
             }
             //
             mb.Add(var, v);
         }
         return(mb);
     }
     else
     {
         throw new Exception("Unsupported Mini-batch-source type!");
     }
 }
Esempio n. 2
0
        /// <summary>
        /// Returns part of mldataset with features labels columns this is needed in case Excel export is performed.
        /// </summary>
        /// <param name="fun"></param>
        /// <param name="evParam"></param>
        /// <param name="device"></param>
        /// <returns></returns>
        public static Dictionary <string, List <List <float> > > FeaturesAndLabels(Function fun, EvaluationParameters evParam, DeviceDescriptor device)
        {
            try
            {
                //declare return vars
                var featDic = new Dictionary <string, List <List <float> > >();

                while (true)
                {
                    //get one minibatch of data for training
                    var mbData   = evParam.MBSource.GetNextMinibatch(evParam.MinibatchSize, device);
                    var mdDataEx = MinibatchSourceEx.ToMinibatchValueData(mbData, evParam.Input.Union(evParam.Ouptut).ToList());
                    var inMap    = new Dictionary <Variable, Value>();

                    //input
                    var vars = evParam.Input;
                    for (int i = 0; i < vars.Count() /*mdDataEx.Count*/; i++)
                    {
                        var vv = vars.ElementAt(i);
                        var d  = mdDataEx.Where(x => x.Key.Name.Equals(vv.Name)).FirstOrDefault();
                        //
                        var fv = MLValue.GetValues(d.Key, d.Value);
                        if (featDic.ContainsKey(d.Key.Name))
                        {
                            featDic[d.Key.Name].AddRange(fv);
                        }
                        else
                        {
                            featDic.Add(d.Key.Name, fv);
                        }
                    }
                    //output
                    var varso = evParam.Ouptut;
                    for (int i = 0; i < varso.Count() /*mdDataEx.Count*/; i++)
                    {
                        var vv = varso.ElementAt(i);
                        var d  = mdDataEx.Where(x => x.Key.Name.Equals(vv.Name)).FirstOrDefault();
                        //
                        var fv = MLValue.GetValues(d.Key, d.Value);
                        if (vv.Shape.Dimensions.Last() == 1)
                        {
                            var value = fv.Select(l => new List <float>()
                            {
                                l.First()
                            }).ToList();
                            if (featDic.ContainsKey(d.Key.Name))
                            {
                                featDic[d.Key.Name].AddRange(value);
                            }
                            else
                            {
                                featDic.Add(d.Key.Name, value);
                            }
                        }
                        else
                        {
                            var value = fv.Select(l => new List <float>()
                            {
                                l.IndexOf(l.Max())
                            }).ToList();
                            if (featDic.ContainsKey(d.Key.Name))
                            {
                                featDic[d.Key.Name].AddRange(value);
                            }
                            else
                            {
                                featDic.Add(d.Key.Name, value);
                            }
                        }
                    }

                    // check if sweep end reached
                    if (mbData.Any(x => x.Value.sweepEnd))
                    {
                        break;
                    }
                }

                return(featDic);
            }
            catch (Exception)
            {
                throw;
            }
        }
Esempio n. 3
0
        public static (List <List <float> > actual, List <List <float> > predicted) EvaluateFunctionEx(Function fun, EvaluationParameters evParam, DeviceDescriptor device)
        {
            try
            {
                //declare return vars
                List <List <float> > actualLst    = new List <List <float> >();
                List <List <float> > predictedLst = new List <List <float> >();

                while (true)
                {
                    Value predicted = null;
                    //get one minibatch of data for training
                    var mbData   = evParam.MBSource.GetNextMinibatch(evParam.MinibatchSize, device);
                    var mbDataEx = MinibatchSourceEx.ToMinibatchValueData(mbData, evParam.Input.Union(evParam.Ouptut).ToList());
                    var inMap    = new Dictionary <Variable, Value>();
                    //
                    var vars = fun.Arguments.Union(fun.Outputs);
                    for (int i = 0; i < vars.Count() /* mbDataEx.Count*/; i++)
                    {
                        var d = mbDataEx.ElementAt(i);
                        var v = vars.Where(x => x.Name.Equals(d.Key.Name)).First();
                        //skip output data
                        if (!evParam.Ouptut.Any(x => x.Name.Equals(v.Name)))
                        {
                            inMap.Add(v, d.Value);
                        }
                    }

                    //actual data if t is available
                    var actualVar = mbDataEx.Keys.Where(x => x.Name.Equals(evParam.Ouptut.First().Name)).FirstOrDefault();
                    var act       = mbDataEx[actualVar].GetDenseData <float>(actualVar).Select(l => l.ToList());
                    actualLst.AddRange(act);

                    //predicted data
                    //map variables and data
                    var predictedDataMap = new Dictionary <Variable, Value>()
                    {
                        { fun, null }
                    };

                    //evaluates model
                    fun.Evaluate(inMap, predictedDataMap, device);
                    predicted = predictedDataMap.Values.First();
                    var pred = predicted.GetDenseData <float>(fun).Select(l => l.ToList());
                    predicted.Erase();
                    predicted.Dispose();
                    predictedLst.AddRange(pred);

                    // check if sweep end reached
                    if (mbData.Any(x => x.Value.sweepEnd))
                    {
                        break;
                    }
                }

                return(actualLst, predictedLst);
            }
            catch (Exception)
            {
                throw;
            }
        }
Esempio n. 4
0
        /// <summary>
        /// Main method for training
        /// </summary>
        /// <param name="trainer"></param>
        /// <param name="network"></param>
        /// <param name="trParams"></param>
        /// <param name="miniBatchSource"></param>
        /// <param name="device"></param>
        /// <param name="token"></param>
        /// <param name="progress"></param>
        /// <param name="modelCheckPoint"></param>
        /// <returns></returns>
        public override TrainResult Train(Trainer trainer, Function network, TrainingParameters trParams,
                                          MinibatchSourceEx miniBatchSource, DeviceDescriptor device, CancellationToken token, TrainingProgress progress, string modelCheckPoint, string historyPath)
        {
            try
            {
                //create trainer result.
                // the variable indicate how training process is ended
                // completed, stopped, crashed,
                var trainResult = new TrainResult();
                var historyFile = "";
                //create training process evaluation collection
                //for each iteration it is stored evaluationValue for training, and validation set with the model
                m_ModelEvaluations = new List <Tuple <double, double, string> >();

                //check what is the optimization (Minimization (error) or maximization (accuracy))
                bool isMinimize = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction());

                //setup first iteration
                if (m_trainingHistory == null)
                {
                    m_trainingHistory = new List <Tuple <int, float, float, float, float> >();
                }
                //in case of continuation of training iteration must start with the last of path previous training process
                int epoch = (m_trainingHistory.Count > 0)? m_trainingHistory.Last().Item1 + 1:1;

                //define progressData
                ProgressData prData = null;

                //define helper variable collection
                var vars = InputVariables.Union(OutputVariables).ToList();

                //training process
                while (true)
                {
                    //get next mini batch data
                    var args       = miniBatchSource.GetNextMinibatch(trParams.BatchSize, device);
                    var isSweepEnd = args.Any(a => a.Value.sweepEnd);
                    //prepare the data for trainer
                    var arguments = MinibatchSourceEx.ToMinibatchValueData(args, vars);
                    GC.Collect();//remove this line after testing phase
                    trainer.TrainMinibatch(arguments, isSweepEnd, device);

                    //make progress
                    if (isSweepEnd)
                    {
                        //check the progress of the training process
                        prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device);
                        //check if training process ends
                        if (epoch >= trParams.Epochs)
                        {
                            //save training checkpoint state
                            if (!string.IsNullOrEmpty(modelCheckPoint))
                            {
                                trainer.SaveCheckpoint(modelCheckPoint);
                            }

                            //save training history
                            if (!string.IsNullOrEmpty(historyPath))
                            {
                                string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};";
                                saveTrainingHistory(m_trainingHistory, header, historyPath);
                            }

                            //save best or last trained model and send report last time before trainer completes
                            var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize);
                            //
                            if (progress != null)
                            {
                                progress(prData);
                            }
                            //
                            trainResult.Iteration           = epoch;
                            trainResult.ProcessState        = ProcessState.Compleated;
                            trainResult.BestModelFile       = bestModelPath;
                            trainResult.TrainingHistoryFile = historyFile;
                            break;
                        }
                        else
                        {
                            epoch++;
                        }
                    }
                    //stop in case user request it
                    if (token.IsCancellationRequested)
                    {
                        if (!string.IsNullOrEmpty(modelCheckPoint))
                        {
                            trainer.SaveCheckpoint(modelCheckPoint);
                        }

                        //save training history
                        if (!string.IsNullOrEmpty(historyPath))
                        {
                            string header = $"{trainer.LossFunction().Name};{trainer.EvaluationFunction().Name};";
                            saveTrainingHistory(m_trainingHistory, header, historyPath);
                        }

                        //sometime stopping training process can be before first epoch passed so make a incomplete progress
                        if (prData == null)//check the progress of the training process
                        {
                            prData = progressTraining(trParams, trainer, network, miniBatchSource, epoch, progress, device);
                        }

                        //save best or last trained model and send report last time before trainer terminates
                        var bestModelPath = saveBestModel(trParams, trainer.Model(), epoch, isMinimize);
                        //
                        if (progress != null)
                        {
                            progress(prData);
                        }

                        //setup training result
                        trainResult.Iteration           = prData.EpochCurrent;
                        trainResult.ProcessState        = ProcessState.Stopped;
                        trainResult.BestModelFile       = bestModelPath;
                        trainResult.TrainingHistoryFile = historyFile;
                        break;
                    }
                }

                return(trainResult);
            }
            catch (Exception ex)
            {
                var ee = ex;
                throw;
            }
            finally
            {
            }
        }