public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) { _parent = parent; _inputColIndices = new int[_parent.Inputs.Length]; _inputTensorShapes = new OnnxShape[_parent.Inputs.Length]; _inputOnnxTypes = new Type[_parent.Inputs.Length]; var model = _parent.Model; for (int i = 0; i < _parent.Inputs.Length; i++) { var inputNodeInfo = model.ModelInfo.GetInput(_parent.Inputs[i]); var shape = inputNodeInfo.Shape; var inputShape = AdjustDimensions(inputNodeInfo.Shape); _inputTensorShapes[i] = inputShape.ToList(); _inputOnnxTypes[i] = inputNodeInfo.TypeInOnnxRuntime; var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]); if (!col.HasValue) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i]); } _inputColIndices[i] = col.Value.Index; var type = inputSchema[_inputColIndices[i]].Type; var vectorType = type as VectorDataViewType; if (vectorType != null && vectorType.Size == 0) { throw Host.Except($"Variable length input columns not supported"); } if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType()) { // If the ONNX model input node expects a type that mismatches with the type of the input IDataView column that is provided // then throw an exception. // This is done except in the case where the ONNX model input node expects a UInt32 but the input column is actually KeyDataViewType // This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426 if (!(type.GetItemType() is KeyDataViewType && inputNodeInfo.DataViewType.GetItemType().RawType == typeof(UInt32))) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString()); } } // If the column is one dimension we make sure that the total size of the Onnx shape matches. // Compute the total size of the known dimensions of the shape. int valCount = inputShape.Where(x => x > 0).Aggregate((x, y) => x * y); // The column length should be divisible by this, so that the other dimensions can be integral. int typeValueCount = type.GetValueCount(); if (typeValueCount % valCount != 0) { throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}."); } } }
public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent) { _parent = parent; _inputColIndices = new int[_parent.Inputs.Length]; _inputTensorShapes = new OnnxShape[_parent.Inputs.Length]; _inputOnnxTypes = new Type[_parent.Inputs.Length]; var model = _parent.Model; for (int i = 0; i < _parent.Inputs.Length; i++) { var inputNodeInfo = model.ModelInfo.GetInput(_parent.Inputs[i]); var shape = inputNodeInfo.Shape; var inputShape = AdjustDimensions(inputNodeInfo.Shape); _inputTensorShapes[i] = inputShape.ToList(); _inputOnnxTypes[i] = inputNodeInfo.TypeInOnnxRuntime; var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]); if (!col.HasValue) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i]); } _inputColIndices[i] = col.Value.Index; var type = inputSchema[_inputColIndices[i]].Type; var vectorType = type as VectorDataViewType; if (vectorType != null && vectorType.Size == 0) { throw Host.Except($"Variable length input columns not supported"); } if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType()) { throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString()); } // If the column is one dimension we make sure that the total size of the Onnx shape matches. // Compute the total size of the known dimensions of the shape. int valCount = inputShape.Where(x => x > 0).Aggregate((x, y) => x * y); // The column length should be divisible by this, so that the other dimensions can be integral. int typeValueCount = type.GetValueCount(); if (typeValueCount % valCount != 0) { throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}."); } } }