示例#1
0
        /// <summary>
        /// Creates an `Operation`.
        /// </summary>
        /// <param name="node_def">`node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.</param>
        /// <param name="g">`Graph`. The parent graph.</param>
        /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
        /// <param name="output_types">list of `DType` objects.</param>
        /// <param name="control_inputs">
        /// list of operations or tensors from which to have a
        /// control dependency.
        /// </param>
        /// <param name="input_types">
        /// List of `DType` objects representing the
        /// types of the tensors accepted by the `Operation`. By default
        /// uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
        /// reference-typed inputs must specify these explicitly.
        /// </param>
        /// <param name="original_op"></param>
        /// <param name="op_def"></param>
        public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            graph = g;

            // Build the list of control inputs.
            var control_input_ops = new List <Operation>();

            if (control_inputs != null)
            {
                foreach (var c in control_inputs)
                {
                    switch (c)
                    {
                    case Operation c1:
                        control_input_ops.Add(c1);
                        break;

                    default:
                        throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
                    }
                }
            }

            // This will be set by self.inputs.

            _id_value = graph._next_id();
            if (op_def == null)
            {
                op_def = g.GetOpDef(node_def.Op);
            }

            var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);

            _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

            // Initialize self._outputs.
            output_types = new TF_DataType[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                output_types[i] = OutputType(i);
            }

            _outputs = new Tensor[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                _outputs[i] = new Tensor(this, i, OutputType(i));
            }

            graph._add_op(this);

            if (_handle != IntPtr.Zero)
            {
                _control_flow_post_processing();
            }
        }
        public Operation(NodeDef node_def, Graph g, List <Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            _graph = g;

            _id_value = _graph._next_id();
            _c_op     = ops._create_c_op(g, node_def, inputs);
            var num_outputs = c_api.TF_OperationNumOutputs(_c_op);

            _outputs = new Tensor[num_outputs];
            for (int i = 0; i < num_outputs; i++)
            {
                _outputs[i] = new Tensor(this, i, TF_DataType.TF_FLOAT);
            }

            _graph._add_op(this);
        }
示例#3
0
        public Operation(NodeDef node_def, Graph g, List <Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            Graph = g;

            _id_value = Graph._next_id();
            if (op_def == null)
            {
                op_def = g.GetOpDef(node_def.Op);
            }

            _handle = ops._create_c_op(g, node_def, inputs);

            output_types = new TF_DataType[NumOutputs];

            for (int i = 0; i < NumOutputs; i++)
            {
                output_types[i] = OutputType(i);
            }

            Graph._add_op(this);
        }