static void Main(string[] args) { string strModel; int nBatch; int nSteps; if (!getArgs(args, out strModel, out nBatch, out nSteps)) { return; } // Create the training session. TrainingSession session = new TrainingSession(); // Set the error callback function called at the end of the forward pass. session.Parameters.OnErrorFunction += OnErrorFunction; // Set the evaluation callback function called after the error function on 'DISPLAY_LOSS_STEPS' session.Parameters.OnEvaluationFunction += OnEvaluationFunction; // Set the training and testing data batch callbacks called to get a new batch of data. session.Parameters.OnGetTrainingDataBatch += OnGetTrainingDataBatch; session.Parameters.OnGetTestingDataBatch += OnGetTestingDataBatch; // Setup the training parameters. session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_MODEL_PATH, strModel); session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_INPUT_LABELS, "labels"); session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_OUTPUT_LOSS, "loss"); session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_OUTPUT_PREDICTIONS, "predictions"); session.Parameters.SetTrainingParameter(OrtTrainingStringParameter.ORT_TRAINING_LOG_PATH, "c:\\temp"); session.Parameters.SetTrainingParameter(OrtTrainingBooleanParameter.ORT_TRAINING_USE_CUDA, true); session.Parameters.SetTrainingParameter(OrtTrainingBooleanParameter.ORT_TRAINING_SHUFFLE_DATA, false); session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_EVAL_BATCH_SIZE, nBatch); session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_TRAIN_BATCH_SIZE, nBatch); session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_NUM_TRAIN_STEPS, nSteps); session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_EVAL_PERIOD, 1); session.Parameters.SetTrainingParameter(OrtTrainingLongParameter.ORT_TRAINING_DISPLAY_LOSS_STEPS, 400); session.Parameters.SetTrainingParameter(OrtTrainingNumericParameter.ORT_TRAINING_LEARNING_RATE, 0.01); session.Parameters.SetTrainingOptimizer(OrtTrainingOptimizer.ORT_TRAINING_OPTIMIZER_SGD); session.Parameters.SetTrainingLossFunction(OrtTrainingLossFunction.ORT_TRAINING_LOSS_FUNCTION_SOFTMAXCROSSENTROPY); session.Parameters.SetupTrainingParameters(); // Setup the training data information. session.Parameters.SetupTrainingData(new List <string>() { "X", "labels" }); // Load the MNIST dataset from file. See http://yann.lecun.com/exdb/mnist/ to get data files. MnistDataLoaderLite dataLoader = new MnistDataLoaderLite("c:\\temp\\data"); dataLoader.ExtractImages(out m_rgTrainingData, out m_rgTestingData); m_nTrainingDataIdx = 0; m_nTestingDataIdx = 0; // Setup the OnnxRuntime instance. OrtEnv.SetLogLevel(LogLevel.Warning); OrtEnv env = OrtEnv.Instance(); // Initialize the training session. session.Initialize(env); // Run the training session. session.RunTraining(); session.EndTraining(); // Cleanup. session.Dispose(); Console.WriteLine("Done!"); Console.WriteLine("press any key to exit..."); Console.ReadKey(); }
private void TestRegisterCustomOpLibrary() { using (var option = new SessionOptions()) { string libName = "custom_op_library.dll"; string modelPath = "custom_op_test.onnx"; if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { libName = "custom_op_library.dll"; } else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { libName = "libcustom_op_library.so"; } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { libName = "libcustom_op_library.dylib"; } string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), libName); Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); var ortEnvInstance = OrtEnv.Instance(); string[] providers = ortEnvInstance.GetAvailableProviders(); if (Array.Exists(providers, provider => provider == "CUDAExecutionProvider")) { option.AppendExecutionProvider_CUDA(0); } IntPtr libraryHandle = IntPtr.Zero; try { option.RegisterCustomOpLibraryV2(libFullPath, out libraryHandle); } catch (Exception ex) { var msg = $"Failed to load custom op library {libFullPath}, error = {ex.Message}"; throw new Exception(msg + "\n" + ex.StackTrace); } using (var session = new InferenceSession(modelPath, option)) { var inputContainer = new List <NamedOnnxValue>(); inputContainer.Add(NamedOnnxValue.CreateFromTensor <float>("input_1", new DenseTensor <float>( new float[] { 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, 11.1f, 12.2f, 13.3f, 14.4f, 15.5f }, new int[] { 3, 5 } ))); inputContainer.Add(NamedOnnxValue.CreateFromTensor <float>("input_2", new DenseTensor <float>( new float[] { 15.5f, 14.4f, 13.3f, 12.2f, 11.1f, 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, 5.5f, 4.4f, 3.3f, 2.2f, 1.1f }, new int[] { 3, 5 } ))); using (var result = session.Run(inputContainer)) { Assert.Equal("output", result.First().Name); var tensorOut = result.First().AsTensor <int>(); var expectedOut = new DenseTensor <int>( new int[] { 17, 17, 17, 17, 17, 17, 18, 18, 18, 17, 17, 17, 17, 17, 17 }, new int[] { 3, 5 } ); Assert.True(tensorOut.SequenceEqual(expectedOut)); } } // Safe to unload the custom op shared library now UnloadLibrary(libraryHandle); } }