Пример #1
0
        /// <summary>Builds a sigmoid model to turn the classifier outputs into probabilities.</summary>
        private LinearClassifier <L, L> FitSigmoid(SVMLightClassifier <L, F> classifier, GeneralDataset <L, F> dataset)
        {
            RVFDataset <L, L> plattDataset = new RVFDataset <L, L>();

            for (int i = 0; i < dataset.Size(); i++)
            {
                RVFDatum <L, F> d      = dataset.GetRVFDatum(i);
                ICounter <L>    scores = classifier.ScoresOf((IDatum <L, F>)d);
                scores.IncrementCount(null);
                plattDataset.Add(new RVFDatum <L, L>(scores, d.Label()));
            }
            LinearClassifierFactory <L, L> factory = new LinearClassifierFactory <L, L>();

            factory.SetPrior(new LogPrior(LogPrior.LogPriorType.Null));
            return(factory.TrainClassifier(plattDataset));
        }
Пример #2
0
        public virtual SVMLightClassifier <L, F> TrainClassifierBasic(GeneralDataset <L, F> dataset)
        {
            IIndex <L> labelIndex   = dataset.LabelIndex();
            IIndex <F> featureIndex = dataset.featureIndex;
            bool       multiclass   = (dataset.NumClasses() > 2);

            try
            {
                // this is the file that the model will be saved to
                File modelFile = File.CreateTempFile("svm-", ".model");
                if (deleteTempFilesOnExit)
                {
                    modelFile.DeleteOnExit();
                }
                // this is the file that the svm light formated dataset
                // will be printed to
                File dataFile = File.CreateTempFile("svm-", ".data");
                if (deleteTempFilesOnExit)
                {
                    dataFile.DeleteOnExit();
                }
                // print the dataset
                PrintWriter pw = new PrintWriter(new FileWriter(dataFile));
                dataset.PrintSVMLightFormat(pw);
                pw.Close();
                // -v 0 makes it not verbose
                // -m 400 gives it a larger cache, for faster training
                string cmd = (multiclass ? svmStructLearn : (useSVMPerf ? svmPerfLearn : svmLightLearn)) + " -v " + svmLightVerbosity + " -m 400 ";
                // set the value of C if we have one specified
                if (C > 0.0)
                {
                    cmd = cmd + " -c " + C + " ";
                }
                else
                {
                    // C value
                    if (useSVMPerf)
                    {
                        cmd = cmd + " -c " + 0.01 + " ";
                    }
                }
                //It's required to specify this parameter for SVM perf
                // Alpha File
                if (useAlphaFile)
                {
                    File newAlphaFile = File.CreateTempFile("svm-", ".alphas");
                    if (deleteTempFilesOnExit)
                    {
                        newAlphaFile.DeleteOnExit();
                    }
                    cmd = cmd + " -a " + newAlphaFile.GetAbsolutePath();
                    if (alphaFile != null)
                    {
                        cmd = cmd + " -y " + alphaFile.GetAbsolutePath();
                    }
                    alphaFile = newAlphaFile;
                }
                // File and Model Data
                cmd = cmd + " " + dataFile.GetAbsolutePath() + " " + modelFile.GetAbsolutePath();
                if (verbose)
                {
                    logger.Info("<< " + cmd + " >>");
                }

                /*Process p = Runtime.getRuntime().exec(cmd);
                 *
                 * p.waitFor();
                 *
                 * if (p.exitValue() != 0) throw new RuntimeException("Error Training SVM Light exit value: " + p.exitValue());
                 * p.destroy();   */
                SystemUtils.Run(new ProcessBuilder(whitespacePattern.Split(cmd)), new PrintWriter(System.Console.Error), new PrintWriter(System.Console.Error));
                if (doEval)
                {
                    File predictFile = File.CreateTempFile("svm-", ".pred");
                    if (deleteTempFilesOnExit)
                    {
                        predictFile.DeleteOnExit();
                    }
                    string evalCmd = (multiclass ? svmStructClassify : (useSVMPerf ? svmPerfClassify : svmLightClassify)) + " " + dataFile.GetAbsolutePath() + " " + modelFile.GetAbsolutePath() + " " + predictFile.GetAbsolutePath();
                    if (verbose)
                    {
                        logger.Info("<< " + evalCmd + " >>");
                    }
                    SystemUtils.Run(new ProcessBuilder(whitespacePattern.Split(evalCmd)), new PrintWriter(System.Console.Error), new PrintWriter(System.Console.Error));
                }
                // read in the model file
                Pair <double, ClassicCounter <int> > weightsAndThresh = ReadModel(modelFile, multiclass);
                double threshold = weightsAndThresh.First();
                ClassicCounter <Pair <F, L> > weights    = ConvertWeights(weightsAndThresh.Second(), featureIndex, labelIndex, multiclass);
                ClassicCounter <L>            thresholds = new ClassicCounter <L>();
                if (!multiclass)
                {
                    thresholds.SetCount(labelIndex.Get(0), -threshold);
                    thresholds.SetCount(labelIndex.Get(1), threshold);
                }
                SVMLightClassifier <L, F> classifier = new SVMLightClassifier <L, F>(weights, thresholds);
                if (doEval)
                {
                    File predictFile = File.CreateTempFile("svm-", ".pred2");
                    if (deleteTempFilesOnExit)
                    {
                        predictFile.DeleteOnExit();
                    }
                    PrintWriter  pw2 = new PrintWriter(predictFile);
                    NumberFormat nf  = NumberFormat.GetNumberInstance();
                    nf.SetMaximumFractionDigits(5);
                    foreach (IDatum <L, F> datum in dataset)
                    {
                        ICounter <L> scores = classifier.ScoresOf(datum);
                        pw2.Println(Counters.ToString(scores, nf));
                    }
                    pw2.Close();
                }
                if (useSigmoid)
                {
                    if (verbose)
                    {
                        System.Console.Out.Write("fitting sigmoid...");
                    }
                    classifier.SetPlatt(FitSigmoid(classifier, dataset));
                    if (verbose)
                    {
                        System.Console.Out.WriteLine("done");
                    }
                }
                return(classifier);
            }
            catch (Exception e)
            {
                throw new Exception(e);
            }
        }