Пример #1
0
        //  def AddWhileContext(self, op, between_op_list, between_ops):
        //    """Add the grad state for the while loop that op belongs to.

        //    Note that op is an Exit, and this method must be called in
        //    the control flow context where gradients() is called.

        //    Note that this method modifies `between_op_list` and `between_ops`.
        //    """
        //    forward_ctxt = _GetWhileContext(op)
        //    grad_state = self._map.get(forward_ctxt)
        //    if grad_state is None:
        //      # This is a new while loop so create a grad state for it.
        //      outer_forward_ctxt = forward_ctxt.outer_context
        //      if outer_forward_ctxt:
        //        outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
        //      outer_grad_state = None
        //      if outer_forward_ctxt:
        //        outer_grad_state = self._map.get(outer_forward_ctxt)
        //      grad_state = GradLoopState(forward_ctxt, outer_grad_state)
        //      self._map[forward_ctxt] = grad_state

        //      # We need to include all exits of a loop for backprop.
        //      for loop_exit in grad_state.forward_loop_exits:
        //        if loop_exit.op not in between_ops:
        //          between_ops.add(loop_exit.op)
        //          between_op_list.append(loop_exit.op)
        public void AddWhileContext(Operation op, List <Operation> between_op_list, List <Operation> between_ops)
        {
            var forward_ctxt = op.GetWhileContext();
            var grad_state   = _map.ContainsKey(forward_ctxt) ? _map[forward_ctxt] : null;

            if (grad_state == null)
            {
                GradLoopState outer_grad_state   = null;
                var           outer_forward_ctxt = forward_ctxt.outer_context;
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt = outer_forward_ctxt.GetWhileContext();
                }
                if (outer_forward_ctxt != null)
                {
                    outer_grad_state = _map[outer_forward_ctxt];
                }
                grad_state         = new GradLoopState(forward_ctxt, outer_grad_state);
                _map[forward_ctxt] = grad_state;

                // We need to include all exits of a loop for backprop.
                foreach (var loop_exit in grad_state.forward_loop_exits)
                {
                    if (!between_ops.Contains(loop_exit.op))
                    {
                        between_ops.add(loop_exit.op);
                        between_op_list.append(loop_exit.op);
                    }
                }
            }
        }
Пример #2
0
        /// <summary>
        /// Create zeros_like gradient for a loop exit.
        /// </summary>
        /// <param name="val"></param>
        /// <returns></returns>
        public Tensor ZerosLikeForExit(Tensor val)
        {
            Tensor result             = null;
            var    val_shape          = val.TensorShape;
            var    forward_ctxt       = val.op._get_control_flow_context();
            var    outer_forward_ctxt = forward_ctxt.outer_context;

            if (outer_forward_ctxt != null)
            {
                outer_forward_ctxt = outer_forward_ctxt.GetWhileContext();
            }
            GradLoopState outer_grad_state = null;

            if (outer_forward_ctxt != null)
            {
                outer_grad_state = _map.get(outer_forward_ctxt);
            }
            // This is a nested loop.
            if (outer_grad_state != null)
            {
                throw new NotImplementedException("ZerosLikeForExit");
            }
            else
            {
                // If the shape is known statically, just create a zero tensor
                // with the right shape.
                if (val_shape.is_fully_defined())
                {
                    result = array_ops.zeros(val_shape.dims, val.dtype);
                }
                else
                {
                    result = array_ops.zeros_like(val, optimize: false);
                }
            }
            return(result);
        }
Пример #3
0
        public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
        {
            // Information needed by backprop.
            _unused_exits       = new List <Tensor>();
            _deferred_exits     = new List <Tensor>();
            _forward_loop_exits = list(forward_ctxt.loop_exits);
            pending_exits_count = len(forward_ctxt.loop_exits);

            _outer_grad_state = outer_grad_state_;

            ControlFlowContext outer_forward_ctxt = null;

            if (outer_grad_state_ != null)
            {
                outer_forward_ctxt = outer_grad_state_.forward_context;
            }

            // Add the forward loop counter.
            // with forward_ctxt._graph.as_default():
            Tensor cnt, forward_index;

            {
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt.Enter();
                }
                (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state);
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt.Exit();
                }
            }
            _forward_context = forward_ctxt;
            _forward_index   = forward_index;

            // Add the backprop WhileContext, and the backprop loop counter.
            if (outer_grad_state != null)
            {
                // This is a nested loop. Remember the iteration counts for each
                // execution of this inner loop.
                throw new NotImplementedException("GradLoopState");
            }
            else
            {
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt.Enter();
                }
                _grad_context = new WhileContext(
                    maximum_iterations: forward_ctxt.maximum_iterations,
                    parallel_iterations: forward_ctxt.parallel_iterations,
                    back_prop: forward_ctxt.back_prop,
                    swap_memory: forward_ctxt.swap_memory,
                    name: forward_ctxt.Name,
                    grad_state: this);
                _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state);
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt.Exit();
                }
            }
        }