示例#1
0
        private void Train(IList <Pair <string, IFileFilter> > trainTreebankPath, Pair <string, IFileFilter> devTreebankPath, string serializedPath)
        {
            log.Info("Training method: " + op.TrainOptions().trainingMethod);
            IList <Tree> binarizedTrees = Generics.NewArrayList();

            foreach (Pair <string, IFileFilter> treebank in trainTreebankPath)
            {
                Sharpen.Collections.AddAll(binarizedTrees, ReadBinarizedTreebank(treebank.First(), treebank.Second()));
            }
            int nThreads = op.trainOptions.trainingThreads;

            nThreads = nThreads <= 0 ? Runtime.GetRuntime().AvailableProcessors() : nThreads;
            Edu.Stanford.Nlp.Tagger.Common.Tagger tagger = null;
            if (op.testOptions.preTag)
            {
                Timing retagTimer = new Timing();
                tagger = Edu.Stanford.Nlp.Tagger.Common.Tagger.LoadModel(op.testOptions.taggerSerializedFile);
                RedoTags(binarizedTrees, tagger, nThreads);
                retagTimer.Done("Retagging");
            }
            ICollection <string> knownStates    = FindKnownStates(binarizedTrees);
            ICollection <string> rootStates     = FindRootStates(binarizedTrees);
            ICollection <string> rootOnlyStates = FindRootOnlyStates(binarizedTrees, rootStates);

            log.Info("Known states: " + knownStates);
            log.Info("States which occur at the root: " + rootStates);
            log.Info("States which only occur at the root: " + rootStates);
            Timing transitionTimer = new Timing();
            IList <IList <ITransition> > transitionLists = CreateTransitionSequence.CreateTransitionSequences(binarizedTrees, op.compoundUnaries, rootStates, rootOnlyStates);
            IIndex <ITransition>         transitionIndex = new HashIndex <ITransition>();

            foreach (IList <ITransition> transitions in transitionLists)
            {
                transitionIndex.AddAll(transitions);
            }
            transitionTimer.Done("Converting trees into transition lists");
            log.Info("Number of transitions: " + transitionIndex.Size());
            Random   random      = new Random(op.trainOptions.randomSeed);
            Treebank devTreebank = null;

            if (devTreebankPath != null)
            {
                devTreebank = ReadTreebank(devTreebankPath.First(), devTreebankPath.Second());
            }
            PerceptronModel newModel = new PerceptronModel(this.op, transitionIndex, knownStates, rootStates, rootOnlyStates);

            newModel.TrainModel(serializedPath, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads);
            this.model = newModel;
        }
示例#2
0
        // TODO: factor out the retagging?
        public static void RedoTags(Tree tree, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger)
        {
            IList <Word>       words  = tree.YieldWords();
            IList <TaggedWord> tagged = tagger.Apply(words);
            IList <ILabel>     tags   = tree.PreTerminalYield();

            if (tags.Count != tagged.Count)
            {
                throw new AssertionError("Tags are not the same size");
            }
            for (int i = 0; i < tags.Count; ++i)
            {
                tags[i].SetValue(tagged[i].Tag());
            }
        }
示例#3
0
 /// <summary>
 /// Will train the model on the given treebank, using devTreebank as
 /// a dev set.
 /// </summary>
 /// <remarks>
 /// Will train the model on the given treebank, using devTreebank as
 /// a dev set.  If op.retrainAfterCutoff is set, will rerun training
 /// after the first time through on a limited set of features.
 /// </remarks>
 public override void TrainModel(string serializedPath, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger, Random random, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, Treebank devTreebank, int nThreads)
 {
     if (op.TrainOptions().retrainAfterCutoff&& op.TrainOptions().featureFrequencyCutoff > 0)
     {
         string tempName = Sharpen.Runtime.Substring(serializedPath, 0, serializedPath.Length - 7) + "-" + "temp.ser.gz";
         TrainModel(tempName, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, null);
         ShiftReduceParser temp = new ShiftReduceParser(op, this);
         temp.SaveModel(tempName);
         ICollection <string> features = featureWeights.Keys;
         featureWeights = Generics.NewHashMap();
         TrainModel(serializedPath, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, features);
     }
     else
     {
         TrainModel(serializedPath, tagger, random, binarizedTrees, transitionLists, devTreebank, nThreads, null);
     }
 }
示例#4
0
 public static void RedoTags(IList <Tree> trees, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger, int nThreads)
 {
     if (nThreads == 1)
     {
         foreach (Tree tree in trees)
         {
             RedoTags(tree, tagger);
         }
     }
     else
     {
         MulticoreWrapper <Tree, Tree> wrapper = new MulticoreWrapper <Tree, Tree>(nThreads, new ShiftReduceParser.RetagProcessor(tagger));
         foreach (Tree tree in trees)
         {
             wrapper.Put(tree);
         }
         wrapper.Join();
     }
 }
示例#5
0
        private void TrainModel(string serializedPath, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger, Random random, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, Treebank devTreebank, int nThreads, ICollection <string> allowedFeatures
                                )
        {
            double bestScore     = 0.0;
            int    bestIteration = 0;
            PriorityQueue <ScoredObject <PerceptronModel> > bestModels = null;

            if (op.TrainOptions().averagedModels > 0)
            {
                bestModels = new PriorityQueue <ScoredObject <PerceptronModel> >(op.TrainOptions().averagedModels + 1, ScoredComparator.AscendingComparator);
            }
            IList <int> indices = Generics.NewArrayList();

            for (int i = 0; i < binarizedTrees.Count; ++i)
            {
                indices.Add(i);
            }
            Oracle oracle = null;

            if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Oracle)
            {
                oracle = new Oracle(binarizedTrees, op.compoundUnaries, rootStates);
            }
            IList <PerceptronModel.Update>           updates = Generics.NewArrayList();
            MulticoreWrapper <int, Pair <int, int> > wrapper = null;

            if (nThreads != 1)
            {
                updates = Java.Util.Collections.SynchronizedList(updates);
                wrapper = new MulticoreWrapper <int, Pair <int, int> >(op.trainOptions.trainingThreads, new PerceptronModel.TrainTreeProcessor(this, binarizedTrees, transitionLists, updates, oracle));
            }
            IntCounter <string> featureFrequencies = null;

            if (op.TrainOptions().featureFrequencyCutoff > 1)
            {
                featureFrequencies = new IntCounter <string>();
            }
            for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration)
            {
                Timing trainingTimer = new Timing();
                int    numCorrect    = 0;
                int    numWrong      = 0;
                Java.Util.Collections.Shuffle(indices, random);
                for (int start = 0; start < indices.Count; start += op.trainOptions.batchSize)
                {
                    int end = Math.Min(start + op.trainOptions.batchSize, indices.Count);
                    Triple <IList <PerceptronModel.Update>, int, int> result = TrainBatch(indices.SubList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper);
                    numCorrect += result.second;
                    numWrong   += result.third;
                    foreach (PerceptronModel.Update update in result.first)
                    {
                        foreach (string feature in update.features)
                        {
                            if (allowedFeatures != null && !allowedFeatures.Contains(feature))
                            {
                                continue;
                            }
                            Weight weights = featureWeights[feature];
                            if (weights == null)
                            {
                                weights = new Weight();
                                featureWeights[feature] = weights;
                            }
                            weights.UpdateWeight(update.goldTransition, update.delta);
                            weights.UpdateWeight(update.predictedTransition, -update.delta);
                            if (featureFrequencies != null)
                            {
                                featureFrequencies.IncrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1);
                            }
                        }
                    }
                    updates.Clear();
                }
                trainingTimer.Done("Iteration " + iteration);
                log.Info("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong");
                OutputStats();
                double labelF1 = 0.0;
                if (devTreebank != null)
                {
                    EvaluateTreebank evaluator = new EvaluateTreebank(op, null, new ShiftReduceParser(op, this), tagger);
                    evaluator.TestOnTreebank(devTreebank);
                    labelF1 = evaluator.GetLBScore();
                    log.Info("Label F1 after " + iteration + " iterations: " + labelF1);
                    if (labelF1 > bestScore)
                    {
                        log.Info("New best dev score (previous best " + bestScore + ")");
                        bestScore     = labelF1;
                        bestIteration = iteration;
                    }
                    else
                    {
                        log.Info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore);
                        if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit))
                        {
                            log.Info("Failed to improve for too long, stopping training");
                            break;
                        }
                    }
                    log.Info();
                    if (bestModels != null)
                    {
                        bestModels.Add(new ScoredObject <PerceptronModel>(new PerceptronModel(this), labelF1));
                        if (bestModels.Count > op.TrainOptions().averagedModels)
                        {
                            bestModels.Poll();
                        }
                    }
                }
                if (op.TrainOptions().saveIntermediateModels&& serializedPath != null && op.trainOptions.debugOutputFrequency > 0)
                {
                    string            tempName = Sharpen.Runtime.Substring(serializedPath, 0, serializedPath.Length - 7) + "-" + Filename.Format(iteration) + "-" + Nf.Format(labelF1) + ".ser.gz";
                    ShiftReduceParser temp     = new ShiftReduceParser(op, this);
                    temp.SaveModel(tempName);
                }
                // TODO: we could save a cutoff version of the model,
                // especially if we also get a dev set number for it, but that
                // might be overkill
                if (iteration % 10 == 0 && op.TrainOptions().decayLearningRate > 0.0)
                {
                    learningRate *= op.TrainOptions().decayLearningRate;
                }
            }
            // end for iterations
            if (wrapper != null)
            {
                wrapper.Join();
            }
            if (bestModels != null)
            {
                if (op.TrainOptions().cvAveragedModels&& devTreebank != null)
                {
                    IList <ScoredObject <PerceptronModel> > models = Generics.NewArrayList();
                    while (bestModels.Count > 0)
                    {
                        models.Add(bestModels.Poll());
                    }
                    Java.Util.Collections.Reverse(models);
                    double bestF1   = 0.0;
                    int    bestSize = 0;
                    for (int i_1 = 1; i_1 <= models.Count; ++i_1)
                    {
                        log.Info("Testing with " + i_1 + " models averaged together");
                        // TODO: this is kind of ugly, would prefer a separate object
                        AverageScoredModels(models.SubList(0, i_1));
                        ShiftReduceParser temp      = new ShiftReduceParser(op, this);
                        EvaluateTreebank  evaluator = new EvaluateTreebank(temp.GetOp(), null, temp, tagger);
                        evaluator.TestOnTreebank(devTreebank);
                        double labelF1 = evaluator.GetLBScore();
                        log.Info("Label F1 for " + i_1 + " models: " + labelF1);
                        if (labelF1 > bestF1)
                        {
                            bestF1   = labelF1;
                            bestSize = i_1;
                        }
                    }
                    AverageScoredModels(models.SubList(0, bestSize));
                }
                else
                {
                    AverageScoredModels(bestModels);
                }
            }
            // TODO: perhaps we should filter the features and then get dev
            // set scores.  That way we can merge the models which are best
            // after filtering.
            if (featureFrequencies != null)
            {
                FilterFeatures(featureFrequencies.KeysAbove(op.TrainOptions().featureFrequencyCutoff));
            }
            CondenseFeatures();
        }
示例#6
0
 public RetagProcessor(Edu.Stanford.Nlp.Tagger.Common.Tagger tagger)
 {
     this.tagger = tagger;
 }
示例#7
0
 /// <summary>Train a new model.</summary>
 /// <remarks>
 /// Train a new model.  This is the method to override for new models
 /// such that the ShiftReduceParser will fill in the model.  Given a
 /// collection of training trees and some other various information,
 /// this should train a new model.  The model is expected to already
 /// know about the possible transitions and which states are eligible
 /// to be root states via the BaseModel constructor.
 /// </remarks>
 /// <param name="serializedPath">Where serialized models go.  If the appropriate options are set, the method can use this to save intermediate models.</param>
 /// <param name="tagger">The tagger to use when evaluating devTreebank.  TODO: it would make more sense for ShiftReduceParser to retag the trees first</param>
 /// <param name="random">A random number generator to use for any random numbers.  Useful to make sure results can be reproduced.</param>
 /// <param name="binarizedTrainTrees">The treebank to train from.</param>
 /// <param name="transitionLists">binarizedTrainTrees converted into lists of transitions that will reproduce the same tree.</param>
 /// <param name="devTreebank">a set of trees which can be used for dev testing (assuming the user provided a dev treebank)</param>
 /// <param name="nThreads">how many threads the model can use for training</param>
 public abstract void TrainModel(string serializedPath, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger, Random random, IList <Tree> binarizedTrainTrees, IList <IList <ITransition> > transitionLists, Treebank devTreebank, int nThreads);