Exemple #1
0
        public TrainingResults Train(Matrix<double> trainingData, Matrix<double> trainingClasses, Matrix<double> validation, Matrix<double> validationClasses, int epochs = 500)
        {
            var ret = new TrainingResults();

            for (int epoch = 0; epoch < epochs; epoch++)
            {
                var newWeights = CalcNewWeights(trainingData, trainingClasses);
                hiddenLayer = newWeights.Item1;
                outputLayer = newWeights.Item2;
                if (epoch % 10 == 0)
                {
                    double s, p;

                    CalcSquaredError(trainingData, trainingClasses, out s, out p);
                    ret.TrainingSquaredError.Add(s);
                    ret.TrainingError.Add(p);

                    if (validation != null && validationClasses != null)
                    {
                        CalcSquaredError(validation, validationClasses, out s, out p);
                        ret.ValidationSquaredError.Add(s);
                        ret.ValidationError.Add(p);
                    }

                    if (epoch % 100 == 0)
                        Console.WriteLine("Epoch {0}, Error {1}, {2}", epoch, s, p * 100);
                }
            }

            return ret;
        }
        public void Train()
        {
            new TestScheduler().With(
                async scheduler =>
            {
                viewModel      = new TrainingViewModel(dataHandler.Object, dataViewModel.Object, trainingManager.Object, manager);
                var collection = new ObservableCollection <TrainedTreeData>();
                var tree       = TrainedTreeData.ConstructFromDocuments(new DocumentSet {
                    Document = new[] { new DocumentDefinition {
                                           Labels = new[] { "test" }
                                       } }
                });
                collection.Add(tree);
                dataViewModel.Setup(item => item.Result).Returns(documentSet);
                dataViewModel.Setup(item => item.SelectedItems).Returns(collection);

                trainingManager.Setup(item => item.Train(It.IsAny <DocumentSet>(), It.IsAny <TrainingHeader>(), It.IsAny <CancellationToken>()))
                .Returns(Task.FromResult(training));
                TrainingResults result = null;
                viewModel.Perform.Subscribe(
                    results =>
                {
                    result = results;
                });

                scheduler.AdvanceByMs(500);
                await viewModel.Perform.Execute();
                scheduler.AdvanceByMs(500);
                Assert.AreEqual(training, result);
            });
        }
Exemple #3
0
        protected void ProcessingTrainingResult(TrainingResults result)
        {
            switch (result)
            {
            case TrainingResults.CONVERGENCE_NOT_REACHED:
            {
                MessageBox.Show(this, "Convergence to precision target not reached.", "Training result",
                                MessageBoxButtons.OK, MessageBoxIcon.Warning);
                break;
            }

            case TrainingResults.INTERNAL_ERROR:
            {
                MessageBox.Show(this, "Internal error occurred.", "Training result", MessageBoxButtons.OK,
                                MessageBoxIcon.Error);
                break;
            }

            case TrainingResults.MODEL_INVALID:
            {
                MessageBox.Show(this, "The data set is not valid.", "Training result", MessageBoxButtons.OK,
                                MessageBoxIcon.Error);
                break;
            }

            case TrainingResults.NOT_VALIDATED:
            {
                MessageBox.Show(this,
                                string.Format(
                                    "Convergence has been reached. Network trained! But not validated, no validation data set present in the model.{0}Trained in {1} - Mse: {2}",
                                    Environment.NewLine, core.TrainingEpoch(), core.LastMse().ToString("F5")), "Training result",
                                MessageBoxButtons.OK, MessageBoxIcon.Warning);
                break;
            }

            case TrainingResults.TRAINING_SUCCESS:
            {
                MessageBox.Show(this,
                                string.Format(
                                    "Convergence has been reached. Network trained and validated!{0}Trained in {1} - Mse: {2}",
                                    Environment.NewLine, core.TrainingEpoch(), core.LastMse().ToString("F5"), "Training result",
                                    MessageBoxButtons.OK, MessageBoxIcon.Exclamation));
                break;
            }

            case TrainingResults.VALIDATION_FAIL:
            {
                MessageBox.Show(this,
                                string.Format(
                                    "Convergence has been reached but the network hasn't passed the validation test.{0}Trained in {1} - Mse: {2}",
                                    Environment.NewLine, core.TrainingEpoch(), core.LastMse().ToString("F5")), "Training result",
                                MessageBoxButtons.OK, MessageBoxIcon.Warning);
                break;
            }
            }
        }
 public void Setup()
 {
     dataViewModel      = new Mock <IDataSelectViewModel>();
     trainingViewModel  = new Mock <ITrainingViewModel>();
     seletableViewModel = new Mock <ISelectableViewModel>();
     dataViewModel.Setup(item => item.Select).Returns(seletableViewModel.Object);
     fileMonitorFactory = new Mock <IFileMonitorFactory>();
     fileFolder         = new Mock <IFolderBrowserDialogService>();
     training           = TestConstants.GetTrainingResults();
 }
Exemple #5
0
 public IFileMonitor Create(string path, TrainingResults training)
 {
     Guard.NotNullOrEmpty(() => path, path);
     Guard.NotNull(() => training, training);
     logger.Debug("Create <{0}>", path);
     return(new FileMonitor(
                new FileWatcher(path, documentParser.Supported),
                new LearnedClassifier(documentParser, new SvmTestClient(training.DataSet, training.Model)),
                new PdfPreviewCreator()));
 }
Exemple #6
0
        protected void DoEpoch(IEnumerable <DataSetItem> trainItems)
        {
            StartEpoch();
            foreach (DataSetItem item in trainItems)
            {
                TrainingResults result = Network.Train(item);

                ItemTrained(result);
                UpdateSnapshot();
            }
        }
 public void Setup()
 {
     manager              = new StateManager();
     training             = TestConstants.GetTrainingResults();
     dataViewModel        = new Mock <IDataSelectViewModel>();
     trainingManager      = new Mock <ITrainingManager>();
     dataHandler          = new Mock <IDataHandler <TrainingResults> >();
     documentSet          = new DocumentSet();
     documentSet.Document = new[] { new DocumentDefinition {
                                        Labels = new[] { "Test" }
                                    } };
 }
Exemple #8
0
        private void ItemTrained(TrainingResults result)
        {
            if (result.Correct)
            {
                epochCorrect++;
            }
            epochTotal++;

            if (epochTotal % 10 == 0)
            {
                UpdateStatus();
            }
        }
Exemple #9
0
        public void Construct()
        {
            var arff   = ArffDataSet.CreateSimple("Test");
            var header = TrainingHeader.CreateDefault();
            var model  = new Model();

            Assert.Throws <ArgumentNullException>(() => new TrainingResults(null, header, arff));
            Assert.Throws <ArgumentNullException>(() => new TrainingResults(model, null, arff));
            Assert.Throws <ArgumentNullException>(() => new TrainingResults(model, header, null));
            var instance = new TrainingResults(model, header, arff);

            Assert.IsNotNull(instance.Header);
            Assert.IsNotNull(instance.Model);
        }
Exemple #10
0
 void Awake()
 {
     // PlayerPrefs.DeleteAll();
     mInstance = this;
     DontDestroyOnLoad(this);
     //  serverManager = GetComponent<ServerManager>();
     userData            = GetComponent <UserData>();
     trainingData        = GetComponent <TrainingData>();
     capitulosData       = GetComponent <CapitulosData>();
     trainingResults     = GetComponent <TrainingResults>();
     dateData            = GetComponent <DateData>();
     resultsData         = GetComponent <ResultsData>();
     usersData           = GetComponent <UsersData>();
     firebaseAuthManager = GetComponent <FirebaseAuthManager>();
 }
        public void Setup()
        {
            var dataSet = ArffDataSet.CreateSimple("Test");
            var model   = new Model();

            model.NumberOfClasses           = 2;
            model.ClassLabels               = null;
            model.NumberOfSVPerClass        = null;
            model.PairwiseProbabilityA      = null;
            model.PairwiseProbabilityB      = null;
            model.SupportVectorCoefficients = new double[1][];
            model.Rho       = new double[1];
            model.Rho[0]    = 0;
            model.Parameter = new Parameter();
            instance        = new TrainingResults(model, TrainingHeader.CreateDefault(), dataSet);
        }
Exemple #12
0
        protected void toolStripMenuItemClickHandler(object sender, EventArgs e)
        {
            try
            {
                if (string.Equals((sender as ToolStripMenuItem).Name, "generateExampleLayoutToolStripMenuItem"))
                {
                    var dialog = new SaveFileDialog();
                    dialog.Filter = "csv|*.csv";
                    dialog.Title  = "Generate a set layout";
                    dialog.ShowDialog(this);

                    if (string.IsNullOrEmpty(dialog.FileName))
                    {
                        return;
                    }

                    var text = Export.GenerateLayout();

                    using (var w = new StreamWriter(dialog.FileName))
                    {
                        w.Write(text);
                    }
                }
                if (string.Equals((sender as ToolStripMenuItem).Name, "cSVToolStripMenuItem"))
                {
                    // Load data set from a csv file
                    OpenFileDialog dialog = new OpenFileDialog();
                    dialog.Filter = "csv|*.csv";
                    dialog.Title  = "Set file";
                    dialog.ShowDialog();
                    string path = dialog.FileName;

                    if (string.IsNullOrWhiteSpace(path))
                    {
                        return;
                    }

                    GuiBehavior(AppStatus.BUSY);

                    int linesCount = SetLoader.LoadFromCsv(path);

                    GuiBehavior(AppStatus.READY);

                    logger.InfoFormat("Set loaded: {0}, read {1} lines", dialog.FileName, linesCount);
                }
                if (string.Equals((sender as ToolStripMenuItem).Name, "setToolStripMenuItem"))
                {
                    // Open set gui to view/edit the patterns
                    SetUi setUi = new SetUi();
                    setUi.ShowDialog(this);
                }
                if (string.Equals((sender as ToolStripMenuItem).Name, "configurationToolStripMenuItem"))
                {
                    // Open the options gui
                    ModelOptionsUi configGUI = new ModelOptionsUi();
                    configGUI.ShowDialog(this);

                    // Rebuild network
                    AppCoreSetup();
                }
                if (string.Equals((sender as ToolStripMenuItem).Name, "startWorkBenchToolStripMenuItem"))
                {
                    GuiBehavior(AppStatus.BUSY);

                    TrainingResults result = TrainingResults.INTERNAL_ERROR;

                    ThreadStart worker = delegate
                    {
                        // Run network training
                        result = core.Train(NinjectBinding.GetKernel.Get <SetModel>()).Result;
                    };

                    Thread waitTraining = new Thread(worker);
                    waitTraining.Start();

                    while (waitTraining.IsAlive)
                    {
                        Application.DoEvents();
                    }

                    ProcessingTrainingResult(result);

                    GuiBehavior(AppStatus.READY);
                }
                if (string.Equals((sender as ToolStripMenuItem).Name, "exitToolStripMenuItem"))
                {
                    // Dispose modules

                    // Before exit release the application resources
                    this.Dispose();
                    Application.Exit();
                }
                if (string.Equals((sender as ToolStripMenuItem).Name, "exportToolStripMenuItem"))
                {
                    var dialog = new SaveFileDialog();
                    dialog.Filter = "csv|*.csv";
                    dialog.Title  = "Export serie to file";
                    dialog.ShowDialog(this);

                    if (string.IsNullOrEmpty(dialog.FileName))
                    {
                        return;
                    }

                    using (StreamWriter w = new StreamWriter(dialog.FileName))
                    {
                        switch (this.tabCtrl.SelectedIndex)
                        {
                        case 0:
                        {
                            Export.SeriesToFile(w, core.CurrentMse.ToArray());
                            break;
                        }

                        case 1:
                        {
                            Export.SeriesToFile(w, core.CurrentWeights.ToArray());
                            break;
                        }

                        case 2:
                        {
                            Export.TargetOutputsToFile(w, core.CurrentOutputs);
                            break;
                        }
                        }
                    }
                }
            }
            catch (Exception exception)
            {
                ExceptionManager.LogAndShowException(exception, "Error", logger);
            }
        }
        private void ItemTrained(TrainingResults result)
        {
            if (result.Correct)
            {
                epochCorrect++;
            }
            epochTotal++;

            if (epochTotal % 10 == 0)
            {
                UpdateStatus();
            }
        }