Пример #1
0
        public static unsafe SafeTensorHandle TF_NewTensor(Shape shape, TF_DataType dtype, void *data)
        {
            var length = shape.size * dtype.get_datatype_size();
            var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, (ulong)length);
            var tensor = TF_TensorData(handle);

            if (tensor == IntPtr.Zero)
            {
                throw new TensorflowException("AllocateTensor failed.");
            }
            if (data != null)
                System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length); }
Пример #2
0
        /// <summary>
        /// Create a TensorProto, invoked in graph mode
        /// </summary>
        /// <param name="values"></param>
        /// <param name="dtype"></param>
        /// <param name="shape"></param>
        /// <param name="verify_shape"></param>
        /// <param name="allow_broadcast"></param>
        /// <returns></returns>
        public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, Shape?shape = null, bool verify_shape = false, bool allow_broadcast = false)
        {
            if (allow_broadcast && verify_shape)
            {
                throw new ValueError("allow_broadcast and verify_shape are not both allowed.");
            }
            if (values is TensorProto tp)
            {
                return(tp);
            }

            var origin_dtype = values.GetDataType();

            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = origin_dtype;
            }
            else if (origin_dtype != dtype)
            {
                var new_system_dtype = dtype.as_system_dtype();
                if (values is long[] long_values)
                {
                    if (dtype == TF_DataType.TF_INT32)
                    {
                        values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
                    }
                }
                else
                {
                    values = Convert.ChangeType(values, new_system_dtype);
                }

                dtype = values.GetDataType();
            }

            shape = shape ?? values.GetShape();
            var tensor_proto = new TensorProto
            {
                Dtype       = dtype.as_datatype_enum(),
                TensorShape = shape.as_shape_proto()
            };

            if (values is NDArray nd)
            {
                // scalar
                if (nd.shape.IsScalar)
                {
                    switch (nd.dtype)
                    {
                    case TF_DataType.TF_BOOL:
                        tensor_proto.BoolVal.AddRange(nd.ToArray <bool>());
                        break;

                    case TF_DataType.TF_UINT8:
                        tensor_proto.IntVal.AddRange(nd.ToArray <byte>().Select(x => (int)x).ToArray());
                        break;

                    case TF_DataType.TF_INT32:
                        tensor_proto.IntVal.AddRange(nd.ToArray <int>());
                        break;

                    case TF_DataType.TF_INT64:
                        tensor_proto.Int64Val.AddRange(nd.ToArray <long>());
                        break;

                    case TF_DataType.TF_FLOAT:
                        tensor_proto.FloatVal.AddRange(nd.ToArray <float>());
                        break;

                    case TF_DataType.TF_DOUBLE:
                        tensor_proto.DoubleVal.AddRange(nd.ToArray <double>());
                        break;

                    default:
                        throw new Exception("make_tensor_proto Not Implemented");
                    }
                }
                else
                {
                    var    len   = nd.dtypesize * nd.size;
                    byte[] bytes = nd.ToByteArray();
                    tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
                }
            }
            else if (dtype == TF_DataType.TF_STRING && !(values is NDArray))
            {
                if (values is string str)
                {
                    tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str));
                }
                else if (values is string[] str_values)
                {
                    tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
                }
                else if (values is byte[] byte_values)
                {
                    tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values);
                }
            }
            else if (values is Array array)
            {
                // array
                var    len   = dtype.get_datatype_size() * (int)shape.size;
                byte[] bytes = new byte[len];
                System.Buffer.BlockCopy(array, 0, bytes, 0, len);
                tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
            }
            else
            {
                switch (values)
                {
                case Axis val:
                    tensor_proto.IntVal.AddRange(val.axis);
                    break;

                case Shape val:
                    tensor_proto.Int64Val.AddRange(val.dims);
                    break;

                case bool val:
                    tensor_proto.BoolVal.AddRange(new[] { val });
                    break;

                case sbyte val:
                    tensor_proto.IntVal.AddRange(new[] { (int)val });
                    break;

                case int val:
                    tensor_proto.IntVal.AddRange(new[] { val });
                    break;

                case long val:
                    tensor_proto.Int64Val.AddRange(new[] { val });
                    break;

                case float val:
                    tensor_proto.FloatVal.AddRange(new[] { val });
                    break;

                case double val:
                    tensor_proto.DoubleVal.AddRange(new[] { val });
                    break;

                default:
                    throw new Exception("make_tensor_proto Not Implemented");
                }
            }

            return(tensor_proto);
        }