public override void OnExit(MethodExecutionArgs args) { if (args.ReturnValue is Tensors outputs) { if (args.Arguments[0] is Tensors inputs) { function.ToGraph(inputs, outputs); } else { function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); } } else { function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); } // cache function. function.ReturnType = args.ReturnValue.GetType(); functions[func_name] = function; // run function args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); }
public MapDataset(IDatasetV2 input_dataset, Func <Tensors, Tensors> map_func, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}"); func.Enter(); var inputs = new Tensors(); foreach (var input in input_dataset.element_spec) { inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); } var outputs = map_func(inputs); func.ToGraph(inputs, outputs); func.Exit(); structure = func.OutputStructure; variant_tensor = ops.map_dataset(input_dataset.variant_tensor, func, output_types, output_shapes, use_inter_op_parallelism: use_inter_op_parallelism, preserve_cardinality: preserve_cardinality); }
public FilterDataset(IDatasetV2 input_dataset, Func <Tensor, bool> predicate_func) : base(input_dataset) { Func <Tensors, Tensors> predicate_func_update = x => { var result = predicate_func(x); return(constant_op.constant(result)); }; var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); func.Enter(); var inputs = new Tensors(); foreach (var input in input_dataset.element_spec) { inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); } var outputs = predicate_func_update(inputs); func.ToGraph(inputs, outputs); func.Exit(); structure = func.OutputStructure; variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, func, output_types, output_shapes); }
public override void OnExit(MethodExecutionArgs args) { if (args.ReturnValue is Tensors outputs) { Tensors inputs = null; outputs = mark_as_return(outputs); if (args.Arguments[0] is Tensors inputs1) { inputs = inputs1; } else { inputs = args.Arguments.Select(x => x as Tensor).ToArray(); } inputs = inputs.Where(x => x.op.OpType == "Placeholder" && x.op.name.StartsWith("inputs")).ToArray(); function.ToGraph(inputs, outputs); } else if (args.ReturnValue is Tensor output) { var inputs = args.Arguments.Select(x => x as Tensor) .Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs")) .ToArray(); var outputs2 = array_ops.identity(output); function.ToGraph(inputs, outputs2); } function.Exit(); // cache function. function.ReturnType = args.ReturnValue.GetType(); functions[func_name] = function; // run function args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs)); }
public MapDataset(IDatasetV2 input_dataset, Func <Tensor, Tensor> map_func, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); var input = tf.placeholder(input_dataset.element_spec[0].dtype); var output = map_func(input); func.ToGraph(input, output); structure = func.OutputStructure; variant_tensor = ops.map_dataset(input_dataset.variant_tensor, func, output_types, output_shapes, use_inter_op_parallelism: use_inter_op_parallelism, preserve_cardinality: preserve_cardinality); }