summaryrefslogtreecommitdiff
path: root/emit.fun
diff options
context:
space:
mode:
authorVladimir Azarov <avm@intermediate-node.net>2025-08-10 02:03:50 +0200
committerVladimir Azarov <avm@intermediate-node.net>2025-08-10 02:03:50 +0200
commit9edb2b8dbd99636cf2d98d3253a0316f74720894 (patch)
tree2618d33ef9b846fdd4340f8aacb41c9e8dea0a34 /emit.fun
parent89cbdbe9e4cb6f142154292cac462e2d130d912a (diff)
Instruction selection for multiplication and division
Diffstat (limited to 'emit.fun')
-rw-r--r--emit.fun214
1 files changed, 169 insertions, 45 deletions
diff --git a/emit.fun b/emit.fun
index 4659acf..bbc6827 100644
--- a/emit.fun
+++ b/emit.fun
@@ -1014,25 +1014,26 @@ functor Emit(I: IL) = struct
fun emitPrologue ({ stackOffset, regsToSave, ... }) name =
let
val () = fprint PP.? name `":\n" %
- val () = fprinttn `"push rbp" %
- val () = fprinttn `"mov rbp, rsp" %
- val () =
- if stackOffset <> 0 then
- fprinttn `"sub rsp, " I (~ stackOffset) %
- else
- ()
in
+ if stackOffset <> 0 then (
+ fprinttn `"push rbp" %;
+ fprinttn `"mov rbp, rsp" %;
+ fprinttn `"sub rsp, " I (~ stackOffset) %
+ ) else
+ ();
List.app emitPushReg regsToSave
end
- fun emitEpilogue { regsToSave, ... } =
+ fun emitEpilogue { regsToSave, stackOffset, ... } =
let
val () = List.app emitPushReg (rev regsToSave)
- val () = fprinttn `"mov rsp, rbp" %
- val () = fprinttn `"pop rbp" %
- val () = fprinttn `"ret" %
in
- ()
+ if stackOffset <> 0 then (
+ fprinttn `"mov rsp, rbp" %;
+ fprinttn `"pop rbp" %
+ ) else
+ ();
+ fprinttn `"ret" %
end
fun pr is8 reg out =
@@ -1155,7 +1156,7 @@ functor Emit(I: IL) = struct
opMV is8 "mov" off c
end
- fun getTripleTemplate I (rd, rs1, rs2) comm =
+ fun getTripleTemplate I (rd, rs1, rs2) comm fold =
let
val (is81, t1) = getType I rd
val (is82, t2) = getType I rs1
@@ -1167,31 +1168,31 @@ functor Emit(I: IL) = struct
val tmp =
case (t1, t2, t3) of
(VtReg r1, VtReg r2, VtReg r3) =>
- if r1 = r2 then
+ if r1 = r2 andalso fold then
RR (r1, r3)
- else if r1 = r3 andalso comm then
+ else if r1 = r3 andalso comm andalso fold then
RR (r1, r2)
else
RRR (r1, r2, r3)
| (VtReg r1, VtReg r2, VtStack off) =>
- if r1 = r2 then
+ if r1 = r2 andalso fold then
RM (r1, off)
else
RRM (r1, r2, off)
| (VtReg r1, VtStack off, VtReg r2) =>
- if r1 = r2 andalso comm then
+ if r1 = r2 andalso comm andalso fold then
RM (r1, off)
else
RMR (r1, off, r2)
| (VtReg r1, VtReg r2, VtConst c) =>
- if r1 = r2 then
+ if r1 = r2 andalso fold then
RV (r1, c)
else
RRV (r1, r2, c)
| (VtReg r1, VtConst c, VtReg r2) =>
- if r1 = r2 andalso comm then
+ if r1 = r2 andalso comm andalso fold then
RV (r1, c)
else
RVR (r1, c, r2)
@@ -1204,12 +1205,12 @@ functor Emit(I: IL) = struct
| (VtStack off, VtReg r1, VtReg r2) => MRR (off, r1, r2)
| (VtStack off1, VtReg r, VtStack off2) =>
- if off1 = off2 andalso comm then
+ if off1 = off2 andalso comm andalso fold then
MR (off1, r)
else
MRM (off1, r, off2)
| (VtStack off1, VtStack off2, VtReg r) =>
- if off1 = off2 then
+ if off1 = off2 andalso fold then
MR (off1, r)
else
MMR (off1, off2, r)
@@ -1218,20 +1219,20 @@ functor Emit(I: IL) = struct
| (VtStack off, VtConst c, VtReg r) => MVR (off, c, r)
| (VtStack off1, VtStack off2, VtStack off3) =>
- if off1 = off2 then
+ if off1 = off2 andalso fold then
MM (off1, off3)
- else if off1 = off3 andalso comm then
+ else if off1 = off3 andalso comm andalso fold then
MM (off1, off2)
else
MMM (off1, off2, off3)
| (VtStack off1, VtStack off2, VtConst c) =>
- if off1 = off2 then
+ if off1 = off2 andalso fold then
MV (off1, c)
else
MMV (off1, off2, c)
| (VtStack off1, VtConst c, VtStack off2) =>
- if off1 = off2 andalso comm then
+ if off1 = off2 andalso comm andalso fold then
MV (off1, c)
else
MVM (off1, c, off2)
@@ -1264,7 +1265,7 @@ functor Emit(I: IL) = struct
fun emitGenComm I op' triple =
let
- val (is8, tmp) = getTripleTemplate I triple true
+ val (is8, tmp) = getTripleTemplate I triple true true
val Pr = fn z => bind A1 (pr is8) z
val { movRR, movRM, movMR, movRV } = getUtilMovs is8
val { opRR, opRM, opMR, opRV, opMV } = getUtilOps is8 op'
@@ -1313,7 +1314,7 @@ functor Emit(I: IL) = struct
fun emitShift I op' triple =
let
- val (is8, tmp) = getTripleTemplate I triple false
+ val (is8, tmp) = getTripleTemplate I triple false true
val Pr = fn z => bind A1 (pr is8) z
val { movRR, movRM, movMR, movRV } = getUtilMovs is8
val opRV = opRV is8 op'
@@ -1360,9 +1361,26 @@ functor Emit(I: IL) = struct
| MV (off, v) => [opMV off (t v)]
end
+ fun wordIsZero w =
+ case Word.compare (w, 0w0) of
+ EQUAL => true
+ | _ => false
+
+ datatype cbv = CbvTrue | CbvFalse | CbvUnsure of int * word
+
+ fun constBoolVal (VConst w) = if wordIsZero w then CbvFalse else CbvTrue
+ | constBoolVal (VAddrConst (id, off)) =
+ if wordIsZero off then
+ CbvTrue
+ else
+ CbvUnsure (id, off)
+
+ fun isZeroConst (VConst 0w0) = true
+ | isZeroConst _ = false
+
fun emitSub I triple =
let
- val (is8, tmp) = getTripleTemplate I triple false
+ val (is8, tmp) = getTripleTemplate I triple false true
val { movRR, movRM, movRV, movMR } = getUtilMovs is8
val { opRR, opRM, opRV, opMR, opMV } = getUtilOps is8 "sub"
in
@@ -1382,7 +1400,11 @@ functor Emit(I: IL) = struct
[movRM r1 off, opRV r1 v]
else
[movRM r1 off, movRV Rax v, opRR r1 Rax]
- | RVR (r1, v, r2) => [movRV r1 v, opRR r1 r2]
+ | RVR (r1, v, r2) =>
+ (printfn `"HERE" %; if r1 = r2 andalso isZeroConst v then
+ [sprintf `"neg " A2 pr is8 r1 %]
+ else
+ [movRV r1 v, opRR r1 r2])
| RVM (r, v, off) => [movRV r v, opRM r off]
| MRR (off, r1, r2) => [movRR Rax r1, opRR Rax r2, movMR off Rax]
| MRM (off1, r, off2) => [opRM r off2, movMR off1 r]
@@ -1399,7 +1421,11 @@ functor Emit(I: IL) = struct
[movRM Rax off2, opRV Rax v, movMR off1 Rax]
else
[movRM Rax off2, movRV Rdx v, opRR Rax Rdx, movMR off1 Rax]
- | MVM (off1, v, off2) => [movRV Rax v, opRM Rax off2, movMR off1 Rax]
+ | MVM (off1, v, off2) =>
+ if off1 = off2 andalso isZeroConst v then
+ [sprintf `"neg " A2 pm is8 off1 %]
+ else
+ [movRV Rax v, opRM Rax off2, movMR off1 Rax]
| MVR (off1, v, r) => [movRV Rax v, opRR Rax r, movMR off1 Rax]
| RR (r1, r2) => [opRR r1 r2]
| RM (r, off) => [opRM r off]
@@ -1417,6 +1443,111 @@ functor Emit(I: IL) = struct
[movRV Rax v, opMR off Rax]
end
+ fun mov is8 r1 r2 =
+ case (r1, r2) of
+ (VtReg r1, VtReg r2) => movRR is8 r1 r2
+ | (VtReg r, VtStack off) => movRM is8 r off
+ | (VtReg r, VtConst c) => movRV is8 r c
+ | (VtStack off, VtReg r) => movMR is8 off r
+ | _ => raise Unreachable
+
+ fun prm is8 vr out =
+ case vr of
+ VtReg r => Printf out A2 pr is8 r %
+ | VtStack off => Printf out A2 pm is8 off %
+ | _ => raise Unreachable
+
+ fun assertSize is81 is82 is83 =
+ if is81 <> is82 orelse is82 <> is83 then
+ raise Unreachable
+ else
+ ()
+
+ fun emitGenConstraint I (rd, rs1, rs2) op' resInReg =
+ let
+ val (is81, t1) = getType I rd
+ val (is82, t2) = getType I rs1
+ val (is83, t3) = getType I rs2
+
+ val () = assertSize is81 is82 is83
+
+ val (first, second) =
+ case (t2, t3) of
+ (VtReg _ | VtStack _, _) => (t3, t2)
+ | (_, VtReg _ | VtStack _) => (t2, t3)
+ | (_, _) => raise Unreachable
+ in
+ [
+ mov is81 (VtReg Rax) first,
+ sprintf `op' `" " A2 prm is81 second %,
+ mov is81 t1 (VtReg resInReg)
+ ]
+ end
+
+ fun emitIMul I (vd, vs1, vs2) =
+ let
+ val (is81, t1) = getType I vd
+ val (is82, t2) = getType I vs1
+ val (is83, t3) = getType I vs2
+
+ val () = assertSize is81 is82 is83
+
+ datatype form = Reduced of vrType | Normal of vrType * vrType
+
+ fun getReg vt =
+ case vt of
+ VtReg r => r
+ | _ => Rax
+
+ fun moveBackIfNeeded dest =
+ if dest = Rax then
+ [mov is81 t1 (VtReg Rax)]
+ else
+ []
+
+ fun op2 () =
+ let
+ val form =
+ if t3 = t1 then
+ Reduced t2
+ else if t2 = t1 then
+ Reduced t3
+ else
+ Normal (t2, t3)
+
+ val dest = getReg t1
+ val main =
+ case form of
+ Reduced rs =>
+ [ sprintf `"imul " A2 pr is81 dest `", " A2 prm is81 rs % ]
+ | Normal (rs1, rs2) => [
+ mov is81 (VtReg dest) rs1,
+ sprintf `"imul " A2 pr is81 dest `", " A2 prm is81 rs2 %
+ ]
+ in
+ main @ moveBackIfNeeded dest
+ end
+
+ fun op3 rs1 c =
+ if fitsInNsx 32 c then
+ let
+ val dest = getReg t1
+
+ val main =
+ [sprintf `"imul " A2 pr is81 dest `", "
+ A2 prm is81 rs1 `", " A2 pc is81 c %]
+ in
+ main @ moveBackIfNeeded dest
+ end
+ else
+ op2 ()
+ in
+ case (t2, t3) of
+ (VtConst c, _) => op3 t3 c
+ | (_, VtConst c) => op3 t2 c
+ | _ => op2 ()
+ end
+
fun emitSet I (vrd, I.SaVReg vrs) =
let
val (is81, t1) = getType I vrd
@@ -1533,20 +1664,6 @@ functor Emit(I: IL) = struct
loop 0w0 prolog
end
- fun wordIsZero w =
- case Word.compare (w, 0w0) of
- EQUAL => true
- | _ => false
-
- datatype cbv = CbvTrue | CbvFalse | CbvUnsure of int * word
-
- fun constBoolVal (VConst w) = if wordIsZero w then CbvFalse else CbvTrue
- | constBoolVal (VAddrConst (id, off)) =
- if wordIsZero off then
- CbvTrue
- else
- CbvUnsure (id, off)
-
fun emitJz E (vr, lid) isJz =
let
val (is8, vt) = getType E vr
@@ -1583,6 +1700,13 @@ functor Emit(I: IL) = struct
| I.IrShr t => emitShift E "shr" t
| I.IrSar t => emitShift E "sar" t
+ | I.IrMul t => emitIMul E t
+ | I.IrIMul t => emitIMul E t
+ | I.IrDiv t => emitGenConstraint E t "div" Rax
+ | I.IrIDiv t => emitGenConstraint E t "idiv" Rax
+ | I.IrMod t => emitGenConstraint E t "div" Rdx
+ | I.IrIMod t => emitGenConstraint E t "idiv" Rdx
+
| I.IrAlloc t => emitAlloc E t
| I.IrRet vr => emitRet E vr idx
| I.IrCopy t => emitCopy E t
@@ -1590,7 +1714,7 @@ functor Emit(I: IL) = struct
| I.IrJmp lid => [ jmp lid ]
| I.IrJz p => emitJz E p true
| I.IrJnz p => emitJz E p false
- | _ => []
+ | I.IrNop comment => [sprintf `"; " `comment %]
fun emitIns (I as { ops, ... }) =
let