/// <summary> /// Tests a random record against the RNN /// </summary> private void TestRandomRecord() { Random r = new Random(); SentenceIORecord temp2 = IORecords[r.Next(0, IORecords.Count - 1)]; Heuristics.Utilities.Matrices.Matrix TempInput = new Heuristics.Utilities.Matrices.Matrix(temp2.Inputs.Count, NumberOfRecordsToKeep); for (int i = 0; i < TempInput.Height; i++) { int Value = temp2.Inputs[i]; for (int j = 0; j < TempInput.Width; j++) { if (j == Value) { TempInput.SetValue(i, j, 1); } else { TempInput.SetValue(i, j, 0); } } } rnn.ForwardPropagation(TempInput); int IndexOfHighest = 5; double HighestValue = 0; for (int i = 0; i < rnn.Outputs.Width; i++) { if (rnn.Outputs.GetValue(rnn.Outputs.Height - 1, i) > HighestValue) { IndexOfHighest = i; HighestValue = rnn.Outputs.GetValue(rnn.Outputs.Height - 1, i); } } txtBoxActual.Text = GetGenreType(IndexOfHighest); txtBoxExpected.Text = GetGenreType(temp2.Output); }
/// <summary> /// Trains the RNN /// </summary> private void Train() { if (IORecords == null || !IORecords.Any()) { return; } if (rnn == null) { rnn = new RecurrentNeuralNetwork(NumberOfRecordsToKeep, 100, 5); } IORecords = SplitIORecords(IORecords, PercentTrainingData * .01); int AmountPerBatch = (int)(IORecords.Count * (PercentTrainingData * .01) / NumberOfBatches);// / NumberOfBatches); int CurrentCounter = 0; int TotalTrainingData = NumberOfBatches * AmountPerBatch; while (CurrentCounter < TotalTrainingData) { List <Heuristics.Utilities.Matrices.Matrix> Inputs = new List <Heuristics.Utilities.Matrices.Matrix>(); List <Heuristics.Utilities.Matrices.Matrix> Outputs = new List <Heuristics.Utilities.Matrices.Matrix>(); for (int sample = 0; sample < AmountPerBatch; sample++) { SentenceIORecord temp = IORecords[sample + CurrentCounter]; Heuristics.Utilities.Matrices.Matrix Input = new Heuristics.Utilities.Matrices.Matrix(temp.Inputs.Count, NumberOfRecordsToKeep); Heuristics.Utilities.Matrices.Matrix Output = new Heuristics.Utilities.Matrices.Matrix(temp.Inputs.Count, 5); for (int i = 0; i < Input.Height; i++) { int Value = temp.Inputs[i]; for (int j = 0; j < Input.Width; j++) { if (j == Value) { Input.SetValue(i, j, 1); } else { Input.SetValue(i, j, 0); } } Value = temp.Output; for (int j = 0; j < Output.Width; j++) { if (j == Value) { Output.SetValue(i, j, 1); } else { Output.SetValue(i, j, 0); } } } Inputs.Add(Input); Outputs.Add(Output); } rnn.Train(Inputs, Outputs, .001, NumberOfEpochs, 0); CurrentCounter += AmountPerBatch; } SaveNNStates(); }
/// <summary> /// Loads the test/trainin data into the IO Records parameter /// </summary> /// <returns></returns> private List <SentenceIORecord> LoadIORecords() { // Book file URL's string Adventure1 = CurrentDirectory + "\\Books\\Adventure\\Tarzan of the Apes.txt"; string Adventure2 = CurrentDirectory + "\\Books\\Adventure\\The Lion of Petra.txt"; string Adventure3 = CurrentDirectory + "\\Books\\Adventure\\The Scarlet Pimpernel.txt"; string Crime1 = CurrentDirectory + "\\Books\\Crime Fiction\\Dead Men Tell No Tales.txt"; string Crime2 = CurrentDirectory + "\\Books\\Crime Fiction\\Tales of Chinatown.txt"; string Crime3 = CurrentDirectory + "\\Books\\Crime Fiction\\The Extraordinary Adventures of Arsene Lupin.txt"; string Horror1 = CurrentDirectory + "\\Books\\Horror\\Ghost Stories of an Antiquary.txt"; string Horror2 = CurrentDirectory + "\\Books\\Horror\\Metamorphosis.txt"; string Horror3 = CurrentDirectory + "\\Books\\Horror\\The Wendigo.txt"; string Romance1 = CurrentDirectory + "\\Books\\Romantic Fiction\\Only a Girl's Love.txt"; string Romance2 = CurrentDirectory + "\\Books\\Romantic Fiction\\Star of India.txt"; string Romance3 = CurrentDirectory + "\\Books\\Romantic Fiction\\Wastralls.txt"; string Science1 = CurrentDirectory + "\\Books\\Science Fiction\\Astounding Stories of Super_Science.txt"; string Science2 = CurrentDirectory + "\\Books\\Science Fiction\\The Lost World.txt"; string Science3 = CurrentDirectory + "\\Books\\Science Fiction\\The Sky Is Falling.txt"; List <Sentence> Sentences = new List <Sentence>(); // Keep track of how many sentences were used for each genre int NumberOfAdventure, NumberOfCrime, NumberOfHorror, NumberOfRomance, NumberOfScience; // Get the sentences from each of the books for each genre Sentences.AddRange(GetSentenceFromBook(Adventure1)); Sentences.AddRange(GetSentenceFromBook(Adventure2)); Sentences.AddRange(GetSentenceFromBook(Adventure3)); NumberOfAdventure = Sentences.Count; Sentences.AddRange(GetSentenceFromBook(Crime1)); Sentences.AddRange(GetSentenceFromBook(Crime2)); Sentences.AddRange(GetSentenceFromBook(Crime3)); NumberOfCrime = Sentences.Count - NumberOfAdventure; Sentences.AddRange(GetSentenceFromBook(Horror1)); Sentences.AddRange(GetSentenceFromBook(Horror2)); Sentences.AddRange(GetSentenceFromBook(Horror3)); NumberOfHorror = Sentences.Count - NumberOfAdventure; Sentences.AddRange(GetSentenceFromBook(Romance1)); Sentences.AddRange(GetSentenceFromBook(Romance2)); Sentences.AddRange(GetSentenceFromBook(Romance3)); NumberOfRomance = Sentences.Count - NumberOfHorror; Sentences.AddRange(GetSentenceFromBook(Science1)); Sentences.AddRange(GetSentenceFromBook(Science2)); Sentences.AddRange(GetSentenceFromBook(Science3)); NumberOfScience = Sentences.Count - NumberOfRomance; // Remove infrequent words Sentences = TextPreperation.RemoveInfrequentWords(Sentences, NumberOfRecordsToKeep, "UNKNOWN_TOKEN"); // Add beggining and ending tokens //Sentences = TextPreperation.AddBegginingAndEndTokens(Sentences, "SENTENCE_START", "SENTENCE_END"); // Map the word strings to integer values List <MappedSentence> MappedSentences = TextPreperation.MapSentences(Sentences); List <SentenceIORecord> IORecords = new List <SentenceIORecord>(); int MappedSentenceNumber = 0; // For each mapped sentence create an input record foreach (MappedSentence s in MappedSentences) { MappedSentenceNumber++; SentenceIORecord temp = new SentenceIORecord(); int GenreID; if (MappedSentenceNumber < NumberOfAdventure) { GenreID = 0; } else if (MappedSentenceNumber < NumberOfCrime + NumberOfAdventure) { GenreID = 1; } else if (MappedSentenceNumber < NumberOfHorror + NumberOfCrime + NumberOfAdventure) { GenreID = 2; } else if (MappedSentenceNumber < NumberOfRomance + NumberOfHorror + NumberOfCrime + NumberOfAdventure) { GenreID = 3; } else { GenreID = 4; } temp.Output = GenreID; foreach (int x in s.IDs) { temp.Inputs.Add(x); } IORecords.Add(temp); } return(IORecords); }