public (double loss, double metric) Evaluate(IMinibatchSource minibatchSource) { // create loss and metric evaluators. using (var lossEvaluator = new MetricEvaluator(m_loss, m_device)) using (var metricEvaluator = new MetricEvaluator(m_metric, m_device)) { bool isSweepEnd = false; while (!isSweepEnd) { // TODO: Add Support for other evaluation batch sizes. const int evaluationBatchSize = 1; var nextMinibatch = minibatchSource.GetNextMinibatch(evaluationBatchSize, m_device); var minibatch = nextMinibatch.minibatch; isSweepEnd = nextMinibatch.isSweepEnd; lossEvaluator.EvalauteNextStep(minibatch, evaluationBatchSize); metricEvaluator.EvalauteNextStep(minibatch, evaluationBatchSize); } var finalLoss = lossEvaluator.CurrentMetric; var finalMetric = metricEvaluator.CurrentMetric; return((float)finalLoss, (float)finalMetric); } }
public Dictionary <string, List <double> > Fit(IMinibatchSource trainMinibatchSource = null, int batchSize = 32, int epochs = 1, IMinibatchSource validationMinibatchSource = null) { // Setup fitter. var fitter = new Fitter(m_trainer, m_device); // TODO: Refactor to callback style reporting for each metric, instead of returning a dictionary. // store epoch history var lossValidationHistory = new Dictionary <string, List <double> > { { m_lossName, new List <double>() }, { m_metricName, new List <double>() }, { m_validationLossName, new List <double>() }, { m_validationMetricName, new List <double>() }, }; for (int epoch = 0; epoch < epochs;) { var(minibatch, isSweepEnd) = trainMinibatchSource.GetNextMinibatch(batchSize, m_device); fitter.FitNextStep(minibatch, batchSize); if (isSweepEnd) { // Get current loss and metric, reset accumulators for next epoch. var currentLoss = fitter.CurrentLoss; var currentMetric = fitter.CurrentMetric; fitter.ResetLossAndMetricAccumulators(); lossValidationHistory[m_lossName].Add(currentLoss); lossValidationHistory[m_metricName].Add(currentMetric); var traceOutput = $"Epoch: {epoch + 1:000} Loss = {currentLoss:F8}, Metric = {currentMetric:F8}"; ++epoch; if (validationMinibatchSource != null) { (var validationLoss, var validationMetric) = Evaluate(validationMinibatchSource); traceOutput += $" - ValidationLoss = {validationLoss:F8}, ValidationMetric = {validationMetric:F8}"; lossValidationHistory[m_validationLossName].Add(validationLoss); lossValidationHistory[m_validationMetricName].Add(validationMetric); } Trace.WriteLine(traceOutput); } } return(lossValidationHistory); }
public IList <IList <float> > Predict(IMinibatchSource minibatchSource) { var predictions = new List <IList <float> >(); var predictor = new Predictor(Network, m_device); bool isSweepEnd = false; while (!isSweepEnd) { // TODO: Add Support for other prediction batch sizes. const int predictionBatchSize = 1; var nextMinibatch = minibatchSource.GetNextMinibatch(predictionBatchSize, m_device); var minibatch = nextMinibatch.minibatch; isSweepEnd = nextMinibatch.isSweepEnd; var batchPredictions = predictor.PredictNextStep(minibatch); predictions.AddRange(batchPredictions); } return(predictions); }