public void TrainWord2Sense(IEnumerable <IDocument> documents, ParallelOptions parallelOptions, int ngrams = 3, double tooRare = 1E-5, double tooCommon = 0.1, Word2SenseTrainingData trainingData = null) { if (trainingData == null) { trainingData = new Word2SenseTrainingData(); } var stopwords = new HashSet <ulong>(StopWords.Spacy.For(Language).Select(w => Data.IgnoreCase ? IgnoreCaseHash64(w.AsSpan()) : Hash64(w.AsSpan())).ToArray()); int docCount = 0, tkCount = 0; var sw = Stopwatch.StartNew(); TrainLock.EnterWriteLock(); try { Parallel.ForEach(documents, parallelOptions, doc => { try { var Previous = new ulong[ngrams]; var Stack = new Queue <ulong>(ngrams); if (doc.TokensCount < ngrams) { return; } //Ignore too small documents Interlocked.Add(ref tkCount, doc.TokensCount); foreach (var span in doc) { var tokens = span.GetTokenized().ToArray(); for (int i = 0; i < tokens.Length; i++) { var tk = tokens[i]; var hash = Data.IgnoreCase ? IgnoreCaseHash64(tk.ValueAsSpan) : Hash64(tk.ValueAsSpan); bool filterPartOfSpeech = !(tk.POS == PartOfSpeech.ADJ || tk.POS == PartOfSpeech.NOUN || tk.POS == PartOfSpeech.PROPN); bool skipIfHasUpperCase = (!Data.IgnoreCase && !tk.ValueAsSpan.IsAllLowerCase()); bool skipIfTooSmall = (tk.Length < 3); bool skipIfNotAllLetterOrDigit = !(tk.ValueAsSpan.IsAllLetterOrDigit()); bool skipIfStopWordOrEntity = stopwords.Contains(hash) || tk.EntityTypes.Any(); //Heuristic for ordinal numbers (i.e. 1st, 2nd, 33rd, etc) bool skipIfMaybeOrdinal = (tk.ValueAsSpan.IndexOfAny(new char[] { '1', '2', '3', '4', '5', '6', '7', '8', '9', '0' }, 0) >= 0 && tk.ValueAsSpan.IndexOfAny(new char[] { 't', 'h', 's', 't', 'r', 'd' }, 0) >= 0 && tk.ValueAsSpan.IndexOfAny(new char[] { 'a', 'b', 'c', 'e', 'f', 'g', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'u', 'v', 'w', 'x', 'y', 'z' }, 0) < 0); bool skipThisToken = filterPartOfSpeech || skipIfHasUpperCase || skipIfTooSmall || skipIfNotAllLetterOrDigit || skipIfStopWordOrEntity || skipIfMaybeOrdinal; if (skipThisToken) { Stack.Clear(); continue; } if (!trainingData.Words.ContainsKey(hash)) { trainingData.Words[hash] = Data.IgnoreCase ? tk.Value.ToLowerInvariant() : tk.Value; } Stack.Enqueue(hash); ulong combined = Stack.ElementAt(0); for (int j = 1; j < Stack.Count; j++) { combined = HashCombine64(combined, Stack.ElementAt(j)); if (trainingData.HashCount.ContainsKey(combined)) { trainingData.HashCount[combined]++; } else { trainingData.Senses[combined] = Stack.Take(j + 1).ToArray(); trainingData.HashCount[combined] = 1; } } if (Stack.Count > ngrams) { Stack.Dequeue(); } } } int count = Interlocked.Increment(ref docCount); if (count % 1000 == 0) { var mem = GC.GetTotalMemory(false); Logger.LogInformation("[MEM: {MEMORY} MB] Training Word2Sense model - at {DOCCOUNT} documents, {TKCOUNT} tokens - elapsed {ELAPSED} seconds at {KTKS} kTk/s)", Math.Round(mem / 1048576.0, 2), docCount, tkCount, sw.Elapsed.TotalSeconds, (tkCount / sw.ElapsedMilliseconds)); } } catch (Exception E) { Logger.LogError(E, "Error during training Word2Sense model"); } }); } catch (OperationCanceledException OCE) { return; } finally { TrainLock.ExitWriteLock(); } Logger.LogInformation("Finish parsing documents for Word2Sense model"); int thresholdRare = (int)Math.Floor(tooRare * docCount); int thresholdCommon = (int)Math.Floor(tooCommon * docCount); var toKeep = trainingData.HashCount.Where(kv => kv.Value >= thresholdRare && kv.Value <= thresholdCommon).OrderByDescending(kv => kv.Value) .Select(kv => kv.Key).ToArray(); foreach (var key in toKeep) { var hashes = trainingData.Senses[key]; var count = trainingData.HashCount[key]; Data.Hashes.Add(key); for (int i = 0; i < hashes.Length; i++) { if (Data.MultiGramHashes.Count <= i) { Data.MultiGramHashes.Add(new HashSet <ulong>()); } Data.MultiGramHashes[i].Add(hashes[i]); } } Logger.LogInformation("Finish training Word2Sense model"); }
public void TrainWord2Sense(IEnumerable <IDocument> documents, ParallelOptions parallelOptions, int ngrams = 3, double tooRare = 1E-5, double tooCommon = 0.1, Word2SenseTrainingData trainingData = null) { var hashCount = new ConcurrentDictionary <ulong, int>(trainingData?.HashCount ?? new Dictionary <ulong, int>()); var senses = new ConcurrentDictionary <ulong, ulong[]>(trainingData?.Senses ?? new Dictionary <ulong, ulong[]>()); var words = new ConcurrentDictionary <ulong, string>(trainingData?.Words ?? new Dictionary <ulong, string>()); var shapes = new ConcurrentDictionary <string, ulong>(trainingData?.Shapes ?? new Dictionary <string, ulong>()); var shapeExamples = new ConcurrentDictionary <string, string[]>(trainingData?.ShapeExamples ?? new Dictionary <string, string[]>()); long totalDocCount = trainingData?.SeenDocuments ?? 0; long totalTokenCount = trainingData?.SeenTokens ?? 0; bool ignoreCase = Data.IgnoreCase; bool ignoreOnlyNumeric = Data.IgnoreOnlyNumeric; var stopwords = new HashSet <ulong>(StopWords.Spacy.For(Language).Select(w => ignoreCase ? IgnoreCaseHash64(w.AsSpan()) : Hash64(w.AsSpan())).ToArray()); int docCount = 0, tkCount = 0; var sw = Stopwatch.StartNew(); TrainLock.EnterWriteLock(); try { Parallel.ForEach(documents, parallelOptions, doc => { try { var stack = new Queue <ulong>(ngrams); if (doc.TokensCount < ngrams) { return; } //Ignore too small documents Interlocked.Add(ref tkCount, doc.TokensCount); foreach (var span in doc) { var tokens = span.GetCapturedTokens().ToArray(); for (int i = 0; i < tokens.Length; i++) { var tk = tokens[i]; if (!(tk is Tokens)) { var shape = tk.ValueAsSpan.Shape(compact: false); shapes.AddOrUpdate(shape, 1, (k, v) => v + 1); shapeExamples.AddOrUpdate(shape, (k) => new[] { tk.Value }, (k, v) => { if (v.Length < 50) { v = v.Concat(new[] { tk.Value }).Distinct().ToArray(); } return(v); }); } var hash = ignoreCase ? IgnoreCaseHash64(tk.ValueAsSpan) : Hash64(tk.ValueAsSpan); bool filterPartOfSpeech = !(tk.POS == PartOfSpeech.ADJ || tk.POS == PartOfSpeech.NOUN); bool skipIfHasUpperCase = (!ignoreCase && !tk.ValueAsSpan.IsAllLowerCase()); bool skipIfTooSmall = (tk.Length < 3); bool skipIfNotAllLetterOrDigit = !(tk.ValueAsSpan.IsAllLetterOrDigit()); bool skipIfStopWordOrEntity = stopwords.Contains(hash) || tk.EntityTypes.Any(); //Heuristic for ordinal numbers (i.e. 1st, 2nd, 33rd, etc) bool skipIfMaybeOrdinal = (tk.ValueAsSpan.IndexOfAny(new char[] { '1', '2', '3', '4', '5', '6', '7', '8', '9', '0' }, 0) >= 0 && tk.ValueAsSpan.IndexOfAny(new char[] { 't', 'h', 's', 't', 'r', 'd' }, 0) >= 0 && tk.ValueAsSpan.IndexOfAny(new char[] { 'a', 'b', 'c', 'e', 'f', 'g', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'u', 'v', 'w', 'x', 'y', 'z' }, 0) < 0); bool skipIfOnlyNumeric = ignoreOnlyNumeric ? !tk.ValueAsSpan.IsLetter() : false; //Only filter for POS if language != any, as otherwise we won't have the POS information bool skipThisToken = (filterPartOfSpeech && Language != Language.Any) || skipIfHasUpperCase || skipIfTooSmall || skipIfNotAllLetterOrDigit || skipIfStopWordOrEntity || skipIfMaybeOrdinal || skipIfOnlyNumeric; if (skipThisToken) { stack.Clear(); continue; } if (!words.ContainsKey(hash)) { words[hash] = ignoreCase ? tk.Value.ToLowerInvariant() : tk.Value; } stack.Enqueue(hash); ulong combined = stack.ElementAt(0); for (int j = 1; j < stack.Count; j++) { combined = HashCombine64(combined, stack.ElementAt(j)); if (hashCount.ContainsKey(combined)) { hashCount[combined]++; } else { senses[combined] = stack.Take(j + 1).ToArray(); hashCount[combined] = 1; } } if (stack.Count > ngrams) { stack.Dequeue(); } } } int count = Interlocked.Increment(ref docCount); if (count % 1000 == 0) { Logger.LogInformation("Training Word2Sense model - at {DOCCOUNT} documents, {TKCOUNT} tokens - elapsed {ELAPSED} seconds at {KTKS} kTk/s)", docCount, tkCount, sw.Elapsed.TotalSeconds, (tkCount / sw.ElapsedMilliseconds)); } } catch (Exception E) { Logger.LogError(E, "Error during training Word2Sense model"); } }); } catch (OperationCanceledException) { return; } finally { TrainLock.ExitWriteLock(); } Logger.LogInformation("Finish parsing documents for Word2Sense model"); totalDocCount += docCount; totalTokenCount += tkCount; int thresholdRare = Math.Max(2, (int)Math.Floor(tooRare * totalTokenCount)); int thresholdCommon = (int)Math.Floor(tooCommon * totalTokenCount); var toKeep = hashCount.Where(kv => kv.Value >= thresholdRare && kv.Value <= thresholdCommon).OrderByDescending(kv => kv.Value) .Select(kv => kv.Key).ToArray(); foreach (var key in toKeep) { if (senses.TryGetValue(key, out var hashes) && hashCount.TryGetValue(key, out var count)) { Data.Hashes.Add(key); for (int i = 0; i < hashes.Length; i++) { if (Data.MultiGramHashes.Count <= i) { Data.MultiGramHashes.Add(new HashSet <ulong>()); } Data.MultiGramHashes[i].Add(hashes[i]); } } } if (trainingData is object) { trainingData.HashCount = new Dictionary <ulong, int>(hashCount); trainingData.Senses = new Dictionary <ulong, ulong[]>(senses); trainingData.Words = new Dictionary <ulong, string>(words); trainingData.SeenDocuments = totalDocCount; trainingData.SeenTokens = totalTokenCount; trainingData.Shapes = new Dictionary <string, ulong>(shapes); trainingData.ShapeExamples = new Dictionary <string, string[]>(shapeExamples); } Logger.LogInformation("Finish training Word2Sense model"); }