Пример #1
0
        public GradLoopState(WhileContext forward_ctxt, GradLoopState outer_grad_state_)
        {
            // Information needed by backprop.
            _unused_exits       = new List <Tensor>();
            _deferred_exits     = new List <Tensor>();
            _forward_loop_exits = list(forward_ctxt.loop_exits);
            pending_exits_count = len(forward_ctxt.loop_exits);

            _outer_grad_state = outer_grad_state_;

            ControlFlowContext outer_forward_ctxt = null;

            if (outer_grad_state_ != null)
            {
                outer_forward_ctxt = outer_grad_state_.forward_context;
            }

            // Add the forward loop counter.
            // with forward_ctxt._graph.as_default():
            Tensor cnt, forward_index;

            {
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt.Enter();
                }
                (cnt, forward_index) = forward_ctxt.AddForwardLoopCounter(outer_grad_state);
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt.Exit();
                }
            }
            _forward_context = forward_ctxt;
            _forward_index   = forward_index;

            // Add the backprop WhileContext, and the backprop loop counter.
            if (outer_grad_state != null)
            {
                // This is a nested loop. Remember the iteration counts for each
                // execution of this inner loop.
                throw new NotImplementedException("GradLoopState");
            }
            else
            {
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt.Enter();
                }
                _grad_context = new WhileContext(
                    maximum_iterations: forward_ctxt.maximum_iterations,
                    parallel_iterations: forward_ctxt.parallel_iterations,
                    back_prop: forward_ctxt.back_prop,
                    swap_memory: forward_ctxt.swap_memory,
                    name: forward_ctxt.Name,
                    grad_state: this);
                _grad_index = _grad_context.AddBackpropLoopCounter(cnt, outer_grad_state);
                if (outer_forward_ctxt != null)
                {
                    outer_forward_ctxt.Exit();
                }
            }
        }