Ejemplo n.º 1
0
        public void CanTransferWeightsInsideNetwrok()
        {
            var o1 = new RMSPropOptimizer <float>(1e-3f);
            var n1 = new LayeredNet <float>(1, 1,
                                            new GruLayer <float>(2, 3),
                                            new LinearLayer <float>(3, 2),
                                            new SoftMaxLayer <float>(2))
            {
                Optimizer = o1
            };

            var n2 = (LayeredNet <float>)n1.Clone();
            var o2 = new RMSPropOptimizer <float>(1e-3f);

            n2.Optimizer = o2;

            n1.UseGpu();
            TrainXor(n1);
            n1.TransferStateToHost();

            TrainXor(n2);

            var w1 = n1.Weights.ToList();
            var w2 = n2.Weights.ToList();

            for (int i = 0; i < w1.Count; i++)
            {
                w1[i].Weight.ShouldMatrixEqualWithinError(w2[i].Weight);
                w1[i].Gradient.ShouldMatrixEqualWithinError(w2[i].Gradient);
                w1[i].Cache1.ShouldMatrixEqualWithinError(w2[i].Cache1);
                w1[i].Cache2.ShouldMatrixEqualWithinError(w2[i].Cache2);
                w1[i].CacheM.ShouldMatrixEqualWithinError(w2[i].CacheM);
            }
        }
Ejemplo n.º 2
0
        public static IOptimizer CreateOptimizer(Options opts)
        {
            // Create optimizer
            IOptimizer optimizer = null;

            if (string.Equals(opts.Optimizer, "Adam", StringComparison.InvariantCultureIgnoreCase))
            {
                optimizer = new AdamOptimizer(opts.GradClip, opts.Beta1, opts.Beta2);
            }
            else
            {
                optimizer = new RMSPropOptimizer(opts.GradClip, opts.Beta1);
            }

            return(optimizer);
        }
Ejemplo n.º 3
0
        public void RMSPropOptimizer_SimpleMultivar()
        {
            // arrange
            var func      = new Mocks.SimpleMultivar();
            var lr        = 0.1D;
            var eps       = 0.2D;
            var gamma     = 0.4D;
            var w         = new double[2][] { new[] { 1.0D, 1.0D }, new[] { 1.0D } };
            var optimizer = new RMSPropOptimizer(gamma, eps);

            // act & assert

            optimizer.Push(w, func.Gradient(w), lr);
            Assert.AreEqual(0.87111518, w[0][0], EPS);
            Assert.AreEqual(0.87596527, w[0][1], EPS);
            Assert.AreEqual(0.87111518, w[1][0], EPS);
            Assert.AreEqual(0.04860721, optimizer.Step2, EPS);

            optimizer.Push(w, func.Gradient(w), lr);
            Assert.AreEqual(0.76683636, w[0][0], EPS);
            Assert.AreEqual(0.77441535, w[0][1], EPS);
            Assert.AreEqual(0.76683636, w[1][0], EPS);
            Assert.AreEqual(0.03206053, optimizer.Step2, EPS);

            optimizer.Push(w, func.Gradient(w), lr);
            Assert.AreEqual(0.67050260, w[0][0], EPS);
            Assert.AreEqual(0.68059867, w[0][1], EPS);
            Assert.AreEqual(0.67050260, w[1][0], EPS);
            Assert.AreEqual(0.02736195, optimizer.Step2, EPS);

            optimizer.Push(w, func.Gradient(w), lr);
            Assert.AreEqual(0.57795893, w[0][0], EPS);
            Assert.AreEqual(0.59072338, w[0][1], EPS);
            Assert.AreEqual(0.57795893, w[1][0], EPS);
            Assert.AreEqual(0.02520623, optimizer.Step2, EPS);

            optimizer.Push(w, func.Gradient(w), lr);
            Assert.AreEqual(0.48793820, w[0][0], EPS);
            Assert.AreEqual(0.50366403, w[0][1], EPS);
            Assert.AreEqual(0.48793820, w[1][0], EPS);
            Assert.AreEqual(0.02378680, optimizer.Step2, EPS);
        }
Ejemplo n.º 4
0
        public void Xor()
        {
            MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance);

            var optimizer = new RMSPropOptimizer <float>(1e-3f);
            var net       = new LayeredNet <float>(1, 1, new AffineLayer <float>(2, 3, AffineActivation.Tanh), new AffineLayer <float>(3, 1, AffineActivation.Tanh)
            {
                ErrorFunction = new MeanSquareError <float>()
            })
            {
                Optimizer = optimizer
            };

            var trainer = new OptimizingTrainer <float>(net, optimizer, new XorDataset(true), new OptimizingTrainerOptions(1)
            {
                ErrorFilterSize    = 0,
                ReportProgress     = new EachIteration(1),
                ReportMesages      = true,
                ProgressWriter     = ConsoleProgressWriter.Instance,
                LearningRateScaler = new ProportionalLearningRateScaler(new EachIteration(1), 9e-5f)
            }, new OptimizingSession("XOR"));

            var runner = ConsoleRunner.Create(trainer, net);

            trainer.TrainReport += (sender, args) =>
            {
                if (args.Errors.Last().RawError < 1e-7f)
                {
                    runner.Stop();
                    Console.WriteLine("Finished training.");
                }
            };

            var gui = new RetiaGui();

            gui.RunAsync(() => new TrainingWindow(new TypedTrainingModel <float>(trainer)));

            runner.Run();
        }
Ejemplo n.º 5
0
        public void Learn(
            [Aliases("b"), Required]        string batchesPath,
            [Aliases("c")]                          string configPath,
            [Aliases("r"), DefaultValue(0.0002f)]  float learningRate,
            [DefaultValue(false)]       bool gpu,
            [DefaultValue(true)]       bool gui)
        {
            MklProvider.TryUseMkl(true, ConsoleProgressWriter.Instance);

            Console.WriteLine($"Loading test set from {batchesPath}");
            _dataProvider.Load(batchesPath);

            var optimizer = new RMSPropOptimizer <float>(learningRate, 0.95f, 0.0f, 0.9f);
            LayeredNet <float> network;

            if (string.IsNullOrEmpty(configPath))
            {
                network           = CreateNetwork(_dataProvider.TrainingSet.InputSize, 128, _dataProvider.TrainingSet.TargetSize);
                network.Optimizer = optimizer;
            }
            else
            {
                network           = CreateNetwork(configPath);
                network.Optimizer = optimizer;
            }
            network.ResetOptimizer();

            if (gpu)
            {
                Console.WriteLine("Brace yourself for GPU!");
                network.UseGpu();
            }

            var trainOptions = new OptimizingTrainerOptions(SEQ_LEN)
            {
                ErrorFilterSize = 100,
                ReportMesages   = true,
                MaxEpoch        = 1000,
                ProgressWriter  = ConsoleProgressWriter.Instance,
                ReportProgress  = new EachIteration(10)
            };

            trainOptions.LearningRateScaler = new ProportionalLearningRateScaler(new ActionSchedule(1, PeriodType.Iteration), 9.9e-5f);

            var session = new OptimizingSession(Path.GetFileNameWithoutExtension(batchesPath));
            var trainer = new OptimizingTrainer <float>(network, optimizer, _dataProvider.TrainingSet, trainOptions, session);

            RetiaGui retiaGui;
            TypedTrainingModel <float> model = null;

            if (gui)
            {
                retiaGui = new RetiaGui();
                retiaGui.RunAsync(() =>
                {
                    model = new TypedTrainingModel <float>(trainer);
                    return(new TrainingWindow(model));
                });
            }

            var epochWatch = new Stopwatch();

            trainer.EpochReached += sess =>
            {
                epochWatch.Stop();
                Console.WriteLine($"Trained epoch in {epochWatch.Elapsed.TotalSeconds} s.");

                // Showcasing plot export
                if (model != null)
                {
                    using (var stream = new MemoryStream())
                    {
                        model.ExportErrorPlot(stream, 600, 400);

                        stream.Seek(0, SeekOrigin.Begin);

                        session.AddFileToReport("ErrorPlots\\plot.png", stream);
                    }
                }

                epochWatch.Restart();
            };
            trainer.PeriodicActions.Add(new UserAction(new ActionSchedule(100, PeriodType.Iteration), () =>
            {
                if (gpu)
                {
                    network.TransferStateToHost();
                }

                string text = TestRNN(network.Clone(1, SEQ_LEN), 500, _dataProvider.Vocab);
                Console.WriteLine(text);
                session.AddFileToReport("Generated\\text.txt", text);

                trainOptions.ProgressWriter.ItemComplete();
            }));

            var runner = ConsoleRunner.Create(trainer, network);

            epochWatch.Start();
            runner.Run();
        }