/// <summary> /// Run forward part on given single device /// </summary> /// <param name="computeGraph">The computing graph for current device. It gets created and passed by the framework</param> /// <param name="srcSnts">A batch of input tokenized sentences in source side</param> /// <param name="tgtSnts">A batch of output tokenized sentences in target side</param> /// <param name="deviceIdIdx">The index of current device</param> /// <returns>The cost of forward part</returns> public override List <NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, int deviceIdIdx, bool isTraining, DecodingOptions decodingOptions) { int batchSize = sntPairBatch.BatchSize; float cost = 0.0f; var nrs = new List <NetworkResult>(); var nr = new NetworkResult { Output = new List <List <List <string> > >() }; (IEncoder encoder, IWeightTensor srcEmbedding, IFeedForwardLayer encoderFFLayer, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(deviceIdIdx); IWeightTensor encOutput1; IWeightTensor encOutput2; if (!isTraining && (m_options.ProcessorType == ProcessorTypeEnums.CPU)) { //We only check cache at inference time string cacheKey1 = GenerateCacheKey(sntPairBatch.GetSrcTokens(0)); if (!m_memoryCache.TryGetValue(cacheKey1, out encOutput1)) { encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey1, encOutput1.CopyWeightsRef($"cache_{encOutput1.Name}", false), cacheEntryOptions); } string cacheKey2 = GenerateCacheKey(sntPairBatch.GetSrcTokens(1)); if (!m_memoryCache.TryGetValue(cacheKey2, out encOutput2)) { encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey2, encOutput2.CopyWeightsRef($"cache_{encOutput2.Name}", false), cacheEntryOptions); } } else { //We always run encoder network during training time or using GPUs encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim] encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim] } if (m_modelMetaData.SimilarityType.Equals("Continuous", StringComparison.InvariantCultureIgnoreCase)) { // Cosine similairy var w12 = computeGraph.EltMul(encOutput1, encOutput2); w12 = computeGraph.Sum(w12, 1); var w1 = computeGraph.EltMul(encOutput1, encOutput1); w1 = computeGraph.Sum(w1, 1); var w2 = computeGraph.EltMul(encOutput2, encOutput2); w2 = computeGraph.Sum(w2, 1); var n12 = computeGraph.EltMul(w1, w2); n12 = computeGraph.Rsqrt(n12); var probs = computeGraph.EltMul(w12, n12); if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { float golden_score_k = float.Parse(tgtSnts[k][0]); // Get golden similiary score from target side float score_k = probs.GetWeightAt(new long[] { k, 0 }); probs.SetWeightAt(score_k - golden_score_k, new long[] { k, 0 }); cost += (float)Math.Abs(score_k - golden_score_k); } probs.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { float score_k = probs.GetWeightAt(new long[] { k, 0 }); nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(score_k.ToString()); } } } else { IWeightTensor encOutput = computeGraph.EltMul(encOutput1, encOutput2); IWeightTensor ffLayer = encoderFFLayer.Process(encOutput, batchSize, computeGraph); using (IWeightTensor probs = computeGraph.Softmax(ffLayer, runGradients: false, inPlace: true)) { if (isTraining) { var tgtSnts = sntPairBatch.GetTgtTokens(0); for (int k = 0; k < batchSize; k++) { int ix_targets_k_j = m_modelMetaData.ClsVocab.GetWordIndex(tgtSnts[k][0]); float score_k = probs.GetWeightAt(new long[] { k, ix_targets_k_j }); cost += (float)-Math.Log(score_k); probs.SetWeightAt(score_k - 1, new long[] { k, ix_targets_k_j }); } ffLayer.CopyWeightsToGradients(probs); nr.Cost = cost / batchSize; } else { // Output "i"th target word using var targetIdxTensor = computeGraph.Argmax(probs, 1); float[] targetIdx = targetIdxTensor.ToWeightArray(); List <string> targetWords = m_modelMetaData.ClsVocab.ConvertIdsToString(targetIdx.ToList()); nr.Output.Add(new List <List <string> >()); for (int k = 0; k < batchSize; k++) { nr.Output[0].Add(new List <string>()); nr.Output[0][k].Add(targetWords[k]); } } } } nrs.Add(nr); return(nrs); }