コード例 #1
0
ファイル: importer.cs プロジェクト: sharwell/TensorFlow.NET
        private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements,
                                                                  Graph graph,
                                                                  TF_ImportGraphDefResults results)
        {
            var return_outputs = results.return_tensors;
            var return_opers   = results.return_opers;

            var combined_return_elements = new List <ITensorOrOperation>();
            int outputs_idx = 0;
            int opers_idx   = 0;

            foreach (var name in requested_return_elements)
            {
                if (name.Contains(":"))
                {
                    combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx]));
                    outputs_idx += 1;
                }
                else
                {
                    throw new NotImplementedException("_GatherReturnElements");
                    // combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx]));
                }
            }

            return(combined_return_elements.ToArray());
        }
コード例 #2
0
        private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, 
            Graph graph, 
            TF_ImportGraphDefResults results)
        {
            var return_outputs = results.return_tensors;
            var return_opers = results.return_opers;

            var combined_return_elements = new List<ITensorOrOperation>();
            int outputs_idx = 0;
#pragma warning disable CS0219 // Variable is assigned but its value is never used
            int opers_idx = 0;
#pragma warning restore CS0219 // Variable is assigned but its value is never used
            foreach(var name in requested_return_elements)
            {
                if (name.Contains(":"))
                {
                    combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx]));
                    outputs_idx += 1;
                }
                else
                {
                    throw new NotImplementedException("_GatherReturnElements");
                    // combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx]));
                }
            }

            return combined_return_elements.ToArray();
        }
コード例 #3
0
ファイル: importer.cs プロジェクト: sharwell/TensorFlow.NET
        public static ITensorOrOperation[] import_graph_def(GraphDef graph_def,
                                                            Dictionary <string, Tensor> input_map = null,
                                                            string[] return_elements = null,
                                                            string name             = null,
                                                            OpList producer_op_list = null)
        {
            var op_dict = op_def_registry.get_registered_ops();

            graph_def       = _ProcessGraphDefParam(graph_def, op_dict);
            input_map       = _ProcessInputMapParam(input_map);
            return_elements = _ProcessReturnElementsParam(return_elements);

            if (producer_op_list != null)
            {
                _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def);
            }

            string prefix = "";
            var    graph  = ops.get_default_graph();

            tf_with(ops.name_scope(name, "import", input_map.Values), scope =>
            {
                prefix = scope;

                /*if (!string.IsNullOrEmpty(prefix))
                 *  prefix = prefix.Substring(0, prefix.Length - 1);
                 * else
                 *  prefix = "";*/

                // Generate any input map tensors inside name scope
                input_map = _ConvertInputMapValues(name, input_map);
            });

            TF_ImportGraphDefResults results = null;
            var bytes = graph_def.ToByteString().ToArray();

            using (var buffer = c_api_util.tf_buffer(bytes))
                using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
                    using (var status = new Status())
                    {
                        _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
                        // need to create a class ImportGraphDefWithResults with IDisposal
                        results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle));
                        status.Check(true);
                    }

            _ProcessNewOps(graph);

            if (return_elements == null)
            {
                return(null);
            }
            else
            {
                return(_GatherReturnElements(return_elements, graph, results));
            }
        }