/// <summary> /// Initialize the training session using the OrtEnv. /// </summary> /// <param name="env">Specifies the OrtEnv to use.</param> public void Initialize(OrtEnv env) { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtInitializeTraining(env.DangerousGetHandle(), m_param.DangerousGetHandle(), m_expectedInputs.DangerousGetHandle(), m_expectedOutputs.DangerousGetHandle())); m_param.ExpectedInputs = getTensorDefs(m_expectedInputs); m_param.ExpectedOutputs = getTensorDefs(m_expectedOutputs); }
/// <summary> /// Setup the training parameters and set the error and evaluation functions. /// </summary> public void SetupTrainingParameters() { Guid guid = System.Guid.NewGuid(); string strKey = guid.ToString(); NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetupTrainingParameters(_nativeHandle, m_fnErrorFunction, m_fnEvaluateFunction, NativeMethods.GetPlatformSerializedString(strKey))); }
/// <summary> /// Return the long based training parameter. /// </summary> /// <param name="key">Specifies the key of the value to get.</param> /// <returns>The long based value is returned.</returns> public long GetTrainingParameter(OrtTrainingLongParameter key) { UIntPtr val = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetParameter_long(_nativeHandle, key, out val)); return((long)val); }
/// <summary> /// Setup the training data and connect the data batch callbacks. /// </summary> /// <param name="rgstrFeedNames">Specifies a list of the data feed names</param> public void SetupTrainingData(List <string> rgstrFeedNames) { string strFeedNames = ""; for (int i = 0; i < rgstrFeedNames.Count; i++) { strFeedNames += rgstrFeedNames[i]; strFeedNames += ";"; } NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetupTrainingData(_nativeHandle, m_fnGetTrainingData, m_fnGetTestingData, NativeMethods.GetPlatformSerializedString(strFeedNames))); }
/// <summary> /// The OrtValueCollection is an object is not a native collection, but instead /// gives access to a group of native OrtValues via its GetAt and SetAt methods. /// </summary> /// <param name="h">Specifies the handle to the native OrtValueCollection to use, or IntPtr.Zero. /// If IntPtr.Zero, the OrtValueCollection creates a value collection that it owns and disposes, /// otherwise the OrtValueCollection does not own the collection and therefore does not dispose it.</param> /// <remarks> /// For efficiency, the OrtValue collection gives access to a set of OrtValues where /// each OrtValue does not actually own the memory but instead points to one or /// more pre-allocated OrtValues. /// </remarks> public OrtValueCollection(IntPtr h) { if (h == IntPtr.Zero) { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtCreateValueCollection(out _nativeHandle)); _ownsHandle = true; } else { _nativeHandle = h; _ownsHandle = false; } }
/// <summary> /// Return the numeric (double) based training parameter. /// </summary> /// <param name="key">Specifies the key of the value to get.</param> /// <returns>The double based value is returned.</returns> public double GetTrainingParameter(OrtTrainingNumericParameter key) { string str = null; var allocator = OrtAllocator.DefaultInstance; IntPtr valHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetNumericParameter(_nativeHandle, key, allocator.Pointer, out valHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, valHandle, 0)) { str = NativeOnnxValueHelper.StringFromNativeUtf8(valHandle); } return(double.Parse(str)); }
/// <summary> /// Returns the optimizer used. /// </summary> /// <returns>The optimizer used is returned.</returns> public OrtTrainingOptimizer GetTrainingOptimizer() { UIntPtr val = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetTrainingOptimizer(_nativeHandle, out val)); switch (((OrtTrainingOptimizer)(int)val)) { case OrtTrainingOptimizer.ORT_TRAINING_OPTIMIZER_SGD: return(OrtTrainingOptimizer.ORT_TRAINING_OPTIMIZER_SGD); default: throw new Exception("Unknown optimizer '" + val.ToString() + "'!"); } }
/// <summary> /// Returns the OrtValue at a given index as well as its name. /// </summary> /// <param name="nIdx">Specifies the index to get.</param> /// <param name="strName">Returns the name of the OrtValue.</param> /// <returns>The OrtValue at the index is returned.</returns> public OrtValue GetAt(int nIdx, out string strName) { IntPtr valData; var allocator = OrtAllocator.DefaultInstance; IntPtr valName; NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetAt(_nativeHandle, nIdx, out valData, allocator.Pointer, out valName)); using (var ortAllocation = new OrtMemoryAllocation(allocator, valName, 0)) { strName = NativeOnnxValueHelper.StringFromNativeUtf8(valName); } return(new OrtValue(valData, false)); }
/// <summary> /// Return the boolean based training parameter. /// </summary> /// <param name="key">Specifies the key of the value to get.</param> /// <returns>The boolean based value is returned.</returns> public bool GetTrainingParameter(OrtTrainingBooleanParameter key) { UIntPtr val = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetParameter_bool(_nativeHandle, key, out val)); if ((ulong)val == 0) { return(false); } else { return(true); } }
/// <summary> /// Returns the loss function used. /// </summary> /// <returns>The loss function used is returned.</returns> public OrtTrainingLossFunction GetTrainingLossFunction() { UIntPtr val = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetTrainingLossFunction(_nativeHandle, out val)); switch (((OrtTrainingLossFunction)(int)val)) { case OrtTrainingLossFunction.ORT_TRAINING_LOSS_FUNCTION_SOFTMAXCROSSENTROPY: return(OrtTrainingLossFunction.ORT_TRAINING_LOSS_FUNCTION_SOFTMAXCROSSENTROPY); default: throw new Exception("Unknown loss function '" + val.ToString() + "'!"); } }
/// <summary> /// IDisposable implementation /// </summary> /// <param name="disposing">true if invoked from Dispose() method</param> protected virtual void Dispose(bool disposing) { if (_disposed) { return; } // dispose managed state (managed objects). if (disposing) { m_rgCleanUpList.Dispose(); } // cleanup unmanaged resources if (_nativeHandle != IntPtr.Zero) { NativeMethodsTraining.OrtReleaseTrainingParameters(_nativeHandle); _nativeHandle = IntPtr.Zero; } _disposed = true; }
/// <summary> /// IDisposable implementation /// </summary> /// <param name="disposing">true if invoked from Dispose() method</param> protected virtual void Dispose(bool disposing) { if (_disposed) { return; } // dispose managed state (managed objects). if (disposing) { } // cleanup unmanaged resources if (_nativeHandle != IntPtr.Zero) { if (_ownsHandle) { NativeMethodsTraining.OrtReleaseValueCollection(_nativeHandle); } _nativeHandle = IntPtr.Zero; } _disposed = true; }
/// <summary> /// Set the long based training parameters. /// </summary> /// <param name="key">Specifies the key of the value to set.</param> /// <param name="lVal">Specifies the value to be set.</param> public void SetTrainingParameter(OrtTrainingLongParameter key, long lVal) { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetParameter_long(_nativeHandle, key, lVal)); }
/// <summary> /// Set the string based training parameters. /// </summary> /// <param name="key">Specifies the key of the value to set.</param> /// <param name="strVal">Specifies the value to be set.</param> public void SetTrainingParameter(OrtTrainingStringParameter key, string strVal) { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetParameter_string(_nativeHandle, key, NativeMethods.GetPlatformSerializedString(strVal))); }
/// <summary> /// Set the numeric (double) based training parameters. /// </summary> /// <param name="key">Specifies the key of the value to set.</param> /// <param name="dfVal">Specifies the value to be set.</param> public void SetTrainingParameter(OrtTrainingNumericParameter key, double dfVal) { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetNumericParameter(_nativeHandle, key, dfVal)); }
/// <summary> /// Set an OrtVale at a given index. /// </summary> /// <param name="nIdx">Specifies the index where the data is to be set.</param> /// <param name="val">Specifies the value to set.</param> /// <param name="strName">Specifies the name of the value.</param> public void SetAt(int nIdx, OrtValue val, string strName = "") { byte[] rgName = (string.IsNullOrEmpty(strName)) ? null : NativeMethods.GetPlatformSerializedString(strName); NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetAt(_nativeHandle, nIdx, val.Handle, rgName)); }
/// <summary> /// Set the training optimizer to use. /// </summary> /// <param name="opt">Specifies the optimizer to use.</param> public void SetTrainingOptimizer(OrtTrainingOptimizer opt) { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetTrainingOptimizer(_nativeHandle, opt)); }
/// <summary> /// End the training session. /// </summary> public void EndTraining() { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtEndTraining(m_param.DangerousGetHandle())); }
/// <summary> /// Set the type of loss function to use. /// </summary> /// <param name="loss">Specifies the loss function type.</param> public void SetTrainingLossFunction(OrtTrainingLossFunction loss) { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetTrainingLossFunction(_nativeHandle, loss)); }
private void Init() { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtCreateTrainingParameters(out _nativeHandle)); }
/// <summary> /// Set the boolean based training parameters. /// </summary> /// <param name="key">Specifies the key of the value to set.</param> /// <param name="bVal">Specifies the value to be set.</param> public void SetTrainingParameter(OrtTrainingBooleanParameter key, bool bVal) { NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetParameter_bool(_nativeHandle, key, bVal)); }