/// <summary>
        /// Evaluate the quality of model on valid corpus.
        /// </summary>
        /// <param name="validCorpus">valid corpus to measure the quality of model</param>
        /// <param name="RunNetwork">The network to run on specific device</param>
        /// <param name="metrics">A set of metrics. The first one is the primary metric</param>
        /// <param name="outputToFile">It indicates if valid corpus and results should be dumped to files</param>
        /// <returns>true if we get a better result on primary metric, otherwise, false</returns>
        internal bool RunValid(ParallelCorpus validCorpus, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> RunNetwork, List <IMetric> metrics, bool outputToFile = false)
        {
            List <string> srcSents           = null;
            List <string> refSents           = null;
            List <string> hypSents           = null;
            int           batchSplitFactor   = 1;
            bool          runNetwordSuccssed = false;

            while (runNetwordSuccssed == false)
            {
                try
                {
                    if (batchSplitFactor == 1)
                    {
                        Logger.WriteLine(Logger.Level.info, ConsoleColor.Gray, $"Start to evaluate model...");
                    }
                    else
                    {
                        Logger.WriteLine(Logger.Level.info, ConsoleColor.Gray, $"Retry to evaluate model. Batch split factor = '{batchSplitFactor}'");
                    }

                    srcSents = new List <string>();
                    refSents = new List <string>();
                    hypSents = new List <string>();

                    // Clear inner status of each metrics
                    foreach (IMetric metric in metrics)
                    {
                        metric.ClearStatus();
                    }

                    List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();
                    foreach (SntPairBatch item in validCorpus)
                    {
                        sntPairBatchs.Add(item);
                        if (sntPairBatchs.Count == DeviceIds.Length)
                        {
                            // Run forward on all available processors
                            Parallel.For(0, m_deviceIds.Length, i =>
                            {
                                SntPairBatch sntPairBatch = sntPairBatchs[i];
                                int batchSegSize          = sntPairBatch.BatchSize / batchSplitFactor;

                                for (int k = 0; k < batchSplitFactor; k++)
                                {
                                    // Construct sentences for encoding and decoding
                                    List <List <string> > srcTkns = new List <List <string> >();
                                    List <List <string> > refTkns = new List <List <string> >();
                                    List <List <string> > hypTkns = new List <List <string> >();
                                    for (int j = k * batchSegSize; j < (k + 1) * batchSegSize; j++)
                                    {
                                        srcTkns.Add(sntPairBatch.SntPairs[j].SrcSnt.ToList());
                                        refTkns.Add(sntPairBatch.SntPairs[j].TgtSnt.ToList());
                                        hypTkns.Add(new List <string>()
                                        {
                                            ParallelCorpus.BOS
                                        });
                                    }

                                    // Create a new computing graph instance
                                    using (IComputeGraph computeGraph = CreateComputGraph(DeviceIds[i], needBack: false))
                                    {
                                        // Run forward part
                                        RunNetwork(computeGraph, srcTkns, hypTkns, DeviceIds[i], false);
                                    }

                                    lock (locker)
                                    {
                                        for (int j = 0; j < hypTkns.Count; j++)
                                        {
                                            foreach (IMetric metric in metrics)
                                            {
                                                metric.Evaluate(new List <List <string> >()
                                                {
                                                    refTkns[j]
                                                }, hypTkns[j]);
                                            }
                                        }

                                        if (outputToFile)
                                        {
                                            for (int j = 0; j < srcTkns.Count; j++)
                                            {
                                                srcSents.Add(string.Join(" ", srcTkns[j]));
                                                refSents.Add(string.Join(" ", refTkns[j]));
                                                hypSents.Add(string.Join(" ", hypTkns[j]));
                                            }
                                        }
                                    }
                                }
                            });

                            sntPairBatchs.Clear();
                        }
                    }

                    runNetwordSuccssed = true;
                }
                catch (Exception err)
                {
                    batchSplitFactor *= 2;
                    if (batchSplitFactor >= 512)
                    {
                        Logger.WriteLine($"Batch split factor is larger than batch size, give it up.");
                        throw err;
                    }
                }
            }

            Logger.WriteLine($"Metrics result:");
            foreach (IMetric metric in metrics)
            {
                Logger.WriteLine(Logger.Level.info, ConsoleColor.DarkGreen, $"{metric.Name} = {metric.GetScoreStr()}");
            }

            if (outputToFile)
            {
                File.WriteAllLines("valid_src.txt", srcSents);
                File.WriteAllLines("valid_ref.txt", refSents);
                File.WriteAllLines("valid_hyp.txt", hypSents);
            }

            if (metrics.Count > 0)
            {
                if (metrics[0].GetPrimaryScore() > m_bestPrimaryScore)
                {
                    Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"We got a better score '{metrics[0].GetPrimaryScore().ToString("F")}' on primary metric '{metrics[0].Name}'. The previous score is '{m_bestPrimaryScore.ToString("F")}'");
                    //We have a better primary score on valid set
                    m_bestPrimaryScore = metrics[0].GetPrimaryScore();
                    return(true);
                }
            }

            return(false);
        }
        internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal   = 0;
            DateTime startDateTime          = DateTime.Now;
            DateTime lastCheckPointDateTime = DateTime.Now;
            double   costInTotal            = 0.0;
            long     srcWordCnts            = 0;
            long     tgtWordCnts            = 0;
            double   avgCostPerWordInTotal  = 0.0;

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    float cost          = 0.0f;
                    int   tlen          = 0;
                    int   processedLine = 0;

                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    // Run forward and backward on all available processors
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        SntPairBatch sntPairBatch_i = sntPairBatchs[i];
                        // Construct sentences for encoding and decoding
                        List <List <string> > srcTkns = new List <List <string> >();
                        List <List <string> > tgtTkns = new List <List <string> >();
                        int sLenInBatch = 0;
                        int tLenInBatch = 0;
                        for (int j = 0; j < sntPairBatch_i.BatchSize; j++)
                        {
                            srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList());
                            sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length;

                            tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList());
                            tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length;
                        }

                        float lcost = 0.0f;
                        // Create a new computing graph instance
                        using (IComputeGraph computeGraph_i = CreateComputGraph(i))
                        {
                            // Run forward part
                            lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true);
                            // Run backward part and compute gradients
                            computeGraph_i.Backward();
                        }

                        lock (locker)
                        {
                            cost                 += lcost;
                            srcWordCnts          += sLenInBatch;
                            tgtWordCnts          += tLenInBatch;
                            tlen                 += tLenInBatch;
                            processedLineInTotal += sntPairBatch_i.BatchSize;
                            processedLine        += sntPairBatch_i.BatchSize;
                        }
                    });

                    //Sum up gradients in all devices, and kept it in default device for parameters optmization
                    SumGradientsToTensorsInDefaultDevice();

                    //Optmize parameters
                    float lr = learningRate.GetCurrentLearningRate();
                    List <IWeightTensor> models = GetParametersFromDefaultDevice();
                    solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);

                    //Clear gradient over all devices
                    ZeroGradientOnAllDevices();

                    costInTotal          += cost;
                    avgCostPerWordInTotal = costInTotal / tgtWordCnts;
                    m_weightsUpdateCount++;
                    if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                    {
                        IterationDone(this, new CostEventArg()
                        {
                            LearningRate              = lr,
                            CostPerWord               = cost / tlen,
                            AvgCostInTotal            = avgCostPerWordInTotal,
                            Epoch                     = ep,
                            Update                    = m_weightsUpdateCount,
                            ProcessedSentencesInTotal = processedLineInTotal,
                            ProcessedWordsInTotal     = srcWordCnts + tgtWordCnts,
                            StartDateTime             = startDateTime
                        });
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }
        internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal   = 0;
            DateTime startDateTime          = DateTime.Now;
            DateTime lastCheckPointDateTime = DateTime.Now;
            double   costInTotal            = 0.0;
            long     srcWordCntsInTotal     = 0;
            long     tgtWordCntsInTotal     = 0;
            double   avgCostPerWordInTotal  = 0.0;

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    int  batchSplitFactor   = 1;
                    bool runNetwordSuccssed = false;

                    while (runNetwordSuccssed == false)
                    {
                        try
                        {
                            (float cost, int sWordCnt, int tWordCnt, int processedLine) = RunNetwork(ForwardOnSingleDevice, sntPairBatchs, batchSplitFactor);
                            processedLineInTotal += processedLine;
                            srcWordCntsInTotal   += sWordCnt;
                            tgtWordCntsInTotal   += tWordCnt;

                            //Sum up gradients in all devices, and kept it in default device for parameters optmization
                            SumGradientsToTensorsInDefaultDevice();

                            //Optmize parameters
                            float lr = learningRate.GetCurrentLearningRate();
                            List <IWeightTensor> models = GetParametersFromDefaultDevice();
                            solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);


                            costInTotal          += cost;
                            avgCostPerWordInTotal = costInTotal / tgtWordCntsInTotal;
                            m_weightsUpdateCount++;
                            if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                            {
                                IterationDone(this, new CostEventArg()
                                {
                                    LearningRate              = lr,
                                    CostPerWord               = cost / tWordCnt,
                                    AvgCostInTotal            = avgCostPerWordInTotal,
                                    Epoch                     = ep,
                                    Update                    = m_weightsUpdateCount,
                                    ProcessedSentencesInTotal = processedLineInTotal,
                                    ProcessedWordsInTotal     = srcWordCntsInTotal + tgtWordCntsInTotal,
                                    StartDateTime             = startDateTime
                                });
                            }

                            runNetwordSuccssed = true;
                        }
                        catch (Exception err)
                        {
                            batchSplitFactor *= 2;
                            Logger.WriteLine($"Increase batch split factor to {batchSplitFactor}, and retry it.");

                            if (batchSplitFactor >= sntPairBatchs[0].BatchSize)
                            {
                                Logger.WriteLine($"Batch split factor is larger than batch size, give it up.");
                                throw err;
                            }
                        }
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }
Example #4
0
        /// <summary>
        /// Evaluate the quality of model on valid corpus.
        /// </summary>
        /// <param name="validCorpus">valid corpus to measure the quality of model</param>
        /// <param name="RunNetwork">The network to run on specific device</param>
        /// <param name="metrics">A set of metrics. The first one is the primary metric</param>
        /// <param name="outputToFile">It indicates if valid corpus and results should be dumped to files</param>
        /// <returns>true if we get a better result on primary metric, otherwise, false</returns>
        internal bool RunValid(ParallelCorpus validCorpus, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> RunNetwork, List <IMetric> metrics, bool outputToFile = false)
        {
            Logger.WriteLine(Logger.Level.info, ConsoleColor.Gray, $"Start to Evaluate model...");

            List <string> srcSents = new List <string>();
            List <string> refSents = new List <string>();
            List <string> hypSents = new List <string>();

            // Clear inner status of each metrics
            foreach (var metric in metrics)
            {
                metric.ClearStatus();
            }

            foreach (var sntPairBatch in validCorpus)
            {
                // Construct sentences for encoding and decoding
                List <List <string> > srcTkns = new List <List <string> >();
                List <List <string> > refTkns = new List <List <string> >();
                for (int j = 0; j < sntPairBatch.BatchSize; j++)
                {
                    srcTkns.Add(sntPairBatch.SntPairs[j].SrcSnt.ToList());
                    refTkns.Add(sntPairBatch.SntPairs[j].TgtSnt.ToList());
                }

                List <List <string> > hypTkns = new List <List <string> >();

                // Create a new computing graph instance
                using (IComputeGraph computeGraph = CreateComputGraph(DeviceIds[0], needBack: false))
                {
                    // Run forward part
                    RunNetwork(computeGraph, srcTkns, hypTkns, DeviceIds[0], false);
                }

                for (int i = 0; i < hypTkns.Count; i++)
                {
                    foreach (var metric in metrics)
                    {
                        metric.Evaluate(new List <List <string> >()
                        {
                            refTkns[i]
                        }, hypTkns[i]);
                    }
                }

                if (outputToFile)
                {
                    for (int j = 0; j < sntPairBatch.BatchSize; j++)
                    {
                        srcSents.Add(String.Join(" ", srcTkns[j]));
                        refSents.Add(String.Join(" ", refTkns[j]));
                        hypSents.Add(String.Join(" ", hypTkns[j]));
                    }
                }
            }

            Logger.WriteLine($"Metrics result:");
            foreach (IMetric metric in metrics)
            {
                Logger.WriteLine(Logger.Level.info, ConsoleColor.DarkGreen, $"{metric.Name} = {metric.GetScoreStr()}");
            }

            if (outputToFile)
            {
                File.WriteAllLines("valid_src.txt", srcSents);
                File.WriteAllLines("valid_ref.txt", refSents);
                File.WriteAllLines("valid_hyp.txt", hypSents);
            }

            if (metrics.Count > 0)
            {
                if (metrics[0].GetPrimaryScore() > m_bestPrimaryScore)
                {
                    Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"We got a better score '{metrics[0].GetPrimaryScore().ToString("F")}' on primary metric '{metrics[0].Name}'. The previous score is '{m_bestPrimaryScore.ToString("F")}'");
                    //We have a better primary score on valid set
                    m_bestPrimaryScore = metrics[0].GetPrimaryScore();
                    return(true);
                }
            }

            return(false);
        }