Пример #1
0
        public virtual LogisticClassifier <L, F> TrainWeightedData(GeneralDataset <L, F> data, float[] dataWeights)
        {
            if (data is RVFDataset)
            {
                ((RVFDataset <L, F>)data).EnsureRealValues();
            }
            if (data.labelIndex.Size() != 2)
            {
                throw new Exception("LogisticClassifier is only for binary classification!");
            }
            IMinimizer <IDiffFunction> minim;
            LogisticObjectiveFunction  lof = null;

            if (data is Dataset <object, object> )
            {
                lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetLabelsArray(), new LogPrior(LogPrior.LogPriorType.Quadratic), dataWeights);
            }
            else
            {
                if (data is RVFDataset <object, object> )
                {
                    lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetValuesArray(), data.GetLabelsArray(), new LogPrior(LogPrior.LogPriorType.Quadratic), dataWeights);
                }
            }
            minim        = new QNMinimizer(lof);
            weights      = minim.Minimize(lof, 1e-4, new double[data.NumFeatureTypes()]);
            featureIndex = data.featureIndex;
            classes[0]   = data.labelIndex.Get(0);
            classes[1]   = data.labelIndex.Get(1);
            return(new LogisticClassifier <L, F>(weights, featureIndex, classes));
        }
Пример #2
0
        public virtual void TrainWeightedData(GeneralDataset <L, F> data, float[] dataWeights)
        {
            //Use LogisticClassifierFactory to train instead.
            if (data.labelIndex.Size() != 2)
            {
                throw new Exception("LogisticClassifier is only for binary classification!");
            }
            IMinimizer <IDiffFunction> minim;
            LogisticObjectiveFunction  lof = null;

            if (data is Dataset <object, object> )
            {
                lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetLabelsArray(), prior, dataWeights);
            }
            else
            {
                if (data is RVFDataset <object, object> )
                {
                    lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetValuesArray(), data.GetLabelsArray(), prior, dataWeights);
                }
            }
            minim        = new QNMinimizer(lof);
            weights      = minim.Minimize(lof, 1e-4, new double[data.NumFeatureTypes()]);
            featureIndex = data.featureIndex;
            classes[0]   = data.labelIndex.Get(0);
            classes[1]   = data.labelIndex.Get(1);
        }
Пример #3
0
        public virtual LogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> data, double l1reg, double tol, LogPrior prior, bool biased)
        {
            if (data is RVFDataset)
            {
                ((RVFDataset <L, F>)data).EnsureRealValues();
            }
            if (data.labelIndex.Size() != 2)
            {
                throw new Exception("LogisticClassifier is only for binary classification!");
            }
            IMinimizer <IDiffFunction> minim;

            if (!biased)
            {
                LogisticObjectiveFunction lof = null;
                if (data is Dataset <object, object> )
                {
                    lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetLabelsArray(), prior);
                }
                else
                {
                    if (data is RVFDataset <object, object> )
                    {
                        lof = new LogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetValuesArray(), data.GetLabelsArray(), prior);
                    }
                }
                if (l1reg > 0.0)
                {
                    minim = ReflectionLoading.LoadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg);
                }
                else
                {
                    minim = new QNMinimizer(lof);
                }
                weights = minim.Minimize(lof, tol, new double[data.NumFeatureTypes()]);
            }
            else
            {
                BiasedLogisticObjectiveFunction lof = new BiasedLogisticObjectiveFunction(data.NumFeatureTypes(), data.GetDataArray(), data.GetLabelsArray(), prior);
                if (l1reg > 0.0)
                {
                    minim = ReflectionLoading.LoadByReflection("edu.stanford.nlp.optimization.OWLQNMinimizer", l1reg);
                }
                else
                {
                    minim = new QNMinimizer(lof);
                }
                weights = minim.Minimize(lof, tol, new double[data.NumFeatureTypes()]);
            }
            featureIndex = data.featureIndex;
            classes[0]   = data.labelIndex.Get(0);
            classes[1]   = data.labelIndex.Get(1);
            return(new LogisticClassifier <L, F>(weights, featureIndex, classes));
        }
Пример #4
0
 //  public static void main(String[] args) {
 //    List examples = new ArrayList();
 //    String leftLight = "leftLight";
 //    String rightLight = "rightLight";
 //    String broken = "BROKEN";
 //    String ok = "OK";
 //    Counter c1 = new ClassicCounter<>();
 //    c1.incrementCount(leftLight, 0);
 //    c1.incrementCount(rightLight, 0);
 //    RVFDatum d1 = new RVFDatum(c1, broken);
 //    examples.add(d1);
 //    Counter c2 = new ClassicCounter<>();
 //    c2.incrementCount(leftLight, 1);
 //    c2.incrementCount(rightLight, 1);
 //    RVFDatum d2 = new RVFDatum(c2, ok);
 //    examples.add(d2);
 //    Counter c3 = new ClassicCounter<>();
 //    c3.incrementCount(leftLight, 0);
 //    c3.incrementCount(rightLight, 1);
 //    RVFDatum d3 = new RVFDatum(c3, ok);
 //    examples.add(d3);
 //    Counter c4 = new ClassicCounter<>();
 //    c4.incrementCount(leftLight, 1);
 //    c4.incrementCount(rightLight, 0);
 //    RVFDatum d4 = new RVFDatum(c4, ok);
 //    examples.add(d4);
 //    Dataset data = new Dataset(examples.size());
 //    data.addAll(examples);
 //    NaiveBayesClassifier classifier = (NaiveBayesClassifier)
 //        new NaiveBayesClassifierFactory(200, 200, 1.0,
 //              LogPrior.LogPriorType.QUADRATIC.ordinal(),
 //              NaiveBayesClassifierFactory.CL)
 //            .trainClassifier(data);
 //    classifier.print();
 //    //now classifiy
 //    for (int i = 0; i < examples.size(); i++) {
 //      RVFDatum d = (RVFDatum) examples.get(i);
 //      Counter scores = classifier.scoresOf(d);
 //      System.out.println("for datum " + d + " scores are " + scores.toString());
 //      System.out.println(" class is " + Counters.topKeys(scores, 1));
 //      System.out.println(" class should be " + d.label());
 //    }
 //  }
 //    String trainFile = args[0];
 //    String testFile = args[1];
 //    NominalDataReader nR = new NominalDataReader();
 //    Map<Integer, Index<String>> indices = Generics.newHashMap();
 //    List<RVFDatum<String, Integer>> train = nR.readData(trainFile, indices);
 //    List<RVFDatum<String, Integer>> test = nR.readData(testFile, indices);
 //    System.out.println("Constrained conditional likelihood no prior :");
 //    for (int j = 0; j < 100; j++) {
 //      NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(train);
 //      classifier.print();
 //      //now classifiy
 //
 //      float accTrain = classifier.accuracy(train.iterator());
 //      log.info("training accuracy " + accTrain);
 //      float accTest = classifier.accuracy(test.iterator());
 //      log.info("test accuracy " + accTest);
 //
 //    }
 //    System.out.println("Unconstrained conditional likelihood no prior :");
 //    for (int j = 0; j < 100; j++) {
 //      NaiveBayesClassifier<String, Integer> classifier = new NaiveBayesClassifierFactory<String, Integer>(0.1, 0.01, 0.6, LogPrior.LogPriorType.NULL.ordinal(), NaiveBayesClassifierFactory.UCL).trainClassifier(train);
 //      classifier.print();
 //      //now classify
 //
 //      float accTrain = classifier.accuracy(train.iterator());
 //      log.info("training accuracy " + accTrain);
 //      float accTest = classifier.accuracy(test.iterator());
 //      log.info("test accuracy " + accTest);
 //    }
 //  }
 public virtual NaiveBayesClassifier <L, F> TrainClassifier(GeneralDataset <L, F> dataset)
 {
     if (dataset is RVFDataset)
     {
         throw new Exception("Not sure if RVFDataset runs correctly in this method. Please update this code if it does.");
     }
     return(TrainClassifier(dataset.GetDataArray(), dataset.labels, dataset.NumFeatures(), dataset.NumClasses(), dataset.labelIndex, dataset.featureIndex));
 }
 public virtual MultinomialLogisticClassifier <L, F> TrainClassifier(GeneralDataset <L, F> dataset)
 {
     numClasses  = dataset.NumClasses();
     numFeatures = dataset.NumFeatures();
     data        = dataset.GetDataArray();
     if (dataset is RVFDataset <object, object> )
     {
         dataValues = dataset.GetValuesArray();
     }
     else
     {
         dataValues = LogisticUtils.InitializeDataValues(data);
     }
     AugmentFeatureMatrix(data, dataValues);
     labels = dataset.GetLabelsArray();
     return(new MultinomialLogisticClassifier <L, F>(TrainWeights(), dataset.featureIndex, dataset.labelIndex));
 }
 public BiasedLogConditionalObjectiveFunction(GeneralDataset <object, object> dataset, double[][] confusionMatrix, LogPrior prior)
     : this(dataset.NumFeatures(), dataset.NumClasses(), dataset.GetDataArray(), dataset.GetLabelsArray(), confusionMatrix, prior)
 {
 }
Пример #7
0
 // amount of add-k smoothing of evidence
 // fudge to keep nonzero
 protected internal override double[][] TrainWeights(GeneralDataset <L, F> data)
 {
     return(TrainWeights(data.GetDataArray(), data.GetLabelsArray()));
 }