diff --git a/src/opt/simplify_cfg.rs b/src/opt/simplify_cfg.rs index d1c41f9..4bcb1ed 100644 --- a/src/opt/simplify_cfg.rs +++ b/src/opt/simplify_cfg.rs @@ -32,24 +32,326 @@ pub struct SimplifyCfgEmpty {} impl Optimize for SimplifyCfgConstProp { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { - todo!() + code.blocks.iter_mut().any(|(_, block)| { + if let Some(exit) = self.simplify_block_exit(&block.exit) { + block.exit = exit; + true + } else { + false + } + }) } } impl Optimize for SimplifyCfgReach { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { - todo!() + let graph = make_cfg(code); + let mut visited = HashSet::new(); + + let mut queue = vec![code.bid_init]; + let _ = visited.insert(code.bid_init); + while let Some(bid) = queue.pop() { + if let Some(args) = graph.get(&bid) { + for arg in args { + if visited.insert(arg.bid) { + queue.push(arg.bid); + } + } + } + } + + let size_org = code.blocks.len(); + code.blocks.retain(|bid, _| visited.contains(bid)); + code.blocks.len() < size_org } } impl Optimize for SimplifyCfgMerge { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { - todo!() + let graph = make_cfg(code); + let mut indegrees = HashMap::new(); + for args in graph.values() { + for arg in args { + *indegrees.entry(arg.bid).or_insert(0) += 1usize; + } + } + + let mut merged_from = HashMap::new(); + let mut merged_to = HashMap::new(); + let mut replaces = HashMap::new(); + + for (bid_from, block_from) in &code.blocks { + if let BlockExit::Jump { arg } = &block_from.exit { + let bid_to = arg.bid; + if *bid_from != bid_to && indegrees.get(&bid_to) == Some(&1) { + let bid_merge_target = *merged_to.get(bid_from).unwrap_or(bid_from); + merged_from + .entry(bid_merge_target) + .or_insert(Vec::new()) + .push(bid_to); + let _ = merged_to.insert(bid_to, bid_merge_target); + + if let Some(mut bids_from) = merged_from.remove(&bid_to) { + merged_from + .entry(bid_merge_target) + .or_insert(Vec::new()) + .append(&mut bids_from); + merged_to.iter_mut().map(|(_, merged_to)| { + if *merged_to == bid_to { + *merged_to = bid_merge_target; + } + }); + } + + let block_to = code.blocks.get(&bid_to).unwrap(); + for (i, (a, p)) in izip!(&arg.args, block_to.phinodes.iter()).enumerate() { + let from = RegisterId::arg(bid_to, i); + let _unused = replaces.insert(from, a.clone()); + } + } + } + } + + for (bid_merge_target, bids_merge_from) in &merged_from { + let blocks_merge_from = bids_merge_from + .iter() + .map(|bid| code.blocks.get(bid).unwrap().clone()) + .collect::>(); + let block_merge_target = code.blocks.get_mut(bid_merge_target).unwrap(); + let mut i = block_merge_target.instructions.len(); + + for (bid, block) in bids_merge_from.iter().zip(blocks_merge_from) { + for (j, inst) in block.instructions.iter().enumerate() { + let dtype = inst.dtype(); + block_merge_target.instructions.push(inst.clone()); + + let from = RegisterId::temp(*bid, j); + let to = Operand::register(RegisterId::temp(*bid_merge_target, i), dtype); + let _unused = replaces.insert(from, to); + i += 1; + } + + block_merge_target.exit = block.exit; + } + } + + code.blocks.retain(|bid, _| !merged_to.contains_key(bid)); + + for (bid, block) in &mut code.blocks { + for inst in block.instructions.iter_mut() { + replace_instruction_operands(inst, &replaces); + } + replace_exit_operands(&mut block.exit, &replaces); + } + + !merged_to.is_empty() } } impl Optimize for SimplifyCfgEmpty { fn optimize(&mut self, code: &mut FunctionDefinition) -> bool { - todo!() + let empty_blocks = code + .blocks + .iter() + .filter(|(_, block)| block.phinodes.is_empty() && block.instructions.is_empty()) + .map(|(bid, block)| (*bid, block.clone())) + .collect::>(); + code.blocks + .iter_mut() + .any(|(_, block)| self.simplify_block_exit(&mut block.exit, &empty_blocks)) + } +} + +impl SimplifyCfgConstProp { + fn simplify_block_exit(&self, exit: &BlockExit) -> Option { + match exit { + BlockExit::ConditionalJump { + condition, + arg_then, + arg_else, + } => { + if arg_then == arg_else { + Some(BlockExit::Jump { + arg: arg_then.clone(), + }) + } else if let Some(condition) = condition.get_constant() { + match condition { + Constant::Int { value: 0, .. } => Some(BlockExit::Jump { + arg: arg_else.clone(), + }), + Constant::Int { value: 1, .. } => Some(BlockExit::Jump { + arg: arg_then.clone(), + }), + _ => None, + } + } else { + None + } + } + BlockExit::Switch { + value, + default, + cases, + } => { + if cases.iter().all(|(_, bid)| default == bid) { + Some(BlockExit::Jump { + arg: default.clone(), + }) + } else { + value.get_constant().map(|value| BlockExit::Jump { + arg: if let Some((_, arg)) = cases.iter().find(|(c, _)| value == c) { + arg.clone() + } else { + default.clone() + }, + }) + } + } + _ => None, + } + } +} + +impl SimplifyCfgEmpty { + fn simplify_jump_arg(&self, arg: &mut JumpArg, empty_blocks: &HashMap) -> bool { + let block = if let Some(empty_block) = empty_blocks.get(&arg.bid) { + empty_block + } else { + return false; + }; + if let BlockExit::Jump { arg: new_arg } = &block.exit { + *arg = new_arg.clone(); + true + } else { + false + } + } + + fn simplify_block_exit( + &self, + exit: &mut BlockExit, + empty_blocks: &HashMap, + ) -> bool { + match exit { + BlockExit::Jump { arg } => { + let block = if let Some(empty_block) = empty_blocks.get(&arg.bid) { + empty_block + } else { + return false; + }; + *exit = block.exit.clone(); + true + } + BlockExit::ConditionalJump { + arg_then, arg_else, .. + } => { + let changed1 = self.simplify_jump_arg(arg_then, empty_blocks); + let changed2 = self.simplify_jump_arg(arg_else, empty_blocks); + changed1 || changed2 + } + BlockExit::Switch { default, cases, .. } => { + let changed1 = self.simplify_jump_arg(default, empty_blocks); + let changed2 = cases + .iter_mut() + .any(|c| self.simplify_jump_arg(&mut c.1, empty_blocks)); + changed1 || changed2 + } + BlockExit::Return { .. } | BlockExit::Unreachable => false, + } + } +} + +fn make_cfg(fdef: &FunctionDefinition) -> HashMap> { + fdef.blocks + .iter() + .map(|(bid, block)| { + let mut args = Vec::new(); + match &block.exit { + BlockExit::Jump { arg } => args.push(arg.clone()), + BlockExit::ConditionalJump { + arg_then, arg_else, .. + } => { + args.push(arg_then.clone()); + args.push(arg_else.clone()); + } + BlockExit::Switch { default, cases, .. } => { + args.push(default.clone()); + for (_, arg) in cases { + args.push(arg.clone()); + } + } + _ => (), + } + (*bid, args) + }) + .collect() +} + +// fn reverse_cfg(cfg: &HashMap>) -> HashMap> { +// let mut result = HashMap::new(); +// for (bid, jumps) in cfg { +// for jump in jumps { +// result +// .entry(jump.bid) +// .or_insert_with(Vec::new) +// .push((*bid, jump.clone())) +// } +// } +// result +// } + +fn replace_instruction_operands(inst: &mut Instruction, replaces: &HashMap) { + match inst { + Instruction::BinOp { lhs, rhs, .. } => { + replace_operand(lhs, replaces); + replace_operand(rhs, replaces); + } + Instruction::UnaryOp { operand, .. } => { + replace_operand(operand, replaces); + } + Instruction::Store { ptr, value } => { + replace_operand(ptr, replaces); + replace_operand(value, replaces); + } + Instruction::Load { ptr } => { + replace_operand(ptr, replaces); + } + Instruction::Call { callee, args, .. } => { + replace_operand(callee, replaces); + for arg in args.iter_mut() { + replace_operand(arg, replaces); + } + } + Instruction::TypeCast { value, .. } => { + replace_operand(value, replaces); + } + Instruction::GetElementPtr { ptr, offset, .. } => { + replace_operand(ptr, replaces); + replace_operand(offset, replaces); + } + _ => todo!(), + } +} + +fn replace_exit_operands(exit: &mut BlockExit, replaces: &HashMap) { + match exit { + BlockExit::ConditionalJump { condition, .. } => { + replace_operand(condition, replaces); + } + BlockExit::Switch { value, .. } => { + replace_operand(value, replaces); + } + BlockExit::Return { value } => { + replace_operand(value, replaces); + } + _ => (), + } +} + +fn replace_operand(operand: &mut Operand, replaces: &HashMap) { + if let Operand::Register { rid, .. } = operand { + if let Some(new_operand) = replaces.get(rid) { + std::mem::replace(operand, new_operand.clone()); + } } }