예제 #1
0
            private static ITensorValueGetter CreateTensorValueGetter(IRow input, DataType onnxType, bool isVector, int colIndex, OnnxShape onnxShape)
            {
                var type = OnnxUtils.OnnxToMlNetType(onnxType).RawType;

                Contracts.AssertValue(type);
                return(Utils.MarshalInvoke(CreateTensorValueGetter <int>, type, input, isVector, colIndex, onnxShape));
            }
예제 #2
0
            private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(Row input, System.Type onnxType, bool isVector, string colName, int colIndex, OnnxShape onnxShape)
            {
                var type = OnnxUtils.OnnxToMlNetType(onnxType).RawType;

                Contracts.AssertValue(type);
                return(Utils.MarshalInvoke(CreateNameOnnxValueGetter <int>, type, input, isVector, colName, colIndex, onnxShape));
            }
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result    = inputSchema.Columns.ToDictionary(x => x.Name);
            var resultDic = inputSchema.Columns.ToDictionary(x => x.Name);

            var input = Transformer.Input;

            if (!inputSchema.TryFindColumn(input, out var col))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
            }
            if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());
            }
            var inputNodeInfo = Transformer.Model.ModelInfo.InputsInfo[0];
            var expectedType  = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);

            if (col.ItemType != expectedType)
            {
                throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
            }

            resultDic[Transformer.Output] = new SchemaShape.Column(Transformer.Output,
                                                                   Transformer.OutputType.IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector
                : SchemaShape.Column.VectorKind.VariableVector, NumberType.R4, false);

            return(new SchemaShape(resultDic.Values));
        }
예제 #4
0
            public Mapper(OnnxTransformer parent, Schema inputSchema) :
                base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema)
            {
                _parent            = parent;
                _inputColIndices   = new int[_parent.Inputs.Length];
                _isInputVector     = new bool[_parent.Inputs.Length];
                _inputTensorShapes = new OnnxShape[_parent.Inputs.Length];
                _inputOnnxTypes    = new System.Type[_parent.Inputs.Length];

                var model = _parent.Model;

                for (int i = 0; i < _parent.Inputs.Length; i++)
                {
                    var idx = model.InputNames.IndexOf(_parent.Inputs[i]);
                    if (idx < 0)
                    {
                        throw Host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model");
                    }

                    var inputNodeInfo = model.ModelInfo.InputsInfo[idx];

                    var shape     = inputNodeInfo.Shape;
                    var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);

                    var inputShape = AdjustDimensions(inputNodeInfo.Shape);
                    _inputTensorShapes[i] = inputShape.ToList();
                    _inputOnnxTypes[i]    = inputNodeInfo.Type;

                    var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
                    if (!col.HasValue)
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i]);
                    }
                    _inputColIndices[i] = col.Value.Index;

                    var type = inputSchema[_inputColIndices[i]].Type;
                    _isInputVector[i] = type.IsVector;

                    if (type.IsVector && type.VectorSize == 0)
                    {
                        throw Host.Except($"Variable length input columns not supported");
                    }

                    if (type.ItemType != inputType)
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputType.ToString(), type.ToString());
                    }

                    // If the column is one dimension we make sure that the total size of the Onnx shape matches.
                    // Compute the total size of the known dimensions of the shape.
                    int valCount = inputShape.Select(x => (int)x).Where(x => x > 0).Aggregate((x, y) => x * y);
                    // The column length should be divisible by this, so that the other dimensions can be integral.
                    if (type.ValueCount % valCount != 0)
                    {
                        throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {type.ValueCount}.");
                    }

                    //Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType);
                }
            }
예제 #5
0
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result = inputSchema.ToDictionary(x => x.Name);
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            for (var i = 0; i < Transformer.Inputs.Length; i++)
            {
                var input = Transformer.Inputs[i];
                if (!inputSchema.TryFindColumn(input, out var col))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());

                var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
                var idx = Transformer.Model.InputNames.IndexOf(input);
                if (idx < 0)
                    throw Host.Except($"Column {input} doesn't match input node names of model.");

                var inputNodeInfo = inputsInfo[idx];
                var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
                if (col.ItemType != expectedType)
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
            }

            for (var i = 0; i < Transformer.Outputs.Length; i++)
            {
                resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
                    Transformer.OutputTypes[i].IsKnownSizeVector ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].ItemType, false);
            }
            return new SchemaShape(resultDic.Values);
        }
            public Mapper(IHostEnvironment env, OnnxTransform parent, Schema inputSchema)
            {
                Contracts.CheckValue(env, nameof(env));
                _host = env.Register(nameof(Mapper));
                _host.CheckValue(inputSchema, nameof(inputSchema));
                _host.CheckValue(parent, nameof(parent));

                _parent = parent;
                var model = _parent.Model;

                _idvToTensorAdapter = new IdvToTensorAdapter(inputSchema, parent._args.InputColumn,
                                                             model.ModelInfo.InputsInfo[0]);

                // TODO: Remove assumption below
                // Assume first output dimension is 1
                var outputNodeInfo = model.ModelInfo.OutputsInfo[0];
                var inputNodeInfo  = model.ModelInfo.InputsInfo[0];

                int[] dims           = outputNodeInfo.Shape.Skip(1).Select(x => (int)x).ToArray();
                var   outputItemType = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
                var   inputShape     = inputNodeInfo.Shape;

                _outputColType     = new VectorType(outputItemType, dims);
                _outputColName     = _parent.Output;
                _outputItemRawType = outputItemType.RawType;

                int inColIndex;

                if (!inputSchema.TryGetColumnIndex(_parent.Input, out inColIndex))
                {
                    throw _host.Except($"Column {_parent.Input} doesn't exist");
                }

                var type = inputSchema.GetColumnType(inColIndex);

                if (type.IsVector && type.VectorSize == 0)
                {
                    throw _host.Except($"Variable length input columns not supported");
                }

                if (type.ItemType != outputItemType)
                {
                    throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Input, outputItemType.ToString(), type.ToString());
                }

                // If the column is one dimension we make sure that the total size of the TF shape matches.
                // Compute the total size of the known dimensions of the shape.
                int valCount = inputShape.Select(x => (int)x).Where(x => x > 0).Aggregate((x, y) => x * y);

                // The column length should be divisible by this, so that the other dimensions can be integral.
                if (type.ValueCount % valCount != 0)
                {
                    throw Contracts.Except($"Input shape mismatch: Input '{_outputColName}' has shape {String.Join(",", inputShape)}, but input data is of length {type.ValueCount}.");
                }

                _host.Assert(_outputItemRawType == _outputColType.ItemType.RawType);
            }
예제 #7
0
        private OnnxTransformer(IHostEnvironment env, Arguments args, byte[] modelBytes = null) :
            base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer)))
        {
            Host.CheckValue(args, nameof(args));

            foreach (var col in args.InputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(args.InputColumns));
            }
            foreach (var col in args.OutputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(args.OutputColumns));
            }

            try
            {
                if (modelBytes == null)
                {
                    Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
                    Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
                    Model = new OnnxModel(args.ModelFile, args.GpuDeviceId, args.FallbackToCpu);
                }
                else
                {
                    Model = OnnxModel.CreateFromBytes(modelBytes, args.GpuDeviceId, args.FallbackToCpu);
                }
            }
            catch (OnnxRuntimeException e)
            {
                throw Host.Except(e, $"Error initializing model :{e.ToString()}");
            }

            var modelInfo = Model.ModelInfo;

            Inputs      = (args.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : args.InputColumns;
            Outputs     = (args.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : args.OutputColumns;
            OutputTypes = new ColumnType[Outputs.Length];
            var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;

            for (int i = 0; i < Outputs.Length; i++)
            {
                var idx = Model.OutputNames.IndexOf(Outputs[i]);
                if (idx < 0)
                {
                    throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model");
                }

                var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
                var shape          = outputNodeInfo.Shape;
                var dims           = AdjustDimensions(shape);
                OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
            }
            _args = args;
        }
예제 #8
0
            protected override Delegate MakeGetter(Row input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
            {
                disposer = null;
                Host.AssertValue(input);
                //Host.Assert(typeof(T) == _outputItemRawType);

                var outputCache = new OutputCache();
                var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
                var type = OnnxUtils.OnnxToMlNetType(_parent.Model.ModelInfo.OutputsInfo[iinfo].Type).RawType;
                Host.Assert(type == _parent.OutputTypes[iinfo].ItemType.RawType);
                var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes);
                return Utils.MarshalInvoke(MakeGetter<int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache);
            }
예제 #9
0
        private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = null) :
            base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransform)))
        {
            Host.CheckValue(args, nameof(args));

            foreach (var col in args.InputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(args.InputColumns));
            }
            foreach (var col in args.OutputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(args.OutputColumns));
            }

            if (modelBytes == null)
            {
                Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
                Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
                Model = new OnnxModel(args.ModelFile);
            }
            else
            {
                Model = OnnxModel.CreateFromBytes(modelBytes);
            }

            var modelInfo = Model.ModelInfo;

            Inputs      = args.InputColumns;
            Outputs     = args.OutputColumns;
            OutputTypes = new ColumnType[args.OutputColumns.Length];
            var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;

            for (int i = 0; i < args.OutputColumns.Length; i++)
            {
                var idx = Model.OutputNames.IndexOf(args.OutputColumns[i]);
                if (idx < 0)
                {
                    throw Host.Except($"Column {args.OutputColumns[i]} doesn't match output node names of model");
                }

                var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
                var shape          = outputNodeInfo.Shape;
                var dims           = AdjustDimensions(shape);
                OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims);
            }
            _args = args;
        }
        private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = null)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.CheckValue(args, nameof(args));
            _host.CheckNonWhiteSpace(args.InputColumn, nameof(args.InputColumn));
            _host.CheckNonWhiteSpace(args.OutputColumn, nameof(args.OutputColumn));

            if (modelBytes == null)
            {
                _host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
                _host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
                Model = new OnnxModel(args.ModelFile);
            }
            else
            {
                Model = OnnxModel.CreateFromBytes(modelBytes);
            }

            var modelInfo = Model.ModelInfo;

            if (modelInfo.InputsInfo.Length != 1)
            {
                throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s).");
            }
            if (modelInfo.OutputsInfo.Length != 1)
            {
                throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s).");
            }

            Input  = args.InputColumn;
            Output = args.OutputColumn;

            var outputNodeInfo = Model.ModelInfo.OutputsInfo[0];
            var type           = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
            var shape          = outputNodeInfo.Shape;
            var dims           = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 };

            OutputType = new VectorType(type, dims);
            _args      = args;
        }