コード例 #1
0
        /// <summary>
        /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
        /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
        /// </summary>
        /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
        SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema)
        {
            Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string columnRole) =>
            {
                if (column.IsValid)
                {
                    if (!inputSchema.TryFindColumn(column.Name, out var outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                    }
                    if (!column.IsCompatibleWith(outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name, column.GetTypeString(), outCol.GetTypeString());
                    }
                }
            };

            // Check the input schema.
            checkColumnValid(ScoreColumn, "score");
            checkColumnValid(WeightColumn, "weight");
            checkColumnValid(LabelColumn, "label");

            // Create the new Probability column.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
                                                                                SchemaShape.Column.VectorKind.Scalar,
                                                                                NumberDataViewType.Single,
                                                                                false,
                                                                                new SchemaShape(AnnotationUtils.GetTrainerOutputAnnotation(true)));

            return(new SchemaShape(outColumns.Values));
        }
コード例 #2
0
        /// <summary>
        /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
        /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
        /// </summary>
        /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
        SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema)
        {
            Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string columnRole) =>
            {
                if (column.IsValid)
                {
                    if (!inputSchema.TryFindColumn(column.Name, out var outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                    }
                    if (!column.IsCompatibleWith(outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name, column.GetTypeString(), outCol.GetTypeString());
                    }
                }
            };

            // check the input schema
            checkColumnValid(ScoreColumn, DefaultColumnNames.Score);
            checkColumnValid(WeightColumn, DefaultColumnNames.Weight);
            checkColumnValid(LabelColumn, DefaultColumnNames.Label);
            checkColumnValid(FeatureColumn, DefaultColumnNames.Features);
            checkColumnValid(PredictedLabel, DefaultColumnNames.PredictedLabel);

            //create the new Probability column
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
                                                                                SchemaShape.Column.VectorKind.Scalar,
                                                                                NumberType.R4,
                                                                                false,
                                                                                new SchemaShape(MetadataUtils.GetTrainerOutputMetadata(true)));

            return(new SchemaShape(outColumns.Values));
        }
コード例 #3
0
        /// <summary>
        /// Gets the output <see cref="SchemaShape"/> of the <see cref="IDataView"/> after fitting the calibrator.
        /// Fitting the calibrator will add a column named "Probability" to the schema. If you already had such a column, a new one will be added.
        /// The same annotation data that would be produced by <see cref="AnnotationUtils.GetTrainerOutputAnnotation(bool)"/> is marked as
        /// being present on the output, if it is present on the input score column.
        /// </summary>
        /// <param name="inputSchema">The input <see cref="SchemaShape"/>.</param>
        SchemaShape IEstimator <CalibratorTransformer <TICalibrator> > .GetOutputSchema(SchemaShape inputSchema)
        {
            Action <SchemaShape.Column, string> checkColumnValid = (SchemaShape.Column column, string columnRole) =>
            {
                if (column.IsValid)
                {
                    if (!inputSchema.TryFindColumn(column.Name, out var outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name);
                    }
                    if (!column.IsCompatibleWith(outCol))
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), columnRole, column.Name, column.GetTypeString(), outCol.GetTypeString());
                    }
                }
            };

            // Check the input schema.
            checkColumnValid(ScoreColumn, "score");
            checkColumnValid(WeightColumn, "weight");
            checkColumnValid(LabelColumn, "label");

            bool success = inputSchema.TryFindColumn(ScoreColumn.Name, out var inputScoreCol);

            Host.Assert(success);
            const SchemaShape.Column.VectorKind scalar = SchemaShape.Column.VectorKind.Scalar;

            var annotations = new List <SchemaShape.Column>();

            annotations.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized,
                                                   SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
            // We only propagate this training column metadata if it looks like it's all there, and all correct.
            if (inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreColumnSetId, out var setIdCol) &&
                setIdCol.Kind == scalar && setIdCol.IsKey && setIdCol.ItemType == NumberDataViewType.UInt32 &&
                inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreColumnKind, out var kindCol) &&
                kindCol.Kind == scalar && kindCol.ItemType is TextDataViewType &&
                inputScoreCol.Annotations.TryFindColumn(AnnotationUtils.Kinds.ScoreValueKind, out var valueKindCol) &&
                valueKindCol.Kind == scalar && valueKindCol.ItemType is TextDataViewType)
            {
                annotations.Add(setIdCol);
                annotations.Add(kindCol);
                annotations.Add(valueKindCol);
            }

            // Create the new Probability column.
            var outColumns = inputSchema.ToDictionary(x => x.Name);

            outColumns[DefaultColumnNames.Probability] = new SchemaShape.Column(DefaultColumnNames.Probability,
                                                                                SchemaShape.Column.VectorKind.Scalar,
                                                                                NumberDataViewType.Single,
                                                                                false, new SchemaShape(annotations));

            return(new SchemaShape(outColumns.Values));
        }
コード例 #4
0
        /// <summary>
        /// Creates a lazy in-memory cache of <paramref name="input"/>.
        /// </summary>
        /// <remarks>
        /// Caching happens per-column. A column is only cached when it is first accessed.
        /// In addition, <paramref name="columnsToPrefetch"/> are considered 'always needed', so these columns
        /// will be cached the first time any data is requested.
        /// </remarks>
        /// <param name="input">The input data.</param>
        /// <param name="columnsToPrefetch">The columns that must be cached whenever anything is cached. An empty array or null
        /// value means that columns are cached upon their first access.</param>
        /// <example>
        /// <format type="text/markdown">
        /// <![CDATA[
        /// [!code-csharp[Cache](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/Cache.cs)]
        /// ]]>
        /// </format>
        /// </example>
        public IDataView Cache(IDataView input, params string[] columnsToPrefetch)
        {
            _env.CheckValue(input, nameof(input));
            _env.CheckValueOrNull(columnsToPrefetch);

            int[] prefetch = new int[Utils.Size(columnsToPrefetch)];
            for (int i = 0; i < prefetch.Length; i++)
            {
                if (!input.Schema.TryGetColumnIndex(columnsToPrefetch[i], out prefetch[i]))
                    throw _env.ExceptSchemaMismatch(nameof(columnsToPrefetch), "prefetch", columnsToPrefetch[i]);
            }
            return new CacheDataView(_env, input, prefetch);
        }
コード例 #5
0
        private static int FindVectorInputColumn(IHostEnvironment env, IReadOnlyList <string> inputColumnNames, DataViewSchema inputSchema, DataViewType[] inputTypes)
        {
            int ivec = -1;

            for (int isrc = 0; isrc < inputColumnNames.Count; isrc++)
            {
                var col = inputSchema.GetColumnOrNull(inputColumnNames[isrc]);
                if (col == null)
                {
                    throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnNames[isrc]);
                }

                if (col.Value.Type is VectorDataViewType)
                {
                    if (ivec >= 0)
                    {
                        throw env.ExceptUserArg(nameof(inputColumnNames), "Can have at most one vector-valued source column");
                    }
                    ivec = isrc;
                }
                inputTypes[isrc] = col.Value.Type.GetItemType();
            }

            return(ivec);
        }
        /// <summary>
        /// <see cref="PretrainedTreeFeaturizationEstimator"/> adds three float-vector columns into <paramref name="inputSchema"/>.
        /// Given a feature vector column, the added columns are the prediction values of all trees, the leaf IDs the feature
        /// vector falls into, and the paths to those leaves.
        /// </summary>
        /// <param name="inputSchema">A schema which contains a feature column. Note that feature column name can be specified
        /// by <see cref="OptionsBase.InputColumnName"/>.</param>
        /// <returns>Output <see cref="SchemaShape"/> produced by <see cref="PretrainedTreeFeaturizationEstimator"/>.</returns>
        public SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Env.CheckValue(inputSchema, nameof(inputSchema));

            if (!inputSchema.TryFindColumn(FeatureColumnName, out var col))
            {
                throw Env.ExceptSchemaMismatch(nameof(inputSchema), "input", FeatureColumnName);
            }

            var result = inputSchema.ToDictionary(x => x.Name);

            if (TreesColumnName != null)
            {
                result[TreesColumnName] = new SchemaShape.Column(TreesColumnName,
                                                                 SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            if (LeavesColumnName != null)
            {
                result[LeavesColumnName] = new SchemaShape.Column(LeavesColumnName,
                                                                  SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            if (PathsColumnName != null)
            {
                result[PathsColumnName] = new SchemaShape.Column(PathsColumnName,
                                                                 SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false);
            }

            return(new SchemaShape(result.Values));
        }
コード例 #7
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null)
        {
            Contracts.CheckValue(env, nameof(env));
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.
            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
                data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    // Hash the samplingKeyColumn.
                    // REVIEW: this could currently crash, since Hash only accepts a limited set
                    // of column types. It used to be HashJoin, but we should probably extend Hash
                    // instead of having two hash transformations.
                    var origStratCol = samplingKeyColumn;
                    samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                    HashingEstimator.ColumnOptionsInternal columnOptions;
                    if (seed.HasValue)
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
                    }
                    else if (((ISeededEnvironment)env).Seed.HasValue)
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)((ISeededEnvironment)env).Seed.Value);
                    }
                    else
                    {
                        columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30);
                    }
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
                    {
                        var origStratCol = samplingKeyColumn;
                        samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
                    }
                }
            }
        }
コード例 #8
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null)
        {
            Contracts.CheckValue(env, nameof(env));
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.
            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
                data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var origStratCol = samplingKeyColumn;
                    samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data);
                    }

                    var localSeed     = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null;
                    var columnOptions =
                        localSeed.HasValue ?
                        new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
                    {
                        var origStratCol = samplingKeyColumn;
                        samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
                    }
                }
            }
        }
コード例 #9
0
        /// <summary>
        /// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
        /// </summary>
        internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int?seed = null)
        {
            // We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
            // build a single hash of it. If it is not, we generate a random number.

            if (samplingKeyColumn == null)
            {
                samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
                data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)seed);
            }
            else
            {
                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[stratCol].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    // Hash the samplingKeyColumn.
                    // REVIEW: this could currently crash, since Hash only accepts a limited set
                    // of column types. It used to be HashJoin, but we should probably extend Hash
                    // instead of having two hash transformations.
                    var origStratCol = samplingKeyColumn;
                    int tmp;
                    int inc = 0;

                    // Generate a new column with the hashed samplingKeyColumn.
                    while (data.Schema.TryGetColumnIndex(samplingKeyColumn, out tmp))
                    {
                        samplingKeyColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
                    }
                    HashingEstimator.ColumnOptions columnOptions;
                    if (seed.HasValue)
                    {
                        columnOptions = new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
                    }
                    else
                    {
                        columnOptions = new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30);
                    }
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
            }
        }
コード例 #10
0
        // Extracts the indices and types of the input columns to the whitening transform.
        private static void GetColTypesAndIndex(IHostEnvironment env, IDataView inputData, ColumnInfo[] columns, out ColumnType[] srcTypes, out int[] cols)
        {
            cols     = new int[columns.Length];
            srcTypes = new ColumnType[columns.Length];
            var inputSchema = inputData.Schema;

            for (int i = 0; i < columns.Length; i++)
            {
                if (!inputSchema.TryGetColumnIndex(columns[i].Input, out cols[i]))
                {
                    throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input);
                }
                srcTypes[i] = inputSchema.GetColumnType(cols[i]);
                var reason = TestColumn(srcTypes[i]);
                if (reason != null)
                {
                    throw env.ExceptParam(nameof(inputData.Schema), reason);
                }
            }
        }
コード例 #11
0
        // Extracts the indices and types of the input columns to the whitening transform.
        private static void GetColTypesAndIndex(IHostEnvironment env, IDataView inputData, VectorWhiteningEstimator.ColumnOptions[] columns, out DataViewType[] srcTypes, out int[] cols)
        {
            cols     = new int[columns.Length];
            srcTypes = new DataViewType[columns.Length];
            var inputSchema = inputData.Schema;

            for (int i = 0; i < columns.Length; i++)
            {
                var col = inputSchema.GetColumnOrNull(columns[i].InputColumnName);
                if (!col.HasValue)
                {
                    throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].InputColumnName);
                }

                cols[i]     = col.Value.Index;
                srcTypes[i] = col.Value.Type;
                var reason = TestColumn(srcTypes[i]);
                if (reason != null)
                {
                    throw env.ExceptParam(nameof(inputData.Schema), reason);
                }
            }
        }
コード例 #12
0
        private static int FindVectorInputColumn(IHostEnvironment env, IReadOnlyList <string> inputColumnNames, SchemaShape inputSchema, DataViewType[] inputTypes)
        {
            int ivec = -1;

            for (int isrc = 0; isrc < inputColumnNames.Count; isrc++)
            {
                if (!inputSchema.TryFindColumn(inputColumnNames[isrc], out var col))
                {
                    throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnNames[isrc]);
                }

                if (col.Kind != SchemaShape.Column.VectorKind.Scalar)
                {
                    if (ivec >= 0)
                    {
                        throw env.ExceptUserArg(nameof(inputColumnNames), "Can have at most one vector-valued source column");
                    }
                    ivec = isrc;
                }
                inputTypes[isrc] = col.ItemType;
            }

            return(ivec);
        }
コード例 #13
0
        /// <summary>
        /// Extract all values of one column of the data view in a form of an <see cref="IEnumerable{T}"/>.
        /// </summary>
        /// <typeparam name="T">The type of the values. This must match the actual column type.</typeparam>
        /// <param name="data">The data view to get the column from.</param>
        /// <param name="env">The current host environment.</param>
        /// <param name="columnName">The name of the column to extract.</param>
        public static IEnumerable <T> GetColumn <T>(this IDataView data, IHostEnvironment env, string columnName)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(data, nameof(data));
            env.CheckNonEmpty(columnName, nameof(columnName));

            if (!data.Schema.TryGetColumnIndex(columnName, out int col))
            {
                throw env.ExceptSchemaMismatch(nameof(columnName), "input", columnName);
            }

            // There are two decisions that we make here:
            // - Is the T an array type?
            //     - If yes, we need to map VBuffer to array and densify.
            //     - If no, this is not needed.
            // - Does T (or item type of T if it's an array) equal to the data view type?
            //     - If this is the same type, we can map directly.
            //     - Otherwise, we need a conversion delegate.

            var colType = data.Schema.GetColumnType(col);

            if (colType.RawType == typeof(T))
            {
                // Direct mapping is possible.
                return(GetColumnDirect <T>(data, col));
            }
            else if (typeof(T) == typeof(string) && colType.IsText)
            {
                // Special case of ROM<char> to string conversion.
                Delegate convert = (Func <ReadOnlyMemory <char>, string>)((ReadOnlyMemory <char> txt) => txt.ToString());
                Func <IDataView, int, Func <int, T>, IEnumerable <T> > del = GetColumnConvert;
                var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(typeof(T), colType.RawType);
                return((IEnumerable <T>)(meth.Invoke(null, new object[] { data, col, convert })));
            }
            else if (typeof(T).IsArray)
            {
                // Output is an array type.
                if (!colType.IsVector)
                {
                    throw env.ExceptSchemaMismatch(nameof(columnName), "input", columnName, "vector", "scalar");
                }
                var elementType = typeof(T).GetElementType();
                if (elementType == colType.ItemType.RawType)
                {
                    // Direct mapping of items.
                    Func <IDataView, int, IEnumerable <int[]> > del = GetColumnArrayDirect <int>;
                    var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(elementType);
                    return((IEnumerable <T>)meth.Invoke(null, new object[] { data, col }));
                }
                else if (elementType == typeof(string) && colType.ItemType.IsText)
                {
                    // Conversion of DvText items to string items.
                    Delegate convert = (Func <ReadOnlyMemory <char>, string>)((ReadOnlyMemory <char> txt) => txt.ToString());
                    Func <IDataView, int, Func <int, long>, IEnumerable <long[]> > del = GetColumnArrayConvert;
                    var meth = del.Method.GetGenericMethodDefinition().MakeGenericMethod(elementType, colType.ItemType.RawType);
                    return((IEnumerable <T>)meth.Invoke(null, new object[] { data, col, convert }));
                }
                // Fall through to the failure.
            }
            throw env.Except($"Could not map a data view column '{columnName}' of type {colType} to {typeof(T)}.");
        }
コード例 #14
0
        internal static string CreateSplitColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int?seed = null, bool fallbackInEnvSeed = false)
        {
            Contracts.CheckValue(env, nameof(env));
            Contracts.CheckValueOrNull(samplingKeyColumn);

            var splitColumnName = data.Schema.GetTempColumnName("SplitColumn");
            int?seedToUse;

            if (seed.HasValue)
            {
                seedToUse = seed.Value;
            }
            else if (fallbackInEnvSeed)
            {
                ISeededEnvironment seededEnv = (ISeededEnvironment)env;
                seedToUse = seededEnv.Seed;
            }
            else
            {
                seedToUse = null;
            }

            // We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number.
            if (samplingKeyColumn == null)
            {
                data = new GenerateNumberTransform(env, data, splitColumnName, (uint?)seedToUse);
            }
            else
            {
                // If samplingKeyColumn was provided we will make a new column based on it, but using a temporary
                // name, as it might be dropped elsewhere in the code

                if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int samplingColIndex))
                {
                    throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);
                }

                var type = data.Schema[samplingColIndex].Type;
                if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
                {
                    var hashInputColumnName = samplingKeyColumn;
                    // HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
                    var itemType = type.GetItemType();
                    if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
                    {
                        data = new TypeConvertingTransformer(env, splitColumnName, DataKind.Int64, samplingKeyColumn).Transform(data);
                        hashInputColumnName = splitColumnName;
                    }

                    var columnOptions =
                        seedToUse.HasValue ?
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) :
                        new HashingEstimator.ColumnOptions(splitColumnName, hashInputColumnName, 30, combine: true);
                    data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
                }
                else
                {
                    if (type != NumberDataViewType.Single && type != NumberDataViewType.Double)
                    {
                        data = new ColumnCopyingEstimator(env, (splitColumnName, samplingKeyColumn)).Fit(data).Transform(data);
                    }
                    else
                    {
                        data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(splitColumnName, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data);
                    }
                }
            }

            return(splitColumnName);
        }