/// <summary> /// Add the loop termination condition and body to the graph. /// </summary> internal Tensor[] BuildLoop <TItem>(Func <LoopVar <TItem>, Tensor> pred, Func <LoopVar <TItem>, LoopVar <TItem> > body, LoopVar <TItem> loop_vars, TensorShape[] shape_invariants, bool return_same_structure) { // Keep original_loop_vars to identify which are TensorArrays var original_loop_vars = loop_vars; // Convert TensorArrays to their flow variables var loop_vars_tensors = nest.flatten2(loop_vars) .Select(x => _convert_tensorarray_to_flow(x)) .ToArray(); if (shape_invariants == null) { shape_invariants = loop_vars_tensors .Select(x => _get_shape_invariant(x as Tensor)) .ToArray(); } Enter(); var(original_body_result, exit_vars) = _BuildLoop( pred, body, original_loop_vars, loop_vars_tensors, shape_invariants); Exit(); var flat_result = original_body_result; var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars); var packed_exit_vars = nest.pack_sequence_as( structure: original_body_result, flat_sequence: exit_vars_with_tensor_arrays); return(packed_exit_vars as Tensor[]); }
/// <summary> /// Add the loop termination condition and body to the graph. /// </summary> /// <typeparam name="TItem"></typeparam> /// <param name="pred"></param> /// <param name="body"></param> /// <param name="original_loop_vars"></param> /// <param name="loop_vars"></param> /// <param name="shape_invariants"></param> /// <returns></returns> private (LoopVar <TItem>, Tensor[]) _BuildLoop <TItem>(Func <LoopVar <TItem>, Tensor> pred, Func <LoopVar <TItem>, LoopVar <TItem> > body, LoopVar <TItem> original_loop_vars, Tensor[] loop_vars, TensorShape[] shape_invariants) where TItem : IFromMergeVars <TItem>, new() { var flat_loop_vars = nest.flatten2(original_loop_vars) .Select(x => (ITensorOrTensorArray)x) .ToArray(); // Let the context know the loop variables so the loop variables // would be added in the outer contexts properly. _InitializeValues(loop_vars); var real_vars = loop_vars; Tensor[] enter_vars = null; tf_with(ops.control_dependencies(null), delegate { enter_vars = real_vars.Select(x => control_flow_ops._Enter(x, _name, is_constant: false, parallel_iterations: _parallel_iterations, use_input_shape: shape_invariants == null)) .ToArray(); foreach (var x in enter_vars) { x.graph.prevent_feeding(x); if (_outer_context != null) { _outer_context.AddInnerOp(x.op); } } }); // Finds the closest enclosing non-None control pivot. var outer_context = _outer_context; object control_pivot = null; while (outer_context != null && control_pivot == null) { } if (control_pivot != null) { } _SetShapeInvariants(real_vars, enter_vars, shape_invariants); // Fix the control inputs and control flow context of these enter ops. _FixControlInputsAndContext(enter_vars); _InitializeValues(enter_vars); _loop_enters = enter_vars.ToList(); var merge_vars = enter_vars .Select(x => merge(new[] { x, x })) .Select(m => (Tensor)m) .ToArray(); _pivot_for_pred = merge_vars[0]; // Build the graph for pred. var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars); var packed_vars = new LoopVar <TItem>( (Tensor)merge_vars_with_tensor_arrays[0], new TItem().FromMergeVars(merge_vars_with_tensor_arrays)); var pp = pred(packed_vars); var c = ops.convert_to_tensor(pp); _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond"); var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot)) .ToArray(); // Build the graph for body. var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray(); _pivot_for_body = vars_for_body[0]; // Convert TensorArray flow variables inside the context back into // their associated TensorArrays for calling the body. var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body); var packed_vars_for_body = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays); var pre_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); var body_result = body(packed_vars_for_body); var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION); // Store body_result to keep track of TensorArrays returned by body var original_body_result = body_result; // Convert TensorArrays returned by body into their flow variables var result = nest.flatten2(body_result) .Select(x => _convert_tensorarray_to_flow(x)) .ToArray(); // result = ops.convert_n_to_tensor_or_composite(result); var next_vars = new List <Tensor>(); foreach (var(m, v) in zip(merge_vars, result)) { next_vars.Add(_AddNextAndBackEdge(m, v)); } // Add the exit ops. var exit_vars = switch_vars.Select(x => exit(x[0])).ToList(); _loop_exits = exit_vars; // Exit the loop. // ExitResult(exit_vars); return(original_body_result, exit_vars.ToArray()); }