public AttentionPreProcessResult PreProcess(IWeightTensor encOutput, int batchSize, IComputeGraph g) { int srcSeqLen = encOutput.Rows / batchSize; AttentionPreProcessResult r = new AttentionPreProcessResult { encOutput = encOutput }; r.Uhs = g.Affine(r.encOutput, m_Ua, m_bUa); r.Uhs = g.View(r.Uhs, dims: new long[] { batchSize, srcSeqLen, -1 }); if (m_enableCoverageModel) { m_coverage.Reset(g.GetWeightFactory(), r.encOutput.Rows); } return(r); }
public AttentionPreProcessResult PreProcess(IWeightTensor inputs, int batchSize, IComputeGraph g) { int srcSeqLen = inputs.Rows / batchSize; AttentionPreProcessResult r = new AttentionPreProcessResult { rawInputs = inputs, inputsBatchFirst = g.TransposeBatch(inputs, batchSize) }; r.uhs = g.Affine(r.inputsBatchFirst, m_Ua, m_bUa); r.uhs = g.View(r.uhs, batchSize, srcSeqLen, -1); if (m_enableCoverageModel) { m_coverage.Reset(g.GetWeightFactory(), r.inputsBatchFirst.Rows); } return(r); }