예제 #1
0
    public CondContext cond()
    {
        CondContext _localctx = new CondContext(Context, State);

        EnterRule(_localctx, 10, RULE_cond);
        try {
            State = 56;
            ErrorHandler.Sync(this);
            switch (Interpreter.AdaptivePredict(TokenStream, 2, Context))
            {
            case 1:
                _localctx = new CondNoElseContext(_localctx);
                EnterOuterAlt(_localctx, 1);
                {
                    State = 42;
                    Match(IF);
                    State = 43;
                    Match(LP);
                    State = 44;
                    expr();
                    State = 45;
                    Match(RP);
                    State = 46;
                    braceblock();
                }
                break;

            case 2:
                _localctx = new CondElseContext(_localctx);
                EnterOuterAlt(_localctx, 2);
                {
                    State = 48;
                    Match(IF);
                    State = 49;
                    Match(LP);
                    State = 50;
                    expr();
                    State = 51;
                    Match(RP);
                    State = 52;
                    braceblock();
                    State = 53;
                    Match(ELSE);
                    State = 54;
                    braceblock();
                }
                break;
            }
        }
        catch (RecognitionException re) {
            _localctx.exception = re;
            ErrorHandler.ReportError(this, re);
            ErrorHandler.Recover(this, re);
        }
        finally {
            ExitRule();
        }
        return(_localctx);
    }
예제 #2
0
        public static Tensor[] cond <T>(Tensor pred,
                                        Func <T[]> true_fn  = null,
                                        Func <T[]> false_fn = null,
                                        bool strict         = false,
                                        string name         = null)
        {
            return(with(ops.name_scope(name, "cond", new { pred }), delegate
            {
                // Add the Switch to the graph.
                var(p_2, p_1) = @switch(pred, pred);
                var pivot_1 = array_ops.identity(p_1, name: "switch_t");
                var pivot_2 = array_ops.identity(p_2, name: "switch_f");
                pred = array_ops.identity(pred, name: "pred_id");

                // Disable the fetching of tensors that are only on one branch of cond.
                foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
                {
                    tensor.op.graph.prevent_fetching(tensor.op);
                }

                // Build the graph for the true branch in a new context.
                var context_t = new CondContext(pred, pivot_1, branch: 1);
                context_t.Enter();
                var(orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
                context_t.Exit();

                // Build the graph for the false branch in a new context.
                var context_f = new CondContext(pred, pivot_2, branch: 0);
                context_f.Enter();
                var(orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
                context_f.Exit();

                var res_t_flat = res_t;
                var res_f_flat = res_f;

                var merges = zip(res_f_flat, res_t_flat)
                             .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
                             .ToArray();

                merges = _convert_flows_to_tensorarrays(orig_res_t, merges);

                ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t);
                ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f);

                return merges;
            }));
        }
예제 #3
0
    public CondContext cond()
    {
        CondContext _localctx = new CondContext(Context, State);

        EnterRule(_localctx, 8, RULE_cond);
        int _la;

        try {
            EnterOuterAlt(_localctx, 1);
            {
                State = 66; Match(LP);
                State = 67; Match(T__2);
                State = 69;
                ErrorHandler.Sync(this);
                _la = TokenStream.LA(1);
                do
                {
                    {
                        {
                            State = 68; condGroup();
                        }
                    }
                    State = 71;
                    ErrorHandler.Sync(this);
                    _la = TokenStream.LA(1);
                } while (_la == LP);
                State = 73; Match(RP);
            }
        }
        catch (RecognitionException re) {
            _localctx.exception = re;
            ErrorHandler.ReportError(this, re);
            ErrorHandler.Recover(this, re);
        }
        finally {
            ExitRule();
        }
        return(_localctx);
    }
예제 #4
0
        //    """Add the getter for an accumulated value in the grad context.
        //
        //    This is added to the backprop loop. Called in the grad context to
        //    get the value of an accumulated value. The stack pop op must be guarded
        //    by the pred of the controlling cond.
        //
        //    Args:
        //      history_value: The history (a stack) of a value.
        //      value: The value that is pushed onto the stack.
        //      dead_branch: True iff the tensor is on a dead branch of a cond.
        //
        //    Returns:
        //      The current value (the top of the stack).
        //    """

        public Tensor AddBackpropAccumulatedValue(Tensor history_value, Tensor value, bool dead_branch = false)
        {
            var history_ctxt = history_value.op._get_control_flow_context();
            // Find the cond context that controls history_value if any.
            CondContext cond_ctxt  = null;
            Tensor      pop        = null;
            var         value_ctxt = value.op._get_control_flow_context();

            while (value_ctxt != null && value_ctxt != history_ctxt)
            {
                if (value_ctxt is CondContext cc)
                {
                    cond_ctxt = cc;
                }
                value_ctxt = value_ctxt.outer_context;
            }
            tf_with(ops.control_dependencies(null), delegate
            {
                grad_context.Enter();
                if (cond_ctxt != null)
                {
                    throw new NotImplementedException("AddBackpropAccumulatedValue");
                }
                pop = gen_data_flow_ops.stack_pop_v2(history_value, value.dtype.as_base_dtype());
                pop.set_shape(value.TensorShape);
                grad_context.Exit();
            });
            var parallel_iterations = grad_context.parallel_iterations;

            if (parallel_iterations > 1)
            {
                // All pops are ordered after pivot_for_body and before grad_sync.
                grad_sync._add_control_input(pop.op);
            }
            return(pop);
        }
예제 #5
0
        public static (Dictionary <string, IVariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
                                                                                                                             bool clear_devices  = false,
                                                                                                                             string import_scope = "",
                                                                                                                             Dictionary <string, Tensor> input_map = null,
                                                                                                                             string unbound_inputs_col_name        = "unbound_inputs",
                                                                                                                             string[] return_elements = null)
        {
            var meta_graph_def = meta_graph_or_file;

            if (!string.IsNullOrEmpty(unbound_inputs_col_name))
            {
                foreach (var col in meta_graph_def.CollectionDef)
                {
                    if (col.Key == unbound_inputs_col_name)
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }
                }
            }

            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph();

            // Gathers the list of nodes we are interested in.
            OpList producer_op_list = null;

            if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
            {
                producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
            }
            var input_graph_def = meta_graph_def.GraphDef;

            // Remove all the explicit device specifications for this node. This helps to
            // make the graph more portable.
            if (clear_devices)
            {
                foreach (var node in input_graph_def.Node)
                {
                    node.Device = "";
                }
            }

            var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
            var imported_return_elements  = importer.import_graph_def(input_graph_def,
                                                                      name: scope_to_prepend_to_names,
                                                                      input_map: input_map,
                                                                      producer_op_list: producer_op_list,
                                                                      return_elements: return_elements);

            // Restores all the other collections.
            var variable_objects = new Dictionary <ByteString, IVariableV1>();

            foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
            {
                // Don't add unbound_inputs to the new graph.
                if (col.Key == unbound_inputs_col_name)
                {
                    continue;
                }

                switch (col.Value.KindCase)
                {
                case KindOneofCase.NodeList:
                    foreach (var value in col.Value.NodeList.Value)
                    {
                        var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
                        graph.add_to_collection(col.Key, col_op);
                    }
                    break;

                case KindOneofCase.BytesList:
                    //var proto_type = ops.get_collection_proto_type(key)
                    if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            IVariableV1 variable = null;
                            if (!variable_objects.ContainsKey(value))
                            {
                                var proto = VariableDef.Parser.ParseFrom(value);
                                if (proto.IsResource)
                                {
                                    variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                else
                                {
                                    variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                variable_objects[value] = variable;
                            }
                            variable = variable_objects[value];
                            graph.add_to_collection(col.Key, variable);
                        }
                    }
                    else
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            switch (col.Key)
                            {
                            case "cond_context":
                            {
                                var proto       = CondContextDef.Parser.ParseFrom(value);
                                var condContext = new CondContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, condContext);
                            }
                            break;

                            case "while_context":
                            {
                                var proto        = WhileContextDef.Parser.ParseFrom(value);
                                var whileContext = new WhileContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, whileContext);
                            }
                            break;

                            default:
                                Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
                                continue;
                            }
                        }
                    }

                    break;

                default:
                    Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping.");
                    break;
                }
            }

            var variables = graph.get_collection <IVariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
                                                               scope: scope_to_prepend_to_names);
            var var_list = new Dictionary <string, IVariableV1>();

            variables.ForEach(v => var_list[ops.strip_name_scope(v.Name, scope_to_prepend_to_names)] = v);

            return(var_list, imported_return_elements);
        }
예제 #6
0
 public void _set_control_flow_context(CondContext ctx)
 {
     _control_flow_context = ctx;
 }
        /// <summary>
        /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
        ///
        /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
        /// `false_fn` must have the same non-zero number and type of outputs.
        ///
        /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
        /// `false_fn` will be executed regardless of which branch is selected at runtime.
        ///
        /// Although this behavior is consistent with the dataflow model of TensorFlow,
        /// it has frequently surprised users who expected a lazier semantics.
        /// Consider the following simple program:
        ///
        /// z = tf.multiply(a, b)
        /// result = tf.cond(x &lt; y, ()=> tf.add(x, z), ()=> tf.square(y))
        ///
        /// If `x&lt;y`, the `tf.add` operation will be executed and `tf.square`
        /// operation will not be executed.Since `z` is needed for at least one
        /// branch of the `cond`, the `tf.multiply` operation is always executed,
        /// unconditionally.
        ///
        /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
        /// call to `cond`, and not at all during `Session.run()`). `cond`
        /// stitches together the graph fragments created during the `true_fn` and
        /// `false_fn` calls with some additional graph nodes to ensure that the right
        /// branch gets executed depending on the value of `pred`.
        ///
        /// `tf.cond` supports nested structures as implemented in
        /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
        /// same(possibly nested) value structure of lists, tuples, and/or named tuples.
        /// Singleton lists and tuples form the only exceptions to this: when returned by
        /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
        /// This behavior is disabled by passing `strict= True`.
        /// </summary>
        /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
        /// `false_fn`.</param>
        /// <param name="true_fn">The callable to be performed if pred is true.</param>
        /// <param name="false_fn">The callable to be performed if pred is false.</param>
        /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
        /// <param name="name">Optional name prefix for the returned tensors.</param>
        /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
        /// callables return a singleton list, the element is extracted from the list.</returns>
        public static Tensor cond(Tensor pred,
                                  Func <ITensorOrOperation> true_fn  = null,
                                  Func <ITensorOrOperation> false_fn = null,
                                  bool strict = false,
                                  string name = null)
        {
            return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate
            {
                // TODO: here a chunk of original code is missing

                /*
                 * with ops.name_scope(name, "cond", [pred]):
                 *  if context.executing_eagerly():
                 *    if pred:
                 *      return _UnpackIfSingleton(true_fn())
                 *    return _UnpackIfSingleton(false_fn())
                 */

                // Add the Switch to the graph.
                var switch_result = @switch(pred, pred);
                var p_2 = switch_result[0];
                var p_1 = switch_result[1];
                var pivot_1 = array_ops.identity(p_1, name: "switch_t");
                var pivot_2 = array_ops.identity(p_2, name: "switch_f");
                pred = array_ops.identity(pred, name: "pred_id");

                // Disable the fetching of tensors that are only on one branch of cond.
                foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
                {
                    tensor.op.graph.prevent_fetching(tensor.op);
                }

                // Build the graph for the true branch in a new context.
                var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
                ITensorOrOperation orig_res_t;
                Tensor res_t;
                try
                {
                    context_t.Enter();
                    (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
                }
                finally
                {
                    context_t.Exit();
                }
                // Build the graph for the false branch in a new context.
                var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
                ITensorOrOperation orig_res_f;
                Tensor res_f;
                try
                {
                    context_f.Enter();
                    (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
                }
                finally
                {
                    context_f.Exit();
                }

                //TODO: missing original code
                //if not strict:
                //  orig_res_t = _UnpackIfSingleton(orig_res_t)
                //  orig_res_f = _UnpackIfSingleton(orig_res_f)

                /*
                 # Check that the return values of the two branches have the same structure.
                 # try:
                 #  nest.assert_same_structure(orig_res_t, orig_res_f)
                 # except TypeError as e:
                 #  raise TypeError(
                 #      "Incompatible return types of true_fn and false_fn: {}".format(e))
                 # except ValueError as e:
                 #  raise ValueError(
                 #      "Incompatible return values of true_fn and false_fn: {}".format(e))*/

                var res_t_flat = new Tensor[] { res_t };
                var res_f_flat = new Tensor[] { res_f };

                foreach (var(val_x, val_y) in zip(res_t_flat, res_f_flat))
                {
                }

                var merges = zip(res_f_flat, res_t_flat)
                             .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 }))
                             .ToArray();

                merges = _convert_flows_to_tensorarrays(new Tensor[] { (Tensor)orig_res_t }, merges);

                ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
                ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);

                return merges[0];
            }));
예제 #8
0
 public CondNoElseContext(CondContext context)
 {
     CopyFrom(context);
 }
예제 #9
0
 public virtual void CopyFrom(CondContext context)
 {
     base.CopyFrom(context);
 }
예제 #10
0
        /// <summary>
        /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
        ///
        /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
        /// `false_fn` must have the same non-zero number and type of outputs.
        ///
        /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
        /// `false_fn` will be executed regardless of which branch is selected at runtime.
        ///
        /// Although this behavior is consistent with the dataflow model of TensorFlow,
        /// it has frequently surprised users who expected a lazier semantics.
        /// Consider the following simple program:
        ///
        /// z = tf.multiply(a, b)
        /// result = tf.cond(x &lt; y, ()=> tf.add(x, z), ()=> tf.square(y))
        ///
        /// If `x&lt;y`, the `tf.add` operation will be executed and `tf.square`
        /// operation will not be executed.Since `z` is needed for at least one
        /// branch of the `cond`, the `tf.multiply` operation is always executed,
        /// unconditionally.
        ///
        /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
        /// call to `cond`, and not at all during `Session.run()`). `cond`
        /// stitches together the graph fragments created during the `true_fn` and
        /// `false_fn` calls with some additional graph nodes to ensure that the right
        /// branch gets executed depending on the value of `pred`.
        ///
        /// `tf.cond` supports nested structures as implemented in
        /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
        /// same(possibly nested) value structure of lists, tuples, and/or named tuples.
        /// Singleton lists and tuples form the only exceptions to this: when returned by
        /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
        /// This behavior is disabled by passing `strict= True`.
        /// </summary>
        /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
        /// `false_fn`.</param>
        /// <param name="true_fn">The callable to be performed if pred is true.</param>
        /// <param name="false_fn">The callable to be performed if pred is false.</param>
        /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
        /// <param name="name">Optional name prefix for the returned tensors.</param>
        /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
        /// callables return a singleton list, the element is extracted from the list.</returns>
        public static Tensor cond(Tensor pred,
                                  Func <ITensorOrOperation> true_fn  = null,
                                  Func <ITensorOrOperation> false_fn = null,
                                  bool strict = false,
                                  string name = null)
        {
            return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate
            {
                // Add the Switch to the graph.
                var switch_result = @switch(pred, pred);
                var(p_2, p_1) = (switch_result[0], switch_result[1]);
                var pivot_1 = array_ops.identity(p_1, name: "switch_t");
                var pivot_2 = array_ops.identity(p_2, name: "switch_f");
                pred = array_ops.identity(pred, name: "pred_id");

                // Disable the fetching of tensors that are only on one branch of cond.
                foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
                {
                    tensor.op.graph.prevent_fetching(tensor.op);
                }

                // Build the graph for the true branch in a new context.
                var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
                ITensorOrOperation orig_res_t;
                Tensor res_t;
                try
                {
                    context_t.Enter();
                    (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
                    context_t.ExitResult(new[] { res_t });
                }
                finally
                {
                    context_t.Exit();
                }
                // Build the graph for the false branch in a new context.
                var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
                ITensorOrOperation orig_res_f;
                Tensor res_f;
                try
                {
                    context_f.Enter();
                    (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
                    context_f.ExitResult(new[] { res_f });
                }
                finally
                {
                    context_f.Exit();
                }

                var res_t_flat = new Tensor[] { res_t };
                var res_f_flat = new Tensor[] { res_f };

                var merges = zip(res_f_flat, res_t_flat)
                             .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0])
                             .ToArray();

                if (orig_res_t is Tensor orig_res_tensor)
                {
                    merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges)
                             .Select(x => x as Tensor)
                             .ToArray();
                }
                else
                {
                }

                if (context_t.outer_context == null)
                {
                    ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
                    ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
                }

                return merges[0];
            }));