private IComputeGraph CreateComputGraph(int deviceIdIdx, bool needBack = true) { IComputeGraph g; if (m_archType == ArchTypeEnums.CPU_MKL) { g = new ComputeGraphMKL(m_weightFactory[deviceIdIdx], needBack); } else if (m_archType == ArchTypeEnums.GPU_CUDA) { g = new ComputeGraphTensor(m_weightFactory[deviceIdIdx], m_deviceIds[deviceIdIdx], needBack); } else { g = new ComputeGraph(m_weightFactory[deviceIdIdx], needBack); } return(g); }
public List <string> Predict(List <string> input) { reversEncoder.Reset(); encoder.Reset(); decoder.Reset(); List <string> result = new List <string>(); #if MKL var G2 = new ComputeGraphMKL(false); #else var G2 = new ComputeGraph(false); #endif List <string> inputSeq = new List <string>(); // inputSeq.Add(m_START); inputSeq.AddRange(input); // inputSeq.Add(m_END); List <string> revseq = inputSeq.ToList(); revseq.Reverse(); List <WeightMatrix> forwardEncoded = new List <WeightMatrix>(); List <WeightMatrix> backwardEncoded = new List <WeightMatrix>(); List <WeightMatrix> encoded = new List <WeightMatrix>(); SparseWeightMatrix sparseInput = new SparseWeightMatrix(1, s_Embedding.Columns); Parallel.Invoke( () => { for (int i = 0; i < inputSeq.Count; i++) { int ix = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(inputSeq[i]) == false) { Logger.WriteLine($"Unknow input word: {inputSeq[i]}"); } else { ix = s_wordToIndex[inputSeq[i]]; } var x2 = G2.PeekRow(s_Embedding, ix); var o = encoder.Encode(x2, G2); forwardEncoded.Add(o); sparseInput.AddWeight(0, ix, 1.0f); } }, () => { for (int i = 0; i < inputSeq.Count; i++) { int ix = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(revseq[i]) == false) { Logger.WriteLine($"Unknow input word: {revseq[i]}"); } else { ix = s_wordToIndex[revseq[i]]; } var x2 = G2.PeekRow(s_Embedding, ix); var o = reversEncoder.Encode(x2, G2); backwardEncoded.Add(o); } }); backwardEncoded.Reverse(); for (int i = 0; i < inputSeq.Count; i++) { //encoded.Add(G2.concatColumns(forwardEncoded[i], backwardEncoded[i])); encoded.Add(G2.add(forwardEncoded[i], backwardEncoded[i])); } //if (UseDropout) //{ // for (int i = 0; i < encoded.Weight.Length; i++) // { // encoded.Weight[i] *= 0.2; // } //} var ix_input = (int)SENTTAGS.START; while (true) { var x = G2.PeekRow(t_Embedding, ix_input); var eOutput = decoder.Decode(sparseInput, x, encoded, G2); if (UseDropout) { for (int i = 0; i < eOutput.Weight.Length; i++) { eOutput.Weight[i] *= 0.2f; } } var o = G2.muladd(eOutput, this.Whd, this.bd); if (UseDropout) { for (int i = 0; i < o.Weight.Length; i++) { o.Weight[i] *= 0.2f; } } var probs = G2.SoftmaxWithCrossEntropy(o); var maxv = probs.Weight[0]; var maxi = 0; for (int i = 1; i < probs.Weight.Length; i++) { if (probs.Weight[i] > maxv) { maxv = probs.Weight[i]; maxi = i; } } var pred = maxi; if (pred == (int)SENTTAGS.END) { break; // END token predicted, break out } if (result.Count > max_word) { break; } // something is wrong var letter2 = m_UNK; if (t_indexToWord.ContainsKey(pred)) { letter2 = t_indexToWord[pred]; } result.Add(letter2); ix_input = pred; } return(result); }
private IComputeGraph Encode(List <string> inputSentence, out float cost, out SparseWeightMatrix sWM, List <WeightMatrix> encoded, Encoder encoder, Encoder reversEncoder, WeightMatrix Embedding) { var reversSentence = inputSentence.ToList(); reversSentence.Reverse(); #if MKL IComputeGraph g = new ComputeGraphMKL(); #else IComputeGraph g = new ComputeGraph(); #endif cost = 0.0f; SparseWeightMatrix tmpSWM = new SparseWeightMatrix(1, Embedding.Columns); List <WeightMatrix> forwardOutputs = new List <WeightMatrix>(); List <WeightMatrix> backwardOutputs = new List <WeightMatrix>(); Parallel.Invoke( () => { for (int i = 0; i < inputSentence.Count; i++) { int ix_source = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(inputSentence[i])) { ix_source = s_wordToIndex[inputSentence[i]]; } var x = g.PeekRow(Embedding, ix_source); var eOutput = encoder.Encode(x, g); forwardOutputs.Add(eOutput); tmpSWM.AddWeight(0, ix_source, 1.0f); } }, () => { for (int i = 0; i < inputSentence.Count; i++) { int ix_source2 = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(reversSentence[i])) { ix_source2 = s_wordToIndex[reversSentence[i]]; } var x2 = g.PeekRow(Embedding, ix_source2); var eOutput2 = reversEncoder.Encode(x2, g); backwardOutputs.Add(eOutput2); } }); backwardOutputs.Reverse(); for (int i = 0; i < inputSentence.Count; i++) { //encoded.Add(g.concatColumns(forwardOutputs[i], backwardOutputs[i])); encoded.Add(g.add(forwardOutputs[i], backwardOutputs[i])); } sWM = tmpSWM; return(g); }