Inheritance: IDisposable
Example #1
0
 private int InternalCallback(global::System.IntPtr netPtr, global::System.IntPtr dataPtr, uint max_epochs, uint epochs_between_reports, float desired_error, uint epochs, global::System.IntPtr user_data)
 {
     NeuralNet callbackNet = new NeuralNet(new neural_net(netPtr, false));
     TrainingData callbackData = new TrainingData(new training_data(dataPtr, false));
     GCHandle handle = (GCHandle)user_data;
     return Callback(callbackNet, callbackData, max_epochs, epochs_between_reports, desired_error, epochs, handle.Target as Object);
 }
Example #2
0
        /* Constructor: NeuralNet

            Creates a copy the other NeuralNet.
        */
        public NeuralNet(NeuralNet other)
        {
            net = new neural_net(other.Net.to_fann());
            Outputs = other.Outputs;
        }
Example #3
0
        static void Main()
        {
            const float desired_error = 0.0F;
            uint max_neurons = 30;
            uint neurons_between_reports = 1;
            uint bit_fail_train, bit_fail_test;
            float mse_train, mse_test;
            DataType[] output;
            DataType[] steepness = new DataType[1];
            int multi = 0;
            ActivationFunction[] activation = new ActivationFunction[1];
            TrainingAlgorithm training_algorithm = TrainingAlgorithm.TRAIN_RPROP;

            Console.WriteLine("Reading data.");

            using (TrainingData trainData = new TrainingData("..\\..\\..\\datasets\\parity8.train"))
            using (TrainingData testData = new TrainingData("..\\..\\..\\datasets\\parity8.test"))
            {
                trainData.ScaleTrainData(-1, 1);
                testData.ScaleTrainData(-1, 1);

                Console.WriteLine("Creating network.");

                using (NeuralNet net = new NeuralNet(NetworkType.SHORTCUT, 2, trainData.InputCount, trainData.OutputCount))
                {
                    net.TrainingAlgorithm = training_algorithm;
                    net.ActivationFunctionHidden = ActivationFunction.SIGMOID_SYMMETRIC;
                    net.ActivationFunctionOutput = ActivationFunction.LINEAR;
                    net.TrainErrorFunction = ErrorFunction.ERRORFUNC_LINEAR;

                    if (multi == 0)
                    {
                        steepness[0] = 1;
                        net.CascadeActivationSteepnesses = steepness;

                        activation[0] = ActivationFunction.SIGMOID_SYMMETRIC;

                        net.CascadeActivationFunctions = activation;
                        net.CascadeCandidateGroupsCount = 8;
                    }

                    if (training_algorithm == TrainingAlgorithm.TRAIN_QUICKPROP)
                    {
                        net.LearningRate = 0.35F;
                        net.RandomizeWeights(-2.0F, 2.0F);
                    }

                    net.BitFailLimit = (DataType)0.9;
                    net.TrainStopFunction = StopFunction.STOPFUNC_BIT;
                    net.PrintParameters();

                    net.Save("..\\..\\..\\examples\\cascade_train2.net");

                    Console.WriteLine("Training network.");

                    net.CascadetrainOnData(trainData, max_neurons, neurons_between_reports, desired_error);

                    net.PrintConnections();

                    mse_train = net.TestData(trainData);
                    bit_fail_train = net.BitFail;
                    mse_test = net.TestData(testData);
                    bit_fail_test = net.BitFail;

                    Console.WriteLine("\nTrain error: {0}, Train bit-fail: {1}, Test error: {2}, Test bit-fail: {3}\n",
                                      mse_train, bit_fail_train, mse_test, bit_fail_test);

                    for (int i = 0; i < trainData.TrainDataLength; i++)
                    {
                        output = net.Run(trainData.GetTrainInput((uint)i));
                        if ((trainData.GetTrainOutput((uint)i)[0] >= 0 && output[0] <= 0) ||
                            (trainData.GetTrainOutput((uint)i)[0] <= 0 && output[0] >= 0))
                        {
                            Console.WriteLine("ERROR: {0} does not match {1}", trainData.GetTrainOutput((uint)i)[0], output[0]);
                        }
                    }

                    Console.WriteLine("Saving network.");
                    net.Save("..\\..\\..\\examples\\cascade_train.net");

                    Console.ReadKey();
                }
            }
        }