예제 #1
0
        private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) :
            base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer)))
        {
            Host.CheckValue(options, nameof(options));

            foreach (var col in options.InputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(options.InputColumns));
            }
            foreach (var col in options.OutputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns));
            }

            try
            {
                if (modelBytes == null)
                {
                    Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
                    Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
                    Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu);
                }
                else
                {
                    Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu);
                }
            }
            catch (OnnxRuntimeException e)
            {
                throw Host.Except(e, $"Error initializing model :{e.ToString()}");
            }

            var modelInfo = Model.ModelInfo;

            Inputs      = (options.InputColumns.Count() == 0) ? Model.InputNames.ToArray() : options.InputColumns;
            Outputs     = (options.OutputColumns.Count() == 0) ? Model.OutputNames.ToArray() : options.OutputColumns;
            OutputTypes = new DataViewType[Outputs.Length];
            var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;

            for (int i = 0; i < Outputs.Length; i++)
            {
                var idx = Model.OutputNames.IndexOf(Outputs[i]);
                if (idx < 0)
                {
                    throw Host.Except($"Column {Outputs[i]} doesn't match output node names of model");
                }

                var outputNodeInfo = Model.ModelInfo.OutputsInfo[idx];
                var shape          = outputNodeInfo.Shape;
                var dims           = AdjustDimensions(shape);
                OutputTypes[i] = new VectorType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
            }
            _options = options;
        }
예제 #2
0
        private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) :
            base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer)))
        {
            Host.CheckValue(options, nameof(options));

            foreach (var col in options.InputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(options.InputColumns));
            }
            foreach (var col in options.OutputColumns)
            {
                Host.CheckNonWhiteSpace(col, nameof(options.OutputColumns));
            }

            // Use ONNXRuntime to figure out the right input and output configuration.
            // However, ONNXRuntime doesn't provide strongly-typed method to access the produced
            // variables, we will inspect the ONNX model file to get information regarding types.
            try
            {
                if (modelBytes == null)
                {
                    // Entering this region means that the model file is passed in by the user.
                    Host.CheckNonWhiteSpace(options.ModelFile, nameof(options.ModelFile));
                    Host.CheckIO(File.Exists(options.ModelFile), "Model file {0} does not exists.", options.ModelFile);
                    // Because we cannot delete the user file, ownModelFile should be false.
                    Model = new OnnxModel(options.ModelFile, options.GpuDeviceId, options.FallbackToCpu, ownModelFile: false);
                }
                else
                {
                    // Entering this region means that the byte[] is passed as the model. To feed that byte[] to ONNXRuntime, we need
                    // to create a temporal file to store it and then call ONNXRuntime's API to load that file.
                    Model = OnnxModel.CreateFromBytes(modelBytes, options.GpuDeviceId, options.FallbackToCpu);
                }
            }
            catch (OnnxRuntimeException e)
            {
                throw Host.Except(e, $"Error initializing model :{e.ToString()}");
            }

            var modelInfo = Model.ModelInfo;

            Inputs      = (options.InputColumns.Count() == 0) ? Model.ModelInfo.InputNames.ToArray() : options.InputColumns;
            Outputs     = (options.OutputColumns.Count() == 0) ? Model.ModelInfo.OutputNames.ToArray() : options.OutputColumns;
            OutputTypes = new DataViewType[Outputs.Length];
            var numModelOutputs = Model.ModelInfo.OutputsInfo.Length;

            for (int i = 0; i < Outputs.Length; i++)
            {
                var outputInfo = Model.ModelInfo.GetOutput(Outputs[i]);
                OutputTypes[i] = outputInfo.DataViewType;
            }
            _options = options;
        }