예제 #1
0
        /// <summary>
        /// Shuffle the rows of <paramref name="input"/>.
        /// </summary>
        /// <remarks>
        /// <see cref="ShuffleRows"/> will shuffle the rows of any input <see cref="IDataView"/> using a streaming approach.
        /// In order to not load the entire dataset in memory, a pool of <paramref name="shufflePoolSize"/> rows will be used
        /// to randomly select rows to output. The pool is constructed from the first <paramref name="shufflePoolSize"/> rows
        /// in <paramref name="input"/>. Rows will then be randomly yielded from the pool and replaced with the next row from <paramref name="input"/>
        /// until all the rows have been yielded, resulting in a new <see cref="IDataView"/> of the same size as <paramref name="input"/>
        /// but with the rows in a randomized order.
        /// If the <see cref="IDataView.CanShuffle"/> property of <paramref name="input"/> is true, then it will also be read into the
        /// pool in a random order, offering two sources of randomness.
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="seed">The random seed. If unspecified, the random state will be instead derived from the <see cref="MLContext"/>.</param>
        /// <param name="shufflePoolSize">The number of rows to hold in the pool. Setting this to 1 will turn off pool shuffling and
        /// <see cref="ShuffleRows"/> will only perform a shuffle by reading <paramref name="input"/> in a random order.</param>
        /// <param name="shuffleSource">If <see langword="false"/>, the transform will not attempt to read <paramref name="input"/> in a random order and only use
        /// pooling to shuffle. This parameter has no effect if the <see cref="IDataView.CanShuffle"/> property of <paramref name="input"/> is <see langword="false"/>.
        /// </param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[ShuffleRows](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/ShuffleRows.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView ShuffleRows(IDataView input,
                                     int?seed            = null,
                                     int shufflePoolSize = RowShufflingTransformer.Defaults.PoolRows,
                                     bool shuffleSource  = !RowShufflingTransformer.Defaults.PoolOnly)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckUserArg(shufflePoolSize > 0, nameof(shufflePoolSize), "Must be positive");

            var options = new RowShufflingTransformer.Options
            {
                PoolRows         = shufflePoolSize,
                PoolOnly         = !shuffleSource,
                ForceShuffle     = true,
                ForceShuffleSeed = seed
            };

            return(new RowShufflingTransformer(_env, options, input));
        }
        /// <summary>
        /// This method ensures that the data meets the requirements of this trainer and its
        /// subclasses, injects necessary transforms, and throws if it couldn't meet them.
        /// </summary>
        /// <param name="ch">The channel</param>
        /// <param name="examples">The training examples</param>
        /// <param name="weightSetCount">Gets the length of weights and bias array. For binary classification and regression,
        /// this is 1. For multi-class classification, this equals the number of classes on the label.</param>
        /// <returns>A potentially modified version of <paramref name="examples"/></returns>
        private RoleMappedData PrepareDataFromTrainingExamples(IChannel ch, RoleMappedData examples, out int weightSetCount)
        {
            ch.AssertValue(examples);
            CheckLabel(examples, out weightSetCount);
            examples.CheckFeatureFloatVector();
            var       idvToShuffle = examples.Data;
            IDataView idvToFeedTrain;

            if (idvToShuffle.CanShuffle)
            {
                idvToFeedTrain = idvToShuffle;
            }
            else
            {
                var shuffleArgs = new RowShufflingTransformer.Options
                {
                    PoolOnly     = false,
                    ForceShuffle = _options.Shuffle
                };
                idvToFeedTrain = new RowShufflingTransformer(Host, shuffleArgs, idvToShuffle);
            }

            ch.Assert(idvToFeedTrain.CanShuffle);

            var roles = examples.Schema.GetColumnRoleNames();
            var examplesToFeedTrain = new RoleMappedData(idvToFeedTrain, roles);

            ch.Assert(examplesToFeedTrain.Schema.Label.HasValue);
            ch.Assert(examplesToFeedTrain.Schema.Feature.HasValue);
            if (examples.Schema.Weight.HasValue)
            {
                ch.Assert(examplesToFeedTrain.Schema.Weight.HasValue);
            }

            ch.Check(examplesToFeedTrain.Schema.Feature.Value.Type is VectorType vecType && vecType.Size > 0, "Training set has no features, aborting training.");
            return(examplesToFeedTrain);
        }
        IDataTransform AppendToPipeline(IDataView input)
        {
            IDataView current = input;

            if (_shuffleInput)
            {
                var args1 = new RowShufflingTransformer.Options()
                {
                    ForceShuffle     = false,
                    ForceShuffleSeed = _seedShuffle,
                    PoolRows         = _poolRows,
                    PoolOnly         = false,
                };
                current = new RowShufflingTransformer(Host, args1, current);
            }

            // We generate a random number.
            var columnName = current.Schema.GetTempColumnName();
            var args2      = new GenerateNumberTransform.Options()
            {
                Columns = new GenerateNumberTransform.Column[] { new GenerateNumberTransform.Column()
                                                                 {
                                                                     Name = columnName
                                                                 } },
                Seed = _seed ?? 42
            };
            IDataTransform currentTr = new GenerateNumberTransform(Host, args2, current);

            // We convert this random number into a part.
            var cRatios = new float[_ratios.Length];

            cRatios[0] = 0;
            for (int i = 1; i < _ratios.Length; ++i)
            {
                cRatios[i] = cRatios[i - 1] + _ratios[i - 1];
            }

            ValueMapper <float, int> mapper = (in float src, ref int dst) =>
            {
                for (int i = cRatios.Length - 1; i > 0; --i)
                {
                    if (src >= cRatios[i])
                    {
                        dst = i;
                        return;
                    }
                }
                dst = 0;
            };

            // Get location of columnName

            int index = SchemaHelper.GetColumnIndex(currentTr.Schema, columnName);
            var ct    = currentTr.Schema[index].Type;
            var view  = LambdaColumnMapper.Create(Host, "Key to part mapper", currentTr,
                                                  columnName, _newColumn, ct, NumberDataViewType.Int32, mapper);

            // We cache the result to avoid the pipeline to change the random number.
            var args3 = new ExtendedCacheTransform.Arguments()
            {
                inDataFrame = string.IsNullOrEmpty(_cacheFile),
                numTheads   = _numThreads,
                cacheFile   = _cacheFile,
                reuse       = _reuse,
            };

            currentTr = new ExtendedCacheTransform(Host, args3, view);

            // Removing the temporary column.
            var objtr   = ColumnSelectingTransformer.CreateDrop(Host, currentTr, new string[] { columnName });
            var finalTr = objtr as IDataTransform;

            if (finalTr == null)
            {
                throw Contracts.ExceptNotSupp("Desgin change.");
            }
            var taggedViews = new List <Tuple <string, ITaggedDataView> >();

            // filenames
            if (_filenames != null || _tags != null)
            {
                int nbf = _filenames == null ? 0 : _filenames.Length;
                if (nbf > 0 && nbf != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nbt = _tags == null ? 0 : _tags.Length;
                if (nbt > 0 && nbt != _ratios.Length)
                {
                    throw Host.Except("Differen number of filenames and ratios.");
                }
                int nb = Math.Max(nbf, nbt);

                using (var ch = Host.Start("Split the datasets and stores each part."))
                {
                    for (int i = 0; i < nb; ++i)
                    {
                        if (_filenames == null || !_filenames.Any())
                        {
                            ch.Info("Create part {0}: {1} (tag: {2})", i + 1, _ratios[i], _tags[i]);
                        }
                        else
                        {
                            ch.Info("Create part {0}: {1} (file: {2})", i + 1, _ratios[i], _filenames[i]);
                        }
                        var ar1 = new RangeFilter.Options()
                        {
                            Column = _newColumn, Min = i, Max = i, IncludeMax = true
                        };
                        int pardId   = i;
                        var filtView = LambdaFilter.Create <int>(Host, string.Format("Select part {0}", i), currentTr,
                                                                 _newColumn, NumberDataViewType.Int32,
                                                                 (in int part) => { return(part.Equals(pardId)); });