예제 #1
0
            private static INamedOnnxValueGetter CreateNamedOnnxValueGetter(DataViewRow input, System.Type onnxType, bool isVector, string colName, int colIndex, OnnxShape onnxShape)
            {
                var type = OnnxUtils.OnnxToMlNetType(onnxType).RawType;

                Contracts.AssertValue(type);
                return(Utils.MarshalInvoke(CreateNameOnnxValueGetter <int>, type, input, isVector, colName, colIndex, onnxShape));
            }
예제 #2
0
        private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) :
            base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer)))
        {
            Host.CheckValue(options, nameof(options));

            foreach (var col in options.InputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(options.InputColumns));
            }
            foreach (var col in options.OutputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns));
            }

            try
            {
                if (modelBytes == null)
                {
                    Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
                    Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
                    Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu);
                }
                else
                {
                    Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu);
                }
            }
            catch (OnnxRuntimeException e)
            {
                throw Host.Except(e, $"Error initializing model :{e.ToString()}");
            }

            var modelInfo = Model.ModelInfo;

            Inputs      = (options.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : options.InputColumns;
            Outputs     = (options.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : options.OutputColumns;
            OutputTypes = new DataViewType[Outputs.Length];
            var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;

            for (int i = 0; i < Outputs.Length; i++)
            {
                var idx = Model.OutputNames.IndexOf(Outputs[i]);
                if (idx < 0)
                {
                    throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model");
                }

                var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
                var shape          = outputNodeInfo.Shape;
                var dims           = AdjustDimensions(shape);
                OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
            }
            _options = options;
        }
예제 #3
0
            protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func <int, bool> activeOutput, out Action disposer)
            {
                disposer = null;
                Host.AssertValue(input);
                //Host.Assert(typeof(T) == _outputItemRawType);

                var outputCache          = new OutputCache();
                var activeOutputColNames = _parent.Outputs.Where((x, i) => activeOutput(i)).ToArray();
                var type = OnnxUtils.OnnxToMlNetType(_parent.Model.ModelInfo.OutputsInfo[iinfo].Type).RawType;

                Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
                var srcNamedValueGetters = GetNamedOnnxValueGetters(input, _parent.Inputs, _inputColIndices, _isInputVector, _inputOnnxTypes, _inputTensorShapes);

                return(Utils.MarshalInvoke(MakeGetter <int>, type, input, iinfo, srcNamedValueGetters, activeOutputColNames, outputCache));
            }
예제 #4
0
        /// <summary>
        /// Returns the <see cref="SchemaShape"/> of the schema which will be produced by the transformer.
        /// Used for schema propagation and verification in a pipeline.
        /// </summary>
        public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
        {
            Host.CheckValue(inputSchema, nameof(inputSchema));
            var result    = inputSchema.ToDictionary(x => x.Name);
            var resultDic = inputSchema.ToDictionary(x => x.Name);

            for (var i = 0; i < Transformer.Inputs.Length; i++)
            {
                var input = Transformer.Inputs[i];
                if (!inputSchema.TryFindColumn(input, out var col))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
                }
                if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
                }

                var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;
                var idx        = Transformer.Model.InputNames.IndexOf(input);
                if (idx < 0)
                {
                    throw Host.Except($"Column {input} doesn't match input node names of model.");
                }

                var inputNodeInfo = inputsInfo[idx];
                var expectedType  = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
                if (col.ItemType != expectedType)
                {
                    throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
                }
            }

            for (var i = 0; i < Transformer.Outputs.Length; i++)
            {
                resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
                                                                           Transformer.OutputTypes[i].IsKnownSizeVector() ? SchemaShape.Column.VectorKind.Vector
                    : SchemaShape.Column.VectorKind.VariableVector, Transformer.OutputTypes[i].GetItemType(), false);
            }
            return(new SchemaShape(resultDic.Values));
        }
예제 #5
0
            public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
                base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
            {
                _parent            = parent;
                _inputColIndices   = new int[_parent.Inputs.Length];
                _isInputVector     = new bool[_parent.Inputs.Length];
                _inputTensorShapes = new OnnxShape[_parent.Inputs.Length];
                _inputOnnxTypes    = new System.Type[_parent.Inputs.Length];

                var model = _parent.Model;

                for (int i = 0; i < _parent.Inputs.Length; i++)
                {
                    var idx = model.InputNames.IndexOf(_parent.Inputs[i]);
                    if (idx < 0)
                    {
                        throw Host.Except($"Column {_parent.Inputs[i]} doesn't match input node names of model");
                    }

                    var inputNodeInfo = model.ModelInfo.InputsInfo[idx];

                    var shape     = inputNodeInfo.Shape;
                    var inputType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);

                    var inputShape = AdjustDimensions(inputNodeInfo.Shape);
                    _inputTensorShapes[i] = inputShape.ToList();
                    _inputOnnxTypes[i]    = inputNodeInfo.Type;

                    var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
                    if (!col.HasValue)
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i]);
                    }
                    _inputColIndices[i] = col.Value.Index;

                    var type       = inputSchema[_inputColIndices[i]].Type;
                    var vectorType = type as VectorType;
                    _isInputVector[i] = vectorType != null;

                    if (vectorType != null && vectorType.Size == 0)
                    {
                        throw Host.Except($"Variable length input columns not supported");
                    }

                    if (type.GetItemType() != inputType)
                    {
                        throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputType.ToString(), type.ToString());
                    }

                    // If the column is one dimension we make sure that the total size of the Onnx shape matches.
                    // Compute the total size of the known dimensions of the shape.
                    int valCount = inputShape.Where(x => x > 0).Aggregate((x, y) => x * y);
                    // The column length should be divisible by this, so that the other dimensions can be integral.
                    int typeValueCount = type.GetValueCount();
                    if (typeValueCount % valCount != 0)
                    {
                        throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
                    }

                    //Host.Assert(_outputItemRawType == _outputColType.ItemType.RawType);
                }
            }