static Dictionary <TokenPair <T>, long> GetPairStatistics <TList>(KeyValuePair <TList, long>[] sortedWordCounts, out Dictionary <TokenPair <T>, Dictionary <int, long> > indices) where TList : IList <T> { var stats = new Dictionary <TokenPair <T>, long>(); indices = new Dictionary <TokenPair <T>, Dictionary <int, long> >(); for (int i = 0; i < sortedWordCounts.Length; i++) { var word = sortedWordCounts[i].Key; var prevChar = word[0]; foreach (var @char in word.Skip(1)) { var pair = TokenPair.Create(prevChar, @char); stats.TryGetValue(pair, out long freq); stats[pair] = freq + sortedWordCounts[i].Value; if (indices.TryGetValue(pair, out var index)) { index.TryGetValue(i, out long indexCounter); index[i] = indexCounter + 1; } else { indices[pair] = new Dictionary <int, long> { [i] = 1 }; } prevChar = @char; } } return(stats); }
static IEnumerable <ValueTuple <int, T, T, int> > ReplacePair(TokenPair <T> pair, KeyValuePair <BytePairToken <T>, long>[] sorted, Dictionary <TokenPair <T>, Dictionary <int, long> > indices) { var(first, second) = pair; throw new NotImplementedException(); }
public static IEnumerable <TokenPair <T> > Learn(IReadOnlyDictionary <BytePairToken <T>, long> vocabulary, int numSymbols, int minFrequency) { if (vocabulary == null) { throw new ArgumentNullException(nameof(vocabulary)); } var sorted = vocabulary.OrderByDescending(token => token.Value).ToArray(); var stats = GetPairStatistics(sorted, out var indices); var bigStats = Copy(stats); double threshold = stats.Values.Max() / 10.0; for (int symbolIndex = 0; symbolIndex < numSymbols; symbolIndex++) { TokenPair <T> mostFrequent = default; if (stats.Count != 0) { mostFrequent = MostFrequent(stats); } if (stats.Count == 0 || (symbolIndex > 0 && stats[mostFrequent] < threshold)) { PruneStats(stats, bigStats, threshold); stats = Copy(bigStats); mostFrequent = MostFrequent(stats); threshold = checked (stats[mostFrequent] * symbolIndex) / (symbolIndex + 10000.0); PruneStats(stats, bigStats, threshold); } if (stats[mostFrequent] < minFrequency) { throw new ArgumentException("Inconsistent input: no pair has required frequency"); } yield return(mostFrequent); var changes = ReplacePair(mostFrequent, sorted, indices); UpdatePairStatistics(mostFrequent, changes, stats, indices); stats[mostFrequent] = 0; if (symbolIndex % 100 == 99) { PruneStats(stats, bigStats, threshold); } } }
public bool Equals(TokenPair <T> other) => EqualityComparer <T> .Default.Equals(this.Token1, other.Token1) && EqualityComparer <T> .Default.Equals(this.Token2, other.Token2);
static void UpdatePairStatistics(TokenPair <T> pair, IEnumerable <ValueTuple <int, T, T, int> > changed, Dictionary <TokenPair <T>, long> stats, Dictionary <TokenPair <T>, Dictionary <int, long> > indices) { stats[pair] = 0; indices[pair] = new Dictionary <int, long>(); var(first, second) = pair; var newPair = first.Append(second); foreach (var(j, word, oldWord, freq) in changed) { int i = 0; while (true) { i = oldWord.IndexOf(first, startIndex: i); if (i < 0) { break; } if (i < oldWord.Count - 1 && oldWord[i + 1].Equals(second)) { if (i > 0) { var prev = oldWord.GetPair(i - 1); stats[prev] -= freq; indices[prev][j]--; } if (i < oldWord.Count - 2) { //assuming a symbol sequence "A B C B", if "B C" is merged, reduce the frequency of "C B". //however, skip this if the sequence is A B C B C, because the frequency of "C B" will be reduced by the previous code block if (!oldWord[i + 2].Equals(first) || i >= oldWord.Count - 3 || !oldWord[i + 3].Equals(second)) { var nex = oldWord.GetPair(i + 1); stats[nex] -= freq; indices[nex][j]--; } } i += 2; } else { i++; } } i = 0; while (true) { i = word.IndexOf(newPair, i); if (i < 0) { break; } if (i > 0) { var prev = word.GetPair(i - 1); stats[prev] += freq; indices[prev][j]++; } if (i < word.Count - 1 && !word[i + 1].Equals(newPair)) { var nex = word.GetPair(i); stats[nex] += freq; indices[nex][j]++; } i++; } } }