private static void CreateNetAndTreainer(ReadoutLayerSettings.ReadoutUnitSettings settings, List <double[]> trainingPredictorsCollection, List <double[]> trainingIdealOutputsCollection, Random rand, out INonRecurrentNetwork net, out INonRecurrentNetworkTrainer trainer ) { if (settings.NetType == ReadoutLayerSettings.ReadoutUnitSettings.ReadoutUnitNetworkType.FF) { FeedForwardNetworkSettings netCfg = (FeedForwardNetworkSettings)settings.NetSettings; FeedForwardNetwork ffn = new FeedForwardNetwork(trainingPredictorsCollection[0].Length, 1, netCfg); net = ffn; if (netCfg.TrainerCfg.GetType() == typeof(QRDRegrTrainerSettings)) { trainer = new QRDRegrTrainer(ffn, trainingPredictorsCollection, trainingIdealOutputsCollection, (QRDRegrTrainerSettings)netCfg.TrainerCfg, rand); } else if (netCfg.TrainerCfg.GetType() == typeof(RidgeRegrTrainerSettings)) { trainer = new RidgeRegrTrainer(ffn, trainingPredictorsCollection, trainingIdealOutputsCollection, (RidgeRegrTrainerSettings)netCfg.TrainerCfg, rand); } else if (netCfg.TrainerCfg.GetType() == typeof(ElasticRegrTrainerSettings)) { trainer = new ElasticRegrTrainer(ffn, trainingPredictorsCollection, trainingIdealOutputsCollection, (ElasticRegrTrainerSettings)netCfg.TrainerCfg); } else if (netCfg.TrainerCfg.GetType() == typeof(RPropTrainerSettings)) { trainer = new RPropTrainer(ffn, trainingPredictorsCollection, trainingIdealOutputsCollection, (RPropTrainerSettings)netCfg.TrainerCfg, rand); } else { throw new ArgumentException($"Unknown trainer {netCfg.TrainerCfg}"); } } else { ParallelPerceptronSettings netCfg = (ParallelPerceptronSettings)settings.NetSettings; ParallelPerceptron ppn = new ParallelPerceptron(trainingPredictorsCollection[0].Length, netCfg); net = ppn; trainer = new PDeltaRuleTrainer(ppn, trainingPredictorsCollection, trainingIdealOutputsCollection, netCfg.PDeltaRuleTrainerCfg, rand); } net.RandomizeWeights(rand); return; }
private static void CreateNetAndTreainer(ReadoutLayerSettings.ReadoutUnitSettings settings, List <double[]> trainingPredictorsCollection, List <double[]> trainingIdealOutputsCollection, Random rand, out INonRecurrentNetwork net, out INonRecurrentNetworkTrainer trainer ) { if (settings.NetType == ReadoutLayerSettings.ReadoutUnitSettings.ReadoutUnitNetworkType.FF) { FeedForwardNetworkSettings netCfg = (FeedForwardNetworkSettings)settings.NetSettings; FeedForwardNetwork ffn = new FeedForwardNetwork(trainingPredictorsCollection[0].Length, 1, netCfg); net = ffn; switch (netCfg.RegressionMethod) { case FeedForwardNetworkSettings.TrainingMethodType.Linear: trainer = new LinRegrTrainer(ffn, trainingPredictorsCollection, trainingIdealOutputsCollection, settings.RegressionAttemptEpochs, rand, netCfg.LinRegrTrainerCfg); break; case FeedForwardNetworkSettings.TrainingMethodType.Resilient: trainer = new RPropTrainer(ffn, trainingPredictorsCollection, trainingIdealOutputsCollection, netCfg.RPropTrainerCfg); break; default: throw new ArgumentException($"Not supported regression method {netCfg.RegressionMethod}"); } } else { ParallelPerceptronSettings netCfg = (ParallelPerceptronSettings)settings.NetSettings; ParallelPerceptron ppn = new ParallelPerceptron(trainingPredictorsCollection[0].Length, netCfg); net = ppn; trainer = new PDeltaRuleTrainer(ppn, trainingPredictorsCollection, trainingIdealOutputsCollection, netCfg.PDeltaRuleTrainerCfg); } net.RandomizeWeights(rand); return; }
/// <summary> /// Prepares trained readout unit for specified output field and task. /// </summary> /// <param name="taskType">Type of the task</param> /// <param name="readoutUnitIdx">Index of the readout unit (informative only)</param> /// <param name="foldNum">Current fold number</param> /// <param name="numOfFolds">Total number of the folds</param> /// <param name="refBinDistr">Reference bin distribution (if task type is Classification)</param> /// <param name="trainingPredictorsCollection">Collection of the predictors for training</param> /// <param name="trainingIdealOutputsCollection">Collection of ideal outputs for training. Note that the double array always has only one member.</param> /// <param name="testingPredictorsCollection">Collection of the predictors for testing</param> /// <param name="testingIdealOutputsCollection">Collection of ideal outputs for testing. Note that the double array always has only one member.</param> /// <param name="rand">Random object to be used</param> /// <param name="readoutUnitSettings">Readout unit configuration parameters</param> /// <param name="controller">Regression controller</param> /// <param name="controllerUserObject">An user object to be passed to controller</param> /// <returns>Prepared readout unit</returns> public static ReadoutUnit CreateTrained(CommonEnums.TaskType taskType, int readoutUnitIdx, int foldNum, int numOfFolds, BinDistribution refBinDistr, List <double[]> trainingPredictorsCollection, List <double[]> trainingIdealOutputsCollection, List <double[]> testingPredictorsCollection, List <double[]> testingIdealOutputsCollection, Random rand, ReadoutLayerSettings.ReadoutUnitSettings readoutUnitSettings, RegressionCallbackDelegate controller = null, Object controllerUserObject = null ) { ReadoutUnit bestReadoutUnit = null; //Regression attempts bool stopRegression = false; for (int regrAttemptNumber = 1; regrAttemptNumber <= readoutUnitSettings.RegressionAttempts; regrAttemptNumber++) { //Create network and trainer CreateNetAndTreainer(readoutUnitSettings, trainingPredictorsCollection, trainingIdealOutputsCollection, rand, out INonRecurrentNetwork net, out INonRecurrentNetworkTrainer trainer ); //Reference binary distribution //Iterate training cycles for (int epoch = 1; epoch <= readoutUnitSettings.RegressionAttemptEpochs; epoch++) { trainer.Iteration(); List <double[]> testingComputedOutputsCollection = null; //Compute current error statistics after training iteration ReadoutUnit currReadoutUnit = new ReadoutUnit(); currReadoutUnit.Network = net; currReadoutUnit.TrainingErrorStat = net.ComputeBatchErrorStat(trainingPredictorsCollection, trainingIdealOutputsCollection, out List <double[]> trainingComputedOutputsCollection); if (taskType == CommonEnums.TaskType.Classification) { currReadoutUnit.TrainingBinErrorStat = new BinErrStat(refBinDistr, trainingComputedOutputsCollection, trainingIdealOutputsCollection); currReadoutUnit.CombinedBinaryError = currReadoutUnit.TrainingBinErrorStat.TotalErrStat.Sum; //currReadoutUnit.CombinedBinaryError = currReadoutUnit.TrainingBinErrorStat.ProportionalErr; } currReadoutUnit.CombinedPrecisionError = currReadoutUnit.TrainingErrorStat.ArithAvg; if (testingPredictorsCollection != null && testingPredictorsCollection.Count > 0) { currReadoutUnit.TestingErrorStat = net.ComputeBatchErrorStat(testingPredictorsCollection, testingIdealOutputsCollection, out testingComputedOutputsCollection); currReadoutUnit.CombinedPrecisionError = Math.Max(currReadoutUnit.CombinedPrecisionError, currReadoutUnit.TestingErrorStat.ArithAvg); if (taskType == CommonEnums.TaskType.Classification) { currReadoutUnit.TestingBinErrorStat = new BinErrStat(refBinDistr, testingComputedOutputsCollection, testingIdealOutputsCollection); currReadoutUnit.CombinedBinaryError = Math.Max(currReadoutUnit.CombinedBinaryError, currReadoutUnit.TestingBinErrorStat.TotalErrStat.Sum); //currReadoutUnit.CombinedBinaryError = Math.Max(currReadoutUnit.CombinedBinaryError, currReadoutUnit.TestingBinErrorStat.ProportionalErr); } } //Current results processing bool better = false, stopTrainingCycle = false; //Result first initialization if (bestReadoutUnit == null) { //Adopt current regression results bestReadoutUnit = currReadoutUnit.DeepClone(); } //Perform call back if it is defined if (controller != null) { //Evaluation of the improvement is driven externally RegressionControlInArgs cbIn = new RegressionControlInArgs { TaskType = taskType, ReadoutUnitIdx = readoutUnitIdx, OutputFieldName = readoutUnitSettings.Name, FoldNum = foldNum, NumOfFolds = numOfFolds, RegrAttemptNumber = regrAttemptNumber, RegrMaxAttempts = readoutUnitSettings.RegressionAttempts, Epoch = epoch, MaxEpochs = readoutUnitSettings.RegressionAttemptEpochs, TrainingPredictorsCollection = trainingPredictorsCollection, TrainingIdealOutputsCollection = trainingIdealOutputsCollection, TrainingComputedOutputsCollection = trainingComputedOutputsCollection, TestingPredictorsCollection = testingPredictorsCollection, TestingIdealOutputsCollection = testingIdealOutputsCollection, TestingComputedOutputsCollection = testingComputedOutputsCollection, CurrReadoutUnit = currReadoutUnit, BestReadoutUnit = bestReadoutUnit, UserObject = controllerUserObject }; //Call external controller RegressionControlOutArgs cbOut = controller(cbIn); //Pick up results better = cbOut.CurrentIsBetter; stopTrainingCycle = cbOut.StopCurrentAttempt; stopRegression = cbOut.StopRegression; } else { //Default implementation better = IsBetter(taskType, currReadoutUnit, bestReadoutUnit); } //Best? if (better) { //Adopt current regression results bestReadoutUnit = currReadoutUnit.DeepClone(); } //Training stop conditions if (stopTrainingCycle || stopRegression) { break; } }//epoch //Regression stop conditions if (stopRegression) { break; } }//regrAttemptNumber //Create statistics of the best network weights bestReadoutUnit.OutputWeightsStat = bestReadoutUnit.Network.ComputeWeightsStat(); return(bestReadoutUnit); }