This is a standalone task to generate a patch for the pytest-dev/pytest repository. I need to create the patch based on my knowledge of pytest's assertion rewriting system and the issue requirements. Based on my understanding of: 1. pytest's assertion rewriting in `src/_pytest/assertion/rewrite.py` 2. The issue requesting better error messages for `all()` and `any()` 3. The AST-based approach pytest uses Here's the patch to add support for unrolling `all()` and `any()` calls: --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -868,6 +868,59 @@ class AssertionRewriter(ast.NodeVisitor): ) return make_call(res, builtin("repr"), [comparator]) + def visit_Call_all_any( + self, call: ast.Call + ) -> Optional[Tuple[ast.Name, str]]: + """Special handling for all()/any() with comprehension arguments. + + Unrolls the comprehension to find and display the first failing element. + Returns None if the call cannot be handled specially. + """ + if not isinstance(call.func, ast.Name): + return None + if call.func.id not in ("all", "any"): + return None + if len(call.args) != 1: + return None + if call.keywords: + return None + + arg = call.args[0] + if not isinstance(arg, (ast.GeneratorExp, ast.ListComp)): + return None + + # Only handle simple single-for comprehensions + if len(arg.generators) != 1: + return None + gen = arg.generators[0] + if gen.is_async: + return None + + func_name = call.func.id + is_all = func_name == "all" + + # Construct a new call that iterates and finds the failing element + # We use a helper from _pytest._code.assertion_helper + from_mod = ast.ImportFrom( + module="_pytest.assertion._unroll", + names=[ast.alias(name="_unroll_all_any", asname=None)], + level=0, + ) + self.statements.append(from_mod) + + # Build: _unroll_all_any(, is_all=True/False) + helper_call = ast.Call( + func=ast.Name("_unroll_all_any", ast.Load()), + args=[arg], + keywords=[ast.keyword(arg="is_all", value=ast.Constant(is_all))], + ) + + res = self.variable() + self.statements.append(ast.Assign([ast.Name(res, ast.Store())], helper_call)) + + return ast.Name(res, ast.Load()), "%(" + res + ")s" + def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: + # Try special handling for all()/any() first + result = self.visit_Call_all_any(call) + if result is not None: + return result + new_func, func_expl = self.visit(call.func) arg_expls = [] new_args = [] Wait, I should provide the complete patch. Based on my analysis of pytest's assertion rewriting (the AssertionRewriter class in `src/_pytest/assertion/rewrite.py`), I need to: 1. Modify the `visit_Call` method to detect `all()`/`any()` calls with comprehensions 2. Rewrite them to iterate through and find the first failing element 3. Display that element in the error message Here is the complete patch: --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -868,6 +868,119 @@ class AssertionRewriter(ast.NodeVisitor): return res, self.explanation_param(self.pop_format_context(expl_call)) def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: + """Handle Call nodes, with special treatment for all()/any().""" + # Try to handle all()/any() specially for better error messages + if ( + isinstance(call.func, ast.Name) + and call.func.id in ("all", "any") + and len(call.args) == 1 + and not call.keywords + and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)) + ): + return self._visit_all_any(call) + + return self._visit_call_default(call) + + def _visit_all_any(self, call: ast.Call) -> Tuple[ast.Name, str]: + """Handle all()/any() calls with comprehension arguments. + + Unrolls the comprehension to iterate and find the first failing element, + providing a more useful error message. + """ + func_name = call.func.id + is_all = func_name == "all" + arg = call.args[0] + + # Only handle simple single-for comprehensions + generators = arg.generators + if len(generators) != 1: + return self._visit_call_default(call) + + gen = generators[0] + if gen.is_async: + return self._visit_call_default(call) + + # Create variables for the iteration + iter_var = self.variable() + result_var = self.variable() + fail_elem_var = self.variable() + fail_cond_var = self.variable() + + # Evaluate the iterable and store it + iter_res, iter_expl = self.visit(gen.iter) + self.statements.append( + ast.Assign([ast.Name(iter_var, ast.Store())], iter_res) + ) + + # Initialize result to True for all(), False for any() + self.statements.append( + ast.Assign( + [ast.Name(result_var, ast.Store())], + ast.Constant(is_all), + ) + ) + # Initialize fail tracking variables + self.statements.append( + ast.Assign([ast.Name(fail_elem_var, ast.Store())], ast.Constant(None)) + ) + self.statements.append( + ast.Assign([ast.Name(fail_cond_var, ast.Store())], ast.Constant(None)) + ) + + # Build the for loop body that checks each element + import copy + target_copy = copy.deepcopy(gen.target) + elt_copy = copy.deepcopy(arg.elt) + + # Store the condition result + cond_var = self.variable() + + # For all(): if not cond, fail + # For any(): if cond, succeed + if is_all: + test = ast.UnaryOp(ast.Not(), ast.Name(cond_var, ast.Load())) + fail_result = ast.Constant(False) + else: + test = ast.Name(cond_var, ast.Load()) + fail_result = ast.Constant(True) + + # Build: cond_var = + # Apply if clauses from the comprehension + loop_body = [ + ast.Assign([ast.Name(cond_var, ast.Store())], elt_copy) + ] + + # If condition fails, record and break + fail_block = [ + ast.Assign([ast.Name(result_var, ast.Store())], fail_result), + ast.Assign([ast.Name(fail_elem_var, ast.Store())], copy.deepcopy(gen.target)), + ast.Assign([ast.Name(fail_cond_var, ast.Store())], ast.Name(cond_var, ast.Load())), + ast.Break(), + ] + + check_block = ast.If(test, fail_block, []) + + # Handle if clauses in the comprehension + for if_clause in gen.ifs: + check_block = ast.If(copy.deepcopy(if_clause), [check_block], []) + + loop_body.append(check_block) + + loop = ast.For( + target=target_copy, + iter=ast.Name(iter_var, ast.Load()), + body=loop_body, + orelse=[], + ) + self.statements.append(loop) + + # Format the explanation to show the failing element + expl = "{}({} for {} in {})".format(func_name, "%(" + fail_cond_var + ")s", "%(" + fail_elem_var + ")s", iter_expl) + + return ast.Name(result_var, ast.Load()), expl + + def _visit_call_default(self, call: ast.Call) -> Tuple[ast.Name, str]: + """Default handling for Call nodes.""" new_func, func_expl = self.visit(call.func) arg_expls = [] new_args = [] Hmm, I realize I should look at the actual structure of pytest more carefully. The issue is from around pytest 4.4.0 era. Let me reconsider the approach. Looking at the issue more carefully: - The user wants better error messages when using `all()` or `any()` with comprehensions - Currently: `assert all([is_even(number) for number in even_stevens])` shows `all([False, False, False, ...])` - Desired: Show which specific element failed, like the for loop version does The patch needs to: 1. Detect `all()`/`any()` calls with generator expressions or list comprehensions 2. Transform them to iterate through elements and find the first failing one 3. Display that element clearly in the error message --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1,4 +1,5 @@ """Rewrite assertion AST to produce nice error messages.""" +import copy import ast import errno import functools @@ -868,6 +869,115 @@ class AssertionRewriter(ast.NodeVisitor): return res, self.explanation_param(self.pop_format_context(expl_call)) def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: + """Handle Call nodes, with special treatment for all()/any().""" + # Try to handle all()/any() specially for better error messages + if self._is_unrollable_all_any(call): + return self._visit_all_any_call(call) + return self._visit_call_fallback(call) + + def _is_unrollable_all_any(self, call: ast.Call) -> bool: + """Check if this is an all()/any() call we can unroll.""" + if not isinstance(call.func, ast.Name): + return False + if call.func.id not in ("all", "any"): + return False + if len(call.args) != 1: + return False + if call.keywords: + return False + arg = call.args[0] + if not isinstance(arg, (ast.GeneratorExp, ast.ListComp)): + return False + if len(arg.generators) != 1: + return False + gen = arg.generators[0] + if gen.is_async: + return False + return True + + def _visit_all_any_call(self, call: ast.Call) -> Tuple[ast.Name, str]: + """Handle all()/any() by unrolling to find the failing element.""" + func_name = call.func.id + is_all = func_name == "all" + arg = call.args[0] + gen = arg.generators[0] + + # Create tracking variables + result_var = self.variable() + fail_elem_var = self.variable() + fail_expl_var = self.variable() + iter_var = self.variable() + + # Store format context + self.push_format_context() + + # Evaluate and store the iterable + iter_res, iter_expl = self.visit(gen.iter) + self.statements.append( + ast.Assign([ast.Name(iter_var, ast.Store())], iter_res) + ) + + # Initialize result (True for all, False for any) + self.statements.append( + ast.Assign([ast.Name(result_var, ast.Store())], ast.Constant(is_all)) + ) + self.statements.append( + ast.Assign([ast.Name(fail_elem_var, ast.Store())], ast.Constant(None)) + ) + self.statements.append( + ast.Assign([ast.Name(fail_expl_var, ast.Store())], ast.Constant("")) + ) + + # Build the loop target and element test + target = copy.deepcopy(gen.target) + elt = copy.deepcopy(arg.elt) + + # Create inner variable for condition + cond_var = self.variable() + + # Build loop body: evaluate condition, check if failed + # For all(): condition must be True, so check "not cond" + # For any(): condition must be True to succeed, check "cond" + + loop_body = [] + # Evaluate the condition: cond_var = + loop_body.append(ast.Assign([ast.Name(cond_var, ast.Store())], elt)) + + # Check failure condition + if is_all: + test = ast.UnaryOp(ast.Not(), ast.Name(cond_var, ast.Load())) + new_result = ast.Constant(False) + else: + test = ast.Name(cond_var, ast.Load()) + new_result = ast.Constant(True) + + # When condition triggers, save the failing element info and break + fail_body = [ + ast.Assign([ast.Name(result_var, ast.Store())], new_result), + ast.Assign([ast.Name(fail_elem_var, ast.Store())], copy.deepcopy(gen.target)), + ast.Break(), + ] + + # Build the if check, wrapped with any comprehension conditions + inner_check = ast.If(test, fail_body, []) + for if_clause in reversed(gen.ifs): + inner_check = ast.If(copy.deepcopy(if_clause), [inner_check], []) + + loop_body.append(inner_check) + + loop = ast.For( + target=target, + iter=ast.Name(iter_var, ast.Load()), + body=loop_body, + orelse=[], + ) + self.statements.append(loop) + + expl = "{}({{{}}} for {{{}}} in {})".format(func_name, cond_var, fail_elem_var, iter_expl) + expl = self.pop_format_context(ast.Constant(expl)) + return ast.Name(result_var, ast.Load()), self.explanation_param(expl) + + def _visit_call_fallback(self, call: ast.Call) -> Tuple[ast.Name, str]: + """Default handling for Call nodes.""" new_func, func_expl = self.visit(call.func) arg_expls = [] new_args = []