public DAGSVM(ClassificationProblem problem) { this.m_problem = problem; this.m_n = problem.CategoryCount; this.m_k = m_n * (m_n - 1) / 2; this.m_dag = new Node(problem, 0, m_n - 1); }
public LinearLeraningMachine(ClassificationProblem problem) { this.m_problem = problem; this.m_t_set = problem.TrainingSet; //this.m_problem.RetrieveVocabulary(out this.m_voc); this.m_l = m_t_set.Examples.Count; //this.m_weight = new SparseVector(m_voc.Count); }
/// <summary> /// foamliu, 2009/12/21, please make sure you've uncompressed "2_newsgroups.7z" in the "data" folder. /// </summary> /// <returns></returns> private static ClassificationProblem CreateText() { const string DataFolder = @"..\data\2_newsgroups"; ClassificationProblem problem = new ClassificationProblem(); ExampleSet t_set = new ExampleSet(); ExampleSet v_set = new ExampleSet(); CategoryCollection collect = new CategoryCollection(); collect.Add(new Category(+1, "+1")); collect.Add(new Category(-1, "-1")); problem.Dimension = 2; problem.CategoryCollection = collect; DirectoryInfo dataFolder = new DirectoryInfo(DataFolder); DirectoryInfo[] subfolders = dataFolder.GetDirectories(); int count = 0; for (int i = 0; i < subfolders.Count(); i++) { DirectoryInfo categoryFolder = subfolders[i]; int cat = i * 2 - 1; // for all the text files in each category FileInfo[] files = categoryFolder.GetFiles(); count = 0; int trainSetCount = Convert.ToInt32(Constants.TrainingSetRatio * files.Count()); for (int j = 0; j < files.Count(); j++) { FileInfo textFile = files[j]; Example e = new Example(); if (++count < trainSetCount) { t_set.AddExample(e); } else { v_set.AddExample(e); } } } problem.TrainingSet = t_set; problem.ValidationSet = v_set; return problem; }
private static ClassificationProblem CreateChessBoard() { ClassificationProblem problem = new ClassificationProblem(); CategoryCollection collect = new CategoryCollection(); collect.Add(new Category(+1, "+1")); collect.Add(new Category(-1, "-1")); problem.Dimension = 2; problem.CategoryCollection = collect; problem.TrainingSet = GetExamples(collect); problem.ValidationSet = GetExamples(collect); return problem; }
//private bool m_isleaf; public Node(ClassificationProblem problem, int first, int second) { this.m_problem = problem; this.m_first = first; this.m_second = second; this.m_llm = new LinearLeraningMachine(problem); this.m_llm.Train(); if (second > first + 1) { this.m_leftChild = new Node(problem, first + 1, second); this.m_rightChild = new Node(problem, first, second - 1); //this.m_isleaf = false; } else { //this.m_isleaf = true; } }
private SparseVector m_weight; // weight vector #endregion Fields #region Constructors public Binary_SVM_SMO(ClassificationProblem problem) { this.m_problem = problem; this.m_t_set = this.m_problem.TrainingSet; // this.m_problem.RetrieveVocabulary(out this.m_voc); this.m_l = m_t_set.Examples.Count; this.m_alpha = new double[m_l]; this.m_error = new double[m_l]; this.m_kernel = new LinearKernel(); this.m_NonBound = new List<int>(); this.m_rand = new Random(); this.m_weight = new SparseVector(problem.Dimension); // foamliu, 2009/01/12, default values this.m_c = Constants.SVM_C; this.m_eta = Constants.SVM_Eta; this.m_tolerance = Constants.SVM_Tolerance; this.m_epsilon = Constants.SVM_Epsilon; }
public double CrossValidate(ClassificationProblem problem) { ExampleSet t_Set; ExampleSet v_Set; // validation set //Logging.Info("Retrieving training set"); t_Set = problem.TrainingSet; //Logging.Info("Retrieving validation set"); v_Set = problem.ValidationSet; int numExample = v_Set.Examples.Count; int numCorrect = 0; //Logging.Info("Cross Validating on validation set"); foreach (Example example in v_Set.Examples) { ClassificationResult result = new ClassificationResult(); this.PredictText(t_Set, example, ref result); if (result.ResultCategoryId == example.Label.Id) { numCorrect++; } } double correctRatio = 1.0 * numCorrect / numExample; Logger.Info(string.Format("Correct ratio: {0}", correctRatio)); return correctRatio; }
//private void BuildExample(TextExample example, Vocabulary voc, int exampleCount) //{ // int dimension = voc.Count; // SparseVector vector = new SparseVector(dimension); // foreach (string word in example.Tokens.Keys) // { // int pos = voc.GetWordPosition(word); // if (pos == Constants.KEY_NOT_FOUND) // continue; // // phi i(x) = tfi log(idfi) /k // // tfi: number of occurences of the term i in the document x // // idfi: the ratio between the total number of documents and the // // number of documents containing the term // // k: normalisation constant ensuring that ||phi|| = 1 // double phi = example.Tokens[word] * Math.Log(exampleCount / voc.WordExampleOccurMap[word]); // vector.Components.Add(pos, phi); // } // vector.Normalize(); // example.X = vector; //} //private void Preprocess(ClassificationProblem problem) //{ // Vocabulary voc; // problem.RetrieveVocabulary(out voc); // foreach (Category c in problem.CategoryCollection.Collection) // { // foreach (TextExample e in c.Examples) // { // BuildExample(e, voc, problem.ExampleCount); // } // } // m_weight = new SparseVector(voc.Count); //} /// <summary> /// simple on-line algorithm for the 1-norm soft margin: /// training SVMs in the non-bias case. /// </summary> /// <param name="problem"></param> public void Train(ClassificationProblem problem) { ExampleSet t_Set; // training set //Logging.Info("Retrieving training set"); t_Set = problem.TrainingSet; l = t_Set.Examples.Count; //Logging.Info("Preprocessing all the examples"); //this.Preprocess(problem); m_Alpha = new double[l]; m_newalpha = new double[l]; for (int i=0;i<m_Alpha.Length;i++) { m_Alpha[i] = 0.0; } //Logging.Info("Gradient descent"); while (true) { for (int i = 0; i < l; i++) { double temp = 0.0; for (int j=0;j<l;j++) { temp += m_Alpha[j] * t_Set.Examples[j].Label.Id * m_kernel.Compute(t_Set.Examples[i].X, t_Set.Examples[j].X); } m_newalpha[i] = m_Alpha[i] + Constants.SVM_Eta * (1.0 - t_Set.Examples[i].Label.Id * temp); if (m_newalpha[i] < 0.0) { m_newalpha[i] = 0.0; } else if (m_newalpha[i] > Constants.SVM_C) { m_newalpha[i] = Constants.SVM_C; } } this.CopyAlphas(); W = this.CalculateSVM_W(t_Set); if (Math.Abs((W - old_W) / W) < Constants.SVM_Tolerance) { break; } Logger.Info(string.Format("SVM W = {0}", W)); old_W = W; } this.CalculateWeight(t_Set); //this.CalculateB(t_Set); }