Ejemplo n.º 1
0
        /// <summary>
        /// Doesn't work because the op can't be found on binary
        /// </summary>
        /// <returns></returns>
        private static OpDef op_NearestNeighbors()
        {
            var def = new OpDef
            {
                Name = "NearestNeighbors"
            };

            def.InputArg.Add(new ArgDef {
                Name = "points", Type = DataType.DtFloat
            });
            def.InputArg.Add(new ArgDef {
                Name = "centers", Type = DataType.DtFloat
            });
            def.InputArg.Add(new ArgDef {
                Name = "k", Type = DataType.DtInt64
            });
            def.OutputArg.Add(new ArgDef {
                Name = "nearest_center_indices", Type = DataType.DtInt64
            });
            def.OutputArg.Add(new ArgDef {
                Name = "nearest_center_distances", Type = DataType.DtFloat
            });

            return(def);
        }
Ejemplo n.º 2
0
        /// <summary>
        /// Creates a TF_Operation.
        /// </summary>
        /// <param name="graph">a `Graph`.</param>
        /// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
        /// <param name="inputs">
        /// A list of `Tensor`s (corresponding to scalar inputs) and lists of
        /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
        /// "list(int64)"). The length of the list should be equal to the number of
        /// inputs specified by this operation's op def.
        /// </param>
        /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
        /// <returns>A wrapped TF_Operation*.</returns>
        public static (IntPtr, OperationDescription) _create_c_op(Graph graph, NodeDef node_def, Tensor[] inputs, Operation[] control_inputs,
                                                                  OpDef op_def = null)
        {
            if (op_def == null)
            {
                op_def = graph.GetOpDef(node_def.Op);
            }

            var input_tensors = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);

            lock (Locks.ProcessWide)
            {
                var op_desc = graph.NewOperation(node_def.Op, node_def.Name);

                if (!string.IsNullOrEmpty(node_def.Device))
                {
                    c_api.TF_SetDevice(op_desc, node_def.Device);
                }

                // Add inputs
                foreach (var op_input in input_tensors)
                {
                    if (op_input.IsList)
                    {
                        c_api.TF_AddInputList(op_desc, op_input.Select(x => x._as_tf_output()).ToArray(), op_input.Count());
                    }
                    else if (op_input.Count() == 1)
                    {
                        c_api.TF_AddInput(op_desc, op_input[0]._as_tf_output());
                    }
                }

                var status = tf.Status;

                // Add control inputs
                foreach (var control_input in control_inputs)
                {
                    c_api.TF_AddControlInput(op_desc, control_input);
                }

                // Add attrs
                foreach (var attr in node_def.Attr)
                {
                    var bytes       = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
                    var protoHandle = Marshal.AllocHGlobal(bytes.Length);
                    Marshal.Copy(bytes, 0, protoHandle, bytes.Length);
                    uint len = (uint)bytes.Length;
                    c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status.Handle);
                    status.Check(true);
                    Marshal.FreeHGlobal(protoHandle);
                }

                var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);

                status.Check(true);

                return(c_op, op_desc);
            }
        }
Ejemplo n.º 3
0
        /// <summary>
        /// Creates a TF_Operation.
        /// </summary>
        /// <param name="graph">a `Graph`.</param>
        /// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
        /// <param name="inputs">
        /// A list of `Tensor`s (corresponding to scalar inputs) and lists of
        /// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
        /// "list(int64)"). The length of the list should be equal to the number of
        /// inputs specified by this operation's op def.
        /// </param>
        /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
        /// <returns>A wrapped TF_Operation*.</returns>
        public static (IntPtr, OperationDescription) _create_c_op(Graph graph, NodeDef node_def, Tensor[] inputs, Operation[] control_inputs,
                                                                  OpDef op_def = null)
        {
            if (op_def == null)
            {
                op_def = graph.GetOpDef(node_def.Op);
            }

            var input_tensors = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);

            var op_desc = graph.NewOperation(node_def.Op, node_def.Name);

            if (!string.IsNullOrEmpty(node_def.Device))
            {
                c_api.TF_SetDevice(op_desc, node_def.Device);
            }

            // Add inputs
            foreach (var op_input in input_tensors)
            {
                if (op_input.IsList)
                {
                    c_api.TF_AddInputList(op_desc, op_input.Select(x => x._as_tf_output()).ToArray(), op_input.Count());
                }
                else if (op_input.Count() == 1)
                {
                    c_api.TF_AddInput(op_desc, op_input[0]._as_tf_output());
                }
            }

            var status = tf.Status;

            // Add control inputs
            foreach (var control_input in control_inputs)
            {
                c_api.TF_AddControlInput(op_desc, control_input);
            }

            // Add attrs
            foreach (var attr in node_def.Attr)
            {
                var bytes = attr.Value.ToByteArray();
                c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle);
                status.Check(true);
            }

            var c_op = op_desc.FinishOperation(status);

            status.Check(true);

            return(c_op, op_desc);
        }
Ejemplo n.º 4
0
 private static void _SetDefaultAttrValues(NodeDef node_def, OpDef op_def)
 {
     foreach (var attr_def in op_def.Attr)
     {
         var key = attr_def.Name;
         if (attr_def.DefaultValue != null)
         {
             var value = node_def.Attr[key];
             if (value == null)
             {
                 node_def.Attr[key] = attr_def.DefaultValue;
             }
         }
     }
 }
Ejemplo n.º 5
0
        private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField <string, AttrValue> attrs)
        {
            var  grouped_inputs = new List <object>();
            int  i           = 0;
            int  input_len   = 0;
            bool is_sequence = false;

            foreach (var input_arg in op_def.InputArg)
            {
                if (!string.IsNullOrEmpty(input_arg.NumberAttr))
                {
                    input_len   = (int)attrs[input_arg.NumberAttr].I;
                    is_sequence = true;
                }
                else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
                {
                    input_len   = attrs[input_arg.TypeListAttr].List.Type.Count;
                    is_sequence = true;
                }
                else
                {
                    input_len   = 1;
                    is_sequence = false;
                }

                if (is_sequence)
                {
                    grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray());
                }
                else
                {
                    grouped_inputs.Add(inputs[i]);
                }

                i += input_len;
            }

            return(grouped_inputs.ToArray());
        }
Ejemplo n.º 6
0
        /// <summary>
        /// Creates an `Operation`.
        /// </summary>
        /// <param name="node_def">`node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.</param>
        /// <param name="g">`Graph`. The parent graph.</param>
        /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
        /// <param name="output_types">list of `DType` objects.</param>
        /// <param name="control_inputs">
        /// list of operations or tensors from which to have a
        /// control dependency.
        /// </param>
        /// <param name="input_types">
        /// List of `DType` objects representing the
        /// types of the tensors accepted by the `Operation`. By default
        /// uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
        /// reference-typed inputs must specify these explicitly.
        /// </param>
        /// <param name="original_op"></param>
        /// <param name="op_def"></param>
        public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            _graph = g;

            // Build the list of control inputs.
            var control_input_ops = new List <Operation>();

            if (control_inputs != null)
            {
                foreach (var c in control_inputs)
                {
                    switch (c)
                    {
                    case Operation c1:
                        control_input_ops.Add(c1);
                        break;

                    case Tensor tensor:
                        control_input_ops.Add(tensor.op);
                        break;

                    // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
                    //case IndexedSlices islices:
                    //    control_input_ops.Add(islices.op);
                    //    break;
                    default:
                        throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
                    }
                }
            }

            // Dict mapping op name to file and line information for op colocation
            // context managers.
            _control_flow_context = graph._get_control_flow_context();

            // This will be set by self.inputs.
            if (op_def == null)
            {
                op_def = g.GetOpDef(node_def.Op);
            }

            var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);

            (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

            // Initialize self._outputs.
            output_types = new TF_DataType[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                output_types[i] = OutputType(i);
            }

            _outputs = new Tensor[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                _outputs[i] = new Tensor(this, i, OutputType(i));
            }

            graph._add_op(this);

            if (_handle != IntPtr.Zero)
            {
                _control_flow_post_processing();
            }
        }
Ejemplo n.º 7
0
        public void TestInception()
        {
            using (Tensor imageTensor = ImageIO.ReadTensorFromImageFile <float>("grace_hopper.jpg", 224, 224, 128.0f, 1.0f))
                using (Inception inceptionGraph = new Inception())
                {
                    bool processCompleted = false;
                    inceptionGraph.OnDownloadCompleted += (sender, e) =>
                    {
                        HashSet <string> opNames        = new HashSet <string>();
                        HashSet <string> couldBeInputs  = new HashSet <string>();
                        HashSet <string> couldBeOutputs = new HashSet <string>();
                        foreach (Operation op in inceptionGraph.Graph)
                        {
                            String name = op.Name;
                            opNames.Add(name);

                            if (op.NumInputs == 0 && op.OpType.Equals("Placeholder"))
                            {
                                couldBeInputs.Add(op.Name);
                                AttrMetadata dtypeMeta   = op.GetAttrMetadata("dtype");
                                AttrMetadata shapeMeta   = op.GetAttrMetadata("shape");
                                DataType     type        = op.GetAttrType("dtype");
                                Int64[]      shape       = op.GetAttrShape("shape");
                                Buffer       valueBuffer = op.GetAttrValueProto("shape");
                                Buffer       shapeBuffer = op.GetAttrTensorShapeProto("shape");
                                Tensorflow.TensorShapeProto shapeProto =
                                    Tensorflow.TensorShapeProto.Parser.ParseFrom(shapeBuffer.Data);
                            }

                            if (op.OpType.Equals("Const"))
                            {
                                AttrMetadata dtypeMeta = op.GetAttrMetadata("dtype");
                                AttrMetadata valueMeta = op.GetAttrMetadata("value");
                                using (Tensor valueTensor = op.GetAttrTensor("value"))
                                {
                                    var dim = valueTensor.Dim;
                                }
                            }

                            if (op.OpType.Equals("Conv2D"))
                            {
                                AttrMetadata stridesMeta = op.GetAttrMetadata("strides");
                                AttrMetadata paddingMeta = op.GetAttrMetadata("padding");
                                AttrMetadata boolMeta    = op.GetAttrMetadata("use_cudnn_on_gpu");
                                Int64[]      strides     = op.GetAttrIntList("strides");
                                bool         useCudnn    = op.GetAttrBool("use_cudnn_on_gpu");
                                String       padding     = op.GetAttrString("padding");
                            }

                            foreach (Output output in op.Outputs)
                            {
                                int[] shape = inceptionGraph.Graph.GetTensorShape(output);
                                if (output.NumConsumers == 0)
                                {
                                    couldBeOutputs.Add(name);
                                }
                            }

                            Buffer           buffer = inceptionGraph.Graph.GetOpDef(op.OpType);
                            Tensorflow.OpDef opDef  = Tensorflow.OpDef.Parser.ParseFrom(buffer.Data);
                        }

                        using (Buffer versionDef = inceptionGraph.Graph.Versions())
                        {
                            int l = versionDef.Length;
                        }

                        Inception.RecognitionResult[] results = inceptionGraph.Recognize(imageTensor);

                        Trace.WriteLine(String.Format("Object is {0} with {1}% probability", results[0].Label, results[0].Probability * 100));

                        processCompleted = true;
                    };

                    inceptionGraph.Init();
                    while (!processCompleted)
                    {
                        Thread.Sleep(1000);
                    }
                }
        }
Ejemplo n.º 8
0
        /// <summary>
        /// Creates an `Operation`.
        /// </summary>
        /// <param name="node_def">`node_def_pb2.NodeDef`.  `NodeDef` for the `Operation`.</param>
        /// <param name="g">`Graph`. The parent graph.</param>
        /// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
        /// <param name="output_types">list of `DType` objects.</param>
        /// <param name="control_inputs">
        /// list of operations or tensors from which to have a
        /// control dependency.
        /// </param>
        /// <param name="input_types">
        /// List of `DType` objects representing the
        /// types of the tensors accepted by the `Operation`. By default
        /// uses `[x.dtype.base_dtype for x in inputs]`.  Operations that expect
        /// reference-typed inputs must specify these explicitly.
        /// </param>
        /// <param name="original_op"></param>
        /// <param name="op_def"></param>
        public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            _graph = g;

            // Build the list of control inputs.
            var control_input_ops = new List <Operation>();

            if (control_inputs != null)
            {
                foreach (var c in control_inputs)
                {
                    switch (c)
                    {
                    case Operation c1:
                        control_input_ops.Add(c1);
                        break;

                    default:
                        throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
                    }
                }
            }

            // This will be set by self.inputs.
            if (op_def == null)
            {
                op_def = g.GetOpDef(node_def.Op);
            }

            var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);

            _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

            // Initialize self._outputs.
            output_types = new TF_DataType[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                output_types[i] = OutputType(i);
            }

            _outputs = new Tensor[NumOutputs];
            for (int i = 0; i < NumOutputs; i++)
            {
                _outputs[i] = new Tensor(this, i, OutputType(i));
            }

            graph._add_op(this);

            if (_handle != IntPtr.Zero)
            {
                _control_flow_post_processing();
            }
        }
Ejemplo n.º 9
0
 private static AttrDef _FindAttrInOpDef(string name, OpDef op_def)
 {
     return(op_def.Attr.FirstOrDefault(x => x.Name == name));
 }
Ejemplo n.º 10
0
        public async Task TestResnet()
        {
            using (Tensor imageTensor = ImageIO.ReadTensorFromImageFile <float>("surfers.jpg", 224, 224, 0, 1.0f / 255.0f))
                using (Resnet resnet = new Resnet())
                {
                    await resnet.Init();

                    MetaGraphDef metaGraphDef = MetaGraphDef.Parser.ParseFrom(resnet.MetaGraphDefBuffer.Data);
                    var          signatureDef = metaGraphDef.SignatureDef["serving_default"];
                    var          inputNode    = signatureDef.Inputs;
                    var          outputNode   = signatureDef.Outputs;

                    HashSet <string> opNames        = new HashSet <string>();
                    HashSet <string> couldBeInputs  = new HashSet <string>();
                    HashSet <string> couldBeOutputs = new HashSet <string>();
                    foreach (Operation op in resnet.Graph)
                    {
                        String name = op.Name;
                        opNames.Add(name);

                        if (op.NumInputs == 0 && op.OpType.Equals("Placeholder"))
                        {
                            couldBeInputs.Add(op.Name);
                            AttrMetadata dtypeMeta   = op.GetAttrMetadata("dtype");
                            AttrMetadata shapeMeta   = op.GetAttrMetadata("shape");
                            DataType     type        = op.GetAttrType("dtype");
                            Int64[]      shape       = op.GetAttrShape("shape");
                            Buffer       valueBuffer = op.GetAttrValueProto("shape");
                            Buffer       shapeBuffer = op.GetAttrTensorShapeProto("shape");
                            Tensorflow.TensorShapeProto shapeProto =
                                Tensorflow.TensorShapeProto.Parser.ParseFrom(shapeBuffer.Data);
                        }

                        if (op.OpType.Equals("Const"))
                        {
                            AttrMetadata dtypeMeta = op.GetAttrMetadata("dtype");
                            AttrMetadata valueMeta = op.GetAttrMetadata("value");
                            using (Tensor valueTensor = op.GetAttrTensor("value"))
                            {
                                var dim = valueTensor.Dim;
                            }
                        }

                        if (op.OpType.Equals("Conv2D"))
                        {
                            AttrMetadata stridesMeta = op.GetAttrMetadata("strides");
                            AttrMetadata paddingMeta = op.GetAttrMetadata("padding");
                            AttrMetadata boolMeta    = op.GetAttrMetadata("use_cudnn_on_gpu");
                            Int64[]      strides     = op.GetAttrIntList("strides");
                            bool         useCudnn    = op.GetAttrBool("use_cudnn_on_gpu");
                            String       padding     = op.GetAttrString("padding");
                        }

                        foreach (Output output in op.Outputs)
                        {
                            int[] shape = resnet.Graph.GetTensorShape(output);
                            if (output.NumConsumers == 0)
                            {
                                couldBeOutputs.Add(name);
                            }
                        }

                        Buffer           buffer = resnet.Graph.GetOpDef(op.OpType);
                        Tensorflow.OpDef opDef  = Tensorflow.OpDef.Parser.ParseFrom(buffer.Data);
                    }

                    using (Buffer versionDef = resnet.Graph.Versions())
                    {
                        int l = versionDef.Length;
                    }

                    Resnet.RecognitionResult[][] results = resnet.Recognize(imageTensor);
                }
        }
Ejemplo n.º 11
0
        public unsafe Operation create_op(string op_type, List <Tensor> inputs, TF_DataType[] dtypes,
                                          TF_DataType[] input_types            = null, string name  = "",
                                          Dictionary <string, AttrValue> attrs = null, OpDef op_def = null)
        {
            if (String.IsNullOrEmpty(name))
            {
                name = op_type;
            }

            name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
            var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);

            var op = new Operation(node_def,
                                   this,
                                   inputs: inputs,
                                   output_types: dtypes,
                                   control_inputs: new object[] { },
                                   input_types: input_types,
                                   original_op: null,
                                   op_def: op_def);

            _create_op_helper(op, true);
            return(op);
        }
 public void add_op(OpDef op_def)
 {
     _ops[op_def.Name] = op_def;
 }
Ejemplo n.º 13
0
        public Operation(NodeDef node_def, Graph g, List <Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            _graph = g;

            _id_value = _graph._next_id();
            _c_op     = ops._create_c_op(g, node_def, inputs);
            var num_outputs = c_api.TF_OperationNumOutputs(_c_op);

            _outputs = new Tensor[num_outputs];
            for (int i = 0; i < num_outputs; i++)
            {
                _outputs[i] = new Tensor(this, i, TF_DataType.TF_FLOAT);
            }

            _graph._add_op(this);
        }
Ejemplo n.º 14
0
        public Operation(NodeDef node_def, Graph g, List <Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
        {
            Graph = g;

            _id_value = Graph._next_id();
            if (op_def == null)
            {
                op_def = g.GetOpDef(node_def.Op);
            }

            _handle = ops._create_c_op(g, node_def, inputs);

            _outputs     = new Tensor[NumOutputs];
            output_types = new TF_DataType[NumOutputs];

            for (int i = 0; i < NumOutputs; i++)
            {
                output_types[i] = OutputType(i);
                _outputs[i]     = new Tensor(this, i, output_types[i]);
            }

            Graph._add_op(this);
        }
Ejemplo n.º 15
0
        private void SetAttrs(string op_type_name,
                              ArgDef input_arg,
                              OpDef op_def,
                              Dictionary <string, object> attrs,
                              Dictionary <string, object> inferred_from,
                              List <TF_DataType> types,
                              List <TF_DataType> base_types,
                              List <TF_DataType> input_types,
                              dynamic values)
        {
            var input_name = input_arg.Name;

            if (!string.IsNullOrEmpty(input_arg.NumberAttr))
            {
                if (attrs.ContainsKey(input_arg.NumberAttr))
                {
                }
                else
                {
                    attrs[input_arg.NumberAttr]         = (values as Tensor[]).Length;
                    inferred_from[input_arg.NumberAttr] = input_name;
                    var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr);
                    if (num_attr.HasMinimum && (values as Tensor[]).Length < num_attr.Minimum)
                    {
                        throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " +
                                             $"than minimum length {num_attr.Minimum}");
                    }
                }

                // All tensors must have the same base type.
                if (input_arg.Type != DataType.DtInvalid)
                {
                }
                else
                {
                    attrs[input_arg.TypeAttr]         = base_types[0];
                    inferred_from[input_arg.TypeAttr] = input_name;
                    var type_attr = op_def.Attr.First(x => x.Name == input_arg.TypeAttr);
                }
            }
            else if (!string.IsNullOrEmpty(input_arg.TypeAttr))
            {
                var attr_value = base_types[0];
                if (attrs.ContainsKey(input_arg.TypeAttr))
                {
                }
                else
                {
                    attrs[input_arg.TypeAttr]         = attr_value;
                    inferred_from[input_arg.TypeAttr] = input_name;
                }
            }
            else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
            {
                var attr_value = base_types;
                if (attrs.ContainsKey(input_arg.TypeListAttr))
                {
                }
                else
                {
                    attrs[input_arg.TypeListAttr]         = attr_value;
                    inferred_from[input_arg.TypeListAttr] = input_name;
                }
            }

            if (input_arg.IsRef)
            {
                input_types.AddRange(types);
            }
            else
            {
                input_types.AddRange(base_types);
            }
        }
Ejemplo n.º 16
0
        private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value)
        {
            var attr_value = new AttrValue();

            if (attr_def.Type.StartsWith("list("))
            {
                if (attr_def.HasMinimum)
                {
                    ;
                }
                attr_value.List = new AttrValue.Types.ListValue();
            }

            switch (attr_def.Type)
            {
            case "string":
                attr_value.S = Google.Protobuf.ByteString.CopyFromUtf8((string)value);
                break;

            case "type":
                attr_value.Type = _MakeType((TF_DataType)value, attr_def);
                break;

            case "list(type)":
                attr_value.List.Type.AddRange((value as IList <TF_DataType>).Select(x => _MakeType(x, attr_def)));
                break;

            case "list(int)":
                attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x)));
                break;

            case "bool":
                attr_value.B = (bool)value;
                break;

            case "float":
                attr_value.F = (float)value;
                break;

            case "int":
                attr_value.I = (int)value;
                if (attr_def.HasMinimum && attr_value.I < attr_def.Minimum)
                {
                    throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}.");
                }
                break;

            case "shape":
                if (value == null && attr_def.DefaultValue != null)
                {
                    attr_value.Shape = attr_def.DefaultValue.Shape;
                }

                if (value is TensorShape val1)
                {
                    attr_value.Shape = val1.as_proto();
                }
                else if (value is long[] val2)
                {
                    attr_value.Shape = tensor_util.as_shape(val2);
                }
                else if (value is int[] val3)
                {
                    attr_value.Shape = tensor_util.as_shape(val3);
                }

                break;

            default:
                throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
            }

            return(attr_value);
        }
Ejemplo n.º 17
0
        private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value)
        {
            var attr_value = new AttrValue();

            if (attr_def.Type.StartsWith("list("))
            {
                if (attr_def.HasMinimum)
#pragma warning disable CS0642 // Possible mistaken empty statement
                {
                    ;
                }
#pragma warning restore CS0642 // Possible mistaken empty statement
                attr_value.List = new AttrValue.Types.ListValue();
            }

            switch (attr_def.Type)
            {
            case "string":
                attr_value.S = _MakeStr((string)value, attr_def);
                break;

            case "type":
                attr_value.Type = _MakeType((TF_DataType)value, attr_def);
                break;

            case "list(type)":
                attr_value.List.Type.AddRange((value as IList <TF_DataType>).Select(x => _MakeType(x, attr_def)));
                break;

            case "list(int)":
                if (value != null)
                {
                    attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x)));
                }
                break;

            case "bool":
                attr_value.B = (bool)value;
                break;

            case "float":
                attr_value.F = (float)value;
                break;

            case "int":
                if (value is long value_long)
                {
                    attr_value.I = value_long;
                }
                else
                {
                    attr_value.I = Convert.ToInt64(value);
                }
                if (attr_def.HasMinimum && attr_value.I < attr_def.Minimum)
                {
                    throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}.");
                }
                break;

            case "shape":
                if (value == null && attr_def.DefaultValue != null)
                {
                    attr_value.Shape = attr_def.DefaultValue.Shape;
                }

                if (value is Shape val1)
                {
                    attr_value.Shape = val1.as_proto();
                }
                else if (value is long[] val2)
                {
                    attr_value.Shape = tensor_util.as_shape(val2);
                }
                else if (value is int[] val3)
                {
                    attr_value.Shape = tensor_util.as_shape(val3);
                }

                break;

            case "list(shape)":
                attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def)));
                break;

            default:
                throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
            }

            return(attr_value);
        }