Files
cs420/src/opt/mem2reg.rs
2025-06-06 13:01:18 +00:00

276 lines
9.4 KiB
Rust

use core::ops::{Deref, DerefMut};
use std::cmp::Ordering;
use std::collections::{BTreeMap, HashMap, HashSet};
use crate::ir::*;
use crate::opt::opt_utils::*;
use crate::opt::*;
pub type Mem2reg = FunctionPass<Mem2regInner>;
#[derive(Default, Clone, Copy, Debug)]
pub struct Mem2regInner {}
impl Optimize<FunctionDefinition> for Mem2regInner {
fn optimize(&mut self, code: &mut FunctionDefinition) -> bool {
let mut inpromotables = HashSet::new();
let mut stores = HashMap::new();
for (bid, block) in &code.blocks {
for inst in &block.instructions {
match inst.deref() {
Instruction::Nop | Instruction::Load { .. } => (),
Instruction::BinOp { lhs, rhs, .. } => {
mark_as_inpromotable(&mut inpromotables, lhs);
mark_as_inpromotable(&mut inpromotables, rhs);
}
Instruction::UnaryOp { operand, .. } => {
mark_as_inpromotable(&mut inpromotables, operand);
}
Instruction::Store { ptr, value } => {
mark_as_inpromotable(&mut inpromotables, value);
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
let _ = stores.entry(*aid).or_insert_with(HashSet::new).insert(*bid);
}
}
Instruction::Call { callee, args, .. } => {
mark_as_inpromotable(&mut inpromotables, callee);
for arg in args {
mark_as_inpromotable(&mut inpromotables, arg);
}
}
Instruction::TypeCast { value, .. } => {
mark_as_inpromotable(&mut inpromotables, value);
}
Instruction::GetElementPtr { ptr, .. } => {
mark_as_inpromotable(&mut inpromotables, ptr);
}
_ => unreachable!(),
}
}
}
if inpromotables.len() == code.allocations.len() {
return false;
}
let cfg = make_cfg(code);
let reverse_cfg = reverse_cfg(&cfg);
let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg);
let phinodes = stores
.into_iter()
.filter(|(aid, _)| !inpromotables.contains(aid))
.map(|(aid, bids)| {
let mut stack = bids.into_iter().collect::<Vec<_>>();
let mut visited = HashSet::new();
while let Some(bid) = stack.pop() {
if let Some(bid_frontiers) = domtree.frontiers(&bid) {
for bid_frontier in bid_frontiers {
if visited.insert(*bid_frontier) {
stack.push(*bid_frontier);
}
}
}
}
(aid, visited)
})
.collect::<BTreeMap<_, _>>(); // aid -> [bid]
let mut phinode_indexes = HashMap::new(); // (aid, bid) -> phinode index
let mut phinode_allocs = HashMap::new(); // bid -> [(aid, dtype)]
for (aid, bids) in phinodes {
for bid in bids {
let block = code.blocks.get_mut(&bid).unwrap();
block.phinodes.push(code.allocations[aid].clone());
let _ = phinode_indexes.insert((aid, bid), block.phinodes.len() - 1);
phinode_allocs.entry(bid).or_insert_with(Vec::new).push(aid);
}
}
let mut replaces = HashMap::new();
traverse_po(
code.bid_init,
code,
&inpromotables,
&domtree,
(&phinode_indexes, &phinode_allocs),
HashMap::new(),
&mut replaces,
);
for (bid, block) in code.blocks.iter_mut() {
loop {
let mut changed = false;
for inst in block.instructions.iter_mut() {
changed = replace_instruction_operands(inst, &replaces) || changed;
}
changed = replace_exit_operands(&mut block.exit, &replaces) || changed;
if !changed {
break;
}
}
for inst in block.instructions.iter_mut() {
match inst.deref().deref() {
Instruction::Store { ptr, .. } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if !inpromotables.contains(aid) {
*inst.deref_mut() = Instruction::Nop;
}
}
}
Instruction::Load { ptr } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if !inpromotables.contains(aid) {
*inst.deref_mut() = Instruction::Nop;
}
}
}
_ => (),
}
}
}
true
}
}
fn mark_as_inpromotable(inpromotables: &mut HashSet<usize>, operand: &Operand) {
if let Some((RegisterId::Local { aid }, _)) = operand.get_register() {
let _ = inpromotables.insert(*aid);
}
}
type PhinodeInfo<'a> = (
&'a HashMap<(usize, BlockId), usize>,
&'a HashMap<BlockId, Vec<usize>>,
);
fn traverse_po(
bid: BlockId,
code: &mut FunctionDefinition,
inpromotables: &HashSet<usize>,
domtree: &Domtree,
(phinode_indexes, phinode_allocs): PhinodeInfo<'_>,
mut block_stacks: HashMap<BlockId, HashMap<usize, Operand>>,
replaces: &mut HashMap<RegisterId, Operand>,
) {
let block = code.blocks.get_mut(&bid).unwrap();
let block_stack = code
.allocations
.iter()
.enumerate()
.filter(|(aid, _)| !inpromotables.contains(aid))
.map(|(aid, dtype)| {
let initial_value = find_latest_value(
aid,
dtype.deref().clone(),
bid,
domtree,
phinode_indexes,
&block_stacks,
);
(aid, initial_value)
})
.collect::<HashMap<_, _>>();
let block_stack = block_stacks.entry(bid).or_insert(block_stack);
for (i, inst) in block.instructions.iter().enumerate() {
match inst.deref() {
Instruction::Store { ptr, value } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if let Some(value_stack) = block_stack.get_mut(aid) {
*value_stack = value.clone();
}
}
}
Instruction::Load { ptr } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if let Some(value_stack) = block_stack.get(aid) {
let _unused =
replaces.insert(RegisterId::temp(bid, i), value_stack.clone());
}
}
}
_ => (),
}
}
match &mut block.exit {
BlockExit::Jump { arg } => {
fill_jump_args(arg, phinode_allocs, block_stack);
}
BlockExit::ConditionalJump {
arg_then, arg_else, ..
} => {
fill_jump_args(arg_then, phinode_allocs, block_stack);
fill_jump_args(arg_else, phinode_allocs, block_stack);
}
BlockExit::Switch { default, cases, .. } => {
fill_jump_args(default, phinode_allocs, block_stack);
for (_, arg) in cases {
fill_jump_args(arg, phinode_allocs, block_stack);
}
}
_ => (),
}
if let Some(bids_successors) = domtree.successors(&bid) {
for bid_successor in bids_successors {
traverse_po(
*bid_successor,
code,
inpromotables,
domtree,
(phinode_indexes, phinode_allocs),
block_stacks.clone(),
replaces,
);
}
}
}
fn find_latest_value(
aid: usize,
dtype: Dtype,
bid: BlockId,
domtree: &Domtree,
phinode_indexes: &HashMap<(usize, BlockId), usize>,
block_stacks: &HashMap<BlockId, HashMap<usize, Operand>>,
) -> Operand {
if let Some(block_stack) = block_stacks.get(&bid) {
if let Some(value_stack) = block_stack.get(&aid) {
return value_stack.clone();
}
}
if let Some(phinode_index) = phinode_indexes.get(&(aid, bid)) {
Operand::register(RegisterId::arg(bid, *phinode_index), dtype)
} else if let Some(bid_idom) = domtree.idom(&bid) {
find_latest_value(
aid,
dtype,
*bid_idom,
domtree,
phinode_indexes,
block_stacks,
)
} else {
Operand::constant(Constant::undef(dtype))
}
}
fn fill_jump_args(
arg: &mut JumpArg,
phinode_allocs: &HashMap<BlockId, Vec<usize>>,
block_stack: &HashMap<usize, Operand>,
) {
if let Some(phinode_allocs) = phinode_allocs.get(&arg.bid) {
for aid in phinode_allocs {
arg.args.push(block_stack[aid].clone());
}
}
}