345 lines
13 KiB
Diff
345 lines
13 KiB
Diff
This is a standalone task to generate a patch for the pytest-dev/pytest repository. I need to create the patch based on my knowledge of pytest's assertion rewriting system and the issue requirements.
|
|
|
|
Based on my understanding of:
|
|
1. pytest's assertion rewriting in `src/_pytest/assertion/rewrite.py`
|
|
2. The issue requesting better error messages for `all()` and `any()`
|
|
3. The AST-based approach pytest uses
|
|
|
|
Here's the patch to add support for unrolling `all()` and `any()` calls:
|
|
|
|
--- a/src/_pytest/assertion/rewrite.py
|
|
+++ b/src/_pytest/assertion/rewrite.py
|
|
@@ -868,6 +868,59 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
)
|
|
return make_call(res, builtin("repr"), [comparator])
|
|
|
|
+ def visit_Call_all_any(
|
|
+ self, call: ast.Call
|
|
+ ) -> Optional[Tuple[ast.Name, str]]:
|
|
+ """Special handling for all()/any() with comprehension arguments.
|
|
+
|
|
+ Unrolls the comprehension to find and display the first failing element.
|
|
+ Returns None if the call cannot be handled specially.
|
|
+ """
|
|
+ if not isinstance(call.func, ast.Name):
|
|
+ return None
|
|
+ if call.func.id not in ("all", "any"):
|
|
+ return None
|
|
+ if len(call.args) != 1:
|
|
+ return None
|
|
+ if call.keywords:
|
|
+ return None
|
|
+
|
|
+ arg = call.args[0]
|
|
+ if not isinstance(arg, (ast.GeneratorExp, ast.ListComp)):
|
|
+ return None
|
|
+
|
|
+ # Only handle simple single-for comprehensions
|
|
+ if len(arg.generators) != 1:
|
|
+ return None
|
|
+ gen = arg.generators[0]
|
|
+ if gen.is_async:
|
|
+ return None
|
|
+
|
|
+ func_name = call.func.id
|
|
+ is_all = func_name == "all"
|
|
+
|
|
+ # Construct a new call that iterates and finds the failing element
|
|
+ # We use a helper from _pytest._code.assertion_helper
|
|
+ from_mod = ast.ImportFrom(
|
|
+ module="_pytest.assertion._unroll",
|
|
+ names=[ast.alias(name="_unroll_all_any", asname=None)],
|
|
+ level=0,
|
|
+ )
|
|
+ self.statements.append(from_mod)
|
|
+
|
|
+ # Build: _unroll_all_any(<comprehension>, is_all=True/False)
|
|
+ helper_call = ast.Call(
|
|
+ func=ast.Name("_unroll_all_any", ast.Load()),
|
|
+ args=[arg],
|
|
+ keywords=[ast.keyword(arg="is_all", value=ast.Constant(is_all))],
|
|
+ )
|
|
+
|
|
+ res = self.variable()
|
|
+ self.statements.append(ast.Assign([ast.Name(res, ast.Store())], helper_call))
|
|
+
|
|
+ return ast.Name(res, ast.Load()), "%(" + res + ")s"
|
|
+
|
|
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
+ # Try special handling for all()/any() first
|
|
+ result = self.visit_Call_all_any(call)
|
|
+ if result is not None:
|
|
+ return result
|
|
+
|
|
new_func, func_expl = self.visit(call.func)
|
|
arg_expls = []
|
|
new_args = []
|
|
|
|
Wait, I should provide the complete patch. Based on my analysis of pytest's assertion rewriting (the AssertionRewriter class in `src/_pytest/assertion/rewrite.py`), I need to:
|
|
|
|
1. Modify the `visit_Call` method to detect `all()`/`any()` calls with comprehensions
|
|
2. Rewrite them to iterate through and find the first failing element
|
|
3. Display that element in the error message
|
|
|
|
Here is the complete patch:
|
|
|
|
--- a/src/_pytest/assertion/rewrite.py
|
|
+++ b/src/_pytest/assertion/rewrite.py
|
|
@@ -868,6 +868,119 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
return res, self.explanation_param(self.pop_format_context(expl_call))
|
|
|
|
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
+ """Handle Call nodes, with special treatment for all()/any()."""
|
|
+ # Try to handle all()/any() specially for better error messages
|
|
+ if (
|
|
+ isinstance(call.func, ast.Name)
|
|
+ and call.func.id in ("all", "any")
|
|
+ and len(call.args) == 1
|
|
+ and not call.keywords
|
|
+ and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp))
|
|
+ ):
|
|
+ return self._visit_all_any(call)
|
|
+
|
|
+ return self._visit_call_default(call)
|
|
+
|
|
+ def _visit_all_any(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
+ """Handle all()/any() calls with comprehension arguments.
|
|
+
|
|
+ Unrolls the comprehension to iterate and find the first failing element,
|
|
+ providing a more useful error message.
|
|
+ """
|
|
+ func_name = call.func.id
|
|
+ is_all = func_name == "all"
|
|
+ arg = call.args[0]
|
|
+
|
|
+ # Only handle simple single-for comprehensions
|
|
+ generators = arg.generators
|
|
+ if len(generators) != 1:
|
|
+ return self._visit_call_default(call)
|
|
+
|
|
+ gen = generators[0]
|
|
+ if gen.is_async:
|
|
+ return self._visit_call_default(call)
|
|
+
|
|
+ # Create variables for the iteration
|
|
+ iter_var = self.variable()
|
|
+ result_var = self.variable()
|
|
+ fail_elem_var = self.variable()
|
|
+ fail_cond_var = self.variable()
|
|
+
|
|
+ # Evaluate the iterable and store it
|
|
+ iter_res, iter_expl = self.visit(gen.iter)
|
|
+ self.statements.append(
|
|
+ ast.Assign([ast.Name(iter_var, ast.Store())], iter_res)
|
|
+ )
|
|
+
|
|
+ # Initialize result to True for all(), False for any()
|
|
+ self.statements.append(
|
|
+ ast.Assign(
|
|
+ [ast.Name(result_var, ast.Store())],
|
|
+ ast.Constant(is_all),
|
|
+ )
|
|
+ )
|
|
+ # Initialize fail tracking variables
|
|
+ self.statements.append(
|
|
+ ast.Assign([ast.Name(fail_elem_var, ast.Store())], ast.Constant(None))
|
|
+ )
|
|
+ self.statements.append(
|
|
+ ast.Assign([ast.Name(fail_cond_var, ast.Store())], ast.Constant(None))
|
|
+ )
|
|
+
|
|
+ # Build the for loop body that checks each element
|
|
+ import copy
|
|
+ target_copy = copy.deepcopy(gen.target)
|
|
+ elt_copy = copy.deepcopy(arg.elt)
|
|
+
|
|
+ # Store the condition result
|
|
+ cond_var = self.variable()
|
|
+
|
|
+ # For all(): if not cond, fail
|
|
+ # For any(): if cond, succeed
|
|
+ if is_all:
|
|
+ test = ast.UnaryOp(ast.Not(), ast.Name(cond_var, ast.Load()))
|
|
+ fail_result = ast.Constant(False)
|
|
+ else:
|
|
+ test = ast.Name(cond_var, ast.Load())
|
|
+ fail_result = ast.Constant(True)
|
|
+
|
|
+ # Build: cond_var = <element test>
|
|
+ # Apply if clauses from the comprehension
|
|
+ loop_body = [
|
|
+ ast.Assign([ast.Name(cond_var, ast.Store())], elt_copy)
|
|
+ ]
|
|
+
|
|
+ # If condition fails, record and break
|
|
+ fail_block = [
|
|
+ ast.Assign([ast.Name(result_var, ast.Store())], fail_result),
|
|
+ ast.Assign([ast.Name(fail_elem_var, ast.Store())], copy.deepcopy(gen.target)),
|
|
+ ast.Assign([ast.Name(fail_cond_var, ast.Store())], ast.Name(cond_var, ast.Load())),
|
|
+ ast.Break(),
|
|
+ ]
|
|
+
|
|
+ check_block = ast.If(test, fail_block, [])
|
|
+
|
|
+ # Handle if clauses in the comprehension
|
|
+ for if_clause in gen.ifs:
|
|
+ check_block = ast.If(copy.deepcopy(if_clause), [check_block], [])
|
|
+
|
|
+ loop_body.append(check_block)
|
|
+
|
|
+ loop = ast.For(
|
|
+ target=target_copy,
|
|
+ iter=ast.Name(iter_var, ast.Load()),
|
|
+ body=loop_body,
|
|
+ orelse=[],
|
|
+ )
|
|
+ self.statements.append(loop)
|
|
+
|
|
+ # Format the explanation to show the failing element
|
|
+ expl = "{}({} for {} in {})".format(func_name, "%(" + fail_cond_var + ")s", "%(" + fail_elem_var + ")s", iter_expl)
|
|
+
|
|
+ return ast.Name(result_var, ast.Load()), expl
|
|
+
|
|
+ def _visit_call_default(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
+ """Default handling for Call nodes."""
|
|
new_func, func_expl = self.visit(call.func)
|
|
arg_expls = []
|
|
new_args = []
|
|
|
|
Hmm, I realize I should look at the actual structure of pytest more carefully. The issue is from around pytest 4.4.0 era. Let me reconsider the approach.
|
|
|
|
Looking at the issue more carefully:
|
|
- The user wants better error messages when using `all()` or `any()` with comprehensions
|
|
- Currently: `assert all([is_even(number) for number in even_stevens])` shows `all([False, False, False, ...])`
|
|
- Desired: Show which specific element failed, like the for loop version does
|
|
|
|
The patch needs to:
|
|
1. Detect `all()`/`any()` calls with generator expressions or list comprehensions
|
|
2. Transform them to iterate through elements and find the first failing one
|
|
3. Display that element clearly in the error message
|
|
|
|
--- a/src/_pytest/assertion/rewrite.py
|
|
+++ b/src/_pytest/assertion/rewrite.py
|
|
@@ -1,4 +1,5 @@
|
|
"""Rewrite assertion AST to produce nice error messages."""
|
|
+import copy
|
|
import ast
|
|
import errno
|
|
import functools
|
|
@@ -868,6 +869,115 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
return res, self.explanation_param(self.pop_format_context(expl_call))
|
|
|
|
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
+ """Handle Call nodes, with special treatment for all()/any()."""
|
|
+ # Try to handle all()/any() specially for better error messages
|
|
+ if self._is_unrollable_all_any(call):
|
|
+ return self._visit_all_any_call(call)
|
|
+ return self._visit_call_fallback(call)
|
|
+
|
|
+ def _is_unrollable_all_any(self, call: ast.Call) -> bool:
|
|
+ """Check if this is an all()/any() call we can unroll."""
|
|
+ if not isinstance(call.func, ast.Name):
|
|
+ return False
|
|
+ if call.func.id not in ("all", "any"):
|
|
+ return False
|
|
+ if len(call.args) != 1:
|
|
+ return False
|
|
+ if call.keywords:
|
|
+ return False
|
|
+ arg = call.args[0]
|
|
+ if not isinstance(arg, (ast.GeneratorExp, ast.ListComp)):
|
|
+ return False
|
|
+ if len(arg.generators) != 1:
|
|
+ return False
|
|
+ gen = arg.generators[0]
|
|
+ if gen.is_async:
|
|
+ return False
|
|
+ return True
|
|
+
|
|
+ def _visit_all_any_call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
+ """Handle all()/any() by unrolling to find the failing element."""
|
|
+ func_name = call.func.id
|
|
+ is_all = func_name == "all"
|
|
+ arg = call.args[0]
|
|
+ gen = arg.generators[0]
|
|
+
|
|
+ # Create tracking variables
|
|
+ result_var = self.variable()
|
|
+ fail_elem_var = self.variable()
|
|
+ fail_expl_var = self.variable()
|
|
+ iter_var = self.variable()
|
|
+
|
|
+ # Store format context
|
|
+ self.push_format_context()
|
|
+
|
|
+ # Evaluate and store the iterable
|
|
+ iter_res, iter_expl = self.visit(gen.iter)
|
|
+ self.statements.append(
|
|
+ ast.Assign([ast.Name(iter_var, ast.Store())], iter_res)
|
|
+ )
|
|
+
|
|
+ # Initialize result (True for all, False for any)
|
|
+ self.statements.append(
|
|
+ ast.Assign([ast.Name(result_var, ast.Store())], ast.Constant(is_all))
|
|
+ )
|
|
+ self.statements.append(
|
|
+ ast.Assign([ast.Name(fail_elem_var, ast.Store())], ast.Constant(None))
|
|
+ )
|
|
+ self.statements.append(
|
|
+ ast.Assign([ast.Name(fail_expl_var, ast.Store())], ast.Constant(""))
|
|
+ )
|
|
+
|
|
+ # Build the loop target and element test
|
|
+ target = copy.deepcopy(gen.target)
|
|
+ elt = copy.deepcopy(arg.elt)
|
|
+
|
|
+ # Create inner variable for condition
|
|
+ cond_var = self.variable()
|
|
+
|
|
+ # Build loop body: evaluate condition, check if failed
|
|
+ # For all(): condition must be True, so check "not cond"
|
|
+ # For any(): condition must be True to succeed, check "cond"
|
|
+
|
|
+ loop_body = []
|
|
+ # Evaluate the condition: cond_var = <elt>
|
|
+ loop_body.append(ast.Assign([ast.Name(cond_var, ast.Store())], elt))
|
|
+
|
|
+ # Check failure condition
|
|
+ if is_all:
|
|
+ test = ast.UnaryOp(ast.Not(), ast.Name(cond_var, ast.Load()))
|
|
+ new_result = ast.Constant(False)
|
|
+ else:
|
|
+ test = ast.Name(cond_var, ast.Load())
|
|
+ new_result = ast.Constant(True)
|
|
+
|
|
+ # When condition triggers, save the failing element info and break
|
|
+ fail_body = [
|
|
+ ast.Assign([ast.Name(result_var, ast.Store())], new_result),
|
|
+ ast.Assign([ast.Name(fail_elem_var, ast.Store())], copy.deepcopy(gen.target)),
|
|
+ ast.Break(),
|
|
+ ]
|
|
+
|
|
+ # Build the if check, wrapped with any comprehension conditions
|
|
+ inner_check = ast.If(test, fail_body, [])
|
|
+ for if_clause in reversed(gen.ifs):
|
|
+ inner_check = ast.If(copy.deepcopy(if_clause), [inner_check], [])
|
|
+
|
|
+ loop_body.append(inner_check)
|
|
+
|
|
+ loop = ast.For(
|
|
+ target=target,
|
|
+ iter=ast.Name(iter_var, ast.Load()),
|
|
+ body=loop_body,
|
|
+ orelse=[],
|
|
+ )
|
|
+ self.statements.append(loop)
|
|
+
|
|
+ expl = "{}({{{}}} for {{{}}} in {})".format(func_name, cond_var, fail_elem_var, iter_expl)
|
|
+ expl = self.pop_format_context(ast.Constant(expl))
|
|
+ return ast.Name(result_var, ast.Load()), self.explanation_param(expl)
|
|
+
|
|
+ def _visit_call_fallback(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
+ """Default handling for Call nodes."""
|
|
new_func, func_expl = self.visit(call.func)
|
|
arg_expls = []
|
|
new_args = []
|