public MapDataset(IDatasetV2 input_dataset, Func <Tensors, Tensors> map_func, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}"); func.Enter(); var inputs = new Tensors(); foreach (var input in input_dataset.element_spec) { inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); } var outputs = map_func(inputs); func.ToGraph(inputs, outputs); func.Exit(); structure = func.OutputStructure; variant_tensor = ops.map_dataset(input_dataset.variant_tensor, func, output_types, output_shapes, use_inter_op_parallelism: use_inter_op_parallelism, preserve_cardinality: preserve_cardinality); }
public FilterDataset(IDatasetV2 input_dataset, Func <Tensor, bool> predicate_func) : base(input_dataset) { Func <Tensors, Tensors> predicate_func_update = x => { var result = predicate_func(x); return(constant_op.constant(result)); }; var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); func.Enter(); var inputs = new Tensors(); foreach (var input in input_dataset.element_spec) { inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); } var outputs = predicate_func_update(inputs); func.ToGraph(inputs, outputs); func.Exit(); structure = func.OutputStructure; variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, func, output_types, output_shapes); }
public override void OnEntry(MethodExecutionArgs args) { // TODO: func_name can be cache in FullName + Args func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}"; if (functions.ContainsKey(func_name)) { function = functions[func_name]; if (args.Arguments[0] is Tensors tensor_inputs) { args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs)); } else { args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray())); } args.FlowBehavior = FlowBehavior.Return; return; } // make function as an Operation by autograph // need to restore mode when exits function = new ConcreteFunction(func_name); function.Enter(); // convert to Tensors if (args.Arguments[0] is Tensors inputs) { originalInputs = inputs; var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: "inputs")).ToArray(); args.Arguments[0] = new Tensors(new_inputs); } else { originalInputs = new Tensors(); // convert args to placeholder for (var i = 0; i < args.Arguments.Length; i++) { if (args.Arguments[i] is EagerTensor tensor) { originalInputs.Add(tensor); args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.shape, name: "inputs"); } } } }
public MapDataset(IDatasetV2 input_dataset, Func <Tensor, Tensor> map_func, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { using var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}"); func.Enter(); var input = tf.placeholder(input_dataset.element_spec[0].dtype); var output = map_func(input); func.ToGraph(input, output); func.Exit(); structure = func.OutputStructure; variant_tensor = ops.map_dataset(input_dataset.variant_tensor, func, output_types, output_shapes, use_inter_op_parallelism: use_inter_op_parallelism, preserve_cardinality: preserve_cardinality); }