/// <summary> /// Create a new variable handle, optionally copying in `extra_handle_data` /// </summary> /// <param name="shape"></param> /// <param name="dtype"></param> /// <param name="shared_name"></param> /// <param name="name"></param> /// <param name="graph_mode"></param> /// <param name="initial_value"></param> /// <returns></returns> public static Tensor variable_handle_from_shape_and_dtype(Shape shape, TF_DataType dtype, string shared_name, string name, bool graph_mode, Tensor initial_value = null) { var container = ops.get_default_graph().Container; var handle = gen_resource_variable_ops.var_handle_op(shape: shape, dtype: dtype, shared_name: shared_name, name: name, container: container); if (initial_value == null) { initial_value = handle; } if (graph_mode) { var full_handle_data = _combine_handle_data(handle, initial_value); _set_handle_shapes_and_types(handle, full_handle_data, graph_mode); return(handle); } else { // We do not want two distinct ResourceVariable objects for the same // underlying resource in the runtime. // When in eager mode, explicitly ensure so here. When in graph mode, it's // ensured by always generating different variable names. var exists = gen_resource_variable_ops.var_is_initialized_op(handle); // We create an assert Op instead of checking right away in order to be // compatible with ASYNC execution mode. Further, since not all devices // support string tensors, we encode the assertion string in the Op name /*gen_logging_ops.assert(gen_math_ops.logical_not(exists), * new[] { exists }, * name: "EagerVariableNameReuse");*/ var handle_data = new HandleData(); handle_data.IsSet = true; handle_data.ShapeAndType.Add(new HandleShapeAndType { Dtype = dtype.as_datatype_enum(), Shape = shape.as_proto() }); _set_handle_shapes_and_types(handle, handle_data, graph_mode); return(handle); } }
/// <summary> /// Create a TensorProto, invoked in graph mode /// </summary> /// <param name="values"></param> /// <param name="dtype"></param> /// <param name="shape"></param> /// <param name="verify_shape"></param> /// <param name="allow_broadcast"></param> /// <returns></returns> public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, Shape?shape = null, bool verify_shape = false, bool allow_broadcast = false) { if (allow_broadcast && verify_shape) { throw new ValueError("allow_broadcast and verify_shape are not both allowed."); } if (values is TensorProto tp) { return(tp); } var origin_dtype = values.GetDataType(); if (dtype == TF_DataType.DtInvalid) { dtype = origin_dtype; } else if (origin_dtype != dtype) { var new_system_dtype = dtype.as_system_dtype(); if (values is long[] long_values) { if (dtype == TF_DataType.TF_INT32) { values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); } } else { values = Convert.ChangeType(values, new_system_dtype); } dtype = values.GetDataType(); } shape = shape ?? values.GetShape(); var tensor_proto = new TensorProto { Dtype = dtype.as_datatype_enum(), TensorShape = shape.as_shape_proto() }; if (values is NDArray nd) { // scalar if (nd.shape.IsScalar) { switch (nd.dtype) { case TF_DataType.TF_BOOL: tensor_proto.BoolVal.AddRange(nd.ToArray <bool>()); break; case TF_DataType.TF_UINT8: tensor_proto.IntVal.AddRange(nd.ToArray <byte>().Select(x => (int)x).ToArray()); break; case TF_DataType.TF_INT32: tensor_proto.IntVal.AddRange(nd.ToArray <int>()); break; case TF_DataType.TF_INT64: tensor_proto.Int64Val.AddRange(nd.ToArray <long>()); break; case TF_DataType.TF_FLOAT: tensor_proto.FloatVal.AddRange(nd.ToArray <float>()); break; case TF_DataType.TF_DOUBLE: tensor_proto.DoubleVal.AddRange(nd.ToArray <double>()); break; default: throw new Exception("make_tensor_proto Not Implemented"); } } else { var len = nd.dtypesize * nd.size; byte[] bytes = nd.ToByteArray(); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); } } else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) { if (values is string str) { tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); } else if (values is string[] str_values) { tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); } else if (values is byte[] byte_values) { tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); } } else if (values is Array array) { // array var len = dtype.get_datatype_size() * (int)shape.size; byte[] bytes = new byte[len]; System.Buffer.BlockCopy(array, 0, bytes, 0, len); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); } else { switch (values) { case Axis val: tensor_proto.IntVal.AddRange(val.axis); break; case Shape val: tensor_proto.Int64Val.AddRange(val.dims); break; case bool val: tensor_proto.BoolVal.AddRange(new[] { val }); break; case sbyte val: tensor_proto.IntVal.AddRange(new[] { (int)val }); break; case int val: tensor_proto.IntVal.AddRange(new[] { val }); break; case long val: tensor_proto.Int64Val.AddRange(new[] { val }); break; case float val: tensor_proto.FloatVal.AddRange(new[] { val }); break; case double val: tensor_proto.DoubleVal.AddRange(new[] { val }); break; default: throw new Exception("make_tensor_proto Not Implemented"); } } return(tensor_proto); }
public static bool is_compatible_with(this TF_DataType self, TF_DataType other) { return(self.as_datatype_enum() == other.as_datatype_enum()); }
public DataType _MakeType(TF_DataType v, AttrDef attr_def) { return(v.as_datatype_enum()); }