private static string TestIsMulticlassLabel(DataViewType type)
 {
     if (type.GetKeyCount() > 0 || type == NumberDataViewType.Single || type == NumberDataViewType.Double)
     {
         return(null);
     }
     return($"Label column type is not supported for binary remapping: {type}. Supported types: key, float, double.");
 }
예제 #2
0
        internal static bool IsValidColumnType(DataViewType type)
        {
            // REVIEW: Consider supporting all integer and unsigned types.
            ulong keyCount = type.GetKeyCount();

            return
                ((0 < keyCount && keyCount < Utils.ArrayMaxSize) || type is BooleanDataViewType ||
                 type == NumberDataViewType.Single || type == NumberDataViewType.Double || type == NumberDataViewType.Int32);
        }
예제 #3
0
        public static bool IsValidRangeFilterColumnType(IExceptionContext ectx, DataViewType type)
        {
            ectx.CheckValue(type, nameof(type));

            return(type == NumberDataViewType.Single || type == NumberDataViewType.Double || type.GetKeyCount() > 0);
        }
예제 #4
0
        private RangeFilter(IHost host, ModelLoadContext ctx, IDataView input)
            : base(host, input)
        {
            Host.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: sizeof(Float)
            // int: id of column name
            // double: min
            // double: max
            // byte: complement
            int cbFloat = ctx.Reader.ReadInt32();

            Host.CheckDecode(cbFloat == sizeof(float));

            var column = ctx.LoadNonEmptyString();
            var schema = Source.Schema;

            if (!schema.TryGetColumnIndex(column, out _index))
            {
                throw Host.ExceptSchemaMismatch(nameof(schema), "source", column);
            }

            _type = schema[_index].Type;
            if (_type != NumberDataViewType.Single && _type != NumberDataViewType.Double && _type.GetKeyCount() == 0)
            {
                throw Host.ExceptSchemaMismatch(nameof(schema), "source", column, "Single, Double or Key", _type.ToString());
            }

            _min = ctx.Reader.ReadDouble();
            _max = ctx.Reader.ReadDouble();
            if (!(_min <= _max))
            {
                throw Host.Except("min", "min must be less than or equal to max");
            }
            _complement = ctx.Reader.ReadBoolByte();
            _includeMin = ctx.Reader.ReadBoolByte();
            _includeMax = ctx.Reader.ReadBoolByte();
        }
예제 #5
0
            private void GetLabels(Transposer trans, DataViewType labelType, int labelCol)
            {
                int min;
                int lim;
                var labels = default(VBuffer <int>);

                // Note: NAs have their own separate bin.
                if (labelType == NumberDataViewType.Int32)
                {
                    var tmp = default(VBuffer <int>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinInts(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberDataViewType.Single)
                {
                    var tmp = default(VBuffer <Single>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinSingles(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType == NumberDataViewType.Double)
                {
                    var tmp = default(VBuffer <Double>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinDoubles(in tmp, ref labels, _numBins, out min, out lim);
                    _numLabels = lim - min;
                }
                else if (labelType is BooleanDataViewType)
                {
                    var tmp = default(VBuffer <bool>);
                    trans.GetSingleSlotValue(labelCol, ref tmp);
                    BinBools(in tmp, ref labels);
                    _numLabels = 3;
                    min        = -1;
                    lim        = 2;
                }
                else
                {
                    ulong labelKeyCount = labelType.GetKeyCount();
                    Contracts.Assert(labelKeyCount < Utils.ArrayMaxSize);
                    KeyLabelGetter <int> del = GetKeyLabels <int>;
                    var methodInfo           = del.GetMethodInfo().GetGenericMethodDefinition().MakeGenericMethod(labelType.RawType);
                    var parameters           = new object[] { trans, labelCol, labelType };
                    _labels    = (VBuffer <int>)methodInfo.Invoke(this, parameters);
                    _numLabels = labelType.GetKeyCountAsInt32(_host) + 1;

                    // No need to densify or shift in this case.
                    return;
                }

                // Densify and shift labels.
                VBufferUtils.Densify(ref labels);
                Contracts.Assert(labels.IsDense);
                var labelsEditor = VBufferEditor.CreateFromBuffer(ref labels);

                for (int i = 0; i < labels.Length; i++)
                {
                    labelsEditor.Values[i] -= min;
                    Contracts.Assert(labelsEditor.Values[i] < _numLabels);
                }
                _labels = labelsEditor.Commit();
            }