/// <summary> /// Get the tensor shape proto value of the attribute /// </summary> /// <param name="attrName">The name of the attribute</param> /// <param name="status">The status</param> /// <returns>The buffer that contains the TensorShapeProto </returns> public Buffer GetAttrTensorShapeProto(String attrName, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) { AttrMetadata meta = GetAttrMetadata(attrName, status); if (meta.Type != AttrType.Shape) { throw new ArgumentException(String.Format("Attribute {0} ({1}) is not a shape", attrName, meta.Type)); } Buffer buffer = new Buffer(); TfInvoke.tfeOperationGetAttrTensorShapeProto(_ptr, attrName, buffer, checker.Status); return(buffer); } }
/// <summary> /// Convert a byte array to a Tensor /// </summary> /// <param name="value">The byte array</param> /// <param name="status">Optional status</param> /// <returns>The tensor</returns> public static Tensor FromString(byte[] value, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) { int length = TfInvoke.tfeStringEncodedSize(value.Length); Tensor tensor = new Tensor(DataType.String, length + 8); IntPtr ptr = tensor.DataPointer; Marshal.WriteInt64(ptr, 0); GCHandle handle = GCHandle.Alloc(value, GCHandleType.Pinned); TfInvoke.tfeStringEncode(handle.AddrOfPinnedObject(), value.Length, new IntPtr(ptr.ToInt64() + 8), length, checker.Status); handle.Free(); return(tensor); } }
public static String[] ListAllPhysicalDevices(Status status = null) { using (StatusChecker checker = new StatusChecker(status)) { byte[] nameBuffer = new byte[2048]; GCHandle nameHandle = GCHandle.Alloc(nameBuffer, GCHandleType.Pinned); TfInvoke.tfeListAllPhysicalDevices( nameHandle.AddrOfPinnedObject(), checker.Status); nameHandle.Free(); String nameResult = System.Text.Encoding.ASCII.GetString(nameBuffer); String[] names = nameResult.TrimEnd('\0', '\n').Split('\n'); return(names); } }
/// <summary> /// Returns the shape of the Tensor /// </summary> /// <param name="output">The output</param> /// <param name="status">The status</param> /// <returns>The shape of the Tensor</returns> public int[] GetTensorShape(Output output, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) { int numDim = TfInvoke.tfeGraphGetTensorNumDims(_ptr, output.Operation, output.Index, checker.Status); if (numDim < 0) { return(null); } else if (numDim == 0) { return(new int[0]); } int[] dims = new int[numDim]; GCHandle handle = GCHandle.Alloc(dims, GCHandleType.Pinned); TfInvoke.tfeGraphGetTensorShape(_ptr, output.Operation, output.Index, handle.AddrOfPinnedObject(), numDim, checker.Status); handle.Free(); return(dims); } }
/// <summary> /// Lists all devices in a session /// </summary> /// <param name="status">The status</param> /// <returns>All devices in the current session</returns> public Device[] ListDevices(Status status = null) { using (StatusChecker checker = new StatusChecker(status)) { byte[] nameBuffer = new byte[2048]; byte[] typeBuffer = new byte[2048]; Int64[] memorySizeBuffer = new Int64[128]; GCHandle nameHandle = GCHandle.Alloc(nameBuffer, GCHandleType.Pinned); GCHandle typeHandle = GCHandle.Alloc(typeBuffer, GCHandleType.Pinned); GCHandle memorySizeHandle = GCHandle.Alloc(memorySizeBuffer, GCHandleType.Pinned); TfInvoke.tfeSessionListDevices( _ptr, nameHandle.AddrOfPinnedObject(), typeHandle.AddrOfPinnedObject(), memorySizeHandle.AddrOfPinnedObject(), checker.Status); nameHandle.Free(); typeHandle.Free(); memorySizeHandle.Free(); String nameResult = System.Text.Encoding.ASCII.GetString(nameBuffer); String[] names = nameResult.TrimEnd('\0', '\n').Split('\n'); String typeResult = System.Text.Encoding.ASCII.GetString(typeBuffer); String[] types = typeResult.TrimEnd('\0', '\n').Split('\n'); Device[] devices = new Device[names.Length]; for (int i = 0; i < devices.Length; i++) { Device d = new Device(); d.Name = names[i]; d.Type = types[i]; d.MemoryBytes = memorySizeBuffer[i]; devices[i] = d; } return(devices); } }
/// <summary> /// Release the unmanaged memory associated with this Session. /// </summary> protected override void DisposeObject() { if (IntPtr.Zero != _ptr) { using (StatusChecker checker = new StatusChecker(null)) TfInvoke.tfeDeleteSession(ref _ptr, checker.Status); } if (_graphNeedDispose && _graph != null) { _graph.Dispose(); _graph = null; } if (_metaGraphDef != null) { _metaGraphDef.Dispose(); _metaGraphDef = null; } _graph = null; }
/// <summary> /// Get the value of the attribute that is a String /// </summary> /// <param name="attrName">The name of the attribute</param> /// <param name="status">The status</param> /// <returns>The string value of the attribute</returns> public String GetAttrString(String attrName, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) { AttrMetadata meta = GetAttrMetadata(attrName, status); if (meta.Type != AttrType.String) { throw new ArgumentException(String.Format("Attribute {0} ({1}) is not a String", attrName, meta.Type)); } IntPtr s = Marshal.AllocHGlobal((int)meta.TotalSize); try { TfInvoke.tfeOperationGetAttrString(_ptr, attrName, s, (int)meta.TotalSize, checker.Status); return(Marshal.PtrToStringAnsi(s)); } finally { Marshal.FreeHGlobal(s); } } }
/// <summary> /// Get the value of the attribute that is a list of Int64 /// </summary> /// <param name="attrName">The name of the attribute</param> /// <param name="status">The status</param> /// <returns>A list ofInt64</returns> public Int64[] GetAttrIntList(String attrName, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) { AttrMetadata meta = GetAttrMetadata(attrName, status); if (!((meta.Type == AttrType.Int) && meta.IsList)) { throw new ArgumentException(String.Format("Attribute {0} ({1}) is not a List of Int", attrName, meta.Type)); } Int64[] list = new Int64[meta.ListSize]; GCHandle handle = GCHandle.Alloc(list, GCHandleType.Pinned); try { TfInvoke.tfeOperationGetAttrIntList(_ptr, attrName, handle.AddrOfPinnedObject(), list.Length, checker.Status); } finally { handle.Free(); } return(list); } }
/// <summary> /// Set a Tensor as an attribute /// </summary> /// <param name="attrName">The name of the attribute</param> /// <param name="tensor">The Tensor</param> /// <param name="status">The status</param> public void SetAttr(String attrName, Tensor tensor, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) TfInvoke.tfeSetAttrTensor(_ptr, attrName, tensor, checker.Status); }
/// <summary> /// Run the graph associated with the session starting with the supplied inputs /// (inputs[0,ninputs-1] with corresponding values in input_values[0,ninputs-1]). /// </summary> /// <param name="inputs">The input nodes</param> /// <param name="inputValues">The input values</param> /// <param name="outputs">The output nodes</param> /// <param name="targetOperations">Optional target operations</param> /// <param name="runOptions"></param> /// May be NULL, in which case it will be ignored; or /// non-NULL, in which case it must point to a `TF_Buffer` containing the /// serialized representation of a `RunOptions` protocol buffer. /// <param name="runMetadata"> /// May be NULL, in which case it will be ignored; or /// non-NULL, in which case it must point to an empty, freshly allocated /// `TF_Buffer` that may be updated to contain the serialized representation /// of a `RunMetadata` protocol buffer. /// </param> /// <param name="status">The status</param> /// <returns>On success, the tensors corresponding to outputs[0,noutputs-1] are placed in the returned Tensors.</returns> public Tensor[] Run(Output[] inputs, Tensor[] inputValues, Output[] outputs, Operation[] targetOperations = null, Buffer runOptions = null, Buffer runMetadata = null, Status status = null) { IntPtr[] inputOps = Array.ConvertAll(inputs, i => i.Operation.Ptr); int[] inputIdx = Array.ConvertAll(inputs, i => i.Index); IntPtr[] inputTensors = Array.ConvertAll(inputValues, i => i.Ptr); GCHandle inputOpsHandle = GCHandle.Alloc(inputOps, GCHandleType.Pinned); GCHandle inputIdxHandle = GCHandle.Alloc(inputIdx, GCHandleType.Pinned); GCHandle inputTensorsHandle = GCHandle.Alloc(inputTensors, GCHandleType.Pinned); IntPtr[] outputOps = Array.ConvertAll(outputs, o => o.Operation.Ptr); int[] outputIdx = Array.ConvertAll(outputs, o => o.Index); IntPtr[] outputTensors = new IntPtr[outputs.Length]; GCHandle outputOpsHandle = GCHandle.Alloc(outputOps, GCHandleType.Pinned); GCHandle outputIdxHandle = GCHandle.Alloc(outputIdx, GCHandleType.Pinned); GCHandle outputTensorsHandle = GCHandle.Alloc(outputTensors, GCHandleType.Pinned); IntPtr targetOpsPtr = IntPtr.Zero; int ntargets = 0; IntPtr[] targetOpsPtrArray = null; GCHandle targetOpsHandle = new GCHandle(); if (targetOperations != null) { targetOpsPtrArray = Array.ConvertAll(targetOperations, o => o.Ptr); targetOpsHandle = GCHandle.Alloc(targetOpsPtrArray, GCHandleType.Pinned); targetOpsPtr = targetOpsHandle.AddrOfPinnedObject(); ntargets = targetOperations.Length; } using (StatusChecker checker = new StatusChecker(status)) { TfInvoke.tfeSessionRun( _ptr, runOptions, inputOpsHandle.AddrOfPinnedObject(), inputIdxHandle.AddrOfPinnedObject(), inputTensorsHandle.AddrOfPinnedObject(), inputs.Length, outputOpsHandle.AddrOfPinnedObject(), outputIdxHandle.AddrOfPinnedObject(), outputTensorsHandle.AddrOfPinnedObject(), outputs.Length, targetOpsPtr, ntargets, runMetadata, checker.Status ); } inputOpsHandle.Free(); inputIdxHandle.Free(); inputTensorsHandle.Free(); if (targetOperations != null) { targetOpsHandle.Free(); } outputOpsHandle.Free(); outputIdxHandle.Free(); outputTensorsHandle.Free(); return(Array.ConvertAll(outputTensors, t => new Tensor(t))); }
/// <summary> /// Close a session. /// Contacts any other processes associated with the session, if applicable. /// </summary> /// <param name="status">The status</param> public void Close(Status status = null) { using (StatusChecker checker = new StatusChecker(status)) TfInvoke.tfeCloseSession(_ptr, checker.Status); }
/// <summary> /// Write out a serialized representation of `graph` (as a GraphDef protocol /// message). /// </summary> /// <param name="outputGraphDef">The buffer to store the GraphDef</param> /// <param name="status">The status</param> public void ToGraphDef(Buffer outputGraphDef, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) TfInvoke.tfeGraphToGraphDef(_ptr, outputGraphDef, checker.Status); }
/// <summary> /// Create a Session from a SavedModel. If successful, populates the internal graph with the contents of the Graph and /// with the MetaGraphDef of the loaded model. /// </summary> /// <param name="exportDir">Must be set to the path of the exported SavedModel.</param> /// <param name="tags">Must include the set of tags used to identify one MetaGraphDef in the SavedModel. Could be "serve", "tpu", "gpu", "train" or other values.</param> /// <param name="sessionOptions">Session options</param> /// <param name="runOptions"></param> /// <param name="status">The status</param> public Session( String exportDir, String[] tags, SessionOptions sessionOptions = null, Buffer runOptions = null, Status status = null) { _graph = new Graph(); _graphNeedDispose = true; _metaGraphDef = new Buffer(); IntPtr exportDirPtr = Marshal.StringToHGlobalAnsi(exportDir); IntPtr[] tagsNative; GCHandle tagsNativeHandle; IntPtr tagsNativePointer = IntPtr.Zero; if (tags != null) { tagsNative = new IntPtr[tags.Length]; for (int i = 0; i < tags.Length; i++) { tagsNative[i] = Marshal.StringToHGlobalAnsi(tags[i]); } tagsNativeHandle = GCHandle.Alloc(tagsNative, GCHandleType.Pinned); tagsNativePointer = tagsNativeHandle.AddrOfPinnedObject(); } else { tagsNativeHandle = new GCHandle(); tagsNative = new IntPtr[0]; } try { using (StatusChecker checker = new StatusChecker(status)) _ptr = TfInvoke.tfeLoadSessionFromSavedModel( sessionOptions, runOptions, exportDirPtr, tagsNativePointer, tagsNative.Length, _graph, _metaGraphDef, checker.Status ); } catch (Exception excpt) { Trace.WriteLine(excpt.Message); throw; } finally { Marshal.FreeHGlobal(exportDirPtr); if (tagsNativeHandle.IsAllocated) { tagsNativeHandle.Free(); } for (int i = 0; i < tagsNative.Length; i++) { Marshal.FreeHGlobal(tagsNative[i]); } } }
/// <summary> /// Write out a serialized representation of this Function (as a FunctionDef protocol /// message) /// </summary> /// <param name="outputFuncDef">a serialized representation of this Function (as a FunctionDef protocol message) </param> /// <param name="status">The status</param> public void ToFunctionDef(Buffer outputFuncDef, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) TfInvoke.tfeFunctionToFunctionDef(_ptr, outputFuncDef, checker.Status); }
/// <summary> /// Load the library specified by libraryFilename and register the ops and /// kernels present in that library. /// </summary> /// <param name="libraryFilename">The library file name</param> /// <param name="status">The status</param> public Library(String libraryFilename, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) _ptr = TfInvoke.tfeLoadLibrary(libraryFilename, checker.Status); }
/// <summary> /// Return a new execution session with the associated graph. /// </summary> /// <param name="graph">Graph must be a valid graph (not deleted or null). This function will /// prevent the graph from being deleted until Session is deleted. /// Does not take ownership of opts. /// </param> /// <param name="sessionOptions">The session options</param> /// <param name="status">The status</param> public Session(Graph graph, SessionOptions sessionOptions = null, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) _ptr = TfInvoke.tfeNewSession(graph, sessionOptions, checker.Status); }
/// <summary> /// Import the graph serialized in <paramref name="graphDef"/> into the current graph. /// Convenience function for when no return outputs have been added. /// </summary> /// <param name="graphDef">The GraphDef to be imported</param> /// <param name="options">The import options</param> /// <param name="status">The status</param> public void ImportGraphDef(Buffer graphDef, ImportGraphDefOptions options, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) TfInvoke.tfeGraphImportGraphDef(_ptr, graphDef, options, checker.Status); }
/// <summary> /// Returns the serialized VersionDef proto for this graph. /// </summary> /// <param name="versionDef">The serialized VersionDef proto for this graph.</param> /// <param name="status">The status</param> public void Versions(Buffer versionDef, Status status = null) { using (StatusChecker checker = new StatusChecker(status)) TfInvoke.tfeGraphVersions(_ptr, versionDef, checker.Status); }