// def AddWhileContext(self, op, between_op_list, between_ops): // """Add the grad state for the while loop that op belongs to. // Note that op is an Exit, and this method must be called in // the control flow context where gradients() is called. // Note that this method modifies `between_op_list` and `between_ops`. // """ // forward_ctxt = _GetWhileContext(op) // grad_state = self._map.get(forward_ctxt) // if grad_state is None: // # This is a new while loop so create a grad state for it. // outer_forward_ctxt = forward_ctxt.outer_context // if outer_forward_ctxt: // outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() // outer_grad_state = None // if outer_forward_ctxt: // outer_grad_state = self._map.get(outer_forward_ctxt) // grad_state = GradLoopState(forward_ctxt, outer_grad_state) // self._map[forward_ctxt] = grad_state // # We need to include all exits of a loop for backprop. // for loop_exit in grad_state.forward_loop_exits: // if loop_exit.op not in between_ops: // between_ops.add(loop_exit.op) // between_op_list.append(loop_exit.op) public void AddWhileContext(Operation op, List <Operation> between_op_list, List <Operation> between_ops) { var forward_ctxt = op.GetWhileContext(); var grad_state = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null; if (grad_state == null) { GradLoopState outer_grad_state = null; var outer_forward_ctxt = forward_ctxt.outer_context; if (outer_forward_ctxt != null) { outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); } if (outer_forward_ctxt != null) { outer_grad_state = _map[outer_forward_ctxt]; } grad_state = new GradLoopState(forward_ctxt, outer_grad_state); _map[forward_ctxt] = grad_state; // We need to include all exits of a loop for backprop. foreach (var loop_exit in grad_state.forward_loop_exits) { if (!between_ops.Contains(loop_exit.op)) { between_ops.add(loop_exit.op); between_op_list.append(loop_exit.op); } } } }
/// <summary> /// Create zeros_like gradient for a loop exit. /// </summary> /// <param name="val"></param> /// <returns></returns> public Tensor ZerosLikeForExit(Tensor val) { Tensor result = null; var val_shape = val.TensorShape; var forward_ctxt = val.op._get_control_flow_context(); var outer_forward_ctxt = forward_ctxt.outer_context; if (outer_forward_ctxt != null) { outer_forward_ctxt = outer_forward_ctxt.GetWhileContext(); } GradLoopState outer_grad_state = null; if (outer_forward_ctxt != null) { outer_grad_state = _map.get(outer_forward_ctxt); } // This is a nested loop. if (outer_grad_state != null) { throw new NotImplementedException("ZerosLikeForExit"); } else { // If the shape is known statically, just create a zero tensor // with the right shape. if (val_shape.is_fully_defined()) { result = array_ops.zeros(val_shape.dims, val.dtype); } else { result = array_ops.zeros_like(val, optimize: false); } } return(result); }
public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_) { // Information needed by backprop. _unused_exits = new List <Tensor>(); _deferred_exits = new List <Tensor>(); _forward_loop_exits = list(forward_ctxt.loop_exits); pending_exits_count = len(forward_ctxt.loop_exits); _outer_grad_state = outer_grad_state_; ControlFlowContext outer_forward_ctxt = null; if (outer_grad_state_ != null) { outer_forward_ctxt = outer_grad_state_.forward_context; } // Add the forward loop counter. // with forward_ctxt._graph.as_default(): Tensor cnt, forward_index; { if (outer_forward_ctxt != null) { outer_forward_ctxt.Enter(); } (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state); if (outer_forward_ctxt != null) { outer_forward_ctxt.Exit(); } } _forward_context = forward_ctxt; _forward_index = forward_index; // Add the backprop WhileContext, and the backprop loop counter. if (outer_grad_state != null) { // This is a nested loop. Remember the iteration counts for each // execution of this inner loop. throw new NotImplementedException("GradLoopState"); } else { if (outer_forward_ctxt != null) { outer_forward_ctxt.Enter(); } _grad_context = new WhileContext( maximum_iterations: forward_ctxt.maximum_iterations, parallel_iterations: forward_ctxt.parallel_iterations, back_prop: forward_ctxt.back_prop, swap_memory: forward_ctxt.swap_memory, name: forward_ctxt.Name, grad_state: this); _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state); if (outer_forward_ctxt != null) { outer_forward_ctxt.Exit(); } } }