Exemple #1
0
        public (double loss, double metric) Evaluate(IMinibatchSource minibatchSource)
        {
            // create loss and metric evaluators.
            using (var lossEvaluator = new MetricEvaluator(m_loss, m_device))
                using (var metricEvaluator = new MetricEvaluator(m_metric, m_device))
                {
                    bool isSweepEnd = false;

                    while (!isSweepEnd)
                    {
                        // TODO: Add Support for other evaluation batch sizes.
                        const int evaluationBatchSize = 1;
                        var       nextMinibatch       = minibatchSource.GetNextMinibatch(evaluationBatchSize, m_device);
                        var       minibatch           = nextMinibatch.minibatch;
                        isSweepEnd = nextMinibatch.isSweepEnd;

                        lossEvaluator.EvalauteNextStep(minibatch, evaluationBatchSize);
                        metricEvaluator.EvalauteNextStep(minibatch, evaluationBatchSize);
                    }

                    var finalLoss   = lossEvaluator.CurrentMetric;
                    var finalMetric = metricEvaluator.CurrentMetric;

                    return((float)finalLoss, (float)finalMetric);
                }
        }
Exemple #2
0
        public Dictionary <string, List <double> > Fit(IMinibatchSource trainMinibatchSource      = null, int batchSize = 32, int epochs = 1,
                                                       IMinibatchSource validationMinibatchSource = null)

        {
            // Setup fitter.
            var fitter = new Fitter(m_trainer, m_device);

            // TODO: Refactor to callback style reporting for each metric, instead of returning a dictionary.
            // store epoch history
            var lossValidationHistory = new Dictionary <string, List <double> >
            {
                { m_lossName, new List <double>() },
                { m_metricName, new List <double>() },
                { m_validationLossName, new List <double>() },
                { m_validationMetricName, new List <double>() },
            };

            for (int epoch = 0; epoch < epochs;)
            {
                var(minibatch, isSweepEnd) = trainMinibatchSource.GetNextMinibatch(batchSize, m_device);

                fitter.FitNextStep(minibatch, batchSize);

                if (isSweepEnd)
                {
                    // Get current loss and metric, reset accumulators for next epoch.
                    var currentLoss   = fitter.CurrentLoss;
                    var currentMetric = fitter.CurrentMetric;
                    fitter.ResetLossAndMetricAccumulators();

                    lossValidationHistory[m_lossName].Add(currentLoss);
                    lossValidationHistory[m_metricName].Add(currentMetric);

                    var traceOutput = $"Epoch: {epoch + 1:000} Loss = {currentLoss:F8}, Metric = {currentMetric:F8}";

                    ++epoch;

                    if (validationMinibatchSource != null)
                    {
                        (var validationLoss, var validationMetric) = Evaluate(validationMinibatchSource);
                        traceOutput += $" - ValidationLoss = {validationLoss:F8}, ValidationMetric = {validationMetric:F8}";

                        lossValidationHistory[m_validationLossName].Add(validationLoss);
                        lossValidationHistory[m_validationMetricName].Add(validationMetric);
                    }

                    Trace.WriteLine(traceOutput);
                }
            }

            return(lossValidationHistory);
        }
Exemple #3
0
        public IList <IList <float> > Predict(IMinibatchSource minibatchSource)
        {
            var predictions = new List <IList <float> >();
            var predictor   = new Predictor(Network, m_device);

            bool isSweepEnd = false;

            while (!isSweepEnd)
            {
                // TODO: Add Support for other prediction batch sizes.
                const int predictionBatchSize = 1;
                var       nextMinibatch       = minibatchSource.GetNextMinibatch(predictionBatchSize, m_device);
                var       minibatch           = nextMinibatch.minibatch;
                isSweepEnd = nextMinibatch.isSweepEnd;

                var batchPredictions = predictor.PredictNextStep(minibatch);
                predictions.AddRange(batchPredictions);
            }

            return(predictions);
        }