/// <summary>
        /// Create the state for all the while loops involved in one gradients().
        /// </summary>
        /// <param name="between_op_list"></param>
        /// <param name="between_ops"></param>
        /// <param name="colocate_gradients_with_ops"></param>
        public static ControlFlowState MaybeCreateControlFlowState(List <Operation> between_op_list, List <Operation> between_ops, bool colocate_gradients_with_ops)
        {
            var flag = new List <Operation>();
            ControlFlowState loop_state = null;

            int pos = 0;

            while (pos < between_op_list.Count)
            {
                var op = between_op_list[pos];
                if (IsLoopExit(op))
                {
                    if (loop_state == null)
                    {
                        loop_state = new ControlFlowState();
                    }
                    if (colocate_gradients_with_ops)
                    {
                        ops.colocate_with(op);
                    }
                    loop_state.AddWhileContext(op, between_op_list, between_ops);
                }
                pos++;
            }

            return(loop_state);
        }
        /// <summary>
        /// Create the state for all the while loops involved in one gradients().
        /// </summary>
        /// <param name="between_op_list"></param>
        /// <param name="between_ops"></param>
        /// <param name="colocate_gradients_with_ops"></param>
        public static ControlFlowState MaybeCreateControlFlowState(List <Operation> between_op_list, List <Operation> between_ops, bool colocate_gradients_with_ops)
        {
            ControlFlowState loop_state = null;

            foreach (var op in between_op_list)
            {
                if (IsLoopExit(op))
                {
                    if (loop_state == null)
                    {
                        loop_state = new ControlFlowState();
                    }
                }
            }

            return(loop_state);
        }
Exemple #3
0
        public static Tensor[] _GradientsHelper(Tensor[] ys,
                                                Tensor[] xs,
                                                Tensor[] grad_ys = null,
                                                string name      = "gradients",
                                                bool colocate_gradients_with_ops = false,
                                                bool gate_gradients     = false,
                                                int aggregation_method  = 0,
                                                Tensor[] stop_gradients = null,
                                                Graph src_graph         = null)
        {
            if (src_graph == null)
            {
                src_graph = ops.get_default_graph();
            }

            // If src_graph is a _FuncGraph (i.e. a function body), gather it and all
            // ancestor graphs. This is necessary for correctly handling captured values.
            var curr_graph = src_graph;

            if (stop_gradients == null)
            {
                stop_gradients = new Tensor[0];
            }
            if (grad_ys == null)
            {
                grad_ys = new Tensor[ys.Length];
            }

            // Iterate over the collected ops.

            /*
             * grads: op => list of gradients received on each output endpoint of the
             * op.  The gradients for each endpoint are initially collected as a list.
             * When it is time to call the op's gradient function, for each endpoint we
             * aggregate the list of received gradients into a Add() Operation if there
             * is more than one.
             */
            var grads = new Dictionary <string, List <List <Tensor> > >();

            Operation[]              reachable_to_ops = null;
            ControlFlowState         loop_state       = null;
            Dictionary <string, int> pending_count    = null;

            tf_with(ops.name_scope(name, "gradients",
                                   values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope =>
            {
                string grad_scope = scope;
                // Get a uid for this call to gradients that can be used to help
                // cluster ops for compilation.
                var gradient_uid = curr_graph.unique_name("uid");
                ys      = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y");
                xs      = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true);
                grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid);

                /*
                 * The approach we take here is as follows: Create a list of all ops in the
                 * subgraph between the ys and xs.  Visit these ops in reverse order of ids
                 * to ensure that when we visit an op the gradients w.r.t its outputs have
                 * been collected.  Then aggregate these gradients if needed, call the op's
                 * gradient function, and add the generated gradients to the gradients for
                 * its input.
                 */

                // Initialize the pending count for ops in the connected subgraph from ys
                // to the xs.
                var to_ops            = ys.Select(x => x.op).ToList();
                var from_ops          = xs.Select(x => x.op).ToList();
                var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
                (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List <object>(), xs);

                // Add the initial gradients for the ys.
                foreach (var(y, grad_y) in zip(ys, grad_ys))
                {
                    _SetGrad(grads, y, grad_y);
                }

                // Initialize queue with to_ops.
                var queue = new Queue <Operation>();
                // Add the ops in 'to_ops' into the queue.
                var to_ops_set = new List <Operation>();
                foreach (var op in to_ops)
                {
                    // 'ready' handles the case where one output gradient relies on
                    // another output's gradient.
                    if (!pending_count.ContainsKey(op.name))
                    {
                        pending_count[op.name] = 0;
                    }
                    bool ready = pending_count[op.name] == 0;
                    if (ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op))
                    {
                        to_ops_set.Add(op);
                        queue.Enqueue(op);
                    }
                }

                if (loop_state != null)
                {
                    var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set);
                    foreach (var y in loop_exits)
                    {
                        //if(IsTrainable(y))
                        throw new NotImplementedException("");
                    }
                }

                var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs);
                while (queue.Count > 0)
                {
                    // generate gradient subgraph for op.
                    var op = queue.Dequeue();

                    _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
                    {
                        if (loop_state != null)
                        {
                            loop_state.EnterGradWhileContext(op, before: true);
                        }
                        var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);
                        if (loop_state != null)
                        {
                            loop_state.ExitGradWhileContext(op, before: true);
                        }

                        Tensor[] in_grads = null;
                        Func <Operation, Tensor[], Tensor[]> grad_fn = null;
                        var is_partitioned_call = _IsPartitionedCall(op);
                        var is_func_call        = false;
                        var has_out_grads       = out_grads.Exists(x => x != null);
                        if (has_out_grads && !stop_ops.Contains(op))
                        {
                            // A grad_fn must be defined, either as a function or as None
                            // for ops that do not have gradients.
                            try
                            {
                                grad_fn = ops.get_gradient_function(op);
                            }
                            catch (LookupError)
                            {
                                if (is_func_call)
                                {
                                    if (is_partitioned_call)
                                    {
                                    }
                                    else
                                    {
                                    }
                                }
                                else
                                {
                                    throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})");
                                }
                            }
                        }

                        if (loop_state != null)
                        {
                            loop_state.EnterGradWhileContext(op, before: false);
                        }

                        if ((is_func_call || grad_fn != null) && has_out_grads)
                        {
                            // NOTE: If _AggregatedGrads didn't compute a value for the i'th
                            // output, it means that the cost does not depend on output[i],
                            // therefore dC/doutput[i] is 0.
                            foreach (var(i, out_grad) in enumerate(out_grads))
                            {
                                if (out_grad == null &&
                                    (grad_fn == null || _IsTrainable(op.outputs[i])))
                                {
                                    // Only trainable outputs or outputs for a function call that
                                    // will use SymbolicGradient get a zero gradient. Gradient
                                    // functions should ignore the gradient for other outputs.
                                    if (loop_state != null)
                                    {
                                        out_grads[i] = new List <Tensor> {
                                            loop_state.ZerosLike(op, i)
                                        }
                                    }
                                    ;
                                    else
                                    {
                                        out_grads[i] = new List <Tensor> {
                                            control_flow_ops.ZerosLikeOutsideLoop(op, i)
                                        }
                                    };
                                }
                            }

                            tf_with(ops.name_scope(op.name + "_grad"), scope1 =>
                            {
                                if (grad_fn != null)
                                {
                                    in_grads = _MaybeCompile(grad_scope,
                                                             op,
                                                             out_grads.Where(x => x != null).Select(x => x[0]).ToArray(),
                                                             null,
                                                             grad_fn);
                                }
                                else
                                {
                                    throw new NotImplementedException("lambda: _SymGrad(op, out_grads)");
                                }
                                _VerifyGeneratedGradients(in_grads, op);
                                if (gate_gradients && in_grads.Count(x => x != null) > 1)
                                {
                                    ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
                                    in_grads = control_flow_ops.tuple(in_grads);
                                }
                            });
                        }
                        else
                        {
                            // If no grad_fn is defined or none of out_grads is available,
                            // just propagate a list of None backwards.
                            in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
                        }

                        var inputs = _NonEagerInputs(op, xs).ToList();
                        foreach (var(t_in, in_grad) in zip(inputs, in_grads))
                        {
                            if (in_grad != null)
                            {
                                if (!(in_grad is null) &&
                                    in_grad.Tag == null &&     // maybe a IndexedSlice
                                    t_in.dtype != TF_DataType.TF_RESOURCE)
                                {
                                    in_grad.set_shape(t_in.TensorShape);
                                }

                                _SetGrad(grads, t_in, in_grad);
                            }
                        }

                        if (loop_state != null)
                        {
                            loop_state.ExitGradWhileContext(op, before: false);
                        }
                    }

                    // Update pending count for the inputs of op and enqueue ready ops.
                    _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs);
                }
            });