コード例 #1
0
        /// <summary>
        /// Add the backprop loop that controls the iterations.
        /// </summary>
        /// <param name="count">The number of iterations for backprop.</param>
        /// <param name="outer_grad_state">The outer grad state. None if not nested.</param>
        /// <returns>The loop index.</returns>
        public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state)
        {
            Tensor one = null;
            var    in_separate_functions = count.graph != ops.get_default_graph();

            if (in_separate_functions)
            {
                // Brings the count into this graph
                count = array_ops.identity(count);
            }
            else
            {
                one = constant_op.constant(1, name: "b_count");
            }

            Enter();
            AddName(count.name);
            var enter_count = _Enter(
                count,
                _name,
                is_constant: false,
                parallel_iterations: _parallel_iterations,
                name: "b_count");

            loop_enters.append(enter_count);

            var merge_count = merge(new[] { enter_count, enter_count })[0];

            _pivot_for_pred = merge_count;
            if (in_separate_functions)
            {
                one = constant_op.constant(1, name: "b_count");
            }
            var pred = math_ops.greater_equal(merge_count, one);

            _pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count");
            var switch_count = @switch(merge_count, _pivot);

            var index = math_ops.subtract(switch_count[1], one);

            _pivot_for_body = index;
            var next_count = _NextIteration(index);

            merge_count.op._update_input(1, next_count);

            var final_zero = exit(switch_count[0], name: "b_count");

            loop_exits.append(final_zero);
            // Force the stack pops of i-th execution of an inner loop to be ordered
            // before the pops of (i+1)-th execution of the same inner loop.
            if (outer_grad_state != null)
            {
                throw new NotImplementedException("outer_grad_state");
            }
            //outer_grad_state.grad_sync._add_control_input(final_zero.op);
            ExitResult(new[] { final_zero });
            Exit();
            return(next_count);
        }
コード例 #2
0
        public WhileContext(int parallel_iterations     = 10,
                            bool back_prop              = true,
                            bool swap_memory            = false,
                            string name                 = "while_context",
                            GradLoopState grad_state    = null,
                            WhileContextDef context_def = null,
                            string import_scope         = null)
        {
            if (context_def != null)
            {
                _init_from_proto(context_def, import_scope: import_scope);
            }
            else
            {
            }

            _grad_state = grad_state;
        }
コード例 #3
0
        public WhileContext(Tensor maximum_iterations = null,
                            int parallel_iterations   = 10,
                            bool back_prop            = true,
                            bool swap_memory          = false,
                            string name = "while_context",
                            GradLoopState grad_state    = null,
                            WhileContextDef context_def = null,
                            string import_scope         = null)
        {
            if (context_def != null)
            {
                _init_from_proto(context_def, import_scope: import_scope);
            }
            else
            {
                __init__();
                _init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name);
            }

            _grad_state = grad_state;
        }
コード例 #4
0
        /// <summary>
        /// Adds a loop that counts the number of iterations.
        /// </summary>
        /// <param name="outer_grad_state">The outer grad state. None if not nested.</param>
        /// <returns>The number of iterations taken by the forward loop and the loop index.</returns>
        public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state)
        {
            var n = constant_op.constant(0, name: "f_count");

            if (outer_grad_state != null)
            {
                throw new NotImplementedException("AddForwardLoopCounter");
            }

            Enter();
            AddName(n.name);
            var enter_n = _Enter(n,
                                 _name,
                                 is_constant: false,
                                 parallel_iterations: _parallel_iterations,
                                 name: "f_count");

            _loop_enters.Add(enter_n);

            var m1       = merge(new[] { enter_n, enter_n });
            var merge_n  = m1[0];
            var switch_n = @switch(merge_n, _pivot);

            var index  = math_ops.add(switch_n[1], 1);
            var next_n = _NextIteration(index);

            merge_n.op._update_input(1, next_n);

            var total_iterations = exit(switch_n[0], name: "f_count");

            loop_exits.append(total_iterations);
            ExitResult(new[] { total_iterations });
            Exit();

            return(total_iterations, next_n);
        }