private void RunValidParallel(Func<IComputeGraph, List<List<string>>, List<List<string>>, int, bool, float> RunNetwork, List<IMetric> metrics, bool outputToFile, List<string> srcSents, List<string> refSents, List<string> hypSents, List<SntPairBatch> sntPairBatchs)
        {
            // Run forward on all available processors
            Parallel.For(0, m_deviceIds.Length, i =>
            {
                SntPairBatch sntPairBatch = sntPairBatchs[i];

                // Construct sentences for encoding and decoding
                List<List<string>> srcTkns = new List<List<string>>();
                List<List<string>> refTkns = new List<List<string>>();
                List<List<string>> hypTkns = new List<List<string>>();
                for (int j = 0; j < sntPairBatch.BatchSize; j++)
                {
                    srcTkns.Add(sntPairBatch.SntPairs[j].SrcSnt.ToList());
                    refTkns.Add(sntPairBatch.SntPairs[j].TgtSnt.ToList());
                    hypTkns.Add(new List<string>() { ParallelCorpus.BOS });
                }

                // Create a new computing graph instance
                using (IComputeGraph computeGraph = CreateComputGraph(i, needBack: false))
                {
                    // Run forward part
                    RunNetwork(computeGraph, srcTkns, hypTkns, i, false);
                }

                lock (locker)
                {

                    for (int j = 0; j < hypTkns.Count; j++)
                    {
                        foreach (IMetric metric in metrics)
                        {
                            if (j < 0 || j >= refTkns.Count)
                            {
                                throw new InvalidDataException($"Ref token only has '{refTkns.Count}' batch, however, it try to access batch '{j}'. Hyp token has '{hypTkns.Count}' tokens, Batch Size = '{sntPairBatch.BatchSize}'");
                            }

                            if (j < 0 || j >= hypTkns.Count)
                            {
                                throw new InvalidDataException($"Hyp token only has '{hypTkns.Count}' batch, however, it try to access batch '{j}'. Ref token has '{refTkns.Count}' tokens, Batch Size = '{sntPairBatch.BatchSize}'");
                            }

                            metric.Evaluate(new List<List<string>>() { refTkns[j] }, hypTkns[j]);
                        }
                    }

                    if (outputToFile)
                    {
                        for (int j = 0; j < srcTkns.Count; j++)
                        {
                            srcSents.Add(string.Join(" ", srcTkns[j]));
                            refSents.Add(string.Join(" ", refTkns[j]));
                            hypSents.Add(string.Join(" ", hypTkns[j]));
                        }
                    }
                }


            });
        }
        private (float, int, int, int) RunNetwork(Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice, List <SntPairBatch> sntPairBatchs, int batchSplitFactor)
        {
            float cost          = 0.0f;
            int   processedLine = 0;
            int   srcWordCnts   = 0;
            int   tgtWordCnts   = 0;

            //Clear gradient over all devices
            ZeroGradientOnAllDevices();

            // Run forward and backward on all available processors
            Parallel.For(0, m_deviceIds.Length, i =>
            {
                SntPairBatch sntPairBatch_i = sntPairBatchs[i];
                int batchSegSize            = sntPairBatch_i.BatchSize / batchSplitFactor;

                for (int k = 0; k < batchSplitFactor; k++)
                {
                    // Construct sentences for encoding and decoding
                    List <List <string> > srcTkns = new List <List <string> >();
                    List <List <string> > tgtTkns = new List <List <string> >();
                    int sLenInBatch = 0;
                    int tLenInBatch = 0;
                    for (int j = k * batchSegSize; j < (k + 1) * batchSegSize; j++)
                    {
                        srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList());
                        sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length;

                        tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList());
                        tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length;
                    }

                    float lcost = 0.0f;
                    // Create a new computing graph instance
                    using (IComputeGraph computeGraph_i = CreateComputGraph(i))
                    {
                        // Run forward part
                        lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true);
                        // Run backward part and compute gradients
                        computeGraph_i.Backward();
                    }

                    lock (locker)
                    {
                        cost          += lcost;
                        srcWordCnts   += sLenInBatch;
                        tgtWordCnts   += tLenInBatch;
                        processedLine += batchSegSize;
                    }
                }
            });

            return(cost, srcWordCnts, tgtWordCnts, processedLine);
        }
        /// <summary>
        /// Evaluate the quality of model on valid corpus.
        /// </summary>
        /// <param name="validCorpus">valid corpus to measure the quality of model</param>
        /// <param name="RunNetwork">The network to run on specific device</param>
        /// <param name="metrics">A set of metrics. The first one is the primary metric</param>
        /// <param name="outputToFile">It indicates if valid corpus and results should be dumped to files</param>
        /// <returns>true if we get a better result on primary metric, otherwise, false</returns>
        internal bool RunValid(ParallelCorpus validCorpus, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> RunNetwork, List <IMetric> metrics, bool outputToFile = false)
        {
            List <string> srcSents           = null;
            List <string> refSents           = null;
            List <string> hypSents           = null;
            int           batchSplitFactor   = 1;
            bool          runNetwordSuccssed = false;

            while (runNetwordSuccssed == false)
            {
                try
                {
                    if (batchSplitFactor == 1)
                    {
                        Logger.WriteLine(Logger.Level.info, ConsoleColor.Gray, $"Start to evaluate model...");
                    }
                    else
                    {
                        Logger.WriteLine(Logger.Level.info, ConsoleColor.Gray, $"Retry to evaluate model. Batch split factor = '{batchSplitFactor}'");
                    }

                    srcSents = new List <string>();
                    refSents = new List <string>();
                    hypSents = new List <string>();

                    // Clear inner status of each metrics
                    foreach (IMetric metric in metrics)
                    {
                        metric.ClearStatus();
                    }

                    List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();
                    foreach (SntPairBatch item in validCorpus)
                    {
                        sntPairBatchs.Add(item);
                        if (sntPairBatchs.Count == DeviceIds.Length)
                        {
                            // Run forward on all available processors
                            Parallel.For(0, m_deviceIds.Length, i =>
                            {
                                SntPairBatch sntPairBatch = sntPairBatchs[i];
                                int batchSegSize          = sntPairBatch.BatchSize / batchSplitFactor;

                                for (int k = 0; k < batchSplitFactor; k++)
                                {
                                    // Construct sentences for encoding and decoding
                                    List <List <string> > srcTkns = new List <List <string> >();
                                    List <List <string> > refTkns = new List <List <string> >();
                                    List <List <string> > hypTkns = new List <List <string> >();
                                    for (int j = k * batchSegSize; j < (k + 1) * batchSegSize; j++)
                                    {
                                        srcTkns.Add(sntPairBatch.SntPairs[j].SrcSnt.ToList());
                                        refTkns.Add(sntPairBatch.SntPairs[j].TgtSnt.ToList());
                                        hypTkns.Add(new List <string>()
                                        {
                                            ParallelCorpus.BOS
                                        });
                                    }

                                    // Create a new computing graph instance
                                    using (IComputeGraph computeGraph = CreateComputGraph(DeviceIds[i], needBack: false))
                                    {
                                        // Run forward part
                                        RunNetwork(computeGraph, srcTkns, hypTkns, DeviceIds[i], false);
                                    }

                                    lock (locker)
                                    {
                                        for (int j = 0; j < hypTkns.Count; j++)
                                        {
                                            foreach (IMetric metric in metrics)
                                            {
                                                metric.Evaluate(new List <List <string> >()
                                                {
                                                    refTkns[j]
                                                }, hypTkns[j]);
                                            }
                                        }

                                        if (outputToFile)
                                        {
                                            for (int j = 0; j < srcTkns.Count; j++)
                                            {
                                                srcSents.Add(string.Join(" ", srcTkns[j]));
                                                refSents.Add(string.Join(" ", refTkns[j]));
                                                hypSents.Add(string.Join(" ", hypTkns[j]));
                                            }
                                        }
                                    }
                                }
                            });

                            sntPairBatchs.Clear();
                        }
                    }

                    runNetwordSuccssed = true;
                }
                catch (Exception err)
                {
                    batchSplitFactor *= 2;
                    if (batchSplitFactor >= 512)
                    {
                        Logger.WriteLine($"Batch split factor is larger than batch size, give it up.");
                        throw err;
                    }
                }
            }

            Logger.WriteLine($"Metrics result:");
            foreach (IMetric metric in metrics)
            {
                Logger.WriteLine(Logger.Level.info, ConsoleColor.DarkGreen, $"{metric.Name} = {metric.GetScoreStr()}");
            }

            if (outputToFile)
            {
                File.WriteAllLines("valid_src.txt", srcSents);
                File.WriteAllLines("valid_ref.txt", refSents);
                File.WriteAllLines("valid_hyp.txt", hypSents);
            }

            if (metrics.Count > 0)
            {
                if (metrics[0].GetPrimaryScore() > m_bestPrimaryScore)
                {
                    Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"We got a better score '{metrics[0].GetPrimaryScore().ToString("F")}' on primary metric '{metrics[0].Name}'. The previous score is '{m_bestPrimaryScore.ToString("F")}'");
                    //We have a better primary score on valid set
                    m_bestPrimaryScore = metrics[0].GetPrimaryScore();
                    return(true);
                }
            }

            return(false);
        }
        internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal   = 0;
            DateTime startDateTime          = DateTime.Now;
            DateTime lastCheckPointDateTime = DateTime.Now;
            double   costInTotal            = 0.0;
            long     srcWordCnts            = 0;
            long     tgtWordCnts            = 0;
            double   avgCostPerWordInTotal  = 0.0;

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    float cost          = 0.0f;
                    int   tlen          = 0;
                    int   processedLine = 0;

                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    // Run forward and backward on all available processors
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        SntPairBatch sntPairBatch_i = sntPairBatchs[i];
                        // Construct sentences for encoding and decoding
                        List <List <string> > srcTkns = new List <List <string> >();
                        List <List <string> > tgtTkns = new List <List <string> >();
                        int sLenInBatch = 0;
                        int tLenInBatch = 0;
                        for (int j = 0; j < sntPairBatch_i.BatchSize; j++)
                        {
                            srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList());
                            sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length;

                            tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList());
                            tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length;
                        }

                        float lcost = 0.0f;
                        // Create a new computing graph instance
                        using (IComputeGraph computeGraph_i = CreateComputGraph(i))
                        {
                            // Run forward part
                            lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true);
                            // Run backward part and compute gradients
                            computeGraph_i.Backward();
                        }

                        lock (locker)
                        {
                            cost                 += lcost;
                            srcWordCnts          += sLenInBatch;
                            tgtWordCnts          += tLenInBatch;
                            tlen                 += tLenInBatch;
                            processedLineInTotal += sntPairBatch_i.BatchSize;
                            processedLine        += sntPairBatch_i.BatchSize;
                        }
                    });

                    //Sum up gradients in all devices, and kept it in default device for parameters optmization
                    SumGradientsToTensorsInDefaultDevice();

                    //Optmize parameters
                    float lr = learningRate.GetCurrentLearningRate();
                    List <IWeightTensor> models = GetParametersFromDefaultDevice();
                    solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);

                    //Clear gradient over all devices
                    ZeroGradientOnAllDevices();

                    costInTotal          += cost;
                    avgCostPerWordInTotal = costInTotal / tgtWordCnts;
                    m_weightsUpdateCount++;
                    if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                    {
                        IterationDone(this, new CostEventArg()
                        {
                            LearningRate              = lr,
                            CostPerWord               = cost / tlen,
                            AvgCostInTotal            = avgCostPerWordInTotal,
                            Epoch                     = ep,
                            Update                    = m_weightsUpdateCount,
                            ProcessedSentencesInTotal = processedLineInTotal,
                            ProcessedWordsInTotal     = srcWordCnts + tgtWordCnts,
                            StartDateTime             = startDateTime
                        });
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }
        /// <summary>
        /// Evaluate the quality of model on valid corpus.
        /// </summary>
        /// <param name="validCorpus">valid corpus to measure the quality of model</param>
        /// <param name="RunNetwork">The network to run on specific device</param>
        /// <param name="metrics">A set of metrics. The first one is the primary metric</param>
        /// <param name="outputToFile">It indicates if valid corpus and results should be dumped to files</param>
        /// <returns>true if we get a better result on primary metric, otherwise, false</returns>
        internal bool RunValid(IEnumerable <SntPairBatch> validCorpus, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> RunNetwork, List <IMetric> metrics, bool outputToFile = false)
        {
            List <string> srcSents = new List <string>();
            List <string> refSents = new List <string>();
            List <string> hypSents = new List <string>();


            // Clear inner status of each metrics
            foreach (IMetric metric in metrics)
            {
                metric.ClearStatus();
            }

            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch item in validCorpus)
            {
                sntPairBatchs.Add(item);
                if (sntPairBatchs.Count == DeviceIds.Length)
                {
                    // Run forward on all available processors
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        SntPairBatch sntPairBatch = sntPairBatchs[i];

                        // Construct sentences for encoding and decoding
                        List <List <string> > srcTkns = new List <List <string> >();
                        List <List <string> > refTkns = new List <List <string> >();
                        List <List <string> > hypTkns = new List <List <string> >();
                        for (int j = 0; j < sntPairBatch.BatchSize; j++)
                        {
                            srcTkns.Add(sntPairBatch.SntPairs[j].SrcSnt.ToList());
                            refTkns.Add(sntPairBatch.SntPairs[j].TgtSnt.ToList());
                            hypTkns.Add(new List <string>()
                            {
                                ParallelCorpus.BOS
                            });
                        }

                        // Create a new computing graph instance
                        using (IComputeGraph computeGraph = CreateComputGraph(i, needBack: false))
                        {
                            // Run forward part
                            RunNetwork(computeGraph, srcTkns, hypTkns, i, false);
                        }

                        lock (locker)
                        {
                            for (int j = 0; j < hypTkns.Count; j++)
                            {
                                foreach (IMetric metric in metrics)
                                {
                                    if (j < 0 || j >= refTkns.Count)
                                    {
                                        throw new InvalidDataException($"Ref token only has '{refTkns.Count}' batch, however, it try to access batch '{j}'. Hyp token has '{hypTkns.Count}' tokens, Batch Size = '{sntPairBatch.BatchSize}'");
                                    }

                                    if (j < 0 || j >= hypTkns.Count)
                                    {
                                        throw new InvalidDataException($"Hyp token only has '{hypTkns.Count}' batch, however, it try to access batch '{j}'. Ref token has '{refTkns.Count}' tokens, Batch Size = '{sntPairBatch.BatchSize}'");
                                    }

                                    metric.Evaluate(new List <List <string> >()
                                    {
                                        refTkns[j]
                                    }, hypTkns[j]);
                                }
                            }

                            if (outputToFile)
                            {
                                for (int j = 0; j < srcTkns.Count; j++)
                                {
                                    srcSents.Add(string.Join(" ", srcTkns[j]));
                                    refSents.Add(string.Join(" ", refTkns[j]));
                                    hypSents.Add(string.Join(" ", hypTkns[j]));
                                }
                            }
                        }
                    });

                    sntPairBatchs.Clear();
                }
            }



            Logger.WriteLine($"Metrics result:");
            foreach (IMetric metric in metrics)
            {
                Logger.WriteLine(Logger.Level.info, ConsoleColor.DarkGreen, $"{metric.Name} = {metric.GetScoreStr()}");
            }

            if (outputToFile)
            {
                File.WriteAllLines("valid_src.txt", srcSents);
                File.WriteAllLines("valid_ref.txt", refSents);
                File.WriteAllLines("valid_hyp.txt", hypSents);
            }

            if (metrics.Count > 0)
            {
                if (metrics[0].GetPrimaryScore() > m_bestPrimaryScore)
                {
                    Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"We got a better score '{metrics[0].GetPrimaryScore().ToString("F")}' on primary metric '{metrics[0].Name}'. The previous score is '{m_bestPrimaryScore.ToString("F")}'");
                    //We have a better primary score on valid set
                    m_bestPrimaryScore = metrics[0].GetPrimaryScore();
                    return(true);
                }
            }

            return(false);
        }