Ejemplo n.º 1
0
        /// <summary>
        /// Run forward part on given single device
        /// </summary>
        /// <param name="g">The computing graph for current device. It gets created and passed by the framework</param>
        /// <param name="srcSnts">A batch of input tokenized sentences in source side</param>
        /// <param name="tgtSnts">A batch of output tokenized sentences in target side. In training mode, it inputs target tokens, otherwise, it outputs target tokens generated by decoder</param>
        /// <param name="deviceIdIdx">The index of current device</param>
        /// <returns>The cost of forward part</returns>
        public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph g, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions)
        {
            List <NetworkResult> nrs = new List <NetworkResult>();

            var srcSnts = sntPairBatch.GetSrcTokens(0);
            var tgtSnts = sntPairBatch.GetTgtTokens(0);

            (IEncoder encoder, IWeightTensor srcEmbedding, IWeightTensor posEmbedding, FeedForwardLayer decoderFFLayer) = GetNetworksOnDeviceAt(deviceIdIdx);

            // Reset networks
            encoder.Reset(g.GetWeightFactory(), srcSnts.Count);

            var originalSrcLengths = BuildInTokens.PadSentences(srcSnts);
            var srcTokensList      = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts);

            BuildInTokens.PadSentences(tgtSnts);
            var tgtTokensLists = m_modelMetaData.ClsVocab.GetWordIndex(tgtSnts);

            int seqLen    = srcSnts[0].Count;
            int batchSize = srcSnts.Count;

            // Encoding input source sentences
            IWeightTensor encOutput = Encoder.Run(g, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, null, srcTokensList, originalSrcLengths);
            IWeightTensor ffLayer   = decoderFFLayer.Process(encOutput, batchSize, g);

            float         cost  = 0.0f;
            IWeightTensor probs = g.Softmax(ffLayer, inPlace: true);

            if (isTraining)
            {
                var tgtTokensTensor = g.CreateTokensTensor(tgtTokensLists);
                cost = g.CrossEntropyLoss(probs, tgtTokensTensor);
            }
            else
            {
                // Output "i"th target word
                using var targetIdxTensor = g.Argmax(probs, 1);
                float[]       targetIdx   = targetIdxTensor.ToWeightArray();
                List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList());

                for (int k = 0; k < batchSize; k++)
                {
                    tgtSnts[k] = targetWords.GetRange(k * seqLen, seqLen);
                }
            }

            NetworkResult nr = new NetworkResult
            {
                Cost   = cost,
                Output = new List <List <List <string> > >()
            };

            nr.Output.Add(tgtSnts);

            nrs.Add(nr);

            return(nrs);
        }