This commit is contained in:
static
2025-06-01 13:56:45 +00:00
parent 40ee759c23
commit f70b84f84f
2 changed files with 166 additions and 226 deletions

View File

@@ -17,7 +17,6 @@ pub struct GvnInner {}
impl Optimize<FunctionDefinition> for GvnInner { impl Optimize<FunctionDefinition> for GvnInner {
fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool {
println!("hihi2352352");
let cfg = make_cfg(code); let cfg = make_cfg(code);
let mut reverse_cfg = reverse_cfg(&cfg); let mut reverse_cfg = reverse_cfg(&cfg);
let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg); let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg);
@@ -28,31 +27,21 @@ impl Optimize<FunctionDefinition> for GvnInner {
let init_leader_table = leader_tables.entry(code.bid_init).or_default(); let init_leader_table = leader_tables.entry(code.bid_init).or_default();
let mut replaces = HashMap::new(); let mut replaces = HashMap::new();
for (aid, _) in code.allocations.iter().enumerate() {
let number = get_register_number(
RegisterId::local(aid),
&mut register_table,
&expression_table,
init_leader_table,
);
}
println!("hihi4");
for bid in domtree.rpo() { for bid in domtree.rpo() {
traverse_rpo( visit_block(
*bid, *bid,
code, code,
&mut reverse_cfg, &mut reverse_cfg,
&domtree, &domtree,
&mut register_table, (
&mut expression_table, &mut register_table,
&mut leader_tables, &mut expression_table,
&mut leader_tables,
),
&mut replaces, &mut replaces,
); );
} }
println!("replaces: {replaces:?}");
let mut result = false; let mut result = false;
for (bid, block) in code.blocks.iter_mut() { for (bid, block) in code.blocks.iter_mut() {
@@ -64,118 +53,105 @@ impl Optimize<FunctionDefinition> for GvnInner {
changed = replace_exit_operands(&mut block.exit, &replaces) || changed; changed = replace_exit_operands(&mut block.exit, &replaces) || changed;
result = result || changed; result = result || changed;
if !changed { if !changed {
break; break;
} }
} }
} }
println!("hihi2");
result result
} }
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Operand { enum Number {
Number(usize), Number(usize),
IrOperand(ir::Operand), Constant(Constant),
} }
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum Expression { enum Expression {
BinOp { BinOp {
op: ast::BinaryOperator, op: ast::BinaryOperator,
lhs: Operand, lhs: Number,
rhs: Operand, rhs: Number,
}, },
UnaryOp { UnaryOp {
op: ast::UnaryOperator, op: ast::UnaryOperator,
operand: Operand, operand: Number,
}, },
TypeCast { TypeCast {
value: Operand, value: Number,
target_dtype: Dtype, target_dtype: Dtype,
}, },
GetElementPtr { GetElementPtr {
ptr: Operand, ptr: Number,
offset: Operand, offset: Number,
dtype: Dtype, dtype: Dtype,
}, },
} }
fn get_register_number( #[derive(Debug, Clone)]
rid: RegisterId, enum LeaderValue {
register_table: &mut HashMap<RegisterId, NumberOrConstant>,
expression_table: &HashMap<Expression, usize>,
leader_table: &mut HashMap<usize, RegisterOrConstant>,
) -> NumberOrConstant {
let len = register_table.len() + expression_table.len();
let number = register_table
.entry(rid)
.or_insert(NumberOrConstant::Number(len))
.clone();
if let NumberOrConstant::Number(number) = number {
let _unused = leader_table.insert(number, RegisterOrConstant::Register(rid));
}
println!("[REG] {number:?}={rid:?}");
number
}
fn get_expression_number(
expr: Expression,
register_table: &HashMap<RegisterId, NumberOrConstant>,
expression_table: &mut HashMap<Expression, usize>,
) -> usize {
let len = register_table.len() + expression_table.len();
let number = *expression_table.entry(expr.clone()).or_insert(len);
println!("[EXPR] {number}={expr:?}");
number
}
#[derive(Clone, Debug)]
enum RegisterOrConstant {
Register(RegisterId), Register(RegisterId),
Constant(Constant), Constant(Constant),
} }
#[derive(Clone, Debug, Hash, Eq, PartialEq)] fn get_register_number(
enum NumberOrConstant { rid: RegisterId,
Number(usize), register_table: &mut HashMap<RegisterId, Number>,
Constant(Constant), expression_table: &HashMap<Expression, usize>,
leader_table: &mut HashMap<usize, LeaderValue>,
) -> Number {
let len = register_table.len() + expression_table.len();
register_table
.entry(rid)
.or_insert_with(|| {
let _unused = leader_table.insert(len, LeaderValue::Register(rid));
Number::Number(len)
})
.clone()
} }
fn traverse_rpo( fn get_expression_number(
expr: Expression,
register_table: &HashMap<RegisterId, Number>,
expression_table: &mut HashMap<Expression, usize>,
) -> usize {
let len = register_table.len() + expression_table.len();
*expression_table.entry(expr).or_insert(len)
}
type NumberingInfo<'a> = (
&'a mut HashMap<RegisterId, Number>,
&'a mut HashMap<Expression, usize>,
&'a mut HashMap<BlockId, HashMap<usize, LeaderValue>>,
);
fn visit_block(
bid: BlockId, bid: BlockId,
code: &mut FunctionDefinition, code: &mut FunctionDefinition,
reverse_cfg: &mut HashMap<BlockId, Vec<(BlockId, JumpArg)>>, reverse_cfg: &mut HashMap<BlockId, Vec<(BlockId, JumpArg)>>,
domtree: &Domtree, domtree: &Domtree,
register_table: &mut HashMap<RegisterId, NumberOrConstant>, (register_table, expression_table, leader_tables): NumberingInfo<'_>,
expression_table: &mut HashMap<Expression, usize>, replaces: &mut HashMap<RegisterId, Operand>,
leader_tables: &mut HashMap<BlockId, HashMap<usize, RegisterOrConstant>>,
replaces: &mut HashMap<RegisterId, ir::Operand>,
) { ) {
println!("bid: {bid}");
let block = code.blocks.get_mut(&bid).unwrap(); let block = code.blocks.get_mut(&bid).unwrap();
let idom_leader_table = domtree let mut predecessors = reverse_cfg.get_mut(&bid);
let mut leader_table = domtree
.idom(&bid) .idom(&bid)
.and_then(|bid_idom| leader_tables.get(bid_idom)) .and_then(|bid_idom| leader_tables.get(bid_idom))
.cloned() .cloned()
.unwrap_or_default(); .unwrap_or_default();
// let leader_table = leader_tables.entry(bid).or_insert(idom_leader_table);
let mut leader_table = idom_leader_table;
let mut predecessors = reverse_cfg.get_mut(&bid);
if let Some(ref predecessors) = predecessors { if let Some(ref predecessors) = predecessors {
for (aid, _) in block.phinodes.iter().enumerate() { for aid in 0..block.phinodes.len() {
let rid = RegisterId::arg(bid, aid); let rid = RegisterId::arg(bid, aid);
let numbers = predecessors let numbers = predecessors
.iter() .iter()
.map(|(_, arg)| match &arg.args[aid] { .map(|(_, arg)| match &arg.args[aid] {
ir::Operand::Constant(constant) => NumberOrConstant::Constant(constant.clone()), Operand::Constant(constant) => Number::Constant(constant.clone()),
ir::Operand::Register { rid, .. } => get_register_number( Operand::Register { rid, .. } => get_register_number(
*rid, *rid,
register_table, register_table,
expression_table, expression_table,
@@ -184,17 +160,15 @@ fn traverse_rpo(
}) })
.collect::<HashSet<_>>(); .collect::<HashSet<_>>();
if numbers.len() == 1 { if numbers.len() == 1 {
let number = numbers.iter().next().unwrap().clone(); match numbers.iter().next().unwrap().clone() {
match number { Number::Number(number) => {
NumberOrConstant::Number(number) => { let _unused = register_table.insert(rid, Number::Number(number));
let _unused = register_table.insert(rid, NumberOrConstant::Number(number)); let _unused = leader_table.insert(number, LeaderValue::Register(rid));
let _unused =
leader_table.insert(number, RegisterOrConstant::Register(rid));
} }
NumberOrConstant::Constant(constant) => { Number::Constant(constant) => {
let _unused = let _unused =
register_table.insert(rid, NumberOrConstant::Constant(constant.clone())); register_table.insert(rid, Number::Constant(constant.clone()));
let _unused = replaces.insert(rid, ir::Operand::constant(constant.clone())); let _unused = replaces.insert(rid, Operand::constant(constant.clone()));
} }
} }
} else { } else {
@@ -207,129 +181,71 @@ fn traverse_rpo(
let mut phinode_allocs = HashMap::new(); let mut phinode_allocs = HashMap::new();
for (i, inst) in block.instructions.iter_mut().enumerate() { for (i, inst) in block.instructions.iter_mut().enumerate() {
let expr = match inst.deref().deref() {
Instruction::BinOp { op, lhs, rhs, .. } => Expression::BinOp {
op: op.clone(),
lhs: ir_operand_to_gvn_operand(
lhs,
register_table,
expression_table,
&mut leader_table,
),
rhs: ir_operand_to_gvn_operand(
rhs,
register_table,
expression_table,
&mut leader_table,
),
},
Instruction::UnaryOp { op, operand, .. } => Expression::UnaryOp {
op: op.clone(),
operand: ir_operand_to_gvn_operand(
operand,
register_table,
expression_table,
&mut leader_table,
),
},
Instruction::TypeCast {
value,
target_dtype,
} => Expression::TypeCast {
value: ir_operand_to_gvn_operand(
value,
register_table,
expression_table,
&mut leader_table,
),
target_dtype: target_dtype.clone(),
},
Instruction::GetElementPtr { ptr, offset, dtype } => Expression::GetElementPtr {
ptr: ir_operand_to_gvn_operand(
ptr,
register_table,
expression_table,
&mut leader_table,
),
offset: ir_operand_to_gvn_operand(
offset,
register_table,
expression_table,
&mut leader_table,
),
dtype: dtype.clone(),
},
_ => continue,
};
let number = get_expression_number(expr, register_table, expression_table);
let rid = RegisterId::temp(bid, i); let rid = RegisterId::temp(bid, i);
let number = get_expression_number(
println!("{i} {number} {expression_table:?} {leader_table:?}"); if let Some(expr) = to_expression(
inst.deref(),
register_table,
expression_table,
&mut leader_table,
) {
expr
} else {
continue;
},
register_table,
expression_table,
);
if let Some(leader_value) = leader_table.get(&number) { if let Some(leader_value) = leader_table.get(&number) {
match leader_value { match leader_value {
RegisterOrConstant::Register(leader_value) => { LeaderValue::Register(rid_leader_value) => {
let _unused = register_table.insert(rid, NumberOrConstant::Number(number)); let _unused = register_table.insert(rid, Number::Number(number));
let _unused = let _unused =
replaces.insert(rid, ir::Operand::register(*leader_value, inst.dtype())); replaces.insert(rid, Operand::register(*rid_leader_value, inst.dtype()));
} }
RegisterOrConstant::Constant(constant) => { LeaderValue::Constant(constant) => {
let _unused = replaces.insert(rid, ir::Operand::constant(constant.clone())); let _unused = replaces.insert(rid, Operand::constant(constant.clone()));
} }
} }
} else { } else {
if let Some(ref mut predecessors) = predecessors { if let Some(ref mut predecessors) = predecessors {
println!("하위 {predecessors:?} {leader_tables:?}");
if predecessors.iter().all(|(bid_predecessor, _)| { if predecessors.iter().all(|(bid_predecessor, _)| {
leader_tables leader_tables
.get(bid_predecessor) .get(bid_predecessor)
.and_then(|leader_table| Some(leader_table.contains_key(&number))) .is_some_and(|leader_table| leader_table.contains_key(&number))
.unwrap_or(false)
}) { }) {
println!("상위"); let rid_phinode = RegisterId::arg(bid, block.phinodes.len());
let phinode_index = block.phinodes.len();
block.phinodes.push(Named::new(None, inst.dtype())); block.phinodes.push(Named::new(None, inst.dtype()));
let _unused = register_table.insert(
RegisterId::arg(bid, phinode_index), let _unused = register_table.insert(rid_phinode, Number::Number(number));
NumberOrConstant::Number(number), let _unused = leader_table.insert(number, LeaderValue::Register(rid_phinode));
); let _unused =
let _unused = leader_table.insert( replaces.insert(rid, Operand::register(rid_phinode, inst.dtype()));
number,
RegisterOrConstant::Register(RegisterId::arg(bid, phinode_index)),
);
let _unused = replaces.insert(
rid,
ir::Operand::register(RegisterId::arg(bid, phinode_index), inst.dtype()),
);
for (bid_predecessor, jump_arg) in predecessors.iter_mut() { for (bid_predecessor, jump_arg) in predecessors.iter_mut() {
let new_arg = match &leader_tables[bid_predecessor][&number] { phinode_allocs
RegisterOrConstant::Register(rid) => {
ir::Operand::register(*rid, inst.dtype())
}
RegisterOrConstant::Constant(constant) => {
ir::Operand::constant(constant.clone())
}
};
jump_arg.args.push(new_arg.clone());
let _ = phinode_allocs
.entry(*bid_predecessor) .entry(*bid_predecessor)
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
.push(new_arg); .push(match &leader_tables[bid_predecessor][&number] {
LeaderValue::Register(rid) => Operand::register(*rid, inst.dtype()),
LeaderValue::Constant(constant) => {
Operand::constant(constant.clone())
}
});
} }
continue; continue;
} }
} }
let _unused = leader_table.insert(number, RegisterOrConstant::Register(rid)); let _unused = leader_table.insert(number, LeaderValue::Register(rid));
let _unused = register_table.insert(rid, NumberOrConstant::Number(number)); let _unused = register_table.insert(rid, Number::Number(number));
} }
} }
for (p_bid, new_args) in phinode_allocs { for (bid_target, new_args) in phinode_allocs {
match &mut code.blocks.get_mut(&p_bid).unwrap().exit { match &mut code.blocks.get_mut(&bid_target).unwrap().exit {
BlockExit::Jump { arg } => { BlockExit::Jump { arg } => {
fill_jump_args(bid, arg, new_args); fill_jump_args(bid, arg, new_args);
} }
@@ -340,10 +256,10 @@ fn traverse_rpo(
fill_jump_args(bid, arg_else, new_args); fill_jump_args(bid, arg_else, new_args);
} }
BlockExit::Switch { default, cases, .. } => { BlockExit::Switch { default, cases, .. } => {
fill_jump_args(bid, default, new_args.clone());
for (_, arg) in cases { for (_, arg) in cases {
fill_jump_args(bid, arg, new_args.clone()); fill_jump_args(bid, arg, new_args.clone());
} }
fill_jump_args(bid, default, new_args);
} }
_ => (), _ => (),
} }
@@ -352,26 +268,53 @@ fn traverse_rpo(
let _unused = leader_tables.insert(bid, leader_table); let _unused = leader_tables.insert(bid, leader_table);
} }
fn ir_operand_to_gvn_operand( fn to_expression(
operand: &ir::Operand, inst: &Instruction,
register_table: &mut HashMap<RegisterId, NumberOrConstant>, register_table: &mut HashMap<RegisterId, Number>,
expression_table: &HashMap<Expression, usize>, expression_table: &HashMap<Expression, usize>,
leader_table: &mut HashMap<usize, RegisterOrConstant>, leader_table: &mut HashMap<usize, LeaderValue>,
) -> Operand { ) -> Option<Expression> {
match operand { match inst {
ir::Operand::Register { rid, .. } => { Instruction::BinOp { op, lhs, rhs, .. } => Some(Expression::BinOp {
match get_register_number(*rid, register_table, expression_table, leader_table) { op: op.clone(),
NumberOrConstant::Number(number) => Operand::Number(number), lhs: get_operand_number(lhs, register_table, expression_table, leader_table),
NumberOrConstant::Constant(constant) => { rhs: get_operand_number(rhs, register_table, expression_table, leader_table),
Operand::IrOperand(ir::Operand::Constant(constant)) }),
} Instruction::UnaryOp { op, operand, .. } => Some(Expression::UnaryOp {
} op: op.clone(),
} operand: get_operand_number(operand, register_table, expression_table, leader_table),
_ => Operand::IrOperand(operand.clone()), }),
Instruction::TypeCast {
value,
target_dtype,
} => Some(Expression::TypeCast {
value: get_operand_number(value, register_table, expression_table, leader_table),
target_dtype: target_dtype.clone(),
}),
Instruction::GetElementPtr { ptr, offset, dtype } => Some(Expression::GetElementPtr {
ptr: get_operand_number(ptr, register_table, expression_table, leader_table),
offset: get_operand_number(offset, register_table, expression_table, leader_table),
dtype: dtype.clone(),
}),
_ => None,
} }
} }
fn fill_jump_args(bid: BlockId, arg: &mut JumpArg, mut new_args: Vec<ir::Operand>) { fn get_operand_number(
operand: &Operand,
register_table: &mut HashMap<RegisterId, Number>,
expression_table: &HashMap<Expression, usize>,
leader_table: &mut HashMap<usize, LeaderValue>,
) -> Number {
match operand {
Operand::Constant(constant) => Number::Constant(constant.clone()),
Operand::Register { rid, .. } => {
get_register_number(*rid, register_table, expression_table, leader_table)
}
}
}
fn fill_jump_args(bid: BlockId, arg: &mut JumpArg, mut new_args: Vec<Operand>) {
if bid == arg.bid { if bid == arg.bid {
arg.args.append(&mut new_args); arg.args.append(&mut new_args);
} }

View File

@@ -52,30 +52,6 @@ pub(crate) fn reverse_cfg(
result result
} }
fn traverse_rpo(bid_init: BlockId, cfg: &HashMap<BlockId, Vec<JumpArg>>) -> Vec<BlockId> {
fn traverse_po(
bid: BlockId,
cfg: &HashMap<BlockId, Vec<JumpArg>>,
visited: &mut HashSet<BlockId>,
post_order: &mut Vec<BlockId>,
) {
for jump in &cfg[&bid] {
if visited.insert(jump.bid) {
traverse_po(jump.bid, cfg, visited, post_order);
}
}
post_order.push(bid);
}
let mut visited = HashSet::new();
let _ = visited.insert(bid_init);
let mut order = Vec::new();
traverse_po(bid_init, cfg, &mut visited, &mut order);
order.reverse();
order
}
pub(crate) struct Domtree { pub(crate) struct Domtree {
idoms: HashMap<BlockId, BlockId>, idoms: HashMap<BlockId, BlockId>,
inverse_idoms: HashMap<BlockId, HashSet<BlockId>>, inverse_idoms: HashMap<BlockId, HashSet<BlockId>>,
@@ -191,6 +167,30 @@ impl Domtree {
} }
} }
fn traverse_rpo(bid_init: BlockId, cfg: &HashMap<BlockId, Vec<JumpArg>>) -> Vec<BlockId> {
fn traverse_po(
bid: BlockId,
cfg: &HashMap<BlockId, Vec<JumpArg>>,
visited: &mut HashSet<BlockId>,
post_order: &mut Vec<BlockId>,
) {
for jump in &cfg[&bid] {
if visited.insert(jump.bid) {
traverse_po(jump.bid, cfg, visited, post_order);
}
}
post_order.push(bid);
}
let mut visited = HashSet::new();
let _ = visited.insert(bid_init);
let mut order = Vec::new();
traverse_po(bid_init, cfg, &mut visited, &mut order);
order.reverse();
order
}
fn intersect_idom( fn intersect_idom(
lhs: Option<BlockId>, lhs: Option<BlockId>,
mut rhs: BlockId, mut rhs: BlockId,
@@ -292,10 +292,7 @@ pub(crate) fn replace_exit_operands(
} }
} }
pub(crate) fn replace_operand( fn replace_operand(operand: &mut Operand, replaces: &HashMap<RegisterId, Operand>) -> bool {
operand: &mut Operand,
replaces: &HashMap<RegisterId, Operand>,
) -> bool {
if let Operand::Register { rid, .. } = operand { if let Operand::Register { rid, .. } = operand {
if let Some(new_operand) = replaces.get(rid) { if let Some(new_operand) = replaces.get(rid) {
if (operand != new_operand) { if (operand != new_operand) {