/// <summary> /// An example command line for training a new parser: /// <br /> /// nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories > /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2>&1 & /// </summary> /// <exception cref="System.IO.IOException"/> /// <exception cref="System.TypeLoadException"/> public static void Main(string[] args) { if (args.Length == 0) { Help(); System.Environment.Exit(2); } log.Info("Running DVParser with arguments:"); foreach (string arg in args) { log.Info(" " + arg); } log.Info(); string parserPath = null; string trainTreebankPath = null; IFileFilter trainTreebankFilter = null; string cachedTrainTreesPath = null; bool runGradientCheck = false; bool runTraining = false; string testTreebankPath = null; IFileFilter testTreebankFilter = null; string initialModelPath = null; string modelPath = null; bool filter = true; string resultsRecordPath = null; IList <string> unusedArgs = new List <string>(); // These parameters can be null or 0 if the model was not // serialized with the new parameters. Setting the options at the // command line will override these defaults. // TODO: if/when we integrate back into the main branch and // rebuild models, we can get rid of this IList <string> argsWithDefaults = new List <string>(Arrays.AsList(new string[] { "-wordVectorFile", Options.LexOptions.DefaultWordVectorFile, "-dvKBest", int.ToString(TrainOptions.DefaultKBest), "-batchSize", int.ToString(TrainOptions.DefaultBatchSize ), "-trainingIterations", int.ToString(TrainOptions.DefaultTrainingIterations), "-qnIterationsPerBatch", int.ToString(TrainOptions.DefaultQnIterationsPerBatch), "-regCost", double.ToString(TrainOptions.DefaultRegcost), "-learningRate", double .ToString(TrainOptions.DefaultLearningRate), "-deltaMargin", double.ToString(TrainOptions.DefaultDeltaMargin), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector" , "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", double.ToString(TrainOptions.DefaultScalingForInit), "-trainWordVectors" })); Sharpen.Collections.AddAll(argsWithDefaults, Arrays.AsList(args)); args = Sharpen.Collections.ToArray(argsWithDefaults, new string[argsWithDefaults.Count]); for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-parser")) { parserPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; testTreebankPath = treebankDescription.First(); testTreebankFilter = treebankDescription.Second(); } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-treebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-treebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; trainTreebankPath = treebankDescription.First(); trainTreebankFilter = treebankDescription.Second(); } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-cachedTrees")) { cachedTrainTreesPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-runGradientCheck")) { runGradientCheck = true; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-train")) { runTraining = true; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-nofilter")) { filter = false; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-continueTraining")) { runTraining = true; filter = false; initialModelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-resultsRecord")) { resultsRecordPath = args[argIndex + 1]; argIndex += 2; } else { unusedArgs.Add(args[argIndex++]); } } } } } } } } } } } if (parserPath == null && modelPath == null) { throw new ArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model"); } if (!runTraining && modelPath == null && !runGradientCheck) { throw new ArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model"); } string[] newArgs = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]); Edu.Stanford.Nlp.Parser.Dvparser.DVParser dvparser = null; LexicalizedParser lexparser = null; if (initialModelPath != null) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(initialModelPath, newArgs)); DVModel model = GetModelFromLexicalizedParser(lexparser); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(model, lexparser); } else { if (runTraining || runGradientCheck) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(parserPath, newArgs)); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(lexparser); } else { if (modelPath != null) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath, newArgs)); DVModel model = GetModelFromLexicalizedParser(lexparser); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(model, lexparser); } } } IList <Tree> trainSentences = new List <Tree>(); IdentityHashMap <Tree, byte[]> trainCompressedParses = Generics.NewIdentityHashMap(); if (cachedTrainTreesPath != null) { foreach (string path in cachedTrainTreesPath.Split(",")) { IList <Pair <Tree, byte[]> > cache = IOUtils.ReadObjectFromFile(path); foreach (Pair <Tree, byte[]> pair in cache) { trainSentences.Add(pair.First()); trainCompressedParses[pair.First()] = pair.Second(); } log.Info("Read in " + cache.Count + " trees from " + path); } } if (trainTreebankPath != null) { // TODO: make the transformer a member of the model? ITreeTransformer transformer = BuildTrainTransformer(dvparser.GetOp()); Treebank treebank = dvparser.GetOp().tlpParams.MemoryTreebank(); treebank.LoadPath(trainTreebankPath, trainTreebankFilter); treebank = treebank.Transform(transformer); log.Info("Read in " + treebank.Count + " trees from " + trainTreebankPath); CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser); CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer); foreach (Tree tree in treebank) { trainSentences.Add(tree); trainCompressedParses[tree] = processor.Process(tree).second; } //System.out.println(tree); log.Info("Finished parsing " + treebank.Count + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each"); } if ((runTraining || runGradientCheck) && filter) { log.Info("Filtering rules for the given training set"); dvparser.dvModel.SetRulesForTrainingSet(trainSentences, trainCompressedParses); log.Info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.Count + " word vectors"); } //dvparser.dvModel.printAllMatrices(); Treebank testTreebank = null; if (testTreebankPath != null) { log.Info("Reading in trees from " + testTreebankPath); if (testTreebankFilter != null) { log.Info("Filtering on " + testTreebankFilter); } testTreebank = dvparser.GetOp().tlpParams.MemoryTreebank(); testTreebank.LoadPath(testTreebankPath, testTreebankFilter); log.Info("Read in " + testTreebank.Count + " trees for testing"); } // runGradientCheck= true; if (runGradientCheck) { log.Info("Running gradient check on " + trainSentences.Count + " trees"); dvparser.RunGradientCheck(trainSentences, trainCompressedParses); } if (runTraining) { log.Info("Training the RNN parser"); log.Info("Current train options: " + dvparser.GetOp().trainOptions); dvparser.Train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath); if (modelPath != null) { dvparser.SaveModel(modelPath); } } if (testTreebankPath != null) { EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.AttachModelToLexicalizedParser()); evaluator.TestOnTreebank(testTreebank); } log.Info("Successfully ran DVParser"); }
/// <exception cref="System.IO.IOException"/> public virtual void Train(IList <Tree> sentences, IdentityHashMap <Tree, byte[]> compressedParses, Treebank testTreebank, string modelPath, string resultsRecordPath) { // process: // we come up with a cost and a derivative for the model // we always use the gold tree as the example to train towards // every time through, we will look at the top N trees from // the LexicalizedParser and pick the best one according to // our model (at the start, this is essentially random) // we use QN to minimize the cost function for the model // to do this minimization, we turn all of the matrices in the // DVModel into one big Theta, which is the set of variables to // be optimized by the QN. Timing timing = new Timing(); long maxTrainTimeMillis = op.trainOptions.maxTrainTimeSeconds * 1000; int batchCount = 0; int debugCycle = 0; double bestLabelF1 = 0.0; if (op.trainOptions.useContextWords) { foreach (Tree tree in sentences) { Edu.Stanford.Nlp.Trees.Trees.ConvertToCoreLabels(tree); tree.SetSpans(); } } // for AdaGrad double[] sumGradSquare = new double[dvModel.TotalParamSize()]; Arrays.Fill(sumGradSquare, 1.0); int numBatches = sentences.Count / op.trainOptions.batchSize + 1; log.Info("Training on " + sentences.Count + " trees in " + numBatches + " batches"); log.Info("Times through each training batch: " + op.trainOptions.trainingIterations); log.Info("QN iterations per batch: " + op.trainOptions.qnIterationsPerBatch); for (int iter = 0; iter < op.trainOptions.trainingIterations; ++iter) { IList <Tree> shuffledSentences = new List <Tree>(sentences); Java.Util.Collections.Shuffle(shuffledSentences, dvModel.rand); for (int batch = 0; batch < numBatches; ++batch) { ++batchCount; // This did not help performance //log.info("Setting AdaGrad's sum of squares to 1..."); //Arrays.fill(sumGradSquare, 1.0); log.Info("======================================"); log.Info("Iteration " + iter + " batch " + batch); // Each batch will be of the specified batch size, except the // last batch will include any leftover trees at the end of // the list int startTree = batch * op.trainOptions.batchSize; int endTree = (batch + 1) * op.trainOptions.batchSize; if (endTree > shuffledSentences.Count) { endTree = shuffledSentences.Count; } ExecuteOneTrainingBatch(shuffledSentences.SubList(startTree, endTree), compressedParses, sumGradSquare); long totalElapsed = timing.Report(); log.Info("Finished iteration " + iter + " batch " + batch + "; total training time " + totalElapsed + " ms"); if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) { // no need to debug output, we're done now break; } if (op.trainOptions.debugOutputFrequency > 0 && batchCount % op.trainOptions.debugOutputFrequency == 0) { log.Info("Finished " + batchCount + " total batches, running evaluation cycle"); // Time for debugging output! double tagF1 = 0.0; double labelF1 = 0.0; if (testTreebank != null) { EvaluateTreebank evaluator = new EvaluateTreebank(AttachModelToLexicalizedParser()); evaluator.TestOnTreebank(testTreebank); labelF1 = evaluator.GetLBScore(); tagF1 = evaluator.GetTagScore(); if (labelF1 > bestLabelF1) { bestLabelF1 = labelF1; } log.Info("Best label f1 on dev set so far: " + Nf.Format(bestLabelF1)); } string tempName = null; if (modelPath != null) { tempName = modelPath; if (modelPath.EndsWith(".ser.gz")) { tempName = Sharpen.Runtime.Substring(modelPath, 0, modelPath.Length - 7) + "-" + Filename.Format(debugCycle) + "-" + Nf.Format(labelF1) + ".ser.gz"; } SaveModel(tempName); } string statusLine = ("CHECKPOINT:" + " iteration " + iter + " batch " + batch + " labelF1 " + Nf.Format(labelF1) + " tagF1 " + Nf.Format(tagF1) + " bestLabelF1 " + Nf.Format(bestLabelF1) + " model " + tempName + op.trainOptions + " word vectors: " + op.lexOptions.wordVectorFile + " numHid: " + op.lexOptions.numHid); log.Info(statusLine); if (resultsRecordPath != null) { FileWriter fout = new FileWriter(resultsRecordPath, true); // append fout.Write(statusLine); fout.Write("\n"); fout.Close(); } ++debugCycle; } } long totalElapsed_1 = timing.Report(); if (maxTrainTimeMillis > 0 && totalElapsed_1 > maxTrainTimeMillis) { // no need to debug output, we're done now log.Info("Max training time exceeded, exiting"); break; } } }
/// <exception cref="System.IO.IOException"/> /// <exception cref="System.TypeLoadException"/> public static void Main(string[] args) { string dvmodelFile = null; string lexparserFile = null; string testTreebankPath = null; IFileFilter testTreebankFilter = null; IList <string> unusedArgs = new List <string>(); for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-lexparser")) { lexparserFile = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; testTreebankPath = treebankDescription.First(); testTreebankFilter = treebankDescription.Second(); } else { unusedArgs.Add(args[argIndex++]); } } } log.Info("Loading lexparser from: " + lexparserFile); string[] newArgs = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]); LexicalizedParser lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(lexparserFile, newArgs)); log.Info("... done"); Treebank testTreebank = null; if (testTreebankPath != null) { log.Info("Reading in trees from " + testTreebankPath); if (testTreebankFilter != null) { log.Info("Filtering on " + testTreebankFilter); } testTreebank = lexparser.GetOp().tlpParams.MemoryTreebank(); testTreebank.LoadPath(testTreebankPath, testTreebankFilter); log.Info("Read in " + testTreebank.Count + " trees for testing"); } double[] labelResults = new double[weights.Length]; double[] tagResults = new double[weights.Length]; for (int i = 0; i < weights.Length; ++i) { lexparser.GetOp().baseParserWeight = weights[i]; EvaluateTreebank evaluator = new EvaluateTreebank(lexparser); evaluator.TestOnTreebank(testTreebank); labelResults[i] = evaluator.GetLBScore(); tagResults[i] = evaluator.GetTagScore(); } for (int i_1 = 0; i_1 < weights.Length; ++i_1) { log.Info("LexicalizedParser weight " + weights[i_1] + ": labeled " + labelResults[i_1] + " tag " + tagResults[i_1]); } }
private void TrainModel(string serializedPath, Edu.Stanford.Nlp.Tagger.Common.Tagger tagger, Random random, IList <Tree> binarizedTrees, IList <IList <ITransition> > transitionLists, Treebank devTreebank, int nThreads, ICollection <string> allowedFeatures ) { double bestScore = 0.0; int bestIteration = 0; PriorityQueue <ScoredObject <PerceptronModel> > bestModels = null; if (op.TrainOptions().averagedModels > 0) { bestModels = new PriorityQueue <ScoredObject <PerceptronModel> >(op.TrainOptions().averagedModels + 1, ScoredComparator.AscendingComparator); } IList <int> indices = Generics.NewArrayList(); for (int i = 0; i < binarizedTrees.Count; ++i) { indices.Add(i); } Oracle oracle = null; if (op.TrainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.Oracle) { oracle = new Oracle(binarizedTrees, op.compoundUnaries, rootStates); } IList <PerceptronModel.Update> updates = Generics.NewArrayList(); MulticoreWrapper <int, Pair <int, int> > wrapper = null; if (nThreads != 1) { updates = Java.Util.Collections.SynchronizedList(updates); wrapper = new MulticoreWrapper <int, Pair <int, int> >(op.trainOptions.trainingThreads, new PerceptronModel.TrainTreeProcessor(this, binarizedTrees, transitionLists, updates, oracle)); } IntCounter <string> featureFrequencies = null; if (op.TrainOptions().featureFrequencyCutoff > 1) { featureFrequencies = new IntCounter <string>(); } for (int iteration = 1; iteration <= op.trainOptions.trainingIterations; ++iteration) { Timing trainingTimer = new Timing(); int numCorrect = 0; int numWrong = 0; Java.Util.Collections.Shuffle(indices, random); for (int start = 0; start < indices.Count; start += op.trainOptions.batchSize) { int end = Math.Min(start + op.trainOptions.batchSize, indices.Count); Triple <IList <PerceptronModel.Update>, int, int> result = TrainBatch(indices.SubList(start, end), binarizedTrees, transitionLists, updates, oracle, wrapper); numCorrect += result.second; numWrong += result.third; foreach (PerceptronModel.Update update in result.first) { foreach (string feature in update.features) { if (allowedFeatures != null && !allowedFeatures.Contains(feature)) { continue; } Weight weights = featureWeights[feature]; if (weights == null) { weights = new Weight(); featureWeights[feature] = weights; } weights.UpdateWeight(update.goldTransition, update.delta); weights.UpdateWeight(update.predictedTransition, -update.delta); if (featureFrequencies != null) { featureFrequencies.IncrementCount(feature, (update.goldTransition >= 0 && update.predictedTransition >= 0) ? 2 : 1); } } } updates.Clear(); } trainingTimer.Done("Iteration " + iteration); log.Info("While training, got " + numCorrect + " transitions correct and " + numWrong + " transitions wrong"); OutputStats(); double labelF1 = 0.0; if (devTreebank != null) { EvaluateTreebank evaluator = new EvaluateTreebank(op, null, new ShiftReduceParser(op, this), tagger); evaluator.TestOnTreebank(devTreebank); labelF1 = evaluator.GetLBScore(); log.Info("Label F1 after " + iteration + " iterations: " + labelF1); if (labelF1 > bestScore) { log.Info("New best dev score (previous best " + bestScore + ")"); bestScore = labelF1; bestIteration = iteration; } else { log.Info("Failed to improve for " + (iteration - bestIteration) + " iteration(s) on previous best score of " + bestScore); if (op.trainOptions.stalledIterationLimit > 0 && (iteration - bestIteration >= op.trainOptions.stalledIterationLimit)) { log.Info("Failed to improve for too long, stopping training"); break; } } log.Info(); if (bestModels != null) { bestModels.Add(new ScoredObject <PerceptronModel>(new PerceptronModel(this), labelF1)); if (bestModels.Count > op.TrainOptions().averagedModels) { bestModels.Poll(); } } } if (op.TrainOptions().saveIntermediateModels&& serializedPath != null && op.trainOptions.debugOutputFrequency > 0) { string tempName = Sharpen.Runtime.Substring(serializedPath, 0, serializedPath.Length - 7) + "-" + Filename.Format(iteration) + "-" + Nf.Format(labelF1) + ".ser.gz"; ShiftReduceParser temp = new ShiftReduceParser(op, this); temp.SaveModel(tempName); } // TODO: we could save a cutoff version of the model, // especially if we also get a dev set number for it, but that // might be overkill if (iteration % 10 == 0 && op.TrainOptions().decayLearningRate > 0.0) { learningRate *= op.TrainOptions().decayLearningRate; } } // end for iterations if (wrapper != null) { wrapper.Join(); } if (bestModels != null) { if (op.TrainOptions().cvAveragedModels&& devTreebank != null) { IList <ScoredObject <PerceptronModel> > models = Generics.NewArrayList(); while (bestModels.Count > 0) { models.Add(bestModels.Poll()); } Java.Util.Collections.Reverse(models); double bestF1 = 0.0; int bestSize = 0; for (int i_1 = 1; i_1 <= models.Count; ++i_1) { log.Info("Testing with " + i_1 + " models averaged together"); // TODO: this is kind of ugly, would prefer a separate object AverageScoredModels(models.SubList(0, i_1)); ShiftReduceParser temp = new ShiftReduceParser(op, this); EvaluateTreebank evaluator = new EvaluateTreebank(temp.GetOp(), null, temp, tagger); evaluator.TestOnTreebank(devTreebank); double labelF1 = evaluator.GetLBScore(); log.Info("Label F1 for " + i_1 + " models: " + labelF1); if (labelF1 > bestF1) { bestF1 = labelF1; bestSize = i_1; } } AverageScoredModels(models.SubList(0, bestSize)); } else { AverageScoredModels(bestModels); } } // TODO: perhaps we should filter the features and then get dev // set scores. That way we can merge the models which are best // after filtering. if (featureFrequencies != null) { FilterFeatures(featureFrequencies.KeysAbove(op.TrainOptions().featureFrequencyCutoff)); } CondenseFeatures(); }
public static void Main(string[] args) { IList <string> remainingArgs = Generics.NewArrayList(); IList <Pair <string, IFileFilter> > trainTreebankPath = null; Pair <string, IFileFilter> testTreebankPath = null; Pair <string, IFileFilter> devTreebankPath = null; string serializedPath = null; string tlppClass = null; string continueTraining = null; for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-trainTreebank")) { if (trainTreebankPath == null) { trainTreebankPath = Generics.NewArrayList(); } trainTreebankPath.Add(ArgUtils.GetTreebankDescription(args, argIndex, "-trainTreebank")); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank")) { testTreebankPath = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-devTreebank")) { devTreebankPath = ArgUtils.GetTreebankDescription(args, argIndex, "-devTreebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-serializedPath") || Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model")) { serializedPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-tlpp")) { tlppClass = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-continueTraining")) { continueTraining = args[argIndex + 1]; argIndex += 2; } else { remainingArgs.Add(args[argIndex]); ++argIndex; } } } } } } } string[] newArgs = new string[remainingArgs.Count]; newArgs = Sharpen.Collections.ToArray(remainingArgs, newArgs); if (trainTreebankPath == null && serializedPath == null) { throw new ArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath"); } ShiftReduceParser parser = null; if (trainTreebankPath != null) { log.Info("Training ShiftReduceParser"); log.Info("Initial arguments:"); log.Info(" " + StringUtils.Join(args)); if (continueTraining != null) { parser = ((ShiftReduceParser)ShiftReduceParser.LoadModel(continueTraining, ArrayUtils.Concatenate(ForceTags, newArgs))); } else { ShiftReduceOptions op = BuildTrainingOptions(tlppClass, newArgs); parser = new ShiftReduceParser(op); } parser.Train(trainTreebankPath, devTreebankPath, serializedPath); parser.SaveModel(serializedPath); } if (serializedPath != null && parser == null) { parser = ((ShiftReduceParser)ShiftReduceParser.LoadModel(serializedPath, ArrayUtils.Concatenate(ForceTags, newArgs))); } //parser.outputStats(); if (testTreebankPath != null) { log.Info("Loading test trees from " + testTreebankPath.First()); Treebank testTreebank = parser.op.tlpParams.MemoryTreebank(); testTreebank.LoadPath(testTreebankPath.First(), testTreebankPath.Second()); log.Info("Loaded " + testTreebank.Count + " trees"); EvaluateTreebank evaluator = new EvaluateTreebank(parser.op, null, parser); evaluator.TestOnTreebank(testTreebank); } }
/// <exception cref="System.IO.IOException"/> /// <exception cref="System.TypeLoadException"/> public static void Main(string[] args) { string modelPath = null; IList <string> baseModelPaths = null; string testTreebankPath = null; IFileFilter testTreebankFilter = null; IList <string> unusedArgs = new List <string>(); for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; testTreebankPath = treebankDescription.First(); testTreebankFilter = treebankDescription.Second(); } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-baseModels")) { argIndex++; baseModelPaths = new List <string>(); while (argIndex < args.Length && args[argIndex][0] != '-') { baseModelPaths.Add(args[argIndex++]); } if (baseModelPaths.Count == 0) { throw new ArgumentException("Found an argument -baseModels with no actual models named"); } } else { unusedArgs.Add(args[argIndex++]); } } } } string[] newArgs = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]); LexicalizedParser underlyingParser = null; Options options = null; LexicalizedParser combinedParser = null; if (baseModelPaths != null) { IList <DVModel> dvparsers = new List <DVModel>(); foreach (string baseModelPath in baseModelPaths) { log.Info("Loading serialized DVParser from " + baseModelPath); LexicalizedParser dvparser = ((LexicalizedParser)LexicalizedParser.LoadModel(baseModelPath)); IReranker reranker = dvparser.reranker; if (!(reranker is DVModelReranker)) { throw new ArgumentException("Expected parsers with DVModel embedded"); } dvparsers.Add(((DVModelReranker)reranker).GetModel()); if (underlyingParser == null) { underlyingParser = dvparser; options = underlyingParser.GetOp(); // TODO: other parser's options? options.SetOptions(newArgs); } log.Info("... done"); } combinedParser = LexicalizedParser.CopyLexicalizedParser(underlyingParser); CombinedDVModelReranker reranker_1 = new CombinedDVModelReranker(options, dvparsers); combinedParser.reranker = reranker_1; combinedParser.SaveParserToSerialized(modelPath); } else { throw new ArgumentException("Need to specify -model to load an already prepared CombinedParser"); } Treebank testTreebank = null; if (testTreebankPath != null) { log.Info("Reading in trees from " + testTreebankPath); if (testTreebankFilter != null) { log.Info("Filtering on " + testTreebankFilter); } testTreebank = combinedParser.GetOp().tlpParams.MemoryTreebank(); testTreebank.LoadPath(testTreebankPath, testTreebankFilter); log.Info("Read in " + testTreebank.Count + " trees for testing"); EvaluateTreebank evaluator = new EvaluateTreebank(combinedParser.GetOp(), null, combinedParser); evaluator.TestOnTreebank(testTreebank); } }