internal string GetTypeString() { string result = ItemType.ToString(); if (IsKey) { result = $"Key<{result}>"; } if (Kind == VectorKind.Vector) { result = $"Vector<{result}>"; } else if (Kind == VectorKind.VariableVector) { result = $"VarVector<{result}>"; } return(result); }
private BoundColumn MakeColumn(DataViewSchema inputSchema, int iinfo) { Contracts.AssertValue(inputSchema); Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); DataViewType itemType = null; int[] sources = new int[_parent._columns[iinfo].Sources.Count]; // Go through the columns, and establish the following: // - indices of input columns in the input schema. Throw if they are not there. // - output type. Throw if the types of inputs are not the same. // - how many slots are there in the output vector (or variable). Denoted by totalSize. // - total size of CategoricalSlotRanges metadata, if present. Denoted by catCount. // - whether the column is normalized. // It is true when ALL inputs are normalized (and of numeric type). // - whether the column has slot names. // It is true if ANY input is a scalar, or has slot names. // - whether the column has categorical slot ranges. // It is true if ANY input has this metadata. int totalSize = 0; int catCount = 0; bool isNormalized = true; bool hasSlotNames = false; bool hasCategoricals = false; for (int i = 0; i < _parent._columns[iinfo].Sources.Count; i++) { var(srcName, srcAlias) = _parent._columns[iinfo].Sources[i]; if (!inputSchema.TryGetColumnIndex(srcName, out int srcCol)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); } sources[i] = srcCol; var curType = inputSchema[srcCol].Type; VectorType curVectorType = curType as VectorType; DataViewType currentItemType = curVectorType?.ItemType ?? curType; int currentValueCount = curVectorType?.Size ?? 1; if (itemType == null) { itemType = currentItemType; totalSize = currentValueCount; } else if (currentItemType.Equals(itemType)) { // If any one input is variable length, then the output is variable length. if (totalSize == 0 || currentValueCount == 0) { totalSize = 0; } else { totalSize += currentValueCount; } } else { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, itemType.ToString(), curType.ToString()); } if (isNormalized && !inputSchema[srcCol].IsNormalized()) { isNormalized = false; } if (AnnotationUtils.TryGetCategoricalFeatureIndices(inputSchema, srcCol, out int[] typeCat))
private RangeFilter(IHost host, ModelLoadContext ctx, IDataView input) : base(host, input) { Host.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: sizeof(Float) // int: id of column name // double: min // double: max // byte: complement int cbFloat = ctx.Reader.ReadInt32(); Host.CheckDecode(cbFloat == sizeof(float)); var column = ctx.LoadNonEmptyString(); var schema = Source.Schema; if (!schema.TryGetColumnIndex(column, out _index)) { throw Host.ExceptSchemaMismatch(nameof(schema), "source", column); } _type = schema[_index].Type; if (_type != NumberDataViewType.Single && _type != NumberDataViewType.Double && _type.GetKeyCount() == 0) { throw Host.ExceptSchemaMismatch(nameof(schema), "source", column, "Single, Double or Key", _type.ToString()); } _min = ctx.Reader.ReadDouble(); _max = ctx.Reader.ReadDouble(); if (!(_min <= _max)) { throw Host.Except("min", "min must be less than or equal to max"); } _complement = ctx.Reader.ReadBoolByte(); _includeMin = ctx.Reader.ReadBoolByte(); _includeMax = ctx.Reader.ReadBoolByte(); }
public static ModelArgs GetModelArgs(DataViewType type, string colName, List <long> dims = null, List <bool> dimsParams = null) { Contracts.CheckValue(type, nameof(type)); Contracts.CheckNonEmpty(colName, nameof(colName)); TensorProto.Types.DataType dataType = TensorProto.Types.DataType.Undefined; Type rawType; if (type is VectorType vectorType) { rawType = vectorType.ItemType.RawType; } else { rawType = type.RawType; } if (rawType == typeof(bool)) { dataType = TensorProto.Types.DataType.Float; } else if (rawType == typeof(ReadOnlyMemory <char>)) { dataType = TensorProto.Types.DataType.String; } else if (rawType == typeof(sbyte)) { dataType = TensorProto.Types.DataType.Int8; } else if (rawType == typeof(byte)) { dataType = TensorProto.Types.DataType.Uint8; } else if (rawType == typeof(short)) { dataType = TensorProto.Types.DataType.Int16; } else if (rawType == typeof(ushort)) { dataType = TensorProto.Types.DataType.Uint16; } else if (rawType == typeof(int)) { dataType = TensorProto.Types.DataType.Int32; } else if (rawType == typeof(uint)) { dataType = TensorProto.Types.DataType.Int64; } else if (rawType == typeof(long)) { dataType = TensorProto.Types.DataType.Int64; } else if (rawType == typeof(ulong)) { dataType = TensorProto.Types.DataType.Uint64; } else if (rawType == typeof(float)) { dataType = TensorProto.Types.DataType.Float; } else if (rawType == typeof(double)) { dataType = TensorProto.Types.DataType.Double; } else { string msg = "Unsupported type: " + type.ToString(); Contracts.Check(false, msg); } string name = colName; List <long> dimsLocal = null; List <bool> dimsParamLocal = null; if (dims != null) { dimsLocal = dims; dimsParamLocal = dimsParams; } else { dimsLocal = new List <long>(); int valueCount = type.GetValueCount(); if (valueCount == 0) //Unknown size. { dimsLocal.Add(1); dimsParamLocal = new List <bool>() { false, true }; //false for batch size, true for dims. } else if (valueCount == 1) { dimsLocal.Add(1); } else if (valueCount > 1) { var vec = (VectorType)type; for (int i = 0; i < vec.Dimensions.Length; i++) { dimsLocal.Add(vec.Dimensions[i]); } } } //batch size. dimsLocal?.Insert(0, 1); return(new ModelArgs(name, dataType, dimsLocal, dimsParamLocal)); }