示例#1
0
        /// <summary>
        /// Creates a handle to a Variable resource.
        /// </summary>
        /// <param name="dtype"></param>
        /// <param name="shape"></param>
        /// <param name="container"></param>
        /// <param name="shared_name"></param>
        /// <param name="name"></param>
        /// <returns></returns>
        public static Tensor var_handle_op(TF_DataType dtype, TensorShape shape,
                                           string container = "", string shared_name = "", string name = null)
        {
            if (tf.context.executing_eagerly())
            {
                var results = EagerTensorPass.Create();
                var attrs   = new object[]
                {
                    "container", container,
                    "shared_name", shared_name,
                    "dtype", dtype,
                    "shape", shape.dims
                };
                Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                          "VarHandleOp", name, null, 0,
                                                          wrap_tfe_src.SetOpAttrs2(attrs),
                                                          op => wrap_tfe_src.SetOpAttrs(op, attrs),
                                                          results.Points, results.Length);
                status.Check(true);
                return(results[0].Resolve());
            }

            var _op = _op_def_lib._apply_op_helper("VarHandleOp", name, new {
                dtype,
                shape,
                container,
                shared_name
            });

            return(_op.output);
        }
        public static Tensor one_hot(Tensor indices, Tensor depth,
                                     Tensor on_value   = null,
                                     Tensor off_value  = null,
                                     TF_DataType dtype = TF_DataType.DtInvalid,
                                     int axis          = -1,
                                     string name       = null)
        {
            if (tf.context.executing_eagerly())
            {
                var results = EagerTensorPass.Create();
                var inputs  = EagerTensorPass.From(indices, depth, on_value, off_value);
                var attrs   = new object[] { "axis", axis };

                Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                          "OneHot", name,
                                                          inputs.Points, inputs.Length,
                                                          wrap_tfe_src.SetOpAttrs2(attrs),
                                                          op => wrap_tfe_src.SetOpAttrs(op, attrs),
                                                          results.Points, results.Length);
                status.Check(true);

                return(results[0].Resolve());
            }

            var _op = _op_def_lib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis });

            return(_op.outputs[0]);
        }
示例#3
0
        public static EagerTensor mul(IntPtr x, IntPtr y, string name = null)
        {
            var    results = EagerTensorPass.Create();
            Status status  = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                       "Mul", name, new IntPtr[]
            {
                x,
                y,
            }, 2,
                                                       null, null,
                                                       results.Points, results.Length);

            status.Check(true);
            return(results[0].Resolve());
        }
示例#4
0
        /// <summary>
        /// Adds a value to the current value of a variable.
        /// </summary>
        /// <param name="resource"></param>
        /// <param name="value"></param>
        /// <param name="name"></param>
        /// <returns></returns>
        public static Operation assign_add_variable_op(Tensor resource, Tensor value, string name = null)
        {
            if (tf.context.executing_eagerly())
            {
                var    inputs = EagerTensorPass.From(resource, value);
                Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                          "AssignAddVariableOp", name,
                                                          inputs.Points, inputs.Length,
                                                          null, null,
                                                          null, 0);
                status.Check(true);
                return(null);
            }

            return(null);
        }
        private static ResourceVariable less <Tx, Ty>(Tx x, Ty y, string name = null)
        {
            if (tf.context.executing_eagerly())
            {
                var    results = EagerTensorPass.Create();
                var    inputs  = EagerTensorPass.From(x, y);
                Status status  = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                           "Less", name,
                                                           inputs.Points, inputs.Length,
                                                           null, null,
                                                           results.Points, results.Length);
                status.Check(true);
                return(tf.Variable(results[0].Resolve()));
            }

            var _op = _op_def_lib._apply_op_helper("Less", name: name, args: new { x, y });

            return(tf.Variable(_op.outputs[0]));
        }
        /// <summary>
        /// Broadcast an array for a compatible shape.
        /// </summary>
        /// <param name="input"></param>
        /// <param name="shape"></param>
        /// <param name="name"></param>
        /// <returns></returns>
        public static Tensor broadcast_to <T>(Tensor input, T shape, string name = null)
        {
            if (tf.context.executing_eagerly())
            {
                var results = EagerTensorPass.Create();
                var inputs  = EagerTensorPass.From(input, shape);

                Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                          "BroadcastTo", name,
                                                          inputs.Points, inputs.Length,
                                                          null, null,
                                                          results.Points, results.Length);
                status.Check(true);
                return(results[0].Resolve());
            }

            var _op = _op_def_lib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name });

            return(_op.outputs[0]);
        }
        public static Tensor select <Tx, Ty>(Tensor condition, Tx t, Ty e, string name = null)
        {
            if (tf.context.executing_eagerly())
            {
                var results = EagerTensorPass.Create();
                var inputs  = EagerTensorPass.From(condition, t, e);

                Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                          "SelectV2", name,
                                                          inputs.Points, inputs.Length,
                                                          null, null,
                                                          results.Points, results.Length);
                status.Check(true);

                return(results[0].Resolve());
            }

            var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e });

            return(_op.outputs[0]);
        }
        /// <summary>
        /// Creates a tensor filled with a scalar value.
        /// </summary>
        /// <param name="dims">A `Tensor`.</param>
        /// <param name="value">A `Tensor`. 0-D (scalar). Value to fill the returned tensor.</param>
        /// <param name="name">A name for the operation (optional).</param>
        /// <returns>A `Tensor`. Has the same type as `value`.</returns>
        public static Tensor fill <T>(Tensor dims, T value, string name = null)
        {
            if (tf.context.executing_eagerly())
            {
                var results = EagerTensorPass.Create();
                var inputs  = EagerTensorPass.From(dims, value);

                Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                          "Fill", name,
                                                          inputs.Points, inputs.Length,
                                                          null, null,
                                                          results.Points, results.Length);
                status.Check(true);

                return(results[0].Resolve());
            }

            var _op = _op_def_lib._apply_op_helper("Fill", name, new { dims, value });

            return(_op.output);
        }
        /// <summary>
        /// Outputs random values from a truncated normal distribution.
        /// </summary>
        /// <param name="shape"></param>
        /// <param name="dtype"></param>
        /// <param name="seed"></param>
        /// <param name="seed2"></param>
        /// <param name="name"></param>
        /// <returns></returns>
        public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int?seed = 0,
                                              int?seed2 = 0, string name = null)
        {
            if (!seed.HasValue)
            {
                seed = 0;
            }
            if (!seed2.HasValue)
            {
                seed2 = 0;
            }

            if (tf.context.executing_eagerly())
            {
                var results = EagerTensorPass.Create();
                var inputs  = EagerTensorPass.From(shape);
                var attrs   = new object[]
                {
                    "seed", seed,
                    "seed2", seed2,
                    "dtype", dtype
                };
                Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
                                                          "TruncatedNormal", name,
                                                          inputs.Points, inputs.Length,
                                                          wrap_tfe_src.SetOpAttrs2(attrs),
                                                          op => wrap_tfe_src.SetOpAttrs(op, attrs),
                                                          results.Points, results.Length);
                status.Check(true);
                return(results[0].Resolve());
            }

            var _op = _op_def_lib._apply_op_helper("TruncatedNormal",
                                                   name: name,
                                                   args: new { shape, dtype, seed, seed2 });

            return(_op.output);
        }