summaryrefslogtreecommitdiff
path: root/emit.fun
diff options
context:
space:
mode:
authorVladimir Azarov <avm@intermediate-node.net>2025-08-08 19:07:58 +0200
committerVladimir Azarov <avm@intermediate-node.net>2025-08-08 19:07:58 +0200
commita417225089fd78d53d73ad63cd79f57d1a4a8ff1 (patch)
treed9da68b0414fdaf08ddccbae20bd0e2977cdca25 /emit.fun
parentb0cb85edf2b60f6f0909355db717376f435ab312 (diff)
Register allocation
Diffstat (limited to 'emit.fun')
-rw-r--r--emit.fun697
1 files changed, 667 insertions, 30 deletions
diff --git a/emit.fun b/emit.fun
index 6e21556..c794f26 100644
--- a/emit.fun
+++ b/emit.fun
@@ -5,8 +5,54 @@ functor Emit(I: IL) = struct
structure D = P.D
structure PP = P.P
+ val ip = I
+
val file = ref NONE
+ datatype reg =
+ Rax | Rbx | Rcx | Rdx | Rsi | Rdi | Rbp | Rsp | R8 |
+ R9 | R10 | R11 | R12 | R13 | R14 | R15
+
+ datatype vConst = VConst of word | VAddrConst of int * word
+ datatype vrType =
+ VtConst of vConst | VtReg of reg | VtStack of int | VtUnk
+
+ datatype affinity = AfHard of reg | AfSoft of int list | AfUnk
+
+ val regs = [
+ (Rax, 0),
+ (Rdx, 1),
+ (Rsp, 2),
+ (Rbp, 3),
+
+ (Rcx, 4),
+ (Rsi, 5),
+ (Rdi, 6),
+ (R8, 7),
+ (R9, 8),
+ (R10, 9),
+ (R11, 10),
+
+ (Rbx, 11),
+ (R12, 12),
+ (R13, 13),
+ (R14, 14),
+ (R15, 15)
+ ]
+
+ val firstUsedReg = 4
+ val usedRegNum = 12
+
+ fun reg2idx reg =
+ case List.find (fn (r, _) => r = reg) regs of
+ NONE => raise Unreachable
+ | SOME (_, idx) => idx
+
+ fun idx2reg idx =
+ case List.find (fn (_, i) => i = idx) regs of
+ NONE => raise Unreachable
+ | SOME (r, _) => r
+
local
fun output s =
let
@@ -120,24 +166,6 @@ functor Emit(I: IL) = struct
D.appi f P.iniLayouts
end
- fun getVarsForAlloc vregs =
- let
- fun loop idx acc =
- if idx = D.length vregs then
- rev acc
- else
- let
- val { t, ... } = D.get vregs idx
- in
- if t = I.RtReg then
- loop (idx + 1) (idx :: acc)
- else
- loop (idx + 1) acc
- end
- in
- loop 0 []
- end
-
fun extendEnd (iStart, iEnd) ops labels =
let
fun loop idx iEnd =
@@ -145,7 +173,7 @@ functor Emit(I: IL) = struct
iEnd
else
let
- val ins = D.get ops idx
+ val (ins, _) = D.get ops idx
in
case ins of
SOME (I.IrJmp lid) | SOME (I.IrJz (_, lid)) |
@@ -167,11 +195,24 @@ functor Emit(I: IL) = struct
loop iEnd iEnd
end
- fun computeIntLocal (s, e) ops labels =
+ fun computeIntLocal (s, e) firstDef ops labels =
let
val e = extendEnd (s, e) ops labels
+
+ val (_, li) = D.get ops firstDef
in
- (s, e)
+ case li of
+ SOME (startL, endL) =>
+ let
+ val (startLoop, endLoop) =
+ (valOf $ D.get labels startL, valOf $ D.get labels endL)
+
+ val s = if s < startLoop then s else startLoop
+ val e = if e > endLoop then e else endLoop
+ in
+ (s, e)
+ end
+ | _ => (s, e)
end
fun getBasicInt [] _ = raise Unreachable
@@ -194,33 +235,629 @@ functor Emit(I: IL) = struct
val (iStart, iEnd) =
if var < localBound then
- computeIntLocal (iStart, iEnd) ops labels
+ computeIntLocal (iStart, iEnd) (List.last defs) ops labels
else
(iStart, iEnd)
in
(var, iStart, iEnd)
end
- fun computeInts (F as I.Fi { vregs, ... }) vars =
- List.map (computeInt F) vars
+ fun computeInts F vars = List.map (computeInt F) vars
fun printInts ints =
let
- val () = printfn `"\nintervals:\n" %
+ val () = printfn `"\nsorted intervals:\n" %
fun p (id, s, e) = printfn `"id: %" I id `" {" I s `", " I e `"}" %
in
List.app p ints
end
+ fun updAff arr idx aff =
+ let
+ val (_, vt) = Array.sub (arr, idx)
+ in
+ Array.update (arr, idx, (aff, vt))
+ end
+
+ datatype insAff = IaNone | IaHard of (int * reg) list
+ | IaSoft of int * int list
+
+ fun parNum2reg pr =
+ case pr of
+ 0 => Rdi
+ | 1 => Rsi
+ | 2 => Rdx
+ | 3 => Rcx
+ | 5 => R8
+ | 6 => R9
+ | _ => raise Unreachable
+
+ fun getInsAff (SOME ins) =
+ let
+ fun tr (rd, rs1, rs2) = IaSoft (rd, [rs1, rs2])
+
+ fun setAff (rd, I.SaVReg rs) = IaSoft (rd, [rs])
+ | setAff _ = IaNone
+
+ fun fcallAff args =
+ let
+ fun collect idx (arg :: args) acc =
+ collect (idx + 1) args ((arg, parNum2reg idx) :: acc)
+ | collect _ [] acc = rev acc
+ in
+ IaHard $ collect 0 args []
+ end
+ in
+ case ins of
+ I.IrSet p => setAff p
+ | I.IrAdd t => tr t
+ | I.IrSub t => tr t
+ | I.IrMul t => tr t
+ | I.IrIMul t => tr t
+ | I.IrDiv t => tr t
+ | I.IrIDiv t => tr t
+ | I.IrMod t => tr t
+ | I.IrIMod t => tr t
+ | I.IrShl t => tr t
+ | I.IrShr t => tr t
+ | I.IrSar t => tr t
+ | I.IrAnd t => tr t
+ | I.IrOr t => tr t
+ | I.IrXor t => tr t
+ | I.IrEq t => tr t
+ | I.IrNeq t => tr t
+ | I.IrCmpul t => tr t
+ | I.IrCmpug t => tr t
+ | I.IrCmpule t => tr t
+ | I.IrCmpuge t => tr t
+ | I.IrCmpsl t => tr t
+ | I.IrCmpsg t => tr t
+ | I.IrCmpsle t => tr t
+ | I.IrCmpsge t => tr t
+
+ | I.IrExtZero _ | I.IrExtSign _
+ | I.IrLoad _ | I.IrStore _ | I.IrJmp _
+ | I.IrJz _ | I.IrJnz _ | I.IrNopLabel _
+ | I.IrNop _ | I.IrRet _ | I.IrAlloc _
+ | I.IrCopy _ => IaNone
+ | I.IrFcall (_, _, args) => fcallAff args
+ end
+ | getInsAff NONE = IaNone
+
+ fun updateSoftAff rinfo rd rss =
+ let
+ fun sort [r] = [r]
+ | sort [rs1, rs2] = if rs1 < rs2 then [rs1, rs2] else [rs2, rs1]
+ | sort _ = raise Unreachable
+
+ fun isNotConst rv =
+ let
+ val (_, vt) = Array.sub (rinfo, rv)
+ in
+ case vt of
+ VtConst _ => false
+ | _ => true
+ end
+
+ val (aff, vt) = Array.sub (rinfo, rd)
+ val rss = List.filter isNotConst $ sort rss
+
+ fun insertSorted ins [] = ins
+ | insertSorted [] acc = acc
+ | insertSorted (x :: xs) (y :: ys) =
+ if x < y then
+ x :: insertSorted xs (y :: ys)
+ else
+ y :: insertSorted (x :: xs) ys
+
+ val aff =
+ case aff of
+ AfUnk => AfSoft rss
+ | AfSoft affs => AfSoft $ insertSorted rss affs
+ | AfHard _ => aff
+ in
+ Array.update (rinfo, rd, (aff, vt))
+ end
+
+ fun updateHardAff rinfo hards =
+ let
+ fun f (rd, reg) =
+ let
+ val (aff, vt) = Array.sub (rinfo, rd)
+
+ val aff =
+ case aff of
+ AfUnk | AfSoft _ => AfHard reg
+ | AfHard _ => raise Unreachable
+ in
+ Array.update (rinfo, rd, (aff, vt))
+ end
+ in
+ List.app f hards
+ end
+
+ fun compAffinity rinfo ops paramNum =
+ let
+ fun compParams idx =
+ if idx = paramNum then
+ ()
+ else
+ let
+ val reg = parNum2reg idx
+ in
+ updAff rinfo idx (AfHard reg);
+ compParams (idx + 1)
+ end
+ val () = compParams 0
+
+ fun loop idx =
+ if idx = D.length ops then
+ ()
+ else
+ let
+ val (ins, _) = D.get ops idx
+ in
+ case getInsAff ins of
+ IaNone => ()
+ | IaSoft (rd, rss) => updateSoftAff rinfo rd rss
+ | IaHard hards => updateHardAff rinfo hards;
+ loop (idx + 1)
+ end
+ in
+ loop 0
+ end
+
+ fun prepareRegInfo paramNum ops vregs =
+ let
+ val rinfo = Array.array (D.length vregs, (AfUnk, VtUnk))
+
+ fun transfer idx acc =
+ if idx = D.length vregs then
+ rev acc
+ else
+ let
+ val (vt, cand) =
+ case #t $ D.get vregs idx of
+ I.RtRem => (VtUnk, NONE)
+ | I.RtConst w => (VtConst (VConst w), NONE)
+ | I.RtAddrConst (id, w) => (VtConst (VAddrConst (id, w)), NONE)
+ | I.RtReg => (VtUnk, SOME idx)
+ in
+ Array.update (rinfo, idx, (AfUnk, vt));
+ transfer (idx + 1) (if isSome cand then valOf cand :: acc else acc)
+ end
+
+ val toAlloc = transfer 0 []
+ in
+ compAffinity rinfo ops paramNum;
+ (toAlloc, rinfo)
+ end
+
+ fun preg reg out =
+ let
+ val s =
+ case reg of
+ Rax => "rax"
+ | Rbx => "rbx"
+ | Rcx => "rcx"
+ | Rdx => "rdx"
+ | Rsi => "rsi"
+ | Rdi => "rdi"
+ | Rbp => "rbp"
+ | Rsp => "rsp"
+ | R8 => "r8"
+ | R9 => "r9"
+ | R10 => "r10"
+ | R11 => "r11"
+ | R12 => "r12"
+ | R13 => "r13"
+ | R14 => "r14"
+ | R15 => "r15"
+ in
+ Printf out `s %
+ end
+
+ val Preg = fn z => bind A1 preg z
+
+ fun affPrint rinfo =
+ let
+ fun pv idx out = Printf out `"%" I idx %
+
+ fun p (idx, (aff, _)) =
+ let
+ val () = printf `"%" I idx %
+ in
+ case aff of
+ AfUnk => printfn `" = unk" %
+ | AfHard reg => printfn `" <- " Preg reg %
+ | AfSoft rss => printfn `" <- " Plist pv rss (", ", true, 1) %
+ end
+ in
+ Array.appi p rinfo
+ end
+
+ fun sort _ [] = []
+ | sort _ [x] = [x]
+ | sort le l =
+ let
+ fun divide [] accp = accp
+ | divide [x] (acc1, acc2) = (x :: acc1, acc2)
+ | divide (x :: y :: tail) (acc1, acc2) =
+ divide tail (x :: acc1, y :: acc2)
+ val (part1, part2) = divide l ([], [])
+ val part1 = sort le part1
+ val part2 = sort le part2
+
+ fun merge [] [] acc = acc
+ | merge [] ys acc = rev $ List.revAppend (ys, acc)
+ | merge xs [] acc = rev $ List.revAppend (xs, acc)
+ | merge (x :: xs) (y :: ys) acc =
+ if le (x, y) then
+ merge xs (y :: ys) (x :: acc)
+ else
+ merge (x :: xs) ys (y :: acc)
+ in
+ merge part1 part2 []
+ end
+
+ fun updateI i = fn z =>
+ let
+ fun from rinfo active pool intervals stackCand =
+ { rinfo, active, pool, intervals, stackCand }
+ fun to f { rinfo, active, pool, intervals, stackCand } =
+ f rinfo active pool intervals stackCand
+ in
+ FRU.makeUpdate5 (from, from, to) i
+ end z
+
+ fun returnToPool pool reg =
+ let
+ val idx = reg2idx reg - firstUsedReg
+ in
+ Array.update (pool, idx, NONE)
+ end
+
+ fun expireOne { rinfo, active, pool, ... } (_, start, _) =
+ case !active of
+ [] => false
+ | (j, startp, endp) :: acts =>
+ if endp > start then
+ false
+ else
+ let
+ val (_, vt) = Array.sub (rinfo, j)
+ val reg = case vt of VtReg reg => reg | _ => raise Unreachable
+
+ val () = printfn `"III!!! interval %"
+ ip j `"(" ip startp `", " ip endp `") "
+ `"with " Preg reg `" has expired" %
+ in
+ returnToPool pool reg;
+ active := acts;
+ true
+ end
+
+ fun expireOld (I as { active, ... }) int =
+ let
+ fun loop I =
+ case expireOne I int of
+ false => ()
+ | true => loop I
+ in
+ case !active of
+ [] => ()
+ | _ => loop I
+ end
+
+ fun addToActive int [] = [int]
+ | addToActive (I as (_, _, e1)) (act :: acts) =
+ if e1 < #3 act then
+ (I :: act :: acts)
+ else
+ act :: addToActive I acts
+
+ fun updReg arr idx reg =
+ let
+ val (aff, _) = Array.sub (arr, idx)
+ in
+ Array.update (arr, idx, (aff, reg))
+ end
+
+ fun assignFirstReg poff { rinfo, pool, ... } vr =
+ let
+ fun loop idx =
+ if idx = Array.length pool then
+ raise Unreachable
+ else
+ let
+ val user = Array.sub (pool, idx)
+ in
+ case user of
+ SOME _ => loop (idx + 1)
+ | NONE =>
+ let
+ val () = Array.update (pool, idx, SOME vr);
+ val reg = idx2reg (firstUsedReg + idx)
+
+ val () = printfn R poff
+ `"assigned (first) reg " Preg reg `" to %" ip vr %
+ in
+ updReg rinfo vr (VtReg reg)
+ end
+ end
+ in
+ loop 0
+ end
+
+ fun freeRegList pool =
+ let
+ fun loop idx acc =
+ if idx = Array.length pool then
+ rev acc
+ else
+ case Array.sub (pool, idx) of
+ NONE => loop (idx + 1) (idx2reg (idx + firstUsedReg) :: acc)
+ | SOME _ => loop (idx + 1) acc
+ in
+ loop 0 []
+ end
+
+ fun getAffRegList rinfo affs =
+ let
+ fun loop [] acc = rev acc
+ | loop (vr :: vrs) acc =
+ let
+ val (_, vt) = Array.sub (rinfo, vr)
+ in
+ case vt of
+ VtReg r => loop vrs (r :: acc)
+ | _ => loop vrs acc
+ end
+ in
+ loop affs []
+ end
+
+ fun findCommonRegs l1 l2 =
+ let
+ val l1 = sort (fn (r1, r2) => reg2idx r1 <= reg2idx r2) l1
+ val l2 = sort (fn (r1, r2) => reg2idx r1 <= reg2idx r2) l2
+
+ fun intersection [] _ = []
+ | intersection _ [] = []
+ | intersection (x :: xs) (y :: ys) =
+ case Int.compare (reg2idx x, reg2idx y) of
+ LESS => intersection xs (y :: ys)
+ | EQUAL => x :: intersection xs ys
+ | GREATER => intersection (x :: xs) ys
+ in
+ intersection l1 l2
+ end
+
+ fun assignSoftReg poff affs (I as { rinfo, pool, ... }) vr =
+ let
+ val () = printfn R poff
+ `"trying to assign register (by affinity) to %" ip vr %
+
+ val regs = freeRegList pool
+ val affRegs = getAffRegList rinfo affs
+ val common = findCommonRegs regs affRegs
+
+ val () = printfn R (poff + 1)
+ `"free registers: " Plist preg regs (", ", true, 0) %
+ val () = printfn R (poff + 1)
+ `"affinity registers: " Plist preg affRegs (", ", true, 0) %
+ in
+ case common of
+ [] =>
+ let
+ val () = printfn R (poff + 1) `"affinity was not satisfied" %
+ in
+ assignFirstReg (poff + 2) I vr
+ end
+ | (reg :: _) =>
+ let
+ in
+ updReg rinfo vr (VtReg reg);
+ Array.update (pool, reg2idx reg - firstUsedReg, SOME vr);
+ printfn R (poff + 1)
+ `"assigned (by affinity) reg " Preg reg `" to %" ip vr %;
+ printfn R (poff + 1)
+ `"free registers: " Plist preg (freeRegList pool) (", ", true, 0) %
+ end
+ end
+
+ fun putToStack poff { rinfo, stackCand, ... } vr =
+ let
+ val () = printfn R poff
+ `"puting %" ip vr `" to stack: " ip (!stackCand) %
+ in
+ updReg rinfo vr (VtStack (!stackCand));
+ stackCand := !stackCand - 8
+ end
+
+ fun assignHardReg poff (I as { rinfo, pool, ... }) vr reg =
+ let
+ val () = printfn R poff
+ `"trying to assign hard reg " A1 preg reg `" to %" ip vr %
+
+ val regIdx = reg2idx reg - firstUsedReg
+ val user = Array.sub (pool, regIdx)
+
+ fun setOurReg () =
+ let
+ val () = printfn R (poff + 1) `"reg assigned" %
+ in
+ Array.update (pool, regIdx, SOME vr);
+ updReg rinfo vr (VtReg reg)
+ end
+ in
+ case user of
+ NONE => setOurReg ()
+ | SOME u =>
+ let
+ val () = printfn R (poff + 1) `"reg is taken by %" ip u %
+ val (aff, _) = Array.sub (rinfo, u)
+ in
+ case aff of
+ AfHard _ => raise Unreachable
+ | AfSoft affs => assignSoftReg poff affs I u
+ | AfUnk => assignFirstReg poff I u
+ ;
+ setOurReg ()
+ end
+ end
+
+ fun assignReg (I as { rinfo, ... }) (vr, _, _) =
+ let
+ val (aff, _) = Array.sub (rinfo, vr)
+ in
+ case aff of
+ AfUnk => assignFirstReg 0 I vr
+ | AfSoft affs => assignSoftReg 0 affs I vr
+ | AfHard reg => assignHardReg 0 I vr reg
+ end
+
+ fun getPool () = Array.array (usedRegNum, NONE)
+
+ fun changeInActive active newInt oldVr =
+ let
+ val a = !active
+ val a = List.filter (fn (v, _, _) => v <> oldVr) a
+ in
+ active := addToActive newInt a
+ end
+
+ fun expropriateReg (I as { rinfo, pool, active, ... }) int reg =
+ let
+ val vr = #1 int
+
+ val regIdx = reg2idx reg - firstUsedReg
+ val u = valOf $ Array.sub (pool, regIdx)
+
+ val (uAff, _) = Array.sub (rinfo, u)
+ val () =
+ case uAff of
+ AfHard _ => raise Unreachable
+ | _ => ()
+
+ val () = putToStack 1 I u
+ val () = Array.update (pool, regIdx, SOME vr)
+ val () = updReg rinfo vr (VtReg reg)
+ in
+ changeInActive active int u
+ end
+
+ fun userIdx pool vr =
+ let
+ fun loop idx =
+ if idx = Array.length pool then
+ raise Unreachable
+ else
+ case Array.sub (pool, idx) of
+ SOME u =>
+ if u = vr then
+ idx
+ else
+ loop (idx + 1)
+ | NONE => loop (idx + 1)
+ in
+ loop 0
+ end
+
+ fun spillAtInterval (I as { rinfo, active, pool, ... }) int =
+ let
+ val spill = List.last (!active)
+
+ val vr = #1 int
+ val (ourAff, _) = Array.sub (rinfo, vr)
+
+ fun isNotHard vr =
+ case #1 $ Array.sub (rinfo, vr) of
+ AfHard _ => false
+ | _ => true
+
+ val () = printfn `"SpilAtInt" %
+ val () = printfn R 0
+ `"free registers: " Plist preg (freeRegList pool) (", ", true, 0) %
+ in
+ case ourAff of
+ AfHard reg => expropriateReg I int reg
+ | _ =>
+ if #3 spill > #3 int andalso isNotHard (#1 spill) then
+ let
+ val idx = userIdx pool (#1 spill)
+ val () = printfn `"spilling!!!" %
+ in
+ Array.update (pool, idx, SOME vr);
+ updReg rinfo vr (VtReg $ idx2reg (idx + firstUsedReg));
+ putToStack 1 I (#1 spill);
+ changeInActive active int (#1 spill)
+ end
+ else
+ putToStack 0 I vr
+ end
+
+ fun linearscan rinfo ints =
+ let
+ fun incStart ((_, start1, _), (_, start2, _)) = start1 <= start2
+ val ints = sort incStart ints
+
+ val () = printInts ints
+
+ fun loop _ [] = ()
+ | loop (I as { active, ... }) (int :: ints) =
+ let
+ val () = printfn `"\n\ninspectiing interval "
+ ip (#1 int) `": (" ip (#2 int) `", " ip (#3 int) `")" %
+
+ val () = expireOld I int
+
+ val () =
+ if length (!active) = usedRegNum then
+ spillAtInterval I int
+ else
+ let
+ val () = assignReg I int
+ in
+ active := addToActive int (!active)
+ end
+ in
+ loop I ints
+ end
+ in
+ loop { active = ref [], pool = getPool (), rinfo,
+ stackCand = ref (~8) } ints
+ end
- fun regAlloc (F as I.Fi { vregs, labels, ... }) =
+ fun printAllocVar rinfo v =
let
- val varsForAlloc = getVarsForAlloc vregs
- val () = printfn `"for alloc: " Plist i varsForAlloc (", ", true, 0) %
+ val () = printf `"%" I v `": " %
+ val (_, vt) = Array.sub (rinfo, v)
+ in
+ case vt of
+ VtStack off => printfn `"stack " I off %
+ | VtReg reg => printfn `"reg " A1 preg reg %
+ | VtConst _ | VtUnk => raise Unreachable
+ end
+
+ fun printAlloced rinfo toAlloc =
+ let
+ val () = printfn `"\nallocated:\n" %
+ in
+ List.app (printAllocVar rinfo) toAlloc
+ end
+
+ fun regAlloc (F as I.Fi { vregs, ops, paramNum, ... }) =
+ let
+ val (toAlloc, regInfo) = prepareRegInfo paramNum ops vregs
+ val () = printfn `"for alloc: " Plist i toAlloc (", ", true, 0) %
+
+ val () = affPrint regInfo
- val intervals = computeInts F varsForAlloc
+ val intervals = computeInts F toAlloc
- val () = printInts intervals
+ val () = linearscan regInfo intervals
+ val () = printAlloced regInfo toAlloc
in
raise Unimplemented
end