diff --git a/crates/codegen/src/compile.rs b/crates/codegen/src/compile.rs index 0131b3008e3..8d4d35c0fc4 100644 --- a/crates/codegen/src/compile.rs +++ b/crates/codegen/src/compile.rs @@ -4030,7 +4030,9 @@ impl Compiler { if let Some(ref guard) = m.guard { // Compile guard and jump to end if false self.compile_expression(guard)?; - emit!(self, Instruction::JumpIfFalseOrPop { target: end }); + emit!(self, Instruction::CopyItem { index: 1_u32 }); + emit!(self, Instruction::PopJumpIfFalse { target: end }); + emit!(self, Instruction::PopTop); } self.compile_statements(&m.body)?; } @@ -4044,92 +4046,97 @@ impl Compiler { Ok(()) } - fn compile_chained_comparison( + /// [CPython `compiler_addcompare`](https://github.com/python/cpython/blob/627894459a84be3488a1789919679c997056a03c/Python/compile.c#L2880-L2924) + fn compile_addcompare(&mut self, op: &CmpOp) { + use bytecode::ComparisonOperator::*; + match op { + CmpOp::Eq => emit!(self, Instruction::CompareOperation { op: Equal }), + CmpOp::NotEq => emit!(self, Instruction::CompareOperation { op: NotEqual }), + CmpOp::Lt => emit!(self, Instruction::CompareOperation { op: Less }), + CmpOp::LtE => emit!(self, Instruction::CompareOperation { op: LessOrEqual }), + CmpOp::Gt => emit!(self, Instruction::CompareOperation { op: Greater }), + CmpOp::GtE => { + emit!(self, Instruction::CompareOperation { op: GreaterOrEqual }) + } + CmpOp::In => emit!(self, Instruction::ContainsOp(Invert::No)), + CmpOp::NotIn => emit!(self, Instruction::ContainsOp(Invert::Yes)), + CmpOp::Is => emit!(self, Instruction::IsOp(Invert::No)), + CmpOp::IsNot => emit!(self, Instruction::IsOp(Invert::Yes)), + } + } + + /// Compile a chained comparison. + /// + /// ```py + /// a == b == c == d + /// ``` + /// + /// Will compile into (pseudo code): + /// + /// ```py + /// result = a == b + /// if result: + /// result = b == c + /// if result: + /// result = c == d + /// ``` + /// + /// # See Also + /// - [CPython `compiler_compare`](https://github.com/python/cpython/blob/627894459a84be3488a1789919679c997056a03c/Python/compile.c#L4678-L4717) + fn compile_compare( &mut self, left: &Expr, ops: &[CmpOp], - exprs: &[Expr], + comparators: &[Expr], ) -> CompileResult<()> { - assert!(!ops.is_empty()); - assert_eq!(exprs.len(), ops.len()); let (last_op, mid_ops) = ops.split_last().unwrap(); - let (last_val, mid_exprs) = exprs.split_last().unwrap(); - - use bytecode::ComparisonOperator::*; - let compile_cmpop = |c: &mut Self, op: &CmpOp| match op { - CmpOp::Eq => emit!(c, Instruction::CompareOperation { op: Equal }), - CmpOp::NotEq => emit!(c, Instruction::CompareOperation { op: NotEqual }), - CmpOp::Lt => emit!(c, Instruction::CompareOperation { op: Less }), - CmpOp::LtE => emit!(c, Instruction::CompareOperation { op: LessOrEqual }), - CmpOp::Gt => emit!(c, Instruction::CompareOperation { op: Greater }), - CmpOp::GtE => { - emit!(c, Instruction::CompareOperation { op: GreaterOrEqual }) - } - CmpOp::In => emit!(c, Instruction::ContainsOp(Invert::No)), - CmpOp::NotIn => emit!(c, Instruction::ContainsOp(Invert::Yes)), - CmpOp::Is => emit!(c, Instruction::IsOp(Invert::No)), - CmpOp::IsNot => emit!(c, Instruction::IsOp(Invert::Yes)), - }; - - // a == b == c == d - // compile into (pseudo code): - // result = a == b - // if result: - // result = b == c - // if result: - // result = c == d + let (last_comparator, mid_comparators) = comparators.split_last().unwrap(); // initialize lhs outside of loop self.compile_expression(left)?; - let end_blocks = if mid_exprs.is_empty() { - None - } else { - let break_block = self.new_block(); - let after_block = self.new_block(); - Some((break_block, after_block)) - }; + if mid_comparators.is_empty() { + self.compile_expression(last_comparator)?; + self.compile_addcompare(last_op); + + return Ok(()); + } + + let cleanup = self.new_block(); // for all comparisons except the last (as the last one doesn't need a conditional jump) - for (op, val) in mid_ops.iter().zip(mid_exprs) { - self.compile_expression(val)?; + for (op, comparator) in mid_ops.iter().zip(mid_comparators) { + self.compile_expression(comparator)?; + // store rhs for the next comparison in chain emit!(self, Instruction::Swap { index: 2 }); - emit!(self, Instruction::CopyItem { index: 2_u32 }); + emit!(self, Instruction::CopyItem { index: 2 }); - compile_cmpop(self, op); + self.compile_addcompare(op); // if comparison result is false, we break with this value; if true, try the next one. - if let Some((break_block, _)) = end_blocks { - emit!( - self, - Instruction::JumpIfFalseOrPop { - target: break_block, - } - ); - } + /* + emit!(self, Instruction::CopyItem { index: 1 }); + // emit!(self, Instruction::ToBool); // TODO: Uncomment this + emit!(self, Instruction::PopJumpIfFalse { target: cleanup }); + emit!(self, Instruction::PopTop); + */ + + emit!(self, Instruction::JumpIfFalseOrPop { target: cleanup }); } - // handle the last comparison - self.compile_expression(last_val)?; - compile_cmpop(self, last_op); + self.compile_expression(last_comparator)?; + self.compile_addcompare(last_op); - if let Some((break_block, after_block)) = end_blocks { - emit!( - self, - Instruction::Jump { - target: after_block, - } - ); - - // early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs. - self.switch_to_block(break_block); - emit!(self, Instruction::Swap { index: 2 }); - emit!(self, Instruction::PopTop); + let end = self.new_block(); + emit!(self, Instruction::Jump { target: end }); - self.switch_to_block(after_block); - } + // early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs. + self.switch_to_block(cleanup); + emit!(self, Instruction::Swap { index: 2 }); + emit!(self, Instruction::PopTop); + self.switch_to_block(end); Ok(()) } @@ -4457,14 +4464,16 @@ impl Compiler { let after_block = self.new_block(); let (last_value, values) = values.split_last().unwrap(); + for value in values { self.compile_expression(value)?; + emit!(self, Instruction::CopyItem { index: 1_u32 }); match op { BoolOp::And => { emit!( self, - Instruction::JumpIfFalseOrPop { + Instruction::PopJumpIfFalse { target: after_block, } ); @@ -4472,12 +4481,14 @@ impl Compiler { BoolOp::Or => { emit!( self, - Instruction::JumpIfTrueOrPop { + Instruction::PopJumpIfTrue { target: after_block, } ); } } + + emit!(self, Instruction::PopTop); } // If all values did not qualify, take the value of the last value: @@ -4554,7 +4565,7 @@ impl Compiler { comparators, .. }) => { - self.compile_chained_comparison(left, ops, comparators)?; + self.compile_compare(left, ops, comparators)?; } // Expr::Constant(ExprConstant { value, .. }) => { // self.emit_load_const(compile_constant(value));