Files
app-store-optimization/skills/loki-mode/benchmarks/results/2026-01-05-01-24-17/swebench-patches/pytest-dev__pytest-5103.patch

345 lines
13 KiB
Diff

This is a standalone task to generate a patch for the pytest-dev/pytest repository. I need to create the patch based on my knowledge of pytest's assertion rewriting system and the issue requirements.
Based on my understanding of:
1. pytest's assertion rewriting in `src/_pytest/assertion/rewrite.py`
2. The issue requesting better error messages for `all()` and `any()`
3. The AST-based approach pytest uses
Here's the patch to add support for unrolling `all()` and `any()` calls:
--- a/src/_pytest/assertion/rewrite.py
+++ b/src/_pytest/assertion/rewrite.py
@@ -868,6 +868,59 @@ class AssertionRewriter(ast.NodeVisitor):
)
return make_call(res, builtin("repr"), [comparator])
+ def visit_Call_all_any(
+ self, call: ast.Call
+ ) -> Optional[Tuple[ast.Name, str]]:
+ """Special handling for all()/any() with comprehension arguments.
+
+ Unrolls the comprehension to find and display the first failing element.
+ Returns None if the call cannot be handled specially.
+ """
+ if not isinstance(call.func, ast.Name):
+ return None
+ if call.func.id not in ("all", "any"):
+ return None
+ if len(call.args) != 1:
+ return None
+ if call.keywords:
+ return None
+
+ arg = call.args[0]
+ if not isinstance(arg, (ast.GeneratorExp, ast.ListComp)):
+ return None
+
+ # Only handle simple single-for comprehensions
+ if len(arg.generators) != 1:
+ return None
+ gen = arg.generators[0]
+ if gen.is_async:
+ return None
+
+ func_name = call.func.id
+ is_all = func_name == "all"
+
+ # Construct a new call that iterates and finds the failing element
+ # We use a helper from _pytest._code.assertion_helper
+ from_mod = ast.ImportFrom(
+ module="_pytest.assertion._unroll",
+ names=[ast.alias(name="_unroll_all_any", asname=None)],
+ level=0,
+ )
+ self.statements.append(from_mod)
+
+ # Build: _unroll_all_any(<comprehension>, is_all=True/False)
+ helper_call = ast.Call(
+ func=ast.Name("_unroll_all_any", ast.Load()),
+ args=[arg],
+ keywords=[ast.keyword(arg="is_all", value=ast.Constant(is_all))],
+ )
+
+ res = self.variable()
+ self.statements.append(ast.Assign([ast.Name(res, ast.Store())], helper_call))
+
+ return ast.Name(res, ast.Load()), "%(" + res + ")s"
+
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ # Try special handling for all()/any() first
+ result = self.visit_Call_all_any(call)
+ if result is not None:
+ return result
+
new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
Wait, I should provide the complete patch. Based on my analysis of pytest's assertion rewriting (the AssertionRewriter class in `src/_pytest/assertion/rewrite.py`), I need to:
1. Modify the `visit_Call` method to detect `all()`/`any()` calls with comprehensions
2. Rewrite them to iterate through and find the first failing element
3. Display that element in the error message
Here is the complete patch:
--- a/src/_pytest/assertion/rewrite.py
+++ b/src/_pytest/assertion/rewrite.py
@@ -868,6 +868,119 @@ class AssertionRewriter(ast.NodeVisitor):
return res, self.explanation_param(self.pop_format_context(expl_call))
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ """Handle Call nodes, with special treatment for all()/any()."""
+ # Try to handle all()/any() specially for better error messages
+ if (
+ isinstance(call.func, ast.Name)
+ and call.func.id in ("all", "any")
+ and len(call.args) == 1
+ and not call.keywords
+ and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp))
+ ):
+ return self._visit_all_any(call)
+
+ return self._visit_call_default(call)
+
+ def _visit_all_any(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ """Handle all()/any() calls with comprehension arguments.
+
+ Unrolls the comprehension to iterate and find the first failing element,
+ providing a more useful error message.
+ """
+ func_name = call.func.id
+ is_all = func_name == "all"
+ arg = call.args[0]
+
+ # Only handle simple single-for comprehensions
+ generators = arg.generators
+ if len(generators) != 1:
+ return self._visit_call_default(call)
+
+ gen = generators[0]
+ if gen.is_async:
+ return self._visit_call_default(call)
+
+ # Create variables for the iteration
+ iter_var = self.variable()
+ result_var = self.variable()
+ fail_elem_var = self.variable()
+ fail_cond_var = self.variable()
+
+ # Evaluate the iterable and store it
+ iter_res, iter_expl = self.visit(gen.iter)
+ self.statements.append(
+ ast.Assign([ast.Name(iter_var, ast.Store())], iter_res)
+ )
+
+ # Initialize result to True for all(), False for any()
+ self.statements.append(
+ ast.Assign(
+ [ast.Name(result_var, ast.Store())],
+ ast.Constant(is_all),
+ )
+ )
+ # Initialize fail tracking variables
+ self.statements.append(
+ ast.Assign([ast.Name(fail_elem_var, ast.Store())], ast.Constant(None))
+ )
+ self.statements.append(
+ ast.Assign([ast.Name(fail_cond_var, ast.Store())], ast.Constant(None))
+ )
+
+ # Build the for loop body that checks each element
+ import copy
+ target_copy = copy.deepcopy(gen.target)
+ elt_copy = copy.deepcopy(arg.elt)
+
+ # Store the condition result
+ cond_var = self.variable()
+
+ # For all(): if not cond, fail
+ # For any(): if cond, succeed
+ if is_all:
+ test = ast.UnaryOp(ast.Not(), ast.Name(cond_var, ast.Load()))
+ fail_result = ast.Constant(False)
+ else:
+ test = ast.Name(cond_var, ast.Load())
+ fail_result = ast.Constant(True)
+
+ # Build: cond_var = <element test>
+ # Apply if clauses from the comprehension
+ loop_body = [
+ ast.Assign([ast.Name(cond_var, ast.Store())], elt_copy)
+ ]
+
+ # If condition fails, record and break
+ fail_block = [
+ ast.Assign([ast.Name(result_var, ast.Store())], fail_result),
+ ast.Assign([ast.Name(fail_elem_var, ast.Store())], copy.deepcopy(gen.target)),
+ ast.Assign([ast.Name(fail_cond_var, ast.Store())], ast.Name(cond_var, ast.Load())),
+ ast.Break(),
+ ]
+
+ check_block = ast.If(test, fail_block, [])
+
+ # Handle if clauses in the comprehension
+ for if_clause in gen.ifs:
+ check_block = ast.If(copy.deepcopy(if_clause), [check_block], [])
+
+ loop_body.append(check_block)
+
+ loop = ast.For(
+ target=target_copy,
+ iter=ast.Name(iter_var, ast.Load()),
+ body=loop_body,
+ orelse=[],
+ )
+ self.statements.append(loop)
+
+ # Format the explanation to show the failing element
+ expl = "{}({} for {} in {})".format(func_name, "%(" + fail_cond_var + ")s", "%(" + fail_elem_var + ")s", iter_expl)
+
+ return ast.Name(result_var, ast.Load()), expl
+
+ def _visit_call_default(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ """Default handling for Call nodes."""
new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
Hmm, I realize I should look at the actual structure of pytest more carefully. The issue is from around pytest 4.4.0 era. Let me reconsider the approach.
Looking at the issue more carefully:
- The user wants better error messages when using `all()` or `any()` with comprehensions
- Currently: `assert all([is_even(number) for number in even_stevens])` shows `all([False, False, False, ...])`
- Desired: Show which specific element failed, like the for loop version does
The patch needs to:
1. Detect `all()`/`any()` calls with generator expressions or list comprehensions
2. Transform them to iterate through elements and find the first failing one
3. Display that element clearly in the error message
--- a/src/_pytest/assertion/rewrite.py
+++ b/src/_pytest/assertion/rewrite.py
@@ -1,4 +1,5 @@
"""Rewrite assertion AST to produce nice error messages."""
+import copy
import ast
import errno
import functools
@@ -868,6 +869,115 @@ class AssertionRewriter(ast.NodeVisitor):
return res, self.explanation_param(self.pop_format_context(expl_call))
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ """Handle Call nodes, with special treatment for all()/any()."""
+ # Try to handle all()/any() specially for better error messages
+ if self._is_unrollable_all_any(call):
+ return self._visit_all_any_call(call)
+ return self._visit_call_fallback(call)
+
+ def _is_unrollable_all_any(self, call: ast.Call) -> bool:
+ """Check if this is an all()/any() call we can unroll."""
+ if not isinstance(call.func, ast.Name):
+ return False
+ if call.func.id not in ("all", "any"):
+ return False
+ if len(call.args) != 1:
+ return False
+ if call.keywords:
+ return False
+ arg = call.args[0]
+ if not isinstance(arg, (ast.GeneratorExp, ast.ListComp)):
+ return False
+ if len(arg.generators) != 1:
+ return False
+ gen = arg.generators[0]
+ if gen.is_async:
+ return False
+ return True
+
+ def _visit_all_any_call(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ """Handle all()/any() by unrolling to find the failing element."""
+ func_name = call.func.id
+ is_all = func_name == "all"
+ arg = call.args[0]
+ gen = arg.generators[0]
+
+ # Create tracking variables
+ result_var = self.variable()
+ fail_elem_var = self.variable()
+ fail_expl_var = self.variable()
+ iter_var = self.variable()
+
+ # Store format context
+ self.push_format_context()
+
+ # Evaluate and store the iterable
+ iter_res, iter_expl = self.visit(gen.iter)
+ self.statements.append(
+ ast.Assign([ast.Name(iter_var, ast.Store())], iter_res)
+ )
+
+ # Initialize result (True for all, False for any)
+ self.statements.append(
+ ast.Assign([ast.Name(result_var, ast.Store())], ast.Constant(is_all))
+ )
+ self.statements.append(
+ ast.Assign([ast.Name(fail_elem_var, ast.Store())], ast.Constant(None))
+ )
+ self.statements.append(
+ ast.Assign([ast.Name(fail_expl_var, ast.Store())], ast.Constant(""))
+ )
+
+ # Build the loop target and element test
+ target = copy.deepcopy(gen.target)
+ elt = copy.deepcopy(arg.elt)
+
+ # Create inner variable for condition
+ cond_var = self.variable()
+
+ # Build loop body: evaluate condition, check if failed
+ # For all(): condition must be True, so check "not cond"
+ # For any(): condition must be True to succeed, check "cond"
+
+ loop_body = []
+ # Evaluate the condition: cond_var = <elt>
+ loop_body.append(ast.Assign([ast.Name(cond_var, ast.Store())], elt))
+
+ # Check failure condition
+ if is_all:
+ test = ast.UnaryOp(ast.Not(), ast.Name(cond_var, ast.Load()))
+ new_result = ast.Constant(False)
+ else:
+ test = ast.Name(cond_var, ast.Load())
+ new_result = ast.Constant(True)
+
+ # When condition triggers, save the failing element info and break
+ fail_body = [
+ ast.Assign([ast.Name(result_var, ast.Store())], new_result),
+ ast.Assign([ast.Name(fail_elem_var, ast.Store())], copy.deepcopy(gen.target)),
+ ast.Break(),
+ ]
+
+ # Build the if check, wrapped with any comprehension conditions
+ inner_check = ast.If(test, fail_body, [])
+ for if_clause in reversed(gen.ifs):
+ inner_check = ast.If(copy.deepcopy(if_clause), [inner_check], [])
+
+ loop_body.append(inner_check)
+
+ loop = ast.For(
+ target=target,
+ iter=ast.Name(iter_var, ast.Load()),
+ body=loop_body,
+ orelse=[],
+ )
+ self.statements.append(loop)
+
+ expl = "{}({{{}}} for {{{}}} in {})".format(func_name, cond_var, fail_elem_var, iter_expl)
+ expl = self.pop_format_context(ast.Constant(expl))
+ return ast.Name(result_var, ast.Load()), self.explanation_param(expl)
+
+ def _visit_call_fallback(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ """Default handling for Call nodes."""
new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []