Esempio n. 1
0
        private void TestRegisterCustomOpLibrary()
        {
            using (var option = new SessionOptions())
            {
                string libName   = "custom_op_library.dll";
                string modelPath = "custom_op_test.onnx";
                if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
                {
                    libName = "custom_op_library.dll";
                }
                else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
                {
                    libName = "libcustom_op_library.so";
                }
                else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
                {
                    libName = "libcustom_op_library.dylib";
                }

                string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName);
                Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist.");

                var      ortEnvInstance = OrtEnv.Instance();
                string[] providers      = ortEnvInstance.GetAvailableProviders();
                if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider"))
                {
                    option.AppendExecutionProvider_CUDA(0);
                }

                IntPtr libraryHandle = IntPtr.Zero;
                try
                {
                    option.RegisterCustomOpLibraryV2(libFullPath, out libraryHandle);
                }
                catch (Exception ex)
                {
                    var msg = $"Failed to load custom op library {libFullPath}, error = {ex.Message}";
                    throw new Exception(msg + "\n" + ex.StackTrace);
                }


                using (var session = new InferenceSession(modelPath, option))
                {
                    var inputContainer = new List <NamedOnnxValue>();
                    inputContainer.Add(NamedOnnxValue.CreateFromTensor <float>("input_1",
                                                                               new DenseTensor <float>(
                                                                                   new float[]
                    {
                        1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
                        6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
                        11.1f, 12.2f, 13.3f, 14.4f, 15.5f
                    },
                                                                                   new int[] { 3, 5 }
                                                                                   )));

                    inputContainer.Add(NamedOnnxValue.CreateFromTensor <float>("input_2",
                                                                               new DenseTensor <float>(
                                                                                   new float[]
                    {
                        15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
                        10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
                        5.5f, 4.4f, 3.3f, 2.2f, 1.1f
                    },
                                                                                   new int[] { 3, 5 }
                                                                                   )));

                    using (var result = session.Run(inputContainer))
                    {
                        Assert.Equal("output", result.First().Name);
                        var tensorOut = result.First().AsTensor <int>();

                        var expectedOut = new DenseTensor <int>(
                            new int[]
                        {
                            17, 17, 17, 17, 17,
                            17, 18, 18, 18, 17,
                            17, 17, 17, 17, 17
                        },
                            new int[] { 3, 5 }
                            );
                        Assert.True(tensorOut.SequenceEqual(expectedOut));
                    }
                }

                // Safe to unload the custom op shared library now
                UnloadLibrary(libraryHandle);
            }
        }