Exemplo n.º 1
0
        public void TanhNodeCanFigureOutOr()
        {
            var description = new NetDescription
            {
                Nodes = new[] {
                    new NodeDescription
                    {
                        NodeId     = 0, Weight = 0.001f,
                        Aggregator = "sum", Processor = "tanh",
                        Inputs     = new []
                        {
                            new NodeInputDescription
                            {
                                FromInputVector = true, InputId = 0, Weight = -0.001f
                            },
                            new NodeInputDescription
                            {
                                FromInputVector = true, InputId = 1, Weight = 0.001f
                            }
                        }
                    }
                },
                Outputs = new[] { 0 }
            };

            var net = Net.FromDescription(description);

            var tests = new[]
            {
                Tuple.Create(new[] { 1f, 0f }, new[] { 1f }),
                Tuple.Create(new[] { 1f, 1f }, new[] { 1f }),
                Tuple.Create(new[] { 0f, 1f }, new[] { 1f }),
                Tuple.Create(new[] { 0f, 0f }, new[] { -1f })
            };

            var trainer = new SimpleTrainer();
            var loss    = trainer.Train(
                net: net,
                tests: tests,
                desiredError: 0.01f,
                maxEpochs: 50000,
                learningRate: 0.5f);

            Assert.IsTrue(loss < 0.01f);
        }
Exemplo n.º 2
0
        public void TurtleZebraGiraffe()
        {
            var tests = new[]
            {
                Tuple.Create(new[] { 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f, 1f }, new[] { 0f, 0f, 1f }),
                Tuple.Create(new[] { 1f, -1f, -1f, -1f, -1f, -1f, -1f, -1f, -1f, -1f }, new[] { 1f, 0f, 0f }),
                Tuple.Create(new[] { 1f, 1f, 1f, 1f, -1f, -1f, -1f, -1f, -1f, -1f }, new[] { 0f, 1f, 0f })
            };
            var description = SimpleDescriptionBuilder.GetDescription(10, new[] { 5, 5, 5, 3 });

            foreach (var node in description.Nodes.Where(node => description.Outputs.Contains(node.NodeId)))
            {
                node.Processor = "sigmoid";
            }
            var net = Net.FromDescription(description);

            WeightFiller.FillWeights(net, .005f);

            var trainer = new SimpleTrainer();
            var error   = trainer.Train(
                net: net,
                tests: tests,
                desiredError: 0.01f,
                maxEpochs: 20000,
                learningRate: .5f);

            Assert.IsTrue(error < .01f);

            var eval = net.GetEvaluationFunction();

            foreach (var test in tests)
            {
                var output = eval(test.Item1);
                Console.WriteLine(string.Join(",", output));
            }
        }
Exemplo n.º 3
0
        public void TanhNetCanBeTrainedOnXOr()
        {
            var tests = new[]
            {
                Tuple.Create(new[] { 1f, -1f }, new[] { 1f }),
                Tuple.Create(new[] { 1f, 1f }, new[] { -1f }),
                Tuple.Create(new[] { -1f, 1f }, new[] { 1f }),
                Tuple.Create(new[] { -1f, -1f }, new[] { -1f })
            };

            var initialDescription = new NetDescription
            {
                Nodes = new[]
                {
                    new NodeDescription
                    {
                        NodeId = 0, Aggregator = "sum", Processor = "tanh",
                        Weight = .21f, Inputs = new []
                        {
                            new NodeInputDescription
                            {
                                FromInputVector = true, InputId = 0, Weight = -.07f
                            },
                            new NodeInputDescription
                            {
                                FromInputVector = true, InputId = 1, Weight = -.28f
                            }
                        }
                    },
                    new NodeDescription
                    {
                        NodeId = 1, Aggregator = "sum", Processor = "tanh",
                        Weight = -.29f, Inputs = new []
                        {
                            new NodeInputDescription
                            {
                                FromInputVector = true, InputId = 0, Weight = .41f
                            },
                            new NodeInputDescription
                            {
                                FromInputVector = true, InputId = 1, Weight = -.05f
                            }
                        }
                    },
                    new NodeDescription
                    {
                        NodeId = 2, Aggregator = "sum", Processor = "tanh",
                        Weight = .11f, Inputs = new []
                        {
                            new NodeInputDescription
                            {
                                FromInputVector = false, InputId = 0, Weight = -.1f
                            },
                            new NodeInputDescription
                            {
                                FromInputVector = false, InputId = 1, Weight = -.21f
                            }
                        }
                    }
                },
                Outputs = new[] { 2 }
            };

            var net = Net.FromDescription(initialDescription);

            WeightFiller.FillWeights(net, .05f);
            var trainer = new SimpleTrainer();

            var error = trainer.Train(
                net: net,
                tests: tests,
                desiredError: .001f,
                maxEpochs: 100000,
                learningRate: 5f);

            Console.WriteLine(error);
            Assert.IsTrue(error < 0.1f);
        }
Exemplo n.º 4
0
        static void Main(string[] args)
        {
            var events = ReadEventFile("trainingEvents.csv");
            var ins    = events
                         .Where(e => (e.Open - e.NextLow) / e.Open > .1f)
                         .Select(evts => evts.GetInputArray());

            Console.WriteLine($"Qualified Events: {ins.Count()}");
            var unsupervisedTests = ins.Select(i => Tuple.Create(i, i));
            var supervisedTests   = events.Select(evt => Tuple.Create(evt.GetInputArray(), evt.GetOutputArray()));

            var builder     = new LayerBuilder();
            var description = builder.BuildDescription(5, new[]
            {
                new LayerBuilder.LayerSpec(10, "sum", "softplus"),
                new LayerBuilder.LayerSpec(10, "sum", "softplus"),
                new LayerBuilder.LayerSpec(10, "sum", "tanh"),
                new LayerBuilder.LayerSpec(10, "sum", "softplus"),
                new LayerBuilder.LayerSpec(10, "sum", "softplus"),
                new LayerBuilder.LayerSpec(5, "sum", null)
            });
            //var description = builder.BuildDescription(5, new[]
            //{
            //    new LayerBuilder.LayerSpec(5, "sum", "softplus"),
            //    new LayerBuilder.LayerSpec(6, "sum", "tanh"),
            //    new LayerBuilder.LayerSpec(4, "sum", "softplus"),
            //    new LayerBuilder.LayerSpec(5, "sum", null)
            //});
            var net     = Net.FromDescription(description);
            var trainer = new SimpleTrainer();

            trainer.Train(
                net: net,
                tests: unsupervisedTests,
                desiredError: 5e-6f,
                maxEpochs: 200,
                learningRate: 0.75f);
            trainer.Train(
                net: net,
                tests: unsupervisedTests,
                desiredError: 5e-6f,
                maxEpochs: 9500,
                learningRate: 0.5f);
            trainer.Train(
                net: net,
                tests: unsupervisedTests,
                desiredError: 1e-8f,
                maxEpochs: 200,
                learningRate: .25f);
            trainer.Train(
                net: net,
                tests: unsupervisedTests,
                desiredError: 1e-8f,
                maxEpochs: 200,
                learningRate: .125f);
            trainer.Train(
                net: net,
                tests: unsupervisedTests,
                desiredError: 1e-8f,
                maxEpochs: 200,
                learningRate: .0625f);


            for (var i = 0; i < 5; i++)
            {
                var nextDescription = net.Description;
                var firstSigmoidId  = nextDescription.Nodes.First(n => n.Processor == "tanh").NodeId;
                nextDescription.Nodes = nextDescription.Nodes.Where(n => n.NodeId != firstSigmoidId).ToArray();
                foreach (var node in nextDescription.Nodes)
                {
                    node.Inputs = node.Inputs
                                  .Where(inp => inp.FromInputVector || inp.InputId != firstSigmoidId)
                                  .ToArray();
                }
                net = Net.FromDescription(nextDescription);

                Console.WriteLine($"Removed {i + 1} sigmoids");

                trainer.Train(
                    net: net,
                    tests: unsupervisedTests,
                    desiredError: 1e-6f,
                    maxEpochs: 400,
                    learningRate: 0.5f);
            }

            trainer.Train(
                net: net,
                tests: unsupervisedTests,
                desiredError: 1e-6f,
                maxEpochs: 500,
                learningRate: 0.5f);
            trainer.Train(
                net: net,
                tests: unsupervisedTests,
                desiredError: 1e-6f,
                maxEpochs: 400,
                learningRate: 0.25f);

            var finalDescription = net.Description;
            var outsRemoved      = new[] { finalDescription.Outputs[1], finalDescription.Outputs[2], finalDescription.Outputs[3], finalDescription.Outputs[4] };

            finalDescription.Nodes = finalDescription.Nodes
                                     .Where(n => !outsRemoved.Contains(n.NodeId))
                                     .ToArray();
            finalDescription.Outputs = new[] { finalDescription.Outputs[0] };
            var nextNet = Net.FromDescription(finalDescription);


            trainer.Train(
                net: nextNet,
                tests: supervisedTests,
                desiredError: 1e-8f,
                maxEpochs: 20000,
                learningRate: .25f);
            trainer.Train(
                net: nextNet,
                tests: supervisedTests,
                desiredError: 1e-8f,
                maxEpochs: 20000,
                learningRate: .125f);
            trainer.Train(
                net: nextNet,
                tests: supervisedTests,
                desiredError: 1e-8f,
                maxEpochs: 2000,
                learningRate: .0625f);
            trainer.Train(
                net: nextNet,
                tests: supervisedTests,
                desiredError: 1e-8f,
                maxEpochs: 2000,
                learningRate: .03125f);
            var final     = net.Description;
            var finalText = JsonConvert.SerializeObject(final);

            using (var writer = File.CreateText("out3.json"))
            {
                writer.Write(finalText);
            }
            Console.WriteLine();

            Console.ReadLine();
        }