private (float, int, int, int) RunNetwork(Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice, List <SntPairBatch> sntPairBatchs, int batchSplitFactor) { float cost = 0.0f; int processedLine = 0; int srcWordCnts = 0; int tgtWordCnts = 0; //Clear gradient over all devices ZeroGradientOnAllDevices(); // Run forward and backward on all available processors Parallel.For(0, m_deviceIds.Length, i => { SntPairBatch sntPairBatch_i = sntPairBatchs[i]; int batchSegSize = sntPairBatch_i.BatchSize / batchSplitFactor; for (int k = 0; k < batchSplitFactor; k++) { // Construct sentences for encoding and decoding List <List <string> > srcTkns = new List <List <string> >(); List <List <string> > tgtTkns = new List <List <string> >(); int sLenInBatch = 0; int tLenInBatch = 0; for (int j = k * batchSegSize; j < (k + 1) * batchSegSize; j++) { srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList()); sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length; tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList()); tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length; } float lcost = 0.0f; // Create a new computing graph instance using (IComputeGraph computeGraph_i = CreateComputGraph(i)) { // Run forward part lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true); // Run backward part and compute gradients computeGraph_i.Backward(); } lock (locker) { cost += lcost; srcWordCnts += sLenInBatch; tgtWordCnts += tLenInBatch; processedLine += batchSegSize; } } }); return(cost, srcWordCnts, tgtWordCnts, processedLine); }
private void TrainEp(int ep, float learningRate) { int processedLine = 0; DateTime startDateTime = DateTime.Now; double costInTotal = 0.0; long srcWordCnts = 0; long tgtWordCnts = 0; double avgCostPerWordInTotal = 0.0; double lastAvgCostPerWordInTotal = 100000.0; List <SntPair> sntPairs = new List <SntPair>(); TensorAllocator.FreeMemoryAllDevices(); Logger.WriteLine($"Base learning rate is '{learningRate}' at epoch '{ep}'"); //Clean caches of parameter optmization Logger.WriteLine($"Cleaning cache of weights optmiazation.'"); CleanWeightCache(); Logger.WriteLine($"Start to process training corpus."); foreach (var sntPair in TrainCorpus) { sntPairs.Add(sntPair); if (sntPairs.Count == TrainCorpus.BatchSize) { List <IWeightMatrix> encoded = new List <IWeightMatrix>(); List <List <string> > srcSnts = new List <List <string> >(); List <List <string> > tgtSnts = new List <List <string> >(); var slen = 0; var tlen = 0; for (int j = 0; j < TrainCorpus.BatchSize; j++) { List <string> srcSnt = new List <string>(); //Add BOS and EOS tags to source sentences srcSnt.Add(m_START); srcSnt.AddRange(sntPairs[j].SrcSnt); srcSnt.Add(m_END); srcSnts.Add(srcSnt); tgtSnts.Add(sntPairs[j].TgtSnt.ToList()); slen += srcSnt.Count; tlen += sntPairs[j].TgtSnt.Length; } srcWordCnts += slen; tgtWordCnts += tlen; Reset(); //Copy weights from weights kept in default device to all other devices SyncWeights(); float cost = 0.0f; Parallel.For(0, m_deviceIds.Length, i => { IComputeGraph computeGraph = CreateComputGraph(i); //Bi-directional encoding input source sentences IWeightMatrix encodedWeightMatrix = Encode(computeGraph, srcSnts.GetRange(i * m_batchSize, m_batchSize), m_biEncoder[i], m_srcEmbedding[i]); //Generate output decoder sentences List <List <string> > predictSentence; float lcost = Decode(tgtSnts.GetRange(i * m_batchSize, m_batchSize), computeGraph, encodedWeightMatrix, m_decoder[i], m_decoderFFLayer[i], m_tgtEmbedding[i], out predictSentence); lock (locker) { cost += lcost; } //Calculate gradients computeGraph.Backward(); }); //Sum up gradients in all devices, and kept it in default device for parameters optmization SyncGradient(); if (float.IsInfinity(cost) == false && float.IsNaN(cost) == false) { processedLine += TrainCorpus.BatchSize; double costPerWord = (cost / tlen); costInTotal += cost; avgCostPerWordInTotal = costInTotal / tgtWordCnts; lastAvgCostPerWordInTotal = avgCostPerWordInTotal; } else { Logger.WriteLine($"Invalid cost value."); } //Optmize parameters float avgAllLR = UpdateParameters(learningRate, TrainCorpus.BatchSize); //Clear gradient over all devices ClearGradient(); if (IterationDone != null && processedLine % (100 * TrainCorpus.BatchSize) == 0) { IterationDone(this, new CostEventArg() { AvgLearningRate = avgAllLR, CostPerWord = cost / tlen, avgCostInTotal = avgCostPerWordInTotal, Epoch = ep, ProcessedSentencesInTotal = processedLine, ProcessedWordsInTotal = srcWordCnts * 2 + tgtWordCnts, StartDateTime = startDateTime }); } //Save model for each 10000 steps if (processedLine % (TrainCorpus.BatchSize * 1000) == 0) { Save(); TensorAllocator.FreeMemoryAllDevices(); } sntPairs.Clear(); } } Logger.WriteLine($"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish."); Save(); }
internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice) { int processedLineInTotal = 0; DateTime startDateTime = DateTime.Now; DateTime lastCheckPointDateTime = DateTime.Now; double costInTotal = 0.0; long srcWordCnts = 0; long tgtWordCnts = 0; double avgCostPerWordInTotal = 0.0; TensorAllocator.FreeMemoryAllDevices(); Logger.WriteLine($"Start to process training corpus."); List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>(); foreach (SntPairBatch sntPairBatch in trainCorpus) { sntPairBatchs.Add(sntPairBatch); if (sntPairBatchs.Count == m_deviceIds.Length) { float cost = 0.0f; int tlen = 0; int processedLine = 0; // Copy weights from weights kept in default device to all other devices CopyWeightsFromDefaultDeviceToAllOtherDevices(); // Run forward and backward on all available processors Parallel.For(0, m_deviceIds.Length, i => { SntPairBatch sntPairBatch_i = sntPairBatchs[i]; // Construct sentences for encoding and decoding List <List <string> > srcTkns = new List <List <string> >(); List <List <string> > tgtTkns = new List <List <string> >(); int sLenInBatch = 0; int tLenInBatch = 0; for (int j = 0; j < sntPairBatch_i.BatchSize; j++) { srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList()); sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length; tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList()); tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length; } float lcost = 0.0f; // Create a new computing graph instance using (IComputeGraph computeGraph_i = CreateComputGraph(i)) { // Run forward part lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true); // Run backward part and compute gradients computeGraph_i.Backward(); } lock (locker) { cost += lcost; srcWordCnts += sLenInBatch; tgtWordCnts += tLenInBatch; tlen += tLenInBatch; processedLineInTotal += sntPairBatch_i.BatchSize; processedLine += sntPairBatch_i.BatchSize; } }); //Sum up gradients in all devices, and kept it in default device for parameters optmization SumGradientsToTensorsInDefaultDevice(); //Optmize parameters float lr = learningRate.GetCurrentLearningRate(); List <IWeightTensor> models = GetParametersFromDefaultDevice(); solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1); //Clear gradient over all devices ZeroGradientOnAllDevices(); costInTotal += cost; avgCostPerWordInTotal = costInTotal / tgtWordCnts; m_weightsUpdateCount++; if (IterationDone != null && m_weightsUpdateCount % 100 == 0) { IterationDone(this, new CostEventArg() { LearningRate = lr, CostPerWord = cost / tlen, AvgCostInTotal = avgCostPerWordInTotal, Epoch = ep, Update = m_weightsUpdateCount, ProcessedSentencesInTotal = processedLineInTotal, ProcessedWordsInTotal = srcWordCnts + tgtWordCnts, StartDateTime = startDateTime }); } // Evaluate model every hour and save it if we could get a better one. TimeSpan ts = DateTime.Now - lastCheckPointDateTime; if (ts.TotalHours > 1.0) { CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); lastCheckPointDateTime = DateTime.Now; } sntPairBatchs.Clear(); } } Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}"); CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal; }