示例#1
0
 private Delegate CreateGetterDelegate <TValue>(int col)
 {
     Ch.Assert(0 <= col && col < _colToActivesIndex.Length);
     Ch.Assert(_colToActivesIndex[col] >= 0);
     Ch.Assert(Schema.GetColumnType(col).RawType == typeof(TValue));
     return(CreateGetterDelegate <TValue>(_pipes[_colToActivesIndex[col]]));
 }
        /// <summary>
        /// Creates a ColumnInfo for the column with the given column index. Note that the name
        /// of the column might actually map to a different column, so this should be used with care
        /// and rarely.
        /// </summary>
        public static ColumnInfo CreateFromIndex(Schema schema, int index)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.CheckParam(0 <= index && index < schema.ColumnCount, nameof(index));

            return(new ColumnInfo(schema.GetColumnName(index), index, schema.GetColumnType(index)));
        }
示例#3
0
        /// <summary>
        /// The categoricalFeatures is a vector of the indices of categorical features slots.
        /// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
        /// So if its value is the range of numbers: 0,2,3,4,8,9
        /// look at it as [0,2],[3,4],[8,9].
        /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
        /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
        /// </summary>
        public static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex, out int[] categoricalFeatures)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.Check(colIndex >= 0, nameof(colIndex));

            bool isValid = false;

            categoricalFeatures = null;
            if (!schema.GetColumnType(colIndex).IsKnownSizeVector)
            {
                return(isValid);
            }

            var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex);

            if (type?.RawType == typeof(VBuffer <int>))
            {
                VBuffer <int> catIndices = default(VBuffer <int>);
                schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices);
                VBufferUtils.Densify(ref catIndices);
                int columnSlotsCount = schema.GetColumnType(colIndex).AsVector.VectorSizeCore;
                if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
                {
                    int previousEndIndex = -1;
                    isValid = true;
                    for (int i = 0; i < catIndices.Values.Length; i += 2)
                    {
                        if (catIndices.Values[i] > catIndices.Values[i + 1] ||
                            catIndices.Values[i] <= previousEndIndex ||
                            catIndices.Values[i] >= columnSlotsCount ||
                            catIndices.Values[i + 1] >= columnSlotsCount)
                        {
                            isValid = false;
                            break;
                        }

                        previousEndIndex = catIndices.Values[i + 1];
                    }
                    if (isValid)
                    {
                        categoricalFeatures = catIndices.Values.Select(val => val).ToArray();
                    }
                }
            }

            return(isValid);
        }
示例#4
0
            private Delegate CreateGetterDelegate(int col)
            {
                Ch.Assert(0 <= col && col < _colToActivesIndex.Length);
                Ch.Assert(_colToActivesIndex[col] >= 0);
                Func <int, Delegate> createDel = CreateGetterDelegate <int>;

                return(Utils.MarshalInvoke(createDel, Schema.GetColumnType(col).RawType, col));
            }
            public Contents(ModelLoadContext ctx, Schema input, Func <ColumnType[], string> testTypes)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                Contracts.CheckValue(input, nameof(input));
                Contracts.CheckValueOrNull(testTypes);

                Input = input;

                // *** Binary format ***
                // int: number of added columns
                // for each added column
                //   int: id of output column name
                //   int: number of input column names
                //   int[]: ids of input column names
                int cinfo = ctx.Reader.ReadInt32();

                Contracts.CheckDecode(cinfo > 0);

                Infos = new ColInfo[cinfo];
                Names = new string[cinfo];
                for (int i = 0; i < cinfo; i++)
                {
                    Names[i] = ctx.LoadNonEmptyString();

                    int csrc = ctx.Reader.ReadInt32();
                    Contracts.CheckDecode(csrc > 0);
                    int[] indices  = new int[csrc];
                    var   srcTypes = new ColumnType[csrc];
                    int?  srcSize  = 0;
                    for (int j = 0; j < csrc; j++)
                    {
                        string src = ctx.LoadNonEmptyString();
                        if (!input.TryGetColumnIndex(src, out indices[j]))
                        {
                            throw Contracts.Except("Source column '{0}' is required but not found", src);
                        }
                        srcTypes[j] = input.GetColumnType(indices[j]);
                        var size = srcTypes[j].ValueCount;
                        srcSize = size == 0 ? null : checked (srcSize + size);
                    }

                    if (testTypes != null)
                    {
                        string reason = testTypes(srcTypes);
                        if (reason != null)
                        {
                            throw Contracts.Except("Source columns '{0}' have invalid types: {1}. Source types: '{2}'.",
                                                   string.Join(", ", indices.Select(k => input.GetColumnName(k))),
                                                   reason,
                                                   string.Join(", ", srcTypes.Select(type => type.ToString())));
                        }
                    }

                    Infos[i] = new ColInfo(srcSize.GetValueOrDefault(), indices, srcTypes);
                }
            }
示例#6
0
 public ColumnType GetColumnType(int col)
 {
     _ectx.Check(0 <= col && col < ColumnCount);
     if (!IsPivot(col))
     {
         return(_inputSchema.GetColumnType(col));
     }
     _ectx.Assert(0 <= _pivotIndex[col] && _pivotIndex[col] < _infos.Length);
     return(_infos[_pivotIndex[col]].ItemType);
 }
示例#7
0
            public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx, Schema input,
                                          ITransposeSchema transInput, Func <ColumnType, string> testType)
            {
                Contracts.AssertValue(parent);
                var host = parent.Host;

                host.CheckValue(ctx, nameof(ctx));
                host.AssertValue(input);
                host.AssertValueOrNull(transInput);
                host.AssertValueOrNull(testType);

                // *** Binary format ***
                // int: number of added columns
                // for each added column
                //   int: id of output column name
                //   int: id of input column name
                int cinfo = ctx.Reader.ReadInt32();

                host.CheckDecode(cinfo > 0);

                var names = new string[cinfo];
                var infos = new ColInfo[cinfo];

                for (int i = 0; i < cinfo; i++)
                {
                    string dst = ctx.LoadNonEmptyString();
                    names[i] = dst;

                    // Note that in old files, the source name may be null indicating that
                    // the source column has the same name as the added column.
                    string tmp = ctx.LoadStringOrNull();
                    string src = tmp ?? dst;
                    host.CheckDecode(!string.IsNullOrEmpty(src));

                    int colSrc;
                    if (!input.TryGetColumnIndex(src, out colSrc))
                    {
                        throw host.Except("Source column '{0}' is required but not found", src);
                    }
                    var type = input.GetColumnType(colSrc);
                    if (testType != null)
                    {
                        string reason = testType(type);
                        if (reason != null)
                        {
                            throw host.Except(InvalidTypeErrorFormat, src, type, reason);
                        }
                    }
                    var slotType = transInput == null ? null : transInput.GetSlotType(colSrc);
                    infos[i] = new ColInfo(dst, colSrc, type, slotType);
                }

                return(new Bindings(parent, infos, input, false, names));
            }
        protected ManyToOneColumnBindingsBase(ManyToOneColumn[] column, Schema input, Func <ColumnType[], string> testTypes)
            : base(input, true, GetNamesAndSanitize(column))
        {
            Contracts.AssertNonEmpty(column);
            Contracts.Assert(column.Length == InfoCount);

            // In lieu of actual protections, I have the following silly asserts, so we can have some
            // warning if we decide to rename this argument, and so know to change the below hard-coded
            // standard column name.
            const string standardColumnArgName = "Column";

            Contracts.Assert(nameof(ValueToKeyMappingTransformer.Arguments.Column) == standardColumnArgName);
            Contracts.Assert(nameof(ColumnConcatenatingTransformer.Arguments.Column) == standardColumnArgName);

            Infos = new ColInfo[InfoCount];
            for (int i = 0; i < Infos.Length; i++)
            {
                var item = column[i];
                Contracts.AssertNonEmpty(item.Name);
                Contracts.AssertNonEmpty(item.Source);

                var src        = item.Source;
                var srcIndices = new int[src.Length];
                var srcTypes   = new ColumnType[src.Length];
                int?srcSize    = 0;
                for (int j = 0; j < src.Length; j++)
                {
                    Contracts.CheckUserArg(!string.IsNullOrWhiteSpace(src[j]), nameof(ManyToOneColumn.Source));
#pragma warning disable MSML_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
                    if (!input.TryGetColumnIndex(src[j], out srcIndices[j]))
                    {
                        throw Contracts.ExceptUserArg(standardColumnArgName, "Source column '{0}' not found", src[j]);
                    }
#pragma warning restore MSML_ContractsNameUsesNameof
                    srcTypes[j] = input.GetColumnType(srcIndices[j]);
                    var size = srcTypes[j].ValueCount;
                    srcSize = size == 0 ? null : checked (srcSize + size);
                }

                if (testTypes != null)
                {
                    string reason = testTypes(srcTypes);
                    if (reason != null)
                    {
#pragma warning disable MSML_ContractsNameUsesNameof // Unfortunately, there is no base class for the columns bindings.
                        throw Contracts.ExceptUserArg(standardColumnArgName, "Column '{0}' has invalid source types: {1}. Source types: '{2}'.",
                                                      item.Name, reason, string.Join(", ", srcTypes.Select(type => type.ToString())));
#pragma warning restore MSML_ContractsNameUsesNameof
                    }
                }
                Infos[i] = new ColInfo(srcSize.GetValueOrDefault(), srcIndices, srcTypes);
            }
        }
示例#9
0
        /// <summary>
        /// Tries to create a ColumnInfo for the column with the given name in the given schema. Returns
        /// false if the name doesn't map to a column.
        /// </summary>
        public static bool TryCreateFromName(Schema schema, string name, out ColumnInfo colInfo)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.CheckNonEmpty(name, nameof(name));

            colInfo = null;
            if (!schema.TryGetColumnIndex(name, out int index))
                return false;

            colInfo = new ColumnInfo(name, index, schema.GetColumnType(index));
            return true;
        }
示例#10
0
        private void CheckInputColumnTypes(Schema schema, out ColumnType labelType, out ColumnType scoreType,
                                           out Schema.Metadata labelMetadata, out Schema.Metadata scoreMetadata)
        {
            Host.AssertNonEmpty(ScoreCol);
            Host.AssertNonEmpty(LabelCol);

            var t = schema.GetColumnType(LabelIndex);

            if (!t.IsKnownSizeVector || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8))
            {
                throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", LabelCol, t);
            }
            labelType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize);
            var slotNamesType = new VectorType(TextType.Instance, t.VectorSize);
            var builder       = new Schema.Metadata.Builder();

            builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, LabelIndex, labelType.VectorSize, "True"));
            labelMetadata = builder.GetMetadata();

            t = schema.GetColumnType(ScoreIndex);
            if (t.VectorSize == 0 || t.ItemType != NumberType.Float)
            {
                throw Host.Except("Score column '{0}' has type '{1}' but must be a known length vector of type R4", ScoreCol, t);
            }
            scoreType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize);
            builder   = new Schema.Metadata.Builder();
            builder.AddSlotNames(t.VectorSize, CreateSlotNamesGetter(schema, ScoreIndex, scoreType.VectorSize, "Predicted"));

            ValueGetter <ReadOnlyMemory <char> > getter = GetScoreColumnKind;

            builder.Add(new Schema.Column(MetadataUtils.Kinds.ScoreColumnKind, TextType.Instance, null), getter);
            getter = GetScoreValueKind;
            builder.Add(new Schema.Column(MetadataUtils.Kinds.ScoreValueKind, TextType.Instance, null), getter);
            ValueGetter <uint> uintGetter = GetScoreColumnSetId(schema);

            builder.Add(new Schema.Column(MetadataUtils.Kinds.ScoreColumnSetId, MetadataUtils.ScoreColumnSetIdType, null), uintGetter);
            scoreMetadata = builder.GetMetadata();
        }
        /// <summary>
        /// Gets the mapping from T into a StringBuilder representation, using various heuristics.
        /// This StringBuilder representation will be a component of the composed KeyValues for the
        /// hash outputs.
        /// </summary>
        public static ValueMapper <T, StringBuilder> GetSimpleMapper <T>(Schema schema, int col)
        {
            Contracts.AssertValue(schema);
            Contracts.Assert(0 <= col && col < schema.ColumnCount);
            var type = schema.GetColumnType(col).ItemType;

            Contracts.Assert(type.RawType == typeof(T));
            var conv = Conversion.Conversions.Instance;

            // First: if not key, then get the standard string converison.
            if (!type.IsKey)
            {
                return(conv.GetStringConversion <T>(type));
            }

            bool identity;

            // Second choice: if key, utilize the KeyValues metadata for that key, if it has one and is text.
            if (schema.HasKeyNames(col, type.KeyCount))
            {
                // REVIEW: Non-textual KeyValues are certainly possible. Should we handle them?
                // Get the key names.
                VBuffer <ReadOnlyMemory <char> > keyValues = default;
                schema.GetMetadata(MetadataUtils.Kinds.KeyValues, col, ref keyValues);
                ReadOnlyMemory <char> value = default;

                // REVIEW: We could optimize for identity, but it's probably not worthwhile.
                var keyMapper = conv.GetStandardConversion <T, uint>(type, NumberType.U4, out identity);
                return
                    ((ref T src, ref StringBuilder dst) =>
                {
                    ClearDst(ref dst);
                    uint intermediate = 0;
                    keyMapper(ref src, ref intermediate);
                    if (intermediate == 0)
                    {
                        return;
                    }
                    keyValues.GetItemOrDefault((int)(intermediate - 1), ref value);
                    dst.AppendMemory(value);
                });
            }

            // Third choice: just use the key value itself, subject to offsetting by the min.
            return(conv.GetKeyStringConversion <T>(type.AsKey));
        }
            public FeatureContributionSchema(IExceptionContext ectx, string columnName, ColumnType columnType, Schema parentSchema, int featureCol)
            {
                Contracts.CheckValueOrNull(ectx);
                Contracts.CheckValue(parentSchema, nameof(parentSchema));
                _ectx = ectx;
                _ectx.CheckNonEmpty(columnName, nameof(columnName));
                _parentSchema      = parentSchema;
                _featureCol        = featureCol;
                _featureVectorSize = _parentSchema.GetColumnType(_featureCol).VectorSize;
                _hasSlotNames      = _parentSchema.HasSlotNames(_featureCol, _featureVectorSize);

                _names         = new string[] { columnName };
                _types         = new ColumnType[] { columnType };
                _columnNameMap = new Dictionary <string, int>()
                {
                    { columnName, 0 }
                };
            }
示例#13
0
        public override Schema GetOutputSchema(Schema inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));

            if (FeatureColumn != null)
            {
                if (!inputSchema.TryGetColumnIndex(FeatureColumn, out int col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), null);
                }
                if (!inputSchema.GetColumnType(col).Equals(FeatureColumnType))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn, FeatureColumnType.ToString(), inputSchema.GetColumnType(col).ToString());
                }
            }

            return(Transform(new EmptyDataView(Host, inputSchema)).Schema);
        }
            public InputRowBase(IHostEnvironment env, Schema schema, InternalSchemaDefinition schemaDef, Delegate[] peeks, Func <int, bool> predicate)
            {
                Contracts.AssertValue(env);
                Host = env.Register("Row");
                Host.AssertValue(schema);
                Host.AssertValue(schemaDef);
                Host.AssertValue(peeks);
                Host.AssertValue(predicate);
                Host.Assert(schema.ColumnCount == schemaDef.Columns.Length);
                Host.Assert(schema.ColumnCount == peeks.Length);

                _colCount = schema.ColumnCount;
                Schema    = schema;
                _getters  = new Delegate[_colCount];
                for (int c = 0; c < _colCount; c++)
                {
                    _getters[c] = predicate(c) ? CreateGetter(schema.GetColumnType(c), schemaDef.Columns[c], peeks[c]) : null;
                }
            }
示例#15
0
            public static Bindings Create(OneToOneTransformBase parent, OneToOneColumn[] column, Schema input,
                                          ITransposeSchema transInput, Func <ColumnType, string> testType)
            {
                Contracts.AssertValue(parent);
                var host = parent.Host;

                host.CheckUserArg(Utils.Size(column) > 0, nameof(column));
                host.AssertValue(input);
                host.AssertValueOrNull(transInput);
                host.AssertValueOrNull(testType);

                var names = new string[column.Length];
                var infos = new ColInfo[column.Length];

                for (int i = 0; i < names.Length; i++)
                {
                    var item = column[i];
                    host.CheckUserArg(item.TrySanitize(), nameof(OneToOneColumn.Name), "Invalid new column name");
                    names[i] = item.Name;

                    int colSrc;
                    if (!input.TryGetColumnIndex(item.Source, out colSrc))
                    {
                        throw host.ExceptUserArg(nameof(OneToOneColumn.Source), "Source column '{0}' not found", item.Source);
                    }

                    var type = input.GetColumnType(colSrc);
                    if (testType != null)
                    {
                        string reason = testType(type);
                        if (reason != null)
                        {
                            throw host.ExceptUserArg(nameof(OneToOneColumn.Source), InvalidTypeErrorFormat, item.Source, type, reason);
                        }
                    }

                    var slotType = transInput == null ? null : transInput.GetSlotType(colSrc);
                    infos[i] = new ColInfo(names[i], colSrc, type, slotType);
                }

                return(new Bindings(parent, infos, input, true, names));
            }
示例#16
0
        /// <summary>
        /// Initializes a new reference of <see cref="SingleFeaturePredictionTransformerBase{TModel, TScorer}"/>.
        /// </summary>
        /// <param name="host">The local instance of <see cref="IHost"/>.</param>
        /// <param name="model">The model used for scoring.</param>
        /// <param name="trainSchema">The schema of the training data.</param>
        /// <param name="featureColumn">The feature column name.</param>
        public SingleFeaturePredictionTransformerBase(IHost host, TModel model, Schema trainSchema, string featureColumn)
            : base(host, model, trainSchema)
        {
            FeatureColumn = featureColumn;

            FeatureColumn = featureColumn;
            if (featureColumn == null)
            {
                FeatureColumnType = null;
            }
            else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
            }
            else
            {
                FeatureColumnType = trainSchema.GetColumnType(col);
            }

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
        }
            private bool SchemasMatch(Schema schema1, Schema schema2)
            {
                if (schema1.ColumnCount != schema2.ColumnCount)
                {
                    return(false);
                }

                int colLim = schema1.ColumnCount;

                for (int col = 0; col < colLim; col++)
                {
                    var type1 = schema1.GetColumnType(col);
                    var type2 = schema2.GetColumnType(col);
                    if (!type1.Equals(type2))
                    {
                        return(false);
                    }
                }

                return(true);
            }
示例#18
0
        public void SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(CanSaveOnnx);

            for (int iinfo = 0; iinfo < Infos.Length; ++iinfo)
            {
                ColInfo info             = Infos[iinfo];
                string  sourceColumnName = Source.Schema.GetColumnName(info.Source);
                if (!ctx.ContainsColumn(sourceColumnName))
                {
                    ctx.RemoveColumn(info.Name, false);
                    continue;
                }

                if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName),
                                    ctx.AddIntermediateVariable(Schema.GetColumnType(_bindings.MapIinfoToCol(iinfo)), info.Name)))
                {
                    ctx.RemoveColumn(info.Name, true);
                }
            }
        }
示例#19
0
        private static void PrintSchema(TextWriter writer, Arguments args, Schema schema, ITransposeSchema tschema)
        {
            Contracts.AssertValue(writer);
            Contracts.AssertValue(args);
            Contracts.AssertValue(schema);
            Contracts.AssertValueOrNull(tschema);
#if !CORECLR
            if (args.ShowJson)
            {
                writer.WriteLine("Json Schema not supported.");
                return;
            }
#endif
            int colLim = schema.ColumnCount;

            var itw = new IndentedTextWriter(writer, "  ");
            itw.WriteLine("{0} columns:", colLim);
            using (itw.Nest())
            {
                var names = default(VBuffer <ReadOnlyMemory <char> >);
                for (int col = 0; col < colLim; col++)
                {
                    var name     = schema.GetColumnName(col);
                    var type     = schema.GetColumnType(col);
                    var slotType = tschema == null ? null : tschema.GetSlotType(col);
                    itw.WriteLine("{0}: {1}{2}", name, type, slotType == null ? "" : " (T)");

                    bool metaVals = args.ShowMetadataValues;
                    if (metaVals || args.ShowMetadataTypes)
                    {
                        ShowMetadata(itw, schema, col, metaVals);
                        continue;
                    }

                    if (!args.ShowSlots)
                    {
                        continue;
                    }
                    if (!type.IsKnownSizeVector)
                    {
                        continue;
                    }
                    ColumnType typeNames;
                    if ((typeNames = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, col)) == null)
                    {
                        continue;
                    }
                    if (typeNames.VectorSize != type.VectorSize || !typeNames.ItemType.IsText)
                    {
                        Contracts.Assert(false, "Unexpected slot names type");
                        continue;
                    }
                    schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref names);
                    if (names.Length != type.VectorSize)
                    {
                        Contracts.Assert(false, "Unexpected length of slot names vector");
                        continue;
                    }

                    using (itw.Nest())
                    {
                        bool verbose = args.Verbose ?? false;
                        foreach (var kvp in names.Items(all: verbose))
                        {
                            if (verbose || !kvp.Value.IsEmpty)
                            {
                                itw.WriteLine("{0}:{1}", kvp.Key, kvp.Value);
                            }
                        }
                    }
                }
            }
        }
示例#20
0
 public ColumnType GetColumnType(int col)
 {
     Contracts.CheckParam(0 <= col && col < ColumnCount, nameof(col));
     return(_input.GetColumnType(Sources[col]));
 }