public IWeightTensor Step(IWeightTensor input, IComputeGraph g) { using (IComputeGraph innerGraph = g.CreateSubGraph(m_name)) { IWeightTensor hidden_prev = m_hidden; IWeightTensor cell_prev = m_cell; IWeightTensor inputs = innerGraph.Concate(1, input, hidden_prev); IWeightTensor hhSum = innerGraph.Affine(inputs, m_Wxh, m_b); IWeightTensor hhSum2 = m_layerNorm1.Norm(hhSum, innerGraph); (IWeightTensor gates_raw, IWeightTensor cell_write_raw) = innerGraph.SplitColumns(hhSum2, m_hdim * 3, m_hdim); IWeightTensor gates = innerGraph.Sigmoid(gates_raw); IWeightTensor cell_write = innerGraph.Tanh(cell_write_raw); (IWeightTensor input_gate, IWeightTensor forget_gate, IWeightTensor output_gate) = innerGraph.SplitColumns(gates, m_hdim, m_hdim, m_hdim); // compute new cell activation: ct = forget_gate * cell_prev + input_gate * cell_write m_cell = g.EltMulMulAdd(forget_gate, cell_prev, input_gate, cell_write); IWeightTensor ct2 = m_layerNorm2.Norm(m_cell, innerGraph); // compute hidden state as gated, saturated cell activations m_hidden = g.EltMul(output_gate, innerGraph.Tanh(ct2)); return(m_hidden); } }
/// <summary> /// Update LSTM-Attention cells according to given weights /// </summary> /// <param name="context">The context weights for attention</param> /// <param name="input">The input weights</param> /// <param name="computeGraph">The compute graph to build workflow</param> /// <returns>Update hidden weights</returns> public IWeightTensor Step(IWeightTensor context, IWeightTensor input, IComputeGraph g) { using (IComputeGraph computeGraph = g.CreateSubGraph(m_name)) { IWeightTensor cell_prev = Cell; IWeightTensor hidden_prev = Hidden; IWeightTensor hxhc = computeGraph.Concate(1, input, hidden_prev, context); IWeightTensor hhSum = computeGraph.Affine(hxhc, m_Wxhc, m_b); IWeightTensor hhSum2 = m_layerNorm1.Norm(hhSum, computeGraph); (IWeightTensor gates_raw, IWeightTensor cell_write_raw) = computeGraph.SplitColumns(hhSum2, m_hiddenDim * 3, m_hiddenDim); IWeightTensor gates = computeGraph.Sigmoid(gates_raw); IWeightTensor cell_write = computeGraph.Tanh(cell_write_raw); (IWeightTensor input_gate, IWeightTensor forget_gate, IWeightTensor output_gate) = computeGraph.SplitColumns(gates, m_hiddenDim, m_hiddenDim, m_hiddenDim); // compute new cell activation: ct = forget_gate * cell_prev + input_gate * cell_write Cell = g.EltMulMulAdd(forget_gate, cell_prev, input_gate, cell_write); IWeightTensor ct2 = m_layerNorm2.Norm(Cell, computeGraph); Hidden = g.EltMul(output_gate, computeGraph.Tanh(ct2)); return(Hidden); } }
public IWeightTensor Encode(IWeightTensor rawInputs, int batchSize, IComputeGraph g, IWeightTensor srcSelfMask) { int seqLen = rawInputs.Rows / batchSize; rawInputs = g.TransposeBatch(rawInputs, seqLen); List <IWeightTensor> inputs = new List <IWeightTensor>(); for (int i = 0; i < seqLen; i++) { IWeightTensor emb_i = g.Peek(rawInputs, 0, i * batchSize, batchSize); inputs.Add(emb_i); } List <IWeightTensor> forwardOutputs = new List <IWeightTensor>(); List <IWeightTensor> backwardOutputs = new List <IWeightTensor>(); List <IWeightTensor> layerOutputs = inputs.ToList(); for (int i = 0; i < m_depth; i++) { for (int j = 0; j < seqLen; j++) { IWeightTensor forwardOutput = m_forwardEncoders[i].Step(layerOutputs[j], g); forwardOutputs.Add(forwardOutput); IWeightTensor backwardOutput = m_backwardEncoders[i].Step(layerOutputs[inputs.Count - j - 1], g); backwardOutputs.Add(backwardOutput); } backwardOutputs.Reverse(); layerOutputs.Clear(); for (int j = 0; j < seqLen; j++) { IWeightTensor concatW = g.Concate(1, forwardOutputs[j], backwardOutputs[j]); layerOutputs.Add(concatW); } } var result = g.Concate(layerOutputs, 0); return(g.TransposeBatch(result, batchSize)); }
public IWeightTensor Perform(IWeightTensor state, AttentionPreProcessResult attnPre, int batchSize, IComputeGraph graph) { int srcSeqLen = attnPre.encOutput.Rows / batchSize; using (IComputeGraph g = graph.CreateSubGraph(m_name)) { // Affine decoder state IWeightTensor wc = g.Affine(state, m_Wa, m_bWa); // Expand dims from [batchSize x decoder_dim] to [batchSize x srcSeqLen x decoder_dim] IWeightTensor wc1 = g.View(wc, dims: new long[] { batchSize, 1, wc.Columns }); IWeightTensor wcExp = g.Expand(wc1, dims: new long[] { batchSize, srcSeqLen, wc.Columns }); IWeightTensor ggs = null; if (m_enableCoverageModel) { // Get coverage model status at {t-1} IWeightTensor wCoverage = g.Affine(m_coverage.Hidden, m_Wc, m_bWc); IWeightTensor wCoverage1 = g.View(wCoverage, dims: new long[] { batchSize, srcSeqLen, -1 }); ggs = g.AddTanh(attnPre.Uhs, wcExp, wCoverage1); } else { ggs = g.AddTanh(attnPre.Uhs, wcExp); } IWeightTensor ggss = g.View(ggs, dims: new long[] { batchSize *srcSeqLen, -1 }); IWeightTensor atten = g.Mul(ggss, m_V); IWeightTensor attenT = g.Transpose(atten); IWeightTensor attenT2 = g.View(attenT, dims: new long[] { batchSize, srcSeqLen }); IWeightTensor attenSoftmax1 = g.Softmax(attenT2, inPlace: true); IWeightTensor attenSoftmax = g.View(attenSoftmax1, dims: new long[] { batchSize, 1, srcSeqLen }); IWeightTensor inputs2 = g.View(attnPre.encOutput, dims: new long[] { batchSize, srcSeqLen, attnPre.encOutput.Columns }); IWeightTensor contexts = graph.MulBatch(attenSoftmax, inputs2); contexts = graph.View(contexts, dims: new long[] { batchSize, attnPre.encOutput.Columns }); if (m_enableCoverageModel) { // Concatenate tensor as input for coverage model IWeightTensor aCoverage = g.View(attenSoftmax1, dims: new long[] { attnPre.encOutput.Rows, 1 }); IWeightTensor state2 = g.View(state, dims: new long[] { batchSize, 1, state.Columns }); IWeightTensor state3 = g.Expand(state2, dims: new long[] { batchSize, srcSeqLen, state.Columns }); IWeightTensor state4 = g.View(state3, dims: new long[] { batchSize *srcSeqLen, -1 }); IWeightTensor concate = g.Concate(1, aCoverage, attnPre.encOutput, state4); m_coverage.Step(concate, graph); } return(contexts); } }