예제 #1
0
        /// <summary>
        /// Create a Session from a SavedModel. If successful, populates the internal graph with the contents of the Graph and
        /// with the MetaGraphDef of the loaded model.
        /// </summary>
        /// <param name="exportDir">Must be set to the path of the exported SavedModel.</param>
        /// <param name="tags">Must include the set of tags used to identify one MetaGraphDef in the SavedModel. Could be "serve", "tpu", "gpu", "train" or other values.</param>
        /// <param name="sessionOptions">Session options</param>
        /// <param name="runOptions"></param>
        /// <param name="status">The status</param>
        public Session(
            String exportDir,
            String[] tags,
            SessionOptions sessionOptions = null,
            Buffer runOptions             = null,
            Status status = null)
        {
            _graph            = new Graph();
            _graphNeedDispose = true;
            _metaGraphDef     = new Buffer();

            IntPtr exportDirPtr = Marshal.StringToHGlobalAnsi(exportDir);

            IntPtr[] tagsNative;
            GCHandle tagsNativeHandle;
            IntPtr   tagsNativePointer = IntPtr.Zero;

            if (tags != null)
            {
                tagsNative = new IntPtr[tags.Length];
                for (int i = 0; i < tags.Length; i++)
                {
                    tagsNative[i] = Marshal.StringToHGlobalAnsi(tags[i]);
                }
                tagsNativeHandle  = GCHandle.Alloc(tagsNative, GCHandleType.Pinned);
                tagsNativePointer = tagsNativeHandle.AddrOfPinnedObject();
            }
            else
            {
                tagsNativeHandle = new GCHandle();
                tagsNative       = new IntPtr[0];
            }

            try
            {
                using (StatusChecker checker = new StatusChecker(status))
                    _ptr = TfInvoke.tfeLoadSessionFromSavedModel(
                        sessionOptions,
                        runOptions,
                        exportDirPtr,
                        tagsNativePointer,
                        tagsNative.Length,
                        _graph,
                        _metaGraphDef,
                        checker.Status
                        );
            }
            catch (Exception excpt)
            {
                Trace.WriteLine(excpt.Message);
                throw;
            }
            finally
            {
                Marshal.FreeHGlobal(exportDirPtr);

                if (tagsNativeHandle.IsAllocated)
                {
                    tagsNativeHandle.Free();
                }

                for (int i = 0; i < tagsNative.Length; i++)
                {
                    Marshal.FreeHGlobal(tagsNative[i]);
                }
            }
        }