Exemplo n.º 1
0
        /// <summary>
        /// Get the tensor shape proto value of the attribute
        /// </summary>
        /// <param name="attrName">The name of the attribute</param>
        /// <param name="status">The status</param>
        /// <returns>The buffer that contains the TensorShapeProto </returns>
        public Buffer GetAttrTensorShapeProto(String attrName, Status status = null)
        {
            using (StatusChecker checker = new StatusChecker(status))
            {
                AttrMetadata meta = GetAttrMetadata(attrName, status);
                if (meta.Type != AttrType.Shape)
                {
                    throw new ArgumentException(String.Format("Attribute {0} ({1}) is not a shape", attrName, meta.Type));
                }

                Buffer buffer = new Buffer();
                TfInvoke.tfeOperationGetAttrTensorShapeProto(_ptr, attrName, buffer, checker.Status);
                return(buffer);
            }
        }
Exemplo n.º 2
0
 /// <summary>
 /// Convert a byte array to a Tensor
 /// </summary>
 /// <param name="value">The byte array</param>
 /// <param name="status">Optional status</param>
 /// <returns>The tensor</returns>
 public static Tensor FromString(byte[] value, Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
     {
         int    length = TfInvoke.tfeStringEncodedSize(value.Length);
         Tensor tensor = new Tensor(DataType.String, length + 8);
         IntPtr ptr    = tensor.DataPointer;
         Marshal.WriteInt64(ptr, 0);
         GCHandle handle = GCHandle.Alloc(value, GCHandleType.Pinned);
         TfInvoke.tfeStringEncode(handle.AddrOfPinnedObject(), value.Length, new IntPtr(ptr.ToInt64() + 8),
                                  length,
                                  checker.Status);
         handle.Free();
         return(tensor);
     }
 }
Exemplo n.º 3
0
        public static String[] ListAllPhysicalDevices(Status status = null)
        {
            using (StatusChecker checker = new StatusChecker(status))
            {
                byte[] nameBuffer = new byte[2048];

                GCHandle nameHandle = GCHandle.Alloc(nameBuffer, GCHandleType.Pinned);

                TfInvoke.tfeListAllPhysicalDevices(
                    nameHandle.AddrOfPinnedObject(),
                    checker.Status);

                nameHandle.Free();
                String   nameResult = System.Text.Encoding.ASCII.GetString(nameBuffer);
                String[] names      = nameResult.TrimEnd('\0', '\n').Split('\n');
                return(names);
            }
        }
Exemplo n.º 4
0
        /// <summary>
        /// Returns the shape of the Tensor
        /// </summary>
        /// <param name="output">The output</param>
        /// <param name="status">The status</param>
        /// <returns>The shape of the Tensor</returns>
        public int[] GetTensorShape(Output output, Status status = null)
        {
            using (StatusChecker checker = new StatusChecker(status))
            {
                int numDim = TfInvoke.tfeGraphGetTensorNumDims(_ptr, output.Operation, output.Index, checker.Status);
                if (numDim < 0)
                {
                    return(null);
                }
                else if (numDim == 0)
                {
                    return(new int[0]);
                }

                int[]    dims   = new int[numDim];
                GCHandle handle = GCHandle.Alloc(dims, GCHandleType.Pinned);
                TfInvoke.tfeGraphGetTensorShape(_ptr, output.Operation, output.Index, handle.AddrOfPinnedObject(), numDim, checker.Status);
                handle.Free();
                return(dims);
            }
        }
Exemplo n.º 5
0
        /// <summary>
        /// Lists all devices in a session
        /// </summary>
        /// <param name="status">The status</param>
        /// <returns>All devices in the current session</returns>
        public Device[] ListDevices(Status status = null)
        {
            using (StatusChecker checker = new StatusChecker(status))
            {
                byte[]   nameBuffer       = new byte[2048];
                byte[]   typeBuffer       = new byte[2048];
                Int64[]  memorySizeBuffer = new Int64[128];
                GCHandle nameHandle       = GCHandle.Alloc(nameBuffer, GCHandleType.Pinned);
                GCHandle typeHandle       = GCHandle.Alloc(typeBuffer, GCHandleType.Pinned);
                GCHandle memorySizeHandle = GCHandle.Alloc(memorySizeBuffer, GCHandleType.Pinned);

                TfInvoke.tfeSessionListDevices(
                    _ptr,
                    nameHandle.AddrOfPinnedObject(),
                    typeHandle.AddrOfPinnedObject(),
                    memorySizeHandle.AddrOfPinnedObject(),
                    checker.Status);

                nameHandle.Free();
                typeHandle.Free();
                memorySizeHandle.Free();

                String   nameResult = System.Text.Encoding.ASCII.GetString(nameBuffer);
                String[] names      = nameResult.TrimEnd('\0', '\n').Split('\n');

                String   typeResult = System.Text.Encoding.ASCII.GetString(typeBuffer);
                String[] types      = typeResult.TrimEnd('\0', '\n').Split('\n');

                Device[] devices = new Device[names.Length];
                for (int i = 0; i < devices.Length; i++)
                {
                    Device d = new Device();
                    d.Name        = names[i];
                    d.Type        = types[i];
                    d.MemoryBytes = memorySizeBuffer[i];
                    devices[i]    = d;
                }
                return(devices);
            }
        }
Exemplo n.º 6
0
        /// <summary>
        /// Release the unmanaged memory associated with this Session.
        /// </summary>
        protected override void DisposeObject()
        {
            if (IntPtr.Zero != _ptr)
            {
                using (StatusChecker checker = new StatusChecker(null))
                    TfInvoke.tfeDeleteSession(ref _ptr, checker.Status);
            }

            if (_graphNeedDispose && _graph != null)
            {
                _graph.Dispose();
                _graph = null;
            }

            if (_metaGraphDef != null)
            {
                _metaGraphDef.Dispose();
                _metaGraphDef = null;
            }

            _graph = null;
        }
Exemplo n.º 7
0
        /// <summary>
        /// Get the value of the attribute that is a String
        /// </summary>
        /// <param name="attrName">The name of the attribute</param>
        /// <param name="status">The status</param>
        /// <returns>The string value of the attribute</returns>
        public String GetAttrString(String attrName, Status status = null)
        {
            using (StatusChecker checker = new StatusChecker(status))
            {
                AttrMetadata meta = GetAttrMetadata(attrName, status);
                if (meta.Type != AttrType.String)
                {
                    throw new ArgumentException(String.Format("Attribute {0} ({1}) is not a String", attrName, meta.Type));
                }

                IntPtr s = Marshal.AllocHGlobal((int)meta.TotalSize);
                try
                {
                    TfInvoke.tfeOperationGetAttrString(_ptr, attrName, s, (int)meta.TotalSize, checker.Status);
                    return(Marshal.PtrToStringAnsi(s));
                }
                finally
                {
                    Marshal.FreeHGlobal(s);
                }
            }
        }
Exemplo n.º 8
0
        /// <summary>
        /// Get the value of the attribute that is a list of Int64
        /// </summary>
        /// <param name="attrName">The name of the attribute</param>
        /// <param name="status">The status</param>
        /// <returns>A list ofInt64</returns>
        public Int64[] GetAttrIntList(String attrName, Status status = null)
        {
            using (StatusChecker checker = new StatusChecker(status))
            {
                AttrMetadata meta = GetAttrMetadata(attrName, status);
                if (!((meta.Type == AttrType.Int) && meta.IsList))
                {
                    throw new ArgumentException(String.Format("Attribute {0} ({1}) is not a List of Int", attrName, meta.Type));
                }

                Int64[]  list   = new Int64[meta.ListSize];
                GCHandle handle = GCHandle.Alloc(list, GCHandleType.Pinned);
                try
                {
                    TfInvoke.tfeOperationGetAttrIntList(_ptr, attrName, handle.AddrOfPinnedObject(), list.Length, checker.Status);
                }
                finally
                {
                    handle.Free();
                }
                return(list);
            }
        }
Exemplo n.º 9
0
 /// <summary>
 /// Set a Tensor as an attribute
 /// </summary>
 /// <param name="attrName">The name of the attribute</param>
 /// <param name="tensor">The Tensor</param>
 /// <param name="status">The status</param>
 public void SetAttr(String attrName, Tensor tensor, Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
         TfInvoke.tfeSetAttrTensor(_ptr, attrName, tensor, checker.Status);
 }
Exemplo n.º 10
0
        /// <summary>
        /// Run the graph associated with the session starting with the supplied inputs
        /// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]).
        /// </summary>
        /// <param name="inputs">The input nodes</param>
        /// <param name="inputValues">The input values</param>
        /// <param name="outputs">The output nodes</param>
        /// <param name="targetOperations">Optional target operations</param>
        /// <param name="runOptions"></param>
        /// May be NULL, in which case it will be ignored; or
        /// non-NULL, in which case it must point to a `TF_Buffer` containing the
        /// serialized representation of a `RunOptions` protocol buffer.
        /// <param name="runMetadata">
        /// May be NULL, in which case it will be ignored; or
        /// non-NULL, in which case it must point to an empty, freshly allocated
        /// `TF_Buffer` that may be updated to contain the serialized representation
        /// of a `RunMetadata` protocol buffer.
        /// </param>
        /// <param name="status">The status</param>
        /// <returns>On success, the tensors corresponding to outputs[0,noutputs-1] are placed in the returned Tensors.</returns>
        public Tensor[] Run(Output[] inputs, Tensor[] inputValues, Output[] outputs, Operation[] targetOperations = null, Buffer runOptions = null, Buffer runMetadata = null, Status status = null)
        {
            IntPtr[] inputOps     = Array.ConvertAll(inputs, i => i.Operation.Ptr);
            int[]    inputIdx     = Array.ConvertAll(inputs, i => i.Index);
            IntPtr[] inputTensors = Array.ConvertAll(inputValues, i => i.Ptr);

            GCHandle inputOpsHandle     = GCHandle.Alloc(inputOps, GCHandleType.Pinned);
            GCHandle inputIdxHandle     = GCHandle.Alloc(inputIdx, GCHandleType.Pinned);
            GCHandle inputTensorsHandle = GCHandle.Alloc(inputTensors, GCHandleType.Pinned);

            IntPtr[] outputOps           = Array.ConvertAll(outputs, o => o.Operation.Ptr);
            int[]    outputIdx           = Array.ConvertAll(outputs, o => o.Index);
            IntPtr[] outputTensors       = new IntPtr[outputs.Length];
            GCHandle outputOpsHandle     = GCHandle.Alloc(outputOps, GCHandleType.Pinned);
            GCHandle outputIdxHandle     = GCHandle.Alloc(outputIdx, GCHandleType.Pinned);
            GCHandle outputTensorsHandle = GCHandle.Alloc(outputTensors, GCHandleType.Pinned);

            IntPtr targetOpsPtr = IntPtr.Zero;
            int    ntargets     = 0;

            IntPtr[] targetOpsPtrArray = null;
            GCHandle targetOpsHandle   = new GCHandle();

            if (targetOperations != null)
            {
                targetOpsPtrArray = Array.ConvertAll(targetOperations, o => o.Ptr);
                targetOpsHandle   = GCHandle.Alloc(targetOpsPtrArray, GCHandleType.Pinned);
                targetOpsPtr      = targetOpsHandle.AddrOfPinnedObject();
                ntargets          = targetOperations.Length;
            }
            using (StatusChecker checker = new StatusChecker(status))
            {
                TfInvoke.tfeSessionRun(
                    _ptr,
                    runOptions,
                    inputOpsHandle.AddrOfPinnedObject(),
                    inputIdxHandle.AddrOfPinnedObject(),
                    inputTensorsHandle.AddrOfPinnedObject(),
                    inputs.Length,
                    outputOpsHandle.AddrOfPinnedObject(),
                    outputIdxHandle.AddrOfPinnedObject(),
                    outputTensorsHandle.AddrOfPinnedObject(),
                    outputs.Length,
                    targetOpsPtr,
                    ntargets,
                    runMetadata,
                    checker.Status
                    );
            }
            inputOpsHandle.Free();
            inputIdxHandle.Free();
            inputTensorsHandle.Free();

            if (targetOperations != null)
            {
                targetOpsHandle.Free();
            }

            outputOpsHandle.Free();
            outputIdxHandle.Free();
            outputTensorsHandle.Free();
            return(Array.ConvertAll(outputTensors, t => new Tensor(t)));
        }
Exemplo n.º 11
0
 /// <summary>
 /// Close a session.
 /// Contacts any other processes associated with the session, if applicable.
 /// </summary>
 /// <param name="status">The status</param>
 public void Close(Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
         TfInvoke.tfeCloseSession(_ptr, checker.Status);
 }
Exemplo n.º 12
0
 /// <summary>
 /// Write out a serialized representation of `graph` (as a GraphDef protocol
 /// message).
 /// </summary>
 /// <param name="outputGraphDef">The buffer to store the GraphDef</param>
 /// <param name="status">The status</param>
 public void ToGraphDef(Buffer outputGraphDef, Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
         TfInvoke.tfeGraphToGraphDef(_ptr, outputGraphDef, checker.Status);
 }
Exemplo n.º 13
0
        /// <summary>
        /// Create a Session from a SavedModel. If successful, populates the internal graph with the contents of the Graph and
        /// with the MetaGraphDef of the loaded model.
        /// </summary>
        /// <param name="exportDir">Must be set to the path of the exported SavedModel.</param>
        /// <param name="tags">Must include the set of tags used to identify one MetaGraphDef in the SavedModel. Could be "serve", "tpu", "gpu", "train" or other values.</param>
        /// <param name="sessionOptions">Session options</param>
        /// <param name="runOptions"></param>
        /// <param name="status">The status</param>
        public Session(
            String exportDir,
            String[] tags,
            SessionOptions sessionOptions = null,
            Buffer runOptions             = null,
            Status status = null)
        {
            _graph            = new Graph();
            _graphNeedDispose = true;
            _metaGraphDef     = new Buffer();

            IntPtr exportDirPtr = Marshal.StringToHGlobalAnsi(exportDir);

            IntPtr[] tagsNative;
            GCHandle tagsNativeHandle;
            IntPtr   tagsNativePointer = IntPtr.Zero;

            if (tags != null)
            {
                tagsNative = new IntPtr[tags.Length];
                for (int i = 0; i < tags.Length; i++)
                {
                    tagsNative[i] = Marshal.StringToHGlobalAnsi(tags[i]);
                }
                tagsNativeHandle  = GCHandle.Alloc(tagsNative, GCHandleType.Pinned);
                tagsNativePointer = tagsNativeHandle.AddrOfPinnedObject();
            }
            else
            {
                tagsNativeHandle = new GCHandle();
                tagsNative       = new IntPtr[0];
            }

            try
            {
                using (StatusChecker checker = new StatusChecker(status))
                    _ptr = TfInvoke.tfeLoadSessionFromSavedModel(
                        sessionOptions,
                        runOptions,
                        exportDirPtr,
                        tagsNativePointer,
                        tagsNative.Length,
                        _graph,
                        _metaGraphDef,
                        checker.Status
                        );
            }
            catch (Exception excpt)
            {
                Trace.WriteLine(excpt.Message);
                throw;
            }
            finally
            {
                Marshal.FreeHGlobal(exportDirPtr);

                if (tagsNativeHandle.IsAllocated)
                {
                    tagsNativeHandle.Free();
                }

                for (int i = 0; i < tagsNative.Length; i++)
                {
                    Marshal.FreeHGlobal(tagsNative[i]);
                }
            }
        }
Exemplo n.º 14
0
 /// <summary>
 /// Write out a serialized representation of this Function (as a FunctionDef protocol
 /// message)
 /// </summary>
 /// <param name="outputFuncDef">a serialized representation of this Function (as a FunctionDef protocol message) </param>
 /// <param name="status">The status</param>
 public void ToFunctionDef(Buffer outputFuncDef, Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
         TfInvoke.tfeFunctionToFunctionDef(_ptr, outputFuncDef, checker.Status);
 }
Exemplo n.º 15
0
 /// <summary>
 /// Load the library specified by libraryFilename and register the ops and
 /// kernels present in that library.
 /// </summary>
 /// <param name="libraryFilename">The library file name</param>
 /// <param name="status">The status</param>
 public Library(String libraryFilename, Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
         _ptr = TfInvoke.tfeLoadLibrary(libraryFilename, checker.Status);
 }
Exemplo n.º 16
0
 /// <summary>
 /// Return a new execution session with the associated graph.
 /// </summary>
 /// <param name="graph">Graph must be a valid graph (not deleted or null).  This function will
 /// prevent the graph from being deleted until Session is deleted.
 /// Does not take ownership of opts.
 /// </param>
 /// <param name="sessionOptions">The session options</param>
 /// <param name="status">The status</param>
 public Session(Graph graph, SessionOptions sessionOptions = null, Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
         _ptr = TfInvoke.tfeNewSession(graph, sessionOptions, checker.Status);
 }
Exemplo n.º 17
0
 /// <summary>
 /// Import the graph serialized in <paramref name="graphDef"/> into the current graph.
 /// Convenience function for when no return outputs have been added.
 /// </summary>
 /// <param name="graphDef">The GraphDef to be imported</param>
 /// <param name="options">The import options</param>
 /// <param name="status">The status</param>
 public void ImportGraphDef(Buffer graphDef, ImportGraphDefOptions options, Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
         TfInvoke.tfeGraphImportGraphDef(_ptr, graphDef, options, checker.Status);
 }
Exemplo n.º 18
0
 /// <summary>
 /// Returns the serialized VersionDef proto for this graph.
 /// </summary>
 /// <param name="versionDef">The serialized VersionDef proto for this graph.</param>
 /// <param name="status">The status</param>
 public void Versions(Buffer versionDef, Status status = null)
 {
     using (StatusChecker checker = new StatusChecker(status))
         TfInvoke.tfeGraphVersions(_ptr, versionDef, checker.Status);
 }