public override void OnExit(MethodExecutionArgs args)
        {
            if (args.ReturnValue is Tensors outputs)
            {
                if (args.Arguments[0] is Tensors inputs)
                {
                    function.ToGraph(inputs, outputs);
                }
                else
                {
                    function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs);
                }
            }
            else
            {
                function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor);
            }

            // cache function.
            function.ReturnType  = args.ReturnValue.GetType();
            functions[func_name] = function;

            // run function
            args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs));
        }
Esempio n. 2
0
        public MapDataset(IDatasetV2 input_dataset,
                          Func <Tensors, Tensors> map_func,
                          bool use_inter_op_parallelism = true,
                          bool preserve_cardinality     = false,
                          bool use_legacy_function      = false) : base(input_dataset)
        {
            var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}");

            func.Enter();
            var inputs = new Tensors();

            foreach (var input in input_dataset.element_spec)
            {
                inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
            }
            var outputs = map_func(inputs);

            func.ToGraph(inputs, outputs);
            func.Exit();

            structure = func.OutputStructure;

            variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
                                             func,
                                             output_types,
                                             output_shapes,
                                             use_inter_op_parallelism: use_inter_op_parallelism,
                                             preserve_cardinality: preserve_cardinality);
        }
        public FilterDataset(IDatasetV2 input_dataset,
                             Func <Tensor, bool> predicate_func) : base(input_dataset)
        {
            Func <Tensors, Tensors> predicate_func_update = x =>
            {
                var result = predicate_func(x);
                return(constant_op.constant(result));
            };

            var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}");

            func.Enter();
            var inputs = new Tensors();

            foreach (var input in input_dataset.element_spec)
            {
                inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
            }
            var outputs = predicate_func_update(inputs);

            func.ToGraph(inputs, outputs);
            func.Exit();

            structure = func.OutputStructure;

            variant_tensor = ops.filter_dataset(input_dataset.variant_tensor,
                                                func,
                                                output_types,
                                                output_shapes);
        }
        public override void OnExit(MethodExecutionArgs args)
        {
            if (args.ReturnValue is Tensors outputs)
            {
                Tensors inputs = null;
                outputs = mark_as_return(outputs);
                if (args.Arguments[0] is Tensors inputs1)
                {
                    inputs = inputs1;
                }
                else
                {
                    inputs = args.Arguments.Select(x => x as Tensor).ToArray();
                }

                inputs = inputs.Where(x => x.op.OpType == "Placeholder" &&
                                      x.op.name.StartsWith("inputs")).ToArray();

                function.ToGraph(inputs, outputs);
            }
            else if (args.ReturnValue is Tensor output)
            {
                var inputs = args.Arguments.Select(x => x as Tensor)
                             .Where(x => x.op.type == "Placeholder" && x.op.name.StartsWith("inputs"))
                             .ToArray();
                var outputs2 = array_ops.identity(output);
                function.ToGraph(inputs, outputs2);
            }

            function.Exit();

            // cache function.
            function.ReturnType  = args.ReturnValue.GetType();
            functions[func_name] = function;

            // run function
            args.ReturnValue = ConvertReturnValue(function.FilteredCall(originalInputs));
        }
Esempio n. 5
0
        public MapDataset(IDatasetV2 input_dataset,
                          Func <Tensor, Tensor> map_func,
                          bool use_inter_op_parallelism = true,
                          bool preserve_cardinality     = false,
                          bool use_legacy_function      = false) : base(input_dataset)
        {
            using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
            var input  = tf.placeholder(input_dataset.element_spec[0].dtype);
            var output = map_func(input);

            func.ToGraph(input, output);

            structure = func.OutputStructure;

            variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
                                             func,
                                             output_types,
                                             output_shapes,
                                             use_inter_op_parallelism: use_inter_op_parallelism,
                                             preserve_cardinality: preserve_cardinality);
        }