Пример #1
0
        /// <summary>
        /// Initialize the training session using the OrtEnv.
        /// </summary>
        /// <param name="env">Specifies the OrtEnv to use.</param>
        public void Initialize(OrtEnv env)
        {
            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtInitializeTraining(env.DangerousGetHandle(), m_param.DangerousGetHandle(), m_expectedInputs.DangerousGetHandle(), m_expectedOutputs.DangerousGetHandle()));

            m_param.ExpectedInputs = getTensorDefs(m_expectedInputs);
            m_param.ExpectedOutputs = getTensorDefs(m_expectedOutputs);
        }
Пример #2
0
        static void Main(string[] args)
        {
            string strModel;
            int    nBatch;
            int    nSteps;

            if (!getArgs(args, out strModel, out nBatch, out nSteps))
            {
                return;
            }

            // Create the training session.
            TrainingSession session = new TrainingSession();

            // Set the error callback function called at the end of the forward pass.
            session.Parameters.OnErrorFunction += OnErrorFunction;
            // Set the evaluation callback function called after the error function on 'DISPLAY_LOSS_STEPS'
            session.Parameters.OnEvaluationFunction += OnEvaluationFunction;
            // Set the training and testing data batch callbacks called to get a new batch of data.
            session.Parameters.OnGetTrainingDataBatch += OnGetTrainingDataBatch;
            session.Parameters.OnGetTestingDataBatch  += OnGetTestingDataBatch;

            // Setup the training parameters.
            session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_MODEL_PATH, strModel);
            session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_INPUT_LABELS, "labels");
            session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_OUTPUT_LOSS, "loss");
            session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_OUTPUT_PREDICTIONS, "predictions");
            session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_LOG_PATH, "c:\\temp");
            session.Parameters.SetTrainingParameter(OrtTrainingBooleanParameter.ORT_TRAINING_USE_CUDA, true);
            session.Parameters.SetTrainingParameter(OrtTrainingBooleanParameter.ORT_TRAINING_SHUFFLE_DATA, false);
            session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_EVAL_BATCH_SIZE, nBatch);
            session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_TRAIN_BATCH_SIZE, nBatch);
            session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_NUM_TRAIN_STEPS, nSteps);
            session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_EVAL_PERIOD, 1);
            session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_DISPLAY_LOSS_STEPS, 400);
            session.Parameters.SetTrainingParameter(OrtTrainingNumericParameter.ORT_TRAINING_LEARNING_RATE, 0.01);
            session.Parameters.SetTrainingOptimizer(OrtTrainingOptimizer.ORT_TRAINING_OPTIMIZER_SGD);
            session.Parameters.SetTrainingLossFunction(OrtTrainingLossFunction.ORT_TRAINING_LOSS_FUNCTION_SOFTMAXCROSSENTROPY);
            session.Parameters.SetupTrainingParameters();

            // Setup the training data information.
            session.Parameters.SetupTrainingData(new List <string>()
            {
                "X", "labels"
            });

            // Load the MNIST dataset from file. See http://yann.lecun.com/exdb/mnist/ to get data files.
            MnistDataLoaderLite dataLoader = new MnistDataLoaderLite("c:\\temp\\data");

            dataLoader.ExtractImages(out m_rgTrainingData, out m_rgTestingData);
            m_nTrainingDataIdx = 0;
            m_nTestingDataIdx  = 0;

            // Setup the OnnxRuntime instance.
            OrtEnv.SetLogLevel(LogLevel.Warning);
            OrtEnv env = OrtEnv.Instance();

            // Initialize the training session.
            session.Initialize(env);

            // Run the training session.
            session.RunTraining();
            session.EndTraining();

            // Cleanup.
            session.Dispose();

            Console.WriteLine("Done!");
            Console.WriteLine("press any key to exit...");
            Console.ReadKey();
        }
Пример #3
0
        private void TestRegisterCustomOpLibrary()
        {
            using (var option = new SessionOptions())
            {
                string libName   = "custom_op_library.dll";
                string modelPath = "custom_op_test.onnx";
                if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
                {
                    libName = "custom_op_library.dll";
                }
                else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
                {
                    libName = "libcustom_op_library.so";
                }
                else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
                {
                    libName = "libcustom_op_library.dylib";
                }

                string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName);
                Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist.");

                var      ortEnvInstance = OrtEnv.Instance();
                string[] providers      = ortEnvInstance.GetAvailableProviders();
                if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider"))
                {
                    option.AppendExecutionProvider_CUDA(0);
                }

                IntPtr libraryHandle = IntPtr.Zero;
                try
                {
                    option.RegisterCustomOpLibraryV2(libFullPath, out libraryHandle);
                }
                catch (Exception ex)
                {
                    var msg = $"Failed to load custom op library {libFullPath}, error = {ex.Message}";
                    throw new Exception(msg + "\n" + ex.StackTrace);
                }


                using (var session = new InferenceSession(modelPath, option))
                {
                    var inputContainer = new List <NamedOnnxValue>();
                    inputContainer.Add(NamedOnnxValue.CreateFromTensor <float>("input_1",
                                                                               new DenseTensor <float>(
                                                                                   new float[]
                    {
                        1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
                        6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
                        11.1f, 12.2f, 13.3f, 14.4f, 15.5f
                    },
                                                                                   new int[] { 3, 5 }
                                                                                   )));

                    inputContainer.Add(NamedOnnxValue.CreateFromTensor <float>("input_2",
                                                                               new DenseTensor <float>(
                                                                                   new float[]
                    {
                        15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
                        10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
                        5.5f, 4.4f, 3.3f, 2.2f, 1.1f
                    },
                                                                                   new int[] { 3, 5 }
                                                                                   )));

                    using (var result = session.Run(inputContainer))
                    {
                        Assert.Equal("output", result.First().Name);
                        var tensorOut = result.First().AsTensor <int>();

                        var expectedOut = new DenseTensor <int>(
                            new int[]
                        {
                            17, 17, 17, 17, 17,
                            17, 18, 18, 18, 17,
                            17, 17, 17, 17, 17
                        },
                            new int[] { 3, 5 }
                            );
                        Assert.True(tensorOut.SequenceEqual(expectedOut));
                    }
                }

                // Safe to unload the custom op shared library now
                UnloadLibrary(libraryHandle);
            }
        }