/// <summary>
        /// Create a new variable handle, optionally copying in `extra_handle_data`
        /// </summary>
        /// <param name="shape"></param>
        /// <param name="dtype"></param>
        /// <param name="shared_name"></param>
        /// <param name="name"></param>
        /// <param name="graph_mode"></param>
        /// <param name="initial_value"></param>
        /// <returns></returns>
        public static Tensor variable_handle_from_shape_and_dtype(Shape shape, TF_DataType dtype,
                                                                  string shared_name, string name, bool graph_mode, Tensor initial_value = null)
        {
            var container = ops.get_default_graph().Container;
            var handle    = gen_resource_variable_ops.var_handle_op(shape: shape,
                                                                    dtype: dtype,
                                                                    shared_name: shared_name,
                                                                    name: name,
                                                                    container: container);

            if (initial_value == null)
            {
                initial_value = handle;
            }

            if (graph_mode)
            {
                var full_handle_data = _combine_handle_data(handle, initial_value);
                _set_handle_shapes_and_types(handle, full_handle_data, graph_mode);
                return(handle);
            }
            else
            {
                // We do not want two distinct ResourceVariable objects for the same
                // underlying resource in the runtime.
                // When in eager mode, explicitly ensure so here. When in graph mode, it's
                // ensured by always generating different variable names.
                var exists = gen_resource_variable_ops.var_is_initialized_op(handle);

                // We create an assert Op instead of checking right away in order to be
                // compatible with ASYNC execution mode. Further, since not all devices
                // support string tensors, we encode the assertion string in the Op name

                /*gen_logging_ops.assert(gen_math_ops.logical_not(exists),
                 *  new[] { exists },
                 *  name: "EagerVariableNameReuse");*/

                var handle_data = new HandleData();
                handle_data.IsSet = true;
                handle_data.ShapeAndType.Add(new HandleShapeAndType
                {
                    Dtype = dtype.as_datatype_enum(),
                    Shape = shape.as_proto()
                });
                _set_handle_shapes_and_types(handle, handle_data, graph_mode);
                return(handle);
            }
        }
Пример #2
0
        /// <summary>
        /// Create a TensorProto, invoked in graph mode
        /// </summary>
        /// <param name="values"></param>
        /// <param name="dtype"></param>
        /// <param name="shape"></param>
        /// <param name="verify_shape"></param>
        /// <param name="allow_broadcast"></param>
        /// <returns></returns>
        public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, Shape?shape = null, bool verify_shape = false, bool allow_broadcast = false)
        {
            if (allow_broadcast && verify_shape)
            {
                throw new ValueError("allow_broadcast and verify_shape are not both allowed.");
            }
            if (values is TensorProto tp)
            {
                return(tp);
            }

            var origin_dtype = values.GetDataType();

            if (dtype == TF_DataType.DtInvalid)
            {
                dtype = origin_dtype;
            }
            else if (origin_dtype != dtype)
            {
                var new_system_dtype = dtype.as_system_dtype();
                if (values is long[] long_values)
                {
                    if (dtype == TF_DataType.TF_INT32)
                    {
                        values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
                    }
                }
                else
                {
                    values = Convert.ChangeType(values, new_system_dtype);
                }

                dtype = values.GetDataType();
            }

            shape = shape ?? values.GetShape();
            var tensor_proto = new TensorProto
            {
                Dtype       = dtype.as_datatype_enum(),
                TensorShape = shape.as_shape_proto()
            };

            if (values is NDArray nd)
            {
                // scalar
                if (nd.shape.IsScalar)
                {
                    switch (nd.dtype)
                    {
                    case TF_DataType.TF_BOOL:
                        tensor_proto.BoolVal.AddRange(nd.ToArray <bool>());
                        break;

                    case TF_DataType.TF_UINT8:
                        tensor_proto.IntVal.AddRange(nd.ToArray <byte>().Select(x => (int)x).ToArray());
                        break;

                    case TF_DataType.TF_INT32:
                        tensor_proto.IntVal.AddRange(nd.ToArray <int>());
                        break;

                    case TF_DataType.TF_INT64:
                        tensor_proto.Int64Val.AddRange(nd.ToArray <long>());
                        break;

                    case TF_DataType.TF_FLOAT:
                        tensor_proto.FloatVal.AddRange(nd.ToArray <float>());
                        break;

                    case TF_DataType.TF_DOUBLE:
                        tensor_proto.DoubleVal.AddRange(nd.ToArray <double>());
                        break;

                    default:
                        throw new Exception("make_tensor_proto Not Implemented");
                    }
                }
                else
                {
                    var    len   = nd.dtypesize * nd.size;
                    byte[] bytes = nd.ToByteArray();
                    tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
                }
            }
            else if (dtype == TF_DataType.TF_STRING && !(values is NDArray))
            {
                if (values is string str)
                {
                    tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str));
                }
                else if (values is string[] str_values)
                {
                    tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x)));
                }
                else if (values is byte[] byte_values)
                {
                    tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values);
                }
            }
            else if (values is Array array)
            {
                // array
                var    len   = dtype.get_datatype_size() * (int)shape.size;
                byte[] bytes = new byte[len];
                System.Buffer.BlockCopy(array, 0, bytes, 0, len);
                tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
            }
            else
            {
                switch (values)
                {
                case Axis val:
                    tensor_proto.IntVal.AddRange(val.axis);
                    break;

                case Shape val:
                    tensor_proto.Int64Val.AddRange(val.dims);
                    break;

                case bool val:
                    tensor_proto.BoolVal.AddRange(new[] { val });
                    break;

                case sbyte val:
                    tensor_proto.IntVal.AddRange(new[] { (int)val });
                    break;

                case int val:
                    tensor_proto.IntVal.AddRange(new[] { val });
                    break;

                case long val:
                    tensor_proto.Int64Val.AddRange(new[] { val });
                    break;

                case float val:
                    tensor_proto.FloatVal.AddRange(new[] { val });
                    break;

                case double val:
                    tensor_proto.DoubleVal.AddRange(new[] { val });
                    break;

                default:
                    throw new Exception("make_tensor_proto Not Implemented");
                }
            }

            return(tensor_proto);
        }
Пример #3
0
 public static bool is_compatible_with(this TF_DataType self, TF_DataType other)
 {
     return(self.as_datatype_enum() == other.as_datatype_enum());
 }
Пример #4
0
 public DataType _MakeType(TF_DataType v, AttrDef attr_def)
 {
     return(v.as_datatype_enum());
 }