public void newNetworkFromFile(string fileName) {
     _ann = new NeuralNetwork(fileName);
 }
示例#2
0
        static void NetworkTraining(int inputSize, int width, int height, int itterations)
        {
            int outputSize = DataSample.labelSize;

            string databasePath1 = @"C:/Users/Mikkel/Documents/MTA18434_SemesterProject/ML_Sound_Samples/Assets/Resources/SampleDatabase/Database.json";

            string networkPath1 = Path.GetDirectoryName(Assembly.GetEntryAssembly().Location) + "/NetworkSave01.json";


            if (File.Exists(networkPath1))
            {
                network = NeuralNetwork.LoadNetwork(networkPath1);

                if (network.inputSize != inputSize)
                {
                    throw new Exception();
                }

                if (network.hiddenDimension != width)
                {
                    throw new Exception();
                }

                if (network.hiddenSize != height)
                {
                    throw new Exception();
                }
            }
            else
            {
                Console.WriteLine("Initializing NN");
                network = new NeuralNetwork(inputSize, width, height, outputSize);
                Console.WriteLine("NN Initialized");
            }

            if (File.Exists(databasePath1))
            {
                Console.WriteLine(File.Exists(databasePath1) ? "Database exists." : "File does not exist.");

                SampleDatabase temp = null;

                using (StreamReader r = new StreamReader("C:/Users/Mikkel/Documents/MTA18434_SemesterProject/ML_Sound_Samples/Assets/Resources/SampleDatabase/Database.json"))
                {
                    using (JsonReader reader = new JsonTextReader(r))
                    {
                        JsonSerializer serializer = new JsonSerializer();
                        Console.WriteLine("Deserializing");
                        temp = serializer.Deserialize <SampleDatabase>(reader);
                    }
                }

                Console.WriteLine(temp.database[0].data.Length);

                DataSample[] trainingSamples = new DataSample[10];
                Random       rand            = new Random();

                for (int i = 0; i < itterations; i++)
                {
                    Console.WriteLine("Progress: " + i + " / " + itterations);
                    // pick 10 samples
                    for (int j = 0; j < 10; j++)
                    {
                        int num = rand.Next(0, temp.database.Length);
                        trainingSamples[j] = new DataSample(temp.database[num].data, temp.database[num].label);

                        //Console.WriteLine("Database sample " + temp.database[num].data[0] + " " + temp.database[num].label);
                    }

                    network.TrainNetwork(trainingSamples);
                }

                Console.WriteLine("Saving network");
                network.SaveNetwork(networkPath1);
                Console.WriteLine("Network saved");
            }
            else
            {
                Console.WriteLine(File.Exists(databasePath1) ? "File exists." : "File does not exist.");
                throw new Exception();
            }
        }
 protected void newNetwork(int[] nodes, double eta, double alpha) {
     nodes[0] = _numInputs;
     nodes[nodes.Length - 1] = _numOutputs;
     //get a new ANN with learning rate eta and momentum alpha
     _ann = new NeuralNetwork(nodes, eta, alpha);
 }
 public GradientWeightVector(NeuralNetwork network)
 {
     inputLayerWeights  = new float[network.inputSize, network.hiddenSize];
     hiddenLayerWeights = new float[network.hiddenDimension - 1, network.hiddenSize, network.hiddenSize];
     outputLayerWeights = new float[network.outputSize, network.hiddenSize];
 }