private static TrainingSessionResult Optimize( SequentialNetwork network, BatchesCollection miniBatches, int epochs, float dropout, [NotNull] WeightsUpdater updater, [CanBeNull] IProgress <BatchProgress> batchProgress, [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress, [CanBeNull] ValidationDataset validationDataset, [CanBeNull] TestDataset testDataset, CancellationToken token) { // Setup DateTime startTime = DateTime.Now; List <DatasetEvaluationResult> validationReports = new List <DatasetEvaluationResult>(), testReports = new List <DatasetEvaluationResult>(); TrainingSessionResult PrepareResult(TrainingStopReason reason, int loops) { return(new TrainingSessionResult(reason, loops, DateTime.Now.Subtract(startTime).RoundToSeconds(), validationReports, testReports)); } // Convergence manager for the validation dataset RelativeConvergence convergence = validationDataset == null ? null : new RelativeConvergence(validationDataset.Tolerance, validationDataset.EpochsInterval); // Optional batch monitor BatchProgressMonitor batchMonitor = batchProgress == null ? null : new BatchProgressMonitor(miniBatches.Count, batchProgress); // Create the training batches for (int i = 0; i < epochs; i++) { // Shuffle the training set miniBatches.CrossShuffle(); // Gradient descent over the current batches for (int j = 0; j < miniBatches.BatchesCount; j++) { if (token.IsCancellationRequested) { return(PrepareResult(TrainingStopReason.TrainingCanceled, i)); } network.Backpropagate(miniBatches.Batches[j], dropout, updater); batchMonitor?.NotifyCompletedBatch(miniBatches.Batches[j].X.GetLength(0)); } batchMonitor?.Reset(); // Check for overflows if (!Parallel.For(0, network._Layers.Length, (j, state) => { if (network._Layers[j] is WeightedLayerBase layer && !layer.ValidateWeights()) { state.Break(); } }).IsCompleted)
private static TrainingSessionResult Optimize( NeuralNetworkBase network, BatchesCollection miniBatches, int epochs, float dropout, [NotNull] WeightsUpdater updater, [CanBeNull] IProgress <BatchProgress> batchProgress, [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress, [CanBeNull] ValidationDataset validationDataset, [CanBeNull] TestDataset testDataset, CancellationToken token) { // Setup DateTime startTime = DateTime.Now; List <DatasetEvaluationResult> validationReports = new List <DatasetEvaluationResult>(), testReports = new List <DatasetEvaluationResult>(); TrainingSessionResult PrepareResult(TrainingStopReason reason, int loops) { return(new TrainingSessionResult(reason, loops, DateTime.Now.Subtract(startTime).RoundToSeconds(), validationReports, testReports)); } // Convergence manager for the validation dataset RelativeConvergence convergence = validationDataset == null ? null : new RelativeConvergence(validationDataset.Tolerance, validationDataset.EpochsInterval); // Optional batch monitor BatchProgressMonitor batchMonitor = batchProgress == null ? null : new BatchProgressMonitor(miniBatches.Count, batchProgress); // Create the training batches for (int i = 0; i < epochs; i++) { // Shuffle the training set miniBatches.CrossShuffle(); // Gradient descent over the current batches BackpropagationInProgress = true; for (int j = 0; j < miniBatches.BatchesCount; j++) { if (token.IsCancellationRequested) { BackpropagationInProgress = false; return(PrepareResult(TrainingStopReason.TrainingCanceled, i)); } network.Backpropagate(miniBatches.Batches[j], dropout, updater); batchMonitor?.NotifyCompletedBatch(miniBatches.Batches[j].X.GetLength(0)); } BackpropagationInProgress = false; batchMonitor?.Reset(); if (network.IsInNumericOverflow) { return(PrepareResult(TrainingStopReason.NumericOverflow, i)); } // Check the training progress if (trainingProgress != null) { (float cost, _, float accuracy) = network.Evaluate(miniBatches); trainingProgress.Report(new TrainingProgressEventArgs(i + 1, cost, accuracy)); } // Check the validation dataset if (convergence != null) { (float cost, _, float accuracy) = network.Evaluate(validationDataset.Dataset); validationReports.Add(new DatasetEvaluationResult(cost, accuracy)); convergence.Value = accuracy; if (convergence.HasConverged) { return(PrepareResult(TrainingStopReason.EarlyStopping, i)); } } // Report progress if necessary if (testDataset != null) { (float cost, _, float accuracy) = network.Evaluate(testDataset.Dataset); testReports.Add(new DatasetEvaluationResult(cost, accuracy)); testDataset.ThreadSafeProgressCallback?.Report(new TrainingProgressEventArgs(i + 1, cost, accuracy)); } } return(PrepareResult(TrainingStopReason.EpochsCompleted, epochs)); }