예제 #1
0
        protected override void build(TensorShape input_shape)
        {
            var input_depth = input_shape.dims.Last();
            var h_depth     = _num_units;

            _kernel = add_weight(_WEIGHTS_VARIABLE_NAME,
                                 shape: new[] { input_depth + h_depth, 4 * _num_units });
            _bias = add_weight(_BIAS_VARIABLE_NAME,
                               shape: new[] { 4 * _num_units },
                               initializer: tf.zeros_initializer);
            built = true;
        }
예제 #2
0
 public static Tensor embedding_lookup(VariableV1 @params, Tensor ids,
                                       string partition_strategy = "mod",
                                       string name           = null,
                                       bool validate_indices = true,
                                       string max_norm       = null)
 {
     return(_embedding_lookup_and_transform(@params: @params,
                                            ids: ids,
                                            partition_strategy: partition_strategy,
                                            name: name,
                                            max_norm: max_norm));
 }
예제 #3
0
        protected override void build(TensorShape inputs_shape)
        {
            var input_depth = inputs_shape.dims[inputs_shape.ndim - 1];

            _kernel = add_weight(
                _WEIGHTS_VARIABLE_NAME,
                shape: new[] { input_depth + _num_units, _num_units });

            _bias = add_weight(
                _BIAS_VARIABLE_NAME,
                shape: new[] { _num_units },
                initializer: tf.zeros_initializer);

            built = true;
        }
예제 #4
0
        /// <summary>
        /// Helper function for embedding_lookup and _compute_sampled_logits.
        /// </summary>
        /// <param name="params"></param>
        /// <param name="ids"></param>
        /// <param name="partition_strategy"></param>
        /// <param name="name"></param>
        /// <param name="max_norm"></param>
        /// <returns></returns>
        public static Tensor _embedding_lookup_and_transform(VariableV1 @params,
                                                             Tensor ids,
                                                             string partition_strategy = "mod",
                                                             string name     = null,
                                                             string max_norm = null)
        {
            return(tf_with(ops.name_scope(name, "embedding_lookup", new { @params, ids }), scope =>
            {
                name = scope;
                int np = 1;
                ids = ops.convert_to_tensor(ids, name: "ids");
                if (np == 1)
                {
                    var gather = array_ops.gather(@params, ids, name: name);
                    var result = _clip(gather, ids, max_norm);

                    return array_ops.identity(result);
                }

                throw new NotImplementedException("_embedding_lookup_and_transform");
            }));
        }
예제 #5
0
 public static bool is_resource_variable(VariableV1 var)
 {
     return(var is ResourceVariable);
 }
예제 #6
0
        public static (Dictionary <string, VariableV1>, ITensorOrOperation[]) import_scoped_meta_graph_with_return_elements(MetaGraphDef meta_graph_or_file,
                                                                                                                            bool clear_devices  = false,
                                                                                                                            string import_scope = "",
                                                                                                                            Dictionary <string, Tensor> input_map = null,
                                                                                                                            string unbound_inputs_col_name        = "unbound_inputs",
                                                                                                                            string[] return_elements = null)
        {
            var meta_graph_def = meta_graph_or_file;

            if (!string.IsNullOrEmpty(unbound_inputs_col_name))
            {
                foreach (var col in meta_graph_def.CollectionDef)
                {
                    if (col.Key == unbound_inputs_col_name)
                    {
                        throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                    }
                }
            }

            // Sets graph to default graph if it's not passed in.
            var graph = ops.get_default_graph();

            // Gathers the list of nodes we are interested in.
            OpList producer_op_list = null;

            if (meta_graph_def.MetaInfoDef.StrippedOpList != null)
            {
                producer_op_list = meta_graph_def.MetaInfoDef.StrippedOpList;
            }
            var input_graph_def = meta_graph_def.GraphDef;

            // Remove all the explicit device specifications for this node. This helps to
            // make the graph more portable.
            if (clear_devices)
            {
                foreach (var node in input_graph_def.Node)
                {
                    node.Device = "";
                }
            }

            var scope_to_prepend_to_names = graph.unique_name("", mark_as_used: false);
            var imported_return_elements  = importer.import_graph_def(input_graph_def,
                                                                      name: scope_to_prepend_to_names,
                                                                      input_map: input_map,
                                                                      producer_op_list: producer_op_list,
                                                                      return_elements: return_elements);

            // Restores all the other collections.
            var variable_objects = new Dictionary <ByteString, VariableV1>();

            foreach (var col in meta_graph_def.CollectionDef.OrderBy(x => x.Key))
            {
                // Don't add unbound_inputs to the new graph.
                if (col.Key == unbound_inputs_col_name)
                {
                    continue;
                }

                switch (col.Value.KindCase)
                {
                case KindOneofCase.NodeList:
                    foreach (var value in col.Value.NodeList.Value)
                    {
                        var col_op = graph.as_graph_element(ops.prepend_name_scope(value, scope_to_prepend_to_names));
                        graph.add_to_collection(col.Key, col_op);
                    }
                    break;

                case KindOneofCase.BytesList:
                    //var proto_type = ops.get_collection_proto_type(key)
                    if (tf.GraphKeys._VARIABLE_COLLECTIONS.Contains(col.Key))
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            VariableV1 variable = null;
                            if (!variable_objects.ContainsKey(value))
                            {
                                var proto = VariableDef.Parser.ParseFrom(value);
                                if (proto.IsResource)
                                {
                                    variable = new ResourceVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                else
                                {
                                    variable = new RefVariable(variable_def: proto, import_scope: scope_to_prepend_to_names);
                                }
                                variable_objects[value] = variable;
                            }
                            variable = variable_objects[value];
                            graph.add_to_collection(col.Key, variable);
                        }
                    }
                    else
                    {
                        foreach (var value in col.Value.BytesList.Value)
                        {
                            switch (col.Key)
                            {
                            case "cond_context":
                            {
                                var proto       = CondContextDef.Parser.ParseFrom(value);
                                var condContext = new CondContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, condContext);
                            }
                            break;

                            case "while_context":
                            {
                                var proto        = WhileContextDef.Parser.ParseFrom(value);
                                var whileContext = new WhileContext().from_proto(proto, import_scope);
                                graph.add_to_collection(col.Key, whileContext);
                            }
                            break;

                            default:
                                Console.WriteLine("import_scoped_meta_graph_with_return_elements");
                                continue;
                            }
                        }
                    }

                    break;

                default:
                    throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
                }
            }

            var variables = graph.get_collection <VariableV1>(tf.GraphKeys.GLOBAL_VARIABLES,
                                                              scope: scope_to_prepend_to_names);
            var var_list = new Dictionary <string, VariableV1>();

            variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);

            return(var_list, imported_return_elements);
        }
예제 #7
0
        private VariableV1 _get_single_variable(string name,
                                                TensorShape shape        = null,
                                                TF_DataType dtype        = TF_DataType.DtInvalid,
                                                IInitializer initializer = null,
                                                bool reuse                = false,
                                                bool?trainable            = null,
                                                List <string> collections = null,
                                                bool validate_shape       = false,
                                                bool?use_resource         = null,
                                                VariableSynchronization synchronization = VariableSynchronization.Auto,
                                                VariableAggregation aggregation         = VariableAggregation.None)
        {
            bool initializing_from_value = false;

            if (use_resource == null)
            {
                use_resource = false;
            }

            if (_vars.ContainsKey(name))
            {
                if (!reuse)
                {
                    var var = _vars[name];
                }
                throw new NotImplementedException("_get_single_variable");
            }

            VariableV1 v = null;

            // Create the tensor to initialize the variable with default value.
            if (initializer == null)
            {
                if (dtype.is_floating())
                {
                    initializer             = tf.glorot_uniform_initializer;
                    initializing_from_value = false;
                }
            }

            // Create the variable.
            ops.init_scope();
            {
                if (initializing_from_value)
                {
                }
                else
                {
                    Func <Tensor> init_val       = () => initializer.call(shape, dtype);
                    var           variable_dtype = dtype.as_base_dtype();

                    v = variable_scope.default_variable_creator(init_val,
                                                                name: name,
                                                                trainable: trainable,
                                                                collections: collections,
                                                                dtype: variable_dtype,
                                                                validate_shape: validate_shape,
                                                                synchronization: synchronization,
                                                                aggregation: aggregation);
                }
            }

            _vars[name] = v;

            return(v);
        }