Beispiel #1
0
        public bool SetDefaultNetwork(string netId)
        {
            var net      = trainStorageManager.LoadNet(netId);
            var settings = trainStorageManager.LoadTrainSettings(netId);

            if (net == null || settings == null)
            {
                return(false);
            }

            NetId         = netId;
            Evaluator     = new NetEvaluator(net);
            TrainSettings = settings;
            trainData     = null;
            return(trainStorageManager.SaveDefaultNetName(netId));
        }
Beispiel #2
0
 public bool SaveTrainData(string netId, TrainData trainData)
 {
     return(save(netId, trainData, TRAIN_DATA_FILE_NAME));
 }
Beispiel #3
0
        public void TrainNetwork()
        {
            stopwatch.Reset();
            stopwatch.Start();

            sendMessage("Loading training settings.");
            trainSettings = storageManager.LoadTrainSettings(netId);
            if (trainSettings == null)
            {
                sendMessage("Train settings file was not found.");
                return;
            }

            using (var db = MausrDb.Create()) {
                sendMessage("Initializing learning environment.");
                trainData = new TrainData();

                int inputSize  = trainSettings.InputImgSizePx * trainSettings.InputImgSizePx;
                int outputSize = db.Symbols.Count();

                var layout = new NetLayout(inputSize, trainSettings.HiddenLayersSizes, outputSize);
                network = new Net(layout, new SigomidActivationFunc(), new NetCostFunction());
                network.SetOutputMap(db.Symbols.Select(s => s.SymbolId).ToArray());
                netEvaluator = new NetEvaluator(network);


                var optimizer = createOptimizer(trainSettings);
                var trainer   = new NetTrainer(network, optimizer, trainSettings.RegularizationLambda);

                unpackedCoefs = network.Layout.AllocateCoefMatrices();

                sendMessage("Preparing inputs and outputs.");
                prepareInOut(db);

                sendMessage(string.Format("Learning of {0} samples started.", trainInputs.RowCount));
                bool converged = trainer.TrainBatch(trainInputs, trainSettings.BatchSize, trainSettings.LearnRounds,
                                                    trainOutIds, trainSettings.InitSeed, trainSettings.MinDerivCompMaxMagn, trainIterationCallback, job.CancellationToken);

                if (job.CancellationToken.IsCancellationRequested)
                {
                    sendMessage("Training {0}.", job.Canceled ? "canceled" : "stopped");
                }
                else
                {
                    sendMessage("Training done ({0}converged).", converged ? "" : "not ");
                }

                if (!job.Canceled)
                {
                    sendMessage("Saving trained data.");
                    if (!storageManager.SaveNet(netId, network))
                    {
                        sendMessage("Failed to save the network.");
                    }

                    if (!storageManager.SaveTrainData(netId, trainData))
                    {
                        sendMessage("Failed to save training data.");
                    }

                    sendMessage("Saving results visualization.");
                    if (!createAndSaveResultsVis(3, 0.01, 40))
                    {
                        sendMessage("Failed visualize results.");
                    }
                }

                stopwatch.Stop();
            }

            sendMessage("All done.");
        }