public void TestRPROPCont() { IMLDataSet trainingSet = XOR.CreateXORDataSet(); BasicNetwork net1 = XOR.CreateUnTrainedXOR(); BasicNetwork net2 = XOR.CreateUnTrainedXOR(); ResilientPropagation rprop1 = new ResilientPropagation(net1, trainingSet); ResilientPropagation rprop2 = new ResilientPropagation(net2, trainingSet); rprop1.Iteration(); rprop1.Iteration(); rprop2.Iteration(); rprop2.Iteration(); TrainingContinuation cont = rprop2.Pause(); ResilientPropagation rprop3 = new ResilientPropagation(net2, trainingSet); rprop3.Resume(cont); rprop1.Iteration(); rprop3.Iteration(); for (int i = 0; i < net1.Flat.Weights.Length; i++) { Assert.AreEqual(net1.Flat.Weights[i], net2.Flat.Weights[i], 0.0001); } }
public static void MyTrainConsole(IMLTrain train, BasicNetwork network, IMLDataSet trainingSet, int minutes, FileInfo networkFile, FileInfo trainFile) { int epoch = 1; long remaining; Console.WriteLine(@"Beginning training..."); long start = Environment.TickCount; do { train.Iteration(); long current = Environment.TickCount; long elapsed = (current - start) / 1000; remaining = minutes - elapsed / 60; Console.WriteLine($@"Iteration #{Format.FormatInteger(epoch)} Error:{Format.FormatPercent(train.Error)} elapsed time = {Format.FormatTimeSpan((int)elapsed)} time left = {Format.FormatTimeSpan((int)remaining * 60)}"); epoch++; EncogDirectoryPersistence.SaveObject(networkFile, network); TrainingContinuation cont = train.Pause(); EncogDirectoryPersistence.SaveObject(trainFile, cont); train.Resume(cont); foreach (var x in cont.Contents) { Console.WriteLine($"{x.Key}: {((double[])x.Value).Average()}"); } }while (remaining > 0 && !train.TrainingDone && !Console.KeyAvailable); Console.WriteLine("Finishing."); train.FinishTraining(); }
/// <summary> /// Pause the training. /// </summary> /// <returns>A training continuation object to continue with.</returns> public override TrainingContinuation Pause() { TrainingContinuation result = new TrainingContinuation(); if (this.FlatTraining is TrainFlatNetworkResilient) { result[ResilientPropagation.LAST_GRADIENTS] = ((TrainFlatNetworkResilient)this.FlatTraining) .LastGradient; result[ResilientPropagation.UPDATE_VALUES] = ((TrainFlatNetworkResilient)this.FlatTraining) .UpdateValues; } #if !SILVERLIGHT else { result[ResilientPropagation.LAST_GRADIENTS] = ((TrainFlatNetworkOpenCL)this.FlatTraining) .LastGradient; result[ResilientPropagation.UPDATE_VALUES] = ((TrainFlatNetworkOpenCL)this.FlatTraining) .UpdateValues; } #endif return(result); }
public void TestRPROPContPersistEG() { IMLDataSet trainingSet = XOR.CreateXORDataSet(); BasicNetwork net1 = XOR.CreateUnTrainedXOR(); BasicNetwork net2 = XOR.CreateUnTrainedXOR(); ResilientPropagation rprop1 = new ResilientPropagation(net1, trainingSet); ResilientPropagation rprop2 = new ResilientPropagation(net2, trainingSet); rprop1.Iteration(); rprop1.Iteration(); rprop2.Iteration(); rprop2.Iteration(); TrainingContinuation cont = rprop2.Pause(); EncogDirectoryPersistence.SaveObject(EG_FILENAME, cont); TrainingContinuation cont2 = (TrainingContinuation)EncogDirectoryPersistence.LoadObject(EG_FILENAME); ResilientPropagation rprop3 = new ResilientPropagation(net2, trainingSet); rprop3.Resume(cont2); rprop1.Iteration(); rprop3.Iteration(); for (int i = 0; i < net1.Flat.Weights.Length; i++) { Assert.AreEqual(net1.Flat.Weights[i], net2.Flat.Weights[i], 0.0001); } }
/// <summary> /// Resume training. /// </summary> /// <param name="state">The training state to return to.</param> public override void Resume(TrainingContinuation state) { if (!IsValidResume(state)) { throw new TrainingError("Invalid training resume data length"); } double[] lastGradient = (double[])state [ResilientPropagation.LAST_GRADIENTS]; double[] updateValues = (double[])state [ResilientPropagation.UPDATE_VALUES]; if (this.FlatTraining is TrainFlatNetworkResilient) { EngineArray.ArrayCopy(lastGradient, ((TrainFlatNetworkResilient)this.FlatTraining) .LastGradient); EngineArray.ArrayCopy(updateValues, ((TrainFlatNetworkResilient)this.FlatTraining) .UpdateValues); } #if !SILVERLIGHT else if (this.FlatTraining is TrainFlatNetworkOpenCL) { EngineArray.ArrayCopy(lastGradient, ((TrainFlatNetworkOpenCL)this .FlatTraining).LastGradient); EngineArray.ArrayCopy(updateValues, ((TrainFlatNetworkOpenCL)this .FlatTraining).UpdateValues); } #endif }
/// <summary> /// Load the object. /// </summary> /// <param name="xmlin">The XML object to load from.</param> /// <returns>The loaded object.</returns> public IEncogPersistedObject Load(ReadXML xmlin) { this.current = new TrainingContinuation(); String name = xmlin.LastTag.Attributes[ EncogPersistedCollection.ATTRIBUTE_NAME]; String description = xmlin.LastTag.Attributes[ EncogPersistedCollection.ATTRIBUTE_DESCRIPTION]; this.current.Name = name; this.current.Description = description; while (xmlin.ReadToTag()) { if (xmlin.IsIt(TrainingContinuationPersistor.TAG_ITEMS, true)) { HandleItems(xmlin); } else if (xmlin.IsIt( EncogPersistedCollection.TYPE_TRAINING_CONTINUATION, false)) { break; } } return(this.current); }
static void Main(string[] args) { using (var p = Process.GetCurrentProcess()) p.PriorityClass = ProcessPriorityClass.Idle; FileInfo dataSetFile = new FileInfo("dataset.egb"); FileInfo networkFile = new FileInfo($"network{networkID}.nn"); FileInfo trainFile = new FileInfo($"train{networkID}.tr"); Console.WriteLine("Loading dataset."); if (!dataSetFile.Exists) { ExtractTrainData(dataSetFile); Console.WriteLine(@"Extracting dataset from database: " + dataSetFile); return; } var trainingSet = EncogUtility.LoadEGB2Memory(dataSetFile); Console.WriteLine($"Loaded {trainingSet.Count} samples. Input size: {trainingSet.InputSize}, Output size: {trainingSet.IdealSize}"); BasicNetwork network; if (networkFile.Exists) { Console.WriteLine($"Loading network {networkFile.FullName}"); network = (BasicNetwork)EncogDirectoryPersistence.LoadObject(networkFile); } else { Console.WriteLine("Creating NN."); network = EncogUtility.SimpleFeedForward(trainingSet.InputSize, 1000, 200, trainingSet.IdealSize, true); network.Reset(); } using (var p = Process.GetCurrentProcess()) Console.WriteLine($"RAM usage: {p.WorkingSet64 / 1024 / 1024} MB."); ResilientPropagation train = new ResilientPropagation(network, trainingSet) { ThreadCount = 0 }; if (trainFile.Exists) { TrainingContinuation cont = (TrainingContinuation)EncogDirectoryPersistence.LoadObject(trainFile); train.Resume(cont); } MyTrainConsole(train, network, trainingSet, minutes, networkFile, trainFile); Console.WriteLine(@"Final Error: " + train.Error); Console.WriteLine(@"Training complete, saving network."); EncogDirectoryPersistence.SaveObject(networkFile, network); Console.WriteLine(@"Network saved. Press s to stop."); ConsoleKeyInfo key; do { key = Console.ReadKey(); }while (key.KeyChar != 's'); }
/// <summary> /// Resume training. /// </summary> /// /// <param name="state">The training state to return to.</param> public override sealed void Resume(TrainingContinuation state) { if (!IsValidResume(state)) { throw new TrainingError("Invalid training resume data length"); } _lastDelta = (double[])state.Get(PropertyLastDelta); }
/// <summary> /// Pause the training. /// </summary> /// <returns>A training continuation object to continue with.</returns> public override TrainingContinuation Pause() { var result = new TrainingContinuation { TrainingType = (GetType().Name) }; result.Contents[LastGradients] = LastGradient; return(result); }
/// <summary> /// Pause the training. /// </summary> /// /// <returns>A training continuation object to continue with.</returns> public override sealed TrainingContinuation Pause() { var result = new TrainingContinuation { TrainingType = GetType().Name }; result.Set(PropertyLastDelta, _lastDelta); return(result); }
/// <summary> /// Resume training. /// </summary> /// <param name="state">The training state to return to.</param> public override void Resume(TrainingContinuation state) { if (!IsValidResume(state)) { throw new TrainingError("Invalid training resume data length"); } ((TrainFlatNetworkBackPropagation)this.FlatTraining).LastDelta = (double[])state[Backpropagation.LAST_DELTA]; }
/// <summary> /// Pause the training. /// </summary> /// <returns>A training continuation object to continue with.</returns> public override TrainingContinuation Pause() { TrainingContinuation result = new TrainingContinuation(); TrainFlatNetworkBackPropagation backFlat = (TrainFlatNetworkBackPropagation)FlatTraining; double[] d = backFlat.LastDelta; result[Backpropagation.LAST_DELTA] = d; return(result); }
/// <summary> /// Pause the training. /// </summary> /// <returns>A training continuation object to continue with.</returns> public override TrainingContinuation Pause() { var result = new TrainingContinuation { TrainingType = (GetType().Name) }; var qprop = (TrainFlatNetworkQPROP)FlatTraining; double[] d = qprop.LastGradient; result.Contents[LastGradients] = d; return(result); }
/// <summary> /// Pause the training. /// </summary> /// /// <returns>A training continuation object to continue with.</returns> public override sealed TrainingContinuation Pause() { var result = new TrainingContinuation(); result.TrainingType = GetType().Name; result.Set(LastGradientsConst, LastGradient); result.Set(UpdateValuesConst, _updateValues); return(result); }
/// <summary> /// Determine if the specified continuation object is valid to resume with. /// </summary> /// <param name="state">The continuation object to check.</param> /// <returns>True if the specified continuation object is valid for this /// training method and network.</returns> public override bool IsValidResume(TrainingContinuation state) { if (!state.Contents.ContainsKey(Backpropagation.LAST_DELTA)) { return(false); } double[] d = (double[])state [Backpropagation.LAST_DELTA]; return(d.Length == Network.Structure.CalculateSize()); }
/// <summary> /// Pause the training. /// </summary> /// /// <returns>A training continuation object to continue with.</returns> public override sealed TrainingContinuation Pause() { var result = new TrainingContinuation { TrainingType = GetType().Name }; var backFlat = (TrainFlatNetworkBackPropagation)FlatTraining; double[] d = backFlat.LastDelta; result.Set(PropertyLastDelta, d); return(result); }
/// <summary> /// Resume training. /// </summary> /// <param name="state">The training state to return to.</param> public override void Resume(TrainingContinuation state) { if (!IsValidResume(state)) { throw new TrainingError("Invalid training resume data length"); } var lastGradient = (double[])state.Contents[ LastGradients]; EngineArray.ArrayCopy(lastGradient, LastGradient); }
/// <summary> /// Resume training. /// </summary> /// /// <param name="state">The training state to return to.</param> public override sealed void Resume(TrainingContinuation state) { if (!IsValidResume(state)) { throw new TrainingError("Invalid training resume data length"); } var lastGradient = (double[])state.Get(LastGradientsConst); var updateValues = (double[])state.Get(UpdateValuesConst); EngineArray.ArrayCopy(lastGradient, LastGradient); EngineArray.ArrayCopy(updateValues, _updateValues); }
/// <summary> /// Pause the training. /// </summary> /// /// <returns>A training continuation object to continue with.</returns> public override sealed TrainingContinuation Pause() { var result = new TrainingContinuation(); result.TrainingType = GetType().Name; result.Set(LastGradients, ((TrainFlatNetworkResilient)FlatTraining).LastGradient); result.Set(UpdateValues, ((TrainFlatNetworkResilient)FlatTraining).UpdateValues); return(result); }
/// <summary> /// Save the object. /// </summary> /// <param name="obj">The object to save.</param> /// <param name="xmlout">The XML output object.</param> public void Save(IEncogPersistedObject obj, WriteXML xmlout) { PersistorUtil.BeginEncogObject( EncogPersistedCollection.TYPE_TRAINING_CONTINUATION, xmlout, obj, true); this.current = (TrainingContinuation)obj; xmlout.BeginTag(TrainingContinuationPersistor.TAG_ITEMS); SaveItems(xmlout); xmlout.EndTag(); xmlout.EndTag(); }
/// <summary> /// Determine if the specified continuation object is valid to resume with. /// </summary> /// <param name="state">The continuation object to check.</param> /// <returns>True if the specified continuation object is valid for this /// training method and network.</returns> public override bool IsValidResume(TrainingContinuation state) { if (!state.Contents.ContainsKey( ResilientPropagation.LAST_GRADIENTS) || !state.Contents.ContainsKey( ResilientPropagation.UPDATE_VALUES)) { return(false); } double[] d = (double[])state [ResilientPropagation.LAST_GRADIENTS]; return(d.Length == Network.Structure.CalculateSize()); }
/// <summary> /// Determine if the specified continuation object is valid to resume with. /// </summary> /// <param name="state">The continuation object to check.</param> /// <returns>True if the specified continuation object is valid for this /// training method and network.</returns> public bool IsValidResume(TrainingContinuation state) { if (!state.Contents.ContainsKey(LastGradients)) { return(false); } if (!state.TrainingType.Equals(GetType().Name)) { return(false); } var d = (double[])state.Contents[LastGradients]; return(d.Length == ((IContainsFlat)Method).Flat.Weights.Length); }
/// <summary> /// Determine if the specified continuation object is valid to resume with. /// </summary> /// /// <param name="state">The continuation object to check.</param> /// <returns>True if the specified continuation object is valid for this /// training method and network.</returns> public bool IsValidResume(TrainingContinuation state) { if (!state.Contents.ContainsKey( LastGradientsConst) || !state.Contents.ContainsKey( UpdateValuesConst)) { return(false); } if (!state.TrainingType.Equals(GetType().Name)) { return(false); } var d = (double[])state.Get(LastGradientsConst); return(d.Length == Network.Flat.Weights.Length); }
/// <summary> /// Determine if the specified continuation object is valid to resume with. /// </summary> /// <param name="state">The continuation object to check.</param> /// <returns>True if the specified continuation object is valid for this /// training method and network.</returns> public bool IsValidResume(TrainingContinuation state) { if (!state.Contents.ContainsKey(LastGradients)) { return false; } if (!state.TrainingType.Equals(GetType().Name)) { return false; } var d = (double[]) state.Contents[LastGradients]; return d.Length == ((IContainsFlat) Method).Flat.Weights.Length; }
/// <summary> /// Pause the training. /// </summary> /// /// <returns>A training continuation object to continue with.</returns> public override sealed TrainingContinuation Pause() { var result = new TrainingContinuation(); result.TrainingType = GetType().Name; result.Set(LastGradients, ((TrainFlatNetworkResilient) FlatTraining).LastGradient); result.Set(UpdateValues, ((TrainFlatNetworkResilient) FlatTraining).UpdateValues); return result; }
/// <summary> /// Pause the training. /// </summary> /// <returns>A training continuation object to continue with.</returns> public override TrainingContinuation Pause() { TrainingContinuation result = new TrainingContinuation(); TrainFlatNetworkBackPropagation backFlat = (TrainFlatNetworkBackPropagation)FlatTraining; double[] d = backFlat.LastDelta; result[Backpropagation.LAST_DELTA] = d; return result; }
/// <summary> /// Pause the training. /// </summary> /// /// <returns>A training continuation object to continue with.</returns> public override sealed TrainingContinuation Pause() { var result = new TrainingContinuation(); result.TrainingType = GetType().Name; result.Set(LastGradientsConst,LastGradient); result.Set(UpdateValuesConst,_updateValues); return result; }
/// <inheritdoc /> public void Resume(TrainingContinuation state) { }
/// <summary> /// This training type does not support training continue. /// </summary> /// /// <param name="state">Not used.</param> public override sealed void Resume(TrainingContinuation state) { }
/// <inheritdoc/> public override void Resume(TrainingContinuation state) { // TODO Auto-generated method stub }
/// <summary> /// Pause the training. /// </summary> /// /// <returns>A training continuation object to continue with.</returns> public override sealed TrainingContinuation Pause() { var result = new TrainingContinuation {TrainingType = GetType().Name}; result.Set(PropertyLastDelta, _lastDelta); return result; }
/// <summary> /// Pause the training. /// </summary> /// <returns>A training continuation object to continue with.</returns> public override TrainingContinuation Pause() { var result = new TrainingContinuation {TrainingType = (GetType().Name)}; var qprop = (TrainFlatNetworkQPROP) FlatTraining; double[] d = qprop.LastGradient; result.Contents[LastGradients] = d; return result; }
/// <summary> /// from Encog.ml.train.MLTrain /// </summary> /// public abstract void Resume( TrainingContinuation state);
/// <summary> /// Resume training. /// </summary> /// <param name="state">The training state to return to.</param> public override void Resume(TrainingContinuation state) { if (!IsValidResume(state)) { throw new TrainingError("Invalid training resume data length"); } var lastGradient = (double[]) state.Contents[ LastGradients]; EngineArray.ArrayCopy(lastGradient, ((TrainFlatNetworkQPROP) FlatTraining).LastGradient); }
/// <summary> /// Pause the training. /// </summary> /// <returns>A training continuation object to continue with.</returns> public override TrainingContinuation Pause() { var result = new TrainingContinuation {TrainingType = (GetType().Name)}; result.Contents[LastGradients] = LastGradient; return result; }
/// <summary> /// Determine if the specified continuation object is valid to resume with. /// </summary> /// /// <param name="state">The continuation object to check.</param> /// <returns>True if the specified continuation object is valid for this /// training method and network.</returns> public bool IsValidResume(TrainingContinuation state) { if (!state.Contents.ContainsKey( LastGradientsConst) || !state.Contents.ContainsKey( UpdateValuesConst)) { return false; } if (!state.TrainingType.Equals(GetType().Name)) { return false; } var d = (double[]) state.Get(LastGradientsConst); return d.Length == Network.Flat.Weights.Length; }
/// <summary> /// Determine if the specified continuation object is valid to resume with. /// </summary> /// <param name="state">The continuation object to check.</param> /// <returns>True if the specified continuation object is valid for this /// training method and network.</returns> public override bool IsValidResume(TrainingContinuation state) { if (!state.Contents.ContainsKey( ResilientPropagation.LAST_GRADIENTS) || !state.Contents.ContainsKey( ResilientPropagation.UPDATE_VALUES)) { return false; } double[] d = (double[])state [ResilientPropagation.LAST_GRADIENTS]; return d.Length == Network.Structure.CalculateSize(); }
/// <summary> /// Resume training. /// </summary> /// /// <param name="state">The training state to return to.</param> public override sealed void Resume(TrainingContinuation state) { if (!IsValidResume(state)) { throw new TrainingError("Invalid training resume data length"); } var lastGradient = (double[]) state.Get(LastGradientsConst); var updateValues = (double[]) state.Get(UpdateValuesConst); EngineArray.ArrayCopy(lastGradient,LastGradient); EngineArray.ArrayCopy(updateValues,_updateValues); }
/// <summary> /// Pause the training. /// </summary> /// <returns>A training continuation object to continue with.</returns> public override TrainingContinuation Pause() { TrainingContinuation result = new TrainingContinuation(); if (this.FlatTraining is TrainFlatNetworkResilient) { result[ResilientPropagation.LAST_GRADIENTS] = ((TrainFlatNetworkResilient)this.FlatTraining) .LastGradient; result[ResilientPropagation.UPDATE_VALUES] = ((TrainFlatNetworkResilient)this.FlatTraining) .UpdateValues; } #if !SILVERLIGHT else { result[ResilientPropagation.LAST_GRADIENTS] = ((TrainFlatNetworkOpenCL)this.FlatTraining) .LastGradient; result[ResilientPropagation.UPDATE_VALUES] = ((TrainFlatNetworkOpenCL)this.FlatTraining) .UpdateValues; } #endif return result; }
/// <summary> /// Pause the training. /// </summary> /// /// <returns>A training continuation object to continue with.</returns> public override sealed TrainingContinuation Pause() { var result = new TrainingContinuation {TrainingType = GetType().Name}; var backFlat = (TrainFlatNetworkBackPropagation) FlatTraining; double[] d = backFlat.LastDelta; result.Set(PropertyLastDelta, d); return result; }
/// <summary> /// Determine if the specified continuation object is valid to resume with. /// </summary> /// <param name="state">The continuation object to check.</param> /// <returns>True if the specified continuation object is valid for this /// training method and network.</returns> public override bool IsValidResume(TrainingContinuation state) { if (!state.Contents.ContainsKey(Backpropagation.LAST_DELTA)) { return false; } double[] d = (double[])state [Backpropagation.LAST_DELTA]; return d.Length == Network.Structure.CalculateSize(); }
/// <inheritdoc/> public override void Resume(TrainingContinuation state) { }
public override void Resume(TrainingContinuation state) { throw new NotImplementedException(); }