This commit is contained in:
static
2025-05-20 01:43:59 +00:00
parent 7b67eaa2ce
commit 1d9727e480

View File

@@ -46,7 +46,7 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
mark_inpromotable(&mut inpromotable, ptr); mark_inpromotable(&mut inpromotable, ptr);
} }
Instruction::Nop => (), Instruction::Nop => (),
_ => todo!() _ => todo!(),
} }
} }
} }
@@ -80,7 +80,7 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
}) })
.collect::<HashMap<_, _>>(); .collect::<HashMap<_, _>>();
let mut inv_joins = HashMap::new(); let mut inv_joins = HashMap::new();
for (aid, bids) in &joins { for (aid, bids) in &joins {
for bid in bids { for bid in bids {
inv_joins.entry(bid).or_insert_with(Vec::new).push(*aid); inv_joins.entry(bid).or_insert_with(Vec::new).push(*aid);
@@ -104,8 +104,6 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
.push((*aid, code.allocations[*aid].deref().clone())); .push((*aid, code.allocations[*aid].deref().clone()));
} }
} }
// for (aid, bids) in &joins { // for (aid, bids) in &joins {
// let alloc = code.allocations.get(*aid).unwrap(); // let alloc = code.allocations.get(*aid).unwrap();
@@ -127,8 +125,8 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
inv_domtree.entry(*idom).or_insert_with(Vec::new).push(*bid); inv_domtree.entry(*idom).or_insert_with(Vec::new).push(*bid);
} }
println!("{:?}", domtree.idoms); // println!("{:?}", domtree.idoms);
println!("{:?}", inv_domtree); // println!("{:?}", inv_domtree);
let mut stack = HashMap::new(); let mut stack = HashMap::new();
let mut replaces = HashMap::new(); let mut replaces = HashMap::new();
@@ -136,8 +134,6 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
fn find_initial( fn find_initial(
aid: usize, aid: usize,
dtype: Dtype, dtype: Dtype,
code: &FunctionDefinition,
inpromotable: &HashSet<usize>,
bid: &BlockId, bid: &BlockId,
stack: &HashMap<BlockId, HashMap<usize, Vec<Operand>>>, stack: &HashMap<BlockId, HashMap<usize, Vec<Operand>>>,
phinode_indexes: &HashMap<(usize, BlockId), usize>, phinode_indexes: &HashMap<(usize, BlockId), usize>,
@@ -145,77 +141,50 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
) -> Operand { ) -> Operand {
if let Some(block_stack) = stack.get(bid) { if let Some(block_stack) = stack.get(bid) {
if let Some(block_stack) = block_stack.get(&aid) { if let Some(block_stack) = block_stack.get(&aid) {
block_stack.last().unwrap().clone() return block_stack.last().unwrap().clone();
} else {
if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) {
Operand::register(RegisterId::arg(*bid, *phinode_index), dtype)
} else if let Some(bid_idom) = domtree.idoms.get(bid) {
find_initial(
aid,
dtype,
code,
inpromotable,
bid_idom,
stack,
phinode_indexes,
domtree,
)
} else {
Operand::constant(Constant::undef(dtype))
}
} }
} else if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { }
if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) {
Operand::register(RegisterId::arg(*bid, *phinode_index), dtype) Operand::register(RegisterId::arg(*bid, *phinode_index), dtype)
} else if let Some(bid_idom) = domtree.idoms.get(bid) { } else if let Some(bid_idom) = domtree.idoms.get(bid) {
find_initial( find_initial(aid, dtype, bid_idom, stack, phinode_indexes, domtree)
aid,
dtype,
code,
inpromotable,
bid_idom,
stack,
phinode_indexes,
domtree,
)
} else { } else {
Operand::constant(Constant::undef(dtype)) Operand::constant(Constant::undef(dtype))
} }
} }
fn traverse_preorder( struct Asdf<'a> {
code: &mut FunctionDefinition, code: &'a mut FunctionDefinition,
phinode_indexes: &HashMap<(usize, BlockId), usize>, phinode_indexes: &'a HashMap<(usize, BlockId), usize>,
phinode_allocs: &HashMap<BlockId, Vec<(usize, Dtype)>>, phinode_allocs: &'a HashMap<BlockId, Vec<(usize, Dtype)>>,
inv_domtree: &HashMap<BlockId, Vec<BlockId>>, inv_domtree: &'a HashMap<BlockId, Vec<BlockId>>,
bid: &BlockId, stack: &'a mut HashMap<BlockId, HashMap<usize, Vec<Operand>>>,
stack: &mut HashMap<BlockId, HashMap<usize, Vec<Operand>>>, replaces: &'a mut HashMap<RegisterId, Operand>,
replaces: &mut HashMap<RegisterId, Operand>, inpromotable: &'a HashSet<usize>,
inpromotable: &HashSet<usize>, domtree: &'a Domtree,
domtree: &Domtree, };
) {
println!("bid: {}", bid);
let stack_org = stack.clone(); fn traverse_preorder(asdf: &mut Asdf<'_>, bid: &BlockId) {
let stack_org: HashMap<BlockId, HashMap<usize, Vec<Operand>>> = asdf.stack.clone();
// let block_stack = stack.entry(*bid).or_insert_with(HashMap::new); // let block_stack = stack.entry(*bid).or_insert_with(HashMap::new);
for (aid, dtype) in code.allocations.iter().enumerate() { for (aid, dtype) in asdf.code.allocations.iter().enumerate() {
if !inpromotable.contains(&aid) { if !asdf.inpromotable.contains(&aid) {
let initial = find_initial( let initial = find_initial(
aid, aid,
dtype.deref().clone(), dtype.deref().clone(),
code,
inpromotable,
bid, bid,
stack, asdf.stack,
phinode_indexes, asdf.phinode_indexes,
domtree, asdf.domtree,
); );
stack asdf.stack
.entry(*bid) .entry(*bid)
.or_insert_with(HashMap::new) .or_default()
.entry(aid) .entry(aid)
.or_insert_with(Vec::new) .or_default()
.push(initial); .push(initial);
// if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) { // if let Some(phinode_index) = phinode_indexes.get(&(aid, *bid)) {
// entry.push(Operand::register( // entry.push(Operand::register(
@@ -228,21 +197,21 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
} }
} }
// find_initial(code, inpromotable, bid, stack, phinode_indexes, domtree); // find_initial(code, inpromotable, bid, stack, phinode_indexes, domtree);
let block_stack = stack.entry(*bid).or_insert_with(HashMap::new); let block_stack = asdf.stack.entry(*bid).or_default();
for (i, inst) in code.blocks[bid].instructions.iter().enumerate() { for (i, inst) in asdf.code.blocks[bid].instructions.iter().enumerate() {
match inst.deref() { match inst.deref() {
Instruction::Store { ptr, value } => { Instruction::Store { ptr, value } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if !inpromotable.contains(aid) { if !asdf.inpromotable.contains(aid) {
block_stack.get_mut(aid).unwrap().push(value.clone()); block_stack.get_mut(aid).unwrap().push(value.clone());
} }
} }
} }
Instruction::Load { ptr } => { Instruction::Load { ptr } => {
if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() { if let Some((RegisterId::Local { aid }, _)) = ptr.get_register() {
if !inpromotable.contains(aid) { if !asdf.inpromotable.contains(aid) {
let _unused = replaces.insert( let _unused = asdf.replaces.insert(
RegisterId::temp(*bid, i), RegisterId::temp(*bid, i),
block_stack[aid].last().unwrap().clone(), block_stack[aid].last().unwrap().clone(),
); );
@@ -253,64 +222,61 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
} }
} }
let block = code.blocks.get_mut(bid).unwrap(); let block = asdf.code.blocks.get_mut(bid).unwrap();
match &mut block.exit { match &mut block.exit {
BlockExit::Jump { arg } => { BlockExit::Jump { arg } => {
fill_jump_args(arg, phinode_allocs, &block_stack); fill_jump_args(arg, asdf.phinode_allocs, block_stack);
} }
BlockExit::ConditionalJump { BlockExit::ConditionalJump {
arg_then, arg_else, .. arg_then, arg_else, ..
} => { } => {
fill_jump_args(arg_then, phinode_allocs, &block_stack); fill_jump_args(arg_then, asdf.phinode_allocs, block_stack);
fill_jump_args(arg_else, phinode_allocs, &block_stack); fill_jump_args(arg_else, asdf.phinode_allocs, block_stack);
} }
BlockExit::Switch { default, cases, .. } => { BlockExit::Switch { default, cases, .. } => {
fill_jump_args(default, phinode_allocs, &block_stack); fill_jump_args(default, asdf.phinode_allocs, block_stack);
for (_, arg) in cases { for (_, arg) in cases {
fill_jump_args(arg, phinode_allocs, &block_stack); fill_jump_args(arg, asdf.phinode_allocs, block_stack);
} }
} }
_ => (), _ => (),
} }
if let Some(succs) = inv_domtree.get(bid) { if let Some(succs) = asdf.inv_domtree.get(bid) {
for succ in succs { for succ in succs {
traverse_preorder( traverse_preorder(asdf, succ);
code,
phinode_indexes,
phinode_allocs,
inv_domtree,
succ,
stack,
replaces,
inpromotable,
domtree,
);
} }
} }
*stack = stack_org; *asdf.stack = stack_org;
} }
let bid_init = code.bid_init; let bid_init = code.bid_init;
traverse_preorder( let mut asdf = Asdf {
code, code,
&phinode_indexes, phinode_indexes: &phinode_indexes,
&phinode_allocs, phinode_allocs: &phinode_allocs,
&inv_domtree, inv_domtree: &inv_domtree,
&bid_init, stack: &mut stack,
&mut stack, replaces: &mut replaces,
&mut replaces, inpromotable: &inpromotable,
&inpromotable, domtree: &domtree,
&domtree, };
);
traverse_preorder(&mut asdf, &bid_init);
for (bid, block) in &mut code.blocks { for (bid, block) in &mut code.blocks {
for inst in block.instructions.iter_mut() { loop {
replace_instruction_operands(inst, &replaces); let mut changed = false;
for inst in block.instructions.iter_mut() {
changed = changed || replace_instruction_operands(inst, &replaces);
}
changed = changed || replace_exit_operands(&mut block.exit, &replaces);
if !changed {
break;
}
} }
replace_exit_operands(&mut block.exit, &replaces);
} }
for block in code.blocks.values_mut() { for block in code.blocks.values_mut() {
@@ -335,7 +301,7 @@ impl Optimize<FunctionDefinition> for Mem2regInner {
} }
} }
println!("replaces: {:?}\n", replaces); // println!("replaces: {:?}\n", replaces);
true // TODO true // TODO
} }
@@ -539,46 +505,90 @@ fn intersect_idom(
} }
} }
fn replace_instruction_operands(inst: &mut Instruction, replaces: &HashMap<RegisterId, Operand>) { fn replace_instruction_operands(
inst: &mut Instruction,
replaces: &HashMap<RegisterId, Operand>,
) -> bool {
match inst { match inst {
Instruction::BinOp { lhs, rhs, .. } => { Instruction::BinOp { lhs, rhs, .. } => {
replace_operand(lhs, replaces); let a = replace_operand(lhs, replaces);
replace_operand(rhs, replaces); let b = replace_operand(rhs, replaces);
a || b
} }
Instruction::UnaryOp { operand, .. } => replace_operand(operand, replaces), Instruction::UnaryOp { operand, .. } => replace_operand(operand, replaces),
Instruction::Store { ptr, value } => { Instruction::Store { ptr, value } => {
replace_operand(ptr, replaces); let a = replace_operand(ptr, replaces);
replace_operand(value, replaces); let b = replace_operand(value, replaces);
a || b
} }
Instruction::Load { ptr } => replace_operand(ptr, replaces), Instruction::Load { ptr } => replace_operand(ptr, replaces),
Instruction::Call { callee, args, .. } => { Instruction::Call { callee, args, .. } => {
replace_operand(callee, replaces); let mut changed = replace_operand(callee, replaces);
for arg in args.iter_mut() { for arg in args.iter_mut() {
replace_operand(arg, replaces); changed = changed || replace_operand(arg, replaces);
} }
changed
} }
Instruction::TypeCast { value, .. } => replace_operand(value, replaces), Instruction::TypeCast { value, .. } => replace_operand(value, replaces),
Instruction::GetElementPtr { ptr, offset, .. } => { Instruction::GetElementPtr { ptr, offset, .. } => {
replace_operand(ptr, replaces); let a = replace_operand(ptr, replaces);
replace_operand(offset, replaces); let b = replace_operand(offset, replaces);
a || b
} }
_ => unreachable!(), _ => unreachable!(),
} }
} }
fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap<RegisterId, Operand>) { fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap<RegisterId, Operand>) -> bool {
match exit { match exit {
BlockExit::ConditionalJump { condition, .. } => replace_operand(condition, replaces), BlockExit::Jump { arg } => {
BlockExit::Switch { value, .. } => replace_operand(value, replaces), let mut changed = false;
for arg in &mut arg.args {
changed = changed || replace_operand(arg, replaces);
}
changed
}
BlockExit::ConditionalJump {
condition,
arg_then,
arg_else,
} => {
let mut changed = replace_operand(condition, replaces);
for arg in &mut arg_then.args {
changed = changed || replace_operand(arg, replaces);
}
for arg in &mut arg_else.args {
changed = changed || replace_operand(arg, replaces);
}
changed
}
BlockExit::Switch {
value,
default,
cases,
} => {
let mut changed = replace_operand(value, replaces);
for arg in &mut default.args {
changed = changed || replace_operand(arg, replaces);
}
for (_, arg) in cases {
for arg in &mut arg.args {
changed = changed || replace_operand(arg, replaces);
}
}
changed
}
BlockExit::Return { value } => replace_operand(value, replaces), BlockExit::Return { value } => replace_operand(value, replaces),
_ => (), _ => false,
} }
} }
fn replace_operand(operand: &mut Operand, replaces: &HashMap<RegisterId, Operand>) { fn replace_operand(operand: &mut Operand, replaces: &HashMap<RegisterId, Operand>) -> bool {
if let Operand::Register { rid, .. } = operand { if let Operand::Register { rid, .. } = operand {
if let Some(new_operand) = replaces.get(rid) { if let Some(new_operand) = replaces.get(rid) {
*operand = new_operand.clone(); *operand = new_operand.clone();
return true;
} }
} }
false
} }