public TrainingJob( NetDescription net, float avgErr, float targetErr, float learningRate, float momentum, int itsLeft) { Net = net; AvgError = avgErr; TargetError = targetErr; CurrentLearningRate = learningRate; CurrentMomentum = momentum; IterationsLeft = itsLeft; }
public void TanhNodeCanFigureOutOr() { var description = new NetDescription { Nodes = new[] { new NodeDescription { NodeId = 0, Weight = 0.001f, Aggregator = "sum", Processor = "tanh", Inputs = new [] { new NodeInputDescription { FromInputVector = true, InputId = 0, Weight = -0.001f }, new NodeInputDescription { FromInputVector = true, InputId = 1, Weight = 0.001f } } } }, Outputs = new[] { 0 } }; var net = Net.FromDescription(description); var tests = new[] { Tuple.Create(new[] { 1f, 0f }, new[] { 1f }), Tuple.Create(new[] { 1f, 1f }, new[] { 1f }), Tuple.Create(new[] { 0f, 1f }, new[] { 1f }), Tuple.Create(new[] { 0f, 0f }, new[] { -1f }) }; var trainer = new SimpleTrainer(); var loss = trainer.Train( net: net, tests: tests, desiredError: 0.01f, maxEpochs: 50000, learningRate: 0.5f); Assert.IsTrue(loss < 0.01f); }
public void TanhNetCanBeTrainedOnXOr() { var tests = new[] { Tuple.Create(new[] { 1f, -1f }, new[] { 1f }), Tuple.Create(new[] { 1f, 1f }, new[] { -1f }), Tuple.Create(new[] { -1f, 1f }, new[] { 1f }), Tuple.Create(new[] { -1f, -1f }, new[] { -1f }) }; var initialDescription = new NetDescription { Nodes = new[] { new NodeDescription { NodeId = 0, Aggregator = "sum", Processor = "tanh", Weight = .21f, Inputs = new [] { new NodeInputDescription { FromInputVector = true, InputId = 0, Weight = -.07f }, new NodeInputDescription { FromInputVector = true, InputId = 1, Weight = -.28f } } }, new NodeDescription { NodeId = 1, Aggregator = "sum", Processor = "tanh", Weight = -.29f, Inputs = new [] { new NodeInputDescription { FromInputVector = true, InputId = 0, Weight = .41f }, new NodeInputDescription { FromInputVector = true, InputId = 1, Weight = -.05f } } }, new NodeDescription { NodeId = 2, Aggregator = "sum", Processor = "tanh", Weight = .11f, Inputs = new [] { new NodeInputDescription { FromInputVector = false, InputId = 0, Weight = -.1f }, new NodeInputDescription { FromInputVector = false, InputId = 1, Weight = -.21f } } } }, Outputs = new[] { 2 } }; var net = Net.FromDescription(initialDescription); WeightFiller.FillWeights(net, .05f); var trainer = new SimpleTrainer(); var error = trainer.Train( net: net, tests: tests, desiredError: .001f, maxEpochs: 100000, learningRate: 5f); Console.WriteLine(error); Assert.IsTrue(error < 0.1f); }
public static Net FromDescription(NetDescription description) { int maxNodeId = 0; int maxInputId = 0; int nextWeightId = 0; var nodes = new Dictionary <int, Node>(); foreach (var nodeDescription in description.Nodes) { var node = new Node(nodeDescription.NodeId, nodeDescription.Aggregator, nodeDescription.Processor, nodeDescription.Weight, nextWeightId++); nodes.Add(node.Id, node); if (node.Id > maxNodeId) { maxNodeId = node.Id; } foreach (var input in nodeDescription.Inputs) { node.AddInput(new NodeInput( input.FromInputVector, input.InputId, nextWeightId++, input.Weight)); if (input.FromInputVector && input.InputId > maxInputId) { maxInputId = input.InputId; } } } var outputNodeId = maxNodeId + 1; foreach (var nodeId in description.Outputs) { nodes[nodeId].AddDownstream(outputNodeId); } foreach (var node in nodes.Values) { foreach (var inputNodeId in node.InputNodeNodes) { nodes[inputNodeId].AddDownstream(node.Id); } } var nodesWithNoDependencies = new HashSet <int>(); var orderedNodes = new List <Node>(); var lastCount = nodes.Count + 1; while (orderedNodes.Count < nodes.Count && orderedNodes.Count != lastCount) { lastCount = orderedNodes.Count; foreach (var node in nodes.Values) { if (nodesWithNoDependencies.Contains(node.Id) || node.InputNodeNodes.Any(inNodeId => !nodesWithNoDependencies.Contains(inNodeId))) { continue; } nodesWithNoDependencies.Add(node.Id); orderedNodes.Add(node); } } if (orderedNodes.Count < nodes.Count) { throw new Exception("Circular dependency in nodes"); } return(new Net(maxInputId + 1, nextWeightId, orderedNodes, description.Outputs, outputNodeId)); }