예제 #1
0
 public NamedOnnxValue GetNamedOnnxValue()
 {
     _srcgetter(ref _vBuffer);
     _vBuffer.CopyToDense(ref _vBufferDense);
     return(OnnxUtils.CreateNamedOnnxValue(_colName, _vBufferDense.GetValues(), _tensorShape));
 }
예제 #2
0
            public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
                base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
            {
                _parent            = parent;
                _inputColIndices   = new int[_parent.Inputs.Length];
                _isInputVector     = new bool[_parent.Inputs.Length];
                _inputTensorShapes = new OnnxShape[_parent.Inputs.Length];
                _inputOnnxTypes    = new System.Type[_parent.Inputs.Length];

                var model = _parent.Model;

                for (int i = 0; i < _parent.Inputs.Length; i++)
                {
                    var idx = model.InputNames.IndexOf(_parent.Inputs[i]);
                    if (idx < 0)
                    {
                        throw Host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model");
                    }

                    var inputNodeInfo = model.ModelInfo.InputsInfo[idx];

                    var shape     = inputNodeInfo.Shape;
                    var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);

                    var inputShape = AdjustDimensions(inputNodeInfo.Shape);
                    _inputTensorShapes[i] = inputShape.ToList();
                    _inputOnnxTypes[i]    = inputNodeInfo.Type;

                    var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
                    if (!col.HasValue)
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i]);
                    }
                    _inputColIndices[i] = col.Value.Index;

                    var type       = inputSchema[_inputColIndices[i]].Type;
                    var vectorType = type as VectorDataViewType;
                    _isInputVector[i] = vectorType != null;

                    if (vectorType != null && vectorType.Size == 0)
                    {
                        throw Host.Except($"Variable length input columns not supported");
                    }

                    if (type.GetItemType() != inputType)
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputType.ToString(), type.ToString());
                    }

                    // If the column is one dimension we make sure that the total size of the Onnx shape matches.
                    // Compute the total size of the known dimensions of the shape.
                    int valCount = inputShape.Where(x => x > 0).Aggregate((x, y) => x * y);
                    // The column length should be divisible by this, so that the other dimensions can be integral.
                    int typeValueCount = type.GetValueCount();
                    if (typeValueCount % valCount != 0)
                    {
                        throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
                    }

                    //Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType);
                }
            }