This commit is contained in:
static
2025-06-17 06:39:16 +00:00
parent 5fa024ceb4
commit 4c91663320

View File

@@ -40,6 +40,8 @@ impl Translate<ir::TranslationUnit> for Asmgen {
struct InferenceGraph {
edges: HashSet<(ir::RegisterId, ir::RegisterId)>,
vertices: HashMap<ir::RegisterId, (ir::Dtype, asm::Register)>,
has_call: bool,
is_a0_return_pointer: bool,
}
impl InferenceGraph {
@@ -234,6 +236,8 @@ impl InferenceGraph {
let (mut num_int_args, mut num_float_args) =
get_number_of_register_arguments(&signature.ret, &signature.params, structs);
let mut has_call = false;
let mut is_a0_return_pointer = false;
'outer: for block in code.blocks.values() {
for inst in &block.instructions {
@@ -253,17 +257,20 @@ impl InferenceGraph {
num_int_args = 8;
num_float_args = 8; // call이 있는 경우 a는 사용 x
has_call = true;
break 'outer;
} else if let ir::Instruction::Store { value, .. } = inst.deref() {
if is_struct(&value.dtype(), structs).is_some() {
num_int_args = 8;
num_float_args = 8;
has_call = true; // memcpy
break 'outer;
}
} else if let ir::Instruction::Load { ptr } = inst.deref() {
if is_struct(ptr.dtype().get_pointer_inner().unwrap(), structs).is_some() {
num_int_args = 8;
num_float_args = 8;
has_call = true; // memcpy
break 'outer;
}
}
@@ -274,6 +281,8 @@ impl InferenceGraph {
if size > 16 {
num_int_args = 8;
num_float_args = 8;
has_call = true; // memcpy
is_a0_return_pointer = true;
}
}
@@ -423,7 +432,12 @@ impl InferenceGraph {
}
}
InferenceGraph { edges, vertices }
InferenceGraph {
edges,
vertices,
has_call,
is_a0_return_pointer,
}
}
fn get_register(&self, rid: &ir::RegisterId) -> Option<asm::Register> {
@@ -474,6 +488,7 @@ struct Context {
stack_allocation: u64,
new_blocks: Vec<(asm::Label, Vec<asm::Instruction>)>,
inference_graph: InferenceGraph,
saved_reg_offset: usize,
}
impl Asmgen {
@@ -612,7 +627,17 @@ impl Asmgen {
stack_allocation += ceil_to_multiple_of_16(get_dtype_size(dtype, structs));
}
stack_allocation += 24; // s0, ra, a0
stack_allocation += 8; // s0
let mut saved_reg_offset = 8;
if inference_graph.has_call {
stack_allocation += 8; // ra
saved_reg_offset += 8;
}
if inference_graph.is_a0_return_pointer {
stack_allocation += 8; // a0
saved_reg_offset += 8;
}
let num_int_saved_regs = inference_graph
.vertices
@@ -640,7 +665,10 @@ impl Asmgen {
.len();
stack_allocation += ((num_int_saved_regs + num_float_saved_regs) * 8) as u64;
if stack_allocation % 16 != 0 {
if stack_allocation == 8 {
// Only s0: No Spill
stack_allocation = 0;
} else if stack_allocation % 16 != 0 {
// 스택은 16바이트 정렬
stack_allocation += 16 - (stack_allocation % 16);
}
@@ -650,57 +678,66 @@ impl Asmgen {
}
let mut insts = vec![];
self.translate_addi(
asm::Register::Sp,
asm::Register::Sp,
!stack_allocation + 1,
&mut insts,
);
self.translate_store(
asm::SType::SD,
asm::Register::Sp,
asm::Register::S0,
stack_allocation - 8,
&mut insts,
);
self.translate_addi(
asm::Register::S0,
asm::Register::Sp,
stack_allocation,
&mut insts,
);
insts.push(asm::Instruction::SType {
instr: asm::SType::SD,
rs1: asm::Register::S0,
rs2: asm::Register::Ra,
imm: asm::Immediate::Value(!16 + 1),
});
insts.push(asm::Instruction::SType {
instr: asm::SType::SD,
rs1: asm::Register::S0,
rs2: asm::Register::A0,
imm: asm::Immediate::Value(!24 + 1),
});
// S1~레지스터 및 FS0~ 레지스터 백업
for i in 0..num_int_saved_regs {
insts.push(asm::Instruction::SType {
instr: asm::SType::SD,
rs1: asm::Register::S0,
rs2: asm::Register::saved(
asm::RegisterType::Integer,
i + 1, // S0는 이 용도로 쓰지 않음
),
imm: asm::Immediate::Value((!(32 + i * 8) + 1) as u64),
});
}
for i in 0..num_float_saved_regs {
insts.push(asm::Instruction::SType {
instr: asm::SType::store(ir::Dtype::DOUBLE),
rs1: asm::Register::S0,
rs2: asm::Register::saved(asm::RegisterType::FloatingPoint, i),
imm: asm::Immediate::Value((!(32 + num_int_saved_regs * 8 + i * 8) + 1) as u64),
});
if stack_allocation != 0 {
self.translate_addi(
asm::Register::Sp,
asm::Register::Sp,
!stack_allocation + 1,
&mut insts,
);
self.translate_store(
asm::SType::SD,
asm::Register::Sp,
asm::Register::S0,
stack_allocation - 8,
&mut insts,
);
self.translate_addi(
asm::Register::S0,
asm::Register::Sp,
stack_allocation,
&mut insts,
);
if inference_graph.has_call {
insts.push(asm::Instruction::SType {
instr: asm::SType::SD,
rs1: asm::Register::S0,
rs2: asm::Register::Ra,
imm: asm::Immediate::Value(!16 + 1),
});
}
if inference_graph.is_a0_return_pointer {
insts.push(asm::Instruction::SType {
instr: asm::SType::SD,
rs1: asm::Register::S0,
rs2: asm::Register::A0,
imm: asm::Immediate::Value(!24 + 1),
});
}
// S1~레지스터 및 FS0~ 레지스터 백업
for i in 0..num_int_saved_regs {
insts.push(asm::Instruction::SType {
instr: asm::SType::SD,
rs1: asm::Register::S0,
rs2: asm::Register::saved(
asm::RegisterType::Integer,
i + 1, // S0는 이 용도로 쓰지 않음
),
imm: asm::Immediate::Value((!(saved_reg_offset + (i + 1) * 8) + 1) as u64),
});
}
for i in 0..num_float_saved_regs {
insts.push(asm::Instruction::SType {
instr: asm::SType::store(ir::Dtype::DOUBLE),
rs1: asm::Register::S0,
rs2: asm::Register::saved(asm::RegisterType::FloatingPoint, i),
imm: asm::Immediate::Value(
(!(saved_reg_offset + num_int_saved_regs * 8 + (i + 1) * 8) + 1) as u64,
),
});
}
}
let mut num_int_args = 0;
@@ -939,73 +976,83 @@ impl Asmgen {
stack_allocation,
new_blocks: Vec::new(),
inference_graph,
saved_reg_offset,
}
}
fn translate_epilogue(&mut self, context: &mut Context) {
// S1~레지스터 및 FS0~ 레지스터 복원
let num_int_regs = context
.inference_graph
.vertices
.values()
.filter_map(|(_, reg)| {
if let asm::Register::Saved(asm::RegisterType::Integer, i) = reg {
Some(*i)
} else {
None
}
})
.collect::<HashSet<_>>()
.len();
let num_float_regs = context
.inference_graph
.vertices
.values()
.filter_map(|(_, reg)| {
if let asm::Register::Saved(asm::RegisterType::FloatingPoint, i) = reg {
Some(*i)
} else {
None
}
})
.collect::<HashSet<_>>()
.len();
for i in 0..num_float_regs {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::load(ir::Dtype::DOUBLE),
rd: asm::Register::saved(asm::RegisterType::FloatingPoint, i),
rs1: asm::Register::S0,
imm: asm::Immediate::Value((!(32 + num_int_regs * 8 + i * 8) + 1) as u64),
});
}
for i in 0..num_int_regs {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::LD,
rs1: asm::Register::S0,
rd: asm::Register::saved(asm::RegisterType::Integer, i + 1), // S0는 이 용도로 쓰지 않음
imm: asm::Immediate::Value((!(32 + i * 8) + 1) as u64),
});
if context.stack_allocation != 0 {
// S1~레지스터 및 FS0~ 레지스터 복원
let num_int_regs = context
.inference_graph
.vertices
.values()
.filter_map(|(_, reg)| {
if let asm::Register::Saved(asm::RegisterType::Integer, i) = reg {
Some(*i)
} else {
None
}
})
.collect::<HashSet<_>>()
.len();
let num_float_regs = context
.inference_graph
.vertices
.values()
.filter_map(|(_, reg)| {
if let asm::Register::Saved(asm::RegisterType::FloatingPoint, i) = reg {
Some(*i)
} else {
None
}
})
.collect::<HashSet<_>>()
.len();
for i in 0..num_float_regs {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::load(ir::Dtype::DOUBLE),
rd: asm::Register::saved(asm::RegisterType::FloatingPoint, i),
rs1: asm::Register::S0,
imm: asm::Immediate::Value(
(!(context.saved_reg_offset + num_int_regs * 8 + (i + 1) * 8) + 1) as u64,
),
});
}
for i in 0..num_int_regs {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::LD,
rs1: asm::Register::S0,
rd: asm::Register::saved(asm::RegisterType::Integer, i + 1), // S0는 이 용도로 쓰지 않음
imm: asm::Immediate::Value(
(!(context.saved_reg_offset + (i + 1) * 8) + 1) as u64,
),
});
}
if context.inference_graph.has_call {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::LD,
rd: asm::Register::Ra,
rs1: asm::Register::S0,
imm: asm::Immediate::Value(!16 + 1),
});
}
self.translate_load(
asm::IType::LD,
asm::Register::S0,
asm::Register::Sp,
context.stack_allocation - 8,
&mut context.insts,
);
self.translate_addi(
asm::Register::Sp,
asm::Register::Sp,
context.stack_allocation,
&mut context.insts,
);
}
context.insts.push(asm::Instruction::IType {
instr: asm::IType::LD,
rd: asm::Register::Ra,
rs1: asm::Register::S0,
imm: asm::Immediate::Value(!16 + 1),
});
self.translate_load(
asm::IType::LD,
asm::Register::S0,
asm::Register::Sp,
context.stack_allocation - 8,
&mut context.insts,
);
self.translate_addi(
asm::Register::Sp,
asm::Register::Sp,
context.stack_allocation,
&mut context.insts,
);
context
.insts
.push(asm::Instruction::Pseudo(asm::Pseudo::Ret));
@@ -1036,10 +1083,140 @@ impl Asmgen {
let operand_dtype = upgrade_dtype(&org_operand_dtype);
let rs1 =
self.translate_load_operand(lhs, get_lhs_register(&operand_dtype), context);
let rs2 =
self.translate_load_operand(rhs, get_rhs_register(&operand_dtype), context);
let rd = rd.unwrap_or(get_res_register(dtype));
if let Some(ir::Constant::Int { value, .. }) = rhs.get_constant() {
let mut imm_mode = false;
let data_size = asm::DataSize::try_from(operand_dtype.clone()).unwrap();
match op {
ast::BinaryOperator::Plus
| ast::BinaryOperator::BitwiseAnd
| ast::BinaryOperator::BitwiseOr
| ast::BinaryOperator::BitwiseXor => {
if (-2048..=2047).contains(&(*value as i128)) {
context.insts.push(asm::Instruction::IType {
instr: match op {
ast::BinaryOperator::Plus => {
asm::IType::Addi(data_size)
}
ast::BinaryOperator::BitwiseAnd => asm::IType::Andi,
ast::BinaryOperator::BitwiseOr => asm::IType::Ori,
ast::BinaryOperator::BitwiseXor => asm::IType::Xori,
_ => unreachable!(),
},
rd,
rs1,
imm: asm::Immediate::Value(*value as u64),
});
imm_mode = true;
}
}
ast::BinaryOperator::Minus => {
if (-2047..=2048).contains(&(*value as i128)) {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::Addi(data_size),
rd,
rs1,
imm: asm::Immediate::Value((!value + 1) as u64),
});
imm_mode = true;
}
}
ast::BinaryOperator::ShiftLeft => {
if (-2048..=2047).contains(&(*value as i128)) {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::Slli(data_size),
rd,
rs1,
imm: asm::Immediate::Value(*value as u64),
});
imm_mode = true;
}
}
ast::BinaryOperator::ShiftRight => {
if (-2048..=2047).contains(&(*value as i128)) {
context.insts.push(asm::Instruction::IType {
instr: if operand_dtype.is_int_signed() {
asm::IType::Srai(data_size)
} else {
asm::IType::Srli(data_size)
},
rd,
rs1,
imm: asm::Immediate::Value(*value as u64),
});
imm_mode = true;
}
}
ast::BinaryOperator::Less => {
if (-2048..=2047).contains(&(*value as i128)) {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::Slti {
is_signed: operand_dtype.is_int_signed(),
},
rd,
rs1,
imm: asm::Immediate::Value(*value as u64),
});
imm_mode = true;
}
}
ast::BinaryOperator::GreaterOrEqual => {
if (-2048..=2047).contains(&(*value as i128)) {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::Slti {
is_signed: operand_dtype.is_int_signed(),
},
rd,
rs1,
imm: asm::Immediate::Value(*value as u64),
});
context.insts.push(asm::Instruction::Pseudo(
asm::Pseudo::Seqz { rd, rs: rd },
));
imm_mode = true;
}
}
ast::BinaryOperator::Equals => {
if (-2048..=2047).contains(&(*value as i128)) {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::Xori,
rd,
rs1,
imm: asm::Immediate::Value(*value as u64),
});
context.insts.push(asm::Instruction::Pseudo(
asm::Pseudo::Seqz { rd, rs: rd },
));
imm_mode = true;
}
}
ast::BinaryOperator::NotEquals => {
if (-2048..=2047).contains(&(*value as i128)) {
context.insts.push(asm::Instruction::IType {
instr: asm::IType::Xori,
rd,
rs1,
imm: asm::Immediate::Value(*value as u64),
});
context.insts.push(asm::Instruction::Pseudo(
asm::Pseudo::Snez { rd, rs: rd },
));
imm_mode = true;
}
}
_ => (),
}
if imm_mode {
if is_spilled {
self.translate_store_result(&rid, dtype.clone(), rd, context);
}
continue;
}
}
let rs2 =
self.translate_load_operand(rhs, get_rhs_register(&operand_dtype), context);
match op {
ast::BinaryOperator::Multiply => {
context.insts.push(asm::Instruction::RType {
@@ -1224,30 +1401,6 @@ impl Asmgen {
_ => unreachable!(),
}
// if !matches!(
// op,
// ast::BinaryOperator::Equals
// | ast::BinaryOperator::NotEquals
// | ast::BinaryOperator::Less
// | ast::BinaryOperator::Greater
// | ast::BinaryOperator::LessOrEqual
// | ast::BinaryOperator::GreaterOrEqual
// ) {
// // 하위 바이트만
// if is_integer(&org_operand_dtype)
// && get_dtype_size(&org_operand_dtype, structs) < 4
// {
// context.insts.push(asm::Instruction::IType {
// instr: asm::IType::Andi,
// rd,
// rs1: rd,
// imm: asm::Immediate::Value(
// (1 << (get_dtype_size(&org_operand_dtype, structs) * 8)) - 1,
// ),
// });
// }
// }
if is_spilled {
self.translate_store_result(&rid, dtype.clone(), rd, context);
}