private IGradientBasedOptimizer createOptimizer(TrainSettings trainSettings) { switch (trainSettings.OptimizationAlgorithm) { case OptimizationAlgorithm.BasicGradientDescent: return(new SteepestDescentBasicOptmizer(trainSettings.LearningRate, trainSettings.MomentumStartPerc / 100.0, trainSettings.MomentumEndPerc / 100.0, trainSettings.MaxIteratinosPerBatch)); case OptimizationAlgorithm.NesterovSutskeverGradientDescent: return(new SteepestDescentAdvancedOptmizer(trainSettings.LearningRate, trainSettings.MomentumStartPerc / 100.0, trainSettings.MomentumEndPerc / 100.0, trainSettings.MaxIteratinosPerBatch)); case OptimizationAlgorithm.RpropPlus: return(new RpropPlusOptmizer(trainSettings.RpropInitStep, trainSettings.RpropMaxStep, trainSettings.RpropStepUpMult, trainSettings.RpropStepDownMult, trainSettings.MaxIteratinosPerBatch)); case OptimizationAlgorithm.ImprovedRpropMinus: default: return(new ImprovedRpropMinusOptmizer(trainSettings.RpropInitStep, trainSettings.RpropMaxStep, trainSettings.RpropStepUpMult, trainSettings.RpropStepDownMult, trainSettings.MaxIteratinosPerBatch)); } }
public bool SaveTrainSettings(string netId, TrainSettings settings) { string baseDir = getSafeBaseDirPath(netId); if (baseDir == null) { return(false); } try { using (var stream = File.Create(Path.Combine(baseDir, TRAIN_SETTINGS_FILE_NAME))) { trainSettingsSerializer.Serialize(stream, settings); return(true); } } catch (Exception ex) { return(false); } }
public void TrainNetwork() { stopwatch.Reset(); stopwatch.Start(); sendMessage("Loading training settings."); trainSettings = storageManager.LoadTrainSettings(netId); if (trainSettings == null) { sendMessage("Train settings file was not found."); return; } using (var db = MausrDb.Create()) { sendMessage("Initializing learning environment."); trainData = new TrainData(); int inputSize = trainSettings.InputImgSizePx * trainSettings.InputImgSizePx; int outputSize = db.Symbols.Count(); var layout = new NetLayout(inputSize, trainSettings.HiddenLayersSizes, outputSize); network = new Net(layout, new SigomidActivationFunc(), new NetCostFunction()); network.SetOutputMap(db.Symbols.Select(s => s.SymbolId).ToArray()); netEvaluator = new NetEvaluator(network); var optimizer = createOptimizer(trainSettings); var trainer = new NetTrainer(network, optimizer, trainSettings.RegularizationLambda); unpackedCoefs = network.Layout.AllocateCoefMatrices(); sendMessage("Preparing inputs and outputs."); prepareInOut(db); sendMessage(string.Format("Learning of {0} samples started.", trainInputs.RowCount)); bool converged = trainer.TrainBatch(trainInputs, trainSettings.BatchSize, trainSettings.LearnRounds, trainOutIds, trainSettings.InitSeed, trainSettings.MinDerivCompMaxMagn, trainIterationCallback, job.CancellationToken); if (job.CancellationToken.IsCancellationRequested) { sendMessage("Training {0}.", job.Canceled ? "canceled" : "stopped"); } else { sendMessage("Training done ({0}converged).", converged ? "" : "not "); } if (!job.Canceled) { sendMessage("Saving trained data."); if (!storageManager.SaveNet(netId, network)) { sendMessage("Failed to save the network."); } if (!storageManager.SaveTrainData(netId, trainData)) { sendMessage("Failed to save training data."); } sendMessage("Saving results visualization."); if (!createAndSaveResultsVis(3, 0.01, 40)) { sendMessage("Failed visualize results."); } } stopwatch.Stop(); } sendMessage("All done."); }