コード例 #1
0
 public PerceptronModel(ShiftReduceOptions op, IIndex <ITransition> transitionIndex, ICollection <string> knownStates, ICollection <string> rootStates, ICollection <string> rootOnlyStates)
     : base(op, transitionIndex, knownStates, rootStates, rootOnlyStates)
 {
     // Serializable
     this.featureWeights = Generics.NewHashMap();
     string[] classes = op.featureFactoryClass.Split(";");
     if (classes.Length == 1)
     {
         this.featureFactory = ReflectionLoading.LoadByReflection(classes[0]);
     }
     else
     {
         FeatureFactory[] factories = new FeatureFactory[classes.Length];
         for (int i = 0; i < classes.Length; ++i)
         {
             int paren = classes[i].IndexOf('(');
             if (paren >= 0)
             {
                 string arg = Sharpen.Runtime.Substring(classes[i], paren + 1, classes[i].Length - 1);
                 factories[i] = ReflectionLoading.LoadByReflection(Sharpen.Runtime.Substring(classes[i], 0, paren), arg);
             }
             else
             {
                 factories[i] = ReflectionLoading.LoadByReflection(classes[i]);
             }
         }
         this.featureFactory = new CombinationFeatureFactory(factories);
     }
 }
コード例 #2
0
 public BaseModel(Edu.Stanford.Nlp.Parser.Shiftreduce.BaseModel other)
 {
     this.op = other.op;
     this.transitionIndex = other.transitionIndex;
     this.knownStates     = other.knownStates;
     this.rootStates      = other.rootStates;
     this.rootOnlyStates  = other.rootOnlyStates;
 }
コード例 #3
0
 public BaseModel(ShiftReduceOptions op, IIndex <ITransition> transitionIndex, ICollection <string> knownStates, ICollection <string> rootStates, ICollection <string> rootOnlyStates)
 {
     // This is shared with the owning ShiftReduceParser (for now, at least)
     // the set of goal categories of a reduce = the set of phrasal categories in a grammar
     this.transitionIndex = transitionIndex;
     this.op             = op;
     this.knownStates    = knownStates;
     this.rootStates     = rootStates;
     this.rootOnlyStates = rootOnlyStates;
 }
コード例 #4
0
        public static ShiftReduceOptions BuildTrainingOptions(string tlppClass, string[] args)
        {
            ShiftReduceOptions op = new ShiftReduceOptions();

            op.SetOptions("-forceTags", "-debugOutputFrequency", "1", "-quietEvaluation");
            if (tlppClass != null)
            {
                op.tlpParams = ReflectionLoading.LoadByReflection(tlppClass);
            }
            op.SetOptions(args);
            if (op.trainOptions.randomSeed == 0)
            {
                op.trainOptions.randomSeed = Runtime.NanoTime();
                log.Info("Random seed not set by options, using " + op.trainOptions.randomSeed);
            }
            return(op);
        }
コード例 #5
0
 public ReorderingOracle(ShiftReduceOptions op)
 {
     this.op = op;
 }
コード例 #6
0
 public ShiftReduceParser(ShiftReduceOptions op, BaseModel model)
 {
     this.op    = op;
     this.model = model;
 }
コード例 #7
0
 public ShiftReduceParser(ShiftReduceOptions op)
     : this(op, null)
 {
 }
コード例 #8
0
        public static void Main(string[] args)
        {
            IList <string> remainingArgs = Generics.NewArrayList();
            IList <Pair <string, IFileFilter> > trainTreebankPath = null;
            Pair <string, IFileFilter>          testTreebankPath  = null;
            Pair <string, IFileFilter>          devTreebankPath   = null;
            string serializedPath   = null;
            string tlppClass        = null;
            string continueTraining = null;

            for (int argIndex = 0; argIndex < args.Length;)
            {
                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-trainTreebank"))
                {
                    if (trainTreebankPath == null)
                    {
                        trainTreebankPath = Generics.NewArrayList();
                    }
                    trainTreebankPath.Add(ArgUtils.GetTreebankDescription(args, argIndex, "-trainTreebank"));
                    argIndex = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                }
                else
                {
                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-testTreebank"))
                    {
                        testTreebankPath = ArgUtils.GetTreebankDescription(args, argIndex, "-testTreebank");
                        argIndex         = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                    }
                    else
                    {
                        if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-devTreebank"))
                        {
                            devTreebankPath = ArgUtils.GetTreebankDescription(args, argIndex, "-devTreebank");
                            argIndex        = argIndex + ArgUtils.NumSubArgs(args, argIndex) + 1;
                        }
                        else
                        {
                            if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-serializedPath") || Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-model"))
                            {
                                serializedPath = args[argIndex + 1];
                                argIndex      += 2;
                            }
                            else
                            {
                                if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-tlpp"))
                                {
                                    tlppClass = args[argIndex + 1];
                                    argIndex += 2;
                                }
                                else
                                {
                                    if (Sharpen.Runtime.EqualsIgnoreCase(args[argIndex], "-continueTraining"))
                                    {
                                        continueTraining = args[argIndex + 1];
                                        argIndex        += 2;
                                    }
                                    else
                                    {
                                        remainingArgs.Add(args[argIndex]);
                                        ++argIndex;
                                    }
                                }
                            }
                        }
                    }
                }
            }
            string[] newArgs = new string[remainingArgs.Count];
            newArgs = Sharpen.Collections.ToArray(remainingArgs, newArgs);
            if (trainTreebankPath == null && serializedPath == null)
            {
                throw new ArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath");
            }
            ShiftReduceParser parser = null;

            if (trainTreebankPath != null)
            {
                log.Info("Training ShiftReduceParser");
                log.Info("Initial arguments:");
                log.Info("   " + StringUtils.Join(args));
                if (continueTraining != null)
                {
                    parser = ((ShiftReduceParser)ShiftReduceParser.LoadModel(continueTraining, ArrayUtils.Concatenate(ForceTags, newArgs)));
                }
                else
                {
                    ShiftReduceOptions op = BuildTrainingOptions(tlppClass, newArgs);
                    parser = new ShiftReduceParser(op);
                }
                parser.Train(trainTreebankPath, devTreebankPath, serializedPath);
                parser.SaveModel(serializedPath);
            }
            if (serializedPath != null && parser == null)
            {
                parser = ((ShiftReduceParser)ShiftReduceParser.LoadModel(serializedPath, ArrayUtils.Concatenate(ForceTags, newArgs)));
            }
            //parser.outputStats();
            if (testTreebankPath != null)
            {
                log.Info("Loading test trees from " + testTreebankPath.First());
                Treebank testTreebank = parser.op.tlpParams.MemoryTreebank();
                testTreebank.LoadPath(testTreebankPath.First(), testTreebankPath.Second());
                log.Info("Loaded " + testTreebank.Count + " trees");
                EvaluateTreebank evaluator = new EvaluateTreebank(parser.op, null, parser);
                evaluator.TestOnTreebank(testTreebank);
            }
        }