/// <summary> /// Add the backprop loop that controls the iterations. /// </summary> /// <param name="count">The number of iterations for backprop.</param> /// <param name="outer_grad_state">The outer grad state. None if not nested.</param> /// <returns>The loop index.</returns> public Tensor AddBackpropLoopCounter(Tensor count, GradLoopState outer_grad_state) { Tensor one = null; var in_separate_functions = count.graph != ops.get_default_graph(); if (in_separate_functions) { // Brings the count into this graph count = array_ops.identity(count); } else { one = constant_op.constant(1, name: "b_count"); } Enter(); AddName(count.name); var enter_count = _Enter( count, _name, is_constant: false, parallel_iterations: _parallel_iterations, name: "b_count"); loop_enters.append(enter_count); var merge_count = merge(new[] { enter_count, enter_count })[0]; _pivot_for_pred = merge_count; if (in_separate_functions) { one = constant_op.constant(1, name: "b_count"); } var pred = math_ops.greater_equal(merge_count, one); _pivot = gen_control_flow_ops.loop_cond(pred, name: "b_count"); var switch_count = @switch(merge_count, _pivot); var index = math_ops.subtract(switch_count[1], one); _pivot_for_body = index; var next_count = _NextIteration(index); merge_count.op._update_input(1, next_count); var final_zero = exit(switch_count[0], name: "b_count"); loop_exits.append(final_zero); // Force the stack pops of i-th execution of an inner loop to be ordered // before the pops of (i+1)-th execution of the same inner loop. if (outer_grad_state != null) { throw new NotImplementedException("outer_grad_state"); } //outer_grad_state.grad_sync._add_control_input(final_zero.op); ExitResult(new[] { final_zero }); Exit(); return(next_count); }
public WhileContext(int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, string name = "while_context", GradLoopState grad_state = null, WhileContextDef context_def = null, string import_scope = null) { if (context_def != null) { _init_from_proto(context_def, import_scope: import_scope); } else { } _grad_state = grad_state; }
public WhileContext(Tensor maximum_iterations = null, int parallel_iterations = 10, bool back_prop = true, bool swap_memory = false, string name = "while_context", GradLoopState grad_state = null, WhileContextDef context_def = null, string import_scope = null) { if (context_def != null) { _init_from_proto(context_def, import_scope: import_scope); } else { __init__(); _init_from_args(maximum_iterations, parallel_iterations, back_prop, swap_memory, name); } _grad_state = grad_state; }
/// <summary> /// Adds a loop that counts the number of iterations. /// </summary> /// <param name="outer_grad_state">The outer grad state. None if not nested.</param> /// <returns>The number of iterations taken by the forward loop and the loop index.</returns> public (Tensor, Tensor) AddForwardLoopCounter(GradLoopState outer_grad_state) { var n = constant_op.constant(0, name: "f_count"); if (outer_grad_state != null) { throw new NotImplementedException("AddForwardLoopCounter"); } Enter(); AddName(n.name); var enter_n = _Enter(n, _name, is_constant: false, parallel_iterations: _parallel_iterations, name: "f_count"); _loop_enters.Add(enter_n); var m1 = merge(new[] { enter_n, enter_n }); var merge_n = m1[0]; var switch_n = @switch(merge_n, _pivot); var index = math_ops.add(switch_n[1], 1); var next_n = _NextIteration(index); merge_n.op._update_input(1, next_n); var total_iterations = exit(switch_n[0], name: "f_count"); loop_exits.append(total_iterations); ExitResult(new[] { total_iterations }); Exit(); return(total_iterations, next_n); }