This commit is contained in:
static
2025-06-18 05:26:20 +00:00
parent f51053ce61
commit 2b0e990acc

View File

@@ -1,5 +1,5 @@
use core::{f32, num};
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet, VecDeque};
use std::hash::Hash;
use std::ops::Deref;
@@ -38,7 +38,6 @@ impl Translate<ir::TranslationUnit> for Asmgen {
}
struct InferenceGraph {
edges: HashSet<(ir::RegisterId, ir::RegisterId)>,
vertices: HashMap<ir::RegisterId, (ir::Dtype, asm::Register)>,
analysis: Analysis,
lives: HashMap<ir::RegisterId, HashSet<ir::RegisterId>>,
@@ -53,6 +52,7 @@ impl InferenceGraph {
let cfg = make_cfg(code);
let reverse_cfg = reverse_cfg(&cfg);
let domtree = Domtree::new(code.bid_init, &cfg, &reverse_cfg);
let analysis = analyze_function(code, signature, structs);
let mut lives = HashMap::new();
let mut first_loop = true;
@@ -128,10 +128,6 @@ impl InferenceGraph {
let mut uses = HashSet::new();
// for (aid, _) in block.phinodes.iter().enumerate() {
// let _ = uses.insert(ir::RegisterId::arg(*bid, aid));
// }
match &block.exit {
ir::BlockExit::Jump { arg } => {
for arg in &arg.args {
@@ -210,58 +206,52 @@ impl InferenceGraph {
first_loop = false;
}
let mut edges = HashSet::new();
let mut adj = HashMap::new();
for lives in lives.values() {
for rid1 in lives {
for rid2 in lives {
if rid1 != rid2 {
let _ = edges.insert((*rid1, *rid2));
adj.entry(*rid1).or_insert_with(Vec::new).push(*rid2);
}
}
}
}
// for (bid, block) in &code.blocks {
// for (i, _) in block.phinodes.iter().enumerate() {
// let rid1 = ir::RegisterId::arg(*bid, i);
// for (j, _) in block.phinodes.iter().enumerate() {
// if i != j {
// let rid2 = ir::RegisterId::arg(*bid, j);
// let _ = edges.insert((rid1, rid2));
// }
// }
// for (j, _) in block.instructions.iter().enumerate() {
// let rid2 = ir::RegisterId::temp(*bid, j);
// let _ = edges.insert((rid1, rid2));
// let _ = edges.insert((rid2, rid1));
// }
// }
// }
let mut num_of_edges = HashMap::new();
for (rid1, _) in &edges {
*num_of_edges.entry(*rid1).or_insert(0) += 1;
let mut pq = BinaryHeap::new();
for vertex in vertices.keys() {
pq.push(SatNode {
rid: *vertex,
sat_deg: 0,
deg: adj.get(vertex).map(|negs| negs.len()).unwrap_or_default(),
});
}
let analysis = analyze_function(code, signature, structs);
let mut saturation = HashMap::<ir::RegisterId, HashSet<asm::Register>>::new();
if !analysis.has_memcpy_in_prologue {
for (aid, dtype) in code.blocks[&code.bid_init].phinodes.iter().enumerate() {
let rid = ir::RegisterId::arg(code.bid_init, aid);
if analysis.is_temporary2(&rid, &lives, true) {
if is_integer(dtype) {
let (_, asm_reg) = vertices.get_mut(&rid).unwrap();
*asm_reg = asm::Register::arg(
asm::RegisterType::Integer,
analysis.primitive_arg_reg_index[&aid] as usize,
);
} else if is_float(dtype) {
let (_, asm_reg) = vertices.get_mut(&rid).unwrap();
*asm_reg = asm::Register::arg(
asm::RegisterType::FloatingPoint,
analysis.primitive_arg_reg_index[&aid] as usize,
);
let asm_reg = asm::Register::arg(
if is_integer(dtype) {
asm::RegisterType::Integer
} else if is_float(dtype) {
asm::RegisterType::FloatingPoint
} else {
continue;
},
analysis.primitive_arg_reg_index[&aid] as usize,
);
vertices.get_mut(&rid).unwrap().1 = asm_reg;
if let Some(neighs) = adj.get(&rid) {
for &nbr in neighs {
let _ = saturation.entry(nbr).or_default().insert(asm_reg);
let sat_deg = saturation[&nbr].len();
pq.push(SatNode {
rid: nbr,
sat_deg,
deg: adj[&nbr].len(),
});
}
}
}
}
@@ -273,6 +263,17 @@ impl InferenceGraph {
if let Some(reg) = reg {
let (_, asm_reg) = vertices.get_mut(&rid).unwrap();
*asm_reg = *reg;
if let Some(neighs) = adj.get(&rid) {
for &nbr in neighs {
let _ = saturation.entry(nbr).or_default().insert(*asm_reg);
let sat_deg = saturation[&nbr].len();
pq.push(SatNode {
rid: nbr,
sat_deg,
deg: adj[&nbr].len(),
});
}
}
}
}
}
@@ -317,208 +318,58 @@ impl InferenceGraph {
let _ = not_spilled.insert(*reg);
}
}
println!("clique: {:?}", clique);
// println!("clique: {:?}", clique);
for reg in clique.difference(&not_spilled) {
println!("spilled! {:?}", reg);
// println!("spilled! {:?}", reg);
let _ = spilled.insert(*reg);
}
}
let mut vertices_order = vertices
.keys()
.map(|rid| (*rid, num_of_edges.get(rid).cloned().unwrap_or_default()))
.sorted_by(|(_, v1), (_, v2)| v2.cmp(v1));
for (rid, count) in vertices_order {
if count == 0 || vertices[&rid].1 != asm::Register::Zero || spilled.contains(&rid) {
continue;
}
let dtype = &vertices[&rid].0;
let neightbors = edges
.iter()
.filter_map(|(r1, r2)| {
if *r1 == rid {
Some(&vertices[r2])
} else {
None
}
})
.collect::<Vec<_>>();
let neighbor_registers = neightbors
.iter()
.filter_map(|(_, reg)| {
// TODO: Saved말고 다른 것도 쓰게?
if is_integer(dtype)
&& matches!(
reg,
asm::Register::Saved(asm::RegisterType::Integer, _)
| asm::Register::Arg(asm::RegisterType::Integer, _)
| asm::Register::Temp(asm::RegisterType::Integer, _)
)
{
return Some(*reg);
}
if is_float(dtype)
&& matches!(
reg,
asm::Register::Saved(asm::RegisterType::FloatingPoint, _)
| asm::Register::Arg(asm::RegisterType::FloatingPoint, _)
| asm::Register::Temp(asm::RegisterType::FloatingPoint, _)
)
{
return Some(*reg);
}
None
})
.collect::<HashSet<_>>();
if is_integer(dtype) {
let smallest_temp_reg = smallest_missing_integer(
&neighbor_registers
.iter()
.filter_map(|reg| {
if let asm::Register::Temp(_, i) = reg {
Some(*i)
} else {
None
}
})
.collect(),
3,
); // t0~2는 못 씀
let smallest_arg_reg = smallest_missing_integer(
&neighbor_registers
.iter()
.filter_map(|reg| {
if let asm::Register::Arg(_, i) = reg {
Some(*i)
} else {
None
}
})
.collect(),
0,
);
let smallest_saved_reg = smallest_missing_integer(
&neighbor_registers
.iter()
.filter_map(|reg| {
if let asm::Register::Saved(_, i) = reg {
Some(*i)
} else {
None
}
})
.collect(),
1,
); // s0는 못 씀
if smallest_temp_reg <= 6 && analysis.is_temporary2(&rid, &lives, false) {
let _unused = vertices.insert(
rid,
(
dtype.clone(),
asm::Register::temp(asm::RegisterType::Integer, smallest_temp_reg),
),
);
} else if smallest_arg_reg <= 7 && analysis.is_temporary2(&rid, &lives, true) {
let _unused = vertices.insert(
rid,
(
dtype.clone(),
asm::Register::arg(asm::RegisterType::Integer, smallest_arg_reg),
),
);
} else if smallest_saved_reg <= 11 {
let _unused = vertices.insert(
rid,
(
dtype.clone(),
asm::Register::saved(asm::RegisterType::Integer, smallest_saved_reg),
),
);
} else {
// Spilling
// println!("asdf");
while let Some(SatNode { rid, .. }) = pq.pop() {
if let Some(nbrs) = adj.get(&rid) {
if vertices[&rid].1 != asm::Register::Zero || spilled.contains(&rid) {
continue;
}
} else if is_float(dtype) {
let smallest_temp_reg = smallest_missing_integer(
&neighbor_registers
.iter()
.filter_map(|reg| {
if let asm::Register::Temp(_, i) = reg {
Some(*i)
} else {
None
}
})
.collect(),
2,
); // ft0~1은 못 씀
let smallest_arg_reg = smallest_missing_integer(
&neighbor_registers
.iter()
.filter_map(|reg| {
if let asm::Register::Arg(_, i) = reg {
Some(*i)
} else {
None
}
})
.collect(),
0,
);
let smallest_saved_reg = smallest_missing_integer(
&neighbor_registers
.iter()
.filter_map(|reg| {
if let asm::Register::Saved(_, i) = reg {
Some(*i)
} else {
None
}
})
.collect(),
0,
);
if smallest_temp_reg <= 11 && analysis.is_temporary2(&rid, &lives, false) {
let _unused = vertices.insert(
rid,
(
dtype.clone(),
asm::Register::temp(
asm::RegisterType::FloatingPoint,
smallest_temp_reg,
),
),
);
} else if smallest_arg_reg <= 7 && analysis.is_temporary2(&rid, &lives, true) {
let _unused = vertices.insert(
rid,
(
dtype.clone(),
asm::Register::arg(asm::RegisterType::FloatingPoint, smallest_arg_reg),
),
);
} else if smallest_saved_reg <= 11 {
let _unused = vertices.insert(
rid,
(
dtype.clone(),
asm::Register::saved(
asm::RegisterType::FloatingPoint,
smallest_saved_reg,
),
),
);
let neighbor_colors: HashSet<_> = adj
.get(&rid)
.into_iter()
.flatten()
.filter_map(|nbr| {
let (_, reg) = &vertices[nbr];
if *reg != asm::Register::Zero {
Some(*reg)
} else {
None
}
})
.collect();
let dtype = &vertices[&rid].0;
let chosen = choose_asm_register(dtype, &neighbor_colors, &analysis, &rid, &lives);
if let Some(reg) = chosen {
let _unused = vertices.insert(rid, (dtype.clone(), reg));
for &nbr in nbrs {
let _ = saturation.entry(nbr).or_default().insert(reg);
if vertices[&nbr].1 == asm::Register::Zero && !spilled.contains(&nbr) {
let sat_deg = saturation[&nbr].len();
pq.push(SatNode {
rid: nbr,
sat_deg,
deg: adj.get(&nbr).map(|negs| negs.len()).unwrap_or_default(),
});
}
}
} else {
// Spilling
// Spill
let _ = spilled.insert(rid);
}
} else {
// TODO: Spilling or 레지스터 쪼개기 필요
}
}
InferenceGraph {
edges,
vertices,
analysis,
lives,
@@ -535,6 +386,92 @@ impl InferenceGraph {
}
}
#[derive(Clone, Copy, Eq, PartialEq)]
struct SatNode {
rid: ir::RegisterId,
sat_deg: usize,
deg: usize,
}
impl Ord for SatNode {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
(self.sat_deg, self.deg).cmp(&(other.sat_deg, other.deg))
}
}
impl PartialOrd for SatNode {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
fn choose_asm_register(
dtype: &ir::Dtype,
neighbor_colors: &HashSet<asm::Register>,
analysis: &Analysis,
rid: &ir::RegisterId,
lives: &HashMap<ir::RegisterId, HashSet<ir::RegisterId>>,
) -> Option<asm::Register> {
let (reg_type, min_temp_reg, max_temp_reg, min_saved_reg) = if is_integer(dtype) {
(asm::RegisterType::Integer, 3, 6, 1)
} else if is_float(dtype) {
(asm::RegisterType::FloatingPoint, 2, 11, 0)
} else {
return None;
};
let smallest_temp_reg = smallest_missing_integer(
&neighbor_colors
.iter()
.filter_map(|reg| {
if let asm::Register::Temp(t, i) = reg {
if *t == reg_type {
return Some(*i);
}
}
None
})
.collect(),
min_temp_reg,
);
let smallest_arg_reg = smallest_missing_integer(
&neighbor_colors
.iter()
.filter_map(|reg| {
if let asm::Register::Arg(t, i) = reg {
if *t == reg_type {
return Some(*i);
}
}
None
})
.collect(),
0,
);
let smallest_saved_reg = smallest_missing_integer(
&neighbor_colors
.iter()
.filter_map(|reg| {
if let asm::Register::Saved(t, i) = reg {
if *t == reg_type {
return Some(*i);
}
}
None
})
.collect(),
min_saved_reg,
);
if smallest_temp_reg <= max_temp_reg && analysis.is_temporary2(rid, lives, false) {
Some(asm::Register::temp(reg_type, smallest_temp_reg))
} else if smallest_arg_reg <= 7 && analysis.is_temporary2(rid, lives, true) {
Some(asm::Register::arg(reg_type, smallest_arg_reg))
} else if smallest_saved_reg <= 11 {
Some(asm::Register::saved(reg_type, smallest_saved_reg))
} else {
None
}
}
fn mark_as_used(operand: &ir::Operand, uses: &mut HashSet<ir::RegisterId>) {
if let ir::Operand::Register { rid, .. } = operand {
if !matches!(rid, ir::RegisterId::Local { .. }) {