private (float, int, int, int) RunNetwork(Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice, List <SntPairBatch> sntPairBatchs, int batchSplitFactor)
        {
            float cost          = 0.0f;
            int   processedLine = 0;
            int   srcWordCnts   = 0;
            int   tgtWordCnts   = 0;

            //Clear gradient over all devices
            ZeroGradientOnAllDevices();

            // Run forward and backward on all available processors
            Parallel.For(0, m_deviceIds.Length, i =>
            {
                SntPairBatch sntPairBatch_i = sntPairBatchs[i];
                int batchSegSize            = sntPairBatch_i.BatchSize / batchSplitFactor;

                for (int k = 0; k < batchSplitFactor; k++)
                {
                    // Construct sentences for encoding and decoding
                    List <List <string> > srcTkns = new List <List <string> >();
                    List <List <string> > tgtTkns = new List <List <string> >();
                    int sLenInBatch = 0;
                    int tLenInBatch = 0;
                    for (int j = k * batchSegSize; j < (k + 1) * batchSegSize; j++)
                    {
                        srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList());
                        sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length;

                        tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList());
                        tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length;
                    }

                    float lcost = 0.0f;
                    // Create a new computing graph instance
                    using (IComputeGraph computeGraph_i = CreateComputGraph(i))
                    {
                        // Run forward part
                        lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true);
                        // Run backward part and compute gradients
                        computeGraph_i.Backward();
                    }

                    lock (locker)
                    {
                        cost          += lcost;
                        srcWordCnts   += sLenInBatch;
                        tgtWordCnts   += tLenInBatch;
                        processedLine += batchSegSize;
                    }
                }
            });

            return(cost, srcWordCnts, tgtWordCnts, processedLine);
        }
        private void TrainEp(int ep, float learningRate)
        {
            int      processedLine = 0;
            DateTime startDateTime = DateTime.Now;

            double         costInTotal               = 0.0;
            long           srcWordCnts               = 0;
            long           tgtWordCnts               = 0;
            double         avgCostPerWordInTotal     = 0.0;
            double         lastAvgCostPerWordInTotal = 100000.0;
            List <SntPair> sntPairs = new List <SntPair>();

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Base learning rate is '{learningRate}' at epoch '{ep}'");

            //Clean caches of parameter optmization
            Logger.WriteLine($"Cleaning cache of weights optmiazation.'");
            CleanWeightCache();

            Logger.WriteLine($"Start to process training corpus.");
            foreach (var sntPair in TrainCorpus)
            {
                sntPairs.Add(sntPair);

                if (sntPairs.Count == TrainCorpus.BatchSize)
                {
                    List <IWeightMatrix>  encoded = new List <IWeightMatrix>();
                    List <List <string> > srcSnts = new List <List <string> >();
                    List <List <string> > tgtSnts = new List <List <string> >();

                    var slen = 0;
                    var tlen = 0;
                    for (int j = 0; j < TrainCorpus.BatchSize; j++)
                    {
                        List <string> srcSnt = new List <string>();

                        //Add BOS and EOS tags to source sentences
                        srcSnt.Add(m_START);
                        srcSnt.AddRange(sntPairs[j].SrcSnt);
                        srcSnt.Add(m_END);

                        srcSnts.Add(srcSnt);
                        tgtSnts.Add(sntPairs[j].TgtSnt.ToList());

                        slen += srcSnt.Count;
                        tlen += sntPairs[j].TgtSnt.Length;
                    }
                    srcWordCnts += slen;
                    tgtWordCnts += tlen;

                    Reset();

                    //Copy weights from weights kept in default device to all other devices
                    SyncWeights();

                    float cost = 0.0f;
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        IComputeGraph computeGraph = CreateComputGraph(i);

                        //Bi-directional encoding input source sentences
                        IWeightMatrix encodedWeightMatrix = Encode(computeGraph, srcSnts.GetRange(i * m_batchSize, m_batchSize), m_biEncoder[i], m_srcEmbedding[i]);

                        //Generate output decoder sentences
                        List <List <string> > predictSentence;
                        float lcost = Decode(tgtSnts.GetRange(i * m_batchSize, m_batchSize), computeGraph, encodedWeightMatrix, m_decoder[i], m_decoderFFLayer[i],
                                             m_tgtEmbedding[i], out predictSentence);

                        lock (locker)
                        {
                            cost += lcost;
                        }
                        //Calculate gradients
                        computeGraph.Backward();
                    });

                    //Sum up gradients in all devices, and kept it in default device for parameters optmization
                    SyncGradient();


                    if (float.IsInfinity(cost) == false && float.IsNaN(cost) == false)
                    {
                        processedLine += TrainCorpus.BatchSize;
                        double costPerWord = (cost / tlen);
                        costInTotal              += cost;
                        avgCostPerWordInTotal     = costInTotal / tgtWordCnts;
                        lastAvgCostPerWordInTotal = avgCostPerWordInTotal;
                    }
                    else
                    {
                        Logger.WriteLine($"Invalid cost value.");
                    }

                    //Optmize parameters
                    float avgAllLR = UpdateParameters(learningRate, TrainCorpus.BatchSize);

                    //Clear gradient over all devices
                    ClearGradient();

                    if (IterationDone != null && processedLine % (100 * TrainCorpus.BatchSize) == 0)
                    {
                        IterationDone(this, new CostEventArg()
                        {
                            AvgLearningRate           = avgAllLR,
                            CostPerWord               = cost / tlen,
                            avgCostInTotal            = avgCostPerWordInTotal,
                            Epoch                     = ep,
                            ProcessedSentencesInTotal = processedLine,
                            ProcessedWordsInTotal     = srcWordCnts * 2 + tgtWordCnts,
                            StartDateTime             = startDateTime
                        });
                    }


                    //Save model for each 10000 steps
                    if (processedLine % (TrainCorpus.BatchSize * 1000) == 0)
                    {
                        Save();
                        TensorAllocator.FreeMemoryAllDevices();
                    }

                    sntPairs.Clear();
                }
            }

            Logger.WriteLine($"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish.");

            Save();
        }
        internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal   = 0;
            DateTime startDateTime          = DateTime.Now;
            DateTime lastCheckPointDateTime = DateTime.Now;
            double   costInTotal            = 0.0;
            long     srcWordCnts            = 0;
            long     tgtWordCnts            = 0;
            double   avgCostPerWordInTotal  = 0.0;

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    float cost          = 0.0f;
                    int   tlen          = 0;
                    int   processedLine = 0;

                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    // Run forward and backward on all available processors
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        SntPairBatch sntPairBatch_i = sntPairBatchs[i];
                        // Construct sentences for encoding and decoding
                        List <List <string> > srcTkns = new List <List <string> >();
                        List <List <string> > tgtTkns = new List <List <string> >();
                        int sLenInBatch = 0;
                        int tLenInBatch = 0;
                        for (int j = 0; j < sntPairBatch_i.BatchSize; j++)
                        {
                            srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList());
                            sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length;

                            tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList());
                            tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length;
                        }

                        float lcost = 0.0f;
                        // Create a new computing graph instance
                        using (IComputeGraph computeGraph_i = CreateComputGraph(i))
                        {
                            // Run forward part
                            lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true);
                            // Run backward part and compute gradients
                            computeGraph_i.Backward();
                        }

                        lock (locker)
                        {
                            cost                 += lcost;
                            srcWordCnts          += sLenInBatch;
                            tgtWordCnts          += tLenInBatch;
                            tlen                 += tLenInBatch;
                            processedLineInTotal += sntPairBatch_i.BatchSize;
                            processedLine        += sntPairBatch_i.BatchSize;
                        }
                    });

                    //Sum up gradients in all devices, and kept it in default device for parameters optmization
                    SumGradientsToTensorsInDefaultDevice();

                    //Optmize parameters
                    float lr = learningRate.GetCurrentLearningRate();
                    List <IWeightTensor> models = GetParametersFromDefaultDevice();
                    solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);

                    //Clear gradient over all devices
                    ZeroGradientOnAllDevices();

                    costInTotal          += cost;
                    avgCostPerWordInTotal = costInTotal / tgtWordCnts;
                    m_weightsUpdateCount++;
                    if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                    {
                        IterationDone(this, new CostEventArg()
                        {
                            LearningRate              = lr,
                            CostPerWord               = cost / tlen,
                            AvgCostInTotal            = avgCostPerWordInTotal,
                            Epoch                     = ep,
                            Update                    = m_weightsUpdateCount,
                            ProcessedSentencesInTotal = processedLineInTotal,
                            ProcessedWordsInTotal     = srcWordCnts + tgtWordCnts,
                            StartDateTime             = startDateTime
                        });
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }