diff --git a/src/opt/gvn.rs b/src/opt/gvn.rs index 9c8b832..f7a2485 100644 --- a/src/opt/gvn.rs +++ b/src/opt/gvn.rs @@ -17,7 +17,6 @@ pub struct GvnInner {} impl Optimize for GvnInner { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { - println!("hihi2352352"); let cfg = make_cfg(code); let mut reverse_cfg = reverse_cfg(&cfg); let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg); @@ -28,31 +27,21 @@ impl Optimize for GvnInner { let init_leader_table = leader_tables.entry(code.bid_init).or_default(); let mut replaces = HashMap::new(); - for (aid, _) in code.allocations.iter().enumerate() { - let number = get_register_number( - RegisterId::local(aid), - &mut register_table, - &expression_table, - init_leader_table, - ); - } - println!("hihi4"); - for bid in domtree.rpo() { - traverse_rpo( + visit_block( *bid, code, &mut reverse_cfg, &domtree, - &mut register_table, - &mut expression_table, - &mut leader_tables, + ( + &mut register_table, + &mut expression_table, + &mut leader_tables, + ), &mut replaces, ); } - println!("replaces: {replaces:?}"); - let mut result = false; for (bid, block) in code.blocks.iter_mut() { @@ -64,118 +53,105 @@ impl Optimize for GvnInner { changed = replace_exit_operands(&mut block.exit, &replaces) || changed; result = result || changed; - if !changed { break; } } } - println!("hihi2"); - result } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum Operand { +enum Number { Number(usize), - IrOperand(ir::Operand), + Constant(Constant), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] enum Expression { BinOp { op: ast::BinaryOperator, - lhs: Operand, - rhs: Operand, + lhs: Number, + rhs: Number, }, UnaryOp { op: ast::UnaryOperator, - operand: Operand, + operand: Number, }, TypeCast { - value: Operand, + value: Number, target_dtype: Dtype, }, GetElementPtr { - ptr: Operand, - offset: Operand, + ptr: Number, + offset: Number, dtype: Dtype, }, } -fn get_register_number( - rid: RegisterId, - register_table: &mut HashMap, - expression_table: &HashMap, - leader_table: &mut HashMap, -) -> NumberOrConstant { - let len = register_table.len() + expression_table.len(); - let number = register_table - .entry(rid) - .or_insert(NumberOrConstant::Number(len)) - .clone(); - if let NumberOrConstant::Number(number) = number { - let _unused = leader_table.insert(number, RegisterOrConstant::Register(rid)); - } - println!("[REG] {number:?}={rid:?}"); - number -} - -fn get_expression_number( - expr: Expression, - register_table: &HashMap, - expression_table: &mut HashMap, -) -> usize { - let len = register_table.len() + expression_table.len(); - let number = *expression_table.entry(expr.clone()).or_insert(len); - println!("[EXPR] {number}={expr:?}"); - number -} - -#[derive(Clone, Debug)] -enum RegisterOrConstant { +#[derive(Debug, Clone)] +enum LeaderValue { Register(RegisterId), Constant(Constant), } -#[derive(Clone, Debug, Hash, Eq, PartialEq)] -enum NumberOrConstant { - Number(usize), - Constant(Constant), +fn get_register_number( + rid: RegisterId, + register_table: &mut HashMap, + expression_table: &HashMap, + leader_table: &mut HashMap, +) -> Number { + let len = register_table.len() + expression_table.len(); + register_table + .entry(rid) + .or_insert_with(|| { + let _unused = leader_table.insert(len, LeaderValue::Register(rid)); + Number::Number(len) + }) + .clone() } -fn traverse_rpo( +fn get_expression_number( + expr: Expression, + register_table: &HashMap, + expression_table: &mut HashMap, +) -> usize { + let len = register_table.len() + expression_table.len(); + *expression_table.entry(expr).or_insert(len) +} + +type NumberingInfo<'a> = ( + &'a mut HashMap, + &'a mut HashMap, + &'a mut HashMap>, +); + +fn visit_block( bid: BlockId, code: &mut FunctionDefinition, reverse_cfg: &mut HashMap>, domtree: &Domtree, - register_table: &mut HashMap, - expression_table: &mut HashMap, - leader_tables: &mut HashMap>, - replaces: &mut HashMap, + (register_table, expression_table, leader_tables): NumberingInfo<'_>, + replaces: &mut HashMap, ) { - println!("bid: {bid}"); - let block = code.blocks.get_mut(&bid).unwrap(); - let idom_leader_table = domtree + let mut predecessors = reverse_cfg.get_mut(&bid); + let mut leader_table = domtree .idom(&bid) .and_then(|bid_idom| leader_tables.get(bid_idom)) .cloned() .unwrap_or_default(); - // let leader_table = leader_tables.entry(bid).or_insert(idom_leader_table); - let mut leader_table = idom_leader_table; - let mut predecessors = reverse_cfg.get_mut(&bid); if let Some(ref predecessors) = predecessors { - for (aid, _) in block.phinodes.iter().enumerate() { + for aid in 0..block.phinodes.len() { let rid = RegisterId::arg(bid, aid); let numbers = predecessors .iter() .map(|(_, arg)| match &arg.args[aid] { - ir::Operand::Constant(constant) => NumberOrConstant::Constant(constant.clone()), - ir::Operand::Register { rid, .. } => get_register_number( + Operand::Constant(constant) => Number::Constant(constant.clone()), + Operand::Register { rid, .. } => get_register_number( *rid, register_table, expression_table, @@ -184,17 +160,15 @@ fn traverse_rpo( }) .collect::>(); if numbers.len() == 1 { - let number = numbers.iter().next().unwrap().clone(); - match number { - NumberOrConstant::Number(number) => { - let _unused = register_table.insert(rid, NumberOrConstant::Number(number)); - let _unused = - leader_table.insert(number, RegisterOrConstant::Register(rid)); + match numbers.iter().next().unwrap().clone() { + Number::Number(number) => { + let _unused = register_table.insert(rid, Number::Number(number)); + let _unused = leader_table.insert(number, LeaderValue::Register(rid)); } - NumberOrConstant::Constant(constant) => { + Number::Constant(constant) => { let _unused = - register_table.insert(rid, NumberOrConstant::Constant(constant.clone())); - let _unused = replaces.insert(rid, ir::Operand::constant(constant.clone())); + register_table.insert(rid, Number::Constant(constant.clone())); + let _unused = replaces.insert(rid, Operand::constant(constant.clone())); } } } else { @@ -207,129 +181,71 @@ fn traverse_rpo( let mut phinode_allocs = HashMap::new(); for (i, inst) in block.instructions.iter_mut().enumerate() { - let expr = match inst.deref().deref() { - Instruction::BinOp { op, lhs, rhs, .. } => Expression::BinOp { - op: op.clone(), - lhs: ir_operand_to_gvn_operand( - lhs, - register_table, - expression_table, - &mut leader_table, - ), - rhs: ir_operand_to_gvn_operand( - rhs, - register_table, - expression_table, - &mut leader_table, - ), - }, - Instruction::UnaryOp { op, operand, .. } => Expression::UnaryOp { - op: op.clone(), - operand: ir_operand_to_gvn_operand( - operand, - register_table, - expression_table, - &mut leader_table, - ), - }, - Instruction::TypeCast { - value, - target_dtype, - } => Expression::TypeCast { - value: ir_operand_to_gvn_operand( - value, - register_table, - expression_table, - &mut leader_table, - ), - target_dtype: target_dtype.clone(), - }, - Instruction::GetElementPtr { ptr, offset, dtype } => Expression::GetElementPtr { - ptr: ir_operand_to_gvn_operand( - ptr, - register_table, - expression_table, - &mut leader_table, - ), - offset: ir_operand_to_gvn_operand( - offset, - register_table, - expression_table, - &mut leader_table, - ), - dtype: dtype.clone(), - }, - _ => continue, - }; - let number = get_expression_number(expr, register_table, expression_table); let rid = RegisterId::temp(bid, i); - - println!("{i} {number} {expression_table:?} {leader_table:?}"); + let number = get_expression_number( + if let Some(expr) = to_expression( + inst.deref(), + register_table, + expression_table, + &mut leader_table, + ) { + expr + } else { + continue; + }, + register_table, + expression_table, + ); if let Some(leader_value) = leader_table.get(&number) { match leader_value { - RegisterOrConstant::Register(leader_value) => { - let _unused = register_table.insert(rid, NumberOrConstant::Number(number)); + LeaderValue::Register(rid_leader_value) => { + let _unused = register_table.insert(rid, Number::Number(number)); let _unused = - replaces.insert(rid, ir::Operand::register(*leader_value, inst.dtype())); + replaces.insert(rid, Operand::register(*rid_leader_value, inst.dtype())); } - RegisterOrConstant::Constant(constant) => { - let _unused = replaces.insert(rid, ir::Operand::constant(constant.clone())); + LeaderValue::Constant(constant) => { + let _unused = replaces.insert(rid, Operand::constant(constant.clone())); } } } else { if let Some(ref mut predecessors) = predecessors { - println!("하위 {predecessors:?} {leader_tables:?}"); if predecessors.iter().all(|(bid_predecessor, _)| { leader_tables .get(bid_predecessor) - .and_then(|leader_table| Some(leader_table.contains_key(&number))) - .unwrap_or(false) + .is_some_and(|leader_table| leader_table.contains_key(&number)) }) { - println!("상위"); - - let phinode_index = block.phinodes.len(); + let rid_phinode = RegisterId::arg(bid, block.phinodes.len()); block.phinodes.push(Named::new(None, inst.dtype())); - let _unused = register_table.insert( - RegisterId::arg(bid, phinode_index), - NumberOrConstant::Number(number), - ); - let _unused = leader_table.insert( - number, - RegisterOrConstant::Register(RegisterId::arg(bid, phinode_index)), - ); - let _unused = replaces.insert( - rid, - ir::Operand::register(RegisterId::arg(bid, phinode_index), inst.dtype()), - ); + + let _unused = register_table.insert(rid_phinode, Number::Number(number)); + let _unused = leader_table.insert(number, LeaderValue::Register(rid_phinode)); + let _unused = + replaces.insert(rid, Operand::register(rid_phinode, inst.dtype())); for (bid_predecessor, jump_arg) in predecessors.iter_mut() { - let new_arg = match &leader_tables[bid_predecessor][&number] { - RegisterOrConstant::Register(rid) => { - ir::Operand::register(*rid, inst.dtype()) - } - RegisterOrConstant::Constant(constant) => { - ir::Operand::constant(constant.clone()) - } - }; - jump_arg.args.push(new_arg.clone()); - let _ = phinode_allocs + phinode_allocs .entry(*bid_predecessor) .or_insert_with(Vec::new) - .push(new_arg); + .push(match &leader_tables[bid_predecessor][&number] { + LeaderValue::Register(rid) => Operand::register(*rid, inst.dtype()), + LeaderValue::Constant(constant) => { + Operand::constant(constant.clone()) + } + }); } continue; } } - let _unused = leader_table.insert(number, RegisterOrConstant::Register(rid)); - let _unused = register_table.insert(rid, NumberOrConstant::Number(number)); + let _unused = leader_table.insert(number, LeaderValue::Register(rid)); + let _unused = register_table.insert(rid, Number::Number(number)); } } - for (p_bid, new_args) in phinode_allocs { - match &mut code.blocks.get_mut(&p_bid).unwrap().exit { + for (bid_target, new_args) in phinode_allocs { + match &mut code.blocks.get_mut(&bid_target).unwrap().exit { BlockExit::Jump { arg } => { fill_jump_args(bid, arg, new_args); } @@ -340,10 +256,10 @@ fn traverse_rpo( fill_jump_args(bid, arg_else, new_args); } BlockExit::Switch { default, cases, .. } => { - fill_jump_args(bid, default, new_args.clone()); for (_, arg) in cases { fill_jump_args(bid, arg, new_args.clone()); } + fill_jump_args(bid, default, new_args); } _ => (), } @@ -352,26 +268,53 @@ fn traverse_rpo( let _unused = leader_tables.insert(bid, leader_table); } -fn ir_operand_to_gvn_operand( - operand: &ir::Operand, - register_table: &mut HashMap, +fn to_expression( + inst: &Instruction, + register_table: &mut HashMap, expression_table: &HashMap, - leader_table: &mut HashMap, -) -> Operand { - match operand { - ir::Operand::Register { rid, .. } => { - match get_register_number(*rid, register_table, expression_table, leader_table) { - NumberOrConstant::Number(number) => Operand::Number(number), - NumberOrConstant::Constant(constant) => { - Operand::IrOperand(ir::Operand::Constant(constant)) - } - } - } - _ => Operand::IrOperand(operand.clone()), + leader_table: &mut HashMap, +) -> Option { + match inst { + Instruction::BinOp { op, lhs, rhs, .. } => Some(Expression::BinOp { + op: op.clone(), + lhs: get_operand_number(lhs, register_table, expression_table, leader_table), + rhs: get_operand_number(rhs, register_table, expression_table, leader_table), + }), + Instruction::UnaryOp { op, operand, .. } => Some(Expression::UnaryOp { + op: op.clone(), + operand: get_operand_number(operand, register_table, expression_table, leader_table), + }), + Instruction::TypeCast { + value, + target_dtype, + } => Some(Expression::TypeCast { + value: get_operand_number(value, register_table, expression_table, leader_table), + target_dtype: target_dtype.clone(), + }), + Instruction::GetElementPtr { ptr, offset, dtype } => Some(Expression::GetElementPtr { + ptr: get_operand_number(ptr, register_table, expression_table, leader_table), + offset: get_operand_number(offset, register_table, expression_table, leader_table), + dtype: dtype.clone(), + }), + _ => None, } } -fn fill_jump_args(bid: BlockId, arg: &mut JumpArg, mut new_args: Vec) { +fn get_operand_number( + operand: &Operand, + register_table: &mut HashMap, + expression_table: &HashMap, + leader_table: &mut HashMap, +) -> Number { + match operand { + Operand::Constant(constant) => Number::Constant(constant.clone()), + Operand::Register { rid, .. } => { + get_register_number(*rid, register_table, expression_table, leader_table) + } + } +} + +fn fill_jump_args(bid: BlockId, arg: &mut JumpArg, mut new_args: Vec) { if bid == arg.bid { arg.args.append(&mut new_args); } diff --git a/src/opt/opt_utils.rs b/src/opt/opt_utils.rs index 80d6a7b..ff69108 100644 --- a/src/opt/opt_utils.rs +++ b/src/opt/opt_utils.rs @@ -52,30 +52,6 @@ pub(crate) fn reverse_cfg( result } -fn traverse_rpo(bid_init: BlockId, cfg: &HashMap>) -> Vec { - fn traverse_po( - bid: BlockId, - cfg: &HashMap>, - visited: &mut HashSet, - post_order: &mut Vec, - ) { - for jump in &cfg[&bid] { - if visited.insert(jump.bid) { - traverse_po(jump.bid, cfg, visited, post_order); - } - } - post_order.push(bid); - } - - let mut visited = HashSet::new(); - let _ = visited.insert(bid_init); - let mut order = Vec::new(); - - traverse_po(bid_init, cfg, &mut visited, &mut order); - order.reverse(); - order -} - pub(crate) struct Domtree { idoms: HashMap, inverse_idoms: HashMap>, @@ -191,6 +167,30 @@ impl Domtree { } } +fn traverse_rpo(bid_init: BlockId, cfg: &HashMap>) -> Vec { + fn traverse_po( + bid: BlockId, + cfg: &HashMap>, + visited: &mut HashSet, + post_order: &mut Vec, + ) { + for jump in &cfg[&bid] { + if visited.insert(jump.bid) { + traverse_po(jump.bid, cfg, visited, post_order); + } + } + post_order.push(bid); + } + + let mut visited = HashSet::new(); + let _ = visited.insert(bid_init); + let mut order = Vec::new(); + + traverse_po(bid_init, cfg, &mut visited, &mut order); + order.reverse(); + order +} + fn intersect_idom( lhs: Option, mut rhs: BlockId, @@ -292,10 +292,7 @@ pub(crate) fn replace_exit_operands( } } -pub(crate) fn replace_operand( - operand: &mut Operand, - replaces: &HashMap, -) -> bool { +fn replace_operand(operand: &mut Operand, replaces: &HashMap) -> bool { if let Operand::Register { rid, .. } = operand { if let Some(new_operand) = replaces.get(rid) { if (operand != new_operand) {