예제 #1
0
        FuncGraph _scratch_graph()
        {
            if (_CURRENT_SCRATCH_GRAPH == null)
            {
                _CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph");
            }

            return(_CURRENT_SCRATCH_GRAPH);
        }
        public EagerDefinedFunction(string name, FuncGraph graph,
                                    Tensors inputs, Tensors outputs,
                                    Dictionary <string, string> attrs)
        {
            _num_outputs = outputs.Length;

            var input_ops  = inputs.Select(x => x.op).ToArray();
            var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op))
                             .Select(x => x as Operation).ToArray();
            var output_names = new string[0];

            _func_graph = new FuncGraph(graph, name, attrs);
            _func_graph.ToGraph(operations, inputs, outputs, output_names);
        }
예제 #3
0
        public ConcreteFunction(Func <Tensor, Tensor> func, TF_DataType dtype)
        {
            string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";

            // IntPtr func_handle;
            using (var graph = new FuncGraph(func_name))
            {
                var input  = tf.placeholder(dtype);
                var output = func(input);

                var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
                _handle = graph.ToGraph(opers,
                                        new[] { input },
                                        new[] { output },
                                        null);
            }
        }
예제 #4
0
        public ConcreteFunction(Func <Tensor, Tensor> func, TF_DataType dtype)
        {
            string func_name = $"{func.Method.Name}_{ops.uid_function()}";

            func_graph = new FuncGraph(func_name);
            func_graph.as_default();
            var input  = tf.placeholder(dtype);
            var output = func(input);

            var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();

            func_graph.ToGraph(opers,
                               new[] { input },
                               new[] { output },
                               null);
            func_graph.Exit();
        }
예제 #5
0
        public ConcreteFunction(Func <Tensor, IDatasetV2> func, TF_DataType dtype)
        {
            string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

            func_graph = new FuncGraph(func_name);
            func_graph.as_default();

            var input  = tf.placeholder(dtype);
            var output = func(input);

            OutputStructure = output.structure;

            var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();

            func_graph.ToGraph(opers,
                               new[] { input },
                               new[] { output.variant_tensor },
                               null);
            func_graph.Exit();
        }
예제 #6
0
        public void clear_session()
        {
            tf.Context.reset_context();
            reset_uids();
            // var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
            if (_GRAPH_LEARNING_PHASES != null)
            {
                _GRAPH_LEARNING_PHASES.Clear();
            }
            if (_GRAPH_LEARNING_PHASES != null)
            {
                _GRAPH_LEARNING_PHASES.Clear();
            }
            PER_GRAPH_LAYER_NAME_UIDS.Clear();
            _CURRENT_SCRATCH_GRAPH = null;
            _GRAPH = null;

            ops.set_default_session(tf.Session(ops.get_default_graph()));
            tf.enable_eager_execution();

            GC.Collect();
            GC.WaitForPendingFinalizers();
        }
예제 #7
0
        public ConcreteFunction(Func <Tensors, Tensors> func,
                                TF_DataType[] dtypes, TensorShape[] shapes)
        {
            string func_name = $"{func.Method.Name}_{Guid.NewGuid()}";

            // IntPtr func_handle;
            using var graph = new FuncGraph(func_name);
            graph.as_default();

            var inputs = new Tensors();

            foreach (var(i, dtype) in enumerate(dtypes))
            {
                inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args"));
            }
            Outputs         = func(inputs);
            OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray();

            var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();

            _handle = graph.ToGraph(opers, inputs, Outputs, null);
            graph.Exit();
        }
예제 #8
0
        public ConcreteFunction(Func <Tensor, Tensor> func, TF_DataType dtype)
        {
            string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";

            tf.compat.v1.disable_eager_execution();

            // IntPtr func_handle;
            using (var graph = new FuncGraph(func_name))
            {
                graph.as_default();
                var input  = tf.placeholder(dtype);
                var output = func(input);

                var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
                _handle = graph.ToGraph(opers,
                                        new Operation[] { input },
                                        new Operation[] { output },
                                        null);

                c_api.TFE_ContextAddFunction(tf.Context.Handle, _handle, tf.Status.Handle);
            }

            tf.enable_eager_execution();
        }
예제 #9
0
 public ConcreteFunction(FuncGraph graph, Dictionary <string, string> attrs)
 {
     func_graph = graph;
 }
예제 #10
0
 public ConcreteFunction(string name)
 {
     func_graph = new FuncGraph(name);
 }
예제 #11
0
        public ConcreteFunction(FuncGraph graph, Dictionary <string, string> attrs)
        {
            func_graph = graph;

            ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
        }
 public TapeGradientFunctions(FuncGraph func_graph,
                              bool need_gradients_for_jvps)
 {
     _func_graph = func_graph;
 }
 public FirstOrderTapeGradientFunctions(FuncGraph func_graph,
                                        bool need_gradients_for_jvps) : base(func_graph,
                                                                             need_gradients_for_jvps)
 {
 }