This commit is contained in:
static
2025-05-26 06:35:02 +00:00
parent 1d9727e480
commit a4cde0207b

View File

@@ -13,45 +13,44 @@ pub struct Mem2regInner {}
impl Optimize<FunctionDefinition> for Mem2regInner { impl Optimize<FunctionDefinition> for Mem2regInner {
fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool {
let mut inpromotable = HashSet::new(); let mut inpromotables = HashSet::new();
let mut stores = HashMap::new(); let mut stores = HashMap::new();
for (bid, block) in &code.blocks { for (bid, block) in &code.blocks {
for inst in &block.instructions { for inst in &block.instructions {
match inst.deref() { match inst.deref() {
Instruction::Nop | Instruction::Load { .. } => (),
Instruction::BinOp { lhs, rhs, .. } => { Instruction::BinOp { lhs, rhs, .. } => {
mark_inpromotable(&mut inpromotable, lhs); mark_as_inpromotable(&mut inpromotables, lhs);
mark_inpromotable(&mut inpromotable, rhs); mark_as_inpromotable(&mut inpromotables, rhs);
} }
Instruction::UnaryOp { operand, .. } => { Instruction::UnaryOp { operand, .. } => {
mark_inpromotable(&mut inpromotable, operand); mark_as_inpromotable(&mut inpromotables, operand);
} }
Instruction::Store { ptr, value } => { Instruction::Store { ptr, value } => {
mark_inpromotable(&mut inpromotable, value); mark_as_inpromotable(&mut inpromotables, value);
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
stores.entry(*aid).or_insert_with(Vec::new).push(*bid); let _ = stores.entry(*aid).or_insert_with(HashSet::new).insert(*bid);
} }
} }
Instruction::Load { .. } => (),
Instruction::Call { callee, args, .. } => { Instruction::Call { callee, args, .. } => {
mark_inpromotable(&mut inpromotable, callee); mark_as_inpromotable(&mut inpromotables, callee);
for arg in args { for arg in args {
mark_inpromotable(&mut inpromotable, arg); mark_as_inpromotable(&mut inpromotables, arg);
} }
} }
Instruction::TypeCast { value, .. } => { Instruction::TypeCast { value, .. } => {
mark_inpromotable(&mut inpromotable, value); mark_as_inpromotable(&mut inpromotables, value);
} }
Instruction::GetElementPtr { ptr, .. } => { Instruction::GetElementPtr { ptr, .. } => {
mark_inpromotable(&mut inpromotable, ptr); mark_as_inpromotable(&mut inpromotables, ptr);
} }
Instruction::Nop => (), _ => unreachable!(),
_ => todo!(),
} }
} }
} }
if inpromotable.len() == code.allocations.len() { if inpromotables.len() == code.allocations.len() {
return false; return false;
} }
@@ -59,239 +58,75 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
let reverse_cfg = reverse_cfg(&cfg); let reverse_cfg = reverse_cfg(&cfg);
let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg); let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg);
let joins = stores let phinodes = stores
.iter() .into_iter()
.filter(|(aid, _)| !inpromotable.contains(*aid)) .filter(|(aid, _)| !inpromotables.contains(aid))
.map(|(aid, bids)| { .map(|(aid, bids)| {
(*aid, { let mut stack = bids.into_iter().collect::<Vec<_>>();
let mut stack = bids.clone(); let mut visited = HashSet::new();
let mut visited = HashSet::new(); while let Some(bid) = stack.pop() {
while let Some(bid) = stack.pop() { if let Some(bid_frontiers) = domtree.frontiers(&bid) {
if let Some(bid_frontiers) = domtree.frontiers(&bid) { for bid_frontier in bid_frontiers {
for bid_frontier in bid_frontiers { if visited.insert(*bid_frontier) {
if visited.insert(*bid_frontier) { stack.push(*bid_frontier);
stack.push(*bid_frontier);
}
} }
} }
} }
visited }
}) (aid, visited)
}) })
.collect::<HashMap<_, _>>(); .collect::<BTreeMap<_, _>>(); // aid -> [bid]
let mut phinode_indexes = HashMap::new(); // (aid, bid) -> phinode index
let mut phinode_allocs = HashMap::new(); // bid -> [(aid, dtype)]
let mut inv_joins = HashMap::new(); for (aid, bids) in phinodes {
for (aid, bids) in &joins {
for bid in bids { for bid in bids {
inv_joins.entry(bid).or_insert_with(Vec::new).push(*aid); let block = code.blocks.get_mut(&bid).unwrap();
} block.phinodes.push(code.allocations[aid].clone());
} let _ = phinode_indexes.insert((aid, bid), block.phinodes.len() - 1);
for (_, aids) in inv_joins.iter_mut() {
aids.sort();
}
let mut phinode_indexes = HashMap::new();
let mut phinode_allocs = HashMap::new();
for (bid, aids) in &inv_joins {
let block = code.blocks.get_mut(bid).unwrap();
let index = block.phinodes.len();
for (i, aid) in aids.iter().enumerate() {
block.phinodes.push(code.allocations[*aid].clone());
let _ = phinode_indexes.insert((*aid, **bid), index + i);
phinode_allocs phinode_allocs
.entry(**bid) .entry(bid)
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
.push((*aid, code.allocations[*aid].deref().clone())); .push((aid, code.allocations[aid].deref().clone()));
} }
} }
// for (aid, bids) in &joins {
// let alloc = code.allocations.get(*aid).unwrap();
// for bid in bids {
// // aid를 담는 Phinode를 넣을 곳
// let block = code.blocks.get_mut(bid).unwrap();
// let index = block.phinodes.len();
// block.phinodes.push(alloc.clone());
// let _ = phinode_indexes.insert((*aid, *bid), index);
// phinode_allocs
// .entry(*bid)
// .or_insert_with(Vec::new)
// .push((*aid, alloc.deref().clone()));
// }
// }
let mut inv_domtree = HashMap::new();
for (bid, idom) in &domtree.idoms {
inv_domtree.entry(*idom).or_insert_with(Vec::new).push(*bid);
}
// println!("{:?}", domtree.idoms);
// println!("{:?}", inv_domtree);
let mut stack = HashMap::new();
let mut replaces = HashMap::new(); let mut replaces = HashMap::new();
traverse_po(
fn find_initial( code.bid_init,
aid: usize,
dtype: Dtype,
bid: &BlockId,
stack: &HashMap<BlockId, HashMap<usize, Vec<Operand>>>,
phinode_indexes: &HashMap<(usize, BlockId), usize>,
domtree: &Domtree,
) -> Operand {
if let Some(block_stack) = stack.get(bid) {
if let Some(block_stack) = block_stack.get(&aid) {
return block_stack.last().unwrap().clone();
}
}
if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) {
Operand::register(RegisterId::arg(*bid, *phinode_index), dtype)
} else if let Some(bid_idom) = domtree.idoms.get(bid) {
find_initial(aid, dtype, bid_idom, stack, phinode_indexes, domtree)
} else {
Operand::constant(Constant::undef(dtype))
}
}
struct Asdf<'a> {
code: &'a mut FunctionDefinition,
phinode_indexes: &'a HashMap<(usize, BlockId), usize>,
phinode_allocs: &'a HashMap<BlockId, Vec<(usize, Dtype)>>,
inv_domtree: &'a HashMap<BlockId, Vec<BlockId>>,
stack: &'a mut HashMap<BlockId, HashMap<usize, Vec<Operand>>>,
replaces: &'a mut HashMap<RegisterId, Operand>,
inpromotable: &'a HashSet<usize>,
domtree: &'a Domtree,
};
fn traverse_preorder(asdf: &mut Asdf<'_>, bid: &BlockId) {
let stack_org: HashMap<BlockId, HashMap<usize, Vec<Operand>>> = asdf.stack.clone();
// let block_stack = stack.entry(*bid).or_insert_with(HashMap::new);
for (aid, dtype) in asdf.code.allocations.iter().enumerate() {
if !asdf.inpromotable.contains(&aid) {
let initial = find_initial(
aid,
dtype.deref().clone(),
bid,
asdf.stack,
asdf.phinode_indexes,
asdf.domtree,
);
asdf.stack
.entry(*bid)
.or_default()
.entry(aid)
.or_default()
.push(initial);
// if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) {
// entry.push(Operand::register(
// RegisterId::arg(*bid, *phinode_index),
// dtype.deref().clone(),
// ));
// } else {
// entry.push(Operand::constant(Constant::undef(dtype.deref().clone())));
// }
}
}
// find_initial(code, inpromotable, bid, stack, phinode_indexes, domtree);
let block_stack = asdf.stack.entry(*bid).or_default();
for (i, inst) in asdf.code.blocks[bid].instructions.iter().enumerate() {
match inst.deref() {
Instruction::Store { ptr, value } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if !asdf.inpromotable.contains(aid) {
block_stack.get_mut(aid).unwrap().push(value.clone());
}
}
}
Instruction::Load { ptr } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if !asdf.inpromotable.contains(aid) {
let _unused = asdf.replaces.insert(
RegisterId::temp(*bid, i),
block_stack[aid].last().unwrap().clone(),
);
}
}
}
_ => (),
}
}
let block = asdf.code.blocks.get_mut(bid).unwrap();
match &mut block.exit {
BlockExit::Jump { arg } => {
fill_jump_args(arg, asdf.phinode_allocs, block_stack);
}
BlockExit::ConditionalJump {
arg_then, arg_else, ..
} => {
fill_jump_args(arg_then, asdf.phinode_allocs, block_stack);
fill_jump_args(arg_else, asdf.phinode_allocs, block_stack);
}
BlockExit::Switch { default, cases, .. } => {
fill_jump_args(default, asdf.phinode_allocs, block_stack);
for (_, arg) in cases {
fill_jump_args(arg, asdf.phinode_allocs, block_stack);
}
}
_ => (),
}
if let Some(succs) = asdf.inv_domtree.get(bid) {
for succ in succs {
traverse_preorder(asdf, succ);
}
}
*asdf.stack = stack_org;
}
let bid_init = code.bid_init;
let mut asdf = Asdf {
code, code,
phinode_indexes: &phinode_indexes, &inpromotables,
phinode_allocs: &phinode_allocs, &domtree,
inv_domtree: &inv_domtree, (&phinode_indexes, &phinode_allocs),
stack: &mut stack, HashMap::new(),
replaces: &mut replaces, &mut replaces,
inpromotable: &inpromotable, );
domtree: &domtree,
};
traverse_preorder(&mut asdf, &bid_init); for (bid, block) in code.blocks.iter_mut() {
for (bid, block) in &mut code.blocks {
loop { loop {
let mut changed = false; let mut changed = false;
for inst in block.instructions.iter_mut() { for inst in block.instructions.iter_mut() {
changed = changed || replace_instruction_operands(inst, &replaces); changed = replace_instruction_operands(inst, &replaces) || changed;
} }
changed = changed || replace_exit_operands(&mut block.exit, &replaces); changed = replace_exit_operands(&mut block.exit, &replaces) || changed;
if !changed { if !changed {
break; break;
} }
} }
}
for block in code.blocks.values_mut() {
for inst in block.instructions.iter_mut() { for inst in block.instructions.iter_mut() {
match inst.deref().deref() { match inst.deref().deref() {
Instruction::Store { ptr, .. } => { Instruction::Store { ptr, .. } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if !inpromotable.contains(aid) { if !inpromotables.contains(aid) {
*inst.deref_mut() = Instruction::Nop; *inst.deref_mut() = Instruction::Nop;
} }
} }
} }
Instruction::Load { ptr } => { Instruction::Load { ptr } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if !inpromotable.contains(aid) { if !inpromotables.contains(aid) {
*inst.deref_mut() = Instruction::Nop; *inst.deref_mut() = Instruction::Nop;
} }
} }
@@ -301,28 +136,13 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
} }
} }
// println!("replaces: {:?}\n", replaces); true
true // TODO
} }
} }
fn fill_jump_args( fn mark_as_inpromotable(inpromotables: &mut HashSet<usize>, operand: &Operand) {
arg: &mut JumpArg,
phinode_allocs: &HashMap<BlockId, Vec<(usize, Dtype)>>,
block_stack: &HashMap<usize, Vec<Operand>>,
) {
if let Some(target_phinode_args) = phinode_allocs.get(&arg.bid) {
for (target_phinode_arg, dtype) in target_phinode_args {
arg.args
.push(block_stack[target_phinode_arg].last().unwrap().clone());
}
}
}
fn mark_inpromotable(inpromotable: &mut HashSet<usize>, operand: &Operand) {
if let Some((RegisterId::Local { aid }, _)) = operand.get_register() { if let Some((RegisterId::Local { aid }, _)) = operand.get_register() {
let _ = inpromotable.insert(*aid); let _ = inpromotables.insert(*aid);
} }
} }
@@ -391,8 +211,8 @@ fn traverse_rpo(bid_init: BlockId, cfg: &HashMap<BlockId, Vec<JumpArg>>) -> Vec<
struct Domtree { struct Domtree {
idoms: HashMap<BlockId, BlockId>, idoms: HashMap<BlockId, BlockId>,
inverse_idoms: HashMap<BlockId, HashSet<BlockId>>,
frontiers: HashMap<BlockId, Vec<BlockId>>, frontiers: HashMap<BlockId, Vec<BlockId>>,
rpo: Vec<BlockId>,
} }
impl Domtree { impl Domtree {
@@ -441,6 +261,14 @@ impl Domtree {
} }
} }
let mut inverse_idoms = HashMap::new();
for (bid, idom) in &idoms {
let _ = inverse_idoms
.entry(*idom)
.or_insert_with(HashSet::new)
.insert(*bid);
}
let mut frontiers = HashMap::new(); let mut frontiers = HashMap::new();
for (bid, preds) in reverse_cfg.iter().filter(|(_, preds)| preds.len() > 1) { for (bid, preds) in reverse_cfg.iter().filter(|(_, preds)| preds.len() > 1) {
let idom = if let Some(idom) = idoms.get(bid) { let idom = if let Some(idom) = idoms.get(bid) {
@@ -459,8 +287,8 @@ impl Domtree {
Self { Self {
idoms, idoms,
inverse_idoms,
frontiers, frontiers,
rpo,
} }
} }
@@ -477,10 +305,6 @@ impl Domtree {
} }
} }
fn idom(&self, bid: &BlockId) -> Option<&BlockId> {
self.idoms.get(bid)
}
fn frontiers(&self, bid: &BlockId) -> Option<&Vec<BlockId>> { fn frontiers(&self, bid: &BlockId) -> Option<&Vec<BlockId>> {
self.frontiers.get(bid) self.frontiers.get(bid)
} }
@@ -505,6 +329,139 @@ fn intersect_idom(
} }
} }
type PhinodeInfo<'a> = (
&'a HashMap<(usize, BlockId), usize>,
&'a HashMap<BlockId, Vec<(usize, Dtype)>>,
);
fn traverse_po(
bid: BlockId,
code: &mut FunctionDefinition,
inpromotables: &HashSet<usize>,
domtree: &Domtree,
(phinode_indexes, phinode_allocs): PhinodeInfo<'_>,
mut block_stacks: HashMap<BlockId, HashMap<usize, Vec<Operand>>>,
replaces: &mut HashMap<RegisterId, Operand>,
) {
let block = code.blocks.get_mut(&bid).unwrap();
let block_stack = code
.allocations
.iter()
.enumerate()
.filter(|(aid, _)| !inpromotables.contains(aid))
.map(|(aid, dtype)| {
let initial_value = find_end_value(
aid,
dtype.deref().clone(),
bid,
domtree,
phinode_indexes,
&block_stacks,
);
(aid, vec![initial_value])
})
.collect::<HashMap<_, _>>();
let block_stack = block_stacks.entry(bid).or_insert(block_stack);
for (i, inst) in block.instructions.iter().enumerate() {
match inst.deref() {
Instruction::Store { ptr, value } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if let Some(value_stack) = block_stack.get_mut(aid) {
value_stack.push(value.clone());
}
}
}
Instruction::Load { ptr } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if let Some(value_stack) = block_stack.get(aid) {
let _unused = replaces.insert(
RegisterId::temp(bid, i),
value_stack.last().unwrap().clone(),
);
}
}
}
_ => (),
}
}
match &mut block.exit {
BlockExit::Jump { arg } => {
fill_jump_args(arg, phinode_allocs, block_stack);
}
BlockExit::ConditionalJump {
arg_then, arg_else, ..
} => {
fill_jump_args(arg_then, phinode_allocs, block_stack);
fill_jump_args(arg_else, phinode_allocs, block_stack);
}
BlockExit::Switch { default, cases, .. } => {
fill_jump_args(default, phinode_allocs, block_stack);
for (_, arg) in cases {
fill_jump_args(arg, phinode_allocs, block_stack);
}
}
_ => (),
}
if let Some(bids_successors) = domtree.inverse_idoms.get(&bid) {
for bid_successor in bids_successors {
traverse_po(
*bid_successor,
code,
inpromotables,
domtree,
(phinode_indexes, phinode_allocs),
block_stacks.clone(),
replaces,
);
}
}
}
fn find_end_value(
aid: usize,
dtype: Dtype,
bid: BlockId,
domtree: &Domtree,
phinode_indexes: &HashMap<(usize, BlockId), usize>,
block_stacks: &HashMap<BlockId, HashMap<usize, Vec<Operand>>>,
) -> Operand {
if let Some(block_stack) = block_stacks.get(&bid) {
if let Some(value_stack) = block_stack.get(&aid) {
return value_stack.last().unwrap().clone();
}
}
if let Some(phinode_index) = phinode_indexes.get(&(aid, bid)) {
Operand::register(RegisterId::arg(bid, *phinode_index), dtype)
} else if let Some(bid_idom) = domtree.idoms.get(&bid) {
find_end_value(
aid,
dtype,
*bid_idom,
domtree,
phinode_indexes,
block_stacks,
)
} else {
Operand::constant(Constant::undef(dtype))
}
}
fn fill_jump_args(
arg: &mut JumpArg,
phinode_allocs: &HashMap<BlockId, Vec<(usize, Dtype)>>,
block_stack: &HashMap<usize, Vec<Operand>>,
) {
if let Some(phinode_allocs) = phinode_allocs.get(&arg.bid) {
for (aid, dtype) in phinode_allocs {
arg.args.push(block_stack[aid].last().unwrap().clone());
}
}
}
fn replace_instruction_operands( fn replace_instruction_operands(
inst: &mut Instruction, inst: &mut Instruction,
replaces: &HashMap<RegisterId, Operand>, replaces: &HashMap<RegisterId, Operand>,
@@ -525,7 +482,7 @@ fn replace_instruction_operands(
Instruction::Call { callee, args, .. } => { Instruction::Call { callee, args, .. } => {
let mut changed = replace_operand(callee, replaces); let mut changed = replace_operand(callee, replaces);
for arg in args.iter_mut() { for arg in args.iter_mut() {
changed = changed || replace_operand(arg, replaces); changed = replace_operand(arg, replaces) || changed;
} }
changed changed
} }
@@ -543,8 +500,8 @@ fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap<RegisterId, Op
match exit { match exit {
BlockExit::Jump { arg } => { BlockExit::Jump { arg } => {
let mut changed = false; let mut changed = false;
for arg in &mut arg.args { for arg in arg.args.iter_mut() {
changed = changed || replace_operand(arg, replaces); changed = replace_operand(arg, replaces) || changed;
} }
changed changed
} }
@@ -554,11 +511,11 @@ fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap<RegisterId, Op
arg_else, arg_else,
} => { } => {
let mut changed = replace_operand(condition, replaces); let mut changed = replace_operand(condition, replaces);
for arg in &mut arg_then.args { for arg in arg_then.args.iter_mut() {
changed = changed || replace_operand(arg, replaces); changed = replace_operand(arg, replaces) || changed;
} }
for arg in &mut arg_else.args { for arg in arg_else.args.iter_mut() {
changed = changed || replace_operand(arg, replaces); changed = replace_operand(arg, replaces) || changed;
} }
changed changed
} }
@@ -568,12 +525,12 @@ fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap<RegisterId, Op
cases, cases,
} => { } => {
let mut changed = replace_operand(value, replaces); let mut changed = replace_operand(value, replaces);
for arg in &mut default.args { for arg in default.args.iter_mut() {
changed = changed || replace_operand(arg, replaces); changed = replace_operand(arg, replaces) || changed;
} }
for (_, arg) in cases { for (_, arg) in cases {
for arg in &mut arg.args { for arg in arg.args.iter_mut() {
changed = changed || replace_operand(arg, replaces); changed = replace_operand(arg, replaces) || changed;
} }
} }
changed changed