350 lines
13 KiB
Diff
350 lines
13 KiB
Diff
# pytest-dev__pytest-5103
|
|
# Loki Mode Multi-Agent Patch
|
|
# Attempts: 1
|
|
|
|
Since I cannot locate the pytest repository in the expected location, I need to generate the patch based on my knowledge of pytest's assertion rewriting module. Let me create the patch for the standard pytest source structure:
|
|
|
|
```
|
|
--- a/src/_pytest/assertion/rewrite.py
|
|
+++ b/src/_pytest/assertion/rewrite.py
|
|
@@ -868,6 +868,117 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
return self.visit(new_call)
|
|
return None
|
|
|
|
+ def _is_builtin_all_any(self, call_node: ast.Call) -> bool:
|
|
+ """Check if a call node is a call to the builtin all() or any() function."""
|
|
+ if isinstance(call_node.func, ast.Name) and call_node.func.id in ("all", "any"):
|
|
+ return True
|
|
+ return False
|
|
+
|
|
+ def _get_generator_from_call(self, call_node: ast.Call):
|
|
+ """Extract generator expression from all/any call, if present."""
|
|
+ if len(call_node.args) != 1:
|
|
+ return None
|
|
+ arg = call_node.args[0]
|
|
+ if isinstance(arg, ast.GeneratorExp):
|
|
+ return arg
|
|
+ return None
|
|
+
|
|
+ def _is_simple_generator(self, genexp: ast.GeneratorExp) -> bool:
|
|
+ """Check if generator has a single 'for' clause without 'if' conditions."""
|
|
+ if len(genexp.generators) != 1:
|
|
+ return False
|
|
+ comp = genexp.generators[0]
|
|
+ # Only handle simple cases without nested generators or complex conditions
|
|
+ if comp.ifs:
|
|
+ return False
|
|
+ if not isinstance(comp.iter, (ast.Name, ast.Attribute, ast.Call, ast.Subscript)):
|
|
+ return False
|
|
+ return True
|
|
+
|
|
+ def _rewrite_all_any(self, call_node: ast.Call) -> ast.expr:
|
|
+ """
|
|
+ Rewrite all(pred(x) for x in iter) to provide better assertion messages.
|
|
+
|
|
+ For all(): Find the first element where predicate is False
|
|
+ For any(): Show that no element satisfied the predicate
|
|
+ """
|
|
+ func_name = call_node.func.id # "all" or "any"
|
|
+ genexp = self._get_generator_from_call(call_node)
|
|
+
|
|
+ if genexp is None or not self._is_simple_generator(genexp):
|
|
+ return None
|
|
+
|
|
+ comp = genexp.generators[0]
|
|
+ target = comp.target # The loop variable (e.g., 'x' in 'for x in iter')
|
|
+ iter_node = comp.iter # The iterable (e.g., 'iter' in 'for x in iter')
|
|
+ elt = genexp.elt # The predicate expression (e.g., 'pred(x)')
|
|
+
|
|
+ # Create a unique variable name to store the failing element
|
|
+ fail_var = self.variable()
|
|
+
|
|
+ # Visit the iterable to get explanation
|
|
+ iter_res, iter_expl = self.visit(iter_node)
|
|
+
|
|
+ # For all(): we want to find first False element
|
|
+ # For any(): we want to confirm no True element exists
|
|
+ #
|
|
+ # Generate: @py_assert_N = next((x for x in iter if not pred(x)), _sentinel)
|
|
+ # Then check: @py_assert_N is _sentinel (for all, means all passed)
|
|
+
|
|
+ # Create inner generator that finds failing element
|
|
+ if func_name == "all":
|
|
+ # Find first element where predicate is False
|
|
+ inner_test = ast.UnaryOp(op=ast.Not(), operand=elt)
|
|
+ else: # any
|
|
+ # Find first element where predicate is True
|
|
+ inner_test = elt
|
|
+
|
|
+ inner_gen = ast.GeneratorExp(
|
|
+ elt=target if isinstance(target, ast.Name) else ast.Name(id='_', ctx=ast.Load()),
|
|
+ generators=[ast.comprehension(
|
|
+ target=target,
|
|
+ iter=iter_res,
|
|
+ ifs=[inner_test],
|
|
+ is_async=0
|
|
+ )]
|
|
+ )
|
|
+
|
|
+ # Create a unique sentinel value
|
|
+ sentinel_var = self.variable()
|
|
+ sentinel_assign = ast.Assign(
|
|
+ targets=[ast.Name(id=sentinel_var, ctx=ast.Store())],
|
|
+ value=ast.Call(
|
|
+ func=ast.Name(id='object', ctx=ast.Load()),
|
|
+ args=[],
|
|
+ keywords=[]
|
|
+ )
|
|
+ )
|
|
+ self.statements.append(sentinel_assign)
|
|
+
|
|
+ # Create: fail_var = next(inner_gen, sentinel)
|
|
+ next_call = ast.Call(
|
|
+ func=ast.Name(id='next', ctx=ast.Load()),
|
|
+ args=[inner_gen, ast.Name(id=sentinel_var, ctx=ast.Load())],
|
|
+ keywords=[]
|
|
+ )
|
|
+
|
|
+ fail_assign = ast.Assign(
|
|
+ targets=[ast.Name(id=fail_var, ctx=ast.Store())],
|
|
+ value=next_call
|
|
+ )
|
|
+ self.statements.append(fail_assign)
|
|
+
|
|
+ # For all(): result is True if fail_var is sentinel (no failures found)
|
|
+ # For any(): result is True if fail_var is not sentinel (found a match)
|
|
+ if func_name == "all":
|
|
+ result = ast.Compare(
|
|
+ left=ast.Name(id=fail_var, ctx=ast.Load()),
|
|
+ ops=[ast.Is()],
|
|
+ comparators=[ast.Name(id=sentinel_var, ctx=ast.Load())]
|
|
+ )
|
|
+ else: # any
|
|
+ result = ast.Compare(
|
|
+ left=ast.Name(id=fail_var, ctx=ast.Load()),
|
|
+ ops=[ast.IsNot()],
|
|
+ comparators=[ast.Name(id=sentinel_var, ctx=ast.Load())]
|
|
+ )
|
|
+
|
|
+ return result, fail_var, iter_expl
|
|
+
|
|
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
new_func, func_expl = self.visit(call.func)
|
|
arg_expls = []
|
|
@@ -879,6 +990,28 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
if res is not None:
|
|
return res
|
|
|
|
+ # Special handling for all() and any() with generator expressions
|
|
+ if self._is_builtin_all_any(call):
|
|
+ genexp = self._get_generator_from_call(call)
|
|
+ if genexp is not None and self._is_simple_generator(genexp):
|
|
+ comp = genexp.generators[0]
|
|
+ target = comp.target
|
|
+ iter_node = comp.iter
|
|
+ elt = genexp.elt
|
|
+
|
|
+ # Store the iterable in a variable first
|
|
+ iter_res, iter_expl = self.visit(iter_node)
|
|
+
|
|
+ # Create a comprehension that collects (element, predicate_result) pairs
|
|
+ # This gives us visibility into what failed
|
|
+ func_name = call.func.id
|
|
+
|
|
+ # Generate explanation showing the first failing/passing element
|
|
+ target_name = target.id if isinstance(target, ast.Name) else "_"
|
|
+ expl = "{}({} for {} in {})".format(
|
|
+ func_name, ast.unparse(elt), target_name, iter_expl)
|
|
+
|
|
new_args = []
|
|
new_kwargs = []
|
|
for arg in call.args:
|
|
```
|
|
|
|
Actually, let me create a proper complete patch based on the standard pytest assertion rewriting architecture:
|
|
|
|
--- a/src/_pytest/assertion/rewrite.py
|
|
+++ b/src/_pytest/assertion/rewrite.py
|
|
@@ -868,6 +868,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
return self.visit(new_call)
|
|
return None
|
|
|
|
+ # NOTE: all/any handling is done in visit_Call_35 for Python 3.5+
|
|
+
|
|
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
|
|
new_func, func_expl = self.visit(call.func)
|
|
arg_expls = []
|
|
@@ -879,6 +881,38 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
if res is not None:
|
|
return res
|
|
|
|
+ # Special handling for all() and any() with generator expressions
|
|
+ if (
|
|
+ isinstance(call.func, ast.Name)
|
|
+ and call.func.id in ("all", "any")
|
|
+ and len(call.args) == 1
|
|
+ and not call.keywords
|
|
+ and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp))
|
|
+ ):
|
|
+ return self._rewrite_all_any_call(call)
|
|
+
|
|
new_args = []
|
|
new_kwargs = []
|
|
for arg in call.args:
|
|
@@ -914,6 +948,89 @@ class AssertionRewriter(ast.NodeVisitor):
|
|
res = self.assign(call)
|
|
return res, outer_expl
|
|
|
|
+ def _rewrite_all_any_call(
|
|
+ self, call: ast.Call
|
|
+ ) -> Tuple[ast.Name, str]:
|
|
+ """Rewrite all()/any() calls to provide better assertion messages.
|
|
+
|
|
+ Instead of just showing "all(<generator>)" or the full list of results,
|
|
+ this finds and displays the first failing element for all() or first
|
|
+ passing element for any().
|
|
+ """
|
|
+ func_name = call.func.id # "all" or "any"
|
|
+ arg = call.args[0]
|
|
+
|
|
+ # Extract components from generator/comprehension
|
|
+ if isinstance(arg, ast.GeneratorExp):
|
|
+ elt = arg.elt
|
|
+ generators = arg.generators
|
|
+ else: # ListComp
|
|
+ elt = arg.elt
|
|
+ generators = arg.generators
|
|
+
|
|
+ # Only handle simple cases with single for clause
|
|
+ if len(generators) != 1:
|
|
+ # Fall back to default behavior for complex generators
|
|
+ return self._visit_call_default(call)
|
|
+
|
|
+ comp = generators[0]
|
|
+ target = comp.target
|
|
+ iter_node = comp.iter
|
|
+
|
|
+ # Store iterable result
|
|
+ iter_res, iter_expl = self.visit(iter_node)
|
|
+
|
|
+ # Create a variable to iterate over
|
|
+ iter_copy = self.variable()
|
|
+ self.statements.append(
|
|
+ ast.Assign(
|
|
+ targets=[ast.Name(iter_copy, ast.Store())],
|
|
+ value=ast.Call(
|
|
+ func=ast.Name("list", ast.Load()),
|
|
+ args=[iter_res],
|
|
+ keywords=[],
|
|
+ ),
|
|
+ )
|
|
+ )
|
|
+
|
|
+ # For each element, check predicate and find first failure/success
|
|
+ result_var = self.variable()
|
|
+ fail_elem_var = self.variable()
|
|
+
|
|
+ # Initialize: result = True for all, False for any
|
|
+ # fail_elem = None
|
|
+ init_val = ast.Constant(value=(func_name == "all"))
|
|
+ self.statements.append(
|
|
+ ast.Assign(
|
|
+ targets=[ast.Name(result_var, ast.Store())],
|
|
+ value=init_val,
|
|
+ )
|
|
+ )
|
|
+ self.statements.append(
|
|
+ ast.Assign(
|
|
+ targets=[ast.Name(fail_elem_var, ast.Store())],
|
|
+ value=ast.Constant(value=None),
|
|
+ )
|
|
+ )
|
|
+
|
|
+ # Build the loop that finds failing element
|
|
+ # For all: find first False, for any: find first True
|
|
+ if func_name == "all":
|
|
+ # Check if predicate is False
|
|
+ check_pred = ast.UnaryOp(ast.Not(), elt)
|
|
+ else:
|
|
+ check_pred = elt
|
|
+
|
|
+ # Create loop body that sets result and fail_elem, then breaks
|
|
+ loop_body = [
|
|
+ ast.If(
|
|
+ test=check_pred,
|
|
+ body=[
|
|
+ ast.Assign(
|
|
+ targets=[ast.Name(result_var, ast.Store())],
|
|
+ value=ast.Constant(value=(func_name != "all")),
|
|
+ ),
|
|
+ ast.Assign(
|
|
+ targets=[ast.Name(fail_elem_var, ast.Store())],
|
|
+ value=target if isinstance(target, ast.Name) else ast.Name("_", ast.Load()),
|
|
+ ),
|
|
+ ast.Break(),
|
|
+ ],
|
|
+ orelse=[],
|
|
+ )
|
|
+ ]
|
|
+
|
|
+ # Add any if-conditions from the generator
|
|
+ for if_clause in comp.ifs:
|
|
+ loop_body = [ast.If(test=if_clause, body=loop_body, orelse=[])]
|
|
+
|
|
+ for_loop = ast.For(
|
|
+ target=target,
|
|
+ iter=ast.Name(iter_copy, ast.Load()),
|
|
+ body=loop_body,
|
|
+ orelse=[],
|
|
+ )
|
|
+ self.statements.append(for_loop)
|
|
+
|
|
+ # Build explanation
|
|
+ target_repr = ast.unparse(target) if hasattr(ast, "unparse") else "_"
|
|
+ elt_repr = ast.unparse(elt) if hasattr(ast, "unparse") else "..."
|
|
+
|
|
+ expl_template = "{}({} for {} in {})".format(
|
|
+ func_name, elt_repr, target_repr, iter_expl
|
|
+ )
|
|
+
|
|
+ # Create format call to show failing element
|
|
+ # "all(is_even(x) for x in nums)\n {func_name} failed at element: {fail_elem}"
|
|
+ format_expl = ast.BinOp(
|
|
+ left=ast.Constant(value=expl_template + "\n"),
|
|
+ op=ast.Add(),
|
|
+ right=ast.IfExp(
|
|
+ test=ast.Compare(
|
|
+ left=ast.Name(fail_elem_var, ast.Load()),
|
|
+ ops=[ast.IsNot()],
|
|
+ comparators=[ast.Constant(value=None)],
|
|
+ ),
|
|
+ body=ast.BinOp(
|
|
+ left=ast.Constant(
|
|
+ value=" {} first failing element: ".format(func_name)
|
|
+ if func_name == "all"
|
|
+ else " {} first passing element: ".format(func_name)
|
|
+ ),
|
|
+ op=ast.Add(),
|
|
+ right=ast.Call(
|
|
+ func=ast.Name("repr", ast.Load()),
|
|
+ args=[ast.Name(fail_elem_var, ast.Load())],
|
|
+ keywords=[],
|
|
+ ),
|
|
+ ),
|
|
+ orelse=ast.Constant(value=""),
|
|
+ ),
|
|
+ )
|
|
+
|
|
+ # Store explanation in a variable
|
|
+ expl_var = self.variable()
|
|
+ self.statements.append(
|
|
+ ast.Assign(
|
|
+ targets=[ast.Name(expl_var, ast.Store())],
|
|
+ value=format_expl,
|
|
+ )
|
|
+ )
|
|
+
|
|
+ res = ast.Name(result_var, ast.Load())
|
|
+ self.expl_stmts.append(ast.Name(expl_var, ast.Load()))
|
|
+
|
|
+ return res, expl_template
|