Exemplo n.º 1
0
        private void performTrainTest(Net net, Matrix <double> inputs, int[] outputIndices,
                                      double learningRate, double regularizationLambda, int maxIters)
        {
            var  optimizer = new SteepestDescentAdvancedOptmizer(learningRate, 0.6, 0.99, maxIters);
            var  trainer   = new NetTrainer(net, optimizer, regularizationLambda);
            bool converged = trainer.Train(inputs, outputIndices, 1e-4, 1, null, CancellationToken.None);

            var actualOutputIndices = trainer.Predict(inputs);

            CollectionAssert.AreEqual(outputIndices, actualOutputIndices.Select(x => x.OutputId).ToArray());
        }
Exemplo n.º 2
0
        public void TrainNetwork()
        {
            stopwatch.Reset();
            stopwatch.Start();

            sendMessage("Loading training settings.");
            trainSettings = storageManager.LoadTrainSettings(netId);
            if (trainSettings == null)
            {
                sendMessage("Train settings file was not found.");
                return;
            }

            using (var db = MausrDb.Create()) {
                sendMessage("Initializing learning environment.");
                trainData = new TrainData();

                int inputSize  = trainSettings.InputImgSizePx * trainSettings.InputImgSizePx;
                int outputSize = db.Symbols.Count();

                var layout = new NetLayout(inputSize, trainSettings.HiddenLayersSizes, outputSize);
                network = new Net(layout, new SigomidActivationFunc(), new NetCostFunction());
                network.SetOutputMap(db.Symbols.Select(s => s.SymbolId).ToArray());
                netEvaluator = new NetEvaluator(network);


                var optimizer = createOptimizer(trainSettings);
                var trainer   = new NetTrainer(network, optimizer, trainSettings.RegularizationLambda);

                unpackedCoefs = network.Layout.AllocateCoefMatrices();

                sendMessage("Preparing inputs and outputs.");
                prepareInOut(db);

                sendMessage(string.Format("Learning of {0} samples started.", trainInputs.RowCount));
                bool converged = trainer.TrainBatch(trainInputs, trainSettings.BatchSize, trainSettings.LearnRounds,
                                                    trainOutIds, trainSettings.InitSeed, trainSettings.MinDerivCompMaxMagn, trainIterationCallback, job.CancellationToken);

                if (job.CancellationToken.IsCancellationRequested)
                {
                    sendMessage("Training {0}.", job.Canceled ? "canceled" : "stopped");
                }
                else
                {
                    sendMessage("Training done ({0}converged).", converged ? "" : "not ");
                }

                if (!job.Canceled)
                {
                    sendMessage("Saving trained data.");
                    if (!storageManager.SaveNet(netId, network))
                    {
                        sendMessage("Failed to save the network.");
                    }

                    if (!storageManager.SaveTrainData(netId, trainData))
                    {
                        sendMessage("Failed to save training data.");
                    }

                    sendMessage("Saving results visualization.");
                    if (!createAndSaveResultsVis(3, 0.01, 40))
                    {
                        sendMessage("Failed visualize results.");
                    }
                }

                stopwatch.Stop();
            }

            sendMessage("All done.");
        }