Ejemplo n.º 1
0
        public static Tensor[] cond <T>(Tensor pred,
                                        Func <T[]> true_fn  = null,
                                        Func <T[]> false_fn = null,
                                        bool strict         = false,
                                        string name         = null)
        {
            return(with(ops.name_scope(name, "cond", new { pred }), delegate
            {
                // Add the Switch to the graph.
                var(p_2, p_1) = @switch(pred, pred);
                var pivot_1 = array_ops.identity(p_1, name: "switch_t");
                var pivot_2 = array_ops.identity(p_2, name: "switch_f");
                pred = array_ops.identity(pred, name: "pred_id");

                // Disable the fetching of tensors that are only on one branch of cond.
                foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
                {
                    tensor.op.graph.prevent_fetching(tensor.op);
                }

                // Build the graph for the true branch in a new context.
                var context_t = new CondContext(pred, pivot_1, branch: 1);
                context_t.Enter();
                var(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
                context_t.Exit();

                // Build the graph for the false branch in a new context.
                var context_f = new CondContext(pred, pivot_2, branch: 0);
                context_f.Enter();
                var(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
                context_f.Exit();

                var res_t_flat = res_t;
                var res_f_flat = res_f;

                var merges = zip(res_f_flat, res_t_flat)
                             .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
                             .ToArray();

                merges = _convert_flows_to_tensorarrays(orig_res_t, merges);

                ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
                ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);

                return merges;
            }));
        }
        /// <summary>
        /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
        ///
        /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
        /// `false_fn` must have the same non-zero number and type of outputs.
        ///
        /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
        /// `false_fn` will be executed regardless of which branch is selected at runtime.
        ///
        /// Although this behavior is consistent with the dataflow model of TensorFlow,
        /// it has frequently surprised users who expected a lazier semantics.
        /// Consider the following simple program:
        ///
        /// z = tf.multiply(a, b)
        /// result = tf.cond(x &lt; y, ()=> tf.add(x, z), ()=> tf.square(y))
        ///
        /// If `x&lt;y`, the `tf.add` operation will be executed and `tf.square`
        /// operation will not be executed.Since `z` is needed for at least one
        /// branch of the `cond`, the `tf.multiply` operation is always executed,
        /// unconditionally.
        ///
        /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
        /// call to `cond`, and not at all during `Session.run()`). `cond`
        /// stitches together the graph fragments created during the `true_fn` and
        /// `false_fn` calls with some additional graph nodes to ensure that the right
        /// branch gets executed depending on the value of `pred`.
        ///
        /// `tf.cond` supports nested structures as implemented in
        /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
        /// same(possibly nested) value structure of lists, tuples, and/or named tuples.
        /// Singleton lists and tuples form the only exceptions to this: when returned by
        /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
        /// This behavior is disabled by passing `strict= True`.
        /// </summary>
        /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
        /// `false_fn`.</param>
        /// <param name="true_fn">The callable to be performed if pred is true.</param>
        /// <param name="false_fn">The callable to be performed if pred is false.</param>
        /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
        /// <param name="name">Optional name prefix for the returned tensors.</param>
        /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
        /// callables return a singleton list, the element is extracted from the list.</returns>
        public static Tensor cond(Tensor pred,
                                  Func <ITensorOrOperation> true_fn  = null,
                                  Func <ITensorOrOperation> false_fn = null,
                                  bool strict = false,
                                  string name = null)
        {
            return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate
            {
                // TODO: here a chunk of original code is missing

                /*
                 * with ops.name_scope(name, "cond", [pred]):
                 *  if context.executing_eagerly():
                 *    if pred:
                 *      return _UnpackIfSingleton(true_fn())
                 *    return _UnpackIfSingleton(false_fn())
                 */

                // Add the Switch to the graph.
                var switch_result = @switch(pred, pred);
                var p_2 = switch_result[0];
                var p_1 = switch_result[1];
                var pivot_1 = array_ops.identity(p_1, name: "switch_t");
                var pivot_2 = array_ops.identity(p_2, name: "switch_f");
                pred = array_ops.identity(pred, name: "pred_id");

                // Disable the fetching of tensors that are only on one branch of cond.
                foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
                {
                    tensor.op.graph.prevent_fetching(tensor.op);
                }

                // Build the graph for the true branch in a new context.
                var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
                ITensorOrOperation orig_res_t;
                Tensor res_t;
                try
                {
                    context_t.Enter();
                    (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
                }
                finally
                {
                    context_t.Exit();
                }
                // Build the graph for the false branch in a new context.
                var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
                ITensorOrOperation orig_res_f;
                Tensor res_f;
                try
                {
                    context_f.Enter();
                    (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
                }
                finally
                {
                    context_f.Exit();
                }

                //TODO: missing original code
                //if not strict:
                //  orig_res_t = _UnpackIfSingleton(orig_res_t)
                //  orig_res_f = _UnpackIfSingleton(orig_res_f)

                /*
                 # Check that the return values of the two branches have the same structure.
                 # try:
                 #  nest.assert_same_structure(orig_res_t, orig_res_f)
                 # except TypeError as e:
                 #  raise TypeError(
                 #      "Incompatible return types of true_fn and false_fn: {}".format(e))
                 # except ValueError as e:
                 #  raise ValueError(
                 #      "Incompatible return values of true_fn and false_fn: {}".format(e))*/

                var res_t_flat = new Tensor[] { res_t };
                var res_f_flat = new Tensor[] { res_f };

                foreach (var(val_x, val_y) in zip(res_t_flat, res_f_flat))
                {
                }

                var merges = zip(res_f_flat, res_t_flat)
                             .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
                             .ToArray();

                merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges);

                ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
                ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);

                return merges[0];
            }));
Ejemplo n.º 3
0
        /// <summary>
        /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
        ///
        /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
        /// `false_fn` must have the same non-zero number and type of outputs.
        ///
        /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
        /// `false_fn` will be executed regardless of which branch is selected at runtime.
        ///
        /// Although this behavior is consistent with the dataflow model of TensorFlow,
        /// it has frequently surprised users who expected a lazier semantics.
        /// Consider the following simple program:
        ///
        /// z = tf.multiply(a, b)
        /// result = tf.cond(x &lt; y, ()=> tf.add(x, z), ()=> tf.square(y))
        ///
        /// If `x&lt;y`, the `tf.add` operation will be executed and `tf.square`
        /// operation will not be executed.Since `z` is needed for at least one
        /// branch of the `cond`, the `tf.multiply` operation is always executed,
        /// unconditionally.
        ///
        /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
        /// call to `cond`, and not at all during `Session.run()`). `cond`
        /// stitches together the graph fragments created during the `true_fn` and
        /// `false_fn` calls with some additional graph nodes to ensure that the right
        /// branch gets executed depending on the value of `pred`.
        ///
        /// `tf.cond` supports nested structures as implemented in
        /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
        /// same(possibly nested) value structure of lists, tuples, and/or named tuples.
        /// Singleton lists and tuples form the only exceptions to this: when returned by
        /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
        /// This behavior is disabled by passing `strict= True`.
        /// </summary>
        /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
        /// `false_fn`.</param>
        /// <param name="true_fn">The callable to be performed if pred is true.</param>
        /// <param name="false_fn">The callable to be performed if pred is false.</param>
        /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
        /// <param name="name">Optional name prefix for the returned tensors.</param>
        /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
        /// callables return a singleton list, the element is extracted from the list.</returns>
        public static Tensor cond(Tensor pred,
                                  Func <ITensorOrOperation> true_fn  = null,
                                  Func <ITensorOrOperation> false_fn = null,
                                  bool strict = false,
                                  string name = null)
        {
            return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate
            {
                // Add the Switch to the graph.
                var switch_result = @switch(pred, pred);
                var(p_2, p_1) = (switch_result[0], switch_result[1]);
                var pivot_1 = array_ops.identity(p_1, name: "switch_t");
                var pivot_2 = array_ops.identity(p_2, name: "switch_f");
                pred = array_ops.identity(pred, name: "pred_id");

                // Disable the fetching of tensors that are only on one branch of cond.
                foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
                {
                    tensor.op.graph.prevent_fetching(tensor.op);
                }

                // Build the graph for the true branch in a new context.
                var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
                ITensorOrOperation orig_res_t;
                Tensor res_t;
                try
                {
                    context_t.Enter();
                    (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
                    context_t.ExitResult(new[] { res_t });
                }
                finally
                {
                    context_t.Exit();
                }
                // Build the graph for the false branch in a new context.
                var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
                ITensorOrOperation orig_res_f;
                Tensor res_f;
                try
                {
                    context_f.Enter();
                    (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
                    context_f.ExitResult(new[] { res_f });
                }
                finally
                {
                    context_f.Exit();
                }

                var res_t_flat = new Tensor[] { res_t };
                var res_f_flat = new Tensor[] { res_f };

                var merges = zip(res_f_flat, res_t_flat)
                             .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0])
                             .ToArray();

                if (orig_res_t is Tensor orig_res_tensor)
                {
                    merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges)
                             .Select(x => x as Tensor)
                             .ToArray();
                }
                else
                {
                }

                if (context_t.outer_context == null)
                {
                    ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
                    ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
                }

                return merges[0];
            }));