public CondContext cond() { CondContext _localctx = new CondContext(Context, State); EnterRule(_localctx, 10, RULE_cond); try { State = 56; ErrorHandler.Sync(this); switch (Interpreter.AdaptivePredict(TokenStream, 2, Context)) { case 1: _localctx = new CondNoElseContext(_localctx); EnterOuterAlt(_localctx, 1); { State = 42; Match(IF); State = 43; Match(LP); State = 44; expr(); State = 45; Match(RP); State = 46; braceblock(); } break; case 2: _localctx = new CondElseContext(_localctx); EnterOuterAlt(_localctx, 2); { State = 48; Match(IF); State = 49; Match(LP); State = 50; expr(); State = 51; Match(RP); State = 52; braceblock(); State = 53; Match(ELSE); State = 54; braceblock(); } break; } } catch (RecognitionException re) { _localctx.exception = re; ErrorHandler.ReportError(this, re); ErrorHandler.Recover(this, re); } finally { ExitRule(); } return(_localctx); }
public static Tensor[] cond <T>(Tensor pred, Func <T[]> true_fn = null, Func <T[]> false_fn = null, bool strict = false, string name = null) { return(with(ops.name_scope(name, "cond", new { pred }), delegate { // Add the Switch to the graph. var(p_2, p_1) = @switch(pred, pred); var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); // Disable the fetching of tensors that are only on one branch of cond. foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) { tensor.op.graph.prevent_fetching(tensor.op); } // Build the graph for the true branch in a new context. var context_t = new CondContext(pred, pivot_1, branch: 1); context_t.Enter(); var(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); context_t.Exit(); // Build the graph for the false branch in a new context. var context_f = new CondContext(pred, pivot_2, branch: 0); context_f.Enter(); var(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); context_f.Exit(); var res_t_flat = res_t; var res_f_flat = res_f; var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .ToArray(); merges = _convert_flows_to_tensorarrays(orig_res_t, merges); ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f); return merges; })); }
public CondContext cond() { CondContext _localctx = new CondContext(Context, State); EnterRule(_localctx, 8, RULE_cond); int _la; try { EnterOuterAlt(_localctx, 1); { State = 66; Match(LP); State = 67; Match(T__2); State = 69; ErrorHandler.Sync(this); _la = TokenStream.LA(1); do { { { State = 68; condGroup(); } } State = 71; ErrorHandler.Sync(this); _la = TokenStream.LA(1); } while (_la == LP); State = 73; Match(RP); } } catch (RecognitionException re) { _localctx.exception = re; ErrorHandler.ReportError(this, re); ErrorHandler.Recover(this, re); } finally { ExitRule(); } return(_localctx); }
// """Add the getter for an accumulated value in the grad context. // // This is added to the backprop loop. Called in the grad context to // get the value of an accumulated value. The stack pop op must be guarded // by the pred of the controlling cond. // // Args: // history_value: The history (a stack) of a value. // value: The value that is pushed onto the stack. // dead_branch: True iff the tensor is on a dead branch of a cond. // // Returns: // The current value (the top of the stack). // """ public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch = false) { var history_ctxt = history_value.op._get_control_flow_context(); // Find the cond context that controls history_value if any. CondContext cond_ctxt = null; Tensor pop = null; var value_ctxt = value.op._get_control_flow_context(); while (value_ctxt != null && value_ctxt != history_ctxt) { if (value_ctxt is CondContext cc) { cond_ctxt = cc; } value_ctxt = value_ctxt.outer_context; } tf_with(ops.control_dependencies(null), delegate { grad_context.Enter(); if (cond_ctxt != null) { throw new NotImplementedException("AddBackpropAccumulatedValue"); } pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype()); pop.set_shape(value.TensorShape); grad_context.Exit(); }); var parallel_iterations = grad_context.parallel_iterations; if (parallel_iterations > 1) { // All pops are ordered after pivot_for_body and before grad_sync. grad_sync._add_control_input(pop.op); } return(pop); }
public static (Dictionary <string, IVariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file, bool clear_devices = false, string import_scope = "", Dictionary <string, Tensor> input_map = null, string unbound_inputs_col_name = "unbound_inputs", string[] return_elements = null) { var meta_graph_def = meta_graph_or_file; if (!string.IsNullOrEmpty(unbound_inputs_col_name)) { foreach (var col in meta_graph_def.CollectionDef) { if (col.Key == unbound_inputs_col_name) { throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); } } } // Sets graph to default graph if it's not passed in. var graph = ops.get_default_graph(); // Gathers the list of nodes we are interested in. OpList producer_op_list = null; if (meta_graph_def.MetaInfoDef.StrippedOpList != null) { producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList; } var input_graph_def = meta_graph_def.GraphDef; // Remove all the explicit device specifications for this node. This helps to // make the graph more portable. if (clear_devices) { foreach (var node in input_graph_def.Node) { node.Device = ""; } } var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false); var imported_return_elements = importer.import_graph_def(input_graph_def, name: scope_to_prepend_to_names, input_map: input_map, producer_op_list: producer_op_list, return_elements: return_elements); // Restores all the other collections. var variable_objects = new Dictionary <ByteString, IVariableV1>(); foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key)) { // Don't add unbound_inputs to the new graph. if (col.Key == unbound_inputs_col_name) { continue; } switch (col.Value.KindCase) { case KindOneofCase.NodeList: foreach (var value in col.Value.NodeList.Value) { var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names)); graph.add_to_collection(col.Key, col_op); } break; case KindOneofCase.BytesList: //var proto_type = ops.get_collection_proto_type(key) if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key)) { foreach (var value in col.Value.BytesList.Value) { IVariableV1 variable = null; if (!variable_objects.ContainsKey(value)) { var proto = VariableDef.Parser.ParseFrom(value); if (proto.IsResource) { variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); } else { variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names); } variable_objects[value] = variable; } variable = variable_objects[value]; graph.add_to_collection(col.Key, variable); } } else { foreach (var value in col.Value.BytesList.Value) { switch (col.Key) { case "cond_context": { var proto = CondContextDef.Parser.ParseFrom(value); var condContext = new CondContext().from_proto(proto, import_scope); graph.add_to_collection(col.Key, condContext); } break; case "while_context": { var proto = WhileContextDef.Parser.ParseFrom(value); var whileContext = new WhileContext().from_proto(proto, import_scope); graph.add_to_collection(col.Key, whileContext); } break; default: Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}"); continue; } } } break; default: Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping."); break; } } var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES, scope: scope_to_prepend_to_names); var var_list = new Dictionary <string, IVariableV1>(); variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v); return(var_list, imported_return_elements); }
public void _set_control_flow_context(CondContext ctx) { _control_flow_context = ctx; }
/// <summary> /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. /// /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and /// `false_fn` must have the same non-zero number and type of outputs. /// /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and /// `false_fn` will be executed regardless of which branch is selected at runtime. /// /// Although this behavior is consistent with the dataflow model of TensorFlow, /// it has frequently surprised users who expected a lazier semantics. /// Consider the following simple program: /// /// z = tf.multiply(a, b) /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) /// /// If `x<y`, the `tf.add` operation will be executed and `tf.square` /// operation will not be executed.Since `z` is needed for at least one /// branch of the `cond`, the `tf.multiply` operation is always executed, /// unconditionally. /// /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the /// call to `cond`, and not at all during `Session.run()`). `cond` /// stitches together the graph fragments created during the `true_fn` and /// `false_fn` calls with some additional graph nodes to ensure that the right /// branch gets executed depending on the value of `pred`. /// /// `tf.cond` supports nested structures as implemented in /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the /// same(possibly nested) value structure of lists, tuples, and/or named tuples. /// Singleton lists and tuples form the only exceptions to this: when returned by /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. /// This behavior is disabled by passing `strict= True`. /// </summary> /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or /// `false_fn`.</param> /// <param name="true_fn">The callable to be performed if pred is true.</param> /// <param name="false_fn">The callable to be performed if pred is false.</param> /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param> /// <param name="name">Optional name prefix for the returned tensors.</param> /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the /// callables return a singleton list, the element is extracted from the list.</returns> public static Tensor cond(Tensor pred, Func <ITensorOrOperation> true_fn = null, Func <ITensorOrOperation> false_fn = null, bool strict = false, string name = null) { return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate { // TODO: here a chunk of original code is missing /* * with ops.name_scope(name, "cond", [pred]): * if context.executing_eagerly(): * if pred: * return _UnpackIfSingleton(true_fn()) * return _UnpackIfSingleton(false_fn()) */ // Add the Switch to the graph. var switch_result = @switch(pred, pred); var p_2 = switch_result[0]; var p_1 = switch_result[1]; var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); // Disable the fetching of tensors that are only on one branch of cond. foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) { tensor.op.graph.prevent_fetching(tensor.op); } // Build the graph for the true branch in a new context. var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); ITensorOrOperation orig_res_t; Tensor res_t; try { context_t.Enter(); (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); } finally { context_t.Exit(); } // Build the graph for the false branch in a new context. var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); ITensorOrOperation orig_res_f; Tensor res_f; try { context_f.Enter(); (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); } finally { context_f.Exit(); } //TODO: missing original code //if not strict: // orig_res_t = _UnpackIfSingleton(orig_res_t) // orig_res_f = _UnpackIfSingleton(orig_res_f) /* # Check that the return values of the two branches have the same structure. # try: # nest.assert_same_structure(orig_res_t, orig_res_f) # except TypeError as e: # raise TypeError( # "Incompatible return types of true_fn and false_fn: {}".format(e)) # except ValueError as e: # raise ValueError( # "Incompatible return values of true_fn and false_fn: {}".format(e))*/ var res_t_flat = new Tensor[] { res_t }; var res_f_flat = new Tensor[] { res_f }; foreach (var(val_x, val_y) in zip(res_t_flat, res_f_flat)) { } var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })) .ToArray(); merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); return merges[0]; }));
public CondNoElseContext(CondContext context) { CopyFrom(context); }
public virtual void CopyFrom(CondContext context) { base.CopyFrom(context); }
/// <summary> /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`. /// /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and /// `false_fn` must have the same non-zero number and type of outputs. /// /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and /// `false_fn` will be executed regardless of which branch is selected at runtime. /// /// Although this behavior is consistent with the dataflow model of TensorFlow, /// it has frequently surprised users who expected a lazier semantics. /// Consider the following simple program: /// /// z = tf.multiply(a, b) /// result = tf.cond(x < y, ()=> tf.add(x, z), ()=> tf.square(y)) /// /// If `x<y`, the `tf.add` operation will be executed and `tf.square` /// operation will not be executed.Since `z` is needed for at least one /// branch of the `cond`, the `tf.multiply` operation is always executed, /// unconditionally. /// /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the /// call to `cond`, and not at all during `Session.run()`). `cond` /// stitches together the graph fragments created during the `true_fn` and /// `false_fn` calls with some additional graph nodes to ensure that the right /// branch gets executed depending on the value of `pred`. /// /// `tf.cond` supports nested structures as implemented in /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the /// same(possibly nested) value structure of lists, tuples, and/or named tuples. /// Singleton lists and tuples form the only exceptions to this: when returned by /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. /// This behavior is disabled by passing `strict= True`. /// </summary> /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or /// `false_fn`.</param> /// <param name="true_fn">The callable to be performed if pred is true.</param> /// <param name="false_fn">The callable to be performed if pred is false.</param> /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param> /// <param name="name">Optional name prefix for the returned tensors.</param> /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the /// callables return a singleton list, the element is extracted from the list.</returns> public static Tensor cond(Tensor pred, Func <ITensorOrOperation> true_fn = null, Func <ITensorOrOperation> false_fn = null, bool strict = false, string name = null) { return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate { // Add the Switch to the graph. var switch_result = @switch(pred, pred); var(p_2, p_1) = (switch_result[0], switch_result[1]); var pivot_1 = array_ops.identity(p_1, name: "switch_t"); var pivot_2 = array_ops.identity(p_2, name: "switch_f"); pred = array_ops.identity(pred, name: "pred_id"); // Disable the fetching of tensors that are only on one branch of cond. foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) { tensor.op.graph.prevent_fetching(tensor.op); } // Build the graph for the true branch in a new context. var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); ITensorOrOperation orig_res_t; Tensor res_t; try { context_t.Enter(); (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); context_t.ExitResult(new[] { res_t }); } finally { context_t.Exit(); } // Build the graph for the false branch in a new context. var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); ITensorOrOperation orig_res_f; Tensor res_f; try { context_f.Enter(); (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); context_f.ExitResult(new[] { res_f }); } finally { context_f.Exit(); } var res_t_flat = new Tensor[] { res_t }; var res_f_flat = new Tensor[] { res_f }; var merges = zip(res_f_flat, res_t_flat) .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0]) .ToArray(); if (orig_res_t is Tensor orig_res_tensor) { merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges) .Select(x => x as Tensor) .ToArray(); } else { } if (context_t.outer_context == null) { ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t); ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f); } return merges[0]; }));