SetCallback() public method

public SetCallback ( TrainingCallback callback, Object userData ) : void
callback TrainingCallback
userData Object
return void
コード例 #1
0
        static void Main()
        {
            const uint num_layers = 3;
            const uint num_neurons_hidden = 96;
            const float desired_error = 0.00007F;

            using (TrainingData trainData = new TrainingData())
            using (TrainingData testData = new TrainingData())
            {
                trainData.CreateTrainFromCallback(374, 48, 3, TrainingDataCallback);
                testData.CreateTrainFromCallback(594, 48, 3, TestDataCallback);

                // Test Accessor classes
                for (int i = 0; i < trainData.TrainDataLength; i++)
                {
                    Console.Write("Input {0}: ", i);
                    for (int j = 0; j < trainData.InputCount; j++)
                    {
                        Console.Write("{0}, ", trainData.InputAccessor[i][j]);
                    }
                    Console.Write("\nOutput {0}: ", i);
                    for (int j = 0; j < trainData.OutputCount; j++)
                    {
                        Console.Write("{0}, ", trainData.OutputAccessor[i][j]);
                    }
                    Console.WriteLine("");
                }

                for (float momentum = 0.0F; momentum < 0.7F; momentum += 0.1F)
                {
                    Console.WriteLine("============= momentum = {0} =============\n", momentum);
                    using (NeuralNet net = new NeuralNet(NetworkType.LAYER, num_layers, trainData.InputCount, num_neurons_hidden, trainData.OutputCount))
                    {
                        net.SetCallback(TrainingCallback, "Hello!");

                        net.TrainingAlgorithm = TrainingAlgorithm.TRAIN_INCREMENTAL;

                        net.LearningMomentum = momentum;

                        net.TrainOnData(trainData, 20000, 500, desired_error);

                        Console.WriteLine("MSE error on train data: {0}", net.TestData(trainData));
                        Console.WriteLine("MSE error on test data: {0}", net.TestData(testData));
                    }

                }
            }
            Console.ReadKey();
        }
コード例 #2
0
ファイル: XorSample.cs プロジェクト: joelself/FannCSharp
        static void XorTest()
        {
            Console.WriteLine("\nXOR test started.");

            const float learning_rate = 0.7f;
            const uint num_layers = 3;
            const uint num_input = 2;
            const uint num_hidden = 3;
            const uint num_output = 1;
            const float desired_error = 0.001f;
            const uint max_iterations = 300000;
            const uint iterations_between_reports = 1000;

            Console.WriteLine("\nCreating network.");

            using (NeuralNet net = new NeuralNet(NetworkType.LAYER, num_layers, num_input, num_hidden, num_output))
            {
                net.LearningRate = learning_rate;

                net.ActivationSteepnessHidden = 1.0F;
                net.ActivationSteepnessOutput = 1.0F;

                net.ActivationFunctionHidden = ActivationFunction.SIGMOID_SYMMETRIC_STEPWISE;
                net.ActivationFunctionOutput = ActivationFunction.SIGMOID_SYMMETRIC_STEPWISE;

                // Output network type and parameters
                Console.Write("\nNetworkType                         :  ");
                switch (net.NetworkType)
                {
                    case NetworkType.LAYER:
                        Console.WriteLine("LAYER");
                        break;
                    case NetworkType.SHORTCUT:
                        Console.WriteLine("SHORTCUT");
                        break;
                    default:
                        Console.WriteLine("UNKNOWN");
                        break;
                }
                net.PrintParameters();

                Console.WriteLine("\nTraining network.");

                using (TrainingData data = new TrainingData())
                {
                    if (data.ReadTrainFromFile("..\\..\\..\\examples\\xor.data"))
                    {
                        // Initialize and train the network with the data
                        net.InitWeights(data);

                        Console.WriteLine("Max Epochs " + String.Format("{0:D}", max_iterations).PadLeft(8) + ". Desired Error: " + String.Format("{0:F}", desired_error).PadRight(8));
                        net.SetCallback(PrintCallback, null);
                        net.TrainOnData(data, max_iterations, iterations_between_reports, desired_error);

                        Console.WriteLine("\nTesting network.");

                        for (uint i = 0; i < data.TrainDataLength; i++)
                        {
                            // Run the network on the test data
                            DataType[] calc_out = net.Run(data.Input[i]);

                            Console.WriteLine("XOR test ({0}, {1}) -> {2}, should be {3}, difference = {4}",
                                data.InputAccessor[(int)i][0].ToString("+#;-#"),
                                data.InputAccessor[(int)i][1].ToString("+#;-#"),
                                calc_out[0] == 0 ? 0.ToString() : calc_out[0].ToString("+#.#####;-#.#####"),
                                data.OutputAccessor[(int)i][0].ToString("+#;-#"),
                                FannAbs(calc_out[0] - data.Output[i][0]));
                        }

                        Console.WriteLine("\nSaving network.");

                        // Save the network in floating point and fixed point
                        net.Save("..\\..\\..\\examples\\xor_float.net");
                        uint decimal_point = (uint)net.SaveToFixed("..\\..\\..\\examples\\xor_fixed.net");
                        data.SaveTrainToFixed("..\\..\\..\\examples\\xor_fixed.data", decimal_point);

                        Console.WriteLine("\nXOR test completed.");
                    }
                }
            }
        }