public static Tensor with_dependencies(Operation[] dependencies, Tensor output_tensor, string name = "")
        {
            var values = new List <object>();

            values.AddRange(dependencies);
            values.Add(output_tensor);

            return(Python.with <ops.name_scope, Tensor>(new ops.name_scope(name, "control_dependency", values), scope =>
            {
                name = scope;

                return Python.with <_ControlDependenciesController, Tensor>(ops.control_dependencies(dependencies), ctl =>
                {
                    output_tensor = ops.convert_to_tensor_or_composite(output_tensor);
                    return _Identity(output_tensor, name: name);
                });
            }));
        }
        private static Operation _GroupControlDeps(string dev, Operation[] deps, string name = "")
        {
            Operation result = null;

            Python.with(ops.control_dependencies(deps), delegate
            {
                if (string.IsNullOrEmpty(dev))
                {
                    result = gen_control_flow_ops.no_op(name);
                }
                else
                {
                    result = gen_control_flow_ops.no_op(name);
                }
            });

            return(result);
        }
Beispiel #3
0
        public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
        {
            var ret = new List <Tensor>();

            foreach (var(i, value) in Python.enumerate(values))
            {
                if (value == null)
                {
                    ret.Add(value);
                }
                else
                {
                    var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
                    ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref));
                }
            }

            return(ret.ToArray());
        }
 public static Tensor random_normal(int[] shape,
                                    float mean        = 0.0f,
                                    float stddev      = 1.0f,
                                    TF_DataType dtype = TF_DataType.TF_FLOAT,
                                    int?seed          = null,
                                    string name       = "")
 {
     return(Python.with <ops.name_scope, Tensor>(new ops.name_scope(name, "random_normal", new object[] { shape, mean, stddev }), scope =>
     {
         var shape_tensor = _ShapeTensor(shape);
         var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean");
         var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name = "stddev");
         var(seed1, seed2) = random_seed.get_seed(seed);
         var rnd = gen_random_ops.random_standard_normal(shape_tensor, dtype: dtype, seed: seed1, seed2: seed2);
         var mul = rnd * stddev_tensor;
         var value = math_ops.add(mul, mean_tensor, name: name);
         return value;
     }));
 }
Beispiel #5
0
 public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true)
 {
     return(Python.with <ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List <Tensor> {
         input
     }), scope =>
     {
         name = scope;
         var input_tensor = ops.convert_to_tensor(input);
         var input_shape = tensor_util.to_shape(input_tensor.shape);
         if (optimize && input_shape.NDim == null)
         {
             return constant_op.constant(input_shape.NDim);
         }
         else
         {
             return gen_array_ops.rank(input, name);
         }
     }));
 }
Beispiel #6
0
        private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
        {
            return(Python.with <ops.name_scope, Tensor>(new ops.name_scope(name, "Shape", new Tensor[] { input }), scope =>
            {
                name = scope;

                if (!tf.context.executing_eagerly())
                {
                    var input_tensor = ops.convert_to_tensor(input);
                    var input_shape = tensor_util.to_shape(input_tensor.shape);
                    if (optimize && input_shape.is_fully_defined())
                    {
                        var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype());
                        return constant_op.constant(nd, name: name);
                    }
                }

                return gen_array_ops.shape(input);
            }));
        }
Beispiel #7
0
        /// <summary>
        /// A context manager that lifts ops out of control-flow scopes and function-building graphs.
        /// </summary>
        /// <returns></returns>
        public static void init_scope()
        {
            // Retrieve the active name scope: entering an `init_scope` preserves
            // the name scope of the current context.
            var default_graph = get_default_graph();
            var scope         = default_graph.get_name_scope();

            if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
            {
                // Names that end with trailing slashes are treated by `name_scope` as
                // absolute.
                scope += "/";
            }
            // inner_device_stack = default_graph._device_function_stack
            // var outer_context = default_graph.as_default;

            Python.with(ops.control_dependencies(null), delegate
            {
                var outer_graph = get_default_graph();
                // outer_device_stack = None
            });
        }
Beispiel #8
0
        public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool optimize = true)
        {
            return(Python.with <ops.name_scope, Tensor>(new ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope =>
            {
                name = scope;
                tensor = ops.convert_to_tensor(tensor, name: "tensor");

                // is_fully_defined return unexpected value.
                if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
                {
                }

                if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT)
                {
                    throw new NotImplementedException("zeros_like");
                    // return zeros(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name);
                }
                else
                {
                    return gen_array_ops.zeros_like(tensor, name: name);
                }
            }));
        }
Beispiel #9
0
        public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
        {
            if (limit == null)
            {
                limit = start;
                start = 0;
            }

            if (delta == null)
            {
                delta = 1;
            }

            return(Python.with <ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
            {
                name = scope;
                var start1 = ops.convert_to_tensor(start, name: "start");
                var limit1 = ops.convert_to_tensor(limit, name: "limit");
                var delta1 = ops.convert_to_tensor(delta, name: "delta");

                return gen_math_ops.range(start1, limit1, delta1, name);
            }));
        }
        public virtual SaverDef _build_internal(RefVariable[] names_to_saveables,
                                                bool reshape    = false,
                                                bool sharded    = false,
                                                int max_to_keep = 5,
                                                float keep_checkpoint_every_n_hours = 10000,
                                                string name = "",
                                                bool restore_sequentially = false,
                                                string filename           = "model",
                                                bool build_save           = true,
                                                bool build_restore        = true)
        {
            if (!build_save || !build_restore)
            {
                throw new ValueError("save and restore operations need to be built together " +
                                     " when eager execution is not enabled.");
            }

            var saveables = saveable_object_util.validate_and_slice_inputs(names_to_saveables);

            if (max_to_keep < 0)
            {
                max_to_keep = 0;
            }

            Tensor    save_tensor = null;
            Operation restore_op  = null;

            return(Python.with <ops.name_scope, SaverDef>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope =>
            {
                name = scope;

                // Add a placeholder string tensor for the filename.
                var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename");
                // Keep the name "Const" for backwards compatibility.
                filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const");

                // Add the save ops.
                if (sharded)
                {
                }
                else
                {
                    if (build_save)
                    {
                        save_tensor = _AddSaveOps(filename_tensor, saveables);
                    }

                    if (build_restore)
                    {
                        restore_op = _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape);
                    }
                }

                var graph = ops.get_default_graph();
                var check_collection_list = graph.get_all_collection_keys();
                foreach (var collection_type in check_collection_list)
                {
                    var cols = graph.get_collection(collection_type);
                    switch (cols)
                    {
                    case List <RefVariable> values:
                        foreach (var element in values)
                        {
                            ;
                        }
                        break;

                    case List <ITensorOrOperation> values:
                        foreach (var element in values)
                        {
                            ;
                        }
                        break;

                    default:
                        throw new NotImplementedException("_build_internal.check_collection_list");
                    }
                }

                return new SaverDef()
                {
                    FilenameTensorName = filename_tensor.name,
                    SaveTensorName = save_tensor.name,
                    RestoreOpName = restore_op.name,
                    MaxToKeep = max_to_keep,
                    Sharded = sharded,
                    KeepCheckpointEveryNHours = keep_checkpoint_every_n_hours,
                    Version = _write_version
                };
            }));
        }
        public static Tensor[] _GradientsHelper(Tensor[] ys,
                                                Tensor[] xs,
                                                Tensor[] grad_ys = null,
                                                string name      = "gradients",
                                                bool colocate_gradients_with_ops = false,
                                                bool gate_gradients     = false,
                                                int aggregation_method  = 0,
                                                Tensor[] stop_gradients = null,
                                                Graph src_graph         = null)
        {
            if (src_graph == null)
            {
                src_graph = ops.get_default_graph();
            }

            // If src_graph is a _FuncGraph (i.e. a function body), gather it and all
            // ancestor graphs. This is necessary for correctly handling captured values.
            var curr_graph = src_graph;

            if (stop_gradients == null)
            {
                stop_gradients = new Tensor[0];
            }
            if (grad_ys == null)
            {
                grad_ys = new Tensor[ys.Length];
            }

            var all = new List <Tensor>();

            all.AddRange(ys);
            all.AddRange(xs);
            all.AddRange(stop_gradients);
            all.AddRange(grad_ys);

            // Iterate over the collected ops.

            /**
             * grads: op => list of gradients received on each output endpoint of the
             * op.  The gradients for each endpoint are initially collected as a list.
             * When it is time to call the op's gradient function, for each endpoint we
             * aggregate the list of received gradients into a Add() Operation if there
             * is more than one.
             **/
            var grads = new Dictionary <string, Tensor[][]>();

            with(ops.name_scope(name, "gradients", values: all), scope =>
            {
                string grad_scope = scope;
                // Get a uid for this call to gradients that can be used to help
                // cluster ops for compilation.
                var gradient_uid = ops.get_default_graph().unique_name("uid");
                ys      = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y");
                xs      = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true);
                grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid);

                /**
                 * The approach we take here is as follows: Create a list of all ops in the
                 * subgraph between the ys and xs.  Visit these ops in reverse order of ids
                 * to ensure that when we visit an op the gradients w.r.t its outputs have
                 * been collected.  Then aggregate these gradients if needed, call the op's
                 * gradient function, and add the generated gradients to the gradients for
                 * its input.
                 **/

                // Initialize the pending count for ops in the connected subgraph from ys
                // to the xs.
                var to_ops            = ys.Select(x => x.op).ToList();
                var from_ops          = xs.Select(x => x.op).ToList();
                var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
                (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List <object>(), xs);

                foreach (var(y, grad_y) in Python.zip(ys, grad_ys))
                {
                    _SetGrad(grads, y, grad_y);
                }

                // Initialize queue with to_ops.
                var queue = new Queue <Operation>();
                // Add the ops in 'to_ops' into the queue.
                var to_ops_set = new List <Operation>();
                foreach (var op in to_ops)
                {
                    // 'ready' handles the case where one output gradient relies on
                    // another output's gradient.
                    if (!pending_count.ContainsKey(op.name))
                    {
                        pending_count[op.name] = 0;
                    }
                    bool ready = pending_count[op.name] == 0;
                    if (ready && !to_ops_set.Contains(op) && reachable_to_ops.Contains(op))
                    {
                        to_ops_set.Add(op);
                        queue.Enqueue(op);
                    }
                }

                var stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs);
                while (queue.Count > 0)
                {
                    // generate gradient subgraph for op.
                    var op = queue.Dequeue();
                    _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
                    //if (loop_state != null)
                    //loop_state.EnterGradWhileContext(op, before: true);
                    var out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, aggregation_method);

                    Tensor[] in_grads       = null;
                    var is_partitioned_call = _IsPartitionedCall(op);
                    var is_func_call        = false;
                    var has_out_grads       = true;
                    if (has_out_grads && !stop_ops.Contains(op))
                    {
                        if (is_func_call)
                        {
                        }
                        else
                        {
                            // A grad_fn must be defined, either as a function or as None
                            // for ops that do not have gradients.
                            var grad_fn = ops.get_gradient_function(op);

                            foreach (var(i, out_grad) in enumerate(out_grads))
                            {
                                if (out_grad == null)
                                {
                                    if (loop_state != null)
                                    {
                                        ;
                                    }
                                    else
                                    {
                                        out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i);
                                    }
                                }
                            }

                            with(ops.name_scope(op.name + "_grad"), scope1 =>
                            {
                                string name1 = scope1;
                                if (grad_fn != null)
                                {
                                    in_grads = _MaybeCompile(grad_scope, op, out_grads, null, grad_fn);
                                    _VerifyGeneratedGradients(in_grads, op);
                                }

                                if (gate_gradients && in_grads.Count(x => x != null) > 1)
                                {
                                    ops._colocate_with_for_gradient(null, gradient_uid, ignore_existing: true);
                                    in_grads = control_flow_ops.tuple(in_grads);
                                }
                            });
                        }
                    }
                    else
                    {
                        in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
                    }

                    var inputs = _NonEagerInputs(op, xs).ToList();
                    foreach (var(t_in, in_grad) in Python.zip(inputs, in_grads))
                    {
                        if (in_grad != null)
                        {
                            if (in_grad is Tensor && t_in.dtype != TF_DataType.TF_RESOURCE)
                            {
                                in_grad.shape = t_in.shape;
                            }

                            _SetGrad(grads, t_in, in_grad);
                        }
                    }

                    // Update pending count for the inputs of op and enqueue ready ops.
                    _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state, xs);
                }
            });
Beispiel #12
0
        public Operation apply_gradients(Tuple <Tensor, RefVariable>[] grads_and_vars, Tensor global_step = null, string name = null)
        {
            // No DistributionStrategy case.
            var converted_grads_and_vars = new List <Tuple <Tensor, RefVariable, _OptimizableVariable> >();

            foreach (var(g, v) in grads_and_vars)
            {
                if (g != null)
                {
                    // Convert the grad to Tensor or IndexedSlices if necessary.
                    var gR = ops.convert_to_tensor_or_indexed_slices(g);
                    var p  = _get_processor(v);
                    converted_grads_and_vars.Add(new Tuple <Tensor, RefVariable, _OptimizableVariable>(gR, v, p));
                }
            }

            var var_list = converted_grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray();

            if (var_list.Length == 0)
            {
                throw new ValueError($"No gradients provided for any variable");
            }

            ops.init_scope();
            _create_slots(var_list);

            var update_ops = new List <Operation>();

            return(Python.with <ops.name_scope, Operation>(new ops.name_scope(name, Name), scope =>
            {
                name = scope;
                _prepare();

                foreach (var(grad, var, processor) in converted_grads_and_vars)
                {
                    if (grad == null)
                    {
                        continue;
                    }

                    var scope_name = var.op.name;
                    Python.with <ops.name_scope>(new ops.name_scope("update_" + scope_name), scope2 =>
                    {
                        update_ops.Add(processor.update_op(this, grad));
                    });
                }

                Operation apply_updates = null;
                if (global_step == null)
                {
                    apply_updates = _finish(update_ops.ToArray(), name);
                }
                else
                {
                }

                if (!tf.context.executing_eagerly())
                {
                    var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List <object>;
                    if (!train_op.Contains(apply_updates))
                    {
                        train_op.Add(apply_updates);
                    }
                }

                return apply_updates;
            }));
Beispiel #13
0
        private void _init_from_args(object initial_value,
                                     bool trainable            = true,
                                     List <string> collections = null,
                                     bool validate_shape       = true,
                                     string caching_device     = "",
                                     string name       = null,
                                     TF_DataType dtype = TF_DataType.DtInvalid)
        {
            if (initial_value is null)
            {
                throw new ValueError("initial_value must be specified.");
            }

            var init_from_fn = false;

            if (collections == null)
            {
                collections = new List <string> {
                    ops.GraphKeys.GLOBAL_VARIABLES
                };
            }

            // Store the graph key so optimizers know how to only retrieve variables from
            // this graph.
            _graph_key = ops.get_default_graph()._graph_key;

            _trainable = trainable;
            if (!collections.Contains(ops.GraphKeys.TRAINABLE_VARIABLES))
            {
                collections.Add(ops.GraphKeys.TRAINABLE_VARIABLES);
            }

            ops.init_scope();
            var values = init_from_fn ? new List <object>() : new List <object> {
                initial_value
            };

            Python.with <ops.name_scope>(new ops.name_scope(name, "Variable", values), scope =>
            {
                if (init_from_fn)
                {
                }
                // Or get the initial value from a Tensor or Python object.
                else
                {
                    _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");

                    var shape = _initial_value.shape;
                    dtype     = _initial_value.dtype;
                    _variable = gen_state_ops.variable_v2(shape, dtype.as_base_dtype(), scope);
                }

                // Manually overrides the variable's shape with the initial value's.
                if (validate_shape)
                {
                    var initial_value_shape = _initial_value.shape;
                }

                // If 'initial_value' makes use of other variables, make sure we don't
                // have an issue if these other variables aren't initialized first by
                // using their initialized_value() method.
                var _initial_value2 = _try_guard_against_uninitialized_dependencies(_initial_value);

                _initializer_op = gen_state_ops.assign(_variable, _initial_value2, validate_shape).op;

                if (!String.IsNullOrEmpty(caching_device))
                {
                }
                else
                {
                    ops.colocate_with(_initializer_op);

                    _snapshot = gen_array_ops.identity(_variable, name = "read");
                }

                ops.add_to_collections(collections, this);
            });
        }
        public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null)
        {
            Dictionary <string, object> keywords = ConvertToDict(args);
            var g      = ops.get_default_graph();
            var op_def = g.GetOpDef(op_type_name);

            // Default name if not specified.
            if (String.IsNullOrEmpty(name))
            {
                name = op_type_name;
            }

            // Check for deprecation
            if (op_def.Deprecation != null && op_def.Deprecation.Version > 0)
            {
            }

            var default_type_attr_map = new Dictionary <string, object>();

            foreach (var attr_def in op_def.Attr)
            {
                if (attr_def.Type != "type")
                {
                    continue;
                }
                var key = attr_def.Name;
                if (attr_def.DefaultValue != null)
                {
                    default_type_attr_map[key] = attr_def.DefaultValue.Type;
                }
            }

            var     attrs       = new Dictionary <string, object>();
            var     inputs      = new List <Tensor>();
            var     input_types = new List <TF_DataType>();
            dynamic values      = null;

            return(Python.with <ops.name_scope, Operation>(new ops.name_scope(name), scope =>
            {
                var inferred_from = new Dictionary <string, object>();
                var base_types = new List <TF_DataType>();
                var types = new List <TF_DataType>();

                // Perform input type inference
                foreach (var input_arg in op_def.InputArg)
                {
                    var input_name = input_arg.Name;

                    if (keywords.ContainsKey(input_name))
                    {
                        values = keywords[input_name];
                    }
                    else if (keywords.ContainsKey(input_name + "_"))
                    {
                        input_name += "_";
                        values = keywords[input_name];
                    }
                    else
                    {
                        throw new TypeError("No argument for input " + input_name);
                    }

                    // Goals:
                    // * Convert values to Tensors if it contains constants.
                    // * Verify that values is a list if that matches the input_arg's
                    // type.
                    // * If the input_arg's type is determined by attrs, either set
                    // those attrs and validate those attr values are legal (if
                    // they have not yet been set) or validate the input matches
                    // the type indicated by the attrs (if they have already been
                    // inferred via an earlier input).
                    // * If the input_arg has an explicit type, make sure the input
                    // conforms.

                    DataType dtype = DataType.DtInvalid;
                    DataType default_dtype = DataType.DtInvalid;

                    if (_IsListParameter(input_arg))
                    {
                        if (!_IsListValue(values))
                        {
                            throw new TypeError($"Expected list for '{input_name}' argument to '{op_type_name}' Op, not {values}.");
                        }
                        if (input_arg.Type != DataType.DtInvalid)
                        {
                            dtype = input_arg.Type;
                        }
                        else if (!String.IsNullOrEmpty(input_arg.NumberAttr))
                        {
                            if (attrs.ContainsKey(input_arg.TypeAttr))
                            {
                                dtype = (DataType)attrs[input_arg.TypeAttr];
                            }
                            else
                            if (values is Tensor[] values1)
                            {
                                dtype = values1[0].dtype.as_datatype_enum();
                            }

                            if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr))
                            {
                                default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
                            }
                        }

                        if (input_arg.IsRef && dtype != DataType.DtInvalid)
                        {
                            dtype = dtype.as_base_dtype();
                        }

                        values = ops.internal_convert_n_to_tensor(values,
                                                                  name: input_arg.Name,
                                                                  dtype: dtype.as_tf_dtype(),
                                                                  preferred_dtype: default_dtype.as_tf_dtype(),
                                                                  as_ref: input_arg.IsRef);
                    }
                    else
                    {
                        if (input_arg.Type != DataType.DtInvalid)
                        {
                            dtype = input_arg.Type;
                        }
                        else if (attrs.ContainsKey(input_arg.TypeAttr))
                        {
                            dtype = (DataType)attrs[input_arg.TypeAttr];
                        }
                        else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr))
                        {
                            default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
                        }

                        values = ops.internal_convert_to_tensor(values,
                                                                name: input_name,
                                                                dtype: dtype.as_tf_dtype(),
                                                                as_ref: input_arg.IsRef,
                                                                preferred_dtype: default_dtype.as_tf_dtype());

                        //if (!String.IsNullOrEmpty(input_arg.TypeAttr))
                        //attrs[input_arg.TypeAttr] = values.dtype;

                        values = new Tensor[] { values };
                    }

                    if (values is Tensor[] values2)
                    {
                        types = values2.Select(x => x.dtype).ToList();
                        inputs.AddRange(values2);
                        base_types = values2.Select(x => x.dtype.as_base_dtype()).ToList();
                    }
                    else
                    {
                        throw new NotImplementedException("_IsListParameter");
                    }

                    SetAttrs(op_type_name,
                             input_arg,
                             op_def,
                             attrs,
                             inferred_from,
                             types,
                             base_types,
                             input_types,
                             values);
                }

                // Process remaining attrs
                foreach (var attr in op_def.Attr)
                {
                    if (keywords.ContainsKey(attr.Name))
                    {
                        attrs[attr.Name] = keywords[attr.Name];
                    }
                }

                // Convert attr values to AttrValue protos.
                var attr_protos = new Dictionary <string, AttrValue>();
                foreach (var attr_def in op_def.Attr)
                {
                    var key = attr_def.Name;
                    var value = attrs[key];

                    if (!attrs.ContainsKey(key))
                    {
                        Console.WriteLine($"_apply_op_helper: key '{key}' is not found in '{op_def.Name}' operation's attr_def.");
                    }

                    attr_protos[key] = SetAttrValue(op_def, attr_def, value);
                }

                attrs.Clear();

                // Determine output types (possibly using attrs)
                var output_types = new List <TF_DataType>();

                foreach (var arg in op_def.OutputArg)
                {
                    types = new List <TF_DataType>();
                    if (!string.IsNullOrEmpty(arg.NumberAttr))
                    {
                    }
                    else if (!string.IsNullOrEmpty(arg.TypeAttr))
                    {
                        types = new List <TF_DataType>()
                        {
                            (TF_DataType)attr_protos[arg.TypeAttr].Type
                        };
                    }

                    if (arg.IsRef)
                    {
                        types = types.Select(x => x.as_ref()).ToList();
                    }

                    output_types.AddRange(types);
                }

                // Add Op to graph
                var op = g.create_op(op_type_name, inputs.ToArray(), output_types.ToArray(),
                                     name: scope,
                                     input_types: input_types.ToArray(),
                                     attrs: attr_protos,
                                     op_def: op_def);

                return op;
            }));
        }
Beispiel #15
0
        public Operation _apply_op_helper(string op_type_name, string name = "", dynamic args = null)
        {
            var keywords = ConvertToDict(args);
            var g        = ops.get_default_graph();
            var op_def   = g.GetOpDef(op_type_name);

            // Default name if not specified.
            if (String.IsNullOrEmpty(name))
            {
                name = op_type_name;
            }

            // Check for deprecation
            if (op_def.Deprecation != null && op_def.Deprecation.Version > 0)
            {
            }

            var default_type_attr_map = new Dictionary <string, object>();

            foreach (var attr_def in op_def.Attr)
            {
                if (attr_def.Type != "type")
                {
                    continue;
                }
                var key = attr_def.Name;
                if (attr_def.DefaultValue != null)
                {
                    default_type_attr_map[key] = attr_def.DefaultValue.Type;
                }
            }

            var attrs       = new Dictionary <string, object>();
            var inputs      = new List <Tensor>();
            var input_types = new List <TF_DataType>();

            Operation op = null;

            Python.with <ops.name_scope>(new ops.name_scope(name), scope =>
            {
                // Perform input type inference
                foreach (var input_arg in op_def.InputArg)
                {
                    var input_name = input_arg.Name;
                    if (keywords[input_name] is double int_value)
                    {
                        keywords[input_name] = constant_op.constant(int_value, input_name);
                    }

                    if (keywords[input_name] is Tensor value)
                    {
                        if (keywords.ContainsKey(input_name))
                        {
                            inputs.Add(value);
                        }

                        if (!String.IsNullOrEmpty(input_arg.TypeAttr))
                        {
                            attrs[input_arg.TypeAttr] = value.dtype;
                        }

                        if (input_arg.IsRef)
                        {
                        }
                        else
                        {
                            var base_type = value.dtype.as_base_dtype();

                            input_types.Add(base_type);
                        }
                    }
                }

                // Process remaining attrs
                foreach (var attr in op_def.Attr)
                {
                    if (keywords.ContainsKey(attr.Name))
                    {
                        attrs[attr.Name] = keywords[attr.Name];
                    }
                }

                // Convert attr values to AttrValue protos.
                var attr_protos = new Dictionary <string, AttrValue>();
                foreach (var attr_def in op_def.Attr)
                {
                    var key        = attr_def.Name;
                    var value      = attrs[key];
                    var attr_value = new AttrValue();

                    switch (attr_def.Type)
                    {
                    case "string":
                        attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
                        break;

                    case "type":
                        attr_value.Type = _MakeType((TF_DataType)value, attr_def);
                        break;

                    case "bool":
                        attr_value.B = (bool)value;
                        break;

                    case "shape":
                        attr_value.Shape = value == null ?
                                           attr_def.DefaultValue.Shape :
                                           tensor_util.as_shape((long[])value);
                        break;

                    default:
                        throw new InvalidDataException($"attr_def.Type {attr_def.Type}");
                    }

                    attr_protos[key] = attr_value;
                }

                // Determine output types (possibly using attrs)
                var output_types = new List <TF_DataType>();

                foreach (var arg in op_def.OutputArg)
                {
                    if (!String.IsNullOrEmpty(arg.NumberAttr))
                    {
                    }
                    else if (!String.IsNullOrEmpty(arg.TypeAttr))
                    {
                        output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type);
                    }
                }

                // Add Op to graph
                op = g.create_op(op_type_name, inputs, output_types.ToArray(),
                                 name: scope,
                                 input_types: input_types.ToArray(),
                                 attrs: attr_protos,
                                 op_def: op_def);
            });

            return(op);
        }