Beispiel #1
0
 public CacheProcessor(CacheParseHypotheses cacher, LexicalizedParser parser, int dvKBest, ITreeTransformer transformer)
 {
     this.cacher      = cacher;
     this.parser      = parser;
     this.dvKBest     = dvKBest;
     this.transformer = transformer;
 }
        public virtual bool RunGradientCheck(IList <Tree> sentences, IdentityHashMap <Tree, byte[]> compressedParses)
        {
            log.Info("Gradient check: converting " + sentences.Count + " compressed trees");
            IdentityHashMap <Tree, IList <Tree> > topParses = CacheParseHypotheses.ConvertToTrees(sentences, compressedParses, op.trainOptions.trainingThreads);

            log.Info("Done converting trees");
            DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(sentences, topParses, dvModel, op);

            return(gcFunc.GradientCheck(1000, 50, dvModel.ParamsToVector()));
        }
        public virtual void FilterRulesForBatch(IDictionary <Tree, byte[]> compressedTrees)
        {
            TwoDimensionalSet <string, string> binaryRules = TwoDimensionalSet.TreeSet();
            ICollection <string> unaryRules = new HashSet <string>();
            ICollection <string> words      = new HashSet <string>();

            foreach (KeyValuePair <Tree, byte[]> entry in compressedTrees)
            {
                SearchRulesForBatch(binaryRules, unaryRules, words, entry.Key);
                foreach (Tree hypothesis in CacheParseHypotheses.ConvertToTrees(entry.Value))
                {
                    SearchRulesForBatch(binaryRules, unaryRules, words, hypothesis);
                }
            }
            FilterRulesForBatch(binaryRules, unaryRules, words);
        }
Beispiel #4
0
            public virtual Pair <Tree, byte[]> Process(Tree tree)
            {
                IList <Tree> topParses = DVParser.GetTopParsesForOneTree(parser, dvKBest, tree, transformer);
                // this block is a test to make sure the conversion code is working...
                IList <Tree> converted  = CacheParseHypotheses.ConvertToTrees(cacher.ConvertToBytes(topParses));
                IList <Tree> simplified = CollectionUtils.TransformAsList(topParses, cacher.treeBasicCategories);

                simplified = CollectionUtils.FilterAsList(simplified, cacher.treeFilter);
                if (simplified.Count != topParses.Count)
                {
                    log.Info("Filtered " + (topParses.Count - simplified.Count) + " trees");
                    if (simplified.Count == 0)
                    {
                        log.Info(" WARNING: filtered all trees for " + tree);
                    }
                }
                if (!simplified.Equals(converted))
                {
                    if (converted.Count != simplified.Count)
                    {
                        throw new AssertionError("horrible error: tree sizes not equal, " + converted.Count + " vs " + simplified.Count);
                    }
                    for (int i = 0; i < converted.Count; ++i)
                    {
                        if (!simplified[i].Equals(converted[i]))
                        {
                            System.Console.Out.WriteLine("=============================");
                            System.Console.Out.WriteLine(simplified[i]);
                            System.Console.Out.WriteLine("=============================");
                            System.Console.Out.WriteLine(converted[i]);
                            System.Console.Out.WriteLine("=============================");
                            throw new AssertionError("horrible error: tree " + i + " not equal for base tree " + tree);
                        }
                    }
                }
                return(Pair.MakePair(tree, cacher.ConvertToBytes(topParses)));
            }
        public virtual void SetRulesForTrainingSet(IList <Tree> sentences, IDictionary <Tree, byte[]> compressedTrees)
        {
            TwoDimensionalSet <string, string> binaryRules = TwoDimensionalSet.TreeSet();
            ICollection <string> unaryRules = new HashSet <string>();
            ICollection <string> words      = new HashSet <string>();

            foreach (Tree sentence in sentences)
            {
                SearchRulesForBatch(binaryRules, unaryRules, words, sentence);
                foreach (Tree hypothesis in CacheParseHypotheses.ConvertToTrees(compressedTrees[sentence]))
                {
                    SearchRulesForBatch(binaryRules, unaryRules, words, hypothesis);
                }
            }
            foreach (Pair <string, string> binary in binaryRules)
            {
                AddRandomBinaryMatrix(binary.first, binary.second);
            }
            foreach (string unary in unaryRules)
            {
                AddRandomUnaryMatrix(unary);
            }
            FilterRulesForBatch(binaryRules, unaryRules, words);
        }
        /// <summary>
        /// An example command line for training a new parser:
        /// <br />
        /// nohup java -mx6g edu.stanford.nlp.parser.dvparser.DVParser -cachedTrees /scr/nlp/data/dvparser/wsj/cached.wsj.train.simple.ser.gz -train -testTreebank  /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj/22 2200-2219 -debugOutputFrequency 400 -nofilter -trainingThreads 5 -parser /u/nlp/data/lexparser/wsjPCFG.nocompact.simple.ser.gz -trainingIterations 40 -batchSize 25 -model /scr/nlp/data/dvparser/wsj/wsj.combine.v2.ser.gz -unkWord "*UNK*" -dvCombineCategories &gt; /scr/nlp/data/dvparser/wsj/wsj.combine.v2.out 2&gt;&amp;1 &amp;
        /// </summary>
        /// <exception cref="System.IO.IOException"/>
        /// <exception cref="System.TypeLoadException"/>
        public static void Main(string[] args)
        {
            if (args.Length == 0)
            {
                Help();
                System.Environment.Exit(2);
            }
            log.Info("Running DVParser with arguments:");
            foreach (string arg in args)
            {
                log.Info("  " + arg);
            }
            log.Info();
            string         parserPath           = null;
            string         trainTreebankPath    = null;
            IFileFilter    trainTreebankFilter  = null;
            string         cachedTrainTreesPath = null;
            bool           runGradientCheck     = false;
            bool           runTraining          = false;
            string         testTreebankPath     = null;
            IFileFilter    testTreebankFilter   = null;
            string         initialModelPath     = null;
            string         modelPath            = null;
            bool           filter            = true;
            string         resultsRecordPath = null;
            IList <string> unusedArgs        = new List <string>();
            // These parameters can be null or 0 if the model was not
            // serialized with the new parameters.  Setting the options at the
            // command line will override these defaults.
            // TODO: if/when we integrate back into the main branch and
            // rebuild models, we can get rid of this
            IList <string> argsWithDefaults = new List <string>(Arrays.AsList(new string[] { "-wordVectorFile", Options.LexOptions.DefaultWordVectorFile, "-dvKBest", int.ToString(TrainOptions.DefaultKBest), "-batchSize", int.ToString(TrainOptions.DefaultBatchSize
                                                                                                                                                                                                                                          ), "-trainingIterations", int.ToString(TrainOptions.DefaultTrainingIterations), "-qnIterationsPerBatch", int.ToString(TrainOptions.DefaultQnIterationsPerBatch), "-regCost", double.ToString(TrainOptions.DefaultRegcost), "-learningRate", double
                                                                                             .ToString(TrainOptions.DefaultLearningRate), "-deltaMargin", double.ToString(TrainOptions.DefaultDeltaMargin), "-unknownNumberVector", "-unknownDashedWordVectors", "-unknownCapsVector", "-unknownchinesepercentvector", "-unknownchinesenumbervector"
                                                                                             , "-unknownchineseyearvector", "-unkWord", "*UNK*", "-transformMatrixType", "DIAGONAL", "-scalingForInit", double.ToString(TrainOptions.DefaultScalingForInit), "-trainWordVectors" }));

            Sharpen.Collections.AddAll(argsWithDefaults, Arrays.AsList(args));
            args = Sharpen.Collections.ToArray(argsWithDefaults, new string[argsWithDefaults.Count]);
            for (int argIndex = 0; argIndex < args.Length;)
            {
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-parser"))
                {
                    parserPath = args[argIndex + 1];
                    argIndex  += 2;
                }
                else
                {
                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank"))
                    {
                        Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank");
                        argIndex           = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                        testTreebankPath   = treebankDescription.First();
                        testTreebankFilter = treebankDescription.Second();
                    }
                    else
                    {
                        if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-treebank"))
                        {
                            Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-treebank");
                            argIndex            = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                            trainTreebankPath   = treebankDescription.First();
                            trainTreebankFilter = treebankDescription.Second();
                        }
                        else
                        {
                            if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-cachedTrees"))
                            {
                                cachedTrainTreesPath = args[argIndex + 1];
                                argIndex            += 2;
                            }
                            else
                            {
                                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-runGradientCheck"))
                                {
                                    runGradientCheck = true;
                                    argIndex++;
                                }
                                else
                                {
                                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-train"))
                                    {
                                        runTraining = true;
                                        argIndex++;
                                    }
                                    else
                                    {
                                        if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model"))
                                        {
                                            modelPath = args[argIndex + 1];
                                            argIndex += 2;
                                        }
                                        else
                                        {
                                            if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-nofilter"))
                                            {
                                                filter = false;
                                                argIndex++;
                                            }
                                            else
                                            {
                                                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-continueTraining"))
                                                {
                                                    runTraining      = true;
                                                    filter           = false;
                                                    initialModelPath = args[argIndex + 1];
                                                    argIndex        += 2;
                                                }
                                                else
                                                {
                                                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-resultsRecord"))
                                                    {
                                                        resultsRecordPath = args[argIndex + 1];
                                                        argIndex         += 2;
                                                    }
                                                    else
                                                    {
                                                        unusedArgs.Add(args[argIndex++]);
                                                    }
                                                }
                                            }
                                        }
                                    }
                                }
                            }
                        }
                    }
                }
            }
            if (parserPath == null && modelPath == null)
            {
                throw new ArgumentException("Must supply either a base parser model with -parser or a serialized DVParser with -model");
            }
            if (!runTraining && modelPath == null && !runGradientCheck)
            {
                throw new ArgumentException("Need to either train a new model, run the gradient check or specify a model to load with -model");
            }
            string[] newArgs = Sharpen.Collections.ToArray(unusedArgs, new string[unusedArgs.Count]);
            Edu.Stanford.Nlp.Parser.Dvparser.DVParser dvparser = null;
            LexicalizedParser lexparser = null;

            if (initialModelPath != null)
            {
                lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(initialModelPath, newArgs));
                DVModel model = GetModelFromLexicalizedParser(lexparser);
                dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(model, lexparser);
            }
            else
            {
                if (runTraining || runGradientCheck)
                {
                    lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(parserPath, newArgs));
                    dvparser  = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(lexparser);
                }
                else
                {
                    if (modelPath != null)
                    {
                        lexparser = ((LexicalizedParser)LexicalizedParser.LoadModel(modelPath, newArgs));
                        DVModel model = GetModelFromLexicalizedParser(lexparser);
                        dvparser = new Edu.Stanford.Nlp.Parser.Dvparser.DVParser(model, lexparser);
                    }
                }
            }
            IList <Tree> trainSentences = new List <Tree>();
            IdentityHashMap <Tree, byte[]> trainCompressedParses = Generics.NewIdentityHashMap();

            if (cachedTrainTreesPath != null)
            {
                foreach (string path in cachedTrainTreesPath.Split(","))
                {
                    IList <Pair <Tree, byte[]> > cache = IOUtils.ReadObjectFromFile(path);
                    foreach (Pair <Tree, byte[]> pair in cache)
                    {
                        trainSentences.Add(pair.First());
                        trainCompressedParses[pair.First()] = pair.Second();
                    }
                    log.Info("Read in " + cache.Count + " trees from " + path);
                }
            }
            if (trainTreebankPath != null)
            {
                // TODO: make the transformer a member of the model?
                ITreeTransformer transformer = BuildTrainTransformer(dvparser.GetOp());
                Treebank         treebank    = dvparser.GetOp().tlpParams.MemoryTreebank();
                treebank.LoadPath(trainTreebankPath, trainTreebankFilter);
                treebank = treebank.Transform(transformer);
                log.Info("Read in " + treebank.Count + " trees from " + trainTreebankPath);
                CacheParseHypotheses cacher = new CacheParseHypotheses(dvparser.parser);
                CacheParseHypotheses.CacheProcessor processor = new CacheParseHypotheses.CacheProcessor(cacher, lexparser, dvparser.op.trainOptions.dvKBest, transformer);
                foreach (Tree tree in treebank)
                {
                    trainSentences.Add(tree);
                    trainCompressedParses[tree] = processor.Process(tree).second;
                }
                //System.out.println(tree);
                log.Info("Finished parsing " + treebank.Count + " trees, getting " + dvparser.op.trainOptions.dvKBest + " hypotheses each");
            }
            if ((runTraining || runGradientCheck) && filter)
            {
                log.Info("Filtering rules for the given training set");
                dvparser.dvModel.SetRulesForTrainingSet(trainSentences, trainCompressedParses);
                log.Info("Done filtering rules; " + dvparser.dvModel.numBinaryMatrices + " binary matrices, " + dvparser.dvModel.numUnaryMatrices + " unary matrices, " + dvparser.dvModel.wordVectors.Count + " word vectors");
            }
            //dvparser.dvModel.printAllMatrices();
            Treebank testTreebank = null;

            if (testTreebankPath != null)
            {
                log.Info("Reading in trees from " + testTreebankPath);
                if (testTreebankFilter != null)
                {
                    log.Info("Filtering on " + testTreebankFilter);
                }
                testTreebank = dvparser.GetOp().tlpParams.MemoryTreebank();
                testTreebank.LoadPath(testTreebankPath, testTreebankFilter);
                log.Info("Read in " + testTreebank.Count + " trees for testing");
            }
            //    runGradientCheck= true;
            if (runGradientCheck)
            {
                log.Info("Running gradient check on " + trainSentences.Count + " trees");
                dvparser.RunGradientCheck(trainSentences, trainCompressedParses);
            }
            if (runTraining)
            {
                log.Info("Training the RNN parser");
                log.Info("Current train options: " + dvparser.GetOp().trainOptions);
                dvparser.Train(trainSentences, trainCompressedParses, testTreebank, modelPath, resultsRecordPath);
                if (modelPath != null)
                {
                    dvparser.SaveModel(modelPath);
                }
            }
            if (testTreebankPath != null)
            {
                EvaluateTreebank evaluator = new EvaluateTreebank(dvparser.AttachModelToLexicalizedParser());
                evaluator.TestOnTreebank(testTreebank);
            }
            log.Info("Successfully ran DVParser");
        }
        public virtual void ExecuteOneTrainingBatch(IList <Tree> trainingBatch, IdentityHashMap <Tree, byte[]> compressedParses, double[] sumGradSquare)
        {
            Timing convertTiming = new Timing();

            convertTiming.Doing("Converting trees");
            IdentityHashMap <Tree, IList <Tree> > topParses = CacheParseHypotheses.ConvertToTrees(trainingBatch, compressedParses, op.trainOptions.trainingThreads);

            convertTiming.Done();
            DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(trainingBatch, topParses, dvModel, op);

            double[] theta = dvModel.ParamsToVector();
            switch (Minimizer)
            {
            case (1):
            {
                //maxFuncIter = 10;
                // 1: QNMinimizer, 2: SGD
                QNMinimizer qn = new QNMinimizer(op.trainOptions.qnEstimates, true);
                qn.UseMinPackSearch();
                qn.UseDiagonalScaling();
                qn.TerminateOnAverageImprovement(true);
                qn.TerminateOnNumericalZero(true);
                qn.TerminateOnRelativeNorm(true);
                theta = qn.Minimize(gcFunc, op.trainOptions.qnTolerance, theta, op.trainOptions.qnIterationsPerBatch);
                break;
            }

            case 2:
            {
                //Minimizer smd = new SGDMinimizer();       double tol = 1e-4;      theta = smd.minimize(gcFunc,tol,theta,op.trainOptions.qnIterationsPerBatch);
                double lastCost  = 0;
                double currCost  = 0;
                bool   firstTime = true;
                for (int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++)
                {
                    //gcFunc.calculate(theta);
                    double[] grad = gcFunc.DerivativeAt(theta);
                    currCost = gcFunc.ValueAt(theta);
                    log.Info("batch cost: " + currCost);
                    //          if(!firstTime){
                    //              if(currCost > lastCost){
                    //                  System.out.println("HOW IS FUNCTION VALUE INCREASING????!!! ... still updating theta");
                    //              }
                    //              if(Math.abs(currCost - lastCost) < 0.0001){
                    //                  System.out.println("function value is not decreasing. stop");
                    //              }
                    //          }else{
                    //              firstTime = false;
                    //          }
                    lastCost = currCost;
                    ArrayMath.AddMultInPlace(theta, grad, -1 * op.trainOptions.learningRate);
                }
                break;
            }

            case 3:
            {
                // AdaGrad
                double eps      = 1e-3;
                double currCost = 0;
                for (int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++)
                {
                    double[] gradf = gcFunc.DerivativeAt(theta);
                    currCost = gcFunc.ValueAt(theta);
                    log.Info("batch cost: " + currCost);
                    for (int feature = 0; feature < gradf.Length; feature++)
                    {
                        sumGradSquare[feature] = sumGradSquare[feature] + gradf[feature] * gradf[feature];
                        theta[feature]         = theta[feature] - (op.trainOptions.learningRate * gradf[feature] / (System.Math.Sqrt(sumGradSquare[feature]) + eps));
                    }
                }
                break;
            }

            default:
            {
                throw new ArgumentException("Unsupported minimizer " + Minimizer);
            }
            }
            dvModel.VectorToParams(theta);
        }
Beispiel #8
0
        /// <summary>
        /// An example of a command line is
        /// <br />
        /// java -mx1g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model /scr/horatio/dvparser/wsjPCFG.nocompact.simple.ser.gz -output cached9.simple.ser.gz  -treebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-202
        /// <br />
        /// java -mx4g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model ~/scr/dvparser/wsjPCFG.nocompact.simple.ser.gz -output cached.train.simple.ser.gz -treebank /afs/ir/data/linguistic-data/Treebank/3/parsed/mrg/wsj 200-2199 -numThreads 6
        /// <br />
        /// java -mx4g edu.stanford.nlp.parser.dvparser.CacheParseHypotheses -model ~/scr/dvparser/chinese/xinhuaPCFG.ser.gz -output cached.xinhua.train.ser.gz -treebank /afs/ir/data/linguistic-data/Chinese-Treebank/6/data/utf8/bracketed  026-270,301-499,600-999
        /// </summary>
        /// <exception cref="System.IO.IOException"/>
        public static void Main(string[] args)
        {
            string parserModel = null;
            string output      = null;
            IList <Pair <string, IFileFilter> > treebanks = Generics.NewArrayList();
            int dvKBest    = 200;
            int numThreads = 1;

            for (int argIndex = 0; argIndex < args.Length;)
            {
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-dvKBest"))
                {
                    dvKBest   = System.Convert.ToInt32(args[argIndex + 1]);
                    argIndex += 2;
                    continue;
                }
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-parser") || args[argIndex].Equals("-model"))
                {
                    parserModel = args[argIndex + 1];
                    argIndex   += 2;
                    continue;
                }
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-output"))
                {
                    output    = args[argIndex + 1];
                    argIndex += 2;
                    continue;
                }
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-treebank"))
                {
                    Pair <string, IFileFilter> treebankDescription = ArgUtils.GetTreebankDescription(args, argIndex, "-treebank");
                    argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                    treebanks.Add(treebankDescription);
                    continue;
                }
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-numThreads"))
                {
                    numThreads = System.Convert.ToInt32(args[argIndex + 1]);
                    argIndex  += 2;
                    continue;
                }
                throw new ArgumentException("Unknown argument " + args[argIndex]);
            }
            if (parserModel == null)
            {
                throw new ArgumentException("Need to supply a parser model with -model");
            }
            if (output == null)
            {
                throw new ArgumentException("Need to supply an output filename with -output");
            }
            if (treebanks.IsEmpty())
            {
                throw new ArgumentException("Need to supply a treebank with -treebank");
            }
            log.Info("Writing output to " + output);
            log.Info("Loading parser model " + parserModel);
            log.Info("Writing " + dvKBest + " hypothesis trees for each tree");
            LexicalizedParser    parser      = ((LexicalizedParser)LexicalizedParser.LoadModel(parserModel, "-dvKBest", int.ToString(dvKBest)));
            CacheParseHypotheses cacher      = new CacheParseHypotheses(parser);
            ITreeTransformer     transformer = DVParser.BuildTrainTransformer(parser.GetOp());
            IList <Tree>         sentences   = new List <Tree>();

            foreach (Pair <string, IFileFilter> description in treebanks)
            {
                log.Info("Reading trees from " + description.first);
                Treebank treebank = parser.GetOp().tlpParams.MemoryTreebank();
                treebank.LoadPath(description.first, description.second);
                treebank = treebank.Transform(transformer);
                Sharpen.Collections.AddAll(sentences, treebank);
            }
            log.Info("Processing " + sentences.Count + " trees");
            IList <Pair <Tree, byte[]> > cache = Generics.NewArrayList();

            transformer = new SynchronizedTreeTransformer(transformer);
            MulticoreWrapper <Tree, Pair <Tree, byte[]> > wrapper = new MulticoreWrapper <Tree, Pair <Tree, byte[]> >(numThreads, new CacheParseHypotheses.CacheProcessor(cacher, parser, dvKBest, transformer));

            foreach (Tree tree in sentences)
            {
                wrapper.Put(tree);
                while (wrapper.Peek())
                {
                    cache.Add(wrapper.Poll());
                    if (cache.Count % 10 == 0)
                    {
                        System.Console.Out.WriteLine("Processed " + cache.Count + " trees");
                    }
                }
            }
            wrapper.Join();
            while (wrapper.Peek())
            {
                cache.Add(wrapper.Poll());
                if (cache.Count % 10 == 0)
                {
                    System.Console.Out.WriteLine("Processed " + cache.Count + " trees");
                }
            }
            System.Console.Out.WriteLine("Finished processing " + cache.Count + " trees");
            IOUtils.WriteObjectToFile(cache, output);
        }