public static Tensor[] cond <T>(Tensor pred, Func <T[]> true_fn = null, Func <T[]> false_fn = null, bool strict = false, string name = null) { return(with(ops.name_scope(name, "cond", new { pred }), delegate { // Add the Switch to the graph. var(p_2, p_1) = @switch(pred, pred); var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); // Disable the fetching of tensors that are only on one branch of cond. foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) { tensor.op.graph.prevent_fetching(tensor.op); } // Build the graph for the true branch in a new context. var context_t = new CondContext(pred, pivot_1, branch: 1); context_t.Enter(); var(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); context_t.Exit(); // Build the graph for the false branch in a new context. var context_f = new CondContext(pred, pivot_2, branch: 0); context_f.Enter(); var(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); context_f.Exit(); var res_t_flat = res_t; var res_f_flat = res_f; var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .ToArray(); merges = _convert_flows_to_tensorarrays(orig_res_t, merges); ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); return merges; })); }
/// <summary> /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. /// /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and /// `false_fn` must have the same non-zero number and type of outputs. /// /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and /// `false_fn` will be executed regardless of which branch is selected at runtime. /// /// Although this behavior is consistent with the dataflow model of TensorFlow, /// it has frequently surprised users who expected a lazier semantics. /// Consider the following simple program: /// /// z = tf.multiply(a, b) /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) /// /// If `x<y`, the `tf.add` operation will be executed and `tf.square` /// operation will not be executed.Since `z` is needed for at least one /// branch of the `cond`, the `tf.multiply` operation is always executed, /// unconditionally. /// /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the /// call to `cond`, and not at all during `Session.run()`). `cond` /// stitches together the graph fragments created during the `true_fn` and /// `false_fn` calls with some additional graph nodes to ensure that the right /// branch gets executed depending on the value of `pred`. /// /// `tf.cond` supports nested structures as implemented in /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the /// same(possibly nested) value structure of lists, tuples, and/or named tuples. /// Singleton lists and tuples form the only exceptions to this: when returned by /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. /// This behavior is disabled by passing `strict= True`. /// </summary> /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or /// `false_fn`.</param> /// <param name="true_fn">The callable to be performed if pred is true.</param> /// <param name="false_fn">The callable to be performed if pred is false.</param> /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param> /// <param name="name">Optional name prefix for the returned tensors.</param> /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the /// callables return a singleton list, the element is extracted from the list.</returns> public static Tensor cond(Tensor pred, Func <ITensorOrOperation> true_fn = null, Func <ITensorOrOperation> false_fn = null, bool strict = false, string name = null) { return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate { // TODO: here a chunk of original code is missing /* * with ops.name_scope(name, "cond", [pred]): * if context.executing_eagerly(): * if pred: * return _UnpackIfSingleton(true_fn()) * return _UnpackIfSingleton(false_fn()) */ // Add the Switch to the graph. var switch_result = @switch(pred, pred); var p_2 = switch_result[0]; var p_1 = switch_result[1]; var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); // Disable the fetching of tensors that are only on one branch of cond. foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) { tensor.op.graph.prevent_fetching(tensor.op); } // Build the graph for the true branch in a new context. var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); ITensorOrOperation orig_res_t; Tensor res_t; try { context_t.Enter(); (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); } finally { context_t.Exit(); } // Build the graph for the false branch in a new context. var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); ITensorOrOperation orig_res_f; Tensor res_f; try { context_f.Enter(); (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); } finally { context_f.Exit(); } //TODO: missing original code //if not strict: // orig_res_t = _UnpackIfSingleton(orig_res_t) // orig_res_f = _UnpackIfSingleton(orig_res_f) /* # Check that the return values of the two branches have the same structure. # try: # nest.assert_same_structure(orig_res_t, orig_res_f) # except TypeError as e: # raise TypeError( # "Incompatible return types of true_fn and false_fn: {}".format(e)) # except ValueError as e: # raise ValueError( # "Incompatible return values of true_fn and false_fn: {}".format(e))*/ var res_t_flat = new Tensor[] { res_t }; var res_f_flat = new Tensor[] { res_f }; foreach (var(val_x, val_y) in zip(res_t_flat, res_f_flat)) { } var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .ToArray(); merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); return merges[0]; }));
/// <summary> /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. /// /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and /// `false_fn` must have the same non-zero number and type of outputs. /// /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and /// `false_fn` will be executed regardless of which branch is selected at runtime. /// /// Although this behavior is consistent with the dataflow model of TensorFlow, /// it has frequently surprised users who expected a lazier semantics. /// Consider the following simple program: /// /// z = tf.multiply(a, b) /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) /// /// If `x<y`, the `tf.add` operation will be executed and `tf.square` /// operation will not be executed.Since `z` is needed for at least one /// branch of the `cond`, the `tf.multiply` operation is always executed, /// unconditionally. /// /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the /// call to `cond`, and not at all during `Session.run()`). `cond` /// stitches together the graph fragments created during the `true_fn` and /// `false_fn` calls with some additional graph nodes to ensure that the right /// branch gets executed depending on the value of `pred`. /// /// `tf.cond` supports nested structures as implemented in /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the /// same(possibly nested) value structure of lists, tuples, and/or named tuples. /// Singleton lists and tuples form the only exceptions to this: when returned by /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. /// This behavior is disabled by passing `strict= True`. /// </summary> /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or /// `false_fn`.</param> /// <param name="true_fn">The callable to be performed if pred is true.</param> /// <param name="false_fn">The callable to be performed if pred is false.</param> /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param> /// <param name="name">Optional name prefix for the returned tensors.</param> /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the /// callables return a singleton list, the element is extracted from the list.</returns> public static Tensor cond(Tensor pred, Func <ITensorOrOperation> true_fn = null, Func <ITensorOrOperation> false_fn = null, bool strict = false, string name = null) { return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate { // Add the Switch to the graph. var switch_result = @switch(pred, pred); var(p_2, p_1) = (switch_result[0], switch_result[1]); var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); // Disable the fetching of tensors that are only on one branch of cond. foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) { tensor.op.graph.prevent_fetching(tensor.op); } // Build the graph for the true branch in a new context. var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); ITensorOrOperation orig_res_t; Tensor res_t; try { context_t.Enter(); (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); context_t.ExitResult(new[] { res_t }); } finally { context_t.Exit(); } // Build the graph for the false branch in a new context. var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); ITensorOrOperation orig_res_f; Tensor res_f; try { context_f.Enter(); (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); context_f.ExitResult(new[] { res_f }); } finally { context_f.Exit(); } var res_t_flat = new Tensor[] { res_t }; var res_f_flat = new Tensor[] { res_f }; var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0]) .ToArray(); if (orig_res_t is Tensor orig_res_tensor) { merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges) .Select(x => x as Tensor) .ToArray(); } else { } if (context_t.outer_context == null) { ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); } return merges[0]; }));