/// <summary>
                /// This is the constructor called for the initial wrapping.
                /// </summary>
                public Bound(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorDataViewType type, ValueGetter <VBuffer <T> > getter,
                             string metadataKind, Func <ISchemaBoundMapper, DataViewType, bool> canWrap)
                {
                    Contracts.CheckValue(env, nameof(env));
                    _host = env.Register(LoaderSignature);
                    _host.CheckValue(mapper, nameof(mapper));
                    _host.CheckValue(type, nameof(type));
                    _host.CheckValue(getter, nameof(getter));
                    _host.CheckNonEmpty(metadataKind, nameof(metadataKind));
                    _host.CheckValueOrNull(canWrap);

                    _mapper = mapper;

                    int  scoreIdx;
                    bool result = mapper.OutputSchema.TryGetColumnIndex(AnnotationUtils.Const.ScoreValueKind.Score, out scoreIdx);

                    if (!result)
                    {
                        throw env.ExceptParam(nameof(mapper), "Mapper did not have a '{0}' column", AnnotationUtils.Const.ScoreValueKind.Score);
                    }

                    _labelNameType   = type;
                    _labelNameGetter = getter;
                    _metadataKind    = metadataKind;
                    _canWrap         = canWrap;

                    OutputSchema = DecorateOutputSchema(mapper.OutputSchema, scoreIdx, _labelNameType, _labelNameGetter, _metadataKind);
                }
            public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segment[] segs, bool user)
            {
                Contracts.AssertNonEmpty(name);
                Contracts.AssertValue(itemType);
                Contracts.AssertValueOrNull(segs);

                int          size = 0;
                DataViewType type = itemType;

                if (segs != null)
                {
                    var order = Utils.GetIdentityPermutation(segs.Length);

                    if ((segs.Length != 0) && (segs[0].Name is null))
                    {
                        Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min));

                        // Check that the segments are disjoint.
                        for (int i = 1; i < order.Length; i++)
                        {
                            int a = order[i - 1];
                            int b = order[i];
                            Contracts.Assert(segs[a].Min <= segs[b].Min);
                            if (segs[a].Lim > segs[b].Min)
                            {
                                throw user?
                                      Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) :
                                          Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name);
                            }
                        }
                    }

                    // Note: since we know that the segments don't overlap, we're guaranteed that
                    // the sum of their sizes doesn't overflow.
                    for (int i = 0; i < segs.Length; i++)
                    {
                        var seg = segs[i];
                        size += (seg.Name is null) ? seg.Lim - seg.Min : 1;
                    }
                    Contracts.Assert(size >= segs.Length);

                    if (size > 1 || segs[0].ForceVector)
                    {
                        type = new VectorDataViewType(itemType, size);
                    }
                }
                else
                {
                    size++;
                }

                return(new ColInfo(name, type, segs, size));
            }
            public ColInfo(string name, int colSrc, DataViewType typeSrc, VectorDataViewType slotTypeSrc)
            {
                Contracts.AssertNonEmpty(name);
                Contracts.Assert(colSrc >= 0);
                Contracts.AssertValue(typeSrc);
                Contracts.AssertValueOrNull(slotTypeSrc);
                Contracts.Assert(slotTypeSrc == null || typeSrc.GetItemType().Equals(slotTypeSrc.ItemType));

                Name        = name;
                Source      = colSrc;
                TypeSrc     = typeSrc;
                SlotTypeSrc = slotTypeSrc;
            }
        private static ValueGetter <VBuffer <TDst> > GetVecGetterAsCore <TSrc, TDst>(VectorDataViewType typeSrc, PrimitiveDataViewType typeDst, GetterFactory getterFact)
        {
            Contracts.Assert(typeof(TSrc) == typeSrc.ItemType.RawType);
            Contracts.Assert(typeof(TDst) == typeDst.RawType);
            Contracts.AssertValue(getterFact);

            var  getter = getterFact.GetGetter <VBuffer <TSrc> >();
            bool identity;
            var  conv = Conversions.Instance.GetStandardConversion <TSrc, TDst>(typeSrc.ItemType, typeDst, out identity);

            if (identity)
            {
                Contracts.Assert(typeof(TSrc) == typeof(TDst));
                return((ValueGetter <VBuffer <TDst> >)(Delegate) getter);
            }

            int size = typeSrc.Size;
            var src  = default(VBuffer <TSrc>);

            return((ref VBuffer <TDst> dst) =>
            {
                getter(ref src);
                if (size > 0)
                {
                    Contracts.Check(src.Length == size);
                }

                var srcValues = src.GetValues();
                int count = srcValues.Length;
                var editor = VBufferEditor.Create(ref dst, src.Length, count);
                if (count > 0)
                {
                    // REVIEW: This would be faster if there were loops for each std conversion.
                    // Consider adding those to the Conversions class.
                    for (int i = 0; i < count; i++)
                    {
                        conv(in srcValues[i], ref editor.Values[i]);
                    }

                    if (!src.IsDense)
                    {
                        var srcIndices = src.GetIndices();
                        srcIndices.CopyTo(editor.Indices);
                    }
                }
                dst = editor.Commit();
            });
        }
            private LabelNameBindableMapper(IHostEnvironment env, ISchemaBindableMapper bindable, VectorDataViewType type, Delegate getter,
                                            string metadataKind, Func <ISchemaBoundMapper, DataViewType, bool> canWrap)
            {
                Contracts.AssertValue(env);
                _host = env.Register(LoaderSignature);
                _host.AssertValue(bindable);
                _host.AssertValue(type);
                _host.AssertValue(getter);
                _host.AssertNonEmpty(metadataKind);
                _host.AssertValueOrNull(canWrap);

                _bindable     = bindable;
                _type         = type;
                _getter       = getter;
                _metadataKind = metadataKind;
                _canWrap      = canWrap;
            }
            public RowMapper(IHostEnvironment env, BindableMapper parent, RoleMappedSchema schema)
            {
                Contracts.AssertValue(env);
                _env = env;
                _env.AssertValue(schema);
                _env.AssertValue(parent);
                _env.Assert(schema.Feature.HasValue);
                _parent = parent;
                InputRoleMappedSchema = schema;
                var genericMapper = parent.GenericMapper.Bind(_env, schema);

                _genericRowMapper = genericMapper as ISchemaBoundRowMapper;
                var featureSize = FeatureColumn.Type.GetVectorSize();

                if (parent.Stringify)
                {
                    var builder = new DataViewSchema.Builder();
                    builder.AddColumn(DefaultColumnNames.FeatureContributions, TextDataViewType.Instance, null);
                    _outputSchema = builder.ToSchema();
                    if (FeatureColumn.HasSlotNames(featureSize))
                    {
                        FeatureColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref _slotNames);
                    }
                    else
                    {
                        _slotNames = VBufferUtils.CreateEmpty <ReadOnlyMemory <char> >(featureSize);
                    }
                }
                else
                {
                    var metadataBuilder = new DataViewSchema.Annotations.Builder();
                    if (InputSchema[FeatureColumn.Index].HasSlotNames(featureSize))
                    {
                        metadataBuilder.AddSlotNames(featureSize, (ref VBuffer <ReadOnlyMemory <char> > value) =>
                                                     FeatureColumn.Annotations.GetValue(AnnotationUtils.Kinds.SlotNames, ref value));
                    }

                    var schemaBuilder           = new DataViewSchema.Builder();
                    var featureContributionType = new VectorDataViewType(NumberDataViewType.Single, ((VectorDataViewType)FeatureColumn.Type).Dimensions);
                    schemaBuilder.AddColumn(DefaultColumnNames.FeatureContributions, featureContributionType, metadataBuilder.ToAnnotations());
                    _outputSchema = schemaBuilder.ToSchema();
                }

                _outputGenericSchema = _genericRowMapper.OutputSchema;
                OutputSchema         = new ZipBinding(new DataViewSchema[] { _outputGenericSchema, _outputSchema, }).OutputSchema;
            }
            private LabelNameBindableMapper(IHost host, ModelLoadContext ctx)
            {
                Contracts.AssertValue(host);
                _host = host;
                _host.AssertValue(ctx);

                ctx.LoadModel <ISchemaBindableMapper, SignatureLoadModel>(_host, out _bindable, _innerDir);
                BinarySaver  saver = new BinarySaver(_host, new BinarySaver.Arguments());
                DataViewType type;
                object       value;

                _host.CheckDecode(saver.TryLoadTypeAndValue(ctx.Reader.BaseStream, out type, out value));
                _type = type as VectorDataViewType;
                _host.CheckDecode(_type != null);
                _host.CheckDecode(value != null);
                _getter       = Utils.MarshalInvoke(DecodeInit <int>, _type.ItemType.RawType, value);
                _metadataKind = ctx.Header.ModelVerReadable >= VersionAddedMetadataKind?
                                ctx.LoadNonEmptyString() : AnnotationUtils.Kinds.SlotNames;
            }
        ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
        {
            Contracts.CheckValue(env, nameof(env));

            using (var ch = env.Register("SchemaBindableWrapper").Start("Bind"))
            {
                ch.CheckValue(schema, nameof(schema));
                if (schema.Feature?.Type is DataViewType type)
                {
                    // Ensure that the feature column type is compatible with the needed input type.
                    var typeIn = ValueMapper != null ? ValueMapper.InputType : new VectorDataViewType(NumberDataViewType.Single);
                    if (type != typeIn)
                    {
                        VectorDataViewType typeVectorType   = type as VectorDataViewType;
                        VectorDataViewType typeInVectorType = typeIn as VectorDataViewType;

                        DataViewType typeItemType   = typeVectorType?.ItemType ?? type;
                        DataViewType typeInItemType = typeInVectorType?.ItemType ?? typeIn;

                        if (!typeItemType.Equals(typeInItemType))
                        {
                            throw ch.Except("Incompatible features column type item type: '{0}' vs '{1}'", typeItemType, typeInItemType);
                        }
                        if ((typeVectorType != null) != (typeInVectorType != null))
                        {
                            throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
                        }
                        // typeIn can legally have unknown size.
                        int typeVectorSize   = typeVectorType?.Size ?? 0;
                        int typeInVectorSize = typeInVectorType?.Size ?? 0;
                        if (typeVectorSize != typeInVectorSize && typeInVectorSize > 0)
                        {
                            throw ch.Except("Incompatible features column type: '{0}' vs '{1}'", type, typeIn);
                        }
                    }
                }
                return(BindCore(ch, schema));
            }
        }
                /// <summary>
                /// Append label names to score column as its metadata.
                /// </summary>
                private DataViewSchema DecorateOutputSchema(DataViewSchema partialSchema, int scoreColumnIndex, VectorDataViewType labelNameType,
                                                            ValueGetter <VBuffer <T> > labelNameGetter, string labelNameKind)
                {
                    var builder = new DataViewSchema.Builder();

                    // Sequentially add columns so that the order of them is not changed comparing with the schema in the mapper
                    // that computes score column.
                    for (int i = 0; i < partialSchema.Count; ++i)
                    {
                        var meta = new DataViewSchema.Annotations.Builder();
                        if (i == scoreColumnIndex)
                        {
                            // Add label names for score column.
                            meta.Add(partialSchema[i].Annotations, selector: s => s != labelNameKind);
                            meta.Add(labelNameKind, labelNameType, labelNameGetter);
                        }
                        else
                        {
                            // Copy all existing metadata because this transform only affects score column.
                            meta.Add(partialSchema[i].Annotations, selector: s => true);
                        }
                        // Instead of appending extra metadata to the existing score column, we create new one because
                        // metadata is read-only.
                        builder.AddColumn(partialSchema[i].Name, partialSchema[i].Type, meta.ToAnnotations());
                    }
                    return(builder.ToSchema());
                }
            internal static ISchemaBoundMapper CreateBound <T>(IHostEnvironment env, ISchemaBoundRowMapper mapper, VectorDataViewType type, Delegate getter,
                                                               string metadataKind, Func <ISchemaBoundMapper, DataViewType, bool> canWrap)
            {
                Contracts.AssertValue(env);
                env.AssertValue(mapper);
                env.AssertValue(type);
                env.AssertValue(getter);
                env.Assert(getter is ValueGetter <VBuffer <T> >);
                env.AssertNonEmpty(metadataKind);
                env.AssertValueOrNull(canWrap);

                return(new Bound <T>(env, mapper, type, (ValueGetter <VBuffer <T> >)getter, metadataKind, canWrap));
            }
 private LabelNameBindableMapper(IHostEnvironment env, ISchemaBoundMapper mapper, VectorDataViewType type, Delegate getter,
                                 string metadataKind, Func <ISchemaBoundMapper, DataViewType, bool> canWrap)
     : this(env, mapper.Bindable, type, getter, metadataKind, canWrap)
 {
 }
            public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper owner, RoleMappedSchema schema,
                               string treesColumnName, string leavesColumnName, string pathsColumnName)
            {
                Contracts.AssertValue(ectx);
                ectx.AssertValue(owner);
                ectx.AssertValue(schema);
                ectx.Assert(schema.Feature.HasValue);

                _ectx = ectx;

                _owner = owner;
                InputRoleMappedSchema = schema;

                // A vector containing the output of each tree on a given example.
                var treeValueType = new VectorDataViewType(NumberDataViewType.Single, owner._ensemble.TrainedEnsemble.NumTrees);
                // An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
                // ends up in all the trees in the ensemble.
                var leafIdType = new VectorDataViewType(NumberDataViewType.Single, owner._totalLeafCount);
                // An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
                // the paths of the example in all the trees in the ensemble.
                // The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
                // and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
                // plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
                // which means that #internal = #leaf - 1.
                // Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
                var pathIdType = new VectorDataViewType(NumberDataViewType.Single, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);

                // Start creating output schema with types derived above.
                var schemaBuilder = new DataViewSchema.Builder();

                _treesColumnName = treesColumnName;
                if (treesColumnName != null)
                {
                    // Metadata of tree values.
                    var treeIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
                    treeIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(treeValueType.Size),
                                              (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetTreeSlotNames);

                    // Add the column of trees' output values
                    schemaBuilder.AddColumn(treesColumnName, treeValueType, treeIdMetadataBuilder.ToAnnotations());
                }

                _leavesColumnName = leavesColumnName;
                if (leavesColumnName != null)
                {
                    // Metadata of leaf IDs.
                    var leafIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
                    leafIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(leafIdType.Size),
                                              (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetLeafSlotNames);
                    leafIdMetadataBuilder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, (ref bool value) => value = true);

                    // Add the column of leaves' IDs where the input example reaches.
                    schemaBuilder.AddColumn(leavesColumnName, leafIdType, leafIdMetadataBuilder.ToAnnotations());
                }

                _pathsColumnName = pathsColumnName;
                if (pathsColumnName != null)
                {
                    // Metadata of path IDs.
                    var pathIdMetadataBuilder = new DataViewSchema.Annotations.Builder();
                    pathIdMetadataBuilder.Add(AnnotationUtils.Kinds.SlotNames, AnnotationUtils.GetNamesType(pathIdType.Size),
                                              (ValueGetter <VBuffer <ReadOnlyMemory <char> > >)owner.GetPathSlotNames);
                    pathIdMetadataBuilder.Add(AnnotationUtils.Kinds.IsNormalized, BooleanDataViewType.Instance, (ref bool value) => value = true);

                    // Add the column of encoded paths which the input example passes.
                    schemaBuilder.AddColumn(pathsColumnName, pathIdType, pathIdMetadataBuilder.ToAnnotations());
                }

                OutputSchema = schemaBuilder.ToSchema();
            }
Exemple #13
0
        public static InternalSchemaDefinition Create(Type userType, SchemaDefinition userSchemaDefinition)
        {
            Contracts.AssertValue(userType);
            Contracts.AssertValueOrNull(userSchemaDefinition);

            if (userSchemaDefinition == null)
            {
                userSchemaDefinition = SchemaDefinition.Create(userType);
            }

            Column[] dstCols = new Column[userSchemaDefinition.Count];

            for (int i = 0; i < userSchemaDefinition.Count; ++i)
            {
                var col = userSchemaDefinition[i];
                if (col.MemberName == null)
                {
                    throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Null field name detected in schema definition");
                }

                bool       isVector;
                Type       dataItemType;
                MemberInfo memberInfo = null;

                if (col.Generator == null)
                {
                    memberInfo = userType.GetField(col.MemberName);

                    if (memberInfo == null)
                    {
                        memberInfo = userType.GetProperty(col.MemberName);
                    }

                    if (memberInfo == null)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field or property with name '{0}' found in type '{1}'",
                                                    col.MemberName,
                                                    userType.FullName);
                    }

                    //Clause to handle the field that may be used to expose the cursor channel.
                    //This field does not need a column.
                    if ((memberInfo is FieldInfo && (memberInfo as FieldInfo).FieldType == typeof(IChannel)) ||
                        (memberInfo is PropertyInfo && (memberInfo as PropertyInfo).PropertyType == typeof(IChannel)))
                    {
                        continue;
                    }

                    GetVectorAndItemType(memberInfo, out isVector, out dataItemType);
                }
                else
                {
                    var parameterType = col.ReturnType;
                    if (parameterType == null)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No return parameter found in computed column.");
                    }
                    GetVectorAndItemType("returnType", parameterType, null, out isVector, out dataItemType);
                }
                // Infer the column name.
                var colName = string.IsNullOrEmpty(col.ColumnName) ? col.MemberName : col.ColumnName;
                // REVIEW: Because order is defined, we allow duplicate column names, since producing an IDataView
                // with duplicate column names is completely legal. Possible objection is that we should make it less
                // convenient to produce "hidden" columns, since this may not be of practical use to users.

                DataViewType colType;
                if (col.ColumnType == null)
                {
                    // Infer a type as best we can.
                    PrimitiveDataViewType itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataItemType);
                    colType = isVector ? new VectorDataViewType(itemType) : (DataViewType)itemType;
                }
                else
                {
                    // Make sure that the types are compatible with the declared type, including
                    // whether it is a vector type.
                    VectorDataViewType columnVectorType = col.ColumnType as VectorDataViewType;
                    if (isVector != (columnVectorType != null))
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to be {1}, but type of associated field '{2}' is {3}",
                                                    colName, columnVectorType != null ? "vector" : "scalar", col.MemberName, isVector ? "vector" : "scalar");
                    }
                    DataViewType itemType = columnVectorType?.ItemType ?? col.ColumnType;
                    if (itemType.RawType != dataItemType)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to have item type {1}, but associated field has type {2}",
                                                    colName, itemType.RawType, dataItemType);
                    }
                    colType = col.ColumnType;
                }

                dstCols[i] = col.Generator != null ?
                             new Column(colName, colType, col.Generator, col.AnnotationInfos)
                    : new Column(colName, colType, memberInfo, col.AnnotationInfos);
            }
            return(new InternalSchemaDefinition(dstCols));
        }
Exemple #14
0
            //private Delegate CreateGetter(SchemaProxy schema, int index, Delegate peek)
            private Delegate CreateGetter(DataViewType colType, InternalSchemaDefinition.Column column, Delegate peek)
            {
                var outputType  = column.OutputType;
                var genericType = outputType;
                FuncInstanceMethodInfo1 <InputRowBase <TRow>, Delegate, Delegate> del;

                if (outputType.IsArray)
                {
                    VectorDataViewType vectorType = colType as VectorDataViewType;
                    Host.Assert(vectorType != null);

                    // String[] -> ReadOnlyMemory<char>
                    if (outputType.GetElementType() == typeof(string))
                    {
                        Host.Assert(vectorType.ItemType is TextDataViewType);
                        return(CreateConvertingArrayGetterDelegate <string, ReadOnlyMemory <char> >(peek, x => x != null ? x.AsMemory() : ReadOnlyMemory <char> .Empty));
                    }

                    // T[] -> VBuffer<T>
                    if (outputType.GetElementType().IsGenericType&& outputType.GetElementType().GetGenericTypeDefinition() == typeof(Nullable <>))
                    {
                        Host.Assert(Nullable.GetUnderlyingType(outputType.GetElementType()) == vectorType.ItemType.RawType);
                    }
                    else
                    {
                        Host.Assert(outputType.GetElementType() == vectorType.ItemType.RawType);
                    }
                    del         = _createDirectArrayGetterDelegateMethodInfo;
                    genericType = outputType.GetElementType();
                }
                else if (colType is VectorDataViewType vectorType)
                {
                    // VBuffer<T> -> VBuffer<T>
                    // REVIEW: Do we care about accommodating VBuffer<string> -> ReadOnlyMemory<char>?
                    Host.Assert(outputType.IsGenericType);
                    Host.Assert(outputType.GetGenericTypeDefinition() == typeof(VBuffer <>));
                    Host.Assert(outputType.GetGenericArguments()[0] == vectorType.ItemType.RawType);
                    del         = _createDirectVBufferGetterDelegateMethodInfo;
                    genericType = vectorType.ItemType.RawType;
                }
                else if (colType is PrimitiveDataViewType)
                {
                    if (outputType == typeof(string))
                    {
                        // String -> ReadOnlyMemory<char>
                        Host.Assert(colType is TextDataViewType);
                        return(CreateConvertingGetterDelegate <String, ReadOnlyMemory <char> >(peek, x => x != null ? x.AsMemory() : ReadOnlyMemory <char> .Empty));
                    }

                    // T -> T
                    if (outputType.IsGenericType && outputType.GetGenericTypeDefinition() == typeof(Nullable <>))
                    {
                        Host.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType));
                    }
                    else
                    {
                        Host.Assert(colType.RawType == outputType);
                    }

                    if (!(colType is KeyDataViewType keyType))
                    {
                        del = _createDirectGetterDelegateMethodInfo;
                    }
                    else
                    {
                        var keyRawType = colType.RawType;
                        return(Utils.MarshalInvoke(_createKeyGetterDelegateMethodInfo, this, keyRawType, peek, colType));
                    }
                }