private static int FindVectorInputColumn(IHostEnvironment env, IReadOnlyList <string> inputColumnNames, DataViewSchema inputSchema, DataViewType[] inputTypes)
        {
            int ivec = -1;

            for (int isrc = 0; isrc < inputColumnNames.Count; isrc++)
            {
                var col = inputSchema.GetColumnOrNull(inputColumnNames[isrc]);
                if (col == null)
                {
                    throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnNames[isrc]);
                }

                if (col.Value.Type is VectorDataViewType)
                {
                    if (ivec >= 0)
                    {
                        throw env.ExceptUserArg(nameof(inputColumnNames), "Can have at most one vector-valued source column");
                    }
                    ivec = isrc;
                }
                inputTypes[isrc] = col.Value.Type.GetItemType();
            }

            return(ivec);
        }
示例#2
0
        public void SetInput <TInput, TInput2>(ArrayVar <TInput> variable, TInput2[] input)
            where TInput : class
        {
            _env.CheckValue(variable, nameof(variable));
            var varName = variable.VarName;

            _env.CheckNonEmpty(varName, nameof(variable.VarName));
            _env.CheckValue(input, nameof(input));
            if (!typeof(TInput).IsAssignableFrom(typeof(TInput2)))
            {
                throw _env.ExceptUserArg(nameof(input), $"Type {typeof(TInput2)} not castable to type {typeof(TInput)}");
            }

            EntryPointVariable entryPointVariable = _graph.GetVariableOrNull(varName);

            if (entryPointVariable == null)
            {
                throw _env.Except("Port '{0}' not found", varName);
            }
            if (entryPointVariable.HasOutputs)
            {
                throw _env.Except("Port '{0}' is not an input", varName);
            }
            if (entryPointVariable.Value != null)
            {
                throw _env.Except("Port '{0}' is already set", varName);
            }
            if (!entryPointVariable.Type.IsAssignableFrom(typeof(TInput[])))
            {
                throw _env.Except("Port '{0}' is of incorrect type", varName);
            }

            entryPointVariable.SetValue(input);
        }
        public static CommonOutputs.MacroOutput <CommonOutputs.TransformOutput> IfNeeded(
            IHostEnvironment env,
            NormalizeTransform.MinMaxArguments input,
            EntryPointNode node)
        {
            var    schema             = input.Data.Schema;
            DvBool isNormalized       = DvBool.False;
            var    columnsToNormalize = new List <NormalizeTransform.AffineColumn>();

            foreach (var column in input.Column)
            {
                int col;
                if (!schema.TryGetColumnIndex(column.Source, out col))
                {
                    throw env.ExceptUserArg(nameof(input.Column), $"Column '{column.Source}' does not exist.");
                }
                if (!schema.TryGetMetadata(BoolType.Instance, MetadataUtils.Kinds.IsNormalized, col, ref isNormalized) ||
                    isNormalized.IsFalse)
                {
                    columnsToNormalize.Add(column);
                }
            }

            var entryPoints = new List <EntryPointNode>();

            if (columnsToNormalize.Count == 0)
            {
                var entryPointNode = EntryPointNode.Create(env, "Transforms.NoOperation", new NopTransform.NopInput(),
                                                           node.Catalog, node.Context, node.InputBindingMap, node.InputMap, node.OutputMap);
                entryPoints.Add(entryPointNode);
            }
            else
            {
                input.Column = columnsToNormalize.ToArray();
                var entryPointNode = EntryPointNode.Create(env, "Transforms.MinMaxNormalizer", input,
                                                           node.Catalog, node.Context, node.InputBindingMap, node.InputMap, node.OutputMap);
                entryPoints.Add(entryPointNode);
            }

            return(new CommonOutputs.MacroOutput <CommonOutputs.TransformOutput>()
            {
                Nodes = entryPoints
            });
        }
        private static int FindVectorInputColumn(IHostEnvironment env, IReadOnlyList <string> inputColumnNames, SchemaShape inputSchema, DataViewType[] inputTypes)
        {
            int ivec = -1;

            for (int isrc = 0; isrc < inputColumnNames.Count; isrc++)
            {
                if (!inputSchema.TryFindColumn(inputColumnNames[isrc], out var col))
                {
                    throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnNames[isrc]);
                }

                if (col.Kind != SchemaShape.Column.VectorKind.Scalar)
                {
                    if (ivec >= 0)
                    {
                        throw env.ExceptUserArg(nameof(inputColumnNames), "Can have at most one vector-valued source column");
                    }
                    ivec = isrc;
                }
                inputTypes[isrc] = col.ItemType;
            }

            return(ivec);
        }
        public static IDataTransform CreateTransformCore(
            OutputKind argsOutputKind,
            OneToOneColumn[] columns,
            List <OutputKind?> columnOutputKinds,
            IDataTransform input,
            IHost h,
            IHostEnvironment env,
            CategoricalHashTransform.Arguments catHashArgs = null)
        {
            Contracts.CheckValue(columns, nameof(columns));
            Contracts.CheckValue(columnOutputKinds, nameof(columnOutputKinds));
            Contracts.CheckParam(columns.Length == columnOutputKinds.Count, nameof(columns));

            using (var ch = h.Start("Create Transform Core"))
            {
                // Create the KeyToVectorTransform, if needed.
                List <KeyToVectorTransform.Column> cols = new List <KeyToVectorTransform.Column>();
                bool binaryEncoding = argsOutputKind == OutputKind.Bin;
                for (int i = 0; i < columns.Length; i++)
                {
                    var column = columns[i];
                    if (!column.TrySanitize())
                    {
                        throw h.ExceptUserArg(nameof(Column.Name));
                    }

                    bool?      bag;
                    OutputKind kind = columnOutputKinds[i].HasValue ? columnOutputKinds[i].Value : argsOutputKind;
                    switch (kind)
                    {
                    default:
                        throw env.ExceptUserArg(nameof(Column.OutputKind));

                    case OutputKind.Key:
                        continue;

                    case OutputKind.Bin:
                        binaryEncoding = true;
                        bag            = false;
                        break;

                    case OutputKind.Ind:
                        bag = false;
                        break;

                    case OutputKind.Bag:
                        bag = true;
                        break;
                    }
                    var col = new KeyToVectorTransform.Column();
                    col.Name   = column.Name;
                    col.Source = column.Name;
                    col.Bag    = bag;
                    cols.Add(col);
                }

                if (cols.Count == 0)
                {
                    return(input);
                }

                IDataTransform transform;
                if (binaryEncoding)
                {
                    if ((catHashArgs?.InvertHash ?? 0) != 0)
                    {
                        ch.Warning("Invert hashing is being used with binary encoding.");
                    }

                    var keyToBinaryArgs = new KeyToBinaryVectorTransform.Arguments();
                    keyToBinaryArgs.Column = cols.ToArray();
                    transform = new KeyToBinaryVectorTransform(h, keyToBinaryArgs, input);
                }
                else
                {
                    var keyToVecArgs = new KeyToVectorTransform.Arguments
                    {
                        Bag    = argsOutputKind == OutputKind.Bag,
                        Column = cols.ToArray()
                    };

                    transform = new KeyToVectorTransform(h, keyToVecArgs, input);
                }

                ch.Done();
                return(transform);
            }
        }
        /// <summary>
        /// Returns the feature selection scores for each slot of each column.
        /// </summary>
        /// <param name="env">The host environment.</param>
        /// <param name="input">The input dataview.</param>
        /// <param name="columns">The columns for which to compute the feature selection scores.</param>
        /// <param name="colSizes">Outputs an array containing the vector sizes of the input columns</param>
        /// <returns>A list of scores.</returns>
        public static long[][] Train(IHostEnvironment env, IDataView input, string[] columns, out int[] colSizes)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));
            env.CheckParam(Utils.Size(columns) > 0, nameof(columns));

            var schema      = input.Schema;
            var size        = columns.Length;
            var activeInput = new bool[schema.ColumnCount];
            var colSrcs     = new int[size];
            var colTypes    = new ColumnType[size];

            colSizes = new int[size];
            for (int i = 0; i < size; i++)
            {
                int colSrc;
                var colName = columns[i];
                if (!schema.TryGetColumnIndex(colName, out colSrc))
                {
                    throw env.ExceptUserArg(nameof(CountFeatureSelectionTransform.Arguments.Column), "Source column '{0}' not found", colName);
                }

                var colType = schema.GetColumnType(colSrc);
                if (colType.IsVector && !colType.IsKnownSizeVector)
                {
                    throw env.ExceptUserArg(nameof(CountFeatureSelectionTransform.Arguments.Column), "Variable length column '{0}' is not allowed", colName);
                }

                activeInput[colSrc] = true;
                colSrcs[i]          = colSrc;
                colTypes[i]         = colType;
                colSizes[i]         = colType.ValueCount;
            }

            var    aggregators = new CountAggregator[size];
            long   rowCur      = 0;
            double rowCount    = input.GetRowCount(true) ?? double.NaN;

            using (var pch = env.StartProgressChannel("Aggregating counts"))
                using (var cursor = input.GetRowCursor(col => activeInput[col]))
                {
                    var header = new ProgressHeader(new[] { "rows" });
                    pch.SetHeader(header, e => { e.SetProgress(0, rowCur, rowCount); });
                    for (int i = 0; i < size; i++)
                    {
                        if (colTypes[i].IsVector)
                        {
                            aggregators[i] = GetVecAggregator(cursor, colTypes[i], colSrcs[i]);
                        }
                        else
                        {
                            aggregators[i] = GetOneAggregator(cursor, colTypes[i], colSrcs[i]);
                        }
                    }

                    while (cursor.MoveNext())
                    {
                        for (int i = 0; i < size; i++)
                        {
                            aggregators[i].ProcessValue();
                        }
                        rowCur++;
                    }
                    pch.Checkpoint(rowCur);
                }
            return(aggregators.Select(a => a.Count).ToArray());
        }