예제 #1
0
        public virtual void TrainWeightedData(GeneralDataset <L, F> data, float[] dataWeights)
        {
            //Use LogisticClassifierFactory to train instead.
            if (data.labelIndex.Size() != 2)
            {
                throw new Exception("LogisticClassifier is only for binary classification!");
            }
            IMinimizer <IDiffFunction> minim;
            LogisticObjectiveFunction  lof = null;

            if (data is Dataset <object, object> )
            {
                lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetLabelsArray(), prior, dataWeights);
            }
            else
            {
                if (data is RVFDataset <object, object> )
                {
                    lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetValuesArray(), data.GetLabelsArray(), prior, dataWeights);
                }
            }
            minim        = new QNMinimizer(lof);
            weights      = minim.Minimize(lof, 1e-4, new double[data.NumFeatureTypes()]);
            featureIndex = data.featureIndex;
            classes[0]   = data.labelIndex.Get(0);
            classes[1]   = data.labelIndex.Get(1);
        }
예제 #2
0
        /// <summary>The examples are assumed to be a list of RFVDatum.</summary>
        /// <remarks>
        /// The examples are assumed to be a list of RFVDatum.
        /// The datums are assumed to not contain the zeroes and then they are added to each instance.
        /// </remarks>
        public virtual NaiveBayesClassifier <L, F> TrainClassifier(GeneralDataset <L, F> examples, ICollection <F> featureSet)
        {
            int numFeatures = featureSet.Count;

            int[][] data   = new int[][] {  };
            int[]   labels = new int[examples.Size()];
            labelIndex   = new HashIndex <L>();
            featureIndex = new HashIndex <F>();
            foreach (F feat in featureSet)
            {
                featureIndex.Add(feat);
            }
            for (int d = 0; d < examples.Size(); d++)
            {
                RVFDatum <L, F> datum = examples.GetRVFDatum(d);
                ICounter <F>    c     = datum.AsFeaturesCounter();
                foreach (F feature in c.KeySet())
                {
                    int fNo   = featureIndex.IndexOf(feature);
                    int value = (int)c.GetCount(feature);
                    data[d][fNo] = value;
                }
                labelIndex.Add(datum.Label());
                labels[d] = labelIndex.IndexOf(datum.Label());
            }
            int numClasses = labelIndex.Size();

            return(TrainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex));
        }
 /// <summary>
 /// Trains a
 /// <see cref="IClassifier{L, F}"/>
 /// on a
 /// <see cref="Dataset{L, F}"/>
 /// .
 /// </summary>
 /// <returns>
 /// A
 /// <see cref="IClassifier{L, F}"/>
 /// trained on the data.
 /// </returns>
 public virtual LinearClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data)
 {
     labelIndex   = data.LabelIndex();
     featureIndex = data.FeatureIndex();
     double[][] weights = TrainWeights(data);
     return(new LinearClassifier <L, F>(weights, featureIndex, labelIndex));
 }
예제 #4
0
        public virtual LogisticClassifier <L, F> TrainWeightedData(GeneralDataset <L, F> data, float[] dataWeights)
        {
            if (data is RVFDataset)
            {
                ((RVFDataset <L, F>)data).EnsureRealValues();
            }
            if (data.labelIndex.Size() != 2)
            {
                throw new Exception("LogisticClassifier is only for binary classification!");
            }
            IMinimizer <IDiffFunction> minim;
            LogisticObjectiveFunction  lof = null;

            if (data is Dataset <object, object> )
            {
                lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetLabelsArray(), new LogPrior(LogPrior.LogPriorType.Quadratic), dataWeights);
            }
            else
            {
                if (data is RVFDataset <object, object> )
                {
                    lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetValuesArray(), data.GetLabelsArray(), new LogPrior(LogPrior.LogPriorType.Quadratic), dataWeights);
                }
            }
            minim        = new QNMinimizer(lof);
            weights      = minim.Minimize(lof, 1e-4, new double[data.NumFeatureTypes()]);
            featureIndex = data.featureIndex;
            classes[0]   = data.labelIndex.Get(0);
            classes[1]   = data.labelIndex.Get(1);
            return(new LogisticClassifier <L, F>(weights, featureIndex, classes));
        }
예제 #5
0
 //  public static void main(String[] args) {
 //    List examples = new ArrayList();
 //    String leftLight = "leftLight";
 //    String rightLight = "rightLight";
 //    String broken = "BROKEN";
 //    String ok = "OK";
 //    Counter c1 = new ClassicCounter<>();
 //    c1.incrementCount(leftLight, 0);
 //    c1.incrementCount(rightLight, 0);
 //    RVFDatum d1 = new RVFDatum(c1, broken);
 //    examples.add(d1);
 //    Counter c2 = new ClassicCounter<>();
 //    c2.incrementCount(leftLight, 1);
 //    c2.incrementCount(rightLight, 1);
 //    RVFDatum d2 = new RVFDatum(c2, ok);
 //    examples.add(d2);
 //    Counter c3 = new ClassicCounter<>();
 //    c3.incrementCount(leftLight, 0);
 //    c3.incrementCount(rightLight, 1);
 //    RVFDatum d3 = new RVFDatum(c3, ok);
 //    examples.add(d3);
 //    Counter c4 = new ClassicCounter<>();
 //    c4.incrementCount(leftLight, 1);
 //    c4.incrementCount(rightLight, 0);
 //    RVFDatum d4 = new RVFDatum(c4, ok);
 //    examples.add(d4);
 //    Dataset data = new Dataset(examples.size());
 //    data.addAll(examples);
 //    NaiveBayesClassifier classifier = (NaiveBayesClassifier)
 //        new NaiveBayesClassifierFactory(200, 200, 1.0,
 //              LogPrior.LogPriorType.QUADRATIC.ordinal(),
 //              NaiveBayesClassifierFactory.CL)
 //            .trainClassifier(data);
 //    classifier.print();
 //    //now classifiy
 //    for (int i = 0; i < examples.size(); i++) {
 //      RVFDatum d = (RVFDatum) examples.get(i);
 //      Counter scores = classifier.scoresOf(d);
 //      System.out.println("for datum " + d + " scores are " + scores.toString());
 //      System.out.println(" class is " + Counters.topKeys(scores, 1));
 //      System.out.println(" class should be " + d.label());
 //    }
 //  }
 //    String trainFile = args[0];
 //    String testFile = args[1];
 //    NominalDataReader nR = new NominalDataReader();
 //    Map<Integer, Index<String>> indices = Generics.newHashMap();
 //    List<RVFDatum<String, Integer>> train = nR.readData(trainFile, indices);
 //    List<RVFDatum<String, Integer>> test = nR.readData(testFile, indices);
 //    System.out.println("Constrained conditional likelihood no prior :");
 //    for (int j = 0; j < 100; j++) {
 //      NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(train);
 //      classifier.print();
 //      //now classifiy
 //
 //      float accTrain = classifier.accuracy(train.iterator());
 //      log.info("training accuracy " + accTrain);
 //      float accTest = classifier.accuracy(test.iterator());
 //      log.info("test accuracy " + accTest);
 //
 //    }
 //    System.out.println("Unconstrained conditional likelihood no prior :");
 //    for (int j = 0; j < 100; j++) {
 //      NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.UCL).trainClassifier(train);
 //      classifier.print();
 //      //now classify
 //
 //      float accTrain = classifier.accuracy(train.iterator());
 //      log.info("training accuracy " + accTrain);
 //      float accTest = classifier.accuracy(test.iterator());
 //      log.info("test accuracy " + accTest);
 //    }
 //  }
 public virtual NaiveBayesClassifier <L, F> TrainClassifier(GeneralDataset <L, F> dataset)
 {
     if (dataset is RVFDataset)
     {
         throw new Exception("Not sure if RVFDataset runs correctly in this method. Please update this code if it does.");
     }
     return(TrainClassifier(dataset.GetDataArray(), dataset.labels, dataset.NumFeatures(), dataset.NumClasses(), dataset.labelIndex, dataset.featureIndex));
 }
예제 #6
0
 public CrossValidator(GeneralDataset <L, F> trainData, int kFold)
 {
     originalTrainData = trainData;
     this.kFold        = kFold;
     savedStates       = new CrossValidator.SavedState[kFold];
     for (int i = 0; i < savedStates.Length; i++)
     {
         savedStates[i] = new CrossValidator.SavedState();
     }
 }
예제 #7
0
        public virtual LogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data, double l1reg, double tol, LogPrior prior, bool biased)
        {
            if (data is RVFDataset)
            {
                ((RVFDataset <L, F>)data).EnsureRealValues();
            }
            if (data.labelIndex.Size() != 2)
            {
                throw new Exception("LogisticClassifier is only for binary classification!");
            }
            IMinimizer <IDiffFunction> minim;

            if (!biased)
            {
                LogisticObjectiveFunction lof = null;
                if (data is Dataset <object, object> )
                {
                    lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetLabelsArray(), prior);
                }
                else
                {
                    if (data is RVFDataset <object, object> )
                    {
                        lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetValuesArray(), data.GetLabelsArray(), prior);
                    }
                }
                if (l1reg > 0.0)
                {
                    minim = ReflectionLoading.LoadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg);
                }
                else
                {
                    minim = new QNMinimizer(lof);
                }
                weights = minim.Minimize(lof, tol, new double[data.NumFeatureTypes()]);
            }
            else
            {
                BiasedLogisticObjectiveFunction lof = new BiasedLogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetLabelsArray(), prior);
                if (l1reg > 0.0)
                {
                    minim = ReflectionLoading.LoadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg);
                }
                else
                {
                    minim = new QNMinimizer(lof);
                }
                weights = minim.Minimize(lof, tol, new double[data.NumFeatureTypes()]);
            }
            featureIndex = data.featureIndex;
            classes[0]   = data.labelIndex.Get(0);
            classes[1]   = data.labelIndex.Get(1);
            return(new LogisticClassifier <L, F>(weights, featureIndex, classes));
        }
예제 #8
0
 public GeneralizedExpectationObjectiveFunction(GeneralDataset <L, F> labeledDataset, IList <IDatum <L, F> > unlabeledDataList, IList <F> geFeatures)
 {
     System.Console.Out.WriteLine("Number of labeled examples:" + labeledDataset.size + "\nNumber of unlabeled examples:" + unlabeledDataList.Count);
     System.Console.Out.WriteLine("Number of GE features:" + geFeatures.Count);
     this.numFeatures       = labeledDataset.NumFeatures();
     this.numClasses        = labeledDataset.NumClasses();
     this.labeledDataset    = labeledDataset;
     this.unlabeledDataList = unlabeledDataList;
     this.geFeatures        = geFeatures;
     this.classifier        = new LinearClassifier <L, F>(null, labeledDataset.featureIndex, labeledDataset.labelIndex);
     ComputeEmpiricalStatistics(geFeatures);
 }
예제 #9
0
        /// <summary>
        /// This method will cross validate on the given data and number of folds
        /// to find the optimal C.
        /// </summary>
        /// <remarks>
        /// This method will cross validate on the given data and number of folds
        /// to find the optimal C.  The scorer is how you determine what to
        /// optimize for (F-score, accuracy, etc).  The C is then saved, so that
        /// if you train a classifier after calling this method, that C will be used.
        /// </remarks>
        public virtual void HeldOutSetC(GeneralDataset <L, F> trainSet, GeneralDataset <L, F> devSet, IScorer <L> scorer, ILineSearcher minimizer)
        {
            useAlphaFile = true;
            bool oldUseSigmoid = useSigmoid;

            useSigmoid = false;
            IDoubleUnaryOperator negativeScorer = null;

            C            = minimizer.Minimize(negativeScorer);
            useAlphaFile = false;
            useSigmoid   = oldUseSigmoid;
        }
예제 #10
0
 public virtual SVMLightClassifier <L, F> TrainClassifier(GeneralDataset <L, F> dataset)
 {
     if (tuneHeldOut)
     {
         HeldOutSetC(dataset, heldOutPercent, scorer, tuneMinimizer);
     }
     else
     {
         if (tuneCV)
         {
             CrossValidateSetC(dataset, folds, scorer, tuneMinimizer);
         }
     }
     return(TrainClassifierBasic(dataset));
 }
예제 #11
0
        /// <summary>
        /// This method will cross validate on the given data and number of folds
        /// to find the optimal C.
        /// </summary>
        /// <remarks>
        /// This method will cross validate on the given data and number of folds
        /// to find the optimal C.  The scorer is how you determine what to
        /// optimize for (F-score, accuracy, etc).  The C is then saved, so that
        /// if you train a classifier after calling this method, that C will be used.
        /// </remarks>
        public virtual void CrossValidateSetC(GeneralDataset <L, F> dataset, int numFolds, IScorer <L> scorer, ILineSearcher minimizer)
        {
            System.Console.Out.WriteLine("in Cross Validate");
            useAlphaFile = true;
            bool oldUseSigmoid = useSigmoid;

            useSigmoid = false;
            CrossValidator <L, F> crossValidator = new CrossValidator <L, F>(dataset, numFolds);
            IToDoubleFunction <Triple <GeneralDataset <L, F>, GeneralDataset <L, F>, CrossValidator.SavedState> > score = null;
            //train(trainSet,true,true);
            IDoubleUnaryOperator negativeScorer = null;

            C            = minimizer.Minimize(negativeScorer);
            useAlphaFile = false;
            useSigmoid   = oldUseSigmoid;
        }
예제 #12
0
        public virtual ICounter <L> ScoresOf(IDatum <L, F> example)
        {
            ICounter <L> scores = new ClassicCounter <L>();

            foreach (L label in labelIndex)
            {
                IDictionary <L, string> posLabelMap = new ArrayMap <L, string>();
                posLabelMap[label] = PosLabel;
                IDatum <string, F>      binDatum         = GeneralDataset.MapDatum(example, posLabelMap, NegLabel);
                IClassifier <string, F> binaryClassifier = GetBinaryClassifier(label);
                ICounter <string>       binScores        = binaryClassifier.ScoresOf(binDatum);
                double score = binScores.GetCount(PosLabel);
                scores.SetCount(label, score);
            }
            return(scores);
        }
 public virtual MultinomialLogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> dataset)
 {
     numClasses  = dataset.NumClasses();
     numFeatures = dataset.NumFeatures();
     data        = dataset.GetDataArray();
     if (dataset is RVFDataset <object, object> )
     {
         dataValues = dataset.GetValuesArray();
     }
     else
     {
         dataValues = LogisticUtils.InitializeDataValues(data);
     }
     AugmentFeatureMatrix(data, dataValues);
     labels = dataset.GetLabelsArray();
     return(new MultinomialLogisticClassifier <L, F>(TrainWeights(), dataset.featureIndex, dataset.labelIndex));
 }
예제 #14
0
        public static Edu.Stanford.Nlp.Classify.OneVsAllClassifier <L, F> Train <L, F>(IClassifierFactory <string, F, IClassifier <string, F> > classifierFactory, GeneralDataset <L, F> dataset, ICollection <L> trainLabels)
        {
            IIndex <L> labelIndex   = dataset.LabelIndex();
            IIndex <F> featureIndex = dataset.FeatureIndex();
            IDictionary <L, IClassifier <string, F> > classifiers = Generics.NewHashMap();

            foreach (L label in trainLabels)
            {
                int i = labelIndex.IndexOf(label);
                logger.Info("Training " + label + " = " + i + ", posIndex = " + posIndex);
                // Create training data for training this classifier
                IDictionary <L, string> posLabelMap = new ArrayMap <L, string>();
                posLabelMap[label] = PosLabel;
                GeneralDataset <string, F> binaryDataset    = dataset.MapDataset(dataset, binaryIndex, posLabelMap, NegLabel);
                IClassifier <string, F>    binaryClassifier = classifierFactory.TrainClassifier(binaryDataset);
                classifiers[label] = binaryClassifier;
            }
            Edu.Stanford.Nlp.Classify.OneVsAllClassifier <L, F> classifier = new Edu.Stanford.Nlp.Classify.OneVsAllClassifier <L, F>(featureIndex, labelIndex, classifiers);
            return(classifier);
        }
예제 #15
0
 public AdaptedGaussianPriorObjectiveFunction(GeneralDataset <L, F> dataset, LogPrior prior, double[][] weights)
     : base(dataset, prior)
 {
     this.weights = To1D(weights);
 }
예제 #16
0
 public _IEnumerator_596(GeneralDataset <L, F> _enclosing)
 {
     this._enclosing = _enclosing;
 }
예제 #17
0
 public CrossValidator(GeneralDataset <L, F> trainData)
     : this(trainData, 10)
 {
 }
예제 #18
0
        public static Edu.Stanford.Nlp.Classify.OneVsAllClassifier <L, F> Train <L, F>(IClassifierFactory <string, F, IClassifier <string, F> > classifierFactory, GeneralDataset <L, F> dataset)
        {
            IIndex <L> labelIndex = dataset.LabelIndex();

            return(Train(classifierFactory, dataset, labelIndex.ObjectsList()));
        }
예제 #19
0
 public virtual LogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data, double l1reg, double tol, bool biased)
 {
     return(TrainClassifier(data, l1reg, tol, new LogPrior(LogPrior.LogPriorType.Quadratic), biased));
 }
예제 #20
0
 public virtual LogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data, double l1reg, double tol, LogPrior prior)
 {
     return(TrainClassifier(data, l1reg, tol, prior, false));
 }
예제 #21
0
        /// <summary>Builds a sigmoid model to turn the classifier outputs into probabilities.</summary>
        private LinearClassifier <L, L> FitSigmoid(SVMLightClassifier <L, F> classifier, GeneralDataset <L, F> dataset)
        {
            RVFDataset <L, L> plattDataset = new RVFDataset <L, L>();

            for (int i = 0; i < dataset.Size(); i++)
            {
                RVFDatum <L, F> d      = dataset.GetRVFDatum(i);
                ICounter <L>    scores = classifier.ScoresOf((IDatum <L, F>)d);
                scores.IncrementCount(null);
                plattDataset.Add(new RVFDatum <L, L>(scores, d.Label()));
            }
            LinearClassifierFactory <L, L> factory = new LinearClassifierFactory <L, L>();

            factory.SetPrior(new LogPrior(LogPrior.LogPriorType.Null));
            return(factory.TrainClassifier(plattDataset));
        }
예제 #22
0
 public virtual LogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data, LogPrior prior, bool biased)
 {
     return(TrainClassifier(data, 0.0, 1e-4, prior, biased));
 }
예제 #23
0
 public virtual LogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data)
 {
     return(TrainClassifier(data, 0.0));
 }
 public BiasedLogConditionalObjectiveFunction(GeneralDataset <object, object> dataset, double[][] confusionMatrix, LogPrior prior)
     : this(dataset.NumFeatures(), dataset.NumClasses(), dataset.GetDataArray(), dataset.GetLabelsArray(), confusionMatrix, prior)
 {
 }
 public BiasedLogConditionalObjectiveFunction(GeneralDataset <object, object> dataset, double[][] confusionMatrix)
     : this(dataset, confusionMatrix, new LogPrior(LogPrior.LogPriorType.Quadratic))
 {
 }
 protected internal abstract double[][] TrainWeights(GeneralDataset <L, F> dataset);
예제 #27
0
 public virtual void Train(GeneralDataset <L, F> data)
 {
     //Use LogisticClassifierFactory to train instead.
     Train(data, 0.0, 1e-4);
 }
예제 #28
0
 public virtual LogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data, double l1reg)
 {
     return(TrainClassifier(data, l1reg, 1e-4));
 }
예제 #29
0
        public virtual SVMLightClassifier <L, F> TrainClassifierBasic(GeneralDataset <L, F> dataset)
        {
            IIndex <L> labelIndex   = dataset.LabelIndex();
            IIndex <F> featureIndex = dataset.featureIndex;
            bool       multiclass   = (dataset.NumClasses() > 2);

            try
            {
                // this is the file that the model will be saved to
                File modelFile = File.CreateTempFile("svm-", ".model");
                if (deleteTempFilesOnExit)
                {
                    modelFile.DeleteOnExit();
                }
                // this is the file that the svm light formated dataset
                // will be printed to
                File dataFile = File.CreateTempFile("svm-", ".data");
                if (deleteTempFilesOnExit)
                {
                    dataFile.DeleteOnExit();
                }
                // print the dataset
                PrintWriter pw = new PrintWriter(new FileWriter(dataFile));
                dataset.PrintSVMLightFormat(pw);
                pw.Close();
                // -v 0 makes it not verbose
                // -m 400 gives it a larger cache, for faster training
                string cmd = (multiclass ? svmStructLearn : (useSVMPerf ? svmPerfLearn : svmLightLearn)) + " -v " + svmLightVerbosity + " -m 400 ";
                // set the value of C if we have one specified
                if (C > 0.0)
                {
                    cmd = cmd + " -c " + C + " ";
                }
                else
                {
                    // C value
                    if (useSVMPerf)
                    {
                        cmd = cmd + " -c " + 0.01 + " ";
                    }
                }
                //It's required to specify this parameter for SVM perf
                // Alpha File
                if (useAlphaFile)
                {
                    File newAlphaFile = File.CreateTempFile("svm-", ".alphas");
                    if (deleteTempFilesOnExit)
                    {
                        newAlphaFile.DeleteOnExit();
                    }
                    cmd = cmd + " -a " + newAlphaFile.GetAbsolutePath();
                    if (alphaFile != null)
                    {
                        cmd = cmd + " -y " + alphaFile.GetAbsolutePath();
                    }
                    alphaFile = newAlphaFile;
                }
                // File and Model Data
                cmd = cmd + " " + dataFile.GetAbsolutePath() + " " + modelFile.GetAbsolutePath();
                if (verbose)
                {
                    logger.Info("<< " + cmd + " >>");
                }

                /*Process p = Runtime.getRuntime().exec(cmd);
                 *
                 * p.waitFor();
                 *
                 * if (p.exitValue() != 0) throw new RuntimeException("Error Training SVM Light exit value: " + p.exitValue());
                 * p.destroy();   */
                SystemUtils.Run(new ProcessBuilder(whitespacePattern.Split(cmd)), new PrintWriter(System.Console.Error), new PrintWriter(System.Console.Error));
                if (doEval)
                {
                    File predictFile = File.CreateTempFile("svm-", ".pred");
                    if (deleteTempFilesOnExit)
                    {
                        predictFile.DeleteOnExit();
                    }
                    string evalCmd = (multiclass ? svmStructClassify : (useSVMPerf ? svmPerfClassify : svmLightClassify)) + " " + dataFile.GetAbsolutePath() + " " + modelFile.GetAbsolutePath() + " " + predictFile.GetAbsolutePath();
                    if (verbose)
                    {
                        logger.Info("<< " + evalCmd + " >>");
                    }
                    SystemUtils.Run(new ProcessBuilder(whitespacePattern.Split(evalCmd)), new PrintWriter(System.Console.Error), new PrintWriter(System.Console.Error));
                }
                // read in the model file
                Pair <double, ClassicCounter <int> > weightsAndThresh = ReadModel(modelFile, multiclass);
                double threshold = weightsAndThresh.First();
                ClassicCounter <Pair <F, L> > weights    = ConvertWeights(weightsAndThresh.Second(), featureIndex, labelIndex, multiclass);
                ClassicCounter <L>            thresholds = new ClassicCounter <L>();
                if (!multiclass)
                {
                    thresholds.SetCount(labelIndex.Get(0), -threshold);
                    thresholds.SetCount(labelIndex.Get(1), threshold);
                }
                SVMLightClassifier <L, F> classifier = new SVMLightClassifier <L, F>(weights, thresholds);
                if (doEval)
                {
                    File predictFile = File.CreateTempFile("svm-", ".pred2");
                    if (deleteTempFilesOnExit)
                    {
                        predictFile.DeleteOnExit();
                    }
                    PrintWriter  pw2 = new PrintWriter(predictFile);
                    NumberFormat nf  = NumberFormat.GetNumberInstance();
                    nf.SetMaximumFractionDigits(5);
                    foreach (IDatum <L, F> datum in dataset)
                    {
                        ICounter <L> scores = classifier.ScoresOf(datum);
                        pw2.Println(Counters.ToString(scores, nf));
                    }
                    pw2.Close();
                }
                if (useSigmoid)
                {
                    if (verbose)
                    {
                        System.Console.Out.Write("fitting sigmoid...");
                    }
                    classifier.SetPlatt(FitSigmoid(classifier, dataset));
                    if (verbose)
                    {
                        System.Console.Out.WriteLine("done");
                    }
                }
                return(classifier);
            }
            catch (Exception e)
            {
                throw new Exception(e);
            }
        }
예제 #30
0
        public virtual void HeldOutSetC(GeneralDataset <L, F> train, double percentHeldOut, IScorer <L> scorer, ILineSearcher minimizer)
        {
            Pair <GeneralDataset <L, F>, GeneralDataset <L, F> > data = train.Split(percentHeldOut);

            HeldOutSetC(data.First(), data.Second(), scorer, minimizer);
        }