From 1d9727e480bcd72dcd10e9a3d5829523e900fbad Mon Sep 17 00:00:00 2001 From: static Date: Tue, 20 May 2025 01:43:59 +0000 Subject: [PATCH] HW4 (2) --- src/opt/mem2reg.rs | 234 +++++++++++++++++++++++---------------------- 1 file changed, 122 insertions(+), 112 deletions(-) diff --git a/src/opt/mem2reg.rs b/src/opt/mem2reg.rs index 70fa5de..671bab8 100644 --- a/src/opt/mem2reg.rs +++ b/src/opt/mem2reg.rs @@ -46,7 +46,7 @@ impl Optimize for Mem2regInner { mark_inpromotable(&mut inpromotable, ptr); } Instruction::Nop => (), - _ => todo!() + _ => todo!(), } } } @@ -80,7 +80,7 @@ impl Optimize for Mem2regInner { }) .collect::>(); - let mut inv_joins = HashMap::new(); + let mut inv_joins = HashMap::new(); for (aid, bids) in &joins { for bid in bids { inv_joins.entry(bid).or_insert_with(Vec::new).push(*aid); @@ -104,8 +104,6 @@ impl Optimize for Mem2regInner { .push((*aid, code.allocations[*aid].deref().clone())); } } - - // for (aid, bids) in &joins { // let alloc = code.allocations.get(*aid).unwrap(); @@ -127,8 +125,8 @@ impl Optimize for Mem2regInner { inv_domtree.entry(*idom).or_insert_with(Vec::new).push(*bid); } - println!("{:?}", domtree.idoms); - println!("{:?}", inv_domtree); + // println!("{:?}", domtree.idoms); + // println!("{:?}", inv_domtree); let mut stack = HashMap::new(); let mut replaces = HashMap::new(); @@ -136,8 +134,6 @@ impl Optimize for Mem2regInner { fn find_initial( aid: usize, dtype: Dtype, - code: &FunctionDefinition, - inpromotable: &HashSet, bid: &BlockId, stack: &HashMap>>, phinode_indexes: &HashMap<(usize, BlockId), usize>, @@ -145,77 +141,50 @@ impl Optimize for Mem2regInner { ) -> Operand { if let Some(block_stack) = stack.get(bid) { if let Some(block_stack) = block_stack.get(&aid) { - block_stack.last().unwrap().clone() - } else { - if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { - Operand::register(RegisterId::arg(*bid, *phinode_index), dtype) - } else if let Some(bid_idom) = domtree.idoms.get(bid) { - find_initial( - aid, - dtype, - code, - inpromotable, - bid_idom, - stack, - phinode_indexes, - domtree, - ) - } else { - Operand::constant(Constant::undef(dtype)) - } + return block_stack.last().unwrap().clone(); } - } else if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { + } + + if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { Operand::register(RegisterId::arg(*bid, *phinode_index), dtype) } else if let Some(bid_idom) = domtree.idoms.get(bid) { - find_initial( - aid, - dtype, - code, - inpromotable, - bid_idom, - stack, - phinode_indexes, - domtree, - ) + find_initial(aid, dtype, bid_idom, stack, phinode_indexes, domtree) } else { Operand::constant(Constant::undef(dtype)) } } - fn traverse_preorder( - code: &mut FunctionDefinition, - phinode_indexes: &HashMap<(usize, BlockId), usize>, - phinode_allocs: &HashMap>, - inv_domtree: &HashMap>, - bid: &BlockId, - stack: &mut HashMap>>, - replaces: &mut HashMap, - inpromotable: &HashSet, - domtree: &Domtree, - ) { - println!("bid: {}", bid); + struct Asdf<'a> { + code: &'a mut FunctionDefinition, + phinode_indexes: &'a HashMap<(usize, BlockId), usize>, + phinode_allocs: &'a HashMap>, + inv_domtree: &'a HashMap>, + stack: &'a mut HashMap>>, + replaces: &'a mut HashMap, + inpromotable: &'a HashSet, + domtree: &'a Domtree, + }; - let stack_org = stack.clone(); + fn traverse_preorder(asdf: &mut Asdf<'_>, bid: &BlockId) { + let stack_org: HashMap>> = asdf.stack.clone(); // let block_stack = stack.entry(*bid).or_insert_with(HashMap::new); - for (aid, dtype) in code.allocations.iter().enumerate() { - if !inpromotable.contains(&aid) { + for (aid, dtype) in asdf.code.allocations.iter().enumerate() { + if !asdf.inpromotable.contains(&aid) { let initial = find_initial( aid, dtype.deref().clone(), - code, - inpromotable, bid, - stack, - phinode_indexes, - domtree, + asdf.stack, + asdf.phinode_indexes, + asdf.domtree, ); - stack + asdf.stack .entry(*bid) - .or_insert_with(HashMap::new) + .or_default() .entry(aid) - .or_insert_with(Vec::new) + .or_default() .push(initial); // if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { // entry.push(Operand::register( @@ -228,21 +197,21 @@ impl Optimize for Mem2regInner { } } // find_initial(code, inpromotable, bid, stack, phinode_indexes, domtree); - let block_stack = stack.entry(*bid).or_insert_with(HashMap::new); + let block_stack = asdf.stack.entry(*bid).or_default(); - for (i, inst) in code.blocks[bid].instructions.iter().enumerate() { + for (i, inst) in asdf.code.blocks[bid].instructions.iter().enumerate() { match inst.deref() { Instruction::Store { ptr, value } => { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { - if !inpromotable.contains(aid) { + if !asdf.inpromotable.contains(aid) { block_stack.get_mut(aid).unwrap().push(value.clone()); } } } Instruction::Load { ptr } => { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { - if !inpromotable.contains(aid) { - let _unused = replaces.insert( + if !asdf.inpromotable.contains(aid) { + let _unused = asdf.replaces.insert( RegisterId::temp(*bid, i), block_stack[aid].last().unwrap().clone(), ); @@ -253,64 +222,61 @@ impl Optimize for Mem2regInner { } } - let block = code.blocks.get_mut(bid).unwrap(); + let block = asdf.code.blocks.get_mut(bid).unwrap(); match &mut block.exit { BlockExit::Jump { arg } => { - fill_jump_args(arg, phinode_allocs, &block_stack); + fill_jump_args(arg, asdf.phinode_allocs, block_stack); } BlockExit::ConditionalJump { arg_then, arg_else, .. } => { - fill_jump_args(arg_then, phinode_allocs, &block_stack); - fill_jump_args(arg_else, phinode_allocs, &block_stack); + fill_jump_args(arg_then, asdf.phinode_allocs, block_stack); + fill_jump_args(arg_else, asdf.phinode_allocs, block_stack); } BlockExit::Switch { default, cases, .. } => { - fill_jump_args(default, phinode_allocs, &block_stack); + fill_jump_args(default, asdf.phinode_allocs, block_stack); for (_, arg) in cases { - fill_jump_args(arg, phinode_allocs, &block_stack); + fill_jump_args(arg, asdf.phinode_allocs, block_stack); } } _ => (), } - if let Some(succs) = inv_domtree.get(bid) { + if let Some(succs) = asdf.inv_domtree.get(bid) { for succ in succs { - traverse_preorder( - code, - phinode_indexes, - phinode_allocs, - inv_domtree, - succ, - stack, - replaces, - inpromotable, - domtree, - ); + traverse_preorder(asdf, succ); } } - *stack = stack_org; + *asdf.stack = stack_org; } let bid_init = code.bid_init; - traverse_preorder( + let mut asdf = Asdf { code, - &phinode_indexes, - &phinode_allocs, - &inv_domtree, - &bid_init, - &mut stack, - &mut replaces, - &inpromotable, - &domtree, - ); + phinode_indexes: &phinode_indexes, + phinode_allocs: &phinode_allocs, + inv_domtree: &inv_domtree, + stack: &mut stack, + replaces: &mut replaces, + inpromotable: &inpromotable, + domtree: &domtree, + }; + + traverse_preorder(&mut asdf, &bid_init); for (bid, block) in &mut code.blocks { - for inst in block.instructions.iter_mut() { - replace_instruction_operands(inst, &replaces); + loop { + let mut changed = false; + for inst in block.instructions.iter_mut() { + changed = changed || replace_instruction_operands(inst, &replaces); + } + changed = changed || replace_exit_operands(&mut block.exit, &replaces); + if !changed { + break; + } } - replace_exit_operands(&mut block.exit, &replaces); } for block in code.blocks.values_mut() { @@ -335,7 +301,7 @@ impl Optimize for Mem2regInner { } } - println!("replaces: {:?}\n", replaces); + // println!("replaces: {:?}\n", replaces); true // TODO } @@ -539,46 +505,90 @@ fn intersect_idom( } } -fn replace_instruction_operands(inst: &mut Instruction, replaces: &HashMap) { +fn replace_instruction_operands( + inst: &mut Instruction, + replaces: &HashMap, +) -> bool { match inst { Instruction::BinOp { lhs, rhs, .. } => { - replace_operand(lhs, replaces); - replace_operand(rhs, replaces); + let a = replace_operand(lhs, replaces); + let b = replace_operand(rhs, replaces); + a || b } Instruction::UnaryOp { operand, .. } => replace_operand(operand, replaces), Instruction::Store { ptr, value } => { - replace_operand(ptr, replaces); - replace_operand(value, replaces); + let a = replace_operand(ptr, replaces); + let b = replace_operand(value, replaces); + a || b } Instruction::Load { ptr } => replace_operand(ptr, replaces), Instruction::Call { callee, args, .. } => { - replace_operand(callee, replaces); + let mut changed = replace_operand(callee, replaces); for arg in args.iter_mut() { - replace_operand(arg, replaces); + changed = changed || replace_operand(arg, replaces); } + changed } Instruction::TypeCast { value, .. } => replace_operand(value, replaces), Instruction::GetElementPtr { ptr, offset, .. } => { - replace_operand(ptr, replaces); - replace_operand(offset, replaces); + let a = replace_operand(ptr, replaces); + let b = replace_operand(offset, replaces); + a || b } _ => unreachable!(), } } -fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap) { +fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap) -> bool { match exit { - BlockExit::ConditionalJump { condition, .. } => replace_operand(condition, replaces), - BlockExit::Switch { value, .. } => replace_operand(value, replaces), + BlockExit::Jump { arg } => { + let mut changed = false; + for arg in &mut arg.args { + changed = changed || replace_operand(arg, replaces); + } + changed + } + BlockExit::ConditionalJump { + condition, + arg_then, + arg_else, + } => { + let mut changed = replace_operand(condition, replaces); + for arg in &mut arg_then.args { + changed = changed || replace_operand(arg, replaces); + } + for arg in &mut arg_else.args { + changed = changed || replace_operand(arg, replaces); + } + changed + } + BlockExit::Switch { + value, + default, + cases, + } => { + let mut changed = replace_operand(value, replaces); + for arg in &mut default.args { + changed = changed || replace_operand(arg, replaces); + } + for (_, arg) in cases { + for arg in &mut arg.args { + changed = changed || replace_operand(arg, replaces); + } + } + changed + } BlockExit::Return { value } => replace_operand(value, replaces), - _ => (), + _ => false, } } -fn replace_operand(operand: &mut Operand, replaces: &HashMap) { +fn replace_operand(operand: &mut Operand, replaces: &HashMap) -> bool { if let Operand::Register { rid, .. } = operand { if let Some(new_operand) = replaces.get(rid) { *operand = new_operand.clone(); + return true; } } + false }