diff options
author | Vladimir Azarov <avm@intermediate-node.net> | 2025-08-08 19:07:58 +0200 |
---|---|---|
committer | Vladimir Azarov <avm@intermediate-node.net> | 2025-08-08 19:07:58 +0200 |
commit | a417225089fd78d53d73ad63cd79f57d1a4a8ff1 (patch) | |
tree | d9da68b0414fdaf08ddccbae20bd0e2977cdca25 /emit.fun | |
parent | b0cb85edf2b60f6f0909355db717376f435ab312 (diff) |
Register allocation
Diffstat (limited to 'emit.fun')
-rw-r--r-- | emit.fun | 697 |
1 files changed, 667 insertions, 30 deletions
@@ -5,8 +5,54 @@ functor Emit(I: IL) = struct structure D = P.D structure PP = P.P + val ip = I + val file = ref NONE + datatype reg = + Rax | Rbx | Rcx | Rdx | Rsi | Rdi | Rbp | Rsp | R8 | + R9 | R10 | R11 | R12 | R13 | R14 | R15 + + datatype vConst = VConst of word | VAddrConst of int * word + datatype vrType = + VtConst of vConst | VtReg of reg | VtStack of int | VtUnk + + datatype affinity = AfHard of reg | AfSoft of int list | AfUnk + + val regs = [ + (Rax, 0), + (Rdx, 1), + (Rsp, 2), + (Rbp, 3), + + (Rcx, 4), + (Rsi, 5), + (Rdi, 6), + (R8, 7), + (R9, 8), + (R10, 9), + (R11, 10), + + (Rbx, 11), + (R12, 12), + (R13, 13), + (R14, 14), + (R15, 15) + ] + + val firstUsedReg = 4 + val usedRegNum = 12 + + fun reg2idx reg = + case List.find (fn (r, _) => r = reg) regs of + NONE => raise Unreachable + | SOME (_, idx) => idx + + fun idx2reg idx = + case List.find (fn (_, i) => i = idx) regs of + NONE => raise Unreachable + | SOME (r, _) => r + local fun output s = let @@ -120,24 +166,6 @@ functor Emit(I: IL) = struct D.appi f P.iniLayouts end - fun getVarsForAlloc vregs = - let - fun loop idx acc = - if idx = D.length vregs then - rev acc - else - let - val { t, ... } = D.get vregs idx - in - if t = I.RtReg then - loop (idx + 1) (idx :: acc) - else - loop (idx + 1) acc - end - in - loop 0 [] - end - fun extendEnd (iStart, iEnd) ops labels = let fun loop idx iEnd = @@ -145,7 +173,7 @@ functor Emit(I: IL) = struct iEnd else let - val ins = D.get ops idx + val (ins, _) = D.get ops idx in case ins of SOME (I.IrJmp lid) | SOME (I.IrJz (_, lid)) | @@ -167,11 +195,24 @@ functor Emit(I: IL) = struct loop iEnd iEnd end - fun computeIntLocal (s, e) ops labels = + fun computeIntLocal (s, e) firstDef ops labels = let val e = extendEnd (s, e) ops labels + + val (_, li) = D.get ops firstDef in - (s, e) + case li of + SOME (startL, endL) => + let + val (startLoop, endLoop) = + (valOf $ D.get labels startL, valOf $ D.get labels endL) + + val s = if s < startLoop then s else startLoop + val e = if e > endLoop then e else endLoop + in + (s, e) + end + | _ => (s, e) end fun getBasicInt [] _ = raise Unreachable @@ -194,33 +235,629 @@ functor Emit(I: IL) = struct val (iStart, iEnd) = if var < localBound then - computeIntLocal (iStart, iEnd) ops labels + computeIntLocal (iStart, iEnd) (List.last defs) ops labels else (iStart, iEnd) in (var, iStart, iEnd) end - fun computeInts (F as I.Fi { vregs, ... }) vars = - List.map (computeInt F) vars + fun computeInts F vars = List.map (computeInt F) vars fun printInts ints = let - val () = printfn `"\nintervals:\n" % + val () = printfn `"\nsorted intervals:\n" % fun p (id, s, e) = printfn `"id: %" I id `" {" I s `", " I e `"}" % in List.app p ints end + fun updAff arr idx aff = + let + val (_, vt) = Array.sub (arr, idx) + in + Array.update (arr, idx, (aff, vt)) + end + + datatype insAff = IaNone | IaHard of (int * reg) list + | IaSoft of int * int list + + fun parNum2reg pr = + case pr of + 0 => Rdi + | 1 => Rsi + | 2 => Rdx + | 3 => Rcx + | 5 => R8 + | 6 => R9 + | _ => raise Unreachable + + fun getInsAff (SOME ins) = + let + fun tr (rd, rs1, rs2) = IaSoft (rd, [rs1, rs2]) + + fun setAff (rd, I.SaVReg rs) = IaSoft (rd, [rs]) + | setAff _ = IaNone + + fun fcallAff args = + let + fun collect idx (arg :: args) acc = + collect (idx + 1) args ((arg, parNum2reg idx) :: acc) + | collect _ [] acc = rev acc + in + IaHard $ collect 0 args [] + end + in + case ins of + I.IrSet p => setAff p + | I.IrAdd t => tr t + | I.IrSub t => tr t + | I.IrMul t => tr t + | I.IrIMul t => tr t + | I.IrDiv t => tr t + | I.IrIDiv t => tr t + | I.IrMod t => tr t + | I.IrIMod t => tr t + | I.IrShl t => tr t + | I.IrShr t => tr t + | I.IrSar t => tr t + | I.IrAnd t => tr t + | I.IrOr t => tr t + | I.IrXor t => tr t + | I.IrEq t => tr t + | I.IrNeq t => tr t + | I.IrCmpul t => tr t + | I.IrCmpug t => tr t + | I.IrCmpule t => tr t + | I.IrCmpuge t => tr t + | I.IrCmpsl t => tr t + | I.IrCmpsg t => tr t + | I.IrCmpsle t => tr t + | I.IrCmpsge t => tr t + + | I.IrExtZero _ | I.IrExtSign _ + | I.IrLoad _ | I.IrStore _ | I.IrJmp _ + | I.IrJz _ | I.IrJnz _ | I.IrNopLabel _ + | I.IrNop _ | I.IrRet _ | I.IrAlloc _ + | I.IrCopy _ => IaNone + | I.IrFcall (_, _, args) => fcallAff args + end + | getInsAff NONE = IaNone + + fun updateSoftAff rinfo rd rss = + let + fun sort [r] = [r] + | sort [rs1, rs2] = if rs1 < rs2 then [rs1, rs2] else [rs2, rs1] + | sort _ = raise Unreachable + + fun isNotConst rv = + let + val (_, vt) = Array.sub (rinfo, rv) + in + case vt of + VtConst _ => false + | _ => true + end + + val (aff, vt) = Array.sub (rinfo, rd) + val rss = List.filter isNotConst $ sort rss + + fun insertSorted ins [] = ins + | insertSorted [] acc = acc + | insertSorted (x :: xs) (y :: ys) = + if x < y then + x :: insertSorted xs (y :: ys) + else + y :: insertSorted (x :: xs) ys + + val aff = + case aff of + AfUnk => AfSoft rss + | AfSoft affs => AfSoft $ insertSorted rss affs + | AfHard _ => aff + in + Array.update (rinfo, rd, (aff, vt)) + end + + fun updateHardAff rinfo hards = + let + fun f (rd, reg) = + let + val (aff, vt) = Array.sub (rinfo, rd) + + val aff = + case aff of + AfUnk | AfSoft _ => AfHard reg + | AfHard _ => raise Unreachable + in + Array.update (rinfo, rd, (aff, vt)) + end + in + List.app f hards + end + + fun compAffinity rinfo ops paramNum = + let + fun compParams idx = + if idx = paramNum then + () + else + let + val reg = parNum2reg idx + in + updAff rinfo idx (AfHard reg); + compParams (idx + 1) + end + val () = compParams 0 + + fun loop idx = + if idx = D.length ops then + () + else + let + val (ins, _) = D.get ops idx + in + case getInsAff ins of + IaNone => () + | IaSoft (rd, rss) => updateSoftAff rinfo rd rss + | IaHard hards => updateHardAff rinfo hards; + loop (idx + 1) + end + in + loop 0 + end + + fun prepareRegInfo paramNum ops vregs = + let + val rinfo = Array.array (D.length vregs, (AfUnk, VtUnk)) + + fun transfer idx acc = + if idx = D.length vregs then + rev acc + else + let + val (vt, cand) = + case #t $ D.get vregs idx of + I.RtRem => (VtUnk, NONE) + | I.RtConst w => (VtConst (VConst w), NONE) + | I.RtAddrConst (id, w) => (VtConst (VAddrConst (id, w)), NONE) + | I.RtReg => (VtUnk, SOME idx) + in + Array.update (rinfo, idx, (AfUnk, vt)); + transfer (idx + 1) (if isSome cand then valOf cand :: acc else acc) + end + + val toAlloc = transfer 0 [] + in + compAffinity rinfo ops paramNum; + (toAlloc, rinfo) + end + + fun preg reg out = + let + val s = + case reg of + Rax => "rax" + | Rbx => "rbx" + | Rcx => "rcx" + | Rdx => "rdx" + | Rsi => "rsi" + | Rdi => "rdi" + | Rbp => "rbp" + | Rsp => "rsp" + | R8 => "r8" + | R9 => "r9" + | R10 => "r10" + | R11 => "r11" + | R12 => "r12" + | R13 => "r13" + | R14 => "r14" + | R15 => "r15" + in + Printf out `s % + end + + val Preg = fn z => bind A1 preg z + + fun affPrint rinfo = + let + fun pv idx out = Printf out `"%" I idx % + + fun p (idx, (aff, _)) = + let + val () = printf `"%" I idx % + in + case aff of + AfUnk => printfn `" = unk" % + | AfHard reg => printfn `" <- " Preg reg % + | AfSoft rss => printfn `" <- " Plist pv rss (", ", true, 1) % + end + in + Array.appi p rinfo + end + + fun sort _ [] = [] + | sort _ [x] = [x] + | sort le l = + let + fun divide [] accp = accp + | divide [x] (acc1, acc2) = (x :: acc1, acc2) + | divide (x :: y :: tail) (acc1, acc2) = + divide tail (x :: acc1, y :: acc2) + val (part1, part2) = divide l ([], []) + val part1 = sort le part1 + val part2 = sort le part2 + + fun merge [] [] acc = acc + | merge [] ys acc = rev $ List.revAppend (ys, acc) + | merge xs [] acc = rev $ List.revAppend (xs, acc) + | merge (x :: xs) (y :: ys) acc = + if le (x, y) then + merge xs (y :: ys) (x :: acc) + else + merge (x :: xs) ys (y :: acc) + in + merge part1 part2 [] + end + + fun updateI i = fn z => + let + fun from rinfo active pool intervals stackCand = + { rinfo, active, pool, intervals, stackCand } + fun to f { rinfo, active, pool, intervals, stackCand } = + f rinfo active pool intervals stackCand + in + FRU.makeUpdate5 (from, from, to) i + end z + + fun returnToPool pool reg = + let + val idx = reg2idx reg - firstUsedReg + in + Array.update (pool, idx, NONE) + end + + fun expireOne { rinfo, active, pool, ... } (_, start, _) = + case !active of + [] => false + | (j, startp, endp) :: acts => + if endp > start then + false + else + let + val (_, vt) = Array.sub (rinfo, j) + val reg = case vt of VtReg reg => reg | _ => raise Unreachable + + val () = printfn `"III!!! interval %" + ip j `"(" ip startp `", " ip endp `") " + `"with " Preg reg `" has expired" % + in + returnToPool pool reg; + active := acts; + true + end + + fun expireOld (I as { active, ... }) int = + let + fun loop I = + case expireOne I int of + false => () + | true => loop I + in + case !active of + [] => () + | _ => loop I + end + + fun addToActive int [] = [int] + | addToActive (I as (_, _, e1)) (act :: acts) = + if e1 < #3 act then + (I :: act :: acts) + else + act :: addToActive I acts + + fun updReg arr idx reg = + let + val (aff, _) = Array.sub (arr, idx) + in + Array.update (arr, idx, (aff, reg)) + end + + fun assignFirstReg poff { rinfo, pool, ... } vr = + let + fun loop idx = + if idx = Array.length pool then + raise Unreachable + else + let + val user = Array.sub (pool, idx) + in + case user of + SOME _ => loop (idx + 1) + | NONE => + let + val () = Array.update (pool, idx, SOME vr); + val reg = idx2reg (firstUsedReg + idx) + + val () = printfn R poff + `"assigned (first) reg " Preg reg `" to %" ip vr % + in + updReg rinfo vr (VtReg reg) + end + end + in + loop 0 + end + + fun freeRegList pool = + let + fun loop idx acc = + if idx = Array.length pool then + rev acc + else + case Array.sub (pool, idx) of + NONE => loop (idx + 1) (idx2reg (idx + firstUsedReg) :: acc) + | SOME _ => loop (idx + 1) acc + in + loop 0 [] + end + + fun getAffRegList rinfo affs = + let + fun loop [] acc = rev acc + | loop (vr :: vrs) acc = + let + val (_, vt) = Array.sub (rinfo, vr) + in + case vt of + VtReg r => loop vrs (r :: acc) + | _ => loop vrs acc + end + in + loop affs [] + end + + fun findCommonRegs l1 l2 = + let + val l1 = sort (fn (r1, r2) => reg2idx r1 <= reg2idx r2) l1 + val l2 = sort (fn (r1, r2) => reg2idx r1 <= reg2idx r2) l2 + + fun intersection [] _ = [] + | intersection _ [] = [] + | intersection (x :: xs) (y :: ys) = + case Int.compare (reg2idx x, reg2idx y) of + LESS => intersection xs (y :: ys) + | EQUAL => x :: intersection xs ys + | GREATER => intersection (x :: xs) ys + in + intersection l1 l2 + end + + fun assignSoftReg poff affs (I as { rinfo, pool, ... }) vr = + let + val () = printfn R poff + `"trying to assign register (by affinity) to %" ip vr % + + val regs = freeRegList pool + val affRegs = getAffRegList rinfo affs + val common = findCommonRegs regs affRegs + + val () = printfn R (poff + 1) + `"free registers: " Plist preg regs (", ", true, 0) % + val () = printfn R (poff + 1) + `"affinity registers: " Plist preg affRegs (", ", true, 0) % + in + case common of + [] => + let + val () = printfn R (poff + 1) `"affinity was not satisfied" % + in + assignFirstReg (poff + 2) I vr + end + | (reg :: _) => + let + in + updReg rinfo vr (VtReg reg); + Array.update (pool, reg2idx reg - firstUsedReg, SOME vr); + printfn R (poff + 1) + `"assigned (by affinity) reg " Preg reg `" to %" ip vr %; + printfn R (poff + 1) + `"free registers: " Plist preg (freeRegList pool) (", ", true, 0) % + end + end + + fun putToStack poff { rinfo, stackCand, ... } vr = + let + val () = printfn R poff + `"puting %" ip vr `" to stack: " ip (!stackCand) % + in + updReg rinfo vr (VtStack (!stackCand)); + stackCand := !stackCand - 8 + end + + fun assignHardReg poff (I as { rinfo, pool, ... }) vr reg = + let + val () = printfn R poff + `"trying to assign hard reg " A1 preg reg `" to %" ip vr % + + val regIdx = reg2idx reg - firstUsedReg + val user = Array.sub (pool, regIdx) + + fun setOurReg () = + let + val () = printfn R (poff + 1) `"reg assigned" % + in + Array.update (pool, regIdx, SOME vr); + updReg rinfo vr (VtReg reg) + end + in + case user of + NONE => setOurReg () + | SOME u => + let + val () = printfn R (poff + 1) `"reg is taken by %" ip u % + val (aff, _) = Array.sub (rinfo, u) + in + case aff of + AfHard _ => raise Unreachable + | AfSoft affs => assignSoftReg poff affs I u + | AfUnk => assignFirstReg poff I u + ; + setOurReg () + end + end + + fun assignReg (I as { rinfo, ... }) (vr, _, _) = + let + val (aff, _) = Array.sub (rinfo, vr) + in + case aff of + AfUnk => assignFirstReg 0 I vr + | AfSoft affs => assignSoftReg 0 affs I vr + | AfHard reg => assignHardReg 0 I vr reg + end + + fun getPool () = Array.array (usedRegNum, NONE) + + fun changeInActive active newInt oldVr = + let + val a = !active + val a = List.filter (fn (v, _, _) => v <> oldVr) a + in + active := addToActive newInt a + end + + fun expropriateReg (I as { rinfo, pool, active, ... }) int reg = + let + val vr = #1 int + + val regIdx = reg2idx reg - firstUsedReg + val u = valOf $ Array.sub (pool, regIdx) + + val (uAff, _) = Array.sub (rinfo, u) + val () = + case uAff of + AfHard _ => raise Unreachable + | _ => () + + val () = putToStack 1 I u + val () = Array.update (pool, regIdx, SOME vr) + val () = updReg rinfo vr (VtReg reg) + in + changeInActive active int u + end + + fun userIdx pool vr = + let + fun loop idx = + if idx = Array.length pool then + raise Unreachable + else + case Array.sub (pool, idx) of + SOME u => + if u = vr then + idx + else + loop (idx + 1) + | NONE => loop (idx + 1) + in + loop 0 + end + + fun spillAtInterval (I as { rinfo, active, pool, ... }) int = + let + val spill = List.last (!active) + + val vr = #1 int + val (ourAff, _) = Array.sub (rinfo, vr) + + fun isNotHard vr = + case #1 $ Array.sub (rinfo, vr) of + AfHard _ => false + | _ => true + + val () = printfn `"SpilAtInt" % + val () = printfn R 0 + `"free registers: " Plist preg (freeRegList pool) (", ", true, 0) % + in + case ourAff of + AfHard reg => expropriateReg I int reg + | _ => + if #3 spill > #3 int andalso isNotHard (#1 spill) then + let + val idx = userIdx pool (#1 spill) + val () = printfn `"spilling!!!" % + in + Array.update (pool, idx, SOME vr); + updReg rinfo vr (VtReg $ idx2reg (idx + firstUsedReg)); + putToStack 1 I (#1 spill); + changeInActive active int (#1 spill) + end + else + putToStack 0 I vr + end + + fun linearscan rinfo ints = + let + fun incStart ((_, start1, _), (_, start2, _)) = start1 <= start2 + val ints = sort incStart ints + + val () = printInts ints + + fun loop _ [] = () + | loop (I as { active, ... }) (int :: ints) = + let + val () = printfn `"\n\ninspectiing interval " + ip (#1 int) `": (" ip (#2 int) `", " ip (#3 int) `")" % + + val () = expireOld I int + + val () = + if length (!active) = usedRegNum then + spillAtInterval I int + else + let + val () = assignReg I int + in + active := addToActive int (!active) + end + in + loop I ints + end + in + loop { active = ref [], pool = getPool (), rinfo, + stackCand = ref (~8) } ints + end - fun regAlloc (F as I.Fi { vregs, labels, ... }) = + fun printAllocVar rinfo v = let - val varsForAlloc = getVarsForAlloc vregs - val () = printfn `"for alloc: " Plist i varsForAlloc (", ", true, 0) % + val () = printf `"%" I v `": " % + val (_, vt) = Array.sub (rinfo, v) + in + case vt of + VtStack off => printfn `"stack " I off % + | VtReg reg => printfn `"reg " A1 preg reg % + | VtConst _ | VtUnk => raise Unreachable + end + + fun printAlloced rinfo toAlloc = + let + val () = printfn `"\nallocated:\n" % + in + List.app (printAllocVar rinfo) toAlloc + end + + fun regAlloc (F as I.Fi { vregs, ops, paramNum, ... }) = + let + val (toAlloc, regInfo) = prepareRegInfo paramNum ops vregs + val () = printfn `"for alloc: " Plist i toAlloc (", ", true, 0) % + + val () = affPrint regInfo - val intervals = computeInts F varsForAlloc + val intervals = computeInts F toAlloc - val () = printInts intervals + val () = linearscan regInfo intervals + val () = printAlloced regInfo toAlloc in raise Unimplemented end |