Exemplo n.º 1
0
        /// <summary>
        /// Add the loop termination condition and body to the graph.
        /// </summary>
        internal Tensor[] BuildLoop <TItem>(Func <LoopVar <TItem>, Tensor> pred,
                                            Func <LoopVar <TItem>, LoopVar <TItem> > body,
                                            LoopVar <TItem> loop_vars,
                                            TensorShape[] shape_invariants,
                                            bool return_same_structure)
        {
            // Keep original_loop_vars to identify which are TensorArrays
            var original_loop_vars = loop_vars;
            // Convert TensorArrays to their flow variables
            var loop_vars_tensors = nest.flatten2(loop_vars)
                                    .Select(x => _convert_tensorarray_to_flow(x))
                                    .ToArray();

            if (shape_invariants == null)
            {
                shape_invariants = loop_vars_tensors
                                   .Select(x => _get_shape_invariant(x as Tensor))
                                   .ToArray();
            }

            Enter();
            var(original_body_result, exit_vars) = _BuildLoop(
                pred, body, original_loop_vars, loop_vars_tensors, shape_invariants);
            Exit();

            var flat_result = original_body_result;

            var exit_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_result, exit_vars);
            var packed_exit_vars             = nest.pack_sequence_as(
                structure: original_body_result,
                flat_sequence: exit_vars_with_tensor_arrays);

            return(packed_exit_vars as Tensor[]);
        }
Exemplo n.º 2
0
        /// <summary>
        /// Add the loop termination condition and body to the graph.
        /// </summary>
        /// <typeparam name="TItem"></typeparam>
        /// <param name="pred"></param>
        /// <param name="body"></param>
        /// <param name="original_loop_vars"></param>
        /// <param name="loop_vars"></param>
        /// <param name="shape_invariants"></param>
        /// <returns></returns>
        private (LoopVar <TItem>, Tensor[]) _BuildLoop <TItem>(Func <LoopVar <TItem>, Tensor> pred,
                                                               Func <LoopVar <TItem>, LoopVar <TItem> > body,
                                                               LoopVar <TItem> original_loop_vars,
                                                               Tensor[] loop_vars,
                                                               TensorShape[] shape_invariants) where TItem : IFromMergeVars <TItem>, new()
        {
            var flat_loop_vars = nest.flatten2(original_loop_vars)
                                 .Select(x => (ITensorOrTensorArray)x)
                                 .ToArray();

            // Let the context know the loop variables so the loop variables
            // would be added in the outer contexts properly.
            _InitializeValues(loop_vars);
            var real_vars = loop_vars;

            Tensor[] enter_vars = null;
            tf_with(ops.control_dependencies(null), delegate
            {
                enter_vars = real_vars.Select(x => control_flow_ops._Enter(x,
                                                                           _name,
                                                                           is_constant: false,
                                                                           parallel_iterations: _parallel_iterations,
                                                                           use_input_shape: shape_invariants == null))
                             .ToArray();

                foreach (var x in enter_vars)
                {
                    x.graph.prevent_feeding(x);
                    if (_outer_context != null)
                    {
                        _outer_context.AddInnerOp(x.op);
                    }
                }
            });

            // Finds the closest enclosing non-None control pivot.
            var    outer_context = _outer_context;
            object control_pivot = null;

            while (outer_context != null && control_pivot == null)
            {
            }

            if (control_pivot != null)
            {
            }

            _SetShapeInvariants(real_vars, enter_vars, shape_invariants);

            // Fix the control inputs and control flow context of these enter ops.
            _FixControlInputsAndContext(enter_vars);
            _InitializeValues(enter_vars);
            _loop_enters = enter_vars.ToList();

            var merge_vars = enter_vars
                             .Select(x => merge(new[] { x, x }))
                             .Select(m => (Tensor)m)
                             .ToArray();

            _pivot_for_pred = merge_vars[0];

            // Build the graph for pred.
            var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars);
            var packed_vars = new LoopVar <TItem>(
                (Tensor)merge_vars_with_tensor_arrays[0],
                new TItem().FromMergeVars(merge_vars_with_tensor_arrays));
            var pp = pred(packed_vars);
            var c  = ops.convert_to_tensor(pp);

            _pivot = gen_control_flow_ops.loop_cond(c, name: "LoopCond");
            var switch_vars = merge_vars.Select(x => _SwitchRefOrTensor(x, _pivot))
                              .ToArray();

            // Build the graph for body.
            var vars_for_body = switch_vars.Select(x => _Identity(x[1])).ToArray();

            _pivot_for_body = vars_for_body[0];
            // Convert TensorArray flow variables inside the context back into
            // their associated TensorArrays for calling the body.
            var vars_for_body_with_tensor_arrays = _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body);
            var packed_vars_for_body             = nest.pack_sequence_as2(original_loop_vars, vars_for_body_with_tensor_arrays);
            var pre_summaries  = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
            var body_result    = body(packed_vars_for_body);
            var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);

            // Store body_result to keep track of TensorArrays returned by body
            var original_body_result = body_result;
            // Convert TensorArrays returned by body into their flow variables
            var result = nest.flatten2(body_result)
                         .Select(x => _convert_tensorarray_to_flow(x))
                         .ToArray();
            // result = ops.convert_n_to_tensor_or_composite(result);
            var next_vars = new List <Tensor>();

            foreach (var(m, v) in zip(merge_vars, result))
            {
                next_vars.Add(_AddNextAndBackEdge(m, v));
            }

            // Add the exit ops.
            var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();

            _loop_exits = exit_vars;

            // Exit the loop.
            // ExitResult(exit_vars);
            return(original_body_result, exit_vars.ToArray());
        }