private static double TestParser(List <IDocument> testDocuments, AveragePerceptronDependencyParser parser) { var sentences = testDocuments.SelectMany(d => d.Spans).Where(s => s.IsProjective() && s.TokensCount > 4 && !s.Any(tk => tk.Value.Contains("@") || tk.Value.Contains("://"))).ToList(); int correctUnlabeled = 0, correctLabeled = 0, total = 0, correctRoot = 0; var sw = new System.Diagnostics.Stopwatch(); sw.Start(); Parallel.ForEach(sentences, s => { var goldHeads = s.Select(tk => tk.Head).ToArray(); var goldLabels = s.Select(tk => tk.DependencyType).ToArray(); parser.Predict(s); int UAS = 0, LAS = 0, tot = 0, ROOT = 0; for (int i = 0; i < s.TokensCount; i++) { //if (!s[i].Value.Any(c=> char.IsPunctuation(c))) { if (goldHeads[i] == -1) { ROOT += (goldHeads[i] == s[i].Head) ? 1 : 0; } bool correctHead = goldHeads[i] == s[i].Head; bool correctLabel = goldLabels[i] == s[i].DependencyType; if (correctHead) { UAS++; } if (correctHead && correctLabel) { LAS++; } tot++; } //Restore original values s[i].Head = goldHeads[i]; s[i].DependencyType = goldLabels[i]; } Interlocked.Add(ref correctUnlabeled, UAS); Interlocked.Add(ref correctLabeled, LAS); Interlocked.Add(ref total, tot); Interlocked.Add(ref correctRoot, ROOT); }); sw.Stop(); double UASscore = 100D * correctUnlabeled / total; Logger.LogInformation($"UAS:{UASscore:0.00}% & LAS:{100D * correctLabeled / total:0.00}% & & R:{100D * correctRoot / sentences.Count:0.00}% @ {1000D * total / sw.ElapsedMilliseconds:0} tokens/second"); return(UASscore); }
public static void Train(string udSource, string ontonotesSource) { var trainFiles = Directory.GetFiles(udSource, "*-train.conllu", SearchOption.AllDirectories); var testFiles = Directory.GetFiles(udSource, "*-dev.conllu", SearchOption.AllDirectories); List <string> trainFilesOntonotesEnglish = null; if (!string.IsNullOrWhiteSpace(ontonotesSource)) { trainFilesOntonotesEnglish = Directory.GetFiles(ontonotesSource, "*.parse.ddg", SearchOption.AllDirectories) .Where(fn => !fn.Contains("sel_") || int.Parse(Path.GetFileNameWithoutExtension(fn).Split(new char[] { '_', '.' }).Skip(1).First()) < 3654) .ToList(); } var trainFilesPerLanguage = trainFiles.Select(f => new { lang = Path.GetFileNameWithoutExtension(f).Replace("_", "-").Split(new char[] { '-' }).First(), file = f }).GroupBy(f => f.lang).ToDictionary(g => g.Key, g => g.Select(f => f.file).ToList()); var testFilesPerLanguage = testFiles.Select(f => new { lang = Path.GetFileNameWithoutExtension(f).Replace("_", "-").Split(new char[] { '-' }).First(), file = f }).GroupBy(f => f.lang).ToDictionary(g => g.Key, g => g.Select(f => f.file).ToList()); var languages = trainFilesPerLanguage.Keys.ToList(); Logger.LogInformation($"Found these languages for training: {string.Join(", ", languages)}"); int N_training = 5; Parallel.ForEach(languages, lang => { Language language; try { language = Languages.CodeToEnum(lang); } catch { Logger.LogWarning($"Unknown language {lang}"); return; } var arcNames = new HashSet <string>(); if (trainFilesPerLanguage.TryGetValue(lang, out var langTrainFiles) && testFilesPerLanguage.TryGetValue(lang, out var langTestFiles)) { var trainDocuments = ReadCorpus(langTrainFiles, arcNames, language); var testDocuments = ReadCorpus(langTestFiles, arcNames, language); if (language == Language.English) { //Merge with Ontonotes 5.0 corpus trainDocuments.AddRange(ReadCorpus(trainFilesOntonotesEnglish, arcNames, language, isOntoNotes: true)); } double bestScore = double.MinValue; for (int i = 0; i < N_training; i++) { var Tagger = new AveragePerceptronTagger(language, 0); Tagger.Train(trainDocuments.AsEnumerable(), (int)(5 + ThreadSafeRandom.Next(15))); var scoreTrain = TestTagger(trainDocuments, Tagger); var scoreTest = TestTagger(testDocuments, Tagger); if (scoreTest > bestScore) { Logger.LogInformation($"\n>>>>> {lang}: NEW POS BEST: {scoreTest:0.0}%"); try { Tagger.StoreAsync().Wait(); } catch (Exception E) { Logger.LogError(E, $"\n>>>>> {lang}: Failed to store model"); } bestScore = scoreTest; } else { Logger.LogInformation($"\n>>>>> {lang}: POS BEST IS STILL : {bestScore:0.0}%"); } } bestScore = double.MinValue; for (int i = 0; i < N_training; i++) { var Parser = new AveragePerceptronDependencyParser(language, 0 /*, arcNames.ToList()*/); try { Parser.Train(trainDocuments.AsEnumerable(), (int)(5 + ThreadSafeRandom.Next(10)), (float)(1D - ThreadSafeRandom.NextDouble() * ThreadSafeRandom.NextDouble())); } catch (Exception E) { Logger.LogInformation("FAIL: " + E.Message); continue; } trainDocuments = ReadCorpus(langTrainFiles, arcNames, language); testDocuments = ReadCorpus(langTestFiles, arcNames, language); if (language == Language.English) { //Merge with Ontonotes 5.0 corpus trainDocuments.AddRange(ReadCorpus(trainFilesOntonotesEnglish, arcNames, language, isOntoNotes: true)); } var scoreTrain = TestParser(trainDocuments, Parser); var scoreTest = TestParser(testDocuments, Parser); if (scoreTest > bestScore) { Logger.LogInformation($"\n>>>>> {lang}: NEW DEP BEST: {scoreTest:0.0}%"); try { Parser.StoreAsync().Wait(); } catch (Exception E) { Logger.LogError(E, $"\n>>>>> {lang}: Failed to store model"); } bestScore = scoreTest; } else { Logger.LogInformation($"\n>>>>> {lang}: DEP BEST IS STILL : {bestScore:0.0}%"); } Parser = null; } } }); foreach (var lang in languages) { Language language; try { language = Languages.CodeToEnum(lang); } catch { Logger.LogInformation($"Unknown language {lang}"); return; } var arcNames = new HashSet <string>(); var trainDocuments = ReadCorpus(trainFilesPerLanguage[lang], arcNames, language); var testDocuments = ReadCorpus(testFilesPerLanguage[lang], arcNames, language); if (language == Language.English) { //Merge with Ontonotes 5.0 corpus var ontonotesDocuments = ReadCorpus(trainFilesOntonotesEnglish, arcNames, language, isOntoNotes: true); trainDocuments.AddRange(ontonotesDocuments); } var Tagger = AveragePerceptronTagger.FromStoreAsync(language, 0, "").WaitResult(); Logger.LogInformation($"\n{lang} - TAGGER / TRAIN"); TestTagger(trainDocuments, Tagger); Logger.LogInformation($"\n{lang} - TAGGER / TEST"); TestTagger(testDocuments, Tagger); trainDocuments = ReadCorpus(trainFilesPerLanguage[lang], arcNames, language); testDocuments = ReadCorpus(testFilesPerLanguage[lang], arcNames, language); var Parser = AveragePerceptronDependencyParser.FromStoreAsync(language, 0, "").WaitResult(); Logger.LogInformation($"\n{lang} - PARSER / TRAIN"); TestParser(trainDocuments, Parser); Logger.LogInformation($"\n{lang} - PARSER / TEST"); TestParser(testDocuments, Parser); } }
public static async Task Train(string udSource, string ontonotesSource, string languagesDirectory) { var trainFiles = Directory.GetFiles(udSource, "*-train.conllu", SearchOption.AllDirectories); var testFiles = Directory.GetFiles(udSource, "*-dev.conllu", SearchOption.AllDirectories); List <string> trainFilesOntonotesEnglish = null; if (!string.IsNullOrWhiteSpace(ontonotesSource)) { trainFilesOntonotesEnglish = Directory.GetFiles(ontonotesSource, "*.parse.ddg", SearchOption.AllDirectories) .Where(fn => !fn.Contains("sel_") || int.Parse(Path.GetFileNameWithoutExtension(fn).Split(new char[] { '_', '.' }).Skip(1).First()) < 3654) .ToList(); } var trainFilesPerLanguage = trainFiles.Select(f => new { lang = Path.GetFileNameWithoutExtension(f).Replace("_", "-").Split(new char[] { '-' }).First(), file = f }).GroupBy(f => f.lang).ToDictionary(g => g.Key, g => g.Select(f => f.file).ToList()); var testFilesPerLanguage = testFiles.Select(f => new { lang = Path.GetFileNameWithoutExtension(f).Replace("_", "-").Split(new char[] { '-' }).First(), file = f }).GroupBy(f => f.lang).ToDictionary(g => g.Key, g => g.Select(f => f.file).ToList()); var languages = new List <(Language language, string lang)>(); foreach (var lang in trainFilesPerLanguage.Keys.Intersect(testFilesPerLanguage.Keys)) { try { var language = Languages.CodeToEnum(lang); languages.Add((language, lang)); } catch { Logger.LogWarning($"Unknown language {lang}"); } } Logger.LogInformation($"Found these languages for training: {string.Join(", ", languages.Select(l => l.language))}"); int attempts = 5; await Task.WhenAll(languages.Select(async v => { await Task.Yield(); var(language, lang) = (v.language, v.lang); var arcNames = new HashSet <string>(); if (trainFilesPerLanguage.TryGetValue(lang, out var langTrainFiles) && testFilesPerLanguage.TryGetValue(lang, out var langTestFiles)) { var trainDocuments = await ReadCorpusAsync(langTrainFiles, arcNames, language); var testDocuments = await ReadCorpusAsync(langTestFiles, arcNames, language); if (language == Language.English) { //Merge with Ontonotes 5.0 corpus var testToTrain = (int)((float)trainFilesOntonotesEnglish.Count * testDocuments.Count / trainDocuments.Count); trainDocuments.AddRange(await ReadCorpusAsync(trainFilesOntonotesEnglish.Take(testToTrain).ToList(), arcNames, language, isOntoNotes: true)); testDocuments.AddRange(await ReadCorpusAsync(trainFilesOntonotesEnglish.Skip(testToTrain).ToList(), arcNames, language, isOntoNotes: true)); } double bestScore = double.MinValue; for (int i = 0; i < attempts; i++) { await Task.Run(async() => { var tagger = new AveragePerceptronTagger(language, 0); tagger.Train(trainDocuments, (5 + ThreadSafeRandom.Next(15))); var scoreTrain = TestTagger(trainDocuments, tagger); var scoreTest = TestTagger(testDocuments, tagger); if (scoreTest > bestScore) { Logger.LogInformation($"\n>>>>> {language}: NEW POS BEST: {scoreTest:0.0}%"); await tagger.StoreAsync(); if (scoreTest > 80) { //Prepare models for new nuget-based distribution var resDir = Path.Combine(languagesDirectory, language.ToString(), "Resources"); Directory.CreateDirectory(resDir); using (var f = File.OpenWrite(Path.Combine(resDir, "tagger.bin"))) { await tagger.StoreAsync(f); } await File.WriteAllTextAsync(Path.Combine(resDir, "tagger.score"), $"{scoreTest:0.0}%"); } bestScore = scoreTest; } else { Logger.LogInformation($"\n>>>>> {language}: POS BEST IS STILL : {bestScore:0.0}%"); } }); } bestScore = double.MinValue; for (int i = 0; i < attempts; i++) { await Task.Run(async() => { var parser = new AveragePerceptronDependencyParser(language, 0 /*, arcNames.ToList()*/); try { parser.Train(trainDocuments, (5 + ThreadSafeRandom.Next(10)), (float)(1D - ThreadSafeRandom.NextDouble() * ThreadSafeRandom.NextDouble())); } catch (Exception E) { Logger.LogError("FAIL", E); return; } trainDocuments = await ReadCorpusAsync(langTrainFiles, arcNames, language); testDocuments = await ReadCorpusAsync(langTestFiles, arcNames, language); if (language == Language.English) { //Merge with Ontonotes 5.0 corpus //Merge with Ontonotes 5.0 corpus var testToTrain = (int)((float)trainFilesOntonotesEnglish.Count *testDocuments.Count / trainDocuments.Count); trainDocuments.AddRange(await ReadCorpusAsync(trainFilesOntonotesEnglish.Take(testToTrain).ToList(), arcNames, language, isOntoNotes: true)); testDocuments.AddRange(await ReadCorpusAsync(trainFilesOntonotesEnglish.Skip(testToTrain).ToList(), arcNames, language, isOntoNotes: true)); } var scoreTrain = TestParser(trainDocuments, parser); var scoreTest = TestParser(testDocuments, parser); if (scoreTest > bestScore) { Logger.LogInformation($"\n>>>>> {language}: NEW DEP BEST: {scoreTest:0.0}%"); if (scoreTest > 80) { //Prepare models for new nuget-based distribution var resDir = Path.Combine(languagesDirectory, language.ToString(), "Resources"); Directory.CreateDirectory(resDir); using (var f = File.OpenWrite(Path.Combine(resDir, "parser.bin"))) { await parser.StoreAsync(f); } await File.WriteAllTextAsync(Path.Combine(resDir, "parser.score"), $"{scoreTest:0.0}%"); } bestScore = scoreTest; } else { Logger.LogInformation($"\n>>>>> {language}: DEP BEST IS STILL : {bestScore:0.0}%"); } parser = null; }); } } })); foreach (var(language, lang) in languages) { var arcNames = new HashSet <string>(); var trainDocuments = await ReadCorpusAsync(trainFilesPerLanguage[lang], arcNames, language); var testDocuments = await ReadCorpusAsync(testFilesPerLanguage[lang], arcNames, language); if (language == Language.English) { //Merge with Ontonotes 5.0 corpus var testToTrain = (int)((float)trainFilesOntonotesEnglish.Count * testDocuments.Count / trainDocuments.Count); trainDocuments.AddRange(await ReadCorpusAsync(trainFilesOntonotesEnglish.Take(testToTrain).ToList(), arcNames, language, isOntoNotes: true)); testDocuments.AddRange(await ReadCorpusAsync(trainFilesOntonotesEnglish.Skip(testToTrain).ToList(), arcNames, language, isOntoNotes: true)); } var tagger = await AveragePerceptronTagger.FromStoreAsync(language, 0, ""); Logger.LogInformation($"\n{lang} - TAGGER / TRAIN"); TestTagger(trainDocuments, tagger); Logger.LogInformation($"\n{lang} - TAGGER / TEST"); TestTagger(testDocuments, tagger); trainDocuments = await ReadCorpusAsync(trainFilesPerLanguage[lang], arcNames, language); testDocuments = await ReadCorpusAsync(testFilesPerLanguage[lang], arcNames, language); var parser = await AveragePerceptronDependencyParser.FromStoreAsync(language, 0, ""); Logger.LogInformation($"\n{lang} - PARSER / TRAIN"); TestParser(trainDocuments, parser); Logger.LogInformation($"\n{lang} - PARSER / TEST"); TestParser(testDocuments, parser); } }