//Methods /// <summary> /// Adds a new member network and updates the cluster error statistics. /// </summary> /// <param name="newMemberNet">The new member network.</param> /// <param name="scopeID">The ID of a network's scope.</param> /// <param name="testData">The testing data bundle (unseen by the network to be added).</param> /// <param name="filters">The filters to be used to denormalize outputs.</param> public void AddMember(TNRNet newMemberNet, int scopeID, VectorBundle testData, FeatureFilterBase[] filters) { //Check the network output if (Output != newMemberNet.Output) { throw new ArgumentException("Inconsistent output type of the network to be added.", "newMemberNet"); } //Check number of outputs consistency if (_memberNetCollection.Count > 0) { if (newMemberNet.Network.NumOfOutputValues != NumOfOutputs) { throw new ArgumentException("Number of outputs of the network differs from already clustered networks.", "newMemberNet"); } } //Add member to inner collection _memberNetCollection.Add(newMemberNet); _memberNetScopeIDCollection.Add(scopeID); //Update cluster error statistics for (int sampleIdx = 0; sampleIdx < testData.OutputVectorCollection.Count; sampleIdx++) { double[] nrmComputedValues = newMemberNet.Network.Compute(testData.InputVectorCollection[sampleIdx]); for (int outIdx = 0; outIdx < nrmComputedValues.Length; outIdx++) { double naturalComputedValue = filters != null ? filters[outIdx].ApplyReverse(nrmComputedValues[outIdx]) : nrmComputedValues[outIdx]; double naturalIdealValue = filters != null ? filters[outIdx].ApplyReverse(testData.OutputVectorCollection[sampleIdx][outIdx]) : testData.OutputVectorCollection[sampleIdx][outIdx]; ErrorStats.Update(nrmComputedValues[outIdx], testData.OutputVectorCollection[sampleIdx][outIdx], naturalComputedValue, naturalIdealValue ); } //outIdx } //sampleIdx return; }
/// <summary> /// The deep copy constructor. /// </summary> /// <param name="source">The source instance.</param> public TNRNet(TNRNet source) { Name = source.Name; Output = source.Output; Network = source.Network?.DeepClone(); TrainerInfoMessage = source.TrainerInfoMessage; TrainingErrorStat = source.TrainingErrorStat?.DeepClone(); TrainingBinErrorStat = source.TrainingBinErrorStat?.DeepClone(); TestingErrorStat = source.TestingErrorStat?.DeepClone(); TestingBinErrorStat = source.TestingBinErrorStat?.DeepClone(); NetworkWeightsStat = source.NetworkWeightsStat?.DeepClone(); CombinedPrecisionError = source.CombinedPrecisionError; CombinedBinaryError = source.CombinedBinaryError; return; }
/// <summary> /// Creates an uninitialized instance. /// </summary> /// <param name="clusterName">The name of the cluster.</param> /// <param name="outputType">The type of output.</param> public ClusterErrStatistics(string clusterName, TNRNet.OutputType outputType) { ClusterName = clusterName; NatPrecissionErrStat = new BasicStat(); NrmPrecissionErrStat = new BasicStat(); if (TNRNet.IsBinErrorStatsOutputType(outputType)) { BinaryErrStat = new BinErrStat(TNRNet.GetOutputDataRange(outputType).Mid); } else { BinaryErrStat = null; } return; }
//Static methods /// <summary> /// Evaluates whether the "candidate" network achieved a better result than the best network so far. /// </summary> /// <remarks> /// The default implementation. /// </remarks> /// <param name="candidate">The candidate network to be evaluated.</param> /// <param name="currentBest">The best network so far.</param> public static bool IsBetter(TNRNet candidate, TNRNet currentBest) { //Binary decisions comparison if (candidate.HasBinErrorStats) { if (candidate.CombinedBinaryError > currentBest.CombinedBinaryError) { return(false); } else if (candidate.CombinedBinaryError < currentBest.CombinedBinaryError) { return(true); } //CombinedBinaryError is the same else if (candidate.TestingBinErrorStat.BinValErrStat[0].Sum > currentBest.TestingBinErrorStat.BinValErrStat[0].Sum) { return(false); } else if (candidate.TestingBinErrorStat.BinValErrStat[0].Sum < currentBest.TestingBinErrorStat.BinValErrStat[0].Sum) { return(true); } //CombinedBinaryError is the same //TestingBinErrorStat.BinValErrStat[0].Sum is the same else if (candidate.TrainingBinErrorStat.BinValErrStat[0].Sum > currentBest.TrainingBinErrorStat.BinValErrStat[0].Sum) { return(false); } else if (candidate.TrainingBinErrorStat.BinValErrStat[0].Sum < currentBest.TrainingBinErrorStat.BinValErrStat[0].Sum) { return(true); } } //Numerical precision comparison if (candidate.CombinedPrecisionError < currentBest.CombinedPrecisionError) { return(true); } else { return(false); } }
/// <summary> /// Gets textual information about the specified trained network instance. /// </summary> /// <param name="network">An instance of trained network.</param> /// <param name="shortVersion">Specifies whether to build short version of the informative text.</param> public string GetNetworkInfoText(TNRNet network, bool shortVersion = true) { StringBuilder text = new StringBuilder(); if (shortVersion) { text.Append("TrainErr "); text.Append(network.TrainingErrorStat.ArithAvg.ToString("E3", CultureInfo.InvariantCulture)); if (network.HasBinErrorStats) { text.Append("/" + network.TrainingBinErrorStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture)); text.Append("/" + network.TrainingBinErrorStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture)); } text.Append(", TestErr "); text.Append(network.TestingErrorStat.ArithAvg.ToString("E3", CultureInfo.InvariantCulture)); if (network.HasBinErrorStats) { text.Append("/" + network.TestingBinErrorStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture)); text.Append("/" + network.TestingBinErrorStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture)); } } else { text.Append("Training numerical error "); text.Append(network.TrainingErrorStat.ArithAvg.ToString("E3", CultureInfo.InvariantCulture)); if (network.HasBinErrorStats) { text.Append(", total bad classifications " + network.TrainingBinErrorStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture)); text.Append(", false positive classifications " + network.TrainingBinErrorStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture)); } text.Append(", Testing numerical error "); text.Append(network.TestingErrorStat.ArithAvg.ToString("E3", CultureInfo.InvariantCulture)); if (network.HasBinErrorStats) { text.Append(", total incorrect classifications " + network.TestingBinErrorStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture)); text.Append(", false positive classifications " + network.TestingBinErrorStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture)); } } return(text.ToString()); }
/// <summary> /// Creates an initialized instance. /// </summary> /// <param name="networkName">Name of the network.</param> /// <param name="attemptNum">The current attempt number.</param> /// <param name="maxNumOfAttempts">The maximum number of attempts.</param> /// <param name="attemptEpochNum">The current epoch number within the current attempt.</param> /// <param name="maxNumOfAttemptEpochs">The maximum number of epochs.</param> /// <param name="currNetwork">The current network and its error statistics.</param> /// <param name="currNetworkLastImprovementEpochNum">An epoch number within the current build attempt when was found an improvement of the current network.</param> /// <param name="bestNetwork">The best network so far and its error statistics.</param> /// <param name="bestNetworkAttemptNum">The attempt number in which was found the best network so far.</param> /// <param name="bestNetworkAttemptEpochNum">The epoch number within the bestNetworkAttemptNum in which was found the best network so far.</param> public BuildProgress(string networkName, int attemptNum, int maxNumOfAttempts, int attemptEpochNum, int maxNumOfAttemptEpochs, TNRNet currNetwork, int currNetworkLastImprovementEpochNum, TNRNet bestNetwork, int bestNetworkAttemptNum, int bestNetworkAttemptEpochNum ) { NetworkName = networkName; AttemptsTracker = new ProgressTracker((uint)maxNumOfAttempts, (uint)attemptNum); AttemptEpochsTracker = new ProgressTracker((uint)maxNumOfAttemptEpochs, (uint)attemptEpochNum); CurrNetwork = currNetwork; CurrNetworkLastImprovementAttemptEpochNum = currNetworkLastImprovementEpochNum; BestNetwork = bestNetwork; BestNetworkAttemptNum = bestNetworkAttemptNum; BestNetworkAttemptEpochNum = bestNetworkAttemptEpochNum; return; }
/// <summary> /// Builds the cluster chain. /// </summary> /// <param name="dataBundle">The data bundle for training.</param> /// <param name="filters">The filters to be used to denormalize outputs.</param> public TNRNetClusterChain Build(VectorBundle dataBundle, FeatureFilterBase[] filters) { //The chain to be built TNRNetClusterChain chain = new TNRNetClusterChain(_chainName, _clusterChainCfg.Output); //Instantiate chained clusters List <TNRNetCluster> chainClusters = new List <TNRNetCluster>(_clusterChainCfg.ClusterCfgCollection.Count); for (int clusterIdx = 0; clusterIdx < _clusterChainCfg.ClusterCfgCollection.Count; clusterIdx++) { //Cluster chainClusters.Add(new TNRNetCluster(_chainName, _clusterChainCfg.ClusterCfgCollection[clusterIdx].Output, _clusterChainCfg.ClusterCfgCollection[clusterIdx].TrainingGroupWeight, _clusterChainCfg.ClusterCfgCollection[clusterIdx].TestingGroupWeight, _clusterChainCfg.ClusterCfgCollection[clusterIdx].SamplesWeight, _clusterChainCfg.ClusterCfgCollection[clusterIdx].NumericalPrecisionWeight, _clusterChainCfg.ClusterCfgCollection[clusterIdx].MisrecognizedFalseWeight, _clusterChainCfg.ClusterCfgCollection[clusterIdx].UnrecognizedTrueWeight ) ); } //Common crossvalidation configuration double boolBorder = _clusterChainCfg.Output == TNRNet.OutputType.Real ? double.NaN : chain.OutputDataRange.Mid; VectorBundle localDataBundle = dataBundle.CreateShallowCopy(); //Member's training ResetProgressTracking(); for (_repetitionIdx = 0; _repetitionIdx < _clusterChainCfg.CrossvalidationCfg.Repetitions; _repetitionIdx++) { //Split data to folds List <VectorBundle> foldCollection = localDataBundle.Folderize(_clusterChainCfg.CrossvalidationCfg.FoldDataRatio, boolBorder); _numOfFoldsPerRepetition = Math.Min(_clusterChainCfg.CrossvalidationCfg.Folds <= 0 ? foldCollection.Count : _clusterChainCfg.CrossvalidationCfg.Folds, foldCollection.Count); List <VectorBundle> currentClusterFoldCollection = CopyFolds(foldCollection); List <VectorBundle> nextClusterFoldCollection = new List <VectorBundle>(foldCollection.Count); //For each cluster for (_clusterIdx = 0; _clusterIdx < chainClusters.Count; _clusterIdx++) { //Train networks for each testing fold. for (_testingFoldIdx = 0; _testingFoldIdx < _numOfFoldsPerRepetition; _testingFoldIdx++) { //Prepare training data bundle VectorBundle trainingData = new VectorBundle(); for (int foldIdx = 0; foldIdx < currentClusterFoldCollection.Count; foldIdx++) { if (foldIdx != _testingFoldIdx) { trainingData.Add(currentClusterFoldCollection[foldIdx]); } } VectorBundle nextClusterUpdatedDataFold = foldCollection[_testingFoldIdx].CreateShallowCopy(); for (_netCfgIdx = 0; _netCfgIdx < _clusterChainCfg.ClusterCfgCollection[_clusterIdx].ClusterNetConfigurations.Count; _netCfgIdx++) { TNRNetBuilder netBuilder = new TNRNetBuilder(_chainName, _clusterChainCfg.ClusterCfgCollection[_clusterIdx].ClusterNetConfigurations[_netCfgIdx], _clusterChainCfg.ClusterCfgCollection[_clusterIdx].Output, trainingData, currentClusterFoldCollection[_testingFoldIdx], _rand, _controller ); //Register notification netBuilder.NetworkBuildProgressChanged += OnNetworkBuildProgressChanged; //Build trained network. Trained network becomes to be the cluster member TNRNet tn = netBuilder.Build(); int netScopeID = _repetitionIdx * NetScopeDelimiterCoeff + _testingFoldIdx; chainClusters[_clusterIdx].AddMember(tn, netScopeID, currentClusterFoldCollection[_testingFoldIdx], filters); //Update input data in the data fold for the next cluster for (int sampleIdx = 0; sampleIdx < currentClusterFoldCollection[_testingFoldIdx].InputVectorCollection.Count; sampleIdx++) { double[] computedNetData = tn.Network.Compute(currentClusterFoldCollection[_testingFoldIdx].InputVectorCollection[sampleIdx]); nextClusterUpdatedDataFold.InputVectorCollection[sampleIdx] = nextClusterUpdatedDataFold.InputVectorCollection[sampleIdx].Concat(computedNetData); } }//netCfgIdx //Add updated data fold for the next cluster nextClusterFoldCollection.Add(nextClusterUpdatedDataFold); }//testingFoldIdx //Switch fold collection currentClusterFoldCollection = nextClusterFoldCollection; nextClusterFoldCollection = new List <VectorBundle>(currentClusterFoldCollection.Count); }//clusterIdx if (_repetitionIdx < _clusterChainCfg.CrossvalidationCfg.Repetitions - 1) { //Reshuffle the data localDataBundle.Shuffle(_rand); } }//repetitionIdx //Make the clusters operable and add them into the chain for (int clusterIdx = 0; clusterIdx < chainClusters.Count; clusterIdx++) { chainClusters[clusterIdx].FinalizeCluster(); chain.AddCluster(chainClusters[clusterIdx]); } //Return the built chain return(chain); }
/// <summary> /// Builds the cluster. /// </summary> /// <param name="dataBundle">The data bundle for training.</param> /// <param name="filters">The filters to be used to denormalize outputs.</param> public TNRNetCluster Build(VectorBundle dataBundle, FeatureFilterBase[] filters) { VectorBundle localDataBundle = dataBundle.CreateShallowCopy(); //Cluster of trained networks TNRNetCluster cluster = new TNRNetCluster(_clusterName, _clusterCfg.Output, _clusterCfg.TrainingGroupWeight, _clusterCfg.TestingGroupWeight, _clusterCfg.SamplesWeight, _clusterCfg.NumericalPrecisionWeight, _clusterCfg.MisrecognizedFalseWeight, _clusterCfg.UnrecognizedTrueWeight ); //Member's training ResetProgressTracking(); for (_repetitionIdx = 0; _repetitionIdx < _crossvalidationCfg.Repetitions; _repetitionIdx++) { //Data split to folds List <VectorBundle> foldCollection = localDataBundle.Folderize(_crossvalidationCfg.FoldDataRatio, _clusterCfg.Output == TNRNet.OutputType.Real ? double.NaN : cluster.OutputDataRange.Mid); _numOfFoldsPerRepetition = Math.Min(_crossvalidationCfg.Folds <= 0 ? foldCollection.Count : _crossvalidationCfg.Folds, foldCollection.Count); //Train the collection of networks for each processing fold. for (_testingFoldIdx = 0; _testingFoldIdx < _numOfFoldsPerRepetition; _testingFoldIdx++) { //Prepare training data bundle VectorBundle trainingData = new VectorBundle(); for (int foldIdx = 0; foldIdx < foldCollection.Count; foldIdx++) { if (foldIdx != _testingFoldIdx) { trainingData.Add(foldCollection[foldIdx]); } } for (_netCfgIdx = 0; _netCfgIdx < _clusterCfg.ClusterNetConfigurations.Count; _netCfgIdx++) { TNRNetBuilder netBuilder = new TNRNetBuilder(_clusterName, _clusterCfg.ClusterNetConfigurations[_netCfgIdx], _clusterCfg.Output, trainingData, foldCollection[_testingFoldIdx], _rand, _controller ); //Register notification netBuilder.NetworkBuildProgressChanged += OnNetworkBuildProgressChanged; //Build trained network. Trained network becomes to be the cluster member TNRNet tn = netBuilder.Build(); //Build an unique network scope identifier int netScopeID = _repetitionIdx * NetScopeDelimiterCoeff + _testingFoldIdx; //Add trained network to a cluster cluster.AddMember(tn, netScopeID, foldCollection[_testingFoldIdx], filters); } //netCfgIdx } //testingFoldIdx if (_repetitionIdx < _crossvalidationCfg.Repetitions - 1) { //Reshuffle the data localDataBundle.Shuffle(_rand); } }//repetitionIdx //Make the cluster operable cluster.FinalizeCluster(); //Return the built cluster return(cluster); }
/// <summary> /// Builds the trained network. /// </summary> /// <returns>The trained network.</returns> public TNRNet Build() { TNRNet bestNetwork = null; int bestNetworkAttempt = 0; int bestNetworkAttemptEpoch = 0; int currNetworkLastImprovementEpoch = 0; double currNetworkLastImprovementCombinedPrecisionError = 0d; double currNetworkLastImprovementCombinedBinaryError = 0d; //Create network and trainer NonRecurrentNetUtils.CreateNetworkAndTrainer(_networkCfg, _trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, _rand, out INonRecurrentNetwork net, out INonRecurrentNetworkTrainer trainer ); //Iterate training cycles while (trainer.Iteration()) { //Compute current error statistics after training iteration //Training data part TNRNet currNetwork = new TNRNet(_networkName, _networkOutput) { Network = net, TrainerInfoMessage = trainer.InfoMessage, TrainingErrorStat = net.ComputeBatchErrorStat(_trainingBundle.InputVectorCollection, _trainingBundle.OutputVectorCollection, out List <double[]> trainingComputedOutputsCollection) }; if (TNRNet.IsBinErrorStatsOutputType(_networkOutput)) { currNetwork.TrainingBinErrorStat = new BinErrStat(BoolBorder, trainingComputedOutputsCollection, _trainingBundle.OutputVectorCollection); currNetwork.CombinedBinaryError = currNetwork.TrainingBinErrorStat.TotalErrStat.Sum; } currNetwork.CombinedPrecisionError = currNetwork.TrainingErrorStat.ArithAvg; //Testing data part currNetwork.TestingErrorStat = net.ComputeBatchErrorStat(_testingBundle.InputVectorCollection, _testingBundle.OutputVectorCollection, out List <double[]> testingComputedOutputsCollection); currNetwork.CombinedPrecisionError = Math.Max(currNetwork.CombinedPrecisionError, currNetwork.TestingErrorStat.ArithAvg); if (TNRNet.IsBinErrorStatsOutputType(_networkOutput)) { currNetwork.TestingBinErrorStat = new BinErrStat(BoolBorder, testingComputedOutputsCollection, _testingBundle.OutputVectorCollection); currNetwork.CombinedBinaryError = Math.Max(currNetwork.CombinedBinaryError, currNetwork.TestingBinErrorStat.TotalErrStat.Sum); } //Restart lastImprovementEpoch when new trainer's attempt started if (trainer.AttemptEpoch == 1) { currNetworkLastImprovementEpoch = trainer.AttemptEpoch; currNetworkLastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; if (TNRNet.IsBinErrorStatsOutputType(_networkOutput)) { currNetworkLastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } } //First initialization of the best network if (bestNetwork == null) { bestNetwork = currNetwork.DeepClone(); bestNetworkAttempt = trainer.Attempt; } if ((TNRNet.IsBinErrorStatsOutputType(_networkOutput) && currNetwork.CombinedBinaryError < currNetworkLastImprovementCombinedBinaryError) || currNetwork.CombinedPrecisionError < currNetworkLastImprovementCombinedPrecisionError ) { currNetworkLastImprovementCombinedPrecisionError = currNetwork.CombinedPrecisionError; if (TNRNet.IsBinErrorStatsOutputType(_networkOutput)) { currNetworkLastImprovementCombinedBinaryError = currNetwork.CombinedBinaryError; } currNetworkLastImprovementEpoch = trainer.AttemptEpoch; } //BuildProgress instance BuildProgress buildProgress = new BuildProgress(_networkName, trainer.Attempt, trainer.MaxAttempt, trainer.AttemptEpoch, trainer.MaxAttemptEpoch, currNetwork, currNetworkLastImprovementEpoch, bestNetwork, bestNetworkAttempt, bestNetworkAttemptEpoch ); //Call controller BuildInstr instructions = _controller(buildProgress); //Better? if (instructions.CurrentIsBetter) { //Adopt current regression unit as a best one bestNetwork = currNetwork.DeepClone(); bestNetworkAttempt = trainer.Attempt; bestNetworkAttemptEpoch = trainer.AttemptEpoch; //Update build progress buildProgress.BestNetwork = bestNetwork; buildProgress.BestNetworkAttemptNum = bestNetworkAttempt; buildProgress.BestNetworkAttemptEpochNum = bestNetworkAttemptEpoch; } //Raise notification event NetworkBuildProgressChanged?.Invoke(buildProgress); //Process instructions if (instructions.StopProcess) { break; } else if (instructions.StopCurrentAttempt) { if (!trainer.NextAttempt()) { break; } } }//while (iteration) //Create statistics of the best network weights bestNetwork.NetworkWeightsStat = bestNetwork.Network.ComputeWeightsStat(); return(bestNetwork); }