diff --git a/src/asmgen/mod.rs b/src/asmgen/mod.rs index d784b28..a4416b0 100644 --- a/src/asmgen/mod.rs +++ b/src/asmgen/mod.rs @@ -40,6 +40,8 @@ impl Translate for Asmgen { struct InferenceGraph { edges: HashSet<(ir::RegisterId, ir::RegisterId)>, vertices: HashMap, + has_call: bool, + is_a0_return_pointer: bool, } impl InferenceGraph { @@ -234,6 +236,8 @@ impl InferenceGraph { let (mut num_int_args, mut num_float_args) = get_number_of_register_arguments(&signature.ret, &signature.params, structs); + let mut has_call = false; + let mut is_a0_return_pointer = false; 'outer: for block in code.blocks.values() { for inst in &block.instructions { @@ -253,17 +257,20 @@ impl InferenceGraph { num_int_args = 8; num_float_args = 8; // call이 있는 경우 a는 사용 x + has_call = true; break 'outer; } else if let ir::Instruction::Store { value, .. } = inst.deref() { if is_struct(&value.dtype(), structs).is_some() { num_int_args = 8; num_float_args = 8; + has_call = true; // memcpy break 'outer; } } else if let ir::Instruction::Load { ptr } = inst.deref() { if is_struct(ptr.dtype().get_pointer_inner().unwrap(), structs).is_some() { num_int_args = 8; num_float_args = 8; + has_call = true; // memcpy break 'outer; } } @@ -274,6 +281,8 @@ impl InferenceGraph { if size > 16 { num_int_args = 8; num_float_args = 8; + has_call = true; // memcpy + is_a0_return_pointer = true; } } @@ -423,7 +432,12 @@ impl InferenceGraph { } } - InferenceGraph { edges, vertices } + InferenceGraph { + edges, + vertices, + has_call, + is_a0_return_pointer, + } } fn get_register(&self, rid: &ir::RegisterId) -> Option { @@ -474,6 +488,7 @@ struct Context { stack_allocation: u64, new_blocks: Vec<(asm::Label, Vec)>, inference_graph: InferenceGraph, + saved_reg_offset: usize, } impl Asmgen { @@ -612,7 +627,17 @@ impl Asmgen { stack_allocation += ceil_to_multiple_of_16(get_dtype_size(dtype, structs)); } - stack_allocation += 24; // s0, ra, a0 + stack_allocation += 8; // s0 + let mut saved_reg_offset = 8; + + if inference_graph.has_call { + stack_allocation += 8; // ra + saved_reg_offset += 8; + } + if inference_graph.is_a0_return_pointer { + stack_allocation += 8; // a0 + saved_reg_offset += 8; + } let num_int_saved_regs = inference_graph .vertices @@ -640,7 +665,10 @@ impl Asmgen { .len(); stack_allocation += ((num_int_saved_regs + num_float_saved_regs) * 8) as u64; - if stack_allocation % 16 != 0 { + if stack_allocation == 8 { + // Only s0: No Spill + stack_allocation = 0; + } else if stack_allocation % 16 != 0 { // 스택은 16바이트 정렬 stack_allocation += 16 - (stack_allocation % 16); } @@ -650,57 +678,66 @@ impl Asmgen { } let mut insts = vec![]; - self.translate_addi( - asm::Register::Sp, - asm::Register::Sp, - !stack_allocation + 1, - &mut insts, - ); - self.translate_store( - asm::SType::SD, - asm::Register::Sp, - asm::Register::S0, - stack_allocation - 8, - &mut insts, - ); - self.translate_addi( - asm::Register::S0, - asm::Register::Sp, - stack_allocation, - &mut insts, - ); - insts.push(asm::Instruction::SType { - instr: asm::SType::SD, - rs1: asm::Register::S0, - rs2: asm::Register::Ra, - imm: asm::Immediate::Value(!16 + 1), - }); - insts.push(asm::Instruction::SType { - instr: asm::SType::SD, - rs1: asm::Register::S0, - rs2: asm::Register::A0, - imm: asm::Immediate::Value(!24 + 1), - }); - // S1~레지스터 및 FS0~ 레지스터 백업 - for i in 0..num_int_saved_regs { - insts.push(asm::Instruction::SType { - instr: asm::SType::SD, - rs1: asm::Register::S0, - rs2: asm::Register::saved( - asm::RegisterType::Integer, - i + 1, // S0는 이 용도로 쓰지 않음 - ), - imm: asm::Immediate::Value((!(32 + i * 8) + 1) as u64), - }); - } - for i in 0..num_float_saved_regs { - insts.push(asm::Instruction::SType { - instr: asm::SType::store(ir::Dtype::DOUBLE), - rs1: asm::Register::S0, - rs2: asm::Register::saved(asm::RegisterType::FloatingPoint, i), - imm: asm::Immediate::Value((!(32 + num_int_saved_regs * 8 + i * 8) + 1) as u64), - }); + if stack_allocation != 0 { + self.translate_addi( + asm::Register::Sp, + asm::Register::Sp, + !stack_allocation + 1, + &mut insts, + ); + self.translate_store( + asm::SType::SD, + asm::Register::Sp, + asm::Register::S0, + stack_allocation - 8, + &mut insts, + ); + self.translate_addi( + asm::Register::S0, + asm::Register::Sp, + stack_allocation, + &mut insts, + ); + if inference_graph.has_call { + insts.push(asm::Instruction::SType { + instr: asm::SType::SD, + rs1: asm::Register::S0, + rs2: asm::Register::Ra, + imm: asm::Immediate::Value(!16 + 1), + }); + } + if inference_graph.is_a0_return_pointer { + insts.push(asm::Instruction::SType { + instr: asm::SType::SD, + rs1: asm::Register::S0, + rs2: asm::Register::A0, + imm: asm::Immediate::Value(!24 + 1), + }); + } + + // S1~레지스터 및 FS0~ 레지스터 백업 + for i in 0..num_int_saved_regs { + insts.push(asm::Instruction::SType { + instr: asm::SType::SD, + rs1: asm::Register::S0, + rs2: asm::Register::saved( + asm::RegisterType::Integer, + i + 1, // S0는 이 용도로 쓰지 않음 + ), + imm: asm::Immediate::Value((!(saved_reg_offset + (i + 1) * 8) + 1) as u64), + }); + } + for i in 0..num_float_saved_regs { + insts.push(asm::Instruction::SType { + instr: asm::SType::store(ir::Dtype::DOUBLE), + rs1: asm::Register::S0, + rs2: asm::Register::saved(asm::RegisterType::FloatingPoint, i), + imm: asm::Immediate::Value( + (!(saved_reg_offset + num_int_saved_regs * 8 + (i + 1) * 8) + 1) as u64, + ), + }); + } } let mut num_int_args = 0; @@ -939,73 +976,83 @@ impl Asmgen { stack_allocation, new_blocks: Vec::new(), inference_graph, + saved_reg_offset, } } fn translate_epilogue(&mut self, context: &mut Context) { - // S1~레지스터 및 FS0~ 레지스터 복원 - let num_int_regs = context - .inference_graph - .vertices - .values() - .filter_map(|(_, reg)| { - if let asm::Register::Saved(asm::RegisterType::Integer, i) = reg { - Some(*i) - } else { - None - } - }) - .collect::>() - .len(); - let num_float_regs = context - .inference_graph - .vertices - .values() - .filter_map(|(_, reg)| { - if let asm::Register::Saved(asm::RegisterType::FloatingPoint, i) = reg { - Some(*i) - } else { - None - } - }) - .collect::>() - .len(); - for i in 0..num_float_regs { - context.insts.push(asm::Instruction::IType { - instr: asm::IType::load(ir::Dtype::DOUBLE), - rd: asm::Register::saved(asm::RegisterType::FloatingPoint, i), - rs1: asm::Register::S0, - imm: asm::Immediate::Value((!(32 + num_int_regs * 8 + i * 8) + 1) as u64), - }); - } - for i in 0..num_int_regs { - context.insts.push(asm::Instruction::IType { - instr: asm::IType::LD, - rs1: asm::Register::S0, - rd: asm::Register::saved(asm::RegisterType::Integer, i + 1), // S0는 이 용도로 쓰지 않음 - imm: asm::Immediate::Value((!(32 + i * 8) + 1) as u64), - }); + if context.stack_allocation != 0 { + // S1~레지스터 및 FS0~ 레지스터 복원 + let num_int_regs = context + .inference_graph + .vertices + .values() + .filter_map(|(_, reg)| { + if let asm::Register::Saved(asm::RegisterType::Integer, i) = reg { + Some(*i) + } else { + None + } + }) + .collect::>() + .len(); + let num_float_regs = context + .inference_graph + .vertices + .values() + .filter_map(|(_, reg)| { + if let asm::Register::Saved(asm::RegisterType::FloatingPoint, i) = reg { + Some(*i) + } else { + None + } + }) + .collect::>() + .len(); + for i in 0..num_float_regs { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::load(ir::Dtype::DOUBLE), + rd: asm::Register::saved(asm::RegisterType::FloatingPoint, i), + rs1: asm::Register::S0, + imm: asm::Immediate::Value( + (!(context.saved_reg_offset + num_int_regs * 8 + (i + 1) * 8) + 1) as u64, + ), + }); + } + for i in 0..num_int_regs { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::LD, + rs1: asm::Register::S0, + rd: asm::Register::saved(asm::RegisterType::Integer, i + 1), // S0는 이 용도로 쓰지 않음 + imm: asm::Immediate::Value( + (!(context.saved_reg_offset + (i + 1) * 8) + 1) as u64, + ), + }); + } + + if context.inference_graph.has_call { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::LD, + rd: asm::Register::Ra, + rs1: asm::Register::S0, + imm: asm::Immediate::Value(!16 + 1), + }); + } + self.translate_load( + asm::IType::LD, + asm::Register::S0, + asm::Register::Sp, + context.stack_allocation - 8, + &mut context.insts, + ); + self.translate_addi( + asm::Register::Sp, + asm::Register::Sp, + context.stack_allocation, + &mut context.insts, + ); } - context.insts.push(asm::Instruction::IType { - instr: asm::IType::LD, - rd: asm::Register::Ra, - rs1: asm::Register::S0, - imm: asm::Immediate::Value(!16 + 1), - }); - self.translate_load( - asm::IType::LD, - asm::Register::S0, - asm::Register::Sp, - context.stack_allocation - 8, - &mut context.insts, - ); - self.translate_addi( - asm::Register::Sp, - asm::Register::Sp, - context.stack_allocation, - &mut context.insts, - ); context .insts .push(asm::Instruction::Pseudo(asm::Pseudo::Ret)); @@ -1036,10 +1083,140 @@ impl Asmgen { let operand_dtype = upgrade_dtype(&org_operand_dtype); let rs1 = self.translate_load_operand(lhs, get_lhs_register(&operand_dtype), context); - let rs2 = - self.translate_load_operand(rhs, get_rhs_register(&operand_dtype), context); let rd = rd.unwrap_or(get_res_register(dtype)); + if let Some(ir::Constant::Int { value, .. }) = rhs.get_constant() { + let mut imm_mode = false; + let data_size = asm::DataSize::try_from(operand_dtype.clone()).unwrap(); + match op { + ast::BinaryOperator::Plus + | ast::BinaryOperator::BitwiseAnd + | ast::BinaryOperator::BitwiseOr + | ast::BinaryOperator::BitwiseXor => { + if (-2048..=2047).contains(&(*value as i128)) { + context.insts.push(asm::Instruction::IType { + instr: match op { + ast::BinaryOperator::Plus => { + asm::IType::Addi(data_size) + } + ast::BinaryOperator::BitwiseAnd => asm::IType::Andi, + ast::BinaryOperator::BitwiseOr => asm::IType::Ori, + ast::BinaryOperator::BitwiseXor => asm::IType::Xori, + _ => unreachable!(), + }, + rd, + rs1, + imm: asm::Immediate::Value(*value as u64), + }); + imm_mode = true; + } + } + ast::BinaryOperator::Minus => { + if (-2047..=2048).contains(&(*value as i128)) { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::Addi(data_size), + rd, + rs1, + imm: asm::Immediate::Value((!value + 1) as u64), + }); + imm_mode = true; + } + } + ast::BinaryOperator::ShiftLeft => { + if (-2048..=2047).contains(&(*value as i128)) { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::Slli(data_size), + rd, + rs1, + imm: asm::Immediate::Value(*value as u64), + }); + imm_mode = true; + } + } + ast::BinaryOperator::ShiftRight => { + if (-2048..=2047).contains(&(*value as i128)) { + context.insts.push(asm::Instruction::IType { + instr: if operand_dtype.is_int_signed() { + asm::IType::Srai(data_size) + } else { + asm::IType::Srli(data_size) + }, + rd, + rs1, + imm: asm::Immediate::Value(*value as u64), + }); + imm_mode = true; + } + } + ast::BinaryOperator::Less => { + if (-2048..=2047).contains(&(*value as i128)) { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::Slti { + is_signed: operand_dtype.is_int_signed(), + }, + rd, + rs1, + imm: asm::Immediate::Value(*value as u64), + }); + imm_mode = true; + } + } + ast::BinaryOperator::GreaterOrEqual => { + if (-2048..=2047).contains(&(*value as i128)) { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::Slti { + is_signed: operand_dtype.is_int_signed(), + }, + rd, + rs1, + imm: asm::Immediate::Value(*value as u64), + }); + context.insts.push(asm::Instruction::Pseudo( + asm::Pseudo::Seqz { rd, rs: rd }, + )); + imm_mode = true; + } + } + ast::BinaryOperator::Equals => { + if (-2048..=2047).contains(&(*value as i128)) { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::Xori, + rd, + rs1, + imm: asm::Immediate::Value(*value as u64), + }); + context.insts.push(asm::Instruction::Pseudo( + asm::Pseudo::Seqz { rd, rs: rd }, + )); + imm_mode = true; + } + } + ast::BinaryOperator::NotEquals => { + if (-2048..=2047).contains(&(*value as i128)) { + context.insts.push(asm::Instruction::IType { + instr: asm::IType::Xori, + rd, + rs1, + imm: asm::Immediate::Value(*value as u64), + }); + context.insts.push(asm::Instruction::Pseudo( + asm::Pseudo::Snez { rd, rs: rd }, + )); + imm_mode = true; + } + } + _ => (), + } + if imm_mode { + if is_spilled { + self.translate_store_result(&rid, dtype.clone(), rd, context); + } + continue; + } + } + + let rs2 = + self.translate_load_operand(rhs, get_rhs_register(&operand_dtype), context); match op { ast::BinaryOperator::Multiply => { context.insts.push(asm::Instruction::RType { @@ -1224,30 +1401,6 @@ impl Asmgen { _ => unreachable!(), } - // if !matches!( - // op, - // ast::BinaryOperator::Equals - // | ast::BinaryOperator::NotEquals - // | ast::BinaryOperator::Less - // | ast::BinaryOperator::Greater - // | ast::BinaryOperator::LessOrEqual - // | ast::BinaryOperator::GreaterOrEqual - // ) { - // // 하위 바이트만 - // if is_integer(&org_operand_dtype) - // && get_dtype_size(&org_operand_dtype, structs) < 4 - // { - // context.insts.push(asm::Instruction::IType { - // instr: asm::IType::Andi, - // rd, - // rs1: rd, - // imm: asm::Immediate::Value( - // (1 << (get_dtype_size(&org_operand_dtype, structs) * 8)) - 1, - // ), - // }); - // } - // } - if is_spilled { self.translate_store_result(&rid, dtype.clone(), rd, context); }