private static bool SetOpAttrScalar(Context ctx, IntPtr op, string key, object value, TF_AttrType type, Dictionary <string, long> attr_list_sizes, Status status) { switch (type) { case TF_AttrType.TF_ATTR_STRING: c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length); break; case TF_AttrType.TF_ATTR_TYPE: c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); break; case TF_AttrType.TF_ATTR_BOOL: c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value)); break; case TF_AttrType.TF_ATTR_INT: c_api.TFE_OpSetAttrInt(op, key, Convert.ToInt64(value)); break; case TF_AttrType.TF_ATTR_SHAPE: var dims = (value as int[]).Select(x => (long)x).ToArray(); c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); status.Check(true); break; default: throw new NotImplementedException($"SetOpAttrScalar for {type}"); } return(true); }
private static bool SetOpAttrList(Context ctx, IntPtr op, string key, object value, TF_AttrType type, Dictionary <string, long> attr_list_sizes, Status status) { return(false); }
bool SetOpAttrList(Context ctx, SafeOpHandle op, string key, object value, TF_AttrType type, Dictionary <string, long> attr_list_sizes, Status status) { return(false); }
bool SetOpAttrScalar(Context ctx, SafeOpHandle op, string key, object value, TF_AttrType type, Dictionary <string, long> attr_list_sizes, Status status) { switch (type) { case TF_AttrType.TF_ATTR_STRING: c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length); break; case TF_AttrType.TF_ATTR_TYPE: c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); break; case TF_AttrType.TF_ATTR_BOOL: c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value)); break; case TF_AttrType.TF_ATTR_INT: var size = Convert.ToInt64(value); c_api.TFE_OpSetAttrInt(op, key, size); if (attr_list_sizes != null) { attr_list_sizes[key] = size; } break; case TF_AttrType.TF_ATTR_FLOAT: c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); break; case TF_AttrType.TF_ATTR_SHAPE: var dims = (value as int[]).Select(x => (long)x).ToArray(); c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); status.Check(true); break; case TF_AttrType.TF_ATTR_FUNC: if (value is ConcreteFunction func) { c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); } else { throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); } break; default: throw new NotImplementedException($"SetOpAttrScalar for {type}"); } return(true); }
bool SetOpAttrList(Context ctx, SafeOpHandle op, string key, object values, TF_AttrType type, Dictionary <string, long> attr_list_sizes, Status status) { if (type == TF_AttrType.TF_ATTR_STRING && values is string[] values3) { c_api.TFE_OpSetAttrStringList(op, key, new IntPtr[0], values3.Select(x => x.Length).ToArray(), values3.Length); attr_list_sizes[key] = values3.Length; } else if (type == TF_AttrType.TF_ATTR_SHAPE && values is TensorShape[] values1) { // Make one pass through the input counting the total number of // dims across all the input lists. var num_values = values1.Length; attr_list_sizes[key] = num_values; var dims = new IntPtr[num_values]; var num_dims = values1.Select(x => x.ndim).ToArray(); for (int i = 0; i < num_values; ++i) { dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim); tf.memcpy(dims[i], values1[i].dims.Select(x => (long)x).ToArray(), values1[i].ndim * sizeof(long)); } c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); } else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) { c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length); attr_list_sizes[key] = values2.Length; } else if (type == TF_AttrType.TF_ATTR_INT && values is int[] values4) { c_api.TFE_OpSetAttrIntList(op, key, values4.Select(x => Convert.ToInt64(x)).ToArray(), values4.Length); attr_list_sizes[key] = values4.Length; } else { throw new NotImplementedException(""); } return(true); }
private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size) { var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_); EXPECT_EQ(TF_Code.TF_OK, s_.Code); char e = expected_list_size >= 0 ? (char)1 : (char)0; /*EXPECT_EQ(e, m.is_list); * EXPECT_EQ(expected_list_size, m.list_size); * EXPECT_EQ(expected_type, m.type); * EXPECT_EQ(expected_total_size, m.total_size);*/ }