public void _set_control_flow_context(ControlFlowContext ctx) { if (name == "define_loss/conv_sobj_branch/batch_normalization/cond/FusedBatchNorm_1") { } _control_flow_context = ctx; }
public void _set_control_flow_context(ControlFlowContext ctx) { if (name.Contains("gradients/rnn/while/basic_rnn_cell/Tanh_grad/TanhGrad/f_acc")) { } _control_flow_context = ctx; }
public static WhileContext GetContainingWhileContext(ControlFlowContext ctxt, ControlFlowContext stop_ctxt = null) { while (ctxt != null) { if (ctxt.IsWhileContext() || ctxt == stop_ctxt) { return(ctxt as WhileContext); } ctxt = ctxt.outer_context; } return(null); }
public void __enter__() { if (_new_stack) { // Clear the control_dependencies graph. _old_stack = _graph._control_dependencies_stack; _graph._control_dependencies_stack = new List <_ControlDependenciesController>(); // Clear the control_flow_context too. _old_control_flow_context = _graph._get_control_flow_context(); _graph._set_control_flow_context(null); } _graph._push_control_dependencies_controller(this); }
/// <summary> /// Create a new `_ControlDependenciesController`. /// /// A `_ControlDependenciesController` is the context manager for /// `with tf.control_dependencies()` blocks.These normally nest, /// as described in the documentation for `control_dependencies()`. /// /// The `control_inputs` argument list control dependencies that must be /// added to the current set of control dependencies.Because of /// uniquification the set can be empty even if the caller passed a list of /// ops.The special value `None` indicates that we want to start a new /// empty set of control dependencies instead of extending the current set. /// /// In that case we also clear the current control flow context, which is an /// additional mechanism to add control dependencies. /// </summary> /// <param name="graph">The graph that this controller is managing.</param> /// <param name="control_inputs">List of ops to use as control inputs in addition /// to the current control dependencies.None to indicate that /// the dependencies should be cleared. /// </param> public _ControlDependenciesController(Graph graph, List <ITensorOrOperation> control_inputs) { _graph = graph; if (control_inputs == null) { _control_inputs_val = new List <ITensorOrOperation>(); _new_stack = true; } else { _control_inputs_val = control_inputs; _new_stack = false; } _seen_nodes = new List <ITensorOrOperation>(); _old_stack = null; _old_control_flow_context = null; }
/// <summary> /// Return the grad state for this op if it's in a forward loop context. /// </summary> /// <param name="op"></param> /// <param name="before"></param> /// <returns></returns> public GradLoopState GetGradState(Operation op, bool before) { ControlFlowContext forward_ctxt = null; if (before && util.IsLoopExit(op)) { forward_ctxt = op._get_control_flow_context(); forward_ctxt = forward_ctxt.outer_context; if (forward_ctxt != null) { forward_ctxt = forward_ctxt.GetWhileContext(); } } else { forward_ctxt = util.GetWhileContext(op); } if (forward_ctxt != null) { return(_map.get(forward_ctxt)); } return(null); }
public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) { // Information needed by backprop. _unused_exits = new List <Tensor>(); _deferred_exits = new List <Tensor>(); _forward_loop_exits = list(forward_ctxt.loop_exits); pending_exits_count = len(forward_ctxt.loop_exits); _outer_grad_state = outer_grad_state_; ControlFlowContext outer_forward_ctxt = null; if (outer_grad_state_ != null) { outer_forward_ctxt = outer_grad_state_.forward_context; } // Add the forward loop counter. // with forward_ctxt._graph.as_default(): Tensor cnt, forward_index; { if (outer_forward_ctxt != null) { outer_forward_ctxt.Enter(); } (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state); if (outer_forward_ctxt != null) { outer_forward_ctxt.Exit(); } } _forward_context = forward_ctxt; _forward_index = forward_index; // Add the backprop WhileContext, and the backprop loop counter. if (outer_grad_state != null) { // This is a nested loop. Remember the iteration counts for each // execution of this inner loop. throw new NotImplementedException("GradLoopState"); } else { if (outer_forward_ctxt != null) { outer_forward_ctxt.Enter(); } _grad_context = new WhileContext( maximum_iterations: forward_ctxt.maximum_iterations, parallel_iterations: forward_ctxt.parallel_iterations, back_prop: forward_ctxt.back_prop, swap_memory: forward_ctxt.swap_memory, name: forward_ctxt.Name, grad_state: this); _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state); if (outer_forward_ctxt != null) { outer_forward_ctxt.Exit(); } } }
public void _set_control_flow_context(ControlFlowContext ctx) { _control_flow_context = ctx; }