Example #1
0
        private static TrainingSessionResult Optimize(
            SequentialNetwork network,
            BatchesCollection miniBatches,
            int epochs, float dropout,
            [NotNull] WeightsUpdater updater,
            [CanBeNull] IProgress <BatchProgress> batchProgress,
            [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress,
            [CanBeNull] ValidationDataset validationDataset,
            [CanBeNull] TestDataset testDataset,
            CancellationToken token)
        {
            // Setup
            DateTime startTime = DateTime.Now;
            List <DatasetEvaluationResult>
            validationReports = new List <DatasetEvaluationResult>(),
                testReports   = new List <DatasetEvaluationResult>();

            TrainingSessionResult PrepareResult(TrainingStopReason reason, int loops)
            {
                return(new TrainingSessionResult(reason, loops, DateTime.Now.Subtract(startTime).RoundToSeconds(), validationReports, testReports));
            }

            // Convergence manager for the validation dataset
            RelativeConvergence convergence = validationDataset == null
                ? null
                : new RelativeConvergence(validationDataset.Tolerance, validationDataset.EpochsInterval);

            // Optional batch monitor
            BatchProgressMonitor batchMonitor = batchProgress == null ? null : new BatchProgressMonitor(miniBatches.Count, batchProgress);

            // Create the training batches
            for (int i = 0; i < epochs; i++)
            {
                // Shuffle the training set
                miniBatches.CrossShuffle();

                // Gradient descent over the current batches
                for (int j = 0; j < miniBatches.BatchesCount; j++)
                {
                    if (token.IsCancellationRequested)
                    {
                        return(PrepareResult(TrainingStopReason.TrainingCanceled, i));
                    }
                    network.Backpropagate(miniBatches.Batches[j], dropout, updater);
                    batchMonitor?.NotifyCompletedBatch(miniBatches.Batches[j].X.GetLength(0));
                }
                batchMonitor?.Reset();

                // Check for overflows
                if (!Parallel.For(0, network._Layers.Length, (j, state) =>
                {
                    if (network._Layers[j] is WeightedLayerBase layer && !layer.ValidateWeights())
                    {
                        state.Break();
                    }
                }).IsCompleted)
Example #2
0
        public static TrainingSessionResult TrainNetwork(
            [NotNull] SequentialNetwork network, [NotNull] BatchesCollection batches,
            int epochs, float dropout,
            [NotNull] ITrainingAlgorithmInfo algorithm,
            [CanBeNull] IProgress <BatchProgress> batchProgress,
            [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress,
            [CanBeNull] ValidationDataset validationDataset,
            [CanBeNull] TestDataset testDataset,
            CancellationToken token)
        {
            SharedEventsService.TrainingStarting.Raise();
            WeightsUpdater optimizer;

            switch (algorithm)
            {
            /* =================
             * Optimization
             * =================
             * The right optimizer is selected here, and the capatured closure for each of them also contains local temporary data, if needed.
             * In this case, the temporary data is managed, so that it will automatically be disposed by the GC and there won't be the need to use
             * another callback when the training stops to handle the cleanup of unmanaged resources. */
            case MomentumInfo momentum:
                optimizer = WeightsUpdaters.Momentum(momentum, network);
                break;

            case StochasticGradientDescentInfo sgd:
                optimizer = WeightsUpdaters.StochasticGradientDescent(sgd);
                break;

            case AdaGradInfo adagrad:
                optimizer = WeightsUpdaters.AdaGrad(adagrad, network);
                break;

            case AdaDeltaInfo adadelta:
                optimizer = WeightsUpdaters.AdaDelta(adadelta, network);
                break;

            case AdamInfo adam:
                optimizer = WeightsUpdaters.Adam(adam, network);
                break;

            case AdaMaxInfo adamax:
                optimizer = WeightsUpdaters.AdaMax(adamax, network);
                break;

            case RMSPropInfo rms:
                optimizer = WeightsUpdaters.RMSProp(rms, network);
                break;

            default:
                throw new ArgumentException("The input training algorithm type is not supported");
            }
            return(Optimize(network, batches, epochs, dropout, optimizer, batchProgress, trainingProgress, validationDataset, testDataset, token));
        }
Example #3
0
        public void BatchDivisionTest2()
        {
            // Sequential
            float[,]
            x = Enumerable.Range(0, 20000 * 784).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 784),
            y = Enumerable.Range(0, 20000 * 10).Select(_ => ThreadSafeRandom.NextUniform(100)).ToArray().AsSpan().AsMatrix(20000, 10);
            BatchesCollection batches = BatchesCollection.From((x, y), 1547);
            HashSet <int>     set1    = new HashSet <int>();

            for (int i = 0; i < 20000; i++)
            {
                set1.Add(GetUid(x, i) ^ GetUid(y, i));
            }
            HashSet <int> set2 = new HashSet <int>();

            for (int i = 0; i < batches.BatchesCount; i++)
            {
                int h = batches.Batches[i].X.GetLength(0);
                for (int j = 0; j < h; j++)
                {
                    set2.Add(GetUid(batches.Batches[i].X, j) ^ GetUid(batches.Batches[i].Y, j));
                }
            }
            Assert.IsTrue(set1.OrderBy(h => h).SequenceEqual(set2.OrderBy(h => h)));
            batches.CrossShuffle();
            HashSet <int> set3 = new HashSet <int>();

            for (int i = 0; i < batches.BatchesCount; i++)
            {
                int h = batches.Batches[i].X.GetLength(0);
                for (int j = 0; j < h; j++)
                {
                    set3.Add(GetUid(batches.Batches[i].X, j) ^ GetUid(batches.Batches[i].Y, j));
                }
            }
            Assert.IsTrue(set1.OrderBy(h => h).SequenceEqual(set3.OrderBy(h => h)));
        }
Example #4
0
        private static TrainingSessionResult Optimize(
            NeuralNetworkBase network,
            BatchesCollection miniBatches,
            int epochs, float dropout,
            [NotNull] WeightsUpdater updater,
            [CanBeNull] IProgress <BatchProgress> batchProgress,
            [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress,
            [CanBeNull] ValidationDataset validationDataset,
            [CanBeNull] TestDataset testDataset,
            CancellationToken token)
        {
            // Setup
            DateTime startTime = DateTime.Now;
            List <DatasetEvaluationResult>
            validationReports = new List <DatasetEvaluationResult>(),
                testReports   = new List <DatasetEvaluationResult>();

            TrainingSessionResult PrepareResult(TrainingStopReason reason, int loops)
            {
                return(new TrainingSessionResult(reason, loops, DateTime.Now.Subtract(startTime).RoundToSeconds(), validationReports, testReports));
            }

            // Convergence manager for the validation dataset
            RelativeConvergence convergence = validationDataset == null
                ? null
                : new RelativeConvergence(validationDataset.Tolerance, validationDataset.EpochsInterval);

            // Optional batch monitor
            BatchProgressMonitor batchMonitor = batchProgress == null ? null : new BatchProgressMonitor(miniBatches.Count, batchProgress);

            // Create the training batches
            for (int i = 0; i < epochs; i++)
            {
                // Shuffle the training set
                miniBatches.CrossShuffle();

                // Gradient descent over the current batches
                BackpropagationInProgress = true;
                for (int j = 0; j < miniBatches.BatchesCount; j++)
                {
                    if (token.IsCancellationRequested)
                    {
                        BackpropagationInProgress = false;
                        return(PrepareResult(TrainingStopReason.TrainingCanceled, i));
                    }
                    network.Backpropagate(miniBatches.Batches[j], dropout, updater);
                    batchMonitor?.NotifyCompletedBatch(miniBatches.Batches[j].X.GetLength(0));
                }
                BackpropagationInProgress = false;
                batchMonitor?.Reset();
                if (network.IsInNumericOverflow)
                {
                    return(PrepareResult(TrainingStopReason.NumericOverflow, i));
                }

                // Check the training progress
                if (trainingProgress != null)
                {
                    (float cost, _, float accuracy) = network.Evaluate(miniBatches);
                    trainingProgress.Report(new TrainingProgressEventArgs(i + 1, cost, accuracy));
                }

                // Check the validation dataset
                if (convergence != null)
                {
                    (float cost, _, float accuracy) = network.Evaluate(validationDataset.Dataset);
                    validationReports.Add(new DatasetEvaluationResult(cost, accuracy));
                    convergence.Value = accuracy;
                    if (convergence.HasConverged)
                    {
                        return(PrepareResult(TrainingStopReason.EarlyStopping, i));
                    }
                }

                // Report progress if necessary
                if (testDataset != null)
                {
                    (float cost, _, float accuracy) = network.Evaluate(testDataset.Dataset);
                    testReports.Add(new DatasetEvaluationResult(cost, accuracy));
                    testDataset.ThreadSafeProgressCallback?.Report(new TrainingProgressEventArgs(i + 1, cost, accuracy));
                }
            }
            return(PrepareResult(TrainingStopReason.EpochsCompleted, epochs));
        }