/// <summary> /// /// </summary> /// <param name="input"></param> /// <param name="innerGraph"></param> /// <returns></returns> public WeightMatrix Step(WeightMatrix input, ComputeGraph innerGraph) { var hidden_prev = ht; var cell_prev = ct; var cell = this; var h0 = innerGraph.mul(input, cell.Wix); var h1 = innerGraph.mul(hidden_prev, cell.Wih); var input_gate = innerGraph.sigmoid( innerGraph.add( innerGraph.add(h0, h1), cell.bi ) ); var h2 = innerGraph.mul(input, cell.Wfx); var h3 = innerGraph.mul(hidden_prev, cell.Wfh); var forget_gate = innerGraph.sigmoid( innerGraph.add( innerGraph.add(h2, h3), cell.bf ) ); var h4 = innerGraph.mul(input, cell.Wox); var h5 = innerGraph.mul(hidden_prev, cell.Woh); var output_gate = innerGraph.sigmoid( innerGraph.add( innerGraph.add(h4, h5), cell.bo ) ); var h6 = innerGraph.mul(input, cell.Wcx); var h7 = innerGraph.mul(hidden_prev, cell.Wch); var cell_write = innerGraph.tanh( innerGraph.add( innerGraph.add(h6, h7), cell.bc ) ); // compute new cell activation var retain_cell = innerGraph.eltmul(forget_gate, cell_prev); // what do we keep from cell var write_cell = innerGraph.eltmul(input_gate, cell_write); // what do we write to cell var cell_d = innerGraph.add(retain_cell, write_cell); // new cell contents // compute hidden state as gated, saturated cell activations var hidden_d = innerGraph.eltmul(output_gate, innerGraph.tanh(cell_d)); this.ht = hidden_d; this.ct = cell_d; return(ht); }
public WeightMatrix Perform(WeightMatrix input, WeightMatrix state, ComputeGraph g) { WeightMatrix context; List <WeightMatrix> atten = new List <WeightMatrix>(); var stateRepeat = g.RepeatRows(state, input.Rows); var baiseInput = new WeightMatrix(input.Rows, 1, 1); var inputb = g.concatColumns(input, baiseInput); var uh = g.mul(inputb, Ua); baiseInput = new WeightMatrix(stateRepeat.Rows, 1, 1); stateRepeat = g.concatColumns(stateRepeat, baiseInput); var wc = g.mul(stateRepeat, Wa); var gg = g.tanh(g.add(uh, wc)); var aa = g.mul(gg, V); var res = g.Softmax(aa); var weighted = g.weightRows(input, res);; context = g.sumColumns(weighted); return(context); }
public WeightMatrix Perform(List <WeightMatrix> input, WeightMatrix state, ComputeGraph g) { WeightMatrix context; List <WeightMatrix> atten = new List <WeightMatrix>(); foreach (var h_j in input) { var uh = g.add(g.mul(h_j, Ua), bUa); var wc = g.add(g.mul(state, Wa), bWa); var gg = g.tanh(g.add(uh, wc)); var aa = g.mul(gg, V); atten.Add(aa); } var res = g.Softmax(atten); var cmax = res[0].Weight[0]; int maxAtt = 0; for (int i = 1; i < res.Count; i++) { if (res[i].Weight[0] > cmax) { cmax = res[i].Weight[0]; maxAtt = i; } } this.MaxIndex = maxAtt; context = g.scalemul(input[0], res[0]); for (int hj = 1; hj < input.Count; hj++) { context = g.add(context, g.scalemul(input[hj], res[hj])); } return(context); }
/// <summary> /// /// </summary> /// <param name="OutputSentence"></param> /// <param name="g"></param> /// <param name="cost"></param> /// <param name="encoded"></param> /// <returns></returns> private double DecodeOutput(List <string> OutputSentence, ComputeGraph g, double cost, List <WeightMatrix> encoded) { int ix_input = 1; for (int i = 0; i < OutputSentence.Count + 1; i++) { int ix_target = 0; if (i == OutputSentence.Count) { ix_target = 0; } else { ix_target = wordToIndex[OutputSentence[i]].i; } var x = g.PeekRow(Embedding, ix_input); var eOutput = decoder.Decode(x, encoded, g); if (UseDropout) { eOutput = g.Dropout(eOutput, 0.2); } var o = g.add( g.mul(eOutput, this.Whd), this.bd); if (UseDropout) { o = g.Dropout(o, 0.2); } var probs = g.SoftmaxWithCrossEntropy(o); cost += -Math.Log(probs.Weight[ix_target]); o.Gradient = probs.Weight; o.Gradient[ix_target] -= 1; ix_input = ix_target; } return(cost); }
/// <summary> /// 预测 /// </summary> /// <param name="inputSeq">输入分词过的语句</param> /// <returns></returns> public ExecuteResult <List <string> > Predict(List <string> inputSeq) { ExecuteResult <List <string> > eresult = new ExecuteResult <List <string> >(); Reset(); List <string> result = new List <string>(); var G2 = new ComputeGraph(false); //反序组 List <string> revseq = inputSeq.ToList(); revseq.Reverse(); List <WeightMatrix> encoded = new List <WeightMatrix>(); // //Console.WriteLine($"keys>{string.Join(",", wordToIndex.Keys.ToArray())}"); for (int i = 0; i < inputSeq.Count; i++) { //索引 if (!wordToIndex.ContainsKey(inputSeq[i])) { return(eresult.SetFail($"抱歉,未能理解 \"{inputSeq[i]}\" 的含义, 请重新训练我吧!")); //return $"抱歉,未能理解 \"{inputSeq[i]}\" 的含义, 请重新训练我吧!".Split(' ').ToList(); //return $"I'm sorry, I can't understand \"{inputSeq[i]}\" the meaning of the word, please you to retrain me!".Split(' ').ToList(); } if (!wordToIndex.ContainsKey(revseq[i])) { return(eresult.SetFail($"抱歉,未能理解 \"{revseq[i]}\" 的含义, 请重新训练我吧!")); //return $"抱歉,未能理解 \"{inputSeq[i]}\" 的含义, 请重新训练我吧!".Split(' ').ToList(); //return $"I'm sorry, I can't understand \"{revseq[i]}\" the meaning of the word, please you to retrain me!".Split(' ').ToList(); } int ix = wordToIndex[inputSeq[i]].i; int ix2 = wordToIndex[revseq[i]].i; var x2 = G2.PeekRow(Embedding, ix); var o = encoder.Encode(x2, G2); var x3 = G2.PeekRow(Embedding, ix2); var eOutput2 = ReversEncoder.Encode(x3, G2); var d = G2.concatColumns(o, eOutput2); encoded.Add(d); } //if (UseDropout) //{ // for (int i = 0; i < encoded.Weight.Length; i++) // { // encoded.Weight[i] *= 0.2; // } //} var ix_input = 1; while (true) { var x = G2.PeekRow(Embedding, ix_input); var eOutput = decoder.Decode(x, encoded, G2); if (UseDropout) { for (int i = 0; i < eOutput.Weight.Length; i++) { eOutput.Weight[i] *= 0.2; } } var o = G2.add( G2.mul(eOutput, this.Whd), this.bd); if (UseDropout) { for (int i = 0; i < o.Weight.Length; i++) { o.Weight[i] *= 0.2; } } var probs = G2.SoftmaxWithCrossEntropy(o); var maxv = probs.Weight[0]; var maxi = 0; for (int i = 1; i < probs.Weight.Length; i++) { if (probs.Weight[i] > maxv) { maxv = probs.Weight[i]; maxi = i; } } var pred = maxi; if (pred == 0) { break; // END token predicted, break out } if (result.Count > max_word) { break; } // something is wrong var letter2 = indexToWord[pred].w; result.Add(letter2); ix_input = pred; } return(eresult.SetData(result).SetOk()); }