コード例 #1
0
ファイル: importer.cs プロジェクト: sharwell/TensorFlow.NET
        public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options,
                                                            string prefix,
                                                            Dictionary <string, Tensor> input_map,
                                                            string[] return_elements)
        {
            c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix);
            c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1);

            foreach (var input in input_map)
            {
                throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
            }

            if (return_elements == null)
            {
                return_elements = new string[0];
            }

            foreach (var name in return_elements)
            {
                if (name.Contains(":"))
                {
                    var(op_name, index) = _ParseTensorName(name);
                    c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index);
                }
                else
                {
                    c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name);
                }
            }

            // c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints);
        }
コード例 #2
0
        public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options, 
            string prefix, 
            Dictionary<string, Tensor> input_map,
            string[] return_elements)
        {
            c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix);
            c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1);

            foreach(var input in input_map)
            {
                var (src_name, src_index) = _ParseTensorName(input.Key);
                c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Handle, src_name, src_index, input.Value._as_tf_output());
            }

            if (return_elements == null)
                return_elements = new string[0];

            foreach (var name in return_elements)
            {
                if(name.Contains(":"))
                {
                    var (op_name, index) = _ParseTensorName(name);
                    c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index);
                }
                else
                {
                    c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name);
                }
            }

            // c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints);
        }
コード例 #3
0
 public bool Import(byte[] bytes, string prefix = "")
 {
     using (var opts = new ImportGraphDefOptions())
         using (var status = new Status())
             using (var graph_def = new Buffer(bytes))
             {
                 c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix);
                 c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle);
                 status.Check(true);
                 return(status.Code == TF_Code.TF_OK);
             }
 }
コード例 #4
0
        public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s)
        {
            var num_return_outputs = opts.NumReturnOutputs;
            var return_outputs     = new TF_Output[num_return_outputs];
            int size = Marshal.SizeOf <TF_Output>();
            var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);

            c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);
            for (int i = 0; i < num_return_outputs; i++)
            {
                var handle = return_output_handle + i * size;
                return_outputs[i] = Marshal.PtrToStructure <TF_Output>(handle);
            }

            return(return_outputs);
        }
コード例 #5
0
        public unsafe TF_Output[] ImportGraphDefWithReturnOutputs(Buffer graph_def, ImportGraphDefOptions opts, Status s)
        {
            as_default();
            var num_return_outputs = opts.NumReturnOutputs;
            var return_outputs     = new TF_Output[num_return_outputs];
            int size = Marshal.SizeOf <TF_Output>();
            var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);

            c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);

            var tf_output_ptr = (TF_Output *)return_output_handle;

            for (int i = 0; i < num_return_outputs; i++)
            {
                return_outputs[i] = *(tf_output_ptr + i);
            }

            Marshal.FreeHGlobal(return_output_handle);

            return(return_outputs);
        }
コード例 #6
0
ファイル: importer.py.cs プロジェクト: aheak/TensorFlow.NET
        public static void _PopulateTFImportGraphDefOptions(ImportGraphDefOptions options,
                                                            string prefix,
                                                            Dictionary <string, Tensor> input_map,
                                                            string[] return_elements)
        {
            c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix);
            c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1);

            foreach (var input in input_map)
            {
                throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
            }

            if (return_elements == null)
            {
                return_elements = new string[0];
            }

            foreach (var name in return_elements)
            {
                throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
            }
        }