public void OneVariableLogisticRegressionRegularizedTest(double lambda)
        {
            var revenuePerPopulationFilePath = FileManagerTest.GetTestDataByName("ResultExamsAdmission.txt");
            var reader = new CsvReader <ResultExamsAdmission>();
            var revenuePerPopulationRows = reader.GetData(revenuePerPopulationFilePath, ",");

            var dataMatrix = new Matrix(revenuePerPopulationRows.Select(x => new ResultExamsAdmissionRow(x)));
            var dataY      = dataMatrix.Columns[0].ToMatrix();
            var dataX      = dataMatrix.ToMatrix(new int[2] {
                1, 2
            });

            var logisticRegression           = new LogisticRegression();
            var logisticRegressionParameters = new LogisticRegressionParameters
            {
                IterationCount = 100,
                ThetaInit      = new Matrix(new double[3, 1] {
                    { 0 }, { 0 }, { 0 }
                }),
                X      = dataX,
                Y      = dataY,
                Lambda = lambda
            };

            var output = logisticRegression.Compute(logisticRegressionParameters);

            var expectedOutput = new Matrix(new double[3, 1] {
                { -25.052148050018360 }, { 0.205354461994741 }, { 0.200583555605940 }
            });

            AssertMatrixAreEqual(expectedOutput, output.Theta);
        }
Exemplo n.º 2
0
        public void SimpleNeuralNetworkTest()
        {
            var takeN = 100;

            var xPath        = FileManagerTest.GetTestDataByName("HandwrittenDigit", "dataX.csv");
            var doubleReader = new CsvDoubleReader();
            var xArray       = doubleReader.GetData(xPath, ",");
            var x            = new Matrix(xArray.Take(takeN));

            var yPath     = FileManagerTest.GetTestDataByName("HandwrittenDigit", "datay.csv");
            var intReader = new CsvIntReader();
            var yArray    = intReader.GetData(yPath, ",").Take(takeN).ToArray();
            var y         = MatrixUtils.VectorToBinaryMatrix(yArray.Select(z => z[0]).ToArray(), 10);

            var neuralNetworkParameters = new SimpleNeuralNetworkParameters()
            {
                Alpha           = 1,
                HiddenLayerSize = 25,
                InputLayerSize  = 400,
                IterationCount  = 10,
                LabelCount      = 10,
                Lambda          = 1,
                X = x,
                Y = y,
            };

            var neuralNetwork        = new SimpleNeuralNetwork();
            var neuralNetworkResults = neuralNetwork.Compute(neuralNetworkParameters);

            var prediction = neuralNetwork.Predict(neuralNetworkResults.Theta1, neuralNetworkResults.Theta2, x, x.RowCount);
        }
        public void OneVariableLogisticRegressionTest()
        {
            var revenuePerPopulationFilePath = FileManagerTest.GetTestDataByName("ResultExamsAdmission.txt");
            var reader = new CsvReader <ResultExamsAdmission>();
            var revenuePerPopulationRows = reader.GetData(revenuePerPopulationFilePath, ",");

            var dataMatrix = new Matrix(revenuePerPopulationRows.Select(x => new ResultExamsAdmissionRow(x)));
            var dataY      = dataMatrix.Columns[0].ToMatrix();
            var dataX      = dataMatrix.ToMatrix(new int[2] {
                1, 2
            });

            var logisticRegression           = new LogisticRegression();
            var logisticRegressionParameters = new LogisticRegressionParameters
            {
                IterationCount = 100,
                ThetaInit      = new Matrix(new double[3, 1] {
                    { 0 }, { 0 }, { 0 }
                }),
                X = dataX,
                Y = dataY
            };

            var output = logisticRegression.Compute(logisticRegressionParameters);

            var expectedOutput = new Matrix(new double[3, 1] {
                { -25.161333566639530 }, { 0.206231713293983 }, { 0.201471600441963 }
            });

            AssertMatrixAreEqual(expectedOutput, output.Theta);
        }