Exemple #1
0
 protected bool relaxed_agree_or_null(Type wannabe, Type t)
 {
     return wannabe.relaxed_agree_or_null(t);
 }
Exemple #2
0
 protected bool agree_or_null(Expression expr, Type t)
 {
     return expr.agree_or_null(t);
 }
Exemple #3
0
 protected bool relaxed_agree(Type wannabe, Type t)
 {
     return wannabe.relaxed_agree(t);
 }
Exemple #4
0
 protected bool agree(Expression expr, Type t)
 {
     return expr.agree(t);
 }
Exemple #5
0
 protected bool agree_or_null(Type wannabe, Type t)
 {
     return wannabe.agree_or_null(t);
 }
Exemple #6
0
 protected bool is_alu_or_sreg_or_null(Expression expr, Type t)
 {
     return is_alu_or_sreg_or_null(expr, t, 0);
 }
Exemple #7
0
 protected bool agree(Type wannabe, Type t)
 {
     return wannabe.agree(t);
 }
Exemple #8
0
        public static Module matmul()
        {
            Type u16 = new Type { Name = TypeName.U16 }, u32 = new Type { Name = TypeName.U32 }, u64 = new Type { Name = TypeName.U64 };
            Type f32 = new Type { Name = TypeName.F32 }, f64 = new Type { Name = TypeName.F64 }, pred = new Type { Name = TypeName.Pred };

            var regs = new Dictionary<String, Reg>();
            Func<String, Reg> reg_u32 = name => new Reg{Name = name, Type = u32};
            Func<String, Reg> reg_f32 = name => new Reg{Name = name, Type = f32};
            Func<int, Reg> rh = i => regs.GetOrCreate(String.Format("%rh{0}", i), name => new Reg{Name = name, Type = u16});
            Func<int, Reg> r = i => regs.GetOrCreate(String.Format("%r{0}", i), name => new Reg{Name = name, Type = u32});
            Func<int, Reg> rd = i => regs.GetOrCreate(String.Format("%rd{0}", i), name => new Reg{Name = name, Type = u64});
            Func<int, Reg> f = i => regs.GetOrCreate(String.Format("%f{0}", i), name => new Reg{Name = name, Type = f32});
            Func<int, Reg> fd = i => regs.GetOrCreate(String.Format("%fd{0}", i), name => new Reg{Name = name, Type = f64});
            Func<int, Reg> p = i => regs.GetOrCreate(String.Format("%p{0}", i), name => new Reg{Name = name, Type = pred});

            var module = new Module(SoftwareIsa.PTX_14, HardwareIsa.SM_13);
            Func<String, Var> param_align4_b8_12 = name => new Var{Name = name, Space = space.param, Alignment = 4, Type = new Type{Name = TypeName.B8, Dims = new []{12}}};
            Var a = param_align4_b8_12("A"), b = param_align4_b8_12("B"), c = param_align4_b8_12("C");
            var kernel = module.AddEntry("MatMulKernel", a, b, c);
            var ptx = kernel.Stmts;

            Func<String, Label> label = name => new Label{Name = name};
            Label loop_body = label("$LoopBody"), after_loop = label("$AfterLoop"), exit = label("$Exit");
            Reg a_width = reg_u32("a_width"), a_height = reg_u32("a_height"), a_raw = reg_u32("a_raw");
            Reg b_width = reg_u32("b_width"), b_height = reg_u32("b_height"), b_raw = reg_u32("b_raw");
            Reg c_width = reg_u32("c_width"), c_height = reg_u32("c_height"), c_raw = reg_u32("c_raw");
            Reg row = reg_u32("row"), col = reg_u32("col"), cvalue = reg_f32("cvalue"), dim = reg_u32("dim");
            Reg a_offset = reg_u32("a_offset"), a_offset_lo = reg_u32("a_offset_lo"), a_offset_stride = reg_u32("a_offset_stride"), a_offset_hi = reg_u32("a_offset_hi");
            Reg b_offset = reg_u32("b_offset"), b_offset_lo = reg_u32("b_offset_lo"), b_offset_stride = reg_u32("b_offset_stride"), b_offset_hi = reg_u32("b_offset_hi");

            ptx.Add(new Comment{Text = Environment.NewLine});
            ptx.Add(new Comment{Text = "int row = blockIdx.y * blockDim.y + threadIdx.y;"});
            ptx.Add(new Comment{Text = "int col = blockIdx.x * blockDim.x + threadIdx.x;"});
            ptx.Add(new mov{type = u16, d = rh(1), a = new ctaid().mod(Mod.X)});
            ptx.Add(new mov{type = u16, d = rh(2), a = new ntid().mod(Mod.X)});
            ptx.Add(new mul{type = u16, mode = mulm.wide, d = r(1), a = rh(1), b = rh(2)});
            ptx.Add(new mov{type = u16, d = rh(3), a = new ctaid().mod(Mod.Y)});
            ptx.Add(new mov{type = u16, d = rh(4), a = new ntid().mod(Mod.Y)});
            ptx.Add(new mul{type = u16, mode = mulm.wide, d = r(2), a = rh(3), b = rh(4)});
            ptx.Add(new cvt{dtype = u32, atype = u16, d = r(3), a = new tid().mod(Mod.X)});
            ptx.Add(new add{type = u32, d = col, a = r(3), b = r(1)});
            ptx.Add(new cvt{dtype = u32, atype = u16, d = r(5), a = new tid().mod(Mod.Y)});
            ptx.Add(new add{type = u32, d = row, a = r(5), b = r(2)});

            ptx.Add(new Comment{Text = Environment.NewLine});
            ptx.Add(new Comment{Text = "if (A.height <= row || B.width <= col) return;"});
            ptx.Add(new ld{ss = space.param, type = u32, d = b_width, a = b + 0});
            ptx.Add(new ld{ss = space.param, type = u32, d = a_height, a = a + 4});
            ptx.Add(new setp{cmpop = cmp.le, type = u32, p = p(6), a = a_height, b = row});
            ptx.Add(new setp{cmpop = cmp.le, type = u32, p = p(7), a = b_width, b = col});
            ptx.Add(new or{type = pred, d = p(1), a = p(6), b = p(7)});
            ptx.Add(new bra{Guard = p(1), tgt = exit});

            ptx.Add(new Comment{Text = Environment.NewLine});
            ptx.Add(new Comment{Text = "float Cvalue = 0;"});
            ptx.Add(new mov{type = f32, d = cvalue, a = (Const)0f});

            ptx.Add(new Comment{Text = Environment.NewLine});
            ptx.Add(new Comment{Text = "for (int dim = 0; dim < A.width; ++dim)"});
            ptx.Add(new ld{ss = space.param, type = u32, d = a_width, a = a + 0});
            ptx.Add(new mov{type = u32, d = dim, a = (Const)0});
            ptx.Add(new setp{cmpop = cmp.le, type = u32, p = new Modded{Mod = Mod.Couple, Embedded = {p(2), p(8)}}, a = a_width, b = dim});
            ptx.Add(new bra{Guard = p(8).mod(Mod.Not), tgt = after_loop});

            ptx.Add(new Comment{Text = Environment.NewLine});
            ptx.Add(new Comment{Text = "Cvalue += A.elements[row * A.width + dim] * B.elements[dim * B.width + col];"});
            ptx.Add(new ld{ss = space.param, type = u32, d = a_raw, a = a + 8});
            ptx.Add(new mul{mode = mulm.lo, type = u32, d = r(18), a = a_width, b = row});
            ptx.Add(new mul{mode = mulm.lo, type = u32, d = a_offset_lo, a = r(18), b = (Const)4});
            ptx.Add(new add{type = u32, d = a_offset, a = a_offset_lo, b = a_raw});
            ptx.Add(new add{type = u32, d = r(21), a = r(18), b = a_width});
            ptx.Add(new mul{mode = mulm.lo, type = u32, d = r(25), a = r(21), b = (Const)4});
            ptx.Add(new add{type = u32, d = a_offset_hi, a = r(25), b = a_raw});
            ptx.Add(new ld{ss = space.param, type = u32, d = b_raw, a = b + 8});
            ptx.Add(new mul{mode = mulm.lo, type = u32, d = b_offset_lo, a = col, b = (Const)4});
            ptx.Add(new add{type = u32, d = b_offset, a = b_offset_lo, b = b_raw});
            ptx.Add(new mul{mode = mulm.lo, type = u32, d = b_offset_stride, a = b_width, b = (Const)4});

            ptx.Add(new Comment{Text = Environment.NewLine});
            ptx.Add(new Comment{Text = "Cvalue += A.elements[row * A.width + dim] * B.elements[dim * B.width + col];"});
            ptx.Add(loop_body);
            ptx.Add(new ld{ss = space.global, type = f32, d = f(2), a = a_offset});
            ptx.Add(new ld{ss = space.global, type = f32, d = f(3), a = b_offset});
            ptx.Add(new mad{type = f32, d = cvalue, a = f(3), b = f(2), c = cvalue});
            ptx.Add(new add{type = u32, d = a_offset, a = a_offset, b = (Const)4});
            ptx.Add(new add{type = u32, d = b_offset, a = b_offset, b = b_offset_stride});
            ptx.Add(new setp{cmpop = cmp.ne, type = u32, p = p(3), a = a_offset, b = a_offset_hi});
            ptx.Add(new bra{Guard = p(3), tgt = loop_body});
            ptx.Add(new bra{uni = true, tgt = after_loop});

            ptx.Add(new Comment{Text = Environment.NewLine});
            ptx.Add(new Comment{Text = "C.elements[row * C.width + col] = Cvalue;"});
            ptx.Add(after_loop);
            ptx.Add(new ld{ss = space.param, type = u32, d = c_raw, a = c + 8});
            ptx.Add(new ld{ss = space.param, type = u32, d = c_width, a = c + 0});
            ptx.Add(new mul{mode = mulm.lo, type = u32, d = r(32), a = c_width, b = row});
            ptx.Add(new add{type = u32, d = r(33), a = col, b = r(32)});
            ptx.Add(new mul{mode = mulm.lo, type = u32, d = r(34), a = r(33), b = (Const)4});
            ptx.Add(new add{type = u32, d = r(35), a = c_raw, b = r(34)});
            ptx.Add(new st{ss = space.global, type = f32, a = r(35), b = cvalue});

            ptx.Add(new Comment{Text = Environment.NewLine});
            ptx.Add(exit);
            ptx.Add(new exit());

            return module;
        }
Exemple #9
0
 public static void EmitTypeSpec(this IndentedWriter w, String code_type, Type assumed_type, Space assumed_space)
 {
     w.EmitTypeSpec(code_type, assumed_type, assumed_space, false);
 }
Exemple #10
0
 protected bool is_relaxed_reg_or_null(Expression expr, Type t)
 {
     return is_relaxed_reg_or_null(expr, t, 0);
 }
Exemple #11
0
 protected bool is_relaxed_reg_or_null(Expression expr, Type t, Mod mod)
 {
     return expr == null || is_relaxed_reg(expr, t, mod);
 }
Exemple #12
0
 protected bool is_relaxed_reg(Expression expr, Type t, Mod mod)
 {
     if (expr == null) return false;
     if (!is_reg(expr)) return false;
     if (!expr.relaxed_agree(t)) return false;
     if (!expr.has_mod(mod)) return false;
     return true;
 }
Exemple #13
0
 protected bool is_relaxed_alu_or_sreg(Expression expr, Type t)
 {
     return is_relaxed_alu_or_sreg(expr, t, 0);
 }
Exemple #14
0
 protected bool is_reg(Expression expr, Type t)
 {
     return is_reg(expr, t, 0);
 }
Exemple #15
0
 protected bool relaxed_agree_or_null(Expression expr, Type t)
 {
     return expr.relaxed_agree_or_null(t);
 }
Exemple #16
0
        public static void EmitTypeSpec(this IndentedWriter w, String code_type, Type assumed_type, Space assumed_space, bool full_name)
        {
            assumed_type.AssertNotNull();
            if (assumed_type.is_arr()) return;

            // 1. Infer strict and relaxed conversions for element type

            var strict = new List<String>();
            var relaxed = new List<String>();

            var el = assumed_type;
            el = el.Unfold(t => t.arr_el(), t => t != null).Last();
            el = el.Unfold(t => t.vec_el(), t => t != null).Last();
            var s_sizes = (el.SizeInMemory * 8).MkArray();
            var r_sizes = (new[] { 8, 16, 32, 64 }).Where(sz1 => sz1 <= el.SizeInMemory * 8).Reverse().ToReadOnly();

            if (assumed_type.is_int())
            {
                s_sizes.ForEach(sz1 => strict.Add("u" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("u" + sz1));
                s_sizes.ForEach(sz1 => strict.Add("s" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("s" + sz1));
                s_sizes.ForEach(sz1 => strict.Add("b" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("b" + sz1));
            }
            else if (assumed_type.is_float())
            {
                s_sizes.ForEach(sz1 => strict.Add("f" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("f" + sz1));
                s_sizes.ForEach(sz1 => strict.Add("b" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("b" + sz1));
            }
            else
            {
                assumed_type.is_bit().AssertTrue();

                s_sizes.ForEach(sz1 => strict.Add("u" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("u" + sz1));
                s_sizes.ForEach(sz1 => strict.Add("s" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("s" + sz1));
                s_sizes.ForEach(sz1 => strict.Add("f" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("f" + sz1));
                s_sizes.ForEach(sz1 => strict.Add("b" + sz1));
                r_sizes.ForEach(sz1 => relaxed.Add("b" + sz1));
            }

            strict.Remove("f8");
            relaxed.Remove("f8");

            // 2. Hack around conversions for usage with vectors

            if (assumed_type.is_vec())
            {
                strict.ForEach((t, i) => strict[i] = String.Format("v{0}_{1}", assumed_type.vec_rank(), t));
                relaxed.ForEach((t, i) => relaxed[i] = String.Format("v{0}_{1}", assumed_type.vec_rank(), t));
            }

            // 3. Emit appropriate implicit conversions
            var ns = "Libptx.Edsl.Common.Types.";
            if (assumed_type.is_opaque()) ns += "Opaque.";
            else if (assumed_type.is_vec()) ns += "Vector.";
            else ns += "Scalar.";
            if (!full_name) ns = null;

            if (assumed_space == Space.Reg)
            {
                strict.ForEach(t => w.WriteLine("public static implicit operator {2}reg_{0}({1} {1}) {{ return new {2}reg_{0}({1}); }}", t, code_type, ns));
                relaxed.ForEach(t => w.WriteLine("public static implicit operator {2}relaxed_reg_{0}({1} {1}) {{ return new {2}relaxed_reg_{0}({1}); }}", t, code_type, ns));
            }

            strict.ForEach(t => w.WriteLine("public static implicit operator {2}{0}({1} {1}) {{ return new {2}{0}({1}); }}", t, code_type, ns));
        }
Exemple #17
0
        protected override void custom_validate_operands()
        {
            if (mode == wide)
            {
                var s_typename = type.Name.ToString();
                var s_twice_typename = s_typename.Slice(0, 1) + int.Parse(s_typename.Slice(1)) * 2;
                var twice_typename = (TypeName)Enum.Parse(typeof(TypeName), s_twice_typename);
                var twice_type = new Type { Name = twice_typename, Mod = type.Mod, Dims = type.Dims };

                is_alu(d, twice_type).AssertTrue();
                is_alu(a, type).AssertTrue();
                is_alu(b, type).AssertTrue();
                is_alu(c, twice_type).AssertTrue();
            }
            else
            {
                is_reg(d, type).AssertTrue();
                is_alu(a, type).AssertTrue();
                is_alu(b, type).AssertTrue();
                is_alu(c, type).AssertTrue();
            }
        }
Exemple #18
0
 protected bool is_alu_or_null(Expression expr, Type t, Mod mod)
 {
     return expr == null || is_alu(expr, t, mod);
 }