public void MemoryMinibatchSource_GetNextMinibatch()
        {
            var observationsShape = new int[] { 5 };
            var observations      = new MemoryMinibatchData(m_observationsData, observationsShape, 9);
            var targetsShape      = new int[] { 1 };
            var targets           = new MemoryMinibatchData(m_targetData, targetsShape, 9);

            // setup name to data map.
            var nameToData = new Dictionary <string, MemoryMinibatchData>
            {
                { "observations", observations },
                { "targets", targets }
            };

            var nameToVariable = new Dictionary <string, Variable>
            {
                { "observations", Variable.InputVariable(observationsShape, DataType.Float) },
                { "targets", Variable.InputVariable(targetsShape, DataType.Float) }
            };

            var sut    = new MemoryMinibatchSource(nameToVariable, nameToData, 5, false);
            var device = DeviceDescriptor.CPUDevice;

            for (int i = 0; i < 30; i++)
            {
                var(minibatch, isSweepEnd) = sut.GetNextMinibatch(3, device);
                var obs = minibatch[nameToVariable["observations"]].GetDenseData <float>(nameToVariable["observations"]);
                var tar = minibatch[nameToVariable["targets"]].GetDenseData <float>(nameToVariable["targets"]);
            }
        }
Exemple #2
0
        static (MemoryMinibatchData observations, MemoryMinibatchData targets) CreateArtificialData(int[] inputShape, int[] outputShape, int observationCount)
        {
            var inputSize = inputShape.Aggregate((d1, d2) => d1 * d2);
            var random    = new Random(32);

            var observationsData = new float[observationCount * inputSize];

            observationsData = observationsData.Select(v => (float)random.NextDouble()).ToArray();

            var observations = new MemoryMinibatchData(observationsData, inputShape.ToArray(), observationCount);

            var targetsData = new float[observationCount];

            targetsData = targetsData.Select(d => (float)random.Next(outputShape.Single())).ToArray();
            var oneHotTargetsData = EncodeOneHot(targetsData);

            var targets = new MemoryMinibatchData(oneHotTargetsData, outputShape, observationCount);

            return(observations, targets);
        }
        public void SharpLearning_With_Cntk_Example()
        {
            // Load data
            var(observations, targets) = DataSetUtilities.LoadWinequalityWhite();

            // transform data for neural net
            var transform = new MinMaxTransformer(0.0, 1.0);

            transform.Transform(observations, observations);

            var featureCount     = observations.ColumnCount;
            var observationCount = observations.RowCount;
            var targetCount      = 1;

            var inputShape  = new int[] { featureCount, 1 };
            var outputShape = new int[] { targetCount };

            // Convert data to float, and wrap as minibatch data.
            var observationsFloat = observations.Data().Select(v => (float)v).ToArray();
            var observationsData  = new MemoryMinibatchData(observationsFloat, inputShape, observationCount);
            var targetsFloat      = targets.Select(v => (float)v).ToArray();
            var targetsData       = new MemoryMinibatchData(targetsFloat, outputShape, observationCount);

            var dataType = DataType.Float;
            var device   = DeviceDescriptor.CPUDevice;

            // setup input and target variables.
            var inputVariable  = Layers.Input(inputShape, dataType);
            var targetVariable = Variable.InputVariable(outputShape, dataType);

            // setup name to variable
            var nameToVariable = new Dictionary <string, Variable>
            {
                { "observations", inputVariable },
                { "targets", targetVariable },
            };

            // Get cross validation folds.
            var sampler = new RandomIndexSampler <double>(seed: 24);
            var crossValidationIndexSets = CrossValidationUtilities
                                           .GetKFoldCrossValidationIndexSets(sampler, foldCount: 10, targets: targets);
            var predictions = new double[observationCount];

            // Run cross validation loop.
            foreach (var set in crossValidationIndexSets)
            {
                // setup data.
                var trainingNameToData = new Dictionary <string, MemoryMinibatchData>
                {
                    { "observations", observationsData.GetSamples(set.trainingIndices) },
                    { "targets", targetsData.GetSamples(set.trainingIndices) }
                };

                var validationNameToData = new Dictionary <string, MemoryMinibatchData>
                {
                    { "observations", observationsData.GetSamples(set.validationIndices) },
                    { "targets", targetsData.GetSamples(set.validationIndices) }
                };

                var trainSource      = new MemoryMinibatchSource(nameToVariable, trainingNameToData, seed: 232, randomize: true);
                var validationSource = new MemoryMinibatchSource(nameToVariable, validationNameToData, seed: 232, randomize: false);

                // Create model and fit.
                var model = CreateModel(inputVariable, targetVariable, targetCount, dataType, device);
                model.Fit(trainSource, batchSize: 128, epochs: 10);

                // Predict.
                var predictionsRaw     = model.Predict(validationSource);
                var currentPredictions = predictionsRaw.Select(v => (double)v.Single()).ToArray();

                // set cross-validation predictions
                var validationIndices = set.validationIndices;
                for (int i = 0; i < validationIndices.Length; i++)
                {
                    predictions[validationIndices[i]] = currentPredictions[i];
                }
            }

            Trace.WriteLine(FormatErrorString(targets, predictions));
        }