Files
app-store-optimization/skills/loki-mode/benchmarks/results/2026-01-05-01-24-17/swebench-patches/sympy__sympy-22840.patch

76 lines
3.9 KiB
Diff

Based on the GitHub issue, I understand the problem with SymPy's `cse()` function. The issue is that when processing `MatrixSymbol` elements (like `a[0, 0]`), the `cse()` function incorrectly extracts the `MatrixSymbol` itself (like `a`) as a common subexpression, even though it shouldn't be.
Let me analyze this issue based on my knowledge of SymPy's codebase and create a patch:
The problem is in `sympy/simplify/cse_main.py`. When `cse()` walks the expression tree and finds `MatrixSymbol[i, j]` (which is a `MatrixElement`), it incorrectly identifies the `MatrixSymbol` as a valid subexpression to extract because `MatrixSymbol` is not in the list of atomic types that should be ignored.
The fix should ensure that `MatrixSymbol` objects are treated as atomic and not extracted as common subexpressions themselves. The key is to modify the `_find_repeated` function or the tree walk to not treat `MatrixSymbol` as a subexpression candidate.
Looking at the structure of `MatrixElement` (which represents `M[i,j]`), its `.args` includes the parent `MatrixSymbol` as the first argument. The `cse` algorithm walks through all args and adds them to the subexpression candidates. But `MatrixSymbol` should be treated similarly to `Symbol` - as an atomic object that shouldn't be substituted.
Here's the patch:
--- a/sympy/simplify/cse_main.py
+++ b/sympy/simplify/cse_main.py
@@ -3,7 +3,7 @@
from sympy.core import Basic, Mul, Add, Pow, sympify, Symbol
from sympy.core.containers import Tuple, OrderedSet
from sympy.core.exprtools import factor_terms
-from sympy.core.function import _coeff_isneg
+from sympy.core.function import _coeff_isneg
from sympy.core.singleton import S
from sympy.utilities.iterables import numbered_symbols, sift, \
topological_sort, iterable
@@ -567,7 +567,9 @@ def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
"""
from sympy.matrices.expressions import MatrixExpr, MatrixSymbol, MatMul, MatAdd
+ from sympy.matrices.expressions.matexpr import MatrixElement
- if opt_subs is None:
+ if opt_subs is None:
opt_subs = dict()
## Find repeated sub-expressions
@@ -597,6 +599,10 @@ def _find_repeated(expr):
if expr in seen_subexp:
return
+ # Do not extract MatrixSymbol as a subexpression
+ if isinstance(expr, MatrixSymbol):
+ return
+
seen_subexp.add(expr)
if expr in opt_subs:
Wait, I need to look at this more carefully. Let me reason through the fix properly based on the issue description and typical SymPy code patterns.
The issue is that when you have `MatrixElement` expressions like `a[0, 0]` (where `a` is a `MatrixSymbol`), the `cse()` function walks through the expression tree and encounters the `MatrixSymbol` `a` as an argument of `MatrixElement`. It then incorrectly considers `a` as a candidate for common subexpression elimination.
The fix should prevent `MatrixSymbol` from being added to the `to_eliminate` set. This is similar to how other atomic expressions are handled.
--- a/sympy/simplify/cse_main.py
+++ b/sympy/simplify/cse_main.py
@@ -4,6 +4,7 @@ from sympy.core import Basic, Mul, Add, Pow, sympify, Symbol
from sympy.core.containers import Tuple, OrderedSet
from sympy.core.exprtools import factor_terms
from sympy.core.function import _coeff_isneg
from sympy.core.singleton import S
from sympy.utilities.iterables import numbered_symbols, sift, \
topological_sort, iterable
@@ -567,6 +568,7 @@ def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
from sympy.matrices.expressions import MatrixExpr, MatrixSymbol, MatMul, MatAdd
+ from sympy.matrices.expressions.matexpr import MatrixElement
if opt_subs is None:
opt_subs = dict()
@@ -599,6 +601,9 @@ def _find_repeated(expr):
if expr in seen_subexp:
return
+ if isinstance(expr, MatrixSymbol):
+ return
+
seen_subexp.add(expr)
if expr in opt_subs: