コード例 #1
0
        /// <summary>
        /// Extract features from source sequence
        /// </summary>
        /// <param name="decoder"></param>
        /// <param name="srcSequence"></param>
        /// <param name="targetSparseFeatureSize"></param>
        /// <param name="srcHiddenAvgOutput"></param>
        /// <param name="srcSparseFeatures"></param>
        protected virtual void ExtractSourceSentenceFeature(RNNDecoder decoder, Sequence srcSequence, int targetSparseFeatureSize)
        {
            //Extract dense features from source sequence
            var srcOutputs = decoder.ComputeTopHiddenLayerOutput(srcSequence);
            int srcSequenceDenseFeatureSize = srcOutputs[0].Length;
            int srcSequenceLength           = srcOutputs.Length - 1;

            if (srcHiddenAvgOutput == null)
            {
                srcHiddenAvgOutput = new float[srcSequenceDenseFeatureSize * 2];
            }

            var j = 0;

            float[] srcOutputForward  = srcOutputs[0];
            float[] srcOutputBackward = srcOutputs[srcSequenceLength];
            while (j < srcSequenceDenseFeatureSize)
            {
                var vForward  = new Vector <float>(srcOutputForward, j);
                var vBackward = new Vector <float>(srcOutputBackward, j);

                vForward.CopyTo(srcHiddenAvgOutput, j);
                vBackward.CopyTo(srcHiddenAvgOutput, srcSequenceDenseFeatureSize + j);

                j += Vector <float> .Count;
            }

            //Extract sparse features from source sequence
            if (srcSparseFeatures == null)
            {
                srcSparseFeatures = new Dictionary <int, float>();
            }
            else
            {
                srcSparseFeatures.Clear();
            }

            for (var i = 0; i < srcSequence.States.Length; i++)
            {
                foreach (var kv in srcSequence.States[i].SparseFeature)
                {
                    var srcSparseFeatureIndex = kv.Key + targetSparseFeatureSize;

                    if (srcSparseFeatures.ContainsKey(srcSparseFeatureIndex) == false)
                    {
                        srcSparseFeatures.Add(srcSparseFeatureIndex, kv.Value);
                    }
                    else
                    {
                        srcSparseFeatures[srcSparseFeatureIndex] += kv.Value;
                    }
                }
            }
        }
コード例 #2
0
        private void LoadLanguageModelRNN(string strLMFileName)
        {
            float regularization = 0.0000001f;
            float dynamic        = 0;

            lmDecoder_rnn = new RNNDecoder();
            lmDecoder_rnn.setLambda(0.5);
            lmDecoder_rnn.setRegularization(regularization);
            lmDecoder_rnn.setDynamic(dynamic);
            lmDecoder_rnn.LoadLM(strLMFileName);
        }
コード例 #3
0
        /// <summary>
        /// Extract features from source sequence
        /// </summary>
        /// <param name="decoder"></param>
        /// <param name="srcSequence"></param>
        /// <param name="targetSparseFeatureSize"></param>
        /// <param name="srcHiddenAvgOutput"></param>
        /// <param name="srcSparseFeatures"></param>
        private void ExtractSourceSentenceFeature(RNNDecoder decoder, Sequence srcSequence, int targetSparseFeatureSize, List <float[]> srcDenseFeatureGroups, SparseVector srcSparseFeatures)
        {
            //Extract dense features from source sequence
            var srcOutputs        = decoder.ComputeTopHiddenLayerOutput(srcSequence);
            int srcSequenceLength = srcOutputs.Length - 1;

            srcDenseFeatureGroups.Add(srcOutputs[srcSequenceLength]);

            if (numSrcDenseFeatureGroups > 1)
            {
                srcDenseFeatureGroups.Add(srcOutputs[0]);
            }

            if (numSrcDenseFeatureGroups > 2)
            {
                float srcOffsetPerBlock = (float)srcSequenceLength / (float)(numSrcDenseFeatureGroups - 1);
                if (srcOffsetPerBlock < 1.0)
                {
                    srcOffsetPerBlock = 1.0f;
                }

                float idx = srcOffsetPerBlock;
                while (srcDenseFeatureGroups.Count < numSrcDenseFeatureGroups && idx < srcSequenceLength)
                {
                    srcDenseFeatureGroups.Add(srcOutputs[(int)idx]);
                    idx += srcOffsetPerBlock;
                }
            }

            //Extract sparse features from source sequence
            Dictionary <int, float> srcSparseFeaturesDict = new Dictionary <int, float>();

            for (var i = 0; i < srcSequence.States.Length; i++)
            {
                foreach (var kv in srcSequence.States[i].SparseFeature)
                {
                    var srcSparseFeatureIndex = kv.Key + targetSparseFeatureSize;

                    if (srcSparseFeaturesDict.ContainsKey(srcSparseFeatureIndex) == false)
                    {
                        srcSparseFeaturesDict.Add(srcSparseFeatureIndex, kv.Value);
                    }
                    else
                    {
                        srcSparseFeaturesDict[srcSparseFeatureIndex] += kv.Value;
                    }
                }
            }

            srcSparseFeatures.SetLength(srcSequence.SparseFeatureSize + targetSparseFeatureSize);
            srcSparseFeatures.AddKeyValuePairData(srcSparseFeaturesDict);
        }
コード例 #4
0
        private static void Test()
        {
            if (string.IsNullOrEmpty(tagFilePath))
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The tag mapping file {0} isn't specified.", tagFilePath);
                UsageTest();
                return;
            }

            //Load tag name
            Logger.WriteLine($"Loading tag file '{tagFilePath}'");
            var tagSet = new TagSet(tagFilePath);

            if (string.IsNullOrEmpty(configFilePath))
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The configuration file {0} isn't specified.", configFilePath);
                UsageTest();
                return;
            }

            if (outputFilePath.Length == 0)
            {
                Logger.WriteLine(Logger.Level.err, "FAILED: The output file name should not be empty.");
                UsageTest();
                return;
            }

            //Create feature extractors and load word embedding data from file
            Logger.WriteLine($"Initializing config file = '{configFilePath}'");
            var config = new Config(configFilePath, tagSet);

            config.ShowFeatureSize();

            //Create instance for decoder
            Logger.WriteLine($"Loading model from {config.ModelFilePath} and creating decoder instance...");
            var decoder = new RNNDecoder(config);

            if (File.Exists(testFilePath) == false)
            {
                Logger.WriteLine(Logger.Level.err, $"FAILED: The test corpus {testFilePath} doesn't exist.");
                UsageTest();
                return;
            }

            var sr = new StreamReader(testFilePath);
            var sw = new StreamWriter(outputFilePath);

            while (true)
            {
                var sent = new Sentence(ReadRecord(sr));
                if (sent.TokensList.Count <= 2)
                {
                    //No more record, it only contains <s> and </s>
                    break;
                }

                if (nBest == 1)
                {
                    //Output decoded result
                    if (decoder.ModelType == MODELTYPE.SeqLabel)
                    {
                        //Append the decoded result into the end of feature set of each token
                        var output = decoder.Process(sent);
                        var sb     = new StringBuilder();
                        for (var i = 0; i < sent.TokensList.Count; i++)
                        {
                            var tokens = string.Join("\t", sent.TokensList[i]);
                            sb.Append(tokens);
                            sb.Append("\t");
                            sb.Append(tagSet.GetTagName(output[i]));
                            sb.AppendLine();
                        }

                        sw.WriteLine(sb.ToString());
                    }
                    else
                    {
                        //Print out source sentence at first, and then generated result sentence
                        var output = decoder.ProcessSeq2Seq(sent);
                        var sb     = new StringBuilder();
                        for (var i = 0; i < sent.TokensList.Count; i++)
                        {
                            var tokens = string.Join("\t", sent.TokensList[i]);
                            sb.AppendLine(tokens);
                        }
                        sw.WriteLine(sb.ToString());
                        sw.WriteLine();

                        sb.Clear();
                        for (var i = 0; i < output.Length; i++)
                        {
                            var token = tagSet.GetTagName(output[i]);
                            sb.AppendLine(token);
                        }
                        sw.WriteLine(sb.ToString());
                        sw.WriteLine();
                    }
                }
                else
                {
                    var output = decoder.ProcessNBest(sent, nBest);
                    var sb     = new StringBuilder();
                    for (var i = 0; i < nBest; i++)
                    {
                        for (var j = 0; j < sent.TokensList.Count; j++)
                        {
                            var tokens = string.Join("\t", sent.TokensList[i]);
                            sb.Append(tokens);
                            sb.Append("\t");
                            sb.Append(tagSet.GetTagName(output[i][j]));
                            sb.AppendLine();
                        }
                        sb.AppendLine();
                    }

                    sw.WriteLine(sb.ToString());
                }
            }

            sr.Close();
            sw.Close();
        }