/// <summary> /// Create a TensorProto, invoked in graph mode /// </summary> /// <param name="values"></param> /// <param name="dtype"></param> /// <param name="shape"></param> /// <param name="verify_shape"></param> /// <param name="allow_broadcast"></param> /// <returns></returns> public static TensorProto make_tensor_proto(object values, TF_DataType dtype = TF_DataType.DtInvalid, Shape?shape = null, bool verify_shape = false, bool allow_broadcast = false) { if (allow_broadcast && verify_shape) { throw new ValueError("allow_broadcast and verify_shape are not both allowed."); } if (values is TensorProto tp) { return(tp); } var origin_dtype = values.GetDataType(); if (dtype == TF_DataType.DtInvalid) { dtype = origin_dtype; } else if (origin_dtype != dtype) { var new_system_dtype = dtype.as_system_dtype(); if (values is long[] long_values) { if (dtype == TF_DataType.TF_INT32) { values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); } } else { values = Convert.ChangeType(values, new_system_dtype); } dtype = values.GetDataType(); } shape = shape ?? values.GetShape(); var tensor_proto = new TensorProto { Dtype = dtype.as_datatype_enum(), TensorShape = shape.as_shape_proto() }; if (values is NDArray nd) { // scalar if (nd.shape.IsScalar) { switch (nd.dtype) { case TF_DataType.TF_BOOL: tensor_proto.BoolVal.AddRange(nd.ToArray <bool>()); break; case TF_DataType.TF_UINT8: tensor_proto.IntVal.AddRange(nd.ToArray <byte>().Select(x => (int)x).ToArray()); break; case TF_DataType.TF_INT32: tensor_proto.IntVal.AddRange(nd.ToArray <int>()); break; case TF_DataType.TF_INT64: tensor_proto.Int64Val.AddRange(nd.ToArray <long>()); break; case TF_DataType.TF_FLOAT: tensor_proto.FloatVal.AddRange(nd.ToArray <float>()); break; case TF_DataType.TF_DOUBLE: tensor_proto.DoubleVal.AddRange(nd.ToArray <double>()); break; default: throw new Exception("make_tensor_proto Not Implemented"); } } else { var len = nd.dtypesize * nd.size; byte[] bytes = nd.ToByteArray(); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); } } else if (dtype == TF_DataType.TF_STRING && !(values is NDArray)) { if (values is string str) { tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFromUtf8(str)); } else if (values is string[] str_values) { tensor_proto.StringVal.AddRange(str_values.Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x))); } else if (values is byte[] byte_values) { tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(byte_values); } } else if (values is Array array) { // array var len = dtype.get_datatype_size() * (int)shape.size; byte[] bytes = new byte[len]; System.Buffer.BlockCopy(array, 0, bytes, 0, len); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); } else { switch (values) { case Axis val: tensor_proto.IntVal.AddRange(val.axis); break; case Shape val: tensor_proto.Int64Val.AddRange(val.dims); break; case bool val: tensor_proto.BoolVal.AddRange(new[] { val }); break; case sbyte val: tensor_proto.IntVal.AddRange(new[] { (int)val }); break; case int val: tensor_proto.IntVal.AddRange(new[] { val }); break; case long val: tensor_proto.Int64Val.AddRange(new[] { val }); break; case float val: tensor_proto.FloatVal.AddRange(new[] { val }); break; case double val: tensor_proto.DoubleVal.AddRange(new[] { val }); break; default: throw new Exception("make_tensor_proto Not Implemented"); } } return(tensor_proto); }