public virtual void TestEndToEndSingleUnaries() { IList <Tree> binarizedTrees = BuildTestTreebank(); Oracle oracle = new Oracle(binarizedTrees, false, Java.Util.Collections.Singleton("ROOT")); RunEndToEndTest(binarizedTrees, oracle); }
public virtual void TestBuildParentMap() { Tree tree = Tree.ValueOf("(A (B foo) (C bar))"); IDictionary <Tree, Tree> parents = Oracle.BuildParentMap(tree); int total = RecursiveTestBuildParentMap(tree, parents); NUnit.Framework.Assert.AreEqual(total, parents.Count); }
public TrainTreeProcessor(PerceptronModel _enclosing, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, IList <PerceptronModel.Update> updates, Oracle oracle) { this._enclosing = _enclosing; // this needs to be a synchronized list this.binarizedTrees = binarizedTrees; this.transitionLists = transitionLists; this.updates = updates; this.oracle = oracle; }
public static void RunEndToEndTest(IList <Tree> binarizedTrees, Oracle oracle) { for (int index = 0; index < binarizedTrees.Count; ++index) { State state = ShiftReduceParser.InitialStateFromGoldTagTree(binarizedTrees[index]); while (!state.IsFinished()) { OracleTransition gold = oracle.GoldTransition(index, state); NUnit.Framework.Assert.IsTrue(gold.transition != null); state = gold.transition.Apply(state); } NUnit.Framework.Assert.AreEqual(binarizedTrees[index], state.stack.Peek()); } }
private void TrainModel(string serializedPath, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger, Random random, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, Treebank devTreebank, int nThreads, ICollection <string> allowedFeatures ) { double bestScore = 0.0; int bestIteration = 0; PriorityQueue <ScoredObject <PerceptronModel> > bestModels = null; if (op.TrainOptions().averagedModels > 0) { bestModels = new PriorityQueue <ScoredObject <PerceptronModel> >(op.TrainOptions().averagedModels + 1, ScoredComparator.AscendingComparator); } IList <int> indices = Generics.NewArrayList(); for (int i = 0; i < binarizedTrees.Count; ++i) { indices.Add(i); } Oracle oracle = null; if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Oracle) { oracle = new Oracle(binarizedTrees, op.compoundUnaries, rootStates); } IList <PerceptronModel.Update> updates = Generics.NewArrayList(); MulticoreWrapper <int, Pair <int, int> > wrapper = null; if (nThreads != 1) { updates = Java.Util.Collections.SynchronizedList(updates); wrapper = new MulticoreWrapper <int, Pair <int, int> >(op.trainOptions.trainingThreads, new PerceptronModel.TrainTreeProcessor(this, binarizedTrees, transitionLists, updates, oracle)); } IntCounter <string> featureFrequencies = null; if (op.TrainOptions().featureFrequencyCutoff > 1) { featureFrequencies = new IntCounter <string>(); } for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration) { Timing trainingTimer = new Timing(); int numCorrect = 0; int numWrong = 0; Java.Util.Collections.Shuffle(indices, random); for (int start = 0; start < indices.Count; start += op.trainOptions.batchSize) { int end = Math.Min(start + op.trainOptions.batchSize, indices.Count); Triple <IList <PerceptronModel.Update>, int, int> result = TrainBatch(indices.SubList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper); numCorrect += result.second; numWrong += result.third; foreach (PerceptronModel.Update update in result.first) { foreach (string feature in update.features) { if (allowedFeatures != null && !allowedFeatures.Contains(feature)) { continue; } Weight weights = featureWeights[feature]; if (weights == null) { weights = new Weight(); featureWeights[feature] = weights; } weights.UpdateWeight(update.goldTransition, update.delta); weights.UpdateWeight(update.predictedTransition, -update.delta); if (featureFrequencies != null) { featureFrequencies.IncrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1); } } } updates.Clear(); } trainingTimer.Done("Iteration " + iteration); log.Info("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong"); OutputStats(); double labelF1 = 0.0; if (devTreebank != null) { EvaluateTreebank evaluator = new EvaluateTreebank(op, null, new ShiftReduceParser(op, this), tagger); evaluator.TestOnTreebank(devTreebank); labelF1 = evaluator.GetLBScore(); log.Info("Label F1 after " + iteration + " iterations: " + labelF1); if (labelF1 > bestScore) { log.Info("New best dev score (previous best " + bestScore + ")"); bestScore = labelF1; bestIteration = iteration; } else { log.Info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore); if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit)) { log.Info("Failed to improve for too long, stopping training"); break; } } log.Info(); if (bestModels != null) { bestModels.Add(new ScoredObject <PerceptronModel>(new PerceptronModel(this), labelF1)); if (bestModels.Count > op.TrainOptions().averagedModels) { bestModels.Poll(); } } } if (op.TrainOptions().saveIntermediateModels&& serializedPath != null && op.trainOptions.debugOutputFrequency > 0) { string tempName = Sharpen.Runtime.Substring(serializedPath, 0, serializedPath.Length - 7) + "-" + Filename.Format(iteration) + "-" + Nf.Format(labelF1) + ".ser.gz"; ShiftReduceParser temp = new ShiftReduceParser(op, this); temp.SaveModel(tempName); } // TODO: we could save a cutoff version of the model, // especially if we also get a dev set number for it, but that // might be overkill if (iteration % 10 == 0 && op.TrainOptions().decayLearningRate > 0.0) { learningRate *= op.TrainOptions().decayLearningRate; } } // end for iterations if (wrapper != null) { wrapper.Join(); } if (bestModels != null) { if (op.TrainOptions().cvAveragedModels&& devTreebank != null) { IList <ScoredObject <PerceptronModel> > models = Generics.NewArrayList(); while (bestModels.Count > 0) { models.Add(bestModels.Poll()); } Java.Util.Collections.Reverse(models); double bestF1 = 0.0; int bestSize = 0; for (int i_1 = 1; i_1 <= models.Count; ++i_1) { log.Info("Testing with " + i_1 + " models averaged together"); // TODO: this is kind of ugly, would prefer a separate object AverageScoredModels(models.SubList(0, i_1)); ShiftReduceParser temp = new ShiftReduceParser(op, this); EvaluateTreebank evaluator = new EvaluateTreebank(temp.GetOp(), null, temp, tagger); evaluator.TestOnTreebank(devTreebank); double labelF1 = evaluator.GetLBScore(); log.Info("Label F1 for " + i_1 + " models: " + labelF1); if (labelF1 > bestF1) { bestF1 = labelF1; bestSize = i_1; } } AverageScoredModels(models.SubList(0, bestSize)); } else { AverageScoredModels(bestModels); } } // TODO: perhaps we should filter the features and then get dev // set scores. That way we can merge the models which are best // after filtering. if (featureFrequencies != null) { FilterFeatures(featureFrequencies.KeysAbove(op.TrainOptions().featureFrequencyCutoff)); } CondenseFeatures(); }
/// <summary> /// Trains a batch of trees and returns the following: a list of /// Update objects, the number of transitions correct, and the number /// of transitions wrong. /// </summary> /// <remarks> /// Trains a batch of trees and returns the following: a list of /// Update objects, the number of transitions correct, and the number /// of transitions wrong. /// If the model is trained with multiple threads, it is expected /// that a valid MulticoreWrapper is passed in which does the /// processing. In that case, the processing is done on all of the /// trees without updating any weights, which allows the results for /// multithreaded training to be reproduced. /// </remarks> private Triple <IList <PerceptronModel.Update>, int, int> TrainBatch(IList <int> indices, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, IList <PerceptronModel.Update> updates, Oracle oracle, MulticoreWrapper <int, Pair <int , int> > wrapper) { int numCorrect = 0; int numWrong = 0; if (op.trainOptions.trainingThreads == 1) { foreach (int index in indices) { Pair <int, int> count = TrainTree(index, binarizedTrees, transitionLists, updates, oracle); numCorrect += count.first; numWrong += count.second; } } else { foreach (int index in indices) { wrapper.Put(index); } wrapper.Join(false); while (wrapper.Peek()) { Pair <int, int> result = wrapper.Poll(); numCorrect += result.first; numWrong += result.second; } } return(new Triple <IList <PerceptronModel.Update>, int, int>(updates, numCorrect, numWrong)); }
private Pair <int, int> TrainTree(int index, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, IList <PerceptronModel.Update> updates, Oracle oracle) { int numCorrect = 0; int numWrong = 0; Tree tree = binarizedTrees[index]; ReorderingOracle reorderer = null; if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderOracle || op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam) { reorderer = new ReorderingOracle(op); } // TODO. This training method seems to be working in that it // trains models just like the gold and early termination methods do. // However, it causes the feature space to go crazy. Presumably // leaving out features with low weights or low frequencies would // significantly help with that. Otherwise, not sure how to keep // it under control. if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Oracle) { State state = ShiftReduceParser.InitialStateFromGoldTagTree(tree); while (!state.IsFinished()) { IList <string> features = featureFactory.Featurize(state); ScoredObject <int> prediction = FindHighestScoringTransition(state, features, true); if (prediction == null) { throw new AssertionError("Did not find a legal transition"); } int predictedNum = prediction.Object(); ITransition predicted = transitionIndex.Get(predictedNum); OracleTransition gold = oracle.GoldTransition(index, state); if (gold.IsCorrect(predicted)) { numCorrect++; if (gold.transition != null && !gold.transition.Equals(predicted)) { int transitionNum = transitionIndex.IndexOf(gold.transition); if (transitionNum < 0) { // TODO: do we want to add unary transitions which are // only possible when the parser has gone off the rails? continue; } updates.Add(new PerceptronModel.Update(features, transitionNum, -1, learningRate)); } } else { numWrong++; int transitionNum = -1; if (gold.transition != null) { transitionNum = transitionIndex.IndexOf(gold.transition); } // TODO: this can theoretically result in a -1 gold // transition if the transition exists, but is a // CompoundUnaryTransition which only exists because the // parser is wrong. Do we want to add those transitions? updates.Add(new PerceptronModel.Update(features, transitionNum, predictedNum, learningRate)); } state = predicted.Apply(state); } } else { if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Beam || op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam) { if (op.TrainOptions().beamSize <= 0) { throw new ArgumentException("Illegal beam size " + op.TrainOptions().beamSize); } IList <ITransition> transitions = Generics.NewLinkedList(transitionLists[index]); PriorityQueue <State> agenda = new PriorityQueue <State>(op.TrainOptions().beamSize + 1, ScoredComparator.AscendingComparator); State goldState = ShiftReduceParser.InitialStateFromGoldTagTree(tree); agenda.Add(goldState); // int transitionCount = 0; while (transitions.Count > 0) { ITransition goldTransition = transitions[0]; ITransition highestScoringTransitionFromGoldState = null; double highestScoreFromGoldState = 0.0; PriorityQueue <State> newAgenda = new PriorityQueue <State>(op.TrainOptions().beamSize + 1, ScoredComparator.AscendingComparator); State highestScoringState = null; State highestCurrentState = null; foreach (State currentState in agenda) { bool isGoldState = (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam && goldState.AreTransitionsEqual(currentState)); IList <string> features = featureFactory.Featurize(currentState); ICollection <ScoredObject <int> > stateTransitions = FindHighestScoringTransitions(currentState, features, true, op.TrainOptions().beamSize, null); foreach (ScoredObject <int> transition in stateTransitions) { State newState = transitionIndex.Get(transition.Object()).Apply(currentState, transition.Score()); newAgenda.Add(newState); if (newAgenda.Count > op.TrainOptions().beamSize) { newAgenda.Poll(); } if (highestScoringState == null || highestScoringState.Score() < newState.Score()) { highestScoringState = newState; highestCurrentState = currentState; } if (isGoldState && (highestScoringTransitionFromGoldState == null || transition.Score() > highestScoreFromGoldState)) { highestScoringTransitionFromGoldState = transitionIndex.Get(transition.Object()); highestScoreFromGoldState = transition.Score(); } } } // This can happen if the REORDER_BEAM method backs itself // into a corner, such as transitioning to something that // can't have a FinalizeTransition applied. This doesn't // happen for the BEAM method because in that case the correct // state (eg one with ROOT) isn't on the agenda so it stops. if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam && highestScoringTransitionFromGoldState == null) { break; } State newGoldState = goldTransition.Apply(goldState, 0.0); // if highest scoring state used the correct transition, no training // otherwise, down the last transition, up the correct if (!newGoldState.AreTransitionsEqual(highestScoringState)) { ++numWrong; IList <string> goldFeatures = featureFactory.Featurize(goldState); int lastTransition = transitionIndex.IndexOf(highestScoringState.transitions.Peek()); updates.Add(new PerceptronModel.Update(featureFactory.Featurize(highestCurrentState), -1, lastTransition, learningRate)); updates.Add(new PerceptronModel.Update(goldFeatures, transitionIndex.IndexOf(goldTransition), -1, learningRate)); if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Beam) { // If the correct state has fallen off the agenda, break if (!ShiftReduceUtils.FindStateOnAgenda(newAgenda, newGoldState)) { break; } else { transitions.Remove(0); } } else { if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderBeam) { if (!ShiftReduceUtils.FindStateOnAgenda(newAgenda, newGoldState)) { if (!reorderer.Reorder(goldState, highestScoringTransitionFromGoldState, transitions)) { break; } newGoldState = highestScoringTransitionFromGoldState.Apply(goldState); if (!ShiftReduceUtils.FindStateOnAgenda(newAgenda, newGoldState)) { break; } } else { transitions.Remove(0); } } } } else { ++numCorrect; transitions.Remove(0); } goldState = newGoldState; agenda = newAgenda; } } else { if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ReorderOracle || op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EarlyTermination || op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod .Gold) { State state = ShiftReduceParser.InitialStateFromGoldTagTree(tree); IList <ITransition> transitions = transitionLists[index]; transitions = Generics.NewLinkedList(transitions); bool keepGoing = true; while (transitions.Count > 0 && keepGoing) { ITransition transition = transitions[0]; int transitionNum = transitionIndex.IndexOf(transition); IList <string> features = featureFactory.Featurize(state); int predictedNum = FindHighestScoringTransition(state, features, false).Object(); ITransition predicted = transitionIndex.Get(predictedNum); if (transitionNum == predictedNum) { transitions.Remove(0); state = transition.Apply(state); numCorrect++; } else { numWrong++; // TODO: allow weighted features, weighted training, etc updates.Add(new PerceptronModel.Update(features, transitionNum, predictedNum, learningRate)); switch (op.TrainOptions().trainingMethod) { case ShiftReduceTrainOptions.TrainingMethod.EarlyTermination: { keepGoing = false; break; } case ShiftReduceTrainOptions.TrainingMethod.Gold: { transitions.Remove(0); state = transition.Apply(state); break; } case ShiftReduceTrainOptions.TrainingMethod.ReorderOracle: { keepGoing = reorderer.Reorder(state, predicted, transitions); if (keepGoing) { state = predicted.Apply(state); } break; } default: { throw new ArgumentException("Unexpected method " + op.TrainOptions().trainingMethod); } } } } } } } return(Pair.MakePair(numCorrect, numWrong)); }