private IComputeGraph CreateComputGraph(int deviceIdIdx, bool needBack = true)
        {
            IComputeGraph g;

            if (m_archType == ArchTypeEnums.CPU_MKL)
            {
                g = new ComputeGraphMKL(m_weightFactory[deviceIdIdx], needBack);
            }
            else if (m_archType == ArchTypeEnums.GPU_CUDA)
            {
                g = new ComputeGraphTensor(m_weightFactory[deviceIdIdx], m_deviceIds[deviceIdIdx], needBack);
            }
            else
            {
                g = new ComputeGraph(m_weightFactory[deviceIdIdx], needBack);
            }

            return(g);
        }
示例#2
0
        public List <string> Predict(List <string> input)
        {
            reversEncoder.Reset();
            encoder.Reset();
            decoder.Reset();

            List <string> result = new List <string>();

#if MKL
            var G2 = new ComputeGraphMKL(false);
#else
            var G2 = new ComputeGraph(false);
#endif

            List <string> inputSeq = new List <string>();
            //    inputSeq.Add(m_START);
            inputSeq.AddRange(input);
            //    inputSeq.Add(m_END);

            List <string> revseq = inputSeq.ToList();
            revseq.Reverse();

            List <WeightMatrix> forwardEncoded  = new List <WeightMatrix>();
            List <WeightMatrix> backwardEncoded = new List <WeightMatrix>();
            List <WeightMatrix> encoded         = new List <WeightMatrix>();
            SparseWeightMatrix  sparseInput     = new SparseWeightMatrix(1, s_Embedding.Columns);

            Parallel.Invoke(
                () =>
            {
                for (int i = 0; i < inputSeq.Count; i++)
                {
                    int ix = (int)SENTTAGS.UNK;
                    if (s_wordToIndex.ContainsKey(inputSeq[i]) == false)
                    {
                        Logger.WriteLine($"Unknow input word: {inputSeq[i]}");
                    }
                    else
                    {
                        ix = s_wordToIndex[inputSeq[i]];
                    }

                    var x2 = G2.PeekRow(s_Embedding, ix);
                    var o  = encoder.Encode(x2, G2);
                    forwardEncoded.Add(o);

                    sparseInput.AddWeight(0, ix, 1.0f);
                }
            },
                () =>
            {
                for (int i = 0; i < inputSeq.Count; i++)
                {
                    int ix = (int)SENTTAGS.UNK;
                    if (s_wordToIndex.ContainsKey(revseq[i]) == false)
                    {
                        Logger.WriteLine($"Unknow input word: {revseq[i]}");
                    }
                    else
                    {
                        ix = s_wordToIndex[revseq[i]];
                    }

                    var x2 = G2.PeekRow(s_Embedding, ix);
                    var o  = reversEncoder.Encode(x2, G2);
                    backwardEncoded.Add(o);
                }
            });

            backwardEncoded.Reverse();
            for (int i = 0; i < inputSeq.Count; i++)
            {
                //encoded.Add(G2.concatColumns(forwardEncoded[i], backwardEncoded[i]));
                encoded.Add(G2.add(forwardEncoded[i], backwardEncoded[i]));
            }

            //if (UseDropout)
            //{
            //    for (int i = 0; i < encoded.Weight.Length; i++)
            //    {
            //        encoded.Weight[i] *= 0.2;
            //    }
            //}
            var ix_input = (int)SENTTAGS.START;
            while (true)
            {
                var x       = G2.PeekRow(t_Embedding, ix_input);
                var eOutput = decoder.Decode(sparseInput, x, encoded, G2);
                if (UseDropout)
                {
                    for (int i = 0; i < eOutput.Weight.Length; i++)
                    {
                        eOutput.Weight[i] *= 0.2f;
                    }
                }
                var o = G2.muladd(eOutput, this.Whd, this.bd);
                if (UseDropout)
                {
                    for (int i = 0; i < o.Weight.Length; i++)
                    {
                        o.Weight[i] *= 0.2f;
                    }
                }
                var probs = G2.SoftmaxWithCrossEntropy(o);
                var maxv  = probs.Weight[0];
                var maxi  = 0;
                for (int i = 1; i < probs.Weight.Length; i++)
                {
                    if (probs.Weight[i] > maxv)
                    {
                        maxv = probs.Weight[i];
                        maxi = i;
                    }
                }
                var pred = maxi;

                if (pred == (int)SENTTAGS.END)
                {
                    break;                            // END token predicted, break out
                }
                if (result.Count > max_word)
                {
                    break;
                }                                       // something is wrong

                var letter2 = m_UNK;
                if (t_indexToWord.ContainsKey(pred))
                {
                    letter2 = t_indexToWord[pred];
                }

                result.Add(letter2);
                ix_input = pred;
            }

            return(result);
        }
示例#3
0
        private IComputeGraph Encode(List <string> inputSentence, out float cost, out SparseWeightMatrix sWM, List <WeightMatrix> encoded, Encoder encoder, Encoder reversEncoder, WeightMatrix Embedding)
        {
            var reversSentence = inputSentence.ToList();

            reversSentence.Reverse();

#if MKL
            IComputeGraph g = new ComputeGraphMKL();
#else
            IComputeGraph g = new ComputeGraph();
#endif


            cost = 0.0f;
            SparseWeightMatrix  tmpSWM          = new SparseWeightMatrix(1, Embedding.Columns);
            List <WeightMatrix> forwardOutputs  = new List <WeightMatrix>();
            List <WeightMatrix> backwardOutputs = new List <WeightMatrix>();

            Parallel.Invoke(
                () =>
            {
                for (int i = 0; i < inputSentence.Count; i++)
                {
                    int ix_source = (int)SENTTAGS.UNK;

                    if (s_wordToIndex.ContainsKey(inputSentence[i]))
                    {
                        ix_source = s_wordToIndex[inputSentence[i]];
                    }
                    var x       = g.PeekRow(Embedding, ix_source);
                    var eOutput = encoder.Encode(x, g);
                    forwardOutputs.Add(eOutput);

                    tmpSWM.AddWeight(0, ix_source, 1.0f);
                }
            },
                () =>
            {
                for (int i = 0; i < inputSentence.Count; i++)
                {
                    int ix_source2 = (int)SENTTAGS.UNK;

                    if (s_wordToIndex.ContainsKey(reversSentence[i]))
                    {
                        ix_source2 = s_wordToIndex[reversSentence[i]];
                    }

                    var x2       = g.PeekRow(Embedding, ix_source2);
                    var eOutput2 = reversEncoder.Encode(x2, g);
                    backwardOutputs.Add(eOutput2);
                }
            });

            backwardOutputs.Reverse();

            for (int i = 0; i < inputSentence.Count; i++)
            {
                //encoded.Add(g.concatColumns(forwardOutputs[i], backwardOutputs[i]));
                encoded.Add(g.add(forwardOutputs[i], backwardOutputs[i]));
            }

            sWM = tmpSWM;

            return(g);
        }