diff --git a/src/opt/mem2reg.rs b/src/opt/mem2reg.rs index 671bab8..3e0e8ba 100644 --- a/src/opt/mem2reg.rs +++ b/src/opt/mem2reg.rs @@ -13,45 +13,44 @@ pub struct Mem2regInner {} impl Optimize for Mem2regInner { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { - let mut inpromotable = HashSet::new(); + let mut inpromotables = HashSet::new(); let mut stores = HashMap::new(); for (bid, block) in &code.blocks { for inst in &block.instructions { match inst.deref() { + Instruction::Nop | Instruction::Load { .. } => (), Instruction::BinOp { lhs, rhs, .. } => { - mark_inpromotable(&mut inpromotable, lhs); - mark_inpromotable(&mut inpromotable, rhs); + mark_as_inpromotable(&mut inpromotables, lhs); + mark_as_inpromotable(&mut inpromotables, rhs); } Instruction::UnaryOp { operand, .. } => { - mark_inpromotable(&mut inpromotable, operand); + mark_as_inpromotable(&mut inpromotables, operand); } Instruction::Store { ptr, value } => { - mark_inpromotable(&mut inpromotable, value); + mark_as_inpromotable(&mut inpromotables, value); if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { - stores.entry(*aid).or_insert_with(Vec::new).push(*bid); + let _ = stores.entry(*aid).or_insert_with(HashSet::new).insert(*bid); } } - Instruction::Load { .. } => (), Instruction::Call { callee, args, .. } => { - mark_inpromotable(&mut inpromotable, callee); + mark_as_inpromotable(&mut inpromotables, callee); for arg in args { - mark_inpromotable(&mut inpromotable, arg); + mark_as_inpromotable(&mut inpromotables, arg); } } Instruction::TypeCast { value, .. } => { - mark_inpromotable(&mut inpromotable, value); + mark_as_inpromotable(&mut inpromotables, value); } Instruction::GetElementPtr { ptr, .. } => { - mark_inpromotable(&mut inpromotable, ptr); + mark_as_inpromotable(&mut inpromotables, ptr); } - Instruction::Nop => (), - _ => todo!(), + _ => unreachable!(), } } } - if inpromotable.len() == code.allocations.len() { + if inpromotables.len() == code.allocations.len() { return false; } @@ -59,239 +58,75 @@ impl Optimize for Mem2regInner { let reverse_cfg = reverse_cfg(&cfg); let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg); - let joins = stores - .iter() - .filter(|(aid, _)| !inpromotable.contains(*aid)) + let phinodes = stores + .into_iter() + .filter(|(aid, _)| !inpromotables.contains(aid)) .map(|(aid, bids)| { - (*aid, { - let mut stack = bids.clone(); - let mut visited = HashSet::new(); - while let Some(bid) = stack.pop() { - if let Some(bid_frontiers) = domtree.frontiers(&bid) { - for bid_frontier in bid_frontiers { - if visited.insert(*bid_frontier) { - stack.push(*bid_frontier); - } + let mut stack = bids.into_iter().collect::>(); + let mut visited = HashSet::new(); + while let Some(bid) = stack.pop() { + if let Some(bid_frontiers) = domtree.frontiers(&bid) { + for bid_frontier in bid_frontiers { + if visited.insert(*bid_frontier) { + stack.push(*bid_frontier); } } } - visited - }) + } + (aid, visited) }) - .collect::>(); + .collect::>(); // aid -> [bid] + let mut phinode_indexes = HashMap::new(); // (aid, bid) -> phinode index + let mut phinode_allocs = HashMap::new(); // bid -> [(aid, dtype)] - let mut inv_joins = HashMap::new(); - for (aid, bids) in &joins { + for (aid, bids) in phinodes { for bid in bids { - inv_joins.entry(bid).or_insert_with(Vec::new).push(*aid); - } - } - for (_, aids) in inv_joins.iter_mut() { - aids.sort(); - } - - let mut phinode_indexes = HashMap::new(); - let mut phinode_allocs = HashMap::new(); - for (bid, aids) in &inv_joins { - let block = code.blocks.get_mut(bid).unwrap(); - let index = block.phinodes.len(); - for (i, aid) in aids.iter().enumerate() { - block.phinodes.push(code.allocations[*aid].clone()); - let _ = phinode_indexes.insert((*aid, **bid), index + i); + let block = code.blocks.get_mut(&bid).unwrap(); + block.phinodes.push(code.allocations[aid].clone()); + let _ = phinode_indexes.insert((aid, bid), block.phinodes.len() - 1); phinode_allocs - .entry(**bid) + .entry(bid) .or_insert_with(Vec::new) - .push((*aid, code.allocations[*aid].deref().clone())); + .push((aid, code.allocations[aid].deref().clone())); } } - // for (aid, bids) in &joins { - // let alloc = code.allocations.get(*aid).unwrap(); - // for bid in bids { - // // aid를 담는 Phinode를 넣을 곳 - // let block = code.blocks.get_mut(bid).unwrap(); - // let index = block.phinodes.len(); - // block.phinodes.push(alloc.clone()); - // let _ = phinode_indexes.insert((*aid, *bid), index); - // phinode_allocs - // .entry(*bid) - // .or_insert_with(Vec::new) - // .push((*aid, alloc.deref().clone())); - // } - // } - - let mut inv_domtree = HashMap::new(); - for (bid, idom) in &domtree.idoms { - inv_domtree.entry(*idom).or_insert_with(Vec::new).push(*bid); - } - - // println!("{:?}", domtree.idoms); - // println!("{:?}", inv_domtree); - - let mut stack = HashMap::new(); let mut replaces = HashMap::new(); - - fn find_initial( - aid: usize, - dtype: Dtype, - bid: &BlockId, - stack: &HashMap>>, - phinode_indexes: &HashMap<(usize, BlockId), usize>, - domtree: &Domtree, - ) -> Operand { - if let Some(block_stack) = stack.get(bid) { - if let Some(block_stack) = block_stack.get(&aid) { - return block_stack.last().unwrap().clone(); - } - } - - if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { - Operand::register(RegisterId::arg(*bid, *phinode_index), dtype) - } else if let Some(bid_idom) = domtree.idoms.get(bid) { - find_initial(aid, dtype, bid_idom, stack, phinode_indexes, domtree) - } else { - Operand::constant(Constant::undef(dtype)) - } - } - - struct Asdf<'a> { - code: &'a mut FunctionDefinition, - phinode_indexes: &'a HashMap<(usize, BlockId), usize>, - phinode_allocs: &'a HashMap>, - inv_domtree: &'a HashMap>, - stack: &'a mut HashMap>>, - replaces: &'a mut HashMap, - inpromotable: &'a HashSet, - domtree: &'a Domtree, - }; - - fn traverse_preorder(asdf: &mut Asdf<'_>, bid: &BlockId) { - let stack_org: HashMap>> = asdf.stack.clone(); - // let block_stack = stack.entry(*bid).or_insert_with(HashMap::new); - - for (aid, dtype) in asdf.code.allocations.iter().enumerate() { - if !asdf.inpromotable.contains(&aid) { - let initial = find_initial( - aid, - dtype.deref().clone(), - bid, - asdf.stack, - asdf.phinode_indexes, - asdf.domtree, - ); - - asdf.stack - .entry(*bid) - .or_default() - .entry(aid) - .or_default() - .push(initial); - // if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { - // entry.push(Operand::register( - // RegisterId::arg(*bid, *phinode_index), - // dtype.deref().clone(), - // )); - // } else { - // entry.push(Operand::constant(Constant::undef(dtype.deref().clone()))); - // } - } - } - // find_initial(code, inpromotable, bid, stack, phinode_indexes, domtree); - let block_stack = asdf.stack.entry(*bid).or_default(); - - for (i, inst) in asdf.code.blocks[bid].instructions.iter().enumerate() { - match inst.deref() { - Instruction::Store { ptr, value } => { - if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { - if !asdf.inpromotable.contains(aid) { - block_stack.get_mut(aid).unwrap().push(value.clone()); - } - } - } - Instruction::Load { ptr } => { - if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { - if !asdf.inpromotable.contains(aid) { - let _unused = asdf.replaces.insert( - RegisterId::temp(*bid, i), - block_stack[aid].last().unwrap().clone(), - ); - } - } - } - _ => (), - } - } - - let block = asdf.code.blocks.get_mut(bid).unwrap(); - match &mut block.exit { - BlockExit::Jump { arg } => { - fill_jump_args(arg, asdf.phinode_allocs, block_stack); - } - BlockExit::ConditionalJump { - arg_then, arg_else, .. - } => { - fill_jump_args(arg_then, asdf.phinode_allocs, block_stack); - fill_jump_args(arg_else, asdf.phinode_allocs, block_stack); - } - BlockExit::Switch { default, cases, .. } => { - fill_jump_args(default, asdf.phinode_allocs, block_stack); - for (_, arg) in cases { - fill_jump_args(arg, asdf.phinode_allocs, block_stack); - } - } - _ => (), - } - - if let Some(succs) = asdf.inv_domtree.get(bid) { - for succ in succs { - traverse_preorder(asdf, succ); - } - } - - *asdf.stack = stack_org; - } - - let bid_init = code.bid_init; - - let mut asdf = Asdf { + traverse_po( + code.bid_init, code, - phinode_indexes: &phinode_indexes, - phinode_allocs: &phinode_allocs, - inv_domtree: &inv_domtree, - stack: &mut stack, - replaces: &mut replaces, - inpromotable: &inpromotable, - domtree: &domtree, - }; + &inpromotables, + &domtree, + (&phinode_indexes, &phinode_allocs), + HashMap::new(), + &mut replaces, + ); - traverse_preorder(&mut asdf, &bid_init); - - for (bid, block) in &mut code.blocks { + for (bid, block) in code.blocks.iter_mut() { loop { let mut changed = false; for inst in block.instructions.iter_mut() { - changed = changed || replace_instruction_operands(inst, &replaces); + changed = replace_instruction_operands(inst, &replaces) || changed; } - changed = changed || replace_exit_operands(&mut block.exit, &replaces); + changed = replace_exit_operands(&mut block.exit, &replaces) || changed; + if !changed { break; } } - } - for block in code.blocks.values_mut() { for inst in block.instructions.iter_mut() { match inst.deref().deref() { Instruction::Store { ptr, .. } => { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { - if !inpromotable.contains(aid) { + if !inpromotables.contains(aid) { *inst.deref_mut() = Instruction::Nop; } } } Instruction::Load { ptr } => { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { - if !inpromotable.contains(aid) { + if !inpromotables.contains(aid) { *inst.deref_mut() = Instruction::Nop; } } @@ -301,28 +136,13 @@ impl Optimize for Mem2regInner { } } - // println!("replaces: {:?}\n", replaces); - - true // TODO + true } } -fn fill_jump_args( - arg: &mut JumpArg, - phinode_allocs: &HashMap>, - block_stack: &HashMap>, -) { - if let Some(target_phinode_args) = phinode_allocs.get(&arg.bid) { - for (target_phinode_arg, dtype) in target_phinode_args { - arg.args - .push(block_stack[target_phinode_arg].last().unwrap().clone()); - } - } -} - -fn mark_inpromotable(inpromotable: &mut HashSet, operand: &Operand) { +fn mark_as_inpromotable(inpromotables: &mut HashSet, operand: &Operand) { if let Some((RegisterId::Local { aid }, _)) = operand.get_register() { - let _ = inpromotable.insert(*aid); + let _ = inpromotables.insert(*aid); } } @@ -391,8 +211,8 @@ fn traverse_rpo(bid_init: BlockId, cfg: &HashMap>) -> Vec< struct Domtree { idoms: HashMap, + inverse_idoms: HashMap>, frontiers: HashMap>, - rpo: Vec, } impl Domtree { @@ -441,6 +261,14 @@ impl Domtree { } } + let mut inverse_idoms = HashMap::new(); + for (bid, idom) in &idoms { + let _ = inverse_idoms + .entry(*idom) + .or_insert_with(HashSet::new) + .insert(*bid); + } + let mut frontiers = HashMap::new(); for (bid, preds) in reverse_cfg.iter().filter(|(_, preds)| preds.len() > 1) { let idom = if let Some(idom) = idoms.get(bid) { @@ -459,8 +287,8 @@ impl Domtree { Self { idoms, + inverse_idoms, frontiers, - rpo, } } @@ -477,10 +305,6 @@ impl Domtree { } } - fn idom(&self, bid: &BlockId) -> Option<&BlockId> { - self.idoms.get(bid) - } - fn frontiers(&self, bid: &BlockId) -> Option<&Vec> { self.frontiers.get(bid) } @@ -505,6 +329,139 @@ fn intersect_idom( } } +type PhinodeInfo<'a> = ( + &'a HashMap<(usize, BlockId), usize>, + &'a HashMap>, +); + +fn traverse_po( + bid: BlockId, + code: &mut FunctionDefinition, + inpromotables: &HashSet, + domtree: &Domtree, + (phinode_indexes, phinode_allocs): PhinodeInfo<'_>, + mut block_stacks: HashMap>>, + replaces: &mut HashMap, +) { + let block = code.blocks.get_mut(&bid).unwrap(); + let block_stack = code + .allocations + .iter() + .enumerate() + .filter(|(aid, _)| !inpromotables.contains(aid)) + .map(|(aid, dtype)| { + let initial_value = find_end_value( + aid, + dtype.deref().clone(), + bid, + domtree, + phinode_indexes, + &block_stacks, + ); + (aid, vec![initial_value]) + }) + .collect::>(); + let block_stack = block_stacks.entry(bid).or_insert(block_stack); + + for (i, inst) in block.instructions.iter().enumerate() { + match inst.deref() { + Instruction::Store { ptr, value } => { + if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { + if let Some(value_stack) = block_stack.get_mut(aid) { + value_stack.push(value.clone()); + } + } + } + Instruction::Load { ptr } => { + if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { + if let Some(value_stack) = block_stack.get(aid) { + let _unused = replaces.insert( + RegisterId::temp(bid, i), + value_stack.last().unwrap().clone(), + ); + } + } + } + _ => (), + } + } + + match &mut block.exit { + BlockExit::Jump { arg } => { + fill_jump_args(arg, phinode_allocs, block_stack); + } + BlockExit::ConditionalJump { + arg_then, arg_else, .. + } => { + fill_jump_args(arg_then, phinode_allocs, block_stack); + fill_jump_args(arg_else, phinode_allocs, block_stack); + } + BlockExit::Switch { default, cases, .. } => { + fill_jump_args(default, phinode_allocs, block_stack); + for (_, arg) in cases { + fill_jump_args(arg, phinode_allocs, block_stack); + } + } + _ => (), + } + + if let Some(bids_successors) = domtree.inverse_idoms.get(&bid) { + for bid_successor in bids_successors { + traverse_po( + *bid_successor, + code, + inpromotables, + domtree, + (phinode_indexes, phinode_allocs), + block_stacks.clone(), + replaces, + ); + } + } +} + +fn find_end_value( + aid: usize, + dtype: Dtype, + bid: BlockId, + domtree: &Domtree, + phinode_indexes: &HashMap<(usize, BlockId), usize>, + block_stacks: &HashMap>>, +) -> Operand { + if let Some(block_stack) = block_stacks.get(&bid) { + if let Some(value_stack) = block_stack.get(&aid) { + return value_stack.last().unwrap().clone(); + } + } + + if let Some(phinode_index) = phinode_indexes.get(&(aid, bid)) { + Operand::register(RegisterId::arg(bid, *phinode_index), dtype) + } else if let Some(bid_idom) = domtree.idoms.get(&bid) { + find_end_value( + aid, + dtype, + *bid_idom, + domtree, + phinode_indexes, + block_stacks, + ) + } else { + Operand::constant(Constant::undef(dtype)) + } +} + +fn fill_jump_args( + arg: &mut JumpArg, + phinode_allocs: &HashMap>, + block_stack: &HashMap>, +) { + if let Some(phinode_allocs) = phinode_allocs.get(&arg.bid) { + for (aid, dtype) in phinode_allocs { + arg.args.push(block_stack[aid].last().unwrap().clone()); + } + } +} + fn replace_instruction_operands( inst: &mut Instruction, replaces: &HashMap, @@ -525,7 +482,7 @@ fn replace_instruction_operands( Instruction::Call { callee, args, .. } => { let mut changed = replace_operand(callee, replaces); for arg in args.iter_mut() { - changed = changed || replace_operand(arg, replaces); + changed = replace_operand(arg, replaces) || changed; } changed } @@ -543,8 +500,8 @@ fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap { let mut changed = false; - for arg in &mut arg.args { - changed = changed || replace_operand(arg, replaces); + for arg in arg.args.iter_mut() { + changed = replace_operand(arg, replaces) || changed; } changed } @@ -554,11 +511,11 @@ fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap { let mut changed = replace_operand(condition, replaces); - for arg in &mut arg_then.args { - changed = changed || replace_operand(arg, replaces); + for arg in arg_then.args.iter_mut() { + changed = replace_operand(arg, replaces) || changed; } - for arg in &mut arg_else.args { - changed = changed || replace_operand(arg, replaces); + for arg in arg_else.args.iter_mut() { + changed = replace_operand(arg, replaces) || changed; } changed } @@ -568,12 +525,12 @@ fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap { let mut changed = replace_operand(value, replaces); - for arg in &mut default.args { - changed = changed || replace_operand(arg, replaces); + for arg in default.args.iter_mut() { + changed = replace_operand(arg, replaces) || changed; } for (_, arg) in cases { - for arg in &mut arg.args { - changed = changed || replace_operand(arg, replaces); + for arg in arg.args.iter_mut() { + changed = replace_operand(arg, replaces) || changed; } } changed