private bool CreateTrainableParameters(IModelMetaData mmd) { Logger.WriteLine($"Creating encoders and decoders..."); var modelMetaData = mmd as SeqLabelModelMetaData; var raDeviceIds = new RoundArray <int>(this.DeviceIds); if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM) { this.m_encoder = new MultiProcessorNetworkWrapper <IEncoder>( new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem(), true), this.DeviceIds); this.m_decoderFFLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", modelMetaData.HiddenDim * 2, modelMetaData.Vocab.TargetWordSize, 0.0f, raDeviceIds.GetNextItem(), true), this.DeviceIds); } else { this.m_encoder = new MultiProcessorNetworkWrapper <IEncoder>( new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, this.m_dropoutRatio, raDeviceIds.GetNextItem(), true), this.DeviceIds); this.m_decoderFFLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", modelMetaData.HiddenDim, modelMetaData.Vocab.TargetWordSize, 0.0f, raDeviceIds.GetNextItem(), true), this.DeviceIds); } this.m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.Vocab.SourceWordSize, modelMetaData.EmbeddingDim }, raDeviceIds.GetNextItem(), normal: NormType.Normal, name: "SrcEmbeddings", isTrainable: true), this.DeviceIds); // m_crfDecoder = new CRFDecoder(modelMetaData.Vocab.TargetWordSize); this.m_posEmbedding = modelMetaData.EncoderType == EncoderTypeEnums.Transformer ? new MultiProcessorNetworkWrapper <IWeightTensor>(this.BuildPositionWeightTensor(Math.Max(this.m_maxSntSize, this.m_maxSntSize) + 2, modelMetaData.EmbeddingDim, raDeviceIds.GetNextItem(), "PosEmbedding", false), this.DeviceIds, true) : null; return(true); }
private bool CreateTrainableParameters(IModelMetaData mmd) { Logger.WriteLine($"Creating encoders and decoders..."); Seq2SeqModelMetaData modelMetaData = mmd as Seq2SeqModelMetaData; RoundArray <int> raDeviceIds = new RoundArray <int>(DeviceIds); if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM) { m_encoder = new MultiProcessorNetworkWrapper <IEncoder>( new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem()), DeviceIds); m_decoderFFLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", modelMetaData.HiddenDim * 2, modelMetaData.Vocab.TargetWordSize, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem()), DeviceIds); } else { m_encoder = new MultiProcessorNetworkWrapper <IEncoder>( new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, m_dropoutRatio, raDeviceIds.GetNextItem()), DeviceIds); m_decoderFFLayer = new MultiProcessorNetworkWrapper <FeedForwardLayer>(new FeedForwardLayer("FeedForward", modelMetaData.HiddenDim, modelMetaData.Vocab.TargetWordSize, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem()), DeviceIds); } m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.Vocab.SourceWordSize, modelMetaData.EmbeddingDim }, raDeviceIds.GetNextItem(), normal: true, name: "SrcEmbeddings", isTrainable: true), DeviceIds); // m_crfDecoder = new CRFDecoder(modelMetaData.Vocab.TargetWordSize); return(true); }
private bool CreateTrainableParameters(IModelMetaData mmd) { Logger.WriteLine($"Creating encoders and decoders..."); Seq2SeqModelMetaData modelMetaData = mmd as Seq2SeqModelMetaData; RoundArray <int> raDeviceIds = new RoundArray <int>(DeviceIds); if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM) { m_encoder = new MultiProcessorNetworkWrapper <IEncoder>( new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem(), isTrainable: m_isEncoderTrainable), DeviceIds); m_decoder = new MultiProcessorNetworkWrapper <AttentionDecoder>( new AttentionDecoder("AttnLSTMDecoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.HiddenDim * 2, modelMetaData.Vocab.TargetWordSize, m_dropoutRatio, modelMetaData.DecoderLayerDepth, raDeviceIds.GetNextItem(), modelMetaData.EnableCoverageModel, isTrainable: m_isDecoderTrainable), DeviceIds); } else { m_encoder = new MultiProcessorNetworkWrapper <IEncoder>( new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.EncoderLayerDepth, m_dropoutRatio, raDeviceIds.GetNextItem(), isTrainable: m_isEncoderTrainable), DeviceIds); m_decoder = new MultiProcessorNetworkWrapper <AttentionDecoder>( new AttentionDecoder("AttnLSTMDecoder", modelMetaData.HiddenDim, modelMetaData.EmbeddingDim, modelMetaData.HiddenDim, modelMetaData.Vocab.TargetWordSize, m_dropoutRatio, modelMetaData.DecoderLayerDepth, raDeviceIds.GetNextItem(), modelMetaData.EnableCoverageModel, isTrainable: m_isDecoderTrainable), DeviceIds); } m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.Vocab.SourceWordSize, modelMetaData.EmbeddingDim }, raDeviceIds.GetNextItem(), normal: true, name: "SrcEmbeddings", isTrainable: m_isSrcEmbTrainable), DeviceIds); m_tgtEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.Vocab.TargetWordSize, modelMetaData.EmbeddingDim }, raDeviceIds.GetNextItem(), normal: true, name: "TgtEmbeddings", isTrainable: m_isTgtEmbTrainable), DeviceIds); return(true); }
public bool SaveModel(IModelMetaData modelMetaData) { try { Logger.WriteLine($"Saving model to '{m_modelFilePath}'"); if (File.Exists(m_modelFilePath)) { File.Copy(m_modelFilePath, $"{m_modelFilePath}.bak", true); } BinaryFormatter bf = new BinaryFormatter(); using (FileStream fs = new FileStream(m_modelFilePath, FileMode.Create, FileAccess.Write)) { // Save model meta data to the stream bf.Serialize(fs, modelMetaData); // All networks and tensors which are MultiProcessorNetworkWrapper<T> will be saved to given stream SaveParameters(fs); } return true; } catch (Exception err) { Logger.WriteLine($"Failed to save model to file. Exception = '{err.Message}'"); return false; } }
private bool CreateTrainableParameters(IModelMetaData mmd) { Logger.WriteLine($"Creating encoders and decoders..."); Seq2SeqModelMetaData modelMetaData = mmd as Seq2SeqModelMetaData; RoundArray <int> raDeviceIds = new RoundArray <int>(DeviceIds); int contextDim = 0; if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM) { m_encoder = new MultiProcessorNetworkWrapper <IEncoder>( new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.SrcEmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem(), isTrainable: m_isEncoderTrainable), DeviceIds); contextDim = modelMetaData.HiddenDim * 2; } else { m_encoder = new MultiProcessorNetworkWrapper <IEncoder>( new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.SrcEmbeddingDim, modelMetaData.EncoderLayerDepth, m_dropoutRatio, raDeviceIds.GetNextItem(), isTrainable: m_isEncoderTrainable), DeviceIds); contextDim = modelMetaData.HiddenDim; } if (modelMetaData.DecoderType == DecoderTypeEnums.AttentionLSTM) { m_decoder = new MultiProcessorNetworkWrapper <IDecoder>( new AttentionDecoder("AttnLSTMDecoder", modelMetaData.HiddenDim, modelMetaData.TgtEmbeddingDim, contextDim, modelMetaData.Vocab.TargetWordSize, m_dropoutRatio, modelMetaData.DecoderLayerDepth, raDeviceIds.GetNextItem(), modelMetaData.EnableCoverageModel, isTrainable: m_isDecoderTrainable), DeviceIds); } else { m_decoder = new MultiProcessorNetworkWrapper <IDecoder>( new TransformerDecoder("TransformerDecoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.TgtEmbeddingDim, modelMetaData.Vocab.TargetWordSize, modelMetaData.EncoderLayerDepth, m_dropoutRatio, raDeviceIds.GetNextItem(), isTrainable: m_isDecoderTrainable), DeviceIds); } if (modelMetaData.EncoderType == EncoderTypeEnums.Transformer || modelMetaData.DecoderType == DecoderTypeEnums.Transformer) { m_posEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(BuildPositionWeightTensor(Math.Max(m_maxSrcSntSize, m_maxTgtSntSize) + 2, contextDim, raDeviceIds.GetNextItem(), "PosEmbedding", false), DeviceIds, true); } else { m_posEmbedding = null; } Logger.WriteLine($"Creating embeddings for source side. Shape = '({modelMetaData.Vocab.SourceWordSize} ,{modelMetaData.SrcEmbeddingDim})'"); m_srcEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.Vocab.SourceWordSize, modelMetaData.SrcEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: m_isSrcEmbTrainable), DeviceIds); Logger.WriteLine($"Creating embeddings for target side. Shape = '({modelMetaData.Vocab.TargetWordSize} ,{modelMetaData.TgtEmbeddingDim})'"); m_tgtEmbedding = new MultiProcessorNetworkWrapper <IWeightTensor>(new WeightTensor(new long[2] { modelMetaData.Vocab.TargetWordSize, modelMetaData.TgtEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: m_isTgtEmbTrainable), DeviceIds); return(true); }
/// <summary> /// Load model from given file /// </summary> /// <param name="InitializeParameters"></param> /// <returns></returns> public IModelMetaData LoadModel(Func<IModelMetaData, bool> InitializeParameters) { Logger.WriteLine($"Loading model from '{m_modelFilePath}'..."); IModelMetaData modelMetaData = null; BinaryFormatter bf = new BinaryFormatter(); using (FileStream fs = new FileStream(m_modelFilePath, FileMode.Open, FileAccess.Read)) { modelMetaData = bf.Deserialize(fs) as IModelMetaData; //Initialize parameters on devices InitializeParameters(modelMetaData); // Load embedding and weights from given model // All networks and tensors which are MultiProcessorNetworkWrapper<T> will be loaded from given stream LoadParameters(fs); } return modelMetaData; }
private void CreateCheckPoint(IEnumerable<SntPairBatch> validCorpus, List<IMetric> metrics, IModelMetaData modelMetaData, Func<IComputeGraph, List<List<string>>, List<List<string>>, int, bool, float> ForwardOnSingleDevice, double avgCostPerWordInTotal) { if (validCorpus != null) { // The valid corpus is provided, so evaluate the model. if (RunValid(validCorpus, ForwardOnSingleDevice, metrics, true) == true) { SaveModel(modelMetaData); } } else if (m_avgCostPerWordInTotalInLastEpoch > avgCostPerWordInTotal) { // We don't have valid corpus, so if we could have lower cost, save the model SaveModel(modelMetaData); } }
internal void TrainOneEpoch(int ep, IEnumerable<SntPairBatch> trainCorpus, IEnumerable<SntPairBatch> validCorpus, ILearningRate learningRate, AdamOptimizer solver, List<IMetric> metrics, IModelMetaData modelMetaData, Func<IComputeGraph, List<List<string>>, List<List<string>>, int, bool, float> ForwardOnSingleDevice) { int processedLineInTotal = 0; DateTime startDateTime = DateTime.Now; double costInTotal = 0.0; long srcWordCntsInTotal = 0; long tgtWordCntsInTotal = 0; double avgCostPerWordInTotal = 0.0; Logger.WriteLine($"Start to process training corpus."); List<SntPairBatch> sntPairBatchs = new List<SntPairBatch>(); foreach (SntPairBatch sntPairBatch in trainCorpus) { sntPairBatchs.Add(sntPairBatch); if (sntPairBatchs.Count == m_deviceIds.Length) { // Copy weights from weights kept in default device to all other devices CopyWeightsFromDefaultDeviceToAllOtherDevices(); int batchSplitFactor = 1; bool runNetwordSuccssed = false; while (runNetwordSuccssed == false) { try { (float cost, int sWordCnt, int tWordCnt, int processedLine) = RunNetwork(ForwardOnSingleDevice, sntPairBatchs, batchSplitFactor); processedLineInTotal += processedLine; srcWordCntsInTotal += sWordCnt; tgtWordCntsInTotal += tWordCnt; //Sum up gradients in all devices, and kept it in default device for parameters optmization SumGradientsToTensorsInDefaultDevice(); //Optmize parameters float lr = learningRate.GetCurrentLearningRate(); List<IWeightTensor> models = GetParametersFromDefaultDevice(); solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1); costInTotal += cost; avgCostPerWordInTotal = costInTotal / tgtWordCntsInTotal; m_weightsUpdateCount++; if (IterationDone != null && m_weightsUpdateCount % 100 == 0) { IterationDone(this, new CostEventArg() { LearningRate = lr, CostPerWord = cost / tWordCnt, AvgCostInTotal = avgCostPerWordInTotal, Epoch = ep, Update = m_weightsUpdateCount, ProcessedSentencesInTotal = processedLineInTotal, ProcessedWordsInTotal = srcWordCntsInTotal + tgtWordCntsInTotal, StartDateTime = startDateTime }); } runNetwordSuccssed = true; } catch (AggregateException err) { if (err.InnerExceptions != null) { string oomMessage = String.Empty; bool isOutOfMemException = false; bool isArithmeticException = false; foreach (var excep in err.InnerExceptions) { if (excep is OutOfMemoryException) { isOutOfMemException = true; oomMessage = excep.Message; break; } else if (excep is ArithmeticException) { isArithmeticException = true; oomMessage = excep.Message; break; } } if (isOutOfMemException) { batchSplitFactor *= 2; Logger.WriteLine($"Got an exception ('{oomMessage}'), so we increase batch split factor to {batchSplitFactor}, and retry it."); if (batchSplitFactor >= sntPairBatchs[0].BatchSize) { Logger.WriteLine($"Batch split factor is larger than batch size, so ignore current mini-batch."); break; } } else if (isArithmeticException) { Logger.WriteLine($"Arithmetic exception: '{err.Message}'"); break; } else { Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}"); throw err; } } else { Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}"); throw err; } } catch (OutOfMemoryException err) { batchSplitFactor *= 2; Logger.WriteLine($"Got an exception ('{err.Message}'), so we increase batch split factor to {batchSplitFactor}, and retry it."); if (batchSplitFactor >= sntPairBatchs[0].BatchSize) { Logger.WriteLine($"Batch split factor is larger than batch size, so ignore current mini-batch."); break; } } catch (ArithmeticException err) { Logger.WriteLine($"Arithmetic exception: '{err.Message}'"); break; } catch (Exception err) { Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}"); throw err; } } // Evaluate model every hour and save it if we could get a better one. TimeSpan ts = DateTime.Now - m_lastCheckPointDateTime; if (ts.TotalHours > 1.0) { CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_lastCheckPointDateTime = DateTime.Now; } sntPairBatchs.Clear(); } } Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}"); // CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal; }
internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice) { int processedLineInTotal = 0; DateTime startDateTime = DateTime.Now; DateTime lastCheckPointDateTime = DateTime.Now; double costInTotal = 0.0; long srcWordCnts = 0; long tgtWordCnts = 0; double avgCostPerWordInTotal = 0.0; TensorAllocator.FreeMemoryAllDevices(); Logger.WriteLine($"Start to process training corpus."); List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>(); foreach (SntPairBatch sntPairBatch in trainCorpus) { sntPairBatchs.Add(sntPairBatch); if (sntPairBatchs.Count == m_deviceIds.Length) { float cost = 0.0f; int tlen = 0; int processedLine = 0; // Copy weights from weights kept in default device to all other devices CopyWeightsFromDefaultDeviceToAllOtherDevices(); // Run forward and backward on all available processors Parallel.For(0, m_deviceIds.Length, i => { SntPairBatch sntPairBatch_i = sntPairBatchs[i]; // Construct sentences for encoding and decoding List <List <string> > srcTkns = new List <List <string> >(); List <List <string> > tgtTkns = new List <List <string> >(); int sLenInBatch = 0; int tLenInBatch = 0; for (int j = 0; j < sntPairBatch_i.BatchSize; j++) { srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList()); sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length; tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList()); tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length; } float lcost = 0.0f; // Create a new computing graph instance using (IComputeGraph computeGraph_i = CreateComputGraph(i)) { // Run forward part lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true); // Run backward part and compute gradients computeGraph_i.Backward(); } lock (locker) { cost += lcost; srcWordCnts += sLenInBatch; tgtWordCnts += tLenInBatch; tlen += tLenInBatch; processedLineInTotal += sntPairBatch_i.BatchSize; processedLine += sntPairBatch_i.BatchSize; } }); //Sum up gradients in all devices, and kept it in default device for parameters optmization SumGradientsToTensorsInDefaultDevice(); //Optmize parameters float lr = learningRate.GetCurrentLearningRate(); List <IWeightTensor> models = GetParametersFromDefaultDevice(); solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1); //Clear gradient over all devices ZeroGradientOnAllDevices(); costInTotal += cost; avgCostPerWordInTotal = costInTotal / tgtWordCnts; m_weightsUpdateCount++; if (IterationDone != null && m_weightsUpdateCount % 100 == 0) { IterationDone(this, new CostEventArg() { LearningRate = lr, CostPerWord = cost / tlen, AvgCostInTotal = avgCostPerWordInTotal, Epoch = ep, Update = m_weightsUpdateCount, ProcessedSentencesInTotal = processedLineInTotal, ProcessedWordsInTotal = srcWordCnts + tgtWordCnts, StartDateTime = startDateTime }); } // Evaluate model every hour and save it if we could get a better one. TimeSpan ts = DateTime.Now - lastCheckPointDateTime; if (ts.TotalHours > 1.0) { CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); lastCheckPointDateTime = DateTime.Now; } sntPairBatchs.Clear(); } } Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}"); CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal); m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal; }