internal static void GetColumnTypeShape(DataViewType type, out Column.VectorKind vecKind, out DataViewType itemType, out bool isKey) { if (type is VectorType vectorType) { if (vectorType.IsKnownSize) { vecKind = Column.VectorKind.Vector; } else { vecKind = Column.VectorKind.VariableVector; } itemType = vectorType.ItemType; } else { vecKind = Column.VectorKind.Scalar; itemType = type; } isKey = itemType is KeyType; if (isKey) { itemType = ColumnTypeExtensions.PrimitiveTypeFromType(itemType.RawType); } }
/// <summary> /// Converts a Onnx type, that follows the System.Type convention /// to the type system ML.NET recognizes (e.g. I4, I8, R4 etc.) /// </summary> /// <param name="type"></param> /// <returns></returns> public static PrimitiveDataViewType OnnxToMlNetType(Type type) { if (!_typeToKindMap.ContainsKey(type)) { throw Contracts.ExceptNotSupp("Onnx type not supported", type); } return(ColumnTypeExtensions.PrimitiveTypeFromKind(_typeToKindMap[type])); }
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var columns = new List <DataViewSchema.DetachedColumn>(); foreach (DateTimeEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced))) { columns.Add(new DataViewSchema.DetachedColumn(_parent._column.Prefix + column.ToString(), ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()))); } return(columns.ToArray()); }
private static ExpressionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: number of output columns // for each output column: // int: number of inputs // foreach input // int: Id of the input column name // int: Id of the expression // int: Id of the output column name // int: the index of the vector input (or -1) // int[]: The data kinds of the input columns var columnCount = ctx.Reader.ReadInt32(); env.CheckDecode(columnCount > 0); var columns = new ColumnInfo[columnCount]; for (int i = 0; i < columnCount; i++) { var inputSize = ctx.Reader.ReadInt32(); env.CheckDecode(inputSize >= 0); var inputColumnNames = new string[inputSize]; for (int j = 0; j < inputSize; j++) { inputColumnNames[j] = ctx.LoadNonEmptyString(); } var expression = ctx.LoadNonEmptyString(); var outputColumnName = ctx.LoadNonEmptyString(); var vectorInputColumn = ctx.Reader.ReadInt32(); env.CheckDecode(vectorInputColumn >= -1); var inputTypes = new DataViewType[inputSize]; for (int j = 0; j < inputSize; j++) { var dataKindIndex = ctx.Reader.ReadInt32(); var kind = InternalDataKindExtensions.FromIndex(dataKindIndex); inputTypes[j] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } var node = ExpressionEstimator.ParseAndBindLambda(env, expression, vectorInputColumn, inputTypes, out var perm); columns[i] = new ColumnInfo(env, inputColumnNames, inputTypes, expression, outputColumnName, vectorInputColumn, node, perm); } return(new ExpressionTransformer(env, columns)); }
/// <summary> /// Helper function to retrieve the Primitie type given a Type /// </summary> internal static PrimitiveType GetPrimitiveType(Type rawType, out bool isVectorType) { Type type = rawType; isVectorType = false; if (type.IsArray) { type = rawType.GetElementType(); isVectorType = true; } if (!type.TryGetDataKind(out DataKind kind)) { throw new InvalidOperationException($"Unsupported type {type} used in mapping."); } return(ColumnTypeExtensions.PrimitiveTypeFromKind(kind)); }
private DataViewType[] ParseTypes(string text, ref int ichMin, int ichLim) { int ichCol = text.IndexOf(':', ichMin); Contracts.Assert(ichMin < ichCol && ichCol < ichLim); string[] toks = text.Substring(ichMin, ichCol - ichMin).Split(','); var res = new DataViewType[toks.Length]; for (int i = 0; i < toks.Length; i++) { InternalDataKind kind; bool tmp = Enum.TryParse(toks[i], out kind); Contracts.Assert(tmp); res[i] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } ichMin = ichCol + 1; return(res); }
public SchemaShape GetOutputSchema(SchemaShape inputSchema) { var columns = inputSchema.ToDictionary(x => x.Name); foreach (ColumnsProduced column in Enum.GetValues(typeof(ColumnsProduced))) { columns[_options.Prefix + column.ToString()] = new SchemaShape.Column(_options.Prefix + column.ToString(), SchemaShape.Column.VectorKind.Scalar, ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()), false, null); } return(new SchemaShape(columns.Values)); }
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { return(_parent._columns.Select(x => new DataViewSchema.DetachedColumn(x.Name, ColumnTypeExtensions.PrimitiveTypeFromType(x.Type))).ToArray()); }
public void Run() { using (var ch = _host.Start("Run")) { var conv = Conversions.Instance; var comp = new SetOfKindsComparer(); var dstToSrcMap = new Dictionary <HashSet <DataKind>, HashSet <DataKind> >(comp); var srcToDstMap = new Dictionary <DataKind, HashSet <DataKind> >(); var kinds = Enum.GetValues(typeof(DataKind)).Cast <DataKind>().Distinct().OrderBy(k => k).ToArray(); var types = kinds.Select(kind => ColumnTypeExtensions.PrimitiveTypeFromKind(kind)).ToArray(); HashSet <DataKind> nonIdentity = null; // For each kind and its associated type. for (int i = 0; i < types.Length; ++i) { ch.AssertValue(types[i]); var info = Utils.MarshalInvoke(KindReport <int>, types[i].RawType, ch, types[i]); var dstKinds = new HashSet <DataKind>(); Delegate del; bool isIdentity; for (int j = 0; j < types.Length; ++j) { if (conv.TryGetStandardConversion(types[i], types[j], out del, out isIdentity)) { dstKinds.Add(types[j].GetRawKind()); } } if (!conv.TryGetStandardConversion(types[i], types[i], out del, out isIdentity)) { Utils.Add(ref nonIdentity, types[i].GetRawKind()); } else { ch.Assert(isIdentity); } srcToDstMap[types[i].GetRawKind()] = dstKinds; HashSet <DataKind> srcKinds; if (!dstToSrcMap.TryGetValue(dstKinds, out srcKinds)) { dstToSrcMap[dstKinds] = srcKinds = new HashSet <DataKind>(); } srcKinds.Add(types[i].GetRawKind()); } // Now perform the final outputs. for (int i = 0; i < kinds.Length; ++i) { var dsts = srcToDstMap[kinds[i]]; HashSet <DataKind> srcs; if (!dstToSrcMap.TryGetValue(dsts, out srcs)) { continue; } ch.Assert(Utils.Size(dsts) >= 1); ch.Assert(Utils.Size(srcs) >= 1); string srcStrings = string.Join(", ", srcs.OrderBy(k => k).Select(k => '`' + k.GetString() + '`')); string dstStrings = string.Join(", ", dsts.OrderBy(k => k).Select(k => '`' + k.GetString() + '`')); dstToSrcMap.Remove(dsts); ch.Info(srcStrings + " | " + dstStrings); } if (Utils.Size(nonIdentity) > 0) { ch.Warning("The following kinds did not have an identity conversion: {0}", string.Join(", ", nonIdentity.OrderBy(k => k).Select(DataKindExtensions.GetString))); } } }