public virtual bool RunGradientCheck(IList <Tree> sentences, IdentityHashMap <Tree, byte[]> compressedParses)
        {
            log.Info("Gradient check: converting " + sentences.Count + " compressed trees");
            IdentityHashMap <Tree, IList <Tree> > topParses = CacheParseHypotheses.ConvertToTrees(sentences, compressedParses, op.trainOptions.trainingThreads);

            log.Info("Done converting trees");
            DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(sentences, topParses, dvModel, op);

            return(gcFunc.GradientCheck(1000, 50, dvModel.ParamsToVector()));
        }
        public virtual void FilterRulesForBatch(IDictionary <Tree, byte[]> compressedTrees)
        {
            TwoDimensionalSet <string, string> binaryRules = TwoDimensionalSet.TreeSet();
            ICollection <string> unaryRules = new HashSet <string>();
            ICollection <string> words      = new HashSet <string>();

            foreach (KeyValuePair <Tree, byte[]> entry in compressedTrees)
            {
                SearchRulesForBatch(binaryRules, unaryRules, words, entry.Key);
                foreach (Tree hypothesis in CacheParseHypotheses.ConvertToTrees(entry.Value))
                {
                    SearchRulesForBatch(binaryRules, unaryRules, words, hypothesis);
                }
            }
            FilterRulesForBatch(binaryRules, unaryRules, words);
        }
Ejemplo n.º 3
0
            public virtual Pair <Tree, byte[]> Process(Tree tree)
            {
                IList <Tree> topParses = DVParser.GetTopParsesForOneTree(parser, dvKBest, tree, transformer);
                // this block is a test to make sure the conversion code is working...
                IList <Tree> converted  = CacheParseHypotheses.ConvertToTrees(cacher.ConvertToBytes(topParses));
                IList <Tree> simplified = CollectionUtils.TransformAsList(topParses, cacher.treeBasicCategories);

                simplified = CollectionUtils.FilterAsList(simplified, cacher.treeFilter);
                if (simplified.Count != topParses.Count)
                {
                    log.Info("Filtered " + (topParses.Count - simplified.Count) + " trees");
                    if (simplified.Count == 0)
                    {
                        log.Info(" WARNING: filtered all trees for " + tree);
                    }
                }
                if (!simplified.Equals(converted))
                {
                    if (converted.Count != simplified.Count)
                    {
                        throw new AssertionError("horrible error: tree sizes not equal, " + converted.Count + " vs " + simplified.Count);
                    }
                    for (int i = 0; i < converted.Count; ++i)
                    {
                        if (!simplified[i].Equals(converted[i]))
                        {
                            System.Console.Out.WriteLine("=============================");
                            System.Console.Out.WriteLine(simplified[i]);
                            System.Console.Out.WriteLine("=============================");
                            System.Console.Out.WriteLine(converted[i]);
                            System.Console.Out.WriteLine("=============================");
                            throw new AssertionError("horrible error: tree " + i + " not equal for base tree " + tree);
                        }
                    }
                }
                return(Pair.MakePair(tree, cacher.ConvertToBytes(topParses)));
            }
        public virtual void SetRulesForTrainingSet(IList <Tree> sentences, IDictionary <Tree, byte[]> compressedTrees)
        {
            TwoDimensionalSet <string, string> binaryRules = TwoDimensionalSet.TreeSet();
            ICollection <string> unaryRules = new HashSet <string>();
            ICollection <string> words      = new HashSet <string>();

            foreach (Tree sentence in sentences)
            {
                SearchRulesForBatch(binaryRules, unaryRules, words, sentence);
                foreach (Tree hypothesis in CacheParseHypotheses.ConvertToTrees(compressedTrees[sentence]))
                {
                    SearchRulesForBatch(binaryRules, unaryRules, words, hypothesis);
                }
            }
            foreach (Pair <string, string> binary in binaryRules)
            {
                AddRandomBinaryMatrix(binary.first, binary.second);
            }
            foreach (string unary in unaryRules)
            {
                AddRandomUnaryMatrix(unary);
            }
            FilterRulesForBatch(binaryRules, unaryRules, words);
        }
        public virtual void ExecuteOneTrainingBatch(IList <Tree> trainingBatch, IdentityHashMap <Tree, byte[]> compressedParses, double[] sumGradSquare)
        {
            Timing convertTiming = new Timing();

            convertTiming.Doing("Converting trees");
            IdentityHashMap <Tree, IList <Tree> > topParses = CacheParseHypotheses.ConvertToTrees(trainingBatch, compressedParses, op.trainOptions.trainingThreads);

            convertTiming.Done();
            DVParserCostAndGradient gcFunc = new DVParserCostAndGradient(trainingBatch, topParses, dvModel, op);

            double[] theta = dvModel.ParamsToVector();
            switch (Minimizer)
            {
            case (1):
            {
                //maxFuncIter = 10;
                // 1: QNMinimizer, 2: SGD
                QNMinimizer qn = new QNMinimizer(op.trainOptions.qnEstimates, true);
                qn.UseMinPackSearch();
                qn.UseDiagonalScaling();
                qn.TerminateOnAverageImprovement(true);
                qn.TerminateOnNumericalZero(true);
                qn.TerminateOnRelativeNorm(true);
                theta = qn.Minimize(gcFunc, op.trainOptions.qnTolerance, theta, op.trainOptions.qnIterationsPerBatch);
                break;
            }

            case 2:
            {
                //Minimizer smd = new SGDMinimizer();       double tol = 1e-4;      theta = smd.minimize(gcFunc,tol,theta,op.trainOptions.qnIterationsPerBatch);
                double lastCost  = 0;
                double currCost  = 0;
                bool   firstTime = true;
                for (int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++)
                {
                    //gcFunc.calculate(theta);
                    double[] grad = gcFunc.DerivativeAt(theta);
                    currCost = gcFunc.ValueAt(theta);
                    log.Info("batch cost: " + currCost);
                    //          if(!firstTime){
                    //              if(currCost > lastCost){
                    //                  System.out.println("HOW IS FUNCTION VALUE INCREASING????!!! ... still updating theta");
                    //              }
                    //              if(Math.abs(currCost - lastCost) < 0.0001){
                    //                  System.out.println("function value is not decreasing. stop");
                    //              }
                    //          }else{
                    //              firstTime = false;
                    //          }
                    lastCost = currCost;
                    ArrayMath.AddMultInPlace(theta, grad, -1 * op.trainOptions.learningRate);
                }
                break;
            }

            case 3:
            {
                // AdaGrad
                double eps      = 1e-3;
                double currCost = 0;
                for (int i = 0; i < op.trainOptions.qnIterationsPerBatch; i++)
                {
                    double[] gradf = gcFunc.DerivativeAt(theta);
                    currCost = gcFunc.ValueAt(theta);
                    log.Info("batch cost: " + currCost);
                    for (int feature = 0; feature < gradf.Length; feature++)
                    {
                        sumGradSquare[feature] = sumGradSquare[feature] + gradf[feature] * gradf[feature];
                        theta[feature]         = theta[feature] - (op.trainOptions.learningRate * gradf[feature] / (System.Math.Sqrt(sumGradSquare[feature]) + eps));
                    }
                }
                break;
            }

            default:
            {
                throw new ArgumentException("Unsupported minimizer " + Minimizer);
            }
            }
            dvModel.VectorToParams(theta);
        }