/// <summary> /// Run forward part on given single device /// </summary> /// <param name="g">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side. In training mode, it inputs target tokens, otherwise, it outputs target tokens generated by decoder</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph g, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { List <NetworkResult> nrs = new List <NetworkResult>(); var srcSnts = sntPairBatch.GetSrcTokens(0); var tgtSnts = sntPairBatch.GetTgtTokens(0); (IEncoder encoder, IWeightTensor srcEmbedding, IWeightTensor posEmbedding, FeedForwardLayer decoderFFLayer) = GetNetworksOnDeviceAt(deviceIdIdx); // Reset networks encoder.Reset(g.GetWeightFactory(), srcSnts.Count); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); BuildInTokens.PadSentences(tgtSnts); var tgtTokensLists = m_modelMetaData.ClsVocab.GetWordIndex(tgtSnts); int seqLen = srcSnts[0].Count; int batchSize = srcSnts.Count; // Encoding input source sentences IWeightTensor encOutput = Encoder.Run(g, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbedding, null, srcTokensList, originalSrcLengths); IWeightTensor ffLayer = decoderFFLayer.Process(encOutput, batchSize, g); float cost = 0.0f; IWeightTensor probs = g.Softmax(ffLayer, inPlace: true); if (isTraining) { var tgtTokensTensor = g.CreateTokensTensor(tgtTokensLists); cost = g.CrossEntropyLoss(probs, tgtTokensTensor); } else { // Output "i"th target word using var targetIdxTensor = g.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList()); for (int k = 0; k < batchSize; k++) { tgtSnts[k] = targetWords.GetRange(k * seqLen, seqLen); } } NetworkResult nr = new NetworkResult { Cost = cost, Output = new List <List <List <string> > >() }; nr.Output.Add(tgtSnts); nrs.Add(nr); return(nrs); }