Пример #1
0
        public void DefaultDataSource()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "val3");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.Add(0.5f, 1.1, "d", "a");
            builder.Add(0.2f, 1.5, "c", "b");
            builder.Add(0.7f, 0.5, "b", "c");
            builder.Add(0.2f, 0.6, "a", "d");
            var table          = builder.Build();
            var vectoriser     = table.GetVectoriser();
            var graph          = new GraphFactory(_lap);
            var dataSource     = graph.CreateDataSource(table, vectoriser);
            var miniBatch      = dataSource.Get(null, new[] { 1 });
            var input          = miniBatch.CurrentSequence.Input[0].GetMatrix().Row(0).AsIndexable();
            var expectedOutput = miniBatch.CurrentSequence.Target.GetMatrix().Row(0).AsIndexable();

            Assert.AreEqual(input[0], 0.2f);
            Assert.AreEqual(input[1], 1.5f);
            Assert.AreEqual(expectedOutput.Count, 4);
            Assert.AreEqual(vectoriser.GetOutputLabel(2, expectedOutput.MaximumIndex()), "b");
        }
Пример #2
0
        static void ManyToOne()
        {
            var grammar   = new SequenceClassification(dictionarySize: 10, minSize: 5, maxSize: 5, noRepeat: true, isStochastic: false);
            var sequences = grammar.GenerateSequences().Take(1000).ToList();
            var builder   = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Matrix, "Sequence");
            builder.AddColumn(ColumnType.Vector, "Summary");

            foreach (var sequence in sequences)
            {
                var list    = new List <FloatVector>();
                var charSet = new HashSet <char>();
                foreach (var ch in sequence)
                {
                    charSet.Add(ch);
                    var row = grammar.Encode(charSet.Select(ch2 => (ch2, 1f)));
                    list.Add(row);
                }
                builder.Add(new FloatMatrix {
                    Row = list.ToArray()
                }, list.Last());
            }
            var data = builder.Build().Split(0);

            using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) {
                var graph       = new GraphFactory(lap);
                var errorMetric = graph.ErrorMetric.BinaryClassification;

                // create the property set
                var propertySet = graph.CurrentPropertySet
                                  .Use(graph.GradientDescent.RmsProp)
                                  .Use(graph.WeightInitialisation.Xavier)
                ;

                // create the engine
                var trainingData = graph.CreateDataSource(data.Training);
                var testData     = trainingData.CloneWith(data.Test);
                var engine       = graph.CreateTrainingEngine(trainingData, 0.03f, 8);

                // build the network
                const int HIDDEN_LAYER_SIZE = 128;
                var       memory            = new float[HIDDEN_LAYER_SIZE];
                var       network           = graph.Connect(engine)
                                              .AddLstm(memory)
                                              //.AddSimpleRecurrent(graph.ReluActivation(), memory)
                                              .AddFeedForward(engine.DataSource.OutputSize)
                                              .Add(graph.SigmoidActivation())
                                              .AddBackpropagationThroughTime(errorMetric)
                ;

                engine.Train(10, testData, errorMetric);

                var networkGraph    = engine.Graph;
                var executionEngine = graph.CreateEngine(networkGraph);

                var output = executionEngine.Execute(testData);
                Console.WriteLine(output.Where(o => o.Target != null).Average(o => o.CalculateError(errorMetric)));
            }
        }
Пример #3
0
        public void TestIndexHydration()
        {
            using (var dataStream = new MemoryStream())
                using (var indexStream = new MemoryStream()) {
                    var builder = BrightWireProvider.CreateDataTableBuilder(dataStream);
                    builder.AddColumn(ColumnType.Boolean, "target", true);
                    builder.AddColumn(ColumnType.Int, "val");
                    builder.AddColumn(ColumnType.String, "label");
                    for (var i = 0; i < 33000; i++)
                    {
                        builder.Add(i % 2 == 0, i, i.ToString());
                    }

                    var table = builder.Build();
                    builder.WriteIndexTo(indexStream);

                    dataStream.Seek(0, SeekOrigin.Begin);
                    indexStream.Seek(0, SeekOrigin.Begin);
                    var newTable = BrightWireProvider.CreateDataTable(dataStream, indexStream);
                    _CompareTables(table, newTable);

                    dataStream.Seek(0, SeekOrigin.Begin);
                    var newTable2 = BrightWireProvider.CreateDataTable(dataStream, null);
                    _CompareTables(table, newTable2);
                }
        }
Пример #4
0
        public void KNN()
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Float, "height");
            dataTable.AddColumn(ColumnType.Int, "weight").IsContinuous    = true;
            dataTable.AddColumn(ColumnType.Int, "foot-size").IsContinuous = true;
            dataTable.AddColumn(ColumnType.String, "gender", true);

            // sample data from: https://en.wikipedia.org/wiki/Naive_Bayes_classifier
            dataTable.Add(6f, 180, 12, "male");
            dataTable.Add(5.92f, 190, 11, "male");
            dataTable.Add(5.58f, 170, 12, "male");
            dataTable.Add(5.92f, 165, 10, "male");
            dataTable.Add(5f, 100, 6, "female");
            dataTable.Add(5.5f, 150, 8, "female");
            dataTable.Add(5.42f, 130, 7, "female");
            dataTable.Add(5.75f, 150, 9, "female");
            var index = dataTable.Build();

            var testData = BrightWireProvider.CreateDataTableBuilder(dataTable.Columns);
            var row      = testData.Add(6f, 130, 8, "?");

            var model          = index.TrainKNearestNeighbours();
            var classifier     = model.CreateClassifier(_lap, 2);
            var classification = classifier.Classify(row);

            Assert.IsTrue(classification.OrderByDescending(c => c.Weight).First().Label == "female");
        }
Пример #5
0
        public void SelectColumns()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.AddColumn(ColumnType.String, "cls2");

            builder.Add(0.5f, 1.1, "a", "a2");
            builder.Add(0.2f, 1.5, "b", "b2");
            builder.Add(0.7f, 0.5, "c", "c2");
            builder.Add(0.2f, 0.6, "d", "d2");

            var table  = builder.Build();
            var table2 = table.SelectColumns(new[] { 1, 2, 3 });

            Assert.AreEqual(table2.TargetColumnIndex, 1);
            Assert.AreEqual(table2.RowCount, 4);
            Assert.AreEqual(table2.ColumnCount, 3);

            var column = table2.GetNumericColumns(new[] { 0 }).Select(r => _lap.CreateVector(r)).First().AsIndexable();

            Assert.AreEqual(column[0], 1.1f);
            Assert.AreEqual(column[1], 1.5f);
        }
Пример #6
0
        public void TestColumnTypes()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Boolean, "boolean");
            builder.AddColumn(ColumnType.Byte, "byte");
            builder.AddColumn(ColumnType.Date, "date");
            builder.AddColumn(ColumnType.Double, "double");
            builder.AddColumn(ColumnType.Float, "float");
            builder.AddColumn(ColumnType.Int, "int");
            builder.AddColumn(ColumnType.Long, "long");
            builder.AddColumn(ColumnType.Null, "null");
            builder.AddColumn(ColumnType.String, "string");

            var now = DateTime.Now;

            builder.Add(true, (byte)100, now, 1.0 / 3, 0.5f, int.MaxValue, long.MaxValue, null, "test");
            var dataTable = builder.Build();

            var firstRow = dataTable.GetRow(0);

            Assert.AreEqual(firstRow.GetField <bool>(0), true);
            Assert.AreEqual(firstRow.GetField <byte>(1), 100);
            Assert.AreEqual(firstRow.GetField <DateTime>(2), now);
            Assert.AreEqual(firstRow.GetField <double>(3), 1.0 / 3);
            Assert.AreEqual(firstRow.GetField <float>(4), 0.5f);
            Assert.AreEqual(firstRow.GetField <int>(5), int.MaxValue);
            Assert.AreEqual(firstRow.GetField <long>(6), long.MaxValue);
            Assert.AreEqual(firstRow.GetField <object>(7), null);
            Assert.AreEqual(firstRow.GetField <string>(8), "test");
        }
Пример #7
0
        public void DeserialiseVectorisationModel()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.String, "label");
            builder.AddColumn(ColumnType.String, "output", true);

            builder.Add("a", "0");
            builder.Add("b", "0");
            builder.Add("c", "1");

            var dataTable  = builder.Build();
            var vectoriser = dataTable.GetVectoriser();
            var model      = vectoriser.GetVectorisationModel();

            var vectorList = new List <FloatVector>();

            dataTable.ForEach(row => vectorList.Add(vectoriser.GetInput(row)));

            var vectoriser2 = dataTable.GetVectoriser(model);
            var vectorList2 = new List <FloatVector>();

            dataTable.ForEach(row => vectorList2.Add(vectoriser2.GetInput(row)));

            foreach (var item in vectorList.Zip(vectorList2, (v1, v2) => (v1, v2)))
            {
                _AssertEqual(item.Item1.Data, item.Item2.Data);
            }
        }
Пример #8
0
        public void TestMultinomialLogisticRegression()
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Float, "height");
            dataTable.AddColumn(ColumnType.Int, "weight");
            dataTable.AddColumn(ColumnType.Int, "foot-size");
            dataTable.AddColumn(ColumnType.String, "gender", true);

            // sample data from: https://en.wikipedia.org/wiki/Naive_Bayes_classifier
            dataTable.Add(6f, 180, 12, "male");
            dataTable.Add(5.92f, 190, 11, "male");
            dataTable.Add(5.58f, 170, 12, "male");
            dataTable.Add(5.92f, 165, 10, "male");
            dataTable.Add(5f, 100, 6, "female");
            dataTable.Add(5.5f, 150, 8, "female");
            dataTable.Add(5.42f, 130, 7, "female");
            dataTable.Add(5.75f, 150, 9, "female");
            var index = dataTable.Build();

            var testData = BrightWireProvider.CreateDataTableBuilder(dataTable.Columns);
            var row      = testData.Add(6f, 130, 8, "?");

            var model          = index.TrainMultinomialLogisticRegression(_lap, 100, 0.1f);
            var classifier     = model.CreateClassifier(_lap);
            var classification = classifier.Classify(row);

            Assert.IsTrue(classification.GetBestClassification() == "female");
        }
Пример #9
0
        public void TestRegression()
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Float, "value");
            dataTable.AddColumn(ColumnType.Float, "result", true);

            // simple linear relationship: result is twice value
            dataTable.Add(1f, 2f);
            dataTable.Add(2f, 4f);
            dataTable.Add(4f, 8f);
            dataTable.Add(8f, 16f);
            var index = dataTable.Build();

            var classifier = index.CreateLinearRegressionTrainer(_lap);
            //var theta = classifier.Solve();
            //var predictor = theta.CreatePredictor(_lap);

            //var prediction = predictor.Predict(3f);
            //Assert.IsTrue(Math.Round(prediction) == 6f);

            var theta      = classifier.GradientDescent(20, 0.01f, 0.1f, cost => true);
            var predictor  = theta.CreatePredictor(_lap);
            var prediction = predictor.Predict(3f);

            Assert.IsTrue(Math.Round(prediction) == 6f);

            var prediction3 = predictor.Predict(new[] {
                new float[] { 10f },
                new float[] { 3f }
            });

            Assert.IsTrue(Math.Round(prediction3[1]) == 6f);
        }
Пример #10
0
        public void TableSummarise()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Boolean, "boolean");
            builder.AddColumn(ColumnType.Byte, "byte");
            builder.AddColumn(ColumnType.Date, "date");
            builder.AddColumn(ColumnType.Double, "double");
            builder.AddColumn(ColumnType.Float, "float");
            builder.AddColumn(ColumnType.Int, "int");
            builder.AddColumn(ColumnType.Long, "long");
            builder.AddColumn(ColumnType.String, "string");

            var now = DateTime.Now;

            builder.Add(true, (sbyte)100, now, 1.0 / 2, 0.5f, int.MaxValue, long.MaxValue, "test");
            builder.Add(true, (sbyte)0, now, 0.0, 0f, int.MinValue, long.MinValue, "test");
            var dataTable = builder.Build();

            var summarisedRow = dataTable.Summarise(1).GetRow(0);

            Assert.AreEqual(summarisedRow.GetField <bool>(0), true);
            Assert.AreEqual(summarisedRow.GetField <sbyte>(1), (sbyte)50);
            Assert.AreEqual(summarisedRow.GetField <double>(3), 0.25);
            Assert.AreEqual(summarisedRow.GetField <string>(7), "test");
        }
Пример #11
0
        public void TestLogisticRegression()
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Float, "hours");
            dataTable.AddColumn(ColumnType.Boolean, "pass", true);

            // sample data from: https://en.wikipedia.org/wiki/Logistic_regression
            dataTable.Add(0.5f, false);
            dataTable.Add(0.75f, false);
            dataTable.Add(1f, false);
            dataTable.Add(1.25f, false);
            dataTable.Add(1.5f, false);
            dataTable.Add(1.75f, false);
            dataTable.Add(1.75f, true);
            dataTable.Add(2f, false);
            dataTable.Add(2.25f, true);
            dataTable.Add(2.5f, false);
            dataTable.Add(2.75f, true);
            dataTable.Add(3f, false);
            dataTable.Add(3.25f, true);
            dataTable.Add(3.5f, false);
            dataTable.Add(4f, true);
            dataTable.Add(4.25f, true);
            dataTable.Add(4.5f, true);
            dataTable.Add(4.75f, true);
            dataTable.Add(5f, true);
            dataTable.Add(5.5f, true);
            var index = dataTable.Build();

            var trainer      = index.CreateLogisticRegressionTrainer(_lap);
            var theta        = trainer.GradientDescent(1000, 0.1f);
            var predictor    = theta.CreatePredictor(_lap);
            var probability1 = predictor.Predict(2f);

            Assert.IsTrue(probability1 < 0.5f);

            var probability2 = predictor.Predict(4f);

            Assert.IsTrue(probability2 >= 0.5f);

            var probability3 = predictor.Predict(new[] {
                new float[] { 1f },
                new float[] { 2f },
                new float[] { 3f },
                new float[] { 4f },
                new float[] { 5f },
            });

            Assert.IsTrue(probability3[0] <= 0.5f);
            Assert.IsTrue(probability3[1] <= 0.5f);
            Assert.IsTrue(probability3[2] >= 0.5f);
            Assert.IsTrue(probability3[3] >= 0.5f);
            Assert.IsTrue(probability3[4] >= 0.5f);
        }
Пример #12
0
        IDataTable _GetSimpleTable()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Int, "val");
            for (var i = 0; i < 10000; i++)
            {
                builder.Add(i);
            }
            return(builder.Build());
        }
Пример #13
0
        static void OneToMany()
        {
            var grammar = new SequenceGenerator(dictionarySize: 10, minSize: 5, maxSize: 5,
                                                noRepeat: true, isStochastic: false);
            var sequences = grammar.GenerateSequences().Take(1000).ToList();
            var builder   = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Vector, "Summary");
            builder.AddColumn(ColumnType.Matrix, "Sequence");
            foreach (var sequence in sequences)
            {
                var sequenceData = sequence.GroupBy(ch => ch).Select(g => (g.Key, g.Count())).
                                   ToDictionary(d => d.Item1, d => (float)d.Item2);
                var summary = grammar.Encode(sequenceData.Select(kv => (kv.Key, kv.Value)));
                var list    = new List <FloatVector>();
                foreach (var item in sequenceData.OrderBy(kv => kv.Key))
                {
                    var row = grammar.Encode(item.Key, item.Value);
                    list.Add(row);
                }

                builder.Add(summary, FloatMatrix.Create(list.ToArray()));
            }

            var data = builder.Build().Split(0);

            using var lap = BrightWireProvider.CreateLinearAlgebra(false);
            var graph       = new GraphFactory(lap);
            var errorMetric = graph.ErrorMetric.BinaryClassification;

            // create the property set
            var propertySet = graph.CurrentPropertySet.Use(graph.GradientDescent.RmsProp).
                              Use(graph.WeightInitialisation.Xavier);

            // create the engine
            const float TRAINING_RATE = 0.1f;
            var         trainingData  = graph.CreateDataSource(data.Training);
            var         testData      = trainingData.CloneWith(data.Test);
            var         engine        = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 8);

            engine.LearningContext.ScheduleLearningRate(30, TRAINING_RATE / 3);

            // build the network
            const int HIDDEN_LAYER_SIZE = 128;

            graph.Connect(engine).AddLstm(HIDDEN_LAYER_SIZE).AddFeedForward(engine.DataSource.OutputSize).
            Add(graph.SigmoidActivation()).AddBackpropagation(errorMetric);
            engine.Train(40, testData, errorMetric);
            var networkGraph    = engine.Graph;
            var executionEngine = graph.CreateEngine(networkGraph);
            var output          = executionEngine.Execute(testData);

            Console.WriteLine(output.Average(o => o.CalculateError(errorMetric)));
        }
Пример #14
0
        /// <summary>
        /// Trains a feed forward neural net on the emotion dataset
        /// http://lpis.csd.auth.gr/publications/tsoumakas-ismir08.pdf
        /// The data files can be downloaded from https://downloads.sourceforge.net/project/mulan/datasets/emotions.rar
        /// </summary>
        /// <param name="dataFilePath"></param>
        public static void MultiLabelSingleClassifier(string dataFilePath)
        {
            var emotionData           = _LoadEmotionData(dataFilePath);
            var attributeColumns      = Enumerable.Range(0, emotionData.ColumnCount - CLASSIFICATION_COUNT).ToList();
            var classificationColumns = Enumerable.Range(emotionData.ColumnCount - CLASSIFICATION_COUNT, CLASSIFICATION_COUNT).ToList();

            // create a new data table with a vector input column and a vector output column
            var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder();

            dataTableBuilder.AddColumn(ColumnType.Vector, "Attributes");
            dataTableBuilder.AddColumn(ColumnType.Vector, "Target", isTarget: true);
            emotionData.ForEach(row => {
                var input  = FloatVector.Create(row.GetFields <float>(attributeColumns).ToArray());
                var target = FloatVector.Create(row.GetFields <float>(classificationColumns).ToArray());
                dataTableBuilder.Add(input, target);
                return(true);
            });
            var data = dataTableBuilder.Build().Split(0);

            // train a neural network
            using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) {
                var graph = new GraphFactory(lap);

                // binary classification rounds each output to 0 or 1 and compares each output against the binary classification targets
                var errorMetric = graph.ErrorMetric.BinaryClassification;

                // configure the network properties
                graph.CurrentPropertySet
                .Use(graph.GradientDescent.Adam)
                .Use(graph.WeightInitialisation.Xavier)
                ;

                // create a training engine
                const float TRAINING_RATE = 0.3f;
                var         trainingData  = graph.CreateDataSource(data.Training);
                var         testData      = trainingData.CloneWith(data.Test);
                var         engine        = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 128);

                // build the network
                const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 2000;
                var       network = graph.Connect(engine)
                                    .AddFeedForward(HIDDEN_LAYER_SIZE)
                                    .Add(graph.SigmoidActivation())
                                    .AddDropOut(dropOutPercentage: 0.5f)
                                    .AddFeedForward(engine.DataSource.OutputSize)
                                    .Add(graph.SigmoidActivation())
                                    .AddBackpropagation(errorMetric)
                ;

                // train the network
                engine.Train(TRAINING_ITERATIONS, testData, errorMetric, null, 50);
            }
        }
Пример #15
0
        public void TableConfusionMatrix()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.String, "actual");
            builder.AddColumn(ColumnType.String, "expected");

            const int CAT_CAT       = 5;
            const int CAT_DOG       = 2;
            const int DOG_CAT       = 3;
            const int DOG_DOG       = 5;
            const int DOG_RABBIT    = 2;
            const int RABBIT_DOG    = 1;
            const int RABBIT_RABBIT = 11;

            for (var i = 0; i < CAT_CAT; i++)
            {
                builder.Add("cat", "cat");
            }
            for (var i = 0; i < CAT_DOG; i++)
            {
                builder.Add("cat", "dog");
            }
            for (var i = 0; i < DOG_CAT; i++)
            {
                builder.Add("dog", "cat");
            }
            for (var i = 0; i < DOG_DOG; i++)
            {
                builder.Add("dog", "dog");
            }
            for (var i = 0; i < DOG_RABBIT; i++)
            {
                builder.Add("dog", "rabbit");
            }
            for (var i = 0; i < RABBIT_DOG; i++)
            {
                builder.Add("rabbit", "dog");
            }
            for (var i = 0; i < RABBIT_RABBIT; i++)
            {
                builder.Add("rabbit", "rabbit");
            }
            var table           = builder.Build();
            var confusionMatrix = table.CreateConfusionMatrix(1, 0);
            var xml             = confusionMatrix.AsXml;

            Assert.AreEqual((uint)CAT_DOG, confusionMatrix.GetCount("cat", "dog"));
            Assert.AreEqual((uint)DOG_RABBIT, confusionMatrix.GetCount("dog", "rabbit"));
            Assert.AreEqual((uint)RABBIT_RABBIT, confusionMatrix.GetCount("rabbit", "rabbit"));
        }
Пример #16
0
        public void TestTargetColumnIndex()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.String, "a");
            builder.AddColumn(ColumnType.String, "b", true);
            builder.AddColumn(ColumnType.String, "c");
            builder.Add("a", "b", "c");
            var table = builder.Build();

            Assert.AreEqual(table.TargetColumnIndex, 1);
            Assert.AreEqual(table.RowCount, 1);
            Assert.AreEqual(table.ColumnCount, 3);
        }
Пример #17
0
        public void TiedAutoEncoder()
        {
            const int DATA_SIZE = 1000, REDUCED_SIZE = 200;

            // create some random data
            var rand    = new Random();
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddVectorColumn(DATA_SIZE, "Input");
            builder.AddVectorColumn(DATA_SIZE, "Output", true);
            for (var i = 0; i < 100; i++)
            {
                var vector = new FloatVector {
                    Data = Enumerable.Range(0, DATA_SIZE).Select(j => Convert.ToSingle(rand.NextDouble())).ToArray()
                };
                builder.Add(vector, vector);
            }
            var dataTable = builder.Build();

            // build the autoencoder with tied weights
            var graph       = new GraphFactory(_lap);
            var dataSource  = graph.CreateDataSource(dataTable);
            var engine      = graph.CreateTrainingEngine(dataSource, 0.03f, 32);
            var errorMetric = graph.ErrorMetric.Quadratic;

            graph.CurrentPropertySet
            .Use(graph.RmsProp())
            .Use(graph.WeightInitialisation.Xavier)
            ;

            graph.Connect(engine)
            .AddFeedForward(REDUCED_SIZE, "layer")
            .Add(graph.TanhActivation())
            .AddTiedFeedForward(engine.Start.FindByName("layer") as IFeedForward)
            .Add(graph.TanhActivation())
            .AddBackpropagation(errorMetric)
            ;
            using (var executionContext = graph.CreateExecutionContext()) {
                for (var i = 0; i < 2; i++)
                {
                    var trainingError = engine.Train(executionContext);
                }
            }
            var networkGraph    = engine.Graph;
            var executionEngine = graph.CreateEngine(networkGraph);
            var results         = executionEngine.Execute(dataTable.GetRow(0).GetField <FloatVector>(0).Data);
        }
Пример #18
0
        public void GetNumericColumns2()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.Add(0.5f, 1.1, "a");
            builder.Add(0.2f, 1.5, "b");
            builder.Add(0.7f, 0.5, "c");
            builder.Add(0.2f, 0.6, "d");
            var table  = builder.Build();
            var column = table.GetNumericColumns(new[] { 1 }).First();

            Assert.AreEqual(column[0], 1.1f);
            Assert.AreEqual(column[1], 1.5f);
        }
Пример #19
0
        public void Fold()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.Add(0.5f, 1.1, "a");
            builder.Add(0.2f, 1.5, "b");
            builder.Add(0.7f, 0.5, "c");
            builder.Add(0.2f, 0.6, "d");
            var table = builder.Build();
            var folds = table.Fold(4, 0, false).ToList();

            Assert.AreEqual(folds.Count, 4);
            Assert.IsTrue(folds.All(r => r.Training.RowCount == 3 && r.Validation.RowCount == 1));
        }
Пример #20
0
        public void TableFilter()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.Add(0.5f, 1.1, "a");
            builder.Add(0.2f, 1.5, "b");
            builder.Add(0.7f, 0.5, "c");
            builder.Add(0.2f, 0.6, "d");
            var table          = builder.Build();
            var projectedTable = table.Project(r => r.GetField <string>(2) == "b" ? null : r.Data);

            Assert.AreEqual(projectedTable.ColumnCount, table.ColumnCount);
            Assert.AreEqual(projectedTable.RowCount, 3);
        }
Пример #21
0
        public void GetNumericRows()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.Add(0.5f, 1.1, "a");
            builder.Add(0.2f, 1.5, "b");
            builder.Add(0.7f, 0.5, "c");
            builder.Add(0.2f, 0.6, "d");
            var table = builder.Build();
            var rows  = table.GetNumericRows(new[] { 1 }).Select(r => _lap.CreateVector(r)).
                        Select(r => r.AsIndexable()).ToList();

            Assert.AreEqual(rows[0][0], 1.1f);
            Assert.AreEqual(rows[1][0], 1.5f);
        }
Пример #22
0
        static IDataTable _CreateComplexTable()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Boolean, "boolean");
            builder.AddColumn(ColumnType.Byte, "byte");
            builder.AddColumn(ColumnType.Date, "date");
            builder.AddColumn(ColumnType.Double, "double");
            builder.AddColumn(ColumnType.Float, "float");
            builder.AddColumn(ColumnType.Int, "int");
            builder.AddColumn(ColumnType.Long, "long");
            builder.AddColumn(ColumnType.String, "string");

            for (var i = 1; i <= 10; i++)
            {
                builder.Add(i % 2 == 0, (sbyte)i, DateTime.Now, (double)i, (float)i, i, (long)i, i.ToString());
            }
            return(builder.Build());
        }
Пример #23
0
        public static IDataSource BuildTensors(GraphFactory graph, IDataSource existing, IReadOnlyList <Mnist.Image> images)
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Tensor, "Image");
            dataTable.AddColumn(ColumnType.Vector, "Target", true);
            foreach (var image in images)
            {
                var data = image.AsFloatTensor;
                dataTable.Add(data.Tensor, data.Label);
            }
            if (existing != null)
            {
                return(existing.CloneWith(dataTable.Build()));
            }
            else
            {
                return(graph.CreateDataSource(dataTable.Build()));
            }
        }
Пример #24
0
        public void TestDataTableAnalysis()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Boolean, "boolean");
            builder.AddColumn(ColumnType.Byte, "byte");
            builder.AddColumn(ColumnType.Date, "date");
            builder.AddColumn(ColumnType.Double, "double");
            builder.AddColumn(ColumnType.Float, "float");
            builder.AddColumn(ColumnType.Int, "int");
            builder.AddColumn(ColumnType.Long, "long");
            builder.AddColumn(ColumnType.Null, "null");
            builder.AddColumn(ColumnType.String, "string");

            for (var i = 1; i <= 10; i++)
            {
                builder.Add(i % 2 == 0, (byte)i, DateTime.Now, (double)i, (float)i, i, (long)i, null, i.ToString());
            }
            var table    = builder.Build();
            var analysis = table.GetAnalysis();
            var xml      = analysis.AsXml;

            var boolAnalysis = analysis[0] as INumericColumnInfo;

            Assert.IsTrue(boolAnalysis.NumDistinct == 2);
            Assert.IsTrue(boolAnalysis.Mean == 0.5);

            var numericAnalysis = new[] { 1, 3, 4, 5, 6 }.Select(i => analysis[i] as INumericColumnInfo).ToList();

            Assert.IsTrue(numericAnalysis.All(a => a.NumDistinct == 10));
            Assert.IsTrue(numericAnalysis.All(a => a.Min == 1));
            Assert.IsTrue(numericAnalysis.All(a => a.Max == 10));
            Assert.IsTrue(numericAnalysis.All(a => a.Mean == 5.5));
            Assert.IsTrue(numericAnalysis.All(a => a.Median.Value == 5));
            Assert.IsTrue(numericAnalysis.All(a => Math.Round(a.StdDev.Value) == 3));

            var stringAnalysis = analysis[8] as IStringColumnInfo;

            Assert.IsTrue(stringAnalysis.NumDistinct == 10);
            Assert.IsTrue(stringAnalysis.MaxLength == 2);
        }
Пример #25
0
        static IDataSource _BuildTensors(GraphFactory graph, IDataSource existing,
                                         IReadOnlyList <Mnist.Image> images)
        {
            // convolutional neural networks expect a 3D tensor => vector mapping
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Tensor, "Image");
            dataTable.AddColumn(ColumnType.Vector, "Target", isTarget: true);
            foreach (var image in images)
            {
                var data = image.AsFloatTensor;
                dataTable.Add(data.Tensor, data.Label);
            }

            // reuse the network used for training when building the test data source
            if (existing != null)
            {
                return(existing.CloneWith(dataTable.Build()));
            }
            return(graph.CreateDataSource(dataTable.Build()));
        }
Пример #26
0
        public static void SimpleLinearTest()
        {
            var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder();

            dataTableBuilder.AddColumn(ColumnType.Float, "capital costs");
            dataTableBuilder.AddColumn(ColumnType.Float, "labour costs");
            dataTableBuilder.AddColumn(ColumnType.Float, "energy costs");
            dataTableBuilder.AddColumn(ColumnType.Float, "output", true);

            dataTableBuilder.Add(98.288f, 0.386f, 13.219f, 1.270f);
            dataTableBuilder.Add(255.068f, 1.179f, 49.145f, 4.597f);
            dataTableBuilder.Add(208.904f, 0.532f, 18.005f, 1.985f);
            dataTableBuilder.Add(528.864f, 1.836f, 75.639f, 9.897f);
            dataTableBuilder.Add(307.419f, 1.136f, 52.234f, 5.907f);
            dataTableBuilder.Add(138.283f, 1.085f, 9.027f, 1.832f);
            dataTableBuilder.Add(418.883f, 2.390f, 1.676f, 4.865f);
            dataTableBuilder.Add(247.439f, 1.356f, 31.244f, 2.728f);
            dataTableBuilder.Add(19.478f, 0.115f, 1.739f, 0.125f);
            dataTableBuilder.Add(537.540f, 2.591f, 104.584f, 9.685f);
            dataTableBuilder.Add(605.507f, 2.789f, 82.296f, 8.727f);
            dataTableBuilder.Add(174.765f, 0.933f, 21.990f, 2.239f);
            dataTableBuilder.Add(946.766f, 4.004f, 125.351f, 10.077f);
            dataTableBuilder.Add(296.490f, 1.513f, 43.232f, 4.477f);
            dataTableBuilder.Add(645.690f, 2.540f, 75.581f, 7.037f);
            dataTableBuilder.Add(288.975f, 1.416f, 42.037f, 3.507f);

            var dataTable = dataTableBuilder.Build().Normalise(NormalisationType.Standard);

            using (var lap = BrightWireProvider.CreateLinearAlgebra(false))
            {
                var trainer = dataTable.CreateLinearRegressionTrainer(lap);
                var theta   = trainer.GradientDescent(20, 0.03f, 0.1f, cost =>
                {
                    Console.WriteLine(cost);
                    return(true);
                });
                Console.WriteLine(theta.Theta);
            }
        }
Пример #27
0
        /// <summary>
        /// Uses a recurrent LSTM neural network to predict stock price movements
        /// Data can be downloaded from https://raw.githubusercontent.com/plotly/datasets/master/stockdata.csv
        /// </summary>
        static void StockData(string dataFilePath)
        {
            // load and normalise the data
            var dataSet    = new StreamReader(dataFilePath).ParseCSV(',', true);
            var normalised = dataSet.Normalise(NormalisationType.FeatureScale);
            var rows       = normalised.GetNumericRows(dataSet.Columns.Where(c => c.Name != "Date").Select(c => c.Index));

            // build the data table with a window of input data and the prediction as the following value
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Matrix, "Past");
            builder.AddColumn(ColumnType.Vector, "Future");
            const int LAST_X_DAYS = 14;

            for (var i = 0; i < rows.Count - LAST_X_DAYS - 1; i++)
            {
                var inputVector = new List <FloatVector>();
                for (var j = 0; j < LAST_X_DAYS; j++)
                {
                    inputVector.Add(FloatVector.Create(rows[i + j]));
                }
                var input  = FloatMatrix.Create(inputVector.ToArray());
                var target = FloatVector.Create(rows[i + LAST_X_DAYS + 1]);
                builder.Add(input, target);
            }
            var data = builder.Build().Split(trainingPercentage: 0.2);

            using (var lap = BrightWireProvider.CreateLinearAlgebra()) {
                var graph       = new GraphFactory(lap);
                var errorMetric = graph.ErrorMetric.Quadratic;

                // create the property set
                graph.CurrentPropertySet
                .Use(graph.GradientDescent.Adam)
                .Use(graph.WeightInitialisation.Xavier);

                // create the engine
                var trainingData = graph.CreateDataSource(data.Training);
                var testData     = trainingData.CloneWith(data.Test);
                var engine       = graph.CreateTrainingEngine(trainingData, learningRate: 0.03f, batchSize: 128);

                // build the network
                const int HIDDEN_LAYER_SIZE = 256;
                graph.Connect(engine)
                .AddLstm(HIDDEN_LAYER_SIZE)
                .AddFeedForward(engine.DataSource.OutputSize)
                .Add(graph.TanhActivation())
                .AddBackpropagationThroughTime(errorMetric);

                // train the network and restore the best result
                GraphModel bestNetwork = null;
                engine.Train(50, testData, errorMetric, model => bestNetwork = model);
                if (bestNetwork != null)
                {
                    // execute each row of the test data on an execution engine
                    var executionEngine = graph.CreateEngine(bestNetwork.Graph);
                    var results         = executionEngine.Execute(testData).OrderSequentialOutput();
                    var expectedOutput  = data.Test.GetColumn <FloatVector>(1);

                    var score = results.Select((r, i) => errorMetric.Compute(r.Last(), expectedOutput[i])).Average();
                    Console.WriteLine(score);
                }
            }
        }
Пример #28
0
        /// <summary>
        /// Trains multiple classifiers on the emotion data set
        /// http://lpis.csd.auth.gr/publications/tsoumakas-ismir08.pdf
        /// The data files can be downloaded from https://downloads.sourceforge.net/project/mulan/datasets/emotions.rar
        /// </summary>
        /// <param name="dataFilePath"></param>
        public static void MultiLabelMultiClassifiers(string dataFilePath)
        {
            var emotionData           = _LoadEmotionData(dataFilePath);
            var attributeCount        = emotionData.ColumnCount - CLASSIFICATION_COUNT;
            var attributeColumns      = Enumerable.Range(0, attributeCount).ToList();
            var classificationColumns = Enumerable.Range(emotionData.ColumnCount - CLASSIFICATION_COUNT, CLASSIFICATION_COUNT).ToList();
            var classificationLabel   = new[] {
                "amazed-suprised",
                "happy-pleased",
                "relaxing-calm",
                "quiet-still",
                "sad-lonely",
                "angry-aggresive"
            };

            // create six separate datasets to train, each with a separate classification column
            var dataSets = Enumerable.Range(attributeCount, CLASSIFICATION_COUNT).Select(targetIndex => {
                var dataTableBuider = BrightWireProvider.CreateDataTableBuilder();
                for (var i = 0; i < attributeCount; i++)
                {
                    dataTableBuider.AddColumn(ColumnType.Float);
                }
                dataTableBuider.AddColumn(ColumnType.Float, "", true);

                return(emotionData.Project(row => row.GetFields <float>(attributeColumns)
                                           .Concat(new[] { row.GetField <float>(targetIndex) })
                                           .Cast <object>()
                                           .ToList()
                                           ));
            }).Select(ds => ds.Split(0)).ToList();

            // train classifiers on each training set
            using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) {
                var graph = new GraphFactory(lap);

                // binary classification rounds each output to 0 or 1 and compares each output against the binary classification targets
                var errorMetric = graph.ErrorMetric.BinaryClassification;

                // configure the network properties
                graph.CurrentPropertySet
                .Use(graph.GradientDescent.Adam)
                .Use(graph.WeightInitialisation.Xavier)
                ;

                for (var i = 0; i < CLASSIFICATION_COUNT; i++)
                {
                    var trainingSet = dataSets[i].Training;
                    var testSet     = dataSets[i].Test;
                    Console.WriteLine("Training on {0}", classificationLabel[i]);

                    // train and evaluate a naive bayes classifier
                    var naiveBayes = trainingSet.TrainNaiveBayes().CreateClassifier();
                    Console.WriteLine("\tNaive bayes accuracy: {0:P}", testSet
                                      .Classify(naiveBayes)
                                      .Average(d => d.Row.GetField <string>(attributeCount) == d.Classification ? 1.0 : 0.0)
                                      );

                    // train a logistic regression classifier
                    var logisticRegression = trainingSet
                                             .TrainLogisticRegression(lap, 2500, 0.25f, 0.01f)
                                             .CreatePredictor(lap)
                                             .ConvertToRowClassifier(attributeColumns)
                    ;
                    Console.WriteLine("\tLogistic regression accuracy: {0:P}", testSet
                                      .Classify(logisticRegression)
                                      .Average(d => d.Row.GetField <string>(attributeCount) == d.Classification ? 1.0 : 0.0)
                                      );

                    // train and evaluate k nearest neighbours
                    var knn = trainingSet.TrainKNearestNeighbours().CreateClassifier(lap, 10);
                    Console.WriteLine("\tK nearest neighbours accuracy: {0:P}", testSet
                                      .Classify(knn)
                                      .Average(d => d.Row.GetField <string>(attributeCount) == d.Classification ? 1.0 : 0.0)
                                      );

                    // create a training engine
                    const float TRAINING_RATE = 0.1f;
                    var         trainingData  = graph.CreateDataSource(trainingSet);
                    var         testData      = trainingData.CloneWith(testSet);
                    var         engine        = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 64);

                    // build the network
                    const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 2000;
                    var       network = graph.Connect(engine)
                                        .AddFeedForward(HIDDEN_LAYER_SIZE)
                                        .Add(graph.SigmoidActivation())
                                        .AddDropOut(dropOutPercentage: 0.5f)
                                        .AddFeedForward(engine.DataSource.OutputSize)
                                        .Add(graph.SigmoidActivation())
                                        .AddBackpropagation(errorMetric)
                    ;

                    // train the network
                    engine.Train(TRAINING_ITERATIONS, testData, errorMetric, null, 200);
                }
            }
        }
Пример #29
0
        static void SequenceToSequence()
        {
            const int SEQUENCE_LENGTH = 5;
            var       grammar         = new SequenceClassification(8, SEQUENCE_LENGTH, SEQUENCE_LENGTH, true, false);
            var       sequences       = grammar.GenerateSequences().Take(2000).ToList();
            var       builder         = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Matrix, "Input");
            builder.AddColumn(ColumnType.Matrix, "Output");

            foreach (var sequence in sequences)
            {
                var encodedSequence  = grammar.Encode(sequence);
                var reversedSequence = new FloatMatrix {
                    Row = encodedSequence.Row.Reverse().Take(SEQUENCE_LENGTH - 1).ToArray()
                };
                builder.Add(encodedSequence, reversedSequence);
            }
            var data = builder.Build().Split(0);

            using (var lap = BrightWireProvider.CreateLinearAlgebra()) {
                var graph       = new GraphFactory(lap);
                var errorMetric = graph.ErrorMetric.BinaryClassification;

                // create the property set
                var propertySet = graph.CurrentPropertySet
                                  .Use(graph.GradientDescent.RmsProp)
                                  .Use(graph.WeightInitialisation.Xavier)
                ;

                const int   BATCH_SIZE        = 16;
                int         HIDDEN_LAYER_SIZE = 64;
                const float TRAINING_RATE     = 0.1f;

                // create the encoder
                var encoderLearningContext = graph.CreateLearningContext(TRAINING_RATE, BATCH_SIZE, TrainingErrorCalculation.Fast, true);
                var encoderMemory          = new float[HIDDEN_LAYER_SIZE];
                var trainingData           = graph.CreateDataSource(data.Training, encoderLearningContext, wb => wb
                                                                    .AddLstm(encoderMemory, "encoder")
                                                                    .WriteNodeMemoryToSlot("shared-memory", wb.Find("encoder") as IHaveMemoryNode)
                                                                    .AddFeedForward(grammar.DictionarySize)
                                                                    .Add(graph.SigmoidActivation())
                                                                    .AddBackpropagationThroughTime(errorMetric)
                                                                    );
                var testData = trainingData.CloneWith(data.Test);

                // create the engine
                var engine = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, BATCH_SIZE);
                engine.LearningContext.ScheduleLearningRate(30, TRAINING_RATE / 3);
                engine.LearningContext.ScheduleLearningRate(40, TRAINING_RATE / 9);

                // create the decoder
                var decoderMemory = new float[HIDDEN_LAYER_SIZE];
                var wb2           = graph.Connect(engine);
                wb2
                .JoinInputWithMemory("shared-memory")
                .IncrementSizeBy(HIDDEN_LAYER_SIZE)
                .AddLstm(decoderMemory, "decoder")
                .AddFeedForward(trainingData.OutputSize)
                .Add(graph.SigmoidActivation())
                .AddBackpropagationThroughTime(errorMetric)
                ;

                engine.Train(50, testData, errorMetric);

                //var dataSourceModel = (trainingData as IAdaptiveDataSource).GetModel();
                //var testData2 = graph.CreateDataSource(data.Test, dataSourceModel);
                //var networkGraph = engine.Graph;
                //var executionEngine = graph.CreateEngine(networkGraph);

                //var output = executionEngine.Execute(testData2);
                //Console.WriteLine(output.Average(o => o.CalculateError(errorMetric)));
            }
        }
Пример #30
0
        /// <summary>
        /// Trains a feed forward neural net on the emotion dataset
        /// http://lpis.csd.auth.gr/publications/tsoumakas-ismir08.pdf
        /// The data files can be downloaded from https://downloads.sourceforge.net/project/mulan/datasets/emotions.rar
        /// </summary>
        /// <param name="dataFilePath"></param>
        public static void MultiLabelSingleClassifier(string dataFilePath)
        {
            var emotionData           = _LoadEmotionData(dataFilePath);
            var attributeColumns      = Enumerable.Range(0, emotionData.ColumnCount - CLASSIFICATION_COUNT).ToList();
            var classificationColumns = Enumerable.Range(emotionData.ColumnCount - CLASSIFICATION_COUNT, CLASSIFICATION_COUNT).ToList();

            // create a new data table with a vector input column and a vector output column
            var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder();

            dataTableBuilder.AddColumn(ColumnType.Vector, "Attributes");
            dataTableBuilder.AddColumn(ColumnType.Vector, "Target", isTarget: true);
            emotionData.ForEach(row => {
                var input  = FloatVector.Create(row.GetFields <float>(attributeColumns).ToArray());
                var target = FloatVector.Create(row.GetFields <float>(classificationColumns).ToArray());
                dataTableBuilder.Add(input, target);
                return(true);
            });
            var data = dataTableBuilder.Build().Split(0);

            // train a neural network
            using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) {
                var graph = new GraphFactory(lap);

                // binary classification rounds each output to 0 or 1 and compares each output against the binary classification targets
                var errorMetric = graph.ErrorMetric.BinaryClassification;

                // configure the network properties
                graph.CurrentPropertySet
                .Use(graph.GradientDescent.Adam)
                .Use(graph.WeightInitialisation.Xavier)
                ;

                // create a training engine
                const float TRAINING_RATE = 0.3f;
                var         trainingData  = graph.CreateDataSource(data.Training);
                var         testData      = trainingData.CloneWith(data.Test);
                var         engine        = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 128);

                // build the network
                const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 2000;
                var       network = graph.Connect(engine)
                                    .AddFeedForward(HIDDEN_LAYER_SIZE)
                                    .Add(graph.SigmoidActivation())
                                    .AddDropOut(dropOutPercentage: 0.5f)
                                    .AddFeedForward(engine.DataSource.OutputSize)
                                    .Add(graph.SigmoidActivation())
                                    .AddBackpropagation(errorMetric)
                ;

                // train the network
                Models.ExecutionGraph bestGraph = null;
                engine.Train(TRAINING_ITERATIONS, testData, errorMetric, model => bestGraph = model.Graph, 50);

                // export the final model and execute it on the training set
                var executionEngine = graph.CreateEngine(bestGraph ?? engine.Graph);
                var output          = executionEngine.Execute(testData);

                // output the results
                var rowIndex = 0;
                foreach (var item in output)
                {
                    var sb = new StringBuilder();
                    foreach (var classification in item.Output.Zip(item.Target, (o, t) => (Output: o, Target: t)))
                    {
                        var columnIndex = 0;
                        sb.AppendLine($"{rowIndex++}) ");
                        foreach (var column in classification.Output.Data.Zip(classification.Target.Data,
                                                                              (o, t) => (Output: o, Target: t)))
                        {
                            var prediction = column.Output >= 0.5f ? "true" : "false";
                            var actual     = column.Target >= 0.5f ? "true" : "false";
                            sb.AppendLine($"\t{columnIndex++}) predicted {prediction} (expected {actual})");
                        }
                    }
                    Console.WriteLine(sb.ToString());
                }
            }
        }