예제 #1
0
        public void KNN()
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Float, "height");
            dataTable.AddColumn(ColumnType.Int, "weight").IsContinuous    = true;
            dataTable.AddColumn(ColumnType.Int, "foot-size").IsContinuous = true;
            dataTable.AddColumn(ColumnType.String, "gender", true);

            // sample data from: https://en.wikipedia.org/wiki/Naive_Bayes_classifier
            dataTable.Add(6f, 180, 12, "male");
            dataTable.Add(5.92f, 190, 11, "male");
            dataTable.Add(5.58f, 170, 12, "male");
            dataTable.Add(5.92f, 165, 10, "male");
            dataTable.Add(5f, 100, 6, "female");
            dataTable.Add(5.5f, 150, 8, "female");
            dataTable.Add(5.42f, 130, 7, "female");
            dataTable.Add(5.75f, 150, 9, "female");
            var index = dataTable.Build();

            var testData = BrightWireProvider.CreateDataTableBuilder(dataTable.Columns);
            var row      = testData.Add(6f, 130, 8, "?");

            var model          = index.TrainKNearestNeighbours();
            var classifier     = model.CreateClassifier(_lap, 2);
            var classification = classifier.Classify(row);

            Assert.IsTrue(classification.OrderByDescending(c => c.Weight).First().Label == "female");
        }
예제 #2
0
        /// <summary>
        /// Trains a linear regression model to predict bicycle sharing patterns
        /// Files can be downloaded from https://archive.ics.uci.edu/ml/machine-learning-databases/00275/
        /// </summary>
        /// <param name="dataFilePath">The path to the csv file</param>
        public static void PredictBicyclesWithLinearModel(string dataFilePath)
        {
            var dataTable = _LoadBicyclesDataTable(dataFilePath);
            var split     = dataTable.Split(0);

            using var lap = BrightWireProvider.CreateLinearAlgebra(false);
            var trainer   = split.Training.CreateLinearRegressionTrainer(lap);
            int iteration = 0;
            var theta     = trainer.GradientDescent(500, 0.000025f, 0f, cost =>
            {
                if (iteration++ % 20 == 0)
                {
                    Console.WriteLine(cost);
                }
                return(true);
            });

            Console.WriteLine(theta.Theta);
            var testData  = split.Test.GetNumericRows(Enumerable.Range(0, dataTable.ColumnCount - 1));
            var predictor = theta.CreatePredictor(lap);
            int index     = 0;

            foreach (var row in testData)
            {
                var prediction = predictor.Predict(row);
                var actual     = split.Test.GetRow(index++).GetField <float>(split.Test.TargetColumnIndex);
            }
        }
예제 #3
0
        public void TestIndexHydration()
        {
            using (var dataStream = new MemoryStream())
                using (var indexStream = new MemoryStream()) {
                    var builder = BrightWireProvider.CreateDataTableBuilder(dataStream);
                    builder.AddColumn(ColumnType.Boolean, "target", true);
                    builder.AddColumn(ColumnType.Int, "val");
                    builder.AddColumn(ColumnType.String, "label");
                    for (var i = 0; i < 33000; i++)
                    {
                        builder.Add(i % 2 == 0, i, i.ToString());
                    }

                    var table = builder.Build();
                    builder.WriteIndexTo(indexStream);

                    dataStream.Seek(0, SeekOrigin.Begin);
                    indexStream.Seek(0, SeekOrigin.Begin);
                    var newTable = BrightWireProvider.CreateDataTable(dataStream, indexStream);
                    _CompareTables(table, newTable);

                    dataStream.Seek(0, SeekOrigin.Begin);
                    var newTable2 = BrightWireProvider.CreateDataTable(dataStream, null);
                    _CompareTables(table, newTable2);
                }
        }
예제 #4
0
        static void ManyToOne()
        {
            var grammar   = new SequenceClassification(dictionarySize: 10, minSize: 5, maxSize: 5, noRepeat: true, isStochastic: false);
            var sequences = grammar.GenerateSequences().Take(1000).ToList();
            var builder   = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Matrix, "Sequence");
            builder.AddColumn(ColumnType.Vector, "Summary");

            foreach (var sequence in sequences)
            {
                var list    = new List <FloatVector>();
                var charSet = new HashSet <char>();
                foreach (var ch in sequence)
                {
                    charSet.Add(ch);
                    var row = grammar.Encode(charSet.Select(ch2 => (ch2, 1f)));
                    list.Add(row);
                }
                builder.Add(new FloatMatrix {
                    Row = list.ToArray()
                }, list.Last());
            }
            var data = builder.Build().Split(0);

            using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) {
                var graph       = new GraphFactory(lap);
                var errorMetric = graph.ErrorMetric.BinaryClassification;

                // create the property set
                var propertySet = graph.CurrentPropertySet
                                  .Use(graph.GradientDescent.RmsProp)
                                  .Use(graph.WeightInitialisation.Xavier)
                ;

                // create the engine
                var trainingData = graph.CreateDataSource(data.Training);
                var testData     = trainingData.CloneWith(data.Test);
                var engine       = graph.CreateTrainingEngine(trainingData, 0.03f, 8);

                // build the network
                const int HIDDEN_LAYER_SIZE = 128;
                var       memory            = new float[HIDDEN_LAYER_SIZE];
                var       network           = graph.Connect(engine)
                                              .AddLstm(memory)
                                              //.AddSimpleRecurrent(graph.ReluActivation(), memory)
                                              .AddFeedForward(engine.DataSource.OutputSize)
                                              .Add(graph.SigmoidActivation())
                                              .AddBackpropagationThroughTime(errorMetric)
                ;

                engine.Train(10, testData, errorMetric);

                var networkGraph    = engine.Graph;
                var executionEngine = graph.CreateEngine(networkGraph);

                var output = executionEngine.Execute(testData);
                Console.WriteLine(output.Where(o => o.Target != null).Average(o => o.CalculateError(errorMetric)));
            }
        }
예제 #5
0
        public void TestColumnTypes()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Boolean, "boolean");
            builder.AddColumn(ColumnType.Byte, "byte");
            builder.AddColumn(ColumnType.Date, "date");
            builder.AddColumn(ColumnType.Double, "double");
            builder.AddColumn(ColumnType.Float, "float");
            builder.AddColumn(ColumnType.Int, "int");
            builder.AddColumn(ColumnType.Long, "long");
            builder.AddColumn(ColumnType.Null, "null");
            builder.AddColumn(ColumnType.String, "string");

            var now = DateTime.Now;

            builder.Add(true, (byte)100, now, 1.0 / 3, 0.5f, int.MaxValue, long.MaxValue, null, "test");
            var dataTable = builder.Build();

            var firstRow = dataTable.GetRow(0);

            Assert.AreEqual(firstRow.GetField <bool>(0), true);
            Assert.AreEqual(firstRow.GetField <byte>(1), 100);
            Assert.AreEqual(firstRow.GetField <DateTime>(2), now);
            Assert.AreEqual(firstRow.GetField <double>(3), 1.0 / 3);
            Assert.AreEqual(firstRow.GetField <float>(4), 0.5f);
            Assert.AreEqual(firstRow.GetField <int>(5), int.MaxValue);
            Assert.AreEqual(firstRow.GetField <long>(6), long.MaxValue);
            Assert.AreEqual(firstRow.GetField <object>(7), null);
            Assert.AreEqual(firstRow.GetField <string>(8), "test");
        }
예제 #6
0
        public static void IrisClustering()
        {
            // download the iris data set
            byte[] data;
            using (var client = new WebClient()) {
                data = client.DownloadData("https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data");
            }

            // parse the iris CSV into a data table
            var dataTable = new StreamReader(new MemoryStream(data)).ParseCSV(',');

            // the last column is the classification target ("Iris-setosa", "Iris-versicolor", or "Iris-virginica")
            var targetColumnIndex = dataTable.TargetColumnIndex = dataTable.ColumnCount - 1;
            var featureColumns    = Enumerable.Range(0, 4).ToList();

            // convert the data table to vectors
            using (var lap = BrightWireProvider.CreateLinearAlgebra()) {
                var rows   = dataTable.GetNumericRows(featureColumns).Select(r => lap.CreateVector(r)).ToList();
                var labels = rows.Zip(dataTable.GetColumn <string>(targetColumnIndex), (r, l) => Tuple.Create(r, l)).ToDictionary(d => d.Item1, d => d.Item2);

                Console.WriteLine("Hierachical Clustering...");
                _WriteClusters(rows.HierachicalCluster(3), labels);
                Console.WriteLine();

                Console.WriteLine("K Means Clustering...");
                _WriteClusters(rows.KMeans(3), labels);
                Console.WriteLine();
            }
        }
예제 #7
0
        public void TestFloatConverter()
        {
            var converter = BrightWireProvider.CreateTypeConverter(float.NaN);

            Assert.IsFalse(float.IsNaN((float)converter.ConvertValue("45.5").ConvertedValue));
            Assert.IsTrue(float.IsNaN((float)converter.ConvertValue("sdf").ConvertedValue));
        }
예제 #8
0
        public void TableSummarise()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Boolean, "boolean");
            builder.AddColumn(ColumnType.Byte, "byte");
            builder.AddColumn(ColumnType.Date, "date");
            builder.AddColumn(ColumnType.Double, "double");
            builder.AddColumn(ColumnType.Float, "float");
            builder.AddColumn(ColumnType.Int, "int");
            builder.AddColumn(ColumnType.Long, "long");
            builder.AddColumn(ColumnType.String, "string");

            var now = DateTime.Now;

            builder.Add(true, (sbyte)100, now, 1.0 / 2, 0.5f, int.MaxValue, long.MaxValue, "test");
            builder.Add(true, (sbyte)0, now, 0.0, 0f, int.MinValue, long.MinValue, "test");
            var dataTable = builder.Build();

            var summarisedRow = dataTable.Summarise(1).GetRow(0);

            Assert.AreEqual(summarisedRow.GetField <bool>(0), true);
            Assert.AreEqual(summarisedRow.GetField <sbyte>(1), (sbyte)50);
            Assert.AreEqual(summarisedRow.GetField <double>(3), 0.25);
            Assert.AreEqual(summarisedRow.GetField <string>(7), "test");
        }
예제 #9
0
        public void SelectColumns()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.AddColumn(ColumnType.String, "cls2");

            builder.Add(0.5f, 1.1, "a", "a2");
            builder.Add(0.2f, 1.5, "b", "b2");
            builder.Add(0.7f, 0.5, "c", "c2");
            builder.Add(0.2f, 0.6, "d", "d2");

            var table  = builder.Build();
            var table2 = table.SelectColumns(new[] { 1, 2, 3 });

            Assert.AreEqual(table2.TargetColumnIndex, 1);
            Assert.AreEqual(table2.RowCount, 4);
            Assert.AreEqual(table2.ColumnCount, 3);

            var column = table2.GetNumericColumns(new[] { 0 }).Select(r => _lap.CreateVector(r)).First().AsIndexable();

            Assert.AreEqual(column[0], 1.1f);
            Assert.AreEqual(column[1], 1.5f);
        }
예제 #10
0
        public void TestMultinomialLogisticRegression()
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Float, "height");
            dataTable.AddColumn(ColumnType.Int, "weight");
            dataTable.AddColumn(ColumnType.Int, "foot-size");
            dataTable.AddColumn(ColumnType.String, "gender", true);

            // sample data from: https://en.wikipedia.org/wiki/Naive_Bayes_classifier
            dataTable.Add(6f, 180, 12, "male");
            dataTable.Add(5.92f, 190, 11, "male");
            dataTable.Add(5.58f, 170, 12, "male");
            dataTable.Add(5.92f, 165, 10, "male");
            dataTable.Add(5f, 100, 6, "female");
            dataTable.Add(5.5f, 150, 8, "female");
            dataTable.Add(5.42f, 130, 7, "female");
            dataTable.Add(5.75f, 150, 9, "female");
            var index = dataTable.Build();

            var testData = BrightWireProvider.CreateDataTableBuilder(dataTable.Columns);
            var row      = testData.Add(6f, 130, 8, "?");

            var model          = index.TrainMultinomialLogisticRegression(_lap, 100, 0.1f);
            var classifier     = model.CreateClassifier(_lap);
            var classification = classifier.Classify(row);

            Assert.IsTrue(classification.GetBestClassification() == "female");
        }
예제 #11
0
        public void DeserialiseVectorisationModel()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.String, "label");
            builder.AddColumn(ColumnType.String, "output", true);

            builder.Add("a", "0");
            builder.Add("b", "0");
            builder.Add("c", "1");

            var dataTable  = builder.Build();
            var vectoriser = dataTable.GetVectoriser();
            var model      = vectoriser.GetVectorisationModel();

            var vectorList = new List <FloatVector>();

            dataTable.ForEach(row => vectorList.Add(vectoriser.GetInput(row)));

            var vectoriser2 = dataTable.GetVectoriser(model);
            var vectorList2 = new List <FloatVector>();

            dataTable.ForEach(row => vectorList2.Add(vectoriser2.GetInput(row)));

            foreach (var item in vectorList.Zip(vectorList2, (v1, v2) => (v1, v2)))
            {
                _AssertEqual(item.Item1.Data, item.Item2.Data);
            }
        }
예제 #12
0
        public void TrainModel3()
        {
            var trainer = BrightWireProvider.CreateMarkovTrainer3 <string>();

            _Train(trainer);
            var model = trainer.Build().AsDictionary;

            // generate some text
            var    rand = new Random();
            string prevPrev = default(string), prev = default(string), curr = default(string);
            var    output = new List <string>();

            for (var i = 0; i < 1024; i++)
            {
                var transitions  = model.GetTransitions(prevPrev, prev, curr);
                var distribution = new Categorical(transitions.Select(d => Convert.ToDouble(d.Probability)).ToArray());
                var next         = transitions[distribution.Sample()].NextState;
                output.Add(next);
                if (SimpleTokeniser.IsEndOfSentence(next))
                {
                    break;
                }
                prevPrev = prev;
                prev     = curr;
                curr     = next;
            }
            Assert.IsTrue(output.Count < 1024);
        }
예제 #13
0
        public void TestRegression()
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Float, "value");
            dataTable.AddColumn(ColumnType.Float, "result", true);

            // simple linear relationship: result is twice value
            dataTable.Add(1f, 2f);
            dataTable.Add(2f, 4f);
            dataTable.Add(4f, 8f);
            dataTable.Add(8f, 16f);
            var index = dataTable.Build();

            var classifier = index.CreateLinearRegressionTrainer(_lap);
            //var theta = classifier.Solve();
            //var predictor = theta.CreatePredictor(_lap);

            //var prediction = predictor.Predict(3f);
            //Assert.IsTrue(Math.Round(prediction) == 6f);

            var theta      = classifier.GradientDescent(20, 0.01f, 0.1f, cost => true);
            var predictor  = theta.CreatePredictor(_lap);
            var prediction = predictor.Predict(3f);

            Assert.IsTrue(Math.Round(prediction) == 6f);

            var prediction3 = predictor.Predict(new[] {
                new float[] { 10f },
                new float[] { 3f }
            });

            Assert.IsTrue(Math.Round(prediction3[1]) == 6f);
        }
예제 #14
0
        /// <summary>
        /// Trains a linear regression model to predict bicycle sharing patterns
        /// Files can be downloaded from https://archive.ics.uci.edu/ml/machine-learning-databases/00275/
        /// </summary>
        /// <param name="dataFilePath">The path to the csv file</param>
        public static void PredictBicyclesWithNeuralNetwork(string dataFilePath)
        {
            var dataTable = _LoadBicyclesDataTable(dataFilePath);
            var split     = dataTable.Split(0);

            using (var lap = BrightWireProvider.CreateLinearAlgebra(false))
            {
                var graph        = new GraphFactory(lap);
                var errorMetric  = graph.ErrorMetric.Quadratic;
                var trainingData = graph.CreateDataSource(split.Training);
                var testData     = trainingData.CloneWith(split.Test);
                graph.CurrentPropertySet
                .Use(graph.Adam())
                ;

                var engine = graph.CreateTrainingEngine(trainingData, 1.3f, 128);
                graph.Connect(engine)
                .AddFeedForward(16)
                .Add(graph.SigmoidActivation())
                //.AddDropOut(dropOutPercentage: 0.5f)
                .AddFeedForward(engine.DataSource.OutputSize)
                //.Add(graph.SigmoidActivation())
                .AddBackpropagation(errorMetric)
                ;

                engine.Train(500, testData, errorMetric);
            }
        }
예제 #15
0
        public void DefaultDataSource()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "val3");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.Add(0.5f, 1.1, "d", "a");
            builder.Add(0.2f, 1.5, "c", "b");
            builder.Add(0.7f, 0.5, "b", "c");
            builder.Add(0.2f, 0.6, "a", "d");
            var table          = builder.Build();
            var vectoriser     = table.GetVectoriser();
            var graph          = new GraphFactory(_lap);
            var dataSource     = graph.CreateDataSource(table, vectoriser);
            var miniBatch      = dataSource.Get(null, new[] { 1 });
            var input          = miniBatch.CurrentSequence.Input[0].GetMatrix().Row(0).AsIndexable();
            var expectedOutput = miniBatch.CurrentSequence.Target.GetMatrix().Row(0).AsIndexable();

            Assert.AreEqual(input[0], 0.2f);
            Assert.AreEqual(input[1], 1.5f);
            Assert.AreEqual(expectedOutput.Count, 4);
            Assert.AreEqual(vectoriser.GetOutputLabel(2, expectedOutput.MaximumIndex()), "b");
        }
예제 #16
0
        /// <summary>
        /// Builds a n-gram based language model and generates new text from the model
        /// </summary>
        public static void MarkovChains()
        {
            // tokenise the novel "The Beautiful and the Damned" by F. Scott Fitzgerald
            List <IReadOnlyList <string> > sentences;

            using (var client = new WebClient())
            {
                var data = client.DownloadString("http://www.gutenberg.org/cache/epub/9830/pg9830.txt");
                var pos  = data.IndexOf("CHAPTER I");
                sentences = SimpleTokeniser.FindSentences(SimpleTokeniser.Tokenise(data.Substring(pos))).
                            ToList();
            }

            // create a markov trainer that uses a window of size 3
            var trainer = BrightWireProvider.CreateMarkovTrainer3 <string>();

            foreach (var sentence in sentences)
            {
                trainer.Add(sentence);
            }
            var model = trainer.Build().AsDictionary;

            // generate some text
            var rand = new Random();

            for (var i = 0; i < 50; i++)
            {
                var    sb = new StringBuilder();
                string prevPrev = default, prev = default, curr = default;
예제 #17
0
        public void TrainModel2()
        {
            var trainer = BrightWireProvider.CreateMarkovTrainer2 <string>();

            _Train(trainer);

            // test serialisation/deserialisation
            using (var buffer = new MemoryStream()) {
                trainer.SerialiseTo(buffer);
                buffer.Position = 0;
                trainer.DeserialiseFrom(buffer, true);
            }
            var dictionary = trainer.Build().AsDictionary;

            // generate some text
            var    rand = new Random();
            string prev = default(string), curr = default(string);
            var    output = new List <string>();

            for (var i = 0; i < 1024; i++)
            {
                var transitions  = dictionary.GetTransitions(prev, curr);
                var distribution = new Categorical(transitions.Select(d => Convert.ToDouble(d.Probability)).ToArray());
                var next         = transitions[distribution.Sample()].NextState;
                output.Add(next);
                if (SimpleTokeniser.IsEndOfSentence(next))
                {
                    break;
                }
                prev = curr;
                curr = next;
            }
            Assert.IsTrue(output.Count < 1024);
        }
예제 #18
0
        public static void TrainWithSelu(string dataFilesPath)
        {
            using (var lap = BrightWireProvider.CreateLinearAlgebra()) {
                var graph = new GraphFactory(lap);

                // parse the iris CSV into a data table and normalise
                var dataTable = new StreamReader(new MemoryStream(File.ReadAllBytes(dataFilesPath))).ParseCSV(',').Normalise(NormalisationType.Standard);

                // split the data table into training and test tables
                var split        = dataTable.Split(0);
                var trainingData = graph.CreateDataSource(split.Training);
                var testData     = graph.CreateDataSource(split.Test);

                // one hot encoding uses the index of the output vector's maximum value as the classification label
                var errorMetric = graph.ErrorMetric.OneHotEncoding;

                // configure the network properties
                graph.CurrentPropertySet
                .Use(graph.GradientDescent.RmsProp)
                .Use(graph.GaussianWeightInitialisation(true, 0.1f, GaussianVarianceCalibration.SquareRoot2N, GaussianVarianceCount.FanInFanOut))
                ;

                // create the training engine and schedule a training rate change
                const float TRAINING_RATE = 0.1f;
                var         engine        = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, batchSize: 128);

                const int LAYER_SIZE = 64;

                Func <INode> activation = () => new SeluActivation();
                //Func<INode> activation = () => graph.ReluActivation();

                // create the network with the custom activation function
                graph.Connect(engine)
                .AddFeedForward(LAYER_SIZE)
                .AddBatchNormalisation()
                .Add(activation())
                .AddFeedForward(LAYER_SIZE)
                .AddBatchNormalisation()
                .Add(activation())
                .AddFeedForward(LAYER_SIZE)
                .AddBatchNormalisation()
                .Add(activation())
                .AddFeedForward(LAYER_SIZE)
                .AddBatchNormalisation()
                .Add(activation())
                .AddFeedForward(LAYER_SIZE)
                .AddBatchNormalisation()
                .Add(activation())
                .AddFeedForward(LAYER_SIZE)
                .AddBatchNormalisation()
                .Add(activation())
                .AddFeedForward(trainingData.OutputSize)
                .Add(graph.SoftMaxActivation())
                .AddBackpropagation(errorMetric)
                ;

                const int TRAINING_ITERATIONS = 500;
                engine.Train(TRAINING_ITERATIONS, testData, errorMetric, null, 50);
            }
        }
예제 #19
0
        public void TestLogisticRegression()
        {
            var dataTable = BrightWireProvider.CreateDataTableBuilder();

            dataTable.AddColumn(ColumnType.Float, "hours");
            dataTable.AddColumn(ColumnType.Boolean, "pass", true);

            // sample data from: https://en.wikipedia.org/wiki/Logistic_regression
            dataTable.Add(0.5f, false);
            dataTable.Add(0.75f, false);
            dataTable.Add(1f, false);
            dataTable.Add(1.25f, false);
            dataTable.Add(1.5f, false);
            dataTable.Add(1.75f, false);
            dataTable.Add(1.75f, true);
            dataTable.Add(2f, false);
            dataTable.Add(2.25f, true);
            dataTable.Add(2.5f, false);
            dataTable.Add(2.75f, true);
            dataTable.Add(3f, false);
            dataTable.Add(3.25f, true);
            dataTable.Add(3.5f, false);
            dataTable.Add(4f, true);
            dataTable.Add(4.25f, true);
            dataTable.Add(4.5f, true);
            dataTable.Add(4.75f, true);
            dataTable.Add(5f, true);
            dataTable.Add(5.5f, true);
            var index = dataTable.Build();

            var trainer      = index.CreateLogisticRegressionTrainer(_lap);
            var theta        = trainer.GradientDescent(1000, 0.1f);
            var predictor    = theta.CreatePredictor(_lap);
            var probability1 = predictor.Predict(2f);

            Assert.IsTrue(probability1 < 0.5f);

            var probability2 = predictor.Predict(4f);

            Assert.IsTrue(probability2 >= 0.5f);

            var probability3 = predictor.Predict(new[] {
                new float[] { 1f },
                new float[] { 2f },
                new float[] { 3f },
                new float[] { 4f },
                new float[] { 5f },
            });

            Assert.IsTrue(probability3[0] <= 0.5f);
            Assert.IsTrue(probability3[1] <= 0.5f);
            Assert.IsTrue(probability3[2] >= 0.5f);
            Assert.IsTrue(probability3[3] >= 0.5f);
            Assert.IsTrue(probability3[4] >= 0.5f);
        }
예제 #20
0
        IDataTable _GetSimpleTable()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Int, "val");
            for (var i = 0; i < 10000; i++)
            {
                builder.Add(i);
            }
            return(builder.Build());
        }
예제 #21
0
        static void OneToMany()
        {
            var grammar = new SequenceGenerator(dictionarySize: 10, minSize: 5, maxSize: 5,
                                                noRepeat: true, isStochastic: false);
            var sequences = grammar.GenerateSequences().Take(1000).ToList();
            var builder   = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Vector, "Summary");
            builder.AddColumn(ColumnType.Matrix, "Sequence");
            foreach (var sequence in sequences)
            {
                var sequenceData = sequence.GroupBy(ch => ch).Select(g => (g.Key, g.Count())).
                                   ToDictionary(d => d.Item1, d => (float)d.Item2);
                var summary = grammar.Encode(sequenceData.Select(kv => (kv.Key, kv.Value)));
                var list    = new List <FloatVector>();
                foreach (var item in sequenceData.OrderBy(kv => kv.Key))
                {
                    var row = grammar.Encode(item.Key, item.Value);
                    list.Add(row);
                }

                builder.Add(summary, FloatMatrix.Create(list.ToArray()));
            }

            var data = builder.Build().Split(0);

            using var lap = BrightWireProvider.CreateLinearAlgebra(false);
            var graph       = new GraphFactory(lap);
            var errorMetric = graph.ErrorMetric.BinaryClassification;

            // create the property set
            var propertySet = graph.CurrentPropertySet.Use(graph.GradientDescent.RmsProp).
                              Use(graph.WeightInitialisation.Xavier);

            // create the engine
            const float TRAINING_RATE = 0.1f;
            var         trainingData  = graph.CreateDataSource(data.Training);
            var         testData      = trainingData.CloneWith(data.Test);
            var         engine        = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 8);

            engine.LearningContext.ScheduleLearningRate(30, TRAINING_RATE / 3);

            // build the network
            const int HIDDEN_LAYER_SIZE = 128;

            graph.Connect(engine).AddLstm(HIDDEN_LAYER_SIZE).AddFeedForward(engine.DataSource.OutputSize).
            Add(graph.SigmoidActivation()).AddBackpropagation(errorMetric);
            engine.Train(40, testData, errorMetric);
            var networkGraph    = engine.Graph;
            var executionEngine = graph.CreateEngine(networkGraph);
            var output          = executionEngine.Execute(testData);

            Console.WriteLine(output.Average(o => o.CalculateError(errorMetric)));
        }
예제 #22
0
        public static void Load()
        {
            _cpu             = BrightWireProvider.CreateLinearAlgebra(false);
            var(graph, data) = MakeGraphAndData();
            var engine      = graph.CreateTrainingEngine(data);
            var errorMetric = new CustomErrorMetric();

            graph.Connect(engine).AddFeedForward(1).Add(graph.SigmoidActivation()).
            AddBackpropagation(errorMetric);
            engine.Train(300, data, errorMetric, bn => bestNetwork = bn);
            AssertEngineGetsGoodResults(engine, data);
        }
예제 #23
0
        /// <summary>
        /// Trains a feed forward neural net on the emotion dataset
        /// http://lpis.csd.auth.gr/publications/tsoumakas-ismir08.pdf
        /// The data files can be downloaded from https://downloads.sourceforge.net/project/mulan/datasets/emotions.rar
        /// </summary>
        /// <param name="dataFilePath"></param>
        public static void MultiLabelSingleClassifier(string dataFilePath)
        {
            var emotionData           = _LoadEmotionData(dataFilePath);
            var attributeColumns      = Enumerable.Range(0, emotionData.ColumnCount - CLASSIFICATION_COUNT).ToList();
            var classificationColumns = Enumerable.Range(emotionData.ColumnCount - CLASSIFICATION_COUNT, CLASSIFICATION_COUNT).ToList();

            // create a new data table with a vector input column and a vector output column
            var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder();

            dataTableBuilder.AddColumn(ColumnType.Vector, "Attributes");
            dataTableBuilder.AddColumn(ColumnType.Vector, "Target", isTarget: true);
            emotionData.ForEach(row => {
                var input  = FloatVector.Create(row.GetFields <float>(attributeColumns).ToArray());
                var target = FloatVector.Create(row.GetFields <float>(classificationColumns).ToArray());
                dataTableBuilder.Add(input, target);
                return(true);
            });
            var data = dataTableBuilder.Build().Split(0);

            // train a neural network
            using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) {
                var graph = new GraphFactory(lap);

                // binary classification rounds each output to 0 or 1 and compares each output against the binary classification targets
                var errorMetric = graph.ErrorMetric.BinaryClassification;

                // configure the network properties
                graph.CurrentPropertySet
                .Use(graph.GradientDescent.Adam)
                .Use(graph.WeightInitialisation.Xavier)
                ;

                // create a training engine
                const float TRAINING_RATE = 0.3f;
                var         trainingData  = graph.CreateDataSource(data.Training);
                var         testData      = trainingData.CloneWith(data.Test);
                var         engine        = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 128);

                // build the network
                const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 2000;
                var       network = graph.Connect(engine)
                                    .AddFeedForward(HIDDEN_LAYER_SIZE)
                                    .Add(graph.SigmoidActivation())
                                    .AddDropOut(dropOutPercentage: 0.5f)
                                    .AddFeedForward(engine.DataSource.OutputSize)
                                    .Add(graph.SigmoidActivation())
                                    .AddBackpropagation(errorMetric)
                ;

                // train the network
                engine.Train(TRAINING_ITERATIONS, testData, errorMetric, null, 50);
            }
        }
예제 #24
0
        /// <summary>
        /// Builds a n-gram based language model and generates new text from the model
        /// </summary>
        public static void MarkovChains()
        {
            // tokenise the novel "The Beautiful and the Damned" by F. Scott Fitzgerald
            List <IReadOnlyList <string> > sentences;

            using (var client = new WebClient()) {
                var data = client.DownloadString("http://www.gutenberg.org/cache/epub/9830/pg9830.txt");
                var pos  = data.IndexOf("CHAPTER I");
                sentences = SimpleTokeniser.FindSentences(SimpleTokeniser.Tokenise(data.Substring(pos))).ToList();
            }

            // create a markov trainer that uses a window of size 3
            var trainer = BrightWireProvider.CreateMarkovTrainer3 <string>();

            foreach (var sentence in sentences)
            {
                trainer.Add(sentence);
            }
            var model = trainer.Build().AsDictionary;

            // generate some text
            var rand = new Random();

            for (var i = 0; i < 50; i++)
            {
                var    sb = new StringBuilder();
                string prevPrev = default(string), prev = default(string), curr = default(string);
                for (var j = 0; j < 256; j++)
                {
                    var transitions  = model.GetTransitions(prevPrev, prev, curr);
                    var distribution = new Categorical(transitions.Select(d => Convert.ToDouble(d.Probability)).ToArray());
                    var next         = transitions[distribution.Sample()].NextState;
                    if (Char.IsLetterOrDigit(next[0]) && sb.Length > 0)
                    {
                        var lastChar = sb[sb.Length - 1];
                        if (lastChar != '\'' && lastChar != '-')
                        {
                            sb.Append(' ');
                        }
                    }
                    sb.Append(next);

                    if (SimpleTokeniser.IsEndOfSentence(next))
                    {
                        break;
                    }
                    prevPrev = prev;
                    prev     = curr;
                    curr     = next;
                }
                Console.WriteLine(sb.ToString());
            }
        }
예제 #25
0
        public void TableConfusionMatrix()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.String, "actual");
            builder.AddColumn(ColumnType.String, "expected");

            const int CAT_CAT       = 5;
            const int CAT_DOG       = 2;
            const int DOG_CAT       = 3;
            const int DOG_DOG       = 5;
            const int DOG_RABBIT    = 2;
            const int RABBIT_DOG    = 1;
            const int RABBIT_RABBIT = 11;

            for (var i = 0; i < CAT_CAT; i++)
            {
                builder.Add("cat", "cat");
            }
            for (var i = 0; i < CAT_DOG; i++)
            {
                builder.Add("cat", "dog");
            }
            for (var i = 0; i < DOG_CAT; i++)
            {
                builder.Add("dog", "cat");
            }
            for (var i = 0; i < DOG_DOG; i++)
            {
                builder.Add("dog", "dog");
            }
            for (var i = 0; i < DOG_RABBIT; i++)
            {
                builder.Add("dog", "rabbit");
            }
            for (var i = 0; i < RABBIT_DOG; i++)
            {
                builder.Add("rabbit", "dog");
            }
            for (var i = 0; i < RABBIT_RABBIT; i++)
            {
                builder.Add("rabbit", "rabbit");
            }
            var table           = builder.Build();
            var confusionMatrix = table.CreateConfusionMatrix(1, 0);
            var xml             = confusionMatrix.AsXml;

            Assert.AreEqual((uint)CAT_DOG, confusionMatrix.GetCount("cat", "dog"));
            Assert.AreEqual((uint)DOG_RABBIT, confusionMatrix.GetCount("dog", "rabbit"));
            Assert.AreEqual((uint)RABBIT_RABBIT, confusionMatrix.GetCount("rabbit", "rabbit"));
        }
예제 #26
0
        public void TestTargetColumnIndex()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.String, "a");
            builder.AddColumn(ColumnType.String, "b", true);
            builder.AddColumn(ColumnType.String, "c");
            builder.Add("a", "b", "c");
            var table = builder.Build();

            Assert.AreEqual(table.TargetColumnIndex, 1);
            Assert.AreEqual(table.RowCount, 1);
            Assert.AreEqual(table.ColumnCount, 3);
        }
예제 #27
0
 public static int TestItems(string modelPath, Bitmap testImage)
 {
     using (var lap = BrightWireProvider.CreateLinearAlgebra(false))
     {
         var     graph       = new GraphFactory(lap);
         DataSet testDataset = BuildTestSet(graph, testImage);
         var     errorMetric = graph.ErrorMetric.OneHotEncoding;
         var     config      = new NetworkConfig();
         config.ERROR_METRIC = errorMetric;
         var engine          = LoadTestingNetwork(modelPath, graph);
         var executionEngine = graph.CreateEngine(engine.Graph);
         var output          = executionEngine.Execute(testDataset.TestData);
         return(GetLargestPercent(output[0]));
     }
 }
예제 #28
0
 public static float TrainCNN(string dataFolderPath, string outputModelPath)
 {
     using (var lap = BrightWireProvider.CreateLinearAlgebra(false))
     {
         var graph       = new GraphFactory(lap);
         var dataset     = CreateDataset(graph, dataFolderPath);
         var errorMetric = graph.ErrorMetric.OneHotEncoding;
         var config      = new NetworkConfig();
         config.ERROR_METRIC = errorMetric;
         var engine          = BuildNetwork(config, graph, dataset, outputModelPath);
         var bestGraph       = TrainModel(engine, config, dataset, outputModelPath);
         var executionEngine = graph.CreateEngine(bestGraph ?? engine.Graph);
         var output          = executionEngine.Execute(dataset.TestData);
         return(output.Average(o => o.CalculateError(errorMetric)));
     }
 }
예제 #29
0
        public void TiedAutoEncoder()
        {
            const int DATA_SIZE = 1000, REDUCED_SIZE = 200;

            // create some random data
            var rand    = new Random();
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddVectorColumn(DATA_SIZE, "Input");
            builder.AddVectorColumn(DATA_SIZE, "Output", true);
            for (var i = 0; i < 100; i++)
            {
                var vector = new FloatVector {
                    Data = Enumerable.Range(0, DATA_SIZE).Select(j => Convert.ToSingle(rand.NextDouble())).ToArray()
                };
                builder.Add(vector, vector);
            }
            var dataTable = builder.Build();

            // build the autoencoder with tied weights
            var graph       = new GraphFactory(_lap);
            var dataSource  = graph.CreateDataSource(dataTable);
            var engine      = graph.CreateTrainingEngine(dataSource, 0.03f, 32);
            var errorMetric = graph.ErrorMetric.Quadratic;

            graph.CurrentPropertySet
            .Use(graph.RmsProp())
            .Use(graph.WeightInitialisation.Xavier)
            ;

            graph.Connect(engine)
            .AddFeedForward(REDUCED_SIZE, "layer")
            .Add(graph.TanhActivation())
            .AddTiedFeedForward(engine.Start.FindByName("layer") as IFeedForward)
            .Add(graph.TanhActivation())
            .AddBackpropagation(errorMetric)
            ;
            using (var executionContext = graph.CreateExecutionContext()) {
                for (var i = 0; i < 2; i++)
                {
                    var trainingError = engine.Train(executionContext);
                }
            }
            var networkGraph    = engine.Graph;
            var executionEngine = graph.CreateEngine(networkGraph);
            var results         = executionEngine.Execute(dataTable.GetRow(0).GetField <FloatVector>(0).Data);
        }
예제 #30
0
        public void TableFilter()
        {
            var builder = BrightWireProvider.CreateDataTableBuilder();

            builder.AddColumn(ColumnType.Float, "val1");
            builder.AddColumn(ColumnType.Double, "val2");
            builder.AddColumn(ColumnType.String, "cls", true);
            builder.Add(0.5f, 1.1, "a");
            builder.Add(0.2f, 1.5, "b");
            builder.Add(0.7f, 0.5, "c");
            builder.Add(0.2f, 0.6, "d");
            var table          = builder.Build();
            var projectedTable = table.Project(r => r.GetField <string>(2) == "b" ? null : r.Data);

            Assert.AreEqual(projectedTable.ColumnCount, table.ColumnCount);
            Assert.AreEqual(projectedTable.RowCount, 3);
        }