示例#1
0
        internal static void GetColumnTypeShape(ColumnType type,
                                                out Column.VectorKind vecKind,
                                                out ColumnType itemType,
                                                out bool isKey)
        {
            if (type.IsKnownSizeVector)
            {
                vecKind = Column.VectorKind.Vector;
            }
            else if (type.IsVector)
            {
                vecKind = Column.VectorKind.VariableVector;
            }
            else
            {
                vecKind = Column.VectorKind.Scalar;
            }

            itemType = type.ItemType;
            if (type.ItemType.IsKey)
            {
                itemType = PrimitiveType.FromKind(type.ItemType.RawKind);
            }
            isKey = type.ItemType.IsKey;
        }
示例#2
0
        public static Output ExtractSweepResult(IHostEnvironment env, ResultInput input)
        {
            var autoMlState = input.State as AutoInference.AutoMlMlState;

            if (autoMlState == null)
            {
                throw env.Except("The state must be a valid AutoMlState.");
            }
            // Create results output dataview
            var       rows = autoMlState.GetAllEvaluatedPipelines().Select(p => p.ToResultRow()).ToList();
            IDataView outputView;
            var       col1 = new KeyValuePair <string, ColumnType>("Graph", TextType.Instance);
            var       col2 = new KeyValuePair <string, ColumnType>("MetricValue", PrimitiveType.FromKind(DataKind.R8));
            var       col3 = new KeyValuePair <string, ColumnType>("PipelineId", TextType.Instance);

            if (rows.Count == 0)
            {
                var host = env.Register("ExtractSweepResult");
                outputView = new EmptyDataView(host, new SimpleSchema(host, col1, col2, col3));
            }
            else
            {
                var builder = new ArrayDataViewBuilder(env);
                builder.AddColumn(col1.Key, (PrimitiveType)col1.Value, rows.Select(r => new DvText(r.GraphJson)).ToArray());
                builder.AddColumn(col2.Key, (PrimitiveType)col2.Value, rows.Select(r => r.MetricValue).ToArray());
                builder.AddColumn(col3.Key, (PrimitiveType)col3.Value, rows.Select(r => new DvText(r.PipelineId)).ToArray());
                outputView = builder.GetDataView();
            }
            return(new Output {
                Results = outputView, State = autoMlState
            });
        }
        public static ColumnType ReadType(ModelLoadContext ctx)
        {
            bool isVector = ctx.Reader.ReadBoolean();

            if (isVector)
            {
                int dimCount = ctx.Reader.ReadInt32();
                if (dimCount != 1)
                {
                    throw Contracts.ExceptNotImpl("Number of dimensions should be 1.");
                }
                var dims = new int[dimCount];
                for (int i = 0; i < dimCount; ++i)
                {
                    dims[i] = ctx.Reader.ReadInt32();
                }
                var kind = (DataKind)ctx.Reader.ReadByte();
                return(new VectorType(PrimitiveType.FromKind(kind), dims[0]));
            }
            else
            {
                var kind = (DataKind)ctx.Reader.ReadByte();
                return(FromKind(kind));
            }
        }
示例#4
0
 /// <summary>
 /// Converts a Onnx type, that follows the System.Type convention
 /// to the type system ML.NET recognizes (e.g. I4, I8, R4 etc.)
 /// </summary>
 /// <param name="type"></param>
 /// <returns></returns>
 public static PrimitiveType OnnxToMlNetType(System.Type type)
 {
     if (!_typeToKindMap.ContainsKey(type))
     {
         throw Contracts.ExceptNotSupp("Onnx type not supported", type);
     }
     return(PrimitiveType.FromKind(_typeToKindMap[type]));
 }
示例#5
0
        public static PrimitiveType OnnxToMlNetType(DataType type)
        {
            DataKind kind;

            switch (type)
            {
            case DataType.Type_Float:
                kind = DataKind.R4;
                break;

            case DataType.Type_Double:
                kind = DataKind.R8;
                break;

            case DataType.Type_Int8:
                kind = DataKind.I1;
                break;

            case DataType.Type_Int16:
                kind = DataKind.I2;
                break;

            case DataType.Type_Int32:
                kind = DataKind.I4;
                break;

            case DataType.Type_Int64:
                kind = DataKind.I8;
                break;

            case DataType.Type_Uint8:
                kind = DataKind.U1;
                break;

            case DataType.Type_Uint16:
                kind = DataKind.U2;
                break;

            case DataType.Type_String:
                kind = DataKind.TX;
                break;

            case DataType.Type_Bool:
                kind = DataKind.BL;
                break;

            case DataType.Type_Invalid:
            default:
                throw Contracts.ExceptNotSupp("Onnx type not supported", type);
            }

            return(PrimitiveType.FromKind(kind));
        }
        /// <summary>
        /// Helper function to retrieve the Primitie type given a Type
        /// </summary>
        internal static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorType)
        {
            Type type = rawType;

            isVectorType = false;
            if (type.IsArray)
            {
                type         = rawType.GetElementType();
                isVectorType = true;
            }

            if (!type.TryGetDataKind(out DataKind kind))
            {
                throw new InvalidOperationException($"Unsupported type {type} used in mapping.");
            }

            return(PrimitiveType.FromKind(kind));
        }
示例#7
0
        /// <summary>
        /// Create a schema shape out of the fully defined schema.
        /// </summary>
        public static SchemaShape Create(ISchema schema)
        {
            Contracts.CheckValue(schema, nameof(schema));
            var cols = new List <Column>();

            for (int iCol = 0; iCol < schema.ColumnCount; iCol++)
            {
                if (!schema.IsHidden(iCol))
                {
                    Column.VectorKind vecKind;
                    var type = schema.GetColumnType(iCol);
                    if (type.IsKnownSizeVector)
                    {
                        vecKind = Column.VectorKind.Vector;
                    }
                    else if (type.IsVector)
                    {
                        vecKind = Column.VectorKind.VariableVector;
                    }
                    else
                    {
                        vecKind = Column.VectorKind.Scalar;
                    }

                    ColumnType itemType = type.ItemType;
                    if (type.ItemType.IsKey)
                    {
                        itemType = PrimitiveType.FromKind(type.ItemType.RawKind);
                    }
                    var isKey = type.ItemType.IsKey;

                    var metadataNames = schema.GetMetadataTypes(iCol)
                                        .Select(kvp => kvp.Key)
                                        .ToArray();
                    cols.Add(new Column(schema.GetColumnName(iCol), vecKind, itemType, isKey, metadataNames));
                }
            }
            return(new SchemaShape(cols.ToArray()));
        }
示例#8
0
        protected override Delegate[] CreatePredictionGetters(Booster xgboostModel, IRow input, Func <int, bool> predicate)
        {
            var active = Utils.BuildArray(OutputSchema.ColumnCount, predicate);

            xgboostModel.LazyInit();
            var getters = new Delegate[1];

            if (active[0])
            {
                var             featureGetter   = RowCursorUtils.GetVecGetterAs <Float>(PrimitiveType.FromKind(DataKind.R4), input, InputSchema.Feature.Index);
                VBuffer <Float> features        = new VBuffer <Float>();
                var             postProcessor   = Parent.GetOutputPostProcessor();
                int             expectedLength  = input.Schema.GetColumnType(InputSchema.Feature.Index).VectorSize;
                var             xgboostBuffer   = Booster.CreateInternalBuffer();
                int             nbMappedClasses = _classMapping == null ? 0 : _numberOfClasses;

                if (nbMappedClasses == 0)
                {
                    ValueGetter <VBuffer <Float> > localGetter = (ref VBuffer <Float> prediction) =>
                    {
                        featureGetter(ref features);
                        Contracts.Assert(features.Length == expectedLength);
                        xgboostModel.Predict(ref features, ref prediction, ref xgboostBuffer);
                        postProcessor(ref prediction);
                    };
                    getters[0] = localGetter;
                }
                else
                {
                    ValueGetter <VBuffer <Float> > localGetter = (ref VBuffer <Float> prediction) =>
                    {
                        featureGetter(ref features);
                        Contracts.Assert(features.Length == expectedLength);
                        xgboostModel.Predict(ref features, ref prediction, ref xgboostBuffer);
                        Contracts.Assert(prediction.IsDense);
                        postProcessor(ref prediction);
                        var indices = prediction.Indices;
                        if (indices == null || indices.Length < _classMapping.Length)
                        {
                            indices = new int[_classMapping.Length];
                        }
                        Array.Copy(_classMapping, indices, _classMapping.Length);
                        prediction = new VBuffer <float>(nbMappedClasses, _classMapping.Length, prediction.Values, indices);
                    };
                    getters[0] = localGetter;
                }
            }
            return(getters);
        }
        public void Run()
        {
            using (var ch = _host.Start("Run"))
            {
                var conv        = Conversions.Instance;
                var comp        = new SetOfKindsComparer();
                var dstToSrcMap = new Dictionary <HashSet <DataKind>, HashSet <DataKind> >(comp);
                var srcToDstMap = new Dictionary <DataKind, HashSet <DataKind> >();

                var kinds = Enum.GetValues(typeof(DataKind)).Cast <DataKind>().Distinct().OrderBy(k => k).ToArray();
                var types = kinds.Select(kind => PrimitiveType.FromKind(kind)).ToArray();

                HashSet <DataKind> nonIdentity = null;
                // For each kind and its associated type.
                for (int i = 0; i < types.Length; ++i)
                {
                    ch.AssertValue(types[i]);
                    var info = Utils.MarshalInvoke(KindReport <int>, types[i].RawType, ch, types[i]);

                    var      dstKinds = new HashSet <DataKind>();
                    Delegate del;
                    bool     isIdentity;
                    for (int j = 0; j < types.Length; ++j)
                    {
                        if (conv.TryGetStandardConversion(types[i], types[j], out del, out isIdentity))
                        {
                            dstKinds.Add(types[j].RawKind);
                        }
                    }
                    if (!conv.TryGetStandardConversion(types[i], types[i], out del, out isIdentity))
                    {
                        Utils.Add(ref nonIdentity, types[i].RawKind);
                    }
                    else
                    {
                        ch.Assert(isIdentity);
                    }

                    srcToDstMap[types[i].RawKind] = dstKinds;
                    HashSet <DataKind> srcKinds;
                    if (!dstToSrcMap.TryGetValue(dstKinds, out srcKinds))
                    {
                        dstToSrcMap[dstKinds] = srcKinds = new HashSet <DataKind>();
                    }
                    srcKinds.Add(types[i].RawKind);
                }

                // Now perform the final outputs.
                for (int i = 0; i < kinds.Length; ++i)
                {
                    var dsts = srcToDstMap[kinds[i]];
                    HashSet <DataKind> srcs;
                    if (!dstToSrcMap.TryGetValue(dsts, out srcs))
                    {
                        continue;
                    }
                    ch.Assert(Utils.Size(dsts) >= 1);
                    ch.Assert(Utils.Size(srcs) >= 1);
                    string srcStrings = string.Join(", ", srcs.OrderBy(k => k).Select(k => '`' + k.GetString() + '`'));
                    string dstStrings = string.Join(", ", dsts.OrderBy(k => k).Select(k => '`' + k.GetString() + '`'));
                    dstToSrcMap.Remove(dsts);
                    ch.Info(srcStrings + " | " + dstStrings);
                }

                if (Utils.Size(nonIdentity) > 0)
                {
                    ch.Warning("The following kinds did not have an identity conversion: {0}",
                               string.Join(", ", nonIdentity.OrderBy(k => k).Select(DataKindExtensions.GetString)));
                }
            }
        }
示例#10
0
        public static InternalSchemaDefinition Create(Type userType, SchemaDefinition userSchemaDefinition = null)
        {
            Contracts.AssertValue(userType);
            Contracts.AssertValueOrNull(userSchemaDefinition);

            if (userSchemaDefinition == null)
            {
                userSchemaDefinition = SchemaDefinition.Create(userType);
            }

            Column[] dstCols = new Column[userSchemaDefinition.Count];

            for (int i = 0; i < userSchemaDefinition.Count; ++i)
            {
                var col = userSchemaDefinition[i];
                if (col.MemberName == null)
                {
                    throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Null field name detected in schema definition");
                }

                bool      isVector;
                DataKind  kind;
                FieldInfo fieldInfo = null;

                if (!col.IsComputed)
                {
                    fieldInfo = userType.GetField(col.MemberName);

                    if (fieldInfo == null)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No field with name '{0}' found in type '{1}'",
                                                    col.MemberName,
                                                    userType.FullName);
                    }

                    //Clause to handle the field that may be used to expose the cursor channel.
                    //This field does not need a column.
                    if (fieldInfo.FieldType == typeof(IChannel))
                    {
                        continue;
                    }

                    GetVectorAndKind(fieldInfo, out isVector, out kind);
                }
                else
                {
                    var parameterType = col.ReturnType;
                    if (parameterType == null)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "No return parameter found in computed column.");
                    }
                    GetVectorAndKind(parameterType, "returnType", out isVector, out kind);
                }
                // Infer the column name.
                var colName = string.IsNullOrEmpty(col.ColumnName) ? col.MemberName : col.ColumnName;
                // REVIEW: Because order is defined, we allow duplicate column names, since producing an IDataView
                // with duplicate column names is completely legal. Possible objection is that we should make it less
                // convenient to produce "hidden" columns, since this may not be of practical use to users.

                ColumnType colType;
                if (col.ColumnType == null)
                {
                    // Infer a type as best we can.
                    PrimitiveType itemType = PrimitiveType.FromKind(kind);
                    colType = isVector ? new VectorType(itemType) : (ColumnType)itemType;
                }
                else
                {
                    // Make sure that the types are compatible with the declared type, including
                    // whether it is a vector type.
                    if (isVector != col.ColumnType.IsVector)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to be {1}, but type of associated field '{2}' is {3}",
                                                    colName, col.ColumnType.IsVector ? "vector" : "scalar", col.MemberName, isVector ? "vector" : "scalar");
                    }
                    if (kind != col.ColumnType.ItemType.RawKind)
                    {
                        throw Contracts.ExceptParam(nameof(userSchemaDefinition), "Column '{0}' is supposed to have item kind {1}, but associated field has kind {2}",
                                                    colName, col.ColumnType.ItemType.RawKind, kind);
                    }
                    colType = col.ColumnType;
                }

                dstCols[i] = col.IsComputed ?
                             new Column(colName, colType, col.Generator, col.Metadata)
                    : new Column(colName, colType, fieldInfo, col.Metadata);
            }
            return(new InternalSchemaDefinition(dstCols));
        }
示例#11
0
        public void TestEqualAndGetHashCode()
        {
            var dict = new Dictionary <ColumnType, string>();
            // add PrimitiveTypes, KeyType & corresponding VectorTypes
            PrimitiveType tmp;
            VectorType    tmp1, tmp2;

            foreach (var kind in (DataKind[])Enum.GetValues(typeof(DataKind)))
            {
                tmp = PrimitiveType.FromKind(kind);
                if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
                {
                    Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
                }
                dict[tmp] = tmp.ToString();
                for (int size = 0; size < 5; size++)
                {
                    tmp1 = new VectorType(tmp, size);
                    if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString())
                    {
                        Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates.");
                    }
                    dict[tmp1] = tmp1.ToString();
                    for (int size1 = 0; size1 < 5; size1++)
                    {
                        tmp2 = new VectorType(tmp, size, size1);
                        if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString())
                        {
                            Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates.");
                        }
                        dict[tmp2] = tmp2.ToString();
                    }
                }

                // KeyType & Vector
                if (!KeyType.IsValidDataKind(kind))
                {
                    continue;
                }
                for (ulong min = 0; min < 5; min++)
                {
                    for (var count = 0; count < 5; count++)
                    {
                        tmp = new KeyType(kind, min, count);
                        if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
                        {
                            Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
                        }
                        dict[tmp] = tmp.ToString();
                        for (int size = 0; size < 5; size++)
                        {
                            tmp1 = new VectorType(tmp, size);
                            if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString())
                            {
                                Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates.");
                            }
                            dict[tmp1] = tmp1.ToString();
                            for (int size1 = 0; size1 < 5; size1++)
                            {
                                tmp2 = new VectorType(tmp, size, size1);
                                if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString())
                                {
                                    Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates.");
                                }
                                dict[tmp2] = tmp2.ToString();
                            }
                        }
                    }
                    tmp = new KeyType(kind, min, 0, false);
                    if (dict.ContainsKey(tmp) && dict[tmp] != tmp.ToString())
                    {
                        Assert.True(false, dict[tmp] + " and " + tmp.ToString() + " are duplicates.");
                    }
                    dict[tmp] = tmp.ToString();
                    for (int size = 0; size < 5; size++)
                    {
                        tmp1 = new VectorType(tmp, size);
                        if (dict.ContainsKey(tmp1) && dict[tmp1] != tmp1.ToString())
                        {
                            Assert.True(false, dict[tmp1] + " and " + tmp1.ToString() + " are duplicates.");
                        }
                        dict[tmp1] = tmp1.ToString();
                        for (int size1 = 0; size1 < 5; size1++)
                        {
                            tmp2 = new VectorType(tmp, size, size1);
                            if (dict.ContainsKey(tmp2) && dict[tmp2] != tmp2.ToString())
                            {
                                Assert.True(false, dict[tmp2] + " and " + tmp2.ToString() + " are duplicates.");
                            }
                            dict[tmp2] = tmp2.ToString();
                        }
                    }
                }
            }

            // add ImageTypes
            for (int height = 1; height < 5; height++)
            {
                for (int width = 1; width < 5; width++)
                {
                    var tmp4 = new ImageType(height, width);
                    if (dict.ContainsKey(tmp4))
                    {
                        Assert.True(false, dict[tmp4] + " and " + tmp4.ToString() + " are duplicates.");
                    }
                    dict[tmp4] = tmp4.ToString();
                }
            }
        }
        protected override Delegate[] CreatePredictionGetters(Booster xgboostModel, IRow input, Func <int, bool> predicate)
        {
            var active = Utils.BuildArray(OutputSchema.ColumnCount, predicate);

            xgboostModel.LazyInit();
            var getters = new Delegate[1];

            if (active[0])
            {
                var             featureGetter  = RowCursorUtils.GetVecGetterAs <Float>(PrimitiveType.FromKind(DataKind.R4), input, InputSchema.Feature.Index);
                VBuffer <Float> features       = default(VBuffer <Float>);
                var             postProcessor  = Parent.GetOutputPostProcessor();
                VBuffer <Float> prediction     = default(VBuffer <Float>);
                int             expectedLength = input.Schema.GetColumnType(InputSchema.Feature.Index).VectorSize;
                var             xgboostBuffer  = Booster.CreateInternalBuffer();

                ValueGetter <Float> localGetter = (ref Float value) =>
                {
                    featureGetter(ref features);
                    Contracts.Assert(features.Length == expectedLength);
                    xgboostModel.Predict(ref features, ref prediction, ref xgboostBuffer);
                    value = prediction.Values[0];
                    postProcessor(ref value);
                };

                getters[0] = localGetter;
            }
            return(getters);
        }
示例#13
0
        private static bool TryCreateEx(IExceptionContext ectx, ColInfo info, DataKind kind, KeyRange range, out PrimitiveType itemType, out ColInfoEx ex)
        {
            ectx.AssertValue(info);
            ectx.Assert(Enum.IsDefined(typeof(DataKind), kind));

            ex = null;

            var typeSrc = info.TypeSrc;

            if (range != null)
            {
                itemType = TypeParsingUtils.ConstructKeyType(kind, range);
                if (!typeSrc.ItemType.IsKey && !typeSrc.ItemType.IsText)
                {
                    return(false);
                }
            }
            else if (!typeSrc.ItemType.IsKey)
            {
                itemType = PrimitiveType.FromKind(kind);
            }
            else if (!KeyType.IsValidDataKind(kind))
            {
                itemType = PrimitiveType.FromKind(kind);
                return(false);
            }
            else
            {
                var key = typeSrc.ItemType.AsKey;
                ectx.Assert(KeyType.IsValidDataKind(key.RawKind));
                int count = key.Count;
                // Technically, it's an error for the counts not to match, but we'll let the Conversions
                // code return false below. There's a possibility we'll change the standard conversions to
                // map out of bounds values to zero, in which case, this is the right thing to do.
                ulong max = kind.ToMaxInt();
                if ((ulong)count > max)
                {
                    count = (int)max;
                }
                itemType = new KeyType(kind, key.Min, count, key.Contiguous);
            }

            // Ensure that the conversion is legal. We don't actually cache the delegate here. It will get
            // re-fetched by the utils code when needed.
            bool     identity;
            Delegate del;

            if (!Runtime.Data.Conversion.Conversions.Instance.TryGetStandardConversion(typeSrc.ItemType, itemType, out del, out identity))
            {
                return(false);
            }

            ColumnType typeDst = itemType;

            if (typeSrc.IsVector)
            {
                typeDst = new VectorType(itemType, typeSrc.AsVector);
            }

            // An output column is transposable iff the input column was transposable.
            VectorType slotType = null;

            if (info.SlotTypeSrc != null)
            {
                slotType = new VectorType(itemType, info.SlotTypeSrc);
            }

            ex = new ColInfoEx(kind, range != null, typeDst, slotType);
            return(true);
        }