Ejemplo n.º 1
0
        internal void TrainOneEpoch(int ep, IEnumerable<SntPairBatch> trainCorpus, IEnumerable<SntPairBatch> validCorpus, ILearningRate learningRate, AdamOptimizer solver, List<IMetric> metrics, IModelMetaData modelMetaData,
            Func<IComputeGraph, List<List<string>>, List<List<string>>, int, bool, float> ForwardOnSingleDevice)
        {
            int processedLineInTotal = 0;
            DateTime startDateTime = DateTime.Now;
            double costInTotal = 0.0;
            long srcWordCntsInTotal = 0;
            long tgtWordCntsInTotal = 0;
            double avgCostPerWordInTotal = 0.0;

            Logger.WriteLine($"Start to process training corpus.");
            List<SntPairBatch> sntPairBatchs = new List<SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    int batchSplitFactor = 1;
                    bool runNetwordSuccssed = false;

                    while (runNetwordSuccssed == false)
                    {
                        try
                        {
                            (float cost, int sWordCnt, int tWordCnt, int processedLine) = RunNetwork(ForwardOnSingleDevice, sntPairBatchs, batchSplitFactor);
                            processedLineInTotal += processedLine;
                            srcWordCntsInTotal += sWordCnt;
                            tgtWordCntsInTotal += tWordCnt;

                            //Sum up gradients in all devices, and kept it in default device for parameters optmization
                            SumGradientsToTensorsInDefaultDevice();

                            //Optmize parameters
                            float lr = learningRate.GetCurrentLearningRate();
                            List<IWeightTensor> models = GetParametersFromDefaultDevice();
                            solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);


                            costInTotal += cost;
                            avgCostPerWordInTotal = costInTotal / tgtWordCntsInTotal;
                            m_weightsUpdateCount++;
                            if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                            {
                                IterationDone(this, new CostEventArg()
                                {
                                    LearningRate = lr,
                                    CostPerWord = cost / tWordCnt,
                                    AvgCostInTotal = avgCostPerWordInTotal,
                                    Epoch = ep,
                                    Update = m_weightsUpdateCount,
                                    ProcessedSentencesInTotal = processedLineInTotal,
                                    ProcessedWordsInTotal = srcWordCntsInTotal + tgtWordCntsInTotal,
                                    StartDateTime = startDateTime
                                });
                            }

                            runNetwordSuccssed = true;
                        }
                        catch (AggregateException err)
                        {
                            if (err.InnerExceptions != null)
                            {
                                string oomMessage = String.Empty;
                                bool isOutOfMemException = false;
                                bool isArithmeticException = false;
                                foreach (var excep in err.InnerExceptions)
                                {
                                    if (excep is OutOfMemoryException)
                                    {
                                        isOutOfMemException = true;
                                        oomMessage = excep.Message;
                                        break;
                                    }
                                    else if (excep is ArithmeticException)
                                    {
                                        isArithmeticException = true;
                                        oomMessage = excep.Message;
                                        break;
                                    }
                                }

                                if (isOutOfMemException)
                                {
                                    batchSplitFactor *= 2;
                                    Logger.WriteLine($"Got an exception ('{oomMessage}'), so we increase batch split factor to {batchSplitFactor}, and retry it.");

                                    if (batchSplitFactor >= sntPairBatchs[0].BatchSize)
                                    {
                                        Logger.WriteLine($"Batch split factor is larger than batch size, so ignore current mini-batch.");
                                        break;
                                    }
                                }
                                else if (isArithmeticException)
                                {
                                    Logger.WriteLine($"Arithmetic exception: '{err.Message}'");
                                    break;
                                }
                                else
                                {
                                    Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}");
                                    throw err;
                                }
                            }
                            else
                            {
                                Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}");
                                throw err;
                            }

                        }
                        catch (OutOfMemoryException err)
                        {
                            batchSplitFactor *= 2;
                            Logger.WriteLine($"Got an exception ('{err.Message}'), so we increase batch split factor to {batchSplitFactor}, and retry it.");

                            if (batchSplitFactor >= sntPairBatchs[0].BatchSize)
                            {
                                Logger.WriteLine($"Batch split factor is larger than batch size, so ignore current mini-batch.");
                                break;
                            }
                        }
                        catch (ArithmeticException err)
                        {
                            Logger.WriteLine($"Arithmetic exception: '{err.Message}'");
                            break;
                        }
                        catch (Exception err)
                        {
                            Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}");
                            throw err;
                        }
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - m_lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        m_lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            //  CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }
        internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal   = 0;
            DateTime startDateTime          = DateTime.Now;
            DateTime lastCheckPointDateTime = DateTime.Now;
            double   costInTotal            = 0.0;
            long     srcWordCnts            = 0;
            long     tgtWordCnts            = 0;
            double   avgCostPerWordInTotal  = 0.0;

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    float cost          = 0.0f;
                    int   tlen          = 0;
                    int   processedLine = 0;

                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    // Run forward and backward on all available processors
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        SntPairBatch sntPairBatch_i = sntPairBatchs[i];
                        // Construct sentences for encoding and decoding
                        List <List <string> > srcTkns = new List <List <string> >();
                        List <List <string> > tgtTkns = new List <List <string> >();
                        int sLenInBatch = 0;
                        int tLenInBatch = 0;
                        for (int j = 0; j < sntPairBatch_i.BatchSize; j++)
                        {
                            srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList());
                            sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length;

                            tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList());
                            tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length;
                        }

                        float lcost = 0.0f;
                        // Create a new computing graph instance
                        using (IComputeGraph computeGraph_i = CreateComputGraph(i))
                        {
                            // Run forward part
                            lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true);
                            // Run backward part and compute gradients
                            computeGraph_i.Backward();
                        }

                        lock (locker)
                        {
                            cost                 += lcost;
                            srcWordCnts          += sLenInBatch;
                            tgtWordCnts          += tLenInBatch;
                            tlen                 += tLenInBatch;
                            processedLineInTotal += sntPairBatch_i.BatchSize;
                            processedLine        += sntPairBatch_i.BatchSize;
                        }
                    });

                    //Sum up gradients in all devices, and kept it in default device for parameters optmization
                    SumGradientsToTensorsInDefaultDevice();

                    //Optmize parameters
                    float lr = learningRate.GetCurrentLearningRate();
                    List <IWeightTensor> models = GetParametersFromDefaultDevice();
                    solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);

                    //Clear gradient over all devices
                    ZeroGradientOnAllDevices();

                    costInTotal          += cost;
                    avgCostPerWordInTotal = costInTotal / tgtWordCnts;
                    m_weightsUpdateCount++;
                    if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                    {
                        IterationDone(this, new CostEventArg()
                        {
                            LearningRate              = lr,
                            CostPerWord               = cost / tlen,
                            AvgCostInTotal            = avgCostPerWordInTotal,
                            Epoch                     = ep,
                            Update                    = m_weightsUpdateCount,
                            ProcessedSentencesInTotal = processedLineInTotal,
                            ProcessedWordsInTotal     = srcWordCnts + tgtWordCnts,
                            StartDateTime             = startDateTime
                        });
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }