void Shuffle(List <SntPair> sntPairs)
        {
            //Put sentence pair with same source length into the bucket
            Dictionary <int, List <SntPair> > dict = new Dictionary <int, List <SntPair> >(); //<source sentence length, sentence pair set>

            foreach (SntPair item in sntPairs)
            {
                if (dict.ContainsKey(item.SrcSnt.Length) == false)
                {
                    dict.Add(item.SrcSnt.Length, new List <SntPair>());
                }
                dict[item.SrcSnt.Length].Add(item);
            }

            //Randomized the order of sentence pairs with same length in source side
            foreach (KeyValuePair <int, List <SntPair> > pair in dict)
            {
                var sntPairList = pair.Value;
                for (int i = 0; i < sntPairList.Count; i++)
                {
                    int     idx = rnd.Next(0, sntPairList.Count);
                    SntPair tmp = sntPairList[i];
                    sntPairList[i]   = sntPairList[idx];
                    sntPairList[idx] = tmp;
                }
            }

            SortedDictionary <int, List <SntPair> > sdict = new SortedDictionary <int, List <SntPair> >(); //<The bucket size, sentence pair set>

            foreach (KeyValuePair <int, List <SntPair> > pair in dict)
            {
                if (sdict.ContainsKey(pair.Value.Count) == false)
                {
                    sdict.Add(pair.Value.Count, new List <SntPair>());
                }
                sdict[pair.Value.Count].AddRange(pair.Value);
            }

            sntPairs.Clear();

            int[] keys = sdict.Keys.ToArray();
            for (int i = 0; i < keys.Length; i++)
            {
                int idx = rnd.Next(0, keys.Length);
                int tmp = keys[i];
                keys[i]   = keys[idx];
                keys[idx] = tmp;
            }

            foreach (var key in keys)
            {
                sntPairs.AddRange(sdict[key]);
            }
        }
        /// <summary>
        /// Shuffle given sentence pairs and return the length of the longgest source sentence
        /// </summary>
        /// <param name="sntPairs"></param>
        /// <returns></returns>
        private int InnerShuffle(List <SntPair> sntPairs)
        {
            int maxSrcLen = 0;

            for (int i = 0; i < sntPairs.Count; i++)
            {
                if (sntPairs[i].SrcSnt.Length > maxSrcLen)
                {
                    maxSrcLen = sntPairs[i].SrcSnt.Length;
                }

                int     idx = rnd.Next(0, sntPairs.Count);
                SntPair tmp = sntPairs[i];
                sntPairs[i]   = sntPairs[idx];
                sntPairs[idx] = tmp;
            }

            return(maxSrcLen);
        }
        public IEnumerator <SntPairBatch> GetEnumerator()
        {
            (string srcShuffledFilePath, string tgtShuffledFilePath) = ShuffleAll();

            using (StreamReader srSrc = new StreamReader(srcShuffledFilePath))
            {
                using (StreamReader srTgt = new StreamReader(tgtShuffledFilePath))
                {
                    int            lastSrcSntLen  = -1;
                    int            maxOutputsSize = m_batchSize * 10000;
                    List <SntPair> outputs        = new List <SntPair>();

                    while (true)
                    {
                        string  line;
                        SntPair sntPair = new SntPair();
                        if ((line = srSrc.ReadLine()) == null)
                        {
                            break;
                        }

                        line = line.ToLower().Trim();
                        if (m_addBOSEOS)
                        {
                            line = $"{BOS} {line} {EOS}";
                        }
                        sntPair.SrcSnt = line.Split(' ');

                        line = srTgt.ReadLine().ToLower().Trim();
                        if (m_addBOSEOS)
                        {
                            line = $"{line} {EOS}";
                        }
                        sntPair.TgtSnt = line.Split(' ');

                        if ((lastSrcSntLen > 0 && lastSrcSntLen != sntPair.SrcSnt.Length) || outputs.Count > maxOutputsSize)
                        {
                            InnerShuffle(outputs);
                            for (int i = 0; i < outputs.Count; i += m_batchSize)
                            {
                                int size = Math.Min(m_batchSize, outputs.Count - i);
                                yield return(new SntPairBatch(outputs.GetRange(i, size)));
                            }

                            outputs.Clear();
                        }

                        outputs.Add(sntPair);
                        lastSrcSntLen = sntPair.SrcSnt.Length;
                    }

                    InnerShuffle(outputs);
                    for (int i = 0; i < outputs.Count; i += m_batchSize)
                    {
                        int size = Math.Min(m_batchSize, outputs.Count - i);
                        yield return(new SntPairBatch(outputs.GetRange(i, size)));
                    }
                }
            }

            File.Delete(srcShuffledFilePath);
            File.Delete(tgtShuffledFilePath);
        }
        private (string, string) ShuffleAll()
        {
            string srcShuffledFilePath = Path.Combine(Directory.GetCurrentDirectory(), Path.GetRandomFileName());
            string tgtShuffledFilePath = Path.Combine(Directory.GetCurrentDirectory(), Path.GetRandomFileName());

            Logger.WriteLine("Shuffling corpus...");

            StreamWriter swSrc = new StreamWriter(srcShuffledFilePath, false);
            StreamWriter swTgt = new StreamWriter(tgtShuffledFilePath, false);

            List <SntPair> sntPairs = new List <SntPair>();

            CorpusSize = 0;
            var tooLongSntCnt = 0;

            for (int i = 0; i < m_srcFileList.Count; i++)
            {
                StreamReader srSrc = new StreamReader(m_srcFileList[i]);
                StreamReader srTgt = new StreamReader(m_tgtFileList[i]);

                while (true)
                {
                    string  line;
                    SntPair sntPair = new SntPair();
                    if ((line = srSrc.ReadLine()) == null)
                    {
                        break;
                    }

                    sntPair.SrcSnt = line.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);

                    line           = srTgt.ReadLine();
                    sntPair.TgtSnt = line.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);

                    if (sntPair.SrcSnt.Length >= m_maxSentLength || sntPair.TgtSnt.Length >= m_maxSentLength)
                    {
                        tooLongSntCnt++;
                        continue;
                    }

                    sntPairs.Add(sntPair);
                    CorpusSize++;
                    if (m_blockSize > 0 && sntPairs.Count >= m_blockSize)
                    {
                        Shuffle(sntPairs);
                        foreach (var item in sntPairs)
                        {
                            swSrc.WriteLine(String.Join(" ", item.SrcSnt));
                            swTgt.WriteLine(String.Join(" ", item.TgtSnt));
                        }
                        sntPairs.Clear();
                    }
                }

                srSrc.Close();
                srTgt.Close();
            }

            if (sntPairs.Count > 0)
            {
                Shuffle(sntPairs);
                foreach (var item in sntPairs)
                {
                    swSrc.WriteLine(String.Join(" ", item.SrcSnt));
                    swTgt.WriteLine(String.Join(" ", item.TgtSnt));
                }

                sntPairs.Clear();
            }


            swSrc.Close();
            swTgt.Close();

            Logger.WriteLine($"Shuffled '{CorpusSize}' sentence pairs to file '{srcShuffledFilePath}' and '{tgtShuffledFilePath}'.");

            if (tooLongSntCnt > 0)
            {
                Logger.WriteLine(Logger.Level.warn, ConsoleColor.Yellow, $"Found {tooLongSntCnt} sentences are longer than '{m_maxSentLength}' tokens, ignore them.");
            }

            return(srcShuffledFilePath, tgtShuffledFilePath);
        }
Beispiel #5
0
        public IEnumerator <T> GetEnumerator()
        {
            (string srcShuffledFilePath, string tgtShuffledFilePath) = ShuffleAll();

            using (StreamReader srSrc = new StreamReader(srcShuffledFilePath))
            {
                using StreamReader srTgt = new StreamReader(tgtShuffledFilePath);
                int[]          lastSrcSntLen  = null;
                int[]          lastTgtSntLen  = null;
                int            maxOutputsSize = m_batchSize * 10000;
                List <SntPair> outputs        = new List <SntPair>();

                while (true)
                {
                    string line;
                    if ((line = srSrc.ReadLine()) == null)
                    {
                        break;
                    }

                    var     srcLine = line.Trim();
                    var     tgtLine = srTgt.ReadLine().Trim();
                    SntPair sntPair = new SntPair(srcLine, tgtLine);


                    if (lastSrcSntLen == null)
                    {
                        lastSrcSntLen = new int[sntPair.SrcTokenGroups.Count];
                        lastTgtSntLen = new int[sntPair.TgtTokenGroups.Count];

                        for (int i = 0; i < lastSrcSntLen.Length; i++)
                        {
                            lastSrcSntLen[i] = -1;
                        }

                        for (int i = 0; i < lastTgtSntLen.Length; i++)
                        {
                            lastTgtSntLen[i] = -1;
                        }
                    }


                    if ((lastTgtSntLen[0] > 0 && m_shuffleEnums == ShuffleEnums.NoPaddingInTgt && SameSntLen(sntPair.TgtTokenGroups, lastTgtSntLen) == false) ||
                        (lastSrcSntLen[0] > 0 && m_shuffleEnums == ShuffleEnums.NoPaddingInSrc && SameSntLen(sntPair.SrcTokenGroups, lastSrcSntLen) == false) ||
                        (lastSrcSntLen[0] > 0 && lastTgtSntLen[0] > 0 && m_shuffleEnums == ShuffleEnums.NoPadding && (SameSntLen(sntPair.TgtTokenGroups, lastTgtSntLen) == false || SameSntLen(sntPair.SrcTokenGroups, lastSrcSntLen) == false)) ||
                        outputs.Count > maxOutputsSize)
                    {
                        // InnerShuffle(outputs);
                        for (int i = 0; i < outputs.Count; i += m_batchSize)
                        {
                            int size  = Math.Min(m_batchSize, outputs.Count - i);
                            var batch = new T();
                            batch.CreateBatch(outputs.GetRange(i, size));
                            yield return(batch);
                        }

                        outputs.Clear();
                    }

                    outputs.Add(sntPair);

                    if (lastSrcSntLen != null)
                    {
                        UpdateSntLen(sntPair.SrcTokenGroups, lastSrcSntLen);
                        UpdateSntLen(sntPair.TgtTokenGroups, lastTgtSntLen);
                    }
                }

                // InnerShuffle(outputs);
                for (int i = 0; i < outputs.Count; i += m_batchSize)
                {
                    int size  = Math.Min(m_batchSize, outputs.Count - i);
                    var batch = new T();
                    batch.CreateBatch(outputs.GetRange(i, size));
                    yield return(batch);
                }
            }

            File.Delete(srcShuffledFilePath);
            File.Delete(tgtShuffledFilePath);
        }