From 9edb2b8dbd99636cf2d98d3253a0316f74720894 Mon Sep 17 00:00:00 2001 From: Vladimir Azarov Date: Sun, 10 Aug 2025 02:03:50 +0200 Subject: Instruction selection for multiplication and division --- emit.fun | 214 +++++++++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 169 insertions(+), 45 deletions(-) (limited to 'emit.fun') diff --git a/emit.fun b/emit.fun index 4659acf..bbc6827 100644 --- a/emit.fun +++ b/emit.fun @@ -1014,25 +1014,26 @@ functor Emit(I: IL) = struct fun emitPrologue ({ stackOffset, regsToSave, ... }) name = let val () = fprint PP.? name `":\n" % - val () = fprinttn `"push rbp" % - val () = fprinttn `"mov rbp, rsp" % - val () = - if stackOffset <> 0 then - fprinttn `"sub rsp, " I (~ stackOffset) % - else - () in + if stackOffset <> 0 then ( + fprinttn `"push rbp" %; + fprinttn `"mov rbp, rsp" %; + fprinttn `"sub rsp, " I (~ stackOffset) % + ) else + (); List.app emitPushReg regsToSave end - fun emitEpilogue { regsToSave, ... } = + fun emitEpilogue { regsToSave, stackOffset, ... } = let val () = List.app emitPushReg (rev regsToSave) - val () = fprinttn `"mov rsp, rbp" % - val () = fprinttn `"pop rbp" % - val () = fprinttn `"ret" % in - () + if stackOffset <> 0 then ( + fprinttn `"mov rsp, rbp" %; + fprinttn `"pop rbp" % + ) else + (); + fprinttn `"ret" % end fun pr is8 reg out = @@ -1155,7 +1156,7 @@ functor Emit(I: IL) = struct opMV is8 "mov" off c end - fun getTripleTemplate I (rd, rs1, rs2) comm = + fun getTripleTemplate I (rd, rs1, rs2) comm fold = let val (is81, t1) = getType I rd val (is82, t2) = getType I rs1 @@ -1167,31 +1168,31 @@ functor Emit(I: IL) = struct val tmp = case (t1, t2, t3) of (VtReg r1, VtReg r2, VtReg r3) => - if r1 = r2 then + if r1 = r2 andalso fold then RR (r1, r3) - else if r1 = r3 andalso comm then + else if r1 = r3 andalso comm andalso fold then RR (r1, r2) else RRR (r1, r2, r3) | (VtReg r1, VtReg r2, VtStack off) => - if r1 = r2 then + if r1 = r2 andalso fold then RM (r1, off) else RRM (r1, r2, off) | (VtReg r1, VtStack off, VtReg r2) => - if r1 = r2 andalso comm then + if r1 = r2 andalso comm andalso fold then RM (r1, off) else RMR (r1, off, r2) | (VtReg r1, VtReg r2, VtConst c) => - if r1 = r2 then + if r1 = r2 andalso fold then RV (r1, c) else RRV (r1, r2, c) | (VtReg r1, VtConst c, VtReg r2) => - if r1 = r2 andalso comm then + if r1 = r2 andalso comm andalso fold then RV (r1, c) else RVR (r1, c, r2) @@ -1204,12 +1205,12 @@ functor Emit(I: IL) = struct | (VtStack off, VtReg r1, VtReg r2) => MRR (off, r1, r2) | (VtStack off1, VtReg r, VtStack off2) => - if off1 = off2 andalso comm then + if off1 = off2 andalso comm andalso fold then MR (off1, r) else MRM (off1, r, off2) | (VtStack off1, VtStack off2, VtReg r) => - if off1 = off2 then + if off1 = off2 andalso fold then MR (off1, r) else MMR (off1, off2, r) @@ -1218,20 +1219,20 @@ functor Emit(I: IL) = struct | (VtStack off, VtConst c, VtReg r) => MVR (off, c, r) | (VtStack off1, VtStack off2, VtStack off3) => - if off1 = off2 then + if off1 = off2 andalso fold then MM (off1, off3) - else if off1 = off3 andalso comm then + else if off1 = off3 andalso comm andalso fold then MM (off1, off2) else MMM (off1, off2, off3) | (VtStack off1, VtStack off2, VtConst c) => - if off1 = off2 then + if off1 = off2 andalso fold then MV (off1, c) else MMV (off1, off2, c) | (VtStack off1, VtConst c, VtStack off2) => - if off1 = off2 andalso comm then + if off1 = off2 andalso comm andalso fold then MV (off1, c) else MVM (off1, c, off2) @@ -1264,7 +1265,7 @@ functor Emit(I: IL) = struct fun emitGenComm I op' triple = let - val (is8, tmp) = getTripleTemplate I triple true + val (is8, tmp) = getTripleTemplate I triple true true val Pr = fn z => bind A1 (pr is8) z val { movRR, movRM, movMR, movRV } = getUtilMovs is8 val { opRR, opRM, opMR, opRV, opMV } = getUtilOps is8 op' @@ -1313,7 +1314,7 @@ functor Emit(I: IL) = struct fun emitShift I op' triple = let - val (is8, tmp) = getTripleTemplate I triple false + val (is8, tmp) = getTripleTemplate I triple false true val Pr = fn z => bind A1 (pr is8) z val { movRR, movRM, movMR, movRV } = getUtilMovs is8 val opRV = opRV is8 op' @@ -1360,9 +1361,26 @@ functor Emit(I: IL) = struct | MV (off, v) => [opMV off (t v)] end + fun wordIsZero w = + case Word.compare (w, 0w0) of + EQUAL => true + | _ => false + + datatype cbv = CbvTrue | CbvFalse | CbvUnsure of int * word + + fun constBoolVal (VConst w) = if wordIsZero w then CbvFalse else CbvTrue + | constBoolVal (VAddrConst (id, off)) = + if wordIsZero off then + CbvTrue + else + CbvUnsure (id, off) + + fun isZeroConst (VConst 0w0) = true + | isZeroConst _ = false + fun emitSub I triple = let - val (is8, tmp) = getTripleTemplate I triple false + val (is8, tmp) = getTripleTemplate I triple false true val { movRR, movRM, movRV, movMR } = getUtilMovs is8 val { opRR, opRM, opRV, opMR, opMV } = getUtilOps is8 "sub" in @@ -1382,7 +1400,11 @@ functor Emit(I: IL) = struct [movRM r1 off, opRV r1 v] else [movRM r1 off, movRV Rax v, opRR r1 Rax] - | RVR (r1, v, r2) => [movRV r1 v, opRR r1 r2] + | RVR (r1, v, r2) => + (printfn `"HERE" %; if r1 = r2 andalso isZeroConst v then + [sprintf `"neg " A2 pr is8 r1 %] + else + [movRV r1 v, opRR r1 r2]) | RVM (r, v, off) => [movRV r v, opRM r off] | MRR (off, r1, r2) => [movRR Rax r1, opRR Rax r2, movMR off Rax] | MRM (off1, r, off2) => [opRM r off2, movMR off1 r] @@ -1399,7 +1421,11 @@ functor Emit(I: IL) = struct [movRM Rax off2, opRV Rax v, movMR off1 Rax] else [movRM Rax off2, movRV Rdx v, opRR Rax Rdx, movMR off1 Rax] - | MVM (off1, v, off2) => [movRV Rax v, opRM Rax off2, movMR off1 Rax] + | MVM (off1, v, off2) => + if off1 = off2 andalso isZeroConst v then + [sprintf `"neg " A2 pm is8 off1 %] + else + [movRV Rax v, opRM Rax off2, movMR off1 Rax] | MVR (off1, v, r) => [movRV Rax v, opRR Rax r, movMR off1 Rax] | RR (r1, r2) => [opRR r1 r2] | RM (r, off) => [opRM r off] @@ -1417,6 +1443,111 @@ functor Emit(I: IL) = struct [movRV Rax v, opMR off Rax] end + fun mov is8 r1 r2 = + case (r1, r2) of + (VtReg r1, VtReg r2) => movRR is8 r1 r2 + | (VtReg r, VtStack off) => movRM is8 r off + | (VtReg r, VtConst c) => movRV is8 r c + | (VtStack off, VtReg r) => movMR is8 off r + | _ => raise Unreachable + + fun prm is8 vr out = + case vr of + VtReg r => Printf out A2 pr is8 r % + | VtStack off => Printf out A2 pm is8 off % + | _ => raise Unreachable + + fun assertSize is81 is82 is83 = + if is81 <> is82 orelse is82 <> is83 then + raise Unreachable + else + () + + fun emitGenConstraint I (rd, rs1, rs2) op' resInReg = + let + val (is81, t1) = getType I rd + val (is82, t2) = getType I rs1 + val (is83, t3) = getType I rs2 + + val () = assertSize is81 is82 is83 + + val (first, second) = + case (t2, t3) of + (VtReg _ | VtStack _, _) => (t3, t2) + | (_, VtReg _ | VtStack _) => (t2, t3) + | (_, _) => raise Unreachable + in + [ + mov is81 (VtReg Rax) first, + sprintf `op' `" " A2 prm is81 second %, + mov is81 t1 (VtReg resInReg) + ] + end + + fun emitIMul I (vd, vs1, vs2) = + let + val (is81, t1) = getType I vd + val (is82, t2) = getType I vs1 + val (is83, t3) = getType I vs2 + + val () = assertSize is81 is82 is83 + + datatype form = Reduced of vrType | Normal of vrType * vrType + + fun getReg vt = + case vt of + VtReg r => r + | _ => Rax + + fun moveBackIfNeeded dest = + if dest = Rax then + [mov is81 t1 (VtReg Rax)] + else + [] + + fun op2 () = + let + val form = + if t3 = t1 then + Reduced t2 + else if t2 = t1 then + Reduced t3 + else + Normal (t2, t3) + + val dest = getReg t1 + val main = + case form of + Reduced rs => + [ sprintf `"imul " A2 pr is81 dest `", " A2 prm is81 rs % ] + | Normal (rs1, rs2) => [ + mov is81 (VtReg dest) rs1, + sprintf `"imul " A2 pr is81 dest `", " A2 prm is81 rs2 % + ] + in + main @ moveBackIfNeeded dest + end + + fun op3 rs1 c = + if fitsInNsx 32 c then + let + val dest = getReg t1 + + val main = + [sprintf `"imul " A2 pr is81 dest `", " + A2 prm is81 rs1 `", " A2 pc is81 c %] + in + main @ moveBackIfNeeded dest + end + else + op2 () + in + case (t2, t3) of + (VtConst c, _) => op3 t3 c + | (_, VtConst c) => op3 t2 c + | _ => op2 () + end + fun emitSet I (vrd, I.SaVReg vrs) = let val (is81, t1) = getType I vrd @@ -1533,20 +1664,6 @@ functor Emit(I: IL) = struct loop 0w0 prolog end - fun wordIsZero w = - case Word.compare (w, 0w0) of - EQUAL => true - | _ => false - - datatype cbv = CbvTrue | CbvFalse | CbvUnsure of int * word - - fun constBoolVal (VConst w) = if wordIsZero w then CbvFalse else CbvTrue - | constBoolVal (VAddrConst (id, off)) = - if wordIsZero off then - CbvTrue - else - CbvUnsure (id, off) - fun emitJz E (vr, lid) isJz = let val (is8, vt) = getType E vr @@ -1583,6 +1700,13 @@ functor Emit(I: IL) = struct | I.IrShr t => emitShift E "shr" t | I.IrSar t => emitShift E "sar" t + | I.IrMul t => emitIMul E t + | I.IrIMul t => emitIMul E t + | I.IrDiv t => emitGenConstraint E t "div" Rax + | I.IrIDiv t => emitGenConstraint E t "idiv" Rax + | I.IrMod t => emitGenConstraint E t "div" Rdx + | I.IrIMod t => emitGenConstraint E t "idiv" Rdx + | I.IrAlloc t => emitAlloc E t | I.IrRet vr => emitRet E vr idx | I.IrCopy t => emitCopy E t @@ -1590,7 +1714,7 @@ functor Emit(I: IL) = struct | I.IrJmp lid => [ jmp lid ] | I.IrJz p => emitJz E p true | I.IrJnz p => emitJz E p false - | _ => [] + | I.IrNop comment => [sprintf `"; " `comment %] fun emitIns (I as { ops, ... }) = let -- cgit v1.2.3