76 lines
3.9 KiB
Diff
76 lines
3.9 KiB
Diff
Based on the GitHub issue, I understand the problem with SymPy's `cse()` function. The issue is that when processing `MatrixSymbol` elements (like `a[0, 0]`), the `cse()` function incorrectly extracts the `MatrixSymbol` itself (like `a`) as a common subexpression, even though it shouldn't be.
|
|
|
|
Let me analyze this issue based on my knowledge of SymPy's codebase and create a patch:
|
|
|
|
The problem is in `sympy/simplify/cse_main.py`. When `cse()` walks the expression tree and finds `MatrixSymbol[i, j]` (which is a `MatrixElement`), it incorrectly identifies the `MatrixSymbol` as a valid subexpression to extract because `MatrixSymbol` is not in the list of atomic types that should be ignored.
|
|
|
|
The fix should ensure that `MatrixSymbol` objects are treated as atomic and not extracted as common subexpressions themselves. The key is to modify the `_find_repeated` function or the tree walk to not treat `MatrixSymbol` as a subexpression candidate.
|
|
|
|
Looking at the structure of `MatrixElement` (which represents `M[i,j]`), its `.args` includes the parent `MatrixSymbol` as the first argument. The `cse` algorithm walks through all args and adds them to the subexpression candidates. But `MatrixSymbol` should be treated similarly to `Symbol` - as an atomic object that shouldn't be substituted.
|
|
|
|
Here's the patch:
|
|
|
|
--- a/sympy/simplify/cse_main.py
|
|
+++ b/sympy/simplify/cse_main.py
|
|
@@ -3,7 +3,7 @@
|
|
from sympy.core import Basic, Mul, Add, Pow, sympify, Symbol
|
|
from sympy.core.containers import Tuple, OrderedSet
|
|
from sympy.core.exprtools import factor_terms
|
|
-from sympy.core.function import _coeff_isneg
|
|
+from sympy.core.function import _coeff_isneg
|
|
from sympy.core.singleton import S
|
|
from sympy.utilities.iterables import numbered_symbols, sift, \
|
|
topological_sort, iterable
|
|
@@ -567,7 +567,9 @@ def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
|
|
"""
|
|
from sympy.matrices.expressions import MatrixExpr, MatrixSymbol, MatMul, MatAdd
|
|
+ from sympy.matrices.expressions.matexpr import MatrixElement
|
|
|
|
- if opt_subs is None:
|
|
+ if opt_subs is None:
|
|
opt_subs = dict()
|
|
|
|
## Find repeated sub-expressions
|
|
@@ -597,6 +599,10 @@ def _find_repeated(expr):
|
|
if expr in seen_subexp:
|
|
return
|
|
|
|
+ # Do not extract MatrixSymbol as a subexpression
|
|
+ if isinstance(expr, MatrixSymbol):
|
|
+ return
|
|
+
|
|
seen_subexp.add(expr)
|
|
|
|
if expr in opt_subs:
|
|
|
|
Wait, I need to look at this more carefully. Let me reason through the fix properly based on the issue description and typical SymPy code patterns.
|
|
|
|
The issue is that when you have `MatrixElement` expressions like `a[0, 0]` (where `a` is a `MatrixSymbol`), the `cse()` function walks through the expression tree and encounters the `MatrixSymbol` `a` as an argument of `MatrixElement`. It then incorrectly considers `a` as a candidate for common subexpression elimination.
|
|
|
|
The fix should prevent `MatrixSymbol` from being added to the `to_eliminate` set. This is similar to how other atomic expressions are handled.
|
|
|
|
--- a/sympy/simplify/cse_main.py
|
|
+++ b/sympy/simplify/cse_main.py
|
|
@@ -4,6 +4,7 @@ from sympy.core import Basic, Mul, Add, Pow, sympify, Symbol
|
|
from sympy.core.containers import Tuple, OrderedSet
|
|
from sympy.core.exprtools import factor_terms
|
|
from sympy.core.function import _coeff_isneg
|
|
from sympy.core.singleton import S
|
|
from sympy.utilities.iterables import numbered_symbols, sift, \
|
|
topological_sort, iterable
|
|
@@ -567,6 +568,7 @@ def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()):
|
|
from sympy.matrices.expressions import MatrixExpr, MatrixSymbol, MatMul, MatAdd
|
|
+ from sympy.matrices.expressions.matexpr import MatrixElement
|
|
|
|
if opt_subs is None:
|
|
opt_subs = dict()
|
|
@@ -599,6 +601,9 @@ def _find_repeated(expr):
|
|
if expr in seen_subexp:
|
|
return
|
|
|
|
+ if isinstance(expr, MatrixSymbol):
|
|
+ return
|
|
+
|
|
seen_subexp.add(expr)
|
|
|
|
if expr in opt_subs:
|