Exemplo n.º 1
0
 public static bool CompareSchema(ISchema sch1, ISchema sch2, bool raise = false)
 {
     if (sch1.ColumnCount != sch2.ColumnCount)
     {
         if (raise)
         {
             throw Contracts.Except("Different number of columns {0} != {1}\nS1: {2}\nS2: {3}",
                                    sch1.ColumnCount, sch2.ColumnCount,
                                    ToString(sch1), ToString(sch2));
         }
         else
         {
             return(false);
         }
     }
     for (int i = 0; i < sch1.ColumnCount; ++i)
     {
         if (sch1.GetColumnName(i) != sch2.GetColumnName(i))
         {
             if (raise)
             {
                 throw Contracts.Except("Column name {0} is different {1} != {2}\nS1: {3}\nS2: {4}",
                                        i, sch1.GetColumnName(i), sch2.GetColumnName(i),
                                        ToString(sch1), ToString(sch2));
             }
             else
             {
                 return(false);
             }
         }
         if (sch1.GetColumnType(i) != sch2.GetColumnType(i))
         {
             var  t1 = sch1.GetColumnType(i);
             var  t2 = sch2.GetColumnType(i);
             bool r  = t1 != t2;
             if (r && t1.IsVector() && t2.IsVector())
             {
                 var v1 = t1.AsVector();
                 var v2 = t2.AsVector();
                 r  = v1.DimCount() != v2.DimCount() || v1.GetKeyCount() != v2.GetKeyCount();
                 r |= v1.RawKind() != v2.RawKind();
                 r |= v1.ItemType() != v2.ItemType();
                 r |= v1.IsKnownSizeVector() != v2.IsKnownSizeVector();
             }
             if (r)
             {
                 if (raise)
                 {
                     throw Contracts.Except("Column type {0} is different {1} != {2}\nS1: {3}\nS2: {4}",
                                            i, t1, t2, ToString(sch1), ToString(sch2));
                 }
                 else
                 {
                     return(false);
                 }
             }
         }
     }
     return(true);
 }
Exemplo n.º 2
0
        public FieldAwareFactorizationMachineScalarRowMapper(IHostEnvironment env, RoleMappedSchema schema,
                                                             ISchema outputSchema, FieldAwareFactorizationMachinePredictor pred)
        {
            Contracts.AssertValue(env);
            Contracts.AssertValue(schema);
            Contracts.CheckParam(outputSchema.ColumnCount == 2, nameof(outputSchema));
            Contracts.CheckParam(outputSchema.GetColumnType(0).IsNumber, nameof(outputSchema));
            Contracts.CheckParam(outputSchema.GetColumnType(1).IsNumber, nameof(outputSchema));
            Contracts.AssertValue(pred);

            _env     = env;
            _columns = schema.GetColumns(RoleMappedSchema.ColumnRole.Feature).ToArray();
            _pred    = pred;

            var inputFeatureColumns = _columns.Select(c => new KeyValuePair <RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, c.Name)).ToList();

            InputSchema  = RoleMappedSchema.Create(schema.Schema, inputFeatureColumns);
            OutputSchema = outputSchema;

            _inputColumnIndexes = new List <int>();
            foreach (var kvp in inputFeatureColumns)
            {
                if (schema.Schema.TryGetColumnIndex(kvp.Value, out int index))
                {
                    _inputColumnIndexes.Add(index);
                }
            }
        }
        private void CheckInputColumnTypes(ISchema schema, out ColumnType labelType, out ColumnType scoreType,
                                           out ColumnMetadataInfo labelMetadata, out ColumnMetadataInfo scoreMetadata)
        {
            Host.AssertNonEmpty(ScoreCol);
            Host.AssertNonEmpty(LabelCol);

            var t = schema.GetColumnType(LabelIndex);

            if (!t.IsKnownSizeVector || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8))
            {
                throw Host.Except("Label column '{0}' has type '{1}' but must be a known-size vector of R4 or R8", LabelCol, t);
            }
            labelType = new VectorType(t.ItemType.AsPrimitive, t.VectorSize);
            var slotNamesType = new VectorType(TextType.Instance, t.VectorSize);

            labelMetadata = new ColumnMetadataInfo(LabelCol);
            labelMetadata.Add(MetadataUtils.Kinds.SlotNames, new MetadataInfo <VBuffer <DvText> >(slotNamesType,
                                                                                                  CreateSlotNamesGetter(schema, LabelIndex, labelType.VectorSize, "True")));

            t = schema.GetColumnType(ScoreIndex);
            if (t.VectorSize == 0 || t.ItemType != NumberType.Float)
            {
                throw Host.Except("Score column '{0}' has type '{1}' but must be a known length vector of type R4", ScoreCol, t);
            }
            scoreType     = new VectorType(t.ItemType.AsPrimitive, t.VectorSize);
            scoreMetadata = new ColumnMetadataInfo(ScoreCol);
            scoreMetadata.Add(MetadataUtils.Kinds.SlotNames, new MetadataInfo <VBuffer <DvText> >(slotNamesType,
                                                                                                  CreateSlotNamesGetter(schema, ScoreIndex, scoreType.VectorSize, "Predicted")));
            scoreMetadata.Add(MetadataUtils.Kinds.ScoreColumnKind, new MetadataInfo <DvText>(TextType.Instance, GetScoreColumnKind));
            scoreMetadata.Add(MetadataUtils.Kinds.ScoreValueKind, new MetadataInfo <DvText>(TextType.Instance, GetScoreValueKind));
            scoreMetadata.Add(MetadataUtils.Kinds.ScoreColumnSetId,
                              new MetadataInfo <uint>(MetadataUtils.ScoreColumnSetIdType, GetScoreColumnSetId(schema)));
        }
Exemplo n.º 4
0
 public XGBoostMulticlassRowMapper(RoleMappedSchema schema, XGBoostMulticlassPredictor parent, IHostEnvironment env,
                                   ISchema outputSchema, int[] classMapping, int numberOfClasses)
     : base(schema, parent, env, outputSchema)
 {
     env.Assert(outputSchema.ColumnCount == 1, "outputSchema");
     env.Assert(outputSchema.GetColumnType(0).IsVector, "outputSchema");
     env.Assert(outputSchema.GetColumnType(0).ItemType.IsNumber, "outputSchema");
     env.Assert(classMapping == null || Utils.IsIncreasing(0, _classMapping, int.MaxValue), "classMapping");
     _classMapping    = classMapping;
     _numberOfClasses = numberOfClasses;
 }
        /// <summary>
        /// To display a schema.
        /// </summary>
        /// <param name="schema">schema</param>
        /// <param name="sep">column separator</param>
        /// <param name="vectorVec">if true, show Vec<R4, 2> and false, shows :R4:5-6 </R4> does the same for keys</param>
        /// <param name="keepHidden">keepHidden columns?</param>
        /// <returns>schema as a string</returns>
        public static string ToString(ISchema schema, string sep = "; ", bool vectorVec = true, bool keepHidden = false)
        {
            var    builder = new StringBuilder();
            string name, type;
            string si;
            int    lag = 0;

            for (int i = 0; i < schema.ColumnCount; ++i)
            {
                if (!keepHidden && schema.IsHidden(i))
                {
                    continue;
                }
                if (builder.Length > 0)
                {
                    builder.Append(sep);
                }
                name = schema.GetColumnName(i);
                var t = schema.GetColumnType(i);
                if (vectorVec || (!t.IsVector && !t.IsKey))
                {
                    type = schema.GetColumnType(i).ToString().Replace(" ", "");
                    si   = (i + lag).ToString();
                }
                else
                {
                    if (t.IsVector)
                    {
                        if (t.AsVector.DimCount != 1)
                        {
                            throw Contracts.ExceptNotSupp("Only vector with one dimension are supported.");
                        }
                        type = t.ItemType.RawKind.ToString();
                        si   = string.Format("{0}-{1}", i + lag, i + lag + t.AsVector.GetDim(0) - 1);
                        lag += t.AsVector.GetDim(0) - 1;
                    }
                    else if (t.IsKey && t.AsKey.Contiguous)
                    {
                        var k = t.AsKey;
                        type = k.Count > 0
                                    ? string.Format("{0}[{1}-{2}]", k.RawKind, k.Min, k.Min + (ulong)k.Count - 1)
                                    : string.Format("{0}[{1}-{2}]", k.RawKind, k.Min, "*");
                        si = i.ToString();
                    }
                    else
                    {
                        throw Contracts.ExceptNotImpl(string.Format("Unable to process type '{0}'.", t));
                    }
                }

                builder.Append(string.Format("{0}:{1}:{2}", name, type, si));
            }
            return(builder.ToString());
        }
            private void CheckInputSchema(ISchema schema, int matrixColumnIndexCol, int matrixRowIndexCol)
            {
                // See if matrix-column-index role's type matches the one expected in the trained predictor
                var    type = schema.GetColumnType(matrixColumnIndexCol);
                string msg  = string.Format("Input column index type '{0}' incompatible with predictor's column index type '{1}'", type, _parent.MatrixColumnIndexType);

                _env.CheckParam(type.Equals(_parent.MatrixColumnIndexType), nameof(schema), msg);

                // See if matrix-column-index  role's type matches the one expected in the trained predictor
                type = schema.GetColumnType(matrixRowIndexCol);
                msg  = string.Format("Input row index type '{0}' incompatible with predictor' row index type '{1}'", type, _parent.MatrixRowIndexType);
                _env.CheckParam(type.Equals(_parent.MatrixRowIndexType), nameof(schema), msg);
            }
Exemplo n.º 7
0
        /// <summary>
        /// Create a schema shape out of the fully defined schema.
        /// </summary>
        public static SchemaShape Create(ISchema schema)
        {
            Contracts.CheckValue(schema, nameof(schema));
            var cols = new List <Column>();

            for (int iCol = 0; iCol < schema.ColumnCount; iCol++)
            {
                if (!schema.IsHidden(iCol))
                {
                    Column.VectorKind vecKind;
                    var type = schema.GetColumnType(iCol);
                    if (type.IsKnownSizeVector)
                    {
                        vecKind = Column.VectorKind.Vector;
                    }
                    else if (type.IsVector)
                    {
                        vecKind = Column.VectorKind.VariableVector;
                    }
                    else
                    {
                        vecKind = Column.VectorKind.Scalar;
                    }

                    var kind  = type.ItemType.RawKind;
                    var isKey = type.ItemType.IsKey;

                    var metadataNames = schema.GetMetadataTypes(iCol)
                                        .Select(kvp => kvp.Key)
                                        .ToArray();
                    cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames));
                }
            }
            return(new SchemaShape(cols.ToArray()));
        }
Exemplo n.º 8
0
            // Computes the types of the columns and constructs the kvMaps.
            private void ComputeKvMaps(ISchema schema, out ColumnType[] types, out KeyToValueMap[] kvMaps)
            {
                types  = new ColumnType[_parent.ColumnPairs.Length];
                kvMaps = new KeyToValueMap[_parent.ColumnPairs.Length];
                for (int iinfo = 0; iinfo < types.Length; iinfo++)
                {
                    // Construct kvMaps.
                    Contracts.Assert(types[iinfo] == null);
                    var typeSrc  = schema.GetColumnType(ColMapNewToOld[iinfo]);
                    var typeVals = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, ColMapNewToOld[iinfo]);
                    Host.Check(typeVals != null, "Metadata KeyValues does not exist");
                    Host.Check(typeVals.VectorSize == typeSrc.ItemType.KeyCount, "KeyValues metadata size does not match column type key count");
                    if (!typeSrc.IsVector)
                    {
                        types[iinfo] = typeVals.ItemType;
                    }
                    else
                    {
                        types[iinfo] = new VectorType(typeVals.ItemType.AsPrimitive, typeSrc.AsVector);
                    }

                    // MarshalInvoke with two generic params.
                    Func <int, ColumnType, ColumnType, KeyToValueMap> func = GetKeyMetadata <int, int>;
                    var meth = func.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(
                        new Type[] { typeSrc.ItemType.RawType, types[iinfo].ItemType.RawType });
                    kvMaps[iinfo] = (KeyToValueMap)meth.Invoke(this, new object[] { iinfo, typeSrc, typeVals });
                }
            }
Exemplo n.º 9
0
        public PredictionTransformerBase(IHost host, TModel model, ISchema trainSchema, string featureColumn)
        {
            Contracts.CheckValue(host, nameof(host));
            Contracts.CheckValueOrNull(featureColumn);
            Host = host;
            Host.CheckValue(trainSchema, nameof(trainSchema));

            Model         = model;
            FeatureColumn = featureColumn;
            if (featureColumn == null)
            {
                FeatureColumnType = null;
            }
            else if (!trainSchema.TryGetColumnIndex(featureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(featureColumn), RoleMappedSchema.ColumnRole.Feature.Value, featureColumn);
            }
            else
            {
                FeatureColumnType = trainSchema.GetColumnType(col);
            }

            TrainSchema    = trainSchema;
            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
        }
            /// <summary>
            /// Binds and validate the type of a column with the given name using input schema and returns the column index.
            /// Fails if there is no column with the given name or if the column type is not valid.
            /// </summary>
            /// <param name="input">The input Schema.</param>
            /// <param name="name">The column name.</param>
            /// <param name="isValidType">Whether the column type is valid.</param>
            /// <param name="expectedType">The expected type of the column.</param>
            /// <param name="index">The column index.</param>
            /// <param name="type">The column type.</param>
            /// <param name="exceptUser">A flag that determines the exception type thrown in failure.</param>
            private static void Bind(ISchema input, string name, Predicate<ColumnType> isValidType, string expectedType,
                out int index, out ColumnType type, bool exceptUser = true)
            {
                Contracts.AssertValue(input);
                Contracts.AssertValue(name);

                if (name == "text")
                {
                    name = "Text";
                }
                if (!input.TryGetColumnIndex(name, out index))
                {
                    throw exceptUser
                        ? Contracts.ExceptUserArg(nameof(Arguments.Column), "Source column '{0}' not found", name)
                        : Contracts.ExceptDecode("Source column '{0}' not found", name);
                }

                type = input.GetColumnType(index);
                if (!isValidType(type))
                {
                    throw exceptUser
                        ? Contracts.ExceptUserArg(nameof(Arguments.Column), "Source column '{0}' has type '{1}' but must be '{2}'", name, type, expectedType)
                        : Contracts.ExceptDecode("Source column '{0}' has type '{1}' but must be '{2}'", name, type, expectedType);
                }
            }
            private int[] GetColumnIds(ISchema schema, string[] names, Func <string, Exception> except)
            {
                Contracts.AssertValue(schema);
                Contracts.AssertValue(names);

                var ids = new int[names.Length];

                for (int i = 0; i < names.Length; i++)
                {
                    int col;
                    if (!schema.TryGetColumnIndex(names[i], out col))
                    {
                        throw except(string.Format("Could not find column '{0}'", names[i]));
                    }

                    var colType = schema.GetColumnType(col);
                    if (!colType.IsPrimitive)
                    {
                        throw except(string.Format("Column '{0}' has type '{1}', but must have a primitive type", names[i], colType));
                    }

                    ids[i] = col;
                }

                return(ids);
            }
Exemplo n.º 12
0
        internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
        {
            Host = host;

            ctx.LoadModel <TModel, SignatureLoadModel>(host, out TModel model, DirModel);
            Model = model;

            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // id of string: feature column.

            // Clone the stream with the schema into memory.
            var ms = new MemoryStream();

            ctx.TryLoadBinaryStream(DirTransSchema, reader =>
            {
                reader.BaseStream.CopyTo(ms);
            });

            ms.Position = 0;
            var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);

            TrainSchema = loader.Schema;

            FeatureColumn = ctx.LoadString();
            if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
            }
            FeatureColumnType = TrainSchema.GetColumnType(col);

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
        }
Exemplo n.º 13
0
        /// <summary>
        /// This is a utility method used to determine whether <see cref="LabelNameBindableMapper"/>
        /// can or should be used to wrap <paramref name="mapper"/>. This will not throw, since the
        /// desired behavior in the event that it cannot be wrapped, is to just back off to the original
        /// "unwrapped" bound mapper.
        /// </summary>
        /// <param name="mapper">The mapper we are seeing if we can wrap</param>
        /// <param name="labelNameType">The type of the label names from the metadata (either
        /// originating from the key value metadata of the training label column, or deserialized
        /// from the model of a bindable mapper)</param>
        /// <returns>Whether we can call <see cref="LabelNameBindableMapper.CreateBound{T}"/> with
        /// this mapper and expect it to succeed</returns>
        public static bool CanWrap(ISchemaBoundMapper mapper, ColumnType labelNameType)
        {
            Contracts.AssertValue(mapper);
            Contracts.AssertValue(labelNameType);

            ISchemaBoundRowMapper rowMapper = mapper as ISchemaBoundRowMapper;

            if (rowMapper == null)
            {
                return(false); // We could cover this case, but it is of no practical worth as far as I see, so I decline to do so.
            }
            ISchema outSchema = mapper.Schema;
            int     scoreIdx;

            if (!outSchema.TryGetColumnIndex(MetadataUtils.Const.ScoreValueKind.Score, out scoreIdx))
            {
                return(false); // The mapper doesn't even publish a score column to attach the metadata to.
            }
            if (outSchema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, scoreIdx) != null)
            {
                return(false); // The mapper publishes a score column, and already produces its own slot names.
            }
            var scoreType = outSchema.GetColumnType(scoreIdx);

            // Check that the type is vector, and is of compatible size with the score output.
            return(labelNameType.IsVector && labelNameType.VectorSize == scoreType.VectorSize);
        }
Exemplo n.º 14
0
        /// <summary>
        /// Creates a ColumnInfo for the column with the given column index. Note that the name
        /// of the column might actually map to a different column, so this should be used with care
        /// and rarely.
        /// </summary>
        public static ColumnInfo CreateFromIndex(ISchema schema, int index)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.CheckParam(0 <= index && index < schema.ColumnCount, nameof(index));

            return(new ColumnInfo(schema.GetColumnName(index), index, schema.GetColumnType(index)));
        }
Exemplo n.º 15
0
            private static void CheckAndBind(IExceptionContext ectx, ISchema inputSchema,
                                             string[] pivotColumns, out PivotColumnInfo[] infos)
            {
                Contracts.AssertValueOrNull(ectx);
                ectx.AssertValue(inputSchema);
                ectx.AssertNonEmpty(pivotColumns);

                infos = new PivotColumnInfo[pivotColumns.Length];
                for (int i = 0; i < pivotColumns.Length; i++)
                {
                    var name = pivotColumns[i];
                    // REVIEW: replace Check with CheckUser, once existing CheckUser is renamed to CheckUserArg or something.
                    ectx.CheckUserArg(!string.IsNullOrEmpty(name), nameof(Arguments.Column), "Column name cannot be empty");
                    int col;
                    if (!inputSchema.TryGetColumnIndex(name, out col))
                    {
                        throw ectx.ExceptUserArg(nameof(Arguments.Column), "Pivot column '{0}' is not found", name);
                    }
                    var colType = inputSchema.GetColumnType(col);
                    if (!colType.IsVector || !colType.ItemType.IsPrimitive)
                    {
                        throw ectx.ExceptUserArg(nameof(Arguments.Column),
                                                 "Pivot column '{0}' has type '{1}', but must be a vector of primitive types", name, colType);
                    }
                    infos[i] = new PivotColumnInfo(name, col, colType.VectorSize, colType.ItemType.AsPrimitive);
                }
            }
            public Mapper(IHostEnvironment env, SequentialAnomalyDetectionTransformBase <TInput, TState> parent, ISchema inputSchema)
            {
                Contracts.CheckValue(env, nameof(env));
                _host = env.Register(nameof(Mapper));
                _host.CheckValue(inputSchema, nameof(inputSchema));
                _host.CheckValue(parent, nameof(parent));

                if (!inputSchema.TryGetColumnIndex(parent.InputColumnName, out _inputColumnIndex))
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName);
                }

                var colType = inputSchema.GetColumnType(_inputColumnIndex);

                if (colType != NumberType.R4)
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName, NumberType.R4.ToString(), colType.ToString());
                }

                _parent       = parent;
                _parentSchema = inputSchema;
                _slotNames    = new VBuffer <ReadOnlyMemory <char> >(4, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(),
                                                                                "P-Value Score".AsMemory(), "Martingale Score".AsMemory() });

                State = _parent.StateRef;
            }
Exemplo n.º 17
0
        /// <summary>
        /// The categoricalFeatures is a vector of the indices of categorical features slots.
        /// This vector should always have an even number of elements, and the elements should be parsed in groups of two consecutive numbers.
        /// So if its value is the range of numbers: 0,2,3,4,8,9
        /// look at it as [0,2],[3,4],[8,9].
        /// The way to interpret that is: feature with indices 0, 1, and 2 are one categorical
        /// Features with indices 3 and 4 are another categorical. Features 5 and 6 don't appear there, so they are not categoricals.
        /// </summary>
        public static bool TryGetCategoricalFeatureIndices(ISchema schema, int colIndex, out int[] categoricalFeatures)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.Check(colIndex >= 0, nameof(colIndex));

            bool isValid = false;

            categoricalFeatures = null;
            if (!schema.GetColumnType(colIndex).IsKnownSizeVector)
            {
                return(isValid);
            }

            var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex);

            if (type?.RawType == typeof(VBuffer <int>))
            {
                VBuffer <int> catIndices = default(VBuffer <int>);
                schema.GetMetadata(MetadataUtils.Kinds.CategoricalSlotRanges, colIndex, ref catIndices);
                VBufferUtils.Densify(ref catIndices);
                int columnSlotsCount = schema.GetColumnType(colIndex).AsVector.VectorSizeCore;
                if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
                {
                    int previousEndIndex = -1;
                    isValid = true;
                    for (int i = 0; i < catIndices.Values.Length; i += 2)
                    {
                        if (catIndices.Values[i] > catIndices.Values[i + 1] ||
                            catIndices.Values[i] <= previousEndIndex ||
                            catIndices.Values[i] >= columnSlotsCount ||
                            catIndices.Values[i + 1] >= columnSlotsCount)
                        {
                            isValid = false;
                            break;
                        }

                        previousEndIndex = catIndices.Values[i + 1];
                    }
                    if (isValid)
                    {
                        categoricalFeatures = catIndices.Values.Select(val => val).ToArray();
                    }
                }
            }

            return(isValid);
        }
Exemplo n.º 18
0
        private static IColumn GetColumnCore <T>(ISchema schema, int col)
        {
            Contracts.AssertValue(schema);
            Contracts.Assert(0 <= col && col < schema.ColumnCount);
            Contracts.Assert(schema.GetColumnType(col).RawType == typeof(T));

            return(new SchemaWrap <T>(schema, col));
        }
        private JToken PfaTypeOrNullForColumn(ISchema schema, int col)
        {
            _host.AssertValue(schema);
            _host.Assert(0 <= col && col < schema.ColumnCount);

            ColumnType type = schema.GetColumnType(col);

            return(T.PfaTypeOrNullForColumnType(type));
        }
Exemplo n.º 20
0
        private void CheckInputColumnTypes(ISchema schema)
        {
            Host.AssertNonEmpty(ScoreCol);
            Host.AssertNonEmpty(LabelCol);

            var t = schema.GetColumnType(LabelIndex);

            if (t != NumberType.R4)
            {
                throw Host.Except("Label column '{0}' has type '{1}' but must be R4", LabelCol, t);
            }

            t = schema.GetColumnType(ScoreIndex);
            if (t.IsVector || t.ItemType != NumberType.Float)
            {
                throw Host.Except("Score column '{0}' has type '{1}' but must be R4", ScoreCol, t);
            }
        }
Exemplo n.º 21
0
        /// <summary>
        /// Exposes a single column in a schema. The column is considered inactive.
        /// </summary>
        /// <param name="schema">The schema to get the data for</param>
        /// <param name="col">The column to get</param>
        /// <returns>A column with <see cref="IColumn.IsActive"/> false</returns>
        public static IColumn GetColumn(ISchema schema, int col)
        {
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.CheckParam(0 <= col && col < schema.ColumnCount, nameof(col));

            Func <ISchema, int, IColumn> func = GetColumnCore <int>;

            return(Utils.MarshalInvoke(func, schema.GetColumnType(col).RawType, schema, col));
        }
Exemplo n.º 22
0
            public Mapper(IHostEnvironment env, OnnxTransform parent, ISchema inputSchema)
            {
                Contracts.CheckValue(env, nameof(env));
                _host = env.Register(nameof(Mapper));
                _host.CheckValue(inputSchema, nameof(inputSchema));
                _host.CheckValue(parent, nameof(parent));

                _parent = parent;
                var model = _parent.Model;

                _idvToTensorAdapter = new IdvToTensorAdapter(inputSchema, parent._args.InputColumn,
                                                             model.ModelInfo.InputsInfo[0]);

                // TODO: Remove assumption below
                // Assume first output dimension is 1
                var outputNodeInfo = model.ModelInfo.OutputsInfo[0];
                var inputNodeInfo  = model.ModelInfo.InputsInfo[0];

                int[] dims           = outputNodeInfo.Shape.Skip(1).Select(x => (int)x).ToArray();
                var   outputItemType = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
                var   inputShape     = inputNodeInfo.Shape;

                _outputColType     = new VectorType(outputItemType, dims);
                _outputColName     = _parent.Output;
                _outputItemRawType = outputItemType.RawType;

                int inColIndex;

                if (!inputSchema.TryGetColumnIndex(_parent.Input, out inColIndex))
                {
                    throw _host.Except($"Column {_parent.Input} doesn't exist");
                }

                var type = inputSchema.GetColumnType(inColIndex);

                if (type.IsVector && type.VectorSize == 0)
                {
                    throw _host.Except($"Variable length input columns not supported");
                }

                if (type.ItemType != outputItemType)
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Input, outputItemType.ToString(), type.ToString());
                }

                // If the column is one dimension we make sure that the total size of the TF shape matches.
                // Compute the total size of the known dimensions of the shape.
                int valCount = inputShape.Select(x => (int)x).Where(x => x > 0).Aggregate((x, y) => x * y);

                // The column length should be divisible by this, so that the other dimensions can be integral.
                if (type.ValueCount % valCount != 0)
                {
                    throw Contracts.Except($"Input shape mismatch: Input '{_outputColName}' has shape {String.Join(",", inputShape)}, but input data is of length {type.ValueCount}.");
                }

                _host.Assert(_outputItemRawType == _outputColType.ItemType.RawType);
            }
Exemplo n.º 23
0
            public SchemaWrap(ISchema schema, int col)
            {
                Contracts.AssertValue(schema);
                Contracts.Assert(0 <= col && col < schema.ColumnCount);
                Contracts.Assert(schema.GetColumnType(col).RawType == typeof(T));

                _schema = schema;
                _col    = col;
            }
 public ColumnType GetColumnType(int col)
 {
     CheckColumnInRange(col);
     if (col < _groupCount)
     {
         return(_input.GetColumnType(GroupIds[col]));
     }
     return(_columnTypes[col - _groupCount]);
 }
            public Contents(ModelLoadContext ctx, ISchema input, Func <ColumnType[], string> testTypes)
            {
                Contracts.CheckValue(ctx, nameof(ctx));
                Contracts.CheckValue(input, nameof(input));
                Contracts.CheckValueOrNull(testTypes);

                Input = input;

                // *** Binary format ***
                // int: number of added columns
                // for each added column
                //   int: id of output column name
                //   int: number of input column names
                //   int[]: ids of input column names
                int cinfo = ctx.Reader.ReadInt32();

                Contracts.CheckDecode(cinfo > 0);

                Infos = new ColInfo[cinfo];
                Names = new string[cinfo];
                for (int i = 0; i < cinfo; i++)
                {
                    Names[i] = ctx.LoadNonEmptyString();

                    int csrc = ctx.Reader.ReadInt32();
                    Contracts.CheckDecode(csrc > 0);
                    int[] indices  = new int[csrc];
                    var   srcTypes = new ColumnType[csrc];
                    int?  srcSize  = 0;
                    for (int j = 0; j < csrc; j++)
                    {
                        string src = ctx.LoadNonEmptyString();
                        if (!input.TryGetColumnIndex(src, out indices[j]))
                        {
                            throw Contracts.Except("Source column '{0}' is required but not found", src);
                        }
                        srcTypes[j] = input.GetColumnType(indices[j]);
                        var size = srcTypes[j].ValueCount;
                        srcSize = size == 0 ? null : checked (srcSize + size);
                    }

                    if (testTypes != null)
                    {
                        string reason = testTypes(srcTypes);
                        if (reason != null)
                        {
                            throw Contracts.Except("Source columns '{0}' have invalid types: {1}. Source types: '{2}'.",
                                                   string.Join(", ", indices.Select(k => input.GetColumnName(k))),
                                                   reason,
                                                   string.Join(", ", srcTypes.Select(type => type.ToString())));
                        }
                    }

                    Infos[i] = new ColInfo(srcSize.GetValueOrDefault(), indices, srcTypes);
                }
            }
        private void CheckInputColumnTypes(ISchema schema)
        {
            Host.AssertNonEmpty(ScoreCol);
            Host.AssertNonEmpty(LabelCol);

            var t = schema.GetColumnType(LabelIndex);

            if (t != NumberType.R4)
            {
                throw Host.Except("Label column '{0}' has type '{1}' but must be R4", LabelCol, t);
            }

            t = schema.GetColumnType(ScoreIndex);
            if (t.VectorSize == 0 || (t.ItemType != NumberType.R4 && t.ItemType != NumberType.R8))
            {
                throw Host.Except(
                          "Score column '{0}' has type '{1}' but must be a known length vector of type R4 or R8", ScoreCol, t);
            }
        }
Exemplo n.º 27
0
        protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol)
        {
            var inType = inputSchema.GetColumnType(srcCol);
            var reason = TestColumn(inType);

            if (reason != null)
            {
                throw Host.ExceptParam(nameof(inputSchema), reason);
            }
        }
Exemplo n.º 28
0
 public ColumnType GetColumnType(int col)
 {
     _ectx.Check(0 <= col && col < ColumnCount);
     if (!IsPivot(col))
     {
         return(_inputSchema.GetColumnType(col));
     }
     _ectx.Assert(0 <= _pivotIndex[col] && _pivotIndex[col] < _infos.Length);
     return(_infos[_pivotIndex[col]].ItemType);
 }
Exemplo n.º 29
0
        public static DataViewType GetColumnType(ISchema schema, string name)
        {
            int index;

            if (!schema.TryGetColumnIndex(name, out index))
            {
                throw Contracts.Except($"Unable to find column '{name}' in schema\n{ToString(schema)}.");
            }
            return(schema.GetColumnType(index));
        }
Exemplo n.º 30
0
            public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx, ISchema input,
                                          ITransposeSchema transInput, Func <ColumnType, string> testType)
            {
                Contracts.AssertValue(parent);
                var host = parent.Host;

                host.CheckValue(ctx, nameof(ctx));
                host.AssertValue(input);
                host.AssertValueOrNull(transInput);
                host.AssertValueOrNull(testType);

                // *** Binary format ***
                // int: number of added columns
                // for each added column
                //   int: id of output column name
                //   int: id of input column name
                int cinfo = ctx.Reader.ReadInt32();

                host.CheckDecode(cinfo > 0);

                var names = new string[cinfo];
                var infos = new ColInfo[cinfo];

                for (int i = 0; i < cinfo; i++)
                {
                    string dst = ctx.LoadNonEmptyString();
                    names[i] = dst;

                    // Note that in old files, the source name may be null indicating that
                    // the source column has the same name as the added column.
                    string tmp = ctx.LoadStringOrNull();
                    string src = tmp ?? dst;
                    host.CheckDecode(!string.IsNullOrEmpty(src));

                    int colSrc;
                    if (!input.TryGetColumnIndex(src, out colSrc))
                    {
                        throw host.Except("Source column '{0}' is required but not found", src);
                    }
                    var type = input.GetColumnType(colSrc);
                    if (testType != null)
                    {
                        string reason = testType(type);
                        if (reason != null)
                        {
                            throw host.Except(InvalidTypeErrorFormat, src, type, reason);
                        }
                    }
                    var slotType = transInput == null ? null : transInput.GetSlotType(colSrc);
                    infos[i] = new ColInfo(dst, colSrc, type, slotType);
                }

                return(new Bindings(parent, infos, input, false, names));
            }