public void TestRPROPCont() { IMLDataSet trainingSet = XOR.CreateXORDataSet(); BasicNetwork net1 = XOR.CreateUnTrainedXOR(); BasicNetwork net2 = XOR.CreateUnTrainedXOR(); ResilientPropagation rprop1 = new ResilientPropagation(net1, trainingSet); ResilientPropagation rprop2 = new ResilientPropagation(net2, trainingSet); rprop1.Iteration(); rprop1.Iteration(); rprop2.Iteration(); rprop2.Iteration(); TrainingContinuation cont = rprop2.Pause(); ResilientPropagation rprop3 = new ResilientPropagation(net2, trainingSet); rprop3.Resume(cont); rprop1.Iteration(); rprop3.Iteration(); for (int i = 0; i < net1.Flat.Weights.Length; i++) { Assert.AreEqual(net1.Flat.Weights[i], net2.Flat.Weights[i], 0.0001); } }
public void TestRPROPContPersistEG() { IMLDataSet trainingSet = XOR.CreateXORDataSet(); BasicNetwork net1 = XOR.CreateUnTrainedXOR(); BasicNetwork net2 = XOR.CreateUnTrainedXOR(); ResilientPropagation rprop1 = new ResilientPropagation(net1, trainingSet); ResilientPropagation rprop2 = new ResilientPropagation(net2, trainingSet); rprop1.Iteration(); rprop1.Iteration(); rprop2.Iteration(); rprop2.Iteration(); TrainingContinuation cont = rprop2.Pause(); EncogDirectoryPersistence.SaveObject(EG_FILENAME, cont); TrainingContinuation cont2 = (TrainingContinuation)EncogDirectoryPersistence.LoadObject(EG_FILENAME); ResilientPropagation rprop3 = new ResilientPropagation(net2, trainingSet); rprop3.Resume(cont2); rprop1.Iteration(); rprop3.Iteration(); for (int i = 0; i < net1.Flat.Weights.Length; i++) { Assert.AreEqual(net1.Flat.Weights[i], net2.Flat.Weights[i], 0.0001); } }
static void Main(string[] args) { using (var p = Process.GetCurrentProcess()) p.PriorityClass = ProcessPriorityClass.Idle; FileInfo dataSetFile = new FileInfo("dataset.egb"); FileInfo networkFile = new FileInfo($"network{networkID}.nn"); FileInfo trainFile = new FileInfo($"train{networkID}.tr"); Console.WriteLine("Loading dataset."); if (!dataSetFile.Exists) { ExtractTrainData(dataSetFile); Console.WriteLine(@"Extracting dataset from database: " + dataSetFile); return; } var trainingSet = EncogUtility.LoadEGB2Memory(dataSetFile); Console.WriteLine($"Loaded {trainingSet.Count} samples. Input size: {trainingSet.InputSize}, Output size: {trainingSet.IdealSize}"); BasicNetwork network; if (networkFile.Exists) { Console.WriteLine($"Loading network {networkFile.FullName}"); network = (BasicNetwork)EncogDirectoryPersistence.LoadObject(networkFile); } else { Console.WriteLine("Creating NN."); network = EncogUtility.SimpleFeedForward(trainingSet.InputSize, 1000, 200, trainingSet.IdealSize, true); network.Reset(); } using (var p = Process.GetCurrentProcess()) Console.WriteLine($"RAM usage: {p.WorkingSet64 / 1024 / 1024} MB."); ResilientPropagation train = new ResilientPropagation(network, trainingSet) { ThreadCount = 0 }; if (trainFile.Exists) { TrainingContinuation cont = (TrainingContinuation)EncogDirectoryPersistence.LoadObject(trainFile); train.Resume(cont); } MyTrainConsole(train, network, trainingSet, minutes, networkFile, trainFile); Console.WriteLine(@"Final Error: " + train.Error); Console.WriteLine(@"Training complete, saving network."); EncogDirectoryPersistence.SaveObject(networkFile, network); Console.WriteLine(@"Network saved. Press s to stop."); ConsoleKeyInfo key; do { key = Console.ReadKey(); }while (key.KeyChar != 's'); }