public CntkDatasetRow GetRowFromDefinition(ITrainingDatasetDefinition datasetDefinition, int rowNumber)
        {
            var dataset = datasetDefinition.BuildDatasetIfNotPresent();

            int rowRead = 0;

            using (var stream = new FileStream(dataset.TrainingDatasetPath,
                                               FileMode.Open,
                                               FileAccess.Read,
                                               FileShare.Read))
                using (var reader = new StreamReader(stream))
                {
                    string line = null;
                    while (rowRead != rowNumber)
                    {
                        if (reader.EndOfStream)
                        {
                            return(null);
                        }
                        line = reader.ReadLine();
                        rowRead++;
                    }

                    return(ParseLineAsCntkDatasetRow(line, datasetDefinition, dataset));
                }
        }
        protected override void TrainNetwork(ITrainingDatasetDefinition datasetDefinition)
        {
            int    i            = 0;
            ushort currentEpoch = Configuration.Epochs;

            while (currentEpoch > 0)
            {
                var minibatchData = _minibatchSource.GetNextMinibatch(Configuration.MinibatchConfig.MinibatchSize, Device);
                var arguments     = new Dictionary <Variable, MinibatchData>
                {
                    { _input, minibatchData[_featureStreamInfo] },
                    { _labels, minibatchData[_labelStreamInfo] }
                };

                _trainer.TrainMinibatch(arguments, Device);

                PrintProgress(i++);

                if (minibatchData.Values.Any(batchData => batchData.sweepEnd))
                {
                    currentEpoch--;
                    if (Configuration.DumpModelSnapshotPerEpoch)
                    {
                        _networkClassifier.Save(
                            Configuration.PersistenceConfig.GetEpochFileNamePathFor(Convert.ToUInt16(Configuration.Epochs - currentEpoch),
                                                                                    datasetDefinition));
                    }
                }
            }
        }
        /// <summary>
        /// Evaluate how good (or bad) training went on test data
        /// </summary>
        protected virtual void EvaluateModel(ITrainingDatasetDefinition datasetDefinition,
                                             string persistedTrainingModelPath,
                                             int howManySamplesToUseFromTestDataset)
        {
            using (var evaluationMinibatchSourceModel = MinibatchSource.TextFormatMinibatchSource
                                                            (TrainingDataset.TestingDatasetPath, GetStreamConfigFrom(datasetDefinition), MinibatchSource.FullDataSweep))
            {
                Function model       = Function.Load(persistedTrainingModelPath, Device);
                var      imageInput  = model.Arguments[0];
                var      labelOutput = model.Outputs.Single(o => o.Name == ClassifierName);

                var featureStreamInfo = evaluationMinibatchSourceModel.StreamInfo(FeatureStreamName);
                var labelStreamInfo   = evaluationMinibatchSourceModel.StreamInfo(LabelsStreamName);

                int batchSize = 50;
                int miscountTotal = 0, totalCount = 0;

                while (true)
                {
                    var minibatchData = evaluationMinibatchSourceModel.GetNextMinibatch((uint)batchSize, Device);
                    if (minibatchData == null || minibatchData.Count == 0)
                    {
                        break;
                    }
                    totalCount += (int)minibatchData[featureStreamInfo].numberOfSamples;

                    // expected lables are in the minibatch data.
                    var labelData      = minibatchData[labelStreamInfo].data.GetDenseData <float>(labelOutput);
                    var expectedLabels = labelData.Select(l => l.IndexOf(l.Max())).ToList();

                    var inputDataMap = new Dictionary <Variable, Value>()
                    {
                        { imageInput, minibatchData[featureStreamInfo].data }
                    };

                    var outputDataMap = new Dictionary <Variable, Value>()
                    {
                        { labelOutput, null }
                    };

                    model.Evaluate(inputDataMap, outputDataMap, Device);
                    var outputData   = outputDataMap[labelOutput].GetDenseData <float>(labelOutput);
                    var actualLabels = outputData.Select(l => l.IndexOf(l.Max())).ToList();

                    int misMatches = actualLabels.Zip(expectedLabels, (a, b) => a.Equals(b) ? 0 : 1).Sum();

                    miscountTotal += misMatches;
                    MessagePrinter.PrintMessage($"Validating Model: Total Samples = {totalCount}, Misclassify Count = {miscountTotal}");

                    if (totalCount > howManySamplesToUseFromTestDataset)
                    {
                        break;
                    }
                }

                float errorRate = 1.0F * miscountTotal / totalCount;
                MessagePrinter.PrintMessage($"Model Validation Error = {errorRate}");
            }
        }
 protected IList <StreamConfiguration> GetStreamConfigFrom(ITrainingDatasetDefinition definition)
 {
     return(new List <StreamConfiguration>
     {
         new StreamConfiguration(FeatureStreamName, definition.SingleElementSize),
         new StreamConfiguration(LabelsStreamName, definition.LabelsAmount)
     });
 }
        protected override void PrepareTrainingData(ITrainingDatasetDefinition datasetDefinition)
        {
            _minibatchSource = MinibatchSource.TextFormatMinibatchSource(
                TrainingDataset.TrainingDatasetPath, GetStreamConfigFrom(datasetDefinition), MinibatchSource.InfinitelyRepeat);

            _featureStreamInfo = _minibatchSource.StreamInfo(FeatureStreamName);
            _labelStreamInfo   = _minibatchSource.StreamInfo(LabelsStreamName);
        }
        public static PreparedLearningDataset ParseFromGZipedDefinitionsForCntk(ITrainingDatasetDefinition datasetDefinition)
        {
            var output = new PreparedLearningDataset();

            output.ValueToLabelMap     = BuildEMNISTValueToLabelMapFor(datasetDefinition);
            output.TrainingDatasetPath = ReadEMNISTInCntkFormat(datasetDefinition, output.ValueToLabelMap, true);
            output.TestingDatasetPath  = ReadEMNISTInCntkFormat(datasetDefinition, output.ValueToLabelMap, false);

            return(output);
        }
        public static string ReadEMNISTInCntkFormat(ITrainingDatasetDefinition definition, SortedDictionary <byte, char> valueToLabel, bool isTrainData)
        {
            if (definition == null)
            {
                throw new ArgumentNullException(nameof(definition));
            }

            if (valueToLabel == null)
            {
                throw new ArgumentNullException(nameof(valueToLabel));
            }

            var outputFilePath = isTrainData ?
                                 string.Format(EMNISTTrainFilenameMask, definition.OutputFileSuffix) :
                                 string.Format(EMNISTTestFilenameMask, definition.OutputFileSuffix);

            using (var imagesGzip = new GZipStream(
                       File.Open(isTrainData ? definition.TrainImagesPath : definition.TestImagesPath, FileMode.Open),
                       CompressionMode.Decompress))
                using (var labelsGzip = new GZipStream(
                           File.Open(isTrainData ? definition.TrainLabelsPath : definition.TestLabelsPath, FileMode.Open),
                           CompressionMode.Decompress))
                {
                    using (var imagesReader = new BinaryReader(imagesGzip, Encoding.ASCII))
                        using (var labelsReader = new BinaryReader(labelsGzip, Encoding.ASCII))
                        {
                            int magicNumber1 = imagesReader.ReadBigEndianInt32();
                            int numImages    = imagesReader.ReadBigEndianInt32();
                            int numRows      = imagesReader.ReadBigEndianInt32();
                            int numCols      = imagesReader.ReadBigEndianInt32();

                            int magicNumber2 = labelsReader.ReadBigEndianInt32();
                            int numLabels    = labelsReader.ReadBigEndianInt32();

                            using (var writer = new StreamWriter(outputFilePath))
                            {
                                for (var i = 0; i < numImages; i++)
                                {
                                    writer.Write(Environment.NewLine
                                                 + $"|labels {labelsReader.ReadByte().AsCntkLabelDefinition(valueToLabel)} "
                                                 + "|features");
                                    for (var j = 0; j < 28 * 28; j++)
                                    {
                                        writer.Write(" ");
                                        writer.Write(imagesReader.ReadByte());
                                    }
                                }
                            }
                        }
                }

            return(outputFilePath);
        }
Esempio n. 8
0
        private static void RunEmnistTraining(string choice)
        {
            ITrainingDatasetDefinition datasetDefinition = null;

            switch (choice)
            {
            case LETTERS_CHOICE:
                datasetDefinition = new EMNISTLetterDataset();
                break;

            case DIGITS_CHOICE:
                datasetDefinition = new EMNISTDigitDataset();
                break;

            case UPPERCASE_LETTERS_CHOICE:
                datasetDefinition = new EMNISTUppercaseLetterDataset();
                break;

            default:
                SharedConsoleCommands.InvalidCommand(choice);
                return;
            }

            TrainingSessionStart(choice);
            var msgPrinter = new ConsolePrinter();

            var outputDir             = $"./{DateTime.Now.ToString("yyyyMMddHHmmss", CultureInfo.InvariantCulture)}/";
            var device                = DeviceDescriptor.GPUDevice(0);
            var trainingConfiguration = new TrainingSessionConfiguration
            {
                Epochs = 200,
                DumpModelSnapshotPerEpoch  = true,
                ProgressEvaluationSeverity = EvaluationSeverity.PerEpoch,
                MinibatchConfig            = new MinibatchConfiguration
                {
                    MinibatchSize = 64,
                    HowManyMinibatchesPerSnapshot      = (60000 / 32),
                    HowManyMinibatchesPerProgressPrint = 500,
                    DumpModelSnapshotPerMinibatch      = false,
                    AsyncMinibatchSnapshot             = false
                },
                PersistenceConfig = TrainingModelPersistenceConfiguration.CreateWithAllLocationsSetTo(outputDir)
            };

            msgPrinter.PrintMessage("\n" + trainingConfiguration + "\n");

            using (var runner = new ConvolutionalNeuralNetworkRunner(device, trainingConfiguration, msgPrinter))
            {
                runner.RunUsing(datasetDefinition);
            }

            EmnistTrainingDone(choice);
        }
        private CntkDatasetRow ParseLineAsCntkDatasetRow(string readLine,
                                                         ITrainingDatasetDefinition datasetDefinition, PreparedLearningDataset dataset)
        {
            var rowStreams  = readLine.Split('|').Skip(1).ToList();
            var labelStream = GetDatasetStreamItems(rowStreams[0], "labels");

            return(new CntkDatasetRow
            {
                DatasetName = datasetDefinition.DataSetName,
                ImagePixels = GetDatasetStreamItems(rowStreams[1], "features"),
                Label = CntkLabelFromLabelStream(labelStream, dataset)
            });
        }
        public TrainingSessionResult RunUsing(ITrainingDatasetDefinition datasetDefinition)
        {
            CleanUp();
            BuildNeuralNetwork(datasetDefinition);
            PrepareTrainingData(datasetDefinition);
            ConfigureTrainer();
            TrainNetwork(datasetDefinition);
            SaveResults(datasetDefinition);
            //TODO: Make test dataset volume configurable!
            EvaluateModel(datasetDefinition, Configuration.PersistenceConfig.GetTrainingResultFileNamePathFor(datasetDefinition), 1000);

            return(null);
        }
        public static SortedDictionary <byte, char> BuildEMNISTValueToLabelMapFor(ITrainingDatasetDefinition definition)
        {
            var output = new SortedDictionary <byte, char>();

            using (var mappingReader = new StreamReader(File.Open(definition.MappingPath, FileMode.Open)))
            {
                var line = "";
                while ((line = mappingReader.ReadLine()) != null)
                {
                    var split = line.Split(' ');
                    output.Add(byte.Parse(split[0]), (char)byte.Parse(split[1]));
                }
            }
            return(output);
        }
 public string GetTrainingResultFileNamePathFor(ITrainingDatasetDefinition dataset, string customPrefix = null)
 {
     return($"{TrainingResultTargetLocation}{customPrefix}FINAL_{dataset.DataSetName}.Model");
 }
 /// <summary>
 /// Lock and set up neural network
 /// </summary>
 protected abstract void BuildNeuralNetwork(ITrainingDatasetDefinition datasetDefinition);
 protected override void SaveResults(ITrainingDatasetDefinition datasetDefinition)
 {
     _networkClassifier.Save(Configuration.PersistenceConfig.GetTrainingResultFileNamePathFor(datasetDefinition));
 }
 public string GetEpochFileNamePathFor(
     ushort epochNumber, ITrainingDatasetDefinition dataset, string customPrefix = null)
 {
     return($"{EpochSnapshotTargetLocation}{customPrefix}EP{epochNumber.ToString("000000")}_{dataset.DataSetName}.Model");
 }
 /// <summary>
 /// Prepare datasets for usage in DNN
 /// </summary>
 protected abstract void PrepareTrainingData(ITrainingDatasetDefinition datasetDefinition);
 /// <summary>
 /// Save training results
 /// </summary>
 protected abstract void SaveResults(ITrainingDatasetDefinition datasetDefinition);
 /// <summary>
 /// Train network with provided batches
 /// </summary>
 protected abstract void TrainNetwork(ITrainingDatasetDefinition datasetDefinition);
        protected override void BuildNeuralNetwork(ITrainingDatasetDefinition datasetDefinition)
        {
            _input = datasetDefinition.AsInputFor(this);

            TrainingDataset      = datasetDefinition.BuildDatasetIfNotPresent();
            NetworkConfiguration = new ConvolutionalNeuralNetworkConfiguration(
                _input.ScaledForConvolutionalNetwork(this),
                (int)datasetDefinition.LabelsAmount,
                ClassifierName);

            //INPUT -> CONV -> RELU -> MAX POOL -> CONV -> RELU -> MAX POOL -> FC

            //28x28x1 -> 14x14x4
            NetworkConfiguration.AppendConvolutionLayer(
                new ConvolutionParams
            {
                FilterSize             = new Dimension2D(3, 3),
                Channels               = 1,
                OutputFeatureMapsCount = 4,
                Stride = new Stride3D(1, 1, 1)
            }, Device);

            NetworkConfiguration.AppendReluActivation();
            NetworkConfiguration.AppendPoolingLayer(
                new PoolingParams
            {
                Type          = PoolingType.Max,
                PoolingWindow = new Dimension2D(3, 3),
                Stride        = new Stride2D(2, 2)
            });

            //14x14x4 -> 7x7x8
            NetworkConfiguration.AppendConvolutionLayer(
                new ConvolutionParams
            {
                FilterSize             = new Dimension2D(3, 3),
                Channels               = 4,
                OutputFeatureMapsCount = 8,
                Stride = new Stride3D(1, 1, 4)
            }, Device);

            NetworkConfiguration.AppendTanHActivation();
            NetworkConfiguration.AppendPoolingLayer(
                new PoolingParams
            {
                Type          = PoolingType.Max,
                PoolingWindow = new Dimension2D(3, 3),
                Stride        = new Stride2D(2, 2)
            });

            //Fully Connect
            NetworkConfiguration.AppendFullyConnectedLinearLayer(Device);

            _networkClassifier = NetworkConfiguration.Evaluate();

            MessagePrinter.PrintMessage(NetworkConfiguration.ToString());

            _labels = CNTKLib.InputVariable(new int[] { (int)datasetDefinition.LabelsAmount }, DataType.Float, LabelsStreamName);
            _trainingLossFunction = CNTKLib.CrossEntropyWithSoftmax(new Variable(_networkClassifier), _labels, LossFunctionName);
            _evaluationFunction   = CNTKLib.ClassificationError(new Variable(_networkClassifier), _labels, ClassificationErrorName);
        }
Esempio n. 20
0
 public static Variable AsInputFor(this ITrainingDatasetDefinition datasetDefinition, DeepLearningRunner inputConsumer)
 {
     return(CNTKLib.InputVariable(datasetDefinition.SingleElementDimensions, DataType.Float, inputConsumer.FeatureStreamName));
 }