public void SetOpAttrs(SafeEagerOpHandle op, params object[] attrs) { var status = tf.Status; var len = attrs.Length; for (int i = 0; i < len; i += 2) { var key = attrs[i].ToString(); var value = attrs[i + 1]; byte is_list = 0; var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle); if (!status.ok()) { return; } if (is_list != 0) { SetOpAttrList(tf.Context, op, key, value as object[], type, null, status); } else { SetOpAttrScalar(tf.Context, op, key, value, type, null, status); } status.Check(true); } }
/// <summary> /// This function will set the op attrs required. If an attr has the value of /// None, then it will read the AttrDef to get the default value and set that /// instead. Any failure in this function will simply fall back to the slow /// path. /// </summary> /// <param name="ctx"></param> /// <param name="op"></param> /// <param name="attr"></param> /// <param name="attr_name"></param> /// <param name="attr_value"></param> /// <param name="attr_list_sizes"></param> /// <param name="status"></param> void SetOpAttrWithDefaults(Context ctx, SafeEagerOpHandle op, AttrDef attr, string attr_name, object attr_value, Dictionary <string, long> attr_list_sizes, Status status) { byte is_list = 0; var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status.Handle); if (status.Code != TF_Code.TF_OK) { return; } if (attr_value == null) { } else { if (is_list != 0) { SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, status); } else { SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, status); } } }
/// <summary> /// Adds input and type attr to the op, and to the list of flattened /// inputs/attrs. /// </summary> /// <param name="inputs"></param> /// <param name="add_type_attr"></param> /// <param name="input_arg"></param> /// <param name="op"></param> /// <param name="status"></param> /// <returns></returns> bool AddInputToOp(object inputs, bool add_type_attr, ArgDef input_arg, List <object> flattened_attrs, List <Tensor> flattened_inputs, SafeEagerOpHandle op, Status status) { var tensor = tf.convert_to_tensor(inputs); flattened_inputs.Add(tensor); if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr)) { var dtype = tensor.dtype; c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype); flattened_attrs.Add(input_arg.TypeAttr); flattened_attrs.Add(dtype); } c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status.Handle); status.Check(true); return(true); }
bool SetOpAttrScalar(Context ctx, SafeEagerOpHandle op, string key, object value, TF_AttrType type, Dictionary <string, long> attr_list_sizes, Status status) { switch (type) { case TF_AttrType.TF_ATTR_STRING: c_api.TFE_OpSetAttrString(op, key, value.ToString(), (ulong)value.ToString().Length); break; case TF_AttrType.TF_ATTR_TYPE: c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); break; case TF_AttrType.TF_ATTR_BOOL: c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value)); break; case TF_AttrType.TF_ATTR_INT: var size = Convert.ToInt64(value); c_api.TFE_OpSetAttrInt(op, key, size); if (attr_list_sizes != null) { attr_list_sizes[key] = size; } break; case TF_AttrType.TF_ATTR_FLOAT: c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); break; case TF_AttrType.TF_ATTR_SHAPE: var dims = (value as long[]).ToArray(); c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); status.Check(true); break; case TF_AttrType.TF_ATTR_FUNC: if (value is ConcreteFunction func) { c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); } else { throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); } break; default: throw new NotImplementedException($"SetOpAttrScalar for {type}"); } return(true); }
bool SetOpAttrList(Context ctx, SafeEagerOpHandle op, string key, object values, TF_AttrType type, Dictionary <string, long> attr_list_sizes, Status status) { if (type == TF_AttrType.TF_ATTR_STRING && values is string[] values3) { c_api.TFE_OpSetAttrStringList(op, key, values3, values3.Select(x => Convert.ToUInt64(x.Length)).ToArray(), values3.Length); attr_list_sizes[key] = values3.Length; } else if (type == TF_AttrType.TF_ATTR_SHAPE && values is Shape[] values1) { // Make one pass through the input counting the total number of // dims across all the input lists. var num_values = values1.Length; attr_list_sizes[key] = num_values; var dims = new IntPtr[num_values]; var num_dims = values1.Select(x => x.ndim).ToArray(); for (int i = 0; i < num_values; ++i) { dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim); tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long)); } c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); } else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) { c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length); attr_list_sizes[key] = values2.Length; } else if (type == TF_AttrType.TF_ATTR_INT && values is int[] values4) { c_api.TFE_OpSetAttrIntList(op, key, values4.Select(x => Convert.ToInt64(x)).ToArray(), values4.Length); attr_list_sizes[key] = values4.Length; } else { throw new NotImplementedException(""); } return(true); }