private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(DataViewRow input, System.Type onnxType, bool isVector, string colName, int colIndex, OnnxShape onnxShape) { var type = OnnxUtils.OnnxToMlNetType(onnxType).RawType; Contracts.AssertValue(type); return(Utils.MarshalInvoke(CreateNameOnnxValueGetter <int>, type, input, isVector, colName, colIndex, onnxShape)); }
private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer))) { Host.CheckValue(options, nameof(options)); foreach (var col in options.InputColumns) { Host.CheckNonWhiteSpace(col, nameof(options.InputColumns)); } foreach (var col in options.OutputColumns) { Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns)); } try { if (modelBytes == null) { Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile)); Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile); Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu); } else { Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu); } } catch (OnnxRuntimeException e) { throw Host.Except(e, $"Error initializing model :{e.ToString()}"); } var modelInfo = Model.ModelInfo; Inputs = (options.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : options.InputColumns; Outputs = (options.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : options.OutputColumns; OutputTypes = new DataViewType[Outputs.Length]; var numModelOutputs = Model.ModelInfo.OutputsInfo.Length; for (int i = 0; i < Outputs.Length; i++) { var idx = Model.OutputNames.IndexOf(Outputs[i]); if (idx < 0) { throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model"); } var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx]; var shape = outputNodeInfo.Shape; var dims = AdjustDimensions(shape); OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray()); } _options = options; }
protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func <int, bool> activeOutput, out Action disposer) { disposer = null; Host.AssertValue(input); //Host.Assert(typeof(T) == _outputItemRawType); var outputCache = new OutputCache(); var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray(); var type = OnnxUtils.OnnxToMlNetType(_parent.Model.ModelInfo.OutputsInfo[iinfo].Type).RawType; Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType); var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes); return(Utils.MarshalInvoke(MakeGetter <int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache)); }
/// <summary> /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer. /// Used for schema propagation and verification in a pipeline. /// </summary> public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); var resultDic = inputSchema.ToDictionary(x => x.Name); for (var i = 0; i < Transformer.Inputs.Length; i++) { var input = Transformer.Inputs[i]; if (!inputSchema.TryFindColumn(input, out var col)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); } if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector)) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString()); } var inputsInfo = Transformer.Model.ModelInfo.InputsInfo; var idx = Transformer.Model.InputNames.IndexOf(input); if (idx < 0) { throw Host.Except($"Column {input} doesn't match input node names of model."); } var inputNodeInfo = inputsInfo[idx]; var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type); if (col.ItemType != expectedType) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); } } for (var i = 0; i < Transformer.Outputs.Length; i++) { resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i], Transformer.OutputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].GetItemType(), false); } return(new SchemaShape(resultDic.Values)); }
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) { _parent = parent; _inputColIndices = new int[_parent.Inputs.Length]; _isInputVector = new bool[_parent.Inputs.Length]; _inputTensorShapes = new OnnxShape[_parent.Inputs.Length]; _inputOnnxTypes = new System.Type[_parent.Inputs.Length]; var model = _parent.Model; for (int i = 0; i < _parent.Inputs.Length; i++) { var idx = model.InputNames.IndexOf(_parent.Inputs[i]); if (idx < 0) { throw Host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model"); } var inputNodeInfo = model.ModelInfo.InputsInfo[idx]; var shape = inputNodeInfo.Shape; var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type); var inputShape = AdjustDimensions(inputNodeInfo.Shape); _inputTensorShapes[i] = inputShape.ToList(); _inputOnnxTypes[i] = inputNodeInfo.Type; var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]); if (!col.HasValue) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i]); } _inputColIndices[i] = col.Value.Index; var type = inputSchema[_inputColIndices[i]].Type; var vectorType = type as VectorType; _isInputVector[i] = vectorType != null; if (vectorType != null && vectorType.Size == 0) { throw Host.Except($"Variable length input columns not supported"); } if (type.GetItemType() != inputType) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputType.ToString(), type.ToString()); } // If the column is one dimension we make sure that the total size of the Onnx shape matches. // Compute the total size of the known dimensions of the shape. int valCount = inputShape.Where(x => x > 0).Aggregate((x, y) => x * y); // The column length should be divisible by this, so that the other dimensions can be integral. int typeValueCount = type.GetValueCount(); if (typeValueCount % valCount != 0) { throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}."); } //Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType); } }