コード例 #1
0
        public static void Train(SentimentModel model, string modelPath, IList <Tree> trainingTrees, IList <Tree> devTrees)
        {
            Timing timing             = new Timing();
            long   maxTrainTimeMillis = model.op.trainOptions.maxTrainTimeSeconds * 1000;
            int    debugCycle         = 0;

            // double bestAccuracy = 0.0;
            // train using AdaGrad (seemed to work best during the dvparser project)
            double[] sumGradSquare = new double[model.TotalParamSize()];
            Arrays.Fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight);
            int numBatches = trainingTrees.Count / model.op.trainOptions.batchSize + 1;

            log.Info("Training on " + trainingTrees.Count + " trees in " + numBatches + " batches");
            log.Info("Times through each training batch: " + model.op.trainOptions.epochs);
            for (int epoch = 0; epoch < model.op.trainOptions.epochs; ++epoch)
            {
                log.Info("======================================");
                log.Info("Starting epoch " + epoch);
                if (epoch > 0 && model.op.trainOptions.adagradResetFrequency > 0 && (epoch % model.op.trainOptions.adagradResetFrequency == 0))
                {
                    log.Info("Resetting adagrad weights to " + model.op.trainOptions.initialAdagradWeight);
                    Arrays.Fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight);
                }
                IList <Tree> shuffledSentences = Generics.NewArrayList(trainingTrees);
                if (model.op.trainOptions.shuffleMatrices)
                {
                    Java.Util.Collections.Shuffle(shuffledSentences, model.rand);
                }
                for (int batch = 0; batch < numBatches; ++batch)
                {
                    log.Info("======================================");
                    log.Info("Epoch " + epoch + " batch " + batch);
                    // Each batch will be of the specified batch size, except the
                    // last batch will include any leftover trees at the end of
                    // the list
                    int startTree = batch * model.op.trainOptions.batchSize;
                    int endTree   = (batch + 1) * model.op.trainOptions.batchSize;
                    if (endTree > shuffledSentences.Count)
                    {
                        endTree = shuffledSentences.Count;
                    }
                    ExecuteOneTrainingBatch(model, shuffledSentences.SubList(startTree, endTree), sumGradSquare);
                    long totalElapsed = timing.Report();
                    log.Info("Finished epoch " + epoch + " batch " + batch + "; total training time " + totalElapsed + " ms");
                    if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis)
                    {
                        // no need to debug output, we're done now
                        break;
                    }
                    if (batch == (numBatches - 1) && model.op.trainOptions.debugOutputEpochs > 0 && (epoch + 1) % model.op.trainOptions.debugOutputEpochs == 0)
                    {
                        double score = 0.0;
                        if (devTrees != null)
                        {
                            Evaluate eval = new Evaluate(model);
                            eval.Eval(devTrees);
                            eval.PrintSummary();
                            score = eval.ExactNodeAccuracy() * 100.0;
                        }
                        // output an intermediate model
                        if (modelPath != null)
                        {
                            string tempPath;
                            if (modelPath.EndsWith(".ser.gz"))
                            {
                                tempPath = Sharpen.Runtime.Substring(modelPath, 0, modelPath.Length - 7) + "-" + Filename.Format(debugCycle) + "-" + Nf.Format(score) + ".ser.gz";
                            }
                            else
                            {
                                if (modelPath.EndsWith(".gz"))
                                {
                                    tempPath = Sharpen.Runtime.Substring(modelPath, 0, modelPath.Length - 3) + "-" + Filename.Format(debugCycle) + "-" + Nf.Format(score) + ".gz";
                                }
                                else
                                {
                                    tempPath = Sharpen.Runtime.Substring(modelPath, 0, modelPath.Length - 3) + "-" + Filename.Format(debugCycle) + "-" + Nf.Format(score);
                                }
                            }
                            model.SaveSerialized(tempPath);
                        }
                        ++debugCycle;
                    }
                }
                long totalElapsed_1 = timing.Report();
                if (maxTrainTimeMillis > 0 && totalElapsed_1 > maxTrainTimeMillis)
                {
                    log.Info("Max training time exceeded, exiting");
                    break;
                }
            }
        }
 public override int DomainDimension()
 {
     // TODO: cache this for speed?
     return(model.TotalParamSize());
 }
コード例 #3
0
        public static bool RunGradientCheck(SentimentModel model, IList <Tree> trees)
        {
            SentimentCostAndGradient gcFunc = new SentimentCostAndGradient(model, trees);

            return(gcFunc.GradientCheck(model.TotalParamSize(), 50, model.ParamsToVector()));
        }