Exemple #1
0
        private unsafe NDArray fetchValue(IntPtr output)
        {
            var tensor = new Tensor(output);
            NDArray nd = null;
            Type type = tensor.dtype.as_numpy_datatype();
            var ndims = tensor.shape.Select(x => (int)x).ToArray();
            var offset = c_api.TF_TensorData(output);

            switch (tensor.dtype)
            {
                case TF_DataType.TF_BOOL:
                    var bools = new bool[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(bools).reshape(ndims);
                    break;
                case TF_DataType.TF_STRING:
                    var bytes = tensor.Data();
                    // wired, don't know why we have to start from offset 9.
                    // length in the begin
                    var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
                    nd = np.array(str).reshape();
                    break;
                case TF_DataType.TF_UINT8:
                    var _bytes = new byte[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(_bytes).reshape(ndims);
                    break;
                case TF_DataType.TF_INT16:
                    var shorts = new short[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        shorts[i] = *(short*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(shorts).reshape(ndims);
                    break;
                case TF_DataType.TF_INT32:
                    var ints = new int[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        ints[i] = *(int*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(ints).reshape(ndims);
                    break;
                case TF_DataType.TF_INT64:
                    var longs = new long[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        longs[i] = *(long*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(longs).reshape(ndims);
                    break;
                case TF_DataType.TF_FLOAT:
                    var floats = new float[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        floats[i] = *(float*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(floats).reshape(ndims);
                    break;
                case TF_DataType.TF_DOUBLE:
                    var doubles = new double[tensor.size];
                    for (ulong i = 0; i < tensor.size; i++)
                        doubles[i] = *(double*)(offset + (int)(tensor.itemsize * i));
                    nd = np.array(doubles).reshape(ndims);
                    break;
                default:
                    throw new NotImplementedException("can't fetch output");
            }

            tensor.Dispose();

            return nd;
        }