public Func <Tensor, Tensor> to_graph(Func <Tensor, Tensor> func) { string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; var graph = new FuncGraph(func_name); graph.as_default(); var input = tf.placeholder(tf.int32); var output = func(input); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); graph.ToGraph(opers, new[] { input }, new[] { output }, null); graph.Exit(); return((Tensor input) => { var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { input }, null, 1); return result[0]; }); }
/// <summary> /// Copies the tensor and all its inputs recursively to the outer graph. /// </summary> /// <param name="tensors"></param> /// <param name="graph"></param> /// <param name="add_sources"></param> /// <param name="handle_captures"></param> /// <param name="base_graph"></param> /// <returns></returns> public static Dictionary <ITensorOrOperation, Operation> lift_to_graph(Tensors init_tensors, FuncGraph graph, List <Tensor> sources, bool add_sources = false, bool handle_captures = false, Graph base_graph = null, Dictionary <ITensorOrOperation, Operation> op_map = null) { base_graph = base_graph ?? init_tensors[0].graph; op_map = op_map ?? new Dictionary <ITensorOrOperation, Operation>(); var visited_ops = sources.Select(x => x.op).ToList(); foreach (var init_tensor in init_tensors) { var src = map_subgraph(init_tensor, sources, visited_ops, add_sources); sources.AddRange(src); } var ops_to_copy = new List <Operation>(); var marked_ops = new List <Operation>(); var ops_to_visit = new Stack <Operation>(init_tensors.Select(x => x.op)); var unvisited_ops = new List <Operation>(ops_to_visit.ToList()); while (unvisited_ops.Count > 0) { while (ops_to_visit.Count > 0) { var op = ops_to_visit.Pop(); if (marked_ops.Contains(op)) { continue; } marked_ops.Add(op); ops_to_copy.append(op); foreach (var inp in op.inputs) { } } // difference_update unvisited_ops.difference_update(marked_ops); if (unvisited_ops.Count > 0) { ops_to_visit.Push(unvisited_ops.Last()); } } // When lifting from one FuncGraph to another, we will need to capture the // relevant tensors as well. var inverse_captures = new Dictionary <Tensor, Tensor>(); Tensor[] internal_captures = null; if (base_graph is FuncGraph base_func_graph) { var captures = base_func_graph.captures; foreach (var(external_capture, internal_capture) in captures) { inverse_captures[internal_capture] = external_capture; } internal_captures = base_func_graph.internal_captures; } graph.as_default(); var source_ops = new List <Operation>(); // Add the sources in the same order as the original graph. foreach (var s in internal_captures) { if (sources.Contains(s)) { sources.Remove(s); source_ops.Add(s.op); _copy_source(s: s, graph: graph, op_map: op_map, handle_captures: handle_captures, inverse_captures: inverse_captures, base_graph: base_graph); } } foreach (var op in reversed(ops_to_copy)) { if (source_ops.Contains(op) || op_map.ContainsKey(op)) { continue; } _copy_non_source(op, graph, op_map, base_graph); } graph.Exit(); return(op_map); }