/// <summary> /// Evaluate the quality of model on valid corpus. /// </summary> /// <param name="validCorpus">valid corpus to measure the quality of model</param> /// <param name="RunNetwork">The network to run on specific device</param> /// <param name="metrics">A set of metrics. The first one is the primary metric</param> /// <param name="outputToFile">It indicates if valid corpus and results should be dumped to files</param> /// <returns>true if we get a better result on primary metric, otherwise, false</returns> internal bool RunValid(ParallelCorpus validCorpus, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> RunNetwork, List <IMetric> metrics, bool outputToFile = false) { List <string> srcSents = null; List <string> refSents = null; List <string> hypSents = null; int batchSplitFactor = 1; bool runNetwordSuccssed = false; while (runNetwordSuccssed == false) { try { if (batchSplitFactor == 1) { Logger.WriteLine(Logger.Level.info, ConsoleColor.Gray, $"Start to evaluate model..."); } else { Logger.WriteLine(Logger.Level.info, ConsoleColor.Gray, $"Retry to evaluate model. Batch split factor = '{batchSplitFactor}'"); } srcSents = new List <string>(); refSents = new List <string>(); hypSents = new List <string>(); // Clear inner status of each metrics foreach (IMetric metric in metrics) { metric.ClearStatus(); } List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>(); foreach (SntPairBatch item in validCorpus) { sntPairBatchs.Add(item); if (sntPairBatchs.Count == DeviceIds.Length) { // Run forward on all available processors Parallel.For(0, m_deviceIds.Length, i => { SntPairBatch sntPairBatch = sntPairBatchs[i]; int batchSegSize = sntPairBatch.BatchSize / batchSplitFactor; for (int k = 0; k < batchSplitFactor; k++) { // Construct sentences for encoding and decoding List <List <string> > srcTkns = new List <List <string> >(); List <List <string> > refTkns = new List <List <string> >(); List <List <string> > hypTkns = new List <List <string> >(); for (int j = k * batchSegSize; j < (k + 1) * batchSegSize; j++) { srcTkns.Add(sntPairBatch.SntPairs[j].SrcSnt.ToList()); refTkns.Add(sntPairBatch.SntPairs[j].TgtSnt.ToList()); hypTkns.Add(new List <string>() { ParallelCorpus.BOS }); } // Create a new computing graph instance using (IComputeGraph computeGraph = CreateComputGraph(DeviceIds[i], needBack: false)) { // Run forward part RunNetwork(computeGraph, srcTkns, hypTkns, DeviceIds[i], false); } lock (locker) { for (int j = 0; j < hypTkns.Count; j++) { foreach (IMetric metric in metrics) { metric.Evaluate(new List <List <string> >() { refTkns[j] }, hypTkns[j]); } } if (outputToFile) { for (int j = 0; j < srcTkns.Count; j++) { srcSents.Add(string.Join(" ", srcTkns[j])); refSents.Add(string.Join(" ", refTkns[j])); hypSents.Add(string.Join(" ", hypTkns[j])); } } } } }); sntPairBatchs.Clear(); } } runNetwordSuccssed = true; } catch (Exception err) { batchSplitFactor *= 2; if (batchSplitFactor >= 512) { Logger.WriteLine($"Batch split factor is larger than batch size, give it up."); throw err; } } } Logger.WriteLine($"Metrics result:"); foreach (IMetric metric in metrics) { Logger.WriteLine(Logger.Level.info, ConsoleColor.DarkGreen, $"{metric.Name} = {metric.GetScoreStr()}"); } if (outputToFile) { File.WriteAllLines("valid_src.txt", srcSents); File.WriteAllLines("valid_ref.txt", refSents); File.WriteAllLines("valid_hyp.txt", hypSents); } if (metrics.Count > 0) { if (metrics[0].GetPrimaryScore() > m_bestPrimaryScore) { Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"We got a better score '{metrics[0].GetPrimaryScore().ToString("F")}' on primary metric '{metrics[0].Name}'. The previous score is '{m_bestPrimaryScore.ToString("F")}'"); //We have a better primary score on valid set m_bestPrimaryScore = metrics[0].GetPrimaryScore(); return(true); } } return(false); }
internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice) { int processedLineInTotal = 0; DateTime startDateTime = DateTime.Now; DateTime lastCheckPointDateTime = DateTime.Now; double costInTotal = 0.0; long srcWordCnts = 0; long tgtWordCnts = 0; double avgCostPerWordInTotal = 0.0; TensorAllocator.FreeMemoryAllDevices(); Logger.WriteLine($"Start to process training corpus."); List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>(); foreach (SntPairBatch sntPairBatch in trainCorpus) { sntPairBatchs.Add(sntPairBatch); if (sntPairBatchs.Count == m_deviceIds.Length) { float cost = 0.0f; int tlen = 0; int processedLine = 0; // Copy weights from weights kept in default device to all other devices CopyWeightsFromDefaultDeviceToAllOtherDevices(); // Run forward and backward on all available processors Parallel.For(0, m_deviceIds.Length, i => { SntPairBatch sntPairBatch_i = sntPairBatchs[i]; // Construct sentences for encoding and decoding List <List <string> > srcTkns = new List <List <string> >(); List <List <string> > tgtTkns = new List <List <string> >(); int sLenInBatch = 0; int tLenInBatch = 0; for (int j = 0; j < sntPairBatch_i.BatchSize; j++) { srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList()); sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length; tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList()); tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length; } float lcost = 0.0f; // Create a new computing graph instance using (IComputeGraph computeGraph_i = CreateComputGraph(i)) { // Run forward part lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true); // Run backward part and compute gradients computeGraph_i.Backward(); } lock (locker) { cost += lcost; srcWordCnts += sLenInBatch; tgtWordCnts += tLenInBatch; tlen += tLenInBatch; processedLineInTotal += sntPairBatch_i.BatchSize; processedLine += sntPairBatch_i.BatchSize; } }); //Sum up gradients in all devices, and kept it in default device for parameters optmization SumGradientsToTensorsInDefaultDevice(); //Optmize parameters float lr = learningRate.GetCurrentLearningRate(); List <IWeightTensor> models = GetParametersFromDefaultDevice(); solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1); //Clear gradient over all devices ZeroGradientOnAllDevices(); costInTotal += cost; avgCostPerWordInTotal = costInTotal / tgtWordCnts; m_weightsUpdateCount++; if (IterationDone != null && m_weightsUpdateCount % 100 == 0) { IterationDone(this, new CostEventArg() { LearningRate = lr, CostPerWord = cost / tlen, AvgCostInTotal = avgCostPerWordInTotal, Epoch = ep, Update = m_weightsUpdateCount, ProcessedSentencesInTotal = processedLineInTotal, ProcessedWordsInTotal = srcWordCnts + tgtWordCnts, StartDateTime = startDateTime }); } // Evaluate model every hour and save it if we could get a better one. TimeSpan ts = DateTime.Now - lastCheckPointDateTime; if (ts.TotalHours > 1.0) { CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); lastCheckPointDateTime = DateTime.Now; } sntPairBatchs.Clear(); } } Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}"); CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal; }
internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice) { int processedLineInTotal = 0; DateTime startDateTime = DateTime.Now; DateTime lastCheckPointDateTime = DateTime.Now; double costInTotal = 0.0; long srcWordCntsInTotal = 0; long tgtWordCntsInTotal = 0; double avgCostPerWordInTotal = 0.0; TensorAllocator.FreeMemoryAllDevices(); Logger.WriteLine($"Start to process training corpus."); List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>(); foreach (SntPairBatch sntPairBatch in trainCorpus) { sntPairBatchs.Add(sntPairBatch); if (sntPairBatchs.Count == m_deviceIds.Length) { // Copy weights from weights kept in default device to all other devices CopyWeightsFromDefaultDeviceToAllOtherDevices(); int batchSplitFactor = 1; bool runNetwordSuccssed = false; while (runNetwordSuccssed == false) { try { (float cost, int sWordCnt, int tWordCnt, int processedLine) = RunNetwork(ForwardOnSingleDevice, sntPairBatchs, batchSplitFactor); processedLineInTotal += processedLine; srcWordCntsInTotal += sWordCnt; tgtWordCntsInTotal += tWordCnt; //Sum up gradients in all devices, and kept it in default device for parameters optmization SumGradientsToTensorsInDefaultDevice(); //Optmize parameters float lr = learningRate.GetCurrentLearningRate(); List <IWeightTensor> models = GetParametersFromDefaultDevice(); solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1); costInTotal += cost; avgCostPerWordInTotal = costInTotal / tgtWordCntsInTotal; m_weightsUpdateCount++; if (IterationDone != null && m_weightsUpdateCount % 100 == 0) { IterationDone(this, new CostEventArg() { LearningRate = lr, CostPerWord = cost / tWordCnt, AvgCostInTotal = avgCostPerWordInTotal, Epoch = ep, Update = m_weightsUpdateCount, ProcessedSentencesInTotal = processedLineInTotal, ProcessedWordsInTotal = srcWordCntsInTotal + tgtWordCntsInTotal, StartDateTime = startDateTime }); } runNetwordSuccssed = true; } catch (Exception err) { batchSplitFactor *= 2; Logger.WriteLine($"Increase batch split factor to {batchSplitFactor}, and retry it."); if (batchSplitFactor >= sntPairBatchs[0].BatchSize) { Logger.WriteLine($"Batch split factor is larger than batch size, give it up."); throw err; } } } // Evaluate model every hour and save it if we could get a better one. TimeSpan ts = DateTime.Now - lastCheckPointDateTime; if (ts.TotalHours > 1.0) { CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); lastCheckPointDateTime = DateTime.Now; } sntPairBatchs.Clear(); } } Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}"); CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal; }
/// <summary> /// Evaluate the quality of model on valid corpus. /// </summary> /// <param name="validCorpus">valid corpus to measure the quality of model</param> /// <param name="RunNetwork">The network to run on specific device</param> /// <param name="metrics">A set of metrics. The first one is the primary metric</param> /// <param name="outputToFile">It indicates if valid corpus and results should be dumped to files</param> /// <returns>true if we get a better result on primary metric, otherwise, false</returns> internal bool RunValid(ParallelCorpus validCorpus, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> RunNetwork, List <IMetric> metrics, bool outputToFile = false) { Logger.WriteLine(Logger.Level.info, ConsoleColor.Gray, $"Start to Evaluate model..."); List <string> srcSents = new List <string>(); List <string> refSents = new List <string>(); List <string> hypSents = new List <string>(); // Clear inner status of each metrics foreach (var metric in metrics) { metric.ClearStatus(); } foreach (var sntPairBatch in validCorpus) { // Construct sentences for encoding and decoding List <List <string> > srcTkns = new List <List <string> >(); List <List <string> > refTkns = new List <List <string> >(); for (int j = 0; j < sntPairBatch.BatchSize; j++) { srcTkns.Add(sntPairBatch.SntPairs[j].SrcSnt.ToList()); refTkns.Add(sntPairBatch.SntPairs[j].TgtSnt.ToList()); } List <List <string> > hypTkns = new List <List <string> >(); // Create a new computing graph instance using (IComputeGraph computeGraph = CreateComputGraph(DeviceIds[0], needBack: false)) { // Run forward part RunNetwork(computeGraph, srcTkns, hypTkns, DeviceIds[0], false); } for (int i = 0; i < hypTkns.Count; i++) { foreach (var metric in metrics) { metric.Evaluate(new List <List <string> >() { refTkns[i] }, hypTkns[i]); } } if (outputToFile) { for (int j = 0; j < sntPairBatch.BatchSize; j++) { srcSents.Add(String.Join(" ", srcTkns[j])); refSents.Add(String.Join(" ", refTkns[j])); hypSents.Add(String.Join(" ", hypTkns[j])); } } } Logger.WriteLine($"Metrics result:"); foreach (IMetric metric in metrics) { Logger.WriteLine(Logger.Level.info, ConsoleColor.DarkGreen, $"{metric.Name} = {metric.GetScoreStr()}"); } if (outputToFile) { File.WriteAllLines("valid_src.txt", srcSents); File.WriteAllLines("valid_ref.txt", refSents); File.WriteAllLines("valid_hyp.txt", hypSents); } if (metrics.Count > 0) { if (metrics[0].GetPrimaryScore() > m_bestPrimaryScore) { Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"We got a better score '{metrics[0].GetPrimaryScore().ToString("F")}' on primary metric '{metrics[0].Name}'. The previous score is '{m_bestPrimaryScore.ToString("F")}'"); //We have a better primary score on valid set m_bestPrimaryScore = metrics[0].GetPrimaryScore(); return(true); } } return(false); }