示例#1
0
        private void Shuffle(List <RawSntPair> rawSntPairs)
        {
            //Put sentence pair with same source length into the bucket
            Dictionary <int, List <RawSntPair> > dict = new Dictionary <int, List <RawSntPair> >(); //<source sentence length, sentence pair set>

            foreach (RawSntPair item in rawSntPairs)
            {
                int length = item.SrcLength;

                if (dict.ContainsKey(length) == false)
                {
                    dict.Add(length, new List <RawSntPair>());
                }
                dict[length].Add(item);
            }

            //Randomized the order of sentence pairs with same length in source side
            Parallel.ForEach(dict, pair =>
                             //foreach (KeyValuePair<int, List<SntPair>> pair in dict)
            {
                Random rnd2 = new Random(DateTime.Now.Millisecond + pair.Key);

                List <RawSntPair> sntPairList = pair.Value;
                for (int i = 0; i < sntPairList.Count; i++)
                {
                    int idx          = rnd2.Next(0, sntPairList.Count);
                    RawSntPair tmp   = sntPairList[i];
                    sntPairList[i]   = sntPairList[idx];
                    sntPairList[idx] = tmp;
                }
            });

            SortedDictionary <int, List <RawSntPair> > sdict = new SortedDictionary <int, List <RawSntPair> >(); //<The bucket size, sentence pair set>

            foreach (KeyValuePair <int, List <RawSntPair> > pair in dict)
            {
                if (sdict.ContainsKey(pair.Value.Count) == false)
                {
                    sdict.Add(pair.Value.Count, new List <RawSntPair>());
                }
                sdict[pair.Value.Count].AddRange(pair.Value);
            }

            rawSntPairs.Clear();

            int[] keys = sdict.Keys.ToArray();
            for (int i = 0; i < keys.Length; i++)
            {
                int idx = rnd.Next(0, keys.Length);
                int tmp = keys[i];
                keys[i]   = keys[idx];
                keys[idx] = tmp;
            }

            foreach (int key in keys)
            {
                rawSntPairs.AddRange(sdict[key]);
            }
        }
示例#2
0
        private (string, string) ShuffleAll()
        {
            SortedDictionary <int, int> dictSrcLenDist = new SortedDictionary <int, int>();
            SortedDictionary <int, int> dictTgtLenDist = new SortedDictionary <int, int>();

            string srcShuffledFilePath = Path.Combine(Directory.GetCurrentDirectory(), Path.GetRandomFileName() + ".tmp");
            string tgtShuffledFilePath = Path.Combine(Directory.GetCurrentDirectory(), Path.GetRandomFileName() + ".tmp");

            Logger.WriteLine($"Shuffling corpus for '{m_srcFileList.Count}' files.");

            StreamWriter swSrc = new StreamWriter(srcShuffledFilePath, false);
            StreamWriter swTgt = new StreamWriter(tgtShuffledFilePath, false);

            List <RawSntPair> sntPairs = new List <RawSntPair>();

            CorpusSize = 0;
            int tooLongSrcSntCnt = 0;

            for (int i = 0; i < m_srcFileList.Count; i++)
            {
                if (m_showTokenDist)
                {
                    Logger.WriteLine($"Process file '{m_srcFileList[i]}' and '{m_tgtFileList[i]}'");
                }

                StreamReader srSrc = new StreamReader(m_srcFileList[i]);
                StreamReader srTgt = new StreamReader(m_tgtFileList[i]);

                while (true)
                {
                    if (srSrc.EndOfStream && srTgt.EndOfStream)
                    {
                        break;
                    }

                    RawSntPair rawSntPair = new RawSntPair(srSrc.ReadLine(), srTgt.ReadLine());
                    if (rawSntPair.IsEmptyPair())
                    {
                        break;
                    }

                    if (dictSrcLenDist.ContainsKey(rawSntPair.SrcLength / 100) == false)
                    {
                        dictSrcLenDist.Add(rawSntPair.SrcLength / 100, 0);
                    }
                    dictSrcLenDist[rawSntPair.SrcLength / 100]++;

                    if (dictTgtLenDist.ContainsKey(rawSntPair.TgtLength / 100) == false)
                    {
                        dictTgtLenDist.Add(rawSntPair.TgtLength / 100, 0);
                    }
                    dictTgtLenDist[rawSntPair.TgtLength / 100]++;


                    bool hasTooLongSent = false;
                    if (rawSntPair.SrcLength > m_maxSentLength)
                    {
                        tooLongSrcSntCnt++;
                        hasTooLongSent = true;
                    }

                    if (hasTooLongSent)
                    {
                        continue;
                    }

                    sntPairs.Add(rawSntPair);
                    CorpusSize++;
                    if (m_blockSize > 0 && sntPairs.Count >= m_blockSize)
                    {
                        Shuffle(sntPairs);
                        foreach (RawSntPair item in sntPairs)
                        {
                            swSrc.WriteLine(item.SrcSnt);
                            swTgt.WriteLine(item.TgtSnt);
                        }
                        sntPairs.Clear();
                    }
                }

                srSrc.Close();
                srTgt.Close();
            }

            if (sntPairs.Count > 0)
            {
                Shuffle(sntPairs);
                foreach (RawSntPair item in sntPairs)
                {
                    swSrc.WriteLine(item.SrcSnt);
                    swTgt.WriteLine(item.TgtSnt);
                }

                sntPairs.Clear();
            }


            swSrc.Close();
            swTgt.Close();

            Logger.WriteLine($"Shuffled '{CorpusSize}' sentence pairs to file '{srcShuffledFilePath}' and '{tgtShuffledFilePath}'.");

            if (tooLongSrcSntCnt > 0)
            {
                Logger.WriteLine(Logger.Level.warn, ConsoleColor.Yellow, $"Found {tooLongSrcSntCnt} source sentences are longer than '{m_maxSentLength}' tokens, ignore them.");
            }

            if (m_showTokenDist)
            {
                Logger.WriteLine($"Src token length distribution");

                int srcTotalNum = 0;
                foreach (var pair in dictSrcLenDist)
                {
                    srcTotalNum += pair.Value;
                }

                int srcAccNum = 0;
                foreach (var pair in dictSrcLenDist)
                {
                    srcAccNum += pair.Value;

                    Logger.WriteLine($"{pair.Key * 100} ~ {(pair.Key + 1) * 100}: {pair.Value} (acc: {(100.0f * (float)srcAccNum / (float)srcTotalNum).ToString("F")}%)");
                }

                Logger.WriteLine($"Tgt token length distribution");

                int tgtTotalNum = 0;
                foreach (var pair in dictTgtLenDist)
                {
                    tgtTotalNum += pair.Value;
                }

                int tgtAccNum = 0;

                foreach (var pair in dictTgtLenDist)
                {
                    tgtAccNum += pair.Value;

                    Logger.WriteLine($"{pair.Key * 100} ~ {(pair.Key + 1) * 100}: {pair.Value}  (acc: {(100.0f * (float)tgtAccNum / (float)tgtTotalNum).ToString("F")}%)");
                }

                m_showTokenDist = false;
            }


            return(srcShuffledFilePath, tgtShuffledFilePath);
        }
示例#3
0
        private void Shuffle(List <RawSntPair> rawSntPairs)
        {
            if (m_shuffleEnums == ShuffleEnums.Random)
            {
                for (int i = 0; i < rawSntPairs.Count; i++)
                {
                    int        idx = rnd.Next(0, rawSntPairs.Count);
                    RawSntPair tmp = rawSntPairs[i];
                    rawSntPairs[i]   = rawSntPairs[idx];
                    rawSntPairs[idx] = tmp;
                }

                return;
            }


            //Put sentence pair with same source length into the bucket
            Dictionary <int, List <RawSntPair> > dict = new Dictionary <int, List <RawSntPair> >(); //<source sentence length, sentence pair set>

            foreach (RawSntPair item in rawSntPairs)
            {
                int length = m_shuffleEnums == ShuffleEnums.NoPaddingInSrc ? item.SrcLength : item.TgtLength;

                if (dict.ContainsKey(length) == false)
                {
                    dict.Add(length, new List <RawSntPair>());
                }

                dict[length].Add(item);
            }

            //Randomized the order of sentence pairs with same length in source side
            Parallel.ForEach(dict, pair =>
                             //foreach (KeyValuePair<int, List<SntPair>> pair in dict)
            {
                Random rnd2 = new Random(DateTime.Now.Millisecond + pair.Key);

                List <RawSntPair> sntPairList = pair.Value;
                for (int i = 0; i < sntPairList.Count; i++)
                {
                    int idx          = rnd2.Next(0, sntPairList.Count);
                    RawSntPair tmp   = sntPairList[i];
                    sntPairList[i]   = sntPairList[idx];
                    sntPairList[idx] = tmp;
                }
            });


            //Split large bucket to smaller buckets
            Dictionary <int, List <RawSntPair> > dictSB = new Dictionary <int, List <RawSntPair> >();

            foreach (var pair in dict)
            {
                if (pair.Value.Count <= m_batchSize)
                {
                    dictSB.Add(pair.Key, pair.Value);
                }
                else
                {
                    int N = pair.Value.Count / m_batchSize;

                    for (int i = 0; i < N; i++)
                    {
                        var pairs = pair.Value.GetRange(i * m_batchSize, m_batchSize);
                        dictSB.Add(pair.Key + 10000 * i, pairs);
                    }

                    if (pair.Value.Count % m_batchSize != 0)
                    {
                        dictSB.Add(pair.Key + 10000 * N, pair.Value.GetRange(m_batchSize * N, pair.Value.Count % m_batchSize));
                    }
                }
            }

            rawSntPairs.Clear();

            int[] keys = dictSB.Keys.ToArray();
            for (int i = 0; i < keys.Length; i++)
            {
                int idx = rnd.Next(0, keys.Length);
                int tmp = keys[i];
                keys[i]   = keys[idx];
                keys[idx] = tmp;
            }

            foreach (int key in keys)
            {
                rawSntPairs.AddRange(dictSB[key]);
            }
        }