functor Emit(I: IL) = struct structure I = I structure P = I.P structure D = P.D structure PP = P.P val ip = I val file = ref NONE datatype reg = Rax | Rbx | Rcx | Rdx | Rsi | Rdi | Rbp | Rsp | R8 | R9 | R10 | R11 | R12 | R13 | R14 | R15 datatype vConst = VConst of word | VAddrConst of int * word datatype vrType = VtConst of vConst | VtReg of reg | VtStack of int | VtUnk datatype affinity = AfHard of reg | AfSoft of int list | AfUnk val regs = [ (Rcx, 0), (Rsi, 1), (Rdi, 2), (R8, 3), (R9, 4), (R10, 5), (R11, 6), (Rbx, 7), (R12, 8), (R13, 9), (R14, 10), (R15, 11), (Rax, 12), (Rdx, 13), (Rsp, 14), (Rbp, 15) ] val callerSavedRegs = 7 val usedRegNum = 12 fun reg2idx reg = case List.find (fn (r, _) => r = reg) regs of NONE => raise Unreachable | SOME (_, idx) => idx fun idx2reg idx = case List.find (fn (_, i) => i = idx) regs of NONE => raise Unreachable | SOME (r, _) => r local fun output s = let val outstream = valOf $ !file in TextIO.output (outstream, s) end val ctx = ((false, makePrintfBase output), fn (_: bool * ((string -> unit) * (unit -> unit))) => ()) in fun fprint g = Fold.fold ctx g end fun fprintt g = fprint `"\t" g fun fprinttn g = fprintt (fn (a, _) => g (a, fn (_, out) => (Printf out `"\n" %))) fun handleBSS objsZI = let val () = fprint `"section .bss\n" % fun handleObj (id, _, t, _, _) = let val align = P.alignOfType t val size = P.sizeOfType t in fprinttn `"align\t" W align %; fprint PP.? id `":\tresb " W size `"\n" % end in List.app handleObj objsZI end fun dd size w = let val cmd = case size of 0w1 => "db" | 0w2 => "dw" | 0w4 => "dd" | 0w8 => "dq" | _ => raise Unreachable in fprint `cmd `" " W w % end fun emitAggrLayout id = let val (_, size, layout) = D.get P.iniLayouts id val () = fprint `"\n" % fun getPadding offset t [] = size - (offset + P.sizeOfType t) | getPadding offset t ({ offset = offset', ... } :: _) = offset' - (offset + P.sizeOfType t) fun emitScalars ({ offset, t, value } :: tail) = let val () = fprint `"\t" % val () = dd (P.sizeOfType t) value val padding = getPadding offset t tail in if padding > 0w0 then fprint `"\n\tresb " W padding `"\n" % else fprint `"\n" %; emitScalars tail end | emitScalars [] = () in emitScalars layout end fun handleData objs = let val () = fprint `"section .data\n" % fun emitLayout (id, _, t, ini, _) = let val align = P.alignOfType t val () = fprinttn `"align\t" W align % val () = fprint PP.? id `":" % in case ini of P.CiniLayout id => emitAggrLayout id | P.CiniExpr _ => raise Unreachable end in List.app emitLayout objs end fun handleStrlits strlits = let fun f id = fprint `".S" I id `":\tdb " `(PP.?? id) `", 0\n" % in fprint `"\n" %; List.app f strlits end fun handleLocalIniLayouts () = let fun f (_, (true, _, _)) = () | f (n, (false, _, _)) = ( fprint `"\talign 16\n" %; fprint `".I" I n `":" %; emitAggrLayout n ) in D.appi f P.iniLayouts end fun extendEnd (iStart, iEnd) ops labels = let fun loop idx iEnd = if idx = D.length ops then iEnd else let val (ins, _) = D.get ops idx in case ins of SOME (I.IrJmp lid) | SOME (I.IrJz (_, lid)) | SOME (I.IrJnz (_, lid)) => let val ldest = valOf $ D.get labels lid val iEnd = if ldest > iStart andalso ldest < iEnd then idx else iEnd in loop (idx + 1) iEnd end | _ => loop (idx + 1) iEnd end in loop iEnd iEnd end fun computeIntLocal (s, e) firstDef ops labels = let val e = extendEnd (s, e) ops labels val (_, li) = D.get ops firstDef in case li of SOME (startL, endL) => let val (startLoop, endLoop) = (valOf $ D.get labels startL, valOf $ D.get labels endL) val s = if s < startLoop then s else startLoop val e = if e > endLoop then e else endLoop in (s, e) end | _ => (s, e) end fun getBasicInt [] _ = raise Unreachable | getBasicInt defs [] = (List.last defs, hd defs + 1) | getBasicInt defs use = let val (firstDef, lastDef) = (List.last defs, hd defs) val (firstUse, lastUse) = (List.last use, hd use) val first = if firstDef < firstUse then firstDef else firstUse - 1 val last = if lastDef < lastUse then lastUse else lastDef + 1 in (first, last) end fun computeInt (I.Fi { vregs, ops, localBound, labels, ... }) var = let val { defs, use, ... } = D.get vregs var val (iStart, iEnd) = getBasicInt defs use val (iStart, iEnd) = if var < localBound then computeIntLocal (iStart, iEnd) (List.last defs) ops labels else (iStart, iEnd) in (var, iStart, iEnd) end fun computeInts F vars = List.map (computeInt F) vars fun printInts ints = let val () = printfn `"\nsorted intervals:\n" % fun p (id, s, e) = printfn `"id: %" I id `" {" I s `", " I e `"}" % in List.app p ints end fun updAff arr idx aff = let val (_, vt) = Array.sub (arr, idx) in Array.update (arr, idx, (aff, vt)) end datatype insAff = IaNone | IaHard of (int * reg) list | IaSoft of int * int list fun parNum2reg pr = case pr of 0 => Rdi | 1 => Rsi | 2 => Rdx | 3 => Rcx | 5 => R8 | 6 => R9 | _ => raise Unreachable fun getInsAff (SOME ins) = let fun tr (rd, rs1, rs2) = IaSoft (rd, [rs1, rs2]) fun setAff (rd, I.SaVReg rs) = IaSoft (rd, [rs]) | setAff _ = IaNone fun fcallAff args = let fun collect idx (arg :: args) acc = collect (idx + 1) args ((arg, parNum2reg idx) :: acc) | collect _ [] acc = rev acc in IaHard $ collect 0 args [] end in case ins of I.IrSet p => setAff p | I.IrAdd t => tr t | I.IrSub t => tr t | I.IrMul t => tr t | I.IrIMul t => tr t | I.IrDiv t => tr t | I.IrIDiv t => tr t | I.IrMod t => tr t | I.IrIMod t => tr t | I.IrShl t => tr t | I.IrShr t => tr t | I.IrSar t => tr t | I.IrAnd t => tr t | I.IrOr t => tr t | I.IrXor t => tr t | I.IrEq t => tr t | I.IrNeq t => tr t | I.IrCmpul t => tr t | I.IrCmpug t => tr t | I.IrCmpule t => tr t | I.IrCmpuge t => tr t | I.IrCmpsl t => tr t | I.IrCmpsg t => tr t | I.IrCmpsle t => tr t | I.IrCmpsge t => tr t | I.IrExtZero _ | I.IrExtSign _ | I.IrLoad _ | I.IrStore _ | I.IrJmp _ | I.IrJz _ | I.IrJnz _ | I.IrNopLabel _ | I.IrNop _ | I.IrRet _ | I.IrAlloc _ | I.IrCopy _ => IaNone | I.IrFcall (_, _, args) => fcallAff args end | getInsAff NONE = IaNone fun updateSoftAff rinfo rd rss = let fun sort [r] = [r] | sort [rs1, rs2] = if rs1 < rs2 then [rs1, rs2] else [rs2, rs1] | sort _ = raise Unreachable fun isNotConst rv = let val (_, vt) = Array.sub (rinfo, rv) in case vt of VtConst _ => false | _ => true end val (aff, vt) = Array.sub (rinfo, rd) val rss = List.filter isNotConst $ sort rss fun insertSorted ins [] = ins | insertSorted [] acc = acc | insertSorted (x :: xs) (y :: ys) = if x < y then x :: insertSorted xs (y :: ys) else y :: insertSorted (x :: xs) ys val aff = case aff of AfUnk => AfSoft rss | AfSoft affs => AfSoft $ insertSorted rss affs | AfHard _ => aff in Array.update (rinfo, rd, (aff, vt)) end fun updateHardAff rinfo hards = let fun f (rd, reg) = let val (aff, vt) = Array.sub (rinfo, rd) val aff = case aff of AfUnk | AfSoft _ => AfHard reg | AfHard _ => raise Unreachable in Array.update (rinfo, rd, (aff, vt)) end in List.app f hards end fun compAffinity rinfo ops paramNum = let fun compParams idx = if idx = paramNum then () else let val reg = parNum2reg idx in updAff rinfo idx (AfHard reg); compParams (idx + 1) end val () = compParams 0 fun loop idx = if idx = D.length ops then () else let val (ins, _) = D.get ops idx in case getInsAff ins of IaNone => () | IaSoft (rd, rss) => updateSoftAff rinfo rd rss | IaHard hards => updateHardAff rinfo hards; loop (idx + 1) end in loop 0 end fun prepareRegInfo paramNum ops vregs = let val rinfo = Array.array (D.length vregs, (AfUnk, VtUnk)) fun transfer idx acc = if idx = D.length vregs then rev acc else let val (vt, cand) = case #t $ D.get vregs idx of I.RtRem => (VtUnk, NONE) | I.RtConst w => (VtConst (VConst w), NONE) | I.RtAddrConst (id, w) => (VtConst (VAddrConst (id, w)), NONE) | I.RtReg => (VtUnk, SOME idx) in Array.update (rinfo, idx, (AfUnk, vt)); transfer (idx + 1) (if isSome cand then valOf cand :: acc else acc) end val toAlloc = transfer 0 [] in compAffinity rinfo ops paramNum; (toAlloc, rinfo) end fun preg reg out = let val s = case reg of Rax => "rax" | Rbx => "rbx" | Rcx => "rcx" | Rdx => "rdx" | Rsi => "rsi" | Rdi => "rdi" | Rbp => "rbp" | Rsp => "rsp" | R8 => "r8" | R9 => "r9" | R10 => "r10" | R11 => "r11" | R12 => "r12" | R13 => "r13" | R14 => "r14" | R15 => "r15" in Printf out `s % end val Preg = fn z => bind A1 preg z fun affPrint rinfo = let fun pv idx out = Printf out `"%" I idx % fun p (idx, (aff, _)) = let val () = printf `"%" I idx % in case aff of AfUnk => printfn `" = unk" % | AfHard reg => printfn `" <- " Preg reg % | AfSoft rss => printfn `" <- " Plist pv rss (", ", true, 1) % end in Array.appi p rinfo end fun sort _ [] = [] | sort _ [x] = [x] | sort le l = let fun divide [] accp = accp | divide [x] (acc1, acc2) = (x :: acc1, acc2) | divide (x :: y :: tail) (acc1, acc2) = divide tail (x :: acc1, y :: acc2) val (part1, part2) = divide l ([], []) val part1 = sort le part1 val part2 = sort le part2 fun merge [] [] acc = acc | merge [] ys acc = rev $ List.revAppend (ys, acc) | merge xs [] acc = rev $ List.revAppend (xs, acc) | merge (x :: xs) (y :: ys) acc = if le (x, y) then merge xs (y :: ys) (x :: acc) else merge (x :: xs) ys (y :: acc) in merge part1 part2 [] end fun updateI i = fn z => let fun from rinfo active pool intervals stackCand = { rinfo, active, pool, intervals, stackCand } fun to f { rinfo, active, pool, intervals, stackCand } = f rinfo active pool intervals stackCand in FRU.makeUpdate5 (from, from, to) i end z fun returnToPool pool reg = let val idx = reg2idx reg in Array.update (pool, idx, NONE) end fun expireOne { rinfo, active, pool, ... } (_, start, _) = case !active of [] => false | (j, startp, endp) :: acts => if endp > start then false else let val (_, vt) = Array.sub (rinfo, j) val reg = case vt of VtReg reg => reg | _ => raise Unreachable val () = printfn `"III!!! interval %" ip j `"(" ip startp `", " ip endp `") " `"with " Preg reg `" has expired" % in returnToPool pool reg; active := acts; true end fun expireOld (I as { active, ... }) int = let fun loop I = case expireOne I int of false => () | true => loop I in case !active of [] => () | _ => loop I end fun addToActive int [] = [int] | addToActive (I as (_, _, e1)) (act :: acts) = if e1 < #3 act then (I :: act :: acts) else act :: addToActive I acts fun updReg arr idx reg = let val (aff, _) = Array.sub (arr, idx) in Array.update (arr, idx, (aff, reg)) end fun assignFirstReg poff { rinfo, pool, ... } vr = let fun loop idx = if idx = Array.length pool then raise Unreachable else let val user = Array.sub (pool, idx) in case user of SOME _ => loop (idx + 1) | NONE => let val () = Array.update (pool, idx, SOME vr); val reg = idx2reg idx val () = printfn R poff `"assigned (first) reg " Preg reg `" to %" ip vr % in updReg rinfo vr (VtReg reg) end end in loop 0 end fun freeRegList pool = let fun loop idx acc = if idx = Array.length pool then rev acc else case Array.sub (pool, idx) of NONE => loop (idx + 1) (idx2reg idx :: acc) | SOME _ => loop (idx + 1) acc in loop 0 [] end fun getAffRegList rinfo affs = let fun loop [] acc = rev acc | loop (vr :: vrs) acc = let val (_, vt) = Array.sub (rinfo, vr) in case vt of VtReg r => loop vrs (r :: acc) | _ => loop vrs acc end in loop affs [] end fun findCommonRegs l1 l2 = let val l1 = sort (fn (r1, r2) => reg2idx r1 <= reg2idx r2) l1 val l2 = sort (fn (r1, r2) => reg2idx r1 <= reg2idx r2) l2 fun intersection [] _ = [] | intersection _ [] = [] | intersection (x :: xs) (y :: ys) = case Int.compare (reg2idx x, reg2idx y) of LESS => intersection xs (y :: ys) | EQUAL => x :: intersection xs ys | GREATER => intersection (x :: xs) ys in intersection l1 l2 end fun assignSoftReg poff affs (I as { rinfo, pool, ... }) vr = let val () = printfn R poff `"trying to assign register (by affinity) to %" ip vr % val regs = freeRegList pool val affRegs = getAffRegList rinfo affs val common = findCommonRegs regs affRegs val () = printfn R (poff + 1) `"free registers: " Plist preg regs (", ", true, 0) % val () = printfn R (poff + 1) `"affinity registers: " Plist preg affRegs (", ", true, 0) % in case common of [] => let val () = printfn R (poff + 1) `"affinity was not satisfied" % in assignFirstReg (poff + 2) I vr end | (reg :: _) => let in updReg rinfo vr (VtReg reg); Array.update (pool, reg2idx reg, SOME vr); printfn R (poff + 1) `"assigned (by affinity) reg " Preg reg `" to %" ip vr %; printfn R (poff + 1) `"free registers: " Plist preg (freeRegList pool) (", ", true, 0) % end end fun putToStack poff { rinfo, stackCand, ... } vr = let val () = printfn R poff `"puting %" ip vr `" to stack: " ip (!stackCand) % in updReg rinfo vr (VtStack (!stackCand)); stackCand := !stackCand - 8 end fun assignHardReg poff (I as { rinfo, pool, ... }) vr reg = let val () = printfn R poff `"trying to assign hard reg " A1 preg reg `" to %" ip vr % val regIdx = reg2idx reg val user = Array.sub (pool, regIdx) fun setOurReg () = let val () = printfn R (poff + 1) `"reg assigned" % in Array.update (pool, regIdx, SOME vr); updReg rinfo vr (VtReg reg) end in case user of NONE => setOurReg () | SOME u => let val () = printfn R (poff + 1) `"reg is taken by %" ip u % val (aff, _) = Array.sub (rinfo, u) in case aff of AfHard _ => raise Unreachable | AfSoft affs => assignSoftReg poff affs I u | AfUnk => assignFirstReg poff I u ; setOurReg () end end fun assignReg (I as { rinfo, ... }) (vr, _, _) = let val (aff, _) = Array.sub (rinfo, vr) in case aff of AfUnk => assignFirstReg 0 I vr | AfSoft affs => assignSoftReg 0 affs I vr | AfHard reg => assignHardReg 0 I vr reg end fun getPool () = Array.array (usedRegNum, NONE) fun changeInActive active newInt oldVr = let val a = !active val a = List.filter (fn (v, _, _) => v <> oldVr) a in active := addToActive newInt a end fun expropriateReg (I as { rinfo, pool, active, ... }) int reg = let val vr = #1 int val regIdx = reg2idx reg val u = valOf $ Array.sub (pool, regIdx) val (uAff, _) = Array.sub (rinfo, u) val () = case uAff of AfHard _ => raise Unreachable | _ => () val () = putToStack 1 I u val () = Array.update (pool, regIdx, SOME vr) val () = updReg rinfo vr (VtReg reg) in changeInActive active int u end fun userIdx pool vr = let fun loop idx = if idx = Array.length pool then raise Unreachable else case Array.sub (pool, idx) of SOME u => if u = vr then idx else loop (idx + 1) | NONE => loop (idx + 1) in loop 0 end fun spillAtInterval (I as { rinfo, active, pool, ... }) int = let val spill = List.last (!active) val vr = #1 int val (ourAff, _) = Array.sub (rinfo, vr) fun isNotHard vr = case #1 $ Array.sub (rinfo, vr) of AfHard _ => false | _ => true val () = printfn `"SpilAtInt" % val () = printfn R 0 `"free registers: " Plist preg (freeRegList pool) (", ", true, 0) % in case ourAff of AfHard reg => expropriateReg I int reg | _ => if #3 spill > #3 int andalso isNotHard (#1 spill) then let val idx = userIdx pool (#1 spill) val () = printfn `"spilling!!!" % in Array.update (pool, idx, SOME vr); updReg rinfo vr (VtReg $ idx2reg idx); putToStack 1 I (#1 spill); changeInActive active int (#1 spill) end else putToStack 0 I vr end fun linearscan rinfo ints = let fun incStart ((_, start1, _), (_, start2, _)) = start1 <= start2 val ints = sort incStart ints val () = printInts ints fun loop _ [] = () | loop (I as { active, ... }) (int :: ints) = let val () = printfn `"\n\ninspectiing interval " ip (#1 int) `": (" ip (#2 int) `", " ip (#3 int) `")" % val () = expireOld I int val () = if length (!active) = usedRegNum then spillAtInterval I int else let val () = assignReg I int in active := addToActive int (!active) end in loop I ints end in loop { active = ref [], pool = getPool (), rinfo, stackCand = ref (~8) } ints end fun printAllocVar rinfo v = let val () = printf `"%" I v `": " % val (_, vt) = Array.sub (rinfo, v) in case vt of VtStack off => printfn `"stack " I off % | VtReg reg => printfn `"reg " A1 preg reg % | VtConst _ | VtUnk => raise Unreachable end fun printAlloced rinfo toAlloc = let val () = printfn `"\nallocated:\n" % in List.app (printAllocVar rinfo) toAlloc end fun getUsedRegs rinfo = let val regs = Array.array (usedRegNum, false) fun loop idx = if idx = Array.length rinfo then () else let val (_, vt) = Array.sub (rinfo, idx) in case vt of VtReg reg => Array.update (regs, reg2idx reg, true) | _ => (); loop (idx + 1) end val () = loop 0 fun collect idx acc = if idx = usedRegNum then acc else if Array.sub (regs, idx) then collect (idx + 1) (idx2reg idx :: acc) else collect (idx + 1) acc in collect 0 [] end fun getRegsToSave rinfo = let val regs = getUsedRegs rinfo in List.filter (fn r => reg2idx r >= callerSavedRegs) regs end fun initMap len = let open Array val map = array (len, array (callerSavedRegs, NONE)) val i = ref 1 in while !i < len do ( update (map, !i, array (callerSavedRegs, NONE)); i := !i + 1 ); map end fun computeMap len intervals rinfo = let val map = initMap len fun addInt (vr, startp, endp) = case #2 $ Array.sub (rinfo, vr) of VtReg reg => let fun f idx = if reg2idx reg >= callerSavedRegs orelse idx = endp then () else let val row = Array.sub (map, idx) in Array.update (row, reg2idx reg, SOME vr); f (idx + 1) end in f (startp + 1) end | _ => () in List.app addInt intervals; map end fun printMap map = let val () = printfn `"Register map\n" % fun printHeader idx = if idx = callerSavedRegs then printf `"\n" % else ( printfp 5 `" " Preg (idx2reg idx) `" " %; printHeader (idx + 1) ) val () = printf `" " % val () = printHeader 0 fun printRow (idx, row) = let val () = printf Ip 4 idx `": " % fun loop idx = if idx = callerSavedRegs then printf `"\n" % else ( case Array.sub (row, idx) of NONE => printf `" " % | SOME vr => printfp 5 `"%" I vr `" " %; loop (idx + 1) ) in loop 0 end in Array.appi printRow map end fun regAlloc (F as I.Fi { vregs, ops, paramNum, ... }) = let val (toAlloc, regInfo) = prepareRegInfo paramNum ops vregs val () = printfn `"for alloc: " Plist i toAlloc (", ", true, 0) % val () = affPrint regInfo val intervals = computeInts F toAlloc val () = linearscan regInfo intervals val () = printAlloced regInfo toAlloc val regsToSave = getRegsToSave regInfo val () = printfn `"registers to save: " Plist preg regsToSave (", ", true, 0) % val regMap = computeMap (D.length ops) intervals regInfo val () = printMap regMap in raise Unimplemented end fun emitFunc (F as I.Fi { vregs, ... }) = let val () = regAlloc F vregs in raise Unimplemented end fun openFile fname = file := SOME (TextIO.openOut fname) fun emit fname (I.Ctx { globSyms, extSyms, objsZI, objs, strlits, funcInfos, ... }) = let val () = openFile fname val () = List.app (fn gs => fprint `"global " PP.? gs `"\n" %) globSyms val () = List.app (fn es => fprint `"extern " PP.? es `"\n" %) extSyms val () = handleBSS objsZI val () = handleData objs val () = handleStrlits strlits val () = handleLocalIniLayouts () val () = List.app emitFunc funcInfos in () end end