Esempio n. 1
0
        private OnnxTransformer(IHostEnvironment env, Arguments args, byte[] modelBytes = null) :
            base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer)))
        {
            Host.CheckValue(args, nameof(args));

            foreach (var col in args.InputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(args.InputColumns));
            }
            foreach (var col in args.OutputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(args.OutputColumns));
            }

            try
            {
                if (modelBytes == null)
                {
                    Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
                    Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
                    Model = new OnnxModel(args.ModelFile, args.GpuDeviceId, args.FallbackToCpu);
                }
                else
                {
                    Model = OnnxModel.CreateFromBytes(modelBytes, args.GpuDeviceId, args.FallbackToCpu);
                }
            }
            catch (OnnxRuntimeException e)
            {
                throw Host.Except(e, $"Error initializing model :{e.ToString()}");
            }

            var modelInfo = Model.ModelInfo;

            Inputs      = (args.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : args.InputColumns;
            Outputs     = (args.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : args.OutputColumns;
            OutputTypes = new ColumnType[Outputs.Length];
            var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;

            for (int i = 0; i < Outputs.Length; i++)
            {
                var idx = Model.OutputNames.IndexOf(Outputs[i]);
                if (idx < 0)
                {
                    throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model");
                }

                var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
                var shape          = outputNodeInfo.Shape;
                var dims           = AdjustDimensions(shape);
                OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
            }
            _args = args;
        }
        private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = null) :
            base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransform)))
        {
            Host.CheckValue(args, nameof(args));

            foreach (var col in args.InputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(args.InputColumns));
            }
            foreach (var col in args.OutputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(args.OutputColumns));
            }

            if (modelBytes == null)
            {
                Host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
                Host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
                Model = new OnnxModel(args.ModelFile);
            }
            else
            {
                Model = OnnxModel.CreateFromBytes(modelBytes);
            }

            var modelInfo = Model.ModelInfo;

            Inputs      = args.InputColumns;
            Outputs     = args.OutputColumns;
            OutputTypes = new ColumnType[args.OutputColumns.Length];
            var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;

            for (int i = 0; i < args.OutputColumns.Length; i++)
            {
                var idx = Model.OutputNames.IndexOf(args.OutputColumns[i]);
                if (idx < 0)
                {
                    throw Host.Except($"Column {args.OutputColumns[i]} doesn't match output node names of model");
                }

                var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
                var shape          = outputNodeInfo.Shape;
                var dims           = AdjustDimensions(shape);
                OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims);
            }
            _args = args;
        }
        private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes = null)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register(RegistrationName);
            _host.CheckValue(args, nameof(args));
            _host.CheckNonWhiteSpace(args.InputColumn, nameof(args.InputColumn));
            _host.CheckNonWhiteSpace(args.OutputColumn, nameof(args.OutputColumn));

            if (modelBytes == null)
            {
                _host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile));
                _host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile));
                Model = new OnnxModel(args.ModelFile);
            }
            else
            {
                Model = OnnxModel.CreateFromBytes(modelBytes);
            }

            var modelInfo = Model.ModelInfo;

            if (modelInfo.InputsInfo.Length != 1)
            {
                throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s).");
            }
            if (modelInfo.OutputsInfo.Length != 1)
            {
                throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s).");
            }

            Input  = args.InputColumn;
            Output = args.OutputColumn;

            var outputNodeInfo = Model.ModelInfo.OutputsInfo[0];
            var type           = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
            var shape          = outputNodeInfo.Shape;
            var dims           = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select(x => (int)x).ToArray() : new[] { 0 };

            OutputType = new VectorType(type, dims);
            _args      = args;
        }