コード例 #1
0
        public void Process()
        {
            Console.WriteLine("Please wait, reading MNIST training data.");
            var dir = AppDomain.CurrentDomain.BaseDirectory;
            var trainingReader = LearnDigitsBackprop.LoadMNIST(dir, true, MNIST_DEPTH);
            var validationReader = LearnDigitsBackprop.LoadMNIST(dir, false, MNIST_DEPTH);

            Console.WriteLine("Training set size: " + trainingReader.NumImages);
            Console.WriteLine("Validation set size: " + validationReader.NumImages);

            var inputCount = trainingReader.Data[0].Input.Length;
            var outputCount = trainingReader.Data[0].Ideal.Length;

            var network = new BasicNetwork();
            network.AddLayer(new BasicLayer(null, true, inputCount));
            network.AddLayer(new BasicLayer(new ActivationReLU(), true, 100));
            network.AddLayer(new DropoutLayer(new ActivationReLU(), true, 50, 0.5));
            network.AddLayer(new BasicLayer(new ActivationReLU(), true, 25));
            network.AddLayer(new BasicLayer(new ActivationSoftMax(), false, outputCount));
            network.FinalizeStructure();
            network.Reset();

            // train the neural network
            Console.WriteLine("Training neural network.");
            var train = new BackPropagation(network, trainingReader.Data, 1e-4, 0.9);
            train.L1 = 0;
            train.L2 = 1e-11;

            PerformIterationsClassifyEarlyStop(train, network, validationReader.Data, 5);
        }
コード例 #2
0
ファイル: LearnXORBackprop.cs プロジェクト: legendvijay/aifh
        /// <summary>
        ///     The entry point for this example.  If you would like to make this example
        ///     stand alone, then add to its own project and rename to Main.
        /// </summary>
        /// <param name="args">Not used.</param>
        public static void ExampleMain(string[] args)
        {
            var network = new BasicNetwork();
            network.AddLayer(new BasicLayer(null, true, 2));
            network.AddLayer(new BasicLayer(new ActivationSigmoid(), true, 5));
            network.AddLayer(new BasicLayer(new ActivationSigmoid(), false, 1));
            network.FinalizeStructure();
            network.Reset();

            var trainingData = BasicData.ConvertArrays(XOR_INPUT, XOR_IDEAL);

            // train the neural network
            var train = new BackPropagation(network, trainingData, 0.7, 0.9);

            var epoch = 1;

            do
            {
                train.Iteration();
                Console.WriteLine("Epoch #" + epoch + " Error:" + train.LastError);
                epoch++;
            } while (train.LastError > 0.01);

            // test the neural network
            Console.WriteLine("Neural Network Results:");
            for (var i = 0; i < XOR_INPUT.Length; i++)
            {
                var output = network.ComputeRegression(XOR_INPUT[i]);
                Console.WriteLine(string.Join(",", XOR_INPUT[i])
                                  + ", actual=" + string.Join(",", output)
                                  + ",ideal=" + string.Join(",", XOR_IDEAL[i]));
            }
        }
コード例 #3
0
ファイル: LearnIrisBackprop.cs プロジェクト: legendvijay/aifh
        /// <summary>
        ///     Run the example.
        /// </summary>
        public void Process()
        {
            // read the iris data from the resources
            var assembly = Assembly.GetExecutingAssembly();
            var res = assembly.GetManifestResourceStream("AIFH_Vol3.Resources.iris.csv");

            // did we fail to read the resouce
            if (res == null)
            {
                Console.WriteLine("Can't read iris data from embedded resources.");
                return;
            }

            // load the data
            var istream = new StreamReader(res);
            var ds = DataSet.Load(istream);
            istream.Close();

            // The following ranges are setup for the Iris data set.  If you wish to normalize other files you will
            // need to modify the below function calls other files.
            ds.NormalizeRange(0, 0, 1);
            ds.NormalizeRange(1, 0, 1);
            ds.NormalizeRange(2, 0, 1);
            ds.NormalizeRange(3, 0, 1);
            var species = ds.EncodeOneOfN(4);

            var trainingData = ds.ExtractSupervised(0, 4, 4, 3);

            var network = new BasicNetwork();
            network.AddLayer(new BasicLayer(null, true, 4));
            network.AddLayer(new BasicLayer(new ActivationReLU(), true, 20));
            network.AddLayer(new BasicLayer(new ActivationSoftMax(), false, 3));
            network.FinalizeStructure();
            network.Reset();

            var train = new BackPropagation(network, trainingData, 0.001, 0.9);

            PerformIterations(train, 100000, 0.02, true);
            QueryOneOfN(network, trainingData, species);
        }
コード例 #4
0
ファイル: PredictSunspots.cs プロジェクト: legendvijay/aifh
        public void Process()
        {
            IList<BasicData> trainingData = LoadSunspots();

            BasicNetwork network = new BasicNetwork();
            network.AddLayer(new BasicLayer(null, true, this.INPUT_WINDOW));
            network.AddLayer(new BasicLayer(new ActivationReLU(), true, 50));
            network.AddLayer(new BasicLayer(new ActivationLinear(), false, 1));
            network.FinalizeStructure();
            network.Reset();

            BackPropagation train = new BackPropagation(network, trainingData, 1e-9, 0.5);
            train.BatchSize = 0;

            PerformIterations(train, 100000, 650, true);
            Query(network, trainingData);
        }
コード例 #5
0
        /// <summary>
        ///     Run the example.
        /// </summary>
        public void Process()
        {
            // read the iris data from the resources
            var assembly = Assembly.GetExecutingAssembly();
            var res = assembly.GetManifestResourceStream("AIFH_Vol3.Resources.auto-mpg.data.csv");

            // did we fail to read the resouce
            if (res == null)
            {
                Console.WriteLine("Can't read auto MPG data from embedded resources.");
                return;
            }

            // load the data
            var istream = new StreamReader(res);
            var ds = DataSet.Load(istream);
            istream.Close();

            // The following ranges are setup for the Auto MPG data set.  If you wish to normalize other files you will
            // need to modify the below function calls other files.

            // First remove some columns that we will not use:
            ds.DeleteColumn(8); // Car name
            ds.DeleteColumn(7); // Car origin
            ds.DeleteColumn(6); // Year
            ds.DeleteUnknowns();

            ds.NormalizeZScore(1);
            ds.NormalizeZScore(2);
            ds.NormalizeZScore(3);
            ds.NormalizeZScore(4);
            ds.NormalizeZScore(5);

            var trainingData = ds.ExtractSupervised(1, 4, 0, 1);

            var splitList = DataUtil.Split(trainingData, 0.75);
            trainingData = splitList[0];
            var validationData = splitList[1];

            Console.WriteLine("Size of dataset: " + ds.Count);
            Console.WriteLine("Size of training set: " + trainingData.Count);
            Console.WriteLine("Size of validation set: " + validationData.Count);

            var inputCount = trainingData[0].Input.Length;

            var network = new BasicNetwork();
            network.AddLayer(new BasicLayer(null, true, inputCount));
            network.AddLayer(new BasicLayer(new ActivationReLU(), true, 50));
            network.AddLayer(new BasicLayer(new ActivationReLU(), true, 25));
            network.AddLayer(new BasicLayer(new ActivationReLU(), true, 5));
            network.AddLayer(new BasicLayer(new ActivationLinear(), false, 1));
            network.FinalizeStructure();
            network.Reset();

            var train = new BackPropagation(network, trainingData, 0.000001, 0.9);

            PerformIterationsEarlyStop(train, network, validationData, 20, new ErrorCalculationMSE());
            Query(network, validationData);
        }