public void TrainWord2Sense(IEnumerable <IDocument> documents, ParallelOptions parallelOptions, int ngrams = 3, double tooRare = 1E-5, double tooCommon = 0.1, Word2SenseTrainingData trainingData = null)
        {
            if (trainingData == null)
            {
                trainingData = new Word2SenseTrainingData();
            }

            var stopwords = new HashSet <ulong>(StopWords.Spacy.For(Language).Select(w => Data.IgnoreCase ? IgnoreCaseHash64(w.AsSpan()) : Hash64(w.AsSpan())).ToArray());

            int docCount = 0, tkCount = 0;

            var sw = Stopwatch.StartNew();

            TrainLock.EnterWriteLock();
            try
            {
                Parallel.ForEach(documents, parallelOptions, doc =>
                {
                    try
                    {
                        var Previous = new ulong[ngrams];
                        var Stack    = new Queue <ulong>(ngrams);

                        if (doc.TokensCount < ngrams)
                        {
                            return;
                        }                                         //Ignore too small documents

                        Interlocked.Add(ref tkCount, doc.TokensCount);
                        foreach (var span in doc)
                        {
                            var tokens = span.GetTokenized().ToArray();

                            for (int i = 0; i < tokens.Length; i++)
                            {
                                var tk = tokens[i];

                                var hash = Data.IgnoreCase ? IgnoreCaseHash64(tk.ValueAsSpan) : Hash64(tk.ValueAsSpan);

                                bool filterPartOfSpeech = !(tk.POS == PartOfSpeech.ADJ || tk.POS == PartOfSpeech.NOUN || tk.POS == PartOfSpeech.PROPN);

                                bool skipIfHasUpperCase = (!Data.IgnoreCase && !tk.ValueAsSpan.IsAllLowerCase());

                                bool skipIfTooSmall = (tk.Length < 3);

                                bool skipIfNotAllLetterOrDigit = !(tk.ValueAsSpan.IsAllLetterOrDigit());

                                bool skipIfStopWordOrEntity = stopwords.Contains(hash) || tk.EntityTypes.Any();

                                //Heuristic for ordinal numbers (i.e. 1st, 2nd, 33rd, etc)
                                bool skipIfMaybeOrdinal = (tk.ValueAsSpan.IndexOfAny(new char[] { '1', '2', '3', '4', '5', '6', '7', '8', '9', '0' }, 0) >= 0 &&
                                                           tk.ValueAsSpan.IndexOfAny(new char[] { 't', 'h', 's', 't', 'r', 'd' }, 0) >= 0 &&
                                                           tk.ValueAsSpan.IndexOfAny(new char[] { 'a', 'b', 'c', 'e', 'f', 'g', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'u', 'v', 'w', 'x', 'y', 'z' }, 0) < 0);

                                bool skipThisToken = filterPartOfSpeech || skipIfHasUpperCase || skipIfTooSmall || skipIfNotAllLetterOrDigit || skipIfStopWordOrEntity || skipIfMaybeOrdinal;

                                if (skipThisToken)
                                {
                                    Stack.Clear();
                                    continue;
                                }

                                if (!trainingData.Words.ContainsKey(hash))
                                {
                                    trainingData.Words[hash] = Data.IgnoreCase ? tk.Value.ToLowerInvariant() : tk.Value;
                                }

                                Stack.Enqueue(hash);
                                ulong combined = Stack.ElementAt(0);

                                for (int j = 1; j < Stack.Count; j++)
                                {
                                    combined = HashCombine64(combined, Stack.ElementAt(j));
                                    if (trainingData.HashCount.ContainsKey(combined))
                                    {
                                        trainingData.HashCount[combined]++;
                                    }
                                    else
                                    {
                                        trainingData.Senses[combined]    = Stack.Take(j + 1).ToArray();
                                        trainingData.HashCount[combined] = 1;
                                    }
                                }

                                if (Stack.Count > ngrams)
                                {
                                    Stack.Dequeue();
                                }
                            }
                        }

                        int count = Interlocked.Increment(ref docCount);

                        if (count % 1000 == 0)
                        {
                            var mem = GC.GetTotalMemory(false);
                            Logger.LogInformation("[MEM: {MEMORY} MB]  Training Word2Sense model - at {DOCCOUNT} documents, {TKCOUNT} tokens - elapsed {ELAPSED} seconds at {KTKS} kTk/s)", Math.Round(mem / 1048576.0, 2), docCount, tkCount, sw.Elapsed.TotalSeconds, (tkCount / sw.ElapsedMilliseconds));
                        }
                    }
                    catch (Exception E)
                    {
                        Logger.LogError(E, "Error during training Word2Sense model");
                    }
                });
            }
            catch (OperationCanceledException OCE)
            {
                return;
            }
            finally
            {
                TrainLock.ExitWriteLock();
            }

            Logger.LogInformation("Finish parsing documents for Word2Sense model");

            int thresholdRare   = (int)Math.Floor(tooRare * docCount);
            int thresholdCommon = (int)Math.Floor(tooCommon * docCount);

            var toKeep = trainingData.HashCount.Where(kv => kv.Value >= thresholdRare && kv.Value <= thresholdCommon).OrderByDescending(kv => kv.Value)
                         .Select(kv => kv.Key).ToArray();

            foreach (var key in toKeep)
            {
                var hashes = trainingData.Senses[key];
                var count  = trainingData.HashCount[key];

                Data.Hashes.Add(key);
                for (int i = 0; i < hashes.Length; i++)
                {
                    if (Data.MultiGramHashes.Count <= i)
                    {
                        Data.MultiGramHashes.Add(new HashSet <ulong>());
                    }
                    Data.MultiGramHashes[i].Add(hashes[i]);
                }
            }

            Logger.LogInformation("Finish training Word2Sense model");
        }
Example #2
0
        public void TrainWord2Sense(IEnumerable <IDocument> documents, ParallelOptions parallelOptions, int ngrams = 3, double tooRare = 1E-5, double tooCommon = 0.1, Word2SenseTrainingData trainingData = null)
        {
            var hashCount     = new ConcurrentDictionary <ulong, int>(trainingData?.HashCount ?? new Dictionary <ulong, int>());
            var senses        = new ConcurrentDictionary <ulong, ulong[]>(trainingData?.Senses ?? new Dictionary <ulong, ulong[]>());
            var words         = new ConcurrentDictionary <ulong, string>(trainingData?.Words ?? new Dictionary <ulong, string>());
            var shapes        = new ConcurrentDictionary <string, ulong>(trainingData?.Shapes ?? new Dictionary <string, ulong>());
            var shapeExamples = new ConcurrentDictionary <string, string[]>(trainingData?.ShapeExamples ?? new Dictionary <string, string[]>());

            long totalDocCount   = trainingData?.SeenDocuments ?? 0;
            long totalTokenCount = trainingData?.SeenTokens ?? 0;

            bool ignoreCase        = Data.IgnoreCase;
            bool ignoreOnlyNumeric = Data.IgnoreOnlyNumeric;
            var  stopwords         = new HashSet <ulong>(StopWords.Spacy.For(Language).Select(w => ignoreCase ? IgnoreCaseHash64(w.AsSpan()) : Hash64(w.AsSpan())).ToArray());

            int docCount = 0, tkCount = 0;

            var sw = Stopwatch.StartNew();

            TrainLock.EnterWriteLock();
            try
            {
                Parallel.ForEach(documents, parallelOptions, doc =>
                {
                    try
                    {
                        var stack = new Queue <ulong>(ngrams);

                        if (doc.TokensCount < ngrams)
                        {
                            return;
                        }                                         //Ignore too small documents

                        Interlocked.Add(ref tkCount, doc.TokensCount);

                        foreach (var span in doc)
                        {
                            var tokens = span.GetCapturedTokens().ToArray();

                            for (int i = 0; i < tokens.Length; i++)
                            {
                                var tk = tokens[i];

                                if (!(tk is Tokens))
                                {
                                    var shape = tk.ValueAsSpan.Shape(compact: false);
                                    shapes.AddOrUpdate(shape, 1, (k, v) => v + 1);

                                    shapeExamples.AddOrUpdate(shape, (k) => new[] { tk.Value }, (k, v) =>
                                    {
                                        if (v.Length < 50)
                                        {
                                            v = v.Concat(new[] { tk.Value }).Distinct().ToArray();
                                        }
                                        return(v);
                                    });
                                }

                                var hash = ignoreCase ? IgnoreCaseHash64(tk.ValueAsSpan) : Hash64(tk.ValueAsSpan);

                                bool filterPartOfSpeech = !(tk.POS == PartOfSpeech.ADJ || tk.POS == PartOfSpeech.NOUN);

                                bool skipIfHasUpperCase = (!ignoreCase && !tk.ValueAsSpan.IsAllLowerCase());

                                bool skipIfTooSmall = (tk.Length < 3);

                                bool skipIfNotAllLetterOrDigit = !(tk.ValueAsSpan.IsAllLetterOrDigit());

                                bool skipIfStopWordOrEntity = stopwords.Contains(hash) || tk.EntityTypes.Any();

                                //Heuristic for ordinal numbers (i.e. 1st, 2nd, 33rd, etc)
                                bool skipIfMaybeOrdinal = (tk.ValueAsSpan.IndexOfAny(new char[] { '1', '2', '3', '4', '5', '6', '7', '8', '9', '0' }, 0) >= 0 &&
                                                           tk.ValueAsSpan.IndexOfAny(new char[] { 't', 'h', 's', 't', 'r', 'd' }, 0) >= 0 &&
                                                           tk.ValueAsSpan.IndexOfAny(new char[] { 'a', 'b', 'c', 'e', 'f', 'g', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'u', 'v', 'w', 'x', 'y', 'z' }, 0) < 0);

                                bool skipIfOnlyNumeric = ignoreOnlyNumeric ? !tk.ValueAsSpan.IsLetter() : false;

                                //Only filter for POS if language != any, as otherwise we won't have the POS information
                                bool skipThisToken = (filterPartOfSpeech && Language != Language.Any) || skipIfHasUpperCase || skipIfTooSmall || skipIfNotAllLetterOrDigit || skipIfStopWordOrEntity || skipIfMaybeOrdinal || skipIfOnlyNumeric;

                                if (skipThisToken)
                                {
                                    stack.Clear();
                                    continue;
                                }

                                if (!words.ContainsKey(hash))
                                {
                                    words[hash] = ignoreCase ? tk.Value.ToLowerInvariant() : tk.Value;
                                }

                                stack.Enqueue(hash);
                                ulong combined = stack.ElementAt(0);

                                for (int j = 1; j < stack.Count; j++)
                                {
                                    combined = HashCombine64(combined, stack.ElementAt(j));
                                    if (hashCount.ContainsKey(combined))
                                    {
                                        hashCount[combined]++;
                                    }
                                    else
                                    {
                                        senses[combined]    = stack.Take(j + 1).ToArray();
                                        hashCount[combined] = 1;
                                    }
                                }

                                if (stack.Count > ngrams)
                                {
                                    stack.Dequeue();
                                }
                            }
                        }

                        int count = Interlocked.Increment(ref docCount);

                        if (count % 1000 == 0)
                        {
                            Logger.LogInformation("Training Word2Sense model - at {DOCCOUNT} documents, {TKCOUNT} tokens - elapsed {ELAPSED} seconds at {KTKS} kTk/s)", docCount, tkCount, sw.Elapsed.TotalSeconds, (tkCount / sw.ElapsedMilliseconds));
                        }
                    }
                    catch (Exception E)
                    {
                        Logger.LogError(E, "Error during training Word2Sense model");
                    }
                });
            }
            catch (OperationCanceledException)
            {
                return;
            }
            finally
            {
                TrainLock.ExitWriteLock();
            }

            Logger.LogInformation("Finish parsing documents for Word2Sense model");

            totalDocCount   += docCount;
            totalTokenCount += tkCount;

            int thresholdRare   = Math.Max(2, (int)Math.Floor(tooRare * totalTokenCount));
            int thresholdCommon = (int)Math.Floor(tooCommon * totalTokenCount);

            var toKeep = hashCount.Where(kv => kv.Value >= thresholdRare && kv.Value <= thresholdCommon).OrderByDescending(kv => kv.Value)
                         .Select(kv => kv.Key).ToArray();

            foreach (var key in toKeep)
            {
                if (senses.TryGetValue(key, out var hashes) && hashCount.TryGetValue(key, out var count))
                {
                    Data.Hashes.Add(key);
                    for (int i = 0; i < hashes.Length; i++)
                    {
                        if (Data.MultiGramHashes.Count <= i)
                        {
                            Data.MultiGramHashes.Add(new HashSet <ulong>());
                        }
                        Data.MultiGramHashes[i].Add(hashes[i]);
                    }
                }
            }

            if (trainingData is object)
            {
                trainingData.HashCount     = new Dictionary <ulong, int>(hashCount);
                trainingData.Senses        = new Dictionary <ulong, ulong[]>(senses);
                trainingData.Words         = new Dictionary <ulong, string>(words);
                trainingData.SeenDocuments = totalDocCount;
                trainingData.SeenTokens    = totalTokenCount;
                trainingData.Shapes        = new Dictionary <string, ulong>(shapes);
                trainingData.ShapeExamples = new Dictionary <string, string[]>(shapeExamples);
            }

            Logger.LogInformation("Finish training Word2Sense model");
        }