Esempio n. 1
0
        public void testPredictProbabilityWrongSolver()
        {
            Problem prob = new Problem();

            prob.l = 1;
            prob.n = 1;
            prob.x = new Feature[prob.l][];
            prob.y = new double[prob.l];
            for (int i = 0; i < prob.l; i++)
            {
                prob.x[i] = new Feature[] {};
                prob.y[i] = i;
            }

            SolverType solverType = SolverType.getById(SolverType.L2R_L1LOSS_SVC_DUAL);
            Parameter  param      = new Parameter(solverType, 10, 0.1);
            Model      model      = Linear.train(prob, param);

            try {
                Linear.predictProbability(model, prob.x[0], new double[1]);
                Assert.Fail("IllegalArgumentException expected");
            } catch (ArgumentException e) {
                Assert.AreEqual("probability output is only supported for logistic regression." //
                                + " This is currently only supported by the following solvers:" //
                                + " L2R_LR, L1R_LR, L2R_LR_DUAL", e.Message);
            }
        }
Esempio n. 2
0
        public void testTrainUnsortedProblem()
        {
            Problem prob = new Problem();

            prob.bias = -1;
            prob.l    = 1;
            prob.n    = 2;
            prob.x    = new Feature[4][];
            prob.x[0] = new Feature[2];

            prob.x[0][0] = new Feature(2, 1);
            prob.x[0][1] = new Feature(1, 1);

            prob.y    = new double[4];
            prob.y[0] = 0;

            Parameter param = new Parameter(SolverType.getById(SolverType.L2R_LR), 10, 0.1);

            try {
                Linear.train(prob, param);
                Assert.Fail("ArgumentException expected");
            } catch (ArgumentException e) {
                Assert.IsTrue(e.Message.Contains("nodes"));
                Assert.IsTrue(e.Message.Contains("sorted"));
                Assert.IsTrue(e.Message.Contains("ascending"));
                Assert.IsTrue(e.Message.Contains("order"));
            }
        }
Esempio n. 3
0
        private SolverType GetSolverType(Norm norm, Loss loss, bool dual, Multiclass multiclass)
        {
            if (multiclass == Multiclass.CrammerSinger)
            {
                return(SolverType.getById(SolverType.MCSVM_CS));
            }

            if (multiclass != Multiclass.Ovr)
            {
                throw new ArgumentException("Invalid multiclass value");
            }

            if (norm == Norm.L2 && loss == Loss.LogisticRegression && !dual)
            {
                return(SolverType.getById(SolverType.L2R_LR));
            }

            if (norm == Norm.L2 && loss == Loss.L2 && dual)
            {
                return(SolverType.getById(SolverType.L2R_L2LOSS_SVC_DUAL));
            }

            if (norm == Norm.L2 && loss == Loss.L2 && !dual)
            {
                return(SolverType.getById(SolverType.L2R_L2LOSS_SVC));
            }

            if (norm == Norm.L2 && loss == Loss.L1 && dual)
            {
                return(SolverType.getById(SolverType.L2R_L1LOSS_SVC_DUAL));
            }

            if (norm == Norm.L1 && loss == Loss.L2 && !dual)
            {
                return(SolverType.getById(SolverType.L1R_L2LOSS_SVC));
            }

            if (norm == Norm.L1 && loss == Loss.LogisticRegression && !dual)
            {
                return(SolverType.getById(SolverType.L1R_LR));
            }

            if (norm == Norm.L2 && loss == Loss.LogisticRegression && dual)
            {
                return(SolverType.getById(SolverType.L2R_LR_DUAL));
            }

            throw new ArgumentException("Given combination of penalty, loss, dual params is not supported");
        }
Esempio n. 4
0
        public void testCrossValidation()
        {
            int numClasses = random.Next(10) + 1;

            Problem prob = createRandomProblem(numClasses);

            Parameter param   = new Parameter(SolverType.getById(SolverType.L2R_LR), 10, 0.01);
            int       nr_fold = 10;

            double[] target = new double[prob.l];
            Linear.crossValidation(prob, param, nr_fold, target);

            foreach (double clazz in target)
            {
                Assert.IsTrue(clazz >= 0);
                Assert.IsTrue(clazz <= numClasses);
            }
        }
Esempio n. 5
0
        public static Model createRandomModel()
        {
            Model model = new Model();

            model.solverType = SolverType.getById(SolverType.L2R_LR);
            model.bias       = 2;
            model.label      = new int[] { 1, int.MaxValue, 2 };
            model.w          = new double[model.label.Length * 300];
            for (int i = 0; i < model.w.Length; i++)
            {
                // precision should be at least 1e-4
                model.w[i] = Math.Round(random.NextDouble() * 100000.0) / 10000.0;
            }

            // force at least one value to be zero
            model.w[random.Next(model.w.Length)] = 0.0;
            model.w[random.Next(model.w.Length)] = -0.0;


            model.nr_feature = model.w.Length / model.label.Length - 1;
            model.nr_class   = model.label.Length;
            return(model);
        }
Esempio n. 6
0
 public void setUp()
 {
     _param = new Parameter(SolverType.getById(SolverType.L2R_L1LOSS_SVC_DUAL), 100, 1e-3);
 }