Ejemplo n.º 1
0
            public Mapper(OnnxTransform parent, Schema inputSchema) :
                base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema)
            {
                _parent            = parent;
                _inputColIndices   = new int[_parent.Inputs.Length];
                _isInputVector     = new bool[_parent.Inputs.Length];
                _inputTensorShapes = new OnnxShape[_parent.Inputs.Length];
                _inputOnnxTypes    = new DataType[_parent.Inputs.Length];

                var model = _parent.Model;

                for (int i = 0; i < _parent.Inputs.Length; i++)
                {
                    var idx = model.InputNames.IndexOf(_parent.Inputs[i]);
                    if (idx < 0)
                    {
                        throw Host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model");
                    }

                    var inputNodeInfo = model.ModelInfo.InputsInfo[idx];

                    var shape     = inputNodeInfo.Shape;
                    var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);

                    var inputShape = inputNodeInfo.Shape;
                    _inputTensorShapes[i] = inputShape;
                    _inputOnnxTypes[i]    = inputNodeInfo.Type;

                    if (!inputSchema.TryGetColumnIndex(_parent.Inputs[i], out _inputColIndices[i]))
                    {
                        throw Host.Except($"Column {_parent.Inputs[i]} doesn't exist");
                    }

                    var type = inputSchema.GetColumnType(_inputColIndices[i]);
                    _isInputVector[i] = type.IsVector;

                    if (type.IsVector && type.VectorSize == 0)
                    {
                        throw Host.Except($"Variable length input columns not supported");
                    }

                    if (type.ItemType != inputType)
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputType.ToString(), type.ToString());
                    }

                    // If the column is one dimension we make sure that the total size of the Onnx shape matches.
                    // Compute the total size of the known dimensions of the shape.
                    int valCount = inputShape.Select(x => (int)x).Where(x => x > 0).Aggregate((x, y) => x * y);
                    // The column length should be divisible by this, so that the other dimensions can be integral.
                    if (type.ValueCount % valCount != 0)
                    {
                        throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {type.ValueCount}.");
                    }

                    //Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType);
                }
            }
            public Mapper(IHostEnvironment env, OnnxTransform parent, Schema inputSchema)
            {
                Contracts.CheckValue(env, nameof(env));
                _host = env.Register(nameof(Mapper));
                _host.CheckValue(inputSchema, nameof(inputSchema));
                _host.CheckValue(parent, nameof(parent));

                _parent = parent;
                var model = _parent.Model;

                _idvToTensorAdapter = new IdvToTensorAdapter(inputSchema, parent._args.InputColumn,
                                                             model.ModelInfo.InputsInfo[0]);

                // TODO: Remove assumption below
                // Assume first output dimension is 1
                var outputNodeInfo = model.ModelInfo.OutputsInfo[0];
                var inputNodeInfo  = model.ModelInfo.InputsInfo[0];

                int[] dims           = outputNodeInfo.Shape.Skip(1).Select(x => (int)x).ToArray();
                var   outputItemType = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
                var   inputShape     = inputNodeInfo.Shape;

                _outputColType     = new VectorType(outputItemType, dims);
                _outputColName     = _parent.Output;
                _outputItemRawType = outputItemType.RawType;

                int inColIndex;

                if (!inputSchema.TryGetColumnIndex(_parent.Input, out inColIndex))
                {
                    throw _host.Except($"Column {_parent.Input} doesn't exist");
                }

                var type = inputSchema.GetColumnType(inColIndex);

                if (type.IsVector && type.VectorSize == 0)
                {
                    throw _host.Except($"Variable length input columns not supported");
                }

                if (type.ItemType != outputItemType)
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Input, outputItemType.ToString(), type.ToString());
                }

                // If the column is one dimension we make sure that the total size of the TF shape matches.
                // Compute the total size of the known dimensions of the shape.
                int valCount = inputShape.Select(x => (int)x).Where(x => x > 0).Aggregate((x, y) => x * y);

                // The column length should be divisible by this, so that the other dimensions can be integral.
                if (type.ValueCount % valCount != 0)
                {
                    throw Contracts.Except($"Input shape mismatch: Input '{_outputColName}' has shape {String.Join(",", inputShape)}, but input data is of length {type.ValueCount}.");
                }

                _host.Assert(_outputItemRawType == _outputColType.ItemType.RawType);
            }