예제 #1
0
        private static void Main()
        {
            SigmaEnvironment.EnableLogging();
            SigmaEnvironment sigma = SigmaEnvironment.Create("sigma_demo");

            // create a new mnist trainer
            string   name    = DemoMode.Name;
            ITrainer trainer = DemoMode.CreateTrainer(sigma);

            trainer.AddLocalHook(new MetricProcessorHook <INDArray>("network.layers.*.weights", (a, h) => h.Divide(h.Sum(a), a.Length), "shared.network_weights_average"));
            trainer.AddLocalHook(new MetricProcessorHook <INDArray>("network.layers.*.weights", (a, h) => h.StandardDeviation(a), "shared.network_weights_stddev"));
            trainer.AddLocalHook(new MetricProcessorHook <INDArray>("network.layers.*.biases", (a, h) => h.Divide(h.Sum(a), a.Length), "shared.network_biases_average"));
            trainer.AddLocalHook(new MetricProcessorHook <INDArray>("network.layers.*.biases", (a, h) => h.StandardDeviation(a), "shared.network_biases_stddev"));
            trainer.AddLocalHook(new MetricProcessorHook <INDArray>("optimiser.updates", (a, h) => h.Divide(h.Sum(a), a.Length), "shared.optimiser_updates_average"));
            trainer.AddLocalHook(new MetricProcessorHook <INDArray>("optimiser.updates", (a, h) => h.StandardDeviation(a), "shared.optimiser_updates_stddev"));
            trainer.AddLocalHook(new MetricProcessorHook <INDArray>("network.layers.*<external_output>._outputs.default.activations", (a, h) => h.Divide(h.Sum(a), a.Length), "shared.network_activations_mean"));

            // create and attach a new UI framework
            WPFMonitor gui = sigma.AddMonitor(new WPFMonitor(name, DemoMode.Language));

            gui.ColourManager.Dark         = DemoMode.Dark;
            gui.ColourManager.PrimaryColor = DemoMode.PrimarySwatch;

            StatusBarLegendInfo iris    = new StatusBarLegendInfo(name, MaterialColour.Blue);
            StatusBarLegendInfo general = new StatusBarLegendInfo("General", MaterialColour.Grey);

            gui.AddLegend(iris);
            gui.AddLegend(general);

            // create a tab
            gui.AddTabs("Overview", "Metrics", "Validation", "Maximisation", "Reproduction", "Update");

            // access the window inside the ui thread
            gui.WindowDispatcher(window =>
            {
                // enable initialisation
                window.IsInitializing = true;

                window.TabControl["Metrics"].GridSize      = new GridSize(2, 4);
                window.TabControl["Validation"].GridSize   = new GridSize(2, 5);
                window.TabControl["Maximisation"].GridSize = new GridSize(2, 5);
                window.TabControl["Reproduction"].GridSize = new GridSize(2, 5);
                window.TabControl["Update"].GridSize       = new GridSize(1, 1);

                window.TabControl["Overview"].GridSize.Rows    -= 1;
                window.TabControl["Overview"].GridSize.Columns -= 1;

                // add a panel that controls the learning process
                window.TabControl["Overview"].AddCumulativePanel(new ControlPanel("Control", trainer), legend: iris);

                ITimeStep reportTimeStep = DemoMode.Slow ? TimeStep.Every(1, TimeScale.Iteration) : TimeStep.Every(10, TimeScale.Epoch);

                var cost1 = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Cost / Epoch", trainer, "optimiser.cost_total", TimeStep.Every(1, TimeScale.Epoch)).Linearify();
                var cost2 = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Cost / Epoch", trainer, "optimiser.cost_total", reportTimeStep);

                var weightAverage = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Mean of Weights / Epoch", trainer, "shared.network_weights_average", reportTimeStep, averageMode: true).Linearify();
                var weightStddev  = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Standard Deviation of Weights / Epoch", trainer, "shared.network_weights_stddev", reportTimeStep, averageMode: true).Linearify();
                var biasesAverage = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Mean of Biases / Epoch", trainer, "shared.network_biases_average", reportTimeStep, averageMode: true).Linearify();
                var biasesStddev  = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Standard Deviation of Biases / Epoch", trainer, "shared.network_biases_stddev", reportTimeStep, averageMode: true).Linearify();
                var updateAverage = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Mean of Parameter Updates / Epoch", trainer, "shared.optimiser_updates_average", reportTimeStep, averageMode: true).Linearify();
                var updateStddev  = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Standard Deviation of Parameter Updates / Epoch", trainer, "shared.optimiser_updates_stddev", reportTimeStep, averageMode: true).Linearify();

                var outputActivationsMean = CreateChartPanel <CartesianChart, GLineSeries, GearedValues <double>, double>("Mean of Output Activations", trainer, "shared.network_activations_mean", reportTimeStep, averageMode: true).Linearify();

                AccuracyPanel accuracy1 = null, accuracy2 = null;
                if (DemoMode != DemoType.Wdbc && DemoMode != DemoType.Parkinsons)
                {
                    accuracy1 = new AccuracyPanel("Validation Accuracy", trainer, DemoMode.Slow ? TimeStep.Every(1, TimeScale.Epoch) : reportTimeStep, null, 1, 2);
                    accuracy1.Fast().Linearify();
                    accuracy2 = new AccuracyPanel("Validation Accuracy", trainer, DemoMode.Slow ? TimeStep.Every(1, TimeScale.Epoch) : reportTimeStep, null, 1, 2);
                    accuracy2.Fast().Linearify();
                }

                IRegistry regTest = new Registry();
                regTest.Add("test", DateTime.Now);

                var parameter = new ParameterPanel("Parameters", sigma, window);
                parameter.Add("Time", typeof(DateTime), regTest, "test");

                ValueSourceReporter valueHook = new ValueSourceReporter(TimeStep.Every(1, TimeScale.Epoch), "optimiser.cost_total");
                trainer.AddGlobalHook(valueHook);
                sigma.SynchronisationHandler.AddSynchronisationSource(valueHook);

                var costBlock = (UserControlParameterVisualiser)parameter.Content.Add("Cost", typeof(double), trainer.Operator.Registry, "optimiser.cost_total");
                costBlock.AutoPollValues(trainer, TimeStep.Every(1, TimeScale.Epoch));

                var learningBlock = (UserControlParameterVisualiser)parameter.Content.Add("Learning rate", typeof(double), trainer.Operator.Registry, "optimiser.learning_rate");
                learningBlock.AutoPollValues(trainer, TimeStep.Every(1, TimeScale.Epoch));

                var paramCount = (UserControlParameterVisualiser)parameter.Content.Add("Parameter count", typeof(long), trainer.Operator.Registry, "network.parameter_count");
                paramCount.AutoPollValues(trainer, TimeStep.Every(1, TimeScale.Start));

                window.TabControl["Overview"].AddCumulativePanel(cost1, 1, 2, legend: iris);
                window.TabControl["Overview"].AddCumulativePanel(parameter);
                //window.TabControl["Overview"].AddCumulativePanel(accuracy1, 1, 2, legend: iris);

                //window.TabControl["Metrics"].AddCumulativePanel(cost2, legend: iris);
                //window.TabControl["Metrics"].AddCumulativePanel(weightAverage, legend: iris);
                //window.TabControl["Metrics"].AddCumulativePanel(biasesAverage, legend: iris);
                window.TabControl["Update"].AddCumulativePanel(updateAverage, legend: iris);
                if (accuracy2 != null)
                {
                    window.TabControl["Metrics"].AddCumulativePanel(accuracy2, legend: iris);
                }

                window.TabControl["Metrics"].AddCumulativePanel(weightStddev, legend: iris);
                window.TabControl["Metrics"].AddCumulativePanel(biasesStddev, legend: iris);
                window.TabControl["Metrics"].AddCumulativePanel(updateStddev, legend: iris);
                window.TabControl["Metrics"].AddCumulativePanel(outputActivationsMean, legend: iris);

                if (DemoMode == DemoType.Mnist)
                {
                    NumberPanel outputpanel = new NumberPanel("Numbers", trainer);
                    DrawPanel drawPanel     = new DrawPanel("Draw", trainer, 560, 560, 20, outputpanel);

                    window.TabControl["Validation"].AddCumulativePanel(drawPanel, 2, 3);
                    window.TabControl["Validation"].AddCumulativePanel(outputpanel, 2);

                    window.TabControl["Validation"].AddCumulativePanel(weightAverage);
                    window.TabControl["Validation"].AddCumulativePanel(biasesAverage);

                    for (int i = 0; i < 10; i++)
                    {
                        window.TabControl["Maximisation"].AddCumulativePanel(new MnistBitmapHookPanel($"Target Maximisation {i}", i, trainer, TimeStep.Every(1, TimeScale.Epoch)));
                    }
                }

                if (DemoMode == DemoType.TicTacToe)
                {
                    window.TabControl["Overview"].AddCumulativePanel(new TicTacToePanel("Play TicTacToe!", trainer));
                }

                //for (int i = 0; i < 10; i++)
                //{
                //	window.TabControl["Reproduction"].AddCumulativePanel(new MnistBitmapHookPanel($"Target Maximisation 7-{i}", 8, 28, 28, trainer, TimeStep.Every(1, TimeScale.Start)));
                //}
            });

            if (DemoMode == DemoType.Mnist)
            {
                sigma.AddMonitor(new HttpMonitor("http://+:8080/sigma/"));
            }

            // the operators should not run instantly but when the user clicks play
            sigma.StartOperatorsOnRun = false;

            sigma.Prepare();

            sigma.RunAsync();

            gui.WindowDispatcher(window => window.IsInitializing = false);
        }