コード例 #1
0
        public Func <Tensor, Tensor> to_graph(Func <Tensor, Tensor> func)
        {
            string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

            var graph = new FuncGraph(func_name);

            graph.as_default();

            var input  = tf.placeholder(tf.int32);
            var output = func(input);

            var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();

            graph.ToGraph(opers,
                          new[] { input },
                          new[] { output },
                          null);
            graph.Exit();


            return((Tensor input) =>
            {
                var result = tf.Runner.TFE_Execute(tf.Context,
                                                   tf.Context.DeviceName,
                                                   func_name,
                                                   new[] { input },
                                                   null,
                                                   1);
                return result[0];
            });
        }
コード例 #2
0
        /// <summary>
        /// Copies the tensor and all its inputs recursively to the outer graph.
        /// </summary>
        /// <param name="tensors"></param>
        /// <param name="graph"></param>
        /// <param name="add_sources"></param>
        /// <param name="handle_captures"></param>
        /// <param name="base_graph"></param>
        /// <returns></returns>
        public static Dictionary <ITensorOrOperation, Operation> lift_to_graph(Tensors init_tensors,
                                                                               FuncGraph graph,
                                                                               List <Tensor> sources,
                                                                               bool add_sources     = false,
                                                                               bool handle_captures = false,
                                                                               Graph base_graph     = null,
                                                                               Dictionary <ITensorOrOperation, Operation> op_map = null)
        {
            base_graph = base_graph ?? init_tensors[0].graph;
            op_map     = op_map ?? new Dictionary <ITensorOrOperation, Operation>();
            var visited_ops = sources.Select(x => x.op).ToList();

            foreach (var init_tensor in init_tensors)
            {
                var src = map_subgraph(init_tensor, sources, visited_ops, add_sources);
                sources.AddRange(src);
            }

            var ops_to_copy   = new List <Operation>();
            var marked_ops    = new List <Operation>();
            var ops_to_visit  = new Stack <Operation>(init_tensors.Select(x => x.op));
            var unvisited_ops = new List <Operation>(ops_to_visit.ToList());

            while (unvisited_ops.Count > 0)
            {
                while (ops_to_visit.Count > 0)
                {
                    var op = ops_to_visit.Pop();
                    if (marked_ops.Contains(op))
                    {
                        continue;
                    }
                    marked_ops.Add(op);
                    ops_to_copy.append(op);
                    foreach (var inp in op.inputs)
                    {
                    }
                }
                // difference_update
                unvisited_ops.difference_update(marked_ops);
                if (unvisited_ops.Count > 0)
                {
                    ops_to_visit.Push(unvisited_ops.Last());
                }
            }

            // When lifting from one FuncGraph to another, we will need to capture the
            // relevant tensors as well.
            var inverse_captures = new Dictionary <Tensor, Tensor>();

            Tensor[] internal_captures = null;
            if (base_graph is FuncGraph base_func_graph)
            {
                var captures = base_func_graph.captures;
                foreach (var(external_capture, internal_capture) in captures)
                {
                    inverse_captures[internal_capture] = external_capture;
                }
                internal_captures = base_func_graph.internal_captures;
            }

            graph.as_default();
            var source_ops = new List <Operation>();

            // Add the sources in the same order as the original graph.
            foreach (var s in internal_captures)
            {
                if (sources.Contains(s))
                {
                    sources.Remove(s);
                    source_ops.Add(s.op);
                    _copy_source(s: s,
                                 graph: graph,
                                 op_map: op_map,
                                 handle_captures: handle_captures,
                                 inverse_captures: inverse_captures,
                                 base_graph: base_graph);
                }
            }

            foreach (var op in reversed(ops_to_copy))
            {
                if (source_ops.Contains(op) || op_map.ContainsKey(op))
                {
                    continue;
                }
                _copy_non_source(op, graph, op_map, base_graph);
            }

            graph.Exit();

            return(op_map);
        }