/// <summary>
        /// Check for a standard binary classification label.
        /// </summary>
        public static void CheckBinaryLabel(this RoleMappedData data)
        {
            Contracts.CheckValue(data, nameof(data));

            if (!data.Schema.Label.HasValue)
            {
                throw Contracts.ExceptParam(nameof(data), "Training data must specify a label column.");
            }
            var col = data.Schema.Label.Value;

            Contracts.Assert(!col.IsHidden);
            if (col.Type != BooleanDataViewType.Instance && col.Type != NumberDataViewType.Single && col.Type != NumberDataViewType.Double && !(col.Type is KeyDataViewType keyType && keyType.Count == 2))
            {
                KeyDataViewType colKeyType = col.Type as KeyDataViewType;
                if (colKeyType != null)
                {
                    if (colKeyType.Count == 1)
                    {
                        throw Contracts.ExceptParam(nameof(data),
                                                    "The label column '{0}' of the training data has only one class. Two classes are required for binary classification.",
                                                    col.Name);
                    }
                    else if (colKeyType.Count > 2)
                    {
                        throw Contracts.ExceptParam(nameof(data),
                                                    "The label column '{0}' of the training data has more than two classes. Only two classes are allowed for binary classification.",
                                                    col.Name);
                    }
                }
                throw Contracts.ExceptParam(nameof(data),
                                            "The label column '{0}' of the training data has a data type not suitable for binary classification: {1}. Type must be Boolean, Single, Double or Key with two classes.",
                                            col.Name, col.Type);
            }
        }
Exemple #2
0
        public void SequencePredictorSchemaTest()
        {
            int keyCount = 10;
            var expectedScoreColumnType = new KeyDataViewType(typeof(uint), keyCount);
            VBuffer <ReadOnlyMemory <char> > keyNames = GenerateKeyNames(keyCount);

            var sequenceSchema = ScoreSchemaFactory.CreateSequencePredictionSchema(expectedScoreColumnType,
                                                                                   AnnotationUtils.Const.ScoreColumnKind.SequenceClassification, keyNames);

            // Output schema should only contain one column, which is the predicted label.
            Assert.Single(sequenceSchema);
            var scoreColumn = sequenceSchema[0];

            // Check score column name.
            Assert.Equal(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, scoreColumn.Name);

            // Check score column type.
            var actualScoreColumnType = scoreColumn.Type as KeyDataViewType;

            Assert.NotNull(actualScoreColumnType);
            Assert.Equal(expectedScoreColumnType.Count, actualScoreColumnType.Count);
            Assert.Equal(expectedScoreColumnType.RawType, actualScoreColumnType.RawType);

            // Check metadata. Because keyNames is not empty, there should be three metadata fields.
            var scoreMetadata = scoreColumn.Annotations;

            Assert.Equal(3, scoreMetadata.Schema.Count);

            // Check metadata columns' names.
            Assert.Equal(AnnotationUtils.Kinds.KeyValues, scoreMetadata.Schema[0].Name);
            Assert.Equal(AnnotationUtils.Kinds.ScoreColumnKind, scoreMetadata.Schema[1].Name);
            Assert.Equal(AnnotationUtils.Kinds.ScoreValueKind, scoreMetadata.Schema[2].Name);

            // Check metadata columns' types.
            Assert.True(scoreMetadata.Schema[0].Type is VectorDataViewType);
            Assert.Equal(keyNames.Length, (scoreMetadata.Schema[0].Type as VectorDataViewType).Size);
            Assert.Equal(TextDataViewType.Instance, (scoreMetadata.Schema[0].Type as VectorDataViewType).ItemType);
            Assert.Equal(TextDataViewType.Instance, scoreColumn.Annotations.Schema[1].Type);
            Assert.Equal(TextDataViewType.Instance, scoreColumn.Annotations.Schema[2].Type);

            // Check metadata columns' values.
            var keyNamesGetter = scoreMetadata.GetGetter <VBuffer <ReadOnlyMemory <char> > >(scoreMetadata.Schema[0]);
            var actualKeyNames = new VBuffer <ReadOnlyMemory <char> >();

            keyNamesGetter(ref actualKeyNames);
            Assert.Equal(keyNames.Length, actualKeyNames.Length);
            Assert.Equal(keyNames.DenseValues(), actualKeyNames.DenseValues());

            var scoreColumnKindGetter = scoreMetadata.GetGetter <ReadOnlyMemory <char> >(scoreMetadata.Schema[1]);
            ReadOnlyMemory <char> scoreColumnKindValue = null;

            scoreColumnKindGetter(ref scoreColumnKindValue);
            Assert.Equal(AnnotationUtils.Const.ScoreColumnKind.SequenceClassification, scoreColumnKindValue.ToString());

            var scoreValueKindGetter = scoreMetadata.GetGetter <ReadOnlyMemory <char> >(scoreMetadata.Schema[2]);
            ReadOnlyMemory <char> scoreValueKindValue = null;

            scoreValueKindGetter(ref scoreValueKindValue);
            Assert.Equal(AnnotationUtils.Const.ScoreValueKind.PredictedLabel, scoreValueKindValue.ToString());
        }
        private static IDataView AppendFloatMapper <TInput>(IHostEnvironment env, IChannel ch, IDataView input,
                                                            string col, KeyDataViewType type, int seed)
        {
            // Any key is convertible to ulong, so rather than add special case handling for all possible
            // key-types we just upfront convert it to the most general type (ulong) and work from there.
            KeyDataViewType dstType = new KeyDataViewType(typeof(ulong), type.Count);
            bool            identity;
            var             converter = Conversions.Instance.GetStandardConversion <TInput, ulong>(type, dstType, out identity);
            var             isNa      = Conversions.Instance.GetIsNAPredicate <TInput>(type);

            ValueMapper <TInput, Single> mapper;

            if (seed == 0)
            {
                mapper =
                    (in TInput src, ref Single dst) =>
                {
                    //Attention: This method is called from multiple threads.
                    //Do not move the temp variable outside this method.
                    //If you do, the variable is shared between the threads and results in a race condition.
                    ulong temp = 0;
                    if (isNa(in src))
                    {
                        dst = Single.NaN;
                        return;
                    }
                    converter(in src, ref temp);
                    dst = (Single)temp - 1;
                };
            }
            else
            {
                ch.Check(type.Count > 0, "Label must be of known cardinality.");
                int[] permutation = Utils.GetRandomPermutation(RandomUtils.Create(seed), type.GetCountAsInt32(env));
                mapper =
                    (in TInput src, ref Single dst) =>
                {
                    //Attention: This method is called from multiple threads.
                    //Do not move the temp variable outside this method.
                    //If you do, the variable is shared between the threads and results in a race condition.
                    ulong temp = 0;
                    if (isNa(in src))
                    {
                        dst = Single.NaN;
                        return;
                    }
                    converter(in src, ref temp);
                    dst = (Single)permutation[(int)(temp - 1)];
                };
            }

            return(LambdaColumnMapper.Create(env, "Key to Float Mapper", input, col, col, type, NumberDataViewType.Single, mapper));
        }
Exemple #4
0
            public Mapper(TokenizingByCharactersTransformer parent, DataViewSchema inputSchema)
                : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema)
            {
                _parent = parent;
                var keyType = new KeyDataViewType(typeof(ushort), CharsCount);

                _type           = new VectorDataViewType(keyType);
                _isSourceVector = new bool[_parent.ColumnPairs.Length];
                for (int i = 0; i < _isSourceVector.Length; i++)
                {
                    _isSourceVector[i] = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type is VectorDataViewType;
                }
            }
Exemple #5
0
        public ClusteringPerInstanceEvaluator(IHostEnvironment env, DataViewSchema schema, string scoreCol, int numClusters)
            : base(env, schema, scoreCol, null)
        {
            CheckInputColumnTypes(schema);
            _numClusters = numClusters;

            _types = new DataViewType[3];
            var key = new KeyDataViewType(typeof(uint), _numClusters);

            _types[ClusterIdCol]          = key;
            _types[SortedClusterCol]      = new VectorDataViewType(key, _numClusters);
            _types[SortedClusterScoreCol] = new VectorDataViewType(NumberDataViewType.Single, _numClusters);
        }
Exemple #6
0
            internal Column(string name, VectorKind vecKind, DataViewType itemType, bool isKey, SchemaShape annotations = null)
            {
                Contracts.CheckNonEmpty(name, nameof(name));
                Contracts.CheckValueOrNull(annotations);
                Contracts.CheckParam(!(itemType is KeyDataViewType), nameof(itemType), "Item type cannot be a key");
                Contracts.CheckParam(!(itemType is VectorDataViewType), nameof(itemType), "Item type cannot be a vector");
                Contracts.CheckParam(!isKey || KeyDataViewType.IsValidDataType(itemType.RawType), nameof(itemType), "The item type must be valid for a key");

                Name        = name;
                Kind        = vecKind;
                ItemType    = itemType;
                IsKey       = isKey;
                Annotations = annotations ?? _empty;
            }
Exemple #7
0
        private static DataViewType MakeColumnType(SchemaShape.Column column)
        {
            DataViewType curType = column.ItemType;

            if (column.IsKey)
            {
                curType = new KeyDataViewType(((PrimitiveDataViewType)curType).RawType, AllKeySizes);
            }
            if (column.Kind == SchemaShape.Column.VectorKind.VariableVector)
            {
                curType = new VectorDataViewType((PrimitiveDataViewType)curType, 0);
            }
            else if (column.Kind == SchemaShape.Column.VectorKind.Vector)
            {
                curType = new VectorDataViewType((PrimitiveDataViewType)curType, AllVectorSizes);
            }
            return(curType);
        }
Exemple #8
0
        private ClusteringPerInstanceEvaluator(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema schema)
            : base(env, ctx, schema)
        {
            CheckInputColumnTypes(schema);

            // *** Binary format **
            // base
            // int: number of clusters

            _numClusters = ctx.Reader.ReadInt32();
            Host.CheckDecode(_numClusters > 0);

            _types = new DataViewType[3];
            var key = new KeyDataViewType(typeof(uint), _numClusters);

            _types[ClusterIdCol]          = key;
            _types[SortedClusterCol]      = new VectorDataViewType(key, _numClusters);
            _types[SortedClusterScoreCol] = new VectorDataViewType(NumberDataViewType.Single, _numClusters);
        }
Exemple #9
0
        internal MatrixFactorizationModelParameters(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyDataViewType matrixColumnIndexType, KeyDataViewType matrixRowIndexType)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.Assert(matrixColumnIndexType.RawType == typeof(uint));
            _host.Assert(matrixRowIndexType.RawType == typeof(uint));
            _host.CheckValue(buffer, nameof(buffer));
            _host.CheckValue(matrixColumnIndexType, nameof(matrixColumnIndexType));
            _host.CheckValue(matrixRowIndexType, nameof(matrixRowIndexType));
            buffer.Get(out NumberOfRows, out NumberOfColumns, out ApproximationRank, out var leftFactorMatrix, out var rightFactorMatrix);
            _leftFactorMatrix  = leftFactorMatrix;
            _rightFactorMatrix = rightFactorMatrix;
            _host.Assert(NumberOfColumns == matrixColumnIndexType.GetCountAsInt32(_host));
            _host.Assert(NumberOfRows == matrixRowIndexType.GetCountAsInt32(_host));
            _host.Assert(_leftFactorMatrix.Length == NumberOfRows * ApproximationRank);
            _host.Assert(_rightFactorMatrix.Length == ApproximationRank * NumberOfColumns);

            MatrixColumnIndexType = matrixColumnIndexType;
            MatrixRowIndexType    = matrixRowIndexType;
        }
Exemple #10
0
        private MatrixFactorizationModelParameters(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            // *** Binary format ***
            // int: number of rows (m), the limit on row
            // int: number of columns (n), the limit on column
            // int: rank of factor matrices (k)
            // float[m * k]: the left factor matrix
            // float[k * n]: the right factor matrix

            NumberOfRows = ctx.Reader.ReadInt32();
            _host.CheckDecode(NumberOfRows > 0);
            if (ctx.Header.ModelVerWritten < VersionNoMinCount)
            {
                ulong mMin = ctx.Reader.ReadUInt64();
                // We no longer support non zero Min for KeyType.
                _host.CheckDecode(mMin == 0);
                _host.CheckDecode((ulong)NumberOfRows <= ulong.MaxValue - mMin);
            }
            NumberOfColumns = ctx.Reader.ReadInt32();
            _host.CheckDecode(NumberOfColumns > 0);
            if (ctx.Header.ModelVerWritten < VersionNoMinCount)
            {
                ulong nMin = ctx.Reader.ReadUInt64();
                // We no longer support non zero Min for KeyType.
                _host.CheckDecode(nMin == 0);
                _host.CheckDecode((ulong)NumberOfColumns <= ulong.MaxValue - nMin);
            }
            ApproximationRank = ctx.Reader.ReadInt32();
            _host.CheckDecode(ApproximationRank > 0);

            _leftFactorMatrix  = Utils.ReadSingleArray(ctx.Reader, checked (NumberOfRows * ApproximationRank));
            _rightFactorMatrix = Utils.ReadSingleArray(ctx.Reader, checked (NumberOfColumns * ApproximationRank));

            MatrixColumnIndexType = new KeyDataViewType(typeof(uint), NumberOfColumns);
            MatrixRowIndexType    = new KeyDataViewType(typeof(uint), NumberOfRows);
        }
 /// <summary>
 /// Returns whether a type is a U4 key of known cardinality, and if so, sets
 /// <paramref name="keyType"/> to a non-null value.
 /// </summary>
 private static bool TryMarshalGoodRowColumnType(DataViewType type, out KeyDataViewType keyType)
 {
     keyType = type as KeyDataViewType;
     return(keyType?.Count > 0 && type.RawType == typeof(uint));
 }
        public void TestEqualAndGetHashCode()
        {
            var dict = new Dictionary <DataViewType, string>();
            // add PrimitiveTypes, KeyType & corresponding VectorTypes
            VectorDataViewType tmp1, tmp2;
            var types = new PrimitiveDataViewType[] { NumberDataViewType.SByte, NumberDataViewType.Int16, NumberDataViewType.Int32, NumberDataViewType.Int64,
                                                      NumberDataViewType.Byte, NumberDataViewType.UInt16, NumberDataViewType.UInt32, NumberDataViewType.UInt64, RowIdDataViewType.Instance,
                                                      TextDataViewType.Instance, BooleanDataViewType.Instance, DateTimeDataViewType.Instance, DateTimeOffsetDataViewType.Instance, TimeSpanDataViewType.Instance };

            foreach (var type in types)
            {
                var tmp = type;
                if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
                {
                    Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
                }
                dict[tmp] = tmp.ToString();
                for (int size = 0; size < 5; size++)
                {
                    tmp1 = new VectorDataViewType(tmp, size);
                    if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString())
                    {
                        Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates.");
                    }
                    dict[tmp1] = tmp1.ToString();
                    for (int size1 = 0; size1 < 5; size1++)
                    {
                        tmp2 = new VectorDataViewType(tmp, size, size1);
                        if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString())
                        {
                            Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates.");
                        }
                        dict[tmp2] = tmp2.ToString();
                    }
                }

                // KeyType & Vector
                var rawType = tmp.RawType;
                if (!KeyDataViewType.IsValidDataType(rawType))
                {
                    continue;
                }
                for (ulong min = 0; min < 5; min++)
                {
                    for (var count = 1; count < 5; count++)
                    {
                        tmp = new KeyDataViewType(rawType, count);
                        if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
                        {
                            Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
                        }
                        dict[tmp] = tmp.ToString();
                        for (int size = 0; size < 5; size++)
                        {
                            tmp1 = new VectorDataViewType(tmp, size);
                            if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString())
                            {
                                Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates.");
                            }
                            dict[tmp1] = tmp1.ToString();
                            for (int size1 = 0; size1 < 5; size1++)
                            {
                                tmp2 = new VectorDataViewType(tmp, size, size1);
                                if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString())
                                {
                                    Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates.");
                                }
                                dict[tmp2] = tmp2.ToString();
                            }
                        }
                    }
                    Assert.True(rawType.TryGetDataKind(out var kind));
                    tmp = new KeyDataViewType(rawType, kind.ToMaxInt());
                    if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
                    {
                        Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
                    }
                    dict[tmp] = tmp.ToString();
                    for (int size = 0; size < 5; size++)
                    {
                        tmp1 = new VectorDataViewType(tmp, size);
                        if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString())
                        {
                            Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates.");
                        }
                        dict[tmp1] = tmp1.ToString();
                        for (int size1 = 0; size1 < 5; size1++)
                        {
                            tmp2 = new VectorDataViewType(tmp, size, size1);
                            if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString())
                            {
                                Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates.");
                            }
                            dict[tmp2] = tmp2.ToString();
                        }
                    }
                }
            }

            // add ImageTypes
            for (int height = 1; height < 5; height++)
            {
                for (int width = 1; width < 5; width++)
                {
                    var tmp4 = new ImageDataViewType(height, width);
                    if (dict.ContainsKey(tmp4))
                    {
                        Assert.True(false, dict[tmp4] + " and " + tmp4.ToString() + " are duplicates.");
                    }
                    dict[tmp4] = tmp4.ToString();
                }
            }
        }
Exemple #13
0
        private static bool TryCreateEx(IExceptionContext ectx, ColInfo info, DataKind kind, KeyCount range,
                                        out PrimitiveDataViewType itemType, out ColInfoEx ex)
        {
            ectx.AssertValue(info);
            ectx.Assert(Enum.IsDefined(typeof(DataKind), kind));

            ex = null;

            var typeSrc = info.TypeSrc;

            if (range != null)
            {
                itemType = TypeParsingUtils.ConstructKeyType(SchemaHelper.DataKind2InternalDataKind(kind), range);
            }
            else if (!typeSrc.ItemType().IsKey())
            {
                itemType = ColumnTypeHelper.PrimitiveFromKind(kind);
            }
            else if (!ColumnTypeHelper.IsValidDataKind(kind))
            {
                itemType = ColumnTypeHelper.PrimitiveFromKind(kind);
                return(false);
            }
            else
            {
                var key = typeSrc.ItemType().AsKey();
                ectx.Assert(ColumnTypeHelper.IsValidDataKind(key.RawKind()));
                ulong count = key.Count;
                // Technically, it's an error for the counts not to match, but we'll let the Conversions
                // code return false below. There's a possibility we'll change the standard conversions to
                // map out of bounds values to zero, in which case, this is the right thing to do.
                ulong max = (ulong)kind;
                if ((ulong)count > max)
                {
                    count = max;
                }
                itemType = new KeyDataViewType(SchemaHelper.DataKind2ColumnType(kind).RawType, count);
            }

            // Ensure that the conversion is legal. We don't actually cache the delegate here. It will get
            // re-fetched by the utils code when needed.
            bool     identity;
            Delegate del;

            if (!Conversions.DefaultInstance.TryGetStandardConversion(typeSrc.ItemType(), itemType, out del, out identity))
            {
                if (typeSrc.ItemType().RawKind() == itemType.RawKind())
                {
                    switch (typeSrc.ItemType().RawKind())
                    {
                    case DataKind.UInt32:
                        // Key starts at 1.
                        // Multiclass future issue
                        uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0);
                        identity = false;
                        ValueMapper <uint, uint> map_ = (in uint src, ref uint dst) => { dst = src + plus; };
                        del = (Delegate)map_;
                        if (del == null)
                        {
                            throw Contracts.ExceptNotSupp("Issue with casting");
                        }
                        break;

                    default:
                        throw Contracts.Except("Not suppoted type {0}", typeSrc.ItemType().RawKind());
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Int64 && kind == DataKind.UInt64)
                {
                    ulong plus = (itemType.IsKey() ? (ulong)1 : (ulong)0) - (typeSrc.IsKey() ? (ulong)1 : (ulong)0);
                    identity = false;
                    ValueMapper <long, ulong> map_ = (in long src, ref ulong dst) =>
                    {
                        CheckRange(src, dst, ectx); dst = (ulong)src + plus;
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.UInt64)
                {
                    ulong plus = (itemType.IsKey() ? (ulong)1 : (ulong)0) - (typeSrc.IsKey() ? (ulong)1 : (ulong)0);
                    identity = false;
                    ValueMapper <float, ulong> map_ = (in float src, ref ulong dst) =>
                    {
                        CheckRange(src, dst, ectx); dst = (ulong)src + plus;
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Int64 && kind == DataKind.UInt32)
                {
                    // Multiclass future issue
                    uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0);
                    identity = false;
                    ValueMapper <long, uint> map_ = (in long src, ref uint dst) =>
                    {
                        CheckRange(src, dst, ectx); dst = (uint)src + plus;
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.UInt32)
                {
                    // Multiclass future issue
                    uint plus = (itemType.IsKey() ? (uint)1 : (uint)0) - (typeSrc.IsKey() ? (uint)1 : (uint)0);
                    identity = false;
                    ValueMapper <float, uint> map_ = (in float src, ref uint dst) =>
                    {
                        CheckRange(src, dst, ectx); dst = (uint)src + plus;
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else if (typeSrc.ItemType().RawKind() == DataKind.Single && kind == DataKind.String)
                {
                    // Multiclass future issue
                    identity = false;
                    ValueMapper <float, DvText> map_ = (in float src, ref DvText dst) =>
                    {
                        dst = new DvText(string.Format("{0}", (int)src));
                    };
                    del = (Delegate)map_;
                    if (del == null)
                    {
                        throw Contracts.ExceptNotSupp("Issue with casting");
                    }
                }
                else
                {
                    return(false);
                }
            }

            DataViewType typeDst = itemType;

            if (typeSrc.IsVector())
            {
                typeDst = new VectorDataViewType(itemType, typeSrc.AsVector().Dimensions.ToArray());
            }

            // An output column is transposable iff the input column was transposable.
            VectorDataViewType slotType = null;

            if (info.SlotTypeSrc != null)
            {
                slotType = new VectorDataViewType(itemType, info.SlotTypeSrc.Dimensions.ToArray());
            }

            ex = new ColInfoEx(kind, range != null, typeDst, slotType);
            return(true);
        }
Exemple #14
0
        // Checks that all the label columns of the model have the same key type as their label column - including the same
        // cardinality and the same key values, and returns the cardinality of the label column key.
        private static int CheckKeyLabelColumnCore <T>(IHostEnvironment env, PredictorModel[] models, KeyDataViewType labelType, DataViewSchema schema, int labelIndex, VectorDataViewType keyValuesType)
            where T : IEquatable <T>
        {
            env.Assert(keyValuesType.ItemType.RawType == typeof(T));
            env.AssertNonEmpty(models);
            var labelNames = default(VBuffer <T>);

            schema[labelIndex].GetKeyValues(ref labelNames);
            var classCount = labelNames.Length;

            var curLabelNames = default(VBuffer <T>);

            for (int i = 1; i < models.Length; i++)
            {
                var model = models[i];
                var edv   = new EmptyDataView(env, model.TransformModel.InputSchema);
                model.PrepareData(env, edv, out RoleMappedData rmd, out IPredictor pred);
                var labelInfo = rmd.Schema.Label.HasValue;
                if (!rmd.Schema.Label.HasValue)
                {
                    throw env.Except("Training schema for model {0} does not have a label column", i);
                }
                var labelCol = rmd.Schema.Label.Value;

                var curLabelType = labelCol.Type as KeyDataViewType;
                if (!labelType.Equals(curLabelType))
                {
                    throw env.Except("Label column of model {0} has different type than model 0", i);
                }

                var mdType = labelCol.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type;
                if (!mdType.Equals(keyValuesType))
                {
                    throw env.Except("Label column of model {0} has different key value type than model 0", i);
                }
                labelCol.GetKeyValues(ref curLabelNames);
                if (!AreEqual(in labelNames, in curLabelNames))
                {
                    throw env.Except("Label of model {0} has different values than model 0", i);
                }
            }
            return(classCount);
        }