/// <summary> /// Creates an initialized instance /// </summary> /// <param name="networkName">Name of the network</param> /// <param name="binBorder">If specified, it indicates that the whole network output is binary and specifies numeric border where GE network output is decided as a 1 and LT output as a 0.</param> /// <param name="foldNum">Current fold number</param> /// <param name="numOfFolds">Total number of the folds</param> /// <param name="foldNetworkNum">Current fold network number</param> /// <param name="numOfFoldNetworks">Total number of the fold networks</param> /// <param name="regrAttemptNumber">Current regression attempt number</param> /// <param name="regrMaxAttempts">Maximum number of regression attempts</param> /// <param name="epoch">Current epoch number within the current regression attempt</param> /// <param name="maxEpochs">Maximum number of epochs</param> /// <param name="currNetwork">Current network and related important error statistics.</param> /// <param name="bestNetwork">The best network for now and related important error statistics.</param> /// <param name="lastImprovementEpoch">Specifies when was lastly found an improvement (bestNetwork=currNetwork).</param> public BuildingState(string networkName, double binBorder, int foldNum, int numOfFolds, int foldNetworkNum, int numOfFoldNetworks, int regrAttemptNumber, int regrMaxAttempts, int epoch, int maxEpochs, TrainedNetwork currNetwork, TrainedNetwork bestNetwork, int lastImprovementEpoch ) { NetworkName = networkName; BinBorder = binBorder; FoldNum = foldNum; NumOfFolds = numOfFolds; FoldNetworkNum = foldNetworkNum; NumOfFoldNetworks = numOfFoldNetworks; RegrAttemptNumber = regrAttemptNumber; RegrMaxAttempts = regrMaxAttempts; Epoch = epoch; MaxEpochs = maxEpochs; CurrNetwork = currNetwork; BestNetwork = bestNetwork; LastImprovementEpoch = lastImprovementEpoch; return; }
//Static methods /// <summary> /// This is a default implementation of an evaluation whether the "candidate" network /// achieved a better result than the best network so far /// </summary> /// <param name="binaryOutput">Indicates the whole network output is binary</param> /// <param name="candidate">Network to be evaluated</param> /// <param name="currentBest">The best network so far</param> public static bool IsBetter(bool binaryOutput, TrainedNetwork candidate, TrainedNetwork currentBest) { if (binaryOutput) { if (candidate.CombinedBinaryError > currentBest.CombinedBinaryError) { return(false); } else if (candidate.CombinedBinaryError < currentBest.CombinedBinaryError) { return(true); } //CombinedBinaryError is the same else if (candidate.TestingBinErrorStat.BinValErrStat[0].Sum > currentBest.TestingBinErrorStat.BinValErrStat[0].Sum) { return(false); } else if (candidate.TestingBinErrorStat.BinValErrStat[0].Sum < currentBest.TestingBinErrorStat.BinValErrStat[0].Sum) { return(true); } //CombinedBinaryError is the same //TestingBinErrorStat.BinValErrStat[0].Sum is the same else if (candidate.TrainingBinErrorStat.BinValErrStat[0].Sum > currentBest.TrainingBinErrorStat.BinValErrStat[0].Sum) { return(false); } else if (candidate.TrainingBinErrorStat.BinValErrStat[0].Sum < currentBest.TrainingBinErrorStat.BinValErrStat[0].Sum) { return(true); } //CombinedBinaryError is the same //TestingBinErrorStat.BinValErrStat[0].Sum is the same //TrainingBinErrorStat.BinValErrStat[0].Sum is the same else if (candidate.CombinedPrecisionError < currentBest.CombinedPrecisionError) { return(true); } else { return(false); } } else { return(candidate.CombinedPrecisionError < currentBest.CombinedPrecisionError); } }
/// <summary> /// The deep copy constructor. /// </summary> /// <param name="source">Source instance</param> public TrainedNetwork(TrainedNetwork source) { NetworkName = source.NetworkName; BinBorder = source.BinBorder; Network = null; if (source.Network != null) { Network = source.Network.DeepClone(); } TrainerInfoMessage = source.TrainerInfoMessage; TrainingErrorStat = source.TrainingErrorStat?.DeepClone(); TrainingBinErrorStat = source.TrainingBinErrorStat?.DeepClone(); TestingErrorStat = source.TestingErrorStat?.DeepClone(); TestingBinErrorStat = source.TestingBinErrorStat?.DeepClone(); OutputWeightsStat = source.OutputWeightsStat?.DeepClone(); CombinedPrecisionError = source.CombinedPrecisionError; CombinedBinaryError = source.CombinedBinaryError; ExpectedPrecisionAccuracy = source.ExpectedPrecisionAccuracy; ExpectedBinaryAccuracy = source.ExpectedBinaryAccuracy; return; }
/// <summary> /// Builds trained network /// </summary> /// <returns>Trained network</returns> public TrainedNetwork Build() { TrainedNetwork bestNetwork = null; int lastImprovementEpoch = 0; double lastImprovementCombinedPrecisionError = 0d; double lastImprovementCombinedBinaryError = 0d; //Create network and trainer NonRecurrentNetUtils.CreateNetworkAndTrainer(_networkSettings, _trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, _rand, out INonRecurrentNetwork net, out INonRecurrentNetworkTrainer trainer ); //Iterate training cycles while (trainer.Iteration()) { //Compute current error statistics after training iteration //Training data part TrainedNetwork currNetwork = new TrainedNetwork { NetworkName = _networkName, BinBorder = _binBorder, Network = net, TrainerInfoMessage = trainer.InfoMessage, TrainingErrorStat = net.ComputeBatchErrorStat(_trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, out List <double[]> trainingComputedOutputsCollection) }; if (BinaryOutput) { currNetwork.TrainingBinErrorStat = new BinErrStat(_binBorder, trainingComputedOutputsCollection, _trainingBundle.OutputVectorCollection); currNetwork.CombinedBinaryError = currNetwork.TrainingBinErrorStat.TotalErrStat.Sum; } currNetwork.CombinedPrecisionError = currNetwork.TrainingErrorStat.ArithAvg; //Testing data part currNetwork.TestingErrorStat = net.ComputeBatchErrorStat(_testingBundle.InputVectorCollection, _testingBundle.OutputVectorCollection, out List <double[]> testingComputedOutputsCollection); currNetwork.CombinedPrecisionError = Math.Max(currNetwork.CombinedPrecisionError, currNetwork.TestingErrorStat.ArithAvg); if (BinaryOutput) { currNetwork.TestingBinErrorStat = new BinErrStat(_binBorder, testingComputedOutputsCollection, _testingBundle.OutputVectorCollection); currNetwork.CombinedBinaryError = Math.Max(currNetwork.CombinedBinaryError, currNetwork.TestingBinErrorStat.TotalErrStat.Sum); } //Expected precision accuracy currNetwork.ExpectedPrecisionAccuracy = Math.Min((1d - (currNetwork.TrainingErrorStat.ArithAvg / currNetwork.Network.OutputRange.Span)), (1d - (currNetwork.TestingErrorStat.ArithAvg / currNetwork.Network.OutputRange.Span))); //Expected binary accuracy if (BinaryOutput) { currNetwork.ExpectedBinaryAccuracy = Math.Min((1d - currNetwork.TrainingBinErrorStat.TotalErrStat.ArithAvg), (1d - currNetwork.TestingBinErrorStat.TotalErrStat.ArithAvg)); } else { currNetwork.ExpectedBinaryAccuracy = double.NaN; } //Restart lastImprovementEpoch when new trainer's attempt started if (trainer.AttemptEpoch == 1) { lastImprovementEpoch = trainer.AttemptEpoch; lastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; lastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } //First initialization of the best network bestNetwork = bestNetwork ?? currNetwork.DeepClone(); //RegrState instance BuildingState regrState = new BuildingState(_networkName, _binBorder, _foldNum, _numOfFolds, _foldNetworkNum, _numOfFoldNetworks, trainer.Attempt, trainer.MaxAttempt, trainer.AttemptEpoch, trainer.MaxAttemptEpoch, currNetwork, bestNetwork, lastImprovementEpoch); //Call controller BuildingInstr instructions = _controller(regrState); //Better? if (instructions.CurrentIsBetter) { //Adopt current regression unit as a best one bestNetwork = currNetwork.DeepClone(); regrState.BestNetwork = bestNetwork; lastImprovementEpoch = trainer.AttemptEpoch; lastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; lastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } if (currNetwork.CombinedBinaryError < lastImprovementCombinedBinaryError || currNetwork.CombinedPrecisionError < lastImprovementCombinedPrecisionError) { lastImprovementEpoch = trainer.AttemptEpoch; lastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; lastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } //Raise notification event RegressionEpochDone(regrState, instructions.CurrentIsBetter); //Process instructions if (instructions.StopProcess) { break; } else if (instructions.StopCurrentAttempt) { if (!trainer.NextAttempt()) { break; } } }//while (iteration) //Create statistics of the best network weights bestNetwork.OutputWeightsStat = bestNetwork.Network.ComputeWeightsStat(); return(bestNetwork); }