public DVParserCostAndGradient(IList <Tree> trainingBatch, IdentityHashMap <Tree, IList <Tree> > topParses, DVModel dvModel, Options op) { this.trainingBatch = trainingBatch; this.topParses = topParses; this.dvModel = dvModel; this.op = op; }
public DVParser(LexicalizedParser parser) { this.parser = parser; this.op = parser.GetOp(); if (op.trainOptions.randomSeed == 0) { op.trainOptions.randomSeed = Runtime.NanoTime(); log.Info("Random seed not set, using randomly chosen seed of " + op.trainOptions.randomSeed); } else { log.Info("Random seed set to " + op.trainOptions.randomSeed); } log.Info("Word vector file: " + op.lexOptions.wordVectorFile); log.Info("Size of word vectors: " + op.lexOptions.numHid); log.Info("Number of hypothesis trees to train against: " + op.trainOptions.dvKBest); log.Info("Number of trees in one batch: " + op.trainOptions.batchSize); log.Info("Number of iterations of trees: " + op.trainOptions.trainingIterations); log.Info("Number of qn iterations per batch: " + op.trainOptions.qnIterationsPerBatch); log.Info("Learning rate: " + op.trainOptions.learningRate); log.Info("Delta margin: " + op.trainOptions.deltaMargin); log.Info("regCost: " + op.trainOptions.regCost); log.Info("Using unknown word vector for numbers: " + op.trainOptions.unknownNumberVector); log.Info("Using unknown dashed word vector heuristics: " + op.trainOptions.unknownDashedWordVectors); log.Info("Using unknown word vector for capitalized words: " + op.trainOptions.unknownCapsVector); log.Info("Using unknown number vector for Chinese words: " + op.trainOptions.unknownChineseNumberVector); log.Info("Using unknown year vector for Chinese words: " + op.trainOptions.unknownChineseYearVector); log.Info("Using unknown percent vector for Chinese words: " + op.trainOptions.unknownChinesePercentVector); log.Info("Initial matrices scaled by: " + op.trainOptions.scalingForInit); log.Info("Training will use " + op.trainOptions.trainingThreads + " thread(s)"); log.Info("Context words are " + ((op.trainOptions.useContextWords) ? "on" : "off")); log.Info("Model will " + ((op.trainOptions.dvSimplifiedModel) ? string.Empty : "not ") + "be simplified"); this.dvModel = new DVModel(op, parser.stateIndex, parser.ug, parser.bg); if (dvModel.unaryTransform.Count != dvModel.unaryScore.Count) { throw new AssertionError("Unary transform and score size not the same"); } if (dvModel.binaryTransform.Size() != dvModel.binaryScore.Size()) { throw new AssertionError("Binary transform and score size not the same"); } }
/// <exception cref="System.IO.IOException"/> public static void Main(string[] args) { string modelPath = null; string outputPath = null; string inputPath = null; string testTreebankPath = null; IFileFilter testTreebankFilter = null; IList <string> unusedArgs = Generics.NewArrayList(); for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-output")) { outputPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-input")) { inputPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; testTreebankPath = treebankDescription.First(); testTreebankFilter = treebankDescription.Second(); } else { unusedArgs.Add(args[argIndex++]); } } } } } string[] newArgs = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]); LexicalizedParser parser = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath, newArgs)); DVModel model = DVParser.GetModelFromLexicalizedParser(parser); File outputFile = new File(outputPath); FileSystem.CheckNotExistsOrFail(outputFile); FileSystem.MkdirOrFail(outputFile); int count = 0; if (inputPath != null) { Reader input = new BufferedReader(new FileReader(inputPath)); DocumentPreprocessor processor = new DocumentPreprocessor(input); foreach (IList <IHasWord> sentence in processor) { count++; // index from 1 IParserQuery pq = parser.ParserQuery(); if (!(pq is RerankingParserQuery)) { throw new ArgumentException("Expected a RerankingParserQuery"); } RerankingParserQuery rpq = (RerankingParserQuery)pq; if (!rpq.Parse(sentence)) { throw new Exception("Unparsable sentence: " + sentence); } IRerankerQuery reranker = rpq.RerankerQuery(); if (!(reranker is DVModelReranker.Query)) { throw new ArgumentException("Expected a DVModelReranker"); } DeepTree deepTree = ((DVModelReranker.Query)reranker).GetDeepTrees()[0]; IdentityHashMap <Tree, SimpleMatrix> vectors = deepTree.GetVectors(); foreach (KeyValuePair <Tree, SimpleMatrix> entry in vectors) { log.Info(entry.Key + " " + entry.Value); } FileWriter fout = new FileWriter(outputPath + File.separator + "sentence" + count + ".txt"); BufferedWriter bout = new BufferedWriter(fout); bout.Write(SentenceUtils.ListToString(sentence)); bout.NewLine(); bout.Write(deepTree.GetTree().ToString()); bout.NewLine(); foreach (IHasWord word in sentence) { OutputMatrix(bout, model.GetWordVector(word.Word())); } Tree rootTree = FindRootTree(vectors); OutputTreeMatrices(bout, rootTree, vectors); bout.Flush(); fout.Close(); } } }
/// <summary> /// An example command line for training a new parser: /// <br /> /// nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories > /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2>&1 & /// </summary> /// <exception cref="System.IO.IOException"/> /// <exception cref="System.TypeLoadException"/> public static void Main(string[] args) { if (args.Length == 0) { Help(); System.Environment.Exit(2); } log.Info("Running DVParser with arguments:"); foreach (string arg in args) { log.Info(" " + arg); } log.Info(); string parserPath = null; string trainTreebankPath = null; IFileFilter trainTreebankFilter = null; string cachedTrainTreesPath = null; bool runGradientCheck = false; bool runTraining = false; string testTreebankPath = null; IFileFilter testTreebankFilter = null; string initialModelPath = null; string modelPath = null; bool filter = true; string resultsRecordPath = null; IList <string> unusedArgs = new List <string>(); // These parameters can be null or 0 if the model was not // serialized with the new parameters. Setting the options at the // command line will override these defaults. // TODO: if/when we integrate back into the main branch and // rebuild models, we can get rid of this IList <string> argsWithDefaults = new List <string>(Arrays.AsList(new string[] { "-wordVectorFile", Options.LexOptions.DefaultWordVectorFile, "-dvKBest", int.ToString(TrainOptions.DefaultKBest), "-batchSize", int.ToString(TrainOptions.DefaultBatchSize ), "-trainingIterations", int.ToString(TrainOptions.DefaultTrainingIterations), "-qnIterationsPerBatch", int.ToString(TrainOptions.DefaultQnIterationsPerBatch), "-regCost", double.ToString(TrainOptions.DefaultRegcost), "-learningRate", double .ToString(TrainOptions.DefaultLearningRate), "-deltaMargin", double.ToString(TrainOptions.DefaultDeltaMargin), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector" , "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", double.ToString(TrainOptions.DefaultScalingForInit), "-trainWordVectors" })); Sharpen.Collections.AddAll(argsWithDefaults, Arrays.AsList(args)); args = Sharpen.Collections.ToArray(argsWithDefaults, new string[argsWithDefaults.Count]); for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-parser")) { parserPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; testTreebankPath = treebankDescription.First(); testTreebankFilter = treebankDescription.Second(); } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-treebank")) { Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-treebank"); argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1; trainTreebankPath = treebankDescription.First(); trainTreebankFilter = treebankDescription.Second(); } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-cachedTrees")) { cachedTrainTreesPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-runGradientCheck")) { runGradientCheck = true; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-train")) { runTraining = true; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-nofilter")) { filter = false; argIndex++; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-continueTraining")) { runTraining = true; filter = false; initialModelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-resultsRecord")) { resultsRecordPath = args[argIndex + 1]; argIndex += 2; } else { unusedArgs.Add(args[argIndex++]); } } } } } } } } } } } if (parserPath == null && modelPath == null) { throw new ArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model"); } if (!runTraining && modelPath == null && !runGradientCheck) { throw new ArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model"); } string[] newArgs = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]); Edu.Stanford.Nlp.Parser.Dvparser.DVParser dvparser = null; LexicalizedParser lexparser = null; if (initialModelPath != null) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(initialModelPath, newArgs)); DVModel model = GetModelFromLexicalizedParser(lexparser); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(model, lexparser); } else { if (runTraining || runGradientCheck) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(parserPath, newArgs)); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(lexparser); } else { if (modelPath != null) { lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath, newArgs)); DVModel model = GetModelFromLexicalizedParser(lexparser); dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(model, lexparser); } } } IList <Tree> trainSentences = new List <Tree>(); IdentityHashMap <Tree, byte[]> trainCompressedParses = Generics.NewIdentityHashMap(); if (cachedTrainTreesPath != null) { foreach (string path in cachedTrainTreesPath.Split(",")) { IList <Pair <Tree, byte[]> > cache = IOUtils.ReadObjectFromFile(path); foreach (Pair <Tree, byte[]> pair in cache) { trainSentences.Add(pair.First()); trainCompressedParses[pair.First()] = pair.Second(); } log.Info("Read in " + cache.Count + " trees from " + path); } } if (trainTreebankPath != null) { // TODO: make the transformer a member of the model? ITreeTransformer transformer = BuildTrainTransformer(dvparser.GetOp()); Treebank treebank = dvparser.GetOp().tlpParams.MemoryTreebank(); treebank.LoadPath(trainTreebankPath, trainTreebankFilter); treebank = treebank.Transform(transformer); log.Info("Read in " + treebank.Count + " trees from " + trainTreebankPath); CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser); CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer); foreach (Tree tree in treebank) { trainSentences.Add(tree); trainCompressedParses[tree] = processor.Process(tree).second; } //System.out.println(tree); log.Info("Finished parsing " + treebank.Count + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each"); } if ((runTraining || runGradientCheck) && filter) { log.Info("Filtering rules for the given training set"); dvparser.dvModel.SetRulesForTrainingSet(trainSentences, trainCompressedParses); log.Info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.Count + " word vectors"); } //dvparser.dvModel.printAllMatrices(); Treebank testTreebank = null; if (testTreebankPath != null) { log.Info("Reading in trees from " + testTreebankPath); if (testTreebankFilter != null) { log.Info("Filtering on " + testTreebankFilter); } testTreebank = dvparser.GetOp().tlpParams.MemoryTreebank(); testTreebank.LoadPath(testTreebankPath, testTreebankFilter); log.Info("Read in " + testTreebank.Count + " trees for testing"); } // runGradientCheck= true; if (runGradientCheck) { log.Info("Running gradient check on " + trainSentences.Count + " trees"); dvparser.RunGradientCheck(trainSentences, trainCompressedParses); } if (runTraining) { log.Info("Training the RNN parser"); log.Info("Current train options: " + dvparser.GetOp().trainOptions); dvparser.Train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath); if (modelPath != null) { dvparser.SaveModel(modelPath); } } if (testTreebankPath != null) { EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.AttachModelToLexicalizedParser()); evaluator.TestOnTreebank(testTreebank); } log.Info("Successfully ran DVParser"); }
public DVParser(DVModel model, LexicalizedParser parser) { this.parser = parser; this.op = parser.GetOp(); this.dvModel = model; }
public UnknownWordPrinter(DVModel model) { this.model = model; this.unk = model.GetUnknownWordVector(); }
public DVModelReranker(DVModel model) { this.op = model.op; this.model = model; }
/// <summary> /// Command line arguments for this program: /// <br /> /// -output: the model file to output /// -input: a list of model files to input /// </summary> public static void Main(string[] args) { string outputModelFilename = null; IList <string> inputModelFilenames = Generics.NewArrayList(); for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-output")) { outputModelFilename = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-input")) { for (++argIndex; argIndex < args.Length && !args[argIndex].StartsWith("-"); ++argIndex) { Sharpen.Collections.AddAll(inputModelFilenames, Arrays.AsList(args[argIndex].Split(","))); } } else { throw new Exception("Unknown argument " + args[argIndex]); } } } if (outputModelFilename == null) { log.Info("Need to specify output model name with -output"); System.Environment.Exit(2); } if (inputModelFilenames.Count == 0) { log.Info("Need to specify input model names with -input"); System.Environment.Exit(2); } log.Info("Averaging " + inputModelFilenames); log.Info("Outputting result to " + outputModelFilename); LexicalizedParser lexparser = null; IList <DVModel> models = Generics.NewArrayList(); foreach (string filename in inputModelFilenames) { LexicalizedParser parser = ((LexicalizedParser)LexicalizedParser.LoadModel(filename)); if (lexparser == null) { lexparser = parser; } models.Add(DVParser.GetModelFromLexicalizedParser(parser)); } IList <TwoDimensionalMap <string, string, SimpleMatrix> > binaryTransformMaps = CollectionUtils.TransformAsList(models, null); IList <TwoDimensionalMap <string, string, SimpleMatrix> > binaryScoreMaps = CollectionUtils.TransformAsList(models, null); IList <IDictionary <string, SimpleMatrix> > unaryTransformMaps = CollectionUtils.TransformAsList(models, null); IList <IDictionary <string, SimpleMatrix> > unaryScoreMaps = CollectionUtils.TransformAsList(models, null); IList <IDictionary <string, SimpleMatrix> > wordMaps = CollectionUtils.TransformAsList(models, null); TwoDimensionalMap <string, string, SimpleMatrix> binaryTransformAverages = AverageBinaryMatrices(binaryTransformMaps); TwoDimensionalMap <string, string, SimpleMatrix> binaryScoreAverages = AverageBinaryMatrices(binaryScoreMaps); IDictionary <string, SimpleMatrix> unaryTransformAverages = AverageUnaryMatrices(unaryTransformMaps); IDictionary <string, SimpleMatrix> unaryScoreAverages = AverageUnaryMatrices(unaryScoreMaps); IDictionary <string, SimpleMatrix> wordAverages = AverageUnaryMatrices(wordMaps); DVModel newModel = new DVModel(binaryTransformAverages, unaryTransformAverages, binaryScoreAverages, unaryScoreAverages, wordAverages, lexparser.GetOp()); DVParser newParser = new DVParser(newModel, lexparser); newParser.SaveModel(outputModelFilename); }
/// <exception cref="System.IO.IOException"/> public static void Main(string[] args) { string modelPath = null; string outputDir = null; for (int argIndex = 0; argIndex < args.Length;) { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model")) { modelPath = args[argIndex + 1]; argIndex += 2; } else { if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-output")) { outputDir = args[argIndex + 1]; argIndex += 2; } else { log.Info("Unknown argument " + args[argIndex]); Help(); } } } if (outputDir == null || modelPath == null) { Help(); } File outputFile = new File(outputDir); FileSystem.CheckNotExistsOrFail(outputFile); FileSystem.MkdirOrFail(outputFile); LexicalizedParser parser = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath)); DVModel model = DVParser.GetModelFromLexicalizedParser(parser); string binaryWDir = outputDir + File.separator + "binaryW"; FileSystem.MkdirOrFail(binaryWDir); foreach (TwoDimensionalMap.Entry <string, string, SimpleMatrix> entry in model.binaryTransform) { string filename = binaryWDir + File.separator + entry.GetFirstKey() + "_" + entry.GetSecondKey() + ".txt"; DumpMatrix(filename, entry.GetValue()); } string binaryScoreDir = outputDir + File.separator + "binaryScore"; FileSystem.MkdirOrFail(binaryScoreDir); foreach (TwoDimensionalMap.Entry <string, string, SimpleMatrix> entry_1 in model.binaryScore) { string filename = binaryScoreDir + File.separator + entry_1.GetFirstKey() + "_" + entry_1.GetSecondKey() + ".txt"; DumpMatrix(filename, entry_1.GetValue()); } string unaryWDir = outputDir + File.separator + "unaryW"; FileSystem.MkdirOrFail(unaryWDir); foreach (KeyValuePair <string, SimpleMatrix> entry_2 in model.unaryTransform) { string filename = unaryWDir + File.separator + entry_2.Key + ".txt"; DumpMatrix(filename, entry_2.Value); } string unaryScoreDir = outputDir + File.separator + "unaryScore"; FileSystem.MkdirOrFail(unaryScoreDir); foreach (KeyValuePair <string, SimpleMatrix> entry_3 in model.unaryScore) { string filename = unaryScoreDir + File.separator + entry_3.Key + ".txt"; DumpMatrix(filename, entry_3.Value); } string embeddingFile = outputDir + File.separator + "embeddings.txt"; FileWriter fout = new FileWriter(embeddingFile); BufferedWriter bout = new BufferedWriter(fout); foreach (KeyValuePair <string, SimpleMatrix> entry_4 in model.wordVectors) { bout.Write(entry_4.Key); SimpleMatrix vector = entry_4.Value; for (int i = 0; i < vector.NumRows(); ++i) { bout.Write(" " + vector.Get(i, 0)); } bout.Write("\n"); } bout.Close(); fout.Close(); }