public Func <Tensor, Tensor> to_graph(Func <Tensor, Tensor> func)
        {
            string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

            var graph = new FuncGraph(func_name);

            graph.as_default();

            var input  = tf.placeholder(tf.int32);
            var output = func(input);

            var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();

            graph.ToGraph(opers,
                          new[] { input },
                          new[] { output },
                          null);
            graph.Exit();


            return((Tensor input) =>
            {
                var result = tf.Runner.TFE_Execute(tf.Context,
                                                   tf.Context.DeviceName,
                                                   func_name,
                                                   new[] { input },
                                                   null,
                                                   1);
                return result[0];
            });
        }
        public Func <Tensor, Tensor, Tensor> to_graph(Func <Tensor, Tensor, Tensor> func)
        {
            string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";

            // IntPtr func_handle;
            using (var graph = new FuncGraph(func_name))
            {
                var input1 = tf.placeholder(tf.int32);
                var input2 = tf.placeholder(tf.int32);
                var output = func(input1, input2);

                var opers       = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
                var func_handle = graph.ToGraph(opers,
                                                new Operation[] { input1, input2 },
                                                new Operation[] { output },
                                                null);
            }

            return((Tensor a, Tensor b) =>
            {
                var result = tf.Runner.TFE_Execute(tf.Context,
                                                   tf.Context.DeviceName,
                                                   func_name,
                                                   new[] { a, b },
                                                   null,
                                                   1);
                return result[0];
            });
        }
        static void _copy_non_source(Operation op, FuncGraph graph, Dictionary <ITensorOrOperation, Operation> op_map, Graph base_graph)
        {
            Operation copied_op     = null;
            var       copied_inputs = new Tensors();

            tf_with(ops.control_dependencies(new object[] { op }), delegate
            {
                // Create a new op in the destination graph if it doesn't exist before.
                var attrs = new Dictionary <string, AttrValue>();
                foreach (var attr_def in op.node_def.Attr)
                {
                    attrs[attr_def.Key] = attr_def.Value;
                }

                copied_op = graph.create_op(op.type,
                                            copied_inputs,
                                            dtypes: op.outputs.Select(x => x.dtype).ToArray(),
                                            attrs: attrs,
                                            name: op.name);
            });
            op_map[op] = copied_op;
            foreach (var(i, o) in enumerate(op.outputs))
            {
                op_map[o] = copied_op.outputs[i];
            }
        }
Exemple #4
0
        public override void OnEntry(MethodExecutionArgs args)
        {
            func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}";

            if (functions.ContainsKey(func_name))
            {
                if (args.Arguments[0] is Tensors tensor_inputs)
                {
                    args.ReturnValue = functions[func_name](tensor_inputs.ToArray());
                }
                else
                {
                    args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray());
                }
                args.FlowBehavior = FlowBehavior.Return;
                return;
            }

            // make function as an Operation by autograph
            graph = new FuncGraph(func_name);

            // convert to Tensors
            if (args.Arguments[0] is Tensors inputs)
            {
                originalInputs = inputs;
                var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray();
                args.Arguments[0] = new Tensors(new_inputs);
            }
            else
            {
                originalInputs = new Tensors(args.Arguments.Length);
                // convert args to placeholder
                for (var i = 0; i < args.Arguments.Length; i++)
                {
                    if (args.Arguments[i] is EagerTensor tensor)
                    {
                        originalInputs[i] = tensor;
                        args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape);
                    }
                }
            }
        }
        static void _copy_source(Tensor s,
                                 FuncGraph graph,
                                 Dictionary <ITensorOrOperation, Operation> op_map,
                                 bool handle_captures,
                                 Dictionary <Tensor, Tensor> inverse_captures,
                                 Graph base_graph)
        {
            Tensor copied_placeholder = null;

            if (handle_captures && inverse_captures.ContainsKey(s))
            {
                copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name);
            }
            else
            {
                throw new NotImplementedException("");
            }
            op_map[s] = copied_placeholder;
            // Add an entry for the op of the source tensor so that if there are any nodes
            // depending on that op via control dependencies it can work correctly.
            op_map[s.op] = copied_placeholder.op;
        }
        /// <summary>
        /// Copies the tensor and all its inputs recursively to the outer graph.
        /// </summary>
        /// <param name="tensors"></param>
        /// <param name="graph"></param>
        /// <param name="add_sources"></param>
        /// <param name="handle_captures"></param>
        /// <param name="base_graph"></param>
        /// <returns></returns>
        public static Dictionary <ITensorOrOperation, Operation> lift_to_graph(Tensors init_tensors,
                                                                               FuncGraph graph,
                                                                               List <Tensor> sources,
                                                                               bool add_sources     = false,
                                                                               bool handle_captures = false,
                                                                               Graph base_graph     = null,
                                                                               Dictionary <ITensorOrOperation, Operation> op_map = null)
        {
            base_graph = base_graph ?? init_tensors[0].graph;
            op_map     = op_map ?? new Dictionary <ITensorOrOperation, Operation>();
            var visited_ops = sources.Select(x => x.op).ToList();

            foreach (var init_tensor in init_tensors)
            {
                var src = map_subgraph(init_tensor, sources, visited_ops, add_sources);
                sources.AddRange(src);
            }

            var ops_to_copy   = new List <Operation>();
            var marked_ops    = new List <Operation>();
            var ops_to_visit  = new Stack <Operation>(init_tensors.Select(x => x.op));
            var unvisited_ops = new List <Operation>(ops_to_visit.ToList());

            while (unvisited_ops.Count > 0)
            {
                while (ops_to_visit.Count > 0)
                {
                    var op = ops_to_visit.Pop();
                    if (marked_ops.Contains(op))
                    {
                        continue;
                    }
                    marked_ops.Add(op);
                    ops_to_copy.append(op);
                    foreach (var inp in op.inputs)
                    {
                    }
                }
                // difference_update
                unvisited_ops.difference_update(marked_ops);
                if (unvisited_ops.Count > 0)
                {
                    ops_to_visit.Push(unvisited_ops.Last());
                }
            }

            // When lifting from one FuncGraph to another, we will need to capture the
            // relevant tensors as well.
            var inverse_captures = new Dictionary <Tensor, Tensor>();

            Tensor[] internal_captures = null;
            if (base_graph is FuncGraph base_func_graph)
            {
                var captures = base_func_graph.captures;
                foreach (var(external_capture, internal_capture) in captures)
                {
                    inverse_captures[internal_capture] = external_capture;
                }
                internal_captures = base_func_graph.internal_captures;
            }

            graph.as_default();
            var source_ops = new List <Operation>();

            // Add the sources in the same order as the original graph.
            foreach (var s in internal_captures)
            {
                if (sources.Contains(s))
                {
                    sources.Remove(s);
                    source_ops.Add(s.op);
                    _copy_source(s: s,
                                 graph: graph,
                                 op_map: op_map,
                                 handle_captures: handle_captures,
                                 inverse_captures: inverse_captures,
                                 base_graph: base_graph);
                }
            }

            foreach (var op in reversed(ops_to_copy))
            {
                if (source_ops.Contains(op) || op_map.ContainsKey(op))
                {
                    continue;
                }
                _copy_non_source(op, graph, op_map, base_graph);
            }

            graph.Exit();

            return(op_map);
        }