/// <summary>
 /// Trains a
 /// <see cref="IClassifier{L, F}"/>
 /// on a
 /// <see cref="Dataset{L, F}"/>
 /// .
 /// </summary>
 /// <returns>
 /// A
 /// <see cref="IClassifier{L, F}"/>
 /// trained on the data.
 /// </returns>
 public virtual LinearClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data)
 {
     labelIndex   = data.LabelIndex();
     featureIndex = data.FeatureIndex();
     double[][] weights = TrainWeights(data);
     return(new LinearClassifier <L, F>(weights, featureIndex, labelIndex));
 }
Ejemplo n.º 2
0
        public static Edu.Stanford.Nlp.Classify.OneVsAllClassifier <L, F> Train <L, F>(IClassifierFactory <string, F, IClassifier <string, F> > classifierFactory, GeneralDataset <L, F> dataset, ICollection <L> trainLabels)
        {
            IIndex <L> labelIndex   = dataset.LabelIndex();
            IIndex <F> featureIndex = dataset.FeatureIndex();
            IDictionary <L, IClassifier <string, F> > classifiers = Generics.NewHashMap();

            foreach (L label in trainLabels)
            {
                int i = labelIndex.IndexOf(label);
                logger.Info("Training " + label + " = " + i + ", posIndex = " + posIndex);
                // Create training data for training this classifier
                IDictionary <L, string> posLabelMap = new ArrayMap <L, string>();
                posLabelMap[label] = PosLabel;
                GeneralDataset <string, F> binaryDataset    = dataset.MapDataset(dataset, binaryIndex, posLabelMap, NegLabel);
                IClassifier <string, F>    binaryClassifier = classifierFactory.TrainClassifier(binaryDataset);
                classifiers[label] = binaryClassifier;
            }
            Edu.Stanford.Nlp.Classify.OneVsAllClassifier <L, F> classifier = new Edu.Stanford.Nlp.Classify.OneVsAllClassifier <L, F>(featureIndex, labelIndex, classifiers);
            return(classifier);
        }
Ejemplo n.º 3
0
        public virtual SVMLightClassifier <L, F> TrainClassifierBasic(GeneralDataset <L, F> dataset)
        {
            IIndex <L> labelIndex   = dataset.LabelIndex();
            IIndex <F> featureIndex = dataset.featureIndex;
            bool       multiclass   = (dataset.NumClasses() > 2);

            try
            {
                // this is the file that the model will be saved to
                File modelFile = File.CreateTempFile("svm-", ".model");
                if (deleteTempFilesOnExit)
                {
                    modelFile.DeleteOnExit();
                }
                // this is the file that the svm light formated dataset
                // will be printed to
                File dataFile = File.CreateTempFile("svm-", ".data");
                if (deleteTempFilesOnExit)
                {
                    dataFile.DeleteOnExit();
                }
                // print the dataset
                PrintWriter pw = new PrintWriter(new FileWriter(dataFile));
                dataset.PrintSVMLightFormat(pw);
                pw.Close();
                // -v 0 makes it not verbose
                // -m 400 gives it a larger cache, for faster training
                string cmd = (multiclass ? svmStructLearn : (useSVMPerf ? svmPerfLearn : svmLightLearn)) + " -v " + svmLightVerbosity + " -m 400 ";
                // set the value of C if we have one specified
                if (C > 0.0)
                {
                    cmd = cmd + " -c " + C + " ";
                }
                else
                {
                    // C value
                    if (useSVMPerf)
                    {
                        cmd = cmd + " -c " + 0.01 + " ";
                    }
                }
                //It's required to specify this parameter for SVM perf
                // Alpha File
                if (useAlphaFile)
                {
                    File newAlphaFile = File.CreateTempFile("svm-", ".alphas");
                    if (deleteTempFilesOnExit)
                    {
                        newAlphaFile.DeleteOnExit();
                    }
                    cmd = cmd + " -a " + newAlphaFile.GetAbsolutePath();
                    if (alphaFile != null)
                    {
                        cmd = cmd + " -y " + alphaFile.GetAbsolutePath();
                    }
                    alphaFile = newAlphaFile;
                }
                // File and Model Data
                cmd = cmd + " " + dataFile.GetAbsolutePath() + " " + modelFile.GetAbsolutePath();
                if (verbose)
                {
                    logger.Info("<< " + cmd + " >>");
                }

                /*Process p = Runtime.getRuntime().exec(cmd);
                 *
                 * p.waitFor();
                 *
                 * if (p.exitValue() != 0) throw new RuntimeException("Error Training SVM Light exit value: " + p.exitValue());
                 * p.destroy();   */
                SystemUtils.Run(new ProcessBuilder(whitespacePattern.Split(cmd)), new PrintWriter(System.Console.Error), new PrintWriter(System.Console.Error));
                if (doEval)
                {
                    File predictFile = File.CreateTempFile("svm-", ".pred");
                    if (deleteTempFilesOnExit)
                    {
                        predictFile.DeleteOnExit();
                    }
                    string evalCmd = (multiclass ? svmStructClassify : (useSVMPerf ? svmPerfClassify : svmLightClassify)) + " " + dataFile.GetAbsolutePath() + " " + modelFile.GetAbsolutePath() + " " + predictFile.GetAbsolutePath();
                    if (verbose)
                    {
                        logger.Info("<< " + evalCmd + " >>");
                    }
                    SystemUtils.Run(new ProcessBuilder(whitespacePattern.Split(evalCmd)), new PrintWriter(System.Console.Error), new PrintWriter(System.Console.Error));
                }
                // read in the model file
                Pair <double, ClassicCounter <int> > weightsAndThresh = ReadModel(modelFile, multiclass);
                double threshold = weightsAndThresh.First();
                ClassicCounter <Pair <F, L> > weights    = ConvertWeights(weightsAndThresh.Second(), featureIndex, labelIndex, multiclass);
                ClassicCounter <L>            thresholds = new ClassicCounter <L>();
                if (!multiclass)
                {
                    thresholds.SetCount(labelIndex.Get(0), -threshold);
                    thresholds.SetCount(labelIndex.Get(1), threshold);
                }
                SVMLightClassifier <L, F> classifier = new SVMLightClassifier <L, F>(weights, thresholds);
                if (doEval)
                {
                    File predictFile = File.CreateTempFile("svm-", ".pred2");
                    if (deleteTempFilesOnExit)
                    {
                        predictFile.DeleteOnExit();
                    }
                    PrintWriter  pw2 = new PrintWriter(predictFile);
                    NumberFormat nf  = NumberFormat.GetNumberInstance();
                    nf.SetMaximumFractionDigits(5);
                    foreach (IDatum <L, F> datum in dataset)
                    {
                        ICounter <L> scores = classifier.ScoresOf(datum);
                        pw2.Println(Counters.ToString(scores, nf));
                    }
                    pw2.Close();
                }
                if (useSigmoid)
                {
                    if (verbose)
                    {
                        System.Console.Out.Write("fitting sigmoid...");
                    }
                    classifier.SetPlatt(FitSigmoid(classifier, dataset));
                    if (verbose)
                    {
                        System.Console.Out.WriteLine("done");
                    }
                }
                return(classifier);
            }
            catch (Exception e)
            {
                throw new Exception(e);
            }
        }
Ejemplo n.º 4
0
        public static Edu.Stanford.Nlp.Classify.OneVsAllClassifier <L, F> Train <L, F>(IClassifierFactory <string, F, IClassifier <string, F> > classifierFactory, GeneralDataset <L, F> dataset)
        {
            IIndex <L> labelIndex = dataset.LabelIndex();

            return(Train(classifierFactory, dataset, labelIndex.ObjectsList()));
        }