private static bool SetOpAttrScalar(Context ctx, IntPtr op,
                                            string key, object value, TF_AttrType type,
                                            Dictionary <string, long> attr_list_sizes,
                                            Status status)
        {
            switch (type)
            {
            case TF_AttrType.TF_ATTR_STRING:
                c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length);
                break;

            case TF_AttrType.TF_ATTR_TYPE:
                c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value);
                break;

            case TF_AttrType.TF_ATTR_BOOL:
                c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value));
                break;

            case TF_AttrType.TF_ATTR_INT:
                c_api.TFE_OpSetAttrInt(op, key, Convert.ToInt64(value));
                break;

            case TF_AttrType.TF_ATTR_SHAPE:
                var dims = (value as int[]).Select(x => (long)x).ToArray();
                c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
                status.Check(true);
                break;

            default:
                throw new NotImplementedException($"SetOpAttrScalar for {type}");
            }

            return(true);
        }
 private static bool SetOpAttrList(Context ctx, IntPtr op,
                                   string key, object value, TF_AttrType type,
                                   Dictionary <string, long> attr_list_sizes,
                                   Status status)
 {
     return(false);
 }
Exemple #3
0
 bool SetOpAttrList(Context ctx, SafeOpHandle op,
                    string key, object value, TF_AttrType type,
                    Dictionary <string, long> attr_list_sizes,
                    Status status)
 {
     return(false);
 }
        bool SetOpAttrScalar(Context ctx, SafeOpHandle op,
                             string key, object value, TF_AttrType type,
                             Dictionary <string, long> attr_list_sizes,
                             Status status)
        {
            switch (type)
            {
            case TF_AttrType.TF_ATTR_STRING:
                c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length);
                break;

            case TF_AttrType.TF_ATTR_TYPE:
                c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value);
                break;

            case TF_AttrType.TF_ATTR_BOOL:
                c_api.TFE_OpSetAttrBool(op, key, Convert.ToBoolean(value));
                break;

            case TF_AttrType.TF_ATTR_INT:
                var size = Convert.ToInt64(value);
                c_api.TFE_OpSetAttrInt(op, key, size);
                if (attr_list_sizes != null)
                {
                    attr_list_sizes[key] = size;
                }
                break;

            case TF_AttrType.TF_ATTR_FLOAT:
                c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
                break;

            case TF_AttrType.TF_ATTR_SHAPE:
                var dims = (value as int[]).Select(x => (long)x).ToArray();
                c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle);
                status.Check(true);
                break;

            case TF_AttrType.TF_ATTR_FUNC:
                if (value is ConcreteFunction func)
                {
                    c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length);
                }
                else
                {
                    throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
                }
                break;

            default:
                throw new NotImplementedException($"SetOpAttrScalar for {type}");
            }

            return(true);
        }
        bool SetOpAttrList(Context ctx, SafeOpHandle op,
                           string key, object values, TF_AttrType type,
                           Dictionary <string, long> attr_list_sizes,
                           Status status)
        {
            if (type == TF_AttrType.TF_ATTR_STRING && values is string[] values3)
            {
                c_api.TFE_OpSetAttrStringList(op, key, new IntPtr[0], values3.Select(x => x.Length).ToArray(), values3.Length);
                attr_list_sizes[key] = values3.Length;
            }
            else if (type == TF_AttrType.TF_ATTR_SHAPE && values is TensorShape[] values1)
            {
                // Make one pass through the input counting the total number of
                // dims across all the input lists.
                var num_values = values1.Length;
                attr_list_sizes[key] = num_values;
                var dims     = new IntPtr[num_values];
                var num_dims = values1.Select(x => x.ndim).ToArray();

                for (int i = 0; i < num_values; ++i)
                {
                    dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim);
                    tf.memcpy(dims[i], values1[i].dims.Select(x => (long)x).ToArray(), values1[i].ndim * sizeof(long));
                }

                c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle);
                Array.ForEach(dims, x => Marshal.FreeHGlobal(x));
            }
            else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2)
            {
                c_api.TFE_OpSetAttrTypeList(op, key, values2, values2.Length);
                attr_list_sizes[key] = values2.Length;
            }
            else if (type == TF_AttrType.TF_ATTR_INT && values is int[] values4)
            {
                c_api.TFE_OpSetAttrIntList(op, key, values4.Select(x => Convert.ToInt64(x)).ToArray(), values4.Length);
                attr_list_sizes[key] = values4.Length;
            }
            else
            {
                throw new NotImplementedException("");
            }

            return(true);
        }
Exemple #6
0
        private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size)
        {
            var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_);

            EXPECT_EQ(TF_Code.TF_OK, s_.Code);
            char e = expected_list_size >= 0 ? (char)1 : (char)0;

            /*EXPECT_EQ(e, m.is_list);
             * EXPECT_EQ(expected_list_size, m.list_size);
             * EXPECT_EQ(expected_type, m.type);
             * EXPECT_EQ(expected_total_size, m.total_size);*/
        }