public static void Train(SentimentModel model, string modelPath, IList <Tree> trainingTrees, IList <Tree> devTrees) { Timing timing = new Timing(); long maxTrainTimeMillis = model.op.trainOptions.maxTrainTimeSeconds * 1000; int debugCycle = 0; // double bestAccuracy = 0.0; // train using AdaGrad (seemed to work best during the dvparser project) double[] sumGradSquare = new double[model.TotalParamSize()]; Arrays.Fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight); int numBatches = trainingTrees.Count / model.op.trainOptions.batchSize + 1; log.Info("Training on " + trainingTrees.Count + " trees in " + numBatches + " batches"); log.Info("Times through each training batch: " + model.op.trainOptions.epochs); for (int epoch = 0; epoch < model.op.trainOptions.epochs; ++epoch) { log.Info("======================================"); log.Info("Starting epoch " + epoch); if (epoch > 0 && model.op.trainOptions.adagradResetFrequency > 0 && (epoch % model.op.trainOptions.adagradResetFrequency == 0)) { log.Info("Resetting adagrad weights to " + model.op.trainOptions.initialAdagradWeight); Arrays.Fill(sumGradSquare, model.op.trainOptions.initialAdagradWeight); } IList <Tree> shuffledSentences = Generics.NewArrayList(trainingTrees); if (model.op.trainOptions.shuffleMatrices) { Java.Util.Collections.Shuffle(shuffledSentences, model.rand); } for (int batch = 0; batch < numBatches; ++batch) { log.Info("======================================"); log.Info("Epoch " + epoch + " batch " + batch); // Each batch will be of the specified batch size, except the // last batch will include any leftover trees at the end of // the list int startTree = batch * model.op.trainOptions.batchSize; int endTree = (batch + 1) * model.op.trainOptions.batchSize; if (endTree > shuffledSentences.Count) { endTree = shuffledSentences.Count; } ExecuteOneTrainingBatch(model, shuffledSentences.SubList(startTree, endTree), sumGradSquare); long totalElapsed = timing.Report(); log.Info("Finished epoch " + epoch + " batch " + batch + "; total training time " + totalElapsed + " ms"); if (maxTrainTimeMillis > 0 && totalElapsed > maxTrainTimeMillis) { // no need to debug output, we're done now break; } if (batch == (numBatches - 1) && model.op.trainOptions.debugOutputEpochs > 0 && (epoch + 1) % model.op.trainOptions.debugOutputEpochs == 0) { double score = 0.0; if (devTrees != null) { Evaluate eval = new Evaluate(model); eval.Eval(devTrees); eval.PrintSummary(); score = eval.ExactNodeAccuracy() * 100.0; } // output an intermediate model if (modelPath != null) { string tempPath; if (modelPath.EndsWith(".ser.gz")) { tempPath = Sharpen.Runtime.Substring(modelPath, 0, modelPath.Length - 7) + "-" + Filename.Format(debugCycle) + "-" + Nf.Format(score) + ".ser.gz"; } else { if (modelPath.EndsWith(".gz")) { tempPath = Sharpen.Runtime.Substring(modelPath, 0, modelPath.Length - 3) + "-" + Filename.Format(debugCycle) + "-" + Nf.Format(score) + ".gz"; } else { tempPath = Sharpen.Runtime.Substring(modelPath, 0, modelPath.Length - 3) + "-" + Filename.Format(debugCycle) + "-" + Nf.Format(score); } } model.SaveSerialized(tempPath); } ++debugCycle; } } long totalElapsed_1 = timing.Report(); if (maxTrainTimeMillis > 0 && totalElapsed_1 > maxTrainTimeMillis) { log.Info("Max training time exceeded, exiting"); break; } } }
public override int DomainDimension() { // TODO: cache this for speed? return(model.TotalParamSize()); }
public static bool RunGradientCheck(SentimentModel model, IList <Tree> trees) { SentimentCostAndGradient gcFunc = new SentimentCostAndGradient(model, trees); return(gcFunc.GradientCheck(model.TotalParamSize(), 50, model.ParamsToVector())); }