Beispiel #1
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="g">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side. In training mode, it inputs target tokens, otherwise, it outputs target tokens generated by decoder</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        private float RunForwardOnSingleDevice(IComputeGraph g, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining)
        {
            (IEncoder encoder, IWeightTensor srcEmbedding, IWeightTensor posEmbedding, FeedForwardLayer decoderFFLayer) = GetNetworksOnDeviceAt(deviceIdIdx);

            // Reset networks
            encoder.Reset(g.GetWeightFactory(), srcSnts.Count);


            List <int> originalSrcLengths = ParallelCorpus.PadSentences(srcSnts);
            int        seqLen             = srcSnts[0].Count;
            int        batchSize          = srcSnts.Count;

            // Encoding input source sentences
            IWeightTensor encOutput    = Encode(g, srcSnts, encoder, srcEmbedding, null, posEmbedding, originalSrcLengths);
            IWeightTensor ffLayer      = decoderFFLayer.Process(encOutput, batchSize, g);
            IWeightTensor ffLayerBatch = g.TransposeBatch(ffLayer, batchSize);

            float cost = 0.0f;

            using (IWeightTensor probs = g.Softmax(ffLayerBatch, runGradients: false, inPlace: true))
            {
                if (isTraining)
                {
                    //Calculate loss for each word in the batch
                    for (int k = 0; k < batchSize; k++)
                    {
                        for (int j = 0; j < seqLen; j++)
                        {
                            using (IWeightTensor probs_k_j = g.PeekRow(probs, k * seqLen + j, runGradients: false))
                            {
                                int   ix_targets_k_j = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSnts[k][j]);
                                float score_k        = probs_k_j.GetWeightAt(ix_targets_k_j);
                                cost += (float)-Math.Log(score_k);

                                probs_k_j.SetWeightAt(score_k - 1, ix_targets_k_j);
                            }
                        }

                        ////CRF part
                        //using (var probs_k = g.PeekRow(probs, k * seqLen, seqLen, runGradients: false))
                        //{
                        //    var weights_k = probs_k.ToWeightArray();
                        //    var crfOutput_k = m_crfDecoder.ForwardBackward(seqLen, weights_k);

                        //    int[] trueTags = new int[seqLen];
                        //    for (int j = 0; j < seqLen; j++)
                        //    {
                        //        trueTags[j] = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSnts[k][j]);
                        //    }
                        //    m_crfDecoder.UpdateBigramTransition(seqLen, crfOutput_k, trueTags);
                        //}
                    }

                    ffLayerBatch.CopyWeightsToGradients(probs);
                }
                else
                {
                    // CRF decoder
                    //for (int k = 0; k < batchSize; k++)
                    //{
                    //    //CRF part
                    //    using (var probs_k = g.PeekRow(probs, k * seqLen, seqLen, runGradients: false))
                    //    {
                    //        var weights_k = probs_k.ToWeightArray();

                    //        var crfOutput_k = m_crfDecoder.DecodeNBestCRF(weights_k, seqLen, 1);
                    //        var targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(crfOutput_k[0].ToList());
                    //        tgtSnts.Add(targetWords);
                    //    }
                    //}


                    // Output "i"th target word
                    int[]         targetIdx   = g.Argmax(probs, 1);
                    List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList());

                    for (int k = 0; k < batchSize; k++)
                    {
                        tgtSnts[k] = targetWords.GetRange(k * seqLen, seqLen);
                    }
                }
            }

            return(cost);
        }
 public void Valid(ParallelCorpus validCorpus, List <IMetric> metrics)
 {
     RunValid(validCorpus, RunForwardOnSingleDevice, metrics, true);
 }
Beispiel #3
0
        /// <summary>
        /// Build vocabulary from training corpus
        /// </summary>
        /// <param name="trainCorpus"></param>
        /// <param name="minFreq"></param>
        public Vocab(ParallelCorpus trainCorpus, int minFreq = 1)
        {
            Logger.WriteLine($"Building vocabulary from given training corpus.");
            // count up all words
            Dictionary <string, int> s_d = new Dictionary <string, int>();
            Dictionary <string, int> t_d = new Dictionary <string, int>();

            CreateIndex();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                foreach (SntPair sntPair in sntPairBatch.SntPairs)
                {
                    var item = sntPair.SrcSnt;
                    for (int i = 0, n = item.Length; i < n; i++)
                    {
                        var txti = item[i];
                        if (s_d.Keys.Contains(txti))
                        {
                            s_d[txti] += 1;
                        }
                        else
                        {
                            s_d.Add(txti, 1);
                        }
                    }

                    var item2 = sntPair.TgtSnt;
                    for (int i = 0, n = item2.Length; i < n; i++)
                    {
                        var txti = item2[i];
                        if (t_d.Keys.Contains(txti))
                        {
                            t_d[txti] += 1;
                        }
                        else
                        {
                            t_d.Add(txti, 1);
                        }
                    }
                }
            }


            var q = 3;

            foreach (var ch in s_d)
            {
                if (ch.Value >= minFreq)
                {
                    // add word to vocab
                    SrcWordToIndex[ch.Key] = q;
                    m_srcIndexToWord[q]    = ch.Key;
                    m_srcVocab.Add(ch.Key);
                    q++;
                }
            }
            Logger.WriteLine($"Source language Max term id = '{q}'");


            q = 3;
            foreach (var ch in t_d)
            {
                if (ch.Value >= minFreq)
                {
                    // add word to vocab
                    TgtWordToIndex[ch.Key] = q;
                    m_tgtIndexToWord[q]    = ch.Key;
                    m_tgtVocab.Add(ch.Key);
                    q++;
                }
            }

            Logger.WriteLine($"Target language Max term id = '{q}'");
        }
Beispiel #4
0
        private static void Main(string[] args)
        {
            Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{GetTimeStamp(DateTime.Now)}.log";
            ShowOptions(args);

            //Parse command line
            Options   opts      = new Options();
            ArgParser argParser = new ArgParser(args, opts);

            if (string.IsNullOrEmpty(opts.ConfigFilePath) == false)
            {
                Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
            }

            AttentionSeq2Seq   ss            = null;
            ProcessorTypeEnums processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
            EncoderTypeEnums   encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
            DecoderTypeEnums   decoderType   = (DecoderTypeEnums)Enum.Parse(typeof(DecoderTypeEnums), opts.DecoderType);
            ModeEnums          mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);

            //Parse device ids from options
            int[] deviceIds = opts.DeviceIds.Split(',').Select(x => int.Parse(x)).ToArray();
            if (mode == ModeEnums.Train)
            {
                // Load train corpus
                ParallelCorpus trainCorpus = new ParallelCorpus(opts.TrainCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);
                // Load valid corpus
                ParallelCorpus validCorpus = string.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                AdamOptimizer optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                // Create metrics
                List <IMetric> metrics = new List <IMetric>
                {
                    new BleuMetric(),
                    new LengthRatioMetric()
                };


                if (!String.IsNullOrEmpty(opts.ModelFilePath) && File.Exists(opts.ModelFilePath))
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, dropoutRatio: opts.DropoutRatio, deviceIds: deviceIds,
                                              isSrcEmbTrainable: opts.IsSrcEmbeddingTrainable, isTgtEmbTrainable: opts.IsTgtEmbeddingTrainable, isEncoderTrainable: opts.IsEncoderTrainable, isDecoderTrainable: opts.IsDecoderTrainable,
                                              maxTgtSntSize: opts.MaxSentLength);
                }
                else
                {
                    // Load or build vocabulary
                    Vocab vocab = null;
                    if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab))
                    {
                        // Vocabulary files are specified, so we load them
                        vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                    }
                    else
                    {
                        // We don't specify vocabulary, so we build it from train corpus
                        vocab = new Vocab(trainCorpus);
                    }

                    //New training
                    ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                              srcEmbeddingFilePath: opts.SrcEmbeddingModelFilePath, tgtEmbeddingFilePath: opts.TgtEmbeddingModelFilePath, vocab: vocab, modelFilePath: opts.ModelFilePath,
                                              dropoutRatio: opts.DropoutRatio, processorType: processorType, deviceIds: deviceIds, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, decoderType: decoderType,
                                              maxTgtSntSize: opts.MaxSentLength, enableCoverageModel: opts.EnableCoverageModel);
                }

                // Add event handler for monitoring
                ss.IterationDone += ss_IterationDone;

                // Kick off training
                ss.Train(maxTrainingEpoch: opts.MaxEpochNum, trainCorpus: trainCorpus, validCorpus: validCorpus, learningRate: learningRate, optimizer: optimizer, metrics: metrics);
            }
            else if (mode == ModeEnums.Valid)
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                // Create metrics
                List <IMetric> metrics = new List <IMetric>
                {
                    new BleuMetric(),
                    new LengthRatioMetric()
                };

                // Load valid corpus
                ParallelCorpus validCorpus = new ParallelCorpus(opts.ValidCorpusPath, opts.SrcLang, opts.TgtLang, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);

                ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);
                ss.Valid(validCorpus: validCorpus, metrics: metrics);
            }
            else if (mode == ModeEnums.Test)
            {
                Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                //Test trained model
                ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);

                List <string> outputLines     = new List <string>();
                string[]      data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (string line in data_sents_raw1)
                {
                    if (opts.BeamSearch > 1)
                    {
                        // Below support beam search
                        List <List <string> > outputWordsList = ss.Predict(line.ToLower().Trim().Split(' ').ToList(), opts.BeamSearch);
                        outputLines.AddRange(outputWordsList.Select(x => string.Join(" ", x)));
                    }
                    else
                    {
                        var outputTokensBatch = ss.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList()));
                        outputLines.AddRange(outputTokensBatch.Select(x => String.Join(" ", x)));
                    }
                }

                File.WriteAllLines(opts.OutputTestFile, outputLines);
            }
            else if (mode == ModeEnums.VisualizeNetwork)
            {
                ss = new AttentionSeq2Seq(embeddingDim: opts.WordVectorSize, hiddenDim: opts.HiddenSize, encoderLayerDepth: opts.EncoderLayerDepth, decoderLayerDepth: opts.DecoderLayerDepth,
                                          vocab: new Vocab(), srcEmbeddingFilePath: null, tgtEmbeddingFilePath: null, modelFilePath: opts.ModelFilePath, dropoutRatio: opts.DropoutRatio,
                                          processorType: processorType, deviceIds: new int[1] {
                    0
                }, multiHeadNum: opts.MultiHeadNum, encoderType: encoderType, decoderType: decoderType, enableCoverageModel: opts.EnableCoverageModel);

                ss.VisualizeNeuralNetwork(opts.VisualizeNNFilePath);
            }
            else if (mode == ModeEnums.DumpVocab)
            {
                ss = new AttentionSeq2Seq(modelFilePath: opts.ModelFilePath, processorType: processorType, deviceIds: deviceIds);
                ss.DumpVocabToFiles(opts.SrcVocab, opts.TgtVocab);
            }
            else
            {
                argParser.Usage();
            }
        }
        /// <summary>
        /// Given input sentence and generate output sentence by seq2seq model with beam search
        /// </summary>
        /// <param name="input"></param>
        /// <param name="beamSearchSize"></param>
        /// <param name="maxOutputLength"></param>
        /// <returns></returns>
        public List <List <string> > Predict(List <string> input, int beamSearchSize = 1, int maxOutputLength = 100)
        {
            (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding) = GetNetworksOnDeviceAt(-1);
            List <List <string> > inputSeqs = ParallelCorpus.ConstructInputTokens(input);
            int batchSize = 1; // For predict with beam search, we currently only supports one sentence per call

            IComputeGraph    g          = CreateComputGraph(m_defaultDeviceId, needBack: false);
            AttentionDecoder rnnDecoder = decoder as AttentionDecoder;

            encoder.Reset(g.GetWeightFactory(), batchSize);
            rnnDecoder.Reset(g.GetWeightFactory(), batchSize);

            // Construct beam search status list
            List <BeamSearchStatus> bssList = new List <BeamSearchStatus>();
            BeamSearchStatus        bss     = new BeamSearchStatus();

            bss.OutputIds.Add((int)SENTTAGS.START);
            bss.CTs = rnnDecoder.GetCTs();
            bss.HTs = rnnDecoder.GetHTs();
            bssList.Add(bss);

            IWeightTensor             encodedWeightMatrix = Encode(g, inputSeqs, encoder, srcEmbedding, null, posEmbedding, null);
            AttentionPreProcessResult attPreProcessResult = rnnDecoder.PreProcess(encodedWeightMatrix, batchSize, g);

            List <BeamSearchStatus> newBSSList = new List <BeamSearchStatus>();
            bool finished     = false;
            int  outputLength = 0;

            while (finished == false && outputLength < maxOutputLength)
            {
                finished = true;
                for (int i = 0; i < bssList.Count; i++)
                {
                    bss = bssList[i];
                    if (bss.OutputIds[bss.OutputIds.Count - 1] == (int)SENTTAGS.END)
                    {
                        newBSSList.Add(bss);
                    }
                    else if (bss.OutputIds.Count > maxOutputLength)
                    {
                        newBSSList.Add(bss);
                    }
                    else
                    {
                        finished = false;
                        int ix_input = bss.OutputIds[bss.OutputIds.Count - 1];
                        rnnDecoder.SetCTs(bss.CTs);
                        rnnDecoder.SetHTs(bss.HTs);

                        IWeightTensor x       = g.PeekRow(tgtEmbedding, ix_input);
                        IWeightTensor eOutput = rnnDecoder.Decode(x, attPreProcessResult, batchSize, g);
                        using (IWeightTensor probs = g.Softmax(eOutput))
                        {
                            List <int> preds = probs.GetTopNMaxWeightIdx(beamSearchSize);
                            for (int j = 0; j < preds.Count; j++)
                            {
                                BeamSearchStatus newBSS = new BeamSearchStatus();
                                newBSS.OutputIds.AddRange(bss.OutputIds);
                                newBSS.OutputIds.Add(preds[j]);

                                newBSS.CTs = rnnDecoder.GetCTs();
                                newBSS.HTs = rnnDecoder.GetHTs();

                                float score = probs.GetWeightAt(preds[j]);
                                newBSS.Score  = bss.Score;
                                newBSS.Score += (float)(-Math.Log(score));

                                //var lengthPenalty = Math.Pow((5.0f + newBSS.OutputIds.Count) / 6, 0.6);
                                //newBSS.Score /= (float)lengthPenalty;

                                newBSSList.Add(newBSS);
                            }
                        }
                    }
                }

                bssList = BeamSearch.GetTopNBSS(newBSSList, beamSearchSize);
                newBSSList.Clear();

                outputLength++;
            }

            // Convert output target word ids to real string
            List <List <string> > results = new List <List <string> >();

            for (int i = 0; i < bssList.Count; i++)
            {
                results.Add(m_modelMetaData.Vocab.ConvertTargetIdsToString(bssList[i].OutputIds));
            }

            return(results);
        }
        /// <summary>
        /// Decode output sentences in training
        /// </summary>
        /// <param name="outputSnts">In training mode, they are golden target sentences, otherwise, they are target sentences generated by the decoder</param>
        /// <param name="g"></param>
        /// <param name="encOutputs"></param>
        /// <param name="decoder"></param>
        /// <param name="decoderFFLayer"></param>
        /// <param name="tgtEmbedding"></param>
        /// <returns></returns>
        private float DecodeAttentionLSTM(List <List <string> > outputSnts, IComputeGraph g, IWeightTensor encOutputs, AttentionDecoder decoder, IWeightTensor tgtEmbedding, int batchSize, bool isTraining = true)
        {
            float cost = 0.0f;

            int[] ix_inputs = new int[batchSize];
            for (int i = 0; i < ix_inputs.Length; i++)
            {
                ix_inputs[i] = m_modelMetaData.Vocab.GetTargetWordIndex(outputSnts[i][0]);
            }

            // Initialize variables accoridng to current mode
            List <int>    originalOutputLengths = isTraining ? ParallelCorpus.PadSentences(outputSnts) : null;
            int           seqLen       = isTraining ? outputSnts[0].Count : 64;
            float         dropoutRatio = isTraining ? m_dropoutRatio : 0.0f;
            HashSet <int> setEndSentId = isTraining ? null : new HashSet <int>();

            // Pre-process for attention model
            AttentionPreProcessResult attPreProcessResult = decoder.PreProcess(encOutputs, batchSize, g);

            for (int i = 1; i < seqLen; i++)
            {
                //Get embedding for all sentence in the batch at position i
                List <IWeightTensor> inputs = new List <IWeightTensor>();
                for (int j = 0; j < batchSize; j++)
                {
                    inputs.Add(g.PeekRow(tgtEmbedding, ix_inputs[j]));
                }
                IWeightTensor inputsM = g.ConcatRows(inputs);

                //Decode output sentence at position i
                IWeightTensor eOutput = decoder.Decode(inputsM, attPreProcessResult, batchSize, g);

                //Softmax for output
                using (IWeightTensor probs = g.Softmax(eOutput, runGradients: false, inPlace: true))
                {
                    if (isTraining)
                    {
                        //Calculate loss for each word in the batch
                        for (int k = 0; k < batchSize; k++)
                        {
                            using (IWeightTensor probs_k = g.PeekRow(probs, k, runGradients: false))
                            {
                                int   ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(outputSnts[k][i]);
                                float score_k      = probs_k.GetWeightAt(ix_targets_k);
                                if (i < originalOutputLengths[k])
                                {
                                    var lcost = (float)-Math.Log(score_k);
                                    if (float.IsNaN(lcost))
                                    {
                                        throw new ArithmeticException($"Score = '{score_k}' Cost = Nan at index '{i}' word '{outputSnts[k][i]}', Output Sentence = '{String.Join(" ", outputSnts[k])}'");
                                    }
                                    else
                                    {
                                        cost += lcost;
                                    }
                                }

                                probs_k.SetWeightAt(score_k - 1, ix_targets_k);
                                ix_inputs[k] = ix_targets_k;
                            }
                        }
                        eOutput.CopyWeightsToGradients(probs);
                    }
                    //if (isTraining)
                    //{
                    //    //Calculate loss for each word in the batch
                    //    int[] targetIds = new int[batchSize];
                    //    int ids = 0;
                    //    for (int k = 0; k < batchSize; k++)
                    //    {
                    //        int targetsId_k = m_modelMetaData.Vocab.GetTargetWordIndex(outputSnts[k][i]);
                    //        targetIds[ids] = i < originalOutputLengths[k] ? targetsId_k : -1;
                    //        ix_inputs[k] = targetsId_k;

                    //        ids++;
                    //    }

                    //    cost += g.UpdateCost(probs, targetIds);
                    //    eOutput.CopyWeightsToGradients(probs);

                    //}
                    else
                    {
                        // Output "i"th target word
                        int[]         targetIdx   = g.Argmax(probs, 1);
                        List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList());
                        for (int j = 0; j < targetWords.Count; j++)
                        {
                            if (setEndSentId.Contains(j) == false)
                            {
                                outputSnts[j].Add(targetWords[j]);

                                if (targetWords[j] == ParallelCorpus.EOS)
                                {
                                    setEndSentId.Add(j);
                                }
                            }
                        }

                        if (setEndSentId.Count == batchSize)
                        {
                            // All target sentences in current batch are finished, so we exit.
                            break;
                        }

                        ix_inputs = targetIdx;
                    }
                }
            }

            return(cost);
        }
        private float DecodeTransformer(List <List <string> > tgtSeqs, IComputeGraph g, IWeightTensor encOutputs, TransformerDecoder decoder,
                                        IWeightTensor tgtEmbedding, IWeightTensor posEmbedding, int batchSize, int deviceId, List <int> srcOriginalLenghts, bool isTraining = true)
        {
            float cost = 0.0f;

            var tgtOriginalLengths = ParallelCorpus.PadSentences(tgtSeqs);
            int tgtSeqLen          = tgtSeqs[0].Count;
            int srcSeqLen          = encOutputs.Rows / batchSize;

            using (IWeightTensor srcTgtMask = MaskUtils.BuildSrcTgtMask(g, srcSeqLen, tgtSeqLen, tgtOriginalLengths, srcOriginalLenghts, deviceId))
            {
                using (IWeightTensor tgtSelfTriMask = MaskUtils.BuildPadSelfTriMask(g, tgtSeqLen, tgtOriginalLengths, deviceId))
                {
                    List <IWeightTensor> inputs = new List <IWeightTensor>();
                    for (int i = 0; i < batchSize; i++)
                    {
                        for (int j = 0; j < tgtSeqLen; j++)
                        {
                            int ix_targets_k = m_modelMetaData.Vocab.GetTargetWordIndex(tgtSeqs[i][j], logUnk: true);

                            var emb = g.PeekRow(tgtEmbedding, ix_targets_k, runGradients: j < tgtOriginalLengths[i] ? true : false);

                            inputs.Add(emb);
                        }
                    }

                    IWeightTensor inputEmbs = inputs.Count > 1 ? g.ConcatRows(inputs) : inputs[0];

                    inputEmbs = AddPositionEmbedding(g, posEmbedding, batchSize, tgtSeqLen, inputEmbs);

                    IWeightTensor decOutput = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g);

                    using (IWeightTensor probs = g.Softmax(decOutput, runGradients: false, inPlace: true))
                    {
                        if (isTraining)
                        {
                            var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS);
                            for (int i = 0; i < batchSize; i++)
                            {
                                for (int j = 0; j < tgtSeqLen; j++)
                                {
                                    using (IWeightTensor probs_i_j = g.PeekRow(probs, i * tgtSeqLen + j, runGradients: false))
                                    {
                                        if (j < tgtOriginalLengths[i])
                                        {
                                            int   ix_targets_i_j = m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true);
                                            float score_i_j      = probs_i_j.GetWeightAt(ix_targets_i_j);

                                            cost += (float)-Math.Log(score_i_j);

                                            probs_i_j.SetWeightAt(score_i_j - 1, ix_targets_i_j);
                                        }
                                        else
                                        {
                                            probs_i_j.CleanWeight();
                                        }
                                    }
                                }
                            }

                            decOutput.CopyWeightsToGradients(probs);
                        }
                        //if (isTraining)
                        //{
                        //    var leftShiftInputSeqs = ParallelCorpus.LeftShiftSnts(tgtSeqs, ParallelCorpus.EOS);
                        //    int[] targetIds = new int[batchSize * tgtSeqLen];
                        //    int ids = 0;
                        //    for (int i = 0; i < batchSize; i++)
                        //    {
                        //        for (int j = 0; j < tgtSeqLen; j++)
                        //        {
                        //            targetIds[ids] = j < tgtOriginalLengths[i] ? m_modelMetaData.Vocab.GetTargetWordIndex(leftShiftInputSeqs[i][j], logUnk: true) : -1;
                        //            ids++;
                        //        }
                        //    }

                        //    cost += g.UpdateCost(probs, targetIds);
                        //    decOutput.CopyWeightsToGradients(probs);
                        //}
                        else
                        {
                            // Output "i"th target word
                            int[]         targetIdx   = g.Argmax(probs, 1);
                            List <string> targetWords = m_modelMetaData.Vocab.ConvertTargetIdsToString(targetIdx.ToList());

                            for (int i = 0; i < batchSize; i++)
                            {
                                tgtSeqs[i].Add(targetWords[i * tgtSeqLen + tgtSeqLen - 1]);
                            }
                        }
                    }
                }
            }



            return(cost);
        }
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        private float RunForwardOnSingleDevice(IComputeGraph computeGraph, List <List <string> > srcSnts, List <List <string> > tgtSnts, int deviceIdIdx, bool isTraining)
        {
            (IEncoder encoder, IDecoder decoder, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor posEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx);

            // Reset networks
            encoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);
            decoder.Reset(computeGraph.GetWeightFactory(), srcSnts.Count);


            List <int>    originalSrcLengths = ParallelCorpus.PadSentences(srcSnts);
            int           srcSeqPaddedLen    = srcSnts[0].Count;
            int           batchSize          = srcSnts.Count;
            IWeightTensor srcSelfMask        = m_shuffleType == ShuffleEnums.NoPaddingInSrc ? null : MaskUtils.BuildPadSelfMask(computeGraph, srcSeqPaddedLen, originalSrcLengths, DeviceIds[deviceIdIdx]); // The length of source sentences are same in a single mini-batch, so we don't have source mask.

            // Encoding input source sentences
            IWeightTensor encOutput = Encode(computeGraph, srcSnts, encoder, srcEmbedding, srcSelfMask, posEmbedding, originalSrcLengths);

            if (srcSelfMask != null)
            {
                srcSelfMask.Dispose();
            }

            // Generate output decoder sentences
            if (decoder is AttentionDecoder)
            {
                return(DecodeAttentionLSTM(tgtSnts, computeGraph, encOutput, decoder as AttentionDecoder, tgtEmbedding, srcSnts.Count, isTraining));
            }
            else
            {
                if (isTraining)
                {
                    return(DecodeTransformer(tgtSnts, computeGraph, encOutput, decoder as TransformerDecoder, tgtEmbedding, posEmbedding, batchSize, DeviceIds[deviceIdIdx], originalSrcLengths, isTraining));
                }
                else
                {
                    for (int i = 0; i < m_maxTgtSntSize; i++)
                    {
                        using (var g = computeGraph.CreateSubGraph($"TransformerDecoder_Step_{i}"))
                        {
                            DecodeTransformer(tgtSnts, g, encOutput, decoder as TransformerDecoder, tgtEmbedding, posEmbedding, batchSize, DeviceIds[deviceIdIdx], originalSrcLengths, isTraining);

                            bool allSntsEnd = true;
                            for (int j = 0; j < tgtSnts.Count; j++)
                            {
                                if (tgtSnts[j][tgtSnts[j].Count - 1] != ParallelCorpus.EOS)
                                {
                                    allSntsEnd = false;
                                    break;
                                }
                            }

                            if (allSntsEnd)
                            {
                                break;
                            }
                        }
                    }

                    RemoveDuplicatedEOS(tgtSnts);
                    return(0.0f);
                }
            }
        }
Beispiel #9
0
        private static void Main(string[] args)
        {
            ShowOptions(args);

            Logger.LogFile = $"{nameof(SeqLabelConsole)}_{GetTimeStamp(DateTime.Now)}.log";

            //Parse command line
            var opts      = new Options();
            var argParser = new ArgParser(args, opts);

            if (string.IsNullOrEmpty(opts.ConfigFilePath) == false)
            {
                Logger.WriteLine($"Loading config file from '{opts.ConfigFilePath}'");
                opts = JsonConvert.DeserializeObject <Options>(File.ReadAllText(opts.ConfigFilePath));
            }


            SequenceLabel sl            = null;
            var           processorType = (ProcessorTypeEnums)Enum.Parse(typeof(ProcessorTypeEnums), opts.ProcessorType);
            var           encoderType   = (EncoderTypeEnums)Enum.Parse(typeof(EncoderTypeEnums), opts.EncoderType);
            var           mode          = (ModeEnums)Enum.Parse(typeof(ModeEnums), opts.TaskName);

            //Parse device ids from options
            var deviceIds = opts.DeviceIds.Split(',').Select(int.Parse).ToArray();

            if (mode == ModeEnums.Train)
            {
                // Load train corpus
                var trainCorpus = new SequenceLabelingCorpus(opts.TrainCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);

                // Load valid corpus
                var validCorpus = string.IsNullOrEmpty(opts.ValidCorpusPath) ? null : new SequenceLabelingCorpus(opts.ValidCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);

                // Load or build vocabulary
                Vocab vocab = null;
                if (!string.IsNullOrEmpty(opts.SrcVocab) && !string.IsNullOrEmpty(opts.TgtVocab))
                {
                    // Vocabulary files are specified, so we load them
                    vocab = new Vocab(opts.SrcVocab, opts.TgtVocab);
                }
                else
                {
                    // We don't specify vocabulary, so we build it from train corpus
                    vocab = new Vocab(trainCorpus);
                }

                // Create learning rate
                ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount);

                // Create optimizer
                var optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);

                // Create metrics
                var metrics = new List <IMetric>();
                foreach (var word in vocab.TgtVocab)
                {
                    metrics.Add(new SequenceLabelFscoreMetric(word));
                }

                if (File.Exists(opts.ModelFilePath) == false)
                {
                    //New training
                    sl = new SequenceLabel(opts.HiddenSize, opts.WordVectorSize, opts.EncoderLayerDepth, opts.MultiHeadNum,
                                           encoderType,
                                           opts.DropoutRatio, deviceIds: deviceIds, processorType: processorType, modelFilePath: opts.ModelFilePath, vocab: vocab, maxSntSize: opts.MaxSentLength);
                }
                else
                {
                    //Incremental training
                    Logger.WriteLine($"Loading model from '{opts.ModelFilePath}'...");
                    sl = new SequenceLabel(opts.ModelFilePath, processorType, deviceIds, opts.DropoutRatio, opts.MaxSentLength);
                }

                // Add event handler for monitoring
                sl.IterationDone += ss_IterationDone;

                // Kick off training
                sl.Train(opts.MaxEpochNum, trainCorpus, validCorpus, learningRate, optimizer: optimizer, metrics: metrics);
            }
            else if (mode == ModeEnums.Valid)
            {
                Logger.WriteLine($"Evaluate model '{opts.ModelFilePath}' by valid corpus '{opts.ValidCorpusPath}'");

                // Load valid corpus
                var validCorpus = new SequenceLabelingCorpus(opts.ValidCorpusPath, opts.BatchSize, opts.ShuffleBlockSize, opts.MaxSentLength);

                var vocab = new Vocab(validCorpus);
                // Create metrics
                var metrics = new List <IMetric>();
                foreach (var word in vocab.TgtVocab)
                {
                    metrics.Add(new SequenceLabelFscoreMetric(word));
                }

                sl = new SequenceLabel(opts.ModelFilePath, processorType, deviceIds, maxSntSize: opts.MaxSentLength);
                sl.Valid(validCorpus, metrics);
            }
            else if (mode == ModeEnums.Test)
            {
                Logger.WriteLine($"Test model '{opts.ModelFilePath}' by input corpus '{opts.InputTestFile}'");

                //Test trained model
                sl = new SequenceLabel(opts.ModelFilePath, processorType, deviceIds, maxSntSize: opts.MaxSentLength);

                var outputLines     = new List <string>();
                var data_sents_raw1 = File.ReadAllLines(opts.InputTestFile);
                foreach (var line in data_sents_raw1)
                {
                    var outputTokensBatch = sl.Test(ParallelCorpus.ConstructInputTokens(line.ToLower().Trim().Split(' ').ToList(), false));
                    outputLines.AddRange(outputTokensBatch.Select(x => string.Join(" ", x)));
                }

                File.WriteAllLines(opts.OutputTestFile, outputLines);
            }
            else
            {
                argParser.Usage();
            }
        }