示例#1
0
        private static TrainingSessionResult TrainNetworkCore(
            [NotNull] INeuralNetwork network,
            [NotNull] ITrainingDataset dataset,
            [NotNull] ITrainingAlgorithmInfo algorithm,
            int epochs, float dropout,
            [CanBeNull] IProgress <BatchProgress> batchProgress,
            [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress,
            [CanBeNull] IValidationDataset validationDataset,
            [CanBeNull] ITestDataset testDataset,
            CancellationToken token)
        {
            // Preliminary checks
            if (epochs < 1)
            {
                throw new ArgumentOutOfRangeException(nameof(epochs), "The number of epochs must at be at least equal to 1");
            }
            if (dropout < 0 || dropout >= 1)
            {
                throw new ArgumentOutOfRangeException(nameof(dropout), "The dropout probability is invalid");
            }

            // Start the training
            return(NetworkTrainer.TrainNetwork(
                       network as SequentialNetwork ?? throw new ArgumentException("The input network instance isn't valid", nameof(network)),
                       dataset as BatchesCollection ?? throw new ArgumentException("The input dataset instance isn't valid", nameof(dataset)),
                       epochs, dropout, algorithm, batchProgress, trainingProgress,
                       validationDataset as ValidationDataset,
                       testDataset as TestDataset,
                       token));
        }
示例#2
0
        public static TrainingSessionResult TrainNetwork(
            [NotNull] SequentialNetwork network, [NotNull] BatchesCollection batches,
            int epochs, float dropout,
            [NotNull] ITrainingAlgorithmInfo algorithm,
            [CanBeNull] IProgress <BatchProgress> batchProgress,
            [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress,
            [CanBeNull] ValidationDataset validationDataset,
            [CanBeNull] TestDataset testDataset,
            CancellationToken token)
        {
            SharedEventsService.TrainingStarting.Raise();
            WeightsUpdater optimizer;

            switch (algorithm)
            {
            /* =================
             * Optimization
             * =================
             * The right optimizer is selected here, and the capatured closure for each of them also contains local temporary data, if needed.
             * In this case, the temporary data is managed, so that it will automatically be disposed by the GC and there won't be the need to use
             * another callback when the training stops to handle the cleanup of unmanaged resources. */
            case MomentumInfo momentum:
                optimizer = WeightsUpdaters.Momentum(momentum, network);
                break;

            case StochasticGradientDescentInfo sgd:
                optimizer = WeightsUpdaters.StochasticGradientDescent(sgd);
                break;

            case AdaGradInfo adagrad:
                optimizer = WeightsUpdaters.AdaGrad(adagrad, network);
                break;

            case AdaDeltaInfo adadelta:
                optimizer = WeightsUpdaters.AdaDelta(adadelta, network);
                break;

            case AdamInfo adam:
                optimizer = WeightsUpdaters.Adam(adam, network);
                break;

            case AdaMaxInfo adamax:
                optimizer = WeightsUpdaters.AdaMax(adamax, network);
                break;

            case RMSPropInfo rms:
                optimizer = WeightsUpdaters.RMSProp(rms, network);
                break;

            default:
                throw new ArgumentException("The input training algorithm type is not supported");
            }
            return(Optimize(network, batches, epochs, dropout, optimizer, batchProgress, trainingProgress, validationDataset, testDataset, token));
        }
示例#3
0
 public static TrainingSessionResult TrainNetwork(
     [NotNull] INeuralNetwork network,
     [NotNull] ITrainingDataset dataset,
     [NotNull] ITrainingAlgorithmInfo algorithm,
     int epochs, float dropout = 0,
     [CanBeNull] Action <BatchProgress> batchCallback = null,
     [CanBeNull] Action <TrainingProgressEventArgs> trainingCallback = null,
     [CanBeNull] IValidationDataset validationDataset = null,
     [CanBeNull] ITestDataset testDataset             = null,
     CancellationToken token = default)
 {
     return(TrainNetworkCore(network, dataset, algorithm, epochs, dropout, batchCallback.AsIProgress(), trainingCallback.AsIProgress(), validationDataset, testDataset, token));
 }
示例#4
0
        private static TrainingSessionResult TrainNetworkCore(
            [NotNull] INeuralNetwork network,
            [NotNull] ITrainingDataset dataset,
            [NotNull] ITrainingAlgorithmInfo algorithm,
            int epochs, float dropout,
            [CanBeNull] IProgress <BatchProgress> batchProgress,
            [CanBeNull] IProgress <TrainingProgressEventArgs> trainingProgress,
            [CanBeNull] IValidationDataset validationDataset,
            [CanBeNull] ITestDataset testDataset,
            CancellationToken token)
        {
            // Preliminary checks
            if (epochs < 1)
            {
                throw new ArgumentOutOfRangeException(nameof(epochs), "The number of epochs must at be at least equal to 1");
            }
            if (dropout < 0 || dropout >= 1)
            {
                throw new ArgumentOutOfRangeException(nameof(dropout), "The dropout probability is invalid");
            }
            if (validationDataset != null && (validationDataset.InputFeatures != dataset.InputFeatures || validationDataset.OutputFeatures != dataset.OutputFeatures))
            {
                throw new ArgumentException("The validation dataset doesn't match the training dataset", nameof(validationDataset));
            }
            if (testDataset != null && (testDataset.InputFeatures != dataset.InputFeatures || testDataset.OutputFeatures != dataset.OutputFeatures))
            {
                throw new ArgumentException("The test dataset doesn't match the training dataset", nameof(testDataset));
            }
            if (dataset.InputFeatures != network.InputInfo.Size || dataset.OutputFeatures != network.OutputInfo.Size)
            {
                throw new ArgumentException("The input dataset doesn't match the number of input and output features for the current network", nameof(dataset));
            }

            // Start the training
            TrainingInProgress = TrainingInProgress
                ? throw new InvalidOperationException("Can't train two networks at the same time") // This would cause problems with cuDNN
                : true;
            TrainingSessionResult result = NetworkTrainer.TrainNetwork(
                network as NeuralNetworkBase ?? throw new ArgumentException("The input network instance isn't valid", nameof(network)),
                dataset as BatchesCollection ?? throw new ArgumentException("The input dataset instance isn't valid", nameof(dataset)),
                epochs, dropout, algorithm, batchProgress, trainingProgress,
                validationDataset as ValidationDataset,
                testDataset as TestDataset,
                token);

            TrainingInProgress = false;
            return(result);
        }
示例#5
0
        public static Task <TrainingSessionResult> TrainNetworkAsync(
            [NotNull] INeuralNetwork network,
            [NotNull] ITrainingDataset dataset,
            [NotNull] ITrainingAlgorithmInfo algorithm,
            int epochs, float dropout = 0,
            [CanBeNull] Action <BatchProgress> batchCallback = null,
            [CanBeNull] Action <TrainingProgressEventArgs> trainingCallback = null,
            [CanBeNull] IValidationDataset validationDataset = null,
            [CanBeNull] ITestDataset testDataset             = null,
            CancellationToken token = default)
        {
            IProgress <BatchProgress>             batchProgress    = batchCallback.AsIProgress();
            IProgress <TrainingProgressEventArgs> trainingProgress = trainingCallback.AsIProgress(); // Capture the synchronization contexts

            return(Task.Run(() => TrainNetworkCore(network, dataset, algorithm, epochs, dropout, batchProgress, trainingProgress, validationDataset, testDataset, token), token));
        }