Пример #1
0
 /// <summary>
 /// Converts a Onnx type, that follows the System.Type convention
 /// to the type system ML.NET recognizes (e.g. I4, I8, R4 etc.)
 /// </summary>
 /// <param name="type"></param>
 /// <returns></returns>
 public static PrimitiveDataViewType OnnxToMlNetType(Type type)
 {
     if (!_typeToKindMap.ContainsKey(type))
     {
         throw Contracts.ExceptNotSupp("Onnx type not supported", type);
     }
     return(ColumnTypeExtensions.PrimitiveTypeFromKind(_typeToKindMap[type]));
 }
Пример #2
0
        private static ExpressionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: number of output columns
            // for each output column:
            //   int: number of inputs
            //   foreach input
            //     int: Id of the input column name
            //   int: Id of the expression
            //   int: Id of the output column name
            //   int: the index of the vector input (or -1)
            //   int[]: The data kinds of the input columns

            var columnCount = ctx.Reader.ReadInt32();

            env.CheckDecode(columnCount > 0);

            var columns = new ColumnInfo[columnCount];

            for (int i = 0; i < columnCount; i++)
            {
                var inputSize = ctx.Reader.ReadInt32();
                env.CheckDecode(inputSize >= 0);
                var inputColumnNames = new string[inputSize];
                for (int j = 0; j < inputSize; j++)
                {
                    inputColumnNames[j] = ctx.LoadNonEmptyString();
                }
                var expression        = ctx.LoadNonEmptyString();
                var outputColumnName  = ctx.LoadNonEmptyString();
                var vectorInputColumn = ctx.Reader.ReadInt32();
                env.CheckDecode(vectorInputColumn >= -1);

                var inputTypes = new DataViewType[inputSize];
                for (int j = 0; j < inputSize; j++)
                {
                    var dataKindIndex = ctx.Reader.ReadInt32();
                    var kind          = InternalDataKindExtensions.FromIndex(dataKindIndex);
                    inputTypes[j] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
                }
                var node = ExpressionEstimator.ParseAndBindLambda(env, expression, vectorInputColumn, inputTypes, out var perm);
                columns[i] = new ColumnInfo(env, inputColumnNames, inputTypes, expression, outputColumnName, vectorInputColumn, node, perm);
            }
            return(new ExpressionTransformer(env, columns));
        }
Пример #3
0
        /// <summary>
        /// Helper function to retrieve the Primitie type given a Type
        /// </summary>
        internal static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorType)
        {
            Type type = rawType;

            isVectorType = false;
            if (type.IsArray)
            {
                type         = rawType.GetElementType();
                isVectorType = true;
            }

            if (!type.TryGetDataKind(out DataKind kind))
            {
                throw new InvalidOperationException($"Unsupported type {type} used in mapping.");
            }

            return(ColumnTypeExtensions.PrimitiveTypeFromKind(kind));
        }
        private DataViewType[] ParseTypes(string text, ref int ichMin, int ichLim)
        {
            int ichCol = text.IndexOf(':', ichMin);

            Contracts.Assert(ichMin < ichCol && ichCol < ichLim);
            string[] toks = text.Substring(ichMin, ichCol - ichMin).Split(',');
            var      res  = new DataViewType[toks.Length];

            for (int i = 0; i < toks.Length; i++)
            {
                InternalDataKind kind;
                bool             tmp = Enum.TryParse(toks[i], out kind);
                Contracts.Assert(tmp);
                res[i] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
            }
            ichMin = ichCol + 1;
            return(res);
        }
Пример #5
0
        public void Run()
        {
            using (var ch = _host.Start("Run"))
            {
                var conv        = Conversions.Instance;
                var comp        = new SetOfKindsComparer();
                var dstToSrcMap = new Dictionary <HashSet <DataKind>, HashSet <DataKind> >(comp);
                var srcToDstMap = new Dictionary <DataKind, HashSet <DataKind> >();

                var kinds = Enum.GetValues(typeof(DataKind)).Cast <DataKind>().Distinct().OrderBy(k => k).ToArray();
                var types = kinds.Select(kind => ColumnTypeExtensions.PrimitiveTypeFromKind(kind)).ToArray();

                HashSet <DataKind> nonIdentity = null;
                // For each kind and its associated type.
                for (int i = 0; i < types.Length; ++i)
                {
                    ch.AssertValue(types[i]);
                    var info = Utils.MarshalInvoke(KindReport <int>, types[i].RawType, ch, types[i]);

                    var      dstKinds = new HashSet <DataKind>();
                    Delegate del;
                    bool     isIdentity;
                    for (int j = 0; j < types.Length; ++j)
                    {
                        if (conv.TryGetStandardConversion(types[i], types[j], out del, out isIdentity))
                        {
                            dstKinds.Add(types[j].GetRawKind());
                        }
                    }
                    if (!conv.TryGetStandardConversion(types[i], types[i], out del, out isIdentity))
                    {
                        Utils.Add(ref nonIdentity, types[i].GetRawKind());
                    }
                    else
                    {
                        ch.Assert(isIdentity);
                    }

                    srcToDstMap[types[i].GetRawKind()] = dstKinds;
                    HashSet <DataKind> srcKinds;
                    if (!dstToSrcMap.TryGetValue(dstKinds, out srcKinds))
                    {
                        dstToSrcMap[dstKinds] = srcKinds = new HashSet <DataKind>();
                    }
                    srcKinds.Add(types[i].GetRawKind());
                }

                // Now perform the final outputs.
                for (int i = 0; i < kinds.Length; ++i)
                {
                    var dsts = srcToDstMap[kinds[i]];
                    HashSet <DataKind> srcs;
                    if (!dstToSrcMap.TryGetValue(dsts, out srcs))
                    {
                        continue;
                    }
                    ch.Assert(Utils.Size(dsts) >= 1);
                    ch.Assert(Utils.Size(srcs) >= 1);
                    string srcStrings = string.Join(", ", srcs.OrderBy(k => k).Select(k => '`' + k.GetString() + '`'));
                    string dstStrings = string.Join(", ", dsts.OrderBy(k => k).Select(k => '`' + k.GetString() + '`'));
                    dstToSrcMap.Remove(dsts);
                    ch.Info(srcStrings + " | " + dstStrings);
                }

                if (Utils.Size(nonIdentity) > 0)
                {
                    ch.Warning("The following kinds did not have an identity conversion: {0}",
                               string.Join(", ", nonIdentity.OrderBy(k => k).Select(DataKindExtensions.GetString)));
                }
            }
        }