public CacheProcessor(CacheParseHypotheses cacher, LexicalizedParser parser, int dvKBest, ITreeTransformer transformer) { this.cacher = cacher; this.parser = parser; this.dvKBest = dvKBest; this.transformer = transformer; }
public virtual bool RunGradientCheck(IList <Tree> sentences, IdentityHashMap <Tree, byte[]> compressedParses) { log.Info("Gradient check: converting " + sentences.Count + " compressed trees"); IdentityHashMap <Tree, IList <Tree> > topParses = CacheParseHypotheses.ConvertToTrees(sentences, compressedParses, op.trainOptions.trainingThreads); log.Info("Done converting trees"); DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(sentences, topParses, dvModel, op); return(gcFunc.GradientCheck(1000, 50, dvModel.ParamsToVector())); }
public virtual void FilterRulesForBatch(IDictionary <Tree, byte[]> compressedTrees) { TwoDimensionalSet <string, string> binaryRules = TwoDimensionalSet.TreeSet(); ICollection <string> unaryRules = new HashSet <string>(); ICollection <string> words = new HashSet <string>(); foreach (KeyValuePair <Tree, byte[]> entry in compressedTrees) { SearchRulesForBatch(binaryRules, unaryRules, words, entry.Key); foreach (Tree hypothesis in CacheParseHypotheses.ConvertToTrees(entry.Value)) { SearchRulesForBatch(binaryRules, unaryRules, words, hypothesis); } } FilterRulesForBatch(binaryRules, unaryRules, words); }
public virtual Pair <Tree, byte[]> Process(Tree tree) { IList <Tree> topParses = DVParser.GetTopParsesForOneTree(parser, dvKBest, tree, transformer); // this block is a test to make sure the conversion code is working... IList <Tree> converted = CacheParseHypotheses.ConvertToTrees(cacher.ConvertToBytes(topParses)); IList <Tree> simplified = CollectionUtils.TransformAsList(topParses, cacher.treeBasicCategories); simplified = CollectionUtils.FilterAsList(simplified, cacher.treeFilter); if (simplified.Count != topParses.Count) { log.Info("Filtered " + (topParses.Count - simplified.Count) + " trees"); if (simplified.Count == 0) { log.Info(" WARNING: filtered all trees for " + tree); } } if (!simplified.Equals(converted)) { if (converted.Count != simplified.Count) { throw new AssertionError("horrible error: tree sizes not equal, " + converted.Count + " vs " + simplified.Count); } for (int i = 0; i < converted.Count; ++i) { if (!simplified[i].Equals(converted[i])) { System.Console.Out.WriteLine("============================="); System.Console.Out.WriteLine(simplified[i]); System.Console.Out.WriteLine("============================="); System.Console.Out.WriteLine(converted[i]); System.Console.Out.WriteLine("============================="); throw new AssertionError("horrible error: tree " + i + " not equal for base tree " + tree); } } } return(Pair.MakePair(tree, cacher.ConvertToBytes(topParses))); }
public virtual void SetRulesForTrainingSet(IList <Tree> sentences, IDictionary <Tree, byte[]> compressedTrees) { TwoDimensionalSet <string, string> binaryRules = TwoDimensionalSet.TreeSet(); ICollection <string> unaryRules = new HashSet <string>(); ICollection <string> words = new HashSet <string>(); foreach (Tree sentence in sentences) { SearchRulesForBatch(binaryRules, unaryRules, words, sentence); foreach (Tree hypothesis in CacheParseHypotheses.ConvertToTrees(compressedTrees[sentence])) { SearchRulesForBatch(binaryRules, unaryRules, words, hypothesis); } } foreach (Pair <string, string> binary in binaryRules) { AddRandomBinaryMatrix(binary.first, binary.second); } foreach (string unary in unaryRules) { AddRandomUnaryMatrix(unary); } FilterRulesForBatch(binaryRules, unaryRules, words); }
/// <summary> /// An example command line for training a new parser: /// <br /> /// nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories > /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2>&1 & /// </summary> /// <exception cref="System.IO.IOException"/> /// <exception cref="System.TypeLoadException"/> public static void Main(string[] args) { if (args.Length == 0) { Help(); System.Environment.Exit(2); } log.Info("Running DVParser with arguments:"); foreach (string arg in args) { log.Info(" " + arg); } log.Info(); string parserPath = null; string trainTreebankPath = null; IFileFilter trainTreebankFilter = null; string cachedTrainTreesPath = null; bool runGradientCheck = false; bool runTraining = false; string testTreebankPath = null; IFileFilter testTreebankFilter = null; string initialModelPath = null; string modelPath = null; bool filter = true; string resultsRecordPath = null; IList <string> unusedArgs = new List <string>(); // These parameters can be null or 0 if the model was not // serialized with the new parameters. Setting the options at the // command line will override these defaults. // TODO: if/when we integrate back into the main branch and // rebuild models, we can get rid of this IList <string> argsWithDefaults = new List <string>(Arrays.AsList(new string[] { "-wordVectorFile", Options.LexOptions.DefaultWordVectorFile, "-dvKBest", int.ToString(TrainOptions.DefaultKBest), "-batchSize", int.ToString(TrainOptions.DefaultBatchSize ), "-trainingIterations", int.ToString(TrainOptions.DefaultTrainingIterations), "-qnIterationsPerBatch", int.ToString(TrainOptions.DefaultQnIterationsPerBatch), "-regCost", double.ToString(TrainOptions.DefaultRegcost), "-learningRate", double .ToString(TrainOptions.DefaultLearningRate), "-deltaMargin", double.ToString(TrainOptions.DefaultDeltaMargin), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector" , "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", double.ToString(TrainOptions.DefaultScalingForInit), "-trainWordVectors" })); Sharpen.Collections.AddAll(argsWithDefaults, Arrays.AsList(args)); args = Sharpen.Collections.ToArray(argsWithDefaults, new string[argsWithDefaults.Count]); for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-parser")) { parserPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; testTreebankPath = treebankDescription.First(); testTreebankFilter = treebankDescription.Second(); } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-treebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-treebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; trainTreebankPath = treebankDescription.First(); trainTreebankFilter = treebankDescription.Second(); } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-cachedTrees")) { cachedTrainTreesPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-runGradientCheck")) { runGradientCheck = true; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-train")) { runTraining = true; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-nofilter")) { filter = false; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-continueTraining")) { runTraining = true; filter = false; initialModelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-resultsRecord")) { resultsRecordPath = args[argIndex + 1]; argIndex += 2; } else { unusedArgs.Add(args[argIndex++]); } } } } } } } } } } } if (parserPath == null && modelPath == null) { throw new ArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model"); } if (!runTraining && modelPath == null && !runGradientCheck) { throw new ArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model"); } string[] newArgs = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]); Edu.Stanford.Nlp.Parser.Dvparser.DVParser dvparser = null; LexicalizedParser lexparser = null; if (initialModelPath != null) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(initialModelPath, newArgs)); DVModel model = GetModelFromLexicalizedParser(lexparser); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(model, lexparser); } else { if (runTraining || runGradientCheck) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(parserPath, newArgs)); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(lexparser); } else { if (modelPath != null) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath, newArgs)); DVModel model = GetModelFromLexicalizedParser(lexparser); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(model, lexparser); } } } IList <Tree> trainSentences = new List <Tree>(); IdentityHashMap <Tree, byte[]> trainCompressedParses = Generics.NewIdentityHashMap(); if (cachedTrainTreesPath != null) { foreach (string path in cachedTrainTreesPath.Split(",")) { IList <Pair <Tree, byte[]> > cache = IOUtils.ReadObjectFromFile(path); foreach (Pair <Tree, byte[]> pair in cache) { trainSentences.Add(pair.First()); trainCompressedParses[pair.First()] = pair.Second(); } log.Info("Read in " + cache.Count + " trees from " + path); } } if (trainTreebankPath != null) { // TODO: make the transformer a member of the model? ITreeTransformer transformer = BuildTrainTransformer(dvparser.GetOp()); Treebank treebank = dvparser.GetOp().tlpParams.MemoryTreebank(); treebank.LoadPath(trainTreebankPath, trainTreebankFilter); treebank = treebank.Transform(transformer); log.Info("Read in " + treebank.Count + " trees from " + trainTreebankPath); CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser); CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer); foreach (Tree tree in treebank) { trainSentences.Add(tree); trainCompressedParses[tree] = processor.Process(tree).second; } //System.out.println(tree); log.Info("Finished parsing " + treebank.Count + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each"); } if ((runTraining || runGradientCheck) && filter) { log.Info("Filtering rules for the given training set"); dvparser.dvModel.SetRulesForTrainingSet(trainSentences, trainCompressedParses); log.Info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.Count + " word vectors"); } //dvparser.dvModel.printAllMatrices(); Treebank testTreebank = null; if (testTreebankPath != null) { log.Info("Reading in trees from " + testTreebankPath); if (testTreebankFilter != null) { log.Info("Filtering on " + testTreebankFilter); } testTreebank = dvparser.GetOp().tlpParams.MemoryTreebank(); testTreebank.LoadPath(testTreebankPath, testTreebankFilter); log.Info("Read in " + testTreebank.Count + " trees for testing"); } // runGradientCheck= true; if (runGradientCheck) { log.Info("Running gradient check on " + trainSentences.Count + " trees"); dvparser.RunGradientCheck(trainSentences, trainCompressedParses); } if (runTraining) { log.Info("Training the RNN parser"); log.Info("Current train options: " + dvparser.GetOp().trainOptions); dvparser.Train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath); if (modelPath != null) { dvparser.SaveModel(modelPath); } } if (testTreebankPath != null) { EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.AttachModelToLexicalizedParser()); evaluator.TestOnTreebank(testTreebank); } log.Info("Successfully ran DVParser"); }
public virtual void ExecuteOneTrainingBatch(IList <Tree> trainingBatch, IdentityHashMap <Tree, byte[]> compressedParses, double[] sumGradSquare) { Timing convertTiming = new Timing(); convertTiming.Doing("Converting trees"); IdentityHashMap <Tree, IList <Tree> > topParses = CacheParseHypotheses.ConvertToTrees(trainingBatch, compressedParses, op.trainOptions.trainingThreads); convertTiming.Done(); DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(trainingBatch, topParses, dvModel, op); double[] theta = dvModel.ParamsToVector(); switch (Minimizer) { case (1): { //maxFuncIter = 10; // 1: QNMinimizer, 2: SGD QNMinimizer qn = new QNMinimizer(op.trainOptions.qnEstimates, true); qn.UseMinPackSearch(); qn.UseDiagonalScaling(); qn.TerminateOnAverageImprovement(true); qn.TerminateOnNumericalZero(true); qn.TerminateOnRelativeNorm(true); theta = qn.Minimize(gcFunc, op.trainOptions.qnTolerance, theta, op.trainOptions.qnIterationsPerBatch); break; } case 2: { //Minimizer smd = new SGDMinimizer(); double tol = 1e-4; theta = smd.minimize(gcFunc,tol,theta,op.trainOptions.qnIterationsPerBatch); double lastCost = 0; double currCost = 0; bool firstTime = true; for (int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++) { //gcFunc.calculate(theta); double[] grad = gcFunc.DerivativeAt(theta); currCost = gcFunc.ValueAt(theta); log.Info("batch cost: " + currCost); // if(!firstTime){ // if(currCost > lastCost){ // System.out.println("HOW IS FUNCTION VALUE INCREASING????!!! ... still updating theta"); // } // if(Math.abs(currCost - lastCost) < 0.0001){ // System.out.println("function value is not decreasing. stop"); // } // }else{ // firstTime = false; // } lastCost = currCost; ArrayMath.AddMultInPlace(theta, grad, -1 * op.trainOptions.learningRate); } break; } case 3: { // AdaGrad double eps = 1e-3; double currCost = 0; for (int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++) { double[] gradf = gcFunc.DerivativeAt(theta); currCost = gcFunc.ValueAt(theta); log.Info("batch cost: " + currCost); for (int feature = 0; feature < gradf.Length; feature++) { sumGradSquare[feature] = sumGradSquare[feature] + gradf[feature] * gradf[feature]; theta[feature] = theta[feature] - (op.trainOptions.learningRate * gradf[feature] / (System.Math.Sqrt(sumGradSquare[feature]) + eps)); } } break; } default: { throw new ArgumentException("Unsupported minimizer " + Minimizer); } } dvModel.VectorToParams(theta); }
/// <summary> /// An example of a command line is /// <br /> /// java -mx1g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model /scr/horatio/dvparser/wsjPCFG.nocompact.simple.ser.gz -output cached9.simple.ser.gz -treebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-202 /// <br /> /// java -mx4g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model ~/scr/dvparser/wsjPCFG.nocompact.simple.ser.gz -output cached.train.simple.ser.gz -treebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-2199 -numThreads 6 /// <br /> /// java -mx4g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model ~/scr/dvparser/chinese/xinhuaPCFG.ser.gz -output cached.xinhua.train.ser.gz -treebank /afs/ir/data/linguistic-data/Chinese-Treebank/6/data/utf8/bracketed 026-270,301-499,600-999 /// </summary> /// <exception cref="System.IO.IOException"/> public static void Main(string[] args) { string parserModel = null; string output = null; IList <Pair <string, IFileFilter> > treebanks = Generics.NewArrayList(); int dvKBest = 200; int numThreads = 1; for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-dvKBest")) { dvKBest = System.Convert.ToInt32(args[argIndex + 1]); argIndex += 2; continue; } if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-parser") || args[argIndex].Equals("-model")) { parserModel = args[argIndex + 1]; argIndex += 2; continue; } if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-output")) { output = args[argIndex + 1]; argIndex += 2; continue; } if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-treebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-treebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; treebanks.Add(treebankDescription); continue; } if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-numThreads")) { numThreads = System.Convert.ToInt32(args[argIndex + 1]); argIndex += 2; continue; } throw new ArgumentException("Unknown argument " + args[argIndex]); } if (parserModel == null) { throw new ArgumentException("Need to supply a parser model with -model"); } if (output == null) { throw new ArgumentException("Need to supply an output filename with -output"); } if (treebanks.IsEmpty()) { throw new ArgumentException("Need to supply a treebank with -treebank"); } log.Info("Writing output to " + output); log.Info("Loading parser model " + parserModel); log.Info("Writing " + dvKBest + " hypothesis trees for each tree"); LexicalizedParser parser = ((LexicalizedParser)LexicalizedParser.LoadModel(parserModel, "-dvKBest", int.ToString(dvKBest))); CacheParseHypotheses cacher = new CacheParseHypotheses(parser); ITreeTransformer transformer = DVParser.BuildTrainTransformer(parser.GetOp()); IList <Tree> sentences = new List <Tree>(); foreach (Pair <string, IFileFilter> description in treebanks) { log.Info("Reading trees from " + description.first); Treebank treebank = parser.GetOp().tlpParams.MemoryTreebank(); treebank.LoadPath(description.first, description.second); treebank = treebank.Transform(transformer); Sharpen.Collections.AddAll(sentences, treebank); } log.Info("Processing " + sentences.Count + " trees"); IList <Pair <Tree, byte[]> > cache = Generics.NewArrayList(); transformer = new SynchronizedTreeTransformer(transformer); MulticoreWrapper <Tree, Pair <Tree, byte[]> > wrapper = new MulticoreWrapper <Tree, Pair <Tree, byte[]> >(numThreads, new CacheParseHypotheses.CacheProcessor(cacher, parser, dvKBest, transformer)); foreach (Tree tree in sentences) { wrapper.Put(tree); while (wrapper.Peek()) { cache.Add(wrapper.Poll()); if (cache.Count % 10 == 0) { System.Console.Out.WriteLine("Processed " + cache.Count + " trees"); } } } wrapper.Join(); while (wrapper.Peek()) { cache.Add(wrapper.Poll()); if (cache.Count % 10 == 0) { System.Console.Out.WriteLine("Processed " + cache.Count + " trees"); } } System.Console.Out.WriteLine("Finished processing " + cache.Count + " trees"); IOUtils.WriteObjectToFile(cache, output); }