示例#1
0
        private void button4_Click(object sender, EventArgs e)
        {
            if (trainingData.Count == 0)
            {
                var rnd = new Random();
                for (int i = 0; i < 10000; i++)
                {
                    float[] input  = new float[5];
                    float[] output = new float[5];

                    float j    = (float)rnd.NextDouble() * 5;
                    int   jint = Math.Min(4, (int)Math.Floor(j));
                    input[jint]  = 1.0f;
                    output[jint] = 1.0f;

                    trainingData.Add(new TrainingSuite.TrainingData(input, output));
                }
            }

            var trainingSuite = new TrainingSuite(trainingData);

            trainingSuite.config.miniBatchSize       = 6;
            trainingSuite.config.epochs              = (int)numericUpDown6.Value;
            trainingSuite.config.shuffleTrainingData = false;

            trainingSuite.config.regularization = TrainingSuite.TrainingConfig.Regularization.None;
            trainingSuite.config.costFunction   = new CrossEntropyErrorFunction();

            trainingPromise = solver.Train(trainingSuite, calculator);

            progressBar1.Value = 0;
            label2.Text        = "Training...";
            trainingBegin      = DateTime.Now;
            timer1.Start();
        }
        public TrainingDialog(TrainingPromise _promise)
        {
            InitializeComponent();
            trainingPromise = _promise;

            DispatcherTimer timer = new DispatcherTimer();

            timer.Interval = TimeSpan.FromMilliseconds(updateInterval);
            timer.Tick    += timer_Tick;
            timer.Start();
        }
示例#3
0
        private void Trainingtimer_Tick(object sender, EventArgs e)
        {
            if (progressDialog != null && trainingPromise != null)
            {
                var    timespan = (DateTime.Now - trainingStart);
                string time     = new TimeSpan(timespan.Hours, timespan.Minutes, timespan.Seconds).ToString();

                progressDialog.UpdateResult(trainingPromise.GetTotalProgress(), trainingPromise.IsReady(), "Training... Epochs done: " + trainingPromise.GetEpochsDone(), time);
                if (trainingPromise.IsReady())
                {
                    trainingPromise = null;
                    progressDialog  = null;
                    trainingtimer.Stop();
                }
            }
        }
示例#4
0
        private void Button9_Click(object sender, EventArgs e)
        {
            System.IO.File.WriteAllText("D:\\nntmp\\log.txt", "BEGIN\n");

            string imgFile       = "";
            string labelFile     = "";
            string testImgFile   = "";
            string testLabelFile = "";

            openFileDialog1.Filter = "Image Training data (Image)|*.*";
            openFileDialog1.Title  = "Open Training images file";
            if (openFileDialog1.ShowDialog() == DialogResult.OK)
            {
                imgFile = openFileDialog1.FileName;
            }
            else
            {
                return;
            }

            openFileDialog1.Filter = "Training data (Label)|*.*";
            openFileDialog1.Title  = "Open Training labels file";
            if (openFileDialog1.ShowDialog() == DialogResult.OK)
            {
                labelFile = openFileDialog1.FileName;
            }
            else
            {
                return;
            }

            openFileDialog1.Filter = "Verification Image Training data (Image)|*.*";
            openFileDialog1.Title  = "Open Verification images file";
            if (openFileDialog1.ShowDialog() == DialogResult.OK)
            {
                testImgFile = openFileDialog1.FileName;
            }
            else
            {
                return;
            }

            openFileDialog1.Filter = "Verification Training data (Label)|*.*";
            openFileDialog1.Title  = "Open Verification labels file";
            if (openFileDialog1.ShowDialog() == DialogResult.OK)
            {
                testLabelFile = openFileDialog1.FileName;
            }
            else
            {
                return;
            }

            LoadingWindow wnd = new LoadingWindow();

            wnd.Text = "Loading training data";

            List <TrainingSuite.TrainingData> trainingData = new List <TrainingSuite.TrainingData>();
            List <TrainingSuite.TrainingData> testData     = new List <TrainingSuite.TrainingData>();

            System.Threading.Thread thread = new System.Threading.Thread(() => {
                LoadTestDataFromFiles(trainingData, labelFile, imgFile, (x) => { wnd.SetProgress(x); }, true);
                LoadTestDataFromFiles(testData, testLabelFile, testImgFile, (x) => { wnd.SetProgress(x); }, false);
                wnd.Finish();
            });

            thread.Start();

            if (wnd.ShowDialog() != DialogResult.OK)
            {
                return;
            }


            int success = 0;

            for (int i = 0; i < testData.Count; i++)
            {
                var output = network.Compute(testData[i].input, calculator);

                int resultIdx   = ClassifyOutput(output);
                int expectedIdx = ClassifyOutput(testData[i].desiredOutput);
                if (resultIdx == expectedIdx)
                {
                    ++success;
                }
            }
            network.AttachDescription("Network (" + string.Join(",", network.GetLayerConfig()) + ") epoch: 0 (initial)   Test success rate: [" + success + " of " + testData.Count + "]");
            System.IO.File.WriteAllText("D:\\nntmp\\network_000000.json", network.ExportToJSON());
            VisualizeNetworkSpecific(network, "D:\\nntmp\\network_000000_vis_", 0, ((float)success / testData.Count) * 100.0f);

            var trainingSuite = new TrainingSuite(trainingData);

            trainingSuite.config.miniBatchSize        = (int)numMiniBatchSize.Value;
            trainingSuite.config.learningRate         = (float)numLearningRate.Value;
            trainingSuite.config.regularizationLambda = (float)numLambda.Value;
            trainingSuite.config.shuffleTrainingData  = true;

            if (comboRegularization.SelectedIndex == 0)
            {
                trainingSuite.config.regularization = TrainingSuite.TrainingConfig.Regularization.None;
            }
            else if (comboRegularization.SelectedIndex == 0)
            {
                trainingSuite.config.regularization = TrainingSuite.TrainingConfig.Regularization.L1;
            }
            else if (comboRegularization.SelectedIndex == 0)
            {
                trainingSuite.config.regularization = TrainingSuite.TrainingConfig.Regularization.L2;
            }

            if (comboCostFunction.SelectedIndex == 0)
            {
                trainingSuite.config.costFunction = new MeanSquaredErrorFunction();
            }
            else if (comboCostFunction.SelectedIndex == 1)
            {
                trainingSuite.config.costFunction = new CrossEntropyErrorFunction();
            }
            trainingSuite.config.epochs = 1;

            LogDebug("Initial network saved " + DateTime.Now.ToString());

            for (int epoch = 0; epoch < (int)numEpoch.Value; epoch++)
            {
                LogDebug("Starting epoch #" + (epoch + 1) + "  " + DateTime.Now.ToString());

                trainingPromise = network.Train(trainingSuite, calculator);
                trainingPromise.Await();
                LogDebug("  Training finished  " + DateTime.Now.ToString());

                success = 0;
                for (int i = 0; i < testData.Count; i++)
                {
                    var output = network.Compute(testData[i].input, calculator);

                    int resultIdx   = ClassifyOutput(output);
                    int expectedIdx = ClassifyOutput(testData[i].desiredOutput);
                    if (resultIdx == expectedIdx)
                    {
                        ++success;
                    }
                }
                LogDebug("  Verification finished Success rate: [" + success + " of " + testData.Count + "]  " + DateTime.Now.ToString());

                network.AttachDescription("Network (" + string.Join(",", network.GetLayerConfig()) + ") epoch: " + (epoch + 1) + "    Test success rate: [" + success + " of " + testData.Count + "]");
                System.IO.File.WriteAllText("D:\\nntmp\\network_" + (epoch + 1).ToString().PadLeft(6, '0') + ".json", network.ExportToJSON());
                VisualizeNetworkSpecific(network, "D:\\nntmp\\network_" + (epoch + 1).ToString().PadLeft(6, '0') + "_vis_", epoch + 1, ((float)success / testData.Count) * 100.0f);
                LogDebug("  Saving finished  " + DateTime.Now.ToString());
            }
        }
示例#5
0
        private void button4_Click(object sender, EventArgs e)
        {
            string imgFile   = "";
            string labelFile = "";

            openFileDialog1.Filter = "Image Training data (Image)|*.*";
            openFileDialog1.Title  = "Open Training images file";
            if (openFileDialog1.ShowDialog() == DialogResult.OK)
            {
                imgFile = openFileDialog1.FileName;
            }
            else
            {
                return;
            }

            openFileDialog1.Filter = "Training data (Label)|*.*";
            openFileDialog1.Title  = "Open Training labels file";
            if (openFileDialog1.ShowDialog() == DialogResult.OK)
            {
                labelFile = openFileDialog1.FileName;
            }
            else
            {
                return;
            }

            LoadingWindow wnd = new LoadingWindow();

            wnd.Text = "Loading training data";

            List <TrainingSuite.TrainingData> trainingData = new List <TrainingSuite.TrainingData>();

            System.Threading.Thread thread = new System.Threading.Thread(() => {
                LoadTestDataFromFiles(trainingData, labelFile, imgFile, (x) => { wnd.SetProgress(x); }, true);
                wnd.Finish();
            });

            thread.Start();

            if (wnd.ShowDialog() != DialogResult.OK)
            {
                return;
            }

            var trainingSuite = new TrainingSuite(trainingData);

            trainingSuite.config.miniBatchSize        = (int)numMiniBatchSize.Value;
            trainingSuite.config.learningRate         = (float)numLearningRate.Value;
            trainingSuite.config.regularizationLambda = (float)numLambda.Value;
            trainingSuite.config.shuffleTrainingData  = checkShuffle.Checked;

            if (comboRegularization.SelectedIndex == 0)
            {
                trainingSuite.config.regularization = TrainingSuite.TrainingConfig.Regularization.None;
            }
            else if (comboRegularization.SelectedIndex == 0)
            {
                trainingSuite.config.regularization = TrainingSuite.TrainingConfig.Regularization.L1;
            }
            else if (comboRegularization.SelectedIndex == 0)
            {
                trainingSuite.config.regularization = TrainingSuite.TrainingConfig.Regularization.L2;
            }

            if (comboCostFunction.SelectedIndex == 0)
            {
                trainingSuite.config.costFunction = new MeanSquaredErrorFunction();
            }
            else if (comboCostFunction.SelectedIndex == 1)
            {
                trainingSuite.config.costFunction = new CrossEntropyErrorFunction();
            }

            trainingSuite.config.epochs = (int)numEpoch.Value;

            trainingStart   = DateTime.Now;
            trainingPromise = network.Train(trainingSuite, calculator);
            trainingtimer.Start();


            progressDialog = new TrainingWindow(trainingPromise);
            progressDialog.ShowDialog();
        }