Exemplo n.º 1
0
        private unsafe NDArray fetchValue(IntPtr output)
        {
            var     tensor = new Tensor(output);
            NDArray nd     = null;
            Type    type   = tensor.dtype.as_numpy_datatype();
            var     ndims  = tensor.shape.Select(x => (int)x).ToArray();

            switch (tensor.dtype)
            {
            case TF_DataType.TF_STRING:
                var bytes = tensor.Data();
                // wired, don't know why we have to start from offset 9.
                // length in the begin
                var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
                nd = np.array(str).reshape();
                break;

            case TF_DataType.TF_INT16:
                var shorts = new short[tensor.size];
                for (ulong i = 0; i < tensor.size; i++)
                {
                    shorts[i] = *(short *)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i));
                }
                nd = np.array(shorts).reshape(ndims);
                break;

            case TF_DataType.TF_INT32:
                var ints = new int[tensor.size];
                for (ulong i = 0; i < tensor.size; i++)
                {
                    ints[i] = *(int *)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i));
                }
                nd = np.array(ints).reshape(ndims);
                break;

            case TF_DataType.TF_FLOAT:
                var floats = new float[tensor.size];
                for (ulong i = 0; i < tensor.size; i++)
                {
                    floats[i] = *(float *)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i));
                }
                nd = np.array(floats).reshape(ndims);
                break;

            case TF_DataType.TF_DOUBLE:
                var doubles = new double[tensor.size];
                for (ulong i = 0; i < tensor.size; i++)
                {
                    doubles[i] = *(double *)(c_api.TF_TensorData(output) + (int)(tensor.dataTypeSize * i));
                }
                nd = np.array(doubles).reshape(ndims);
                break;

            default:
                throw new NotImplementedException("can't fetch output");
            }

            return(nd);
        }
Exemplo n.º 2
0
        private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair <TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List <Operation> target_list)
        {
            // Ensure any changes to the graph are reflected in the runtime.
            _extend_graph();

            var status = new Status();

            var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();

            c_api.TF_SessionRun(_session,
                                run_options: null,
                                inputs: feed_dict.Select(f => f.Key).ToArray(),
                                input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
                                ninputs: feed_dict.Length,
                                outputs: fetch_list,
                                output_values: output_values,
                                noutputs: fetch_list.Length,
                                target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
                                ntargets: target_list.Count,
                                run_metadata: IntPtr.Zero,
                                status: status);

            status.Check(true);

            var result = new NDArray[fetch_list.Length];

            for (int i = 0; i < fetch_list.Length; i++)
            {
                var tensor = new Tensor(output_values[i]);

                switch (tensor.dtype)
                {
                case TF_DataType.TF_STRING:
                {
                    // wired, don't know why we have to start from offset 9.
                    var bytes  = tensor.Data();
                    var output = UTF8Encoding.Default.GetString(bytes, 9, bytes.Length - 9);
                    result[i] = fetchValue(tensor, output);
                }
                break;

                case TF_DataType.TF_FLOAT:
                {
                    var output = *(float *)c_api.TF_TensorData(output_values[i]);
                    result[i] = fetchValue(tensor, output);
                }
                break;

                case TF_DataType.TF_INT16:
                {
                    var output = *(short *)c_api.TF_TensorData(output_values[i]);
                    result[i] = fetchValue(tensor, output);
                }
                break;

                case TF_DataType.TF_INT32:
                {
                    var output = *(int *)c_api.TF_TensorData(output_values[i]);
                    result[i] = fetchValue(tensor, output);
                }
                break;

                default:
                    throw new NotImplementedException("can't get output");
                }
            }

            return(result);
        }