public WeightMatrix Decode(SparseWeightMatrix sparseInput, WeightMatrix input, WeightMatrix encoderOutput, IComputeGraph g) { var V = input; var lastStatus = this.decoders.FirstOrDefault().ct; var context = Attention.Perform(encoderOutput, lastStatus, g); foreach (var decoder in decoders) { var e = decoder.Step(sparseInput, context, V, g); V = e; } return(V); }
public List <string> Predict(List <string> input) { reversEncoder.Reset(); encoder.Reset(); decoder.Reset(); List <string> result = new List <string>(); #if MKL var G2 = new ComputeGraphMKL(false); #else var G2 = new ComputeGraph(false); #endif List <string> inputSeq = new List <string>(); // inputSeq.Add(m_START); inputSeq.AddRange(input); // inputSeq.Add(m_END); List <string> revseq = inputSeq.ToList(); revseq.Reverse(); List <WeightMatrix> forwardEncoded = new List <WeightMatrix>(); List <WeightMatrix> backwardEncoded = new List <WeightMatrix>(); List <WeightMatrix> encoded = new List <WeightMatrix>(); SparseWeightMatrix sparseInput = new SparseWeightMatrix(1, s_Embedding.Columns); Parallel.Invoke( () => { for (int i = 0; i < inputSeq.Count; i++) { int ix = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(inputSeq[i]) == false) { Logger.WriteLine($"Unknow input word: {inputSeq[i]}"); } else { ix = s_wordToIndex[inputSeq[i]]; } var x2 = G2.PeekRow(s_Embedding, ix); var o = encoder.Encode(x2, G2); forwardEncoded.Add(o); sparseInput.AddWeight(0, ix, 1.0f); } }, () => { for (int i = 0; i < inputSeq.Count; i++) { int ix = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(revseq[i]) == false) { Logger.WriteLine($"Unknow input word: {revseq[i]}"); } else { ix = s_wordToIndex[revseq[i]]; } var x2 = G2.PeekRow(s_Embedding, ix); var o = reversEncoder.Encode(x2, G2); backwardEncoded.Add(o); } }); backwardEncoded.Reverse(); for (int i = 0; i < inputSeq.Count; i++) { //encoded.Add(G2.concatColumns(forwardEncoded[i], backwardEncoded[i])); encoded.Add(G2.add(forwardEncoded[i], backwardEncoded[i])); } //if (UseDropout) //{ // for (int i = 0; i < encoded.Weight.Length; i++) // { // encoded.Weight[i] *= 0.2; // } //} var ix_input = (int)SENTTAGS.START; while (true) { var x = G2.PeekRow(t_Embedding, ix_input); var eOutput = decoder.Decode(sparseInput, x, encoded, G2); if (UseDropout) { for (int i = 0; i < eOutput.Weight.Length; i++) { eOutput.Weight[i] *= 0.2f; } } var o = G2.muladd(eOutput, this.Whd, this.bd); if (UseDropout) { for (int i = 0; i < o.Weight.Length; i++) { o.Weight[i] *= 0.2f; } } var probs = G2.SoftmaxWithCrossEntropy(o); var maxv = probs.Weight[0]; var maxi = 0; for (int i = 1; i < probs.Weight.Length; i++) { if (probs.Weight[i] > maxv) { maxv = probs.Weight[i]; maxi = i; } } var pred = maxi; if (pred == (int)SENTTAGS.END) { break; // END token predicted, break out } if (result.Count > max_word) { break; } // something is wrong var letter2 = m_UNK; if (t_indexToWord.ContainsKey(pred)) { letter2 = t_indexToWord[pred]; } result.Add(letter2); ix_input = pred; } return(result); }
private IComputeGraph Encode(List <string> inputSentence, out float cost, out SparseWeightMatrix sWM, List <WeightMatrix> encoded, Encoder encoder, Encoder reversEncoder, WeightMatrix Embedding) { var reversSentence = inputSentence.ToList(); reversSentence.Reverse(); #if MKL IComputeGraph g = new ComputeGraphMKL(); #else IComputeGraph g = new ComputeGraph(); #endif cost = 0.0f; SparseWeightMatrix tmpSWM = new SparseWeightMatrix(1, Embedding.Columns); List <WeightMatrix> forwardOutputs = new List <WeightMatrix>(); List <WeightMatrix> backwardOutputs = new List <WeightMatrix>(); Parallel.Invoke( () => { for (int i = 0; i < inputSentence.Count; i++) { int ix_source = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(inputSentence[i])) { ix_source = s_wordToIndex[inputSentence[i]]; } var x = g.PeekRow(Embedding, ix_source); var eOutput = encoder.Encode(x, g); forwardOutputs.Add(eOutput); tmpSWM.AddWeight(0, ix_source, 1.0f); } }, () => { for (int i = 0; i < inputSentence.Count; i++) { int ix_source2 = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(reversSentence[i])) { ix_source2 = s_wordToIndex[reversSentence[i]]; } var x2 = g.PeekRow(Embedding, ix_source2); var eOutput2 = reversEncoder.Encode(x2, g); backwardOutputs.Add(eOutput2); } }); backwardOutputs.Reverse(); for (int i = 0; i < inputSentence.Count; i++) { //encoded.Add(g.concatColumns(forwardOutputs[i], backwardOutputs[i])); encoded.Add(g.add(forwardOutputs[i], backwardOutputs[i])); } sWM = tmpSWM; return(g); }
private float DecodeOutput(string[] OutputSentence, IComputeGraph g, float cost, SparseWeightMatrix sparseInput, List <WeightMatrix> encoded, AttentionDecoder decoder, WeightMatrix Whd, WeightMatrix bd, WeightMatrix Embedding) { int ix_input = (int)SENTTAGS.START; for (int i = 0; i < OutputSentence.Length + 1; i++) { int ix_target = (int)SENTTAGS.UNK; if (i == OutputSentence.Length) { ix_target = (int)SENTTAGS.END; } else { if (t_wordToIndex.ContainsKey(OutputSentence[i])) { ix_target = t_wordToIndex[OutputSentence[i]]; } } var x = g.PeekRow(Embedding, ix_input); var eOutput = decoder.Decode(sparseInput, x, encoded, g); if (UseDropout) { eOutput = g.Dropout(eOutput, 0.2f); } var o = g.muladd(eOutput, Whd, bd); if (UseDropout) { o = g.Dropout(o, 0.2f); } var probs = g.SoftmaxWithCrossEntropy(o); cost += (float)-Math.Log(probs.Weight[ix_target]); o.Gradient = probs.Weight; o.Gradient[ix_target] -= 1; ix_input = ix_target; } return(cost); }
public WeightMatrix Step(SparseWeightMatrix sparseInput, WeightMatrix context, WeightMatrix input, IComputeGraph innerGraph) { var hidden_prev = ht; var cell_prev = ct; var cell = this; WeightMatrix input_gate = null; WeightMatrix forget_gate = null; WeightMatrix output_gate = null; WeightMatrix cell_write = null; Parallel.Invoke( () => { var h0 = innerGraph.mul(input, cell.Wix); var h1 = innerGraph.mul(hidden_prev, cell.Wih); var h11 = innerGraph.mul(context, cell.WiC); if (sdim > 0) { var h111 = innerGraph.mul(sparseInput, cell.WiS); input_gate = innerGraph.addsigmoid(h0, h1, h11, h111, cell.bi); } else { input_gate = innerGraph.addsigmoid(h0, h1, h11, cell.bi); } }, () => { var h2 = innerGraph.mul(input, cell.Wfx); var h3 = innerGraph.mul(hidden_prev, cell.Wfh); var h33 = innerGraph.mul(context, cell.WfC); if (sdim > 0) { var h333 = innerGraph.mul(sparseInput, cell.WfS); forget_gate = innerGraph.addsigmoid(h3, h2, h33, h333, cell.bf); } else { forget_gate = innerGraph.addsigmoid(h3, h2, h33, cell.bf); } }, () => { var h4 = innerGraph.mul(input, cell.Wox); var h5 = innerGraph.mul(hidden_prev, cell.Woh); var h55 = innerGraph.mul(context, cell.WoC); if (sdim > 0) { var h555 = innerGraph.mul(sparseInput, cell.WoS); output_gate = innerGraph.addsigmoid(h5, h4, h55, h555, cell.bo); } else { output_gate = innerGraph.addsigmoid(h5, h4, h55, cell.bo); } }, () => { var h6 = innerGraph.mul(input, cell.Wcx); var h7 = innerGraph.mul(hidden_prev, cell.Wch); var h77 = innerGraph.mul(context, cell.WcC); if (sdim > 0) { var h777 = innerGraph.mul(sparseInput, cell.WcS); cell_write = innerGraph.addtanh(h7, h6, h77, h777, cell.bc); } else { cell_write = innerGraph.addtanh(h7, h6, h77, cell.bc); } }); // compute new cell activation var retain_cell = innerGraph.eltmul(forget_gate, cell_prev); // what do we keep from cell var write_cell = innerGraph.eltmul(input_gate, cell_write); // what do we write to cell var cell_d = innerGraph.add(retain_cell, write_cell); // new cell contents // compute hidden state as gated, saturated cell activations var hidden_d = innerGraph.eltmul(output_gate, innerGraph.tanh(cell_d)); this.ht = hidden_d; this.ct = cell_d; return(ht); }