Exemple #1
0
        public static double Predict(Problem problem, string outputFile, Model model, bool predict_probability, int MaxClassCount = 1)
        {
            int          num          = 0;
            int          num2         = 0;
            double       num3         = 0.0;
            double       num4         = 0.0;
            double       num5         = 0.0;
            double       num6         = 0.0;
            double       num7         = 0.0;
            double       num8         = 0.0;
            StreamWriter streamWriter = (outputFile != null) ? new StreamWriter(outputFile) : null;
            SvmType      svmType      = Procedures.svm_get_svm_type(model);
            int          num9         = Procedures.svm_get_nr_class(model);

            int[]    array  = new int[num9];
            double[] array2 = null;
            if (predict_probability)
            {
                if (svmType == SvmType.EPSILON_SVR || svmType == SvmType.NU_SVR)
                {
                    Console.WriteLine("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=" + Procedures.svm_get_svr_probability(model));
                }
                else
                {
                    Procedures.svm_get_labels(model, array);
                    array2 = new double[num9];
                    if (streamWriter != null)
                    {
                        streamWriter.Write("labels");
                        for (int i = 0; i < num9; i++)
                        {
                            streamWriter.Write(" " + array[i]);
                        }
                        streamWriter.Write("\n");
                    }
                }
            }
            for (int j = 0; j < problem.Count; j++)
            {
                double num10 = problem.Y[j];
                Node[] x     = problem.X[j];
                double num11;
                if (predict_probability && (svmType == SvmType.C_SVC || svmType == SvmType.NU_SVC))
                {
                    num11 = Procedures.svm_predict_probability(model, x, array2);
                    if (streamWriter != null)
                    {
                        streamWriter.Write(num11 + " ");
                        for (int k = 0; k < num9; k++)
                        {
                            streamWriter.Write(array2[k] + " ");
                        }
                        streamWriter.Write("\n");
                    }
                }
                else
                {
                    num11 = Procedures.svm_predict(model, x);
                    if (MaxClassCount == 1)
                    {
                        if (streamWriter != null)
                        {
                            streamWriter.Write(num11 + "\n");
                        }
                    }
                    else
                    {
                        int[] array3 = default(int[]);
                        Procedures.svm_predict_multi(model, x, out array3);
                        List <KeyValuePair <int, int> > list = new List <KeyValuePair <int, int> >(array3.Length);
                        for (int l = 0; l < array3.Length; l++)
                        {
                            list.Add(new KeyValuePair <int, int>(l, array3[l]));
                        }
                        list.Sort((KeyValuePair <int, int> first, KeyValuePair <int, int> second) => - first.Value.CompareTo(second.Value));
                        for (int m = 0; m < Math.Min(MaxClassCount, list.Count); m++)
                        {
                            if (m > 0)
                            {
                                streamWriter.Write('\t');
                            }
                            streamWriter.Write(list[m].Key);
                        }
                        streamWriter.Write("\n");
                    }
                }
                if (num11 == num10)
                {
                    num++;
                }
                num3 += (num11 - num10) * (num11 - num10);
                num4 += num11;
                num5 += num10;
                num6 += num11 * num11;
                num7 += num10 * num10;
                num8 += num11 * num10;
                num2++;
            }
            if (streamWriter != null)
            {
                streamWriter.Close();
            }
            if (svmType != SvmType.EPSILON_SVR && svmType != SvmType.NU_SVR)
            {
                return((double)num / (double)num2);
            }
            return(((double)problem.Count * num8 - num4 * num5) / (Math.Sqrt((double)problem.Count * num6 - num4 * num4) * Math.Sqrt((double)problem.Count * num7 - num5 * num5)));
        }
Exemple #2
0
		private static void parseCommandLine(string[] args, out Parameter parameters, out Problem problem, out bool crossValidation, out int nrfold, out string modelFilename)
		{
			parameters = new Parameter();
			crossValidation = false;
			nrfold = 0;
			int num = 0;
			while (num < args.Length && args[num][0] == '-')
			{
				num++;
				switch (args[num - 1][1])
				{
				case 's':
					parameters.SvmType = (SvmType)int.Parse(args[num]);
					break;
				case 't':
					parameters.KernelType = (KernelType)int.Parse(args[num]);
					break;
				case 'd':
					parameters.Degree = int.Parse(args[num]);
					break;
				case 'g':
					parameters.Gamma = double.Parse(args[num]);
					break;
				case 'r':
					parameters.Coefficient0 = double.Parse(args[num]);
					break;
				case 'n':
					parameters.Nu = double.Parse(args[num]);
					break;
				case 'm':
					parameters.CacheSize = double.Parse(args[num]);
					break;
				case 'c':
					parameters.C = double.Parse(args[num]);
					break;
				case 'e':
					parameters.EPS = double.Parse(args[num]);
					break;
				case 'p':
					parameters.P = double.Parse(args[num]);
					break;
				case 'h':
					parameters.Shrinking = (int.Parse(args[num]) == 1);
					break;
				case 'b':
					parameters.Probability = (int.Parse(args[num]) == 1);
					break;
				case 'v':
					crossValidation = true;
					nrfold = int.Parse(args[num]);
					if (nrfold >= 2)
					{
						break;
					}
					throw new ArgumentException("n-fold cross validation: n must >= 2");
				case 'w':
					parameters.Weights[int.Parse(args[num - 1].Substring(2))] = double.Parse(args[1]);
					break;
				default:
					throw new ArgumentException("Unknown Parameter");
				}
				num++;
			}
			if (num >= args.Length)
			{
				throw new ArgumentException("No input file specified");
			}
			problem = Problem.Read(args[num]);
			if (parameters.Gamma == 0.0)
			{
				parameters.Gamma = 1.0 / (double)problem.MaxIndex;
			}
			if (num < args.Length - 1)
			{
				modelFilename = args[num + 1];
			}
			else
			{
				int startIndex = args[num].LastIndexOf('/') + 1;
				modelFilename = args[num].Substring(startIndex) + ".model";
			}
		}