public IWeightTensor Perform(IWeightTensor state, AttentionPreProcessResult attenPreProcessResult, int batchSize, IComputeGraph graph) { int srcSeqLen = attenPreProcessResult.inputsBatchFirst.Rows / batchSize; using (IComputeGraph g = graph.CreateSubGraph(m_name)) { // Affine decoder state IWeightTensor wc = g.Affine(state, m_Wa, m_bWa); // Expand dims from [batchSize x decoder_dim] to [batchSize x srcSeqLen x decoder_dim] IWeightTensor wc1 = g.View(wc, batchSize, 1, wc.Columns); IWeightTensor wcExp = g.Expand(wc1, batchSize, srcSeqLen, wc.Columns); IWeightTensor ggs = null; if (m_enableCoverageModel) { // Get coverage model status at {t-1} IWeightTensor wCoverage = g.Affine(m_coverage.Hidden, m_Wc, m_bWc); IWeightTensor wCoverage1 = g.View(wCoverage, batchSize, srcSeqLen, -1); ggs = g.AddTanh(attenPreProcessResult.uhs, wcExp, wCoverage1); } else { ggs = g.AddTanh(attenPreProcessResult.uhs, wcExp); } IWeightTensor ggss = g.View(ggs, batchSize * srcSeqLen, -1); IWeightTensor atten = g.Mul(ggss, m_V); IWeightTensor attenT = g.Transpose(atten); IWeightTensor attenT2 = g.View(attenT, batchSize, srcSeqLen); IWeightTensor attenSoftmax1 = g.Softmax(attenT2, inPlace: true); IWeightTensor attenSoftmax = g.View(attenSoftmax1, batchSize, 1, srcSeqLen); IWeightTensor inputs2 = g.View(attenPreProcessResult.inputsBatchFirst, batchSize, srcSeqLen, attenPreProcessResult.inputsBatchFirst.Columns); IWeightTensor contexts = graph.MulBatch(attenSoftmax, inputs2, batchSize); if (m_enableCoverageModel) { // Concatenate tensor as input for coverage model IWeightTensor aCoverage = g.View(attenSoftmax1, attenPreProcessResult.inputsBatchFirst.Rows, 1); IWeightTensor state2 = g.View(state, batchSize, 1, state.Columns); IWeightTensor state3 = g.Expand(state2, batchSize, srcSeqLen, state.Columns); IWeightTensor state4 = g.View(state3, batchSize * srcSeqLen, -1); IWeightTensor concate = g.ConcatColumns(aCoverage, attenPreProcessResult.inputsBatchFirst, state4); m_coverage.Step(concate, graph); } return(contexts); } }
public IWeightMatrix Perform(IWeightMatrix state, AttentionPreProcessResult attenPreProcessResult, IComputeGraph g) { var bWas = g.RepeatRows(bWa, state.Rows); var wc = g.MulAdd(state, Wa, bWas); var wcs = g.RepeatRows(wc, attenPreProcessResult.inputsUnfolder[0].Rows); var ggs = g.AddTanh(attenPreProcessResult.uhs, wcs); var atten = g.Mul(ggs, V); List <IWeightMatrix> attens = g.UnFolderRow(atten, m_batchSize); List <IWeightMatrix> contexts = new List <IWeightMatrix>(); List <IWeightMatrix> attensT = new List <IWeightMatrix>(); for (int i = 0; i < m_batchSize; i++) { attensT.Add(g.Transpose2(attens[i])); } var attenT = g.ConcatRows(attensT); var attenSoftmax = g.SoftmaxM(attenT); for (int i = 0; i < m_batchSize; i++) { IWeightMatrix context = g.Mul(g.PeekRow(attenSoftmax, i), attenPreProcessResult.inputsUnfolder[i]); contexts.Add(context); } return(g.ConcatRows(contexts)); }
public IWeightMatrix Perform(IWeightMatrix state, AttentionPreProcessResult attenPreProcessResult, IComputeGraph g) { var bWas = g.RepeatRows(bWa, state.Rows); var wc = g.MulAdd(state, Wa, bWas); var wcs = g.RepeatRows(wc, attenPreProcessResult.inputs.Rows / m_batchSize); var ggs = g.AddTanh(attenPreProcessResult.uhs, wcs); var atten = g.Mul(ggs, V); var atten2 = g.PermuteBatch(atten, m_batchSize); var attenT = g.Transpose2(atten2); var attenT2 = g.View(attenT, m_batchSize, attenPreProcessResult.inputs.Rows / m_batchSize); var attenSoftmax = g.Softmax(attenT2); IWeightMatrix contexts = g.MulBatch(attenSoftmax, attenPreProcessResult.inputs, m_batchSize); return(contexts); }
public IWeightTensor Perform(IWeightTensor state, AttentionPreProcessResult attenPreProcessResult, int batchSize, IComputeGraph graph) { IComputeGraph g = graph.CreateSubGraph(m_name); var wc = g.Affine(state, m_Wa, m_bWa); var wcs = g.RepeatRows(wc, attenPreProcessResult.inputs.Rows / batchSize); var ggs = g.AddTanh(attenPreProcessResult.uhs, wcs); var atten = g.Mul(ggs, m_V); var atten2 = g.TransposeBatch(atten, batchSize); var attenT = g.Transpose(atten2); var attenT2 = g.View(attenT, batchSize, attenPreProcessResult.inputs.Rows / batchSize); var attenSoftmax1 = g.Softmax(attenT2, inPlace: true); var attenSoftmax = g.View(attenSoftmax1, batchSize, attenSoftmax1.Rows / batchSize, attenSoftmax1.Columns); var inputs2 = g.View(attenPreProcessResult.inputs, batchSize, attenPreProcessResult.inputs.Rows / batchSize, attenPreProcessResult.inputs.Columns); IWeightTensor contexts = g.MulBatch(attenSoftmax, inputs2, batchSize); return(contexts); }