use core::ops::{Deref, DerefMut}; use std::cmp::Ordering; use std::collections::{BTreeMap, HashMap, HashSet}; use crate::ir::*; use crate::opt::opt_utils::*; use crate::opt::*; pub type Mem2reg = FunctionPass; #[derive(Default, Clone, Copy, Debug)] pub struct Mem2regInner {} impl Optimize for Mem2regInner { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { let mut inpromotables = HashSet::new(); let mut stores = HashMap::new(); for (bid, block) in &code.blocks { for inst in &block.instructions { match inst.deref() { Instruction::Nop | Instruction::Load { .. } => (), Instruction::BinOp { lhs, rhs, .. } => { mark_as_inpromotable(&mut inpromotables, lhs); mark_as_inpromotable(&mut inpromotables, rhs); } Instruction::UnaryOp { operand, .. } => { mark_as_inpromotable(&mut inpromotables, operand); } Instruction::Store { ptr, value } => { mark_as_inpromotable(&mut inpromotables, value); if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { let _ = stores.entry(*aid).or_insert_with(HashSet::new).insert(*bid); } } Instruction::Call { callee, args, .. } => { mark_as_inpromotable(&mut inpromotables, callee); for arg in args { mark_as_inpromotable(&mut inpromotables, arg); } } Instruction::TypeCast { value, .. } => { mark_as_inpromotable(&mut inpromotables, value); } Instruction::GetElementPtr { ptr, .. } => { mark_as_inpromotable(&mut inpromotables, ptr); } _ => unreachable!(), } } } if inpromotables.len() == code.allocations.len() { return false; } let cfg = make_cfg(code); let reverse_cfg = reverse_cfg(&cfg); let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg); let phinodes = stores .into_iter() .filter(|(aid, _)| !inpromotables.contains(aid)) .map(|(aid, bids)| { let mut stack = bids.into_iter().collect::>(); let mut visited = HashSet::new(); while let Some(bid) = stack.pop() { if let Some(bid_frontiers) = domtree.frontiers(&bid) { for bid_frontier in bid_frontiers { if visited.insert(*bid_frontier) { stack.push(*bid_frontier); } } } } (aid, visited) }) .collect::>(); // aid -> [bid] let mut phinode_indexes = HashMap::new(); // (aid, bid) -> phinode index let mut phinode_allocs = HashMap::new(); // bid -> [(aid, dtype)] for (aid, bids) in phinodes { for bid in bids { let block = code.blocks.get_mut(&bid).unwrap(); block.phinodes.push(code.allocations[aid].clone()); let _ = phinode_indexes.insert((aid, bid), block.phinodes.len() - 1); phinode_allocs.entry(bid).or_insert_with(Vec::new).push(aid); } } let mut replaces = HashMap::new(); traverse_po( code.bid_init, code, &inpromotables, &domtree, (&phinode_indexes, &phinode_allocs), HashMap::new(), &mut replaces, ); for (bid, block) in code.blocks.iter_mut() { loop { let mut changed = false; for inst in block.instructions.iter_mut() { changed = replace_instruction_operands(inst, &replaces) || changed; } changed = replace_exit_operands(&mut block.exit, &replaces) || changed; if !changed { break; } } for inst in block.instructions.iter_mut() { match inst.deref().deref() { Instruction::Store { ptr, .. } => { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if !inpromotables.contains(aid) { *inst.deref_mut() = Instruction::Nop; } } } Instruction::Load { ptr } => { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if !inpromotables.contains(aid) { *inst.deref_mut() = Instruction::Nop; } } } _ => (), } } } true } } fn mark_as_inpromotable(inpromotables: &mut HashSet, operand: &Operand) { if let Some((RegisterId::Local { aid }, _)) = operand.get_register() { let _ = inpromotables.insert(*aid); } } type PhinodeInfo<'a> = ( &'a HashMap<(usize, BlockId), usize>, &'a HashMap>, ); fn traverse_po( bid: BlockId, code: &mut FunctionDefinition, inpromotables: &HashSet, domtree: &Domtree, (phinode_indexes, phinode_allocs): PhinodeInfo<'_>, mut block_stacks: HashMap>, replaces: &mut HashMap, ) { let block = code.blocks.get_mut(&bid).unwrap(); let block_stack = code .allocations .iter() .enumerate() .filter(|(aid, _)| !inpromotables.contains(aid)) .map(|(aid, dtype)| { let initial_value = find_latest_value( aid, dtype.deref().clone(), bid, domtree, phinode_indexes, &block_stacks, ); (aid, initial_value) }) .collect::>(); let block_stack = block_stacks.entry(bid).or_insert(block_stack); for (i, inst) in block.instructions.iter().enumerate() { match inst.deref() { Instruction::Store { ptr, value } => { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if let Some(value_stack) = block_stack.get_mut(aid) { *value_stack = value.clone(); } } } Instruction::Load { ptr } => { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if let Some(value_stack) = block_stack.get(aid) { let _unused = replaces.insert(RegisterId::temp(bid, i), value_stack.clone()); } } } _ => (), } } match &mut block.exit { BlockExit::Jump { arg } => { fill_jump_args(arg, phinode_allocs, block_stack); } BlockExit::ConditionalJump { arg_then, arg_else, .. } => { fill_jump_args(arg_then, phinode_allocs, block_stack); fill_jump_args(arg_else, phinode_allocs, block_stack); } BlockExit::Switch { default, cases, .. } => { fill_jump_args(default, phinode_allocs, block_stack); for (_, arg) in cases { fill_jump_args(arg, phinode_allocs, block_stack); } } _ => (), } if let Some(bids_successors) = domtree.successors(&bid) { for bid_successor in bids_successors { traverse_po( *bid_successor, code, inpromotables, domtree, (phinode_indexes, phinode_allocs), block_stacks.clone(), replaces, ); } } } fn find_latest_value( aid: usize, dtype: Dtype, bid: BlockId, domtree: &Domtree, phinode_indexes: &HashMap<(usize, BlockId), usize>, block_stacks: &HashMap>, ) -> Operand { if let Some(block_stack) = block_stacks.get(&bid) { if let Some(value_stack) = block_stack.get(&aid) { return value_stack.clone(); } } if let Some(phinode_index) = phinode_indexes.get(&(aid, bid)) { Operand::register(RegisterId::arg(bid, *phinode_index), dtype) } else if let Some(bid_idom) = domtree.idom(&bid) { find_latest_value( aid, dtype, *bid_idom, domtree, phinode_indexes, block_stacks, ) } else { Operand::constant(Constant::undef(dtype)) } } fn fill_jump_args( arg: &mut JumpArg, phinode_allocs: &HashMap>, block_stack: &HashMap, ) { if let Some(phinode_allocs) = phinode_allocs.get(&arg.bid) { for aid in phinode_allocs { arg.args.push(block_stack[aid].clone()); } } }