private void Reset(Encoder encoder, Encoder reversEncoder, AttentionDecoder decoder) { encoder.Reset(); reversEncoder.Reset(); decoder.Reset(); }
/// <summary> /// Encode source sentences and output encoded weights /// </summary> /// <param name="g"></param> /// <param name="inputSentences"></param> /// <param name="encoder"></param> /// <param name="reversEncoder"></param> /// <param name="Embedding"></param> /// <returns></returns> private IWeightMatrix Encode(IComputeGraph g, List <List <string> > inputSentences, Encoder encoder, Encoder reversEncoder, IWeightMatrix Embedding) { PadSentences(inputSentences); List <IWeightMatrix> forwardOutputs = new List <IWeightMatrix>(); List <IWeightMatrix> backwardOutputs = new List <IWeightMatrix>(); int seqLen = inputSentences[0].Count; List <IWeightMatrix> forwardInput = new List <IWeightMatrix>(); for (int i = 0; i < seqLen; i++) { for (int j = 0; j < inputSentences.Count; j++) { var inputSentence = inputSentences[j]; int ix_source = (int)SENTTAGS.UNK; if (m_srcWordToIndex.ContainsKey(inputSentence[i])) { ix_source = m_srcWordToIndex[inputSentence[i]]; } var x = g.PeekRow(Embedding, ix_source); forwardInput.Add(x); } } var forwardInputsM = g.ConcatRows(forwardInput); List <IWeightMatrix> attResults = new List <IWeightMatrix>(); for (int i = 0; i < seqLen; i++) { var emb_i = g.PeekRow(forwardInputsM, i * inputSentences.Count, inputSentences.Count); attResults.Add(emb_i); } for (int i = 0; i < seqLen; i++) { var eOutput = encoder.Encode(attResults[i], g); forwardOutputs.Add(eOutput); var eOutput2 = reversEncoder.Encode(attResults[seqLen - i - 1], g); backwardOutputs.Add(eOutput2); } backwardOutputs.Reverse(); var encodedOutput = g.ConcatRowColumn(forwardOutputs, backwardOutputs); return(encodedOutput); }
private IComputeGraph Encode(List <string> inputSentence, out float cost, out SparseWeightMatrix sWM, List <WeightMatrix> encoded, Encoder encoder, Encoder reversEncoder, WeightMatrix Embedding) { var reversSentence = inputSentence.ToList(); reversSentence.Reverse(); #if MKL IComputeGraph g = new ComputeGraphMKL(); #else IComputeGraph g = new ComputeGraph(); #endif cost = 0.0f; SparseWeightMatrix tmpSWM = new SparseWeightMatrix(1, Embedding.Columns); List <WeightMatrix> forwardOutputs = new List <WeightMatrix>(); List <WeightMatrix> backwardOutputs = new List <WeightMatrix>(); Parallel.Invoke( () => { for (int i = 0; i < inputSentence.Count; i++) { int ix_source = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(inputSentence[i])) { ix_source = s_wordToIndex[inputSentence[i]]; } var x = g.PeekRow(Embedding, ix_source); var eOutput = encoder.Encode(x, g); forwardOutputs.Add(eOutput); tmpSWM.AddWeight(0, ix_source, 1.0f); } }, () => { for (int i = 0; i < inputSentence.Count; i++) { int ix_source2 = (int)SENTTAGS.UNK; if (s_wordToIndex.ContainsKey(reversSentence[i])) { ix_source2 = s_wordToIndex[reversSentence[i]]; } var x2 = g.PeekRow(Embedding, ix_source2); var eOutput2 = reversEncoder.Encode(x2, g); backwardOutputs.Add(eOutput2); } }); backwardOutputs.Reverse(); for (int i = 0; i < inputSentence.Count; i++) { //encoded.Add(g.concatColumns(forwardOutputs[i], backwardOutputs[i])); encoded.Add(g.add(forwardOutputs[i], backwardOutputs[i])); } sWM = tmpSWM; return(g); }
private void Reset(IWeightFactory weightFactory, Encoder encoder, Encoder reversEncoder, AttentionDecoder decoder) { encoder.Reset(weightFactory); reversEncoder.Reset(weightFactory); decoder.Reset(weightFactory); }