/// <summary> /// Builds trained network /// </summary> /// <returns>Trained network</returns> public TrainedNetwork Build() { TrainedNetwork bestNetwork = null; int lastImprovementEpoch = 0; double lastImprovementCombinedPrecisionError = 0d; double lastImprovementCombinedBinaryError = 0d; //Create network and trainer NonRecurrentNetUtils.CreateNetworkAndTrainer(_networkSettings, _trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, _rand, out INonRecurrentNetwork net, out INonRecurrentNetworkTrainer trainer ); //Iterate training cycles while (trainer.Iteration()) { //Compute current error statistics after training iteration //Training data part TrainedNetwork currNetwork = new TrainedNetwork { NetworkName = _networkName, BinBorder = _binBorder, Network = net, TrainerInfoMessage = trainer.InfoMessage, TrainingErrorStat = net.ComputeBatchErrorStat(_trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, out List <double[]> trainingComputedOutputsCollection) }; if (BinaryOutput) { currNetwork.TrainingBinErrorStat = new BinErrStat(_binBorder, trainingComputedOutputsCollection, _trainingBundle.OutputVectorCollection); currNetwork.CombinedBinaryError = currNetwork.TrainingBinErrorStat.TotalErrStat.Sum; } currNetwork.CombinedPrecisionError = currNetwork.TrainingErrorStat.ArithAvg; //Testing data part currNetwork.TestingErrorStat = net.ComputeBatchErrorStat(_testingBundle.InputVectorCollection, _testingBundle.OutputVectorCollection, out List <double[]> testingComputedOutputsCollection); currNetwork.CombinedPrecisionError = Math.Max(currNetwork.CombinedPrecisionError, currNetwork.TestingErrorStat.ArithAvg); if (BinaryOutput) { currNetwork.TestingBinErrorStat = new BinErrStat(_binBorder, testingComputedOutputsCollection, _testingBundle.OutputVectorCollection); currNetwork.CombinedBinaryError = Math.Max(currNetwork.CombinedBinaryError, currNetwork.TestingBinErrorStat.TotalErrStat.Sum); } //Expected precision accuracy currNetwork.ExpectedPrecisionAccuracy = Math.Min((1d - (currNetwork.TrainingErrorStat.ArithAvg / currNetwork.Network.OutputRange.Span)), (1d - (currNetwork.TestingErrorStat.ArithAvg / currNetwork.Network.OutputRange.Span))); //Expected binary accuracy if (BinaryOutput) { currNetwork.ExpectedBinaryAccuracy = Math.Min((1d - currNetwork.TrainingBinErrorStat.TotalErrStat.ArithAvg), (1d - currNetwork.TestingBinErrorStat.TotalErrStat.ArithAvg)); } else { currNetwork.ExpectedBinaryAccuracy = double.NaN; } //Restart lastImprovementEpoch when new trainer's attempt started if (trainer.AttemptEpoch == 1) { lastImprovementEpoch = trainer.AttemptEpoch; lastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; lastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } //First initialization of the best network bestNetwork = bestNetwork ?? currNetwork.DeepClone(); //RegrState instance BuildingState regrState = new BuildingState(_networkName, _binBorder, _foldNum, _numOfFolds, _foldNetworkNum, _numOfFoldNetworks, trainer.Attempt, trainer.MaxAttempt, trainer.AttemptEpoch, trainer.MaxAttemptEpoch, currNetwork, bestNetwork, lastImprovementEpoch); //Call controller BuildingInstr instructions = _controller(regrState); //Better? if (instructions.CurrentIsBetter) { //Adopt current regression unit as a best one bestNetwork = currNetwork.DeepClone(); regrState.BestNetwork = bestNetwork; lastImprovementEpoch = trainer.AttemptEpoch; lastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; lastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } if (currNetwork.CombinedBinaryError < lastImprovementCombinedBinaryError || currNetwork.CombinedPrecisionError < lastImprovementCombinedPrecisionError) { lastImprovementEpoch = trainer.AttemptEpoch; lastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; lastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } //Raise notification event RegressionEpochDone(regrState, instructions.CurrentIsBetter); //Process instructions if (instructions.StopProcess) { break; } else if (instructions.StopCurrentAttempt) { if (!trainer.NextAttempt()) { break; } } }//while (iteration) //Create statistics of the best network weights bestNetwork.OutputWeightsStat = bestNetwork.Network.ComputeWeightsStat(); return(bestNetwork); }
/// <summary> /// Builds the trained network. /// </summary> /// <returns>The trained network.</returns> public TNRNet Build() { TNRNet bestNetwork = null; int bestNetworkAttempt = 0; int bestNetworkAttemptEpoch = 0; int currNetworkLastImprovementEpoch = 0; double currNetworkLastImprovementCombinedPrecisionError = 0d; double currNetworkLastImprovementCombinedBinaryError = 0d; //Create network and trainer NonRecurrentNetUtils.CreateNetworkAndTrainer(_networkCfg, _trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, _rand, out INonRecurrentNetwork net, out INonRecurrentNetworkTrainer trainer ); //Iterate training cycles while (trainer.Iteration()) { //Compute current error statistics after training iteration //Training data part TNRNet currNetwork = new TNRNet(_networkName, _networkOutput) { Network = net, TrainerInfoMessage = trainer.InfoMessage, TrainingErrorStat = net.ComputeBatchErrorStat(_trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, out List <double[]> trainingComputedOutputsCollection) }; if (TNRNet.IsBinErrorStatsOutputType(_networkOutput)) { currNetwork.TrainingBinErrorStat = new BinErrStat(BoolBorder, trainingComputedOutputsCollection, _trainingBundle.OutputVectorCollection); currNetwork.CombinedBinaryError = currNetwork.TrainingBinErrorStat.TotalErrStat.Sum; } currNetwork.CombinedPrecisionError = currNetwork.TrainingErrorStat.ArithAvg; //Testing data part currNetwork.TestingErrorStat = net.ComputeBatchErrorStat(_testingBundle.InputVectorCollection, _testingBundle.OutputVectorCollection, out List <double[]> testingComputedOutputsCollection); currNetwork.CombinedPrecisionError = Math.Max(currNetwork.CombinedPrecisionError, currNetwork.TestingErrorStat.ArithAvg); if (TNRNet.IsBinErrorStatsOutputType(_networkOutput)) { currNetwork.TestingBinErrorStat = new BinErrStat(BoolBorder, testingComputedOutputsCollection, _testingBundle.OutputVectorCollection); currNetwork.CombinedBinaryError = Math.Max(currNetwork.CombinedBinaryError, currNetwork.TestingBinErrorStat.TotalErrStat.Sum); } //Restart lastImprovementEpoch when new trainer's attempt started if (trainer.AttemptEpoch == 1) { currNetworkLastImprovementEpoch = trainer.AttemptEpoch; currNetworkLastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; if (TNRNet.IsBinErrorStatsOutputType(_networkOutput)) { currNetworkLastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } } //First initialization of the best network if (bestNetwork == null) { bestNetwork = currNetwork.DeepClone(); bestNetworkAttempt = trainer.Attempt; } if ((TNRNet.IsBinErrorStatsOutputType(_networkOutput) && currNetwork.CombinedBinaryError < currNetworkLastImprovementCombinedBinaryError) || currNetwork.CombinedPrecisionError < currNetworkLastImprovementCombinedPrecisionError ) { currNetworkLastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; if (TNRNet.IsBinErrorStatsOutputType(_networkOutput)) { currNetworkLastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } currNetworkLastImprovementEpoch = trainer.AttemptEpoch; } //BuildProgress instance BuildProgress buildProgress = new BuildProgress(_networkName, trainer.Attempt, trainer.MaxAttempt, trainer.AttemptEpoch, trainer.MaxAttemptEpoch, currNetwork, currNetworkLastImprovementEpoch, bestNetwork, bestNetworkAttempt, bestNetworkAttemptEpoch ); //Call controller BuildInstr instructions = _controller(buildProgress); //Better? if (instructions.CurrentIsBetter) { //Adopt current regression unit as a best one bestNetwork = currNetwork.DeepClone(); bestNetworkAttempt = trainer.Attempt; bestNetworkAttemptEpoch = trainer.AttemptEpoch; //Update build progress buildProgress.BestNetwork = bestNetwork; buildProgress.BestNetworkAttemptNum = bestNetworkAttempt; buildProgress.BestNetworkAttemptEpochNum = bestNetworkAttemptEpoch; } //Raise notification event NetworkBuildProgressChanged?.Invoke(buildProgress); //Process instructions if (instructions.StopProcess) { break; } else if (instructions.StopCurrentAttempt) { if (!trainer.NextAttempt()) { break; } } }//while (iteration) //Create statistics of the best network weights bestNetwork.NetworkWeightsStat = bestNetwork.Network.ComputeWeightsStat(); return(bestNetwork); }