Ejemplo n.º 1
0
 private void ForwardDeclare(Assembly assembly)
 {
     // Forward declare procedures
     foreach (var proc in assembly.Procedures)
     {
         var x86proc = new X86Proc(GetSymbolName(proc));
         procs[proc] = x86proc;
         // Forward declare each basic block inside
         foreach (var bb in proc.BasicBlocks)
         {
             basicBlocks[bb] = new X86BasicBlock(x86proc, formatOptions.Escape(bb.Name));
         }
     }
 }
Ejemplo n.º 2
0
        private void CompileInstr(Proc proc, Instr instr)
        {
            registerPool.FreeAll();
            CommentInstr(instr);
            switch (instr)
            {
            case Instr.Call call:
            {
                if (!(call.Procedure.Type is Type.Proc procTy))
                {
                    throw new InvalidOperationException();
                }
                if (procTy.CallConv == CallConv.Cdecl)
                {
                    // Push arguments backwards, track the offset of esp
                    var espOffset = 0;
                    int i         = 0;
                    foreach (var arg in call.Arguments.Reverse())
                    {
                        var argValue = CompileValue(arg);
                        CommentInstr($"argument {i}");
                        WritePush(argValue);
                        espOffset += SizeOf(arg);
                        registerPool.FreeAll();
                        ++i;
                    }
                    // If the return method is by address, we pass the address
                    var returnMethod = GetReturnMethod(procTy.Return);
                    if (returnMethod == ReturnMethod.Address)
                    {
                        var resultAddr    = CompileToAddress(call.Result);
                        var resultAddrReg = registerPool.Allocate(DataWidth.dword);
                        CommentInstr("return target address");
                        WriteInstr(X86Op.Lea, resultAddrReg, resultAddr);
                        WriteInstr(X86Op.Push, resultAddrReg);
                        espOffset += DataWidth.dword.Size;
                        registerPool.FreeAll();
                    }
                    // Do the call
                    var procedure = CompileSingleValue(call.Procedure);
                    WriteInstr(X86Op.Call, procedure);
                    // Restore stack
                    WriteInstr(X86Op.Add, Register.esp, new Operand.Literal(DataWidth.dword, espOffset));
                    // Writing the result back
                    if (returnMethod == ReturnMethod.None || returnMethod == ReturnMethod.Address)
                    {
                        // No-op
                        // For size == 0 there's nothing to copy, for size > 4 the value is already copied
                    }
                    else
                    {
                        CommentInstr("copy return value");
                        // Just store what's written in eax or eax:edx
                        var resultSize = SizeOf(procTy.Return);
                        var result     = CompileValue(call.Result);
                        if (result.Length == 1)
                        {
                            var eax = Register.AtSlot(0, DataWidth.GetFromSize(resultSize));
                            WriteInstr(X86Op.Mov, result[0], eax);
                        }
                        else
                        {
                            Debug.Assert(result.Length == 2);
                            var eax = Register.AtSlot(0, DataWidth.dword);
                            var edx = Register.AtSlot(2, DataWidth.GetFromSize(resultSize - 4));
                            WriteInstr(X86Op.Mov, result[0], eax);
                            WriteInstr(X86Op.Mov, result[1], edx);
                        }
                    }
                }
                else
                {
                    throw new NotImplementedException();
                }
            }
            break;

            case Instr.Ret ret:
            {
                // Writing the return value
                if (proc.CallConv == CallConv.Cdecl)
                {
                    var returnMethod = GetReturnMethod(ret.Value.Type);
                    if (returnMethod == ReturnMethod.None)
                    {
                        // No-op
                    }
                    else if (returnMethod == ReturnMethod.Eax || returnMethod == ReturnMethod.EaxEdx)
                    {
                        // We write in eax or the eax:edx pair
                        // TODO: Floats are returned in a different register!
                        // NOTE: Might be unnecessary allocation for edx
                        registerPool.Allocate(Register.eax, Register.edx);
                        var retValue = CompileValue(ret.Value);
                        if (retValue.Length == 1)
                        {
                            WriteInstr(X86Op.Mov, Register.eax, retValue[0]);
                        }
                        else
                        {
                            Debug.Assert(retValue.Length == 2);
                            WriteInstr(X86Op.Mov, Register.eax, retValue[0]);
                            WriteInstr(X86Op.Mov, Register.edx, retValue[1]);
                        }
                    }
                    else
                    {
                        // We receive a return address
                        // First we load that return address
                        var retAddrAddr = new Operand.Address(Register.ebp, 8);
                        var retAddr     = registerPool.Allocate(DataWidth.dword);
                        CommentInstr("copy return value");
                        WriteInstr(X86Op.Mov, retAddr, retAddrAddr);
                        // Compile the value
                        var retValue = CompileValue(ret.Value);
                        // Copy
                        WriteCopy(retAddr, retValue);
                    }
                }
                else
                {
                    throw new NotImplementedException();
                }
                // Write the epilogue, return
                WriteProcEpilogue(proc);
                CommentInstr(instr);
                WriteInstr(X86Op.Ret);
            }
            break;

            case Instr.Jmp jmp:
                WriteInstr(X86Op.Jmp, basicBlocks[jmp.Target]);
                break;

            case Instr.JmpIf jmpIf:
            {
                var condParts = CompileValue(jmpIf.Condition);
                Debug.Assert(condParts.Length > 0);
                // Initial condition is the first bytes
                var condHolder = registerPool.Allocate(condParts[0].GetWidth(sizeContext));
                WriteInstr(X86Op.Mov, condHolder, condParts[0]);
                // Remaining are or-ed to the result
                foreach (var item in condParts.Skip(1))
                {
                    WriteInstr(X86Op.Or, condHolder, item);
                }
                WriteInstr(X86Op.Test, condHolder, condHolder);
                WriteInstr(X86Op.Jne, basicBlocks[jmpIf.Then]);
                WriteInstr(X86Op.Jmp, basicBlocks[jmpIf.Else]);
            }
            break;

            case Instr.Store store:
            {
                var target  = CompileToAddress(store.Target);
                var value   = CompileValue(store.Value);
                var address = registerPool.Allocate(DataWidth.dword);
                WriteInstr(X86Op.Mov, address, target);
                WriteCopy(address, value);
            }
            break;

            case Instr.Load load:
            {
                var target     = CompileToAddress(load.Result);
                var source     = CompileToAddress(load.Source);
                var targetAddr = registerPool.Allocate(DataWidth.dword);
                var sourceAddr = registerPool.Allocate(DataWidth.dword);
                WriteInstr(X86Op.Lea, targetAddr, target);
                WriteInstr(X86Op.Mov, sourceAddr, source);
                WriteMemcopy(targetAddr, sourceAddr, SizeOf(load.Result));
            }
            break;

            case Instr.Alloc alloc:
            {
                var size = SizeOf(alloc.Allocated);
                WriteInstr(X86Op.Sub, Register.esp, new Operand.Literal(DataWidth.dword, size));
                var result = CompileToAddress(alloc.Result);
                WriteInstr(X86Op.Mov, result, Register.esp);
            }
            break;

            case Instr.ElementPtr elementPtr:
            {
                var target = CompileToAddress(elementPtr.Result);
                var value  = CompileSingleValue(elementPtr.Value);
                // We need it in a register
                value = LoadToRegister(value);
                var structTy = (Struct)((Type.Ptr)elementPtr.Value.Type).Subtype;

                var index = elementPtr.Index.Value;
                // Get offset, add it to the base address
                var offset = OffsetOf(structTy, index);
                WriteInstr(X86Op.Add, value, new Operand.Literal(DataWidth.dword, offset));
                WriteMov(target, value);
            }
            break;

            case Instr.Cast cast:
            {
                if (cast.Target is Type.Ptr && cast.Value.Type is Type.Ptr)
                {
                    // Should be a no-op, simply copy
                    var target = CompileToAddress(cast.Result);
                    var src    = CompileSingleValue(cast.Value);
                    WriteMov(target, src);
                }
                else
                {
                    throw new InvalidOperationException();
                }
            }
            break;

            // Size-dependent operations

            case Instr.Cmp cmp:
            {
                var targetParts = CompileValue(cmp.Result);
                var leftParts   = CompileValue(cmp.Left);
                var rightParts  = CompileValue(cmp.Right);
                Debug.Assert(leftParts.Length == rightParts.Length);
                var target      = targetParts[0];
                var targetWidth = target.GetWidth(sizeContext);

                var truthy = new Operand.Literal(targetWidth, 1);
                var falsy  = new Operand.Literal(targetWidth, 0);

                // We need to branch to write the result
                var labelNameBase = GetUniqueName("cmp_result");
                Debug.Assert(currentProcedure != null);
                var trueBB    = new X86BasicBlock(currentProcedure, $"{labelNameBase}_T");
                var falseBB   = new X86BasicBlock(currentProcedure, $"{labelNameBase}_F");
                var finallyBB = new X86BasicBlock(currentProcedure, $"{labelNameBase}_C");

                // Add all these basic blocks to the current procedure
                Debug.Assert(currentProcedure != null);
                currentProcedure.BasicBlocks.Add(trueBB);
                currentProcedure.BasicBlocks.Add(falseBB);
                currentProcedure.BasicBlocks.Add(finallyBB);

                bool signed = ((Type.Int)cmp.Left.Type).Signed;
                // For each part pair we compare
                // NOTE: It's important that we start from the most significant bytes here for the relational operators
                int i = 0;
                foreach (var(l, r) in leftParts.Zip(rightParts).Reverse())
                {
                    // NOTE: first && signed means that we are looking at the MSB, so in case of a signed
                    // comparison, we need to use the signed compare jumps, unsigned in any other case
                    // TL;DR: fisrt && signed means we need signed jump operation
                    bool first = i == 0;
                    bool last  = i == leftParts.Length - 1;
                    i += 1;

                    // We need to do the comparison
                    var tmp = WriteNonImmOrMemoryInstr(X86Op.Cmp, l, r);
                    if (tmp != null)
                    {
                        registerPool.Free(tmp);
                    }
                    if (last)
                    {
                        // If this was the last part to compare, we can just branch to the truthy part as the
                        // condition succeeded (or the falsy one on mismatch)
                        var inverseOp = ComparisonToJump(cmp.Comparison.Inverse, first && signed);
                        WriteInstr(inverseOp, falseBB);
                        WriteInstr(X86Op.Jmp, trueBB);
                    }
                    else
                    {
                        // These are not the last elements we compare
                        if (cmp.Comparison == Comparison.eq)
                        {
                            // We can jump to false on inequality
                            WriteInstr(X86Op.Jne, falseBB);
                        }
                        else if (cmp.Comparison == Comparison.ne)
                        {
                            // We can jump to true on inequality
                            WriteInstr(X86Op.Jne, trueBB);
                        }
                        else if (cmp.Comparison == Comparison.le || cmp.Comparison == Comparison.le_eq)
                        {
                            // If the part less, then we can jump to true, if greater, then to false
                            var lessOp    = ComparisonToJump(Comparison.le, first && signed);
                            var greaterOp = ComparisonToJump(Comparison.gr, first && signed);
                            WriteInstr(lessOp, trueBB);
                            WriteInstr(greaterOp, falseBB);
                        }
                        else if (cmp.Comparison == Comparison.gr || cmp.Comparison == Comparison.gr_eq)
                        {
                            // If the part is greater, then we can jump to true, if less, then to false
                            var greaterOp = ComparisonToJump(Comparison.gr, first && signed);
                            var lessOp    = ComparisonToJump(Comparison.le, first && signed);
                            WriteInstr(greaterOp, trueBB);
                            WriteInstr(lessOp, falseBB);
                        }
                    }
                }

                // On true block, we write the truthy value then jump to the continuation
                currentBasicBlock = trueBB;
                WriteInstr(X86Op.Mov, target, truthy);
                WriteInstr(X86Op.Jmp, finallyBB);
                // On false block, we write the falsy value then jump to the continuation
                currentBasicBlock = falseBB;
                WriteInstr(X86Op.Mov, target, falsy);
                WriteInstr(X86Op.Jmp, finallyBB);
                // We continue writing on the continuation
                currentBasicBlock = finallyBB;
            }
            break;

            case Instr.Add:
            case Instr.Sub:
            {
                var arith  = (ArithInstr)instr;
                var target = CompileToAddress(arith.Result);

                if (arith.Left.Type is Type.Ptr leftPtr)
                {
                    // Pointer arithmetic
                    var left  = CompileSingleValue(arith.Left);
                    var right = CompileSingleValue(arith.Right);
                    right = LoadToRegister(right);

                    WriteMov(target, left);
                    // We multiply the offset by the data size to get the expected behavior
                    WriteInstr(X86Op.Imul, right, new Operand.Literal(DataWidth.dword, SizeOf(leftPtr.Subtype)));
                    WriteInstr(arith is Instr.Add ? X86Op.Add : X86Op.Sub, target, right);
                }
                else if (arith.Left.Type is Type.Int && arith.Right.Type is Type.Int)
                {
                    // Integer arithmetic
                    var targetAddr = registerPool.Allocate(DataWidth.dword);
                    WriteInstr(X86Op.Lea, targetAddr, target);
                    var left  = CompileValue(arith.Left);
                    var right = CompileValue(arith.Right);
                    Debug.Assert(left.Length == right.Length);

                    var firstOp = arith is Instr.Add ? X86Op.Add : X86Op.Sub;
                    var remOps  = arith is Instr.Add ? X86Op.Adc : X86Op.Sbb;

                    int  offset = 0;
                    bool first  = true;
                    foreach (var(l, r) in left.Zip(right))
                    {
                        var width = l.GetWidth(sizeContext);
                        var addr  = new Operand.Address(targetAddr, offset);
                        var displ = new Operand.Indirect(width, addr);
                        WriteMov(displ, l);
                        if (first)
                        {
                            // We use add or sub
                            WriteNonImmOrMemoryInstr(firstOp, displ, r);
                        }
                        else
                        {
                            // We use adc or sbb
                            WriteNonImmOrMemoryInstr(remOps, displ, r);
                        }

                        offset += width.Size;
                        first   = false;
                    }
                }
                else
                {
                    throw new InvalidOperationException();
                }
            }
            break;

            case Instr.Mul mul:
            {
                if (SizeOf(mul.Left.Type) > 4)
                {
                    // For now we skip this
                    throw new NotSupportedException("Multiplication of > 4 byte operands is not supported!");
                }
                var target = CompileToAddress(mul.Result);
                var left   = CompileSingleValue(mul.Left);
                var right  = CompileSingleValue(mul.Right);
                var tmp    = registerPool.Allocate(DataWidth.dword);
                WriteMov(tmp, left);
                WriteInstr(X86Op.Imul, tmp, right);
                WriteMov(target, tmp);
            }
            break;

            case Instr.Div:
            case Instr.Mod:
            {
                var arith = (ArithInstr)instr;
                if (SizeOf(arith.Left.Type) > 4)
                {
                    // For now we skip this
                    throw new NotSupportedException("Division of > 4 byte operands is not supported!");
                }
                // NOTE: This is different from multiplication
                registerPool.Allocate(Register.eax, Register.edx);

                var target = CompileToAddress(arith.Result);
                var left   = CompileSingleValue(arith.Left);
                var right  = CompileSingleValue(arith.Right);

                right = LoadToRegister(right);
                WriteInstr(X86Op.Mov, Register.edx, new Operand.Literal(DataWidth.dword, 0));
                WriteInstr(X86Op.Mov, Register.eax, left);
                WriteInstr(X86Op.Idiv, right);
                WriteInstr(X86Op.Mov, target, arith is Instr.Div ? Register.eax : Register.edx);
            }
            break;

            case Instr.BitAnd:
            case Instr.BitOr:
            case Instr.BitXor:
            {
                var bitw = (BitwiseInstr)instr;

                var target     = CompileToAddress(bitw.Result);
                var leftParts  = CompileValue(bitw.Left);
                var rightParts = CompileValue(bitw.Right);

                var op = bitw switch
                {
                    Instr.BitAnd => X86Op.And,
                    Instr.BitOr => X86Op.Or,
                    Instr.BitXor => X86Op.Xor,
                    _ => throw new NotImplementedException(),
                };

                // We can just apply the operation for each part
                Debug.Assert(leftParts.Length == rightParts.Length);
                var targetAddr = registerPool.Allocate(DataWidth.dword);
                WriteInstr(X86Op.Lea, targetAddr, target);
                int offset = 0;
                foreach (var(l, r) in leftParts.Zip(rightParts))
                {
                    var width = l.GetWidth(sizeContext);
                    var addr  = new Operand.Address(targetAddr, offset);
                    var displ = new Operand.Indirect(width, addr);
                    WriteMov(displ, l);
                    WriteInstr(op, displ, r);
                    offset += width.Size;
                }
            }
            break;

            case Instr.Shl:
            case Instr.Shr:
            {
                var bitsh = (BitShiftInstr)instr;
                if (SizeOf(bitsh.Shifted.Type) > 4 || SizeOf(bitsh.Amount.Type) > 4)
                {
                    // For now we skip this
                    throw new NotSupportedException("Shifting of > 4 byte operands is not supported!");
                }

                var target = CompileToAddress(bitsh.Result);
                var left   = CompileSingleValue(bitsh.Shifted);
                var right  = CompileSingleValue(bitsh.Amount);
                left = LoadToRegister(left);
                WriteInstr(bitsh is Instr.Shl ? X86Op.Shl : X86Op.Shr, left, right);
                WriteMov(target, left);
            }
            break;

            default: throw new NotImplementedException();
            }
        }