private void Shuffle(List <RawSntPair> rawSntPairs) { //Put sentence pair with same source length into the bucket Dictionary <int, List <RawSntPair> > dict = new Dictionary <int, List <RawSntPair> >(); //<source sentence length, sentence pair set> foreach (RawSntPair item in rawSntPairs) { int length = item.SrcLength; if (dict.ContainsKey(length) == false) { dict.Add(length, new List <RawSntPair>()); } dict[length].Add(item); } //Randomized the order of sentence pairs with same length in source side Parallel.ForEach(dict, pair => //foreach (KeyValuePair<int, List<SntPair>> pair in dict) { Random rnd2 = new Random(DateTime.Now.Millisecond + pair.Key); List <RawSntPair> sntPairList = pair.Value; for (int i = 0; i < sntPairList.Count; i++) { int idx = rnd2.Next(0, sntPairList.Count); RawSntPair tmp = sntPairList[i]; sntPairList[i] = sntPairList[idx]; sntPairList[idx] = tmp; } }); SortedDictionary <int, List <RawSntPair> > sdict = new SortedDictionary <int, List <RawSntPair> >(); //<The bucket size, sentence pair set> foreach (KeyValuePair <int, List <RawSntPair> > pair in dict) { if (sdict.ContainsKey(pair.Value.Count) == false) { sdict.Add(pair.Value.Count, new List <RawSntPair>()); } sdict[pair.Value.Count].AddRange(pair.Value); } rawSntPairs.Clear(); int[] keys = sdict.Keys.ToArray(); for (int i = 0; i < keys.Length; i++) { int idx = rnd.Next(0, keys.Length); int tmp = keys[i]; keys[i] = keys[idx]; keys[idx] = tmp; } foreach (int key in keys) { rawSntPairs.AddRange(sdict[key]); } }
private (string, string) ShuffleAll() { SortedDictionary <int, int> dictSrcLenDist = new SortedDictionary <int, int>(); SortedDictionary <int, int> dictTgtLenDist = new SortedDictionary <int, int>(); string srcShuffledFilePath = Path.Combine(Directory.GetCurrentDirectory(), Path.GetRandomFileName() + ".tmp"); string tgtShuffledFilePath = Path.Combine(Directory.GetCurrentDirectory(), Path.GetRandomFileName() + ".tmp"); Logger.WriteLine($"Shuffling corpus for '{m_srcFileList.Count}' files."); StreamWriter swSrc = new StreamWriter(srcShuffledFilePath, false); StreamWriter swTgt = new StreamWriter(tgtShuffledFilePath, false); List <RawSntPair> sntPairs = new List <RawSntPair>(); CorpusSize = 0; int tooLongSrcSntCnt = 0; for (int i = 0; i < m_srcFileList.Count; i++) { if (m_showTokenDist) { Logger.WriteLine($"Process file '{m_srcFileList[i]}' and '{m_tgtFileList[i]}'"); } StreamReader srSrc = new StreamReader(m_srcFileList[i]); StreamReader srTgt = new StreamReader(m_tgtFileList[i]); while (true) { if (srSrc.EndOfStream && srTgt.EndOfStream) { break; } RawSntPair rawSntPair = new RawSntPair(srSrc.ReadLine(), srTgt.ReadLine()); if (rawSntPair.IsEmptyPair()) { break; } if (dictSrcLenDist.ContainsKey(rawSntPair.SrcLength / 100) == false) { dictSrcLenDist.Add(rawSntPair.SrcLength / 100, 0); } dictSrcLenDist[rawSntPair.SrcLength / 100]++; if (dictTgtLenDist.ContainsKey(rawSntPair.TgtLength / 100) == false) { dictTgtLenDist.Add(rawSntPair.TgtLength / 100, 0); } dictTgtLenDist[rawSntPair.TgtLength / 100]++; bool hasTooLongSent = false; if (rawSntPair.SrcLength > m_maxSentLength) { tooLongSrcSntCnt++; hasTooLongSent = true; } if (hasTooLongSent) { continue; } sntPairs.Add(rawSntPair); CorpusSize++; if (m_blockSize > 0 && sntPairs.Count >= m_blockSize) { Shuffle(sntPairs); foreach (RawSntPair item in sntPairs) { swSrc.WriteLine(item.SrcSnt); swTgt.WriteLine(item.TgtSnt); } sntPairs.Clear(); } } srSrc.Close(); srTgt.Close(); } if (sntPairs.Count > 0) { Shuffle(sntPairs); foreach (RawSntPair item in sntPairs) { swSrc.WriteLine(item.SrcSnt); swTgt.WriteLine(item.TgtSnt); } sntPairs.Clear(); } swSrc.Close(); swTgt.Close(); Logger.WriteLine($"Shuffled '{CorpusSize}' sentence pairs to file '{srcShuffledFilePath}' and '{tgtShuffledFilePath}'."); if (tooLongSrcSntCnt > 0) { Logger.WriteLine(Logger.Level.warn, ConsoleColor.Yellow, $"Found {tooLongSrcSntCnt} source sentences are longer than '{m_maxSentLength}' tokens, ignore them."); } if (m_showTokenDist) { Logger.WriteLine($"Src token length distribution"); int srcTotalNum = 0; foreach (var pair in dictSrcLenDist) { srcTotalNum += pair.Value; } int srcAccNum = 0; foreach (var pair in dictSrcLenDist) { srcAccNum += pair.Value; Logger.WriteLine($"{pair.Key * 100} ~ {(pair.Key + 1) * 100}: {pair.Value} (acc: {(100.0f * (float)srcAccNum / (float)srcTotalNum).ToString("F")}%)"); } Logger.WriteLine($"Tgt token length distribution"); int tgtTotalNum = 0; foreach (var pair in dictTgtLenDist) { tgtTotalNum += pair.Value; } int tgtAccNum = 0; foreach (var pair in dictTgtLenDist) { tgtAccNum += pair.Value; Logger.WriteLine($"{pair.Key * 100} ~ {(pair.Key + 1) * 100}: {pair.Value} (acc: {(100.0f * (float)tgtAccNum / (float)tgtTotalNum).ToString("F")}%)"); } m_showTokenDist = false; } return(srcShuffledFilePath, tgtShuffledFilePath); }
private void Shuffle(List <RawSntPair> rawSntPairs) { if (m_shuffleEnums == ShuffleEnums.Random) { for (int i = 0; i < rawSntPairs.Count; i++) { int idx = rnd.Next(0, rawSntPairs.Count); RawSntPair tmp = rawSntPairs[i]; rawSntPairs[i] = rawSntPairs[idx]; rawSntPairs[idx] = tmp; } return; } //Put sentence pair with same source length into the bucket Dictionary <int, List <RawSntPair> > dict = new Dictionary <int, List <RawSntPair> >(); //<source sentence length, sentence pair set> foreach (RawSntPair item in rawSntPairs) { int length = m_shuffleEnums == ShuffleEnums.NoPaddingInSrc ? item.SrcLength : item.TgtLength; if (dict.ContainsKey(length) == false) { dict.Add(length, new List <RawSntPair>()); } dict[length].Add(item); } //Randomized the order of sentence pairs with same length in source side Parallel.ForEach(dict, pair => //foreach (KeyValuePair<int, List<SntPair>> pair in dict) { Random rnd2 = new Random(DateTime.Now.Millisecond + pair.Key); List <RawSntPair> sntPairList = pair.Value; for (int i = 0; i < sntPairList.Count; i++) { int idx = rnd2.Next(0, sntPairList.Count); RawSntPair tmp = sntPairList[i]; sntPairList[i] = sntPairList[idx]; sntPairList[idx] = tmp; } }); //Split large bucket to smaller buckets Dictionary <int, List <RawSntPair> > dictSB = new Dictionary <int, List <RawSntPair> >(); foreach (var pair in dict) { if (pair.Value.Count <= m_batchSize) { dictSB.Add(pair.Key, pair.Value); } else { int N = pair.Value.Count / m_batchSize; for (int i = 0; i < N; i++) { var pairs = pair.Value.GetRange(i * m_batchSize, m_batchSize); dictSB.Add(pair.Key + 10000 * i, pairs); } if (pair.Value.Count % m_batchSize != 0) { dictSB.Add(pair.Key + 10000 * N, pair.Value.GetRange(m_batchSize * N, pair.Value.Count % m_batchSize)); } } } rawSntPairs.Clear(); int[] keys = dictSB.Keys.ToArray(); for (int i = 0; i < keys.Length; i++) { int idx = rnd.Next(0, keys.Length); int tmp = keys[i]; keys[i] = keys[idx]; keys[idx] = tmp; } foreach (int key in keys) { rawSntPairs.AddRange(dictSB[key]); } }