private static TrainingSessionResult Optimize( SequentialNetwork network, BatchesCollection miniBatches, int epochs, float dropout, [NotNull] WeightsUpdater updater, [CanBeNull] IProgress <BatchProgress> batchProgress, [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress, [CanBeNull] ValidationDataset validationDataset, [CanBeNull] TestDataset testDataset, CancellationToken token) { // Setup DateTime startTime = DateTime.Now; List <DatasetEvaluationResult> validationReports = new List <DatasetEvaluationResult>(), testReports = new List <DatasetEvaluationResult>(); TrainingSessionResult PrepareResult(TrainingStopReason reason, int loops) { return(new TrainingSessionResult(reason, loops, DateTime.Now.Subtract(startTime).RoundToSeconds(), validationReports, testReports)); } // Convergence manager for the validation dataset RelativeConvergence convergence = validationDataset == null ? null : new RelativeConvergence(validationDataset.Tolerance, validationDataset.EpochsInterval); // Optional batch monitor BatchProgressMonitor batchMonitor = batchProgress == null ? null : new BatchProgressMonitor(miniBatches.Count, batchProgress); // Create the training batches for (int i = 0; i < epochs; i++) { // Shuffle the training set miniBatches.CrossShuffle(); // Gradient descent over the current batches for (int j = 0; j < miniBatches.BatchesCount; j++) { if (token.IsCancellationRequested) { return(PrepareResult(TrainingStopReason.TrainingCanceled, i)); } network.Backpropagate(miniBatches.Batches[j], dropout, updater); batchMonitor?.NotifyCompletedBatch(miniBatches.Batches[j].X.GetLength(0)); } batchMonitor?.Reset(); // Check for overflows if (!Parallel.For(0, network._Layers.Length, (j, state) => { if (network._Layers[j] is WeightedLayerBase layer && !layer.ValidateWeights()) { state.Break(); } }).IsCompleted)
public void BatchDivisionTest2() { // Sequential float[,] x = Enumerable.Range(0, 20000 * 784).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 784), y = Enumerable.Range(0, 20000 * 10).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 10); BatchesCollection batches = BatchesCollection.From((x, y), 1547); HashSet <int> set1 = new HashSet <int>(); for (int i = 0; i < 20000; i++) { set1.Add(GetUid(x, i) ^ GetUid(y, i)); } HashSet <int> set2 = new HashSet <int>(); for (int i = 0; i < batches.BatchesCount; i++) { int h = batches.Batches[i].X.GetLength(0); for (int j = 0; j < h; j++) { set2.Add(GetUid(batches.Batches[i].X, j) ^ GetUid(batches.Batches[i].Y, j)); } } Assert.IsTrue(set1.OrderBy(h => h).SequenceEqual(set2.OrderBy(h => h))); batches.CrossShuffle(); HashSet <int> set3 = new HashSet <int>(); for (int i = 0; i < batches.BatchesCount; i++) { int h = batches.Batches[i].X.GetLength(0); for (int j = 0; j < h; j++) { set3.Add(GetUid(batches.Batches[i].X, j) ^ GetUid(batches.Batches[i].Y, j)); } } Assert.IsTrue(set1.OrderBy(h => h).SequenceEqual(set3.OrderBy(h => h))); }
private static TrainingSessionResult Optimize( NeuralNetworkBase network, BatchesCollection miniBatches, int epochs, float dropout, [NotNull] WeightsUpdater updater, [CanBeNull] IProgress <BatchProgress> batchProgress, [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress, [CanBeNull] ValidationDataset validationDataset, [CanBeNull] TestDataset testDataset, CancellationToken token) { // Setup DateTime startTime = DateTime.Now; List <DatasetEvaluationResult> validationReports = new List <DatasetEvaluationResult>(), testReports = new List <DatasetEvaluationResult>(); TrainingSessionResult PrepareResult(TrainingStopReason reason, int loops) { return(new TrainingSessionResult(reason, loops, DateTime.Now.Subtract(startTime).RoundToSeconds(), validationReports, testReports)); } // Convergence manager for the validation dataset RelativeConvergence convergence = validationDataset == null ? null : new RelativeConvergence(validationDataset.Tolerance, validationDataset.EpochsInterval); // Optional batch monitor BatchProgressMonitor batchMonitor = batchProgress == null ? null : new BatchProgressMonitor(miniBatches.Count, batchProgress); // Create the training batches for (int i = 0; i < epochs; i++) { // Shuffle the training set miniBatches.CrossShuffle(); // Gradient descent over the current batches BackpropagationInProgress = true; for (int j = 0; j < miniBatches.BatchesCount; j++) { if (token.IsCancellationRequested) { BackpropagationInProgress = false; return(PrepareResult(TrainingStopReason.TrainingCanceled, i)); } network.Backpropagate(miniBatches.Batches[j], dropout, updater); batchMonitor?.NotifyCompletedBatch(miniBatches.Batches[j].X.GetLength(0)); } BackpropagationInProgress = false; batchMonitor?.Reset(); if (network.IsInNumericOverflow) { return(PrepareResult(TrainingStopReason.NumericOverflow, i)); } // Check the training progress if (trainingProgress != null) { (float cost, _, float accuracy) = network.Evaluate(miniBatches); trainingProgress.Report(new TrainingProgressEventArgs(i + 1, cost, accuracy)); } // Check the validation dataset if (convergence != null) { (float cost, _, float accuracy) = network.Evaluate(validationDataset.Dataset); validationReports.Add(new DatasetEvaluationResult(cost, accuracy)); convergence.Value = accuracy; if (convergence.HasConverged) { return(PrepareResult(TrainingStopReason.EarlyStopping, i)); } } // Report progress if necessary if (testDataset != null) { (float cost, _, float accuracy) = network.Evaluate(testDataset.Dataset); testReports.Add(new DatasetEvaluationResult(cost, accuracy)); testDataset.ThreadSafeProgressCallback?.Report(new TrainingProgressEventArgs(i + 1, cost, accuracy)); } } return(PrepareResult(TrainingStopReason.EpochsCompleted, epochs)); }