Skip to content
149 changes: 80 additions & 69 deletions crates/codegen/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4030,7 +4030,9 @@ impl Compiler {
if let Some(ref guard) = m.guard {
// Compile guard and jump to end if false
self.compile_expression(guard)?;
emit!(self, Instruction::JumpIfFalseOrPop { target: end });
emit!(self, Instruction::CopyItem { index: 1_u32 });
emit!(self, Instruction::PopJumpIfFalse { target: end });
emit!(self, Instruction::PopTop);
}
self.compile_statements(&m.body)?;
}
Expand All @@ -4044,92 +4046,97 @@ impl Compiler {
Ok(())
}

fn compile_chained_comparison(
/// [CPython `compiler_addcompare`](https://github.com/python/cpython/blob/627894459a84be3488a1789919679c997056a03c/Python/compile.c#L2880-L2924)
fn compile_addcompare(&mut self, op: &CmpOp) {
use bytecode::ComparisonOperator::*;
match op {
CmpOp::Eq => emit!(self, Instruction::CompareOperation { op: Equal }),
CmpOp::NotEq => emit!(self, Instruction::CompareOperation { op: NotEqual }),
CmpOp::Lt => emit!(self, Instruction::CompareOperation { op: Less }),
CmpOp::LtE => emit!(self, Instruction::CompareOperation { op: LessOrEqual }),
CmpOp::Gt => emit!(self, Instruction::CompareOperation { op: Greater }),
CmpOp::GtE => {
emit!(self, Instruction::CompareOperation { op: GreaterOrEqual })
}
CmpOp::In => emit!(self, Instruction::ContainsOp(Invert::No)),
CmpOp::NotIn => emit!(self, Instruction::ContainsOp(Invert::Yes)),
CmpOp::Is => emit!(self, Instruction::IsOp(Invert::No)),
CmpOp::IsNot => emit!(self, Instruction::IsOp(Invert::Yes)),
}
}

/// Compile a chained comparison.
///
/// ```py
/// a == b == c == d
/// ```
///
/// Will compile into (pseudo code):
///
/// ```py
/// result = a == b
/// if result:
/// result = b == c
/// if result:
/// result = c == d
/// ```
///
/// # See Also
/// - [CPython `compiler_compare`](https://github.com/python/cpython/blob/627894459a84be3488a1789919679c997056a03c/Python/compile.c#L4678-L4717)
fn compile_compare(
&mut self,
left: &Expr,
ops: &[CmpOp],
exprs: &[Expr],
comparators: &[Expr],
) -> CompileResult<()> {
assert!(!ops.is_empty());
assert_eq!(exprs.len(), ops.len());
let (last_op, mid_ops) = ops.split_last().unwrap();
let (last_val, mid_exprs) = exprs.split_last().unwrap();

use bytecode::ComparisonOperator::*;
let compile_cmpop = |c: &mut Self, op: &CmpOp| match op {
CmpOp::Eq => emit!(c, Instruction::CompareOperation { op: Equal }),
CmpOp::NotEq => emit!(c, Instruction::CompareOperation { op: NotEqual }),
CmpOp::Lt => emit!(c, Instruction::CompareOperation { op: Less }),
CmpOp::LtE => emit!(c, Instruction::CompareOperation { op: LessOrEqual }),
CmpOp::Gt => emit!(c, Instruction::CompareOperation { op: Greater }),
CmpOp::GtE => {
emit!(c, Instruction::CompareOperation { op: GreaterOrEqual })
}
CmpOp::In => emit!(c, Instruction::ContainsOp(Invert::No)),
CmpOp::NotIn => emit!(c, Instruction::ContainsOp(Invert::Yes)),
CmpOp::Is => emit!(c, Instruction::IsOp(Invert::No)),
CmpOp::IsNot => emit!(c, Instruction::IsOp(Invert::Yes)),
};

// a == b == c == d
// compile into (pseudo code):
// result = a == b
// if result:
// result = b == c
// if result:
// result = c == d
let (last_comparator, mid_comparators) = comparators.split_last().unwrap();

// initialize lhs outside of loop
self.compile_expression(left)?;

let end_blocks = if mid_exprs.is_empty() {
None
} else {
let break_block = self.new_block();
let after_block = self.new_block();
Some((break_block, after_block))
};
if mid_comparators.is_empty() {
self.compile_expression(last_comparator)?;
self.compile_addcompare(last_op);

return Ok(());
}

let cleanup = self.new_block();

// for all comparisons except the last (as the last one doesn't need a conditional jump)
for (op, val) in mid_ops.iter().zip(mid_exprs) {
self.compile_expression(val)?;
for (op, comparator) in mid_ops.iter().zip(mid_comparators) {
self.compile_expression(comparator)?;

// store rhs for the next comparison in chain
emit!(self, Instruction::Swap { index: 2 });
emit!(self, Instruction::CopyItem { index: 2_u32 });
emit!(self, Instruction::CopyItem { index: 2 });

compile_cmpop(self, op);
self.compile_addcompare(op);

// if comparison result is false, we break with this value; if true, try the next one.
if let Some((break_block, _)) = end_blocks {
emit!(
self,
Instruction::JumpIfFalseOrPop {
target: break_block,
}
);
}
/*
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the blocker of this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the only reason why both opcodes stayed:/

This change caused regression for

def test_no_wraparound_jump(self):
# See https://bugs.python.org/issue46724
def while_not_chained(a, b, c):
while not (a < b < c):
pass
for instr in dis.Bytecode(while_not_chained):
self.assertNotEqual(instr.opname, "EXTENDED_ARG")

Not because it found an EXTENDED_ARG, but because iterating over dis.Bytecode(...) gave an error

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if that's because our test_compile is an old one, replacing it will be also good.

emit!(self, Instruction::CopyItem { index: 1 });
// emit!(self, Instruction::ToBool); // TODO: Uncomment this
emit!(self, Instruction::PopJumpIfFalse { target: cleanup });
emit!(self, Instruction::PopTop);
*/

emit!(self, Instruction::JumpIfFalseOrPop { target: cleanup });
}

// handle the last comparison
self.compile_expression(last_val)?;
compile_cmpop(self, last_op);
self.compile_expression(last_comparator)?;
self.compile_addcompare(last_op);

if let Some((break_block, after_block)) = end_blocks {
emit!(
self,
Instruction::Jump {
target: after_block,
}
);

// early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs.
self.switch_to_block(break_block);
emit!(self, Instruction::Swap { index: 2 });
emit!(self, Instruction::PopTop);
let end = self.new_block();
emit!(self, Instruction::Jump { target: end });

self.switch_to_block(after_block);
}
// early exit left us with stack: `rhs, comparison_result`. We need to clean up rhs.
self.switch_to_block(cleanup);
emit!(self, Instruction::Swap { index: 2 });
emit!(self, Instruction::PopTop);

self.switch_to_block(end);
Ok(())
}

Expand Down Expand Up @@ -4457,27 +4464,31 @@ impl Compiler {
let after_block = self.new_block();

let (last_value, values) = values.split_last().unwrap();

for value in values {
self.compile_expression(value)?;

emit!(self, Instruction::CopyItem { index: 1_u32 });
match op {
BoolOp::And => {
emit!(
self,
Instruction::JumpIfFalseOrPop {
Instruction::PopJumpIfFalse {
target: after_block,
}
);
}
BoolOp::Or => {
emit!(
self,
Instruction::JumpIfTrueOrPop {
Instruction::PopJumpIfTrue {
target: after_block,
}
);
}
}

emit!(self, Instruction::PopTop);
}

// If all values did not qualify, take the value of the last value:
Expand Down Expand Up @@ -4554,7 +4565,7 @@ impl Compiler {
comparators,
..
}) => {
self.compile_chained_comparison(left, ops, comparators)?;
self.compile_compare(left, ops, comparators)?;
}
// Expr::Constant(ExprConstant { value, .. }) => {
// self.emit_load_const(compile_constant(value));
Expand Down
Loading