public Func <Tensor, Tensor> to_graph(Func <Tensor, Tensor> func) { string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; var graph = new FuncGraph(func_name); graph.as_default(); var input = tf.placeholder(tf.int32); var output = func(input); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); graph.ToGraph(opers, new[] { input }, new[] { output }, null); graph.Exit(); return((Tensor input) => { var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { input }, null, 1); return result[0]; }); }
public Func <Tensor, Tensor, Tensor> to_graph(Func <Tensor, Tensor, Tensor> func) { string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; // IntPtr func_handle; using (var graph = new FuncGraph(func_name)) { var input1 = tf.placeholder(tf.int32); var input2 = tf.placeholder(tf.int32); var output = func(input1, input2); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var func_handle = graph.ToGraph(opers, new Operation[] { input1, input2 }, new Operation[] { output }, null); } return((Tensor a, Tensor b) => { var result = tf.Runner.TFE_Execute(tf.Context, tf.Context.DeviceName, func_name, new[] { a, b }, null, 1); return result[0]; }); }
static void _copy_non_source(Operation op, FuncGraph graph, Dictionary <ITensorOrOperation, Operation> op_map, Graph base_graph) { Operation copied_op = null; var copied_inputs = new Tensors(); tf_with(ops.control_dependencies(new object[] { op }), delegate { // Create a new op in the destination graph if it doesn't exist before. var attrs = new Dictionary <string, AttrValue>(); foreach (var attr_def in op.node_def.Attr) { attrs[attr_def.Key] = attr_def.Value; } copied_op = graph.create_op(op.type, copied_inputs, dtypes: op.outputs.Select(x => x.dtype).ToArray(), attrs: attrs, name: op.name); }); op_map[op] = copied_op; foreach (var(i, o) in enumerate(op.outputs)) { op_map[o] = copied_op.outputs[i]; } }
public override void OnEntry(MethodExecutionArgs args) { func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}"; if (functions.ContainsKey(func_name)) { if (args.Arguments[0] is Tensors tensor_inputs) { args.ReturnValue = functions[func_name](tensor_inputs.ToArray()); } else { args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray()); } args.FlowBehavior = FlowBehavior.Return; return; } // make function as an Operation by autograph graph = new FuncGraph(func_name); // convert to Tensors if (args.Arguments[0] is Tensors inputs) { originalInputs = inputs; var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray(); args.Arguments[0] = new Tensors(new_inputs); } else { originalInputs = new Tensors(args.Arguments.Length); // convert args to placeholder for (var i = 0; i < args.Arguments.Length; i++) { if (args.Arguments[i] is EagerTensor tensor) { originalInputs[i] = tensor; args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape); } } } }
static void _copy_source(Tensor s, FuncGraph graph, Dictionary <ITensorOrOperation, Operation> op_map, bool handle_captures, Dictionary <Tensor, Tensor> inverse_captures, Graph base_graph) { Tensor copied_placeholder = null; if (handle_captures && inverse_captures.ContainsKey(s)) { copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name); } else { throw new NotImplementedException(""); } op_map[s] = copied_placeholder; // Add an entry for the op of the source tensor so that if there are any nodes // depending on that op via control dependencies it can work correctly. op_map[s.op] = copied_placeholder.op; }
/// <summary> /// Copies the tensor and all its inputs recursively to the outer graph. /// </summary> /// <param name="tensors"></param> /// <param name="graph"></param> /// <param name="add_sources"></param> /// <param name="handle_captures"></param> /// <param name="base_graph"></param> /// <returns></returns> public static Dictionary <ITensorOrOperation, Operation> lift_to_graph(Tensors init_tensors, FuncGraph graph, List <Tensor> sources, bool add_sources = false, bool handle_captures = false, Graph base_graph = null, Dictionary <ITensorOrOperation, Operation> op_map = null) { base_graph = base_graph ?? init_tensors[0].graph; op_map = op_map ?? new Dictionary <ITensorOrOperation, Operation>(); var visited_ops = sources.Select(x => x.op).ToList(); foreach (var init_tensor in init_tensors) { var src = map_subgraph(init_tensor, sources, visited_ops, add_sources); sources.AddRange(src); } var ops_to_copy = new List <Operation>(); var marked_ops = new List <Operation>(); var ops_to_visit = new Stack <Operation>(init_tensors.Select(x => x.op)); var unvisited_ops = new List <Operation>(ops_to_visit.ToList()); while (unvisited_ops.Count > 0) { while (ops_to_visit.Count > 0) { var op = ops_to_visit.Pop(); if (marked_ops.Contains(op)) { continue; } marked_ops.Add(op); ops_to_copy.append(op); foreach (var inp in op.inputs) { } } // difference_update unvisited_ops.difference_update(marked_ops); if (unvisited_ops.Count > 0) { ops_to_visit.Push(unvisited_ops.Last()); } } // When lifting from one FuncGraph to another, we will need to capture the // relevant tensors as well. var inverse_captures = new Dictionary <Tensor, Tensor>(); Tensor[] internal_captures = null; if (base_graph is FuncGraph base_func_graph) { var captures = base_func_graph.captures; foreach (var(external_capture, internal_capture) in captures) { inverse_captures[internal_capture] = external_capture; } internal_captures = base_func_graph.internal_captures; } graph.as_default(); var source_ops = new List <Operation>(); // Add the sources in the same order as the original graph. foreach (var s in internal_captures) { if (sources.Contains(s)) { sources.Remove(s); source_ops.Add(s.op); _copy_source(s: s, graph: graph, op_map: op_map, handle_captures: handle_captures, inverse_captures: inverse_captures, base_graph: base_graph); } } foreach (var op in reversed(ops_to_copy)) { if (source_ops.Contains(op) || op_map.ContainsKey(op)) { continue; } _copy_non_source(op, graph, op_map, base_graph); } graph.Exit(); return(op_map); }