public override void OnEntry(MethodExecutionArgs args) { // TODO: func_name can be cache in FullName + Args func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; if (functions.ContainsKey(func_name)) { function = functions[func_name]; if (args.Arguments[0] is Tensors tensor_inputs) { args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs)); } else { args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray())); } args.FlowBehavior = FlowBehavior.Return; return; } // make function as an Operation by autograph // need to restore mode when exits function = new ConcreteFunction(func_name); function.Enter(); // convert to Tensors if (args.Arguments[0] is Tensors inputs) { originalInputs = inputs; var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: "inputs")).ToArray(); args.Arguments[0] = new Tensors(new_inputs); } else { originalInputs = new Tensors(); // convert args to placeholder for (var i = 0; i < args.Arguments.Length; i++) { if (args.Arguments[i] is EagerTensor tensor) { originalInputs.Add(tensor); args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.shape, name: "inputs"); } } } }
public override void OnExit(MethodExecutionArgs args) { if (args.ReturnValue is Tensors outputs) { Tensors inputs = null; outputs = mark_as_return(outputs); if (args.Arguments[0] is Tensors inputs1) { inputs = inputs1; } else { inputs = args.Arguments.Select(x => x as Tensor).ToArray(); } inputs = inputs.Where(x => x.op.OpType == "Placeholder" && x.op.name.StartsWith("inputs")).ToArray(); function.ToGraph(inputs, outputs); } else if (args.ReturnValue is Tensor output) { var inputs = args.Arguments.Select(x => x as Tensor) .Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs")) .ToArray(); var outputs2 = array_ops.identity(output); function.ToGraph(inputs, outputs2); } function.Exit(); // cache function. function.ReturnType = args.ReturnValue.GetType(); functions[func_name] = function; // run function args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs)); }