Ejemplo n.º 1
0
        public static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast)
        {
            if (tf.context.executing_eagerly())
            {
            }

            Graph g            = ops.get_default_graph();
            var   tensor_value = new AttrValue();

            tensor_value.Tensor = tensor_util.make_tensor_proto(value,
                                                                dtype: dtype,
                                                                shape: shape,
                                                                verify_shape: verify_shape,
                                                                allow_broadcast: allow_broadcast);

            var dtype_value = new AttrValue
            {
                Type = tensor_value.Tensor.Dtype,
            };

            var attrs = new Dictionary <string, AttrValue>();

            attrs["value"] = tensor_value;
            attrs["dtype"] = dtype_value;

            var op = g.create_op("Const",
                                 new Tensor[0],
                                 new TF_DataType[] { dtype_value.Type.as_tf_dtype() },
                                 attrs: attrs,
                                 name: name);

            return(op.outputs[0]);
        }
Ejemplo n.º 2
0
        /// <summary>
        /// Creates a constant tensor.
        ///
        /// The resulting tensor is populated with values of type `dtype`, as
        /// specified by arguments `value` and (optionally) `shape`
        /// </summary>
        /// <param name="value">A constant value (or list) of output type `dtype`.</param>
        /// <param name="dtype">The type of the elements of the resulting tensor.</param>
        /// <param name="shape">Optional dimensions of resulting tensor.</param>
        /// <param name="name">Optional name for the tensor.</param>
        /// <param name="verify_shape">Boolean that enables verification of a shape of values.</param>
        /// <returns></returns>
        public static Tensor Constant(NDArray nd, string name = "Const", bool verify_shape = false)
        {
            Graph g            = ops.get_default_graph();
            var   tensor_pb    = tensor_util.make_tensor_proto(nd, verify_shape);
            var   tensor_value = new AttrValue
            {
                Type   = tensor_pb.Dtype,
                Tensor = tensor_pb
            };

            var dtype_value = new AttrValue
            {
                Type = tensor_value.Tensor.Dtype,
            };

            var attrs = new Dictionary <string, AttrValue>();

            attrs["value"] = tensor_value;
            attrs["dtype"] = dtype_value;

            var op = g.create_op("Const",
                                 null,
                                 new TF_DataType[] { (TF_DataType)dtype_value.Type },
                                 attrs: attrs,
                                 name: name);

            return(op.outputs[0]);
        }
Ejemplo n.º 3
0
        /// <param name="verify_shape">Boolean that enables verification of a shape of values.</param>
        public static Tensor _constant_impl(object value,
                                            TF_DataType dtype,
                                            TensorShape shape,
                                            string name,
                                            bool verify_shape,
                                            bool allow_broadcast)
        {
            if (tf.Context.executing_eagerly())
            {
                var t = convert_to_eager_tensor(value, tf.Context, dtype: dtype);
                if (shape == null)
                {
                    return(t);
                }

                if (t.shape.SequenceEqual(shape.dims))
                {
                    return(t);
                }

                if (verify_shape)
                {
                    throw new TypeError($"Expected Tensor's shape: {shape}, got {t.shape}.");
                }

                var num_t = t.TensorShape.num_elements();
                if (num_t == shape.num_elements())
                {
                    return(_eager_reshape(t, shape, tf.Context));
                }
                if (num_t == 1)
                {
                    if (t.dtype == dtypes.@bool)
                    {
                        throw new NotImplementedException("");
                    }
                    else
                    {
                        return(_eager_fill(shape, t, tf.Context));
                    }
                }
            }

            Graph g            = ops.get_default_graph();
            var   tensor_value = new AttrValue();

            tensor_value.Tensor = tensor_util.make_tensor_proto(value,
                                                                dtype: dtype,
                                                                shape: shape,
                                                                verify_shape: verify_shape,
                                                                allow_broadcast: allow_broadcast);

            var dtype_value = new AttrValue
            {
                Type = tensor_value.Tensor.Dtype,
            };

            var attrs = new Dictionary <string, AttrValue>();

            attrs["value"] = tensor_value;
            attrs["dtype"] = dtype_value;

            var op = g.create_op("Const",
                                 new Tensor[0],
                                 new TF_DataType[] { dtype_value.Type.as_tf_dtype() },
                                 attrs: attrs,
                                 name: name);

            return(op.outputs[0]);
        }