public bool Equals(Model other) { if (ReferenceEquals(null, other)) return false; if (ReferenceEquals(this, other)) return true; return other.bias.Equals(bias) && other.label.SequenceEqual(label) && other.nr_class == nr_class && other.nr_feature == nr_feature && Equals(other.solverType, solverType) && other.w.SequenceEqual(w); }
/** * <p><b>Note: The streams are NOT closed</b></p> */ public static void doPredict(StreamReader reader, StreamWriter writer, Model model) { int correct = 0; int total = 0; double error = 0; double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0; int nr_class = model.getNrClass(); double[] prob_estimates = null; int n; int nr_feature = model.getNrFeature(); if (model.bias >= 0) n = nr_feature + 1; else n = nr_feature; if (flag_predict_probability && !model.isProbabilityModel()) { throw new ArgumentException("probability output is only supported for logistic regression"); } if (flag_predict_probability) { int[] labels = model.getLabels(); prob_estimates = new double[nr_class]; writer.Write("labels"); for (int j = 0; j < nr_class; j++) writer.Write(" {0}", labels[j]); writer.WriteLine(); } String line = null; while ((line = reader.ReadLine()) != null) { List<Feature> x = new List<Feature>(); string[] parts = line.Split(new[]{' ', '\t'}, StringSplitOptions.RemoveEmptyEntries); double target_label; if (parts.Length == 0) { throw new InvalidOperationException("Wrong input format at line " + (total + 1)); } String label = parts[0]; target_label = Linear.atof(label); foreach (var token in parts.Skip(1)) { string[] split = token.Split(':'); if (split.Length < 2) { throw new InvalidOperationException("Wrong input format at line " + (total + 1)); } try { int idx = Linear.atoi(split[0]); double val = Linear.atof(split[1]); // feature indices larger than those in training are not used if (idx <= nr_feature) { Feature node = new Feature(idx, val); x.Add(node); } } catch (FormatException e) { throw new InvalidOperationException("Wrong input format at line " + (total + 1), e); } } if (model.bias >= 0) { Feature node = new Feature(n, model.bias); x.Add(node); } Feature[] nodes = x.ToArray(); double predict_label; if (flag_predict_probability) { Debug.Assert(prob_estimates != null); predict_label = Linear.predictProbability(model, nodes, prob_estimates); Console.Write(predict_label); for (int j = 0; j < model.nr_class; j++) Console.Write(" {0}", prob_estimates[j]); Console.WriteLine(); } else { predict_label = Linear.predict(model, nodes); Console.WriteLine("{0}", predict_label); } if (predict_label == target_label) { ++correct; } error += (predict_label - target_label) * (predict_label - target_label); sump += predict_label; sumt += target_label; sumpp += predict_label * predict_label; sumtt += target_label * target_label; sumpt += predict_label * target_label; ++total; } if (model.solverType.isSupportVectorRegression()) // { Linear.info("Mean squared error = {0} (regression)", error / total); Linear.info("Squared correlation coefficient = {0} (regression)", // ((total * sumpt - sump * sumt) * (total * sumpt - sump * sumt)) / ((total * sumpp - sump * sump) * (total * sumtt - sumt * sumt))); } else { Linear.info("Accuracy = {0} ({1}/{2})", (double)correct / total * 100, correct, total); } }
/** * Writes the model to the modelOutput. * It uses {@link java.util.Locale#ENGLISH} for number formatting. * * <p><b>Note: The modelOutput is closed after reading or in case of an exception.</b></p> */ public static void saveModel(StreamWriter modelOutput, Model model) { int nr_feature = model.nr_feature; int w_size = nr_feature; if (model.bias >= 0) w_size++; int nr_w = model.nr_class; if (model.nr_class == 2 && model.solverType.getId() != SolverType.MCSVM_CS) nr_w = 1; modelOutput.WriteLine("solver_type {0}", model.solverType.Name); modelOutput.WriteLine("nr_class {0}", model.nr_class); if (model.label != null) { modelOutput.Write("label"); for (int i = 0; i < model.nr_class; i++) { modelOutput.Write(" {0}", model.label[i]); } modelOutput.WriteLine(); } modelOutput.WriteLine("nr_feature {0}", nr_feature); modelOutput.WriteLine("bias {0}", model.bias); modelOutput.WriteLine("w"); for (int i = 0; i < w_size; i++) { for (int j = 0; j < nr_w; j++) { double value = model.w[i * nr_w + j]; /** this optimization is the reason for {@link Model#equals(double[], double[])} */ if (value == 0.0) { modelOutput.Write("{0} ", 0); } else { modelOutput.Write("{0} ", value); } } modelOutput.WriteLine(); } modelOutput.Flush(); }
/** * Writes the model to the file with ISO-8859-1 charset. * It uses {@link java.util.Locale#ENGLISH} for number formatting. */ public static void saveModel(FileInfo modelFile, Model model) { using (StreamWriter sw = new StreamWriter(File.OpenWrite(modelFile.FullName), FILE_CHARSET)) { saveModel(sw, model); } }
public static double predictValues(Model model, Feature[] x, double[] dec_values) { int n; if (model.bias >= 0) n = model.nr_feature + 1; else n = model.nr_feature; double[] w = model.w; int nr_w; if (model.nr_class == 2 && model.solverType.getId() != SolverType.MCSVM_CS) nr_w = 1; else nr_w = model.nr_class; for (int i = 0; i < nr_w; i++) dec_values[i] = 0; foreach (Feature lx in x) { int idx = lx.Index; // the dimension of testing data may exceed that of training if (idx <= n) { for (int i = 0; i < nr_w; i++) { dec_values[i] += w[(idx - 1) * nr_w + i] * lx.Value; } } } if (model.nr_class == 2) { if (model.solverType.isSupportVectorRegression()) return dec_values[0]; else return (dec_values[0] > 0) ? model.label[0] : model.label[1]; } else { int dec_max_idx = 0; for (int i = 1; i < model.nr_class; i++) { if (dec_values[i] > dec_values[dec_max_idx]) dec_max_idx = i; } return model.label[dec_max_idx]; } }
/** * @throws IllegalArgumentException if model is not probabilistic (see {@link Model#isProbabilityModel()}) */ public static double predictProbability(Model model, Feature[] x, double[] prob_estimates) { if (!model.isProbabilityModel()) { StringBuilder sb = new StringBuilder("probability output is only supported for logistic regression"); sb.Append(". This is currently only supported by the following solvers: "); int i = 0; foreach (SolverType solverType in SolverType.values()) { if (solverType.isLogisticRegressionSolver()) { if (i++ > 0) { sb.Append(", "); } sb.Append(solverType.Name); } } throw new ArgumentException(sb.ToString()); } int nr_class = model.nr_class; int nr_w; if (nr_class == 2) nr_w = 1; else nr_w = nr_class; double label = predictValues(model, x, prob_estimates); for (int i = 0; i < nr_w; i++) prob_estimates[i] = 1 / (1 + Math.Exp(-prob_estimates[i])); if (nr_class == 2) // for binary classification prob_estimates[1] = 1.0 - prob_estimates[0]; else { double sum = 0; for (int i = 0; i < nr_class; i++) sum += prob_estimates[i]; for (int i = 0; i < nr_class; i++) prob_estimates[i] = prob_estimates[i] / sum; } return label; }
public static double predict(Model model, Feature[] x) { double[] dec_values = new double[model.nr_class]; return predictValues(model, x, dec_values); }
/** * Loads the model from inputReader. * It uses {@link java.util.Locale#ENGLISH} for number formatting. * * <p>Note: The inputReader is <b>NOT closed</b> after reading or in case of an exception.</p> */ public static Model loadModel(StreamReader inputReader) { Model model = new Model(); model.label = null; String line = null; while ((line = inputReader.ReadLine()) != null) { string[] split = line.Split(new[] { ' ', '\t' }, StringSplitOptions.RemoveEmptyEntries); if (split[0].Equals("solver_type")) { SolverType solver = SolverType.values().FirstOrDefault(v => v.Name == split[1]); if (solver == null) { throw new InvalidOperationException("unknown solver type"); } model.solverType = solver; } else if (split[0].Equals("nr_class")) { model.nr_class = atoi(split[1]); int.Parse(split[1]); } else if (split[0].Equals("nr_feature")) { model.nr_feature = atoi(split[1]); } else if (split[0].Equals("bias")) { model.bias = atof(split[1]); } else if (split[0].Equals("w")) { break; } else if (split[0].Equals("label")) { model.label = new int[model.nr_class]; for (int i = 0; i < model.nr_class; i++) { model.label[i] = atoi(split[i + 1]); } } else { throw new InvalidOperationException("unknown text in model file: [" + line + "]"); } } int w_size = model.nr_feature; if (model.bias >= 0) w_size++; int nr_w = model.nr_class; if (model.nr_class == 2 && model.solverType.getId() != SolverType.MCSVM_CS) nr_w = 1; model.w = new double[w_size * nr_w]; char[] buffer = new char[128]; for (int i = 0; i < w_size; i++) { for (int j = 0; j < nr_w; j++) { int b = 0; while (true) { int ch = inputReader.Read(); if (ch == -1) { throw new EndOfStreamException("unexpected EOF"); } if (ch == ' ') { model.w[i * nr_w + j] = atof(new string(buffer, 0, b)); break; } else { buffer[b++] = (char)ch; } } } } return model; }
/** * @throws IllegalArgumentException if the feature nodes of prob are not sorted in ascending order */ public static Model train(Problem prob, Parameter param) { if (prob == null) throw new ArgumentNullException("problem must not be null"); if (param == null) throw new ArgumentNullException("parameter must not be null"); if (prob.n == 0) throw new ArgumentNullException("problem has zero features"); if (prob.l == 0) throw new ArgumentNullException("problem has zero instances"); foreach (Feature[] nodes in prob.x) { int indexBefore = 0; foreach (Feature n_ in nodes) { if (n_.Index <= indexBefore) { throw new ArgumentException("feature nodes must be sorted by index in ascending order"); } indexBefore = n_.Index; } } int l = prob.l; int n = prob.n; int w_size = prob.n; Model model = new Model(); if (prob.bias >= 0) model.nr_feature = n - 1; else model.nr_feature = n; model.solverType = param.solverType; model.bias = prob.bias; if (param.solverType.getId() == SolverType.L2R_L2LOSS_SVR || // param.solverType.getId() == SolverType.L2R_L1LOSS_SVR_DUAL || // param.solverType.getId() == SolverType.L2R_L2LOSS_SVR_DUAL) { model.w = new double[w_size]; model.nr_class = 2; model.label = null; checkProblemSize(n, model.nr_class); train_one(prob, param, model.w, 0, 0); } else { int[] perm = new int[l]; // group training data of the same class GroupClassesReturn rv = groupClasses(prob, perm); int nr_class = rv.nr_class; int[] label = rv.label; int[] start = rv.start; int[] count = rv.count; checkProblemSize(n, nr_class); model.nr_class = nr_class; model.label = new int[nr_class]; for (int i = 0; i < nr_class; i++) model.label[i] = label[i]; // calculate weighted C double[] weighted_C = new double[nr_class]; for (int i = 0; i < nr_class; i++) weighted_C[i] = param.C; for (int i = 0; i < param.getNumWeights(); i++) { int j; for (j = 0; j < nr_class; j++) if (param.weightLabel[i] == label[j]) break; if (j == nr_class) throw new ArgumentException("class label " + param.weightLabel[i] + " specified in weight is not found"); weighted_C[j] *= param.weight[i]; } // constructing the subproblem Feature[][] x = new Feature[l][]; for (int i = 0; i < l; i++) x[i] = prob.x[perm[i]]; Problem sub_prob = new Problem(); sub_prob.l = l; sub_prob.n = n; sub_prob.x = new Feature[sub_prob.l][]; sub_prob.y = new double[sub_prob.l]; for (int k = 0; k < sub_prob.l; k++) sub_prob.x[k] = x[k]; // multi-class svm by Crammer and Singer if (param.solverType.getId() == SolverType.MCSVM_CS) { model.w = new double[n * nr_class]; for (int i = 0; i < nr_class; i++) { for (int j = start[i]; j < start[i] + count[i]; j++) { sub_prob.y[j] = i; } } SolverMCSVM_CS solver = new SolverMCSVM_CS(sub_prob, nr_class, weighted_C, param.eps); solver.solve(model.w); } else { if (nr_class == 2) { model.w = new double[w_size]; int e0 = start[0] + count[0]; int k = 0; for (; k < e0; k++) sub_prob.y[k] = +1; for (; k < sub_prob.l; k++) sub_prob.y[k] = -1; train_one(sub_prob, param, model.w, weighted_C[0], weighted_C[1]); } else { model.w = new double[w_size * nr_class]; double[] w = new double[w_size]; for (int i = 0; i < nr_class; i++) { int si = start[i]; int ei = si + count[i]; int k = 0; for (; k < si; k++) sub_prob.y[k] = -1; for (; k < ei; k++) sub_prob.y[k] = +1; for (; k < sub_prob.l; k++) sub_prob.y[k] = -1; train_one(sub_prob, param, w, weighted_C[i], param.C); for (int j = 0; j < n; j++) model.w[j * nr_class + i] = w[j]; } } } } return model; }