コード例 #1
0
 public virtual bool Parse <_T0>(IList <_T0> sentence)
     where _T0 : IHasWord
 {
     this.originalSentence = sentence;
     initialState          = ShiftReduceParser.InitialStateFromTaggedSentence(sentence);
     return(ParseInternal());
 }
コード例 #2
0
 public virtual void TestReorderIncorrectShiftResultingTree()
 {
     for (int testcase = 0; testcase < correctTrees.Length; ++testcase)
     {
         State state = ShiftReduceParser.InitialStateFromGoldTagTree(correctTrees[testcase]);
         IList <ITransition> gold = CreateTransitionSequence.CreateTransitionSequence(binarizedTrees[testcase]);
         // System.err.println(correctTrees[testcase]);
         // System.err.println(gold);
         int tnum = 0;
         for (; tnum < gold.Count; ++tnum)
         {
             if (gold[tnum] is BinaryTransition)
             {
                 break;
             }
             state = gold[tnum].Apply(state);
         }
         state = shift.Apply(state);
         IList <ITransition> reordered = Generics.NewLinkedList(gold.SubList(tnum, gold.Count));
         NUnit.Framework.Assert.IsTrue(ReorderingOracle.ReorderIncorrectShiftTransition(reordered));
         // System.err.println(reordered);
         foreach (ITransition transition in reordered)
         {
             state = transition.Apply(state);
         }
         Tree debinarized = debinarizer.TransformTree(state.stack.Peek());
         // System.err.println(debinarized);
         NUnit.Framework.Assert.AreEqual(incorrectShiftTrees[testcase].ToString(), debinarized.ToString());
     }
 }
コード例 #3
0
        protected virtual void SetUp()
        {
            Options  op       = new Options();
            Treebank treebank = op.tlpParams.MemoryTreebank();

            Sharpen.Collections.AddAll(treebank, Arrays.AsList(correctTrees));
            binarizedTrees = ShiftReduceParser.BinarizeTreebank(treebank, op);
        }
コード例 #4
0
        public static ParserGrammar LoadModel(string path, params string[] extraFlags)
        {
            ShiftReduceParser parser = IOUtils.ReadObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading parser from serialized file", path);

            if (extraFlags.Length > 0)
            {
                parser.SetOptionFlags(extraFlags);
            }
            return(parser);
        }
コード例 #5
0
        // A small variety of trees to test on, especially with different depths of unary transitions
        public virtual IList <Tree> BuildTestTreebank()
        {
            MemoryTreebank treebank = new MemoryTreebank();

            foreach (string text in TestTrees)
            {
                Tree tree = Tree.ValueOf(text);
                treebank.Add(tree);
            }
            IList <Tree> binarizedTrees = ShiftReduceParser.BinarizeTreebank(treebank, new Options());

            return(binarizedTrees);
        }
コード例 #6
0
 public virtual void TestCompoundUnaryTransitions()
 {
     foreach (string treeText in treeStrings)
     {
         Tree tree = ConvertTree(treeText);
         IList <ITransition> transitions = CreateTransitionSequence.CreateTransitionSequence(tree, true, Java.Util.Collections.Singleton("ROOT"), Java.Util.Collections.Singleton("ROOT"));
         State state = ShiftReduceParser.InitialStateFromGoldTagTree(tree);
         foreach (ITransition transition in transitions)
         {
             state = transition.Apply(state);
         }
         NUnit.Framework.Assert.AreEqual(tree, state.stack.Peek());
     }
 }
コード例 #7
0
        public virtual void TestSeparators()
        {
            Tree tree = ConvertTree(commaTreeString);
            IList <ITransition> transitions         = CreateTransitionSequence.CreateTransitionSequence(tree, true, Java.Util.Collections.Singleton("ROOT"), Java.Util.Collections.Singleton("ROOT"));
            IList <string>      expectedTransitions = Arrays.AsList(new string[] { "Shift", "Shift", "Shift", "Shift", "RightBinary(@ADJP)", "RightBinary(ADJP)", "Shift", "RightBinary(@NP)", "RightBinary(NP)", "CompoundUnary*([ROOT, FRAG])", "Finalize", "Idle" });

            NUnit.Framework.Assert.AreEqual(expectedTransitions, CollectionUtils.TransformAsList(transitions, null));
            string expectedSeparators = "[{2=,}]";
            State  state = ShiftReduceParser.InitialStateFromGoldTagTree(tree);

            NUnit.Framework.Assert.AreEqual(1, state.separators.Count);
            NUnit.Framework.Assert.AreEqual(2, state.separators.FirstKey());
            NUnit.Framework.Assert.AreEqual(",", state.separators[2]);
        }
コード例 #8
0
 public static void RunEndToEndTest(IList <Tree> binarizedTrees, Oracle oracle)
 {
     for (int index = 0; index < binarizedTrees.Count; ++index)
     {
         State state = ShiftReduceParser.InitialStateFromGoldTagTree(binarizedTrees[index]);
         while (!state.IsFinished())
         {
             OracleTransition gold = oracle.GoldTransition(index, state);
             NUnit.Framework.Assert.IsTrue(gold.transition != null);
             state = gold.transition.Apply(state);
         }
         NUnit.Framework.Assert.AreEqual(binarizedTrees[index], state.stack.Peek());
     }
 }
コード例 #9
0
        public virtual void TestTransition()
        {
            string[] words = new string[] { "This", "is", "a", "short", "test", "." };
            string[] tags  = new string[] { "DT", "VBZ", "DT", "JJ", "NN", "." };
            NUnit.Framework.Assert.AreEqual(words.Length, tags.Length);
            IList <TaggedWord> sentence = SentenceUtils.ToTaggedList(Arrays.AsList(words), Arrays.AsList(tags));
            State           state       = ShiftReduceParser.InitialStateFromTaggedSentence(sentence);
            ShiftTransition shift       = new ShiftTransition();

            for (int i = 0; i < 3; ++i)
            {
                state = shift.Apply(state);
            }
            NUnit.Framework.Assert.AreEqual(3, state.tokenPosition);
        }
コード例 #10
0
        public virtual void TestInitialStateFromTagged()
        {
            string[] words = new string[] { "This", "is", "a", "short", "test", "." };
            string[] tags  = new string[] { "DT", "VBZ", "DT", "JJ", "NN", "." };
            NUnit.Framework.Assert.AreEqual(words.Length, tags.Length);
            IList <TaggedWord> sentence = SentenceUtils.ToTaggedList(Arrays.AsList(words), Arrays.AsList(tags));
            State state = ShiftReduceParser.InitialStateFromTaggedSentence(sentence);

            for (int i = 0; i < words.Length; ++i)
            {
                NUnit.Framework.Assert.AreEqual(tags[i], state.sentence[i].Value());
                NUnit.Framework.Assert.AreEqual(1, state.sentence[i].Children().Length);
                NUnit.Framework.Assert.AreEqual(words[i], state.sentence[i].Children()[0].Value());
            }
        }
コード例 #11
0
 /// <summary>
 /// Will train the model on the given treebank, using devTreebank as
 /// a dev set.
 /// </summary>
 /// <remarks>
 /// Will train the model on the given treebank, using devTreebank as
 /// a dev set.  If op.retrainAfterCutoff is set, will rerun training
 /// after the first time through on a limited set of features.
 /// </remarks>
 public override void TrainModel(string serializedPath, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger, Random random, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, Treebank devTreebank, int nThreads)
 {
     if (op.TrainOptions().retrainAfterCutoff&& op.TrainOptions().featureFrequencyCutoff > 0)
     {
         string tempName = Sharpen.Runtime.Substring(serializedPath, 0, serializedPath.Length - 7) + "-" + "temp.ser.gz";
         TrainModel(tempName, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, null);
         ShiftReduceParser temp = new ShiftReduceParser(op, this);
         temp.SaveModel(tempName);
         ICollection <string> features = featureWeights.Keys;
         featureWeights = Generics.NewHashMap();
         TrainModel(serializedPath, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, features);
     }
     else
     {
         TrainModel(serializedPath, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, null);
     }
 }
コード例 #12
0
        public virtual void TestBinarySide()
        {
            string[] words = new string[] { "This", "is", "a", "short", "test", "." };
            string[] tags  = new string[] { "DT", "VBZ", "DT", "JJ", "NN", "." };
            NUnit.Framework.Assert.AreEqual(words.Length, tags.Length);
            IList <TaggedWord> sentence = SentenceUtils.ToTaggedList(Arrays.AsList(words), Arrays.AsList(tags));
            State           state       = ShiftReduceParser.InitialStateFromTaggedSentence(sentence);
            ShiftTransition shift       = new ShiftTransition();

            state = shift.Apply(shift.Apply(state));
            BinaryTransition transition = new BinaryTransition("NP", BinaryTransition.Side.Right);
            State            next       = transition.Apply(state);

            NUnit.Framework.Assert.AreEqual(BinaryTransition.Side.Right, ShiftReduceUtils.GetBinarySide(next.stack.Peek()));
            transition = new BinaryTransition("NP", BinaryTransition.Side.Left);
            next       = transition.Apply(state);
            NUnit.Framework.Assert.AreEqual(BinaryTransition.Side.Left, ShiftReduceUtils.GetBinarySide(next.stack.Peek()));
        }
コード例 #13
0
        private void TrainModel(string serializedPath, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger, Random random, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, Treebank devTreebank, int nThreads, ICollection <string> allowedFeatures
                                )
        {
            double bestScore     = 0.0;
            int    bestIteration = 0;
            PriorityQueue <ScoredObject <PerceptronModel> > bestModels = null;

            if (op.TrainOptions().averagedModels > 0)
            {
                bestModels = new PriorityQueue <ScoredObject <PerceptronModel> >(op.TrainOptions().averagedModels + 1, ScoredComparator.AscendingComparator);
            }
            IList <int> indices = Generics.NewArrayList();

            for (int i = 0; i < binarizedTrees.Count; ++i)
            {
                indices.Add(i);
            }
            Oracle oracle = null;

            if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Oracle)
            {
                oracle = new Oracle(binarizedTrees, op.compoundUnaries, rootStates);
            }
            IList <PerceptronModel.Update>           updates = Generics.NewArrayList();
            MulticoreWrapper <int, Pair <int, int> > wrapper = null;

            if (nThreads != 1)
            {
                updates = Java.Util.Collections.SynchronizedList(updates);
                wrapper = new MulticoreWrapper <int, Pair <int, int> >(op.trainOptions.trainingThreads, new PerceptronModel.TrainTreeProcessor(this, binarizedTrees, transitionLists, updates, oracle));
            }
            IntCounter <string> featureFrequencies = null;

            if (op.TrainOptions().featureFrequencyCutoff > 1)
            {
                featureFrequencies = new IntCounter <string>();
            }
            for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration)
            {
                Timing trainingTimer = new Timing();
                int    numCorrect    = 0;
                int    numWrong      = 0;
                Java.Util.Collections.Shuffle(indices, random);
                for (int start = 0; start < indices.Count; start += op.trainOptions.batchSize)
                {
                    int end = Math.Min(start + op.trainOptions.batchSize, indices.Count);
                    Triple <IList <PerceptronModel.Update>, int, int> result = TrainBatch(indices.SubList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper);
                    numCorrect += result.second;
                    numWrong   += result.third;
                    foreach (PerceptronModel.Update update in result.first)
                    {
                        foreach (string feature in update.features)
                        {
                            if (allowedFeatures != null && !allowedFeatures.Contains(feature))
                            {
                                continue;
                            }
                            Weight weights = featureWeights[feature];
                            if (weights == null)
                            {
                                weights = new Weight();
                                featureWeights[feature] = weights;
                            }
                            weights.UpdateWeight(update.goldTransition, update.delta);
                            weights.UpdateWeight(update.predictedTransition, -update.delta);
                            if (featureFrequencies != null)
                            {
                                featureFrequencies.IncrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1);
                            }
                        }
                    }
                    updates.Clear();
                }
                trainingTimer.Done("Iteration " + iteration);
                log.Info("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong");
                OutputStats();
                double labelF1 = 0.0;
                if (devTreebank != null)
                {
                    EvaluateTreebank evaluator = new EvaluateTreebank(op, null, new ShiftReduceParser(op, this), tagger);
                    evaluator.TestOnTreebank(devTreebank);
                    labelF1 = evaluator.GetLBScore();
                    log.Info("Label F1 after " + iteration + " iterations: " + labelF1);
                    if (labelF1 > bestScore)
                    {
                        log.Info("New best dev score (previous best " + bestScore + ")");
                        bestScore     = labelF1;
                        bestIteration = iteration;
                    }
                    else
                    {
                        log.Info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
                        if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit))
                        {
                            log.Info("Failed to improve for too long, stopping training");
                            break;
                        }
                    }
                    log.Info();
                    if (bestModels != null)
                    {
                        bestModels.Add(new ScoredObject <PerceptronModel>(new PerceptronModel(this), labelF1));
                        if (bestModels.Count > op.TrainOptions().averagedModels)
                        {
                            bestModels.Poll();
                        }
                    }
                }
                if (op.TrainOptions().saveIntermediateModels&& serializedPath != null && op.trainOptions.debugOutputFrequency > 0)
                {
                    string            tempName = Sharpen.Runtime.Substring(serializedPath, 0, serializedPath.Length - 7) + "-" + Filename.Format(iteration) + "-" + Nf.Format(labelF1) + ".ser.gz";
                    ShiftReduceParser temp     = new ShiftReduceParser(op, this);
                    temp.SaveModel(tempName);
                }
                // TODO: we could save a cutoff version of the model,
                // especially if we also get a dev set number for it, but that
                // might be overkill
                if (iteration % 10 == 0 && op.TrainOptions().decayLearningRate > 0.0)
                {
                    learningRate *= op.TrainOptions().decayLearningRate;
                }
            }
            // end for iterations
            if (wrapper != null)
            {
                wrapper.Join();
            }
            if (bestModels != null)
            {
                if (op.TrainOptions().cvAveragedModels&& devTreebank != null)
                {
                    IList <ScoredObject <PerceptronModel> > models = Generics.NewArrayList();
                    while (bestModels.Count > 0)
                    {
                        models.Add(bestModels.Poll());
                    }
                    Java.Util.Collections.Reverse(models);
                    double bestF1   = 0.0;
                    int    bestSize = 0;
                    for (int i_1 = 1; i_1 <= models.Count; ++i_1)
                    {
                        log.Info("Testing with " + i_1 + " models averaged together");
                        // TODO: this is kind of ugly, would prefer a separate object
                        AverageScoredModels(models.SubList(0, i_1));
                        ShiftReduceParser temp      = new ShiftReduceParser(op, this);
                        EvaluateTreebank  evaluator = new EvaluateTreebank(temp.GetOp(), null, temp, tagger);
                        evaluator.TestOnTreebank(devTreebank);
                        double labelF1 = evaluator.GetLBScore();
                        log.Info("Label F1 for " + i_1 + " models: " + labelF1);
                        if (labelF1 > bestF1)
                        {
                            bestF1   = labelF1;
                            bestSize = i_1;
                        }
                    }
                    AverageScoredModels(models.SubList(0, bestSize));
                }
                else
                {
                    AverageScoredModels(bestModels);
                }
            }
            // TODO: perhaps we should filter the features and then get dev
            // set scores.  That way we can merge the models which are best
            // after filtering.
            if (featureFrequencies != null)
            {
                FilterFeatures(featureFrequencies.KeysAbove(op.TrainOptions().featureFrequencyCutoff));
            }
            CondenseFeatures();
        }
コード例 #14
0
        public static void Main(string[] args)
        {
            IList <string> remainingArgs = Generics.NewArrayList();
            IList <Pair <string, IFileFilter> > trainTreebankPath = null;
            Pair <string, IFileFilter>          testTreebankPath  = null;
            Pair <string, IFileFilter>          devTreebankPath   = null;
            string serializedPath   = null;
            string tlppClass        = null;
            string continueTraining = null;

            for (int argIndex = 0; argIndex < args.Length;)
            {
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-trainTreebank"))
                {
                    if (trainTreebankPath == null)
                    {
                        trainTreebankPath = Generics.NewArrayList();
                    }
                    trainTreebankPath.Add(ArgUtils.GetTreebankDescription(args, argIndex, "-trainTreebank"));
                    argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                }
                else
                {
                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank"))
                    {
                        testTreebankPath = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank");
                        argIndex         = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                    }
                    else
                    {
                        if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-devTreebank"))
                        {
                            devTreebankPath = ArgUtils.GetTreebankDescription(args, argIndex, "-devTreebank");
                            argIndex        = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                        }
                        else
                        {
                            if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-serializedPath") || Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model"))
                            {
                                serializedPath = args[argIndex + 1];
                                argIndex      += 2;
                            }
                            else
                            {
                                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-tlpp"))
                                {
                                    tlppClass = args[argIndex + 1];
                                    argIndex += 2;
                                }
                                else
                                {
                                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-continueTraining"))
                                    {
                                        continueTraining = args[argIndex + 1];
                                        argIndex        += 2;
                                    }
                                    else
                                    {
                                        remainingArgs.Add(args[argIndex]);
                                        ++argIndex;
                                    }
                                }
                            }
                        }
                    }
                }
            }
            string[] newArgs = new string[remainingArgs.Count];
            newArgs = Sharpen.Collections.ToArray(remainingArgs, newArgs);
            if (trainTreebankPath == null && serializedPath == null)
            {
                throw new ArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath");
            }
            ShiftReduceParser parser = null;

            if (trainTreebankPath != null)
            {
                log.Info("Training ShiftReduceParser");
                log.Info("Initial arguments:");
                log.Info("   " + StringUtils.Join(args));
                if (continueTraining != null)
                {
                    parser = ((ShiftReduceParser)ShiftReduceParser.LoadModel(continueTraining, ArrayUtils.Concatenate(ForceTags, newArgs)));
                }
                else
                {
                    ShiftReduceOptions op = BuildTrainingOptions(tlppClass, newArgs);
                    parser = new ShiftReduceParser(op);
                }
                parser.Train(trainTreebankPath, devTreebankPath, serializedPath);
                parser.SaveModel(serializedPath);
            }
            if (serializedPath != null && parser == null)
            {
                parser = ((ShiftReduceParser)ShiftReduceParser.LoadModel(serializedPath, ArrayUtils.Concatenate(ForceTags, newArgs)));
            }
            //parser.outputStats();
            if (testTreebankPath != null)
            {
                log.Info("Loading test trees from " + testTreebankPath.First());
                Treebank testTreebank = parser.op.tlpParams.MemoryTreebank();
                testTreebank.LoadPath(testTreebankPath.First(), testTreebankPath.Second());
                log.Info("Loaded " + testTreebank.Count + " trees");
                EvaluateTreebank evaluator = new EvaluateTreebank(parser.op, null, parser);
                evaluator.TestOnTreebank(testTreebank);
            }
        }
コード例 #15
0
        private Pair <int, int> TrainTree(int index, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, IList <PerceptronModel.Update> updates, Oracle oracle)
        {
            int              numCorrect = 0;
            int              numWrong   = 0;
            Tree             tree       = binarizedTrees[index];
            ReorderingOracle reorderer  = null;

            if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderOracle || op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam)
            {
                reorderer = new ReorderingOracle(op);
            }
            // TODO.  This training method seems to be working in that it
            // trains models just like the gold and early termination methods do.
            // However, it causes the feature space to go crazy.  Presumably
            // leaving out features with low weights or low frequencies would
            // significantly help with that.  Otherwise, not sure how to keep
            // it under control.
            if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Oracle)
            {
                State state = ShiftReduceParser.InitialStateFromGoldTagTree(tree);
                while (!state.IsFinished())
                {
                    IList <string>     features   = featureFactory.Featurize(state);
                    ScoredObject <int> prediction = FindHighestScoringTransition(state, features, true);
                    if (prediction == null)
                    {
                        throw new AssertionError("Did not find a legal transition");
                    }
                    int              predictedNum = prediction.Object();
                    ITransition      predicted    = transitionIndex.Get(predictedNum);
                    OracleTransition gold         = oracle.GoldTransition(index, state);
                    if (gold.IsCorrect(predicted))
                    {
                        numCorrect++;
                        if (gold.transition != null && !gold.transition.Equals(predicted))
                        {
                            int transitionNum = transitionIndex.IndexOf(gold.transition);
                            if (transitionNum < 0)
                            {
                                // TODO: do we want to add unary transitions which are
                                // only possible when the parser has gone off the rails?
                                continue;
                            }
                            updates.Add(new PerceptronModel.Update(features, transitionNum, -1, learningRate));
                        }
                    }
                    else
                    {
                        numWrong++;
                        int transitionNum = -1;
                        if (gold.transition != null)
                        {
                            transitionNum = transitionIndex.IndexOf(gold.transition);
                        }
                        // TODO: this can theoretically result in a -1 gold
                        // transition if the transition exists, but is a
                        // CompoundUnaryTransition which only exists because the
                        // parser is wrong.  Do we want to add those transitions?
                        updates.Add(new PerceptronModel.Update(features, transitionNum, predictedNum, learningRate));
                    }
                    state = predicted.Apply(state);
                }
            }
            else
            {
                if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Beam || op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam)
                {
                    if (op.TrainOptions().beamSize <= 0)
                    {
                        throw new ArgumentException("Illegal beam size " + op.TrainOptions().beamSize);
                    }
                    IList <ITransition>   transitions = Generics.NewLinkedList(transitionLists[index]);
                    PriorityQueue <State> agenda      = new PriorityQueue <State>(op.TrainOptions().beamSize + 1, ScoredComparator.AscendingComparator);
                    State goldState = ShiftReduceParser.InitialStateFromGoldTagTree(tree);
                    agenda.Add(goldState);
                    // int transitionCount = 0;
                    while (transitions.Count > 0)
                    {
                        ITransition           goldTransition = transitions[0];
                        ITransition           highestScoringTransitionFromGoldState = null;
                        double                highestScoreFromGoldState             = 0.0;
                        PriorityQueue <State> newAgenda = new PriorityQueue <State>(op.TrainOptions().beamSize + 1, ScoredComparator.AscendingComparator);
                        State highestScoringState       = null;
                        State highestCurrentState       = null;
                        foreach (State currentState in agenda)
                        {
                            bool           isGoldState = (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam && goldState.AreTransitionsEqual(currentState));
                            IList <string> features    = featureFactory.Featurize(currentState);
                            ICollection <ScoredObject <int> > stateTransitions = FindHighestScoringTransitions(currentState, features, true, op.TrainOptions().beamSize, null);
                            foreach (ScoredObject <int> transition in stateTransitions)
                            {
                                State newState = transitionIndex.Get(transition.Object()).Apply(currentState, transition.Score());
                                newAgenda.Add(newState);
                                if (newAgenda.Count > op.TrainOptions().beamSize)
                                {
                                    newAgenda.Poll();
                                }
                                if (highestScoringState == null || highestScoringState.Score() < newState.Score())
                                {
                                    highestScoringState = newState;
                                    highestCurrentState = currentState;
                                }
                                if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.Score() > highestScoreFromGoldState))
                                {
                                    highestScoringTransitionFromGoldState = transitionIndex.Get(transition.Object());
                                    highestScoreFromGoldState             = transition.Score();
                                }
                            }
                        }
                        // This can happen if the REORDER_BEAM method backs itself
                        // into a corner, such as transitioning to something that
                        // can't have a FinalizeTransition applied.  This doesn't
                        // happen for the BEAM method because in that case the correct
                        // state (eg one with ROOT) isn't on the agenda so it stops.
                        if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam && highestScoringTransitionFromGoldState == null)
                        {
                            break;
                        }
                        State newGoldState = goldTransition.Apply(goldState, 0.0);
                        // if highest scoring state used the correct transition, no training
                        // otherwise, down the last transition, up the correct
                        if (!newGoldState.AreTransitionsEqual(highestScoringState))
                        {
                            ++numWrong;
                            IList <string> goldFeatures   = featureFactory.Featurize(goldState);
                            int            lastTransition = transitionIndex.IndexOf(highestScoringState.transitions.Peek());
                            updates.Add(new PerceptronModel.Update(featureFactory.Featurize(highestCurrentState), -1, lastTransition, learningRate));
                            updates.Add(new PerceptronModel.Update(goldFeatures, transitionIndex.IndexOf(goldTransition), -1, learningRate));
                            if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Beam)
                            {
                                // If the correct state has fallen off the agenda, break
                                if (!ShiftReduceUtils.FindStateOnAgenda(newAgenda, newGoldState))
                                {
                                    break;
                                }
                                else
                                {
                                    transitions.Remove(0);
                                }
                            }
                            else
                            {
                                if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam)
                                {
                                    if (!ShiftReduceUtils.FindStateOnAgenda(newAgenda, newGoldState))
                                    {
                                        if (!reorderer.Reorder(goldState, highestScoringTransitionFromGoldState, transitions))
                                        {
                                            break;
                                        }
                                        newGoldState = highestScoringTransitionFromGoldState.Apply(goldState);
                                        if (!ShiftReduceUtils.FindStateOnAgenda(newAgenda, newGoldState))
                                        {
                                            break;
                                        }
                                    }
                                    else
                                    {
                                        transitions.Remove(0);
                                    }
                                }
                            }
                        }
                        else
                        {
                            ++numCorrect;
                            transitions.Remove(0);
                        }
                        goldState = newGoldState;
                        agenda    = newAgenda;
                    }
                }
                else
                {
                    if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderOracle || op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EarlyTermination || op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod
                        .Gold)
                    {
                        State state = ShiftReduceParser.InitialStateFromGoldTagTree(tree);
                        IList <ITransition> transitions = transitionLists[index];
                        transitions = Generics.NewLinkedList(transitions);
                        bool keepGoing = true;
                        while (transitions.Count > 0 && keepGoing)
                        {
                            ITransition    transition    = transitions[0];
                            int            transitionNum = transitionIndex.IndexOf(transition);
                            IList <string> features      = featureFactory.Featurize(state);
                            int            predictedNum  = FindHighestScoringTransition(state, features, false).Object();
                            ITransition    predicted     = transitionIndex.Get(predictedNum);
                            if (transitionNum == predictedNum)
                            {
                                transitions.Remove(0);
                                state = transition.Apply(state);
                                numCorrect++;
                            }
                            else
                            {
                                numWrong++;
                                // TODO: allow weighted features, weighted training, etc
                                updates.Add(new PerceptronModel.Update(features, transitionNum, predictedNum, learningRate));
                                switch (op.TrainOptions().trainingMethod)
                                {
                                case ShiftReduceTrainOptions.TrainingMethod.EarlyTermination:
                                {
                                    keepGoing = false;
                                    break;
                                }

                                case ShiftReduceTrainOptions.TrainingMethod.Gold:
                                {
                                    transitions.Remove(0);
                                    state = transition.Apply(state);
                                    break;
                                }

                                case ShiftReduceTrainOptions.TrainingMethod.ReorderOracle:
                                {
                                    keepGoing = reorderer.Reorder(state, predicted, transitions);
                                    if (keepGoing)
                                    {
                                        state = predicted.Apply(state);
                                    }
                                    break;
                                }

                                default:
                                {
                                    throw new ArgumentException("Unexpected method " + op.TrainOptions().trainingMethod);
                                }
                                }
                            }
                        }
                    }
                }
            }
            return(Pair.MakePair(numCorrect, numWrong));
        }
コード例 #16
0
 public virtual bool Parse(Tree tree)
 {
     this.originalSentence = tree.YieldHasWord();
     initialState          = ShiftReduceParser.InitialStateFromGoldTagTree(tree);
     return(ParseInternal());
 }
コード例 #17
0
 public ShiftReduceParserQuery(ShiftReduceParser parser)
 {
     this.parser = parser;
 }