예제 #1
0
        /// <summary>
        /// This method ensures that the data meets the requirements of this trainer and its
        /// subclasses, injects necessary transforms, and throws if it couldn't meet them.
        /// </summary>
        /// <param name="ch">The channel</param>
        /// <param name="examples">The training examples</param>
        /// <param name="weightSetCount">Gets the length of weights and bias array. For binary classification and regression,
        /// this is 1. For multi-class classification, this equals the number of classes on the label.</param>
        /// <returns>A potentially modified version of <paramref name="examples"/></returns>
        private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount)
        {
            ch.AssertValue(examples);
            CheckLabel(examples, out weightSetCount);
            examples.CheckFeatureFloatVector();
            var       idvToShuffle = examples.Data;
            IDataView idvToFeedTrain;

            if (idvToShuffle.CanShuffle)
            {
                idvToFeedTrain = idvToShuffle;
            }
            else
            {
                var shuffleArgs = new RowShufflingTransformer.Arguments
                {
                    PoolOnly     = false,
                    ForceShuffle = _args.Shuffle
                };
                idvToFeedTrain = new RowShufflingTransformer(Host, shuffleArgs, idvToShuffle);
            }

            ch.Assert(idvToFeedTrain.CanShuffle);

            var roles = examples.Schema.GetColumnRoleNames();
            var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);

            ch.AssertValue(examplesToFeedTrain.Schema.Label);
            ch.AssertValue(examplesToFeedTrain.Schema.Feature);
            if (examples.Schema.Weight != null)
            {
                ch.AssertValue(examplesToFeedTrain.Schema.Weight);
            }

            int numFeatures = examplesToFeedTrain.Schema.Feature.Type.VectorSize;

            ch.Check(numFeatures > 0, "Training set has no features, aborting training.");
            return(examplesToFeedTrain);
        }
예제 #2
0
        /// <summary>
        /// This method ensures that the data meets the requirements of this trainer and its
        /// subclasses, injects necessary transforms, and throws if it couldn't meet them.
        /// </summary>
        /// <param name="ch">The channel</param>
        /// <param name="examples">The training examples</param>
        /// <param name="weightSetCount">Gets the length of weights and bias array. For binary classification and regression,
        /// this is 1. For multi-class classification, this equals the number of classes on the label.</param>
        /// <returns>A potentially modified version of <paramref name="examples"/></returns>
        private protected RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount)
        {
            ch.AssertValue(examples);
            CheckLabel(examples, out weightSetCount);
            examples.CheckFeatureFloatVector();
            var       idvToShuffle = examples.Data;
            IDataView idvToFeedTrain;

            if (idvToShuffle.CanShuffle)
            {
                idvToFeedTrain = idvToShuffle;
            }
            else
            {
                var shuffleArgs = new RowShufflingTransformer.Options
                {
                    PoolOnly     = false,
                    ForceShuffle = ShuffleData
                };
                idvToFeedTrain = new RowShufflingTransformer(Host, shuffleArgs, idvToShuffle);
            }

            ch.Assert(idvToFeedTrain.CanShuffle);

            var roles = examples.Schema.GetColumnRoleNames();
            var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);

            ch.Assert(examplesToFeedTrain.Schema.Label.HasValue);
            ch.Assert(examplesToFeedTrain.Schema.Feature.HasValue);
            if (examples.Schema.Weight.HasValue)
            {
                ch.Assert(examplesToFeedTrain.Schema.Weight.HasValue);
            }

            ch.Check(examplesToFeedTrain.Schema.Feature.Value.Type is VectorType vecType && vecType.Size > 0, "Training set has no features, aborting training.");
            return(examplesToFeedTrain);
        }
예제 #3
0
        private void ExecuteTFTransformMNISTConvTrainingTest(bool shuffle, int?shuffleSeed, double expectedMicroAccuracy, double expectedMacroAccruacy)
        {
            const string modelLocation = "mnist_conv_model";

            try
            {
                var mlContext = new MLContext(seed: 1, conc: 1);

                var reader = mlContext.Data.CreateTextReader(new[]
                {
                    new TextLoader.Column("Label", DataKind.U4, new [] { new TextLoader.Range(0) }, new KeyRange(0, 9)),
                    new TextLoader.Column("TfLabel", DataKind.I8, 0),
                    new TextLoader.Column("Placeholder", DataKind.R4, new [] { new TextLoader.Range(1, 784) })
                }
                                                             );

                var trainData = reader.Read(GetDataPath(TestDatasets.mnistTiny28.trainFilename));
                var testData  = reader.Read(GetDataPath(TestDatasets.mnistOneClass.testFilename));

                IDataView preprocessedTrainData = null;
                IDataView preprocessedTestData  = null;
                if (shuffle)
                {
                    // Shuffle training data set
                    preprocessedTrainData = new RowShufflingTransformer(mlContext, new RowShufflingTransformer.Arguments()
                    {
                        ForceShuffle     = shuffle,
                        ForceShuffleSeed = shuffleSeed
                    }, trainData);

                    // Shuffle test data set
                    preprocessedTestData = new RowShufflingTransformer(mlContext, new RowShufflingTransformer.Arguments()
                    {
                        ForceShuffle     = shuffle,
                        ForceShuffleSeed = shuffleSeed
                    }, testData);
                }
                else
                {
                    preprocessedTrainData = trainData;
                    preprocessedTestData  = testData;
                }

                var pipe = mlContext.Transforms.CopyColumns(("Placeholder", "Features"))
                           .Append(new TensorFlowEstimator(mlContext, new TensorFlowTransformer.Arguments()
                {
                    ModelLocation         = modelLocation,
                    InputColumns          = new[] { "Features" },
                    OutputColumns         = new[] { "Prediction" },
                    LabelColumn           = "TfLabel",
                    TensorFlowLabel       = "Label",
                    OptimizationOperation = "MomentumOp",
                    LossOperation         = "Loss",
                    MetricOperation       = "Accuracy",
                    Epoch = 10,
                    LearningRateOperation = "learning_rate",
                    LearningRate          = 0.01f,
                    BatchSize             = 20,
                    ReTrain = true
                }))
                           .Append(mlContext.Transforms.Concatenate("Features", "Prediction"))
                           .AppendCacheCheckpoint(mlContext)
                           .Append(mlContext.MulticlassClassification.Trainers.LightGbm("Label", "Features"));

                var trainedModel = pipe.Fit(preprocessedTrainData);
                var predicted    = trainedModel.Transform(preprocessedTestData);
                var metrics      = mlContext.MulticlassClassification.Evaluate(predicted);

                // First group of checks. They check if the overall prediction quality is ok using a test set.
                Assert.InRange(metrics.AccuracyMicro, expectedMicroAccuracy - .01, expectedMicroAccuracy + .01);
                Assert.InRange(metrics.AccuracyMacro, expectedMacroAccruacy - .01, expectedMicroAccuracy + .01);

                // Create prediction function and test prediction
                var predictFunction = trainedModel.CreatePredictionEngine <MNISTData, MNISTPrediction>(mlContext);

                var oneSample = GetOneMNISTExample();

                var prediction = predictFunction.Predict(oneSample);

                Assert.Equal(5, GetMaxIndexForOnePrediction(prediction));
            }
            finally
            {
                // This test changes the state of the model.
                // Cleanup folder so that other test can also use the same model.
                CleanUp(modelLocation);
            }
        }
예제 #4
0
        IDataTransform AppendToPipeline(IDataView input)
        {
            IDataView current = input;

            if (_shuffleInput)
            {
                var args1 = new RowShufflingTransformer.Arguments()
                {
                    ForceShuffle     = false,
                    ForceShuffleSeed = _seedShuffle,
                    PoolRows         = _poolRows,
                    PoolOnly         = false,
                };
                current = new RowShufflingTransformer(Host, args1, current);
            }

            // We generate a random number.
            var columnName = current.Schema.GetTempColumnName();
            var args2      = new GenerateNumberTransform.Arguments()
            {
                Column = new GenerateNumberTransform.Column[] { new GenerateNumberTransform.Column()
                                                                {
                                                                    Name = columnName
                                                                } },
                Seed = _seed ?? 42
            };
            IDataTransform currentTr = new GenerateNumberTransform(Host, args2, current);

            // We convert this random number into a part.
            var cRatios = new float[_ratios.Length];

            cRatios[0] = 0;
            for (int i = 1; i < _ratios.Length; ++i)
            {
                cRatios[i] = cRatios[i - 1] + _ratios[i - 1];
            }

            ValueMapper <float, int> mapper = (in float src, ref int dst) =>
            {
                for (int i = cRatios.Length - 1; i > 0; --i)
                {
                    if (src >= cRatios[i])
                    {
                        dst = i;
                        return;
                    }
                }
                dst = 0;
            };

            // Get location of columnName

            int index;

            currentTr.Schema.TryGetColumnIndex(columnName, out index);
            var ct   = currentTr.Schema.GetColumnType(index);
            var view = LambdaColumnMapper.Create(Host, "Key to part mapper", currentTr,
                                                 columnName, _newColumn, ct, NumberType.I4, mapper);

            // We cache the result to avoid the pipeline to change the random number.
            var args3 = new ExtendedCacheTransform.Arguments()
            {
                inDataFrame = string.IsNullOrEmpty(_cacheFile),
                numTheads   = _numThreads,
                cacheFile   = _cacheFile,
                reuse       = _reuse,
            };

            currentTr = new ExtendedCacheTransform(Host, args3, view);

            // Removing the temporary column.
            var finalTr     = ColumnSelectingTransformer.CreateDrop(Host, currentTr, new string[] { columnName });
            var taggedViews = new List <Tuple <string, ITaggedDataView> >();

            // filenames
            if (_filenames != null || _tags != null)
            {
                int nbf = _filenames == null ? 0 : _filenames.Length;
                if (nbf > 0 && nbf != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nbt = _tags == null ? 0 : _tags.Length;
                if (nbt > 0 && nbt != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nb = Math.Max(nbf, nbt);

                using (var ch = Host.Start("Split the datasets and stores each part."))
                {
                    for (int i = 0; i < nb; ++i)
                    {
                        if (_filenames == null || !_filenames.Any())
                        {
                            ch.Info("Create part {0}: {1} (tag: {2})", i + 1, _ratios[i], _tags[i]);
                        }
                        else
                        {
                            ch.Info("Create part {0}: {1} (file: {2})", i + 1, _ratios[i], _filenames[i]);
                        }
                        var ar1 = new RangeFilter.Arguments()
                        {
                            Column = _newColumn, Min = i, Max = i, IncludeMax = true
                        };
                        int pardId   = i;
                        var filtView = LambdaFilter.Create <int>(Host, string.Format("Select part {0}", i), currentTr,
                                                                 _newColumn, NumberType.I4,
                                                                 (in int part) => { return(part.Equals(pardId)); });