public void RunTest() { // Example from Edwin Chen, Introduction to Restricted Boltzmann Machines // http://blog.echen.me/2011/07/18/introduction-to-restricted-boltzmann-machines/ double[][] inputs = { new double[] { 1,1,1,0,0,0 }, new double[] { 1,0,1,0,0,0 }, new double[] { 1,1,1,0,0,0 }, new double[] { 0,0,1,1,1,0 }, new double[] { 0,0,1,1,0,0 }, new double[] { 0,0,1,1,1,0 } }; BernoulliFunction activation = new BernoulliFunction(); BernoulliFunction.Random = new ThreadSafeRandom(2); RestrictedBoltzmannMachine network = new RestrictedBoltzmannMachine(activation, 6, 2); network.Hidden.Neurons[0].Weights[0] = 0.00461421; network.Hidden.Neurons[0].Weights[1] = 0.04337112; network.Hidden.Neurons[0].Weights[2] = -0.10839599; network.Hidden.Neurons[0].Weights[3] = -0.06234004; network.Hidden.Neurons[0].Weights[4] = -0.03017057; network.Hidden.Neurons[0].Weights[5] = 0.09520391; network.Hidden.Neurons[0].Threshold = 0; network.Hidden.Neurons[1].Weights[0] = 0.08263872; network.Hidden.Neurons[1].Weights[1] = -0.118437; network.Hidden.Neurons[1].Weights[2] = -0.21710971; network.Hidden.Neurons[1].Weights[3] = 0.02332903; network.Hidden.Neurons[1].Weights[4] = 0.00953116; network.Hidden.Neurons[1].Weights[5] = 0.09870652; network.Hidden.Neurons[1].Threshold = 0; network.Visible.Neurons[0].Threshold = 0; network.Visible.Neurons[1].Threshold = 0; network.Visible.Neurons[2].Threshold = 0; network.Visible.Neurons[3].Threshold = 0; network.Visible.Neurons[4].Threshold = 0; network.Visible.Neurons[5].Threshold = 0; network.Visible.CopyReversedWeightsFrom(network.Hidden); ContrastiveDivergenceLearning target = new ContrastiveDivergenceLearning(network); target.Momentum = 0; target.LearningRate = 0.1; target.Decay = 0; int iterations = 5000; double[] errors = new double[iterations]; for (int i = 0; i < iterations; i++) errors[i] = target.RunEpoch(inputs); double startError = errors[0]; double lastError = errors[iterations - 1]; Assert.IsTrue(startError > lastError); Assert.AreEqual(9.5400234262580224, startError); Assert.AreEqual(1.3364496250348414, lastError, 1e-10); { double[] output = network.GenerateOutput(new double[] { 0, 0, 0, 1, 1, 0 }); Assert.AreEqual(2, output.Length); Assert.AreEqual(0, output[0]); Assert.AreEqual(1, output[1]); } { double[] output = network.GenerateOutput(new double[] { 1, 1, 1, 0, 0, 0 }); Assert.AreEqual(2, output.Length); Assert.AreEqual(1, output[0]); Assert.AreEqual(0, output[1]); } }