示例#1
0
        public static void Main(string[] args)
        {
            // Create a training set
            IList <IDatum <string, string> > trainingData = new List <IDatum <string, string> >();

            trainingData.Add(MakeStopLights(Green, Red));
            trainingData.Add(MakeStopLights(Green, Red));
            trainingData.Add(MakeStopLights(Green, Red));
            trainingData.Add(MakeStopLights(Red, Green));
            trainingData.Add(MakeStopLights(Red, Green));
            trainingData.Add(MakeStopLights(Red, Green));
            trainingData.Add(MakeStopLights(Red, Red));
            // Create a test set
            IDatum <string, string> workingLights = MakeStopLights(Green, Red);
            IDatum <string, string> brokenLights  = MakeStopLights(Red, Red);
            // Build a classifier factory
            LinearClassifierFactory <string, string> factory = new LinearClassifierFactory <string, string>();

            factory.UseConjugateGradientAscent();
            // Turn on per-iteration convergence updates
            factory.SetVerbose(true);
            //Small amount of smoothing
            factory.SetSigma(10.0);
            // Build a classifier
            LinearClassifier <string, string> classifier = factory.TrainClassifier(trainingData);

            // Check out the learned weights
            classifier.Dump();
            // Test the classifier
            System.Console.Out.WriteLine("Working instance got: " + classifier.ClassOf(workingLights));
            classifier.JustificationOf(workingLights);
            System.Console.Out.WriteLine("Broken instance got: " + classifier.ClassOf(brokenLights));
            classifier.JustificationOf(brokenLights);
        }
        public static void TestDataset()
        {
            Dataset <string, string> data = new Dataset <string, string>();

            data.Add(new BasicDatum <string, string>(Arrays.AsList(new string[] { "fever", "cough", "congestion" }), "cold"));
            data.Add(new BasicDatum <string, string>(Arrays.AsList(new string[] { "fever", "cough", "nausea" }), "flu"));
            data.Add(new BasicDatum <string, string>(Arrays.AsList(new string[] { "cough", "congestion" }), "cold"));
            // data.summaryStatistics();
            NUnit.Framework.Assert.AreEqual(4, data.NumFeatures());
            NUnit.Framework.Assert.AreEqual(4, data.NumFeatureTypes());
            NUnit.Framework.Assert.AreEqual(2, data.NumClasses());
            NUnit.Framework.Assert.AreEqual(8, data.NumFeatureTokens());
            NUnit.Framework.Assert.AreEqual(3, data.Size());
            data.ApplyFeatureCountThreshold(2);
            NUnit.Framework.Assert.AreEqual(3, data.NumFeatures());
            NUnit.Framework.Assert.AreEqual(3, data.NumFeatureTypes());
            NUnit.Framework.Assert.AreEqual(2, data.NumClasses());
            NUnit.Framework.Assert.AreEqual(7, data.NumFeatureTokens());
            NUnit.Framework.Assert.AreEqual(3, data.Size());
            //Dataset data = Dataset.readSVMLightFormat(args[0]);
            //double[] scores = data.getInformationGains();
            //System.out.println(ArrayMath.mean(scores));
            //System.out.println(ArrayMath.variance(scores));
            LinearClassifierFactory <string, string> factory    = new LinearClassifierFactory <string, string>();
            LinearClassifier <string, string>        classifier = factory.TrainClassifier(data);
            IDatum <string, string> d = new BasicDatum <string, string>(Arrays.AsList(new string[] { "cough", "fever" }));

            NUnit.Framework.Assert.AreEqual("Classification incorrect", "flu", classifier.ClassOf(d));
            ICounter <string> probs = classifier.ProbabilityOf(d);

            NUnit.Framework.Assert.AreEqual("Returned probability incorrect", 0.4553, probs.GetCount("cold"), 0.0001);
            NUnit.Framework.Assert.AreEqual("Returned probability incorrect", 0.5447, probs.GetCount("flu"), 0.0001);
            System.Console.Out.WriteLine();
        }
示例#3
0
        /// <exception cref="System.IO.IOException"/>
        /// <exception cref="System.TypeLoadException"/>
        private static void DemonstrateSerialization()
        {
            System.Console.Out.WriteLine();
            System.Console.Out.WriteLine("Demonstrating working with a serialized classifier");
            ColumnDataClassifier         cdc = new ColumnDataClassifier(where + "examples/cheese2007.prop");
            IClassifier <string, string> cl  = cdc.MakeClassifier(cdc.ReadTrainingExamples(where + "examples/cheeseDisease.train"));

            // Exhibit serialization and deserialization working. Serialized to bytes in memory for simplicity
            System.Console.Out.WriteLine();
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ObjectOutputStream    oos  = new ObjectOutputStream(baos);

            oos.WriteObject(cl);
            oos.Close();
            byte[] @object                       = baos.ToByteArray();
            ByteArrayInputStream bais            = new ByteArrayInputStream(@object);
            ObjectInputStream    ois             = new ObjectInputStream(bais);
            LinearClassifier <string, string> lc = ErasureUtils.UncheckedCast(ois.ReadObject());

            ois.Close();
            ColumnDataClassifier cdc2 = new ColumnDataClassifier(where + "examples/cheese2007.prop");

            // We compare the output of the deserialized classifier lc versus the original one cl
            // For both we use a ColumnDataClassifier to convert text lines to examples
            System.Console.Out.WriteLine();
            System.Console.Out.WriteLine("Making predictions with both classifiers");
            foreach (string line in ObjectBank.GetLineIterator(where + "examples/cheeseDisease.test", "utf-8"))
            {
                IDatum <string, string> d  = cdc.MakeDatumFromLine(line);
                IDatum <string, string> d2 = cdc2.MakeDatumFromLine(line);
                System.Console.Out.Printf("%s  =origi=>  %s (%.4f)%n", line, cl.ClassOf(d), cl.ScoresOf(d).GetCount(cl.ClassOf(d)));
                System.Console.Out.Printf("%s  =deser=>  %s (%.4f)%n", line, lc.ClassOf(d2), lc.ScoresOf(d).GetCount(lc.ClassOf(d)));
            }
        }
示例#4
0
        /// <summary>Train a multinomial classifier off of the provided dataset.</summary>
        /// <param name="dataset">The dataset to train the classifier off of.</param>
        /// <returns>A classifier.</returns>
        public static IClassifier <string, string> TrainMultinomialClassifier(GeneralDataset <string, string> dataset, int featureThreshold, double sigma)
        {
            // Set up the dataset and factory
            log.Info("Applying feature threshold (" + featureThreshold + ")...");
            dataset.ApplyFeatureCountThreshold(featureThreshold);
            log.Info("Randomizing dataset...");
            dataset.Randomize(42l);
            log.Info("Creating factory...");
            LinearClassifierFactory <string, string> factory = InitFactory(sigma);

            // Train the final classifier
            log.Info("BEGIN training");
            LinearClassifier <string, string> classifier = factory.TrainClassifier(dataset);

            log.Info("END training");
            // Debug
            KBPRelationExtractor.Accuracy trainAccuracy = new KBPRelationExtractor.Accuracy();
            foreach (IDatum <string, string> datum in dataset)
            {
                string guess = classifier.ClassOf(datum);
                trainAccuracy.Predict(Java.Util.Collections.Singleton(guess), Java.Util.Collections.Singleton(datum.Label()));
            }
            log.Info("Training accuracy:");
            log.Info(trainAccuracy.ToString());
            log.Info(string.Empty);
            // Return the classifier
            return(classifier);
        }