/// <summary> /// Check for a standard binary classification label. /// </summary> public static void CheckBinaryLabel(this RoleMappedData data) { Contracts.CheckValue(data, nameof(data)); if (!data.Schema.Label.HasValue) { throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column."); } var col = data.Schema.Label.Value; Contracts.Assert(!col.IsHidden); if (col.Type != BooleanDataViewType.Instance && col.Type != NumberDataViewType.Single && col.Type != NumberDataViewType.Double && !(col.Type is KeyDataViewType keyType && keyType.Count == 2)) { KeyDataViewType colKeyType = col.Type as KeyDataViewType; if (colKeyType != null) { if (colKeyType.Count == 1) { throw Contracts.ExceptParam(nameof(data), "The label column '{0}' of the training data has only one class. Two classes are required for binary classification.", col.Name); } else if (colKeyType.Count > 2) { throw Contracts.ExceptParam(nameof(data), "The label column '{0}' of the training data has more than two classes. Only two classes are allowed for binary classification.", col.Name); } } throw Contracts.ExceptParam(nameof(data), "The label column '{0}' of the training data has a data type not suitable for binary classification: {1}. Type must be Boolean, Single, Double or Key with two classes.", col.Name, col.Type); } }
public void SequencePredictorSchemaTest() { int keyCount = 10; var expectedScoreColumnType = new KeyDataViewType(typeof(uint), keyCount); VBuffer <ReadOnlyMemory <char> > keyNames = GenerateKeyNames(keyCount); var sequenceSchema = ScoreSchemaFactory.CreateSequencePredictionSchema(expectedScoreColumnType, AnnotationUtils.Const.ScoreColumnKind.SequenceClassification, keyNames); // Output schema should only contain one column, which is the predicted label. Assert.Single(sequenceSchema); var scoreColumn = sequenceSchema[0]; // Check score column name. Assert.Equal(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, scoreColumn.Name); // Check score column type. var actualScoreColumnType = scoreColumn.Type as KeyDataViewType; Assert.NotNull(actualScoreColumnType); Assert.Equal(expectedScoreColumnType.Count, actualScoreColumnType.Count); Assert.Equal(expectedScoreColumnType.RawType, actualScoreColumnType.RawType); // Check metadata. Because keyNames is not empty, there should be three metadata fields. var scoreMetadata = scoreColumn.Annotations; Assert.Equal(3, scoreMetadata.Schema.Count); // Check metadata columns' names. Assert.Equal(AnnotationUtils.Kinds.KeyValues, scoreMetadata.Schema[0].Name); Assert.Equal(AnnotationUtils.Kinds.ScoreColumnKind, scoreMetadata.Schema[1].Name); Assert.Equal(AnnotationUtils.Kinds.ScoreValueKind, scoreMetadata.Schema[2].Name); // Check metadata columns' types. Assert.True(scoreMetadata.Schema[0].Type is VectorDataViewType); Assert.Equal(keyNames.Length, (scoreMetadata.Schema[0].Type as VectorDataViewType).Size); Assert.Equal(TextDataViewType.Instance, (scoreMetadata.Schema[0].Type as VectorDataViewType).ItemType); Assert.Equal(TextDataViewType.Instance, scoreColumn.Annotations.Schema[1].Type); Assert.Equal(TextDataViewType.Instance, scoreColumn.Annotations.Schema[2].Type); // Check metadata columns' values. var keyNamesGetter = scoreMetadata.GetGetter <VBuffer <ReadOnlyMemory <char> > >(scoreMetadata.Schema[0]); var actualKeyNames = new VBuffer <ReadOnlyMemory <char> >(); keyNamesGetter(ref actualKeyNames); Assert.Equal(keyNames.Length, actualKeyNames.Length); Assert.Equal(keyNames.DenseValues(), actualKeyNames.DenseValues()); var scoreColumnKindGetter = scoreMetadata.GetGetter <ReadOnlyMemory <char> >(scoreMetadata.Schema[1]); ReadOnlyMemory <char> scoreColumnKindValue = null; scoreColumnKindGetter(ref scoreColumnKindValue); Assert.Equal(AnnotationUtils.Const.ScoreColumnKind.SequenceClassification, scoreColumnKindValue.ToString()); var scoreValueKindGetter = scoreMetadata.GetGetter <ReadOnlyMemory <char> >(scoreMetadata.Schema[2]); ReadOnlyMemory <char> scoreValueKindValue = null; scoreValueKindGetter(ref scoreValueKindValue); Assert.Equal(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, scoreValueKindValue.ToString()); }
private static IDataView AppendFloatMapper <TInput>(IHostEnvironment env, IChannel ch, IDataView input, string col, KeyDataViewType type, int seed) { // Any key is convertible to ulong, so rather than add special case handling for all possible // key-types we just upfront convert it to the most general type (ulong) and work from there. KeyDataViewType dstType = new KeyDataViewType(typeof(ulong), type.Count); bool identity; var converter = Conversions.Instance.GetStandardConversion <TInput, ulong>(type, dstType, out identity); var isNa = Conversions.Instance.GetIsNAPredicate <TInput>(type); ValueMapper <TInput, Single> mapper; if (seed == 0) { mapper = (in TInput src, ref Single dst) => { //Attention: This method is called from multiple threads. //Do not move the temp variable outside this method. //If you do, the variable is shared between the threads and results in a race condition. ulong temp = 0; if (isNa(in src)) { dst = Single.NaN; return; } converter(in src, ref temp); dst = (Single)temp - 1; }; } else { ch.Check(type.Count > 0, "Label must be of known cardinality."); int[] permutation = Utils.GetRandomPermutation(RandomUtils.Create(seed), type.GetCountAsInt32(env)); mapper = (in TInput src, ref Single dst) => { //Attention: This method is called from multiple threads. //Do not move the temp variable outside this method. //If you do, the variable is shared between the threads and results in a race condition. ulong temp = 0; if (isNa(in src)) { dst = Single.NaN; return; } converter(in src, ref temp); dst = (Single)permutation[(int)(temp - 1)]; }; } return(LambdaColumnMapper.Create(env, "Key to Float Mapper", input, col, col, type, NumberDataViewType.Single, mapper)); }
public Mapper(TokenizingByCharactersTransformer parent, DataViewSchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; var keyType = new KeyDataViewType(typeof(ushort), CharsCount); _type = new VectorDataViewType(keyType); _isSourceVector = new bool[_parent.ColumnPairs.Length]; for (int i = 0; i < _isSourceVector.Length; i++) { _isSourceVector[i] = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type is VectorDataViewType; } }
public ClusteringPerInstanceEvaluator(IHostEnvironment env, DataViewSchema schema, string scoreCol, int numClusters) : base(env, schema, scoreCol, null) { CheckInputColumnTypes(schema); _numClusters = numClusters; _types = new DataViewType[3]; var key = new KeyDataViewType(typeof(uint), _numClusters); _types[ClusterIdCol] = key; _types[SortedClusterCol] = new VectorDataViewType(key, _numClusters); _types[SortedClusterScoreCol] = new VectorDataViewType(NumberDataViewType.Single, _numClusters); }
internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null) { Contracts.CheckNonEmpty(name, nameof(name)); Contracts.CheckValueOrNull(annotations); Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key"); Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector"); Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key"); Name = name; Kind = vecKind; ItemType = itemType; IsKey = isKey; Annotations = annotations ?? _empty; }
private static DataViewType MakeColumnType(SchemaShape.Column column) { DataViewType curType = column.ItemType; if (column.IsKey) { curType = new KeyDataViewType(((PrimitiveDataViewType)curType).RawType, AllKeySizes); } if (column.Kind == SchemaShape.Column.VectorKind.VariableVector) { curType = new VectorDataViewType((PrimitiveDataViewType)curType, 0); } else if (column.Kind == SchemaShape.Column.VectorKind.Vector) { curType = new VectorDataViewType((PrimitiveDataViewType)curType, AllVectorSizes); } return(curType); }
private ClusteringPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema) : base(env, ctx, schema) { CheckInputColumnTypes(schema); // *** Binary format ** // base // int: number of clusters _numClusters = ctx.Reader.ReadInt32(); Host.CheckDecode(_numClusters > 0); _types = new DataViewType[3]; var key = new KeyDataViewType(typeof(uint), _numClusters); _types[ClusterIdCol] = key; _types[SortedClusterCol] = new VectorDataViewType(key, _numClusters); _types[SortedClusterScoreCol] = new VectorDataViewType(NumberDataViewType.Single, _numClusters); }
internal MatrixFactorizationModelParameters(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyDataViewType matrixColumnIndexType, KeyDataViewType matrixRowIndexType) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); _host.Assert(matrixColumnIndexType.RawType == typeof(uint)); _host.Assert(matrixRowIndexType.RawType == typeof(uint)); _host.CheckValue(buffer, nameof(buffer)); _host.CheckValue(matrixColumnIndexType, nameof(matrixColumnIndexType)); _host.CheckValue(matrixRowIndexType, nameof(matrixRowIndexType)); buffer.Get(out NumberOfRows, out NumberOfColumns, out ApproximationRank, out var leftFactorMatrix, out var rightFactorMatrix); _leftFactorMatrix = leftFactorMatrix; _rightFactorMatrix = rightFactorMatrix; _host.Assert(NumberOfColumns == matrixColumnIndexType.GetCountAsInt32(_host)); _host.Assert(NumberOfRows == matrixRowIndexType.GetCountAsInt32(_host)); _host.Assert(_leftFactorMatrix.Length == NumberOfRows * ApproximationRank); _host.Assert(_rightFactorMatrix.Length == ApproximationRank * NumberOfColumns); MatrixColumnIndexType = matrixColumnIndexType; MatrixRowIndexType = matrixRowIndexType; }
private MatrixFactorizationModelParameters(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); // *** Binary format *** // int: number of rows (m), the limit on row // int: number of columns (n), the limit on column // int: rank of factor matrices (k) // float[m * k]: the left factor matrix // float[k * n]: the right factor matrix NumberOfRows = ctx.Reader.ReadInt32(); _host.CheckDecode(NumberOfRows > 0); if (ctx.Header.ModelVerWritten < VersionNoMinCount) { ulong mMin = ctx.Reader.ReadUInt64(); // We no longer support non zero Min for KeyType. _host.CheckDecode(mMin == 0); _host.CheckDecode((ulong)NumberOfRows <= ulong.MaxValue - mMin); } NumberOfColumns = ctx.Reader.ReadInt32(); _host.CheckDecode(NumberOfColumns > 0); if (ctx.Header.ModelVerWritten < VersionNoMinCount) { ulong nMin = ctx.Reader.ReadUInt64(); // We no longer support non zero Min for KeyType. _host.CheckDecode(nMin == 0); _host.CheckDecode((ulong)NumberOfColumns <= ulong.MaxValue - nMin); } ApproximationRank = ctx.Reader.ReadInt32(); _host.CheckDecode(ApproximationRank > 0); _leftFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked (NumberOfRows * ApproximationRank)); _rightFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked (NumberOfColumns * ApproximationRank)); MatrixColumnIndexType = new KeyDataViewType(typeof(uint), NumberOfColumns); MatrixRowIndexType = new KeyDataViewType(typeof(uint), NumberOfRows); }
/// <summary> /// Returns whether a type is a U4 key of known cardinality, and if so, sets /// <paramref name="keyType"/> to a non-null value. /// </summary> private static bool TryMarshalGoodRowColumnType(DataViewType type, out KeyDataViewType keyType) { keyType = type as KeyDataViewType; return(keyType?.Count > 0 && type.RawType == typeof(uint)); }
public void TestEqualAndGetHashCode() { var dict = new Dictionary <DataViewType, string>(); // add PrimitiveTypes, KeyType & corresponding VectorTypes VectorDataViewType tmp1, tmp2; var types = new PrimitiveDataViewType[] { NumberDataViewType.SByte, NumberDataViewType.Int16, NumberDataViewType.Int32, NumberDataViewType.Int64, NumberDataViewType.Byte, NumberDataViewType.UInt16, NumberDataViewType.UInt32, NumberDataViewType.UInt64, RowIdDataViewType.Instance, TextDataViewType.Instance, BooleanDataViewType.Instance, DateTimeDataViewType.Instance, DateTimeOffsetDataViewType.Instance, TimeSpanDataViewType.Instance }; foreach (var type in types) { var tmp = type; if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString()) { Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates."); } dict[tmp] = tmp.ToString(); for (int size = 0; size < 5; size++) { tmp1 = new VectorDataViewType(tmp, size); if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString()) { Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates."); } dict[tmp1] = tmp1.ToString(); for (int size1 = 0; size1 < 5; size1++) { tmp2 = new VectorDataViewType(tmp, size, size1); if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString()) { Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates."); } dict[tmp2] = tmp2.ToString(); } } // KeyType & Vector var rawType = tmp.RawType; if (!KeyDataViewType.IsValidDataType(rawType)) { continue; } for (ulong min = 0; min < 5; min++) { for (var count = 1; count < 5; count++) { tmp = new KeyDataViewType(rawType, count); if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString()) { Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates."); } dict[tmp] = tmp.ToString(); for (int size = 0; size < 5; size++) { tmp1 = new VectorDataViewType(tmp, size); if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString()) { Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates."); } dict[tmp1] = tmp1.ToString(); for (int size1 = 0; size1 < 5; size1++) { tmp2 = new VectorDataViewType(tmp, size, size1); if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString()) { Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates."); } dict[tmp2] = tmp2.ToString(); } } } Assert.True(rawType.TryGetDataKind(out var kind)); tmp = new KeyDataViewType(rawType, kind.ToMaxInt()); if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString()) { Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates."); } dict[tmp] = tmp.ToString(); for (int size = 0; size < 5; size++) { tmp1 = new VectorDataViewType(tmp, size); if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString()) { Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates."); } dict[tmp1] = tmp1.ToString(); for (int size1 = 0; size1 < 5; size1++) { tmp2 = new VectorDataViewType(tmp, size, size1); if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString()) { Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates."); } dict[tmp2] = tmp2.ToString(); } } } } // add ImageTypes for (int height = 1; height < 5; height++) { for (int width = 1; width < 5; width++) { var tmp4 = new ImageDataViewType(height, width); if (dict.ContainsKey(tmp4)) { Assert.True(false, dict[tmp4] + " and " + tmp4.ToString() + " are duplicates."); } dict[tmp4] = tmp4.ToString(); } } }
private static bool TryCreateEx(IExceptionContext ectx, ColInfo info, DataKind kind, KeyCount range, out PrimitiveDataViewType itemType, out ColInfoEx ex) { ectx.AssertValue(info); ectx.Assert(Enum.IsDefined(typeof(DataKind), kind)); ex = null; var typeSrc = info.TypeSrc; if (range != null) { itemType = TypeParsingUtils.ConstructKeyType(SchemaHelper.DataKind2InternalDataKind(kind), range); } else if (!typeSrc.ItemType().IsKey()) { itemType = ColumnTypeHelper.PrimitiveFromKind(kind); } else if (!ColumnTypeHelper.IsValidDataKind(kind)) { itemType = ColumnTypeHelper.PrimitiveFromKind(kind); return(false); } else { var key = typeSrc.ItemType().AsKey(); ectx.Assert(ColumnTypeHelper.IsValidDataKind(key.RawKind())); ulong count = key.Count; // Technically, it's an error for the counts not to match, but we'll let the Conversions // code return false below. There's a possibility we'll change the standard conversions to // map out of bounds values to zero, in which case, this is the right thing to do. ulong max = (ulong)kind; if ((ulong)count > max) { count = max; } itemType = new KeyDataViewType(SchemaHelper.DataKind2ColumnType(kind).RawType, count); } // Ensure that the conversion is legal. We don't actually cache the delegate here. It will get // re-fetched by the utils code when needed. bool identity; Delegate del; if (!Conversions.DefaultInstance.TryGetStandardConversion(typeSrc.ItemType(), itemType, out del, out identity)) { if (typeSrc.ItemType().RawKind() == itemType.RawKind()) { switch (typeSrc.ItemType().RawKind()) { case DataKind.UInt32: // Key starts at 1. // Multiclass future issue uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0); identity = false; ValueMapper <uint, uint> map_ = (in uint src, ref uint dst) => { dst = src + plus; }; del = (Delegate)map_; if (del == null) { throw Contracts.ExceptNotSupp("Issue with casting"); } break; default: throw Contracts.Except("Not suppoted type {0}", typeSrc.ItemType().RawKind()); } } else if (typeSrc.ItemType().RawKind() == DataKind.Int64 && kind == DataKind.UInt64) { ulong plus = (itemType.IsKey() ? (ulong)1 : (ulong)0) - (typeSrc.IsKey() ? (ulong)1 : (ulong)0); identity = false; ValueMapper <long, ulong> map_ = (in long src, ref ulong dst) => { CheckRange(src, dst, ectx); dst = (ulong)src + plus; }; del = (Delegate)map_; if (del == null) { throw Contracts.ExceptNotSupp("Issue with casting"); } } else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.UInt64) { ulong plus = (itemType.IsKey() ? (ulong)1 : (ulong)0) - (typeSrc.IsKey() ? (ulong)1 : (ulong)0); identity = false; ValueMapper <float, ulong> map_ = (in float src, ref ulong dst) => { CheckRange(src, dst, ectx); dst = (ulong)src + plus; }; del = (Delegate)map_; if (del == null) { throw Contracts.ExceptNotSupp("Issue with casting"); } } else if (typeSrc.ItemType().RawKind() == DataKind.Int64 && kind == DataKind.UInt32) { // Multiclass future issue uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0); identity = false; ValueMapper <long, uint> map_ = (in long src, ref uint dst) => { CheckRange(src, dst, ectx); dst = (uint)src + plus; }; del = (Delegate)map_; if (del == null) { throw Contracts.ExceptNotSupp("Issue with casting"); } } else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.UInt32) { // Multiclass future issue uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0); identity = false; ValueMapper <float, uint> map_ = (in float src, ref uint dst) => { CheckRange(src, dst, ectx); dst = (uint)src + plus; }; del = (Delegate)map_; if (del == null) { throw Contracts.ExceptNotSupp("Issue with casting"); } } else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.String) { // Multiclass future issue identity = false; ValueMapper <float, DvText> map_ = (in float src, ref DvText dst) => { dst = new DvText(string.Format("{0}", (int)src)); }; del = (Delegate)map_; if (del == null) { throw Contracts.ExceptNotSupp("Issue with casting"); } } else { return(false); } } DataViewType typeDst = itemType; if (typeSrc.IsVector()) { typeDst = new VectorDataViewType(itemType, typeSrc.AsVector().Dimensions.ToArray()); } // An output column is transposable iff the input column was transposable. VectorDataViewType slotType = null; if (info.SlotTypeSrc != null) { slotType = new VectorDataViewType(itemType, info.SlotTypeSrc.Dimensions.ToArray()); } ex = new ColInfoEx(kind, range != null, typeDst, slotType); return(true); }
// Checks that all the label columns of the model have the same key type as their label column - including the same // cardinality and the same key values, and returns the cardinality of the label column key. private static int CheckKeyLabelColumnCore <T>(IHostEnvironment env, PredictorModel[] models, KeyDataViewType labelType, DataViewSchema schema, int labelIndex, VectorDataViewType keyValuesType) where T : IEquatable <T> { env.Assert(keyValuesType.ItemType.RawType == typeof(T)); env.AssertNonEmpty(models); var labelNames = default(VBuffer <T>); schema[labelIndex].GetKeyValues(ref labelNames); var classCount = labelNames.Length; var curLabelNames = default(VBuffer <T>); for (int i = 1; i < models.Length; i++) { var model = models[i]; var edv = new EmptyDataView(env, model.TransformModel.InputSchema); model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred); var labelInfo = rmd.Schema.Label.HasValue; if (!rmd.Schema.Label.HasValue) { throw env.Except("Training schema for model {0} does not have a label column", i); } var labelCol = rmd.Schema.Label.Value; var curLabelType = labelCol.Type as KeyDataViewType; if (!labelType.Equals(curLabelType)) { throw env.Except("Label column of model {0} has different type than model 0", i); } var mdType = labelCol.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type; if (!mdType.Equals(keyValuesType)) { throw env.Except("Label column of model {0} has different key value type than model 0", i); } labelCol.GetKeyValues(ref curLabelNames); if (!AreEqual(in labelNames, in curLabelNames)) { throw env.Except("Label of model {0} has different values than model 0", i); } } return(classCount); }