コード例 #1
0
        /// <summary>
        /// Initialize the training session using the OrtEnv.
        /// </summary>
        /// <param name="env">Specifies the OrtEnv to use.</param>
        public void Initialize(OrtEnv env)
        {
            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtInitializeTraining(env.DangerousGetHandle(), m_param.DangerousGetHandle(), m_expectedInputs.DangerousGetHandle(), m_expectedOutputs.DangerousGetHandle()));

            m_param.ExpectedInputs = getTensorDefs(m_expectedInputs);
            m_param.ExpectedOutputs = getTensorDefs(m_expectedOutputs);
        }
コード例 #2
0
        /// <summary>
        /// Setup the training parameters and set the error and evaluation functions.
        /// </summary>
        public void SetupTrainingParameters()
        {
            Guid   guid   = System.Guid.NewGuid();
            string strKey = guid.ToString();

            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetupTrainingParameters(_nativeHandle, m_fnErrorFunction, m_fnEvaluateFunction, NativeMethods.GetPlatformSerializedString(strKey)));
        }
コード例 #3
0
        /// <summary>
        /// Return the long based training parameter.
        /// </summary>
        /// <param name="key">Specifies the key of the value to get.</param>
        /// <returns>The long based value is returned.</returns>
        public long GetTrainingParameter(OrtTrainingLongParameter key)
        {
            UIntPtr val = UIntPtr.Zero;

            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetParameter_long(_nativeHandle, key, out val));

            return((long)val);
        }
コード例 #4
0
        /// <summary>
        /// Setup the training data and connect the data batch callbacks.
        /// </summary>
        /// <param name="rgstrFeedNames">Specifies a list of the data feed names</param>
        public void SetupTrainingData(List <string> rgstrFeedNames)
        {
            string strFeedNames = "";

            for (int i = 0; i < rgstrFeedNames.Count; i++)
            {
                strFeedNames += rgstrFeedNames[i];
                strFeedNames += ";";
            }

            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetupTrainingData(_nativeHandle, m_fnGetTrainingData, m_fnGetTestingData, NativeMethods.GetPlatformSerializedString(strFeedNames)));
        }
コード例 #5
0
 /// <summary>
 /// The OrtValueCollection is an object is not a native collection, but instead
 /// gives access to a group of native OrtValues via its GetAt and SetAt methods.
 /// </summary>
 /// <param name="h">Specifies the handle to the native OrtValueCollection to use, or IntPtr.Zero.
 /// If IntPtr.Zero, the OrtValueCollection creates a value collection that it owns and disposes,
 /// otherwise the OrtValueCollection does not own the collection and therefore does not dispose it.</param>
 /// <remarks>
 /// For efficiency, the OrtValue collection gives access to a set of OrtValues where
 /// each OrtValue does not actually own the memory but instead points to one or
 /// more pre-allocated OrtValues.
 /// </remarks>
 public OrtValueCollection(IntPtr h)
 {
     if (h == IntPtr.Zero)
     {
         NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtCreateValueCollection(out _nativeHandle));
         _ownsHandle = true;
     }
     else
     {
         _nativeHandle = h;
         _ownsHandle   = false;
     }
 }
コード例 #6
0
        /// <summary>
        /// Return the numeric (double) based training parameter.
        /// </summary>
        /// <param name="key">Specifies the key of the value to get.</param>
        /// <returns>The double based value is returned.</returns>
        public double GetTrainingParameter(OrtTrainingNumericParameter key)
        {
            string str       = null;
            var    allocator = OrtAllocator.DefaultInstance;
            IntPtr valHandle = IntPtr.Zero;

            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetNumericParameter(_nativeHandle, key, allocator.Pointer, out valHandle));

            using (var ortAllocation = new OrtMemoryAllocation(allocator, valHandle, 0))
            {
                str = NativeOnnxValueHelper.StringFromNativeUtf8(valHandle);
            }
            return(double.Parse(str));
        }
コード例 #7
0
        /// <summary>
        /// Returns the optimizer used.
        /// </summary>
        /// <returns>The optimizer used is returned.</returns>
        public OrtTrainingOptimizer GetTrainingOptimizer()
        {
            UIntPtr val = UIntPtr.Zero;

            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetTrainingOptimizer(_nativeHandle, out val));

            switch (((OrtTrainingOptimizer)(int)val))
            {
            case OrtTrainingOptimizer.ORT_TRAINING_OPTIMIZER_SGD:
                return(OrtTrainingOptimizer.ORT_TRAINING_OPTIMIZER_SGD);

            default:
                throw new Exception("Unknown optimizer '" + val.ToString() + "'!");
            }
        }
コード例 #8
0
        /// <summary>
        /// Returns the OrtValue at a given index as well as its name.
        /// </summary>
        /// <param name="nIdx">Specifies the index to get.</param>
        /// <param name="strName">Returns the name of the OrtValue.</param>
        /// <returns>The OrtValue at the index is returned.</returns>
        public OrtValue GetAt(int nIdx, out string strName)
        {
            IntPtr valData;
            var    allocator = OrtAllocator.DefaultInstance;
            IntPtr valName;

            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetAt(_nativeHandle, nIdx, out valData, allocator.Pointer, out valName));

            using (var ortAllocation = new OrtMemoryAllocation(allocator, valName, 0))
            {
                strName = NativeOnnxValueHelper.StringFromNativeUtf8(valName);
            }

            return(new OrtValue(valData, false));
        }
コード例 #9
0
        /// <summary>
        /// Return the boolean based training parameter.
        /// </summary>
        /// <param name="key">Specifies the key of the value to get.</param>
        /// <returns>The boolean based value is returned.</returns>
        public bool GetTrainingParameter(OrtTrainingBooleanParameter key)
        {
            UIntPtr val = UIntPtr.Zero;

            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetParameter_bool(_nativeHandle, key, out val));

            if ((ulong)val == 0)
            {
                return(false);
            }
            else
            {
                return(true);
            }
        }
コード例 #10
0
        /// <summary>
        /// Returns the loss function used.
        /// </summary>
        /// <returns>The loss function used is returned.</returns>
        public OrtTrainingLossFunction GetTrainingLossFunction()
        {
            UIntPtr val = UIntPtr.Zero;

            NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtGetTrainingLossFunction(_nativeHandle, out val));

            switch (((OrtTrainingLossFunction)(int)val))
            {
            case OrtTrainingLossFunction.ORT_TRAINING_LOSS_FUNCTION_SOFTMAXCROSSENTROPY:
                return(OrtTrainingLossFunction.ORT_TRAINING_LOSS_FUNCTION_SOFTMAXCROSSENTROPY);

            default:
                throw new Exception("Unknown loss function '" + val.ToString() + "'!");
            }
        }
コード例 #11
0
        /// <summary>
        /// IDisposable implementation
        /// </summary>
        /// <param name="disposing">true if invoked from Dispose() method</param>
        protected virtual void Dispose(bool disposing)
        {
            if (_disposed)
            {
                return;
            }

            // dispose managed state (managed objects).
            if (disposing)
            {
                m_rgCleanUpList.Dispose();
            }

            // cleanup unmanaged resources
            if (_nativeHandle != IntPtr.Zero)
            {
                NativeMethodsTraining.OrtReleaseTrainingParameters(_nativeHandle);
                _nativeHandle = IntPtr.Zero;
            }
            _disposed = true;
        }
コード例 #12
0
        /// <summary>
        /// IDisposable implementation
        /// </summary>
        /// <param name="disposing">true if invoked from Dispose() method</param>
        protected virtual void Dispose(bool disposing)
        {
            if (_disposed)
            {
                return;
            }

            // dispose managed state (managed objects).
            if (disposing)
            {
            }

            // cleanup unmanaged resources
            if (_nativeHandle != IntPtr.Zero)
            {
                if (_ownsHandle)
                {
                    NativeMethodsTraining.OrtReleaseValueCollection(_nativeHandle);
                }
                _nativeHandle = IntPtr.Zero;
            }
            _disposed = true;
        }
コード例 #13
0
 /// <summary>
 /// Set the long based training parameters.
 /// </summary>
 /// <param name="key">Specifies the key of the value to set.</param>
 /// <param name="lVal">Specifies the value to be set.</param>
 public void SetTrainingParameter(OrtTrainingLongParameter key, long lVal)
 {
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetParameter_long(_nativeHandle, key, lVal));
 }
コード例 #14
0
 /// <summary>
 /// Set the string based training parameters.
 /// </summary>
 /// <param name="key">Specifies the key of the value to set.</param>
 /// <param name="strVal">Specifies the value to be set.</param>
 public void SetTrainingParameter(OrtTrainingStringParameter key, string strVal)
 {
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetParameter_string(_nativeHandle, key, NativeMethods.GetPlatformSerializedString(strVal)));
 }
コード例 #15
0
 /// <summary>
 /// Set the numeric (double) based training parameters.
 /// </summary>
 /// <param name="key">Specifies the key of the value to set.</param>
 /// <param name="dfVal">Specifies the value to be set.</param>
 public void SetTrainingParameter(OrtTrainingNumericParameter key, double dfVal)
 {
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetNumericParameter(_nativeHandle, key, dfVal));
 }
コード例 #16
0
 /// <summary>
 /// Set an OrtVale at a given index.
 /// </summary>
 /// <param name="nIdx">Specifies the index where the data is to be set.</param>
 /// <param name="val">Specifies the value to set.</param>
 /// <param name="strName">Specifies the name of the value.</param>
 public void SetAt(int nIdx, OrtValue val, string strName = "")
 {
     byte[] rgName = (string.IsNullOrEmpty(strName)) ? null : NativeMethods.GetPlatformSerializedString(strName);
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetAt(_nativeHandle, nIdx, val.Handle, rgName));
 }
コード例 #17
0
 /// <summary>
 /// Set the training optimizer to use.
 /// </summary>
 /// <param name="opt">Specifies the optimizer to use.</param>
 public void SetTrainingOptimizer(OrtTrainingOptimizer opt)
 {
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetTrainingOptimizer(_nativeHandle, opt));
 }
コード例 #18
0
 /// <summary>
 /// End the training session.
 /// </summary>
 public void EndTraining()
 {
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtEndTraining(m_param.DangerousGetHandle()));
 }
コード例 #19
0
 /// <summary>
 /// Set the type of loss function to use.
 /// </summary>
 /// <param name="loss">Specifies the loss function type.</param>
 public void SetTrainingLossFunction(OrtTrainingLossFunction loss)
 {
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetTrainingLossFunction(_nativeHandle, loss));
 }
コード例 #20
0
 private void Init()
 {
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtCreateTrainingParameters(out _nativeHandle));
 }
コード例 #21
0
 /// <summary>
 /// Set the boolean based training parameters.
 /// </summary>
 /// <param name="key">Specifies the key of the value to set.</param>
 /// <param name="bVal">Specifies the value to be set.</param>
 public void SetTrainingParameter(OrtTrainingBooleanParameter key, bool bVal)
 {
     NativeApiStatus.VerifySuccess(NativeMethodsTraining.OrtSetParameter_bool(_nativeHandle, key, bVal));
 }