private unsafe NDArray fetchValue(IntPtr output) { var tensor = new Tensor(output); NDArray nd = null; Type type = tensor.dtype.as_numpy_datatype(); var ndims = tensor.shape.Select(x => (int)x).ToArray(); switch (tensor.dtype) { case TF_DataType.TF_STRING: var bytes = tensor.Data(); // wired, don't know why we have to start from offset 9. // length in the begin var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); nd = np.array(str).reshape(); break; case TF_DataType.TF_INT16: var shorts = new short[tensor.size]; for (ulong i = 0; i < tensor.size; i++) { shorts[i] = *(short *)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i)); } nd = np.array(shorts).reshape(ndims); break; case TF_DataType.TF_INT32: var ints = new int[tensor.size]; for (ulong i = 0; i < tensor.size; i++) { ints[i] = *(int *)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i)); } nd = np.array(ints).reshape(ndims); break; case TF_DataType.TF_FLOAT: var floats = new float[tensor.size]; for (ulong i = 0; i < tensor.size; i++) { floats[i] = *(float *)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i)); } nd = np.array(floats).reshape(ndims); break; case TF_DataType.TF_DOUBLE: var doubles = new double[tensor.size]; for (ulong i = 0; i < tensor.size; i++) { doubles[i] = *(double *)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i)); } nd = np.array(doubles).reshape(ndims); break; default: throw new NotImplementedException("can't fetch output"); } return(nd); }
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair <TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List <Operation> target_list) { // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); var status = new Status(); var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); c_api.TF_SessionRun(_session, run_options: null, inputs: feed_dict.Select(f => f.Key).ToArray(), input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), ninputs: feed_dict.Length, outputs: fetch_list, output_values: output_values, noutputs: fetch_list.Length, target_opers: target_list.Select(f => (IntPtr)f).ToArray(), ntargets: target_list.Count, run_metadata: IntPtr.Zero, status: status); status.Check(true); var result = new NDArray[fetch_list.Length]; for (int i = 0; i < fetch_list.Length; i++) { var tensor = new Tensor(output_values[i]); switch (tensor.dtype) { case TF_DataType.TF_STRING: { // wired, don't know why we have to start from offset 9. var bytes = tensor.Data(); var output = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9); result[i] = fetchValue(tensor, output); } break; case TF_DataType.TF_FLOAT: { var output = *(float *)c_api.TF_TensorData(output_values[i]); result[i] = fetchValue(tensor, output); } break; case TF_DataType.TF_INT16: { var output = *(short *)c_api.TF_TensorData(output_values[i]); result[i] = fetchValue(tensor, output); } break; case TF_DataType.TF_INT32: { var output = *(int *)c_api.TF_TensorData(output_values[i]); result[i] = fetchValue(tensor, output); } break; default: throw new NotImplementedException("can't get output"); } } return(result); }