コード例 #1
0
        private IGradientBasedOptimizer createOptimizer(TrainSettings trainSettings)
        {
            switch (trainSettings.OptimizationAlgorithm)
            {
            case OptimizationAlgorithm.BasicGradientDescent:
                return(new SteepestDescentBasicOptmizer(trainSettings.LearningRate,
                                                        trainSettings.MomentumStartPerc / 100.0, trainSettings.MomentumEndPerc / 100.0,
                                                        trainSettings.MaxIteratinosPerBatch));

            case OptimizationAlgorithm.NesterovSutskeverGradientDescent:
                return(new SteepestDescentAdvancedOptmizer(trainSettings.LearningRate,
                                                           trainSettings.MomentumStartPerc / 100.0, trainSettings.MomentumEndPerc / 100.0,
                                                           trainSettings.MaxIteratinosPerBatch));

            case OptimizationAlgorithm.RpropPlus:
                return(new RpropPlusOptmizer(trainSettings.RpropInitStep, trainSettings.RpropMaxStep,
                                             trainSettings.RpropStepUpMult, trainSettings.RpropStepDownMult,
                                             trainSettings.MaxIteratinosPerBatch));

            case OptimizationAlgorithm.ImprovedRpropMinus:
            default:
                return(new ImprovedRpropMinusOptmizer(trainSettings.RpropInitStep, trainSettings.RpropMaxStep,
                                                      trainSettings.RpropStepUpMult, trainSettings.RpropStepDownMult,
                                                      trainSettings.MaxIteratinosPerBatch));
            }
        }
コード例 #2
0
        public bool SaveTrainSettings(string netId, TrainSettings settings)
        {
            string baseDir = getSafeBaseDirPath(netId);

            if (baseDir == null)
            {
                return(false);
            }

            try {
                using (var stream = File.Create(Path.Combine(baseDir, TRAIN_SETTINGS_FILE_NAME))) {
                    trainSettingsSerializer.Serialize(stream, settings);
                    return(true);
                }
            }
            catch (Exception ex) {
                return(false);
            }
        }
コード例 #3
0
        public void TrainNetwork()
        {
            stopwatch.Reset();
            stopwatch.Start();

            sendMessage("Loading training settings.");
            trainSettings = storageManager.LoadTrainSettings(netId);
            if (trainSettings == null)
            {
                sendMessage("Train settings file was not found.");
                return;
            }

            using (var db = MausrDb.Create()) {
                sendMessage("Initializing learning environment.");
                trainData = new TrainData();

                int inputSize  = trainSettings.InputImgSizePx * trainSettings.InputImgSizePx;
                int outputSize = db.Symbols.Count();

                var layout = new NetLayout(inputSize, trainSettings.HiddenLayersSizes, outputSize);
                network = new Net(layout, new SigomidActivationFunc(), new NetCostFunction());
                network.SetOutputMap(db.Symbols.Select(s => s.SymbolId).ToArray());
                netEvaluator = new NetEvaluator(network);


                var optimizer = createOptimizer(trainSettings);
                var trainer   = new NetTrainer(network, optimizer, trainSettings.RegularizationLambda);

                unpackedCoefs = network.Layout.AllocateCoefMatrices();

                sendMessage("Preparing inputs and outputs.");
                prepareInOut(db);

                sendMessage(string.Format("Learning of {0} samples started.", trainInputs.RowCount));
                bool converged = trainer.TrainBatch(trainInputs, trainSettings.BatchSize, trainSettings.LearnRounds,
                                                    trainOutIds, trainSettings.InitSeed, trainSettings.MinDerivCompMaxMagn, trainIterationCallback, job.CancellationToken);

                if (job.CancellationToken.IsCancellationRequested)
                {
                    sendMessage("Training {0}.", job.Canceled ? "canceled" : "stopped");
                }
                else
                {
                    sendMessage("Training done ({0}converged).", converged ? "" : "not ");
                }

                if (!job.Canceled)
                {
                    sendMessage("Saving trained data.");
                    if (!storageManager.SaveNet(netId, network))
                    {
                        sendMessage("Failed to save the network.");
                    }

                    if (!storageManager.SaveTrainData(netId, trainData))
                    {
                        sendMessage("Failed to save training data.");
                    }

                    sendMessage("Saving results visualization.");
                    if (!createAndSaveResultsVis(3, 0.01, 40))
                    {
                        sendMessage("Failed visualize results.");
                    }
                }

                stopwatch.Stop();
            }

            sendMessage("All done.");
        }