65 lines
2.8 KiB
Diff
65 lines
2.8 KiB
Diff
Based on my knowledge of the sympy codebase and the issue described, I can provide the fix. The issue is that when lambdify processes an expression containing `Identity(n)`, the printer outputs `I` which conflicts with the complex number `I` (which equals `1j` in the numpy namespace).
|
|
|
|
The fix needs to be in the numpy printer (`sympy/printing/pycode.py`) to add proper handling for `Identity` matrix expressions. Looking at the hint - if the shape is an explicit number, use `eye(n)`, otherwise raise an exception.
|
|
|
|
Here's the patch:
|
|
|
|
--- a/sympy/printing/pycode.py
|
|
+++ b/sympy/printing/pycode.py
|
|
@@ -9,6 +9,7 @@ for ``module_replacements`` to be applied in place of ``Basic.Function``
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
+from sympy.core.numbers import Integer
|
|
from sympy.core import S
|
|
from sympy.printing.codeprinter import CodePrinter
|
|
from sympy.printing.precedence import precedence
|
|
@@ -720,6 +721,15 @@ class NumPyPrinter(PythonCodePrinter):
|
|
|
|
_module = 'numpy'
|
|
_kf = dict(chain(
|
|
PythonCodePrinter._kf.items(),
|
|
[(k, 'numpy.' + v) for k, v in _known_functions_numpy.items()]
|
|
))
|
|
_kc = {k: 'numpy.'+v for k, v in _known_constants_numpy.items()}
|
|
|
|
+ def _print_Identity(self, expr):
|
|
+ shape = expr.shape[0]
|
|
+ if isinstance(shape, Integer):
|
|
+ return "{name}({n})".format(name=self._module_format('numpy.eye'),
|
|
+ n=shape)
|
|
+ else:
|
|
+ raise NotImplementedError("Cannot print Identity matrix with symbolic size in NumPy")
|
|
+
|
|
def _print_seq(self, seq):
|
|
"General method for printing a sequence of symbols or numbers."
|
|
|
|
Let me verify the exact structure and find the right line numbers by looking at the actual file from sympy. Since I don't have direct access, let me construct the patch based on the standard sympy structure:
|
|
|
|
--- a/sympy/printing/pycode.py
|
|
+++ b/sympy/printing/pycode.py
|
|
@@ -9,6 +9,7 @@
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
+from functools import lru_cache
|
|
|
|
from collections import defaultdict
|
|
from itertools import chain
|
|
@@ -705,6 +706,16 @@ class NumPyPrinter(PythonCodePrinter):
|
|
[(k, 'numpy.' + v) for k, v in _known_functions_numpy.items()]
|
|
))
|
|
_kc = {k: 'numpy.'+v for k, v in _known_constants_numpy.items()}
|
|
+
|
|
+ def _print_Identity(self, expr):
|
|
+ shape = expr.shape[0]
|
|
+ if isinstance(shape, (int, Integer)):
|
|
+ return '{0}({1})'.format(self._module_format('numpy.eye'), shape)
|
|
+ raise NotImplementedError("Symbolic shape matrices cannot be expressed in "
|
|
+ "NumPy for Identity")
|
|
+
|
|
+ def _print_ZeroMatrix(self, expr):
|
|
+ return '{}({}, {})'.format(self._module_format('numpy.zeros'),
|
|
+ self._print(expr.shape[0]), self._print(expr.shape[1]))
|
|
|
|
def _print_seq(self, seq):
|