diff --git a/src/opt/mem2reg.rs b/src/opt/mem2reg.rs index eda5bb8..70fa5de 100644 --- a/src/opt/mem2reg.rs +++ b/src/opt/mem2reg.rs @@ -1,4 +1,5 @@ use core::ops::{Deref, DerefMut}; +use std::cmp::Ordering; use std::collections::{BTreeMap, HashMap, HashSet}; use crate::ir::*; @@ -12,6 +13,572 @@ pub struct Mem2regInner {} impl Optimize for Mem2regInner { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { - todo!() + let mut inpromotable = HashSet::new(); + let mut stores = HashMap::new(); + + for (bid, block) in &code.blocks { + for inst in &block.instructions { + match inst.deref() { + Instruction::BinOp { lhs, rhs, .. } => { + mark_inpromotable(&mut inpromotable, lhs); + mark_inpromotable(&mut inpromotable, rhs); + } + Instruction::UnaryOp { operand, .. } => { + mark_inpromotable(&mut inpromotable, operand); + } + Instruction::Store { ptr, value } => { + mark_inpromotable(&mut inpromotable, value); + if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { + stores.entry(*aid).or_insert_with(Vec::new).push(*bid); + } + } + Instruction::Load { .. } => (), + Instruction::Call { callee, args, .. } => { + mark_inpromotable(&mut inpromotable, callee); + for arg in args { + mark_inpromotable(&mut inpromotable, arg); + } + } + Instruction::TypeCast { value, .. } => { + mark_inpromotable(&mut inpromotable, value); + } + Instruction::GetElementPtr { ptr, .. } => { + mark_inpromotable(&mut inpromotable, ptr); + } + Instruction::Nop => (), + _ => todo!() + } + } + } + + if inpromotable.len() == code.allocations.len() { + return false; + } + + let cfg = make_cfg(code); + let reverse_cfg = reverse_cfg(&cfg); + let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg); + + let joins = stores + .iter() + .filter(|(aid, _)| !inpromotable.contains(*aid)) + .map(|(aid, bids)| { + (*aid, { + let mut stack = bids.clone(); + let mut visited = HashSet::new(); + while let Some(bid) = stack.pop() { + if let Some(bid_frontiers) = domtree.frontiers(&bid) { + for bid_frontier in bid_frontiers { + if visited.insert(*bid_frontier) { + stack.push(*bid_frontier); + } + } + } + } + visited + }) + }) + .collect::>(); + + let mut inv_joins = HashMap::new(); + for (aid, bids) in &joins { + for bid in bids { + inv_joins.entry(bid).or_insert_with(Vec::new).push(*aid); + } + } + for (_, aids) in inv_joins.iter_mut() { + aids.sort(); + } + + let mut phinode_indexes = HashMap::new(); + let mut phinode_allocs = HashMap::new(); + for (bid, aids) in &inv_joins { + let block = code.blocks.get_mut(bid).unwrap(); + let index = block.phinodes.len(); + for (i, aid) in aids.iter().enumerate() { + block.phinodes.push(code.allocations[*aid].clone()); + let _ = phinode_indexes.insert((*aid, **bid), index + i); + phinode_allocs + .entry(**bid) + .or_insert_with(Vec::new) + .push((*aid, code.allocations[*aid].deref().clone())); + } + } + + + + // for (aid, bids) in &joins { + // let alloc = code.allocations.get(*aid).unwrap(); + // for bid in bids { + // // aid를 담는 Phinode를 넣을 곳 + // let block = code.blocks.get_mut(bid).unwrap(); + // let index = block.phinodes.len(); + // block.phinodes.push(alloc.clone()); + // let _ = phinode_indexes.insert((*aid, *bid), index); + // phinode_allocs + // .entry(*bid) + // .or_insert_with(Vec::new) + // .push((*aid, alloc.deref().clone())); + // } + // } + + let mut inv_domtree = HashMap::new(); + for (bid, idom) in &domtree.idoms { + inv_domtree.entry(*idom).or_insert_with(Vec::new).push(*bid); + } + + println!("{:?}", domtree.idoms); + println!("{:?}", inv_domtree); + + let mut stack = HashMap::new(); + let mut replaces = HashMap::new(); + + fn find_initial( + aid: usize, + dtype: Dtype, + code: &FunctionDefinition, + inpromotable: &HashSet, + bid: &BlockId, + stack: &HashMap>>, + phinode_indexes: &HashMap<(usize, BlockId), usize>, + domtree: &Domtree, + ) -> Operand { + if let Some(block_stack) = stack.get(bid) { + if let Some(block_stack) = block_stack.get(&aid) { + block_stack.last().unwrap().clone() + } else { + if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { + Operand::register(RegisterId::arg(*bid, *phinode_index), dtype) + } else if let Some(bid_idom) = domtree.idoms.get(bid) { + find_initial( + aid, + dtype, + code, + inpromotable, + bid_idom, + stack, + phinode_indexes, + domtree, + ) + } else { + Operand::constant(Constant::undef(dtype)) + } + } + } else if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { + Operand::register(RegisterId::arg(*bid, *phinode_index), dtype) + } else if let Some(bid_idom) = domtree.idoms.get(bid) { + find_initial( + aid, + dtype, + code, + inpromotable, + bid_idom, + stack, + phinode_indexes, + domtree, + ) + } else { + Operand::constant(Constant::undef(dtype)) + } + } + + fn traverse_preorder( + code: &mut FunctionDefinition, + phinode_indexes: &HashMap<(usize, BlockId), usize>, + phinode_allocs: &HashMap>, + inv_domtree: &HashMap>, + bid: &BlockId, + stack: &mut HashMap>>, + replaces: &mut HashMap, + inpromotable: &HashSet, + domtree: &Domtree, + ) { + println!("bid: {}", bid); + + let stack_org = stack.clone(); + // let block_stack = stack.entry(*bid).or_insert_with(HashMap::new); + + for (aid, dtype) in code.allocations.iter().enumerate() { + if !inpromotable.contains(&aid) { + let initial = find_initial( + aid, + dtype.deref().clone(), + code, + inpromotable, + bid, + stack, + phinode_indexes, + domtree, + ); + + stack + .entry(*bid) + .or_insert_with(HashMap::new) + .entry(aid) + .or_insert_with(Vec::new) + .push(initial); + // if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { + // entry.push(Operand::register( + // RegisterId::arg(*bid, *phinode_index), + // dtype.deref().clone(), + // )); + // } else { + // entry.push(Operand::constant(Constant::undef(dtype.deref().clone()))); + // } + } + } + // find_initial(code, inpromotable, bid, stack, phinode_indexes, domtree); + let block_stack = stack.entry(*bid).or_insert_with(HashMap::new); + + for (i, inst) in code.blocks[bid].instructions.iter().enumerate() { + match inst.deref() { + Instruction::Store { ptr, value } => { + if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { + if !inpromotable.contains(aid) { + block_stack.get_mut(aid).unwrap().push(value.clone()); + } + } + } + Instruction::Load { ptr } => { + if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { + if !inpromotable.contains(aid) { + let _unused = replaces.insert( + RegisterId::temp(*bid, i), + block_stack[aid].last().unwrap().clone(), + ); + } + } + } + _ => (), + } + } + + let block = code.blocks.get_mut(bid).unwrap(); + match &mut block.exit { + BlockExit::Jump { arg } => { + fill_jump_args(arg, phinode_allocs, &block_stack); + } + BlockExit::ConditionalJump { + arg_then, arg_else, .. + } => { + fill_jump_args(arg_then, phinode_allocs, &block_stack); + fill_jump_args(arg_else, phinode_allocs, &block_stack); + } + BlockExit::Switch { default, cases, .. } => { + fill_jump_args(default, phinode_allocs, &block_stack); + for (_, arg) in cases { + fill_jump_args(arg, phinode_allocs, &block_stack); + } + } + _ => (), + } + + if let Some(succs) = inv_domtree.get(bid) { + for succ in succs { + traverse_preorder( + code, + phinode_indexes, + phinode_allocs, + inv_domtree, + succ, + stack, + replaces, + inpromotable, + domtree, + ); + } + } + + *stack = stack_org; + } + + let bid_init = code.bid_init; + + traverse_preorder( + code, + &phinode_indexes, + &phinode_allocs, + &inv_domtree, + &bid_init, + &mut stack, + &mut replaces, + &inpromotable, + &domtree, + ); + + for (bid, block) in &mut code.blocks { + for inst in block.instructions.iter_mut() { + replace_instruction_operands(inst, &replaces); + } + replace_exit_operands(&mut block.exit, &replaces); + } + + for block in code.blocks.values_mut() { + for inst in block.instructions.iter_mut() { + match inst.deref().deref() { + Instruction::Store { ptr, .. } => { + if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { + if !inpromotable.contains(aid) { + *inst.deref_mut() = Instruction::Nop; + } + } + } + Instruction::Load { ptr } => { + if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { + if !inpromotable.contains(aid) { + *inst.deref_mut() = Instruction::Nop; + } + } + } + _ => (), + } + } + } + + println!("replaces: {:?}\n", replaces); + + true // TODO + } +} + +fn fill_jump_args( + arg: &mut JumpArg, + phinode_allocs: &HashMap>, + block_stack: &HashMap>, +) { + if let Some(target_phinode_args) = phinode_allocs.get(&arg.bid) { + for (target_phinode_arg, dtype) in target_phinode_args { + arg.args + .push(block_stack[target_phinode_arg].last().unwrap().clone()); + } + } +} + +fn mark_inpromotable(inpromotable: &mut HashSet, operand: &Operand) { + if let Some((RegisterId::Local { aid }, _)) = operand.get_register() { + let _ = inpromotable.insert(*aid); + } +} + +fn make_cfg(fdef: &FunctionDefinition) -> HashMap> { + fdef.blocks + .iter() + .map(|(bid, block)| { + let mut args = Vec::new(); + match &block.exit { + BlockExit::Jump { arg } => args.push(arg.clone()), + BlockExit::ConditionalJump { + arg_then, arg_else, .. + } => { + args.push(arg_then.clone()); + args.push(arg_else.clone()); + } + BlockExit::Switch { default, cases, .. } => { + args.push(default.clone()); + for (_, arg) in cases { + args.push(arg.clone()); + } + } + _ => (), + } + (*bid, args) + }) + .collect() +} + +fn reverse_cfg(cfg: &HashMap>) -> HashMap> { + let mut result = HashMap::new(); + for (bid, jumps) in cfg { + for jump in jumps { + result + .entry(jump.bid) + .or_insert_with(Vec::new) + .push((*bid, jump.clone())); + } + } + result +} + +fn traverse_rpo(bid_init: BlockId, cfg: &HashMap>) -> Vec { + fn traverse_po( + bid: BlockId, + cfg: &HashMap>, + visited: &mut HashSet, + post_order: &mut Vec, + ) { + for jump in &cfg[&bid] { + if visited.insert(jump.bid) { + traverse_po(jump.bid, cfg, visited, post_order); + } + } + post_order.push(bid); + } + + let mut visited = HashSet::new(); + let _ = visited.insert(bid_init); + let mut order = Vec::new(); + + traverse_po(bid_init, cfg, &mut visited, &mut order); + order.reverse(); + order +} + +struct Domtree { + idoms: HashMap, + frontiers: HashMap>, + rpo: Vec, +} + +impl Domtree { + fn new( + bid_init: BlockId, + cfg: &HashMap>, + reverse_cfg: &HashMap>, + ) -> Self { + let rpo = traverse_rpo(bid_init, cfg); + let irpo = rpo.iter().enumerate().map(|(i, bid)| (*bid, i)).collect(); + let mut idoms: HashMap = HashMap::new(); + + loop { + let mut changed = false; + + for bid in &rpo { + if *bid == bid_init { + continue; + } + + let mut idom = None; + for (bid_pred, _) in &reverse_cfg[bid] { + if *bid_pred == bid_init || idoms.contains_key(bid_pred) { + idom = Some(intersect_idom(idom, *bid_pred, &irpo, &idoms)); + } + } + + if let Some(idom) = idom { + let _ = idoms + .entry(*bid) + .and_modify(|v| { + if *v != idom { + changed = true; + *v = idom; + } + }) + .or_insert_with(|| { + changed = true; + idom + }); + } + } + + if !changed { + break; + } + } + + let mut frontiers = HashMap::new(); + for (bid, preds) in reverse_cfg.iter().filter(|(_, preds)| preds.len() > 1) { + let idom = if let Some(idom) = idoms.get(bid) { + idom + } else { + continue; + }; + for (bid_pred, _) in preds { + let mut runner = *bid_pred; + while !Self::dominates(&idoms, runner, *bid) { + frontiers.entry(runner).or_insert_with(Vec::new).push(*bid); + runner = idoms[&runner]; + } + } + } + + Self { + idoms, + frontiers, + rpo, + } + } + + fn dominates(idoms: &HashMap, bid1: BlockId, mut bid2: BlockId) -> bool { + loop { + bid2 = if let Some(idom2) = idoms.get(&bid2) { + *idom2 + } else { + return false; + }; + if bid1 == bid2 { + return true; + } + } + } + + fn idom(&self, bid: &BlockId) -> Option<&BlockId> { + self.idoms.get(bid) + } + + fn frontiers(&self, bid: &BlockId) -> Option<&Vec> { + self.frontiers.get(bid) + } +} + +fn intersect_idom( + lhs: Option, + mut rhs: BlockId, + irpo: &HashMap, + idoms: &HashMap, +) -> BlockId { + let mut lhs = if let Some(lhs) = lhs { lhs } else { return rhs }; + loop { + if lhs == rhs { + return lhs; + } + match irpo[&lhs].cmp(&irpo[&rhs]) { + Ordering::Less => rhs = idoms[&rhs], + Ordering::Greater => lhs = idoms[&lhs], + Ordering::Equal => unreachable!(), + } + } +} + +fn replace_instruction_operands(inst: &mut Instruction, replaces: &HashMap) { + match inst { + Instruction::BinOp { lhs, rhs, .. } => { + replace_operand(lhs, replaces); + replace_operand(rhs, replaces); + } + Instruction::UnaryOp { operand, .. } => replace_operand(operand, replaces), + Instruction::Store { ptr, value } => { + replace_operand(ptr, replaces); + replace_operand(value, replaces); + } + Instruction::Load { ptr } => replace_operand(ptr, replaces), + Instruction::Call { callee, args, .. } => { + replace_operand(callee, replaces); + for arg in args.iter_mut() { + replace_operand(arg, replaces); + } + } + Instruction::TypeCast { value, .. } => replace_operand(value, replaces), + Instruction::GetElementPtr { ptr, offset, .. } => { + replace_operand(ptr, replaces); + replace_operand(offset, replaces); + } + _ => unreachable!(), + } +} + +fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap) { + match exit { + BlockExit::ConditionalJump { condition, .. } => replace_operand(condition, replaces), + BlockExit::Switch { value, .. } => replace_operand(value, replaces), + BlockExit::Return { value } => replace_operand(value, replaces), + _ => (), + } +} + +fn replace_operand(operand: &mut Operand, replaces: &HashMap) { + if let Operand::Register { rid, .. } = operand { + if let Some(new_operand) = replaces.get(rid) { + *operand = new_operand.clone(); + } } }