Ejemplo n.º 1
0
        /// <summary>
        ///   Creates a Support Vector Machine and estimate 
        ///   its parameters using a learning algorithm.
        /// </summary>
        /// 
        private void btnRunTraining_Click(object sender, EventArgs e)
        {
            if (dgvTrainingSource.Rows.Count == 0)
            {
                MessageBox.Show("Please load the training data before clicking this button");
                return;
            }

            lbStatus.Text = "Gathering data. This may take a while...";
            Application.DoEvents();



            // Extract inputs and outputs
            int rows = dgvTrainingSource.Rows.Count;
            double[][] input = new double[rows][];
            int[] output = new int[rows];
            for (int i = 0; i < rows; i++)
            {
                input[i] = (double[])dgvTrainingSource.Rows[i].Cells["colTrainingFeatures"].Value;
                output[i] = (int)dgvTrainingSource.Rows[i].Cells["colTrainingLabel"].Value;
            }

            // Create the chosen kernel function 
            // using the user interface parameters
            //
            IKernel kernel = createKernel();

            // Extract training parameters from the interface
            double complexity = (double)numComplexity.Value;
            double tolerance = (double)numTolerance.Value;
            int cacheSize = (int)numCache.Value;
            SelectionStrategy strategy = (SelectionStrategy)cbStrategy.SelectedItem;

            // Create the learning algorithm using the machine and the training data
            var ml = new MulticlassSupportVectorLearning<IKernel>()
            {
                // Configure the learning algorithm
                Learner = (param) => new SequentialMinimalOptimization<IKernel>()
                {
                    Complexity = complexity,
                    Tolerance = tolerance,
                    CacheSize = cacheSize,
                    Strategy = strategy,
                    Kernel = kernel
                }
            };


            lbStatus.Text = "Training the classifiers. This may take a (very) significant amount of time...";
            Application.DoEvents();

            Stopwatch sw = Stopwatch.StartNew();

            // Train the machines. It should take a while.
            ksvm = ml.Learn(input, output);

            // If we created a linear machine, compress the support vectors 
            // into one single parameter vector for increased performance:
            if (ksvm.Kernel is Linear)
            {
                ksvm.Compress();
            }

            sw.Stop();

            double error = new ZeroOneLoss(output)
            {
                Mean = true
            }.Loss(ksvm.Decide(input));


            lbStatus.Text = String.Format(
                "Training complete ({0}ms, {1}er). Click Classify to test the classifiers.",
                sw.ElapsedMilliseconds, error);

            // Update the interface status
            btnClassifyVoting.Enabled = true;
            btnClassifyElimination.Enabled = true;
            btnCalibration.Enabled = true;


            // Populate the information tab with the machines
            dgvMachines.Rows.Clear();
            int k = 1;
            for (int i = 0; i < 10; i++)
            {
                for (int j = 0; j < i; j++, k++)
                {
                    var machine = ksvm[i, j];

                    int sv = machine.SupportVectors == null ? 0 : machine.SupportVectors.Length;

                    int c = dgvMachines.Rows.Add(k, i + "-vs-" + j, sv, machine.Threshold);
                    dgvMachines.Rows[c].Tag = machine;
                }
            }

            // approximate size in bytes = 
            //   number of support vectors * number of doubles in a support vector * size of double
            int bytes = ksvm.SupportVectorUniqueCount * 1024 * sizeof(double);
            float megabytes = bytes / (1024 * 1024);
            lbSize.Text = String.Format("{0} ({1} MB)", ksvm.SupportVectorUniqueCount, megabytes);
        }
        public void RunTest2()
        {
            double[][] inputs =
            {
                new double[] { 0, 1, 1, 0 }, // 0
                new double[] { 0, 1, 0, 0 }, // 0
                new double[] { 0, 0, 1, 0 }, // 0
                new double[] { 0, 1, 1, 0 }, // 0
                new double[] { 0, 1, 0, 0 }, // 0
                new double[] { 1, 0, 0, 0 }, // 1
                new double[] { 1, 0, 0, 0 }, // 1
                new double[] { 1, 0, 0, 1 }, // 1
                new double[] { 0, 0, 0, 1 }, // 1
                new double[] { 0, 0, 0, 1 }, // 1
                new double[] { 1, 1, 1, 1 }, // 2
                new double[] { 1, 0, 1, 1 }, // 2
                new double[] { 1, 1, 0, 1 }, // 2
                new double[] { 0, 1, 1, 1 }, // 2
                new double[] { 1, 1, 1, 1 }, // 2
            };

            int[] outputs =
            {
                0, 0, 0, 0, 0,
                1, 1, 1, 1, 1,
                2, 2, 2, 2, 2,
            };

            IKernel kernel = new Linear();
            var machine = new MulticlassSupportVectorMachine(4, kernel, 3);
            var target = new MulticlassSupportVectorLearning(machine, inputs, outputs);

            target.Algorithm = (svm, classInputs, classOutputs, i, j) =>
                new SequentialMinimalOptimization(svm, classInputs, classOutputs);

            double error1 = target.Run();
            Assert.AreEqual(0, error1);

            int[] actual = new int[outputs.Length];
            var paths = new Decision[outputs.Length][];
            for (int i = 0; i < actual.Length; i++)
            {
                actual[i] = machine.Decide(inputs[i]);
                paths[i] = machine.GetLastDecisionPath();
                Assert.AreEqual(outputs[i], actual[i]);
            }

            var original = (MulticlassSupportVectorMachine)machine.Clone();

            target.Algorithm = (svm, classInputs, classOutputs, i, j) =>
                new ProbabilisticOutputCalibration(svm, classInputs, classOutputs);

            double error2 = target.Run();
            Assert.AreEqual(0, error2);

            int[] actual2 = new int[outputs.Length];
            var paths2 = new Decision[outputs.Length][];
            for (int i = 0; i < actual.Length; i++)
            {
                actual2[i] = machine.Decide(inputs[i]);
                paths2[i] = machine.GetLastDecisionPath();
                Assert.AreEqual(outputs[i], actual[i]);
            }

            var svm21 = machine[2, 1];
            var org21 = original[2, 1];
            var probe = inputs[12];
            var w21 = svm21.Weights;
            var o21 = org21.Weights;
            Assert.IsFalse(w21.IsEqual(o21, rtol: 1e-2));
            bool b = svm21.Decide(probe);
            bool a = org21.Decide(probe);
            Assert.AreEqual(a, b);

            double[][] probabilities = machine.Probabilities(inputs);

            //string str = probabilities.ToString(CSharpJaggedMatrixFormatProvider.InvariantCulture);

            double[][] expected = new double[][]
            {
                new double[] { 0.978013252309678, 0.00665988562670578, 0.015326862063616 },
                new double[] { 0.923373734751393, 0.0433240974867644, 0.033302167761843 },
                new double[] { 0.902265207121918, 0.0651939200306017, 0.0325408728474804 },
                new double[] { 0.978013252309678, 0.00665988562670578, 0.015326862063616 },
                new double[] { 0.923373734751393, 0.0433240974867644, 0.033302167761843 },
                new double[] { 0.0437508203303804, 0.79994737664453, 0.156301803025089 },
                new double[] { 0.0437508203303804, 0.79994737664453, 0.156301803025089 },
                new double[] { 0.0147601290467641, 0.948443224264852, 0.0367966466883842 },
                new double[] { 0.0920231845129213, 0.875878175972548, 0.0320986395145312 },
                new double[] { 0.0920231845129213, 0.875878175972548, 0.0320986395145312 },
                new double[] { 0.00868243281954335, 0.00491075178001821, 0.986406815400439 },
                new double[] { 0.0144769600209954, 0.0552754387307989, 0.930247601248206 },
                new double[] { 0.0144769600209954, 0.0552754387307989, 0.930247601248206 },
                new double[] { 0.0584631682316073, 0.0122104663095354, 0.929326365458857 },
                new double[] { 0.00868243281954335, 0.00491075178001821, 0.986406815400439 } 
            };

            Assert.IsTrue(probabilities.IsEqual(expected, rtol: 1e-8));
        }