Exemplo n.º 1
0
 private static void Update(INeuralConnection conn, WeightDecayRule rule)
 {
     double w4 = Math.Pow(conn.Weight, 4.0);
     double div = w4 + rule.Cutoff4;
     if (div != 0.0)
     {
         double delta = (w4 / div) * rule.Factor * conn.Weight;
         conn.Weight += delta;
     }
 }
Exemplo n.º 2
0
        private static void Begin()
        {
            var trainingProv = CreateProvider(true, 10000);
            var trainingStrat = new GaussianBatchingStrategy(2.0);
            //var trainingStrat = new MonteCarloBatchingStrategy();
            var trainingBatcher = new ScriptCollectionBatcher(trainingStrat, trainingProv, 250, 250);

            var validProv = CreateProvider(false, 1000);
            var validStrat = new MonteCarloBatchingStrategy();
            var validBatcher = new ScriptCollectionBatcher(validStrat, validProv, 50, 1000);

            trainingBatcher.Initialize();
            validBatcher.Initialize();

            Console.WriteLine("Training samples: " + trainingProv.Count);
            Console.WriteLine("Validation samples: " + validProv.Count);

            // Rules:
            Console.WriteLine("Creating learning rules ...");
            var weightInitRule = new NoisedWeightInitializationRule { Noise = 0.5, IsEnabled = true };
            var decayRule = new WeightDecayRule { Factor = -0.000001, IsEnabled = false };
            //var learningRule = new QuickpropRule { StepSize = 0.01 };
            //var learningRule = new SCGRule();
            //var learningRule = new LMRule();
            //var learningRule = new MetaQSARule { Mode = LearningMode.Stochastic, Momentum = 0.1, StepSizeRange = new DoubleRange(0.0, 0.01), StepSize = 0.005, StochasticAdaptiveStateUpdate = false };
            //var learningRule = new SuperSABRule { Mode = LearningMode.Batch, Momentum = 0.8, StepSizeRange = new DoubleRange(0.0, 0.05), StepSize = 0.01, StochasticAdaptiveStateUpdate = false };
            var learningRule = new SignChangesRule { Mode = LearningMode.Stochastic, Momentum = 0.2, StepSizeRange = new DoubleRange(0.0, 0.001), StepSize = 0.001, StochasticAdaptiveStateUpdate = false };
            //var learningRule = new GradientDescentRule { Mode = LearningMode.Stochastic, Momentum = 0.1, StepSize = 0.0001 };
            //var learningRule = new QSARule();
            //var learningRule = new MAQRule();
            //var learningRule = new AdaptiveAnnealingRule { WeightGenMul = 0.01, AcceptProbMul = 0.01 };
            //var learningRule = new RpropRule { Momentum = 0.01, StepSize = 0.0001 };
            //var learningRule = new CrossEntropyRule { PopulationSize = 50, NumberOfElites = 10, MutationChance = 0.001, MutationStrength = 0.1, DistributionType = DistributionType.Gaussian };
            //var learningRule = new GARule { };

            var wdRule = (ILearningRule)learningRule as IWeightDecayedLearningRule;
            if (wdRule != null)
            {
                wdRule.WeightDecay = new WeightDecay { Factor = -0.0001, IsEnabled = false };
            }

            IterationRepeatPars iterationRepeatPars = new IterationRepeatPars(1, 5);

            // Net:
            Console.WriteLine("Creating Neural Network ...");
            var network = CreateNetwork(trainingProv.InputSize, trainingProv.OutputSize, weightInitRule, decayRule, learningRule);
            var exec = new LearningExecution(network, iterationRepeatPars);
            var items = network.GetItems();
            int nc = items.NodeEntries.Select(e => e.Node).OfType<ActivationNeuron>().Count();
            int sc = items.ConnectionEntries.Select(e => e.Connection).OfType<Synapse>().Count();
            Console.WriteLine("Neurons: {0} Synapses: {1}", nc, sc);

            // Epoch:
            Console.WriteLine("Initializing epoch ...");
            var epoch = new LearningEpoch(exec, trainingBatcher, validBatcher, 1);
            epoch.Initialize();
            epoch.CurrentResult.Updated += (sender, e) => WriteResult(epoch);
            epoch.BestValidationResult.Updated += (sender, e) => vbestNet = network.Clone();

            // Training loop:
            Console.WriteLine("Starting ...");

            bool done = false;
            do
            {
                //CodeBench.By("Epoch").Do = () =>
                //{
                //    epoch.Step();
                //};

                //CodeBench.By("Epoch").WriteToConsole();

                epoch.Step();

                if (Console.KeyAvailable)
                {
                    var key = Console.ReadKey();
                    switch (key.Key)
                    {
                        case ConsoleKey.Escape:
                            done = true;
                            break;
                        case ConsoleKey.S:
                            Save(network.Clone());
                            break;
                        case ConsoleKey.V:
                            if (vbestNet != null) Save(vbestNet);
                            break;
                        case ConsoleKey.NumPad1:
                            Test(network.Clone());
                            break;
                        case ConsoleKey.NumPad2:
                            if (vbestNet != null) Test(vbestNet);
                            break;
                    }
                }
            }
            while (!done);
        }