void Shuffle(List <SntPair> sntPairs) { //Put sentence pair with same source length into the bucket Dictionary <int, List <SntPair> > dict = new Dictionary <int, List <SntPair> >(); //<source sentence length, sentence pair set> foreach (SntPair item in sntPairs) { if (dict.ContainsKey(item.SrcSnt.Length) == false) { dict.Add(item.SrcSnt.Length, new List <SntPair>()); } dict[item.SrcSnt.Length].Add(item); } //Randomized the order of sentence pairs with same length in source side foreach (KeyValuePair <int, List <SntPair> > pair in dict) { var sntPairList = pair.Value; for (int i = 0; i < sntPairList.Count; i++) { int idx = rnd.Next(0, sntPairList.Count); SntPair tmp = sntPairList[i]; sntPairList[i] = sntPairList[idx]; sntPairList[idx] = tmp; } } SortedDictionary <int, List <SntPair> > sdict = new SortedDictionary <int, List <SntPair> >(); //<The bucket size, sentence pair set> foreach (KeyValuePair <int, List <SntPair> > pair in dict) { if (sdict.ContainsKey(pair.Value.Count) == false) { sdict.Add(pair.Value.Count, new List <SntPair>()); } sdict[pair.Value.Count].AddRange(pair.Value); } sntPairs.Clear(); int[] keys = sdict.Keys.ToArray(); for (int i = 0; i < keys.Length; i++) { int idx = rnd.Next(0, keys.Length); int tmp = keys[i]; keys[i] = keys[idx]; keys[idx] = tmp; } foreach (var key in keys) { sntPairs.AddRange(sdict[key]); } }
/// <summary> /// Shuffle given sentence pairs and return the length of the longgest source sentence /// </summary> /// <param name="sntPairs"></param> /// <returns></returns> private int InnerShuffle(List <SntPair> sntPairs) { int maxSrcLen = 0; for (int i = 0; i < sntPairs.Count; i++) { if (sntPairs[i].SrcSnt.Length > maxSrcLen) { maxSrcLen = sntPairs[i].SrcSnt.Length; } int idx = rnd.Next(0, sntPairs.Count); SntPair tmp = sntPairs[i]; sntPairs[i] = sntPairs[idx]; sntPairs[idx] = tmp; } return(maxSrcLen); }
public IEnumerator <SntPairBatch> GetEnumerator() { (string srcShuffledFilePath, string tgtShuffledFilePath) = ShuffleAll(); using (StreamReader srSrc = new StreamReader(srcShuffledFilePath)) { using (StreamReader srTgt = new StreamReader(tgtShuffledFilePath)) { int lastSrcSntLen = -1; int maxOutputsSize = m_batchSize * 10000; List <SntPair> outputs = new List <SntPair>(); while (true) { string line; SntPair sntPair = new SntPair(); if ((line = srSrc.ReadLine()) == null) { break; } line = line.ToLower().Trim(); if (m_addBOSEOS) { line = $"{BOS} {line} {EOS}"; } sntPair.SrcSnt = line.Split(' '); line = srTgt.ReadLine().ToLower().Trim(); if (m_addBOSEOS) { line = $"{line} {EOS}"; } sntPair.TgtSnt = line.Split(' '); if ((lastSrcSntLen > 0 && lastSrcSntLen != sntPair.SrcSnt.Length) || outputs.Count > maxOutputsSize) { InnerShuffle(outputs); for (int i = 0; i < outputs.Count; i += m_batchSize) { int size = Math.Min(m_batchSize, outputs.Count - i); yield return(new SntPairBatch(outputs.GetRange(i, size))); } outputs.Clear(); } outputs.Add(sntPair); lastSrcSntLen = sntPair.SrcSnt.Length; } InnerShuffle(outputs); for (int i = 0; i < outputs.Count; i += m_batchSize) { int size = Math.Min(m_batchSize, outputs.Count - i); yield return(new SntPairBatch(outputs.GetRange(i, size))); } } } File.Delete(srcShuffledFilePath); File.Delete(tgtShuffledFilePath); }
private (string, string) ShuffleAll() { string srcShuffledFilePath = Path.Combine(Directory.GetCurrentDirectory(), Path.GetRandomFileName()); string tgtShuffledFilePath = Path.Combine(Directory.GetCurrentDirectory(), Path.GetRandomFileName()); Logger.WriteLine("Shuffling corpus..."); StreamWriter swSrc = new StreamWriter(srcShuffledFilePath, false); StreamWriter swTgt = new StreamWriter(tgtShuffledFilePath, false); List <SntPair> sntPairs = new List <SntPair>(); CorpusSize = 0; var tooLongSntCnt = 0; for (int i = 0; i < m_srcFileList.Count; i++) { StreamReader srSrc = new StreamReader(m_srcFileList[i]); StreamReader srTgt = new StreamReader(m_tgtFileList[i]); while (true) { string line; SntPair sntPair = new SntPair(); if ((line = srSrc.ReadLine()) == null) { break; } sntPair.SrcSnt = line.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); line = srTgt.ReadLine(); sntPair.TgtSnt = line.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); if (sntPair.SrcSnt.Length >= m_maxSentLength || sntPair.TgtSnt.Length >= m_maxSentLength) { tooLongSntCnt++; continue; } sntPairs.Add(sntPair); CorpusSize++; if (m_blockSize > 0 && sntPairs.Count >= m_blockSize) { Shuffle(sntPairs); foreach (var item in sntPairs) { swSrc.WriteLine(String.Join(" ", item.SrcSnt)); swTgt.WriteLine(String.Join(" ", item.TgtSnt)); } sntPairs.Clear(); } } srSrc.Close(); srTgt.Close(); } if (sntPairs.Count > 0) { Shuffle(sntPairs); foreach (var item in sntPairs) { swSrc.WriteLine(String.Join(" ", item.SrcSnt)); swTgt.WriteLine(String.Join(" ", item.TgtSnt)); } sntPairs.Clear(); } swSrc.Close(); swTgt.Close(); Logger.WriteLine($"Shuffled '{CorpusSize}' sentence pairs to file '{srcShuffledFilePath}' and '{tgtShuffledFilePath}'."); if (tooLongSntCnt > 0) { Logger.WriteLine(Logger.Level.warn, ConsoleColor.Yellow, $"Found {tooLongSntCnt} sentences are longer than '{m_maxSentLength}' tokens, ignore them."); } return(srcShuffledFilePath, tgtShuffledFilePath); }
public IEnumerator <T> GetEnumerator() { (string srcShuffledFilePath, string tgtShuffledFilePath) = ShuffleAll(); using (StreamReader srSrc = new StreamReader(srcShuffledFilePath)) { using StreamReader srTgt = new StreamReader(tgtShuffledFilePath); int[] lastSrcSntLen = null; int[] lastTgtSntLen = null; int maxOutputsSize = m_batchSize * 10000; List <SntPair> outputs = new List <SntPair>(); while (true) { string line; if ((line = srSrc.ReadLine()) == null) { break; } var srcLine = line.Trim(); var tgtLine = srTgt.ReadLine().Trim(); SntPair sntPair = new SntPair(srcLine, tgtLine); if (lastSrcSntLen == null) { lastSrcSntLen = new int[sntPair.SrcTokenGroups.Count]; lastTgtSntLen = new int[sntPair.TgtTokenGroups.Count]; for (int i = 0; i < lastSrcSntLen.Length; i++) { lastSrcSntLen[i] = -1; } for (int i = 0; i < lastTgtSntLen.Length; i++) { lastTgtSntLen[i] = -1; } } if ((lastTgtSntLen[0] > 0 && m_shuffleEnums == ShuffleEnums.NoPaddingInTgt && SameSntLen(sntPair.TgtTokenGroups, lastTgtSntLen) == false) || (lastSrcSntLen[0] > 0 && m_shuffleEnums == ShuffleEnums.NoPaddingInSrc && SameSntLen(sntPair.SrcTokenGroups, lastSrcSntLen) == false) || (lastSrcSntLen[0] > 0 && lastTgtSntLen[0] > 0 && m_shuffleEnums == ShuffleEnums.NoPadding && (SameSntLen(sntPair.TgtTokenGroups, lastTgtSntLen) == false || SameSntLen(sntPair.SrcTokenGroups, lastSrcSntLen) == false)) || outputs.Count > maxOutputsSize) { // InnerShuffle(outputs); for (int i = 0; i < outputs.Count; i += m_batchSize) { int size = Math.Min(m_batchSize, outputs.Count - i); var batch = new T(); batch.CreateBatch(outputs.GetRange(i, size)); yield return(batch); } outputs.Clear(); } outputs.Add(sntPair); if (lastSrcSntLen != null) { UpdateSntLen(sntPair.SrcTokenGroups, lastSrcSntLen); UpdateSntLen(sntPair.TgtTokenGroups, lastTgtSntLen); } } // InnerShuffle(outputs); for (int i = 0; i < outputs.Count; i += m_batchSize) { int size = Math.Min(m_batchSize, outputs.Count - i); var batch = new T(); batch.CreateBatch(outputs.GetRange(i, size)); yield return(batch); } } File.Delete(srcShuffledFilePath); File.Delete(tgtShuffledFilePath); }