public static Operation group(List <Operation> inputs, string name = "")
        {
            using (var namescope = new ops.name_scope <Operation>(name, "group_deps", inputs))
            {
                name = namescope;

                // Sorts *inputs according to their devices.
                var ops_on_device = new Dictionary <string, Operation[]>();
                foreach (var inp in inputs)
                {
                    ops_on_device[inp.Device] = new Operation[] { inp };
                }

                // 1-level tree. The root node is the returned NoOp node.
                if (ops_on_device.Count == 1)
                {
                    var dev  = ops_on_device.Keys.First();
                    var deps = ops_on_device.Values.First();
                    return(_GroupControlDeps(dev, deps, name));
                }

                // 2-level tree. The root node is the returned NoOp node.
                // deps contains 1 NoOp node for each device.
                return(null);
            }
        }
Пример #2
0
        /// <summary>
        /// Computes the sum of elements across dimensions of a tensor.
        /// </summary>
        /// <param name="input"></param>
        /// <param name="axis"></param>
        /// <returns></returns>
        public static Tensor reduce_sum(Tensor input, int[] axis = null)
        {
            Tensor rank;
            string name;

            using (var namescop = new ops.name_scope("", "Rank", new List <Tensor> {
                input
            }))
            {
                name = namescop;
                rank = gen_array_ops.rank(input, namescop);
            }

            using (var namescope = new ops.name_scope("range", "Range", new List <Tensor> {
                0D, input, 1D
            }))
            {
                name = namescope;
                var start = ops.convert_to_tensor(0D);
                var limit = ops.convert_to_tensor(input);
                var delta = ops.convert_to_tensor(1D);

                var t = gen_math_ops.range(start, limit, delta, name);
            }

            var s = gen_math_ops.sum(input, rank);

            return(s);
        }
Пример #3
0
        public static void _GradientsHelper(object ys,
                                            object xs,
                                            object grad_ys = null,
                                            string name    = "gradients",
                                            bool colocate_gradients_with_ops = false,
                                            bool gate_gradients   = false,
                                            object stop_gradients = null,
                                            Graph src_graph       = null)
        {
            if (src_graph == null)
            {
                src_graph = ops.get_default_graph();
            }

            // If src_graph is a _FuncGraph (i.e. a function body), gather it and all
            // ancestor graphs. This is necessary for correctly handling captured values.
            var curr_graph = src_graph;

            var           ys1             = _AsList(ys);
            var           xs1             = _AsList(xs);
            List <Tensor> grad_ys1        = null;
            List <Tensor> stop_gradients1 = stop_gradients == null ? new List <Tensor>() : _AsList(stop_gradients);

            if (grad_ys == null)
            {
                grad_ys1 = ys1.Select(x => new Tensor(IntPtr.Zero)).ToList();
            }
            else
            {
                grad_ys = _AsList(grad_ys);
            }

            var all = new List <Tensor>();

            all.AddRange(ys1);
            all.AddRange(xs1);
            all.AddRange(stop_gradients1);
            all.AddRange(grad_ys1);

            string grad_scope = "";

            using (var namescope = new ops.name_scope <Tensor>(name, "gradients", values: all))
            {
                grad_scope = namescope;
                // Get a uid for this call to gradients that can be used to help
                // cluster ops for compilation.
                var gradient_uid = ops.get_default_graph().unique_name("uid");

                // Initialize the pending count for ops in the connected subgraph from ys
                // to the xs.
                var to_ops            = ys1.Select(x => x.op).ToList();
                var from_ops          = xs1.Select(x => x.op).ToList();
                var stop_gradient_ops = stop_gradients1.Select(x => x.op).ToList();
                _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List <object>(), xs1);
            }
        }
Пример #4
0
        private static Tensor BinaryOpWrapper <Tx, Ty>(string name, Tx x, Ty y)
        {
            TF_DataType dtype = TF_DataType.DtInvalid;

            if (x is Tensor tl)
            {
                dtype = tl.dtype.as_base_dtype();
            }
            if (y is Tensor tr)
            {
                dtype = tr.dtype.as_base_dtype();
            }

            var namescope = new ops.name_scope(null, name, new { x, y });

            return(Python.with <ops.name_scope, Tensor>(namescope, scope =>
            {
                Tensor result = null;
                var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
                var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y");

                switch (name)
                {
                case "add":
                    result = gen_math_ops.add(x1, y1, name: scope);
                    break;

                case "truediv":
                    result = gen_math_ops.real_div(x1, y1, name: scope);
                    break;

                case "mul":
                    result = gen_math_ops.mul(x1, y1, name: scope);
                    break;

                case "sub":
                    result = gen_math_ops.sub(x1, y1, name: scope);
                    break;

                case "mod":
                    result = gen_math_ops.floor_mod(x1, y1, name: scope);
                    break;

                default:
                    throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty)}");
                }

                return result;
            }));
        }
Пример #5
0
        private VariableScope _enter_scope_uncached()
        {
            ops.name_scope current_name_scope;
            if (_auxiliary_name_scope)
            {
                // Create a new name scope later
                current_name_scope = null;
            }
            else
            {
                // Reenter the current name scope
                string name_scope = ops.get_name_scope();
                if (!string.IsNullOrEmpty(name_scope))
                {
                    // Hack to reenter
                    name_scope += "/";
                }
                current_name_scope = new ops.name_scope(name_scope);
            }

            if (_name != null || _scope != null)
            {
                var name_scope = _name == null?_scope._name.Split('/').Last() : _name;

                if (name_scope != null || current_name_scope != null)
                {
                    current_name_scope = new ops.name_scope(name_scope);
                }
                current_name_scope.__enter__();
                var current_name_scope_name = current_name_scope;
                _current_name_scope = current_name_scope;
                string            old_name_scope      = current_name_scope_name;
                PureVariableScope pure_variable_scope = null;
                if (_scope == null)
                {
                    pure_variable_scope = new PureVariableScope(_name, old_name_scope: old_name_scope);
                }
                else
                {
                    pure_variable_scope = new PureVariableScope(_scope, old_name_scope: old_name_scope);
                }
                pure_variable_scope.__enter__();
                VariableScope entered_pure_variable_scope = pure_variable_scope;
                _cached_pure_variable_scope = pure_variable_scope;
                return(entered_pure_variable_scope);
            }

            throw new NotImplementedException("_enter_scope_uncached");
        }
Пример #6
0
        public variable_scope(VariableScope scope,
                              string default_name       = "",
                              object values             = null,
                              bool auxiliary_name_scope = true)
        {
            _scope              = scope;
            _default_name       = default_name;
            _values             = values;
            _current_name_scope = null;

            _use_resource = false;
            if (_default_name == null && _scope == null)
            {
                throw new TypeError("If default_name is None then scope is required");
            }

            _auxiliary_name_scope = auxiliary_name_scope;
        }
Пример #7
0
        private void _init_from_args(object initial_value,
                                     bool trainable            = true,
                                     List <string> collections = null,
                                     bool validate_shape       = true,
                                     string caching_device     = "",
                                     string name       = "",
                                     TF_DataType dtype = TF_DataType.DtInvalid)
        {
            if (initial_value is null)
            {
                throw new ValueError("initial_value must be specified.");
            }

            var init_from_fn = false;

            if (collections == null)
            {
                collections = new List <string> {
                    ops.GraphKeys.GLOBAL_VARIABLES
                };
            }

            // Store the graph key so optimizers know how to only retrieve variables from
            // this graph.
            _graph_key = ops.get_default_graph()._graph_key;

            _trainable = trainable;
            if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
            {
                collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
            }

            ops.init_scope();
            var values = init_from_fn ? new List <object>() : new List <object> {
                initial_value
            };

            using (var namescope = new ops.name_scope <object>(name, "Variable", values))
            {
                name = namescope;

                if (init_from_fn)
                {
                }
                // Or get the initial value from a Tensor or Python object.
                else
                {
                    _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");

                    var shape = _initial_value.shape;
                    dtype     = _initial_value.dtype;
                    _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), name);
                }

                // Manually overrides the variable's shape with the initial value's.
                if (validate_shape)
                {
                    var initial_value_shape = _initial_value.shape;
                }

                // If 'initial_value' makes use of other variables, make sure we don't
                // have an issue if these other variables aren't initialized first by
                // using their initialized_value() method.
                var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);

                _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;

                if (!String.IsNullOrEmpty(caching_device))
                {
                }
                else
                {
                    ops.colocate_with(_initializer_op);

                    _snapshot = gen_array_ops.identity(_variable, name = "read");
                }

                ops.add_to_collections(collections, this);
            }
        }
Пример #8
0
        public Operation _apply_op_helper(string op_type_name, string name = "", Dictionary <string, object> keywords = null)
        {
            var g      = ops.get_default_graph();
            var op_def = g.GetOpDef(op_type_name);

            // Default name if not specified.
            if (String.IsNullOrEmpty(name))
            {
                name = op_type_name;
            }

            // Check for deprecation
            if (op_def.Deprecation != null && op_def.Deprecation.Version > 0)
            {
            }

            var default_type_attr_map = new Dictionary <string, object>();

            foreach (var attr_def in op_def.Attr)
            {
                if (attr_def.Type != "type")
                {
                    continue;
                }
                var key = attr_def.Name;
                if (attr_def.DefaultValue != null)
                {
                    default_type_attr_map[key] = attr_def.DefaultValue.Type;
                }
            }

            var attrs       = new Dictionary <string, object>();
            var inputs      = new List <Tensor>();
            var input_types = new List <TF_DataType>();

            string scope = "";

            using (var namescope = new ops.name_scope <object>(name))
            {
                scope = namescope;

                // Perform input type inference
                foreach (var input_arg in op_def.InputArg)
                {
                    var input_name = input_arg.Name;
                    if (keywords[input_name] is double int_value)
                    {
                        keywords[input_name] = constant_op.Constant(int_value, input_name);
                    }

                    if (keywords[input_name] is Tensor value)
                    {
                        if (keywords.ContainsKey(input_name))
                        {
                            inputs.Add(value);
                        }

                        if (!String.IsNullOrEmpty(input_arg.TypeAttr))
                        {
                            attrs[input_arg.TypeAttr] = value.dtype;
                        }

                        if (input_arg.IsRef)
                        {
                        }
                        else
                        {
                            var base_type = value.dtype.as_base_dtype();

                            input_types.Add(base_type);
                        }
                    }
                }

                // Process remaining attrs
                foreach (var attr in op_def.Attr)
                {
                    if (keywords.ContainsKey(attr.Name))
                    {
                        attrs[attr.Name] = keywords[attr.Name];
                    }
                }

                // Convert attr values to AttrValue protos.
                var attr_protos = new Dictionary <string, AttrValue>();
                foreach (var attr_def in op_def.Attr)
                {
                    var key        = attr_def.Name;
                    var value      = attrs[key];
                    var attr_value = new AttrValue();

                    switch (attr_def.Type)
                    {
                    case "string":
                        attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
                        break;

                    case "type":
                        attr_value.Type = _MakeType((TF_DataType)value, attr_def);
                        break;

                    case "bool":
                        attr_value.B = (bool)value;
                        break;

                    case "shape":
                        attr_value.Shape = value == null ?
                                           attr_def.DefaultValue.Shape :
                                           tensor_util.as_shape((long[])value);
                        break;

                    default:
                        throw new InvalidDataException($"attr_def.Type {attr_def.Type}");
                    }

                    attr_protos[key] = attr_value;
                }

                // Determine output types (possibly using attrs)
                var output_types = new List <TF_DataType>();

                foreach (var arg in op_def.OutputArg)
                {
                    if (!String.IsNullOrEmpty(arg.NumberAttr))
                    {
                    }
                    else if (!String.IsNullOrEmpty(arg.TypeAttr))
                    {
                        output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
                    }
                }

                // Add Op to graph
                var op = g.create_op(op_type_name, inputs, output_types.ToArray(),
                                     name: scope,
                                     input_types: input_types.ToArray(),
                                     attrs: attr_protos,
                                     op_def: op_def);

                return(op);
            }
        }