diff --git a/src/asmgen/mod.rs b/src/asmgen/mod.rs index 8a5e22a..13c74d0 100644 --- a/src/asmgen/mod.rs +++ b/src/asmgen/mod.rs @@ -1,5 +1,5 @@ use core::{f32, num}; -use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; +use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet, VecDeque}; use std::hash::Hash; use std::ops::Deref; @@ -38,7 +38,6 @@ impl Translate for Asmgen { } struct InferenceGraph { - edges: HashSet<(ir::RegisterId, ir::RegisterId)>, vertices: HashMap, analysis: Analysis, lives: HashMap>, @@ -53,6 +52,7 @@ impl InferenceGraph { let cfg = make_cfg(code); let reverse_cfg = reverse_cfg(&cfg); let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg); + let analysis = analyze_function(code, signature, structs); let mut lives = HashMap::new(); let mut first_loop = true; @@ -128,10 +128,6 @@ impl InferenceGraph { let mut uses = HashSet::new(); - // for (aid, _) in block.phinodes.iter().enumerate() { - // let _ = uses.insert(ir::RegisterId::arg(*bid, aid)); - // } - match &block.exit { ir::BlockExit::Jump { arg } => { for arg in &arg.args { @@ -210,58 +206,52 @@ impl InferenceGraph { first_loop = false; } - let mut edges = HashSet::new(); - + let mut adj = HashMap::new(); for lives in lives.values() { for rid1 in lives { for rid2 in lives { if rid1 != rid2 { - let _ = edges.insert((*rid1, *rid2)); + adj.entry(*rid1).or_insert_with(Vec::new).push(*rid2); } } } } - // for (bid, block) in &code.blocks { - // for (i, _) in block.phinodes.iter().enumerate() { - // let rid1 = ir::RegisterId::arg(*bid, i); - // for (j, _) in block.phinodes.iter().enumerate() { - // if i != j { - // let rid2 = ir::RegisterId::arg(*bid, j); - // let _ = edges.insert((rid1, rid2)); - // } - // } - // for (j, _) in block.instructions.iter().enumerate() { - // let rid2 = ir::RegisterId::temp(*bid, j); - // let _ = edges.insert((rid1, rid2)); - // let _ = edges.insert((rid2, rid1)); - // } - // } - // } - - let mut num_of_edges = HashMap::new(); - for (rid1, _) in &edges { - *num_of_edges.entry(*rid1).or_insert(0) += 1; + let mut pq = BinaryHeap::new(); + for vertex in vertices.keys() { + pq.push(SatNode { + rid: *vertex, + sat_deg: 0, + deg: adj.get(vertex).map(|negs| negs.len()).unwrap_or_default(), + }); } - - let analysis = analyze_function(code, signature, structs); + let mut saturation = HashMap::>::new(); if !analysis.has_memcpy_in_prologue { for (aid, dtype) in code.blocks[&code.bid_init].phinodes.iter().enumerate() { let rid = ir::RegisterId::arg(code.bid_init, aid); if analysis.is_temporary2(&rid, &lives, true) { - if is_integer(dtype) { - let (_, asm_reg) = vertices.get_mut(&rid).unwrap(); - *asm_reg = asm::Register::arg( - asm::RegisterType::Integer, - analysis.primitive_arg_reg_index[&aid] as usize, - ); - } else if is_float(dtype) { - let (_, asm_reg) = vertices.get_mut(&rid).unwrap(); - *asm_reg = asm::Register::arg( - asm::RegisterType::FloatingPoint, - analysis.primitive_arg_reg_index[&aid] as usize, - ); + let asm_reg = asm::Register::arg( + if is_integer(dtype) { + asm::RegisterType::Integer + } else if is_float(dtype) { + asm::RegisterType::FloatingPoint + } else { + continue; + }, + analysis.primitive_arg_reg_index[&aid] as usize, + ); + vertices.get_mut(&rid).unwrap().1 = asm_reg; + if let Some(neighs) = adj.get(&rid) { + for &nbr in neighs { + let _ = saturation.entry(nbr).or_default().insert(asm_reg); + let sat_deg = saturation[&nbr].len(); + pq.push(SatNode { + rid: nbr, + sat_deg, + deg: adj[&nbr].len(), + }); + } } } } @@ -273,6 +263,17 @@ impl InferenceGraph { if let Some(reg) = reg { let (_, asm_reg) = vertices.get_mut(&rid).unwrap(); *asm_reg = *reg; + if let Some(neighs) = adj.get(&rid) { + for &nbr in neighs { + let _ = saturation.entry(nbr).or_default().insert(*asm_reg); + let sat_deg = saturation[&nbr].len(); + pq.push(SatNode { + rid: nbr, + sat_deg, + deg: adj[&nbr].len(), + }); + } + } } } } @@ -317,208 +318,58 @@ impl InferenceGraph { let _ = not_spilled.insert(*reg); } } - println!("clique: {:?}", clique); + // println!("clique: {:?}", clique); for reg in clique.difference(¬_spilled) { - println!("spilled! {:?}", reg); + // println!("spilled! {:?}", reg); let _ = spilled.insert(*reg); } } - let mut vertices_order = vertices - .keys() - .map(|rid| (*rid, num_of_edges.get(rid).cloned().unwrap_or_default())) - .sorted_by(|(_, v1), (_, v2)| v2.cmp(v1)); - - for (rid, count) in vertices_order { - if count == 0 || vertices[&rid].1 != asm::Register::Zero || spilled.contains(&rid) { - continue; - } - - let dtype = &vertices[&rid].0; - let neightbors = edges - .iter() - .filter_map(|(r1, r2)| { - if *r1 == rid { - Some(&vertices[r2]) - } else { - None - } - }) - .collect::>(); - let neighbor_registers = neightbors - .iter() - .filter_map(|(_, reg)| { - // TODO: Saved말고 다른 것도 쓰게? - if is_integer(dtype) - && matches!( - reg, - asm::Register::Saved(asm::RegisterType::Integer, _) - | asm::Register::Arg(asm::RegisterType::Integer, _) - | asm::Register::Temp(asm::RegisterType::Integer, _) - ) - { - return Some(*reg); - } - if is_float(dtype) - && matches!( - reg, - asm::Register::Saved(asm::RegisterType::FloatingPoint, _) - | asm::Register::Arg(asm::RegisterType::FloatingPoint, _) - | asm::Register::Temp(asm::RegisterType::FloatingPoint, _) - ) - { - return Some(*reg); - } - None - }) - .collect::>(); - if is_integer(dtype) { - let smallest_temp_reg = smallest_missing_integer( - &neighbor_registers - .iter() - .filter_map(|reg| { - if let asm::Register::Temp(_, i) = reg { - Some(*i) - } else { - None - } - }) - .collect(), - 3, - ); // t0~2는 못 씀 - let smallest_arg_reg = smallest_missing_integer( - &neighbor_registers - .iter() - .filter_map(|reg| { - if let asm::Register::Arg(_, i) = reg { - Some(*i) - } else { - None - } - }) - .collect(), - 0, - ); - let smallest_saved_reg = smallest_missing_integer( - &neighbor_registers - .iter() - .filter_map(|reg| { - if let asm::Register::Saved(_, i) = reg { - Some(*i) - } else { - None - } - }) - .collect(), - 1, - ); // s0는 못 씀 - if smallest_temp_reg <= 6 && analysis.is_temporary2(&rid, &lives, false) { - let _unused = vertices.insert( - rid, - ( - dtype.clone(), - asm::Register::temp(asm::RegisterType::Integer, smallest_temp_reg), - ), - ); - } else if smallest_arg_reg <= 7 && analysis.is_temporary2(&rid, &lives, true) { - let _unused = vertices.insert( - rid, - ( - dtype.clone(), - asm::Register::arg(asm::RegisterType::Integer, smallest_arg_reg), - ), - ); - } else if smallest_saved_reg <= 11 { - let _unused = vertices.insert( - rid, - ( - dtype.clone(), - asm::Register::saved(asm::RegisterType::Integer, smallest_saved_reg), - ), - ); - } else { - // Spilling + // println!("asdf"); + while let Some(SatNode { rid, .. }) = pq.pop() { + if let Some(nbrs) = adj.get(&rid) { + if vertices[&rid].1 != asm::Register::Zero || spilled.contains(&rid) { + continue; } - } else if is_float(dtype) { - let smallest_temp_reg = smallest_missing_integer( - &neighbor_registers - .iter() - .filter_map(|reg| { - if let asm::Register::Temp(_, i) = reg { - Some(*i) - } else { - None - } - }) - .collect(), - 2, - ); // ft0~1은 못 씀 - let smallest_arg_reg = smallest_missing_integer( - &neighbor_registers - .iter() - .filter_map(|reg| { - if let asm::Register::Arg(_, i) = reg { - Some(*i) - } else { - None - } - }) - .collect(), - 0, - ); - let smallest_saved_reg = smallest_missing_integer( - &neighbor_registers - .iter() - .filter_map(|reg| { - if let asm::Register::Saved(_, i) = reg { - Some(*i) - } else { - None - } - }) - .collect(), - 0, - ); - if smallest_temp_reg <= 11 && analysis.is_temporary2(&rid, &lives, false) { - let _unused = vertices.insert( - rid, - ( - dtype.clone(), - asm::Register::temp( - asm::RegisterType::FloatingPoint, - smallest_temp_reg, - ), - ), - ); - } else if smallest_arg_reg <= 7 && analysis.is_temporary2(&rid, &lives, true) { - let _unused = vertices.insert( - rid, - ( - dtype.clone(), - asm::Register::arg(asm::RegisterType::FloatingPoint, smallest_arg_reg), - ), - ); - } else if smallest_saved_reg <= 11 { - let _unused = vertices.insert( - rid, - ( - dtype.clone(), - asm::Register::saved( - asm::RegisterType::FloatingPoint, - smallest_saved_reg, - ), - ), - ); + + let neighbor_colors: HashSet<_> = adj + .get(&rid) + .into_iter() + .flatten() + .filter_map(|nbr| { + let (_, reg) = &vertices[nbr]; + if *reg != asm::Register::Zero { + Some(*reg) + } else { + None + } + }) + .collect(); + let dtype = &vertices[&rid].0; + let chosen = choose_asm_register(dtype, &neighbor_colors, &analysis, &rid, &lives); + + if let Some(reg) = chosen { + let _unused = vertices.insert(rid, (dtype.clone(), reg)); + for &nbr in nbrs { + let _ = saturation.entry(nbr).or_default().insert(reg); + + if vertices[&nbr].1 == asm::Register::Zero && !spilled.contains(&nbr) { + let sat_deg = saturation[&nbr].len(); + pq.push(SatNode { + rid: nbr, + sat_deg, + deg: adj.get(&nbr).map(|negs| negs.len()).unwrap_or_default(), + }); + } + } } else { - // Spilling + // Spill + let _ = spilled.insert(rid); } - } else { - // TODO: Spilling or 레지스터 쪼개기 필요 } } InferenceGraph { - edges, vertices, analysis, lives, @@ -535,6 +386,92 @@ impl InferenceGraph { } } +#[derive(Clone, Copy, Eq, PartialEq)] +struct SatNode { + rid: ir::RegisterId, + sat_deg: usize, + deg: usize, +} + +impl Ord for SatNode { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + (self.sat_deg, self.deg).cmp(&(other.sat_deg, other.deg)) + } +} + +impl PartialOrd for SatNode { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +fn choose_asm_register( + dtype: &ir::Dtype, + neighbor_colors: &HashSet, + analysis: &Analysis, + rid: &ir::RegisterId, + lives: &HashMap>, +) -> Option { + let (reg_type, min_temp_reg, max_temp_reg, min_saved_reg) = if is_integer(dtype) { + (asm::RegisterType::Integer, 3, 6, 1) + } else if is_float(dtype) { + (asm::RegisterType::FloatingPoint, 2, 11, 0) + } else { + return None; + }; + let smallest_temp_reg = smallest_missing_integer( + &neighbor_colors + .iter() + .filter_map(|reg| { + if let asm::Register::Temp(t, i) = reg { + if *t == reg_type { + return Some(*i); + } + } + None + }) + .collect(), + min_temp_reg, + ); + let smallest_arg_reg = smallest_missing_integer( + &neighbor_colors + .iter() + .filter_map(|reg| { + if let asm::Register::Arg(t, i) = reg { + if *t == reg_type { + return Some(*i); + } + } + None + }) + .collect(), + 0, + ); + let smallest_saved_reg = smallest_missing_integer( + &neighbor_colors + .iter() + .filter_map(|reg| { + if let asm::Register::Saved(t, i) = reg { + if *t == reg_type { + return Some(*i); + } + } + None + }) + .collect(), + min_saved_reg, + ); + if smallest_temp_reg <= max_temp_reg && analysis.is_temporary2(rid, lives, false) { + Some(asm::Register::temp(reg_type, smallest_temp_reg)) + } else if smallest_arg_reg <= 7 && analysis.is_temporary2(rid, lives, true) { + Some(asm::Register::arg(reg_type, smallest_arg_reg)) + } else if smallest_saved_reg <= 11 { + Some(asm::Register::saved(reg_type, smallest_saved_reg)) + } else { + None + } +} + fn mark_as_used(operand: &ir::Operand, uses: &mut HashSet) { if let ir::Operand::Register { rid, .. } = operand { if !matches!(rid, ir::RegisterId::Local { .. }) {