private static void Main(string[] args) { var data = DataReader.ReadFromFile("data/iris.data")/*.OrderBy(i => random.Next())*/; foreach(var entry in data) { //TODO: it is suboptimal to calculate min and max every time entry.PetalLength = Normalize(entry.PetalLength, data.Min(i => i.PetalLength), data.Max(i => i.PetalLength)); entry.PetalWidth = Normalize(entry.PetalWidth, data.Min(i => i.PetalWidth), data.Max(i => i.PetalWidth)); entry.SepalLength = Normalize(entry.SepalLength, data.Min(i => i.SepalLength), data.Max(i => i.SepalLength)); entry.SepalWidth = Normalize(entry.SepalWidth, data.Min(i => i.SepalWidth), data.Max(i => i.SepalWidth)); } var network = new Network(new NeuralNet.TransferFunctions.HyperbolicTangentFunction(), true); network.FillNetwork(4, 3, 6); var train = data.Take(TrainSetSize).ToArray(); var inputAndExpectedResuls = train.Select(entry => new InputExpectedResult(entry.AsInput, entry.AsOutput)); var validation = data.Skip(TrainSetSize).ToArray(); var bp = new Backpropagate(network, 0.5, 3); var trainData = inputAndExpectedResuls.ToArray(); int trains = 0; double score = 0; var start = Environment.TickCount; while(score < 90) { trains++; bp.Train(trainData.OrderBy(i => random.Next()).ToArray()); var stats = NetworkValidation.Validate(network, inputAndExpectedResuls, IrisEntry.IsOutputSuccess); score = stats.SuccessPercentage; Console.WriteLine($"{trains,-4}" + stats.ToString()); } Console.WriteLine($"\nTime elapsed: {Environment.TickCount - start}Ms"); Console.WriteLine("Done training"); var trainStats = NetworkValidation.Validate(network, train.Select(entry => new InputExpectedResult(entry.AsInput, entry.AsOutput)), IrisEntry.IsOutputSuccess); var validateStats = NetworkValidation.Validate(network, validation.Select(entry => new InputExpectedResult(entry.AsInput, entry.AsOutput)), IrisEntry.IsOutputSuccess); Console.WriteLine($"{trainStats.ToString()} TRAIN"); Console.WriteLine($"{validateStats.ToString()} VALIDATE"); Console.WriteLine($"{(trainStats + validateStats).ToString()} TOTAL"); Console.ReadKey(); }
public void TestTraining() { var sigmoid = new SigmoidFunction(); var net = new Network(sigmoid, true); net.FillNetwork(2, 2, 2); net.Nodes[0][0].GetOutgoingConnections()[0].Weight = .15; net.Nodes[0][0].GetOutgoingConnections()[1].Weight = .2; net.Nodes[0][1].GetOutgoingConnections()[0].Weight = .25; net.Nodes[0][1].GetOutgoingConnections()[1].Weight = .3; net.Nodes[1][0].GetOutgoingConnections()[0].Weight = .4; net.Nodes[1][0].GetOutgoingConnections()[1].Weight = .45; net.Nodes[1][1].GetOutgoingConnections()[0].Weight = .5; net.Nodes[1][1].GetOutgoingConnections()[1].Weight = .55; var expected = new InputExpectedResult(new double[] { .05, .1 }, new double[] { .01, .99 }); var before = NetworkValidation.Validate(net, new InputExpectedResult[] { expected }, (a, b) => true); var bp = new Backpropagate(net, 0.5); bp.Train(new InputExpectedResult[] { expected }); var after = NetworkValidation.Validate(net, new InputExpectedResult[] { expected }, (a, b) => true); Assert.IsTrue(before.AvgSSE > after.AvgSSE); }