//Train the network many times, with different initial values, evaluate them on the cross valiadtion data and select the best one private static ActivationNetwork trainNetworksCompeteOnCrossValidation(ActivationNetwork neuralNet, ISupervisedLearning teacher, double[][] input, double[][] output, double[][] crossValidationInput, char[] crossValidationDataLabels) { DefaultLog.Info("Training {0} neural networks & picking the one that performs best on the cross-validation data . . .", NUM_NETWORKS_TO_TRAIN_FOR_CROSS_VALIDATION_COMPETITION); MemoryStream bestNetworkStream = new MemoryStream(); uint bestNetworkNumMisclassified = uint.MaxValue; for (int i = 0; i < NUM_NETWORKS_TO_TRAIN_FOR_CROSS_VALIDATION_COMPETITION; i++) { DefaultLog.Info("Training network {0}/{1}", (i + 1), NUM_NETWORKS_TO_TRAIN_FOR_CROSS_VALIDATION_COMPETITION); //Train a new network neuralNet.Randomize(); //Reset the weights to random values trainNetwork(neuralNet, teacher, input, output, crossValidationInput, crossValidationDataLabels); //Compare this new networks performance to our current best network NeuralNetworkEvaluator evaluator = new NeuralNetworkEvaluator(neuralNet); evaluator.Evaluate(crossValidationInput, crossValidationDataLabels); uint numMisclassified = evaluator.ConfusionMatrix.NumMisclassifications; if (numMisclassified < bestNetworkNumMisclassified) { //This network performed better than out current best network, make this the new best //Clear the Memory Stream storing the current best network bestNetworkStream.SetLength(0); //Save the network & update the best numMisclassified neuralNet.Save(bestNetworkStream); bestNetworkNumMisclassified = numMisclassified; } } DefaultLog.Info("Trained all networks and selected the best one"); //Load up the network that performed best bestNetworkStream.Position = 0; //Read from the start of the stream ActivationNetwork bestNetwork = ActivationNetwork.Load(bestNetworkStream) as ActivationNetwork; return bestNetwork; }
private static void trainNetwork(ActivationNetwork neuralNet, ISupervisedLearning teacher, double[][] input, double[][] output, double[][] crossValidationInput, char[] crossValidationDataLabels) { //Make the network learn the data DefaultLog.Info("Training the neural network . . ."); double error; //TODO: Store the previous NUM_ITERATIONS_EQUAL_IMPLIES_PLATEAU networks so in the event of over-learning, we can return to the best one //Use the cross-validation data to notice if the network starts to over-learn the data. //Store the previous network (before training) and check if the performance drops on the cross-validation data MemoryStream prevNetworkStream = new MemoryStream(); uint prevNetworkNumMisclassified = uint.MaxValue; Queue<uint> prevNetworksNumMisclassified = new Queue<uint>(NUM_ITERATIONS_EQUAL_IMPLIES_PLATEAU); //Initialise the queue to be full of uint.MaxValue for (int i = 0; i < NUM_ITERATIONS_EQUAL_IMPLIES_PLATEAU; i++) { prevNetworksNumMisclassified.Enqueue(prevNetworkNumMisclassified); } int iterNum = 1; do { //Perform an iteration of training (calls teacher.Run() for each item in the array of inputs/outputs provided) error = teacher.RunEpoch(input, output); //Progress update if (iterNum % ITERATIONS_PER_PROGRESS_UPDATE == 0) { DefaultLog.Debug(String.Format("Learned for {0} iterations. Error: {1}", iterNum, error)); } //Evaluate this network on the cross-validation data NeuralNetworkEvaluator crossValidationEvaluator = new NeuralNetworkEvaluator(neuralNet); crossValidationEvaluator.Evaluate(crossValidationInput, crossValidationDataLabels); uint networkNumMisclassified = crossValidationEvaluator.ConfusionMatrix.NumMisclassifications; DefaultLog.Debug(String.Format("Network misclassified {0} / {1} on the cross-validation data set", networkNumMisclassified, crossValidationEvaluator.ConfusionMatrix.TotalClassifications)); //Check if we've overlearned the data and performance on the cross-valiadtion data has dropped off if (networkNumMisclassified > Stats.Mean(prevNetworksNumMisclassified)) //Use the mean of the number of misclassification, as the actual number can move around a bit { //Cross-Validation performance has dropped, reinstate the previous network & break DefaultLog.Debug(String.Format("Network has started to overlearn the training data on iteration {0}. Using previous classifier.", iterNum)); prevNetworkStream.Position = 0; //Set head to start of stream neuralNet = ActivationNetwork.Load(prevNetworkStream) as ActivationNetwork; //Read in the network break; } //Clear the Memory Stream storing the previous network prevNetworkStream.SetLength(0); //Store this network & the number of characters it misclassified on the cross-validation data neuralNet.Save(prevNetworkStream); //This is now the previous network, update the number it misclassified prevNetworkNumMisclassified = networkNumMisclassified; prevNetworksNumMisclassified.Dequeue(); prevNetworksNumMisclassified.Enqueue(prevNetworkNumMisclassified); //Check if the performance has plateaued if (prevNetworksNumMisclassified.Distinct().Count() == 1) //Allow for slight movement in performance here?? { //Cross-Validation performance has plateaued, use this network as the final one & break DefaultLog.Debug(String.Format("Network performance on cross-validation data has plateaued on iteration {0}.", iterNum)); break; } //Check if we've performed the max number of iterations if (iterNum > MAX_LEARNING_ITERATIONS) { DefaultLog.Debug(String.Format("Reached the maximum number of learning iterations ({0}), with error {1}", MAX_LEARNING_ITERATIONS, error)); break; } iterNum++; } while (error > LEARNED_AT_ERROR); DefaultLog.Info("Data learned to an error of {0}", error); }
private static NeuralNetworkEvaluator evaluateSingleLayerActivationNetworkWithSigmoidFunctionBackPropagationLearning( double[][] input, double[][] output, double[][] crossValidationInput, char[] crossValidationDataLabels, double[][] evaluationInput, char[] evaluationDataLabels, double learningRate, string networkName) { //Create the neural Network BipolarSigmoidFunction sigmoidFunction = new BipolarSigmoidFunction(2.0f); ActivationNetwork neuralNet = new ActivationNetwork(sigmoidFunction, input[0].Length, ClassifierHelpers.NUM_CHAR_CLASSES); //Randomise the networks initial weights neuralNet.Randomize(); //Create teacher that the network will use to learn the data (Back Propogation Learning technique used here) BackPropagationLearning teacher = new BackPropagationLearning(neuralNet); teacher.LearningRate = LEARNING_RATE; //Train the Network //trainNetwork(neuralNet, teacher, input, output, crossValidationInput, crossValidationDataLabels); //Train multiple networks, pick the one that performs best on the Cross-Validation data neuralNet = trainNetworksCompeteOnCrossValidation(neuralNet, teacher, input, output, crossValidationInput, crossValidationDataLabels); //Evaluate the network returned on the cross-validation data so it can be compared to the current best NeuralNetworkEvaluator crossValEvaluator = new NeuralNetworkEvaluator(neuralNet); crossValEvaluator.Evaluate(crossValidationInput, crossValidationDataLabels); //See if this network is better than the current best network of it's type //Try and load a previous network of this type string previousNetworkPath = Program.NEURAL_NETWORKS_PATH + networkName + Program.NEURAL_NETWORK_FILE_EXTENSION; string networkCMPath = Program.NEURAL_NETWORKS_PATH + networkName + ".csv"; bool newBest = false; ActivationNetwork bestNetwork = neuralNet; if(File.Exists(previousNetworkPath)) { //Load the previous network & evaluate it ActivationNetwork previous = ActivationNetwork.Load(previousNetworkPath) as ActivationNetwork; NeuralNetworkEvaluator prevCrossValEval = new NeuralNetworkEvaluator(previous); prevCrossValEval.Evaluate(crossValidationInput, crossValidationDataLabels); //If this network is better than the previous best, write it out as the new best if(prevCrossValEval.ConfusionMatrix.NumMisclassifications > crossValEvaluator.ConfusionMatrix.NumMisclassifications) { DefaultLog.Info("New best cross-validation score for network \"{0}\". Previous was {1}/{2}, new best is {3}/{2}", networkName, prevCrossValEval.ConfusionMatrix.NumMisclassifications, prevCrossValEval.ConfusionMatrix.TotalClassifications, crossValEvaluator.ConfusionMatrix.NumMisclassifications); //Delete the old files File.Delete(previousNetworkPath); File.Delete(networkCMPath); newBest = true; } else //The previous network is still the best { DefaultLog.Info("Existing \"{0}\" network performed better than new one. New network scored {1}/{2}, existing scored {3}/{2}", networkName, crossValEvaluator.ConfusionMatrix.NumMisclassifications, crossValEvaluator.ConfusionMatrix.TotalClassifications, prevCrossValEval.ConfusionMatrix.NumMisclassifications); bestNetwork = previous; } } else //Otherwise there isn't a previous best { DefaultLog.Info("No previous best record for network \"{0}\" . . .", networkName); newBest = true; } //Evaluate the best system on the evaluation data NeuralNetworkEvaluator evaluator = new NeuralNetworkEvaluator(bestNetwork); evaluator.Evaluate(evaluationInput, evaluationDataLabels); //If there is a new best to write out if(newBest) { DefaultLog.Info("Writing out net best network of type\"{0}\"", networkName); neuralNet.Save(previousNetworkPath); //Write out the Confusion Matrix for the evaluation data, not cross-validation evaluator.ConfusionMatrix.WriteToCsv(networkCMPath); DefaultLog.Info("Finished writing out network \"{0}\"", networkName); } return evaluator; }