public void ClassificationNeuralNetModel_PredictProbability_Single()
        {
            var numberOfObservations = 500;
            var numberOfFeatures     = 5;
            var numberOfClasses      = 5;

            var random       = new Random(32);
            var observations = new F64Matrix(numberOfObservations, numberOfFeatures);

            observations.Map(() => random.NextDouble());
            var targets = Enumerable.Range(0, numberOfObservations)
                          .Select(i => (double)random.Next(0, numberOfClasses)).ToArray();

            var sut = ClassificationNeuralNetModel.Load(() => new StringReader(m_classificationNeuralNetModelText));

            var predictions = new ProbabilityPrediction[numberOfObservations];

            for (int i = 0; i < numberOfObservations; i++)
            {
                predictions[i] = sut.PredictProbability(observations.Row(i));
            }

            var evaluator = new TotalErrorClassificationMetric <double>();
            var actual    = evaluator.Error(targets, predictions.Select(p => p.Prediction).ToArray());

            Assert.AreEqual(0.762, actual);
        }
Пример #2
0
        public void ClassificationNeuralNetModel_Save()
        {
            var numberOfObservations = 500;
            var numberOfFeatures     = 5;
            var numberOfClasses      = 5;

            var random       = new Random(32);
            var observations = new F64Matrix(numberOfObservations, numberOfFeatures);

            observations.Map(() => random.NextDouble());
            var targets = Enumerable.Range(0, numberOfObservations).Select(i => (double)random.Next(0, numberOfClasses)).ToArray();

            var net = new NeuralNet();

            net.Add(new InputLayer(numberOfFeatures));
            net.Add(new DenseLayer(10));
            net.Add(new SvmLayer(numberOfClasses));

            var learner = new ClassificationNeuralNetLearner(net, new AccuracyLoss());
            var sut     = learner.Learn(observations, targets);

            // save model.
            var writer = new StringWriter();

            sut.Save(() => writer);

            // load model and assert prediction results.
            sut = ClassificationNeuralNetModel.Load(() => new StringReader(writer.ToString()));
            var predictions = sut.Predict(observations);

            var evaluator = new TotalErrorClassificationMetric <double>();
            var actual    = evaluator.Error(targets, predictions);

            Assert.AreEqual(0.762, actual, 0.0000001);
        }
Пример #3
0
 /// <summary>
 /// Method reads trained model from pravate field and load it to ClassificationNeuralNetModel object.
 /// </summary>
 private void LoadAnn()
 {
     if (!isAnnLoaded)
     {
         try
         {
             annModel    = ClassificationNeuralNetModel.Load(() => new StringReader(allAnnModel));
             isAnnLoaded = true;
         }
         catch (Exception ex)
         {
             isAnnLoaded = false;
             throw new Exception(ex.Message);
         }
     }
 }
Пример #4
0
        public void ClassificationNeuralNetModel_Predict_Multiple()
        {
            var numberOfObservations = 500;
            var numberOfFeatures     = 5;
            var numberOfClasses      = 5;

            var random       = new Random(32);
            var observations = new F64Matrix(numberOfObservations, numberOfFeatures);

            observations.Map(() => random.NextDouble());
            var targets = Enumerable.Range(0, numberOfObservations).Select(i => (double)random.Next(0, numberOfClasses)).ToArray();

            var sut = ClassificationNeuralNetModel.Load(() => new StringReader(ClassificationNeuralNetModelText));

            var predictions = sut.Predict(observations);

            var evaluator = new TotalErrorClassificationMetric <double>();
            var actual    = evaluator.Error(targets, predictions);

            Assert.AreEqual(0.762, actual);
        }
Пример #5
0
        private DecodedLetters DecodeExtractedLetters(IEnumerable <IEnumerable <CsvRow> > lettersCollections, QrReaderData qrData)
        {
            string currentResult     = null;
            var    isChecksumCorrect = false;

            foreach (var letters in lettersCollections)
            {
                var csvReader        = DataOutput.GetCsvParser(letters);
                var targetName       = "class";
                var featureNames     = csvReader.EnumerateRows(c => c != targetName).First().ColumnNameToIndex.Keys.ToArray();
                var testObservations = csvReader.EnumerateRows(featureNames).ToF64Matrix();
                testObservations.Map(p => p / 255);

                var model       = ClassificationNeuralNetModel.Load(() => new StreamReader("network.xml"));
                var predictions = model.Predict(testObservations);

                var stringBuilder = new StringBuilder();
                foreach (var prediction in predictions)
                {
                    stringBuilder.Append(Math.Abs(prediction - -1) < 0.01 ? '-' : Alphabet.Base32Alphabet.ToString()[(int)prediction]);
                }

                currentResult = stringBuilder.ToString();

                isChecksumCorrect = Sha1Helper.IsChecksumCorrect(currentResult, qrData.MetaData.Checksum);

                if (isChecksumCorrect)
                {
                    break;
                }
            }

            return(new DecodedLetters {
                Letters = currentResult, IsChecksumOk = isChecksumCorrect
            });
        }