/// <summary> /// Create the state for all the while loops involved in one gradients(). /// </summary> /// <param name="between_op_list"></param> /// <param name="between_ops"></param> /// <param name="colocate_gradients_with_ops"></param> public static ControlFlowState MaybeCreateControlFlowState(List <Operation> between_op_list, List <Operation> between_ops, bool colocate_gradients_with_ops) { var flag = new List <Operation>(); ControlFlowState loop_state = null; int pos = 0; while (pos < between_op_list.Count) { var op = between_op_list[pos]; if (IsLoopExit(op)) { if (loop_state == null) { loop_state = new ControlFlowState(); } if (colocate_gradients_with_ops) { ops.colocate_with(op); } loop_state.AddWhileContext(op, between_op_list, between_ops); } pos++; } return(loop_state); }
/// <summary> /// Create the state for all the while loops involved in one gradients(). /// </summary> /// <param name="between_op_list"></param> /// <param name="between_ops"></param> /// <param name="colocate_gradients_with_ops"></param> public static ControlFlowState MaybeCreateControlFlowState(List <Operation> between_op_list, List <Operation> between_ops, bool colocate_gradients_with_ops) { ControlFlowState loop_state = null; foreach (var op in between_op_list) { if (IsLoopExit(op)) { if (loop_state == null) { loop_state = new ControlFlowState(); } } } return(loop_state); }
public static Tensor[] _GradientsHelper(Tensor[] ys, Tensor[] xs, Tensor[] grad_ys = null, string name = "gradients", bool colocate_gradients_with_ops = false, bool gate_gradients = false, int aggregation_method = 0, Tensor[] stop_gradients = null, Graph src_graph = null) { if (src_graph == null) { src_graph = ops.get_default_graph(); } // If src_graph is a _FuncGraph (i.e. a function body), gather it and all // ancestor graphs. This is necessary for correctly handling captured values. var curr_graph = src_graph; if (stop_gradients == null) { stop_gradients = new Tensor[0]; } if (grad_ys == null) { grad_ys = new Tensor[ys.Length]; } // Iterate over the collected ops. /* * grads: op => list of gradients received on each output endpoint of the * op. The gradients for each endpoint are initially collected as a list. * When it is time to call the op's gradient function, for each endpoint we * aggregate the list of received gradients into a Add() Operation if there * is more than one. */ var grads = new Dictionary <string, List <List <Tensor> > >(); Operation[] reachable_to_ops = null; ControlFlowState loop_state = null; Dictionary <string, int> pending_count = null; tf_with(ops.name_scope(name, "gradients", values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope => { string grad_scope = scope; // Get a uid for this call to gradients that can be used to help // cluster ops for compilation. var gradient_uid = curr_graph.unique_name("uid"); ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); /* * The approach we take here is as follows: Create a list of all ops in the * subgraph between the ys and xs. Visit these ops in reverse order of ids * to ensure that when we visit an op the gradients w.r.t its outputs have * been collected. Then aggregate these gradients if needed, call the op's * gradient function, and add the generated gradients to the gradients for * its input. */ // Initialize the pending count for ops in the connected subgraph from ys // to the xs. var to_ops = ys.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List <object>(), xs); // Add the initial gradients for the ys. foreach (var(y, grad_y) in zip(ys, grad_ys)) { _SetGrad(grads, y, grad_y); } // Initialize queue with to_ops. var queue = new Queue <Operation>(); // Add the ops in 'to_ops' into the queue. var to_ops_set = new List <Operation>(); foreach (var op in to_ops) { // 'ready' handles the case where one output gradient relies on // another output's gradient. if (!pending_count.ContainsKey(op.name)) { pending_count[op.name] = 0; } bool ready = pending_count[op.name] == 0; if (ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op)) { to_ops_set.Add(op); queue.Enqueue(op); } } if (loop_state != null) { var loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set); foreach (var y in loop_exits) { //if(IsTrainable(y)) throw new NotImplementedException(""); } } var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs); while (queue.Count > 0) { // generate gradient subgraph for op. var op = queue.Dequeue(); _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops); { if (loop_state != null) { loop_state.EnterGradWhileContext(op, before: true); } var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method); if (loop_state != null) { loop_state.ExitGradWhileContext(op, before: true); } Tensor[] in_grads = null; Func <Operation, Tensor[], Tensor[]> grad_fn = null; var is_partitioned_call = _IsPartitionedCall(op); var is_func_call = false; var has_out_grads = out_grads.Exists(x => x != null); if (has_out_grads && !stop_ops.Contains(op)) { // A grad_fn must be defined, either as a function or as None // for ops that do not have gradients. try { grad_fn = ops.get_gradient_function(op); } catch (LookupError) { if (is_func_call) { if (is_partitioned_call) { } else { } } else { throw new LookupError($"No gradient defined for operation '{op.name}' (op type: {op.type})"); } } } if (loop_state != null) { loop_state.EnterGradWhileContext(op, before: false); } if ((is_func_call || grad_fn != null) && has_out_grads) { // NOTE: If _AggregatedGrads didn't compute a value for the i'th // output, it means that the cost does not depend on output[i], // therefore dC/doutput[i] is 0. foreach (var(i, out_grad) in enumerate(out_grads)) { if (out_grad == null && (grad_fn == null || _IsTrainable(op.outputs[i]))) { // Only trainable outputs or outputs for a function call that // will use SymbolicGradient get a zero gradient. Gradient // functions should ignore the gradient for other outputs. if (loop_state != null) { out_grads[i] = new List <Tensor> { loop_state.ZerosLike(op, i) } } ; else { out_grads[i] = new List <Tensor> { control_flow_ops.ZerosLikeOutsideLoop(op, i) } }; } } tf_with(ops.name_scope(op.name + "_grad"), scope1 => { if (grad_fn != null) { in_grads = _MaybeCompile(grad_scope, op, out_grads.Where(x => x != null).Select(x => x[0]).ToArray(), null, grad_fn); } else { throw new NotImplementedException("lambda: _SymGrad(op, out_grads)"); } _VerifyGeneratedGradients(in_grads, op); if (gate_gradients && in_grads.Count(x => x != null) > 1) { ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true); in_grads = control_flow_ops.tuple(in_grads); } }); } else { // If no grad_fn is defined or none of out_grads is available, // just propagate a list of None backwards. in_grads = new Tensor[_NonEagerInputs(op, xs).Count()]; } var inputs = _NonEagerInputs(op, xs).ToList(); foreach (var(t_in, in_grad) in zip(inputs, in_grads)) { if (in_grad != null) { if (!(in_grad is null) && in_grad.Tag == null && // maybe a IndexedSlice t_in.dtype != TF_DataType.TF_RESOURCE) { in_grad.set_shape(t_in.TensorShape); } _SetGrad(grads, t_in, in_grad); } } if (loop_state != null) { loop_state.ExitGradWhileContext(op, before: false); } } // Update pending count for the inputs of op and enqueue ready ops. _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs); } });