public override double CalculateDocumentLogProbability(Corpus testCorpus) { double corpusLogP = 0.0; // Sentence probability = multiplication of all words probabilities foreach (var sentence in testCorpus.AllTokenizedSentences) { // Initialize x_{-1}, x_{-2} to START var u = "<s>"; var v = "<s>"; double logPs = 0.0; foreach (var w in sentence) { double qWuv = ComputeWordProbability(u, v, w); // Add to sentence probability logPs += Math.Log2(qWuv); // Replace previous tokens u = v; v = w; } corpusLogP += logPs; } return(corpusLogP); }
public override void TrainLanguageModel(Corpus trainingCorpus) { // Since we interpolate all three q's for every n-gram, preparing this model means preparing all the base n-gram models this.UnigramLM.TrainLanguageModel(trainingCorpus); this.BigramLM.TrainLanguageModel(trainingCorpus); this.TrigramLM.TrainLanguageModel(trainingCorpus); }
public static void AddStopTokens(Corpus corpus) { foreach (var line in corpus.AllTokenizedSentences) { line.Add("</s>"); } ; corpus.ComputeTotalWordsCount(); }
public static void UnkCorpus(Corpus corpus, HashSet <string> validVocabulary) { foreach (var line in corpus.AllTokenizedSentences) { for (var i = 0; i < line.Count; i++) { if (!validVocabulary.Contains(line[i])) { line[i] = "<unk>"; } } } ; }
public static HashSet <string> GetValidVocabulary(Corpus corpus, double unkRatio) { // Implement a naive unk strategy: unk top n% of lowest count words // TODO parametize the strategy // Order by word count, then alphabetically to ensure determinism TODO I may want to randomize var uniqueSortedTokens = corpus.AllTokenizedSentences.SelectMany(s => s).GroupBy(w => w).OrderByDescending(g => g.Count()).ThenBy(g => g.Key); // Remove n% of words var ratioToKeep = 1.0 - unkRatio; var wordsToKeep = (int)Math.Floor(uniqueSortedTokens.Count() * ratioToKeep); var validVocabulary = uniqueSortedTokens.Take(wordsToKeep).Select(g => g.Key).ToHashSet(); return(validVocabulary); }
public override void TrainLanguageModel(Corpus trainingCorpus) { // We are fine flattening all setences into one big string as every word probability is independent of any previous one, so no risk on wrapping sentences // Enumerate once so we don't keep on doing it later on var flattenedTokenizedAndProcessedSentences = trainingCorpus.AllTokenizedSentences.SelectMany(s => s).ToList(); // Add n-gram counts (used in numerator) this.NGramCounts = flattenedTokenizedAndProcessedSentences.GroupBy(w => w).ToDictionary(g => new Unigram { w = g.Key }.GetComparisonKey(), g => g.Count()); this.UniqueNGramsCount = NGramCounts.Count; // Add n-1-gram counts (used in denominator) // In the case of unigrams, we care about total token counts, including STOP tokens this.NGramCounts[new Unigram { w = string.Empty }.GetComparisonKey()] = trainingCorpus.TotalWordsCount; }
public override void TrainLanguageModel(Corpus trainingCorpus) { // We need to process sentece by sentence to avoid wrapping sentences, ie. counting (STOP, v, w) trigrams foreach (var sentence in trainingCorpus.AllTokenizedSentences) { // Initialize x_{-1}, x_{-2} to START var u = "<s>"; var v = "<s>"; // We now need to store all counts of c(u, v, w) and c(v, w) foreach (var w in sentence) { Bigram uvBigram = new Bigram { v = u, w = v }; Trigram uvwTrigram = new Trigram { u = u, v = v, w = w }; // +1 to current count, current will be 0 if not found, thus starting at 1 as expected this.NGramCounts.TryGetValue(uvBigram.GetComparisonKey(), out int uvCount); uvCount++; this.NGramCounts[uvBigram.GetComparisonKey()] = uvCount; var isNewNgram = !this.NGramCounts.TryGetValue(uvwTrigram.GetComparisonKey(), out int uvwCount); uvwCount++; this.NGramCounts[uvwTrigram.GetComparisonKey()] = uvwCount; if (isNewNgram) { this.UniqueNGramsCount++; } // Replace previous tokens u = v; v = w; } } }
public override double CalculateDocumentLogProbability(Corpus testCorpus) { double corpusLogP = 0.0; // Sentence probability = multiplication of all words probabilities foreach (var sentence in testCorpus.AllTokenizedSentences) { double logPs = 0.0; foreach (var w in sentence) { double qW = ComputeWordProbability(string.Empty, string.Empty, w); // Add to sentence probability logPs += Math.Log2(qW); } corpusLogP += logPs; } return(corpusLogP); }
private static void TrainAllLanguageModels(LanguageModelHyperparameters hyperparameters, string crossValIterationPath, Corpus preProcessedCollectionCorpus) { var stopwatch = new Stopwatch(); var i = 1; foreach (var categoryLanguageModel in hyperparameters.CategoryNGramLanguageModelsMap.Append(new KeyValuePair <string, INGramLanguageModel>("ALLCATEGORIES", hyperparameters.CollectionLevelLanguageModel))) { var category = categoryLanguageModel.Key; var languageModel = categoryLanguageModel.Value; stopwatch.Restart(); Corpus preProcessedCategoryTrainingCorpus; if (category.Equals("ALLCATEGORIES")) { preProcessedCategoryTrainingCorpus = preProcessedCollectionCorpus; } else { preProcessedCategoryTrainingCorpus = new Corpus(); preProcessedCategoryTrainingCorpus.InitializeAndPreprocessCategoryCorpus(Path.Combine(crossValIterationPath, "training"), category, hyperparameters); } TextProcessingUtilities.UnkCorpus(preProcessedCategoryTrainingCorpus, Corpus.ValidVocabulary); TextProcessingUtilities.AddStopTokens(preProcessedCategoryTrainingCorpus); languageModel.TrainLanguageModel(preProcessedCategoryTrainingCorpus); stopwatch.Stop(); //Console.WriteLine($@"LanguageModel for category {category} trained in {stopwatch.ElapsedMilliseconds} ms. {i}/{hyperparameters.CategoryNGramLanguageModelsMap.Count} done"); i++; } }
static void Main(string[] args) { var appConfigName = "app.config"; try { if (!File.Exists(Path.Combine(args[0], appConfigName))) { Console.WriteLine($"{appConfigName} not found in {args[0]}"); } } catch (Exception) { Console.WriteLine($"{appConfigName} not found in {args[0]}"); } var configPath = args[0]; var allRuns = File.ReadAllLines(Path.Combine(configPath, appConfigName)).Where(s => !string.IsNullOrWhiteSpace(s) && !s.StartsWith("##")); string dataset = args.Length > 1 && args[1].ToLower().Equals("-usesongs") ? "songs" : "reuters"; string datasetCrossValRootPath = Path.Combine(configPath, @$ "Dataset/{dataset}/CrossVal/"); int crossValidationValue = new DirectoryInfo(datasetCrossValRootPath).GetDirectories().Length; for (int i = 0; i < crossValidationValue; i++) { Console.WriteLine($@"Cross validation iteration {i + 1}"); var allHyperparameters = allRuns.Select(r => LanguageModelHyperparameters.GenerateFromArguments(r)); var crossValIterationPath = Path.Combine(datasetCrossValRootPath, @$ "{i + 1}"); // Our corpus existing classification is independent of training Corpus.InitializeAndFillCategoriesMap(crossValIterationPath); NaiveBayesClassifier.InitializeAndFillCategoryTrainingCounts(Corpus.CategoriesMap); // Delete previous predictions files var dir = new DirectoryInfo(crossValIterationPath); foreach (var file in dir.EnumerateFiles("predictions*")) { file.Delete(); } var runId = 1; foreach (var hyperparameters in allHyperparameters) { var globalStopwatch = new Stopwatch(); globalStopwatch.Start(); // We do this here as volcabulary can change depending on hyperparams //Console.WriteLine($@"Parsing all training documents to get valid vocabulary and train collection level unigram model (used by some smoothing techniques)..."); var allCategoriesTrainingCorpus = new Corpus(); allCategoriesTrainingCorpus.InitializeAndPreprocessCategoryCorpus(Path.Combine(crossValIterationPath, "training"), "ALLCATEGORIES", hyperparameters); Corpus.InitializeAndFillValidVocabulary(allCategoriesTrainingCorpus, hyperparameters); //Console.WriteLine($@"Generated valid vocabulary. Elapsed time: {globalStopwatch.ElapsedMilliseconds}"); TrainAllLanguageModels(hyperparameters, crossValIterationPath, allCategoriesTrainingCorpus); //Console.WriteLine(); //Console.WriteLine($@"Training done in {globalStopwatch.ElapsedMilliseconds} ms"); //Console.WriteLine(); //Console.WriteLine($@"Classifying documents"); var allPredictions = ClassifyAllTestDocuments(hyperparameters, crossValIterationPath); File.WriteAllLines(Path.Combine(crossValIterationPath, @$ "predictions{runId}"), allPredictions); //Console.WriteLine($@"Elapsed time: {globalStopwatch.ElapsedMilliseconds} ms"); runId++; } } }
public abstract double CalculateDocumentLogProbability(Corpus corpus);
public abstract void TrainLanguageModel(Corpus corpus);