Esempio n. 1
0
        public MapDataset(IDatasetV2 input_dataset,
                          Func <Tensors, Tensors> map_func,
                          bool use_inter_op_parallelism = true,
                          bool preserve_cardinality     = false,
                          bool use_legacy_function      = false) : base(input_dataset)
        {
            var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}");

            func.Enter();
            var inputs = new Tensors();

            foreach (var input in input_dataset.element_spec)
            {
                inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
            }
            var outputs = map_func(inputs);

            func.ToGraph(inputs, outputs);
            func.Exit();

            structure = func.OutputStructure;

            variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
                                             func,
                                             output_types,
                                             output_shapes,
                                             use_inter_op_parallelism: use_inter_op_parallelism,
                                             preserve_cardinality: preserve_cardinality);
        }
        public FilterDataset(IDatasetV2 input_dataset,
                             Func <Tensor, bool> predicate_func) : base(input_dataset)
        {
            Func <Tensors, Tensors> predicate_func_update = x =>
            {
                var result = predicate_func(x);
                return(constant_op.constant(result));
            };

            var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}");

            func.Enter();
            var inputs = new Tensors();

            foreach (var input in input_dataset.element_spec)
            {
                inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
            }
            var outputs = predicate_func_update(inputs);

            func.ToGraph(inputs, outputs);
            func.Exit();

            structure = func.OutputStructure;

            variant_tensor = ops.filter_dataset(input_dataset.variant_tensor,
                                                func,
                                                output_types,
                                                output_shapes);
        }
        public override void OnEntry(MethodExecutionArgs args)
        {
            // TODO: func_name can be cache in FullName + Args
            func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}_{ops.uid_function()}";

            if (functions.ContainsKey(func_name))
            {
                function = functions[func_name];
                if (args.Arguments[0] is Tensors tensor_inputs)
                {
                    args.ReturnValue = ConvertReturnValue(function.FilteredCall(tensor_inputs));
                }
                else
                {
                    args.ReturnValue = ConvertReturnValue(function.FilteredCall(args.Arguments.Select(x => x as Tensor).ToArray()));
                }
                args.FlowBehavior = FlowBehavior.Return;
                return;
            }

            // make function as an Operation by autograph
            // need to restore mode when exits
            function = new ConcreteFunction(func_name);
            function.Enter();

            // convert to Tensors
            if (args.Arguments[0] is Tensors inputs)
            {
                originalInputs = inputs;
                var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: "inputs")).ToArray();
                args.Arguments[0] = new Tensors(new_inputs);
            }
            else
            {
                originalInputs = new Tensors();
                // convert args to placeholder
                for (var i = 0; i < args.Arguments.Length; i++)
                {
                    if (args.Arguments[i] is EagerTensor tensor)
                    {
                        originalInputs.Add(tensor);
                        args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.shape, name: "inputs");
                    }
                }
            }
        }
Esempio n. 4
0
        public MapDataset(IDatasetV2 input_dataset,
                          Func <Tensor, Tensor> map_func,
                          bool use_inter_op_parallelism = true,
                          bool preserve_cardinality     = false,
                          bool use_legacy_function      = false) : base(input_dataset)
        {
            using var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}");
            func.Enter();
            var input  = tf.placeholder(input_dataset.element_spec[0].dtype);
            var output = map_func(input);

            func.ToGraph(input, output);
            func.Exit();

            structure = func.OutputStructure;

            variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
                                             func,
                                             output_types,
                                             output_shapes,
                                             use_inter_op_parallelism: use_inter_op_parallelism,
                                             preserve_cardinality: preserve_cardinality);
        }