MTerryJack commited on
Commit
8e459aa
·
verified ·
1 Parent(s): 11fc647

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.13/site-packages/sympy/algebras/tests/__init__.py +0 -0
  2. .venv/lib/python3.13/site-packages/sympy/algebras/tests/test_quaternion.py +437 -0
  3. .venv/lib/python3.13/site-packages/sympy/logic/algorithms/__init__.py +0 -0
  4. .venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll.py +308 -0
  5. .venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll2.py +688 -0
  6. .venv/lib/python3.13/site-packages/sympy/logic/algorithms/lra_theory.py +912 -0
  7. .venv/lib/python3.13/site-packages/sympy/logic/algorithms/minisat22_wrapper.py +46 -0
  8. .venv/lib/python3.13/site-packages/sympy/logic/algorithms/pycosat_wrapper.py +41 -0
  9. .venv/lib/python3.13/site-packages/sympy/logic/algorithms/z3_wrapper.py +115 -0
  10. .venv/lib/python3.13/site-packages/sympy/logic/tests/__init__.py +0 -0
  11. .venv/lib/python3.13/site-packages/sympy/logic/tests/test_boolalg.py +1367 -0
  12. .venv/lib/python3.13/site-packages/sympy/logic/tests/test_dimacs.py +234 -0
  13. .venv/lib/python3.13/site-packages/sympy/logic/tests/test_inference.py +396 -0
  14. .venv/lib/python3.13/site-packages/sympy/logic/tests/test_lra_theory.py +440 -0
  15. .venv/lib/python3.13/site-packages/sympy/logic/utilities/__init__.py +3 -0
  16. .venv/lib/python3.13/site-packages/sympy/logic/utilities/dimacs.py +69 -0
  17. .venv/lib/python3.13/site-packages/sympy/printing/tests/__init__.py +0 -0
  18. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_aesaracode.py +633 -0
  19. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_c.py +888 -0
  20. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_codeprinter.py +77 -0
  21. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_conventions.py +116 -0
  22. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_cupy.py +56 -0
  23. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_cxx.py +86 -0
  24. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_dot.py +134 -0
  25. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_glsl.py +998 -0
  26. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_gtk.py +18 -0
  27. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_jax.py +370 -0
  28. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_jscode.py +396 -0
  29. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_julia.py +390 -0
  30. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_lambdarepr.py +246 -0
  31. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_latex.py +0 -0
  32. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_maple.py +381 -0
  33. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathematica.py +287 -0
  34. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathml.py +0 -0
  35. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_numpy.py +381 -0
  36. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_octave.py +515 -0
  37. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_precedence.py +128 -0
  38. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_preview.py +38 -0
  39. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_pycode.py +493 -0
  40. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_python.py +203 -0
  41. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_repr.py +382 -0
  42. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_rust.py +363 -0
  43. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_smtlib.py +553 -0
  44. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_str.py +1206 -0
  45. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_tableform.py +182 -0
  46. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_tensorflow.py +493 -0
  47. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_theanocode.py +639 -0
  48. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_torch.py +531 -0
  49. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_tree.py +196 -0
  50. .venv/lib/python3.13/site-packages/sympy/solvers/benchmarks/__init__.py +0 -0
.venv/lib/python3.13/site-packages/sympy/algebras/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/algebras/tests/test_quaternion.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.testing.pytest import slow
2
+ from sympy.core.function import diff
3
+ from sympy.core.function import expand
4
+ from sympy.core.numbers import (E, I, Rational, pi)
5
+ from sympy.core.singleton import S
6
+ from sympy.core.symbol import (Symbol, symbols)
7
+ from sympy.functions.elementary.complexes import (Abs, conjugate, im, re, sign)
8
+ from sympy.functions.elementary.exponential import log
9
+ from sympy.functions.elementary.miscellaneous import sqrt
10
+ from sympy.functions.elementary.trigonometric import (acos, asin, cos, sin, atan2, atan)
11
+ from sympy.integrals.integrals import integrate
12
+ from sympy.matrices.dense import Matrix
13
+ from sympy.simplify import simplify
14
+ from sympy.simplify.trigsimp import trigsimp
15
+ from sympy.algebras.quaternion import Quaternion
16
+ from sympy.testing.pytest import raises
17
+ import math
18
+ from itertools import permutations, product
19
+
20
+ w, x, y, z = symbols('w:z')
21
+ phi = symbols('phi')
22
+
23
+ def test_quaternion_construction():
24
+ q = Quaternion(w, x, y, z)
25
+ assert q + q == Quaternion(2*w, 2*x, 2*y, 2*z)
26
+
27
+ q2 = Quaternion.from_axis_angle((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3),
28
+ pi*Rational(2, 3))
29
+ assert q2 == Quaternion(S.Half, S.Half,
30
+ S.Half, S.Half)
31
+
32
+ M = Matrix([[cos(phi), -sin(phi), 0], [sin(phi), cos(phi), 0], [0, 0, 1]])
33
+ q3 = trigsimp(Quaternion.from_rotation_matrix(M))
34
+ assert q3 == Quaternion(
35
+ sqrt(2)*sqrt(cos(phi) + 1)/2, 0, 0, sqrt(2 - 2*cos(phi))*sign(sin(phi))/2)
36
+
37
+ nc = Symbol('nc', commutative=False)
38
+ raises(ValueError, lambda: Quaternion(w, x, nc, z))
39
+
40
+
41
+ def test_quaternion_construction_norm():
42
+ q1 = Quaternion(*symbols('a:d'))
43
+
44
+ q2 = Quaternion(w, x, y, z)
45
+ assert expand((q1*q2).norm()**2 - (q1.norm()**2 * q2.norm()**2)) == 0
46
+
47
+ q3 = Quaternion(w, x, y, z, norm=1)
48
+ assert (q1 * q3).norm() == q1.norm()
49
+
50
+
51
+ def test_issue_25254():
52
+ # calculating the inverse cached the norm which caused problems
53
+ # when multiplying
54
+ p = Quaternion(1, 0, 0, 0)
55
+ q = Quaternion.from_axis_angle((1, 1, 1), 3 * math.pi/4)
56
+ qi = q.inverse() # this operation cached the norm
57
+ test = q * p * qi
58
+ assert ((test - p).norm() < 1E-10)
59
+
60
+
61
+ def test_to_and_from_Matrix():
62
+ q = Quaternion(w, x, y, z)
63
+ q_full = Quaternion.from_Matrix(q.to_Matrix())
64
+ q_vect = Quaternion.from_Matrix(q.to_Matrix(True))
65
+ assert (q - q_full).is_zero_quaternion()
66
+ assert (q.vector_part() - q_vect).is_zero_quaternion()
67
+
68
+
69
+ def test_product_matrices():
70
+ q1 = Quaternion(w, x, y, z)
71
+ q2 = Quaternion(*(symbols("a:d")))
72
+ assert (q1 * q2).to_Matrix() == q1.product_matrix_left * q2.to_Matrix()
73
+ assert (q1 * q2).to_Matrix() == q2.product_matrix_right * q1.to_Matrix()
74
+
75
+ R1 = (q1.product_matrix_left * q1.product_matrix_right.T)[1:, 1:]
76
+ R2 = simplify(q1.to_rotation_matrix()*q1.norm()**2)
77
+ assert R1 == R2
78
+
79
+
80
+ def test_quaternion_axis_angle():
81
+
82
+ test_data = [ # axis, angle, expected_quaternion
83
+ ((1, 0, 0), 0, (1, 0, 0, 0)),
84
+ ((1, 0, 0), pi/2, (sqrt(2)/2, sqrt(2)/2, 0, 0)),
85
+ ((0, 1, 0), pi/2, (sqrt(2)/2, 0, sqrt(2)/2, 0)),
86
+ ((0, 0, 1), pi/2, (sqrt(2)/2, 0, 0, sqrt(2)/2)),
87
+ ((1, 0, 0), pi, (0, 1, 0, 0)),
88
+ ((0, 1, 0), pi, (0, 0, 1, 0)),
89
+ ((0, 0, 1), pi, (0, 0, 0, 1)),
90
+ ((1, 1, 1), pi, (0, 1/sqrt(3),1/sqrt(3),1/sqrt(3))),
91
+ ((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3), pi*2/3, (S.Half, S.Half, S.Half, S.Half))
92
+ ]
93
+
94
+ for axis, angle, expected in test_data:
95
+ assert Quaternion.from_axis_angle(axis, angle) == Quaternion(*expected)
96
+
97
+
98
+ def test_quaternion_axis_angle_simplification():
99
+ result = Quaternion.from_axis_angle((1, 2, 3), asin(4))
100
+ assert result.a == cos(asin(4)/2)
101
+ assert result.b == sqrt(14)*sin(asin(4)/2)/14
102
+ assert result.c == sqrt(14)*sin(asin(4)/2)/7
103
+ assert result.d == 3*sqrt(14)*sin(asin(4)/2)/14
104
+
105
+ def test_quaternion_complex_real_addition():
106
+ a = symbols("a", complex=True)
107
+ b = symbols("b", real=True)
108
+ # This symbol is not complex:
109
+ c = symbols("c", commutative=False)
110
+
111
+ q = Quaternion(w, x, y, z)
112
+ assert a + q == Quaternion(w + re(a), x + im(a), y, z)
113
+ assert 1 + q == Quaternion(1 + w, x, y, z)
114
+ assert I + q == Quaternion(w, 1 + x, y, z)
115
+ assert b + q == Quaternion(w + b, x, y, z)
116
+ raises(ValueError, lambda: c + q)
117
+ raises(ValueError, lambda: q * c)
118
+ raises(ValueError, lambda: c * q)
119
+
120
+ assert -q == Quaternion(-w, -x, -y, -z)
121
+
122
+ q1 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False)
123
+ q2 = Quaternion(1, 4, 7, 8)
124
+
125
+ assert q1 + (2 + 3*I) == Quaternion(5 + 7*I, 2 + 5*I, 0, 7 + 8*I)
126
+ assert q2 + (2 + 3*I) == Quaternion(3, 7, 7, 8)
127
+ assert q1 * (2 + 3*I) == \
128
+ Quaternion((2 + 3*I)*(3 + 4*I), (2 + 3*I)*(2 + 5*I), 0, (2 + 3*I)*(7 + 8*I))
129
+ assert q2 * (2 + 3*I) == Quaternion(-10, 11, 38, -5)
130
+
131
+ q1 = Quaternion(1, 2, 3, 4)
132
+ q0 = Quaternion(0, 0, 0, 0)
133
+ assert q1 + q0 == q1
134
+ assert q1 - q0 == q1
135
+ assert q1 - q1 == q0
136
+
137
+
138
+ def test_quaternion_subs():
139
+ q = Quaternion.from_axis_angle((0, 0, 1), phi)
140
+ assert q.subs(phi, 0) == Quaternion(1, 0, 0, 0)
141
+
142
+
143
+ def test_quaternion_evalf():
144
+ assert (Quaternion(sqrt(2), 0, 0, sqrt(3)).evalf() ==
145
+ Quaternion(sqrt(2).evalf(), 0, 0, sqrt(3).evalf()))
146
+ assert (Quaternion(1/sqrt(2), 0, 0, 1/sqrt(2)).evalf() ==
147
+ Quaternion((1/sqrt(2)).evalf(), 0, 0, (1/sqrt(2)).evalf()))
148
+
149
+
150
+ def test_quaternion_functions():
151
+ q = Quaternion(w, x, y, z)
152
+ q1 = Quaternion(1, 2, 3, 4)
153
+ q0 = Quaternion(0, 0, 0, 0)
154
+
155
+ assert conjugate(q) == Quaternion(w, -x, -y, -z)
156
+ assert q.norm() == sqrt(w**2 + x**2 + y**2 + z**2)
157
+ assert q.normalize() == Quaternion(w, x, y, z) / sqrt(w**2 + x**2 + y**2 + z**2)
158
+ assert q.inverse() == Quaternion(w, -x, -y, -z) / (w**2 + x**2 + y**2 + z**2)
159
+ assert q.inverse() == q.pow(-1)
160
+ raises(ValueError, lambda: q0.inverse())
161
+ assert q.pow(2) == Quaternion(w**2 - x**2 - y**2 - z**2, 2*w*x, 2*w*y, 2*w*z)
162
+ assert q**(2) == Quaternion(w**2 - x**2 - y**2 - z**2, 2*w*x, 2*w*y, 2*w*z)
163
+ assert q1.pow(-2) == Quaternion(
164
+ Rational(-7, 225), Rational(-1, 225), Rational(-1, 150), Rational(-2, 225))
165
+ assert q1**(-2) == Quaternion(
166
+ Rational(-7, 225), Rational(-1, 225), Rational(-1, 150), Rational(-2, 225))
167
+ assert q1.pow(-0.5) == NotImplemented
168
+ raises(TypeError, lambda: q1**(-0.5))
169
+
170
+ assert q1.exp() == \
171
+ Quaternion(E * cos(sqrt(29)),
172
+ 2 * sqrt(29) * E * sin(sqrt(29)) / 29,
173
+ 3 * sqrt(29) * E * sin(sqrt(29)) / 29,
174
+ 4 * sqrt(29) * E * sin(sqrt(29)) / 29)
175
+ assert q1.log() == \
176
+ Quaternion(log(sqrt(30)),
177
+ 2 * sqrt(29) * acos(sqrt(30)/30) / 29,
178
+ 3 * sqrt(29) * acos(sqrt(30)/30) / 29,
179
+ 4 * sqrt(29) * acos(sqrt(30)/30) / 29)
180
+
181
+ assert q1.pow_cos_sin(2) == \
182
+ Quaternion(30 * cos(2 * acos(sqrt(30)/30)),
183
+ 60 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29,
184
+ 90 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29,
185
+ 120 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29)
186
+
187
+ assert diff(Quaternion(x, x, x, x), x) == Quaternion(1, 1, 1, 1)
188
+
189
+ assert integrate(Quaternion(x, x, x, x), x) == \
190
+ Quaternion(x**2 / 2, x**2 / 2, x**2 / 2, x**2 / 2)
191
+
192
+ assert Quaternion(1, x, x**2, x**3).integrate(x) == \
193
+ Quaternion(x, x**2/2, x**3/3, x**4/4)
194
+
195
+ assert Quaternion(sin(x), cos(x), sin(2*x), cos(2*x)).integrate(x) == \
196
+ Quaternion(-cos(x), sin(x), -cos(2*x)/2, sin(2*x)/2)
197
+
198
+ assert Quaternion(x**2, y**2, z**2, x*y*z).integrate(x, y) == \
199
+ Quaternion(x**3*y/3, x*y**3/3, x*y*z**2, x**2*y**2*z/4)
200
+
201
+ assert Quaternion.rotate_point((1, 1, 1), q1) == (S.One / 5, 1, S(7) / 5)
202
+ n = Symbol('n')
203
+ raises(TypeError, lambda: q1**n)
204
+ n = Symbol('n', integer=True)
205
+ raises(TypeError, lambda: q1**n)
206
+
207
+ assert Quaternion(22, 23, 55, 8).scalar_part() == 22
208
+ assert Quaternion(w, x, y, z).scalar_part() == w
209
+
210
+ assert Quaternion(22, 23, 55, 8).vector_part() == Quaternion(0, 23, 55, 8)
211
+ assert Quaternion(w, x, y, z).vector_part() == Quaternion(0, x, y, z)
212
+
213
+ assert q1.axis() == Quaternion(0, 2*sqrt(29)/29, 3*sqrt(29)/29, 4*sqrt(29)/29)
214
+ assert q1.axis().pow(2) == Quaternion(-1, 0, 0, 0)
215
+ assert q0.axis().scalar_part() == 0
216
+ assert (q.axis() == Quaternion(0,
217
+ x/sqrt(x**2 + y**2 + z**2),
218
+ y/sqrt(x**2 + y**2 + z**2),
219
+ z/sqrt(x**2 + y**2 + z**2)))
220
+
221
+ assert q0.is_pure() is True
222
+ assert q1.is_pure() is False
223
+ assert Quaternion(0, 0, 0, 3).is_pure() is True
224
+ assert Quaternion(0, 2, 10, 3).is_pure() is True
225
+ assert Quaternion(w, 2, 10, 3).is_pure() is None
226
+
227
+ assert q1.angle() == 2*atan(sqrt(29))
228
+ assert q.angle() == 2*atan2(sqrt(x**2 + y**2 + z**2), w)
229
+
230
+ assert Quaternion.arc_coplanar(q1, Quaternion(2, 4, 6, 8)) is True
231
+ assert Quaternion.arc_coplanar(q1, Quaternion(1, -2, -3, -4)) is True
232
+ assert Quaternion.arc_coplanar(q1, Quaternion(1, 8, 12, 16)) is True
233
+ assert Quaternion.arc_coplanar(q1, Quaternion(1, 2, 3, 4)) is True
234
+ assert Quaternion.arc_coplanar(q1, Quaternion(w, 4, 6, 8)) is True
235
+ assert Quaternion.arc_coplanar(q1, Quaternion(2, 7, 4, 1)) is False
236
+ assert Quaternion.arc_coplanar(q1, Quaternion(w, x, y, z)) is None
237
+ raises(ValueError, lambda: Quaternion.arc_coplanar(q1, q0))
238
+
239
+ assert Quaternion.vector_coplanar(
240
+ Quaternion(0, 8, 12, 16),
241
+ Quaternion(0, 4, 6, 8),
242
+ Quaternion(0, 2, 3, 4)) is True
243
+ assert Quaternion.vector_coplanar(
244
+ Quaternion(0, 0, 0, 0), Quaternion(0, 4, 6, 8), Quaternion(0, 2, 3, 4)) is True
245
+ assert Quaternion.vector_coplanar(
246
+ Quaternion(0, 8, 2, 6), Quaternion(0, 1, 6, 6), Quaternion(0, 0, 3, 4)) is False
247
+ assert Quaternion.vector_coplanar(
248
+ Quaternion(0, 1, 3, 4),
249
+ Quaternion(0, 4, w, 6),
250
+ Quaternion(0, 6, 8, 1)) is None
251
+ raises(ValueError, lambda:
252
+ Quaternion.vector_coplanar(q0, Quaternion(0, 4, 6, 8), q1))
253
+
254
+ assert Quaternion(0, 1, 2, 3).parallel(Quaternion(0, 2, 4, 6)) is True
255
+ assert Quaternion(0, 1, 2, 3).parallel(Quaternion(0, 2, 2, 6)) is False
256
+ assert Quaternion(0, 1, 2, 3).parallel(Quaternion(w, x, y, 6)) is None
257
+ raises(ValueError, lambda: q0.parallel(q1))
258
+
259
+ assert Quaternion(0, 1, 2, 3).orthogonal(Quaternion(0, -2, 1, 0)) is True
260
+ assert Quaternion(0, 2, 4, 7).orthogonal(Quaternion(0, 2, 2, 6)) is False
261
+ assert Quaternion(0, 2, 4, 7).orthogonal(Quaternion(w, x, y, 6)) is None
262
+ raises(ValueError, lambda: q0.orthogonal(q1))
263
+
264
+ assert q1.index_vector() == Quaternion(
265
+ 0, 2*sqrt(870)/29,
266
+ 3*sqrt(870)/29,
267
+ 4*sqrt(870)/29)
268
+ assert Quaternion(0, 3, 9, 4).index_vector() == Quaternion(0, 3, 9, 4)
269
+
270
+ assert Quaternion(4, 3, 9, 4).mensor() == log(sqrt(122))
271
+ assert Quaternion(3, 3, 0, 2).mensor() == log(sqrt(22))
272
+
273
+ assert q0.is_zero_quaternion() is True
274
+ assert q1.is_zero_quaternion() is False
275
+ assert Quaternion(w, 0, 0, 0).is_zero_quaternion() is None
276
+
277
+ def test_quaternion_conversions():
278
+ q1 = Quaternion(1, 2, 3, 4)
279
+
280
+ assert q1.to_axis_angle() == ((2 * sqrt(29)/29,
281
+ 3 * sqrt(29)/29,
282
+ 4 * sqrt(29)/29),
283
+ 2 * acos(sqrt(30)/30))
284
+
285
+ assert (q1.to_rotation_matrix() ==
286
+ Matrix([[Rational(-2, 3), Rational(2, 15), Rational(11, 15)],
287
+ [Rational(2, 3), Rational(-1, 3), Rational(2, 3)],
288
+ [Rational(1, 3), Rational(14, 15), Rational(2, 15)]]))
289
+
290
+ assert (q1.to_rotation_matrix((1, 1, 1)) ==
291
+ Matrix([
292
+ [Rational(-2, 3), Rational(2, 15), Rational(11, 15), Rational(4, 5)],
293
+ [Rational(2, 3), Rational(-1, 3), Rational(2, 3), S.Zero],
294
+ [Rational(1, 3), Rational(14, 15), Rational(2, 15), Rational(-2, 5)],
295
+ [S.Zero, S.Zero, S.Zero, S.One]]))
296
+
297
+ theta = symbols("theta", real=True)
298
+ q2 = Quaternion(cos(theta/2), 0, 0, sin(theta/2))
299
+
300
+ assert trigsimp(q2.to_rotation_matrix()) == Matrix([
301
+ [cos(theta), -sin(theta), 0],
302
+ [sin(theta), cos(theta), 0],
303
+ [0, 0, 1]])
304
+
305
+ assert q2.to_axis_angle() == ((0, 0, sin(theta/2)/Abs(sin(theta/2))),
306
+ 2*acos(cos(theta/2)))
307
+
308
+ assert trigsimp(q2.to_rotation_matrix((1, 1, 1))) == Matrix([
309
+ [cos(theta), -sin(theta), 0, sin(theta) - cos(theta) + 1],
310
+ [sin(theta), cos(theta), 0, -sin(theta) - cos(theta) + 1],
311
+ [0, 0, 1, 0],
312
+ [0, 0, 0, 1]])
313
+
314
+
315
+ def test_rotation_matrix_homogeneous():
316
+ q = Quaternion(w, x, y, z)
317
+ R1 = q.to_rotation_matrix(homogeneous=True) * q.norm()**2
318
+ R2 = simplify(q.to_rotation_matrix(homogeneous=False) * q.norm()**2)
319
+ assert R1 == R2
320
+
321
+
322
+ def test_quaternion_rotation_iss1593():
323
+ """
324
+ There was a sign mistake in the definition,
325
+ of the rotation matrix. This tests that particular sign mistake.
326
+ See issue 1593 for reference.
327
+ See wikipedia
328
+ https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Quaternion-derived_rotation_matrix
329
+ for the correct definition
330
+ """
331
+ q = Quaternion(cos(phi/2), sin(phi/2), 0, 0)
332
+ assert(trigsimp(q.to_rotation_matrix()) == Matrix([
333
+ [1, 0, 0],
334
+ [0, cos(phi), -sin(phi)],
335
+ [0, sin(phi), cos(phi)]]))
336
+
337
+
338
+ def test_quaternion_multiplication():
339
+ q1 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False)
340
+ q2 = Quaternion(1, 2, 3, 5)
341
+ q3 = Quaternion(1, 1, 1, y)
342
+
343
+ assert Quaternion._generic_mul(S(4), S.One) == 4
344
+ assert (Quaternion._generic_mul(S(4), q1) ==
345
+ Quaternion(12 + 16*I, 8 + 20*I, 0, 28 + 32*I))
346
+ assert q2.mul(2) == Quaternion(2, 4, 6, 10)
347
+ assert q2.mul(q3) == Quaternion(-5*y - 4, 3*y - 2, 9 - 2*y, y + 4)
348
+ assert q2.mul(q3) == q2*q3
349
+
350
+ z = symbols('z', complex=True)
351
+ z_quat = Quaternion(re(z), im(z), 0, 0)
352
+ q = Quaternion(*symbols('q:4', real=True))
353
+
354
+ assert z * q == z_quat * q
355
+ assert q * z == q * z_quat
356
+
357
+
358
+ def test_issue_16318():
359
+ #for rtruediv
360
+ q0 = Quaternion(0, 0, 0, 0)
361
+ raises(ValueError, lambda: 1/q0)
362
+ #for rotate_point
363
+ q = Quaternion(1, 2, 3, 4)
364
+ (axis, angle) = q.to_axis_angle()
365
+ assert Quaternion.rotate_point((1, 1, 1), (axis, angle)) == (S.One / 5, 1, S(7) / 5)
366
+ #test for to_axis_angle
367
+ q = Quaternion(-1, 1, 1, 1)
368
+ axis = (-sqrt(3)/3, -sqrt(3)/3, -sqrt(3)/3)
369
+ angle = 2*pi/3
370
+ assert (axis, angle) == q.to_axis_angle()
371
+
372
+
373
+ @slow
374
+ def test_to_euler():
375
+ q = Quaternion(w, x, y, z)
376
+ q_normalized = q.normalize()
377
+
378
+ seqs = ['zxy', 'zyx', 'zyz', 'zxz']
379
+ seqs += [seq.upper() for seq in seqs]
380
+
381
+ for seq in seqs:
382
+ euler_from_q = q.to_euler(seq)
383
+ q_back = simplify(Quaternion.from_euler(euler_from_q, seq))
384
+ assert q_back == q_normalized
385
+
386
+
387
+ def test_to_euler_iss24504():
388
+ """
389
+ There was a mistake in the degenerate case testing
390
+ See issue 24504 for reference.
391
+ """
392
+ q = Quaternion.from_euler((phi, 0, 0), 'zyz')
393
+ assert trigsimp(q.to_euler('zyz'), inverse=True) == (phi, 0, 0)
394
+
395
+
396
+ def test_to_euler_numerical_singilarities():
397
+
398
+ def test_one_case(angles, seq):
399
+ q = Quaternion.from_euler(angles, seq)
400
+ assert q.to_euler(seq) == angles
401
+
402
+ # symmetric
403
+ test_one_case((pi/2, 0, 0), 'zyz')
404
+ test_one_case((pi/2, 0, 0), 'ZYZ')
405
+ test_one_case((pi/2, pi, 0), 'zyz')
406
+ test_one_case((pi/2, pi, 0), 'ZYZ')
407
+
408
+ # asymmetric
409
+ test_one_case((pi/2, pi/2, 0), 'zyx')
410
+ test_one_case((pi/2, -pi/2, 0), 'zyx')
411
+ test_one_case((pi/2, pi/2, 0), 'ZYX')
412
+ test_one_case((pi/2, -pi/2, 0), 'ZYX')
413
+
414
+
415
+ @slow
416
+ def test_to_euler_options():
417
+ def test_one_case(q):
418
+ angles1 = Matrix(q.to_euler(seq, True, True))
419
+ angles2 = Matrix(q.to_euler(seq, False, False))
420
+ angle_errors = simplify(angles1-angles2).evalf()
421
+ for angle_error in angle_errors:
422
+ # forcing angles to set {-pi, pi}
423
+ angle_error = (angle_error + pi) % (2 * pi) - pi
424
+ assert angle_error < 10e-7
425
+
426
+ for xyz in ('xyz', 'XYZ'):
427
+ for seq_tuple in permutations(xyz):
428
+ for symmetric in (True, False):
429
+ if symmetric:
430
+ seq = ''.join([seq_tuple[0], seq_tuple[1], seq_tuple[0]])
431
+ else:
432
+ seq = ''.join(seq_tuple)
433
+
434
+ for elements in product([-1, 0, 1], repeat=4):
435
+ q = Quaternion(*elements)
436
+ if not q.is_zero_quaternion():
437
+ test_one_case(q)
.venv/lib/python3.13/site-packages/sympy/logic/algorithms/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of DPLL algorithm
2
+
3
+ Further improvements: eliminate calls to pl_true, implement branching rules,
4
+ efficient unit propagation.
5
+
6
+ References:
7
+ - https://en.wikipedia.org/wiki/DPLL_algorithm
8
+ - https://www.researchgate.net/publication/242384772_Implementations_of_the_DPLL_Algorithm
9
+ """
10
+
11
+ from sympy.core.sorting import default_sort_key
12
+ from sympy.logic.boolalg import Or, Not, conjuncts, disjuncts, to_cnf, \
13
+ to_int_repr, _find_predicates
14
+ from sympy.assumptions.cnf import CNF
15
+ from sympy.logic.inference import pl_true, literal_symbol
16
+
17
+
18
+ def dpll_satisfiable(expr):
19
+ """
20
+ Check satisfiability of a propositional sentence.
21
+ It returns a model rather than True when it succeeds
22
+
23
+ >>> from sympy.abc import A, B
24
+ >>> from sympy.logic.algorithms.dpll import dpll_satisfiable
25
+ >>> dpll_satisfiable(A & ~B)
26
+ {A: True, B: False}
27
+ >>> dpll_satisfiable(A & ~A)
28
+ False
29
+
30
+ """
31
+ if not isinstance(expr, CNF):
32
+ clauses = conjuncts(to_cnf(expr))
33
+ else:
34
+ clauses = expr.clauses
35
+ if False in clauses:
36
+ return False
37
+ symbols = sorted(_find_predicates(expr), key=default_sort_key)
38
+ symbols_int_repr = set(range(1, len(symbols) + 1))
39
+ clauses_int_repr = to_int_repr(clauses, symbols)
40
+ result = dpll_int_repr(clauses_int_repr, symbols_int_repr, {})
41
+ if not result:
42
+ return result
43
+ output = {}
44
+ for key in result:
45
+ output.update({symbols[key - 1]: result[key]})
46
+ return output
47
+
48
+
49
+ def dpll(clauses, symbols, model):
50
+ """
51
+ Compute satisfiability in a partial model.
52
+ Clauses is an array of conjuncts.
53
+
54
+ >>> from sympy.abc import A, B, D
55
+ >>> from sympy.logic.algorithms.dpll import dpll
56
+ >>> dpll([A, B, D], [A, B], {D: False})
57
+ False
58
+
59
+ """
60
+ # compute DP kernel
61
+ P, value = find_unit_clause(clauses, model)
62
+ while P:
63
+ model.update({P: value})
64
+ symbols.remove(P)
65
+ if not value:
66
+ P = ~P
67
+ clauses = unit_propagate(clauses, P)
68
+ P, value = find_unit_clause(clauses, model)
69
+ P, value = find_pure_symbol(symbols, clauses)
70
+ while P:
71
+ model.update({P: value})
72
+ symbols.remove(P)
73
+ if not value:
74
+ P = ~P
75
+ clauses = unit_propagate(clauses, P)
76
+ P, value = find_pure_symbol(symbols, clauses)
77
+ # end DP kernel
78
+ unknown_clauses = []
79
+ for c in clauses:
80
+ val = pl_true(c, model)
81
+ if val is False:
82
+ return False
83
+ if val is not True:
84
+ unknown_clauses.append(c)
85
+ if not unknown_clauses:
86
+ return model
87
+ if not clauses:
88
+ return model
89
+ P = symbols.pop()
90
+ model_copy = model.copy()
91
+ model.update({P: True})
92
+ model_copy.update({P: False})
93
+ symbols_copy = symbols[:]
94
+ return (dpll(unit_propagate(unknown_clauses, P), symbols, model) or
95
+ dpll(unit_propagate(unknown_clauses, Not(P)), symbols_copy, model_copy))
96
+
97
+
98
+ def dpll_int_repr(clauses, symbols, model):
99
+ """
100
+ Compute satisfiability in a partial model.
101
+ Arguments are expected to be in integer representation
102
+
103
+ >>> from sympy.logic.algorithms.dpll import dpll_int_repr
104
+ >>> dpll_int_repr([{1}, {2}, {3}], {1, 2}, {3: False})
105
+ False
106
+
107
+ """
108
+ # compute DP kernel
109
+ P, value = find_unit_clause_int_repr(clauses, model)
110
+ while P:
111
+ model.update({P: value})
112
+ symbols.remove(P)
113
+ if not value:
114
+ P = -P
115
+ clauses = unit_propagate_int_repr(clauses, P)
116
+ P, value = find_unit_clause_int_repr(clauses, model)
117
+ P, value = find_pure_symbol_int_repr(symbols, clauses)
118
+ while P:
119
+ model.update({P: value})
120
+ symbols.remove(P)
121
+ if not value:
122
+ P = -P
123
+ clauses = unit_propagate_int_repr(clauses, P)
124
+ P, value = find_pure_symbol_int_repr(symbols, clauses)
125
+ # end DP kernel
126
+ unknown_clauses = []
127
+ for c in clauses:
128
+ val = pl_true_int_repr(c, model)
129
+ if val is False:
130
+ return False
131
+ if val is not True:
132
+ unknown_clauses.append(c)
133
+ if not unknown_clauses:
134
+ return model
135
+ P = symbols.pop()
136
+ model_copy = model.copy()
137
+ model.update({P: True})
138
+ model_copy.update({P: False})
139
+ symbols_copy = symbols.copy()
140
+ return (dpll_int_repr(unit_propagate_int_repr(unknown_clauses, P), symbols, model) or
141
+ dpll_int_repr(unit_propagate_int_repr(unknown_clauses, -P), symbols_copy, model_copy))
142
+
143
+ ### helper methods for DPLL
144
+
145
+
146
+ def pl_true_int_repr(clause, model={}):
147
+ """
148
+ Lightweight version of pl_true.
149
+ Argument clause represents the set of args of an Or clause. This is used
150
+ inside dpll_int_repr, it is not meant to be used directly.
151
+
152
+ >>> from sympy.logic.algorithms.dpll import pl_true_int_repr
153
+ >>> pl_true_int_repr({1, 2}, {1: False})
154
+ >>> pl_true_int_repr({1, 2}, {1: False, 2: False})
155
+ False
156
+
157
+ """
158
+ result = False
159
+ for lit in clause:
160
+ if lit < 0:
161
+ p = model.get(-lit)
162
+ if p is not None:
163
+ p = not p
164
+ else:
165
+ p = model.get(lit)
166
+ if p is True:
167
+ return True
168
+ elif p is None:
169
+ result = None
170
+ return result
171
+
172
+
173
+ def unit_propagate(clauses, symbol):
174
+ """
175
+ Returns an equivalent set of clauses
176
+ If a set of clauses contains the unit clause l, the other clauses are
177
+ simplified by the application of the two following rules:
178
+
179
+ 1. every clause containing l is removed
180
+ 2. in every clause that contains ~l this literal is deleted
181
+
182
+ Arguments are expected to be in CNF.
183
+
184
+ >>> from sympy.abc import A, B, D
185
+ >>> from sympy.logic.algorithms.dpll import unit_propagate
186
+ >>> unit_propagate([A | B, D | ~B, B], B)
187
+ [D, B]
188
+
189
+ """
190
+ output = []
191
+ for c in clauses:
192
+ if c.func != Or:
193
+ output.append(c)
194
+ continue
195
+ for arg in c.args:
196
+ if arg == ~symbol:
197
+ output.append(Or(*[x for x in c.args if x != ~symbol]))
198
+ break
199
+ if arg == symbol:
200
+ break
201
+ else:
202
+ output.append(c)
203
+ return output
204
+
205
+
206
+ def unit_propagate_int_repr(clauses, s):
207
+ """
208
+ Same as unit_propagate, but arguments are expected to be in integer
209
+ representation
210
+
211
+ >>> from sympy.logic.algorithms.dpll import unit_propagate_int_repr
212
+ >>> unit_propagate_int_repr([{1, 2}, {3, -2}, {2}], 2)
213
+ [{3}]
214
+
215
+ """
216
+ negated = {-s}
217
+ return [clause - negated for clause in clauses if s not in clause]
218
+
219
+
220
+ def find_pure_symbol(symbols, unknown_clauses):
221
+ """
222
+ Find a symbol and its value if it appears only as a positive literal
223
+ (or only as a negative) in clauses.
224
+
225
+ >>> from sympy.abc import A, B, D
226
+ >>> from sympy.logic.algorithms.dpll import find_pure_symbol
227
+ >>> find_pure_symbol([A, B, D], [A|~B,~B|~D,D|A])
228
+ (A, True)
229
+
230
+ """
231
+ for sym in symbols:
232
+ found_pos, found_neg = False, False
233
+ for c in unknown_clauses:
234
+ if not found_pos and sym in disjuncts(c):
235
+ found_pos = True
236
+ if not found_neg and Not(sym) in disjuncts(c):
237
+ found_neg = True
238
+ if found_pos != found_neg:
239
+ return sym, found_pos
240
+ return None, None
241
+
242
+
243
+ def find_pure_symbol_int_repr(symbols, unknown_clauses):
244
+ """
245
+ Same as find_pure_symbol, but arguments are expected
246
+ to be in integer representation
247
+
248
+ >>> from sympy.logic.algorithms.dpll import find_pure_symbol_int_repr
249
+ >>> find_pure_symbol_int_repr({1,2,3},
250
+ ... [{1, -2}, {-2, -3}, {3, 1}])
251
+ (1, True)
252
+
253
+ """
254
+ all_symbols = set().union(*unknown_clauses)
255
+ found_pos = all_symbols.intersection(symbols)
256
+ found_neg = all_symbols.intersection([-s for s in symbols])
257
+ for p in found_pos:
258
+ if -p not in found_neg:
259
+ return p, True
260
+ for p in found_neg:
261
+ if -p not in found_pos:
262
+ return -p, False
263
+ return None, None
264
+
265
+
266
+ def find_unit_clause(clauses, model):
267
+ """
268
+ A unit clause has only 1 variable that is not bound in the model.
269
+
270
+ >>> from sympy.abc import A, B, D
271
+ >>> from sympy.logic.algorithms.dpll import find_unit_clause
272
+ >>> find_unit_clause([A | B | D, B | ~D, A | ~B], {A:True})
273
+ (B, False)
274
+
275
+ """
276
+ for clause in clauses:
277
+ num_not_in_model = 0
278
+ for literal in disjuncts(clause):
279
+ sym = literal_symbol(literal)
280
+ if sym not in model:
281
+ num_not_in_model += 1
282
+ P, value = sym, not isinstance(literal, Not)
283
+ if num_not_in_model == 1:
284
+ return P, value
285
+ return None, None
286
+
287
+
288
+ def find_unit_clause_int_repr(clauses, model):
289
+ """
290
+ Same as find_unit_clause, but arguments are expected to be in
291
+ integer representation.
292
+
293
+ >>> from sympy.logic.algorithms.dpll import find_unit_clause_int_repr
294
+ >>> find_unit_clause_int_repr([{1, 2, 3},
295
+ ... {2, -3}, {1, -2}], {1: True})
296
+ (2, False)
297
+
298
+ """
299
+ bound = set(model) | {-sym for sym in model}
300
+ for clause in clauses:
301
+ unbound = clause - bound
302
+ if len(unbound) == 1:
303
+ p = unbound.pop()
304
+ if p < 0:
305
+ return -p, False
306
+ else:
307
+ return p, True
308
+ return None, None
.venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll2.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implementation of DPLL algorithm
2
+
3
+ Features:
4
+ - Clause learning
5
+ - Watch literal scheme
6
+ - VSIDS heuristic
7
+
8
+ References:
9
+ - https://en.wikipedia.org/wiki/DPLL_algorithm
10
+ """
11
+
12
+ from collections import defaultdict
13
+ from heapq import heappush, heappop
14
+
15
+ from sympy.core.sorting import ordered
16
+ from sympy.assumptions.cnf import EncodedCNF
17
+
18
+ from sympy.logic.algorithms.lra_theory import LRASolver
19
+
20
+
21
+ def dpll_satisfiable(expr, all_models=False, use_lra_theory=False):
22
+ """
23
+ Check satisfiability of a propositional sentence.
24
+ It returns a model rather than True when it succeeds.
25
+ Returns a generator of all models if all_models is True.
26
+
27
+ Examples
28
+ ========
29
+
30
+ >>> from sympy.abc import A, B
31
+ >>> from sympy.logic.algorithms.dpll2 import dpll_satisfiable
32
+ >>> dpll_satisfiable(A & ~B)
33
+ {A: True, B: False}
34
+ >>> dpll_satisfiable(A & ~A)
35
+ False
36
+
37
+ """
38
+ if not isinstance(expr, EncodedCNF):
39
+ exprs = EncodedCNF()
40
+ exprs.add_prop(expr)
41
+ expr = exprs
42
+
43
+ # Return UNSAT when False (encoded as 0) is present in the CNF
44
+ if {0} in expr.data:
45
+ if all_models:
46
+ return (f for f in [False])
47
+ return False
48
+
49
+ if use_lra_theory:
50
+ lra, immediate_conflicts = LRASolver.from_encoded_cnf(expr)
51
+ else:
52
+ lra = None
53
+ immediate_conflicts = []
54
+ solver = SATSolver(expr.data + immediate_conflicts, expr.variables, set(), expr.symbols, lra_theory=lra)
55
+ models = solver._find_model()
56
+
57
+ if all_models:
58
+ return _all_models(models)
59
+
60
+ try:
61
+ return next(models)
62
+ except StopIteration:
63
+ return False
64
+
65
+ # Uncomment to confirm the solution is valid (hitting set for the clauses)
66
+ #else:
67
+ #for cls in clauses_int_repr:
68
+ #assert solver.var_settings.intersection(cls)
69
+
70
+
71
+ def _all_models(models):
72
+ satisfiable = False
73
+ try:
74
+ while True:
75
+ yield next(models)
76
+ satisfiable = True
77
+ except StopIteration:
78
+ if not satisfiable:
79
+ yield False
80
+
81
+
82
+ class SATSolver:
83
+ """
84
+ Class for representing a SAT solver capable of
85
+ finding a model to a boolean theory in conjunctive
86
+ normal form.
87
+ """
88
+
89
+ def __init__(self, clauses, variables, var_settings, symbols=None,
90
+ heuristic='vsids', clause_learning='none', INTERVAL=500,
91
+ lra_theory = None):
92
+
93
+ self.var_settings = var_settings
94
+ self.heuristic = heuristic
95
+ self.is_unsatisfied = False
96
+ self._unit_prop_queue = []
97
+ self.update_functions = []
98
+ self.INTERVAL = INTERVAL
99
+
100
+ if symbols is None:
101
+ self.symbols = list(ordered(variables))
102
+ else:
103
+ self.symbols = symbols
104
+
105
+ self._initialize_variables(variables)
106
+ self._initialize_clauses(clauses)
107
+
108
+ if 'vsids' == heuristic:
109
+ self._vsids_init()
110
+ self.heur_calculate = self._vsids_calculate
111
+ self.heur_lit_assigned = self._vsids_lit_assigned
112
+ self.heur_lit_unset = self._vsids_lit_unset
113
+ self.heur_clause_added = self._vsids_clause_added
114
+
115
+ # Note: Uncomment this if/when clause learning is enabled
116
+ #self.update_functions.append(self._vsids_decay)
117
+
118
+ else:
119
+ raise NotImplementedError
120
+
121
+ if 'simple' == clause_learning:
122
+ self.add_learned_clause = self._simple_add_learned_clause
123
+ self.compute_conflict = self._simple_compute_conflict
124
+ self.update_functions.append(self._simple_clean_clauses)
125
+ elif 'none' == clause_learning:
126
+ self.add_learned_clause = lambda x: None
127
+ self.compute_conflict = lambda: None
128
+ else:
129
+ raise NotImplementedError
130
+
131
+ # Create the base level
132
+ self.levels = [Level(0)]
133
+ self._current_level.varsettings = var_settings
134
+
135
+ # Keep stats
136
+ self.num_decisions = 0
137
+ self.num_learned_clauses = 0
138
+ self.original_num_clauses = len(self.clauses)
139
+
140
+ self.lra = lra_theory
141
+
142
+ def _initialize_variables(self, variables):
143
+ """Set up the variable data structures needed."""
144
+ self.sentinels = defaultdict(set)
145
+ self.occurrence_count = defaultdict(int)
146
+ self.variable_set = [False] * (len(variables) + 1)
147
+
148
+ def _initialize_clauses(self, clauses):
149
+ """Set up the clause data structures needed.
150
+
151
+ For each clause, the following changes are made:
152
+ - Unit clauses are queued for propagation right away.
153
+ - Non-unit clauses have their first and last literals set as sentinels.
154
+ - The number of clauses a literal appears in is computed.
155
+ """
156
+ self.clauses = [list(clause) for clause in clauses]
157
+
158
+ for i, clause in enumerate(self.clauses):
159
+
160
+ # Handle the unit clauses
161
+ if 1 == len(clause):
162
+ self._unit_prop_queue.append(clause[0])
163
+ continue
164
+
165
+ self.sentinels[clause[0]].add(i)
166
+ self.sentinels[clause[-1]].add(i)
167
+
168
+ for lit in clause:
169
+ self.occurrence_count[lit] += 1
170
+
171
+ def _find_model(self):
172
+ """
173
+ Main DPLL loop. Returns a generator of models.
174
+
175
+ Variables are chosen successively, and assigned to be either
176
+ True or False. If a solution is not found with this setting,
177
+ the opposite is chosen and the search continues. The solver
178
+ halts when every variable has a setting.
179
+
180
+ Examples
181
+ ========
182
+
183
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
184
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
185
+ ... {3, -2}], {1, 2, 3}, set())
186
+ >>> list(l._find_model())
187
+ [{1: True, 2: False, 3: False}, {1: True, 2: True, 3: True}]
188
+
189
+ >>> from sympy.abc import A, B, C
190
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
191
+ ... {3, -2}], {1, 2, 3}, set(), [A, B, C])
192
+ >>> list(l._find_model())
193
+ [{A: True, B: False, C: False}, {A: True, B: True, C: True}]
194
+
195
+ """
196
+
197
+ # We use this variable to keep track of if we should flip a
198
+ # variable setting in successive rounds
199
+ flip_var = False
200
+
201
+ # Check if unit prop says the theory is unsat right off the bat
202
+ self._simplify()
203
+ if self.is_unsatisfied:
204
+ return
205
+
206
+ # While the theory still has clauses remaining
207
+ while True:
208
+ # Perform cleanup / fixup at regular intervals
209
+ if self.num_decisions % self.INTERVAL == 0:
210
+ for func in self.update_functions:
211
+ func()
212
+
213
+ if flip_var:
214
+ # We have just backtracked and we are trying to opposite literal
215
+ flip_var = False
216
+ lit = self._current_level.decision
217
+
218
+ else:
219
+ # Pick a literal to set
220
+ lit = self.heur_calculate()
221
+ self.num_decisions += 1
222
+
223
+ # Stopping condition for a satisfying theory
224
+ if 0 == lit:
225
+
226
+ # check if assignment satisfies lra theory
227
+ if self.lra:
228
+ for enc_var in self.var_settings:
229
+ res = self.lra.assert_lit(enc_var)
230
+ if res is not None:
231
+ break
232
+ res = self.lra.check()
233
+ self.lra.reset_bounds()
234
+ else:
235
+ res = None
236
+ if res is None or res[0]:
237
+ yield {self.symbols[abs(lit) - 1]:
238
+ lit > 0 for lit in self.var_settings}
239
+ else:
240
+ self._simple_add_learned_clause(res[1])
241
+
242
+ # backtrack until we unassign one of the literals causing the conflict
243
+ while not any(-lit in res[1] for lit in self._current_level.var_settings):
244
+ self._undo()
245
+
246
+ while self._current_level.flipped:
247
+ self._undo()
248
+ if len(self.levels) == 1:
249
+ return
250
+ flip_lit = -self._current_level.decision
251
+ self._undo()
252
+ self.levels.append(Level(flip_lit, flipped=True))
253
+ flip_var = True
254
+ continue
255
+
256
+ # Start the new decision level
257
+ self.levels.append(Level(lit))
258
+
259
+ # Assign the literal, updating the clauses it satisfies
260
+ self._assign_literal(lit)
261
+
262
+ # _simplify the theory
263
+ self._simplify()
264
+
265
+ # Check if we've made the theory unsat
266
+ if self.is_unsatisfied:
267
+
268
+ self.is_unsatisfied = False
269
+
270
+ # We unroll all of the decisions until we can flip a literal
271
+ while self._current_level.flipped:
272
+ self._undo()
273
+
274
+ # If we've unrolled all the way, the theory is unsat
275
+ if 1 == len(self.levels):
276
+ return
277
+
278
+ # Detect and add a learned clause
279
+ self.add_learned_clause(self.compute_conflict())
280
+
281
+ # Try the opposite setting of the most recent decision
282
+ flip_lit = -self._current_level.decision
283
+ self._undo()
284
+ self.levels.append(Level(flip_lit, flipped=True))
285
+ flip_var = True
286
+
287
+ ########################
288
+ # Helper Methods #
289
+ ########################
290
+ @property
291
+ def _current_level(self):
292
+ """The current decision level data structure
293
+
294
+ Examples
295
+ ========
296
+
297
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
298
+ >>> l = SATSolver([{1}, {2}], {1, 2}, set())
299
+ >>> next(l._find_model())
300
+ {1: True, 2: True}
301
+ >>> l._current_level.decision
302
+ 0
303
+ >>> l._current_level.flipped
304
+ False
305
+ >>> l._current_level.var_settings
306
+ {1, 2}
307
+
308
+ """
309
+ return self.levels[-1]
310
+
311
+ def _clause_sat(self, cls):
312
+ """Check if a clause is satisfied by the current variable setting.
313
+
314
+ Examples
315
+ ========
316
+
317
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
318
+ >>> l = SATSolver([{1}, {-1}], {1}, set())
319
+ >>> try:
320
+ ... next(l._find_model())
321
+ ... except StopIteration:
322
+ ... pass
323
+ >>> l._clause_sat(0)
324
+ False
325
+ >>> l._clause_sat(1)
326
+ True
327
+
328
+ """
329
+ for lit in self.clauses[cls]:
330
+ if lit in self.var_settings:
331
+ return True
332
+ return False
333
+
334
+ def _is_sentinel(self, lit, cls):
335
+ """Check if a literal is a sentinel of a given clause.
336
+
337
+ Examples
338
+ ========
339
+
340
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
341
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
342
+ ... {3, -2}], {1, 2, 3}, set())
343
+ >>> next(l._find_model())
344
+ {1: True, 2: False, 3: False}
345
+ >>> l._is_sentinel(2, 3)
346
+ True
347
+ >>> l._is_sentinel(-3, 1)
348
+ False
349
+
350
+ """
351
+ return cls in self.sentinels[lit]
352
+
353
+ def _assign_literal(self, lit):
354
+ """Make a literal assignment.
355
+
356
+ The literal assignment must be recorded as part of the current
357
+ decision level. Additionally, if the literal is marked as a
358
+ sentinel of any clause, then a new sentinel must be chosen. If
359
+ this is not possible, then unit propagation is triggered and
360
+ another literal is added to the queue to be set in the future.
361
+
362
+ Examples
363
+ ========
364
+
365
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
366
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
367
+ ... {3, -2}], {1, 2, 3}, set())
368
+ >>> next(l._find_model())
369
+ {1: True, 2: False, 3: False}
370
+ >>> l.var_settings
371
+ {-3, -2, 1}
372
+
373
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
374
+ ... {3, -2}], {1, 2, 3}, set())
375
+ >>> l._assign_literal(-1)
376
+ >>> try:
377
+ ... next(l._find_model())
378
+ ... except StopIteration:
379
+ ... pass
380
+ >>> l.var_settings
381
+ {-1}
382
+
383
+ """
384
+ self.var_settings.add(lit)
385
+ self._current_level.var_settings.add(lit)
386
+ self.variable_set[abs(lit)] = True
387
+ self.heur_lit_assigned(lit)
388
+
389
+ sentinel_list = list(self.sentinels[-lit])
390
+
391
+ for cls in sentinel_list:
392
+ if not self._clause_sat(cls):
393
+ other_sentinel = None
394
+ for newlit in self.clauses[cls]:
395
+ if newlit != -lit:
396
+ if self._is_sentinel(newlit, cls):
397
+ other_sentinel = newlit
398
+ elif not self.variable_set[abs(newlit)]:
399
+ self.sentinels[-lit].remove(cls)
400
+ self.sentinels[newlit].add(cls)
401
+ other_sentinel = None
402
+ break
403
+
404
+ # Check if no sentinel update exists
405
+ if other_sentinel:
406
+ self._unit_prop_queue.append(other_sentinel)
407
+
408
+ def _undo(self):
409
+ """
410
+ _undo the changes of the most recent decision level.
411
+
412
+ Examples
413
+ ========
414
+
415
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
416
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
417
+ ... {3, -2}], {1, 2, 3}, set())
418
+ >>> next(l._find_model())
419
+ {1: True, 2: False, 3: False}
420
+ >>> level = l._current_level
421
+ >>> level.decision, level.var_settings, level.flipped
422
+ (-3, {-3, -2}, False)
423
+ >>> l._undo()
424
+ >>> level = l._current_level
425
+ >>> level.decision, level.var_settings, level.flipped
426
+ (0, {1}, False)
427
+
428
+ """
429
+ # Undo the variable settings
430
+ for lit in self._current_level.var_settings:
431
+ self.var_settings.remove(lit)
432
+ self.heur_lit_unset(lit)
433
+ self.variable_set[abs(lit)] = False
434
+
435
+ # Pop the level off the stack
436
+ self.levels.pop()
437
+
438
+ #########################
439
+ # Propagation #
440
+ #########################
441
+ """
442
+ Propagation methods should attempt to soundly simplify the boolean
443
+ theory, and return True if any simplification occurred and False
444
+ otherwise.
445
+ """
446
+ def _simplify(self):
447
+ """Iterate over the various forms of propagation to simplify the theory.
448
+
449
+ Examples
450
+ ========
451
+
452
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
453
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
454
+ ... {3, -2}], {1, 2, 3}, set())
455
+ >>> l.variable_set
456
+ [False, False, False, False]
457
+ >>> l.sentinels
458
+ {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}}
459
+
460
+ >>> l._simplify()
461
+
462
+ >>> l.variable_set
463
+ [False, True, False, False]
464
+ >>> l.sentinels
465
+ {-3: {0, 2}, -2: {3, 4}, -1: set(), 2: {0, 3},
466
+ ...3: {2, 4}}
467
+
468
+ """
469
+ changed = True
470
+ while changed:
471
+ changed = False
472
+ changed |= self._unit_prop()
473
+ changed |= self._pure_literal()
474
+
475
+ def _unit_prop(self):
476
+ """Perform unit propagation on the current theory."""
477
+ result = len(self._unit_prop_queue) > 0
478
+ while self._unit_prop_queue:
479
+ next_lit = self._unit_prop_queue.pop()
480
+ if -next_lit in self.var_settings:
481
+ self.is_unsatisfied = True
482
+ self._unit_prop_queue = []
483
+ return False
484
+ else:
485
+ self._assign_literal(next_lit)
486
+
487
+ return result
488
+
489
+ def _pure_literal(self):
490
+ """Look for pure literals and assign them when found."""
491
+ return False
492
+
493
+ #########################
494
+ # Heuristics #
495
+ #########################
496
+ def _vsids_init(self):
497
+ """Initialize the data structures needed for the VSIDS heuristic."""
498
+ self.lit_heap = []
499
+ self.lit_scores = {}
500
+
501
+ for var in range(1, len(self.variable_set)):
502
+ self.lit_scores[var] = float(-self.occurrence_count[var])
503
+ self.lit_scores[-var] = float(-self.occurrence_count[-var])
504
+ heappush(self.lit_heap, (self.lit_scores[var], var))
505
+ heappush(self.lit_heap, (self.lit_scores[-var], -var))
506
+
507
+ def _vsids_decay(self):
508
+ """Decay the VSIDS scores for every literal.
509
+
510
+ Examples
511
+ ========
512
+
513
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
514
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
515
+ ... {3, -2}], {1, 2, 3}, set())
516
+
517
+ >>> l.lit_scores
518
+ {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0}
519
+
520
+ >>> l._vsids_decay()
521
+
522
+ >>> l.lit_scores
523
+ {-3: -1.0, -2: -1.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -1.0}
524
+
525
+ """
526
+ # We divide every literal score by 2 for a decay factor
527
+ # Note: This doesn't change the heap property
528
+ for lit in self.lit_scores.keys():
529
+ self.lit_scores[lit] /= 2.0
530
+
531
+ def _vsids_calculate(self):
532
+ """
533
+ VSIDS Heuristic Calculation
534
+
535
+ Examples
536
+ ========
537
+
538
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
539
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
540
+ ... {3, -2}], {1, 2, 3}, set())
541
+
542
+ >>> l.lit_heap
543
+ [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)]
544
+
545
+ >>> l._vsids_calculate()
546
+ -3
547
+
548
+ >>> l.lit_heap
549
+ [(-2.0, -2), (-2.0, 2), (0.0, -1), (0.0, 1), (-2.0, 3)]
550
+
551
+ """
552
+ if len(self.lit_heap) == 0:
553
+ return 0
554
+
555
+ # Clean out the front of the heap as long the variables are set
556
+ while self.variable_set[abs(self.lit_heap[0][1])]:
557
+ heappop(self.lit_heap)
558
+ if len(self.lit_heap) == 0:
559
+ return 0
560
+
561
+ return heappop(self.lit_heap)[1]
562
+
563
+ def _vsids_lit_assigned(self, lit):
564
+ """Handle the assignment of a literal for the VSIDS heuristic."""
565
+ pass
566
+
567
+ def _vsids_lit_unset(self, lit):
568
+ """Handle the unsetting of a literal for the VSIDS heuristic.
569
+
570
+ Examples
571
+ ========
572
+
573
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
574
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
575
+ ... {3, -2}], {1, 2, 3}, set())
576
+ >>> l.lit_heap
577
+ [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)]
578
+
579
+ >>> l._vsids_lit_unset(2)
580
+
581
+ >>> l.lit_heap
582
+ [(-2.0, -3), (-2.0, -2), (-2.0, -2), (-2.0, 2), (-2.0, 3), (0.0, -1),
583
+ ...(-2.0, 2), (0.0, 1)]
584
+
585
+ """
586
+ var = abs(lit)
587
+ heappush(self.lit_heap, (self.lit_scores[var], var))
588
+ heappush(self.lit_heap, (self.lit_scores[-var], -var))
589
+
590
+ def _vsids_clause_added(self, cls):
591
+ """Handle the addition of a new clause for the VSIDS heuristic.
592
+
593
+ Examples
594
+ ========
595
+
596
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
597
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
598
+ ... {3, -2}], {1, 2, 3}, set())
599
+
600
+ >>> l.num_learned_clauses
601
+ 0
602
+ >>> l.lit_scores
603
+ {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0}
604
+
605
+ >>> l._vsids_clause_added({2, -3})
606
+
607
+ >>> l.num_learned_clauses
608
+ 1
609
+ >>> l.lit_scores
610
+ {-3: -1.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -2.0}
611
+
612
+ """
613
+ self.num_learned_clauses += 1
614
+ for lit in cls:
615
+ self.lit_scores[lit] += 1
616
+
617
+ ########################
618
+ # Clause Learning #
619
+ ########################
620
+ def _simple_add_learned_clause(self, cls):
621
+ """Add a new clause to the theory.
622
+
623
+ Examples
624
+ ========
625
+
626
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
627
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
628
+ ... {3, -2}], {1, 2, 3}, set())
629
+
630
+ >>> l.num_learned_clauses
631
+ 0
632
+ >>> l.clauses
633
+ [[2, -3], [1], [3, -3], [2, -2], [3, -2]]
634
+ >>> l.sentinels
635
+ {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}}
636
+
637
+ >>> l._simple_add_learned_clause([3])
638
+
639
+ >>> l.clauses
640
+ [[2, -3], [1], [3, -3], [2, -2], [3, -2], [3]]
641
+ >>> l.sentinels
642
+ {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4, 5}}
643
+
644
+ """
645
+ cls_num = len(self.clauses)
646
+ self.clauses.append(cls)
647
+
648
+ for lit in cls:
649
+ self.occurrence_count[lit] += 1
650
+
651
+ self.sentinels[cls[0]].add(cls_num)
652
+ self.sentinels[cls[-1]].add(cls_num)
653
+
654
+ self.heur_clause_added(cls)
655
+
656
+ def _simple_compute_conflict(self):
657
+ """ Build a clause representing the fact that at least one decision made
658
+ so far is wrong.
659
+
660
+ Examples
661
+ ========
662
+
663
+ >>> from sympy.logic.algorithms.dpll2 import SATSolver
664
+ >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
665
+ ... {3, -2}], {1, 2, 3}, set())
666
+ >>> next(l._find_model())
667
+ {1: True, 2: False, 3: False}
668
+ >>> l._simple_compute_conflict()
669
+ [3]
670
+
671
+ """
672
+ return [-(level.decision) for level in self.levels[1:]]
673
+
674
+ def _simple_clean_clauses(self):
675
+ """Clean up learned clauses."""
676
+ pass
677
+
678
+
679
+ class Level:
680
+ """
681
+ Represents a single level in the DPLL algorithm, and contains
682
+ enough information for a sound backtracking procedure.
683
+ """
684
+
685
+ def __init__(self, decision, flipped=False):
686
+ self.decision = decision
687
+ self.var_settings = set()
688
+ self.flipped = flipped
.venv/lib/python3.13/site-packages/sympy/logic/algorithms/lra_theory.py ADDED
@@ -0,0 +1,912 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implements "A Fast Linear-Arithmetic Solver for DPLL(T)"
2
+
3
+ The LRASolver class defined in this file can be used
4
+ in conjunction with a SAT solver to check the
5
+ satisfiability of formulas involving inequalities.
6
+
7
+ Here's an example of how that would work:
8
+
9
+ Suppose you want to check the satisfiability of
10
+ the following formula:
11
+
12
+ >>> from sympy.core.relational import Eq
13
+ >>> from sympy.abc import x, y
14
+ >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & (~Eq(y, 1) | Eq(1, 2))
15
+
16
+ First a preprocessing step should be done on f. During preprocessing,
17
+ f should be checked for any predicates such as `Q.prime` that can't be
18
+ handled. Also unequality like `~Eq(y, 1)` should be split.
19
+
20
+ I should mention that the paper says to split both equalities and
21
+ unequality, but this implementation only requires that unequality
22
+ be split.
23
+
24
+ >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & ((y < 1) | (y > 1) | Eq(1, 2))
25
+
26
+ Then an LRASolver instance needs to be initialized with this formula.
27
+
28
+ >>> from sympy.assumptions.cnf import CNF, EncodedCNF
29
+ >>> from sympy.assumptions.ask import Q
30
+ >>> from sympy.logic.algorithms.lra_theory import LRASolver
31
+ >>> cnf = CNF.from_prop(f)
32
+ >>> enc = EncodedCNF()
33
+ >>> enc.add_from_cnf(cnf)
34
+ >>> lra, conflicts = LRASolver.from_encoded_cnf(enc)
35
+
36
+ Any immediate one-lital conflicts clauses will be detected here.
37
+ In this example, `~Eq(1, 2)` is one such conflict clause. We'll
38
+ want to add it to `f` so that the SAT solver is forced to
39
+ assign Eq(1, 2) to False.
40
+
41
+ >>> f = f & ~Eq(1, 2)
42
+
43
+ Now that the one-literal conflict clauses have been added
44
+ and an lra object has been initialized, we can pass `f`
45
+ to a SAT solver. The SAT solver will give us a satisfying
46
+ assignment such as:
47
+
48
+ (1 = 2): False
49
+ (y = 1): True
50
+ (y < 1): True
51
+ (y > 1): True
52
+ (x = 0): True
53
+ (x < 0): True
54
+ (x > 0): True
55
+
56
+ Next you would pass this assignment to the LRASolver
57
+ which will be able to determine that this particular
58
+ assignment is satisfiable or not.
59
+
60
+ Note that since EncodedCNF is inherently non-deterministic,
61
+ the int each predicate is encoded as is not consistent. As a
62
+ result, the code below likely does not reflect the assignment
63
+ given above.
64
+
65
+ >>> lra.assert_lit(-1) #doctest: +SKIP
66
+ >>> lra.assert_lit(2) #doctest: +SKIP
67
+ >>> lra.assert_lit(3) #doctest: +SKIP
68
+ >>> lra.assert_lit(4) #doctest: +SKIP
69
+ >>> lra.assert_lit(5) #doctest: +SKIP
70
+ >>> lra.assert_lit(6) #doctest: +SKIP
71
+ >>> lra.assert_lit(7) #doctest: +SKIP
72
+ >>> is_sat, conflict_or_assignment = lra.check()
73
+
74
+ As the particular assignment suggested is not satisfiable,
75
+ the LRASolver will return unsat and a conflict clause when
76
+ given that assignment. The conflict clause will always be
77
+ minimal, but there can be multiple minimal conflict clauses.
78
+ One possible conflict clause could be `~(x < 0) | ~(x > 0)`.
79
+
80
+ We would then add whatever conflict clause is given to
81
+ `f` to prevent the SAT solver from coming up with an
82
+ assignment with the same conflicting literals. In this case,
83
+ the conflict clause `~(x < 0) | ~(x > 0)` would prevent
84
+ any assignment where both (x < 0) and (x > 0) were both
85
+ true.
86
+
87
+ The SAT solver would then find another assignment
88
+ and we would check that assignment with the LRASolver
89
+ and so on. Eventually either a satisfying assignment
90
+ that the SAT solver and LRASolver agreed on would be found
91
+ or enough conflict clauses would be added so that the
92
+ boolean formula was unsatisfiable.
93
+
94
+
95
+ This implementation is based on [1]_, which includes a
96
+ detailed explanation of the algorithm and pseudocode
97
+ for the most important functions.
98
+
99
+ [1]_ also explains how backtracking and theory propagation
100
+ could be implemented to speed up the current implementation,
101
+ but these are not currently implemented.
102
+
103
+ TODO:
104
+ - Handle non-rational real numbers
105
+ - Handle positive and negative infinity
106
+ - Implement backtracking and theory proposition
107
+ - Simplify matrix by removing unused variables using Gaussian elimination
108
+
109
+ References
110
+ ==========
111
+
112
+ .. [1] Dutertre, B., de Moura, L.:
113
+ A Fast Linear-Arithmetic Solver for DPLL(T)
114
+ https://link.springer.com/chapter/10.1007/11817963_11
115
+ """
116
+ from sympy.solvers.solveset import linear_eq_to_matrix
117
+ from sympy.matrices.dense import eye
118
+ from sympy.assumptions import Predicate
119
+ from sympy.assumptions.assume import AppliedPredicate
120
+ from sympy.assumptions.ask import Q
121
+ from sympy.core import Dummy
122
+ from sympy.core.mul import Mul
123
+ from sympy.core.add import Add
124
+ from sympy.core.relational import Eq, Ne
125
+ from sympy.core.sympify import sympify
126
+ from sympy.core.singleton import S
127
+ from sympy.core.numbers import Rational, oo
128
+ from sympy.matrices.dense import Matrix
129
+
130
+ class UnhandledInput(Exception):
131
+ """
132
+ Raised while creating an LRASolver if non-linearity
133
+ or non-rational numbers are present.
134
+ """
135
+
136
+ # predicates that LRASolver understands and makes use of
137
+ ALLOWED_PRED = {Q.eq, Q.gt, Q.lt, Q.le, Q.ge}
138
+
139
+ # if true ~Q.gt(x, y) implies Q.le(x, y)
140
+ HANDLE_NEGATION = True
141
+
142
+ class LRASolver():
143
+ """
144
+ Linear Arithmetic Solver for DPLL(T) implemented with an algorithm based on
145
+ the Dual Simplex method. Uses Bland's pivoting rule to avoid cycling.
146
+
147
+ References
148
+ ==========
149
+
150
+ .. [1] Dutertre, B., de Moura, L.:
151
+ A Fast Linear-Arithmetic Solver for DPLL(T)
152
+ https://link.springer.com/chapter/10.1007/11817963_11
153
+ """
154
+
155
+ def __init__(self, A, slack_variables, nonslack_variables, enc_to_boundary, s_subs, testing_mode):
156
+ """
157
+ Use the "from_encoded_cnf" method to create a new LRASolver.
158
+ """
159
+ self.run_checks = testing_mode
160
+ self.s_subs = s_subs # used only for test_lra_theory.test_random_problems
161
+
162
+ if any(not isinstance(a, Rational) for a in A):
163
+ raise UnhandledInput("Non-rational numbers are not handled")
164
+ if any(not isinstance(b.bound, Rational) for b in enc_to_boundary.values()):
165
+ raise UnhandledInput("Non-rational numbers are not handled")
166
+ m, n = len(slack_variables), len(slack_variables)+len(nonslack_variables)
167
+ if m != 0:
168
+ assert A.shape == (m, n)
169
+ if self.run_checks:
170
+ assert A[:, n-m:] == -eye(m)
171
+
172
+ self.enc_to_boundary = enc_to_boundary # mapping of int to Boundary objects
173
+ self.boundary_to_enc = {value: key for key, value in enc_to_boundary.items()}
174
+ self.A = A
175
+ self.slack = slack_variables
176
+ self.nonslack = nonslack_variables
177
+ self.all_var = nonslack_variables + slack_variables
178
+
179
+ self.slack_set = set(slack_variables)
180
+
181
+ self.is_sat = True # While True, all constraints asserted so far are satisfiable
182
+ self.result = None # always one of: (True, assignment), (False, conflict clause), None
183
+
184
+ @staticmethod
185
+ def from_encoded_cnf(encoded_cnf, testing_mode=False):
186
+ """
187
+ Creates an LRASolver from an EncodedCNF object
188
+ and a list of conflict clauses for propositions
189
+ that can be simplified to True or False.
190
+
191
+ Parameters
192
+ ==========
193
+
194
+ encoded_cnf : EncodedCNF
195
+
196
+ testing_mode : bool
197
+ Setting testing_mode to True enables some slow assert statements
198
+ and sorting to reduce nonterministic behavior.
199
+
200
+ Returns
201
+ =======
202
+
203
+ (lra, conflicts)
204
+
205
+ lra : LRASolver
206
+
207
+ conflicts : list
208
+ Contains a one-literal conflict clause for each proposition
209
+ that can be simplified to True or False.
210
+
211
+ Example
212
+ =======
213
+
214
+ >>> from sympy.core.relational import Eq
215
+ >>> from sympy.assumptions.cnf import CNF, EncodedCNF
216
+ >>> from sympy.assumptions.ask import Q
217
+ >>> from sympy.logic.algorithms.lra_theory import LRASolver
218
+ >>> from sympy.abc import x, y, z
219
+ >>> phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6))
220
+ >>> phi = phi & (Eq(x + y, 2) | (x + 2 * y - z > 4))
221
+ >>> phi = phi & Q.gt(2, 1)
222
+ >>> cnf = CNF.from_prop(phi)
223
+ >>> enc = EncodedCNF()
224
+ >>> enc.from_cnf(cnf)
225
+ >>> lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True)
226
+ >>> lra #doctest: +SKIP
227
+ <sympy.logic.algorithms.lra_theory.LRASolver object at 0x7fdcb0e15b70>
228
+ >>> conflicts #doctest: +SKIP
229
+ [[4]]
230
+ """
231
+ # This function has three main jobs:
232
+ # - raise errors if the input formula is not handled
233
+ # - preprocesses the formula into a matrix and single variable constraints
234
+ # - create one-literal conflict clauses from predicates that are always True
235
+ # or always False such as Q.gt(3, 2)
236
+ #
237
+ # See the preprocessing section of "A Fast Linear-Arithmetic Solver for DPLL(T)"
238
+ # for an explanation of how the formula is converted into a matrix
239
+ # and a set of single variable constraints.
240
+
241
+ encoding = {} # maps int to boundary
242
+ A = []
243
+
244
+ basic = []
245
+ s_count = 0
246
+ s_subs = {}
247
+ nonbasic = []
248
+
249
+ if testing_mode:
250
+ # sort to reduce nondeterminism
251
+ encoded_cnf_items = sorted(encoded_cnf.encoding.items(), key=lambda x: str(x))
252
+ else:
253
+ encoded_cnf_items = encoded_cnf.encoding.items()
254
+
255
+ empty_var = Dummy()
256
+ var_to_lra_var = {}
257
+ conflicts = []
258
+
259
+ for prop, enc in encoded_cnf_items:
260
+ if isinstance(prop, Predicate):
261
+ prop = prop(empty_var)
262
+ if not isinstance(prop, AppliedPredicate):
263
+ if prop == True:
264
+ conflicts.append([enc])
265
+ continue
266
+ if prop == False:
267
+ conflicts.append([-enc])
268
+ continue
269
+
270
+ raise ValueError(f"Unhandled Predicate: {prop}")
271
+
272
+ assert prop.function in ALLOWED_PRED
273
+ if prop.lhs == S.NaN or prop.rhs == S.NaN:
274
+ raise ValueError(f"{prop} contains nan")
275
+ if prop.lhs.is_imaginary or prop.rhs.is_imaginary:
276
+ raise UnhandledInput(f"{prop} contains an imaginary component")
277
+ if prop.lhs == oo or prop.rhs == oo:
278
+ raise UnhandledInput(f"{prop} contains infinity")
279
+
280
+ prop = _eval_binrel(prop) # simplify variable-less quantities to True / False if possible
281
+ if prop == True:
282
+ conflicts.append([enc])
283
+ continue
284
+ elif prop == False:
285
+ conflicts.append([-enc])
286
+ continue
287
+ elif prop is None:
288
+ raise UnhandledInput(f"{prop} could not be simplified")
289
+
290
+ expr = prop.lhs - prop.rhs
291
+ if prop.function in [Q.ge, Q.gt]:
292
+ expr = -expr
293
+
294
+ # expr should be less than (or equal to) 0
295
+ # otherwise prop is False
296
+ if prop.function in [Q.le, Q.ge]:
297
+ bool = (expr <= 0)
298
+ elif prop.function in [Q.lt, Q.gt]:
299
+ bool = (expr < 0)
300
+ else:
301
+ assert prop.function == Q.eq
302
+ bool = Eq(expr, 0)
303
+
304
+ if bool == True:
305
+ conflicts.append([enc])
306
+ continue
307
+ elif bool == False:
308
+ conflicts.append([-enc])
309
+ continue
310
+
311
+
312
+ vars, const = _sep_const_terms(expr) # example: (2x + 3y + 2) --> (2x + 3y), (2)
313
+ vars, var_coeff = _sep_const_coeff(vars) # examples: (2x) --> (x, 2); (2x + 3y) --> (2x + 3y), (1)
314
+ const = const / var_coeff
315
+
316
+ terms = _list_terms(vars) # example: (2x + 3y) --> [2x, 3y]
317
+ for term in terms:
318
+ term, _ = _sep_const_coeff(term)
319
+ assert len(term.free_symbols) > 0
320
+ if term not in var_to_lra_var:
321
+ var_to_lra_var[term] = LRAVariable(term)
322
+ nonbasic.append(term)
323
+
324
+ if len(terms) > 1:
325
+ if vars not in s_subs:
326
+ s_count += 1
327
+ d = Dummy(f"s{s_count}")
328
+ var_to_lra_var[d] = LRAVariable(d)
329
+ basic.append(d)
330
+ s_subs[vars] = d
331
+ A.append(vars - d)
332
+ var = s_subs[vars]
333
+ else:
334
+ var = terms[0]
335
+
336
+ assert var_coeff != 0
337
+
338
+ equality = prop.function == Q.eq
339
+ upper = var_coeff > 0 if not equality else None
340
+ strict = prop.function in [Q.gt, Q.lt]
341
+ b = Boundary(var_to_lra_var[var], -const, upper, equality, strict)
342
+ encoding[enc] = b
343
+
344
+ fs = [v.free_symbols for v in nonbasic + basic]
345
+ assert all(len(syms) > 0 for syms in fs)
346
+ fs_count = sum(len(syms) for syms in fs)
347
+ if len(fs) > 0 and len(set.union(*fs)) < fs_count:
348
+ raise UnhandledInput("Nonlinearity is not handled")
349
+
350
+ A, _ = linear_eq_to_matrix(A, nonbasic + basic)
351
+ nonbasic = [var_to_lra_var[nb] for nb in nonbasic]
352
+ basic = [var_to_lra_var[b] for b in basic]
353
+ for idx, var in enumerate(nonbasic + basic):
354
+ var.col_idx = idx
355
+
356
+ return LRASolver(A, basic, nonbasic, encoding, s_subs, testing_mode), conflicts
357
+
358
+ def reset_bounds(self):
359
+ """
360
+ Resets the state of the LRASolver to before
361
+ anything was asserted.
362
+ """
363
+ self.result = None
364
+ for var in self.all_var:
365
+ var.lower = LRARational(-float("inf"), 0)
366
+ var.lower_from_eq = False
367
+ var.lower_from_neg = False
368
+ var.upper = LRARational(float("inf"), 0)
369
+ var.upper_from_eq= False
370
+ var.lower_from_neg = False
371
+ var.assign = LRARational(0, 0)
372
+
373
+ def assert_lit(self, enc_constraint):
374
+ """
375
+ Assert a literal representing a constraint
376
+ and update the internal state accordingly.
377
+
378
+ Note that due to peculiarities of this implementation
379
+ asserting ~(x > 0) will assert (x <= 0) but asserting
380
+ ~Eq(x, 0) will not do anything.
381
+
382
+ Parameters
383
+ ==========
384
+
385
+ enc_constraint : int
386
+ A mapping of encodings to constraints
387
+ can be found in `self.enc_to_boundary`.
388
+
389
+ Returns
390
+ =======
391
+
392
+ None or (False, explanation)
393
+
394
+ explanation : set of ints
395
+ A conflict clause that "explains" why
396
+ the literals asserted so far are unsatisfiable.
397
+ """
398
+ if abs(enc_constraint) not in self.enc_to_boundary:
399
+ return None
400
+
401
+ if not HANDLE_NEGATION and enc_constraint < 0:
402
+ return None
403
+
404
+ boundary = self.enc_to_boundary[abs(enc_constraint)]
405
+ sym, c, negated = boundary.var, boundary.bound, enc_constraint < 0
406
+
407
+ if boundary.equality and negated:
408
+ return None # negated equality is not handled and should only appear in conflict clauses
409
+
410
+ upper = boundary.upper != negated
411
+ if boundary.strict != negated:
412
+ delta = -1 if upper else 1
413
+ c = LRARational(c, delta)
414
+ else:
415
+ c = LRARational(c, 0)
416
+
417
+ if boundary.equality:
418
+ res1 = self._assert_lower(sym, c, from_equality=True, from_neg=negated)
419
+ if res1 and res1[0] == False:
420
+ res = res1
421
+ else:
422
+ res2 = self._assert_upper(sym, c, from_equality=True, from_neg=negated)
423
+ res = res2
424
+ elif upper:
425
+ res = self._assert_upper(sym, c, from_neg=negated)
426
+ else:
427
+ res = self._assert_lower(sym, c, from_neg=negated)
428
+
429
+ if self.is_sat and sym not in self.slack_set:
430
+ self.is_sat = res is None
431
+ else:
432
+ self.is_sat = False
433
+
434
+ return res
435
+
436
+ def _assert_upper(self, xi, ci, from_equality=False, from_neg=False):
437
+ """
438
+ Adjusts the upper bound on variable xi if the new upper bound is
439
+ more limiting. The assignment of variable xi is adjusted to be
440
+ within the new bound if needed.
441
+
442
+ Also calls `self._update` to update the assignment for slack variables
443
+ to keep all equalities satisfied.
444
+ """
445
+ if self.result:
446
+ assert self.result[0] != False
447
+ self.result = None
448
+ if ci >= xi.upper:
449
+ return None
450
+ if ci < xi.lower:
451
+ assert (xi.lower[1] >= 0) is True
452
+ assert (ci[1] <= 0) is True
453
+
454
+ lit1, neg1 = Boundary.from_lower(xi)
455
+
456
+ lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=True, equality=from_equality)
457
+ if from_neg:
458
+ lit2 = lit2.get_negated()
459
+ neg2 = -1 if from_neg else 1
460
+
461
+ conflict = [-neg1*self.boundary_to_enc[lit1], -neg2*self.boundary_to_enc[lit2]]
462
+ self.result = False, conflict
463
+ return self.result
464
+ xi.upper = ci
465
+ xi.upper_from_eq = from_equality
466
+ xi.upper_from_neg = from_neg
467
+ if xi in self.nonslack and xi.assign > ci:
468
+ self._update(xi, ci)
469
+
470
+ if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
471
+ for v in self.all_var):
472
+ M = self.A
473
+ X = Matrix([v.assign[0] for v in self.all_var])
474
+ assert all(abs(val) < 10 ** (-10) for val in M * X)
475
+
476
+ return None
477
+
478
+ def _assert_lower(self, xi, ci, from_equality=False, from_neg=False):
479
+ """
480
+ Adjusts the lower bound on variable xi if the new lower bound is
481
+ more limiting. The assignment of variable xi is adjusted to be
482
+ within the new bound if needed.
483
+
484
+ Also calls `self._update` to update the assignment for slack variables
485
+ to keep all equalities satisfied.
486
+ """
487
+ if self.result:
488
+ assert self.result[0] != False
489
+ self.result = None
490
+ if ci <= xi.lower:
491
+ return None
492
+ if ci > xi.upper:
493
+ assert (xi.upper[1] <= 0) is True
494
+ assert (ci[1] >= 0) is True
495
+
496
+ lit1, neg1 = Boundary.from_upper(xi)
497
+
498
+ lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=False, equality=from_equality)
499
+ if from_neg:
500
+ lit2 = lit2.get_negated()
501
+ neg2 = -1 if from_neg else 1
502
+
503
+ conflict = [-neg1*self.boundary_to_enc[lit1],-neg2*self.boundary_to_enc[lit2]]
504
+ self.result = False, conflict
505
+ return self.result
506
+ xi.lower = ci
507
+ xi.lower_from_eq = from_equality
508
+ xi.lower_from_neg = from_neg
509
+ if xi in self.nonslack and xi.assign < ci:
510
+ self._update(xi, ci)
511
+
512
+ if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
513
+ for v in self.all_var):
514
+ M = self.A
515
+ X = Matrix([v.assign[0] for v in self.all_var])
516
+ assert all(abs(val) < 10 ** (-10) for val in M * X)
517
+
518
+ return None
519
+
520
+ def _update(self, xi, v):
521
+ """
522
+ Updates all slack variables that have equations that contain
523
+ variable xi so that they stay satisfied given xi is equal to v.
524
+ """
525
+ i = xi.col_idx
526
+ for j, b in enumerate(self.slack):
527
+ aji = self.A[j, i]
528
+ b.assign = b.assign + (v - xi.assign)*aji
529
+ xi.assign = v
530
+
531
+ def check(self):
532
+ """
533
+ Searches for an assignment that satisfies all constraints
534
+ or determines that no such assignment exists and gives
535
+ a minimal conflict clause that "explains" why the
536
+ constraints are unsatisfiable.
537
+
538
+ Returns
539
+ =======
540
+
541
+ (True, assignment) or (False, explanation)
542
+
543
+ assignment : dict of LRAVariables to values
544
+ Assigned values are tuples that represent a rational number
545
+ plus some infinatesimal delta.
546
+
547
+ explanation : set of ints
548
+ """
549
+ if self.is_sat:
550
+ return True, {var: var.assign for var in self.all_var}
551
+ if self.result:
552
+ return self.result
553
+
554
+ from sympy.matrices.dense import Matrix
555
+ M = self.A.copy()
556
+ basic = {s: i for i, s in enumerate(self.slack)} # contains the row index associated with each basic variable
557
+ nonbasic = set(self.nonslack)
558
+ while True:
559
+ if self.run_checks:
560
+ # nonbasic variables must always be within bounds
561
+ assert all(((nb.assign >= nb.lower) == True) and ((nb.assign <= nb.upper) == True) for nb in nonbasic)
562
+
563
+ # assignments for x must always satisfy Ax = 0
564
+ # probably have to turn this off when dealing with strict ineq
565
+ if all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
566
+ for v in self.all_var):
567
+ X = Matrix([v.assign[0] for v in self.all_var])
568
+ assert all(abs(val) < 10**(-10) for val in M*X)
569
+
570
+ # check upper and lower match this format:
571
+ # x <= rat + delta iff x < rat
572
+ # x >= rat - delta iff x > rat
573
+ # this wouldn't make sense:
574
+ # x <= rat - delta
575
+ # x >= rat + delta
576
+ assert all(x.upper[1] <= 0 for x in self.all_var)
577
+ assert all(x.lower[1] >= 0 for x in self.all_var)
578
+
579
+ cand = [b for b in basic if b.assign < b.lower or b.assign > b.upper]
580
+
581
+ if len(cand) == 0:
582
+ return True, {var: var.assign for var in self.all_var}
583
+
584
+ xi = min(cand, key=lambda v: v.col_idx) # Bland's rule
585
+ i = basic[xi]
586
+
587
+ if xi.assign < xi.lower:
588
+ cand = [nb for nb in nonbasic
589
+ if (M[i, nb.col_idx] > 0 and nb.assign < nb.upper)
590
+ or (M[i, nb.col_idx] < 0 and nb.assign > nb.lower)]
591
+ if len(cand) == 0:
592
+ N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0]
593
+ N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0]
594
+
595
+ conflict = []
596
+ conflict += [Boundary.from_upper(nb) for nb in N_plus]
597
+ conflict += [Boundary.from_lower(nb) for nb in N_minus]
598
+ conflict.append(Boundary.from_lower(xi))
599
+ conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict]
600
+ return False, conflict
601
+ xj = min(cand, key=str)
602
+ M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.lower)
603
+
604
+ if xi.assign > xi.upper:
605
+ cand = [nb for nb in nonbasic
606
+ if (M[i, nb.col_idx] < 0 and nb.assign < nb.upper)
607
+ or (M[i, nb.col_idx] > 0 and nb.assign > nb.lower)]
608
+
609
+ if len(cand) == 0:
610
+ N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0]
611
+ N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0]
612
+
613
+ conflict = []
614
+ conflict += [Boundary.from_upper(nb) for nb in N_minus]
615
+ conflict += [Boundary.from_lower(nb) for nb in N_plus]
616
+ conflict.append(Boundary.from_upper(xi))
617
+
618
+ conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict]
619
+ return False, conflict
620
+ xj = min(cand, key=lambda v: v.col_idx)
621
+ M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.upper)
622
+
623
+ def _pivot_and_update(self, M, basic, nonbasic, xi, xj, v):
624
+ """
625
+ Pivots basic variable xi with nonbasic variable xj,
626
+ and sets value of xi to v and adjusts the values of all basic variables
627
+ to keep equations satisfied.
628
+ """
629
+ i, j = basic[xi], xj.col_idx
630
+ assert M[i, j] != 0
631
+ theta = (v - xi.assign)*(1/M[i, j])
632
+ xi.assign = v
633
+ xj.assign = xj.assign + theta
634
+ for xk in basic:
635
+ if xk != xi:
636
+ k = basic[xk]
637
+ akj = M[k, j]
638
+ xk.assign = xk.assign + theta*akj
639
+ # pivot
640
+ basic[xj] = basic[xi]
641
+ del basic[xi]
642
+ nonbasic.add(xi)
643
+ nonbasic.remove(xj)
644
+ return self._pivot(M, i, j)
645
+
646
+ @staticmethod
647
+ def _pivot(M, i, j):
648
+ """
649
+ Performs a pivot operation about entry i, j of M by performing
650
+ a series of row operations on a copy of M and returning the result.
651
+ The original M is left unmodified.
652
+
653
+ Conceptually, M represents a system of equations and pivoting
654
+ can be thought of as rearranging equation i to be in terms of
655
+ variable j and then substituting in the rest of the equations
656
+ to get rid of other occurances of variable j.
657
+
658
+ Example
659
+ =======
660
+
661
+ >>> from sympy.matrices.dense import Matrix
662
+ >>> from sympy.logic.algorithms.lra_theory import LRASolver
663
+ >>> from sympy import var
664
+ >>> Matrix(3, 3, var('a:i'))
665
+ Matrix([
666
+ [a, b, c],
667
+ [d, e, f],
668
+ [g, h, i]])
669
+
670
+ This matrix is equivalent to:
671
+ 0 = a*x + b*y + c*z
672
+ 0 = d*x + e*y + f*z
673
+ 0 = g*x + h*y + i*z
674
+
675
+ >>> LRASolver._pivot(_, 1, 0)
676
+ Matrix([
677
+ [ 0, -a*e/d + b, -a*f/d + c],
678
+ [-1, -e/d, -f/d],
679
+ [ 0, h - e*g/d, i - f*g/d]])
680
+
681
+ We rearrange equation 1 in terms of variable 0 (x)
682
+ and substitute to remove x from the other equations.
683
+
684
+ 0 = 0 + (-a*e/d + b)*y + (-a*f/d + c)*z
685
+ 0 = -x + (-e/d)*y + (-f/d)*z
686
+ 0 = 0 + (h - e*g/d)*y + (i - f*g/d)*z
687
+ """
688
+ _, _, Mij = M[i, :], M[:, j], M[i, j]
689
+ if Mij == 0:
690
+ raise ZeroDivisionError("Tried to pivot about zero-valued entry.")
691
+ A = M.copy()
692
+ A[i, :] = -A[i, :]/Mij
693
+ for row in range(M.shape[0]):
694
+ if row != i:
695
+ A[row, :] = A[row, :] + A[row, j] * A[i, :]
696
+
697
+ return A
698
+
699
+
700
+ def _sep_const_coeff(expr):
701
+ """
702
+ Example
703
+ =======
704
+
705
+ >>> from sympy.logic.algorithms.lra_theory import _sep_const_coeff
706
+ >>> from sympy.abc import x, y
707
+ >>> _sep_const_coeff(2*x)
708
+ (x, 2)
709
+ >>> _sep_const_coeff(2*x + 3*y)
710
+ (2*x + 3*y, 1)
711
+ """
712
+ if isinstance(expr, Add):
713
+ return expr, sympify(1)
714
+
715
+ if isinstance(expr, Mul):
716
+ coeffs = expr.args
717
+ else:
718
+ coeffs = [expr]
719
+
720
+ var, const = [], []
721
+ for c in coeffs:
722
+ c = sympify(c)
723
+ if len(c.free_symbols)==0:
724
+ const.append(c)
725
+ else:
726
+ var.append(c)
727
+ return Mul(*var), Mul(*const)
728
+
729
+
730
+ def _list_terms(expr):
731
+ if not isinstance(expr, Add):
732
+ return [expr]
733
+
734
+ return expr.args
735
+
736
+
737
+ def _sep_const_terms(expr):
738
+ """
739
+ Example
740
+ =======
741
+
742
+ >>> from sympy.logic.algorithms.lra_theory import _sep_const_terms
743
+ >>> from sympy.abc import x, y
744
+ >>> _sep_const_terms(2*x + 3*y + 2)
745
+ (2*x + 3*y, 2)
746
+ """
747
+ if isinstance(expr, Add):
748
+ terms = expr.args
749
+ else:
750
+ terms = [expr]
751
+
752
+ var, const = [], []
753
+ for t in terms:
754
+ if len(t.free_symbols) == 0:
755
+ const.append(t)
756
+ else:
757
+ var.append(t)
758
+ return sum(var), sum(const)
759
+
760
+
761
+ def _eval_binrel(binrel):
762
+ """
763
+ Simplify binary relation to True / False if possible.
764
+ """
765
+ if not (len(binrel.lhs.free_symbols) == 0 and len(binrel.rhs.free_symbols) == 0):
766
+ return binrel
767
+ if binrel.function == Q.lt:
768
+ res = binrel.lhs < binrel.rhs
769
+ elif binrel.function == Q.gt:
770
+ res = binrel.lhs > binrel.rhs
771
+ elif binrel.function == Q.le:
772
+ res = binrel.lhs <= binrel.rhs
773
+ elif binrel.function == Q.ge:
774
+ res = binrel.lhs >= binrel.rhs
775
+ elif binrel.function == Q.eq:
776
+ res = Eq(binrel.lhs, binrel.rhs)
777
+ elif binrel.function == Q.ne:
778
+ res = Ne(binrel.lhs, binrel.rhs)
779
+
780
+ if res == True or res == False:
781
+ return res
782
+ else:
783
+ return None
784
+
785
+
786
+ class Boundary:
787
+ """
788
+ Represents an upper or lower bound or an equality between a symbol
789
+ and some constant.
790
+ """
791
+ def __init__(self, var, const, upper, equality, strict=None):
792
+ if not equality in [True, False]:
793
+ assert equality in [True, False]
794
+
795
+
796
+ self.var = var
797
+ if isinstance(const, tuple):
798
+ s = const[1] != 0
799
+ if strict:
800
+ assert s == strict
801
+ self.bound = const[0]
802
+ self.strict = s
803
+ else:
804
+ self.bound = const
805
+ self.strict = strict
806
+ self.upper = upper if not equality else None
807
+ self.equality = equality
808
+ self.strict = strict
809
+ assert self.strict is not None
810
+
811
+ @staticmethod
812
+ def from_upper(var):
813
+ neg = -1 if var.upper_from_neg else 1
814
+ b = Boundary(var, var.upper[0], True, var.upper_from_eq, var.upper[1] != 0)
815
+ if neg < 0:
816
+ b = b.get_negated()
817
+ return b, neg
818
+
819
+ @staticmethod
820
+ def from_lower(var):
821
+ neg = -1 if var.lower_from_neg else 1
822
+ b = Boundary(var, var.lower[0], False, var.lower_from_eq, var.lower[1] != 0)
823
+ if neg < 0:
824
+ b = b.get_negated()
825
+ return b, neg
826
+
827
+ def get_negated(self):
828
+ return Boundary(self.var, self.bound, not self.upper, self.equality, not self.strict)
829
+
830
+ def get_inequality(self):
831
+ if self.equality:
832
+ return Eq(self.var.var, self.bound)
833
+ elif self.upper and self.strict:
834
+ return self.var.var < self.bound
835
+ elif not self.upper and self.strict:
836
+ return self.var.var > self.bound
837
+ elif self.upper:
838
+ return self.var.var <= self.bound
839
+ else:
840
+ return self.var.var >= self.bound
841
+
842
+ def __repr__(self):
843
+ return repr("Boundary(" + repr(self.get_inequality()) + ")")
844
+
845
+ def __eq__(self, other):
846
+ other = (other.var, other.bound, other.strict, other.upper, other.equality)
847
+ return (self.var, self.bound, self.strict, self.upper, self.equality) == other
848
+
849
+ def __hash__(self):
850
+ return hash((self.var, self.bound, self.strict, self.upper, self.equality))
851
+
852
+
853
+ class LRARational():
854
+ """
855
+ Represents a rational plus or minus some amount
856
+ of arbitrary small deltas.
857
+ """
858
+ def __init__(self, rational, delta):
859
+ self.value = (rational, delta)
860
+
861
+ def __lt__(self, other):
862
+ return self.value < other.value
863
+
864
+ def __le__(self, other):
865
+ return self.value <= other.value
866
+
867
+ def __eq__(self, other):
868
+ return self.value == other.value
869
+
870
+ def __add__(self, other):
871
+ return LRARational(self.value[0] + other.value[0], self.value[1] + other.value[1])
872
+
873
+ def __sub__(self, other):
874
+ return LRARational(self.value[0] - other.value[0], self.value[1] - other.value[1])
875
+
876
+ def __mul__(self, other):
877
+ assert not isinstance(other, LRARational)
878
+ return LRARational(self.value[0] * other, self.value[1] * other)
879
+
880
+ def __getitem__(self, index):
881
+ return self.value[index]
882
+
883
+ def __repr__(self):
884
+ return repr(self.value)
885
+
886
+
887
+ class LRAVariable():
888
+ """
889
+ Object to keep track of upper and lower bounds
890
+ on `self.var`.
891
+ """
892
+ def __init__(self, var):
893
+ self.upper = LRARational(float("inf"), 0)
894
+ self.upper_from_eq = False
895
+ self.upper_from_neg = False
896
+ self.lower = LRARational(-float("inf"), 0)
897
+ self.lower_from_eq = False
898
+ self.lower_from_neg = False
899
+ self.assign = LRARational(0,0)
900
+ self.var = var
901
+ self.col_idx = None
902
+
903
+ def __repr__(self):
904
+ return repr(self.var)
905
+
906
+ def __eq__(self, other):
907
+ if not isinstance(other, LRAVariable):
908
+ return False
909
+ return other.var == self.var
910
+
911
+ def __hash__(self):
912
+ return hash(self.var)
.venv/lib/python3.13/site-packages/sympy/logic/algorithms/minisat22_wrapper.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.assumptions.cnf import EncodedCNF
2
+
3
+ def minisat22_satisfiable(expr, all_models=False, minimal=False):
4
+
5
+ if not isinstance(expr, EncodedCNF):
6
+ exprs = EncodedCNF()
7
+ exprs.add_prop(expr)
8
+ expr = exprs
9
+
10
+ from pysat.solvers import Minisat22
11
+
12
+ # Return UNSAT when False (encoded as 0) is present in the CNF
13
+ if {0} in expr.data:
14
+ if all_models:
15
+ return (f for f in [False])
16
+ return False
17
+
18
+ r = Minisat22(expr.data)
19
+
20
+ if minimal:
21
+ r.set_phases([-(i+1) for i in range(r.nof_vars())])
22
+
23
+ if not r.solve():
24
+ return False
25
+
26
+ if not all_models:
27
+ return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r.get_model()}
28
+
29
+ else:
30
+ # Make solutions SymPy compatible by creating a generator
31
+ def _gen(results):
32
+ satisfiable = False
33
+ while results.solve():
34
+ sol = results.get_model()
35
+ yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol}
36
+ if minimal:
37
+ results.add_clause([-i for i in sol if i>0])
38
+ else:
39
+ results.add_clause([-i for i in sol])
40
+ satisfiable = True
41
+ if not satisfiable:
42
+ yield False
43
+ raise StopIteration
44
+
45
+
46
+ return _gen(r)
.venv/lib/python3.13/site-packages/sympy/logic/algorithms/pycosat_wrapper.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.assumptions.cnf import EncodedCNF
2
+
3
+
4
+ def pycosat_satisfiable(expr, all_models=False):
5
+ import pycosat
6
+ if not isinstance(expr, EncodedCNF):
7
+ exprs = EncodedCNF()
8
+ exprs.add_prop(expr)
9
+ expr = exprs
10
+
11
+ # Return UNSAT when False (encoded as 0) is present in the CNF
12
+ if {0} in expr.data:
13
+ if all_models:
14
+ return (f for f in [False])
15
+ return False
16
+
17
+ if not all_models:
18
+ r = pycosat.solve(expr.data)
19
+ result = (r != "UNSAT")
20
+ if not result:
21
+ return result
22
+ return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r}
23
+ else:
24
+ r = pycosat.itersolve(expr.data)
25
+ result = (r != "UNSAT")
26
+ if not result:
27
+ return result
28
+
29
+ # Make solutions SymPy compatible by creating a generator
30
+ def _gen(results):
31
+ satisfiable = False
32
+ try:
33
+ while True:
34
+ sol = next(results)
35
+ yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol}
36
+ satisfiable = True
37
+ except StopIteration:
38
+ if not satisfiable:
39
+ yield False
40
+
41
+ return _gen(r)
.venv/lib/python3.13/site-packages/sympy/logic/algorithms/z3_wrapper.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.printing.smtlib import smtlib_code
2
+ from sympy.assumptions.assume import AppliedPredicate
3
+ from sympy.assumptions.cnf import EncodedCNF
4
+ from sympy.assumptions.ask import Q
5
+
6
+ from sympy.core import Add, Mul
7
+ from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan
8
+ from sympy.functions.elementary.complexes import Abs
9
+ from sympy.functions.elementary.exponential import Pow
10
+ from sympy.functions.elementary.miscellaneous import Min, Max
11
+ from sympy.logic.boolalg import And, Or, Xor, Implies
12
+ from sympy.logic.boolalg import Not, ITE
13
+ from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate
14
+ from sympy.external import import_module
15
+
16
+ def z3_satisfiable(expr, all_models=False):
17
+ if not isinstance(expr, EncodedCNF):
18
+ exprs = EncodedCNF()
19
+ exprs.add_prop(expr)
20
+ expr = exprs
21
+
22
+ z3 = import_module("z3")
23
+ if z3 is None:
24
+ raise ImportError("z3 is not installed")
25
+
26
+ s = encoded_cnf_to_z3_solver(expr, z3)
27
+
28
+ res = str(s.check())
29
+ if res == "unsat":
30
+ return False
31
+ elif res == "sat":
32
+ return z3_model_to_sympy_model(s.model(), expr)
33
+ else:
34
+ return None
35
+
36
+
37
+ def z3_model_to_sympy_model(z3_model, enc_cnf):
38
+ rev_enc = {value : key for key, value in enc_cnf.encoding.items()}
39
+ return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model}
40
+
41
+
42
+ def clause_to_assertion(clause):
43
+ clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause]
44
+ return "(assert (or " + " ".join(clause_strings) + "))"
45
+
46
+
47
+ def encoded_cnf_to_z3_solver(enc_cnf, z3):
48
+ def dummify_bool(pred):
49
+ return False
50
+ assert isinstance(pred, AppliedPredicate)
51
+
52
+ if pred.function in [Q.positive, Q.negative, Q.zero]:
53
+ return pred
54
+ else:
55
+ return False
56
+
57
+ s = z3.Solver()
58
+
59
+ declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables]
60
+ assertions = [clause_to_assertion(clause) for clause in enc_cnf.data]
61
+
62
+ symbols = set()
63
+ for pred, enc in enc_cnf.encoding.items():
64
+ if not isinstance(pred, AppliedPredicate):
65
+ continue
66
+ if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive):
67
+ continue
68
+
69
+ pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions)
70
+
71
+ symbols |= pred.free_symbols
72
+ pred = pred_str
73
+ clause = f"(implies d{enc} {pred})"
74
+ assertion = "(assert " + clause + ")"
75
+ assertions.append(assertion)
76
+
77
+ for sym in symbols:
78
+ declarations.append(f"(declare-const {sym} Real)")
79
+
80
+ declarations = "\n".join(declarations)
81
+ assertions = "\n".join(assertions)
82
+ s.from_string(declarations)
83
+ s.from_string(assertions)
84
+
85
+ return s
86
+
87
+
88
+ known_functions = {
89
+ Add: '+',
90
+ Mul: '*',
91
+
92
+ Equality: '=',
93
+ LessThan: '<=',
94
+ GreaterThan: '>=',
95
+ StrictLessThan: '<',
96
+ StrictGreaterThan: '>',
97
+
98
+ EqualityPredicate(): '=',
99
+ LessThanPredicate(): '<=',
100
+ GreaterThanPredicate(): '>=',
101
+ StrictLessThanPredicate(): '<',
102
+ StrictGreaterThanPredicate(): '>',
103
+
104
+ Abs: 'abs',
105
+ Min: 'min',
106
+ Max: 'max',
107
+ Pow: '^',
108
+
109
+ And: 'and',
110
+ Or: 'or',
111
+ Xor: 'xor',
112
+ Not: 'not',
113
+ ITE: 'ite',
114
+ Implies: '=>',
115
+ }
.venv/lib/python3.13/site-packages/sympy/logic/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/logic/tests/test_boolalg.py ADDED
@@ -0,0 +1,1367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.assumptions.ask import Q
2
+ from sympy.assumptions.refine import refine
3
+ from sympy.core.numbers import oo
4
+ from sympy.core.relational import Equality, Eq, Ne
5
+ from sympy.core.singleton import S
6
+ from sympy.core.symbol import (Dummy, symbols)
7
+ from sympy.functions import Piecewise
8
+ from sympy.functions.elementary.trigonometric import cos, sin
9
+ from sympy.sets.sets import Interval, Union
10
+ from sympy.sets.contains import Contains
11
+ from sympy.simplify.simplify import simplify
12
+ from sympy.logic.boolalg import (
13
+ And, Boolean, Equivalent, ITE, Implies, Nand, Nor, Not, Or,
14
+ POSform, SOPform, Xor, Xnor, conjuncts, disjuncts,
15
+ distribute_or_over_and, distribute_and_over_or,
16
+ eliminate_implications, is_nnf, is_cnf, is_dnf, simplify_logic,
17
+ to_nnf, to_cnf, to_dnf, to_int_repr, bool_map, true, false,
18
+ BooleanAtom, is_literal, term_to_integer,
19
+ truth_table, as_Boolean, to_anf, is_anf, distribute_xor_over_and,
20
+ anf_coeffs, ANFform, bool_minterm, bool_maxterm, bool_monomial,
21
+ _check_pair, _convert_to_varsSOP, _convert_to_varsPOS, Exclusive,
22
+ gateinputcount)
23
+ from sympy.assumptions.cnf import CNF
24
+
25
+ from sympy.testing.pytest import raises, XFAIL, slow
26
+
27
+ from itertools import combinations, permutations, product
28
+
29
+ A, B, C, D = symbols('A:D')
30
+ a, b, c, d, e, w, x, y, z = symbols('a:e w:z')
31
+
32
+
33
+ def test_overloading():
34
+ """Test that |, & are overloaded as expected"""
35
+
36
+ assert A & B == And(A, B)
37
+ assert A | B == Or(A, B)
38
+ assert (A & B) | C == Or(And(A, B), C)
39
+ assert A >> B == Implies(A, B)
40
+ assert A << B == Implies(B, A)
41
+ assert ~A == Not(A)
42
+ assert A ^ B == Xor(A, B)
43
+
44
+
45
+ def test_And():
46
+ assert And() is true
47
+ assert And(A) == A
48
+ assert And(True) is true
49
+ assert And(False) is false
50
+ assert And(True, True) is true
51
+ assert And(True, False) is false
52
+ assert And(False, False) is false
53
+ assert And(True, A) == A
54
+ assert And(False, A) is false
55
+ assert And(True, True, True) is true
56
+ assert And(True, True, A) == A
57
+ assert And(True, False, A) is false
58
+ assert And(1, A) == A
59
+ raises(TypeError, lambda: And(2, A))
60
+ assert And(A < 1, A >= 1) is false
61
+ e = A > 1
62
+ assert And(e, e.canonical) == e.canonical
63
+ g, l, ge, le = A > B, B < A, A >= B, B <= A
64
+ assert And(g, l, ge, le) == And(ge, g)
65
+ assert {And(*i) for i in permutations((l, g, le, ge))} == {And(ge, g)}
66
+ assert And(And(Eq(a, 0), Eq(b, 0)), And(Ne(a, 0), Eq(c, 0))) is false
67
+
68
+
69
+ def test_Or():
70
+ assert Or() is false
71
+ assert Or(A) == A
72
+ assert Or(True) is true
73
+ assert Or(False) is false
74
+ assert Or(True, True) is true
75
+ assert Or(True, False) is true
76
+ assert Or(False, False) is false
77
+ assert Or(True, A) is true
78
+ assert Or(False, A) == A
79
+ assert Or(True, False, False) is true
80
+ assert Or(True, False, A) is true
81
+ assert Or(False, False, A) == A
82
+ assert Or(1, A) is true
83
+ raises(TypeError, lambda: Or(2, A))
84
+ assert Or(A < 1, A >= 1) is true
85
+ e = A > 1
86
+ assert Or(e, e.canonical) == e
87
+ g, l, ge, le = A > B, B < A, A >= B, B <= A
88
+ assert Or(g, l, ge, le) == Or(g, ge)
89
+
90
+
91
+ def test_Xor():
92
+ assert Xor() is false
93
+ assert Xor(A) == A
94
+ assert Xor(A, A) is false
95
+ assert Xor(True, A, A) is true
96
+ assert Xor(A, A, A, A, A) == A
97
+ assert Xor(True, False, False, A, B) == ~Xor(A, B)
98
+ assert Xor(True) is true
99
+ assert Xor(False) is false
100
+ assert Xor(True, True) is false
101
+ assert Xor(True, False) is true
102
+ assert Xor(False, False) is false
103
+ assert Xor(True, A) == ~A
104
+ assert Xor(False, A) == A
105
+ assert Xor(True, False, False) is true
106
+ assert Xor(True, False, A) == ~A
107
+ assert Xor(False, False, A) == A
108
+ assert isinstance(Xor(A, B), Xor)
109
+ assert Xor(A, B, Xor(C, D)) == Xor(A, B, C, D)
110
+ assert Xor(A, B, Xor(B, C)) == Xor(A, C)
111
+ assert Xor(A < 1, A >= 1, B) == Xor(0, 1, B) == Xor(1, 0, B)
112
+ e = A > 1
113
+ assert Xor(e, e.canonical) == Xor(0, 0) == Xor(1, 1)
114
+
115
+
116
+ def test_rewrite_as_And():
117
+ expr = x ^ y
118
+ assert expr.rewrite(And) == (x | y) & (~x | ~y)
119
+
120
+
121
+ def test_rewrite_as_Or():
122
+ expr = x ^ y
123
+ assert expr.rewrite(Or) == (x & ~y) | (y & ~x)
124
+
125
+
126
+ def test_rewrite_as_Nand():
127
+ expr = (y & z) | (z & ~w)
128
+ assert expr.rewrite(Nand) == ~(~(y & z) & ~(z & ~w))
129
+
130
+
131
+ def test_rewrite_as_Nor():
132
+ expr = z & (y | ~w)
133
+ assert expr.rewrite(Nor) == ~(~z | ~(y | ~w))
134
+
135
+
136
+ def test_Not():
137
+ raises(TypeError, lambda: Not(True, False))
138
+ assert Not(True) is false
139
+ assert Not(False) is true
140
+ assert Not(0) is true
141
+ assert Not(1) is false
142
+ assert Not(2) is false
143
+
144
+
145
+ def test_Nand():
146
+ assert Nand() is false
147
+ assert Nand(A) == ~A
148
+ assert Nand(True) is false
149
+ assert Nand(False) is true
150
+ assert Nand(True, True) is false
151
+ assert Nand(True, False) is true
152
+ assert Nand(False, False) is true
153
+ assert Nand(True, A) == ~A
154
+ assert Nand(False, A) is true
155
+ assert Nand(True, True, True) is false
156
+ assert Nand(True, True, A) == ~A
157
+ assert Nand(True, False, A) is true
158
+
159
+
160
+ def test_Nor():
161
+ assert Nor() is true
162
+ assert Nor(A) == ~A
163
+ assert Nor(True) is false
164
+ assert Nor(False) is true
165
+ assert Nor(True, True) is false
166
+ assert Nor(True, False) is false
167
+ assert Nor(False, False) is true
168
+ assert Nor(True, A) is false
169
+ assert Nor(False, A) == ~A
170
+ assert Nor(True, True, True) is false
171
+ assert Nor(True, True, A) is false
172
+ assert Nor(True, False, A) is false
173
+
174
+
175
+ def test_Xnor():
176
+ assert Xnor() is true
177
+ assert Xnor(A) == ~A
178
+ assert Xnor(A, A) is true
179
+ assert Xnor(True, A, A) is false
180
+ assert Xnor(A, A, A, A, A) == ~A
181
+ assert Xnor(True) is false
182
+ assert Xnor(False) is true
183
+ assert Xnor(True, True) is true
184
+ assert Xnor(True, False) is false
185
+ assert Xnor(False, False) is true
186
+ assert Xnor(True, A) == A
187
+ assert Xnor(False, A) == ~A
188
+ assert Xnor(True, False, False) is false
189
+ assert Xnor(True, False, A) == A
190
+ assert Xnor(False, False, A) == ~A
191
+
192
+
193
+ def test_Implies():
194
+ raises(ValueError, lambda: Implies(A, B, C))
195
+ assert Implies(True, True) is true
196
+ assert Implies(True, False) is false
197
+ assert Implies(False, True) is true
198
+ assert Implies(False, False) is true
199
+ assert Implies(0, A) is true
200
+ assert Implies(1, 1) is true
201
+ assert Implies(1, 0) is false
202
+ assert A >> B == B << A
203
+ assert (A < 1) >> (A >= 1) == (A >= 1)
204
+ assert (A < 1) >> (S.One > A) is true
205
+ assert A >> A is true
206
+
207
+
208
+ def test_Equivalent():
209
+ assert Equivalent(A, B) == Equivalent(B, A) == Equivalent(A, B, A)
210
+ assert Equivalent() is true
211
+ assert Equivalent(A, A) == Equivalent(A) is true
212
+ assert Equivalent(True, True) == Equivalent(False, False) is true
213
+ assert Equivalent(True, False) == Equivalent(False, True) is false
214
+ assert Equivalent(A, True) == A
215
+ assert Equivalent(A, False) == Not(A)
216
+ assert Equivalent(A, B, True) == A & B
217
+ assert Equivalent(A, B, False) == ~A & ~B
218
+ assert Equivalent(1, A) == A
219
+ assert Equivalent(0, A) == Not(A)
220
+ assert Equivalent(A, Equivalent(B, C)) != Equivalent(Equivalent(A, B), C)
221
+ assert Equivalent(A < 1, A >= 1) is false
222
+ assert Equivalent(A < 1, A >= 1, 0) is false
223
+ assert Equivalent(A < 1, A >= 1, 1) is false
224
+ assert Equivalent(A < 1, S.One > A) == Equivalent(1, 1) == Equivalent(0, 0)
225
+ assert Equivalent(Equality(A, B), Equality(B, A)) is true
226
+
227
+
228
+ def test_Exclusive():
229
+ assert Exclusive(False, False, False) is true
230
+ assert Exclusive(True, False, False) is true
231
+ assert Exclusive(True, True, False) is false
232
+ assert Exclusive(True, True, True) is false
233
+
234
+
235
+ def test_equals():
236
+ assert Not(Or(A, B)).equals(And(Not(A), Not(B))) is True
237
+ assert Equivalent(A, B).equals((A >> B) & (B >> A)) is True
238
+ assert ((A | ~B) & (~A | B)).equals((~A & ~B) | (A & B)) is True
239
+ assert (A >> B).equals(~A >> ~B) is False
240
+ assert (A >> (B >> A)).equals(A >> (C >> A)) is False
241
+ raises(NotImplementedError, lambda: (A & B).equals(A > B))
242
+
243
+
244
+ def test_simplification_boolalg():
245
+ """
246
+ Test working of simplification methods.
247
+ """
248
+ set1 = [[0, 0, 1], [0, 1, 1], [1, 0, 0], [1, 1, 0]]
249
+ set2 = [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1]]
250
+ assert SOPform([x, y, z], set1) == Or(And(Not(x), z), And(Not(z), x))
251
+ assert Not(SOPform([x, y, z], set2)) == \
252
+ Not(Or(And(Not(x), Not(z)), And(x, z)))
253
+ assert POSform([x, y, z], set1 + set2) is true
254
+ assert SOPform([x, y, z], set1 + set2) is true
255
+ assert SOPform([Dummy(), Dummy(), Dummy()], set1 + set2) is true
256
+
257
+ minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1],
258
+ [1, 1, 1, 1]]
259
+ dontcares = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]]
260
+ assert (
261
+ SOPform([w, x, y, z], minterms, dontcares) ==
262
+ Or(And(y, z), And(Not(w), Not(x))))
263
+ assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z)
264
+
265
+ minterms = [1, 3, 7, 11, 15]
266
+ dontcares = [0, 2, 5]
267
+ assert (
268
+ SOPform([w, x, y, z], minterms, dontcares) ==
269
+ Or(And(y, z), And(Not(w), Not(x))))
270
+ assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z)
271
+
272
+ minterms = [1, [0, 0, 1, 1], 7, [1, 0, 1, 1],
273
+ [1, 1, 1, 1]]
274
+ dontcares = [0, [0, 0, 1, 0], 5]
275
+ assert (
276
+ SOPform([w, x, y, z], minterms, dontcares) ==
277
+ Or(And(y, z), And(Not(w), Not(x))))
278
+ assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z)
279
+
280
+ minterms = [1, {y: 1, z: 1}]
281
+ dontcares = [0, [0, 0, 1, 0], 5]
282
+ assert (
283
+ SOPform([w, x, y, z], minterms, dontcares) ==
284
+ Or(And(y, z), And(Not(w), Not(x))))
285
+ assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z)
286
+
287
+ minterms = [{y: 1, z: 1}, 1]
288
+ dontcares = [[0, 0, 0, 0]]
289
+
290
+ minterms = [[0, 0, 0]]
291
+ raises(ValueError, lambda: SOPform([w, x, y, z], minterms))
292
+ raises(ValueError, lambda: POSform([w, x, y, z], minterms))
293
+
294
+ raises(TypeError, lambda: POSform([w, x, y, z], ["abcdefg"]))
295
+
296
+ # test simplification
297
+ ans = And(A, Or(B, C))
298
+ assert simplify_logic(A & (B | C)) == ans
299
+ assert simplify_logic((A & B) | (A & C)) == ans
300
+ assert simplify_logic(Implies(A, B)) == Or(Not(A), B)
301
+ assert simplify_logic(Equivalent(A, B)) == \
302
+ Or(And(A, B), And(Not(A), Not(B)))
303
+ assert simplify_logic(And(Equality(A, 2), C)) == And(Equality(A, 2), C)
304
+ assert simplify_logic(And(Equality(A, 2), A)) == And(Equality(A, 2), A)
305
+ assert simplify_logic(And(Equality(A, B), C)) == And(Equality(A, B), C)
306
+ assert simplify_logic(Or(And(Equality(A, 3), B), And(Equality(A, 3), C))) \
307
+ == And(Equality(A, 3), Or(B, C))
308
+ b = (~x & ~y & ~z) | (~x & ~y & z)
309
+ e = And(A, b)
310
+ assert simplify_logic(e) == A & ~x & ~y
311
+ raises(ValueError, lambda: simplify_logic(A & (B | C), form='blabla'))
312
+ assert simplify(Or(x <= y, And(x < y, z))) == (x <= y)
313
+ assert simplify(Or(x <= y, And(y > x, z))) == (x <= y)
314
+ assert simplify(Or(x >= y, And(y < x, z))) == (x >= y)
315
+
316
+ # Check that expressions with nine variables or more are not simplified
317
+ # (without the force-flag)
318
+ a, b, c, d, e, f, g, h, j = symbols('a b c d e f g h j')
319
+ expr = a & b & c & d & e & f & g & h & j | \
320
+ a & b & c & d & e & f & g & h & ~j
321
+ # This expression can be simplified to get rid of the j variables
322
+ assert simplify_logic(expr) == expr
323
+
324
+ # Test dontcare
325
+ assert simplify_logic((a & b) | c | d, dontcare=(a & b)) == c | d
326
+
327
+ # check input
328
+ ans = SOPform([x, y], [[1, 0]])
329
+ assert SOPform([x, y], [[1, 0]]) == ans
330
+ assert POSform([x, y], [[1, 0]]) == ans
331
+
332
+ raises(ValueError, lambda: SOPform([x], [[1]], [[1]]))
333
+ assert SOPform([x], [[1]], [[0]]) is true
334
+ assert SOPform([x], [[0]], [[1]]) is true
335
+ assert SOPform([x], [], []) is false
336
+
337
+ raises(ValueError, lambda: POSform([x], [[1]], [[1]]))
338
+ assert POSform([x], [[1]], [[0]]) is true
339
+ assert POSform([x], [[0]], [[1]]) is true
340
+ assert POSform([x], [], []) is false
341
+
342
+ # check working of simplify
343
+ assert simplify((A & B) | (A & C)) == And(A, Or(B, C))
344
+ assert simplify(And(x, Not(x))) == False
345
+ assert simplify(Or(x, Not(x))) == True
346
+ assert simplify(And(Eq(x, 0), Eq(x, y))) == And(Eq(x, 0), Eq(y, 0))
347
+ assert And(Eq(x - 1, 0), Eq(x, y)).simplify() == And(Eq(x, 1), Eq(y, 1))
348
+ assert And(Ne(x - 1, 0), Ne(x, y)).simplify() == And(Ne(x, 1), Ne(x, y))
349
+ assert And(Eq(x - 1, 0), Ne(x, y)).simplify() == And(Eq(x, 1), Ne(y, 1))
350
+ assert And(Eq(x - 1, 0), Eq(x, z + y), Eq(y + x, 0)).simplify(
351
+ ) == And(Eq(x, 1), Eq(y, -1), Eq(z, 2))
352
+ assert And(Eq(x - 1, 0), Eq(x + 2, 3)).simplify() == Eq(x, 1)
353
+ assert And(Ne(x - 1, 0), Ne(x + 2, 3)).simplify() == Ne(x, 1)
354
+ assert And(Eq(x - 1, 0), Eq(x + 2, 2)).simplify() == False
355
+ assert And(Ne(x - 1, 0), Ne(x + 2, 2)).simplify(
356
+ ) == And(Ne(x, 1), Ne(x, 0))
357
+ assert simplify(Xor(x, ~x)) == True
358
+
359
+
360
+ def test_bool_map():
361
+ """
362
+ Test working of bool_map function.
363
+ """
364
+
365
+ minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1],
366
+ [1, 1, 1, 1]]
367
+ assert bool_map(Not(Not(a)), a) == (a, {a: a})
368
+ assert bool_map(SOPform([w, x, y, z], minterms),
369
+ POSform([w, x, y, z], minterms)) == \
370
+ (And(Or(Not(w), y), Or(Not(x), y), z), {x: x, w: w, z: z, y: y})
371
+ assert bool_map(SOPform([x, z, y], [[1, 0, 1]]),
372
+ SOPform([a, b, c], [[1, 0, 1]])) != False
373
+ function1 = SOPform([x, z, y], [[1, 0, 1], [0, 0, 1]])
374
+ function2 = SOPform([a, b, c], [[1, 0, 1], [1, 0, 0]])
375
+ assert bool_map(function1, function2) == \
376
+ (function1, {y: a, z: b})
377
+ assert bool_map(Xor(x, y), ~Xor(x, y)) == False
378
+ assert bool_map(And(x, y), Or(x, y)) is None
379
+ assert bool_map(And(x, y), And(x, y, z)) is None
380
+ # issue 16179
381
+ assert bool_map(Xor(x, y, z), ~Xor(x, y, z)) == False
382
+ assert bool_map(Xor(a, x, y, z), ~Xor(a, x, y, z)) == False
383
+
384
+
385
+ def test_bool_symbol():
386
+ """Test that mixing symbols with boolean values
387
+ works as expected"""
388
+
389
+ assert And(A, True) == A
390
+ assert And(A, True, True) == A
391
+ assert And(A, False) is false
392
+ assert And(A, True, False) is false
393
+ assert Or(A, True) is true
394
+ assert Or(A, False) == A
395
+
396
+
397
+ def test_is_boolean():
398
+ assert isinstance(True, Boolean) is False
399
+ assert isinstance(true, Boolean) is True
400
+ assert 1 == True
401
+ assert 1 != true
402
+ assert (1 == true) is False
403
+ assert 0 == False
404
+ assert 0 != false
405
+ assert (0 == false) is False
406
+ assert true.is_Boolean is True
407
+ assert (A & B).is_Boolean
408
+ assert (A | B).is_Boolean
409
+ assert (~A).is_Boolean
410
+ assert (A ^ B).is_Boolean
411
+ assert A.is_Boolean != isinstance(A, Boolean)
412
+ assert isinstance(A, Boolean)
413
+
414
+
415
+ def test_subs():
416
+ assert (A & B).subs(A, True) == B
417
+ assert (A & B).subs(A, False) is false
418
+ assert (A & B).subs(B, True) == A
419
+ assert (A & B).subs(B, False) is false
420
+ assert (A & B).subs({A: True, B: True}) is true
421
+ assert (A | B).subs(A, True) is true
422
+ assert (A | B).subs(A, False) == B
423
+ assert (A | B).subs(B, True) is true
424
+ assert (A | B).subs(B, False) == A
425
+ assert (A | B).subs({A: True, B: True}) is true
426
+
427
+
428
+ """
429
+ we test for axioms of boolean algebra
430
+ see https://en.wikipedia.org/wiki/Boolean_algebra_(structure)
431
+ """
432
+
433
+
434
+ def test_commutative():
435
+ """Test for commutativity of And and Or"""
436
+ A, B = map(Boolean, symbols('A,B'))
437
+
438
+ assert A & B == B & A
439
+ assert A | B == B | A
440
+
441
+
442
+ def test_and_associativity():
443
+ """Test for associativity of And"""
444
+
445
+ assert (A & B) & C == A & (B & C)
446
+
447
+
448
+ def test_or_assicativity():
449
+ assert ((A | B) | C) == (A | (B | C))
450
+
451
+
452
+ def test_double_negation():
453
+ a = Boolean()
454
+ assert ~(~a) == a
455
+
456
+
457
+ # test methods
458
+
459
+ def test_eliminate_implications():
460
+ assert eliminate_implications(Implies(A, B, evaluate=False)) == (~A) | B
461
+ assert eliminate_implications(
462
+ A >> (C >> Not(B))) == Or(Or(Not(B), Not(C)), Not(A))
463
+ assert eliminate_implications(Equivalent(A, B, C, D)) == \
464
+ (~A | B) & (~B | C) & (~C | D) & (~D | A)
465
+
466
+
467
+ def test_conjuncts():
468
+ assert conjuncts(A & B & C) == {A, B, C}
469
+ assert conjuncts((A | B) & C) == {A | B, C}
470
+ assert conjuncts(A) == {A}
471
+ assert conjuncts(True) == {True}
472
+ assert conjuncts(False) == {False}
473
+
474
+
475
+ def test_disjuncts():
476
+ assert disjuncts(A | B | C) == {A, B, C}
477
+ assert disjuncts((A | B) & C) == {(A | B) & C}
478
+ assert disjuncts(A) == {A}
479
+ assert disjuncts(True) == {True}
480
+ assert disjuncts(False) == {False}
481
+
482
+
483
+ def test_distribute():
484
+ assert distribute_and_over_or(Or(And(A, B), C)) == And(Or(A, C), Or(B, C))
485
+ assert distribute_or_over_and(And(A, Or(B, C))) == Or(And(A, B), And(A, C))
486
+ assert distribute_xor_over_and(And(A, Xor(B, C))) == Xor(And(A, B), And(A, C))
487
+
488
+
489
+ def test_to_anf():
490
+ x, y, z = symbols('x,y,z')
491
+ assert to_anf(And(x, y)) == And(x, y)
492
+ assert to_anf(Or(x, y)) == Xor(x, y, And(x, y))
493
+ assert to_anf(Or(Implies(x, y), And(x, y), y)) == \
494
+ Xor(x, True, x & y, remove_true=False)
495
+ assert to_anf(Or(Nand(x, y), Nor(x, y), Xnor(x, y), Implies(x, y))) == True
496
+ assert to_anf(Or(x, Not(y), Nor(x, z), And(x, y), Nand(y, z))) == \
497
+ Xor(True, And(y, z), And(x, y, z), remove_true=False)
498
+ assert to_anf(Xor(x, y)) == Xor(x, y)
499
+ assert to_anf(Not(x)) == Xor(x, True, remove_true=False)
500
+ assert to_anf(Nand(x, y)) == Xor(True, And(x, y), remove_true=False)
501
+ assert to_anf(Nor(x, y)) == Xor(x, y, True, And(x, y), remove_true=False)
502
+ assert to_anf(Implies(x, y)) == Xor(x, True, And(x, y), remove_true=False)
503
+ assert to_anf(Equivalent(x, y)) == Xor(x, y, True, remove_true=False)
504
+ assert to_anf(Nand(x | y, x >> y), deep=False) == \
505
+ Xor(True, And(Or(x, y), Implies(x, y)), remove_true=False)
506
+ assert to_anf(Nor(x ^ y, x & y), deep=False) == \
507
+ Xor(True, Or(Xor(x, y), And(x, y)), remove_true=False)
508
+ # issue 25218
509
+ assert to_anf(x ^ ~(x ^ y ^ ~y)) == False
510
+
511
+
512
+ def test_to_nnf():
513
+ assert to_nnf(true) is true
514
+ assert to_nnf(false) is false
515
+ assert to_nnf(A) == A
516
+ assert to_nnf(A | ~A | B) is true
517
+ assert to_nnf(A & ~A & B) is false
518
+ assert to_nnf(A >> B) == ~A | B
519
+ assert to_nnf(Equivalent(A, B, C)) == (~A | B) & (~B | C) & (~C | A)
520
+ assert to_nnf(A ^ B ^ C) == \
521
+ (A | B | C) & (~A | ~B | C) & (A | ~B | ~C) & (~A | B | ~C)
522
+ assert to_nnf(ITE(A, B, C)) == (~A | B) & (A | C)
523
+ assert to_nnf(Not(A | B | C)) == ~A & ~B & ~C
524
+ assert to_nnf(Not(A & B & C)) == ~A | ~B | ~C
525
+ assert to_nnf(Not(A >> B)) == A & ~B
526
+ assert to_nnf(Not(Equivalent(A, B, C))) == And(Or(A, B, C), Or(~A, ~B, ~C))
527
+ assert to_nnf(Not(A ^ B ^ C)) == \
528
+ (~A | B | C) & (A | ~B | C) & (A | B | ~C) & (~A | ~B | ~C)
529
+ assert to_nnf(Not(ITE(A, B, C))) == (~A | ~B) & (A | ~C)
530
+ assert to_nnf((A >> B) ^ (B >> A)) == (A & ~B) | (~A & B)
531
+ assert to_nnf((A >> B) ^ (B >> A), False) == \
532
+ (~A | ~B | A | B) & ((A & ~B) | (~A & B))
533
+ assert ITE(A, 1, 0).to_nnf() == A
534
+ assert ITE(A, 0, 1).to_nnf() == ~A
535
+ # although ITE can hold non-Boolean, it will complain if
536
+ # an attempt is made to convert the ITE to Boolean nnf
537
+ raises(TypeError, lambda: ITE(A < 1, [1], B).to_nnf())
538
+
539
+
540
+ def test_to_cnf():
541
+ assert to_cnf(~(B | C)) == And(Not(B), Not(C))
542
+ assert to_cnf((A & B) | C) == And(Or(A, C), Or(B, C))
543
+ assert to_cnf(A >> B) == (~A) | B
544
+ assert to_cnf(A >> (B & C)) == (~A | B) & (~A | C)
545
+ assert to_cnf(A & (B | C) | ~A & (B | C), True) == B | C
546
+ assert to_cnf(A & B) == And(A, B)
547
+
548
+ assert to_cnf(Equivalent(A, B)) == And(Or(A, Not(B)), Or(B, Not(A)))
549
+ assert to_cnf(Equivalent(A, B & C)) == \
550
+ (~A | B) & (~A | C) & (~B | ~C | A)
551
+ assert to_cnf(Equivalent(A, B | C), True) == \
552
+ And(Or(Not(B), A), Or(Not(C), A), Or(B, C, Not(A)))
553
+ assert to_cnf(A + 1) == A + 1
554
+
555
+
556
+ def test_issue_18904():
557
+ x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 = symbols('x1:16')
558
+ eq = ((x1 & x2 & x3 & x4 & x5 & x6 & x7 & x8 & x9) |
559
+ (x1 & x2 & x3 & x4 & x5 & x6 & x7 & x10 & x9) |
560
+ (x1 & x11 & x3 & x12 & x5 & x13 & x14 & x15 & x9))
561
+ assert is_cnf(to_cnf(eq))
562
+ raises(ValueError, lambda: to_cnf(eq, simplify=True))
563
+ for f, t in zip((And, Or), (to_cnf, to_dnf)):
564
+ eq = f(x1, x2, x3, x4, x5, x6, x7, x8, x9)
565
+ raises(ValueError, lambda: to_cnf(eq, simplify=True))
566
+ assert t(eq, simplify=True, force=True) == eq
567
+
568
+
569
+ def test_issue_9949():
570
+ assert is_cnf(to_cnf((b > -5) | (a > 2) & (a < 4)))
571
+
572
+
573
+ def test_to_CNF():
574
+ assert CNF.CNF_to_cnf(CNF.to_CNF(~(B | C))) == to_cnf(~(B | C))
575
+ assert CNF.CNF_to_cnf(CNF.to_CNF((A & B) | C)) == to_cnf((A & B) | C)
576
+ assert CNF.CNF_to_cnf(CNF.to_CNF(A >> B)) == to_cnf(A >> B)
577
+ assert CNF.CNF_to_cnf(CNF.to_CNF(A >> (B & C))) == to_cnf(A >> (B & C))
578
+ assert CNF.CNF_to_cnf(CNF.to_CNF(A & (B | C) | ~A & (B | C))) == to_cnf(A & (B | C) | ~A & (B | C))
579
+ assert CNF.CNF_to_cnf(CNF.to_CNF(A & B)) == to_cnf(A & B)
580
+
581
+
582
+ def test_to_dnf():
583
+ assert to_dnf(~(B | C)) == And(Not(B), Not(C))
584
+ assert to_dnf(A & (B | C)) == Or(And(A, B), And(A, C))
585
+ assert to_dnf(A >> B) == (~A) | B
586
+ assert to_dnf(A >> (B & C)) == (~A) | (B & C)
587
+ assert to_dnf(A | B) == A | B
588
+
589
+ assert to_dnf(Equivalent(A, B), True) == \
590
+ Or(And(A, B), And(Not(A), Not(B)))
591
+ assert to_dnf(Equivalent(A, B & C), True) == \
592
+ Or(And(A, B, C), And(Not(A), Not(B)), And(Not(A), Not(C)))
593
+ assert to_dnf(A + 1) == A + 1
594
+
595
+
596
+ def test_to_int_repr():
597
+ x, y, z = map(Boolean, symbols('x,y,z'))
598
+
599
+ def sorted_recursive(arg):
600
+ try:
601
+ return sorted(sorted_recursive(x) for x in arg)
602
+ except TypeError: # arg is not a sequence
603
+ return arg
604
+
605
+ assert sorted_recursive(to_int_repr([x | y, z | x], [x, y, z])) == \
606
+ sorted_recursive([[1, 2], [1, 3]])
607
+ assert sorted_recursive(to_int_repr([x | y, z | ~x], [x, y, z])) == \
608
+ sorted_recursive([[1, 2], [3, -1]])
609
+
610
+
611
+ def test_is_anf():
612
+ x, y = symbols('x,y')
613
+ assert is_anf(true) is True
614
+ assert is_anf(false) is True
615
+ assert is_anf(x) is True
616
+ assert is_anf(And(x, y)) is True
617
+ assert is_anf(Xor(x, y, And(x, y))) is True
618
+ assert is_anf(Xor(x, y, Or(x, y))) is False
619
+ assert is_anf(Xor(Not(x), y)) is False
620
+
621
+
622
+ def test_is_nnf():
623
+ assert is_nnf(true) is True
624
+ assert is_nnf(A) is True
625
+ assert is_nnf(~A) is True
626
+ assert is_nnf(A & B) is True
627
+ assert is_nnf((A & B) | (~A & A) | (~B & B) | (~A & ~B), False) is True
628
+ assert is_nnf((A | B) & (~A | ~B)) is True
629
+ assert is_nnf(Not(Or(A, B))) is False
630
+ assert is_nnf(A ^ B) is False
631
+ assert is_nnf((A & B) | (~A & A) | (~B & B) | (~A & ~B), True) is False
632
+
633
+
634
+ def test_is_cnf():
635
+ assert is_cnf(x) is True
636
+ assert is_cnf(x | y | z) is True
637
+ assert is_cnf(x & y & z) is True
638
+ assert is_cnf((x | y) & z) is True
639
+ assert is_cnf((x & y) | z) is False
640
+ assert is_cnf(~(x & y) | z) is False
641
+
642
+
643
+ def test_is_dnf():
644
+ assert is_dnf(x) is True
645
+ assert is_dnf(x | y | z) is True
646
+ assert is_dnf(x & y & z) is True
647
+ assert is_dnf((x & y) | z) is True
648
+ assert is_dnf((x | y) & z) is False
649
+ assert is_dnf(~(x | y) & z) is False
650
+
651
+
652
+ def test_ITE():
653
+ A, B, C = symbols('A:C')
654
+ assert ITE(True, False, True) is false
655
+ assert ITE(True, True, False) is true
656
+ assert ITE(False, True, False) is false
657
+ assert ITE(False, False, True) is true
658
+ assert isinstance(ITE(A, B, C), ITE)
659
+
660
+ A = True
661
+ assert ITE(A, B, C) == B
662
+ A = False
663
+ assert ITE(A, B, C) == C
664
+ B = True
665
+ assert ITE(And(A, B), B, C) == C
666
+ assert ITE(Or(A, False), And(B, True), False) is false
667
+ assert ITE(x, A, B) == Not(x)
668
+ assert ITE(x, B, A) == x
669
+ assert ITE(1, x, y) == x
670
+ assert ITE(0, x, y) == y
671
+ raises(TypeError, lambda: ITE(2, x, y))
672
+ raises(TypeError, lambda: ITE(1, [], y))
673
+ raises(TypeError, lambda: ITE(1, (), y))
674
+ raises(TypeError, lambda: ITE(1, y, []))
675
+ assert ITE(1, 1, 1) is S.true
676
+ assert isinstance(ITE(1, 1, 1, evaluate=False), ITE)
677
+
678
+ assert ITE(Eq(x, True), y, x) == ITE(x, y, x)
679
+ assert ITE(Eq(x, False), y, x) == ITE(~x, y, x)
680
+ assert ITE(Ne(x, True), y, x) == ITE(~x, y, x)
681
+ assert ITE(Ne(x, False), y, x) == ITE(x, y, x)
682
+ assert ITE(Eq(S.true, x), y, x) == ITE(x, y, x)
683
+ assert ITE(Eq(S.false, x), y, x) == ITE(~x, y, x)
684
+ assert ITE(Ne(S.true, x), y, x) == ITE(~x, y, x)
685
+ assert ITE(Ne(S.false, x), y, x) == ITE(x, y, x)
686
+ # 0 and 1 in the context are not treated as True/False
687
+ # so the equality must always be False since dissimilar
688
+ # objects cannot be equal
689
+ assert ITE(Eq(x, 0), y, x) == x
690
+ assert ITE(Eq(x, 1), y, x) == x
691
+ assert ITE(Ne(x, 0), y, x) == y
692
+ assert ITE(Ne(x, 1), y, x) == y
693
+ assert ITE(Eq(x, 0), y, z).subs(x, 0) == y
694
+ assert ITE(Eq(x, 0), y, z).subs(x, 1) == z
695
+ raises(ValueError, lambda: ITE(x > 1, y, x, z))
696
+
697
+
698
+ def test_is_literal():
699
+ assert is_literal(True) is True
700
+ assert is_literal(False) is True
701
+ assert is_literal(A) is True
702
+ assert is_literal(~A) is True
703
+ assert is_literal(Or(A, B)) is False
704
+ assert is_literal(Q.zero(A)) is True
705
+ assert is_literal(Not(Q.zero(A))) is True
706
+ assert is_literal(Or(A, B)) is False
707
+ assert is_literal(And(Q.zero(A), Q.zero(B))) is False
708
+ assert is_literal(x < 3)
709
+ assert not is_literal(x + y < 3)
710
+
711
+
712
+ def test_operators():
713
+ # Mostly test __and__, __rand__, and so on
714
+ assert True & A == A & True == A
715
+ assert False & A == A & False == False
716
+ assert A & B == And(A, B)
717
+ assert True | A == A | True == True
718
+ assert False | A == A | False == A
719
+ assert A | B == Or(A, B)
720
+ assert ~A == Not(A)
721
+ assert True >> A == A << True == A
722
+ assert False >> A == A << False == True
723
+ assert A >> True == True << A == True
724
+ assert A >> False == False << A == ~A
725
+ assert A >> B == B << A == Implies(A, B)
726
+ assert True ^ A == A ^ True == ~A
727
+ assert False ^ A == A ^ False == A
728
+ assert A ^ B == Xor(A, B)
729
+
730
+
731
+ def test_true_false():
732
+ assert true is S.true
733
+ assert false is S.false
734
+ assert true is not True
735
+ assert false is not False
736
+ assert true
737
+ assert not false
738
+ assert true == True
739
+ assert false == False
740
+ assert not (true == False)
741
+ assert not (false == True)
742
+ assert not (true == false)
743
+
744
+ assert hash(true) == hash(True)
745
+ assert hash(false) == hash(False)
746
+ assert len({true, True}) == len({false, False}) == 1
747
+
748
+ assert isinstance(true, BooleanAtom)
749
+ assert isinstance(false, BooleanAtom)
750
+ # We don't want to subclass from bool, because bool subclasses from
751
+ # int. But operators like &, |, ^, <<, >>, and ~ act differently on 0 and
752
+ # 1 then we want them to on true and false. See the docstrings of the
753
+ # various And, Or, etc. functions for examples.
754
+ assert not isinstance(true, bool)
755
+ assert not isinstance(false, bool)
756
+
757
+ # Note: using 'is' comparison is important here. We want these to return
758
+ # true and false, not True and False
759
+
760
+ assert Not(true) is false
761
+ assert Not(True) is false
762
+ assert Not(false) is true
763
+ assert Not(False) is true
764
+ assert ~true is false
765
+ assert ~false is true
766
+
767
+ for T, F in product((True, true), (False, false)):
768
+ assert And(T, F) is false
769
+ assert And(F, T) is false
770
+ assert And(F, F) is false
771
+ assert And(T, T) is true
772
+ assert And(T, x) == x
773
+ assert And(F, x) is false
774
+ if not (T is True and F is False):
775
+ assert T & F is false
776
+ assert F & T is false
777
+ if F is not False:
778
+ assert F & F is false
779
+ if T is not True:
780
+ assert T & T is true
781
+
782
+ assert Or(T, F) is true
783
+ assert Or(F, T) is true
784
+ assert Or(F, F) is false
785
+ assert Or(T, T) is true
786
+ assert Or(T, x) is true
787
+ assert Or(F, x) == x
788
+ if not (T is True and F is False):
789
+ assert T | F is true
790
+ assert F | T is true
791
+ if F is not False:
792
+ assert F | F is false
793
+ if T is not True:
794
+ assert T | T is true
795
+
796
+ assert Xor(T, F) is true
797
+ assert Xor(F, T) is true
798
+ assert Xor(F, F) is false
799
+ assert Xor(T, T) is false
800
+ assert Xor(T, x) == ~x
801
+ assert Xor(F, x) == x
802
+ if not (T is True and F is False):
803
+ assert T ^ F is true
804
+ assert F ^ T is true
805
+ if F is not False:
806
+ assert F ^ F is false
807
+ if T is not True:
808
+ assert T ^ T is false
809
+
810
+ assert Nand(T, F) is true
811
+ assert Nand(F, T) is true
812
+ assert Nand(F, F) is true
813
+ assert Nand(T, T) is false
814
+ assert Nand(T, x) == ~x
815
+ assert Nand(F, x) is true
816
+
817
+ assert Nor(T, F) is false
818
+ assert Nor(F, T) is false
819
+ assert Nor(F, F) is true
820
+ assert Nor(T, T) is false
821
+ assert Nor(T, x) is false
822
+ assert Nor(F, x) == ~x
823
+
824
+ assert Implies(T, F) is false
825
+ assert Implies(F, T) is true
826
+ assert Implies(F, F) is true
827
+ assert Implies(T, T) is true
828
+ assert Implies(T, x) == x
829
+ assert Implies(F, x) is true
830
+ assert Implies(x, T) is true
831
+ assert Implies(x, F) == ~x
832
+ if not (T is True and F is False):
833
+ assert T >> F is false
834
+ assert F << T is false
835
+ assert F >> T is true
836
+ assert T << F is true
837
+ if F is not False:
838
+ assert F >> F is true
839
+ assert F << F is true
840
+ if T is not True:
841
+ assert T >> T is true
842
+ assert T << T is true
843
+
844
+ assert Equivalent(T, F) is false
845
+ assert Equivalent(F, T) is false
846
+ assert Equivalent(F, F) is true
847
+ assert Equivalent(T, T) is true
848
+ assert Equivalent(T, x) == x
849
+ assert Equivalent(F, x) == ~x
850
+ assert Equivalent(x, T) == x
851
+ assert Equivalent(x, F) == ~x
852
+
853
+ assert ITE(T, T, T) is true
854
+ assert ITE(T, T, F) is true
855
+ assert ITE(T, F, T) is false
856
+ assert ITE(T, F, F) is false
857
+ assert ITE(F, T, T) is true
858
+ assert ITE(F, T, F) is false
859
+ assert ITE(F, F, T) is true
860
+ assert ITE(F, F, F) is false
861
+
862
+ assert all(i.simplify(1, 2) is i for i in (S.true, S.false))
863
+
864
+
865
+ def test_bool_as_set():
866
+ assert ITE(y <= 0, False, y >= 1).as_set() == Interval(1, oo)
867
+ assert And(x <= 2, x >= -2).as_set() == Interval(-2, 2)
868
+ assert Or(x >= 2, x <= -2).as_set() == Interval(-oo, -2) + Interval(2, oo)
869
+ assert Not(x > 2).as_set() == Interval(-oo, 2)
870
+ # issue 10240
871
+ assert Not(And(x > 2, x < 3)).as_set() == \
872
+ Union(Interval(-oo, 2), Interval(3, oo))
873
+ assert true.as_set() == S.UniversalSet
874
+ assert false.as_set() is S.EmptySet
875
+ assert x.as_set() == S.UniversalSet
876
+ assert And(Or(x < 1, x > 3), x < 2).as_set() == Interval.open(-oo, 1)
877
+ assert And(x < 1, sin(x) < 3).as_set() == (x < 1).as_set()
878
+ raises(NotImplementedError, lambda: (sin(x) < 1).as_set())
879
+ # watch for object morph in as_set
880
+ assert Eq(-1, cos(2 * x) ** 2 / sin(2 * x) ** 2).as_set() is S.EmptySet
881
+
882
+
883
+ @XFAIL
884
+ def test_multivariate_bool_as_set():
885
+ x, y = symbols('x,y')
886
+
887
+ assert And(x >= 0, y >= 0).as_set() == Interval(0, oo) * Interval(0, oo)
888
+ assert Or(x >= 0, y >= 0).as_set() == S.Reals * S.Reals - \
889
+ Interval(-oo, 0, True, True) * Interval(-oo, 0, True, True)
890
+
891
+
892
+ def test_all_or_nothing():
893
+ x = symbols('x', extended_real=True)
894
+ args = x >= -oo, x <= oo
895
+ v = And(*args)
896
+ if v.func is And:
897
+ assert len(v.args) == len(args) - args.count(S.true)
898
+ else:
899
+ assert v == True
900
+ v = Or(*args)
901
+ if v.func is Or:
902
+ assert len(v.args) == 2
903
+ else:
904
+ assert v == True
905
+
906
+
907
+ def test_canonical_atoms():
908
+ assert true.canonical == true
909
+ assert false.canonical == false
910
+
911
+
912
+ def test_negated_atoms():
913
+ assert true.negated == false
914
+ assert false.negated == true
915
+
916
+
917
+ def test_issue_8777():
918
+ assert And(x > 2, x < oo).as_set() == Interval(2, oo, left_open=True)
919
+ assert And(x >= 1, x < oo).as_set() == Interval(1, oo)
920
+ assert (x < oo).as_set() == Interval(-oo, oo)
921
+ assert (x > -oo).as_set() == Interval(-oo, oo)
922
+
923
+
924
+ def test_issue_8975():
925
+ assert Or(And(-oo < x, x <= -2), And(2 <= x, x < oo)).as_set() == \
926
+ Interval(-oo, -2) + Interval(2, oo)
927
+
928
+
929
+ def test_term_to_integer():
930
+ assert term_to_integer([1, 0, 1, 0, 0, 1, 0]) == 82
931
+ assert term_to_integer('0010101000111001') == 10809
932
+
933
+
934
+ def test_issue_21971():
935
+ a, b, c, d = symbols('a b c d')
936
+ f = a & b & c | a & c
937
+ assert f.subs(a & c, d) == b & d | d
938
+ assert f.subs(a & b & c, d) == a & c | d
939
+
940
+ f = (a | b | c) & (a | c)
941
+ assert f.subs(a | c, d) == (b | d) & d
942
+ assert f.subs(a | b | c, d) == (a | c) & d
943
+
944
+ f = (a ^ b ^ c) & (a ^ c)
945
+ assert f.subs(a ^ c, d) == (b ^ d) & d
946
+ assert f.subs(a ^ b ^ c, d) == (a ^ c) & d
947
+
948
+
949
+ def test_truth_table():
950
+ assert list(truth_table(And(x, y), [x, y], input=False)) == \
951
+ [False, False, False, True]
952
+ assert list(truth_table(x | y, [x, y], input=False)) == \
953
+ [False, True, True, True]
954
+ assert list(truth_table(x >> y, [x, y], input=False)) == \
955
+ [True, True, False, True]
956
+ assert list(truth_table(And(x, y), [x, y])) == \
957
+ [([0, 0], False), ([0, 1], False), ([1, 0], False), ([1, 1], True)]
958
+
959
+
960
+ def test_issue_8571():
961
+ for t in (S.true, S.false):
962
+ raises(TypeError, lambda: +t)
963
+ raises(TypeError, lambda: -t)
964
+ raises(TypeError, lambda: abs(t))
965
+ # use int(bool(t)) to get 0 or 1
966
+ raises(TypeError, lambda: int(t))
967
+
968
+ for o in [S.Zero, S.One, x]:
969
+ for _ in range(2):
970
+ raises(TypeError, lambda: o + t)
971
+ raises(TypeError, lambda: o - t)
972
+ raises(TypeError, lambda: o % t)
973
+ raises(TypeError, lambda: o * t)
974
+ raises(TypeError, lambda: o / t)
975
+ raises(TypeError, lambda: o ** t)
976
+ o, t = t, o # do again in reversed order
977
+
978
+
979
+ def test_expand_relational():
980
+ n = symbols('n', negative=True)
981
+ p, q = symbols('p q', positive=True)
982
+ r = ((n + q * (-n / q + 1)) / (q * (-n / q + 1)) < 0)
983
+ assert r is not S.false
984
+ assert r.expand() is S.false
985
+ assert (q > 0).expand() is S.true
986
+
987
+
988
+ def test_issue_12717():
989
+ assert S.true.is_Atom == True
990
+ assert S.false.is_Atom == True
991
+
992
+
993
+ def test_as_Boolean():
994
+ nz = symbols('nz', nonzero=True)
995
+ assert all(as_Boolean(i) is S.true for i in (True, S.true, 1, nz))
996
+ z = symbols('z', zero=True)
997
+ assert all(as_Boolean(i) is S.false for i in (False, S.false, 0, z))
998
+ assert all(as_Boolean(i) == i for i in (x, x < 0))
999
+ for i in (2, S(2), x + 1, []):
1000
+ raises(TypeError, lambda: as_Boolean(i))
1001
+
1002
+
1003
+ def test_binary_symbols():
1004
+ assert ITE(x < 1, y, z).binary_symbols == {y, z}
1005
+ for f in (Eq, Ne):
1006
+ assert f(x, 1).binary_symbols == set()
1007
+ assert f(x, True).binary_symbols == {x}
1008
+ assert f(x, False).binary_symbols == {x}
1009
+ assert S.true.binary_symbols == set()
1010
+ assert S.false.binary_symbols == set()
1011
+ assert x.binary_symbols == {x}
1012
+ assert And(x, Eq(y, False), Eq(z, 1)).binary_symbols == {x, y}
1013
+ assert Q.prime(x).binary_symbols == set()
1014
+ assert Q.lt(x, 1).binary_symbols == set()
1015
+ assert Q.is_true(x).binary_symbols == {x}
1016
+ assert Q.eq(x, True).binary_symbols == {x}
1017
+ assert Q.prime(x).binary_symbols == set()
1018
+
1019
+
1020
+ def test_BooleanFunction_diff():
1021
+ assert And(x, y).diff(x) == Piecewise((0, Eq(y, False)), (1, True))
1022
+
1023
+
1024
+ def test_issue_14700():
1025
+ A, B, C, D, E, F, G, H = symbols('A B C D E F G H')
1026
+ q = ((B & D & H & ~F) | (B & H & ~C & ~D) | (B & H & ~C & ~F) |
1027
+ (B & H & ~D & ~G) | (B & H & ~F & ~G) | (C & G & ~B & ~D) |
1028
+ (C & G & ~D & ~H) | (C & G & ~F & ~H) | (D & F & H & ~B) |
1029
+ (D & F & ~G & ~H) | (B & D & F & ~C & ~H) | (D & E & F & ~B & ~C) |
1030
+ (D & F & ~A & ~B & ~C) | (D & F & ~A & ~C & ~H) |
1031
+ (A & B & D & F & ~E & ~H))
1032
+ soldnf = ((B & D & H & ~F) | (D & F & H & ~B) | (B & H & ~C & ~D) |
1033
+ (B & H & ~D & ~G) | (C & G & ~B & ~D) | (C & G & ~D & ~H) |
1034
+ (C & G & ~F & ~H) | (D & F & ~G & ~H) | (D & E & F & ~C & ~H) |
1035
+ (D & F & ~A & ~C & ~H) | (A & B & D & F & ~E & ~H))
1036
+ solcnf = ((B | C | D) & (B | D | G) & (C | D | H) & (C | F | H) &
1037
+ (D | G | H) & (F | G | H) & (B | F | ~D | ~H) &
1038
+ (~B | ~D | ~F | ~H) & (D | ~B | ~C | ~G | ~H) &
1039
+ (A | H | ~C | ~D | ~F | ~G) & (H | ~C | ~D | ~E | ~F | ~G) &
1040
+ (B | E | H | ~A | ~D | ~F | ~G))
1041
+ assert simplify_logic(q, "dnf") == soldnf
1042
+ assert simplify_logic(q, "cnf") == solcnf
1043
+
1044
+ minterms = [[0, 1, 0, 0], [0, 1, 0, 1], [0, 1, 1, 0], [0, 1, 1, 1],
1045
+ [0, 0, 1, 1], [1, 0, 1, 1]]
1046
+ dontcares = [[1, 0, 0, 0], [1, 0, 0, 1], [1, 1, 0, 0], [1, 1, 0, 1]]
1047
+ assert SOPform([w, x, y, z], minterms) == (x & ~w) | (y & z & ~x)
1048
+ # Should not be more complicated with don't cares
1049
+ assert SOPform([w, x, y, z], minterms, dontcares) == \
1050
+ (x & ~w) | (y & z & ~x)
1051
+
1052
+
1053
+ def test_issue_25115():
1054
+ cond = Contains(x, S.Integers)
1055
+ # Previously this raised an exception:
1056
+ assert simplify_logic(cond) == cond
1057
+
1058
+
1059
+ def test_relational_simplification():
1060
+ w, x, y, z = symbols('w x y z', real=True)
1061
+ d, e = symbols('d e', real=False)
1062
+ # Test all combinations or sign and order
1063
+ assert Or(x >= y, x < y).simplify() == S.true
1064
+ assert Or(x >= y, y > x).simplify() == S.true
1065
+ assert Or(x >= y, -x > -y).simplify() == S.true
1066
+ assert Or(x >= y, -y < -x).simplify() == S.true
1067
+ assert Or(-x <= -y, x < y).simplify() == S.true
1068
+ assert Or(-x <= -y, -x > -y).simplify() == S.true
1069
+ assert Or(-x <= -y, y > x).simplify() == S.true
1070
+ assert Or(-x <= -y, -y < -x).simplify() == S.true
1071
+ assert Or(y <= x, x < y).simplify() == S.true
1072
+ assert Or(y <= x, y > x).simplify() == S.true
1073
+ assert Or(y <= x, -x > -y).simplify() == S.true
1074
+ assert Or(y <= x, -y < -x).simplify() == S.true
1075
+ assert Or(-y >= -x, x < y).simplify() == S.true
1076
+ assert Or(-y >= -x, y > x).simplify() == S.true
1077
+ assert Or(-y >= -x, -x > -y).simplify() == S.true
1078
+ assert Or(-y >= -x, -y < -x).simplify() == S.true
1079
+
1080
+ assert Or(x < y, x >= y).simplify() == S.true
1081
+ assert Or(y > x, x >= y).simplify() == S.true
1082
+ assert Or(-x > -y, x >= y).simplify() == S.true
1083
+ assert Or(-y < -x, x >= y).simplify() == S.true
1084
+ assert Or(x < y, -x <= -y).simplify() == S.true
1085
+ assert Or(-x > -y, -x <= -y).simplify() == S.true
1086
+ assert Or(y > x, -x <= -y).simplify() == S.true
1087
+ assert Or(-y < -x, -x <= -y).simplify() == S.true
1088
+ assert Or(x < y, y <= x).simplify() == S.true
1089
+ assert Or(y > x, y <= x).simplify() == S.true
1090
+ assert Or(-x > -y, y <= x).simplify() == S.true
1091
+ assert Or(-y < -x, y <= x).simplify() == S.true
1092
+ assert Or(x < y, -y >= -x).simplify() == S.true
1093
+ assert Or(y > x, -y >= -x).simplify() == S.true
1094
+ assert Or(-x > -y, -y >= -x).simplify() == S.true
1095
+ assert Or(-y < -x, -y >= -x).simplify() == S.true
1096
+
1097
+ # Some other tests
1098
+ assert Or(x >= y, w < z, x <= y).simplify() == S.true
1099
+ assert And(x >= y, x < y).simplify() == S.false
1100
+ assert Or(x >= y, Eq(y, x)).simplify() == (x >= y)
1101
+ assert And(x >= y, Eq(y, x)).simplify() == Eq(x, y)
1102
+ assert And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify() == \
1103
+ (Eq(x, y) & (x >= 1) & (y >= 5) & (y > z))
1104
+ assert Or(Eq(x, y), x >= y, w < y, z < y).simplify() == \
1105
+ (x >= y) | (y > z) | (w < y)
1106
+ assert And(Eq(x, y), x >= y, w < y, y >= z, z < y).simplify() == \
1107
+ Eq(x, y) & (y > z) & (w < y)
1108
+ # assert And(Eq(x, y), x >= y, w < y, y >= z, z < y).simplify(relational_minmax=True) == \
1109
+ # And(Eq(x, y), y > Max(w, z))
1110
+ # assert Or(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify(relational_minmax=True) == \
1111
+ # (Eq(x, y) | (x >= 1) | (y > Min(2, z)))
1112
+ assert And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify() == \
1113
+ (Eq(x, y) & (x >= 1) & (y >= 5) & (y > z))
1114
+ assert (Eq(x, y) & Eq(d, e) & (x >= y) & (d >= e)).simplify() == \
1115
+ (Eq(x, y) & Eq(d, e) & (d >= e))
1116
+ assert And(Eq(x, y), Eq(x, -y)).simplify() == And(Eq(x, 0), Eq(y, 0))
1117
+ assert Xor(x >= y, x <= y).simplify() == Ne(x, y)
1118
+ assert And(x > 1, x < -1, Eq(x, y)).simplify() == S.false
1119
+ # From #16690
1120
+ assert And(x >= y, Eq(y, 0)).simplify() == And(x >= 0, Eq(y, 0))
1121
+ assert Or(Ne(x, 1), Ne(x, 2)).simplify() == S.true
1122
+ assert And(Eq(x, 1), Ne(2, x)).simplify() == Eq(x, 1)
1123
+ assert Or(Eq(x, 1), Ne(2, x)).simplify() == Ne(x, 2)
1124
+
1125
+
1126
+ def test_issue_8373():
1127
+ x = symbols('x', real=True)
1128
+ assert Or(x < 1, x > -1).simplify() == S.true
1129
+ assert Or(x < 1, x >= 1).simplify() == S.true
1130
+ assert And(x < 1, x >= 1).simplify() == S.false
1131
+ assert Or(x <= 1, x >= 1).simplify() == S.true
1132
+
1133
+
1134
+ def test_issue_7950():
1135
+ x = symbols('x', real=True)
1136
+ assert And(Eq(x, 1), Eq(x, 2)).simplify() == S.false
1137
+
1138
+
1139
+ @slow
1140
+ def test_relational_simplification_numerically():
1141
+ def test_simplification_numerically_function(original, simplified):
1142
+ symb = original.free_symbols
1143
+ n = len(symb)
1144
+ valuelist = list(set(combinations(list(range(-(n - 1), n)) * n, n)))
1145
+ for values in valuelist:
1146
+ sublist = dict(zip(symb, values))
1147
+ originalvalue = original.subs(sublist)
1148
+ simplifiedvalue = simplified.subs(sublist)
1149
+ assert originalvalue == simplifiedvalue, "Original: {}\nand" \
1150
+ " simplified: {}\ndo not evaluate to the same value for {}" \
1151
+ "".format(original, simplified, sublist)
1152
+
1153
+ w, x, y, z = symbols('w x y z', real=True)
1154
+ d, e = symbols('d e', real=False)
1155
+
1156
+ expressions = (And(Eq(x, y), x >= y, w < y, y >= z, z < y),
1157
+ And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y),
1158
+ Or(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y),
1159
+ And(x >= y, Eq(y, x)),
1160
+ Or(And(Eq(x, y), x >= y, w < y, Or(y >= z, z < y)),
1161
+ And(Eq(x, y), x >= 1, 2 < y, y >= -1, z < y)),
1162
+ (Eq(x, y) & Eq(d, e) & (x >= y) & (d >= e)),
1163
+ )
1164
+
1165
+ for expression in expressions:
1166
+ test_simplification_numerically_function(expression,
1167
+ expression.simplify())
1168
+
1169
+
1170
+ def test_relational_simplification_patterns_numerically():
1171
+ from sympy.core import Wild
1172
+ from sympy.logic.boolalg import _simplify_patterns_and, \
1173
+ _simplify_patterns_or, _simplify_patterns_xor
1174
+ a = Wild('a')
1175
+ b = Wild('b')
1176
+ c = Wild('c')
1177
+ symb = [a, b, c]
1178
+ patternlists = [[And, _simplify_patterns_and()],
1179
+ [Or, _simplify_patterns_or()],
1180
+ [Xor, _simplify_patterns_xor()]]
1181
+ valuelist = list(set(combinations(list(range(-2, 3)) * 3, 3)))
1182
+ # Skip combinations of +/-2 and 0, except for all 0
1183
+ valuelist = [v for v in valuelist if any(w % 2 for w in v) or not any(v)]
1184
+ for func, patternlist in patternlists:
1185
+ for pattern in patternlist:
1186
+ original = func(*pattern[0].args)
1187
+ simplified = pattern[1]
1188
+ for values in valuelist:
1189
+ sublist = dict(zip(symb, values))
1190
+ originalvalue = original.xreplace(sublist)
1191
+ simplifiedvalue = simplified.xreplace(sublist)
1192
+ assert originalvalue == simplifiedvalue, "Original: {}\nand" \
1193
+ " simplified: {}\ndo not evaluate to the same value for" \
1194
+ "{}".format(pattern[0], simplified, sublist)
1195
+
1196
+
1197
+ def test_issue_16803():
1198
+ n = symbols('n')
1199
+ # No simplification done, but should not raise an exception
1200
+ assert ((n > 3) | (n < 0) | ((n > 0) & (n < 3))).simplify() == \
1201
+ (n > 3) | (n < 0) | ((n > 0) & (n < 3))
1202
+
1203
+
1204
+ def test_issue_17530():
1205
+ r = {x: oo, y: oo}
1206
+ assert Or(x + y > 0, x - y < 0).subs(r)
1207
+ assert not And(x + y < 0, x - y < 0).subs(r)
1208
+ raises(TypeError, lambda: Or(x + y < 0, x - y < 0).subs(r))
1209
+ raises(TypeError, lambda: And(x + y > 0, x - y < 0).subs(r))
1210
+ raises(TypeError, lambda: And(x + y > 0, x - y < 0).subs(r))
1211
+
1212
+
1213
+ def test_anf_coeffs():
1214
+ assert anf_coeffs([1, 0]) == [1, 1]
1215
+ assert anf_coeffs([0, 0, 0, 1]) == [0, 0, 0, 1]
1216
+ assert anf_coeffs([0, 1, 1, 1]) == [0, 1, 1, 1]
1217
+ assert anf_coeffs([1, 1, 1, 0]) == [1, 0, 0, 1]
1218
+ assert anf_coeffs([1, 0, 0, 0]) == [1, 1, 1, 1]
1219
+ assert anf_coeffs([1, 0, 0, 1]) == [1, 1, 1, 0]
1220
+ assert anf_coeffs([1, 1, 0, 1]) == [1, 0, 1, 1]
1221
+
1222
+
1223
+ def test_ANFform():
1224
+ x, y = symbols('x,y')
1225
+ assert ANFform([x], [1, 1]) == True
1226
+ assert ANFform([x], [0, 0]) == False
1227
+ assert ANFform([x], [1, 0]) == Xor(x, True, remove_true=False)
1228
+ assert ANFform([x, y], [1, 1, 1, 0]) == \
1229
+ Xor(True, And(x, y), remove_true=False)
1230
+
1231
+
1232
+ def test_bool_minterm():
1233
+ x, y = symbols('x,y')
1234
+ assert bool_minterm(3, [x, y]) == And(x, y)
1235
+ assert bool_minterm([1, 0], [x, y]) == And(Not(y), x)
1236
+
1237
+
1238
+ def test_bool_maxterm():
1239
+ x, y = symbols('x,y')
1240
+ assert bool_maxterm(2, [x, y]) == Or(Not(x), y)
1241
+ assert bool_maxterm([0, 1], [x, y]) == Or(Not(y), x)
1242
+
1243
+
1244
+ def test_bool_monomial():
1245
+ x, y = symbols('x,y')
1246
+ assert bool_monomial(1, [x, y]) == y
1247
+ assert bool_monomial([1, 1], [x, y]) == And(x, y)
1248
+
1249
+
1250
+ def test_check_pair():
1251
+ assert _check_pair([0, 1, 0], [0, 1, 1]) == 2
1252
+ assert _check_pair([0, 1, 0], [1, 1, 1]) == -1
1253
+
1254
+
1255
+ def test_issue_19114():
1256
+ expr = (B & C) | (A & ~C) | (~A & ~B)
1257
+ # Expression is minimal, but there are multiple minimal forms possible
1258
+ res1 = (A & B) | (C & ~A) | (~B & ~C)
1259
+ result = to_dnf(expr, simplify=True)
1260
+ assert result in (expr, res1)
1261
+
1262
+
1263
+ def test_issue_20870():
1264
+ result = SOPform([a, b, c, d], [1, 2, 3, 4, 5, 6, 8, 9, 11, 12, 14, 15])
1265
+ expected = ((d & ~b) | (a & b & c) | (a & ~c & ~d) |
1266
+ (b & ~a & ~c) | (c & ~a & ~d))
1267
+ assert result == expected
1268
+
1269
+
1270
+ def test_convert_to_varsSOP():
1271
+ assert _convert_to_varsSOP([0, 1, 0], [x, y, z]) == And(Not(x), y, Not(z))
1272
+ assert _convert_to_varsSOP([3, 1, 0], [x, y, z]) == And(y, Not(z))
1273
+
1274
+
1275
+ def test_convert_to_varsPOS():
1276
+ assert _convert_to_varsPOS([0, 1, 0], [x, y, z]) == Or(x, Not(y), z)
1277
+ assert _convert_to_varsPOS([3, 1, 0], [x, y, z]) == Or(Not(y), z)
1278
+
1279
+
1280
+ def test_gateinputcount():
1281
+ a, b, c, d, e = symbols('a:e')
1282
+ assert gateinputcount(And(a, b)) == 2
1283
+ assert gateinputcount(a | b & c & d ^ (e | a)) == 9
1284
+ assert gateinputcount(And(a, True)) == 0
1285
+ raises(TypeError, lambda: gateinputcount(a * b))
1286
+
1287
+
1288
+ def test_refine():
1289
+ # relational
1290
+ assert not refine(x < 0, ~(x < 0))
1291
+ assert refine(x < 0, (x < 0))
1292
+ assert refine(x < 0, (0 > x)) is S.true
1293
+ assert refine(x < 0, (y < 0)) == (x < 0)
1294
+ assert not refine(x <= 0, ~(x <= 0))
1295
+ assert refine(x <= 0, (x <= 0))
1296
+ assert refine(x <= 0, (0 >= x)) is S.true
1297
+ assert refine(x <= 0, (y <= 0)) == (x <= 0)
1298
+ assert not refine(x > 0, ~(x > 0))
1299
+ assert refine(x > 0, (x > 0))
1300
+ assert refine(x > 0, (0 < x)) is S.true
1301
+ assert refine(x > 0, (y > 0)) == (x > 0)
1302
+ assert not refine(x >= 0, ~(x >= 0))
1303
+ assert refine(x >= 0, (x >= 0))
1304
+ assert refine(x >= 0, (0 <= x)) is S.true
1305
+ assert refine(x >= 0, (y >= 0)) == (x >= 0)
1306
+ assert not refine(Eq(x, 0), ~(Eq(x, 0)))
1307
+ assert refine(Eq(x, 0), (Eq(x, 0)))
1308
+ assert refine(Eq(x, 0), (Eq(0, x))) is S.true
1309
+ assert refine(Eq(x, 0), (Eq(y, 0))) == Eq(x, 0)
1310
+ assert not refine(Ne(x, 0), ~(Ne(x, 0)))
1311
+ assert refine(Ne(x, 0), (Ne(0, x))) is S.true
1312
+ assert refine(Ne(x, 0), (Ne(x, 0)))
1313
+ assert refine(Ne(x, 0), (Ne(y, 0))) == (Ne(x, 0))
1314
+
1315
+ # boolean functions
1316
+ assert refine(And(x > 0, y > 0), (x > 0)) == (y > 0)
1317
+ assert refine(And(x > 0, y > 0), (x > 0) & (y > 0)) is S.true
1318
+
1319
+ # predicates
1320
+ assert refine(Q.positive(x), Q.positive(x)) is S.true
1321
+ assert refine(Q.positive(x), Q.negative(x)) is S.false
1322
+ assert refine(Q.positive(x), Q.real(x)) == Q.positive(x)
1323
+
1324
+
1325
+ def test_relational_threeterm_simplification_patterns_numerically():
1326
+ from sympy.core import Wild
1327
+ from sympy.logic.boolalg import _simplify_patterns_and3
1328
+ a = Wild('a')
1329
+ b = Wild('b')
1330
+ c = Wild('c')
1331
+ symb = [a, b, c]
1332
+ patternlists = [[And, _simplify_patterns_and3()]]
1333
+ valuelist = list(set(combinations(list(range(-2, 3)) * 3, 3)))
1334
+ # Skip combinations of +/-2 and 0, except for all 0
1335
+ valuelist = [v for v in valuelist if any(w % 2 for w in v) or not any(v)]
1336
+ for func, patternlist in patternlists:
1337
+ for pattern in patternlist:
1338
+ original = func(*pattern[0].args)
1339
+ simplified = pattern[1]
1340
+ for values in valuelist:
1341
+ sublist = dict(zip(symb, values))
1342
+ originalvalue = original.xreplace(sublist)
1343
+ simplifiedvalue = simplified.xreplace(sublist)
1344
+ assert originalvalue == simplifiedvalue, "Original: {}\nand" \
1345
+ " simplified: {}\ndo not evaluate to the same value for" \
1346
+ "{}".format(pattern[0], simplified, sublist)
1347
+
1348
+
1349
+ def test_issue_25451():
1350
+ x = Or(And(a, c), Eq(a, b))
1351
+ assert isinstance(x, Or)
1352
+ assert set(x.args) == {And(a, c), Eq(a, b)}
1353
+
1354
+
1355
+ def test_issue_26985():
1356
+ a, b, c, d = symbols('a b c d')
1357
+
1358
+ # Expression before applying to_anf
1359
+ x = Xor(c, And(a, b), And(a, c))
1360
+ y = Xor(a, b, And(a, c))
1361
+
1362
+ # Applying to_anf
1363
+ result = Xor(Xor(d, And(x, y)), And(x, y))
1364
+ result_anf = to_anf(Xor(to_anf(Xor(d, And(x, y))), And(x, y)))
1365
+
1366
+ assert result_anf == d
1367
+ assert result == d
.venv/lib/python3.13/site-packages/sympy/logic/tests/test_dimacs.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Various tests on satisfiability using dimacs cnf file syntax
2
+ You can find lots of cnf files in
3
+ ftp://dimacs.rutgers.edu/pub/challenge/satisfiability/benchmarks/cnf/
4
+ """
5
+
6
+ from sympy.logic.utilities.dimacs import load
7
+ from sympy.logic.algorithms.dpll import dpll_satisfiable
8
+
9
+
10
+ def test_f1():
11
+ assert bool(dpll_satisfiable(load(f1)))
12
+
13
+
14
+ def test_f2():
15
+ assert bool(dpll_satisfiable(load(f2)))
16
+
17
+
18
+ def test_f3():
19
+ assert bool(dpll_satisfiable(load(f3)))
20
+
21
+
22
+ def test_f4():
23
+ assert not bool(dpll_satisfiable(load(f4)))
24
+
25
+
26
+ def test_f5():
27
+ assert bool(dpll_satisfiable(load(f5)))
28
+
29
+ f1 = """c simple example
30
+ c Resolution: SATISFIABLE
31
+ c
32
+ p cnf 3 2
33
+ 1 -3 0
34
+ 2 3 -1 0
35
+ """
36
+
37
+
38
+ f2 = """c an example from Quinn's text, 16 variables and 18 clauses.
39
+ c Resolution: SATISFIABLE
40
+ c
41
+ p cnf 16 18
42
+ 1 2 0
43
+ -2 -4 0
44
+ 3 4 0
45
+ -4 -5 0
46
+ 5 -6 0
47
+ 6 -7 0
48
+ 6 7 0
49
+ 7 -16 0
50
+ 8 -9 0
51
+ -8 -14 0
52
+ 9 10 0
53
+ 9 -10 0
54
+ -10 -11 0
55
+ 10 12 0
56
+ 11 12 0
57
+ 13 14 0
58
+ 14 -15 0
59
+ 15 16 0
60
+ """
61
+
62
+ f3 = """c
63
+ p cnf 6 9
64
+ -1 0
65
+ -3 0
66
+ 2 -1 0
67
+ 2 -4 0
68
+ 5 -4 0
69
+ -1 -3 0
70
+ -4 -6 0
71
+ 1 3 -2 0
72
+ 4 6 -2 -5 0
73
+ """
74
+
75
+ f4 = """c
76
+ c file: hole6.cnf [http://people.sc.fsu.edu/~jburkardt/data/cnf/hole6.cnf]
77
+ c
78
+ c SOURCE: John Hooker (jh38+@andrew.cmu.edu)
79
+ c
80
+ c DESCRIPTION: Pigeon hole problem of placing n (for file 'holen.cnf') pigeons
81
+ c in n+1 holes without placing 2 pigeons in the same hole
82
+ c
83
+ c NOTE: Part of the collection at the Forschungsinstitut fuer
84
+ c anwendungsorientierte Wissensverarbeitung in Ulm Germany.
85
+ c
86
+ c NOTE: Not satisfiable
87
+ c
88
+ p cnf 42 133
89
+ -1 -7 0
90
+ -1 -13 0
91
+ -1 -19 0
92
+ -1 -25 0
93
+ -1 -31 0
94
+ -1 -37 0
95
+ -7 -13 0
96
+ -7 -19 0
97
+ -7 -25 0
98
+ -7 -31 0
99
+ -7 -37 0
100
+ -13 -19 0
101
+ -13 -25 0
102
+ -13 -31 0
103
+ -13 -37 0
104
+ -19 -25 0
105
+ -19 -31 0
106
+ -19 -37 0
107
+ -25 -31 0
108
+ -25 -37 0
109
+ -31 -37 0
110
+ -2 -8 0
111
+ -2 -14 0
112
+ -2 -20 0
113
+ -2 -26 0
114
+ -2 -32 0
115
+ -2 -38 0
116
+ -8 -14 0
117
+ -8 -20 0
118
+ -8 -26 0
119
+ -8 -32 0
120
+ -8 -38 0
121
+ -14 -20 0
122
+ -14 -26 0
123
+ -14 -32 0
124
+ -14 -38 0
125
+ -20 -26 0
126
+ -20 -32 0
127
+ -20 -38 0
128
+ -26 -32 0
129
+ -26 -38 0
130
+ -32 -38 0
131
+ -3 -9 0
132
+ -3 -15 0
133
+ -3 -21 0
134
+ -3 -27 0
135
+ -3 -33 0
136
+ -3 -39 0
137
+ -9 -15 0
138
+ -9 -21 0
139
+ -9 -27 0
140
+ -9 -33 0
141
+ -9 -39 0
142
+ -15 -21 0
143
+ -15 -27 0
144
+ -15 -33 0
145
+ -15 -39 0
146
+ -21 -27 0
147
+ -21 -33 0
148
+ -21 -39 0
149
+ -27 -33 0
150
+ -27 -39 0
151
+ -33 -39 0
152
+ -4 -10 0
153
+ -4 -16 0
154
+ -4 -22 0
155
+ -4 -28 0
156
+ -4 -34 0
157
+ -4 -40 0
158
+ -10 -16 0
159
+ -10 -22 0
160
+ -10 -28 0
161
+ -10 -34 0
162
+ -10 -40 0
163
+ -16 -22 0
164
+ -16 -28 0
165
+ -16 -34 0
166
+ -16 -40 0
167
+ -22 -28 0
168
+ -22 -34 0
169
+ -22 -40 0
170
+ -28 -34 0
171
+ -28 -40 0
172
+ -34 -40 0
173
+ -5 -11 0
174
+ -5 -17 0
175
+ -5 -23 0
176
+ -5 -29 0
177
+ -5 -35 0
178
+ -5 -41 0
179
+ -11 -17 0
180
+ -11 -23 0
181
+ -11 -29 0
182
+ -11 -35 0
183
+ -11 -41 0
184
+ -17 -23 0
185
+ -17 -29 0
186
+ -17 -35 0
187
+ -17 -41 0
188
+ -23 -29 0
189
+ -23 -35 0
190
+ -23 -41 0
191
+ -29 -35 0
192
+ -29 -41 0
193
+ -35 -41 0
194
+ -6 -12 0
195
+ -6 -18 0
196
+ -6 -24 0
197
+ -6 -30 0
198
+ -6 -36 0
199
+ -6 -42 0
200
+ -12 -18 0
201
+ -12 -24 0
202
+ -12 -30 0
203
+ -12 -36 0
204
+ -12 -42 0
205
+ -18 -24 0
206
+ -18 -30 0
207
+ -18 -36 0
208
+ -18 -42 0
209
+ -24 -30 0
210
+ -24 -36 0
211
+ -24 -42 0
212
+ -30 -36 0
213
+ -30 -42 0
214
+ -36 -42 0
215
+ 6 5 4 3 2 1 0
216
+ 12 11 10 9 8 7 0
217
+ 18 17 16 15 14 13 0
218
+ 24 23 22 21 20 19 0
219
+ 30 29 28 27 26 25 0
220
+ 36 35 34 33 32 31 0
221
+ 42 41 40 39 38 37 0
222
+ """
223
+
224
+ f5 = """c simple example requiring variable selection
225
+ c
226
+ c NOTE: Satisfiable
227
+ c
228
+ p cnf 5 5
229
+ 1 2 3 0
230
+ 1 -2 3 0
231
+ 4 5 -3 0
232
+ 1 -4 -3 0
233
+ -1 -5 0
234
+ """
.venv/lib/python3.13/site-packages/sympy/logic/tests/test_inference.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """For more tests on satisfiability, see test_dimacs"""
2
+
3
+ from sympy.assumptions.ask import Q
4
+ from sympy.core.symbol import symbols
5
+ from sympy.core.relational import Unequality
6
+ from sympy.logic.boolalg import And, Or, Implies, Equivalent, true, false
7
+ from sympy.logic.inference import literal_symbol, \
8
+ pl_true, satisfiable, valid, entails, PropKB
9
+ from sympy.logic.algorithms.dpll import dpll, dpll_satisfiable, \
10
+ find_pure_symbol, find_unit_clause, unit_propagate, \
11
+ find_pure_symbol_int_repr, find_unit_clause_int_repr, \
12
+ unit_propagate_int_repr
13
+ from sympy.logic.algorithms.dpll2 import dpll_satisfiable as dpll2_satisfiable
14
+
15
+ from sympy.logic.algorithms.z3_wrapper import z3_satisfiable
16
+ from sympy.assumptions.cnf import CNF, EncodedCNF
17
+ from sympy.logic.tests.test_lra_theory import make_random_problem
18
+ from sympy.core.random import randint
19
+
20
+ from sympy.testing.pytest import raises, skip
21
+ from sympy.external import import_module
22
+
23
+
24
+ def test_literal():
25
+ A, B = symbols('A,B')
26
+ assert literal_symbol(True) is True
27
+ assert literal_symbol(False) is False
28
+ assert literal_symbol(A) is A
29
+ assert literal_symbol(~A) is A
30
+
31
+
32
+ def test_find_pure_symbol():
33
+ A, B, C = symbols('A,B,C')
34
+ assert find_pure_symbol([A], [A]) == (A, True)
35
+ assert find_pure_symbol([A, B], [~A | B, ~B | A]) == (None, None)
36
+ assert find_pure_symbol([A, B, C], [ A | ~B, ~B | ~C, C | A]) == (A, True)
37
+ assert find_pure_symbol([A, B, C], [~A | B, B | ~C, C | A]) == (B, True)
38
+ assert find_pure_symbol([A, B, C], [~A | ~B, ~B | ~C, C | A]) == (B, False)
39
+ assert find_pure_symbol(
40
+ [A, B, C], [~A | B, ~B | ~C, C | A]) == (None, None)
41
+
42
+
43
+ def test_find_pure_symbol_int_repr():
44
+ assert find_pure_symbol_int_repr([1], [{1}]) == (1, True)
45
+ assert find_pure_symbol_int_repr([1, 2],
46
+ [{-1, 2}, {-2, 1}]) == (None, None)
47
+ assert find_pure_symbol_int_repr([1, 2, 3],
48
+ [{1, -2}, {-2, -3}, {3, 1}]) == (1, True)
49
+ assert find_pure_symbol_int_repr([1, 2, 3],
50
+ [{-1, 2}, {2, -3}, {3, 1}]) == (2, True)
51
+ assert find_pure_symbol_int_repr([1, 2, 3],
52
+ [{-1, -2}, {-2, -3}, {3, 1}]) == (2, False)
53
+ assert find_pure_symbol_int_repr([1, 2, 3],
54
+ [{-1, 2}, {-2, -3}, {3, 1}]) == (None, None)
55
+
56
+
57
+ def test_unit_clause():
58
+ A, B, C = symbols('A,B,C')
59
+ assert find_unit_clause([A], {}) == (A, True)
60
+ assert find_unit_clause([A, ~A], {}) == (A, True) # Wrong ??
61
+ assert find_unit_clause([A | B], {A: True}) == (B, True)
62
+ assert find_unit_clause([A | B], {B: True}) == (A, True)
63
+ assert find_unit_clause(
64
+ [A | B | C, B | ~C, A | ~B], {A: True}) == (B, False)
65
+ assert find_unit_clause([A | B | C, B | ~C, A | B], {A: True}) == (B, True)
66
+ assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True)
67
+
68
+
69
+ def test_unit_clause_int_repr():
70
+ assert find_unit_clause_int_repr(map(set, [[1]]), {}) == (1, True)
71
+ assert find_unit_clause_int_repr(map(set, [[1], [-1]]), {}) == (1, True)
72
+ assert find_unit_clause_int_repr([{1, 2}], {1: True}) == (2, True)
73
+ assert find_unit_clause_int_repr([{1, 2}], {2: True}) == (1, True)
74
+ assert find_unit_clause_int_repr(map(set,
75
+ [[1, 2, 3], [2, -3], [1, -2]]), {1: True}) == (2, False)
76
+ assert find_unit_clause_int_repr(map(set,
77
+ [[1, 2, 3], [3, -3], [1, 2]]), {1: True}) == (2, True)
78
+
79
+ A, B, C = symbols('A,B,C')
80
+ assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True)
81
+
82
+
83
+ def test_unit_propagate():
84
+ A, B, C = symbols('A,B,C')
85
+ assert unit_propagate([A | B], A) == []
86
+ assert unit_propagate([A | B, ~A | C, ~C | B, A], A) == [C, ~C | B, A]
87
+
88
+
89
+ def test_unit_propagate_int_repr():
90
+ assert unit_propagate_int_repr([{1, 2}], 1) == []
91
+ assert unit_propagate_int_repr(map(set,
92
+ [[1, 2], [-1, 3], [-3, 2], [1]]), 1) == [{3}, {-3, 2}]
93
+
94
+
95
+ def test_dpll():
96
+ """This is also tested in test_dimacs"""
97
+ A, B, C = symbols('A,B,C')
98
+ assert dpll([A | B], [A, B], {A: True, B: True}) == {A: True, B: True}
99
+
100
+
101
+ def test_dpll_satisfiable():
102
+ A, B, C = symbols('A,B,C')
103
+ assert dpll_satisfiable( A & ~A ) is False
104
+ assert dpll_satisfiable( A & ~B ) == {A: True, B: False}
105
+ assert dpll_satisfiable(
106
+ A | B ) in ({A: True}, {B: True}, {A: True, B: True})
107
+ assert dpll_satisfiable(
108
+ (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
109
+ assert dpll_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False},
110
+ {A: True, C: True}, {B: True, C: True})
111
+ assert dpll_satisfiable( A & B & C ) == {A: True, B: True, C: True}
112
+ assert dpll_satisfiable( (A | B) & (A >> B) ) == {B: True}
113
+ assert dpll_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
114
+ assert dpll_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
115
+
116
+
117
+ def test_dpll2_satisfiable():
118
+ A, B, C = symbols('A,B,C')
119
+ assert dpll2_satisfiable( A & ~A ) is False
120
+ assert dpll2_satisfiable( A & ~B ) == {A: True, B: False}
121
+ assert dpll2_satisfiable(
122
+ A | B ) in ({A: True}, {B: True}, {A: True, B: True})
123
+ assert dpll2_satisfiable(
124
+ (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
125
+ assert dpll2_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
126
+ {A: True, B: True, C: True})
127
+ assert dpll2_satisfiable( A & B & C ) == {A: True, B: True, C: True}
128
+ assert dpll2_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
129
+ {B: True, A: True})
130
+ assert dpll2_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
131
+ assert dpll2_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
132
+
133
+
134
+ def test_minisat22_satisfiable():
135
+ A, B, C = symbols('A,B,C')
136
+ minisat22_satisfiable = lambda expr: satisfiable(expr, algorithm="minisat22")
137
+ assert minisat22_satisfiable( A & ~A ) is False
138
+ assert minisat22_satisfiable( A & ~B ) == {A: True, B: False}
139
+ assert minisat22_satisfiable(
140
+ A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False})
141
+ assert minisat22_satisfiable(
142
+ (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
143
+ assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
144
+ {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False})
145
+ assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True}
146
+ assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
147
+ {B: True, A: True})
148
+ assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
149
+ assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
150
+
151
+ def test_minisat22_minimal_satisfiable():
152
+ A, B, C = symbols('A,B,C')
153
+ minisat22_satisfiable = lambda expr, minimal=True: satisfiable(expr, algorithm="minisat22", minimal=True)
154
+ assert minisat22_satisfiable( A & ~A ) is False
155
+ assert minisat22_satisfiable( A & ~B ) == {A: True, B: False}
156
+ assert minisat22_satisfiable(
157
+ A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False})
158
+ assert minisat22_satisfiable(
159
+ (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
160
+ assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
161
+ {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False})
162
+ assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True}
163
+ assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
164
+ {B: True, A: True})
165
+ assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
166
+ assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
167
+ g = satisfiable((A | B | C),algorithm="minisat22",minimal=True,all_models=True)
168
+ sol = next(g)
169
+ first_solution = {key for key, value in sol.items() if value}
170
+ sol=next(g)
171
+ second_solution = {key for key, value in sol.items() if value}
172
+ sol=next(g)
173
+ third_solution = {key for key, value in sol.items() if value}
174
+ assert not first_solution <= second_solution
175
+ assert not second_solution <= third_solution
176
+ assert not first_solution <= third_solution
177
+
178
+ def test_satisfiable():
179
+ A, B, C = symbols('A,B,C')
180
+ assert satisfiable(A & (A >> B) & ~B) is False
181
+
182
+
183
+ def test_valid():
184
+ A, B, C = symbols('A,B,C')
185
+ assert valid(A >> (B >> A)) is True
186
+ assert valid((A >> (B >> C)) >> ((A >> B) >> (A >> C))) is True
187
+ assert valid((~B >> ~A) >> (A >> B)) is True
188
+ assert valid(A | B | C) is False
189
+ assert valid(A >> B) is False
190
+
191
+
192
+ def test_pl_true():
193
+ A, B, C = symbols('A,B,C')
194
+ assert pl_true(True) is True
195
+ assert pl_true( A & B, {A: True, B: True}) is True
196
+ assert pl_true( A | B, {A: True}) is True
197
+ assert pl_true( A | B, {B: True}) is True
198
+ assert pl_true( A | B, {A: None, B: True}) is True
199
+ assert pl_true( A >> B, {A: False}) is True
200
+ assert pl_true( A | B | ~C, {A: False, B: True, C: True}) is True
201
+ assert pl_true(Equivalent(A, B), {A: False, B: False}) is True
202
+
203
+ # test for false
204
+ assert pl_true(False) is False
205
+ assert pl_true( A & B, {A: False, B: False}) is False
206
+ assert pl_true( A & B, {A: False}) is False
207
+ assert pl_true( A & B, {B: False}) is False
208
+ assert pl_true( A | B, {A: False, B: False}) is False
209
+
210
+ #test for None
211
+ assert pl_true(B, {B: None}) is None
212
+ assert pl_true( A & B, {A: True, B: None}) is None
213
+ assert pl_true( A >> B, {A: True, B: None}) is None
214
+ assert pl_true(Equivalent(A, B), {A: None}) is None
215
+ assert pl_true(Equivalent(A, B), {A: True, B: None}) is None
216
+
217
+ # Test for deep
218
+ assert pl_true(A | B, {A: False}, deep=True) is None
219
+ assert pl_true(~A & ~B, {A: False}, deep=True) is None
220
+ assert pl_true(A | B, {A: False, B: False}, deep=True) is False
221
+ assert pl_true(A & B & (~A | ~B), {A: True}, deep=True) is False
222
+ assert pl_true((C >> A) >> (B >> A), {C: True}, deep=True) is True
223
+
224
+
225
+ def test_pl_true_wrong_input():
226
+ from sympy.core.numbers import pi
227
+ raises(ValueError, lambda: pl_true('John Cleese'))
228
+ raises(ValueError, lambda: pl_true(42 + pi + pi ** 2))
229
+ raises(ValueError, lambda: pl_true(42))
230
+
231
+
232
+ def test_entails():
233
+ A, B, C = symbols('A, B, C')
234
+ assert entails(A, [A >> B, ~B]) is False
235
+ assert entails(B, [Equivalent(A, B), A]) is True
236
+ assert entails((A >> B) >> (~A >> ~B)) is False
237
+ assert entails((A >> B) >> (~B >> ~A)) is True
238
+
239
+
240
+ def test_PropKB():
241
+ A, B, C = symbols('A,B,C')
242
+ kb = PropKB()
243
+ assert kb.ask(A >> B) is False
244
+ assert kb.ask(A >> (B >> A)) is True
245
+ kb.tell(A >> B)
246
+ kb.tell(B >> C)
247
+ assert kb.ask(A) is False
248
+ assert kb.ask(B) is False
249
+ assert kb.ask(C) is False
250
+ assert kb.ask(~A) is False
251
+ assert kb.ask(~B) is False
252
+ assert kb.ask(~C) is False
253
+ assert kb.ask(A >> C) is True
254
+ kb.tell(A)
255
+ assert kb.ask(A) is True
256
+ assert kb.ask(B) is True
257
+ assert kb.ask(C) is True
258
+ assert kb.ask(~C) is False
259
+ kb.retract(A)
260
+ assert kb.ask(C) is False
261
+
262
+
263
+ def test_propKB_tolerant():
264
+ """"tolerant to bad input"""
265
+ kb = PropKB()
266
+ A, B, C = symbols('A,B,C')
267
+ assert kb.ask(B) is False
268
+
269
+ def test_satisfiable_non_symbols():
270
+ x, y = symbols('x y')
271
+ assumptions = Q.zero(x*y)
272
+ facts = Implies(Q.zero(x*y), Q.zero(x) | Q.zero(y))
273
+ query = ~Q.zero(x) & ~Q.zero(y)
274
+ refutations = [
275
+ {Q.zero(x): True, Q.zero(x*y): True},
276
+ {Q.zero(y): True, Q.zero(x*y): True},
277
+ {Q.zero(x): True, Q.zero(y): True, Q.zero(x*y): True},
278
+ {Q.zero(x): True, Q.zero(y): False, Q.zero(x*y): True},
279
+ {Q.zero(x): False, Q.zero(y): True, Q.zero(x*y): True}]
280
+ assert not satisfiable(And(assumptions, facts, query), algorithm='dpll')
281
+ assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll') in refutations
282
+ assert not satisfiable(And(assumptions, facts, query), algorithm='dpll2')
283
+ assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll2') in refutations
284
+
285
+ def test_satisfiable_bool():
286
+ from sympy.core.singleton import S
287
+ assert satisfiable(true) == {true: true}
288
+ assert satisfiable(S.true) == {true: true}
289
+ assert satisfiable(false) is False
290
+ assert satisfiable(S.false) is False
291
+
292
+
293
+ def test_satisfiable_all_models():
294
+ from sympy.abc import A, B
295
+ assert next(satisfiable(False, all_models=True)) is False
296
+ assert list(satisfiable((A >> ~A) & A, all_models=True)) == [False]
297
+ assert list(satisfiable(True, all_models=True)) == [{true: true}]
298
+
299
+ models = [{A: True, B: False}, {A: False, B: True}]
300
+ result = satisfiable(A ^ B, all_models=True)
301
+ models.remove(next(result))
302
+ models.remove(next(result))
303
+ raises(StopIteration, lambda: next(result))
304
+ assert not models
305
+
306
+ assert list(satisfiable(Equivalent(A, B), all_models=True)) == \
307
+ [{A: False, B: False}, {A: True, B: True}]
308
+
309
+ models = [{A: False, B: False}, {A: False, B: True}, {A: True, B: True}]
310
+ for model in satisfiable(A >> B, all_models=True):
311
+ models.remove(model)
312
+ assert not models
313
+
314
+ # This is a santiy test to check that only the required number
315
+ # of solutions are generated. The expr below has 2**100 - 1 models
316
+ # which would time out the test if all are generated at once.
317
+ from sympy.utilities.iterables import numbered_symbols
318
+ from sympy.logic.boolalg import Or
319
+ sym = numbered_symbols()
320
+ X = [next(sym) for i in range(100)]
321
+ result = satisfiable(Or(*X), all_models=True)
322
+ for i in range(10):
323
+ assert next(result)
324
+
325
+
326
+ def test_z3():
327
+ z3 = import_module("z3")
328
+
329
+ if not z3:
330
+ skip("z3 not installed.")
331
+ A, B, C = symbols('A,B,C')
332
+ x, y, z = symbols('x,y,z')
333
+ assert z3_satisfiable((x >= 2) & (x < 1)) is False
334
+ assert z3_satisfiable( A & ~A ) is False
335
+
336
+ model = z3_satisfiable(A & (~A | B | C))
337
+ assert bool(model) is True
338
+ assert model[A] is True
339
+
340
+ # test nonlinear function
341
+ assert z3_satisfiable((x ** 2 >= 2) & (x < 1) & (x > -1)) is False
342
+
343
+
344
+ def test_z3_vs_lra_dpll2():
345
+ z3 = import_module("z3")
346
+ if z3 is None:
347
+ skip("z3 not installed.")
348
+
349
+ def boolean_formula_to_encoded_cnf(bf):
350
+ cnf = CNF.from_prop(bf)
351
+ enc = EncodedCNF()
352
+ enc.from_cnf(cnf)
353
+ return enc
354
+
355
+ def make_random_cnf(num_clauses=5, num_constraints=10, num_var=2):
356
+ assert num_clauses <= num_constraints
357
+ constraints = make_random_problem(num_variables=num_var, num_constraints=num_constraints, rational=False)
358
+ clauses = [[cons] for cons in constraints[:num_clauses]]
359
+ for cons in constraints[num_clauses:]:
360
+ if isinstance(cons, Unequality):
361
+ cons = ~cons
362
+ i = randint(0, num_clauses-1)
363
+ clauses[i].append(cons)
364
+
365
+ clauses = [Or(*clause) for clause in clauses]
366
+ cnf = And(*clauses)
367
+ return boolean_formula_to_encoded_cnf(cnf)
368
+
369
+ lra_dpll2_satisfiable = lambda x: dpll2_satisfiable(x, use_lra_theory=True)
370
+
371
+ for _ in range(50):
372
+ cnf = make_random_cnf(num_clauses=10, num_constraints=15, num_var=2)
373
+
374
+ try:
375
+ z3_sat = z3_satisfiable(cnf)
376
+ except z3.z3types.Z3Exception:
377
+ continue
378
+
379
+ lra_dpll2_sat = lra_dpll2_satisfiable(cnf) is not False
380
+
381
+ assert z3_sat == lra_dpll2_sat
382
+
383
+ def test_issue_27733():
384
+ x, y = symbols('x,y')
385
+ clauses = [[1, -3, -2], [5, 7, -8, -6, -4], [-10, -9, 10, 11, -4], [-12, 13, 14], [-10, 9, -6, 11, -4],
386
+ [16, -15, 18, -19, -17], [11, -6, 10, -9], [9, 11, -10, -9], [2, -3, -1], [-13, 12], [-15, 3, -17],
387
+ [-16, -15, 19, -17], [-6, -9, 10, 11, -4], [20, -1, -2], [-23, -22, -21], [10, 11, -10, -9],
388
+ [9, 11, -4, -10], [24, -6, -4], [-14, 12], [-10, -9, 9, -6, 11], [25, -27, -26], [-15, 19, -18, -17],
389
+ [5, 8, -7, -6, -4], [-30, -29, 28], [12], [14]]
390
+
391
+ encoding = {Q.gt(y, i): i for i in range(1, 31) if i != 11 and i != 12}
392
+ encoding[Q.gt(x, 0)] = 11
393
+ encoding[Q.lt(x, 0)] = 12
394
+
395
+ cnf = EncodedCNF(clauses, encoding)
396
+ assert satisfiable(cnf, use_lra_theory=True) is False
.venv/lib/python3.13/site-packages/sympy/logic/tests/test_lra_theory.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.numbers import Rational, I, oo
2
+ from sympy.core.relational import Eq
3
+ from sympy.core.symbol import symbols
4
+ from sympy.core.singleton import S
5
+ from sympy.matrices.dense import Matrix
6
+ from sympy.matrices.dense import randMatrix
7
+ from sympy.assumptions.ask import Q
8
+ from sympy.logic.boolalg import And
9
+ from sympy.abc import x, y, z
10
+ from sympy.assumptions.cnf import CNF, EncodedCNF
11
+ from sympy.functions.elementary.trigonometric import cos
12
+ from sympy.external import import_module
13
+
14
+ from sympy.logic.algorithms.lra_theory import LRASolver, UnhandledInput, LRARational, HANDLE_NEGATION
15
+ from sympy.core.random import random, choice, randint
16
+ from sympy.core.sympify import sympify
17
+ from sympy.ntheory.generate import randprime
18
+ from sympy.core.relational import StrictLessThan, StrictGreaterThan
19
+ import itertools
20
+
21
+ from sympy.testing.pytest import raises, XFAIL, skip
22
+
23
+ def make_random_problem(num_variables=2, num_constraints=2, sparsity=.1, rational=True,
24
+ disable_strict = False, disable_nonstrict=False, disable_equality=False):
25
+ def rand(sparsity=sparsity):
26
+ if random() < sparsity:
27
+ return sympify(0)
28
+ if rational:
29
+ int1, int2 = [randprime(0, 50) for _ in range(2)]
30
+ return Rational(int1, int2) * choice([-1, 1])
31
+ else:
32
+ return randint(1, 10) * choice([-1, 1])
33
+
34
+ variables = symbols('x1:%s' % (num_variables + 1))
35
+ constraints = []
36
+ for _ in range(num_constraints):
37
+ lhs, rhs = sum(rand() * x for x in variables), rand(sparsity=0) # sparsity=0 bc of bug with smtlib_code
38
+ options = []
39
+ if not disable_equality:
40
+ options += [Eq(lhs, rhs)]
41
+ if not disable_nonstrict:
42
+ options += [lhs <= rhs, lhs >= rhs]
43
+ if not disable_strict:
44
+ options += [lhs < rhs, lhs > rhs]
45
+
46
+ constraints.append(choice(options))
47
+
48
+ return constraints
49
+
50
+ def check_if_satisfiable_with_z3(constraints):
51
+ from sympy.external.importtools import import_module
52
+ from sympy.printing.smtlib import smtlib_code
53
+ from sympy.logic.boolalg import And
54
+ boolean_formula = And(*constraints)
55
+ z3 = import_module("z3")
56
+ if z3:
57
+ smtlib_string = smtlib_code(boolean_formula)
58
+ s = z3.Solver()
59
+ s.from_string(smtlib_string)
60
+ res = str(s.check())
61
+ if res == 'sat':
62
+ return True
63
+ elif res == 'unsat':
64
+ return False
65
+ else:
66
+ raise ValueError(f"z3 was not able to check the satisfiability of {boolean_formula}")
67
+
68
+ def find_rational_assignment(constr, assignment, iter=20):
69
+ eps = sympify(1)
70
+
71
+ for _ in range(iter):
72
+ assign = {key: val[0] + val[1]*eps for key, val in assignment.items()}
73
+ try:
74
+ for cons in constr:
75
+ assert cons.subs(assign) == True
76
+ return assign
77
+ except AssertionError:
78
+ eps = eps/2
79
+
80
+ return None
81
+
82
+ def boolean_formula_to_encoded_cnf(bf):
83
+ cnf = CNF.from_prop(bf)
84
+ enc = EncodedCNF()
85
+ enc.from_cnf(cnf)
86
+ return enc
87
+
88
+
89
+ def test_from_encoded_cnf():
90
+ s1, s2 = symbols("s1 s2")
91
+
92
+ # Test preprocessing
93
+ # Example is from section 3 of paper.
94
+ phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6)) & (Eq(x + y, 2) | (x + 2 * y - z > 4))
95
+ enc = boolean_formula_to_encoded_cnf(phi)
96
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
97
+ assert lra.A.shape == (2, 5)
98
+ assert str(lra.slack) == '[_s1, _s2]'
99
+ assert str(lra.nonslack) == '[x, y, z]'
100
+ assert lra.A == Matrix([[ 1, 1, 0, -1, 0],
101
+ [-1, -2, 1, 0, -1]])
102
+ assert {(str(b.var), b.bound, b.upper, b.equality, b.strict) for b in lra.enc_to_boundary.values()} == {('_s1', 2, None, True, False),
103
+ ('_s1', 2, True, False, False),
104
+ ('_s2', -4, True, False, True),
105
+ ('_s2', -6, True, False, False),
106
+ ('x', 0, False, False, False)}
107
+
108
+
109
+ def test_problem():
110
+ from sympy.logic.algorithms.lra_theory import LRASolver
111
+ from sympy.assumptions.cnf import CNF, EncodedCNF
112
+ cons = [-2 * x - 2 * y >= 7, -9 * y >= 7, -6 * y >= 5]
113
+ cnf = CNF().from_prop(And(*cons))
114
+ enc = EncodedCNF()
115
+ enc.from_cnf(cnf)
116
+ lra, _ = LRASolver.from_encoded_cnf(enc)
117
+ lra.assert_lit(1)
118
+ lra.assert_lit(2)
119
+ lra.assert_lit(3)
120
+ is_sat, assignment = lra.check()
121
+ assert is_sat is True
122
+
123
+
124
+ def test_random_problems():
125
+ z3 = import_module("z3")
126
+ if z3 is None:
127
+ skip("z3 is not installed")
128
+
129
+ special_cases = []; x1, x2, x3 = symbols("x1 x2 x3")
130
+ special_cases.append([x1 - 3 * x2 <= -5, 6 * x1 + 4 * x2 <= 0, -7 * x1 + 3 * x2 <= 3])
131
+ special_cases.append([-3 * x1 >= 3, Eq(4 * x1, -1)])
132
+ special_cases.append([-4 * x1 < 4, 6 * x1 <= -6])
133
+ special_cases.append([-3 * x2 >= 7, 6 * x1 <= -5, -3 * x2 <= -4])
134
+ special_cases.append([x + y >= 2, x + y <= 1])
135
+ special_cases.append([x >= 0, x + y <= 2, x + 2 * y - z >= 6]) # from paper example
136
+ special_cases.append([-2 * x1 - 2 * x2 >= 7, -9 * x1 >= 7, -6 * x1 >= 5])
137
+ special_cases.append([2 * x1 > -3, -9 * x1 < -6, 9 * x1 <= 6])
138
+ special_cases.append([-2*x1 < -4, 9*x1 > -9])
139
+ special_cases.append([-6*x1 >= -1, -8*x1 + x2 >= 5, -8*x1 + 7*x2 < 4, x1 > 7])
140
+ special_cases.append([Eq(x1, 2), Eq(5*x1, -2), Eq(-7*x2, -6), Eq(9*x1 + 10*x2, 9)])
141
+ special_cases.append([Eq(3*x1, 6), Eq(x1 - 8*x2, -9), Eq(-7*x1 + 5*x2, 3), Eq(3*x2, 7)])
142
+ special_cases.append([-4*x1 < 4, 6*x1 <= -6])
143
+ special_cases.append([-3*x1 + 8*x2 >= -8, -10*x2 > 9, 8*x1 - 4*x2 < 8, 10*x1 - 9*x2 >= -9])
144
+ special_cases.append([x1 + 5*x2 >= -6, 9*x1 - 3*x2 >= -9, 6*x1 + 6*x2 < -10, -3*x1 + 3*x2 < -7])
145
+ special_cases.append([-9*x1 < 7, -5*x1 - 7*x2 < -1, 3*x1 + 7*x2 > 1, -6*x1 - 6*x2 > 9])
146
+ special_cases.append([9*x1 - 6*x2 >= -7, 9*x1 + 4*x2 < -8, -7*x2 <= 1, 10*x2 <= -7])
147
+
148
+ feasible_count = 0
149
+ for i in range(50):
150
+ if i % 8 == 0:
151
+ constraints = make_random_problem(num_variables=1, num_constraints=2, rational=False)
152
+ elif i % 8 == 1:
153
+ constraints = make_random_problem(num_variables=2, num_constraints=4, rational=False, disable_equality=True,
154
+ disable_nonstrict=True)
155
+ elif i % 8 == 2:
156
+ constraints = make_random_problem(num_variables=2, num_constraints=4, rational=False, disable_strict=True)
157
+ elif i % 8 == 3:
158
+ constraints = make_random_problem(num_variables=3, num_constraints=12, rational=False)
159
+ else:
160
+ constraints = make_random_problem(num_variables=3, num_constraints=6, rational=False)
161
+
162
+ if i < len(special_cases):
163
+ constraints = special_cases[i]
164
+
165
+ if False in constraints or True in constraints:
166
+ continue
167
+
168
+ phi = And(*constraints)
169
+ if phi == False:
170
+ continue
171
+ cnf = CNF.from_prop(phi); enc = EncodedCNF()
172
+ enc.from_cnf(cnf)
173
+ assert all(0 not in clause for clause in enc.data)
174
+
175
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
176
+ s_subs = lra.s_subs
177
+
178
+ lra.run_checks = True
179
+ s_subs_rev = {value: key for key, value in s_subs.items()}
180
+ lits = {lit for clause in enc.data for lit in clause}
181
+
182
+ bounds = [(lra.enc_to_boundary[l], l) for l in lits if l in lra.enc_to_boundary]
183
+ bounds = sorted(bounds, key=lambda x: (str(x[0].var), x[0].bound, str(x[0].upper))) # to remove nondeterminism
184
+
185
+ for b, l in bounds:
186
+ if lra.result and lra.result[0] == False:
187
+ break
188
+ lra.assert_lit(l)
189
+
190
+ feasible = lra.check()
191
+
192
+ if feasible[0] == True:
193
+ feasible_count += 1
194
+ assert check_if_satisfiable_with_z3(constraints) is True
195
+ cons_funcs = [cons.func for cons in constraints]
196
+ assignment = feasible[1]
197
+ assignment = {key.var : value for key, value in assignment.items()}
198
+ if not (StrictLessThan in cons_funcs or StrictGreaterThan in cons_funcs):
199
+ assignment = {key: value[0] for key, value in assignment.items()}
200
+ for cons in constraints:
201
+ assert cons.subs(assignment) == True
202
+
203
+ else:
204
+ rat_assignment = find_rational_assignment(constraints, assignment)
205
+ assert rat_assignment is not None
206
+ else:
207
+ assert check_if_satisfiable_with_z3(constraints) is False
208
+
209
+ conflict = feasible[1]
210
+ assert len(conflict) >= 2
211
+ conflict = {lra.enc_to_boundary[-l].get_inequality() for l in conflict}
212
+ conflict = {clause.subs(s_subs_rev) for clause in conflict}
213
+ assert check_if_satisfiable_with_z3(conflict) is False
214
+
215
+ # check that conflict clause is probably minimal
216
+ for subset in itertools.combinations(conflict, len(conflict)-1):
217
+ assert check_if_satisfiable_with_z3(subset) is True
218
+
219
+
220
+ @XFAIL
221
+ def test_pos_neg_zero():
222
+ bf = Q.positive(x) & Q.negative(x) & Q.zero(y)
223
+ enc = boolean_formula_to_encoded_cnf(bf)
224
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
225
+ for lit in enc.encoding.values():
226
+ if lra.assert_lit(lit) is not None:
227
+ break
228
+ assert len(lra.enc_to_boundary) == 3
229
+ assert lra.check()[0] == False
230
+
231
+ bf = Q.positive(x) & Q.lt(x, -1)
232
+ enc = boolean_formula_to_encoded_cnf(bf)
233
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
234
+ for lit in enc.encoding.values():
235
+ if lra.assert_lit(lit) is not None:
236
+ break
237
+ assert len(lra.enc_to_boundary) == 2
238
+ assert lra.check()[0] == False
239
+
240
+ bf = Q.positive(x) & Q.zero(x)
241
+ enc = boolean_formula_to_encoded_cnf(bf)
242
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
243
+ for lit in enc.encoding.values():
244
+ if lra.assert_lit(lit) is not None:
245
+ break
246
+ assert len(lra.enc_to_boundary) == 2
247
+ assert lra.check()[0] == False
248
+
249
+ bf = Q.positive(x) & Q.zero(y)
250
+ enc = boolean_formula_to_encoded_cnf(bf)
251
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
252
+ for lit in enc.encoding.values():
253
+ if lra.assert_lit(lit) is not None:
254
+ break
255
+ assert len(lra.enc_to_boundary) == 2
256
+ assert lra.check()[0] == True
257
+
258
+
259
+ @XFAIL
260
+ def test_pos_neg_infinite():
261
+ bf = Q.positive_infinite(x) & Q.lt(x, 10000000) & Q.positive_infinite(y)
262
+ enc = boolean_formula_to_encoded_cnf(bf)
263
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
264
+ for lit in enc.encoding.values():
265
+ if lra.assert_lit(lit) is not None:
266
+ break
267
+ assert len(lra.enc_to_boundary) == 3
268
+ assert lra.check()[0] == False
269
+
270
+ bf = Q.positive_infinite(x) & Q.gt(x, 10000000) & Q.positive_infinite(y)
271
+ enc = boolean_formula_to_encoded_cnf(bf)
272
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
273
+ for lit in enc.encoding.values():
274
+ if lra.assert_lit(lit) is not None:
275
+ break
276
+ assert len(lra.enc_to_boundary) == 3
277
+ assert lra.check()[0] == True
278
+
279
+ bf = Q.positive_infinite(x) & Q.negative_infinite(x)
280
+ enc = boolean_formula_to_encoded_cnf(bf)
281
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
282
+ for lit in enc.encoding.values():
283
+ if lra.assert_lit(lit) is not None:
284
+ break
285
+ assert len(lra.enc_to_boundary) == 2
286
+ assert lra.check()[0] == False
287
+
288
+
289
+ def test_binrel_evaluation():
290
+ bf = Q.gt(3, 2)
291
+ enc = boolean_formula_to_encoded_cnf(bf)
292
+ lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True)
293
+ assert len(lra.enc_to_boundary) == 0
294
+ assert conflicts == [[1]]
295
+
296
+ bf = Q.lt(3, 2)
297
+ enc = boolean_formula_to_encoded_cnf(bf)
298
+ lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True)
299
+ assert len(lra.enc_to_boundary) == 0
300
+ assert conflicts == [[-1]]
301
+
302
+
303
+ def test_negation():
304
+ assert HANDLE_NEGATION is True
305
+ bf = Q.gt(x, 1) & ~Q.gt(x, 0)
306
+ enc = boolean_formula_to_encoded_cnf(bf)
307
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
308
+ for clause in enc.data:
309
+ for lit in clause:
310
+ lra.assert_lit(lit)
311
+ assert len(lra.enc_to_boundary) == 2
312
+ assert lra.check()[0] == False
313
+ assert sorted(lra.check()[1]) in [[-1, 2], [-2, 1]]
314
+
315
+ bf = ~Q.gt(x, 1) & ~Q.lt(x, 0)
316
+ enc = boolean_formula_to_encoded_cnf(bf)
317
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
318
+ for clause in enc.data:
319
+ for lit in clause:
320
+ lra.assert_lit(lit)
321
+ assert len(lra.enc_to_boundary) == 2
322
+ assert lra.check()[0] == True
323
+
324
+ bf = ~Q.gt(x, 0) & ~Q.lt(x, 1)
325
+ enc = boolean_formula_to_encoded_cnf(bf)
326
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
327
+ for clause in enc.data:
328
+ for lit in clause:
329
+ lra.assert_lit(lit)
330
+ assert len(lra.enc_to_boundary) == 2
331
+ assert lra.check()[0] == False
332
+
333
+ bf = ~Q.gt(x, 0) & ~Q.le(x, 0)
334
+ enc = boolean_formula_to_encoded_cnf(bf)
335
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
336
+ for clause in enc.data:
337
+ for lit in clause:
338
+ lra.assert_lit(lit)
339
+ assert len(lra.enc_to_boundary) == 2
340
+ assert lra.check()[0] == False
341
+
342
+ bf = ~Q.le(x+y, 2) & ~Q.ge(x-y, 2) & ~Q.ge(y, 0)
343
+ enc = boolean_formula_to_encoded_cnf(bf)
344
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
345
+ for clause in enc.data:
346
+ for lit in clause:
347
+ lra.assert_lit(lit)
348
+ assert len(lra.enc_to_boundary) == 3
349
+ assert lra.check()[0] == False
350
+ assert len(lra.check()[1]) == 3
351
+ assert all(i > 0 for i in lra.check()[1])
352
+
353
+
354
+ def test_unhandled_input():
355
+ nan = S.NaN
356
+ bf = Q.gt(3, nan) & Q.gt(x, nan)
357
+ enc = boolean_formula_to_encoded_cnf(bf)
358
+ raises(ValueError, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True))
359
+
360
+ bf = Q.gt(3, I) & Q.gt(x, I)
361
+ enc = boolean_formula_to_encoded_cnf(bf)
362
+ raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True))
363
+
364
+ bf = Q.gt(3, float("inf")) & Q.gt(x, float("inf"))
365
+ enc = boolean_formula_to_encoded_cnf(bf)
366
+ raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True))
367
+
368
+ bf = Q.gt(3, oo) & Q.gt(x, oo)
369
+ enc = boolean_formula_to_encoded_cnf(bf)
370
+ raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True))
371
+
372
+ # test non-linearity
373
+ bf = Q.gt(x**2 + x, 2)
374
+ enc = boolean_formula_to_encoded_cnf(bf)
375
+ raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True))
376
+
377
+ bf = Q.gt(cos(x) + x, 2)
378
+ enc = boolean_formula_to_encoded_cnf(bf)
379
+ raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True))
380
+
381
+ @XFAIL
382
+ def test_infinite_strict_inequalities():
383
+ # Extensive testing of the interaction between strict inequalities
384
+ # and constraints containing infinity is needed because
385
+ # the paper's rule for strict inequalities don't work when
386
+ # infinite numbers are allowed. Using the paper's rules you
387
+ # can end up with situations where oo + delta > oo is considered
388
+ # True when oo + delta should be equal to oo.
389
+ # See https://math.stackexchange.com/questions/4757069/can-this-method-of-converting-strict-inequalities-to-equisatisfiable-nonstrict-i
390
+ bf = (-x - y >= -float("inf")) & (x > 0) & (y >= float("inf"))
391
+ enc = boolean_formula_to_encoded_cnf(bf)
392
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
393
+ for lit in sorted(enc.encoding.values()):
394
+ if lra.assert_lit(lit) is not None:
395
+ break
396
+ assert len(lra.enc_to_boundary) == 3
397
+ assert lra.check()[0] == True
398
+
399
+
400
+ def test_pivot():
401
+ for _ in range(10):
402
+ m = randMatrix(5)
403
+ rref = m.rref()
404
+ for _ in range(5):
405
+ i, j = randint(0, 4), randint(0, 4)
406
+ if m[i, j] != 0:
407
+ assert LRASolver._pivot(m, i, j).rref() == rref
408
+
409
+
410
+ def test_reset_bounds():
411
+ bf = Q.ge(x, 1) & Q.lt(x, 1)
412
+ enc = boolean_formula_to_encoded_cnf(bf)
413
+ lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True)
414
+ for clause in enc.data:
415
+ for lit in clause:
416
+ lra.assert_lit(lit)
417
+ assert len(lra.enc_to_boundary) == 2
418
+ assert lra.check()[0] == False
419
+
420
+ lra.reset_bounds()
421
+ assert lra.check()[0] == True
422
+ for var in lra.all_var:
423
+ assert var.upper == LRARational(float("inf"), 0)
424
+ assert var.upper_from_eq == False
425
+ assert var.upper_from_neg == False
426
+ assert var.lower == LRARational(-float("inf"), 0)
427
+ assert var.lower_from_eq == False
428
+ assert var.lower_from_neg == False
429
+ assert var.assign == LRARational(0, 0)
430
+ assert var.var is not None
431
+ assert var.col_idx is not None
432
+
433
+
434
+ def test_empty_cnf():
435
+ cnf = CNF()
436
+ enc = EncodedCNF()
437
+ enc.from_cnf(cnf)
438
+ lra, conflict = LRASolver.from_encoded_cnf(enc)
439
+ assert len(conflict) == 0
440
+ assert lra.check() == (True, {})
.venv/lib/python3.13/site-packages/sympy/logic/utilities/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .dimacs import load_file
2
+
3
+ __all__ = ['load_file']
.venv/lib/python3.13/site-packages/sympy/logic/utilities/dimacs.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """For reading in DIMACS file format
2
+
3
+ www.cs.ubc.ca/~hoos/SATLIB/Benchmarks/SAT/satformat.ps
4
+
5
+ """
6
+
7
+ from sympy.core import Symbol
8
+ from sympy.logic.boolalg import And, Or
9
+ import re
10
+ from pathlib import Path
11
+
12
+
13
+ def load(s):
14
+ """Loads a boolean expression from a string.
15
+
16
+ Examples
17
+ ========
18
+
19
+ >>> from sympy.logic.utilities.dimacs import load
20
+ >>> load('1')
21
+ cnf_1
22
+ >>> load('1 2')
23
+ cnf_1 | cnf_2
24
+ >>> load('1 \\n 2')
25
+ cnf_1 & cnf_2
26
+ >>> load('1 2 \\n 3')
27
+ cnf_3 & (cnf_1 | cnf_2)
28
+ """
29
+ clauses = []
30
+
31
+ lines = s.split('\n')
32
+
33
+ pComment = re.compile(r'c.*')
34
+ pStats = re.compile(r'p\s*cnf\s*(\d*)\s*(\d*)')
35
+
36
+ while len(lines) > 0:
37
+ line = lines.pop(0)
38
+
39
+ # Only deal with lines that aren't comments
40
+ if not pComment.match(line):
41
+ m = pStats.match(line)
42
+
43
+ if not m:
44
+ nums = line.rstrip('\n').split(' ')
45
+ list = []
46
+ for lit in nums:
47
+ if lit != '':
48
+ if int(lit) == 0:
49
+ continue
50
+ num = abs(int(lit))
51
+ sign = True
52
+ if int(lit) < 0:
53
+ sign = False
54
+
55
+ if sign:
56
+ list.append(Symbol("cnf_%s" % num))
57
+ else:
58
+ list.append(~Symbol("cnf_%s" % num))
59
+
60
+ if len(list) > 0:
61
+ clauses.append(Or(*list))
62
+
63
+ return And(*clauses)
64
+
65
+
66
+ def load_file(location):
67
+ """Loads a boolean expression from a file."""
68
+ s = Path(location).read_text()
69
+ return load(s)
.venv/lib/python3.13/site-packages/sympy/printing/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_aesaracode.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Important note on tests in this module - the Aesara printing functions use a
3
+ global cache by default, which means that tests using it will modify global
4
+ state and thus not be independent from each other. Instead of using the "cache"
5
+ keyword argument each time, this module uses the aesara_code_ and
6
+ aesara_function_ functions defined below which default to using a new, empty
7
+ cache instead.
8
+ """
9
+
10
+ import logging
11
+
12
+ from sympy.external import import_module
13
+ from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy
14
+
15
+ from sympy.utilities.exceptions import ignore_warnings
16
+
17
+
18
+ aesaralogger = logging.getLogger('aesara.configdefaults')
19
+ aesaralogger.setLevel(logging.CRITICAL)
20
+ aesara = import_module('aesara')
21
+ aesaralogger.setLevel(logging.WARNING)
22
+
23
+
24
+ if aesara:
25
+ import numpy as np
26
+ aet = aesara.tensor
27
+ from aesara.scalar.basic import ScalarType
28
+ from aesara.graph.basic import Variable
29
+ from aesara.tensor.var import TensorVariable
30
+ from aesara.tensor.elemwise import Elemwise, DimShuffle
31
+ from aesara.tensor.math import Dot
32
+
33
+ from sympy.printing.aesaracode import true_divide
34
+
35
+ xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz']
36
+ Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ']
37
+ else:
38
+ #bin/test will not execute any tests now
39
+ disabled = True
40
+
41
+ import sympy as sy
42
+ from sympy.core.singleton import S
43
+ from sympy.abc import x, y, z, t
44
+ from sympy.printing.aesaracode import (aesara_code, dim_handling,
45
+ aesara_function)
46
+
47
+
48
+ # Default set of matrix symbols for testing - make square so we can both
49
+ # multiply and perform elementwise operations between them.
50
+ X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ']
51
+
52
+ # For testing AppliedUndef
53
+ f_t = sy.Function('f')(t)
54
+
55
+
56
+ def aesara_code_(expr, **kwargs):
57
+ """ Wrapper for aesara_code that uses a new, empty cache by default. """
58
+ kwargs.setdefault('cache', {})
59
+ with warns_deprecated_sympy():
60
+ return aesara_code(expr, **kwargs)
61
+
62
+ def aesara_function_(inputs, outputs, **kwargs):
63
+ """ Wrapper for aesara_function that uses a new, empty cache by default. """
64
+ kwargs.setdefault('cache', {})
65
+ with warns_deprecated_sympy():
66
+ return aesara_function(inputs, outputs, **kwargs)
67
+
68
+
69
+ def fgraph_of(*exprs):
70
+ """ Transform SymPy expressions into Aesara Computation.
71
+
72
+ Parameters
73
+ ==========
74
+ exprs
75
+ SymPy expressions
76
+
77
+ Returns
78
+ =======
79
+ aesara.graph.fg.FunctionGraph
80
+ """
81
+ outs = list(map(aesara_code_, exprs))
82
+ ins = list(aesara.graph.basic.graph_inputs(outs))
83
+ ins, outs = aesara.graph.basic.clone(ins, outs)
84
+ return aesara.graph.fg.FunctionGraph(ins, outs)
85
+
86
+
87
+ def aesara_simplify(fgraph):
88
+ """ Simplify a Aesara Computation.
89
+
90
+ Parameters
91
+ ==========
92
+ fgraph : aesara.graph.fg.FunctionGraph
93
+
94
+ Returns
95
+ =======
96
+ aesara.graph.fg.FunctionGraph
97
+ """
98
+ mode = aesara.compile.get_default_mode().excluding("fusion")
99
+ fgraph = fgraph.clone()
100
+ mode.optimizer.rewrite(fgraph)
101
+ return fgraph
102
+
103
+
104
+ def theq(a, b):
105
+ """ Test two Aesara objects for equality.
106
+
107
+ Also accepts numeric types and lists/tuples of supported types.
108
+
109
+ Note - debugprint() has a bug where it will accept numeric types but does
110
+ not respect the "file" argument and in this case and instead prints the number
111
+ to stdout and returns an empty string. This can lead to tests passing where
112
+ they should fail because any two numbers will always compare as equal. To
113
+ prevent this we treat numbers as a separate case.
114
+ """
115
+ numeric_types = (int, float, np.number)
116
+ a_is_num = isinstance(a, numeric_types)
117
+ b_is_num = isinstance(b, numeric_types)
118
+
119
+ # Compare numeric types using regular equality
120
+ if a_is_num or b_is_num:
121
+ if not (a_is_num and b_is_num):
122
+ return False
123
+
124
+ return a == b
125
+
126
+ # Compare sequences element-wise
127
+ a_is_seq = isinstance(a, (tuple, list))
128
+ b_is_seq = isinstance(b, (tuple, list))
129
+
130
+ if a_is_seq or b_is_seq:
131
+ if not (a_is_seq and b_is_seq) or type(a) != type(b):
132
+ return False
133
+
134
+ return list(map(theq, a)) == list(map(theq, b))
135
+
136
+ # Otherwise, assume debugprint() can handle it
137
+ astr = aesara.printing.debugprint(a, file='str')
138
+ bstr = aesara.printing.debugprint(b, file='str')
139
+
140
+ # Check for bug mentioned above
141
+ for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]:
142
+ if argstr == '':
143
+ raise TypeError(
144
+ 'aesara.printing.debugprint(%s) returned empty string '
145
+ '(%s is instance of %r)'
146
+ % (argname, argname, type(argval))
147
+ )
148
+
149
+ return astr == bstr
150
+
151
+
152
+ def test_example_symbols():
153
+ """
154
+ Check that the example symbols in this module print to their Aesara
155
+ equivalents, as many of the other tests depend on this.
156
+ """
157
+ assert theq(xt, aesara_code_(x))
158
+ assert theq(yt, aesara_code_(y))
159
+ assert theq(zt, aesara_code_(z))
160
+ assert theq(Xt, aesara_code_(X))
161
+ assert theq(Yt, aesara_code_(Y))
162
+ assert theq(Zt, aesara_code_(Z))
163
+
164
+
165
+ def test_Symbol():
166
+ """ Test printing a Symbol to a aesara variable. """
167
+ xx = aesara_code_(x)
168
+ assert isinstance(xx, Variable)
169
+ assert xx.broadcastable == ()
170
+ assert xx.name == x.name
171
+
172
+ xx2 = aesara_code_(x, broadcastables={x: (False,)})
173
+ assert xx2.broadcastable == (False,)
174
+ assert xx2.name == x.name
175
+
176
+ def test_MatrixSymbol():
177
+ """ Test printing a MatrixSymbol to a aesara variable. """
178
+ XX = aesara_code_(X)
179
+ assert isinstance(XX, TensorVariable)
180
+ assert XX.broadcastable == (False, False)
181
+
182
+ @SKIP # TODO - this is currently not checked but should be implemented
183
+ def test_MatrixSymbol_wrong_dims():
184
+ """ Test MatrixSymbol with invalid broadcastable. """
185
+ bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)]
186
+ for bc in bcs:
187
+ with raises(ValueError):
188
+ aesara_code_(X, broadcastables={X: bc})
189
+
190
+ def test_AppliedUndef():
191
+ """ Test printing AppliedUndef instance, which works similarly to Symbol. """
192
+ ftt = aesara_code_(f_t)
193
+ assert isinstance(ftt, TensorVariable)
194
+ assert ftt.broadcastable == ()
195
+ assert ftt.name == 'f_t'
196
+
197
+
198
+ def test_add():
199
+ expr = x + y
200
+ comp = aesara_code_(expr)
201
+ assert comp.owner.op == aesara.tensor.add
202
+
203
+ def test_trig():
204
+ assert theq(aesara_code_(sy.sin(x)), aet.sin(xt))
205
+ assert theq(aesara_code_(sy.tan(x)), aet.tan(xt))
206
+
207
+ def test_many():
208
+ """ Test printing a complex expression with multiple symbols. """
209
+ expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z)
210
+ comp = aesara_code_(expr)
211
+ expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt)
212
+ assert theq(comp, expected)
213
+
214
+
215
+ def test_dtype():
216
+ """ Test specifying specific data types through the dtype argument. """
217
+ for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']:
218
+ assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype
219
+
220
+ # "floatX" type
221
+ assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64')
222
+
223
+ # Type promotion
224
+ assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32'
225
+ assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64'
226
+
227
+
228
+ def test_broadcastables():
229
+ """ Test the "broadcastables" argument when printing symbol-like objects. """
230
+
231
+ # No restrictions on shape
232
+ for s in [x, f_t]:
233
+ for bc in [(), (False,), (True,), (False, False), (True, False)]:
234
+ assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc
235
+
236
+ # TODO - matrix broadcasting?
237
+
238
+ def test_broadcasting():
239
+ """ Test "broadcastable" attribute after applying element-wise binary op. """
240
+
241
+ expr = x + y
242
+
243
+ cases = [
244
+ [(), (), ()],
245
+ [(False,), (False,), (False,)],
246
+ [(True,), (False,), (False,)],
247
+ [(False, True), (False, False), (False, False)],
248
+ [(True, False), (False, False), (False, False)],
249
+ ]
250
+
251
+ for bc1, bc2, bc3 in cases:
252
+ comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2})
253
+ assert comp.broadcastable == bc3
254
+
255
+
256
+ def test_MatMul():
257
+ expr = X*Y*Z
258
+ expr_t = aesara_code_(expr)
259
+ assert isinstance(expr_t.owner.op, Dot)
260
+ assert theq(expr_t, Xt.dot(Yt).dot(Zt))
261
+
262
+ def test_Transpose():
263
+ assert isinstance(aesara_code_(X.T).owner.op, DimShuffle)
264
+
265
+ def test_MatAdd():
266
+ expr = X+Y+Z
267
+ assert isinstance(aesara_code_(expr).owner.op, Elemwise)
268
+
269
+
270
+ def test_Rationals():
271
+ assert theq(aesara_code_(sy.Integer(2) / 3), true_divide(2, 3))
272
+ assert theq(aesara_code_(S.Half), true_divide(1, 2))
273
+
274
+ def test_Integers():
275
+ assert aesara_code_(sy.Integer(3)) == 3
276
+
277
+ def test_factorial():
278
+ n = sy.Symbol('n')
279
+ assert aesara_code_(sy.factorial(n))
280
+
281
+ def test_Derivative():
282
+ with ignore_warnings(UserWarning):
283
+ simp = lambda expr: aesara_simplify(fgraph_of(expr))
284
+ assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))),
285
+ simp(aesara.grad(aet.sin(xt), xt)))
286
+
287
+
288
+ def test_aesara_function_simple():
289
+ """ Test aesara_function() with single output. """
290
+ f = aesara_function_([x, y], [x+y])
291
+ assert f(2, 3) == 5
292
+
293
+ def test_aesara_function_multi():
294
+ """ Test aesara_function() with multiple outputs. """
295
+ f = aesara_function_([x, y], [x+y, x-y])
296
+ o1, o2 = f(2, 3)
297
+ assert o1 == 5
298
+ assert o2 == -1
299
+
300
+ def test_aesara_function_numpy():
301
+ """ Test aesara_function() vs Numpy implementation. """
302
+ f = aesara_function_([x, y], [x+y], dim=1,
303
+ dtypes={x: 'float64', y: 'float64'})
304
+ assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9
305
+
306
+ f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
307
+ dim=1)
308
+ xx = np.arange(3).astype('float64')
309
+ yy = 2*np.arange(3).astype('float64')
310
+ assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9
311
+
312
+
313
+ def test_aesara_function_matrix():
314
+ m = sy.Matrix([[x, y], [z, x + y + z]])
315
+ expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]])
316
+ f = aesara_function_([x, y, z], [m])
317
+ np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
318
+ f = aesara_function_([x, y, z], [m], scalar=True)
319
+ np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
320
+ f = aesara_function_([x, y, z], [m, m])
321
+ assert isinstance(f(1.0, 2.0, 3.0), type([]))
322
+ np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected)
323
+ np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected)
324
+
325
+ def test_dim_handling():
326
+ assert dim_handling([x], dim=2) == {x: (False, False)}
327
+ assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True),
328
+ y: (False, False)}
329
+ assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)}
330
+
331
+ def test_aesara_function_kwargs():
332
+ """
333
+ Test passing additional kwargs from aesara_function() to aesara.function().
334
+ """
335
+ import numpy as np
336
+ f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore',
337
+ dtypes={x: 'float64', y: 'float64', z: 'float64'})
338
+ assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9
339
+
340
+ f = aesara_function_([x, y, z], [x+y],
341
+ dtypes={x: 'float64', y: 'float64', z: 'float64'},
342
+ dim=1, on_unused_input='ignore')
343
+ xx = np.arange(3).astype('float64')
344
+ yy = 2*np.arange(3).astype('float64')
345
+ zz = 2*np.arange(3).astype('float64')
346
+ assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9
347
+
348
+ def test_aesara_function_scalar():
349
+ """ Test the "scalar" argument to aesara_function(). """
350
+ from aesara.compile.function.types import Function
351
+
352
+ args = [
353
+ ([x, y], [x + y], None, [0]), # Single 0d output
354
+ ([X, Y], [X + Y], None, [2]), # Single 2d output
355
+ ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output
356
+ ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs
357
+ ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d
358
+ ]
359
+
360
+ # Create and test functions with and without the scalar setting
361
+ for inputs, outputs, in_dims, out_dims in args:
362
+ for scalar in [False, True]:
363
+
364
+ f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar)
365
+
366
+ # Check the aesara_function attribute is set whether wrapped or not
367
+ assert isinstance(f.aesara_function, Function)
368
+
369
+ # Feed in inputs of the appropriate size and get outputs
370
+ in_values = [
371
+ np.ones([1 if bc else 5 for bc in i.type.broadcastable])
372
+ for i in f.aesara_function.input_storage
373
+ ]
374
+ out_values = f(*in_values)
375
+ if not isinstance(out_values, list):
376
+ out_values = [out_values]
377
+
378
+ # Check output types and shapes
379
+ assert len(out_dims) == len(out_values)
380
+ for d, value in zip(out_dims, out_values):
381
+
382
+ if scalar and d == 0:
383
+ # Should have been converted to a scalar value
384
+ assert isinstance(value, np.number)
385
+
386
+ else:
387
+ # Otherwise should be an array
388
+ assert isinstance(value, np.ndarray)
389
+ assert value.ndim == d
390
+
391
+ def test_aesara_function_bad_kwarg():
392
+ """
393
+ Passing an unknown keyword argument to aesara_function() should raise an
394
+ exception.
395
+ """
396
+ raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3))
397
+
398
+
399
+ def test_slice():
400
+ assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3)
401
+
402
+ def theq_slice(s1, s2):
403
+ for attr in ['start', 'stop', 'step']:
404
+ a1 = getattr(s1, attr)
405
+ a2 = getattr(s2, attr)
406
+ if a1 is None or a2 is None:
407
+ if not (a1 is None or a2 is None):
408
+ return False
409
+ elif not theq(a1, a2):
410
+ return False
411
+ return True
412
+
413
+ dtypes = {x: 'int32', y: 'int32'}
414
+ assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt))
415
+ assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3))
416
+
417
+ def test_MatrixSlice():
418
+ cache = {}
419
+
420
+ n = sy.Symbol('n', integer=True)
421
+ X = sy.MatrixSymbol('X', n, n)
422
+
423
+ Y = X[1:2:3, 4:5:6]
424
+ Yt = aesara_code_(Y, cache=cache)
425
+
426
+ s = ScalarType('int64')
427
+ assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
428
+ assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache)
429
+ # == doesn't work in Aesara like it does in SymPy. You have to use
430
+ # equals.
431
+ assert all(Yt.owner.inputs[i].data == i for i in range(1, 7))
432
+
433
+ k = sy.Symbol('k')
434
+ aesara_code_(k, dtypes={k: 'int32'})
435
+ start, stop, step = 4, k, 2
436
+ Y = X[start:stop:step]
437
+ Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'})
438
+ # assert Yt.owner.op.idx_list[0].stop == kt
439
+
440
+ def test_BlockMatrix():
441
+ n = sy.Symbol('n', integer=True)
442
+ A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD']
443
+ At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D))
444
+ Block = sy.BlockMatrix([[A, B], [C, D]])
445
+ Blockt = aesara_code_(Block)
446
+ solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)),
447
+ aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))]
448
+ assert any(theq(Blockt, solution) for solution in solutions)
449
+
450
+ @SKIP
451
+ def test_BlockMatrix_Inverse_execution():
452
+ k, n = 2, 4
453
+ dtype = 'float32'
454
+ A = sy.MatrixSymbol('A', n, k)
455
+ B = sy.MatrixSymbol('B', n, n)
456
+ inputs = A, B
457
+ output = B.I*A
458
+
459
+ cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
460
+ B: [(n//2, n//2), (n//2, n//2)]}
461
+ cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs]
462
+ cutoutput = output.subs(dict(zip(inputs, cutinputs)))
463
+
464
+ dtypes = dict(zip(inputs, [dtype]*len(inputs)))
465
+ f = aesara_function_(inputs, [output], dtypes=dtypes, cache={})
466
+ fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)],
467
+ dtypes=dtypes, cache={})
468
+
469
+ ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs]
470
+ ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype),
471
+ np.eye(n).astype(dtype)]
472
+ ninputs[1] += np.ones(B.shape)*1e-5
473
+
474
+ assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5)
475
+
476
+ def test_DenseMatrix():
477
+ from aesara.tensor.basic import Join
478
+
479
+ t = sy.Symbol('theta')
480
+ for MatrixType in [sy.Matrix, sy.ImmutableMatrix]:
481
+ X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]])
482
+ tX = aesara_code_(X)
483
+ assert isinstance(tX, TensorVariable)
484
+ assert isinstance(tX.owner.op, Join)
485
+
486
+
487
+ def test_cache_basic():
488
+ """ Test single symbol-like objects are cached when printed by themselves. """
489
+
490
+ # Pairs of objects which should be considered equivalent with respect to caching
491
+ pairs = [
492
+ (x, sy.Symbol('x')),
493
+ (X, sy.MatrixSymbol('X', *X.shape)),
494
+ (f_t, sy.Function('f')(sy.Symbol('t'))),
495
+ ]
496
+
497
+ for s1, s2 in pairs:
498
+ cache = {}
499
+ st = aesara_code_(s1, cache=cache)
500
+
501
+ # Test hit with same instance
502
+ assert aesara_code_(s1, cache=cache) is st
503
+
504
+ # Test miss with same instance but new cache
505
+ assert aesara_code_(s1, cache={}) is not st
506
+
507
+ # Test hit with different but equivalent instance
508
+ assert aesara_code_(s2, cache=cache) is st
509
+
510
+ def test_global_cache():
511
+ """ Test use of the global cache. """
512
+ from sympy.printing.aesaracode import global_cache
513
+
514
+ backup = dict(global_cache)
515
+ try:
516
+ # Temporarily empty global cache
517
+ global_cache.clear()
518
+
519
+ for s in [x, X, f_t]:
520
+ with warns_deprecated_sympy():
521
+ st = aesara_code(s)
522
+ assert aesara_code(s) is st
523
+
524
+ finally:
525
+ # Restore global cache
526
+ global_cache.update(backup)
527
+
528
+ def test_cache_types_distinct():
529
+ """
530
+ Test that symbol-like objects of different types (Symbol, MatrixSymbol,
531
+ AppliedUndef) are distinguished by the cache even if they have the same
532
+ name.
533
+ """
534
+ symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t]
535
+
536
+ cache = {} # Single shared cache
537
+ printed = {}
538
+
539
+ for s in symbols:
540
+ st = aesara_code_(s, cache=cache)
541
+ assert st not in printed.values()
542
+ printed[s] = st
543
+
544
+ # Check all printed objects are distinct
545
+ assert len(set(map(id, printed.values()))) == len(symbols)
546
+
547
+ # Check retrieving
548
+ for s, st in printed.items():
549
+ with warns_deprecated_sympy():
550
+ assert aesara_code(s, cache=cache) is st
551
+
552
+ def test_symbols_are_created_once():
553
+ """
554
+ Test that a symbol is cached and reused when it appears in an expression
555
+ more than once.
556
+ """
557
+ expr = sy.Add(x, x, evaluate=False)
558
+ comp = aesara_code_(expr)
559
+
560
+ assert theq(comp, xt + xt)
561
+ assert not theq(comp, xt + aesara_code_(x))
562
+
563
+ def test_cache_complex():
564
+ """
565
+ Test caching on a complicated expression with multiple symbols appearing
566
+ multiple times.
567
+ """
568
+ expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y)
569
+ symbol_names = {s.name for s in expr.free_symbols}
570
+ expr_t = aesara_code_(expr)
571
+
572
+ # Iterate through variables in the Aesara computational graph that the
573
+ # printed expression depends on
574
+ seen = set()
575
+ for v in aesara.graph.basic.ancestors([expr_t]):
576
+ # Owner-less, non-constant variables should be our symbols
577
+ if v.owner is None and not isinstance(v, aesara.graph.basic.Constant):
578
+ # Check it corresponds to a symbol and appears only once
579
+ assert v.name in symbol_names
580
+ assert v.name not in seen
581
+ seen.add(v.name)
582
+
583
+ # Check all were present
584
+ assert seen == symbol_names
585
+
586
+
587
+ def test_Piecewise():
588
+ # A piecewise linear
589
+ expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III
590
+ result = aesara_code_(expr)
591
+ assert result.owner.op == aet.switch
592
+
593
+ expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1))
594
+ assert theq(result, expected)
595
+
596
+ expr = sy.Piecewise((x, x < 0))
597
+ result = aesara_code_(expr)
598
+ expected = aet.switch(xt < 0, xt, np.nan)
599
+ assert theq(result, expected)
600
+
601
+ expr = sy.Piecewise((0, sy.And(x>0, x<2)), \
602
+ (x, sy.Or(x>2, x<0)))
603
+ result = aesara_code_(expr)
604
+ expected = aet.switch(aet.and_(xt>0,xt<2), 0, \
605
+ aet.switch(aet.or_(xt>2, xt<0), xt, np.nan))
606
+ assert theq(result, expected)
607
+
608
+
609
+ def test_Relationals():
610
+ assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt))
611
+ # assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt)) # TODO - implement
612
+ assert theq(aesara_code_(x > y), xt > yt)
613
+ assert theq(aesara_code_(x < y), xt < yt)
614
+ assert theq(aesara_code_(x >= y), xt >= yt)
615
+ assert theq(aesara_code_(x <= y), xt <= yt)
616
+
617
+
618
+ def test_complexfunctions():
619
+ dtypes = {x:'complex128', y:'complex128'}
620
+ with warns_deprecated_sympy():
621
+ xt, yt = aesara_code(x, dtypes=dtypes), aesara_code(y, dtypes=dtypes)
622
+ from sympy.functions.elementary.complexes import conjugate
623
+ from aesara.tensor import as_tensor_variable as atv
624
+ from aesara.tensor import complex as cplx
625
+ with warns_deprecated_sympy():
626
+ assert theq(aesara_code(y*conjugate(x), dtypes=dtypes), yt*(xt.conj()))
627
+ assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1)))
628
+
629
+
630
+ def test_constantfunctions():
631
+ with warns_deprecated_sympy():
632
+ tf = aesara_function([],[1+1j])
633
+ assert(tf()==1+1j)
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_c.py ADDED
@@ -0,0 +1,888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (
2
+ S, pi, oo, Symbol, symbols, Rational, Integer, Float, Function, Mod, GoldenRatio, EulerGamma, Catalan,
3
+ Lambda, Dummy, nan, Mul, Pow, UnevaluatedExpr
4
+ )
5
+ from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
6
+ from sympy.functions import (
7
+ Abs, acos, acosh, asin, asinh, atan, atanh, atan2, ceiling, cos, cosh, erf,
8
+ erfc, exp, floor, gamma, log, loggamma, Max, Min, Piecewise, sign, sin, sinh,
9
+ sqrt, tan, tanh, fibonacci, lucas
10
+ )
11
+ from sympy.sets import Range
12
+ from sympy.logic import ITE, Implies, Equivalent
13
+ from sympy.codegen import For, aug_assign, Assignment
14
+ from sympy.testing.pytest import raises, XFAIL
15
+ from sympy.printing.codeprinter import PrintMethodNotImplementedError
16
+ from sympy.printing.c import C89CodePrinter, C99CodePrinter, get_math_macros
17
+ from sympy.codegen.ast import (
18
+ AddAugmentedAssignment, Element, Type, FloatType, Declaration, Pointer, Variable, value_const, pointer_const,
19
+ While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall, Return,
20
+ real, float32, float64, float80, float128, intc, Comment, CodeBlock, stderr, QuotedString
21
+ )
22
+ from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, fma, log10, Cbrt, hypot, Sqrt, isnan, isinf
23
+ from sympy.codegen.cnodes import restrict
24
+ from sympy.utilities.lambdify import implemented_function
25
+ from sympy.tensor import IndexedBase, Idx
26
+ from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix
27
+
28
+ from sympy.printing.codeprinter import ccode
29
+
30
+ x, y, z = symbols('x,y,z')
31
+
32
+
33
+ def test_printmethod():
34
+ class fabs(Abs):
35
+ def _ccode(self, printer):
36
+ return "fabs(%s)" % printer._print(self.args[0])
37
+
38
+ assert ccode(fabs(x)) == "fabs(x)"
39
+
40
+
41
+ def test_ccode_sqrt():
42
+ assert ccode(sqrt(x)) == "sqrt(x)"
43
+ assert ccode(x**0.5) == "sqrt(x)"
44
+ assert ccode(sqrt(x)) == "sqrt(x)"
45
+
46
+
47
+ def test_ccode_Pow():
48
+ assert ccode(x**3) == "pow(x, 3)"
49
+ assert ccode(x**(y**3)) == "pow(x, pow(y, 3))"
50
+ g = implemented_function('g', Lambda(x, 2*x))
51
+ assert ccode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
52
+ "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2) + y)"
53
+ assert ccode(x**-1.0) == '1.0/x'
54
+ assert ccode(x**Rational(2, 3)) == 'pow(x, 2.0/3.0)'
55
+ assert ccode(x**Rational(2, 3), type_aliases={real: float80}) == 'powl(x, 2.0L/3.0L)'
56
+ _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
57
+ (lambda base, exp: not exp.is_integer, "pow")]
58
+ assert ccode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
59
+ assert ccode(x**0.5, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 0.5)'
60
+ assert ccode(x**Rational(16, 5), user_functions={'Pow': _cond_cfunc}) == 'pow(x, 16.0/5.0)'
61
+ _cond_cfunc2 = [(lambda base, exp: base == 2, lambda base, exp: 'exp2(%s)' % exp),
62
+ (lambda base, exp: base != 2, 'pow')]
63
+ # Related to gh-11353
64
+ assert ccode(2**x, user_functions={'Pow': _cond_cfunc2}) == 'exp2(x)'
65
+ assert ccode(x**2, user_functions={'Pow': _cond_cfunc2}) == 'pow(x, 2)'
66
+ # For issue 14160
67
+ assert ccode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
68
+ evaluate=False)) == '-2*x/(y*y)'
69
+
70
+
71
+ def test_ccode_Max():
72
+ # Test for gh-11926
73
+ assert ccode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
74
+
75
+
76
+ def test_ccode_Min_performance():
77
+ #Shouldn't take more than a few seconds
78
+ big_min = Min(*symbols('a[0:50]'))
79
+ for curr_standard in ('c89', 'c99', 'c11'):
80
+ output = ccode(big_min, standard=curr_standard)
81
+ assert output.count('(') == output.count(')')
82
+
83
+
84
+ def test_ccode_constants_mathh():
85
+ assert ccode(exp(1)) == "M_E"
86
+ assert ccode(pi) == "M_PI"
87
+ assert ccode(oo, standard='c89') == "HUGE_VAL"
88
+ assert ccode(-oo, standard='c89') == "-HUGE_VAL"
89
+ assert ccode(oo) == "INFINITY"
90
+ assert ccode(-oo, standard='c99') == "-INFINITY"
91
+ assert ccode(pi, type_aliases={real: float80}) == "M_PIl"
92
+
93
+
94
+ def test_ccode_constants_other():
95
+ assert ccode(2*GoldenRatio) == "const double GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
96
+ assert ccode(
97
+ 2*Catalan) == "const double Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
98
+ assert ccode(2*EulerGamma) == "const double EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
99
+
100
+
101
+ def test_ccode_Rational():
102
+ assert ccode(Rational(3, 7)) == "3.0/7.0"
103
+ assert ccode(Rational(3, 7), type_aliases={real: float80}) == "3.0L/7.0L"
104
+ assert ccode(Rational(18, 9)) == "2"
105
+ assert ccode(Rational(3, -7)) == "-3.0/7.0"
106
+ assert ccode(Rational(3, -7), type_aliases={real: float80}) == "-3.0L/7.0L"
107
+ assert ccode(Rational(-3, -7)) == "3.0/7.0"
108
+ assert ccode(Rational(-3, -7), type_aliases={real: float80}) == "3.0L/7.0L"
109
+ assert ccode(x + Rational(3, 7)) == "x + 3.0/7.0"
110
+ assert ccode(x + Rational(3, 7), type_aliases={real: float80}) == "x + 3.0L/7.0L"
111
+ assert ccode(Rational(3, 7)*x) == "(3.0/7.0)*x"
112
+ assert ccode(Rational(3, 7)*x, type_aliases={real: float80}) == "(3.0L/7.0L)*x"
113
+
114
+
115
+ def test_ccode_Integer():
116
+ assert ccode(Integer(67)) == "67"
117
+ assert ccode(Integer(-1)) == "-1"
118
+
119
+
120
+ def test_ccode_functions():
121
+ assert ccode(sin(x) ** cos(x)) == "pow(sin(x), cos(x))"
122
+
123
+
124
+ def test_ccode_inline_function():
125
+ x = symbols('x')
126
+ g = implemented_function('g', Lambda(x, 2*x))
127
+ assert ccode(g(x)) == "2*x"
128
+ g = implemented_function('g', Lambda(x, 2*x/Catalan))
129
+ assert ccode(
130
+ g(x)) == "const double Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
131
+ A = IndexedBase('A')
132
+ i = Idx('i', symbols('n', integer=True))
133
+ g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
134
+ assert ccode(g(A[i]), assign_to=A[i]) == (
135
+ "for (int i=0; i<n; i++){\n"
136
+ " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
137
+ "}"
138
+ )
139
+
140
+
141
+ def test_ccode_exceptions():
142
+ assert ccode(gamma(x), standard='C99') == "tgamma(x)"
143
+ with raises(PrintMethodNotImplementedError):
144
+ ccode(gamma(x), standard='C89')
145
+ with raises(PrintMethodNotImplementedError):
146
+ ccode(gamma(x), standard='C89', allow_unknown_functions=False)
147
+
148
+ ccode(gamma(x), standard='C89', allow_unknown_functions=True)
149
+
150
+
151
+
152
+ def test_ccode_functions2():
153
+ assert ccode(ceiling(x)) == "ceil(x)"
154
+ assert ccode(Abs(x)) == "fabs(x)"
155
+ assert ccode(gamma(x)) == "tgamma(x)"
156
+ r, s = symbols('r,s', real=True)
157
+ assert ccode(Mod(ceiling(r), ceiling(s))) == '((ceil(r) % ceil(s)) + '\
158
+ 'ceil(s)) % ceil(s)'
159
+ assert ccode(Mod(r, s)) == "fmod(r, s)"
160
+ p1, p2 = symbols('p1 p2', integer=True, positive=True)
161
+ assert ccode(Mod(p1, p2)) == 'p1 % p2'
162
+ assert ccode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)'
163
+ assert ccode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)'
164
+ assert ccode(-Mod(3, 7, evaluate=False)) == '-(3 % 7)'
165
+ assert ccode(r*Mod(p1, p2)) == 'r*(p1 % p2)'
166
+ assert ccode(Mod(p1, p2)**s) == 'pow(p1 % p2, s)'
167
+ n = symbols('n', integer=True, negative=True)
168
+ assert ccode(Mod(-n, p2)) == '(-n) % p2'
169
+ assert ccode(fibonacci(n)) == '((1.0/5.0)*pow(2, -n)*sqrt(5)*(-pow(1 - sqrt(5), n) + pow(1 + sqrt(5), n)))'
170
+ assert ccode(lucas(n)) == '(pow(2, -n)*(pow(1 - sqrt(5), n) + pow(1 + sqrt(5), n)))'
171
+
172
+
173
+ def test_ccode_user_functions():
174
+ x = symbols('x', integer=False)
175
+ n = symbols('n', integer=True)
176
+ custom_functions = {
177
+ "ceiling": "ceil",
178
+ "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
179
+ }
180
+ assert ccode(ceiling(x), user_functions=custom_functions) == "ceil(x)"
181
+ assert ccode(Abs(x), user_functions=custom_functions) == "fabs(x)"
182
+ assert ccode(Abs(n), user_functions=custom_functions) == "abs(n)"
183
+
184
+ expr = Symbol('a')
185
+ muladd = Function('muladd')
186
+ for i in range(0, 100):
187
+ # the large number of terms acts as a regression test for gh-23839
188
+ expr = muladd(Rational(1, 2), Symbol(f'a{i}'), expr)
189
+ out = ccode(expr, user_functions={'muladd':'muladd'})
190
+ assert 'a99' in out
191
+ assert out.count('muladd') == 100
192
+
193
+
194
+ def test_ccode_boolean():
195
+ assert ccode(True) == "true"
196
+ assert ccode(S.true) == "true"
197
+ assert ccode(False) == "false"
198
+ assert ccode(S.false) == "false"
199
+ assert ccode(x & y) == "x && y"
200
+ assert ccode(x | y) == "x || y"
201
+ assert ccode(~x) == "!x"
202
+ assert ccode(x & y & z) == "x && y && z"
203
+ assert ccode(x | y | z) == "x || y || z"
204
+ assert ccode((x & y) | z) == "z || x && y"
205
+ assert ccode((x | y) & z) == "z && (x || y)"
206
+ # Automatic rewrites
207
+ assert ccode(x ^ y) == '(x || y) && (!x || !y)'
208
+ assert ccode((x ^ y) ^ z) == '(x || y || z) && (x || !y || !z) && (y || !x || !z) && (z || !x || !y)'
209
+ assert ccode(Implies(x, y)) == 'y || !x'
210
+ assert ccode(Equivalent(x, z ^ y, Implies(z, x))) == '(x || (y || !z) && (z || !y)) && (z && !x || (y || z) && (!y || !z))'
211
+
212
+
213
+ def test_ccode_Relational():
214
+ assert ccode(Eq(x, y)) == "x == y"
215
+ assert ccode(Ne(x, y)) == "x != y"
216
+ assert ccode(Le(x, y)) == "x <= y"
217
+ assert ccode(Lt(x, y)) == "x < y"
218
+ assert ccode(Gt(x, y)) == "x > y"
219
+ assert ccode(Ge(x, y)) == "x >= y"
220
+
221
+
222
+ def test_ccode_Piecewise():
223
+ expr = Piecewise((x, x < 1), (x**2, True))
224
+ assert ccode(expr) == (
225
+ "((x < 1) ? (\n"
226
+ " x\n"
227
+ ")\n"
228
+ ": (\n"
229
+ " pow(x, 2)\n"
230
+ "))")
231
+ assert ccode(expr, assign_to="c") == (
232
+ "if (x < 1) {\n"
233
+ " c = x;\n"
234
+ "}\n"
235
+ "else {\n"
236
+ " c = pow(x, 2);\n"
237
+ "}")
238
+ expr = Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))
239
+ assert ccode(expr) == (
240
+ "((x < 1) ? (\n"
241
+ " x\n"
242
+ ")\n"
243
+ ": ((x < 2) ? (\n"
244
+ " x + 1\n"
245
+ ")\n"
246
+ ": (\n"
247
+ " pow(x, 2)\n"
248
+ ")))")
249
+ assert ccode(expr, assign_to='c') == (
250
+ "if (x < 1) {\n"
251
+ " c = x;\n"
252
+ "}\n"
253
+ "else if (x < 2) {\n"
254
+ " c = x + 1;\n"
255
+ "}\n"
256
+ "else {\n"
257
+ " c = pow(x, 2);\n"
258
+ "}")
259
+ # Check that Piecewise without a True (default) condition error
260
+ expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
261
+ raises(ValueError, lambda: ccode(expr))
262
+
263
+
264
+ def test_ccode_sinc():
265
+ from sympy.functions.elementary.trigonometric import sinc
266
+ expr = sinc(x)
267
+ assert ccode(expr) == (
268
+ "(((x != 0) ? (\n"
269
+ " sin(x)/x\n"
270
+ ")\n"
271
+ ": (\n"
272
+ " 1\n"
273
+ ")))")
274
+
275
+
276
+ def test_ccode_Piecewise_deep():
277
+ p = ccode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
278
+ assert p == (
279
+ "2*((x < 1) ? (\n"
280
+ " x\n"
281
+ ")\n"
282
+ ": ((x < 2) ? (\n"
283
+ " x + 1\n"
284
+ ")\n"
285
+ ": (\n"
286
+ " pow(x, 2)\n"
287
+ ")))")
288
+ expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
289
+ assert ccode(expr) == (
290
+ "pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n"
291
+ " 0\n"
292
+ ")\n"
293
+ ": (\n"
294
+ " 1\n"
295
+ ")) + cos(z) - 1")
296
+ assert ccode(expr, assign_to='c') == (
297
+ "c = pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n"
298
+ " 0\n"
299
+ ")\n"
300
+ ": (\n"
301
+ " 1\n"
302
+ ")) + cos(z) - 1;")
303
+
304
+
305
+ def test_ccode_ITE():
306
+ expr = ITE(x < 1, y, z)
307
+ assert ccode(expr) == (
308
+ "((x < 1) ? (\n"
309
+ " y\n"
310
+ ")\n"
311
+ ": (\n"
312
+ " z\n"
313
+ "))")
314
+
315
+
316
+ def test_ccode_settings():
317
+ raises(TypeError, lambda: ccode(sin(x), method="garbage"))
318
+
319
+
320
+ def test_ccode_Indexed():
321
+ s, n, m, o = symbols('s n m o', integer=True)
322
+ i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
323
+
324
+ x = IndexedBase('x')[j]
325
+ A = IndexedBase('A')[i, j]
326
+ B = IndexedBase('B')[i, j, k]
327
+
328
+ p = C99CodePrinter()
329
+
330
+ assert p._print_Indexed(x) == 'x[j]'
331
+ assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
332
+ assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
333
+
334
+ A = IndexedBase('A', shape=(5,3))[i, j]
335
+ assert p._print_Indexed(A) == 'A[%s]' % (3*i + j)
336
+
337
+ A = IndexedBase('A', shape=(5,3), strides='F')[i, j]
338
+ assert ccode(A) == 'A[%s]' % (i + 5*j)
339
+
340
+ A = IndexedBase('A', shape=(29,29), strides=(1, s), offset=o)[i, j]
341
+ assert ccode(A) == 'A[o + s*j + i]'
342
+
343
+ Abase = IndexedBase('A', strides=(s, m, n), offset=o)
344
+ assert ccode(Abase[i, j, k]) == 'A[m*j + n*k + o + s*i]'
345
+ assert ccode(Abase[2, 3, k]) == 'A[3*m + n*k + o + 2*s]'
346
+
347
+
348
+ def test_Element():
349
+ assert ccode(Element('x', 'ij')) == 'x[i][j]'
350
+ assert ccode(Element('x', 'ij', strides='kl', offset='o')) == 'x[i*k + j*l + o]'
351
+ assert ccode(Element('x', (3,))) == 'x[3]'
352
+ assert ccode(Element('x', (3,4,5))) == 'x[3][4][5]'
353
+
354
+
355
+ def test_ccode_Indexed_without_looking_for_contraction():
356
+ len_y = 5
357
+ y = IndexedBase('y', shape=(len_y,))
358
+ x = IndexedBase('x', shape=(len_y,))
359
+ Dy = IndexedBase('Dy', shape=(len_y-1,))
360
+ i = Idx('i', len_y-1)
361
+ e = Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
362
+ code0 = ccode(e.rhs, assign_to=e.lhs, contract=False)
363
+ assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
364
+
365
+
366
+ def test_ccode_loops_matrix_vector():
367
+ n, m = symbols('n m', integer=True)
368
+ A = IndexedBase('A')
369
+ x = IndexedBase('x')
370
+ y = IndexedBase('y')
371
+ i = Idx('i', m)
372
+ j = Idx('j', n)
373
+
374
+ s = (
375
+ 'for (int i=0; i<m; i++){\n'
376
+ ' y[i] = 0;\n'
377
+ '}\n'
378
+ 'for (int i=0; i<m; i++){\n'
379
+ ' for (int j=0; j<n; j++){\n'
380
+ ' y[i] = A[%s]*x[j] + y[i];\n' % (i*n + j) +\
381
+ ' }\n'
382
+ '}'
383
+ )
384
+ assert ccode(A[i, j]*x[j], assign_to=y[i]) == s
385
+
386
+
387
+ def test_dummy_loops():
388
+ i, m = symbols('i m', integer=True, cls=Dummy)
389
+ x = IndexedBase('x')
390
+ y = IndexedBase('y')
391
+ i = Idx(i, m)
392
+
393
+ expected = (
394
+ 'for (int i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
395
+ ' y[i_%(icount)i] = x[i_%(icount)i];\n'
396
+ '}'
397
+ ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
398
+
399
+ assert ccode(x[i], assign_to=y[i]) == expected
400
+
401
+
402
+ def test_ccode_loops_add():
403
+ n, m = symbols('n m', integer=True)
404
+ A = IndexedBase('A')
405
+ x = IndexedBase('x')
406
+ y = IndexedBase('y')
407
+ z = IndexedBase('z')
408
+ i = Idx('i', m)
409
+ j = Idx('j', n)
410
+
411
+ s = (
412
+ 'for (int i=0; i<m; i++){\n'
413
+ ' y[i] = x[i] + z[i];\n'
414
+ '}\n'
415
+ 'for (int i=0; i<m; i++){\n'
416
+ ' for (int j=0; j<n; j++){\n'
417
+ ' y[i] = A[%s]*x[j] + y[i];\n' % (i*n + j) +\
418
+ ' }\n'
419
+ '}'
420
+ )
421
+ assert ccode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == s
422
+
423
+
424
+ def test_ccode_loops_multiple_contractions():
425
+ n, m, o, p = symbols('n m o p', integer=True)
426
+ a = IndexedBase('a')
427
+ b = IndexedBase('b')
428
+ y = IndexedBase('y')
429
+ i = Idx('i', m)
430
+ j = Idx('j', n)
431
+ k = Idx('k', o)
432
+ l = Idx('l', p)
433
+
434
+ s = (
435
+ 'for (int i=0; i<m; i++){\n'
436
+ ' y[i] = 0;\n'
437
+ '}\n'
438
+ 'for (int i=0; i<m; i++){\n'
439
+ ' for (int j=0; j<n; j++){\n'
440
+ ' for (int k=0; k<o; k++){\n'
441
+ ' for (int l=0; l<p; l++){\n'
442
+ ' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
443
+ ' }\n'
444
+ ' }\n'
445
+ ' }\n'
446
+ '}'
447
+ )
448
+ assert ccode(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == s
449
+
450
+
451
+ def test_ccode_loops_addfactor():
452
+ n, m, o, p = symbols('n m o p', integer=True)
453
+ a = IndexedBase('a')
454
+ b = IndexedBase('b')
455
+ c = IndexedBase('c')
456
+ y = IndexedBase('y')
457
+ i = Idx('i', m)
458
+ j = Idx('j', n)
459
+ k = Idx('k', o)
460
+ l = Idx('l', p)
461
+
462
+ s = (
463
+ 'for (int i=0; i<m; i++){\n'
464
+ ' y[i] = 0;\n'
465
+ '}\n'
466
+ 'for (int i=0; i<m; i++){\n'
467
+ ' for (int j=0; j<n; j++){\n'
468
+ ' for (int k=0; k<o; k++){\n'
469
+ ' for (int l=0; l<p; l++){\n'
470
+ ' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
471
+ ' }\n'
472
+ ' }\n'
473
+ ' }\n'
474
+ '}'
475
+ )
476
+ assert ccode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) == s
477
+
478
+
479
+ def test_ccode_loops_multiple_terms():
480
+ n, m, o, p = symbols('n m o p', integer=True)
481
+ a = IndexedBase('a')
482
+ b = IndexedBase('b')
483
+ c = IndexedBase('c')
484
+ y = IndexedBase('y')
485
+ i = Idx('i', m)
486
+ j = Idx('j', n)
487
+ k = Idx('k', o)
488
+
489
+ s0 = (
490
+ 'for (int i=0; i<m; i++){\n'
491
+ ' y[i] = 0;\n'
492
+ '}\n'
493
+ )
494
+ s1 = (
495
+ 'for (int i=0; i<m; i++){\n'
496
+ ' for (int j=0; j<n; j++){\n'
497
+ ' for (int k=0; k<o; k++){\n'
498
+ ' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
499
+ ' }\n'
500
+ ' }\n'
501
+ '}\n'
502
+ )
503
+ s2 = (
504
+ 'for (int i=0; i<m; i++){\n'
505
+ ' for (int k=0; k<o; k++){\n'
506
+ ' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
507
+ ' }\n'
508
+ '}\n'
509
+ )
510
+ s3 = (
511
+ 'for (int i=0; i<m; i++){\n'
512
+ ' for (int j=0; j<n; j++){\n'
513
+ ' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
514
+ ' }\n'
515
+ '}\n'
516
+ )
517
+ c = ccode(b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
518
+ assert (c == s0 + s1 + s2 + s3[:-1] or
519
+ c == s0 + s1 + s3 + s2[:-1] or
520
+ c == s0 + s2 + s1 + s3[:-1] or
521
+ c == s0 + s2 + s3 + s1[:-1] or
522
+ c == s0 + s3 + s1 + s2[:-1] or
523
+ c == s0 + s3 + s2 + s1[:-1])
524
+
525
+
526
+ def test_dereference_printing():
527
+ expr = x + y + sin(z) + z
528
+ assert ccode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
529
+
530
+
531
+ def test_Matrix_printing():
532
+ # Test returning a Matrix
533
+ mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
534
+ A = MatrixSymbol('A', 3, 1)
535
+ assert ccode(mat, A) == (
536
+ "A[0] = x*y;\n"
537
+ "if (y > 0) {\n"
538
+ " A[1] = x + 2;\n"
539
+ "}\n"
540
+ "else {\n"
541
+ " A[1] = y;\n"
542
+ "}\n"
543
+ "A[2] = sin(z);")
544
+ # Test using MatrixElements in expressions
545
+ expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
546
+ assert ccode(expr) == (
547
+ "((x > 0) ? (\n"
548
+ " 2*A[2]\n"
549
+ ")\n"
550
+ ": (\n"
551
+ " A[2]\n"
552
+ ")) + sin(A[1]) + A[0]")
553
+ # Test using MatrixElements in a Matrix
554
+ q = MatrixSymbol('q', 5, 1)
555
+ M = MatrixSymbol('M', 3, 3)
556
+ m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
557
+ [q[1,0] + q[2,0], q[3, 0], 5],
558
+ [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
559
+ assert ccode(m, M) == (
560
+ "M[0] = sin(q[1]);\n"
561
+ "M[1] = 0;\n"
562
+ "M[2] = cos(q[2]);\n"
563
+ "M[3] = q[1] + q[2];\n"
564
+ "M[4] = q[3];\n"
565
+ "M[5] = 5;\n"
566
+ "M[6] = 2*q[4]/q[1];\n"
567
+ "M[7] = sqrt(q[0]) + 4;\n"
568
+ "M[8] = 0;")
569
+
570
+
571
+ def test_sparse_matrix():
572
+ # gh-15791
573
+ with raises(PrintMethodNotImplementedError):
574
+ ccode(SparseMatrix([[1, 2, 3]]))
575
+
576
+ assert 'Not supported in C' in C89CodePrinter({'strict': False}).doprint(SparseMatrix([[1, 2, 3]]))
577
+
578
+
579
+
580
+ def test_ccode_reserved_words():
581
+ x, y = symbols('x, if')
582
+ with raises(ValueError):
583
+ ccode(y**2, error_on_reserved=True, standard='C99')
584
+ assert ccode(y**2) == 'pow(if_, 2)'
585
+ assert ccode(x * y**2, dereference=[y]) == 'pow((*if_), 2)*x'
586
+ assert ccode(y**2, reserved_word_suffix='_unreserved') == 'pow(if_unreserved, 2)'
587
+
588
+
589
+ def test_ccode_sign():
590
+ expr1, ref1 = sign(x) * y, 'y*(((x) > 0) - ((x) < 0))'
591
+ expr2, ref2 = sign(cos(x)), '(((cos(x)) > 0) - ((cos(x)) < 0))'
592
+ expr3, ref3 = sign(2 * x + x**2) * x + x**2, 'pow(x, 2) + x*(((pow(x, 2) + 2*x) > 0) - ((pow(x, 2) + 2*x) < 0))'
593
+ assert ccode(expr1) == ref1
594
+ assert ccode(expr1, 'z') == 'z = %s;' % ref1
595
+ assert ccode(expr2) == ref2
596
+ assert ccode(expr3) == ref3
597
+
598
+ def test_ccode_Assignment():
599
+ assert ccode(Assignment(x, y + z)) == 'x = y + z;'
600
+ assert ccode(aug_assign(x, '+', y + z)) == 'x += y + z;'
601
+
602
+
603
+ def test_ccode_For():
604
+ f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
605
+ assert ccode(f) == ("for (x = 0; x < 10; x += 2) {\n"
606
+ " y *= x;\n"
607
+ "}")
608
+
609
+ def test_ccode_Max_Min():
610
+ assert ccode(Max(x, 0), standard='C89') == '((0 > x) ? 0 : x)'
611
+ assert ccode(Max(x, 0), standard='C99') == 'fmax(0, x)'
612
+ assert ccode(Min(x, 0, sqrt(x)), standard='c89') == (
613
+ '((0 < ((x < sqrt(x)) ? x : sqrt(x))) ? 0 : ((x < sqrt(x)) ? x : sqrt(x)))'
614
+ )
615
+
616
+ def test_ccode_standard():
617
+ assert ccode(expm1(x), standard='c99') == 'expm1(x)'
618
+ assert ccode(nan, standard='c99') == 'NAN'
619
+ assert ccode(float('nan'), standard='c99') == 'NAN'
620
+
621
+
622
+ def test_C89CodePrinter():
623
+ c89printer = C89CodePrinter()
624
+ assert c89printer.language == 'C'
625
+ assert c89printer.standard == 'C89'
626
+ assert 'void' in c89printer.reserved_words
627
+ assert 'template' not in c89printer.reserved_words
628
+ assert c89printer.doprint(log10(x)) == 'log10(x)'
629
+
630
+
631
+ def test_C99CodePrinter():
632
+ assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)'
633
+ assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)'
634
+ assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)'
635
+ assert C99CodePrinter().doprint(log2(x)) == 'log2(x)'
636
+ assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)'
637
+ assert C99CodePrinter().doprint(log10(x)) == 'log10(x)'
638
+ assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken.
639
+ assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)'
640
+ assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)'
641
+ assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))'
642
+ assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)'
643
+ c99printer = C99CodePrinter()
644
+ assert c99printer.language == 'C'
645
+ assert c99printer.standard == 'C99'
646
+ assert 'restrict' in c99printer.reserved_words
647
+ assert 'using' not in c99printer.reserved_words
648
+
649
+
650
+ @XFAIL
651
+ def test_C99CodePrinter__precision_f80():
652
+ f80_printer = C99CodePrinter({"type_aliases": {real: float80}})
653
+ assert f80_printer.doprint(sin(x + Float('2.1'))) == 'sinl(x + 2.1L)'
654
+
655
+
656
+ def test_C99CodePrinter__precision():
657
+ n = symbols('n', integer=True)
658
+ p = symbols('p', integer=True, positive=True)
659
+ f32_printer = C99CodePrinter({"type_aliases": {real: float32}})
660
+ f64_printer = C99CodePrinter({"type_aliases": {real: float64}})
661
+ f80_printer = C99CodePrinter({"type_aliases": {real: float80}})
662
+ assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)'
663
+ assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)'
664
+ assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)'
665
+
666
+ for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']):
667
+ def check(expr, ref):
668
+ assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())
669
+ check(Abs(n), 'abs(n)')
670
+ check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})')
671
+ check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))')
672
+ check(exp(x*8.0), 'exp{s}(8.0{S}*x)')
673
+ check(exp2(x), 'exp2{s}(x)')
674
+ check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)')
675
+ check(Mod(p, 2), 'p % 2')
676
+ check(Mod(2*p + 3, 3*p + 5, evaluate=False), '(2*p + 3) % (3*p + 5)')
677
+ check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})')
678
+ check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})')
679
+ check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)')
680
+ check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)')
681
+ check(log2(x*8.0), 'log2{s}(8.0{S}*x)')
682
+ check(log1p(x), 'log1p{s}(x)')
683
+ check(2**x, 'pow{s}(2, x)')
684
+ check(2.0**x, 'pow{s}(2.0{S}, x)')
685
+ check(x**3, 'pow{s}(x, 3)')
686
+ check(x**4.0, 'pow{s}(x, 4.0{S})')
687
+ check(sqrt(3+x), 'sqrt{s}(x + 3)')
688
+ check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})')
689
+ check(hypot(x, y), 'hypot{s}(x, y)')
690
+ check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})')
691
+ check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})')
692
+ check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})')
693
+ check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})')
694
+ check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})')
695
+ check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})')
696
+ check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)')
697
+
698
+ check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})')
699
+ check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})')
700
+ check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})')
701
+ check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})')
702
+ check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})')
703
+ check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})')
704
+ check(erf(42.*x), 'erf{s}(42.0{S}*x)')
705
+ check(erfc(42.*x), 'erfc{s}(42.0{S}*x)')
706
+ check(gamma(x), 'tgamma{s}(x)')
707
+ check(loggamma(x), 'lgamma{s}(x)')
708
+
709
+ check(ceiling(x + 2.), "ceil{s}(x) + 2")
710
+ check(floor(x + 2.), "floor{s}(x) + 2")
711
+ check(fma(x, y, -z), 'fma{s}(x, y, -z)')
712
+ check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))')
713
+ check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
714
+
715
+
716
+ def test_get_math_macros():
717
+ macros = get_math_macros()
718
+ assert macros[exp(1)] == 'M_E'
719
+ assert macros[1/Sqrt(2)] == 'M_SQRT1_2'
720
+
721
+
722
+ def test_ccode_Declaration():
723
+ i = symbols('i', integer=True)
724
+ var1 = Variable(i, type=Type.from_expr(i))
725
+ dcl1 = Declaration(var1)
726
+ assert ccode(dcl1) == 'int i'
727
+
728
+ var2 = Variable(x, type=float32, attrs={value_const})
729
+ dcl2a = Declaration(var2)
730
+ assert ccode(dcl2a) == 'const float x'
731
+ dcl2b = var2.as_Declaration(value=pi)
732
+ assert ccode(dcl2b) == 'const float x = M_PI'
733
+
734
+ var3 = Variable(y, type=Type('bool'))
735
+ dcl3 = Declaration(var3)
736
+ printer = C89CodePrinter()
737
+ assert 'stdbool.h' not in printer.headers
738
+ assert printer.doprint(dcl3) == 'bool y'
739
+ assert 'stdbool.h' in printer.headers
740
+
741
+ u = symbols('u', real=True)
742
+ ptr4 = Pointer.deduced(u, attrs={pointer_const, restrict})
743
+ dcl4 = Declaration(ptr4)
744
+ assert ccode(dcl4) == 'double * const restrict u'
745
+
746
+ var5 = Variable(x, Type('__float128'), attrs={value_const})
747
+ dcl5a = Declaration(var5)
748
+ assert ccode(dcl5a) == 'const __float128 x'
749
+ var5b = Variable(var5.symbol, var5.type, pi, attrs=var5.attrs)
750
+ dcl5b = Declaration(var5b)
751
+ assert ccode(dcl5b) == 'const __float128 x = M_PI'
752
+
753
+
754
+ def test_C99CodePrinter_custom_type():
755
+ # We will look at __float128 (new in glibc 2.26)
756
+ f128 = FloatType('_Float128', float128.nbits, float128.nmant, float128.nexp)
757
+ p128 = C99CodePrinter({
758
+ "type_aliases": {real: f128},
759
+ "type_literal_suffixes": {f128: 'Q'},
760
+ "type_func_suffixes": {f128: 'f128'},
761
+ "type_math_macro_suffixes": {
762
+ real: 'f128',
763
+ f128: 'f128'
764
+ },
765
+ "type_macros": {
766
+ f128: ('__STDC_WANT_IEC_60559_TYPES_EXT__',)
767
+ }
768
+ })
769
+ assert p128.doprint(x) == 'x'
770
+ assert not p128.headers
771
+ assert not p128.libraries
772
+ assert not p128.macros
773
+ assert p128.doprint(2.0) == '2.0Q'
774
+ assert not p128.headers
775
+ assert not p128.libraries
776
+ assert p128.macros == {'__STDC_WANT_IEC_60559_TYPES_EXT__'}
777
+
778
+ assert p128.doprint(Rational(1, 2)) == '1.0Q/2.0Q'
779
+ assert p128.doprint(sin(x)) == 'sinf128(x)'
780
+ assert p128.doprint(cos(2., evaluate=False)) == 'cosf128(2.0Q)'
781
+ assert p128.doprint(x**-1.0) == '1.0Q/x'
782
+
783
+ var5 = Variable(x, f128, attrs={value_const})
784
+
785
+ dcl5a = Declaration(var5)
786
+ assert ccode(dcl5a) == 'const _Float128 x'
787
+ var5b = Variable(x, f128, pi, attrs={value_const})
788
+ dcl5b = Declaration(var5b)
789
+ assert p128.doprint(dcl5b) == 'const _Float128 x = M_PIf128'
790
+ var5b = Variable(x, f128, value=Catalan.evalf(38), attrs={value_const})
791
+ dcl5c = Declaration(var5b)
792
+ assert p128.doprint(dcl5c) == 'const _Float128 x = %sQ' % Catalan.evalf(f128.decimal_dig)
793
+
794
+
795
+ def test_MatrixElement_printing():
796
+ # test cases for issue #11821
797
+ A = MatrixSymbol("A", 1, 3)
798
+ B = MatrixSymbol("B", 1, 3)
799
+ C = MatrixSymbol("C", 1, 3)
800
+
801
+ assert(ccode(A[0, 0]) == "A[0]")
802
+ assert(ccode(3 * A[0, 0]) == "3*A[0]")
803
+
804
+ F = C[0, 0].subs(C, A - B)
805
+ assert(ccode(F) == "(A - B)[0]")
806
+
807
+ def test_ccode_math_macros():
808
+ assert ccode(z + exp(1)) == 'z + M_E'
809
+ assert ccode(z + log2(exp(1))) == 'z + M_LOG2E'
810
+ assert ccode(z + 1/log(2)) == 'z + M_LOG2E'
811
+ assert ccode(z + log(2)) == 'z + M_LN2'
812
+ assert ccode(z + log(10)) == 'z + M_LN10'
813
+ assert ccode(z + pi) == 'z + M_PI'
814
+ assert ccode(z + pi/2) == 'z + M_PI_2'
815
+ assert ccode(z + pi/4) == 'z + M_PI_4'
816
+ assert ccode(z + 1/pi) == 'z + M_1_PI'
817
+ assert ccode(z + 2/pi) == 'z + M_2_PI'
818
+ assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI'
819
+ assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI'
820
+ assert ccode(z + sqrt(2)) == 'z + M_SQRT2'
821
+ assert ccode(z + Sqrt(2)) == 'z + M_SQRT2'
822
+ assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2'
823
+ assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2'
824
+
825
+
826
+ def test_ccode_Type():
827
+ assert ccode(Type('float')) == 'float'
828
+ assert ccode(intc) == 'int'
829
+
830
+
831
+ def test_ccode_codegen_ast():
832
+ # Note that C only allows comments of the form /* ... */, double forward
833
+ # slash is not standard C, and some C compilers will grind to a halt upon
834
+ # encountering them.
835
+ assert ccode(Comment("this is a comment")) == "/* this is a comment */" # not //
836
+ assert ccode(While(abs(x) > 1, [aug_assign(x, '-', 1)])) == (
837
+ 'while (fabs(x) > 1) {\n'
838
+ ' x -= 1;\n'
839
+ '}'
840
+ )
841
+ assert ccode(Scope([AddAugmentedAssignment(x, 1)])) == (
842
+ '{\n'
843
+ ' x += 1;\n'
844
+ '}'
845
+ )
846
+ inp_x = Declaration(Variable(x, type=real))
847
+ assert ccode(FunctionPrototype(real, 'pwer', [inp_x])) == 'double pwer(double x)'
848
+ assert ccode(FunctionDefinition(real, 'pwer', [inp_x], [Assignment(x, x**2)])) == (
849
+ 'double pwer(double x){\n'
850
+ ' x = pow(x, 2);\n'
851
+ '}'
852
+ )
853
+
854
+ # Elements of CodeBlock are formatted as statements:
855
+ block = CodeBlock(
856
+ x,
857
+ Print([x, y], "%d %d"),
858
+ Print([QuotedString('hello'), y], "%s %d", file=stderr),
859
+ FunctionCall('pwer', [x]),
860
+ Return(x),
861
+ )
862
+ assert ccode(block) == '\n'.join([
863
+ 'x;',
864
+ 'printf("%d %d", x, y);',
865
+ 'fprintf(stderr, "%s %d", "hello", y);',
866
+ 'pwer(x);',
867
+ 'return x;',
868
+ ])
869
+
870
+ def test_ccode_UnevaluatedExpr():
871
+ assert ccode(UnevaluatedExpr(y * x) + z) == "z + x*y"
872
+ assert ccode(UnevaluatedExpr(y + x) + z) == "z + (x + y)" # gh-21955
873
+ w = symbols('w')
874
+ assert ccode(UnevaluatedExpr(y + x) + UnevaluatedExpr(z + w)) == "(w + z) + (x + y)"
875
+
876
+ p, q, r = symbols("p q r", real=True)
877
+ q_r = UnevaluatedExpr(q + r)
878
+ expr = abs(exp(p+q_r))
879
+ assert ccode(expr) == "exp(p + (q + r))"
880
+
881
+
882
+ def test_ccode_array_like_containers():
883
+ assert ccode([2,3,4]) == "{2, 3, 4}"
884
+ assert ccode((2,3,4)) == "{2, 3, 4}"
885
+
886
+ def test_ccode__isinf_isnan():
887
+ assert ccode(isinf(x)) == 'isinf(x)'
888
+ assert ccode(isnan(x)) == 'isnan(x)'
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_codeprinter.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.printing.codeprinter import CodePrinter, PrintMethodNotImplementedError
2
+ from sympy.core import symbols
3
+ from sympy.core.symbol import Dummy
4
+ from sympy.testing.pytest import raises
5
+ from sympy import cos
6
+ from sympy.utilities.lambdify import lambdify
7
+ from math import cos as math_cos
8
+ from sympy.printing.lambdarepr import LambdaPrinter
9
+
10
+
11
+ def setup_test_printer(**kwargs):
12
+ p = CodePrinter(settings=kwargs)
13
+ p._not_supported = set()
14
+ p._number_symbols = set()
15
+ return p
16
+
17
+
18
+ def test_print_Dummy():
19
+ d = Dummy('d')
20
+ p = setup_test_printer()
21
+ assert p._print_Dummy(d) == "d_%i" % d.dummy_index
22
+
23
+ def test_print_Symbol():
24
+
25
+ x, y = symbols('x, if')
26
+
27
+ p = setup_test_printer()
28
+ assert p._print(x) == 'x'
29
+ assert p._print(y) == 'if'
30
+
31
+ p.reserved_words.update(['if'])
32
+ assert p._print(y) == 'if_'
33
+
34
+ p = setup_test_printer(error_on_reserved=True)
35
+ p.reserved_words.update(['if'])
36
+ with raises(ValueError):
37
+ p._print(y)
38
+
39
+ p = setup_test_printer(reserved_word_suffix='_He_Man')
40
+ p.reserved_words.update(['if'])
41
+ assert p._print(y) == 'if_He_Man'
42
+
43
+
44
+ def test_lambdify_LaTeX_symbols_issue_23374():
45
+ # Create symbols with Latex style names
46
+ x1, x2 = symbols("x_{1} x_2")
47
+
48
+ # Lambdify the function
49
+ f1 = lambdify([x1, x2], cos(x1 ** 2 + x2 ** 2))
50
+
51
+ # Test that the function works correctly (numerically)
52
+ assert f1(1, 2) == math_cos(1 ** 2 + 2 ** 2)
53
+
54
+ # Explicitly generate a custom printer to verify the naming convention
55
+ p = LambdaPrinter()
56
+ expr_str = p.doprint(cos(x1 ** 2 + x2 ** 2))
57
+ assert 'x_1' in expr_str
58
+ assert 'x_2' in expr_str
59
+
60
+
61
+ def test_issue_15791():
62
+ class CrashingCodePrinter(CodePrinter):
63
+ def emptyPrinter(self, obj):
64
+ raise NotImplementedError
65
+
66
+ from sympy.matrices import (
67
+ MutableSparseMatrix,
68
+ ImmutableSparseMatrix,
69
+ )
70
+
71
+ c = CrashingCodePrinter()
72
+
73
+ # these should not silently succeed
74
+ with raises(PrintMethodNotImplementedError):
75
+ c.doprint(ImmutableSparseMatrix(2, 2, {}))
76
+ with raises(PrintMethodNotImplementedError):
77
+ c.doprint(MutableSparseMatrix(2, 2, {}))
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_conventions.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from sympy.core.function import (Derivative, Function)
4
+ from sympy.core.numbers import oo
5
+ from sympy.core.symbol import symbols
6
+ from sympy.functions.elementary.exponential import exp
7
+ from sympy.functions.elementary.trigonometric import cos
8
+ from sympy.integrals.integrals import Integral
9
+ from sympy.functions.special.bessel import besselj
10
+ from sympy.functions.special.polynomials import legendre
11
+ from sympy.functions.combinatorial.numbers import bell
12
+ from sympy.printing.conventions import split_super_sub, requires_partial
13
+ from sympy.testing.pytest import XFAIL
14
+
15
+ def test_super_sub():
16
+ assert split_super_sub("beta_13_2") == ("beta", [], ["13", "2"])
17
+ assert split_super_sub("beta_132_20") == ("beta", [], ["132", "20"])
18
+ assert split_super_sub("beta_13") == ("beta", [], ["13"])
19
+ assert split_super_sub("x_a_b") == ("x", [], ["a", "b"])
20
+ assert split_super_sub("x_1_2_3") == ("x", [], ["1", "2", "3"])
21
+ assert split_super_sub("x_a_b1") == ("x", [], ["a", "b1"])
22
+ assert split_super_sub("x_a_1") == ("x", [], ["a", "1"])
23
+ assert split_super_sub("x_1_a") == ("x", [], ["1", "a"])
24
+ assert split_super_sub("x_1^aa") == ("x", ["aa"], ["1"])
25
+ assert split_super_sub("x_1__aa") == ("x", ["aa"], ["1"])
26
+ assert split_super_sub("x_11^a") == ("x", ["a"], ["11"])
27
+ assert split_super_sub("x_11__a") == ("x", ["a"], ["11"])
28
+ assert split_super_sub("x_a_b_c_d") == ("x", [], ["a", "b", "c", "d"])
29
+ assert split_super_sub("x_a_b^c^d") == ("x", ["c", "d"], ["a", "b"])
30
+ assert split_super_sub("x_a_b__c__d") == ("x", ["c", "d"], ["a", "b"])
31
+ assert split_super_sub("x_a^b_c^d") == ("x", ["b", "d"], ["a", "c"])
32
+ assert split_super_sub("x_a__b_c__d") == ("x", ["b", "d"], ["a", "c"])
33
+ assert split_super_sub("x^a^b_c_d") == ("x", ["a", "b"], ["c", "d"])
34
+ assert split_super_sub("x__a__b_c_d") == ("x", ["a", "b"], ["c", "d"])
35
+ assert split_super_sub("x^a^b^c^d") == ("x", ["a", "b", "c", "d"], [])
36
+ assert split_super_sub("x__a__b__c__d") == ("x", ["a", "b", "c", "d"], [])
37
+ assert split_super_sub("alpha_11") == ("alpha", [], ["11"])
38
+ assert split_super_sub("alpha_11_11") == ("alpha", [], ["11", "11"])
39
+ assert split_super_sub("w1") == ("w", [], ["1"])
40
+ assert split_super_sub("w𝟙") == ("w", [], ["𝟙"])
41
+ assert split_super_sub("w11") == ("w", [], ["11"])
42
+ assert split_super_sub("w𝟙𝟙") == ("w", [], ["𝟙𝟙"])
43
+ assert split_super_sub("w𝟙2𝟙") == ("w", [], ["𝟙2𝟙"])
44
+ assert split_super_sub("w1^a") == ("w", ["a"], ["1"])
45
+ assert split_super_sub("ω1") == ("ω", [], ["1"])
46
+ assert split_super_sub("ω11") == ("ω", [], ["11"])
47
+ assert split_super_sub("ω1^a") == ("ω", ["a"], ["1"])
48
+ assert split_super_sub("ω𝟙^α") == ("ω", ["α"], ["𝟙"])
49
+ assert split_super_sub("ω𝟙2^3α") == ("ω", ["3α"], ["𝟙2"])
50
+ assert split_super_sub("") == ("", [], [])
51
+
52
+
53
+ def test_requires_partial():
54
+ x, y, z, t, nu = symbols('x y z t nu')
55
+ n = symbols('n', integer=True)
56
+
57
+ f = x * y
58
+ assert requires_partial(Derivative(f, x)) is True
59
+ assert requires_partial(Derivative(f, y)) is True
60
+
61
+ ## integrating out one of the variables
62
+ assert requires_partial(Derivative(Integral(exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False
63
+
64
+ ## bessel function with smooth parameter
65
+ f = besselj(nu, x)
66
+ assert requires_partial(Derivative(f, x)) is True
67
+ assert requires_partial(Derivative(f, nu)) is True
68
+
69
+ ## bessel function with integer parameter
70
+ f = besselj(n, x)
71
+ assert requires_partial(Derivative(f, x)) is False
72
+ # this is not really valid (differentiating with respect to an integer)
73
+ # but there's no reason to use the partial derivative symbol there. make
74
+ # sure we don't throw an exception here, though
75
+ assert requires_partial(Derivative(f, n)) is False
76
+
77
+ ## bell polynomial
78
+ f = bell(n, x)
79
+ assert requires_partial(Derivative(f, x)) is False
80
+ # again, invalid
81
+ assert requires_partial(Derivative(f, n)) is False
82
+
83
+ ## legendre polynomial
84
+ f = legendre(0, x)
85
+ assert requires_partial(Derivative(f, x)) is False
86
+
87
+ f = legendre(n, x)
88
+ assert requires_partial(Derivative(f, x)) is False
89
+ # again, invalid
90
+ assert requires_partial(Derivative(f, n)) is False
91
+
92
+ f = x ** n
93
+ assert requires_partial(Derivative(f, x)) is False
94
+
95
+ assert requires_partial(Derivative(Integral((x*y) ** n * exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False
96
+
97
+ # parametric equation
98
+ f = (exp(t), cos(t))
99
+ g = sum(f)
100
+ assert requires_partial(Derivative(g, t)) is False
101
+
102
+ f = symbols('f', cls=Function)
103
+ assert requires_partial(Derivative(f(x), x)) is False
104
+ assert requires_partial(Derivative(f(x), y)) is False
105
+ assert requires_partial(Derivative(f(x, y), x)) is True
106
+ assert requires_partial(Derivative(f(x, y), y)) is True
107
+ assert requires_partial(Derivative(f(x, y), z)) is True
108
+ assert requires_partial(Derivative(f(x, y), x, y)) is True
109
+
110
+ @XFAIL
111
+ def test_requires_partial_unspecified_variables():
112
+ x, y = symbols('x y')
113
+ # function of unspecified variables
114
+ f = symbols('f', cls=Function)
115
+ assert requires_partial(Derivative(f, x)) is False
116
+ assert requires_partial(Derivative(f, x, y)) is True
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_cupy.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete.summations import Sum
2
+ from sympy.functions.elementary.exponential import log
3
+ from sympy.functions.elementary.miscellaneous import sqrt
4
+ from sympy.utilities.lambdify import lambdify
5
+ from sympy.abc import x, i, a, b
6
+ from sympy.codegen.numpy_nodes import logaddexp
7
+ from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions
8
+
9
+ from sympy.testing.pytest import skip, raises
10
+ from sympy.external import import_module
11
+
12
+ cp = import_module('cupy')
13
+
14
+ def test_cupy_print():
15
+ prntr = CuPyPrinter()
16
+ assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)'
17
+ assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)'
18
+ assert prntr.doprint(log(x)) == 'cupy.log(x)'
19
+ assert prntr.doprint("acos(x)") == 'cupy.arccos(x)'
20
+ assert prntr.doprint("exp(x)") == 'cupy.exp(x)'
21
+ assert prntr.doprint("Abs(x)") == 'abs(x)'
22
+
23
+ def test_not_cupy_print():
24
+ prntr = CuPyPrinter()
25
+ with raises(NotImplementedError):
26
+ prntr.doprint("abcd(x)")
27
+
28
+ def test_cupy_sum():
29
+ if not cp:
30
+ skip("CuPy not installed")
31
+
32
+ s = Sum(x ** i, (i, a, b))
33
+ f = lambdify((a, b, x), s, 'cupy')
34
+
35
+ a_, b_ = 0, 10
36
+ x_ = cp.linspace(-1, +1, 10)
37
+ assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
38
+
39
+ s = Sum(i * x, (i, a, b))
40
+ f = lambdify((a, b, x), s, 'numpy')
41
+
42
+ a_, b_ = 0, 10
43
+ x_ = cp.linspace(-1, +1, 10)
44
+ assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
45
+
46
+ def test_cupy_known_funcs_consts():
47
+ assert _cupy_known_constants['NaN'] == 'cupy.nan'
48
+ assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma'
49
+
50
+ assert _cupy_known_functions['acos'] == 'cupy.arccos'
51
+ assert _cupy_known_functions['log'] == 'cupy.log'
52
+
53
+ def test_cupy_print_methods():
54
+ prntr = CuPyPrinter()
55
+ assert hasattr(prntr, '_print_acos')
56
+ assert hasattr(prntr, '_print_log')
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_cxx.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.numbers import Float, Integer, Rational
2
+ from sympy.core.symbol import symbols
3
+ from sympy.functions import beta, Ei, zeta, Max, Min, sqrt, riemann_xi, frac
4
+ from sympy.printing.cxx import CXX98CodePrinter, CXX11CodePrinter, CXX17CodePrinter, cxxcode
5
+ from sympy.codegen.cfunctions import log1p
6
+
7
+
8
+ x, y, u, v = symbols('x y u v')
9
+
10
+
11
+ def test_CXX98CodePrinter():
12
+ assert CXX98CodePrinter().doprint(Max(x, 3)) in ('std::max(x, 3)', 'std::max(3, x)')
13
+ assert CXX98CodePrinter().doprint(Min(x, 3, sqrt(x))) == 'std::min(3, std::min(x, std::sqrt(x)))'
14
+ cxx98printer = CXX98CodePrinter()
15
+ assert cxx98printer.language == 'C++'
16
+ assert cxx98printer.standard == 'C++98'
17
+ assert 'template' in cxx98printer.reserved_words
18
+ assert 'alignas' not in cxx98printer.reserved_words
19
+
20
+
21
+ def test_CXX11CodePrinter():
22
+ assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)'
23
+
24
+ cxx11printer = CXX11CodePrinter()
25
+ assert cxx11printer.language == 'C++'
26
+ assert cxx11printer.standard == 'C++11'
27
+ assert 'operator' in cxx11printer.reserved_words
28
+ assert 'noexcept' in cxx11printer.reserved_words
29
+ assert 'concept' not in cxx11printer.reserved_words
30
+
31
+
32
+ def test_subclass_print_method():
33
+ class MyPrinter(CXX11CodePrinter):
34
+ def _print_log1p(self, expr):
35
+ return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args))
36
+
37
+ assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)'
38
+
39
+
40
+ def test_subclass_print_method__ns():
41
+ class MyPrinter(CXX11CodePrinter):
42
+ _ns = 'my_library::'
43
+
44
+ p = CXX11CodePrinter()
45
+ myp = MyPrinter()
46
+
47
+ assert p.doprint(log1p(x)) == 'std::log1p(x)'
48
+ assert myp.doprint(log1p(x)) == 'my_library::log1p(x)'
49
+
50
+
51
+ def test_CXX17CodePrinter():
52
+ assert CXX17CodePrinter().doprint(beta(x, y)) == 'std::beta(x, y)'
53
+ assert CXX17CodePrinter().doprint(Ei(x)) == 'std::expint(x)'
54
+ assert CXX17CodePrinter().doprint(zeta(x)) == 'std::riemann_zeta(x)'
55
+
56
+ # Automatic rewrite
57
+ assert CXX17CodePrinter().doprint(frac(x)) == '(x - std::floor(x))'
58
+ assert CXX17CodePrinter().doprint(riemann_xi(x)) == '((1.0/2.0)*std::pow(M_PI, -1.0/2.0*x)*x*(x - 1)*std::tgamma((1.0/2.0)*x)*std::riemann_zeta(x))'
59
+
60
+
61
+ def test_cxxcode():
62
+ assert sorted(cxxcode(sqrt(x)*.5).split('*')) == sorted(['0.5', 'std::sqrt(x)'])
63
+
64
+ def test_cxxcode_nested_minmax():
65
+ assert cxxcode(Max(Min(x, y), Min(u, v))) \
66
+ == 'std::max(std::min(u, v), std::min(x, y))'
67
+ assert cxxcode(Min(Max(x, y), Max(u, v))) \
68
+ == 'std::min(std::max(u, v), std::max(x, y))'
69
+
70
+ def test_subclass_Integer_Float():
71
+ class MyPrinter(CXX17CodePrinter):
72
+ def _print_Integer(self, arg):
73
+ return 'bigInt("%s")' % super()._print_Integer(arg)
74
+
75
+ def _print_Float(self, arg):
76
+ rat = Rational(arg)
77
+ return 'bigFloat(%s, %s)' % (
78
+ self._print(Integer(rat.p)),
79
+ self._print(Integer(rat.q))
80
+ )
81
+
82
+ p = MyPrinter()
83
+ for i in range(13):
84
+ assert p.doprint(i) == 'bigInt("%d")' % i
85
+ assert p.doprint(Float(0.5)) == 'bigFloat(bigInt("1"), bigInt("2"))'
86
+ assert p.doprint(x**-1.0) == 'bigFloat(bigInt("1"), bigInt("1"))/x'
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_dot.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.printing.dot import (purestr, styleof, attrprint, dotnode,
2
+ dotedges, dotprint)
3
+ from sympy.core.basic import Basic
4
+ from sympy.core.expr import Expr
5
+ from sympy.core.numbers import (Float, Integer)
6
+ from sympy.core.singleton import S
7
+ from sympy.core.symbol import (Symbol, symbols)
8
+ from sympy.printing.repr import srepr
9
+ from sympy.abc import x
10
+
11
+
12
+ def test_purestr():
13
+ assert purestr(Symbol('x')) == "Symbol('x')"
14
+ assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))"
15
+ assert purestr(Float(2)) == "Float('2.0', precision=53)"
16
+
17
+ assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ())
18
+ assert purestr(Basic(S(1), S(2)), with_args=True) == \
19
+ ('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)'))
20
+ assert purestr(Float(2), with_args=True) == \
21
+ ("Float('2.0', precision=53)", ())
22
+
23
+
24
+ def test_styleof():
25
+ styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
26
+ (Expr, {'color': 'black'})]
27
+ assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'}
28
+
29
+ assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'}
30
+
31
+
32
+ def test_attrprint():
33
+ assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \
34
+ '"color"="blue", "shape"="ellipse"'
35
+
36
+ def test_dotnode():
37
+
38
+ assert dotnode(x, repeat=False) == \
39
+ '"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];'
40
+ assert dotnode(x+2, repeat=False) == \
41
+ '"Add(Integer(2), Symbol(\'x\'))" ' \
42
+ '["color"="black", "label"="Add", "shape"="ellipse"];', \
43
+ dotnode(x+2,repeat=0)
44
+
45
+ assert dotnode(x + x**2, repeat=False) == \
46
+ '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \
47
+ '["color"="black", "label"="Add", "shape"="ellipse"];'
48
+ assert dotnode(x + x**2, repeat=True) == \
49
+ '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \
50
+ '["color"="black", "label"="Add", "shape"="ellipse"];'
51
+
52
+ def test_dotedges():
53
+ assert sorted(dotedges(x+2, repeat=False)) == [
54
+ '"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";',
55
+ '"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";'
56
+ ]
57
+ assert sorted(dotedges(x + 2, repeat=True)) == [
58
+ '"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";',
59
+ '"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";'
60
+ ]
61
+
62
+ def test_dotprint():
63
+ text = dotprint(x+2, repeat=False)
64
+ assert all(e in text for e in dotedges(x+2, repeat=False))
65
+ assert all(
66
+ n in text for n in [dotnode(expr, repeat=False)
67
+ for expr in (x, Integer(2), x+2)])
68
+ assert 'digraph' in text
69
+
70
+ text = dotprint(x+x**2, repeat=False)
71
+ assert all(e in text for e in dotedges(x+x**2, repeat=False))
72
+ assert all(
73
+ n in text for n in [dotnode(expr, repeat=False)
74
+ for expr in (x, Integer(2), x**2)])
75
+ assert 'digraph' in text
76
+
77
+ text = dotprint(x+x**2, repeat=True)
78
+ assert all(e in text for e in dotedges(x+x**2, repeat=True))
79
+ assert all(
80
+ n in text for n in [dotnode(expr, pos=())
81
+ for expr in [x + x**2]])
82
+
83
+ text = dotprint(x**x, repeat=True)
84
+ assert all(e in text for e in dotedges(x**x, repeat=True))
85
+ assert all(
86
+ n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))])
87
+ assert 'digraph' in text
88
+
89
+ def test_dotprint_depth():
90
+ text = dotprint(3*x+2, depth=1)
91
+ assert dotnode(3*x+2) in text
92
+ assert dotnode(x) not in text
93
+ text = dotprint(3*x+2)
94
+ assert "depth" not in text
95
+
96
+ def test_Matrix_and_non_basics():
97
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
98
+ n = Symbol('n')
99
+ assert dotprint(MatrixSymbol('X', n, n)) == \
100
+ """digraph{
101
+
102
+ # Graph style
103
+ "ordering"="out"
104
+ "rankdir"="TD"
105
+
106
+ #########
107
+ # Nodes #
108
+ #########
109
+
110
+ "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"];
111
+ "Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"];
112
+ "Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"];
113
+ "Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"];
114
+
115
+ #########
116
+ # Edges #
117
+ #########
118
+
119
+ "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)";
120
+ "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)";
121
+ "MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)";
122
+ }"""
123
+
124
+
125
+ def test_labelfunc():
126
+ text = dotprint(x + 2, labelfunc=srepr)
127
+ assert "Symbol('x')" in text
128
+ assert "Integer(2)" in text
129
+
130
+
131
+ def test_commutative():
132
+ x, y = symbols('x y', commutative=False)
133
+ assert dotprint(x + y) == dotprint(y + x)
134
+ assert dotprint(x*y) != dotprint(y*x)
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_glsl.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (pi, symbols, Rational, Integer, GoldenRatio, EulerGamma,
2
+ Catalan, Lambda, Dummy, Eq, Ne, Le, Lt, Gt, Ge)
3
+ from sympy.functions import Piecewise, sin, cos, Abs, exp, ceiling, sqrt
4
+ from sympy.testing.pytest import raises, warns_deprecated_sympy
5
+ from sympy.printing.glsl import GLSLPrinter
6
+ from sympy.printing.str import StrPrinter
7
+ from sympy.utilities.lambdify import implemented_function
8
+ from sympy.tensor import IndexedBase, Idx
9
+ from sympy.matrices import Matrix, MatrixSymbol
10
+ from sympy.core import Tuple
11
+ from sympy.printing.glsl import glsl_code
12
+ import textwrap
13
+
14
+ x, y, z = symbols('x,y,z')
15
+
16
+
17
+ def test_printmethod():
18
+ assert glsl_code(Abs(x)) == "abs(x)"
19
+
20
+ def test_print_without_operators():
21
+ assert glsl_code(x*y,use_operators = False) == 'mul(x, y)'
22
+ assert glsl_code(x**y+z,use_operators = False) == 'add(pow(x, y), z)'
23
+ assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))'
24
+ assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))'
25
+ assert glsl_code(x*(y+z**y**0.5),use_operators = False) == 'mul(x, add(y, pow(z, sqrt(y))))'
26
+ assert glsl_code(-x-y, use_operators=False, zero='zero()') == 'sub(zero(), add(x, y))'
27
+ assert glsl_code(-x-y, use_operators=False) == 'sub(0.0, add(x, y))'
28
+
29
+ def test_glsl_code_sqrt():
30
+ assert glsl_code(sqrt(x)) == "sqrt(x)"
31
+ assert glsl_code(x**0.5) == "sqrt(x)"
32
+ assert glsl_code(sqrt(x)) == "sqrt(x)"
33
+
34
+
35
+ def test_glsl_code_Pow():
36
+ g = implemented_function('g', Lambda(x, 2*x))
37
+ assert glsl_code(x**3) == "pow(x, 3.0)"
38
+ assert glsl_code(x**(y**3)) == "pow(x, pow(y, 3.0))"
39
+ assert glsl_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
40
+ "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2.0) + y)"
41
+ assert glsl_code(x**-1.0) == '1.0/x'
42
+
43
+
44
+ def test_glsl_code_Relational():
45
+ assert glsl_code(Eq(x, y)) == "x == y"
46
+ assert glsl_code(Ne(x, y)) == "x != y"
47
+ assert glsl_code(Le(x, y)) == "x <= y"
48
+ assert glsl_code(Lt(x, y)) == "x < y"
49
+ assert glsl_code(Gt(x, y)) == "x > y"
50
+ assert glsl_code(Ge(x, y)) == "x >= y"
51
+
52
+
53
+ def test_glsl_code_constants_mathh():
54
+ assert glsl_code(exp(1)) == "float E = 2.71828183;\nE"
55
+ assert glsl_code(pi) == "float pi = 3.14159265;\npi"
56
+ # assert glsl_code(oo) == "Number.POSITIVE_INFINITY"
57
+ # assert glsl_code(-oo) == "Number.NEGATIVE_INFINITY"
58
+
59
+
60
+ def test_glsl_code_constants_other():
61
+ assert glsl_code(2*GoldenRatio) == "float GoldenRatio = 1.61803399;\n2*GoldenRatio"
62
+ assert glsl_code(2*Catalan) == "float Catalan = 0.915965594;\n2*Catalan"
63
+ assert glsl_code(2*EulerGamma) == "float EulerGamma = 0.577215665;\n2*EulerGamma"
64
+
65
+
66
+ def test_glsl_code_Rational():
67
+ assert glsl_code(Rational(3, 7)) == "3.0/7.0"
68
+ assert glsl_code(Rational(18, 9)) == "2"
69
+ assert glsl_code(Rational(3, -7)) == "-3.0/7.0"
70
+ assert glsl_code(Rational(-3, -7)) == "3.0/7.0"
71
+
72
+
73
+ def test_glsl_code_Integer():
74
+ assert glsl_code(Integer(67)) == "67"
75
+ assert glsl_code(Integer(-1)) == "-1"
76
+
77
+
78
+ def test_glsl_code_functions():
79
+ assert glsl_code(sin(x) ** cos(x)) == "pow(sin(x), cos(x))"
80
+
81
+
82
+ def test_glsl_code_inline_function():
83
+ x = symbols('x')
84
+ g = implemented_function('g', Lambda(x, 2*x))
85
+ assert glsl_code(g(x)) == "2*x"
86
+ g = implemented_function('g', Lambda(x, 2*x/Catalan))
87
+ assert glsl_code(g(x)) == "float Catalan = 0.915965594;\n2*x/Catalan"
88
+ A = IndexedBase('A')
89
+ i = Idx('i', symbols('n', integer=True))
90
+ g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
91
+ assert glsl_code(g(A[i]), assign_to=A[i]) == (
92
+ "for (int i=0; i<n; i++){\n"
93
+ " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
94
+ "}"
95
+ )
96
+
97
+
98
+ def test_glsl_code_exceptions():
99
+ assert glsl_code(ceiling(x)) == "ceil(x)"
100
+ assert glsl_code(Abs(x)) == "abs(x)"
101
+
102
+
103
+ def test_glsl_code_boolean():
104
+ assert glsl_code(x & y) == "x && y"
105
+ assert glsl_code(x | y) == "x || y"
106
+ assert glsl_code(~x) == "!x"
107
+ assert glsl_code(x & y & z) == "x && y && z"
108
+ assert glsl_code(x | y | z) == "x || y || z"
109
+ assert glsl_code((x & y) | z) == "z || x && y"
110
+ assert glsl_code((x | y) & z) == "z && (x || y)"
111
+
112
+
113
+ def test_glsl_code_Piecewise():
114
+ expr = Piecewise((x, x < 1), (x**2, True))
115
+ p = glsl_code(expr)
116
+ s = \
117
+ """\
118
+ ((x < 1) ? (
119
+ x
120
+ )
121
+ : (
122
+ pow(x, 2.0)
123
+ ))\
124
+ """
125
+ assert p == s
126
+ assert glsl_code(expr, assign_to="c") == (
127
+ "if (x < 1) {\n"
128
+ " c = x;\n"
129
+ "}\n"
130
+ "else {\n"
131
+ " c = pow(x, 2.0);\n"
132
+ "}")
133
+ # Check that Piecewise without a True (default) condition error
134
+ expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
135
+ raises(ValueError, lambda: glsl_code(expr))
136
+
137
+
138
+ def test_glsl_code_Piecewise_deep():
139
+ p = glsl_code(2*Piecewise((x, x < 1), (x**2, True)))
140
+ s = \
141
+ """\
142
+ 2*((x < 1) ? (
143
+ x
144
+ )
145
+ : (
146
+ pow(x, 2.0)
147
+ ))\
148
+ """
149
+ assert p == s
150
+
151
+
152
+ def test_glsl_code_settings():
153
+ raises(TypeError, lambda: glsl_code(sin(x), method="garbage"))
154
+
155
+
156
+ def test_glsl_code_Indexed():
157
+ n, m, o = symbols('n m o', integer=True)
158
+ i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
159
+ p = GLSLPrinter()
160
+ p._not_c = set()
161
+
162
+ x = IndexedBase('x')[j]
163
+ assert p._print_Indexed(x) == 'x[j]'
164
+ A = IndexedBase('A')[i, j]
165
+ assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
166
+ B = IndexedBase('B')[i, j, k]
167
+ assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
168
+
169
+ assert p._not_c == set()
170
+
171
+ def test_glsl_code_list_tuple_Tuple():
172
+ assert glsl_code([1,2,3,4]) == 'vec4(1, 2, 3, 4)'
173
+ assert glsl_code([1,2,3],glsl_types=False) == 'float[3](1, 2, 3)'
174
+ assert glsl_code([1,2,3]) == glsl_code((1,2,3))
175
+ assert glsl_code([1,2,3]) == glsl_code(Tuple(1,2,3))
176
+
177
+ m = MatrixSymbol('A',3,4)
178
+ assert glsl_code([m[0],m[1]])
179
+
180
+ def test_glsl_code_loops_matrix_vector():
181
+ n, m = symbols('n m', integer=True)
182
+ A = IndexedBase('A')
183
+ x = IndexedBase('x')
184
+ y = IndexedBase('y')
185
+ i = Idx('i', m)
186
+ j = Idx('j', n)
187
+
188
+ s = (
189
+ 'for (int i=0; i<m; i++){\n'
190
+ ' y[i] = 0.0;\n'
191
+ '}\n'
192
+ 'for (int i=0; i<m; i++){\n'
193
+ ' for (int j=0; j<n; j++){\n'
194
+ ' y[i] = A[n*i + j]*x[j] + y[i];\n'
195
+ ' }\n'
196
+ '}'
197
+ )
198
+
199
+ c = glsl_code(A[i, j]*x[j], assign_to=y[i])
200
+ assert c == s
201
+
202
+
203
+ def test_dummy_loops():
204
+ i, m = symbols('i m', integer=True, cls=Dummy)
205
+ x = IndexedBase('x')
206
+ y = IndexedBase('y')
207
+ i = Idx(i, m)
208
+
209
+ expected = (
210
+ 'for (int i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
211
+ ' y[i_%(icount)i] = x[i_%(icount)i];\n'
212
+ '}'
213
+ ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
214
+ code = glsl_code(x[i], assign_to=y[i])
215
+ assert code == expected
216
+
217
+
218
+ def test_glsl_code_loops_add():
219
+ n, m = symbols('n m', integer=True)
220
+ A = IndexedBase('A')
221
+ x = IndexedBase('x')
222
+ y = IndexedBase('y')
223
+ z = IndexedBase('z')
224
+ i = Idx('i', m)
225
+ j = Idx('j', n)
226
+
227
+ s = (
228
+ 'for (int i=0; i<m; i++){\n'
229
+ ' y[i] = x[i] + z[i];\n'
230
+ '}\n'
231
+ 'for (int i=0; i<m; i++){\n'
232
+ ' for (int j=0; j<n; j++){\n'
233
+ ' y[i] = A[n*i + j]*x[j] + y[i];\n'
234
+ ' }\n'
235
+ '}'
236
+ )
237
+ c = glsl_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
238
+ assert c == s
239
+
240
+
241
+ def test_glsl_code_loops_multiple_contractions():
242
+ n, m, o, p = symbols('n m o p', integer=True)
243
+ a = IndexedBase('a')
244
+ b = IndexedBase('b')
245
+ y = IndexedBase('y')
246
+ i = Idx('i', m)
247
+ j = Idx('j', n)
248
+ k = Idx('k', o)
249
+ l = Idx('l', p)
250
+
251
+ s = (
252
+ 'for (int i=0; i<m; i++){\n'
253
+ ' y[i] = 0.0;\n'
254
+ '}\n'
255
+ 'for (int i=0; i<m; i++){\n'
256
+ ' for (int j=0; j<n; j++){\n'
257
+ ' for (int k=0; k<o; k++){\n'
258
+ ' for (int l=0; l<p; l++){\n'
259
+ ' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
260
+ ' }\n'
261
+ ' }\n'
262
+ ' }\n'
263
+ '}'
264
+ )
265
+ c = glsl_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
266
+ assert c == s
267
+
268
+
269
+ def test_glsl_code_loops_addfactor():
270
+ n, m, o, p = symbols('n m o p', integer=True)
271
+ a = IndexedBase('a')
272
+ b = IndexedBase('b')
273
+ c = IndexedBase('c')
274
+ y = IndexedBase('y')
275
+ i = Idx('i', m)
276
+ j = Idx('j', n)
277
+ k = Idx('k', o)
278
+ l = Idx('l', p)
279
+
280
+ s = (
281
+ 'for (int i=0; i<m; i++){\n'
282
+ ' y[i] = 0.0;\n'
283
+ '}\n'
284
+ 'for (int i=0; i<m; i++){\n'
285
+ ' for (int j=0; j<n; j++){\n'
286
+ ' for (int k=0; k<o; k++){\n'
287
+ ' for (int l=0; l<p; l++){\n'
288
+ ' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
289
+ ' }\n'
290
+ ' }\n'
291
+ ' }\n'
292
+ '}'
293
+ )
294
+ c = glsl_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
295
+ assert c == s
296
+
297
+
298
+ def test_glsl_code_loops_multiple_terms():
299
+ n, m, o, p = symbols('n m o p', integer=True)
300
+ a = IndexedBase('a')
301
+ b = IndexedBase('b')
302
+ c = IndexedBase('c')
303
+ y = IndexedBase('y')
304
+ i = Idx('i', m)
305
+ j = Idx('j', n)
306
+ k = Idx('k', o)
307
+
308
+ s0 = (
309
+ 'for (int i=0; i<m; i++){\n'
310
+ ' y[i] = 0.0;\n'
311
+ '}\n'
312
+ )
313
+ s1 = (
314
+ 'for (int i=0; i<m; i++){\n'
315
+ ' for (int j=0; j<n; j++){\n'
316
+ ' for (int k=0; k<o; k++){\n'
317
+ ' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
318
+ ' }\n'
319
+ ' }\n'
320
+ '}\n'
321
+ )
322
+ s2 = (
323
+ 'for (int i=0; i<m; i++){\n'
324
+ ' for (int k=0; k<o; k++){\n'
325
+ ' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
326
+ ' }\n'
327
+ '}\n'
328
+ )
329
+ s3 = (
330
+ 'for (int i=0; i<m; i++){\n'
331
+ ' for (int j=0; j<n; j++){\n'
332
+ ' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
333
+ ' }\n'
334
+ '}\n'
335
+ )
336
+ c = glsl_code(
337
+ b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
338
+ assert (c == s0 + s1 + s2 + s3[:-1] or
339
+ c == s0 + s1 + s3 + s2[:-1] or
340
+ c == s0 + s2 + s1 + s3[:-1] or
341
+ c == s0 + s2 + s3 + s1[:-1] or
342
+ c == s0 + s3 + s1 + s2[:-1] or
343
+ c == s0 + s3 + s2 + s1[:-1])
344
+
345
+
346
+ def test_Matrix_printing():
347
+ # Test returning a Matrix
348
+
349
+ mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
350
+ A = MatrixSymbol('A', 3, 1)
351
+ assert glsl_code(mat, assign_to=A) == (
352
+ '''A[0][0] = x*y;
353
+ if (y > 0) {
354
+ A[1][0] = x + 2;
355
+ }
356
+ else {
357
+ A[1][0] = y;
358
+ }
359
+ A[2][0] = sin(z);''' )
360
+ assert glsl_code(Matrix([A[0],A[1]]))
361
+ # Test using MatrixElements in expressions
362
+ expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
363
+ assert glsl_code(expr) == (
364
+ '''((x > 0) ? (
365
+ 2*A[2][0]
366
+ )
367
+ : (
368
+ A[2][0]
369
+ )) + sin(A[1][0]) + A[0][0]''' )
370
+
371
+ # Test using MatrixElements in a Matrix
372
+ q = MatrixSymbol('q', 5, 1)
373
+ M = MatrixSymbol('M', 3, 3)
374
+ m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
375
+ [q[1,0] + q[2,0], q[3, 0], 5],
376
+ [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
377
+ assert glsl_code(m,M) == (
378
+ '''M[0][0] = sin(q[1]);
379
+ M[0][1] = 0;
380
+ M[0][2] = cos(q[2]);
381
+ M[1][0] = q[1] + q[2];
382
+ M[1][1] = q[3];
383
+ M[1][2] = 5;
384
+ M[2][0] = 2*q[4]/q[1];
385
+ M[2][1] = sqrt(q[0]) + 4;
386
+ M[2][2] = 0;'''
387
+ )
388
+
389
+ def test_Matrices_1x7():
390
+ gl = glsl_code
391
+ A = Matrix([1,2,3,4,5,6,7])
392
+ assert gl(A) == 'float[7](1, 2, 3, 4, 5, 6, 7)'
393
+ assert gl(A.transpose()) == 'float[7](1, 2, 3, 4, 5, 6, 7)'
394
+
395
+ def test_Matrices_1x7_array_type_int():
396
+ gl = glsl_code
397
+ A = Matrix([1,2,3,4,5,6,7])
398
+ assert gl(A, array_type='int') == 'int[7](1, 2, 3, 4, 5, 6, 7)'
399
+
400
+ def test_Tuple_array_type_custom():
401
+ gl = glsl_code
402
+ A = symbols('a b c')
403
+ assert gl(A, array_type='AbcType', glsl_types=False) == 'AbcType[3](a, b, c)'
404
+
405
+ def test_Matrices_1x7_spread_assign_to_symbols():
406
+ gl = glsl_code
407
+ A = Matrix([1,2,3,4,5,6,7])
408
+ assign_to = symbols('x.a x.b x.c x.d x.e x.f x.g')
409
+ assert gl(A, assign_to=assign_to) == textwrap.dedent('''\
410
+ x.a = 1;
411
+ x.b = 2;
412
+ x.c = 3;
413
+ x.d = 4;
414
+ x.e = 5;
415
+ x.f = 6;
416
+ x.g = 7;'''
417
+ )
418
+
419
+ def test_spread_assign_to_nested_symbols():
420
+ gl = glsl_code
421
+ expr = ((1,2,3), (1,2,3))
422
+ assign_to = (symbols('a b c'), symbols('x y z'))
423
+ assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
424
+ a = 1;
425
+ b = 2;
426
+ c = 3;
427
+ x = 1;
428
+ y = 2;
429
+ z = 3;'''
430
+ )
431
+
432
+ def test_spread_assign_to_deeply_nested_symbols():
433
+ gl = glsl_code
434
+ a, b, c, x, y, z = symbols('a b c x y z')
435
+ expr = (((1,2),3), ((1,2),3))
436
+ assign_to = (((a, b), c), ((x, y), z))
437
+ assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
438
+ a = 1;
439
+ b = 2;
440
+ c = 3;
441
+ x = 1;
442
+ y = 2;
443
+ z = 3;'''
444
+ )
445
+
446
+ def test_matrix_of_tuples_spread_assign_to_symbols():
447
+ gl = glsl_code
448
+ with warns_deprecated_sympy():
449
+ expr = Matrix([[(1,2),(3,4)],[(5,6),(7,8)]])
450
+ assign_to = (symbols('a b'), symbols('c d'), symbols('e f'), symbols('g h'))
451
+ assert gl(expr, assign_to) == textwrap.dedent('''\
452
+ a = 1;
453
+ b = 2;
454
+ c = 3;
455
+ d = 4;
456
+ e = 5;
457
+ f = 6;
458
+ g = 7;
459
+ h = 8;'''
460
+ )
461
+
462
+ def test_cannot_assign_to_cause_mismatched_length():
463
+ expr = (1, 2)
464
+ assign_to = symbols('x y z')
465
+ raises(ValueError, lambda: glsl_code(expr, assign_to))
466
+
467
+ def test_matrix_4x4_assign():
468
+ gl = glsl_code
469
+ expr = MatrixSymbol('A',4,4) * MatrixSymbol('B',4,4) + MatrixSymbol('C',4,4)
470
+ assign_to = MatrixSymbol('X',4,4)
471
+ assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\
472
+ X[0][0] = A[0][0]*B[0][0] + A[0][1]*B[1][0] + A[0][2]*B[2][0] + A[0][3]*B[3][0] + C[0][0];
473
+ X[0][1] = A[0][0]*B[0][1] + A[0][1]*B[1][1] + A[0][2]*B[2][1] + A[0][3]*B[3][1] + C[0][1];
474
+ X[0][2] = A[0][0]*B[0][2] + A[0][1]*B[1][2] + A[0][2]*B[2][2] + A[0][3]*B[3][2] + C[0][2];
475
+ X[0][3] = A[0][0]*B[0][3] + A[0][1]*B[1][3] + A[0][2]*B[2][3] + A[0][3]*B[3][3] + C[0][3];
476
+ X[1][0] = A[1][0]*B[0][0] + A[1][1]*B[1][0] + A[1][2]*B[2][0] + A[1][3]*B[3][0] + C[1][0];
477
+ X[1][1] = A[1][0]*B[0][1] + A[1][1]*B[1][1] + A[1][2]*B[2][1] + A[1][3]*B[3][1] + C[1][1];
478
+ X[1][2] = A[1][0]*B[0][2] + A[1][1]*B[1][2] + A[1][2]*B[2][2] + A[1][3]*B[3][2] + C[1][2];
479
+ X[1][3] = A[1][0]*B[0][3] + A[1][1]*B[1][3] + A[1][2]*B[2][3] + A[1][3]*B[3][3] + C[1][3];
480
+ X[2][0] = A[2][0]*B[0][0] + A[2][1]*B[1][0] + A[2][2]*B[2][0] + A[2][3]*B[3][0] + C[2][0];
481
+ X[2][1] = A[2][0]*B[0][1] + A[2][1]*B[1][1] + A[2][2]*B[2][1] + A[2][3]*B[3][1] + C[2][1];
482
+ X[2][2] = A[2][0]*B[0][2] + A[2][1]*B[1][2] + A[2][2]*B[2][2] + A[2][3]*B[3][2] + C[2][2];
483
+ X[2][3] = A[2][0]*B[0][3] + A[2][1]*B[1][3] + A[2][2]*B[2][3] + A[2][3]*B[3][3] + C[2][3];
484
+ X[3][0] = A[3][0]*B[0][0] + A[3][1]*B[1][0] + A[3][2]*B[2][0] + A[3][3]*B[3][0] + C[3][0];
485
+ X[3][1] = A[3][0]*B[0][1] + A[3][1]*B[1][1] + A[3][2]*B[2][1] + A[3][3]*B[3][1] + C[3][1];
486
+ X[3][2] = A[3][0]*B[0][2] + A[3][1]*B[1][2] + A[3][2]*B[2][2] + A[3][3]*B[3][2] + C[3][2];
487
+ X[3][3] = A[3][0]*B[0][3] + A[3][1]*B[1][3] + A[3][2]*B[2][3] + A[3][3]*B[3][3] + C[3][3];'''
488
+ )
489
+
490
+ def test_1xN_vecs():
491
+ gl = glsl_code
492
+ for i in range(1,10):
493
+ A = Matrix(range(i))
494
+ assert gl(A.transpose()) == gl(A)
495
+ assert gl(A,mat_transpose=True) == gl(A)
496
+ if i > 1:
497
+ if i <= 4:
498
+ assert gl(A) == 'vec%s(%s)' % (i,', '.join(str(s) for s in range(i)))
499
+ else:
500
+ assert gl(A) == 'float[%s](%s)' % (i,', '.join(str(s) for s in range(i)))
501
+
502
+ def test_MxN_mats():
503
+ generatedAssertions='def test_misc_mats():\n'
504
+ for i in range(1,6):
505
+ for j in range(1,6):
506
+ A = Matrix([[x + y*j for x in range(j)] for y in range(i)])
507
+ gl = glsl_code(A)
508
+ glTransposed = glsl_code(A,mat_transpose=True)
509
+ generatedAssertions+=' mat = '+StrPrinter()._print(A)+'\n\n'
510
+ generatedAssertions+=' gl = \'\'\''+gl+'\'\'\'\n'
511
+ generatedAssertions+=' glTransposed = \'\'\''+glTransposed+'\'\'\'\n\n'
512
+ generatedAssertions+=' assert glsl_code(mat) == gl\n'
513
+ generatedAssertions+=' assert glsl_code(mat,mat_transpose=True) == glTransposed\n'
514
+ if i == 1 and j == 1:
515
+ assert gl == '0'
516
+ elif i <= 4 and j <= 4 and i>1 and j>1:
517
+ assert gl.startswith('mat%s' % j)
518
+ assert glTransposed.startswith('mat%s' % i)
519
+ elif i == 1 and j <= 4:
520
+ assert gl.startswith('vec')
521
+ elif j == 1 and i <= 4:
522
+ assert gl.startswith('vec')
523
+ elif i == 1:
524
+ assert gl.startswith('float[%s]('% j*i)
525
+ assert glTransposed.startswith('float[%s]('% j*i)
526
+ elif j == 1:
527
+ assert gl.startswith('float[%s]('% i*j)
528
+ assert glTransposed.startswith('float[%s]('% i*j)
529
+ else:
530
+ assert gl.startswith('float[%s](' % (i*j))
531
+ assert glTransposed.startswith('float[%s](' % (i*j))
532
+ glNested = glsl_code(A,mat_nested=True)
533
+ glNestedTransposed = glsl_code(A,mat_transpose=True,mat_nested=True)
534
+ assert glNested.startswith('float[%s][%s]' % (i,j))
535
+ assert glNestedTransposed.startswith('float[%s][%s]' % (j,i))
536
+ generatedAssertions+=' glNested = \'\'\''+glNested+'\'\'\'\n'
537
+ generatedAssertions+=' glNestedTransposed = \'\'\''+glNestedTransposed+'\'\'\'\n\n'
538
+ generatedAssertions+=' assert glsl_code(mat,mat_nested=True) == glNested\n'
539
+ generatedAssertions+=' assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed\n\n'
540
+ generateAssertions = False # set this to true to write bake these generated tests to a file
541
+ if generateAssertions:
542
+ gen = open('test_glsl_generated_matrices.py','w')
543
+ gen.write(generatedAssertions)
544
+ gen.close()
545
+
546
+
547
+ # these assertions were generated from the previous function
548
+ # glsl has complicated rules and this makes it easier to look over all the cases
549
+ def test_misc_mats():
550
+
551
+ mat = Matrix([[0]])
552
+
553
+ gl = '''0'''
554
+ glTransposed = '''0'''
555
+
556
+ assert glsl_code(mat) == gl
557
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
558
+
559
+ mat = Matrix([[0, 1]])
560
+
561
+ gl = '''vec2(0, 1)'''
562
+ glTransposed = '''vec2(0, 1)'''
563
+
564
+ assert glsl_code(mat) == gl
565
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
566
+
567
+ mat = Matrix([[0, 1, 2]])
568
+
569
+ gl = '''vec3(0, 1, 2)'''
570
+ glTransposed = '''vec3(0, 1, 2)'''
571
+
572
+ assert glsl_code(mat) == gl
573
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
574
+
575
+ mat = Matrix([[0, 1, 2, 3]])
576
+
577
+ gl = '''vec4(0, 1, 2, 3)'''
578
+ glTransposed = '''vec4(0, 1, 2, 3)'''
579
+
580
+ assert glsl_code(mat) == gl
581
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
582
+
583
+ mat = Matrix([[0, 1, 2, 3, 4]])
584
+
585
+ gl = '''float[5](0, 1, 2, 3, 4)'''
586
+ glTransposed = '''float[5](0, 1, 2, 3, 4)'''
587
+
588
+ assert glsl_code(mat) == gl
589
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
590
+
591
+ mat = Matrix([
592
+ [0],
593
+ [1]])
594
+
595
+ gl = '''vec2(0, 1)'''
596
+ glTransposed = '''vec2(0, 1)'''
597
+
598
+ assert glsl_code(mat) == gl
599
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
600
+
601
+ mat = Matrix([
602
+ [0, 1],
603
+ [2, 3]])
604
+
605
+ gl = '''mat2(0, 1, 2, 3)'''
606
+ glTransposed = '''mat2(0, 2, 1, 3)'''
607
+
608
+ assert glsl_code(mat) == gl
609
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
610
+
611
+ mat = Matrix([
612
+ [0, 1, 2],
613
+ [3, 4, 5]])
614
+
615
+ gl = '''mat3x2(0, 1, 2, 3, 4, 5)'''
616
+ glTransposed = '''mat2x3(0, 3, 1, 4, 2, 5)'''
617
+
618
+ assert glsl_code(mat) == gl
619
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
620
+
621
+ mat = Matrix([
622
+ [0, 1, 2, 3],
623
+ [4, 5, 6, 7]])
624
+
625
+ gl = '''mat4x2(0, 1, 2, 3, 4, 5, 6, 7)'''
626
+ glTransposed = '''mat2x4(0, 4, 1, 5, 2, 6, 3, 7)'''
627
+
628
+ assert glsl_code(mat) == gl
629
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
630
+
631
+ mat = Matrix([
632
+ [0, 1, 2, 3, 4],
633
+ [5, 6, 7, 8, 9]])
634
+
635
+ gl = '''float[10](
636
+ 0, 1, 2, 3, 4,
637
+ 5, 6, 7, 8, 9
638
+ ) /* a 2x5 matrix */'''
639
+ glTransposed = '''float[10](
640
+ 0, 5,
641
+ 1, 6,
642
+ 2, 7,
643
+ 3, 8,
644
+ 4, 9
645
+ ) /* a 5x2 matrix */'''
646
+
647
+ assert glsl_code(mat) == gl
648
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
649
+ glNested = '''float[2][5](
650
+ float[](0, 1, 2, 3, 4),
651
+ float[](5, 6, 7, 8, 9)
652
+ )'''
653
+ glNestedTransposed = '''float[5][2](
654
+ float[](0, 5),
655
+ float[](1, 6),
656
+ float[](2, 7),
657
+ float[](3, 8),
658
+ float[](4, 9)
659
+ )'''
660
+
661
+ assert glsl_code(mat,mat_nested=True) == glNested
662
+ assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
663
+
664
+ mat = Matrix([
665
+ [0],
666
+ [1],
667
+ [2]])
668
+
669
+ gl = '''vec3(0, 1, 2)'''
670
+ glTransposed = '''vec3(0, 1, 2)'''
671
+
672
+ assert glsl_code(mat) == gl
673
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
674
+
675
+ mat = Matrix([
676
+ [0, 1],
677
+ [2, 3],
678
+ [4, 5]])
679
+
680
+ gl = '''mat2x3(0, 1, 2, 3, 4, 5)'''
681
+ glTransposed = '''mat3x2(0, 2, 4, 1, 3, 5)'''
682
+
683
+ assert glsl_code(mat) == gl
684
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
685
+
686
+ mat = Matrix([
687
+ [0, 1, 2],
688
+ [3, 4, 5],
689
+ [6, 7, 8]])
690
+
691
+ gl = '''mat3(0, 1, 2, 3, 4, 5, 6, 7, 8)'''
692
+ glTransposed = '''mat3(0, 3, 6, 1, 4, 7, 2, 5, 8)'''
693
+
694
+ assert glsl_code(mat) == gl
695
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
696
+
697
+ mat = Matrix([
698
+ [0, 1, 2, 3],
699
+ [4, 5, 6, 7],
700
+ [8, 9, 10, 11]])
701
+
702
+ gl = '''mat4x3(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)'''
703
+ glTransposed = '''mat3x4(0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11)'''
704
+
705
+ assert glsl_code(mat) == gl
706
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
707
+
708
+ mat = Matrix([
709
+ [ 0, 1, 2, 3, 4],
710
+ [ 5, 6, 7, 8, 9],
711
+ [10, 11, 12, 13, 14]])
712
+
713
+ gl = '''float[15](
714
+ 0, 1, 2, 3, 4,
715
+ 5, 6, 7, 8, 9,
716
+ 10, 11, 12, 13, 14
717
+ ) /* a 3x5 matrix */'''
718
+ glTransposed = '''float[15](
719
+ 0, 5, 10,
720
+ 1, 6, 11,
721
+ 2, 7, 12,
722
+ 3, 8, 13,
723
+ 4, 9, 14
724
+ ) /* a 5x3 matrix */'''
725
+
726
+ assert glsl_code(mat) == gl
727
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
728
+ glNested = '''float[3][5](
729
+ float[]( 0, 1, 2, 3, 4),
730
+ float[]( 5, 6, 7, 8, 9),
731
+ float[](10, 11, 12, 13, 14)
732
+ )'''
733
+ glNestedTransposed = '''float[5][3](
734
+ float[](0, 5, 10),
735
+ float[](1, 6, 11),
736
+ float[](2, 7, 12),
737
+ float[](3, 8, 13),
738
+ float[](4, 9, 14)
739
+ )'''
740
+
741
+ assert glsl_code(mat,mat_nested=True) == glNested
742
+ assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
743
+
744
+ mat = Matrix([
745
+ [0],
746
+ [1],
747
+ [2],
748
+ [3]])
749
+
750
+ gl = '''vec4(0, 1, 2, 3)'''
751
+ glTransposed = '''vec4(0, 1, 2, 3)'''
752
+
753
+ assert glsl_code(mat) == gl
754
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
755
+
756
+ mat = Matrix([
757
+ [0, 1],
758
+ [2, 3],
759
+ [4, 5],
760
+ [6, 7]])
761
+
762
+ gl = '''mat2x4(0, 1, 2, 3, 4, 5, 6, 7)'''
763
+ glTransposed = '''mat4x2(0, 2, 4, 6, 1, 3, 5, 7)'''
764
+
765
+ assert glsl_code(mat) == gl
766
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
767
+
768
+ mat = Matrix([
769
+ [0, 1, 2],
770
+ [3, 4, 5],
771
+ [6, 7, 8],
772
+ [9, 10, 11]])
773
+
774
+ gl = '''mat3x4(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)'''
775
+ glTransposed = '''mat4x3(0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11)'''
776
+
777
+ assert glsl_code(mat) == gl
778
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
779
+
780
+ mat = Matrix([
781
+ [ 0, 1, 2, 3],
782
+ [ 4, 5, 6, 7],
783
+ [ 8, 9, 10, 11],
784
+ [12, 13, 14, 15]])
785
+
786
+ gl = '''mat4( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)'''
787
+ glTransposed = '''mat4(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)'''
788
+
789
+ assert glsl_code(mat) == gl
790
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
791
+
792
+ mat = Matrix([
793
+ [ 0, 1, 2, 3, 4],
794
+ [ 5, 6, 7, 8, 9],
795
+ [10, 11, 12, 13, 14],
796
+ [15, 16, 17, 18, 19]])
797
+
798
+ gl = '''float[20](
799
+ 0, 1, 2, 3, 4,
800
+ 5, 6, 7, 8, 9,
801
+ 10, 11, 12, 13, 14,
802
+ 15, 16, 17, 18, 19
803
+ ) /* a 4x5 matrix */'''
804
+ glTransposed = '''float[20](
805
+ 0, 5, 10, 15,
806
+ 1, 6, 11, 16,
807
+ 2, 7, 12, 17,
808
+ 3, 8, 13, 18,
809
+ 4, 9, 14, 19
810
+ ) /* a 5x4 matrix */'''
811
+
812
+ assert glsl_code(mat) == gl
813
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
814
+ glNested = '''float[4][5](
815
+ float[]( 0, 1, 2, 3, 4),
816
+ float[]( 5, 6, 7, 8, 9),
817
+ float[](10, 11, 12, 13, 14),
818
+ float[](15, 16, 17, 18, 19)
819
+ )'''
820
+ glNestedTransposed = '''float[5][4](
821
+ float[](0, 5, 10, 15),
822
+ float[](1, 6, 11, 16),
823
+ float[](2, 7, 12, 17),
824
+ float[](3, 8, 13, 18),
825
+ float[](4, 9, 14, 19)
826
+ )'''
827
+
828
+ assert glsl_code(mat,mat_nested=True) == glNested
829
+ assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
830
+
831
+ mat = Matrix([
832
+ [0],
833
+ [1],
834
+ [2],
835
+ [3],
836
+ [4]])
837
+
838
+ gl = '''float[5](0, 1, 2, 3, 4)'''
839
+ glTransposed = '''float[5](0, 1, 2, 3, 4)'''
840
+
841
+ assert glsl_code(mat) == gl
842
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
843
+
844
+ mat = Matrix([
845
+ [0, 1],
846
+ [2, 3],
847
+ [4, 5],
848
+ [6, 7],
849
+ [8, 9]])
850
+
851
+ gl = '''float[10](
852
+ 0, 1,
853
+ 2, 3,
854
+ 4, 5,
855
+ 6, 7,
856
+ 8, 9
857
+ ) /* a 5x2 matrix */'''
858
+ glTransposed = '''float[10](
859
+ 0, 2, 4, 6, 8,
860
+ 1, 3, 5, 7, 9
861
+ ) /* a 2x5 matrix */'''
862
+
863
+ assert glsl_code(mat) == gl
864
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
865
+ glNested = '''float[5][2](
866
+ float[](0, 1),
867
+ float[](2, 3),
868
+ float[](4, 5),
869
+ float[](6, 7),
870
+ float[](8, 9)
871
+ )'''
872
+ glNestedTransposed = '''float[2][5](
873
+ float[](0, 2, 4, 6, 8),
874
+ float[](1, 3, 5, 7, 9)
875
+ )'''
876
+
877
+ assert glsl_code(mat,mat_nested=True) == glNested
878
+ assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
879
+
880
+ mat = Matrix([
881
+ [ 0, 1, 2],
882
+ [ 3, 4, 5],
883
+ [ 6, 7, 8],
884
+ [ 9, 10, 11],
885
+ [12, 13, 14]])
886
+
887
+ gl = '''float[15](
888
+ 0, 1, 2,
889
+ 3, 4, 5,
890
+ 6, 7, 8,
891
+ 9, 10, 11,
892
+ 12, 13, 14
893
+ ) /* a 5x3 matrix */'''
894
+ glTransposed = '''float[15](
895
+ 0, 3, 6, 9, 12,
896
+ 1, 4, 7, 10, 13,
897
+ 2, 5, 8, 11, 14
898
+ ) /* a 3x5 matrix */'''
899
+
900
+ assert glsl_code(mat) == gl
901
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
902
+ glNested = '''float[5][3](
903
+ float[]( 0, 1, 2),
904
+ float[]( 3, 4, 5),
905
+ float[]( 6, 7, 8),
906
+ float[]( 9, 10, 11),
907
+ float[](12, 13, 14)
908
+ )'''
909
+ glNestedTransposed = '''float[3][5](
910
+ float[](0, 3, 6, 9, 12),
911
+ float[](1, 4, 7, 10, 13),
912
+ float[](2, 5, 8, 11, 14)
913
+ )'''
914
+
915
+ assert glsl_code(mat,mat_nested=True) == glNested
916
+ assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
917
+
918
+ mat = Matrix([
919
+ [ 0, 1, 2, 3],
920
+ [ 4, 5, 6, 7],
921
+ [ 8, 9, 10, 11],
922
+ [12, 13, 14, 15],
923
+ [16, 17, 18, 19]])
924
+
925
+ gl = '''float[20](
926
+ 0, 1, 2, 3,
927
+ 4, 5, 6, 7,
928
+ 8, 9, 10, 11,
929
+ 12, 13, 14, 15,
930
+ 16, 17, 18, 19
931
+ ) /* a 5x4 matrix */'''
932
+ glTransposed = '''float[20](
933
+ 0, 4, 8, 12, 16,
934
+ 1, 5, 9, 13, 17,
935
+ 2, 6, 10, 14, 18,
936
+ 3, 7, 11, 15, 19
937
+ ) /* a 4x5 matrix */'''
938
+
939
+ assert glsl_code(mat) == gl
940
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
941
+ glNested = '''float[5][4](
942
+ float[]( 0, 1, 2, 3),
943
+ float[]( 4, 5, 6, 7),
944
+ float[]( 8, 9, 10, 11),
945
+ float[](12, 13, 14, 15),
946
+ float[](16, 17, 18, 19)
947
+ )'''
948
+ glNestedTransposed = '''float[4][5](
949
+ float[](0, 4, 8, 12, 16),
950
+ float[](1, 5, 9, 13, 17),
951
+ float[](2, 6, 10, 14, 18),
952
+ float[](3, 7, 11, 15, 19)
953
+ )'''
954
+
955
+ assert glsl_code(mat,mat_nested=True) == glNested
956
+ assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
957
+
958
+ mat = Matrix([
959
+ [ 0, 1, 2, 3, 4],
960
+ [ 5, 6, 7, 8, 9],
961
+ [10, 11, 12, 13, 14],
962
+ [15, 16, 17, 18, 19],
963
+ [20, 21, 22, 23, 24]])
964
+
965
+ gl = '''float[25](
966
+ 0, 1, 2, 3, 4,
967
+ 5, 6, 7, 8, 9,
968
+ 10, 11, 12, 13, 14,
969
+ 15, 16, 17, 18, 19,
970
+ 20, 21, 22, 23, 24
971
+ ) /* a 5x5 matrix */'''
972
+ glTransposed = '''float[25](
973
+ 0, 5, 10, 15, 20,
974
+ 1, 6, 11, 16, 21,
975
+ 2, 7, 12, 17, 22,
976
+ 3, 8, 13, 18, 23,
977
+ 4, 9, 14, 19, 24
978
+ ) /* a 5x5 matrix */'''
979
+
980
+ assert glsl_code(mat) == gl
981
+ assert glsl_code(mat,mat_transpose=True) == glTransposed
982
+ glNested = '''float[5][5](
983
+ float[]( 0, 1, 2, 3, 4),
984
+ float[]( 5, 6, 7, 8, 9),
985
+ float[](10, 11, 12, 13, 14),
986
+ float[](15, 16, 17, 18, 19),
987
+ float[](20, 21, 22, 23, 24)
988
+ )'''
989
+ glNestedTransposed = '''float[5][5](
990
+ float[](0, 5, 10, 15, 20),
991
+ float[](1, 6, 11, 16, 21),
992
+ float[](2, 7, 12, 17, 22),
993
+ float[](3, 8, 13, 18, 23),
994
+ float[](4, 9, 14, 19, 24)
995
+ )'''
996
+
997
+ assert glsl_code(mat,mat_nested=True) == glNested
998
+ assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_gtk.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.functions.elementary.trigonometric import sin
2
+ from sympy.printing.gtk import print_gtk
3
+ from sympy.testing.pytest import XFAIL, raises
4
+
5
+ # this test fails if python-lxml isn't installed. We don't want to depend on
6
+ # anything with SymPy
7
+
8
+
9
+ @XFAIL
10
+ def test_1():
11
+ from sympy.abc import x
12
+ print_gtk(x**2, start_viewer=False)
13
+ print_gtk(x**2 + sin(x)/4, start_viewer=False)
14
+
15
+
16
+ def test_settings():
17
+ from sympy.abc import x
18
+ raises(TypeError, lambda: print_gtk(x, method="garbage"))
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_jax.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete.summations import Sum
2
+ from sympy.core.mod import Mod
3
+ from sympy.core.relational import (Equality, Unequality)
4
+ from sympy.functions.elementary.miscellaneous import sqrt
5
+ from sympy.functions.elementary.piecewise import Piecewise
6
+ from sympy.matrices.expressions.blockmatrix import BlockMatrix
7
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
8
+ from sympy.matrices.expressions.special import Identity
9
+ from sympy.utilities.lambdify import lambdify
10
+
11
+ from sympy.abc import x, i, j, a, b, c, d
12
+ from sympy.core import Function, Pow, Symbol
13
+ from sympy.codegen.matrix_nodes import MatrixSolve
14
+ from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
15
+ from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt
16
+ from sympy.tensor.array import Array
17
+ from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
18
+ PermuteDims, ArrayDiagonal
19
+ from sympy.printing.numpy import JaxPrinter, _jax_known_constants, _jax_known_functions
20
+ from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
21
+
22
+ from sympy.testing.pytest import skip, raises
23
+ from sympy.external import import_module
24
+
25
+ # Unlike NumPy which will aggressively promote operands to double precision,
26
+ # jax always uses single precision. Double precision in jax can be
27
+ # configured before the call to `import jax`, however this must be explicitly
28
+ # configured and is not fully supported. Thus, the tests here have been modified
29
+ # from the tests in test_numpy.py, only in the fact that they assert lambdify
30
+ # function accuracy to only single precision accuracy.
31
+ # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
32
+
33
+ jax = import_module('jax')
34
+
35
+ if jax:
36
+ deafult_float_info = jax.numpy.finfo(jax.numpy.array([]).dtype)
37
+ JAX_DEFAULT_EPSILON = deafult_float_info.eps
38
+
39
+
40
+ def test_jax_piecewise_regression():
41
+ """
42
+ NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid
43
+ breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+.
44
+ See gh-9747 and gh-9749 for details.
45
+ """
46
+ printer = JaxPrinter()
47
+ p = Piecewise((1, x < 0), (0, True))
48
+ assert printer.doprint(p) == \
49
+ 'jax.numpy.select([jax.numpy.less(x, 0),True], [1,0], default=jax.numpy.nan)'
50
+ assert printer.module_imports == {'jax.numpy': {'select', 'less', 'nan'}}
51
+
52
+
53
+ def test_jax_logaddexp():
54
+ lae = logaddexp(a, b)
55
+ assert JaxPrinter().doprint(lae) == 'jax.numpy.logaddexp(a, b)'
56
+ lae2 = logaddexp2(a, b)
57
+ assert JaxPrinter().doprint(lae2) == 'jax.numpy.logaddexp2(a, b)'
58
+
59
+
60
+ def test_jax_sum():
61
+ if not jax:
62
+ skip("JAX not installed")
63
+
64
+ s = Sum(x ** i, (i, a, b))
65
+ f = lambdify((a, b, x), s, 'jax')
66
+
67
+ a_, b_ = 0, 10
68
+ x_ = jax.numpy.linspace(-1, +1, 10)
69
+ assert jax.numpy.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
70
+
71
+ s = Sum(i * x, (i, a, b))
72
+ f = lambdify((a, b, x), s, 'jax')
73
+
74
+ a_, b_ = 0, 10
75
+ x_ = jax.numpy.linspace(-1, +1, 10)
76
+ assert jax.numpy.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
77
+
78
+
79
+ def test_jax_multiple_sums():
80
+ if not jax:
81
+ skip("JAX not installed")
82
+
83
+ s = Sum((x + j) * i, (i, a, b), (j, c, d))
84
+ f = lambdify((a, b, c, d, x), s, 'jax')
85
+
86
+ a_, b_ = 0, 10
87
+ c_, d_ = 11, 21
88
+ x_ = jax.numpy.linspace(-1, +1, 10)
89
+ assert jax.numpy.allclose(f(a_, b_, c_, d_, x_),
90
+ sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1)))
91
+
92
+
93
+ def test_jax_codegen_einsum():
94
+ if not jax:
95
+ skip("JAX not installed")
96
+
97
+ M = MatrixSymbol("M", 2, 2)
98
+ N = MatrixSymbol("N", 2, 2)
99
+
100
+ cg = convert_matrix_to_array(M * N)
101
+ f = lambdify((M, N), cg, 'jax')
102
+
103
+ ma = jax.numpy.array([[1, 2], [3, 4]])
104
+ mb = jax.numpy.array([[1,-2], [-1, 3]])
105
+ assert (f(ma, mb) == jax.numpy.matmul(ma, mb)).all()
106
+
107
+
108
+ def test_jax_codegen_extra():
109
+ if not jax:
110
+ skip("JAX not installed")
111
+
112
+ M = MatrixSymbol("M", 2, 2)
113
+ N = MatrixSymbol("N", 2, 2)
114
+ P = MatrixSymbol("P", 2, 2)
115
+ Q = MatrixSymbol("Q", 2, 2)
116
+ ma = jax.numpy.array([[1, 2], [3, 4]])
117
+ mb = jax.numpy.array([[1,-2], [-1, 3]])
118
+ mc = jax.numpy.array([[2, 0], [1, 2]])
119
+ md = jax.numpy.array([[1,-1], [4, 7]])
120
+
121
+ cg = ArrayTensorProduct(M, N)
122
+ f = lambdify((M, N), cg, 'jax')
123
+ assert (f(ma, mb) == jax.numpy.einsum(ma, [0, 1], mb, [2, 3])).all()
124
+
125
+ cg = ArrayAdd(M, N)
126
+ f = lambdify((M, N), cg, 'jax')
127
+ assert (f(ma, mb) == ma+mb).all()
128
+
129
+ cg = ArrayAdd(M, N, P)
130
+ f = lambdify((M, N, P), cg, 'jax')
131
+ assert (f(ma, mb, mc) == ma+mb+mc).all()
132
+
133
+ cg = ArrayAdd(M, N, P, Q)
134
+ f = lambdify((M, N, P, Q), cg, 'jax')
135
+ assert (f(ma, mb, mc, md) == ma+mb+mc+md).all()
136
+
137
+ cg = PermuteDims(M, [1, 0])
138
+ f = lambdify((M,), cg, 'jax')
139
+ assert (f(ma) == ma.T).all()
140
+
141
+ cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
142
+ f = lambdify((M, N), cg, 'jax')
143
+ assert (f(ma, mb) == jax.numpy.transpose(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all()
144
+
145
+ cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
146
+ f = lambdify((M, N), cg, 'jax')
147
+ assert (f(ma, mb) == jax.numpy.diagonal(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all()
148
+
149
+
150
+ def test_jax_relational():
151
+ if not jax:
152
+ skip("JAX not installed")
153
+
154
+ e = Equality(x, 1)
155
+
156
+ f = lambdify((x,), e, 'jax')
157
+ x_ = jax.numpy.array([0, 1, 2])
158
+ assert jax.numpy.array_equal(f(x_), [False, True, False])
159
+
160
+ e = Unequality(x, 1)
161
+
162
+ f = lambdify((x,), e, 'jax')
163
+ x_ = jax.numpy.array([0, 1, 2])
164
+ assert jax.numpy.array_equal(f(x_), [True, False, True])
165
+
166
+ e = (x < 1)
167
+
168
+ f = lambdify((x,), e, 'jax')
169
+ x_ = jax.numpy.array([0, 1, 2])
170
+ assert jax.numpy.array_equal(f(x_), [True, False, False])
171
+
172
+ e = (x <= 1)
173
+
174
+ f = lambdify((x,), e, 'jax')
175
+ x_ = jax.numpy.array([0, 1, 2])
176
+ assert jax.numpy.array_equal(f(x_), [True, True, False])
177
+
178
+ e = (x > 1)
179
+
180
+ f = lambdify((x,), e, 'jax')
181
+ x_ = jax.numpy.array([0, 1, 2])
182
+ assert jax.numpy.array_equal(f(x_), [False, False, True])
183
+
184
+ e = (x >= 1)
185
+
186
+ f = lambdify((x,), e, 'jax')
187
+ x_ = jax.numpy.array([0, 1, 2])
188
+ assert jax.numpy.array_equal(f(x_), [False, True, True])
189
+
190
+ # Multi-condition expressions
191
+ e = (x >= 1) & (x < 2)
192
+ f = lambdify((x,), e, 'jax')
193
+ x_ = jax.numpy.array([0, 1, 2])
194
+ assert jax.numpy.array_equal(f(x_), [False, True, False])
195
+
196
+ e = (x >= 1) | (x < 2)
197
+ f = lambdify((x,), e, 'jax')
198
+ x_ = jax.numpy.array([0, 1, 2])
199
+ assert jax.numpy.array_equal(f(x_), [True, True, True])
200
+
201
+ def test_jax_mod():
202
+ if not jax:
203
+ skip("JAX not installed")
204
+
205
+ e = Mod(a, b)
206
+ f = lambdify((a, b), e, 'jax')
207
+
208
+ a_ = jax.numpy.array([0, 1, 2, 3])
209
+ b_ = 2
210
+ assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1])
211
+
212
+ a_ = jax.numpy.array([0, 1, 2, 3])
213
+ b_ = jax.numpy.array([2, 2, 2, 2])
214
+ assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1])
215
+
216
+ a_ = jax.numpy.array([2, 3, 4, 5])
217
+ b_ = jax.numpy.array([2, 3, 4, 5])
218
+ assert jax.numpy.array_equal(f(a_, b_), [0, 0, 0, 0])
219
+
220
+
221
+ def test_jax_pow():
222
+ if not jax:
223
+ skip('JAX not installed')
224
+
225
+ expr = Pow(2, -1, evaluate=False)
226
+ f = lambdify([], expr, 'jax')
227
+ assert f() == 0.5
228
+
229
+
230
+ def test_jax_expm1():
231
+ if not jax:
232
+ skip("JAX not installed")
233
+
234
+ f = lambdify((a,), expm1(a), 'jax')
235
+ assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * JAX_DEFAULT_EPSILON
236
+
237
+
238
+ def test_jax_log1p():
239
+ if not jax:
240
+ skip("JAX not installed")
241
+
242
+ f = lambdify((a,), log1p(a), 'jax')
243
+ assert abs(f(1e-99) - 1e-99) <= 1e-99 * JAX_DEFAULT_EPSILON
244
+
245
+ def test_jax_hypot():
246
+ if not jax:
247
+ skip("JAX not installed")
248
+ assert abs(lambdify((a, b), hypot(a, b), 'jax')(3, 4) - 5) <= JAX_DEFAULT_EPSILON
249
+
250
+ def test_jax_log10():
251
+ if not jax:
252
+ skip("JAX not installed")
253
+
254
+ assert abs(lambdify((a,), log10(a), 'jax')(100) - 2) <= JAX_DEFAULT_EPSILON
255
+
256
+
257
+ def test_jax_exp2():
258
+ if not jax:
259
+ skip("JAX not installed")
260
+ assert abs(lambdify((a,), exp2(a), 'jax')(5) - 32) <= JAX_DEFAULT_EPSILON
261
+
262
+
263
+ def test_jax_log2():
264
+ if not jax:
265
+ skip("JAX not installed")
266
+ assert abs(lambdify((a,), log2(a), 'jax')(256) - 8) <= JAX_DEFAULT_EPSILON
267
+
268
+
269
+ def test_jax_Sqrt():
270
+ if not jax:
271
+ skip("JAX not installed")
272
+ assert abs(lambdify((a,), Sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON
273
+
274
+
275
+ def test_jax_sqrt():
276
+ if not jax:
277
+ skip("JAX not installed")
278
+ assert abs(lambdify((a,), sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON
279
+
280
+
281
+ def test_jax_matsolve():
282
+ if not jax:
283
+ skip("JAX not installed")
284
+
285
+ M = MatrixSymbol("M", 3, 3)
286
+ x = MatrixSymbol("x", 3, 1)
287
+
288
+ expr = M**(-1) * x + x
289
+ matsolve_expr = MatrixSolve(M, x) + x
290
+
291
+ f = lambdify((M, x), expr, 'jax')
292
+ f_matsolve = lambdify((M, x), matsolve_expr, 'jax')
293
+
294
+ m0 = jax.numpy.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]])
295
+ assert jax.numpy.linalg.matrix_rank(m0) == 3
296
+
297
+ x0 = jax.numpy.array([3, 4, 5])
298
+
299
+ assert jax.numpy.allclose(f_matsolve(m0, x0), f(m0, x0))
300
+
301
+
302
+ def test_16857():
303
+ if not jax:
304
+ skip("JAX not installed")
305
+
306
+ a_1 = MatrixSymbol('a_1', 10, 3)
307
+ a_2 = MatrixSymbol('a_2', 10, 3)
308
+ a_3 = MatrixSymbol('a_3', 10, 3)
309
+ a_4 = MatrixSymbol('a_4', 10, 3)
310
+ A = BlockMatrix([[a_1, a_2], [a_3, a_4]])
311
+ assert A.shape == (20, 6)
312
+
313
+ printer = JaxPrinter()
314
+ assert printer.doprint(A) == 'jax.numpy.block([[a_1, a_2], [a_3, a_4]])'
315
+
316
+
317
+ def test_issue_17006():
318
+ if not jax:
319
+ skip("JAX not installed")
320
+
321
+ M = MatrixSymbol("M", 2, 2)
322
+
323
+ f = lambdify(M, M + Identity(2), 'jax')
324
+ ma = jax.numpy.array([[1, 2], [3, 4]])
325
+ mr = jax.numpy.array([[2, 2], [3, 5]])
326
+
327
+ assert (f(ma) == mr).all()
328
+
329
+ from sympy.core.symbol import symbols
330
+ n = symbols('n', integer=True)
331
+ N = MatrixSymbol("M", n, n)
332
+ raises(NotImplementedError, lambda: lambdify(N, N + Identity(n), 'jax'))
333
+
334
+
335
+ def test_jax_array():
336
+ assert JaxPrinter().doprint(Array(((1, 2), (3, 5)))) == 'jax.numpy.array([[1, 2], [3, 5]])'
337
+ assert JaxPrinter().doprint(Array((1, 2))) == 'jax.numpy.array([1, 2])'
338
+
339
+
340
+ def test_jax_known_funcs_consts():
341
+ assert _jax_known_constants['NaN'] == 'jax.numpy.nan'
342
+ assert _jax_known_constants['EulerGamma'] == 'jax.numpy.euler_gamma'
343
+
344
+ assert _jax_known_functions['acos'] == 'jax.numpy.arccos'
345
+ assert _jax_known_functions['log'] == 'jax.numpy.log'
346
+
347
+
348
+ def test_jax_print_methods():
349
+ prntr = JaxPrinter()
350
+ assert hasattr(prntr, '_print_acos')
351
+ assert hasattr(prntr, '_print_log')
352
+
353
+
354
+ def test_jax_printmethod():
355
+ printer = JaxPrinter()
356
+ assert hasattr(printer, 'printmethod')
357
+ assert printer.printmethod == '_jaxcode'
358
+
359
+
360
+ def test_jax_custom_print_method():
361
+
362
+ class expm1(Function):
363
+
364
+ def _jaxcode(self, printer):
365
+ x, = self.args
366
+ function = f'expm1({printer._print(x)})'
367
+ return printer._module_format(printer._module + '.' + function)
368
+
369
+ printer = JaxPrinter()
370
+ assert printer.doprint(expm1(Symbol('x'))) == 'jax.numpy.expm1(x)'
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_jscode.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio,
2
+ EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le,
3
+ Lt, Gt, Ge, Mod)
4
+ from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
5
+ sinh, cosh, tanh, asin, acos, acosh, Max, Min)
6
+ from sympy.testing.pytest import raises
7
+ from sympy.printing.jscode import JavascriptCodePrinter
8
+ from sympy.utilities.lambdify import implemented_function
9
+ from sympy.tensor import IndexedBase, Idx
10
+ from sympy.matrices import Matrix, MatrixSymbol
11
+
12
+ from sympy.printing.jscode import jscode
13
+
14
+ x, y, z = symbols('x,y,z')
15
+
16
+
17
+ def test_printmethod():
18
+ assert jscode(Abs(x)) == "Math.abs(x)"
19
+
20
+
21
+ def test_jscode_sqrt():
22
+ assert jscode(sqrt(x)) == "Math.sqrt(x)"
23
+ assert jscode(x**0.5) == "Math.sqrt(x)"
24
+ assert jscode(x**(S.One/3)) == "Math.cbrt(x)"
25
+
26
+
27
+ def test_jscode_Pow():
28
+ g = implemented_function('g', Lambda(x, 2*x))
29
+ assert jscode(x**3) == "Math.pow(x, 3)"
30
+ assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))"
31
+ assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
32
+ "Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)"
33
+ assert jscode(x**-1.0) == '1/x'
34
+
35
+
36
+ def test_jscode_constants_mathh():
37
+ assert jscode(exp(1)) == "Math.E"
38
+ assert jscode(pi) == "Math.PI"
39
+ assert jscode(oo) == "Number.POSITIVE_INFINITY"
40
+ assert jscode(-oo) == "Number.NEGATIVE_INFINITY"
41
+
42
+
43
+ def test_jscode_constants_other():
44
+ assert jscode(
45
+ 2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
46
+ assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
47
+ assert jscode(
48
+ 2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
49
+
50
+
51
+ def test_jscode_Rational():
52
+ assert jscode(Rational(3, 7)) == "3/7"
53
+ assert jscode(Rational(18, 9)) == "2"
54
+ assert jscode(Rational(3, -7)) == "-3/7"
55
+ assert jscode(Rational(-3, -7)) == "3/7"
56
+
57
+
58
+ def test_Relational():
59
+ assert jscode(Eq(x, y)) == "x == y"
60
+ assert jscode(Ne(x, y)) == "x != y"
61
+ assert jscode(Le(x, y)) == "x <= y"
62
+ assert jscode(Lt(x, y)) == "x < y"
63
+ assert jscode(Gt(x, y)) == "x > y"
64
+ assert jscode(Ge(x, y)) == "x >= y"
65
+
66
+
67
+ def test_Mod():
68
+ assert jscode(Mod(x, y)) == '((x % y) + y) % y'
69
+ assert jscode(Mod(x, x + y)) == '((x % (x + y)) + (x + y)) % (x + y)'
70
+ p1, p2 = symbols('p1 p2', positive=True)
71
+ assert jscode(Mod(p1, p2)) == 'p1 % p2'
72
+ assert jscode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)'
73
+ assert jscode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)'
74
+ assert jscode(-Mod(p1, p2)) == '-(p1 % p2)'
75
+ assert jscode(x*Mod(p1, p2)) == 'x*(p1 % p2)'
76
+
77
+
78
+ def test_jscode_Integer():
79
+ assert jscode(Integer(67)) == "67"
80
+ assert jscode(Integer(-1)) == "-1"
81
+
82
+
83
+ def test_jscode_functions():
84
+ assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))"
85
+ assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)"
86
+ assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)"
87
+ assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)"
88
+ assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)"
89
+
90
+
91
+ def test_jscode_inline_function():
92
+ x = symbols('x')
93
+ g = implemented_function('g', Lambda(x, 2*x))
94
+ assert jscode(g(x)) == "2*x"
95
+ g = implemented_function('g', Lambda(x, 2*x/Catalan))
96
+ assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
97
+ A = IndexedBase('A')
98
+ i = Idx('i', symbols('n', integer=True))
99
+ g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
100
+ assert jscode(g(A[i]), assign_to=A[i]) == (
101
+ "for (var i=0; i<n; i++){\n"
102
+ " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
103
+ "}"
104
+ )
105
+
106
+
107
+ def test_jscode_exceptions():
108
+ assert jscode(ceiling(x)) == "Math.ceil(x)"
109
+ assert jscode(Abs(x)) == "Math.abs(x)"
110
+
111
+
112
+ def test_jscode_boolean():
113
+ assert jscode(x & y) == "x && y"
114
+ assert jscode(x | y) == "x || y"
115
+ assert jscode(~x) == "!x"
116
+ assert jscode(x & y & z) == "x && y && z"
117
+ assert jscode(x | y | z) == "x || y || z"
118
+ assert jscode((x & y) | z) == "z || x && y"
119
+ assert jscode((x | y) & z) == "z && (x || y)"
120
+
121
+
122
+ def test_jscode_Piecewise():
123
+ expr = Piecewise((x, x < 1), (x**2, True))
124
+ p = jscode(expr)
125
+ s = \
126
+ """\
127
+ ((x < 1) ? (
128
+ x
129
+ )
130
+ : (
131
+ Math.pow(x, 2)
132
+ ))\
133
+ """
134
+ assert p == s
135
+ assert jscode(expr, assign_to="c") == (
136
+ "if (x < 1) {\n"
137
+ " c = x;\n"
138
+ "}\n"
139
+ "else {\n"
140
+ " c = Math.pow(x, 2);\n"
141
+ "}")
142
+ # Check that Piecewise without a True (default) condition error
143
+ expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
144
+ raises(ValueError, lambda: jscode(expr))
145
+
146
+
147
+ def test_jscode_Piecewise_deep():
148
+ p = jscode(2*Piecewise((x, x < 1), (x**2, True)))
149
+ s = \
150
+ """\
151
+ 2*((x < 1) ? (
152
+ x
153
+ )
154
+ : (
155
+ Math.pow(x, 2)
156
+ ))\
157
+ """
158
+ assert p == s
159
+
160
+
161
+ def test_jscode_settings():
162
+ raises(TypeError, lambda: jscode(sin(x), method="garbage"))
163
+
164
+
165
+ def test_jscode_Indexed():
166
+ n, m, o = symbols('n m o', integer=True)
167
+ i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
168
+ p = JavascriptCodePrinter()
169
+ p._not_c = set()
170
+
171
+ x = IndexedBase('x')[j]
172
+ assert p._print_Indexed(x) == 'x[j]'
173
+ A = IndexedBase('A')[i, j]
174
+ assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
175
+ B = IndexedBase('B')[i, j, k]
176
+ assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
177
+
178
+ assert p._not_c == set()
179
+
180
+
181
+ def test_jscode_loops_matrix_vector():
182
+ n, m = symbols('n m', integer=True)
183
+ A = IndexedBase('A')
184
+ x = IndexedBase('x')
185
+ y = IndexedBase('y')
186
+ i = Idx('i', m)
187
+ j = Idx('j', n)
188
+
189
+ s = (
190
+ 'for (var i=0; i<m; i++){\n'
191
+ ' y[i] = 0;\n'
192
+ '}\n'
193
+ 'for (var i=0; i<m; i++){\n'
194
+ ' for (var j=0; j<n; j++){\n'
195
+ ' y[i] = A[n*i + j]*x[j] + y[i];\n'
196
+ ' }\n'
197
+ '}'
198
+ )
199
+ c = jscode(A[i, j]*x[j], assign_to=y[i])
200
+ assert c == s
201
+
202
+
203
+ def test_dummy_loops():
204
+ i, m = symbols('i m', integer=True, cls=Dummy)
205
+ x = IndexedBase('x')
206
+ y = IndexedBase('y')
207
+ i = Idx(i, m)
208
+
209
+ expected = (
210
+ 'for (var i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
211
+ ' y[i_%(icount)i] = x[i_%(icount)i];\n'
212
+ '}'
213
+ ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
214
+ code = jscode(x[i], assign_to=y[i])
215
+ assert code == expected
216
+
217
+
218
+ def test_jscode_loops_add():
219
+ n, m = symbols('n m', integer=True)
220
+ A = IndexedBase('A')
221
+ x = IndexedBase('x')
222
+ y = IndexedBase('y')
223
+ z = IndexedBase('z')
224
+ i = Idx('i', m)
225
+ j = Idx('j', n)
226
+
227
+ s = (
228
+ 'for (var i=0; i<m; i++){\n'
229
+ ' y[i] = x[i] + z[i];\n'
230
+ '}\n'
231
+ 'for (var i=0; i<m; i++){\n'
232
+ ' for (var j=0; j<n; j++){\n'
233
+ ' y[i] = A[n*i + j]*x[j] + y[i];\n'
234
+ ' }\n'
235
+ '}'
236
+ )
237
+ c = jscode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
238
+ assert c == s
239
+
240
+
241
+ def test_jscode_loops_multiple_contractions():
242
+ n, m, o, p = symbols('n m o p', integer=True)
243
+ a = IndexedBase('a')
244
+ b = IndexedBase('b')
245
+ y = IndexedBase('y')
246
+ i = Idx('i', m)
247
+ j = Idx('j', n)
248
+ k = Idx('k', o)
249
+ l = Idx('l', p)
250
+
251
+ s = (
252
+ 'for (var i=0; i<m; i++){\n'
253
+ ' y[i] = 0;\n'
254
+ '}\n'
255
+ 'for (var i=0; i<m; i++){\n'
256
+ ' for (var j=0; j<n; j++){\n'
257
+ ' for (var k=0; k<o; k++){\n'
258
+ ' for (var l=0; l<p; l++){\n'
259
+ ' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
260
+ ' }\n'
261
+ ' }\n'
262
+ ' }\n'
263
+ '}'
264
+ )
265
+ c = jscode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
266
+ assert c == s
267
+
268
+
269
+ def test_jscode_loops_addfactor():
270
+ n, m, o, p = symbols('n m o p', integer=True)
271
+ a = IndexedBase('a')
272
+ b = IndexedBase('b')
273
+ c = IndexedBase('c')
274
+ y = IndexedBase('y')
275
+ i = Idx('i', m)
276
+ j = Idx('j', n)
277
+ k = Idx('k', o)
278
+ l = Idx('l', p)
279
+
280
+ s = (
281
+ 'for (var i=0; i<m; i++){\n'
282
+ ' y[i] = 0;\n'
283
+ '}\n'
284
+ 'for (var i=0; i<m; i++){\n'
285
+ ' for (var j=0; j<n; j++){\n'
286
+ ' for (var k=0; k<o; k++){\n'
287
+ ' for (var l=0; l<p; l++){\n'
288
+ ' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
289
+ ' }\n'
290
+ ' }\n'
291
+ ' }\n'
292
+ '}'
293
+ )
294
+ c = jscode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
295
+ assert c == s
296
+
297
+
298
+ def test_jscode_loops_multiple_terms():
299
+ n, m, o, p = symbols('n m o p', integer=True)
300
+ a = IndexedBase('a')
301
+ b = IndexedBase('b')
302
+ c = IndexedBase('c')
303
+ y = IndexedBase('y')
304
+ i = Idx('i', m)
305
+ j = Idx('j', n)
306
+ k = Idx('k', o)
307
+
308
+ s0 = (
309
+ 'for (var i=0; i<m; i++){\n'
310
+ ' y[i] = 0;\n'
311
+ '}\n'
312
+ )
313
+ s1 = (
314
+ 'for (var i=0; i<m; i++){\n'
315
+ ' for (var j=0; j<n; j++){\n'
316
+ ' for (var k=0; k<o; k++){\n'
317
+ ' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
318
+ ' }\n'
319
+ ' }\n'
320
+ '}\n'
321
+ )
322
+ s2 = (
323
+ 'for (var i=0; i<m; i++){\n'
324
+ ' for (var k=0; k<o; k++){\n'
325
+ ' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
326
+ ' }\n'
327
+ '}\n'
328
+ )
329
+ s3 = (
330
+ 'for (var i=0; i<m; i++){\n'
331
+ ' for (var j=0; j<n; j++){\n'
332
+ ' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
333
+ ' }\n'
334
+ '}\n'
335
+ )
336
+ c = jscode(
337
+ b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
338
+ assert (c == s0 + s1 + s2 + s3[:-1] or
339
+ c == s0 + s1 + s3 + s2[:-1] or
340
+ c == s0 + s2 + s1 + s3[:-1] or
341
+ c == s0 + s2 + s3 + s1[:-1] or
342
+ c == s0 + s3 + s1 + s2[:-1] or
343
+ c == s0 + s3 + s2 + s1[:-1])
344
+
345
+
346
+ def test_Matrix_printing():
347
+ # Test returning a Matrix
348
+ mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
349
+ A = MatrixSymbol('A', 3, 1)
350
+ assert jscode(mat, A) == (
351
+ "A[0] = x*y;\n"
352
+ "if (y > 0) {\n"
353
+ " A[1] = x + 2;\n"
354
+ "}\n"
355
+ "else {\n"
356
+ " A[1] = y;\n"
357
+ "}\n"
358
+ "A[2] = Math.sin(z);")
359
+ # Test using MatrixElements in expressions
360
+ expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
361
+ assert jscode(expr) == (
362
+ "((x > 0) ? (\n"
363
+ " 2*A[2]\n"
364
+ ")\n"
365
+ ": (\n"
366
+ " A[2]\n"
367
+ ")) + Math.sin(A[1]) + A[0]")
368
+ # Test using MatrixElements in a Matrix
369
+ q = MatrixSymbol('q', 5, 1)
370
+ M = MatrixSymbol('M', 3, 3)
371
+ m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
372
+ [q[1,0] + q[2,0], q[3, 0], 5],
373
+ [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
374
+ assert jscode(m, M) == (
375
+ "M[0] = Math.sin(q[1]);\n"
376
+ "M[1] = 0;\n"
377
+ "M[2] = Math.cos(q[2]);\n"
378
+ "M[3] = q[1] + q[2];\n"
379
+ "M[4] = q[3];\n"
380
+ "M[5] = 5;\n"
381
+ "M[6] = 2*q[4]/q[1];\n"
382
+ "M[7] = Math.sqrt(q[0]) + 4;\n"
383
+ "M[8] = 0;")
384
+
385
+
386
+ def test_MatrixElement_printing():
387
+ # test cases for issue #11821
388
+ A = MatrixSymbol("A", 1, 3)
389
+ B = MatrixSymbol("B", 1, 3)
390
+ C = MatrixSymbol("C", 1, 3)
391
+
392
+ assert(jscode(A[0, 0]) == "A[0]")
393
+ assert(jscode(3 * A[0, 0]) == "3*A[0]")
394
+
395
+ F = C[0, 0].subs(C, A - B)
396
+ assert(jscode(F) == "(A - B)[0]")
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_julia.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
2
+ Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge)
3
+ from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow
4
+ from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc
5
+ from sympy.testing.pytest import raises
6
+ from sympy.utilities.lambdify import implemented_function
7
+ from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
8
+ HadamardProduct, SparseMatrix)
9
+ from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli,
10
+ besselk, hankel1, hankel2, airyai,
11
+ airybi, airyaiprime, airybiprime)
12
+ from sympy.testing.pytest import XFAIL
13
+
14
+ from sympy.printing.julia import julia_code
15
+
16
+ x, y, z = symbols('x,y,z')
17
+
18
+
19
+ def test_Integer():
20
+ assert julia_code(Integer(67)) == "67"
21
+ assert julia_code(Integer(-1)) == "-1"
22
+
23
+
24
+ def test_Rational():
25
+ assert julia_code(Rational(3, 7)) == "3 // 7"
26
+ assert julia_code(Rational(18, 9)) == "2"
27
+ assert julia_code(Rational(3, -7)) == "-3 // 7"
28
+ assert julia_code(Rational(-3, -7)) == "3 // 7"
29
+ assert julia_code(x + Rational(3, 7)) == "x + 3 // 7"
30
+ assert julia_code(Rational(3, 7)*x) == "(3 // 7) * x"
31
+
32
+
33
+ def test_Relational():
34
+ assert julia_code(Eq(x, y)) == "x == y"
35
+ assert julia_code(Ne(x, y)) == "x != y"
36
+ assert julia_code(Le(x, y)) == "x <= y"
37
+ assert julia_code(Lt(x, y)) == "x < y"
38
+ assert julia_code(Gt(x, y)) == "x > y"
39
+ assert julia_code(Ge(x, y)) == "x >= y"
40
+
41
+
42
+ def test_Function():
43
+ assert julia_code(sin(x) ** cos(x)) == "sin(x) .^ cos(x)"
44
+ assert julia_code(abs(x)) == "abs(x)"
45
+ assert julia_code(ceiling(x)) == "ceil(x)"
46
+
47
+
48
+ def test_Pow():
49
+ assert julia_code(x**3) == "x .^ 3"
50
+ assert julia_code(x**(y**3)) == "x .^ (y .^ 3)"
51
+ assert julia_code(x**Rational(2, 3)) == 'x .^ (2 // 3)'
52
+ g = implemented_function('g', Lambda(x, 2*x))
53
+ assert julia_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
54
+ "(3.5 * 2 * x) .^ (-x + y .^ x) ./ (x .^ 2 + y)"
55
+ # For issue 14160
56
+ assert julia_code(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
57
+ evaluate=False)) == '-2 * x ./ (y .* y)'
58
+
59
+
60
+ def test_basic_ops():
61
+ assert julia_code(x*y) == "x .* y"
62
+ assert julia_code(x + y) == "x + y"
63
+ assert julia_code(x - y) == "x - y"
64
+ assert julia_code(-x) == "-x"
65
+
66
+
67
+ def test_1_over_x_and_sqrt():
68
+ # 1.0 and 0.5 would do something different in regular StrPrinter,
69
+ # but these are exact in IEEE floating point so no different here.
70
+ assert julia_code(1/x) == '1 ./ x'
71
+ assert julia_code(x**-1) == julia_code(x**-1.0) == '1 ./ x'
72
+ assert julia_code(1/sqrt(x)) == '1 ./ sqrt(x)'
73
+ assert julia_code(x**-S.Half) == julia_code(x**-0.5) == '1 ./ sqrt(x)'
74
+ assert julia_code(sqrt(x)) == 'sqrt(x)'
75
+ assert julia_code(x**S.Half) == julia_code(x**0.5) == 'sqrt(x)'
76
+ assert julia_code(1/pi) == '1 / pi'
77
+ assert julia_code(pi**-1) == julia_code(pi**-1.0) == '1 / pi'
78
+ assert julia_code(pi**-0.5) == '1 / sqrt(pi)'
79
+
80
+
81
+ def test_mix_number_mult_symbols():
82
+ assert julia_code(3*x) == "3 * x"
83
+ assert julia_code(pi*x) == "pi * x"
84
+ assert julia_code(3/x) == "3 ./ x"
85
+ assert julia_code(pi/x) == "pi ./ x"
86
+ assert julia_code(x/3) == "x / 3"
87
+ assert julia_code(x/pi) == "x / pi"
88
+ assert julia_code(x*y) == "x .* y"
89
+ assert julia_code(3*x*y) == "3 * x .* y"
90
+ assert julia_code(3*pi*x*y) == "3 * pi * x .* y"
91
+ assert julia_code(x/y) == "x ./ y"
92
+ assert julia_code(3*x/y) == "3 * x ./ y"
93
+ assert julia_code(x*y/z) == "x .* y ./ z"
94
+ assert julia_code(x/y*z) == "x .* z ./ y"
95
+ assert julia_code(1/x/y) == "1 ./ (x .* y)"
96
+ assert julia_code(2*pi*x/y/z) == "2 * pi * x ./ (y .* z)"
97
+ assert julia_code(3*pi/x) == "3 * pi ./ x"
98
+ assert julia_code(S(3)/5) == "3 // 5"
99
+ assert julia_code(S(3)/5*x) == "(3 // 5) * x"
100
+ assert julia_code(x/y/z) == "x ./ (y .* z)"
101
+ assert julia_code((x+y)/z) == "(x + y) ./ z"
102
+ assert julia_code((x+y)/(z+x)) == "(x + y) ./ (x + z)"
103
+ assert julia_code((x+y)/EulerGamma) == "(x + y) / eulergamma"
104
+ assert julia_code(x/3/pi) == "x / (3 * pi)"
105
+ assert julia_code(S(3)/5*x*y/pi) == "(3 // 5) * x .* y / pi"
106
+
107
+
108
+ def test_mix_number_pow_symbols():
109
+ assert julia_code(pi**3) == 'pi ^ 3'
110
+ assert julia_code(x**2) == 'x .^ 2'
111
+ assert julia_code(x**(pi**3)) == 'x .^ (pi ^ 3)'
112
+ assert julia_code(x**y) == 'x .^ y'
113
+ assert julia_code(x**(y**z)) == 'x .^ (y .^ z)'
114
+ assert julia_code((x**y)**z) == '(x .^ y) .^ z'
115
+
116
+
117
+ def test_imag():
118
+ I = S('I')
119
+ assert julia_code(I) == "im"
120
+ assert julia_code(5*I) == "5im"
121
+ assert julia_code((S(3)/2)*I) == "(3 // 2) * im"
122
+ assert julia_code(3+4*I) == "3 + 4im"
123
+
124
+
125
+ def test_constants():
126
+ assert julia_code(pi) == "pi"
127
+ assert julia_code(oo) == "Inf"
128
+ assert julia_code(-oo) == "-Inf"
129
+ assert julia_code(S.NegativeInfinity) == "-Inf"
130
+ assert julia_code(S.NaN) == "NaN"
131
+ assert julia_code(S.Exp1) == "e"
132
+ assert julia_code(exp(1)) == "e"
133
+
134
+
135
+ def test_constants_other():
136
+ assert julia_code(2*GoldenRatio) == "2 * golden"
137
+ assert julia_code(2*Catalan) == "2 * catalan"
138
+ assert julia_code(2*EulerGamma) == "2 * eulergamma"
139
+
140
+
141
+ def test_boolean():
142
+ assert julia_code(x & y) == "x && y"
143
+ assert julia_code(x | y) == "x || y"
144
+ assert julia_code(~x) == "!x"
145
+ assert julia_code(x & y & z) == "x && y && z"
146
+ assert julia_code(x | y | z) == "x || y || z"
147
+ assert julia_code((x & y) | z) == "z || x && y"
148
+ assert julia_code((x | y) & z) == "z && (x || y)"
149
+
150
+ def test_sinc():
151
+ assert julia_code(sinc(x)) == 'sinc(x / pi)'
152
+ assert julia_code(sinc(x + 3)) == 'sinc((x + 3) / pi)'
153
+ assert julia_code(sinc(pi * (x + 3))) == 'sinc(x + 3)'
154
+
155
+ def test_Matrices():
156
+ assert julia_code(Matrix(1, 1, [10])) == "[10]"
157
+ A = Matrix([[1, sin(x/2), abs(x)],
158
+ [0, 1, pi],
159
+ [0, exp(1), ceiling(x)]])
160
+ expected = ("[1 sin(x / 2) abs(x);\n"
161
+ "0 1 pi;\n"
162
+ "0 e ceil(x)]")
163
+ assert julia_code(A) == expected
164
+ # row and columns
165
+ assert julia_code(A[:,0]) == "[1, 0, 0]"
166
+ assert julia_code(A[0,:]) == "[1 sin(x / 2) abs(x)]"
167
+ # empty matrices
168
+ assert julia_code(Matrix(0, 0, [])) == 'zeros(0, 0)'
169
+ assert julia_code(Matrix(0, 3, [])) == 'zeros(0, 3)'
170
+ # annoying to read but correct
171
+ assert julia_code(Matrix([[x, x - y, -y]])) == "[x x - y -y]"
172
+
173
+
174
+ def test_vector_entries_hadamard():
175
+ # For a row or column, user might to use the other dimension
176
+ A = Matrix([[1, sin(2/x), 3*pi/x/5]])
177
+ assert julia_code(A) == "[1 sin(2 ./ x) (3 // 5) * pi ./ x]"
178
+ assert julia_code(A.T) == "[1, sin(2 ./ x), (3 // 5) * pi ./ x]"
179
+
180
+
181
+ @XFAIL
182
+ def test_Matrices_entries_not_hadamard():
183
+ # For Matrix with col >= 2, row >= 2, they need to be scalars
184
+ # FIXME: is it worth worrying about this? Its not wrong, just
185
+ # leave it user's responsibility to put scalar data for x.
186
+ A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]])
187
+ expected = ("[1 sin(2/x) 3*pi/(5*x);\n"
188
+ "1 2 x*y]") # <- we give x.*y
189
+ assert julia_code(A) == expected
190
+
191
+
192
+ def test_MatrixSymbol():
193
+ n = Symbol('n', integer=True)
194
+ A = MatrixSymbol('A', n, n)
195
+ B = MatrixSymbol('B', n, n)
196
+ assert julia_code(A*B) == "A * B"
197
+ assert julia_code(B*A) == "B * A"
198
+ assert julia_code(2*A*B) == "2 * A * B"
199
+ assert julia_code(B*2*A) == "2 * B * A"
200
+ assert julia_code(A*(B + 3*Identity(n))) == "A * (3 * eye(n) + B)"
201
+ assert julia_code(A**(x**2)) == "A ^ (x .^ 2)"
202
+ assert julia_code(A**3) == "A ^ 3"
203
+ assert julia_code(A**S.Half) == "A ^ (1 // 2)"
204
+
205
+
206
+ def test_special_matrices():
207
+ assert julia_code(6*Identity(3)) == "6 * eye(3)"
208
+
209
+
210
+ def test_containers():
211
+ assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
212
+ "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
213
+ assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
214
+ assert julia_code([1]) == "Any[1]"
215
+ assert julia_code((1,)) == "(1,)"
216
+ assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
217
+ assert julia_code((1, x*y, (3, x**2))) == "(1, x .* y, (3, x .^ 2))"
218
+ # scalar, matrix, empty matrix and empty list
219
+ assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
220
+
221
+
222
+ def test_julia_noninline():
223
+ source = julia_code((x+y)/Catalan, assign_to='me', inline=False)
224
+ expected = (
225
+ "const Catalan = %s\n"
226
+ "me = (x + y) / Catalan"
227
+ ) % Catalan.evalf(17)
228
+ assert source == expected
229
+
230
+
231
+ def test_julia_piecewise():
232
+ expr = Piecewise((x, x < 1), (x**2, True))
233
+ assert julia_code(expr) == "((x < 1) ? (x) : (x .^ 2))"
234
+ assert julia_code(expr, assign_to="r") == (
235
+ "r = ((x < 1) ? (x) : (x .^ 2))")
236
+ assert julia_code(expr, assign_to="r", inline=False) == (
237
+ "if (x < 1)\n"
238
+ " r = x\n"
239
+ "else\n"
240
+ " r = x .^ 2\n"
241
+ "end")
242
+ expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True))
243
+ expected = ("((x < 1) ? (x .^ 2) :\n"
244
+ "(x < 2) ? (x .^ 3) :\n"
245
+ "(x < 3) ? (x .^ 4) : (x .^ 5))")
246
+ assert julia_code(expr) == expected
247
+ assert julia_code(expr, assign_to="r") == "r = " + expected
248
+ assert julia_code(expr, assign_to="r", inline=False) == (
249
+ "if (x < 1)\n"
250
+ " r = x .^ 2\n"
251
+ "elseif (x < 2)\n"
252
+ " r = x .^ 3\n"
253
+ "elseif (x < 3)\n"
254
+ " r = x .^ 4\n"
255
+ "else\n"
256
+ " r = x .^ 5\n"
257
+ "end")
258
+ # Check that Piecewise without a True (default) condition error
259
+ expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
260
+ raises(ValueError, lambda: julia_code(expr))
261
+
262
+
263
+ def test_julia_piecewise_times_const():
264
+ pw = Piecewise((x, x < 1), (x**2, True))
265
+ assert julia_code(2*pw) == "2 * ((x < 1) ? (x) : (x .^ 2))"
266
+ assert julia_code(pw/x) == "((x < 1) ? (x) : (x .^ 2)) ./ x"
267
+ assert julia_code(pw/(x*y)) == "((x < 1) ? (x) : (x .^ 2)) ./ (x .* y)"
268
+ assert julia_code(pw/3) == "((x < 1) ? (x) : (x .^ 2)) / 3"
269
+
270
+
271
+ def test_julia_matrix_assign_to():
272
+ A = Matrix([[1, 2, 3]])
273
+ assert julia_code(A, assign_to='a') == "a = [1 2 3]"
274
+ A = Matrix([[1, 2], [3, 4]])
275
+ assert julia_code(A, assign_to='A') == "A = [1 2;\n3 4]"
276
+
277
+
278
+ def test_julia_matrix_assign_to_more():
279
+ # assigning to Symbol or MatrixSymbol requires lhs/rhs match
280
+ A = Matrix([[1, 2, 3]])
281
+ B = MatrixSymbol('B', 1, 3)
282
+ C = MatrixSymbol('C', 2, 3)
283
+ assert julia_code(A, assign_to=B) == "B = [1 2 3]"
284
+ raises(ValueError, lambda: julia_code(A, assign_to=x))
285
+ raises(ValueError, lambda: julia_code(A, assign_to=C))
286
+
287
+
288
+ def test_julia_matrix_1x1():
289
+ A = Matrix([[3]])
290
+ B = MatrixSymbol('B', 1, 1)
291
+ C = MatrixSymbol('C', 1, 2)
292
+ assert julia_code(A, assign_to=B) == "B = [3]"
293
+ # FIXME?
294
+ #assert julia_code(A, assign_to=x) == "x = [3]"
295
+ raises(ValueError, lambda: julia_code(A, assign_to=C))
296
+
297
+
298
+ def test_julia_matrix_elements():
299
+ A = Matrix([[x, 2, x*y]])
300
+ assert julia_code(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
301
+ A = MatrixSymbol('AA', 1, 3)
302
+ assert julia_code(A) == "AA"
303
+ assert julia_code(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \
304
+ "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
305
+ assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
306
+
307
+
308
+ def test_julia_boolean():
309
+ assert julia_code(True) == "true"
310
+ assert julia_code(S.true) == "true"
311
+ assert julia_code(False) == "false"
312
+ assert julia_code(S.false) == "false"
313
+
314
+
315
+ def test_julia_not_supported():
316
+ with raises(NotImplementedError):
317
+ julia_code(S.ComplexInfinity)
318
+
319
+ f = Function('f')
320
+ assert julia_code(f(x).diff(x), strict=False) == (
321
+ "# Not supported in Julia:\n"
322
+ "# Derivative\n"
323
+ "Derivative(f(x), x)"
324
+ )
325
+
326
+
327
+ def test_trick_indent_with_end_else_words():
328
+ # words starting with "end" or "else" do not confuse the indenter
329
+ t1 = S('endless')
330
+ t2 = S('elsewhere')
331
+ pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True))
332
+ assert julia_code(pw, inline=False) == (
333
+ "if (x < 0)\n"
334
+ " endless\n"
335
+ "elseif (x <= 1)\n"
336
+ " elsewhere\n"
337
+ "else\n"
338
+ " 1\n"
339
+ "end")
340
+
341
+
342
+ def test_haramard():
343
+ A = MatrixSymbol('A', 3, 3)
344
+ B = MatrixSymbol('B', 3, 3)
345
+ v = MatrixSymbol('v', 3, 1)
346
+ h = MatrixSymbol('h', 1, 3)
347
+ C = HadamardProduct(A, B)
348
+ assert julia_code(C) == "A .* B"
349
+ assert julia_code(C*v) == "(A .* B) * v"
350
+ assert julia_code(h*C*v) == "h * (A .* B) * v"
351
+ assert julia_code(C*A) == "(A .* B) * A"
352
+ # mixing Hadamard and scalar strange b/c we vectorize scalars
353
+ assert julia_code(C*x*y) == "(x .* y) * (A .* B)"
354
+
355
+
356
+ def test_sparse():
357
+ M = SparseMatrix(5, 6, {})
358
+ M[2, 2] = 10
359
+ M[1, 2] = 20
360
+ M[1, 3] = 22
361
+ M[0, 3] = 30
362
+ M[3, 0] = x*y
363
+ assert julia_code(M) == (
364
+ "sparse([4, 2, 3, 1, 2], [1, 3, 3, 4, 4], [x .* y, 20, 10, 30, 22], 5, 6)"
365
+ )
366
+
367
+
368
+ def test_specfun():
369
+ n = Symbol('n')
370
+ for f in [besselj, bessely, besseli, besselk]:
371
+ assert julia_code(f(n, x)) == f.__name__ + '(n, x)'
372
+ for f in [airyai, airyaiprime, airybi, airybiprime]:
373
+ assert julia_code(f(x)) == f.__name__ + '(x)'
374
+ assert julia_code(hankel1(n, x)) == 'hankelh1(n, x)'
375
+ assert julia_code(hankel2(n, x)) == 'hankelh2(n, x)'
376
+ assert julia_code(jn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* besselj(n + 1 // 2, x) / 2'
377
+ assert julia_code(yn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* bessely(n + 1 // 2, x) / 2'
378
+
379
+
380
+ def test_MatrixElement_printing():
381
+ # test cases for issue #11821
382
+ A = MatrixSymbol("A", 1, 3)
383
+ B = MatrixSymbol("B", 1, 3)
384
+ C = MatrixSymbol("C", 1, 3)
385
+
386
+ assert(julia_code(A[0, 0]) == "A[1,1]")
387
+ assert(julia_code(3 * A[0, 0]) == "3 * A[1,1]")
388
+
389
+ F = C[0, 0].subs(C, A - B)
390
+ assert(julia_code(F) == "(A - B)[1,1]")
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_lambdarepr.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete.summations import Sum
2
+ from sympy.core.expr import Expr
3
+ from sympy.core.symbol import symbols
4
+ from sympy.functions.elementary.miscellaneous import sqrt
5
+ from sympy.functions.elementary.piecewise import Piecewise
6
+ from sympy.functions.elementary.trigonometric import sin
7
+ from sympy.matrices.dense import MutableDenseMatrix as Matrix
8
+ from sympy.sets.sets import Interval
9
+ from sympy.utilities.lambdify import lambdify
10
+ from sympy.testing.pytest import raises
11
+
12
+ from sympy.printing.tensorflow import TensorflowPrinter
13
+ from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter
14
+
15
+
16
+ x, y, z = symbols("x,y,z")
17
+ i, a, b = symbols("i,a,b")
18
+ j, c, d = symbols("j,c,d")
19
+
20
+
21
+ def test_basic():
22
+ assert lambdarepr(x*y) == "x*y"
23
+ assert lambdarepr(x + y) in ["y + x", "x + y"]
24
+ assert lambdarepr(x**y) == "x**y"
25
+
26
+
27
+ def test_matrix():
28
+ # Test printing a Matrix that has an element that is printed differently
29
+ # with the LambdaPrinter than with the StrPrinter.
30
+ e = x % 2
31
+ assert lambdarepr(e) != str(e)
32
+ assert lambdarepr(Matrix([e])) == 'ImmutableDenseMatrix([[x % 2]])'
33
+
34
+
35
+ def test_piecewise():
36
+ # In each case, test eval() the lambdarepr() to make sure there are a
37
+ # correct number of parentheses. It will give a SyntaxError if there aren't.
38
+
39
+ h = "lambda x: "
40
+
41
+ p = Piecewise((x, x < 0))
42
+ l = lambdarepr(p)
43
+ eval(h + l)
44
+ assert l == "((x) if (x < 0) else None)"
45
+
46
+ p = Piecewise(
47
+ (1, x < 1),
48
+ (2, x < 2),
49
+ (0, True)
50
+ )
51
+ l = lambdarepr(p)
52
+ eval(h + l)
53
+ assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))"
54
+
55
+ p = Piecewise(
56
+ (1, x < 1),
57
+ (2, x < 2),
58
+ )
59
+ l = lambdarepr(p)
60
+ eval(h + l)
61
+ assert l == "((1) if (x < 1) else (2) if (x < 2) else None)"
62
+
63
+ p = Piecewise(
64
+ (x, x < 1),
65
+ (x**2, Interval(3, 4, True, False).contains(x)),
66
+ (0, True),
67
+ )
68
+ l = lambdarepr(p)
69
+ eval(h + l)
70
+ assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))"
71
+
72
+ p = Piecewise(
73
+ (x**2, x < 0),
74
+ (x, x < 1),
75
+ (2 - x, x >= 1),
76
+ (0, True), evaluate=False
77
+ )
78
+ l = lambdarepr(p)
79
+ eval(h + l)
80
+ assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\
81
+ " else (2 - x) if (x >= 1) else (0))"
82
+
83
+ p = Piecewise(
84
+ (x**2, x < 0),
85
+ (x, x < 1),
86
+ (2 - x, x >= 1), evaluate=False
87
+ )
88
+ l = lambdarepr(p)
89
+ eval(h + l)
90
+ assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\
91
+ " else (2 - x) if (x >= 1) else None)"
92
+
93
+ p = Piecewise(
94
+ (1, x >= 1),
95
+ (2, x >= 2),
96
+ (3, x >= 3),
97
+ (4, x >= 4),
98
+ (5, x >= 5),
99
+ (6, True)
100
+ )
101
+ l = lambdarepr(p)
102
+ eval(h + l)
103
+ assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\
104
+ " else (4) if (x >= 4) else (5) if (x >= 5) else (6))"
105
+
106
+ p = Piecewise(
107
+ (1, x <= 1),
108
+ (2, x <= 2),
109
+ (3, x <= 3),
110
+ (4, x <= 4),
111
+ (5, x <= 5),
112
+ (6, True)
113
+ )
114
+ l = lambdarepr(p)
115
+ eval(h + l)
116
+ assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\
117
+ " else (4) if (x <= 4) else (5) if (x <= 5) else (6))"
118
+
119
+ p = Piecewise(
120
+ (1, x > 1),
121
+ (2, x > 2),
122
+ (3, x > 3),
123
+ (4, x > 4),
124
+ (5, x > 5),
125
+ (6, True)
126
+ )
127
+ l = lambdarepr(p)
128
+ eval(h + l)
129
+ assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\
130
+ " else (4) if (x > 4) else (5) if (x > 5) else (6))"
131
+
132
+ p = Piecewise(
133
+ (1, x < 1),
134
+ (2, x < 2),
135
+ (3, x < 3),
136
+ (4, x < 4),
137
+ (5, x < 5),
138
+ (6, True)
139
+ )
140
+ l = lambdarepr(p)
141
+ eval(h + l)
142
+ assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\
143
+ " else (4) if (x < 4) else (5) if (x < 5) else (6))"
144
+
145
+ p = Piecewise(
146
+ (Piecewise(
147
+ (1, x > 0),
148
+ (2, True)
149
+ ), y > 0),
150
+ (3, True)
151
+ )
152
+ l = lambdarepr(p)
153
+ eval(h + l)
154
+ assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))"
155
+
156
+
157
+ def test_sum__1():
158
+ # In each case, test eval() the lambdarepr() to make sure that
159
+ # it evaluates to the same results as the symbolic expression
160
+ s = Sum(x ** i, (i, a, b))
161
+ l = lambdarepr(s)
162
+ assert l == "(builtins.sum(x**i for i in range(a, b+1)))"
163
+
164
+ args = x, a, b
165
+ f = lambdify(args, s)
166
+ v = 2, 3, 8
167
+ assert f(*v) == s.subs(zip(args, v)).doit()
168
+
169
+ def test_sum__2():
170
+ s = Sum(i * x, (i, a, b))
171
+ l = lambdarepr(s)
172
+ assert l == "(builtins.sum(i*x for i in range(a, b+1)))"
173
+
174
+ args = x, a, b
175
+ f = lambdify(args, s)
176
+ v = 2, 3, 8
177
+ assert f(*v) == s.subs(zip(args, v)).doit()
178
+
179
+
180
+ def test_multiple_sums():
181
+ s = Sum(i * x + j, (i, a, b), (j, c, d))
182
+
183
+ l = lambdarepr(s)
184
+ assert l == "(builtins.sum(i*x + j for j in range(c, d+1) for i in range(a, b+1)))"
185
+
186
+ args = x, a, b, c, d
187
+ f = lambdify(args, s)
188
+ vals = 2, 3, 4, 5, 6
189
+ f_ref = s.subs(zip(args, vals)).doit()
190
+ f_res = f(*vals)
191
+ assert f_res == f_ref
192
+
193
+
194
+ def test_sqrt():
195
+ prntr = LambdaPrinter({'standard' : 'python3'})
196
+ assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)'
197
+ assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
198
+
199
+
200
+ def test_settings():
201
+ raises(TypeError, lambda: lambdarepr(sin(x), method="garbage"))
202
+
203
+
204
+ def test_numexpr():
205
+ # test ITE rewrite as Piecewise
206
+ from sympy.logic.boolalg import ITE
207
+ expr = ITE(x > 0, True, False, evaluate=False)
208
+ assert NumExprPrinter().doprint(expr) == \
209
+ "numexpr.evaluate('where((x > 0), True, False)', truediv=True)"
210
+
211
+ from sympy.codegen.ast import Return, FunctionDefinition, Variable, Assignment
212
+ func_def = FunctionDefinition(None, 'foo', [Variable(x)], [Assignment(y,x), Return(y**2)])
213
+ expected = "def foo(x):\n"\
214
+ " y = numexpr.evaluate('x', truediv=True)\n"\
215
+ " return numexpr.evaluate('y**2', truediv=True)"
216
+ assert NumExprPrinter().doprint(func_def) == expected
217
+
218
+
219
+ class CustomPrintedObject(Expr):
220
+ def _lambdacode(self, printer):
221
+ return 'lambda'
222
+
223
+ def _tensorflowcode(self, printer):
224
+ return 'tensorflow'
225
+
226
+ def _numpycode(self, printer):
227
+ return 'numpy'
228
+
229
+ def _numexprcode(self, printer):
230
+ return 'numexpr'
231
+
232
+ def _mpmathcode(self, printer):
233
+ return 'mpmath'
234
+
235
+
236
+ def test_printmethod():
237
+ # In each case, printmethod is called to test
238
+ # its working
239
+
240
+ obj = CustomPrintedObject()
241
+ assert LambdaPrinter().doprint(obj) == 'lambda'
242
+ assert TensorflowPrinter().doprint(obj) == 'tensorflow'
243
+ assert NumExprPrinter().doprint(obj) == "numexpr.evaluate('numexpr', truediv=True)"
244
+
245
+ assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \
246
+ "numexpr.evaluate('where((x >= 0), y, z)', truediv=True)"
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_latex.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_maple.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
2
+ Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge)
3
+ from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow
4
+ from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc, lucas
5
+ from sympy.testing.pytest import raises
6
+ from sympy.utilities.lambdify import implemented_function
7
+ from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
8
+ HadamardProduct, SparseMatrix)
9
+ from sympy.functions.special.bessel import besseli
10
+
11
+ from sympy.printing.maple import maple_code
12
+
13
+ x, y, z = symbols('x,y,z')
14
+
15
+
16
+ def test_Integer():
17
+ assert maple_code(Integer(67)) == "67"
18
+ assert maple_code(Integer(-1)) == "-1"
19
+
20
+
21
+ def test_Rational():
22
+ assert maple_code(Rational(3, 7)) == "3/7"
23
+ assert maple_code(Rational(18, 9)) == "2"
24
+ assert maple_code(Rational(3, -7)) == "-3/7"
25
+ assert maple_code(Rational(-3, -7)) == "3/7"
26
+ assert maple_code(x + Rational(3, 7)) == "x + 3/7"
27
+ assert maple_code(Rational(3, 7) * x) == '(3/7)*x'
28
+
29
+
30
+ def test_Relational():
31
+ assert maple_code(Eq(x, y)) == "x = y"
32
+ assert maple_code(Ne(x, y)) == "x <> y"
33
+ assert maple_code(Le(x, y)) == "x <= y"
34
+ assert maple_code(Lt(x, y)) == "x < y"
35
+ assert maple_code(Gt(x, y)) == "x > y"
36
+ assert maple_code(Ge(x, y)) == "x >= y"
37
+
38
+
39
+ def test_Function():
40
+ assert maple_code(sin(x) ** cos(x)) == "sin(x)^cos(x)"
41
+ assert maple_code(abs(x)) == "abs(x)"
42
+ assert maple_code(ceiling(x)) == "ceil(x)"
43
+
44
+
45
+ def test_Pow():
46
+ assert maple_code(x ** 3) == "x^3"
47
+ assert maple_code(x ** (y ** 3)) == "x^(y^3)"
48
+
49
+ assert maple_code((x ** 3) ** y) == "(x^3)^y"
50
+ assert maple_code(x ** Rational(2, 3)) == 'x^(2/3)'
51
+
52
+ g = implemented_function('g', Lambda(x, 2 * x))
53
+ assert maple_code(1 / (g(x) * 3.5) ** (x - y ** x) / (x ** 2 + y)) == \
54
+ "(3.5*2*x)^(-x + y^x)/(x^2 + y)"
55
+ # For issue 14160
56
+ assert maple_code(Mul(-2, x, Pow(Mul(y, y, evaluate=False), -1, evaluate=False),
57
+ evaluate=False)) == '-2*x/(y*y)'
58
+
59
+
60
+ def test_basic_ops():
61
+ assert maple_code(x * y) == "x*y"
62
+ assert maple_code(x + y) == "x + y"
63
+ assert maple_code(x - y) == "x - y"
64
+ assert maple_code(-x) == "-x"
65
+
66
+
67
+ def test_1_over_x_and_sqrt():
68
+ # 1.0 and 0.5 would do something different in regular StrPrinter,
69
+ # but these are exact in IEEE floating point so no different here.
70
+ assert maple_code(1 / x) == '1/x'
71
+ assert maple_code(x ** -1) == maple_code(x ** -1.0) == '1/x'
72
+ assert maple_code(1 / sqrt(x)) == '1/sqrt(x)'
73
+ assert maple_code(x ** -S.Half) == maple_code(x ** -0.5) == '1/sqrt(x)'
74
+ assert maple_code(sqrt(x)) == 'sqrt(x)'
75
+ assert maple_code(x ** S.Half) == maple_code(x ** 0.5) == 'sqrt(x)'
76
+ assert maple_code(1 / pi) == '1/Pi'
77
+ assert maple_code(pi ** -1) == maple_code(pi ** -1.0) == '1/Pi'
78
+ assert maple_code(pi ** -0.5) == '1/sqrt(Pi)'
79
+
80
+
81
+ def test_mix_number_mult_symbols():
82
+ assert maple_code(3 * x) == "3*x"
83
+ assert maple_code(pi * x) == "Pi*x"
84
+ assert maple_code(3 / x) == "3/x"
85
+ assert maple_code(pi / x) == "Pi/x"
86
+ assert maple_code(x / 3) == '(1/3)*x'
87
+ assert maple_code(x / pi) == "x/Pi"
88
+ assert maple_code(x * y) == "x*y"
89
+ assert maple_code(3 * x * y) == "3*x*y"
90
+ assert maple_code(3 * pi * x * y) == "3*Pi*x*y"
91
+ assert maple_code(x / y) == "x/y"
92
+ assert maple_code(3 * x / y) == "3*x/y"
93
+ assert maple_code(x * y / z) == "x*y/z"
94
+ assert maple_code(x / y * z) == "x*z/y"
95
+ assert maple_code(1 / x / y) == "1/(x*y)"
96
+ assert maple_code(2 * pi * x / y / z) == "2*Pi*x/(y*z)"
97
+ assert maple_code(3 * pi / x) == "3*Pi/x"
98
+ assert maple_code(S(3) / 5) == "3/5"
99
+ assert maple_code(S(3) / 5 * x) == '(3/5)*x'
100
+ assert maple_code(x / y / z) == "x/(y*z)"
101
+ assert maple_code((x + y) / z) == "(x + y)/z"
102
+ assert maple_code((x + y) / (z + x)) == "(x + y)/(x + z)"
103
+ assert maple_code((x + y) / EulerGamma) == '(x + y)/gamma'
104
+ assert maple_code(x / 3 / pi) == '(1/3)*x/Pi'
105
+ assert maple_code(S(3) / 5 * x * y / pi) == '(3/5)*x*y/Pi'
106
+
107
+
108
+ def test_mix_number_pow_symbols():
109
+ assert maple_code(pi ** 3) == 'Pi^3'
110
+ assert maple_code(x ** 2) == 'x^2'
111
+
112
+ assert maple_code(x ** (pi ** 3)) == 'x^(Pi^3)'
113
+ assert maple_code(x ** y) == 'x^y'
114
+
115
+ assert maple_code(x ** (y ** z)) == 'x^(y^z)'
116
+ assert maple_code((x ** y) ** z) == '(x^y)^z'
117
+
118
+
119
+ def test_imag():
120
+ I = S('I')
121
+ assert maple_code(I) == "I"
122
+ assert maple_code(5 * I) == "5*I"
123
+
124
+ assert maple_code((S(3) / 2) * I) == "(3/2)*I"
125
+ assert maple_code(3 + 4 * I) == "3 + 4*I"
126
+
127
+
128
+ def test_constants():
129
+ assert maple_code(pi) == "Pi"
130
+ assert maple_code(oo) == "infinity"
131
+ assert maple_code(-oo) == "-infinity"
132
+ assert maple_code(S.NegativeInfinity) == "-infinity"
133
+ assert maple_code(S.NaN) == "undefined"
134
+ assert maple_code(S.Exp1) == "exp(1)"
135
+ assert maple_code(exp(1)) == "exp(1)"
136
+
137
+
138
+ def test_constants_other():
139
+ assert maple_code(2 * GoldenRatio) == '2*(1/2 + (1/2)*sqrt(5))'
140
+ assert maple_code(2 * Catalan) == '2*Catalan'
141
+ assert maple_code(2 * EulerGamma) == "2*gamma"
142
+
143
+
144
+ def test_boolean():
145
+ assert maple_code(x & y) == "x and y"
146
+ assert maple_code(x | y) == "x or y"
147
+ assert maple_code(~x) == "not x"
148
+ assert maple_code(x & y & z) == "x and y and z"
149
+ assert maple_code(x | y | z) == "x or y or z"
150
+ assert maple_code((x & y) | z) == "z or x and y"
151
+ assert maple_code((x | y) & z) == "z and (x or y)"
152
+
153
+
154
+ def test_Matrices():
155
+ assert maple_code(Matrix(1, 1, [10])) == \
156
+ 'Matrix([[10]], storage = rectangular)'
157
+
158
+ A = Matrix([[1, sin(x / 2), abs(x)],
159
+ [0, 1, pi],
160
+ [0, exp(1), ceiling(x)]])
161
+ expected = \
162
+ 'Matrix(' \
163
+ '[[1, sin((1/2)*x), abs(x)],' \
164
+ ' [0, 1, Pi],' \
165
+ ' [0, exp(1), ceil(x)]], ' \
166
+ 'storage = rectangular)'
167
+ assert maple_code(A) == expected
168
+
169
+ # row and columns
170
+ assert maple_code(A[:, 0]) == \
171
+ 'Matrix([[1], [0], [0]], storage = rectangular)'
172
+ assert maple_code(A[0, :]) == \
173
+ 'Matrix([[1, sin((1/2)*x), abs(x)]], storage = rectangular)'
174
+ assert maple_code(Matrix([[x, x - y, -y]])) == \
175
+ 'Matrix([[x, x - y, -y]], storage = rectangular)'
176
+
177
+ # empty matrices
178
+ assert maple_code(Matrix(0, 0, [])) == \
179
+ 'Matrix([], storage = rectangular)'
180
+ assert maple_code(Matrix(0, 3, [])) == \
181
+ 'Matrix([], storage = rectangular)'
182
+
183
+ def test_SparseMatrices():
184
+ assert maple_code(SparseMatrix(Identity(2))) == 'Matrix([[1, 0], [0, 1]], storage = sparse)'
185
+
186
+
187
+ def test_vector_entries_hadamard():
188
+ # For a row or column, user might to use the other dimension
189
+ A = Matrix([[1, sin(2 / x), 3 * pi / x / 5]])
190
+ assert maple_code(A) == \
191
+ 'Matrix([[1, sin(2/x), (3/5)*Pi/x]], storage = rectangular)'
192
+ assert maple_code(A.T) == \
193
+ 'Matrix([[1], [sin(2/x)], [(3/5)*Pi/x]], storage = rectangular)'
194
+
195
+
196
+ def test_Matrices_entries_not_hadamard():
197
+ A = Matrix([[1, sin(2 / x), 3 * pi / x / 5], [1, 2, x * y]])
198
+ expected = \
199
+ 'Matrix([[1, sin(2/x), (3/5)*Pi/x], [1, 2, x*y]], ' \
200
+ 'storage = rectangular)'
201
+ assert maple_code(A) == expected
202
+
203
+
204
+ def test_MatrixSymbol():
205
+ n = Symbol('n', integer=True)
206
+ A = MatrixSymbol('A', n, n)
207
+ B = MatrixSymbol('B', n, n)
208
+ assert maple_code(A * B) == "A.B"
209
+ assert maple_code(B * A) == "B.A"
210
+ assert maple_code(2 * A * B) == "2*A.B"
211
+ assert maple_code(B * 2 * A) == "2*B.A"
212
+
213
+ assert maple_code(
214
+ A * (B + 3 * Identity(n))) == "A.(3*Matrix(n, shape = identity) + B)"
215
+
216
+ assert maple_code(A ** (x ** 2)) == "MatrixPower(A, x^2)"
217
+ assert maple_code(A ** 3) == "MatrixPower(A, 3)"
218
+ assert maple_code(A ** (S.Half)) == "MatrixPower(A, 1/2)"
219
+
220
+
221
+ def test_special_matrices():
222
+ assert maple_code(6 * Identity(3)) == "6*Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = sparse)"
223
+ assert maple_code(Identity(x)) == 'Matrix(x, shape = identity)'
224
+
225
+
226
+ def test_containers():
227
+ assert maple_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
228
+ "[1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]"
229
+
230
+ assert maple_code((1, 2, (3, 4))) == "[1, 2, [3, 4]]"
231
+ assert maple_code([1]) == "[1]"
232
+ assert maple_code((1,)) == "[1]"
233
+ assert maple_code(Tuple(*[1, 2, 3])) == "[1, 2, 3]"
234
+ assert maple_code((1, x * y, (3, x ** 2))) == "[1, x*y, [3, x^2]]"
235
+ # scalar, matrix, empty matrix and empty list
236
+
237
+ assert maple_code((1, eye(3), Matrix(0, 0, []), [])) == \
238
+ "[1, Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = rectangular), Matrix([], storage = rectangular), []]"
239
+
240
+
241
+ def test_maple_noninline():
242
+ source = maple_code((x + y)/Catalan, assign_to='me', inline=False)
243
+ expected = "me := (x + y)/Catalan"
244
+
245
+ assert source == expected
246
+
247
+
248
+ def test_maple_matrix_assign_to():
249
+ A = Matrix([[1, 2, 3]])
250
+ assert maple_code(A, assign_to='a') == "a := Matrix([[1, 2, 3]], storage = rectangular)"
251
+ A = Matrix([[1, 2], [3, 4]])
252
+ assert maple_code(A, assign_to='A') == "A := Matrix([[1, 2], [3, 4]], storage = rectangular)"
253
+
254
+
255
+ def test_maple_matrix_assign_to_more():
256
+ # assigning to Symbol or MatrixSymbol requires lhs/rhs match
257
+ A = Matrix([[1, 2, 3]])
258
+ B = MatrixSymbol('B', 1, 3)
259
+ C = MatrixSymbol('C', 2, 3)
260
+ assert maple_code(A, assign_to=B) == "B := Matrix([[1, 2, 3]], storage = rectangular)"
261
+ raises(ValueError, lambda: maple_code(A, assign_to=x))
262
+ raises(ValueError, lambda: maple_code(A, assign_to=C))
263
+
264
+
265
+ def test_maple_matrix_1x1():
266
+ A = Matrix([[3]])
267
+ assert maple_code(A, assign_to='B') == "B := Matrix([[3]], storage = rectangular)"
268
+
269
+
270
+ def test_maple_matrix_elements():
271
+ A = Matrix([[x, 2, x * y]])
272
+
273
+ assert maple_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x^2 + x*y + 2"
274
+ AA = MatrixSymbol('AA', 1, 3)
275
+ assert maple_code(AA) == "AA"
276
+
277
+ assert maple_code(AA[0, 0] ** 2 + sin(AA[0, 1]) + AA[0, 2]) == \
278
+ "sin(AA[1, 2]) + AA[1, 1]^2 + AA[1, 3]"
279
+ assert maple_code(sum(AA)) == "AA[1, 1] + AA[1, 2] + AA[1, 3]"
280
+
281
+
282
+ def test_maple_boolean():
283
+ assert maple_code(True) == "true"
284
+ assert maple_code(S.true) == "true"
285
+ assert maple_code(False) == "false"
286
+ assert maple_code(S.false) == "false"
287
+
288
+
289
+ def test_sparse():
290
+ M = SparseMatrix(5, 6, {})
291
+ M[2, 2] = 10
292
+ M[1, 2] = 20
293
+ M[1, 3] = 22
294
+ M[0, 3] = 30
295
+ M[3, 0] = x * y
296
+ assert maple_code(M) == \
297
+ 'Matrix([[0, 0, 0, 30, 0, 0],' \
298
+ ' [0, 0, 20, 22, 0, 0],' \
299
+ ' [0, 0, 10, 0, 0, 0],' \
300
+ ' [x*y, 0, 0, 0, 0, 0],' \
301
+ ' [0, 0, 0, 0, 0, 0]], ' \
302
+ 'storage = sparse)'
303
+
304
+ # Not an important point.
305
+ def test_maple_not_supported():
306
+ with raises(NotImplementedError):
307
+ maple_code(S.ComplexInfinity)
308
+
309
+
310
+ def test_MatrixElement_printing():
311
+ # test cases for issue #11821
312
+ A = MatrixSymbol("A", 1, 3)
313
+ B = MatrixSymbol("B", 1, 3)
314
+
315
+ assert (maple_code(A[0, 0]) == "A[1, 1]")
316
+ assert (maple_code(3 * A[0, 0]) == "3*A[1, 1]")
317
+
318
+ F = A-B
319
+
320
+ assert (maple_code(F[0,0]) == "A[1, 1] - B[1, 1]")
321
+
322
+
323
+ def test_hadamard():
324
+ A = MatrixSymbol('A', 3, 3)
325
+ B = MatrixSymbol('B', 3, 3)
326
+ v = MatrixSymbol('v', 3, 1)
327
+ h = MatrixSymbol('h', 1, 3)
328
+ C = HadamardProduct(A, B)
329
+ assert maple_code(C) == "A*B"
330
+
331
+ assert maple_code(C * v) == "(A*B).v"
332
+ # HadamardProduct is higher than dot product.
333
+
334
+ assert maple_code(h * C * v) == "h.(A*B).v"
335
+
336
+ assert maple_code(C * A) == "(A*B).A"
337
+ # mixing Hadamard and scalar strange b/c we vectorize scalars
338
+
339
+ assert maple_code(C * x * y) == "x*y*(A*B)"
340
+
341
+
342
+ def test_maple_piecewise():
343
+ expr = Piecewise((x, x < 1), (x ** 2, True))
344
+
345
+ assert maple_code(expr) == "piecewise(x < 1, x, x^2)"
346
+ assert maple_code(expr, assign_to="r") == (
347
+ "r := piecewise(x < 1, x, x^2)")
348
+
349
+ expr = Piecewise((x ** 2, x < 1), (x ** 3, x < 2), (x ** 4, x < 3), (x ** 5, True))
350
+ expected = "piecewise(x < 1, x^2, x < 2, x^3, x < 3, x^4, x^5)"
351
+ assert maple_code(expr) == expected
352
+ assert maple_code(expr, assign_to="r") == "r := " + expected
353
+
354
+ # Check that Piecewise without a True (default) condition error
355
+ expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
356
+ raises(ValueError, lambda: maple_code(expr))
357
+
358
+
359
+ def test_maple_piecewise_times_const():
360
+ pw = Piecewise((x, x < 1), (x ** 2, True))
361
+
362
+ assert maple_code(2 * pw) == "2*piecewise(x < 1, x, x^2)"
363
+ assert maple_code(pw / x) == "piecewise(x < 1, x, x^2)/x"
364
+ assert maple_code(pw / (x * y)) == "piecewise(x < 1, x, x^2)/(x*y)"
365
+ assert maple_code(pw / 3) == "(1/3)*piecewise(x < 1, x, x^2)"
366
+
367
+
368
+ def test_maple_derivatives():
369
+ f = Function('f')
370
+ assert maple_code(f(x).diff(x)) == 'diff(f(x), x)'
371
+ assert maple_code(f(x).diff(x, 2)) == 'diff(f(x), x$2)'
372
+
373
+
374
+ def test_automatic_rewrites():
375
+ assert maple_code(lucas(x)) == '(2^(-x)*((1 - sqrt(5))^x + (1 + sqrt(5))^x))'
376
+ assert maple_code(sinc(x)) == '(piecewise(x <> 0, sin(x)/x, 1))'
377
+
378
+
379
+ def test_specfun():
380
+ assert maple_code('asin(x)') == 'arcsin(x)'
381
+ assert maple_code(besseli(x, y)) == 'BesselI(x, y)'
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathematica.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, Tuple,
2
+ Derivative, Eq, Ne, Le, Lt, Gt, Ge)
3
+ from sympy.integrals import Integral
4
+ from sympy.concrete import Sum
5
+ from sympy.functions import (exp, sin, cos, fresnelc, fresnels, conjugate, Max,
6
+ Min, gamma, polygamma, loggamma, erf, erfi, erfc,
7
+ erf2, expint, erfinv, erfcinv, Ei, Si, Ci, li,
8
+ Shi, Chi, uppergamma, beta, subfactorial, erf2inv,
9
+ factorial, factorial2, catalan, RisingFactorial,
10
+ FallingFactorial, harmonic, atan2, sec, acsc,
11
+ hermite, laguerre, assoc_laguerre, jacobi,
12
+ gegenbauer, chebyshevt, chebyshevu, legendre,
13
+ assoc_legendre, Li, LambertW)
14
+
15
+ from sympy.printing.mathematica import mathematica_code as mcode
16
+
17
+ x, y, z, w = symbols('x,y,z,w')
18
+ f = Function('f')
19
+
20
+
21
+ def test_Integer():
22
+ assert mcode(Integer(67)) == "67"
23
+ assert mcode(Integer(-1)) == "-1"
24
+
25
+
26
+ def test_Rational():
27
+ assert mcode(Rational(3, 7)) == "3/7"
28
+ assert mcode(Rational(18, 9)) == "2"
29
+ assert mcode(Rational(3, -7)) == "-3/7"
30
+ assert mcode(Rational(-3, -7)) == "3/7"
31
+ assert mcode(x + Rational(3, 7)) == "x + 3/7"
32
+ assert mcode(Rational(3, 7)*x) == "(3/7)*x"
33
+
34
+
35
+ def test_Relational():
36
+ assert mcode(Eq(x, y)) == "x == y"
37
+ assert mcode(Ne(x, y)) == "x != y"
38
+ assert mcode(Le(x, y)) == "x <= y"
39
+ assert mcode(Lt(x, y)) == "x < y"
40
+ assert mcode(Gt(x, y)) == "x > y"
41
+ assert mcode(Ge(x, y)) == "x >= y"
42
+
43
+
44
+ def test_Function():
45
+ assert mcode(f(x, y, z)) == "f[x, y, z]"
46
+ assert mcode(sin(x) ** cos(x)) == "Sin[x]^Cos[x]"
47
+ assert mcode(sec(x) * acsc(x)) == "ArcCsc[x]*Sec[x]"
48
+ assert mcode(atan2(y, x)) == "ArcTan[x, y]"
49
+ assert mcode(conjugate(x)) == "Conjugate[x]"
50
+ assert mcode(Max(x, y, z)*Min(y, z)) == "Max[x, y, z]*Min[y, z]"
51
+ assert mcode(fresnelc(x)) == "FresnelC[x]"
52
+ assert mcode(fresnels(x)) == "FresnelS[x]"
53
+ assert mcode(gamma(x)) == "Gamma[x]"
54
+ assert mcode(uppergamma(x, y)) == "Gamma[x, y]"
55
+ assert mcode(polygamma(x, y)) == "PolyGamma[x, y]"
56
+ assert mcode(loggamma(x)) == "LogGamma[x]"
57
+ assert mcode(erf(x)) == "Erf[x]"
58
+ assert mcode(erfc(x)) == "Erfc[x]"
59
+ assert mcode(erfi(x)) == "Erfi[x]"
60
+ assert mcode(erf2(x, y)) == "Erf[x, y]"
61
+ assert mcode(expint(x, y)) == "ExpIntegralE[x, y]"
62
+ assert mcode(erfcinv(x)) == "InverseErfc[x]"
63
+ assert mcode(erfinv(x)) == "InverseErf[x]"
64
+ assert mcode(erf2inv(x, y)) == "InverseErf[x, y]"
65
+ assert mcode(Ei(x)) == "ExpIntegralEi[x]"
66
+ assert mcode(Ci(x)) == "CosIntegral[x]"
67
+ assert mcode(li(x)) == "LogIntegral[x]"
68
+ assert mcode(Si(x)) == "SinIntegral[x]"
69
+ assert mcode(Shi(x)) == "SinhIntegral[x]"
70
+ assert mcode(Chi(x)) == "CoshIntegral[x]"
71
+ assert mcode(beta(x, y)) == "Beta[x, y]"
72
+ assert mcode(factorial(x)) == "Factorial[x]"
73
+ assert mcode(factorial2(x)) == "Factorial2[x]"
74
+ assert mcode(subfactorial(x)) == "Subfactorial[x]"
75
+ assert mcode(FallingFactorial(x, y)) == "FactorialPower[x, y]"
76
+ assert mcode(RisingFactorial(x, y)) == "Pochhammer[x, y]"
77
+ assert mcode(catalan(x)) == "CatalanNumber[x]"
78
+ assert mcode(harmonic(x)) == "HarmonicNumber[x]"
79
+ assert mcode(harmonic(x, y)) == "HarmonicNumber[x, y]"
80
+ assert mcode(Li(x)) == "LogIntegral[x] - LogIntegral[2]"
81
+ assert mcode(LambertW(x)) == "ProductLog[x]"
82
+ assert mcode(LambertW(x, -1)) == "ProductLog[-1, x]"
83
+ assert mcode(LambertW(x, y)) == "ProductLog[y, x]"
84
+
85
+
86
+ def test_special_polynomials():
87
+ assert mcode(hermite(x, y)) == "HermiteH[x, y]"
88
+ assert mcode(laguerre(x, y)) == "LaguerreL[x, y]"
89
+ assert mcode(assoc_laguerre(x, y, z)) == "LaguerreL[x, y, z]"
90
+ assert mcode(jacobi(x, y, z, w)) == "JacobiP[x, y, z, w]"
91
+ assert mcode(gegenbauer(x, y, z)) == "GegenbauerC[x, y, z]"
92
+ assert mcode(chebyshevt(x, y)) == "ChebyshevT[x, y]"
93
+ assert mcode(chebyshevu(x, y)) == "ChebyshevU[x, y]"
94
+ assert mcode(legendre(x, y)) == "LegendreP[x, y]"
95
+ assert mcode(assoc_legendre(x, y, z)) == "LegendreP[x, y, z]"
96
+
97
+
98
+ def test_Pow():
99
+ assert mcode(x**3) == "x^3"
100
+ assert mcode(x**(y**3)) == "x^(y^3)"
101
+ assert mcode(1/(f(x)*3.5)**(x - y**x)/(x**2 + y)) == \
102
+ "(3.5*f[x])^(-x + y^x)/(x^2 + y)"
103
+ assert mcode(x**-1.0) == 'x^(-1.0)'
104
+ assert mcode(x**Rational(2, 3)) == 'x^(2/3)'
105
+
106
+
107
+ def test_Mul():
108
+ A, B, C, D = symbols('A B C D', commutative=False)
109
+ assert mcode(x*y*z) == "x*y*z"
110
+ assert mcode(x*y*A) == "x*y*A"
111
+ assert mcode(x*y*A*B) == "x*y*A**B"
112
+ assert mcode(x*y*A*B*C) == "x*y*A**B**C"
113
+ assert mcode(x*A*B*(C + D)*A*y) == "x*y*A**B**(C + D)**A"
114
+
115
+
116
+ def test_constants():
117
+ assert mcode(S.Zero) == "0"
118
+ assert mcode(S.One) == "1"
119
+ assert mcode(S.NegativeOne) == "-1"
120
+ assert mcode(S.Half) == "1/2"
121
+ assert mcode(S.ImaginaryUnit) == "I"
122
+
123
+ assert mcode(oo) == "Infinity"
124
+ assert mcode(S.NegativeInfinity) == "-Infinity"
125
+ assert mcode(S.ComplexInfinity) == "ComplexInfinity"
126
+ assert mcode(S.NaN) == "Indeterminate"
127
+
128
+ assert mcode(S.Exp1) == "E"
129
+ assert mcode(pi) == "Pi"
130
+ assert mcode(S.GoldenRatio) == "GoldenRatio"
131
+ assert mcode(S.TribonacciConstant) == \
132
+ "(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \
133
+ "(1/3)*(3*33^(1/2) + 19)^(1/3))"
134
+ assert mcode(2*S.TribonacciConstant) == \
135
+ "2*(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \
136
+ "(1/3)*(3*33^(1/2) + 19)^(1/3))"
137
+ assert mcode(S.EulerGamma) == "EulerGamma"
138
+ assert mcode(S.Catalan) == "Catalan"
139
+
140
+
141
+ def test_containers():
142
+ assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
143
+ "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}"
144
+ assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}"
145
+ assert mcode([1]) == "{1}"
146
+ assert mcode((1,)) == "{1}"
147
+ assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}"
148
+
149
+
150
+ def test_matrices():
151
+ from sympy.matrices import MutableDenseMatrix, MutableSparseMatrix, \
152
+ ImmutableDenseMatrix, ImmutableSparseMatrix
153
+ A = MutableDenseMatrix(
154
+ [[1, -1, 0, 0],
155
+ [0, 1, -1, 0],
156
+ [0, 0, 1, -1],
157
+ [0, 0, 0, 1]]
158
+ )
159
+ B = MutableSparseMatrix(A)
160
+ C = ImmutableDenseMatrix(A)
161
+ D = ImmutableSparseMatrix(A)
162
+
163
+ assert mcode(C) == mcode(A) == \
164
+ "{{1, -1, 0, 0}, " \
165
+ "{0, 1, -1, 0}, " \
166
+ "{0, 0, 1, -1}, " \
167
+ "{0, 0, 0, 1}}"
168
+
169
+ assert mcode(D) == mcode(B) == \
170
+ "SparseArray[{" \
171
+ "{1, 1} -> 1, {1, 2} -> -1, {2, 2} -> 1, {2, 3} -> -1, " \
172
+ "{3, 3} -> 1, {3, 4} -> -1, {4, 4} -> 1" \
173
+ "}, {4, 4}]"
174
+
175
+ # Trivial cases of matrices
176
+ assert mcode(MutableDenseMatrix(0, 0, [])) == '{}'
177
+ assert mcode(MutableSparseMatrix(0, 0, [])) == 'SparseArray[{}, {0, 0}]'
178
+ assert mcode(MutableDenseMatrix(0, 3, [])) == '{}'
179
+ assert mcode(MutableSparseMatrix(0, 3, [])) == 'SparseArray[{}, {0, 3}]'
180
+ assert mcode(MutableDenseMatrix(3, 0, [])) == '{{}, {}, {}}'
181
+ assert mcode(MutableSparseMatrix(3, 0, [])) == 'SparseArray[{}, {3, 0}]'
182
+
183
+ def test_NDArray():
184
+ from sympy.tensor.array import (
185
+ MutableDenseNDimArray, ImmutableDenseNDimArray,
186
+ MutableSparseNDimArray, ImmutableSparseNDimArray)
187
+
188
+ example = MutableDenseNDimArray(
189
+ [[[1, 2, 3, 4],
190
+ [5, 6, 7, 8],
191
+ [9, 10, 11, 12]],
192
+ [[13, 14, 15, 16],
193
+ [17, 18, 19, 20],
194
+ [21, 22, 23, 24]]]
195
+ )
196
+
197
+ assert mcode(example) == \
198
+ "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \
199
+ "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}"
200
+
201
+ example = ImmutableDenseNDimArray(example)
202
+
203
+ assert mcode(example) == \
204
+ "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \
205
+ "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}"
206
+
207
+ example = MutableSparseNDimArray(example)
208
+
209
+ assert mcode(example) == \
210
+ "SparseArray[{" \
211
+ "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \
212
+ "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \
213
+ "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \
214
+ "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \
215
+ "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \
216
+ "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \
217
+ "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \
218
+ "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \
219
+ "}, {2, 3, 4}]"
220
+
221
+ example = ImmutableSparseNDimArray(example)
222
+
223
+ assert mcode(example) == \
224
+ "SparseArray[{" \
225
+ "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \
226
+ "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \
227
+ "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \
228
+ "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \
229
+ "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \
230
+ "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \
231
+ "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \
232
+ "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \
233
+ "}, {2, 3, 4}]"
234
+
235
+
236
+ def test_Integral():
237
+ assert mcode(Integral(sin(sin(x)), x)) == "Hold[Integrate[Sin[Sin[x]], x]]"
238
+ assert mcode(Integral(exp(-x**2 - y**2),
239
+ (x, -oo, oo),
240
+ (y, -oo, oo))) == \
241
+ "Hold[Integrate[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \
242
+ "{y, -Infinity, Infinity}]]"
243
+
244
+
245
+ def test_Derivative():
246
+ assert mcode(Derivative(sin(x), x)) == "Hold[D[Sin[x], x]]"
247
+ assert mcode(Derivative(x, x)) == "Hold[D[x, x]]"
248
+ assert mcode(Derivative(sin(x)*y**4, x, 2)) == "Hold[D[y^4*Sin[x], {x, 2}]]"
249
+ assert mcode(Derivative(sin(x)*y**4, x, y, x)) == "Hold[D[y^4*Sin[x], x, y, x]]"
250
+ assert mcode(Derivative(sin(x)*y**4, x, y, 3, x)) == "Hold[D[y^4*Sin[x], x, {y, 3}, x]]"
251
+
252
+
253
+ def test_Sum():
254
+ assert mcode(Sum(sin(x), (x, 0, 10))) == "Hold[Sum[Sin[x], {x, 0, 10}]]"
255
+ assert mcode(Sum(exp(-x**2 - y**2),
256
+ (x, -oo, oo),
257
+ (y, -oo, oo))) == \
258
+ "Hold[Sum[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \
259
+ "{y, -Infinity, Infinity}]]"
260
+
261
+
262
+ def test_comment():
263
+ from sympy.printing.mathematica import MCodePrinter
264
+ assert MCodePrinter()._get_comment("Hello World") == \
265
+ "(* Hello World *)"
266
+
267
+
268
+ def test_userfuncs():
269
+ # Dictionary mutation test
270
+ some_function = symbols("some_function", cls=Function)
271
+ my_user_functions = {"some_function": "SomeFunction"}
272
+ assert mcode(
273
+ some_function(z),
274
+ user_functions=my_user_functions) == \
275
+ 'SomeFunction[z]'
276
+ assert mcode(
277
+ some_function(z),
278
+ user_functions=my_user_functions) == \
279
+ 'SomeFunction[z]'
280
+
281
+ # List argument test
282
+ my_user_functions = \
283
+ {"some_function": [(lambda x: True, "SomeOtherFunction")]}
284
+ assert mcode(
285
+ some_function(z),
286
+ user_functions=my_user_functions) == \
287
+ 'SomeOtherFunction[z]'
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathml.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_numpy.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete.summations import Sum
2
+ from sympy.core.mod import Mod
3
+ from sympy.core.relational import (Equality, Unequality)
4
+ from sympy.core.symbol import Symbol
5
+ from sympy.functions.elementary.miscellaneous import sqrt
6
+ from sympy.functions.elementary.piecewise import Piecewise
7
+ from sympy.functions.special.gamma_functions import polygamma
8
+ from sympy.functions.special.error_functions import (Si, Ci)
9
+ from sympy.matrices import Matrix
10
+ from sympy.matrices.expressions.blockmatrix import BlockMatrix
11
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
12
+ from sympy.matrices.expressions.special import Identity
13
+ from sympy.utilities.lambdify import lambdify
14
+ from sympy import symbols, Min, Max
15
+
16
+ from sympy.abc import x, i, j, a, b, c, d
17
+ from sympy.core import Pow
18
+ from sympy.codegen.matrix_nodes import MatrixSolve
19
+ from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
20
+ from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt
21
+ from sympy.tensor.array import Array
22
+ from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
23
+ PermuteDims, ArrayDiagonal
24
+ from sympy.printing.numpy import NumPyPrinter, SciPyPrinter, _numpy_known_constants, \
25
+ _numpy_known_functions, _scipy_known_constants, _scipy_known_functions
26
+ from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
27
+
28
+ from sympy.testing.pytest import skip, raises
29
+ from sympy.external import import_module
30
+
31
+ np = import_module('numpy')
32
+ jax = import_module('jax')
33
+
34
+ if np:
35
+ deafult_float_info = np.finfo(np.array([]).dtype)
36
+ NUMPY_DEFAULT_EPSILON = deafult_float_info.eps
37
+
38
+ def test_numpy_piecewise_regression():
39
+ """
40
+ NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid
41
+ breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+.
42
+ See gh-9747 and gh-9749 for details.
43
+ """
44
+ printer = NumPyPrinter()
45
+ p = Piecewise((1, x < 0), (0, True))
46
+ assert printer.doprint(p) == \
47
+ 'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)'
48
+ assert printer.module_imports == {'numpy': {'select', 'less', 'nan'}}
49
+
50
+ def test_numpy_logaddexp():
51
+ lae = logaddexp(a, b)
52
+ assert NumPyPrinter().doprint(lae) == 'numpy.logaddexp(a, b)'
53
+ lae2 = logaddexp2(a, b)
54
+ assert NumPyPrinter().doprint(lae2) == 'numpy.logaddexp2(a, b)'
55
+
56
+
57
+ def test_sum():
58
+ if not np:
59
+ skip("NumPy not installed")
60
+
61
+ s = Sum(x ** i, (i, a, b))
62
+ f = lambdify((a, b, x), s, 'numpy')
63
+
64
+ a_, b_ = 0, 10
65
+ x_ = np.linspace(-1, +1, 10)
66
+ assert np.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
67
+
68
+ s = Sum(i * x, (i, a, b))
69
+ f = lambdify((a, b, x), s, 'numpy')
70
+
71
+ a_, b_ = 0, 10
72
+ x_ = np.linspace(-1, +1, 10)
73
+ assert np.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
74
+
75
+
76
+ def test_multiple_sums():
77
+ if not np:
78
+ skip("NumPy not installed")
79
+
80
+ s = Sum((x + j) * i, (i, a, b), (j, c, d))
81
+ f = lambdify((a, b, c, d, x), s, 'numpy')
82
+
83
+ a_, b_ = 0, 10
84
+ c_, d_ = 11, 21
85
+ x_ = np.linspace(-1, +1, 10)
86
+ assert np.allclose(f(a_, b_, c_, d_, x_),
87
+ sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1)))
88
+
89
+
90
+ def test_codegen_einsum():
91
+ if not np:
92
+ skip("NumPy not installed")
93
+
94
+ M = MatrixSymbol("M", 2, 2)
95
+ N = MatrixSymbol("N", 2, 2)
96
+
97
+ cg = convert_matrix_to_array(M * N)
98
+ f = lambdify((M, N), cg, 'numpy')
99
+
100
+ ma = np.array([[1, 2], [3, 4]])
101
+ mb = np.array([[1,-2], [-1, 3]])
102
+ assert (f(ma, mb) == np.matmul(ma, mb)).all()
103
+
104
+
105
+ def test_codegen_extra():
106
+ if not np:
107
+ skip("NumPy not installed")
108
+
109
+ M = MatrixSymbol("M", 2, 2)
110
+ N = MatrixSymbol("N", 2, 2)
111
+ P = MatrixSymbol("P", 2, 2)
112
+ Q = MatrixSymbol("Q", 2, 2)
113
+ ma = np.array([[1, 2], [3, 4]])
114
+ mb = np.array([[1,-2], [-1, 3]])
115
+ mc = np.array([[2, 0], [1, 2]])
116
+ md = np.array([[1,-1], [4, 7]])
117
+
118
+ cg = ArrayTensorProduct(M, N)
119
+ f = lambdify((M, N), cg, 'numpy')
120
+ assert (f(ma, mb) == np.einsum(ma, [0, 1], mb, [2, 3])).all()
121
+
122
+ cg = ArrayAdd(M, N)
123
+ f = lambdify((M, N), cg, 'numpy')
124
+ assert (f(ma, mb) == ma+mb).all()
125
+
126
+ cg = ArrayAdd(M, N, P)
127
+ f = lambdify((M, N, P), cg, 'numpy')
128
+ assert (f(ma, mb, mc) == ma+mb+mc).all()
129
+
130
+ cg = ArrayAdd(M, N, P, Q)
131
+ f = lambdify((M, N, P, Q), cg, 'numpy')
132
+ assert (f(ma, mb, mc, md) == ma+mb+mc+md).all()
133
+
134
+ cg = PermuteDims(M, [1, 0])
135
+ f = lambdify((M,), cg, 'numpy')
136
+ assert (f(ma) == ma.T).all()
137
+
138
+ cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
139
+ f = lambdify((M, N), cg, 'numpy')
140
+ assert (f(ma, mb) == np.transpose(np.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all()
141
+
142
+ cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
143
+ f = lambdify((M, N), cg, 'numpy')
144
+ assert (f(ma, mb) == np.diagonal(np.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all()
145
+
146
+
147
+ def test_relational():
148
+ if not np:
149
+ skip("NumPy not installed")
150
+
151
+ e = Equality(x, 1)
152
+
153
+ f = lambdify((x,), e)
154
+ x_ = np.array([0, 1, 2])
155
+ assert np.array_equal(f(x_), [False, True, False])
156
+
157
+ e = Unequality(x, 1)
158
+
159
+ f = lambdify((x,), e)
160
+ x_ = np.array([0, 1, 2])
161
+ assert np.array_equal(f(x_), [True, False, True])
162
+
163
+ e = (x < 1)
164
+
165
+ f = lambdify((x,), e)
166
+ x_ = np.array([0, 1, 2])
167
+ assert np.array_equal(f(x_), [True, False, False])
168
+
169
+ e = (x <= 1)
170
+
171
+ f = lambdify((x,), e)
172
+ x_ = np.array([0, 1, 2])
173
+ assert np.array_equal(f(x_), [True, True, False])
174
+
175
+ e = (x > 1)
176
+
177
+ f = lambdify((x,), e)
178
+ x_ = np.array([0, 1, 2])
179
+ assert np.array_equal(f(x_), [False, False, True])
180
+
181
+ e = (x >= 1)
182
+
183
+ f = lambdify((x,), e)
184
+ x_ = np.array([0, 1, 2])
185
+ assert np.array_equal(f(x_), [False, True, True])
186
+
187
+
188
+ def test_mod():
189
+ if not np:
190
+ skip("NumPy not installed")
191
+
192
+ e = Mod(a, b)
193
+ f = lambdify((a, b), e)
194
+
195
+ a_ = np.array([0, 1, 2, 3])
196
+ b_ = 2
197
+ assert np.array_equal(f(a_, b_), [0, 1, 0, 1])
198
+
199
+ a_ = np.array([0, 1, 2, 3])
200
+ b_ = np.array([2, 2, 2, 2])
201
+ assert np.array_equal(f(a_, b_), [0, 1, 0, 1])
202
+
203
+ a_ = np.array([2, 3, 4, 5])
204
+ b_ = np.array([2, 3, 4, 5])
205
+ assert np.array_equal(f(a_, b_), [0, 0, 0, 0])
206
+
207
+
208
+ def test_pow():
209
+ if not np:
210
+ skip('NumPy not installed')
211
+
212
+ expr = Pow(2, -1, evaluate=False)
213
+ f = lambdify([], expr, 'numpy')
214
+ assert f() == 0.5
215
+
216
+
217
+ def test_expm1():
218
+ if not np:
219
+ skip("NumPy not installed")
220
+
221
+ f = lambdify((a,), expm1(a), 'numpy')
222
+ assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * NUMPY_DEFAULT_EPSILON
223
+
224
+
225
+ def test_log1p():
226
+ if not np:
227
+ skip("NumPy not installed")
228
+
229
+ f = lambdify((a,), log1p(a), 'numpy')
230
+ assert abs(f(1e-99) - 1e-99) <= 1e-99 * NUMPY_DEFAULT_EPSILON
231
+
232
+ def test_hypot():
233
+ if not np:
234
+ skip("NumPy not installed")
235
+ assert abs(lambdify((a, b), hypot(a, b), 'numpy')(3, 4) - 5) <= NUMPY_DEFAULT_EPSILON
236
+
237
+ def test_log10():
238
+ if not np:
239
+ skip("NumPy not installed")
240
+ assert abs(lambdify((a,), log10(a), 'numpy')(100) - 2) <= NUMPY_DEFAULT_EPSILON
241
+
242
+
243
+ def test_exp2():
244
+ if not np:
245
+ skip("NumPy not installed")
246
+ assert abs(lambdify((a,), exp2(a), 'numpy')(5) - 32) <= NUMPY_DEFAULT_EPSILON
247
+
248
+
249
+ def test_log2():
250
+ if not np:
251
+ skip("NumPy not installed")
252
+ assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) <= NUMPY_DEFAULT_EPSILON
253
+
254
+
255
+ def test_Sqrt():
256
+ if not np:
257
+ skip("NumPy not installed")
258
+ assert abs(lambdify((a,), Sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON
259
+
260
+
261
+ def test_sqrt():
262
+ if not np:
263
+ skip("NumPy not installed")
264
+ assert abs(lambdify((a,), sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON
265
+
266
+
267
+ def test_matsolve():
268
+ if not np:
269
+ skip("NumPy not installed")
270
+
271
+ M = MatrixSymbol("M", 3, 3)
272
+ x = MatrixSymbol("x", 3, 1)
273
+
274
+ expr = M**(-1) * x + x
275
+ matsolve_expr = MatrixSolve(M, x) + x
276
+
277
+ f = lambdify((M, x), expr)
278
+ f_matsolve = lambdify((M, x), matsolve_expr)
279
+
280
+ m0 = np.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]])
281
+ assert np.linalg.matrix_rank(m0) == 3
282
+
283
+ x0 = np.array([3, 4, 5])
284
+
285
+ assert np.allclose(f_matsolve(m0, x0), f(m0, x0))
286
+
287
+
288
+ def test_16857():
289
+ if not np:
290
+ skip("NumPy not installed")
291
+
292
+ a_1 = MatrixSymbol('a_1', 10, 3)
293
+ a_2 = MatrixSymbol('a_2', 10, 3)
294
+ a_3 = MatrixSymbol('a_3', 10, 3)
295
+ a_4 = MatrixSymbol('a_4', 10, 3)
296
+ A = BlockMatrix([[a_1, a_2], [a_3, a_4]])
297
+ assert A.shape == (20, 6)
298
+
299
+ printer = NumPyPrinter()
300
+ assert printer.doprint(A) == 'numpy.block([[a_1, a_2], [a_3, a_4]])'
301
+
302
+
303
+ def test_issue_17006():
304
+ if not np:
305
+ skip("NumPy not installed")
306
+
307
+ M = MatrixSymbol("M", 2, 2)
308
+
309
+ f = lambdify(M, M + Identity(2))
310
+ ma = np.array([[1, 2], [3, 4]])
311
+ mr = np.array([[2, 2], [3, 5]])
312
+
313
+ assert (f(ma) == mr).all()
314
+
315
+ from sympy.core.symbol import symbols
316
+ n = symbols('n', integer=True)
317
+ N = MatrixSymbol("M", n, n)
318
+ raises(NotImplementedError, lambda: lambdify(N, N + Identity(n)))
319
+
320
+ def test_jax_tuple_compatibility():
321
+ if not jax:
322
+ skip("Jax not installed")
323
+
324
+ x, y, z = symbols('x y z')
325
+ expr = Max(x, y, z) + Min(x, y, z)
326
+ func = lambdify((x, y, z), expr, 'jax')
327
+ input_tuple1, input_tuple2 = (1, 2, 3), (4, 5, 6)
328
+ input_array1, input_array2 = jax.numpy.asarray(input_tuple1), jax.numpy.asarray(input_tuple2)
329
+ assert np.allclose(func(*input_tuple1), func(*input_array1))
330
+ assert np.allclose(func(*input_tuple2), func(*input_array2))
331
+
332
+ def test_numpy_array():
333
+ p = NumPyPrinter()
334
+ assert p.doprint(Array([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])'
335
+ assert p.doprint(Array([1, 2])) == 'numpy.array([1, 2])'
336
+ assert p.doprint(Array([[[1, 2, 3]]])) == 'numpy.array([[[1, 2, 3]]])'
337
+ assert p.doprint(Array([], (0,))) == 'numpy.zeros((0,))'
338
+ assert p.doprint(Array([], (0, 0))) == 'numpy.zeros((0, 0))'
339
+ assert p.doprint(Array([], (0, 1))) == 'numpy.zeros((0, 1))'
340
+ assert p.doprint(Array([], (1, 0))) == 'numpy.zeros((1, 0))'
341
+ assert p.doprint(Array([1], ())) == 'numpy.array(1)'
342
+
343
+ def test_numpy_matrix():
344
+ p = NumPyPrinter()
345
+ assert p.doprint(Matrix([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])'
346
+ assert p.doprint(Matrix([1, 2])) == 'numpy.array([[1], [2]])'
347
+ assert p.doprint(Matrix(0, 0, [])) == 'numpy.zeros((0, 0))'
348
+ assert p.doprint(Matrix(0, 1, [])) == 'numpy.zeros((0, 1))'
349
+ assert p.doprint(Matrix(1, 0, [])) == 'numpy.zeros((1, 0))'
350
+
351
+ def test_numpy_known_funcs_consts():
352
+ assert _numpy_known_constants['NaN'] == 'numpy.nan'
353
+ assert _numpy_known_constants['EulerGamma'] == 'numpy.euler_gamma'
354
+
355
+ assert _numpy_known_functions['acos'] == 'numpy.arccos'
356
+ assert _numpy_known_functions['log'] == 'numpy.log'
357
+
358
+ def test_scipy_known_funcs_consts():
359
+ assert _scipy_known_constants['GoldenRatio'] == 'scipy.constants.golden_ratio'
360
+ assert _scipy_known_constants['Pi'] == 'scipy.constants.pi'
361
+
362
+ assert _scipy_known_functions['erf'] == 'scipy.special.erf'
363
+ assert _scipy_known_functions['factorial'] == 'scipy.special.factorial'
364
+
365
+ def test_numpy_print_methods():
366
+ prntr = NumPyPrinter()
367
+ assert hasattr(prntr, '_print_acos')
368
+ assert hasattr(prntr, '_print_log')
369
+
370
+ def test_scipy_print_methods():
371
+ prntr = SciPyPrinter()
372
+ assert hasattr(prntr, '_print_acos')
373
+ assert hasattr(prntr, '_print_log')
374
+ assert hasattr(prntr, '_print_erf')
375
+ assert hasattr(prntr, '_print_factorial')
376
+ assert hasattr(prntr, '_print_chebyshevt')
377
+ k = Symbol('k', integer=True, nonnegative=True)
378
+ x = Symbol('x', real=True)
379
+ assert prntr.doprint(polygamma(k, x)) == "scipy.special.polygamma(k, x)"
380
+ assert prntr.doprint(Si(x)) == "scipy.special.sici(x)[0]"
381
+ assert prntr.doprint(Ci(x)) == "scipy.special.sici(x)[1]"
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_octave.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer,
2
+ Tuple, Symbol, EulerGamma, GoldenRatio, Catalan,
3
+ Lambda, Mul, Pow, Mod, Eq, Ne, Le, Lt, Gt, Ge)
4
+ from sympy.codegen.matrix_nodes import MatrixSolve
5
+ from sympy.functions import (arg, atan2, bernoulli, beta, ceiling, chebyshevu,
6
+ chebyshevt, conjugate, DiracDelta, exp, expint,
7
+ factorial, floor, harmonic, Heaviside, im,
8
+ laguerre, LambertW, log, Max, Min, Piecewise,
9
+ polylog, re, RisingFactorial, sign, sinc, sqrt,
10
+ zeta, binomial, legendre, dirichlet_eta,
11
+ riemann_xi)
12
+ from sympy.functions import (sin, cos, tan, cot, sec, csc, asin, acos, acot,
13
+ atan, asec, acsc, sinh, cosh, tanh, coth, csch,
14
+ sech, asinh, acosh, atanh, acoth, asech, acsch)
15
+ from sympy.testing.pytest import raises, XFAIL
16
+ from sympy.utilities.lambdify import implemented_function
17
+ from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity,
18
+ HadamardProduct, SparseMatrix, HadamardPower)
19
+ from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli,
20
+ besselk, hankel1, hankel2, airyai,
21
+ airybi, airyaiprime, airybiprime)
22
+ from sympy.functions.special.gamma_functions import (gamma, lowergamma,
23
+ uppergamma, loggamma,
24
+ polygamma)
25
+ from sympy.functions.special.error_functions import (Chi, Ci, erf, erfc, erfi,
26
+ erfcinv, erfinv, fresnelc,
27
+ fresnels, li, Shi, Si, Li,
28
+ erf2, Ei)
29
+ from sympy.printing.octave import octave_code, octave_code as mcode
30
+
31
+ x, y, z = symbols('x,y,z')
32
+
33
+
34
+ def test_Integer():
35
+ assert mcode(Integer(67)) == "67"
36
+ assert mcode(Integer(-1)) == "-1"
37
+
38
+
39
+ def test_Rational():
40
+ assert mcode(Rational(3, 7)) == "3/7"
41
+ assert mcode(Rational(18, 9)) == "2"
42
+ assert mcode(Rational(3, -7)) == "-3/7"
43
+ assert mcode(Rational(-3, -7)) == "3/7"
44
+ assert mcode(x + Rational(3, 7)) == "x + 3/7"
45
+ assert mcode(Rational(3, 7)*x) == "3*x/7"
46
+
47
+
48
+ def test_Relational():
49
+ assert mcode(Eq(x, y)) == "x == y"
50
+ assert mcode(Ne(x, y)) == "x != y"
51
+ assert mcode(Le(x, y)) == "x <= y"
52
+ assert mcode(Lt(x, y)) == "x < y"
53
+ assert mcode(Gt(x, y)) == "x > y"
54
+ assert mcode(Ge(x, y)) == "x >= y"
55
+
56
+
57
+ def test_Function():
58
+ assert mcode(sin(x) ** cos(x)) == "sin(x).^cos(x)"
59
+ assert mcode(sign(x)) == "sign(x)"
60
+ assert mcode(exp(x)) == "exp(x)"
61
+ assert mcode(log(x)) == "log(x)"
62
+ assert mcode(factorial(x)) == "factorial(x)"
63
+ assert mcode(floor(x)) == "floor(x)"
64
+ assert mcode(atan2(y, x)) == "atan2(y, x)"
65
+ assert mcode(beta(x, y)) == 'beta(x, y)'
66
+ assert mcode(polylog(x, y)) == 'polylog(x, y)'
67
+ assert mcode(harmonic(x)) == 'harmonic(x)'
68
+ assert mcode(bernoulli(x)) == "bernoulli(x)"
69
+ assert mcode(bernoulli(x, y)) == "bernoulli(x, y)"
70
+ assert mcode(legendre(x, y)) == "legendre(x, y)"
71
+
72
+
73
+ def test_Function_change_name():
74
+ assert mcode(abs(x)) == "abs(x)"
75
+ assert mcode(ceiling(x)) == "ceil(x)"
76
+ assert mcode(arg(x)) == "angle(x)"
77
+ assert mcode(im(x)) == "imag(x)"
78
+ assert mcode(re(x)) == "real(x)"
79
+ assert mcode(conjugate(x)) == "conj(x)"
80
+ assert mcode(chebyshevt(y, x)) == "chebyshevT(y, x)"
81
+ assert mcode(chebyshevu(y, x)) == "chebyshevU(y, x)"
82
+ assert mcode(laguerre(x, y)) == "laguerreL(x, y)"
83
+ assert mcode(Chi(x)) == "coshint(x)"
84
+ assert mcode(Shi(x)) == "sinhint(x)"
85
+ assert mcode(Ci(x)) == "cosint(x)"
86
+ assert mcode(Si(x)) == "sinint(x)"
87
+ assert mcode(li(x)) == "logint(x)"
88
+ assert mcode(loggamma(x)) == "gammaln(x)"
89
+ assert mcode(polygamma(x, y)) == "psi(x, y)"
90
+ assert mcode(RisingFactorial(x, y)) == "pochhammer(x, y)"
91
+ assert mcode(DiracDelta(x)) == "dirac(x)"
92
+ assert mcode(DiracDelta(x, 3)) == "dirac(3, x)"
93
+ assert mcode(Heaviside(x)) == "heaviside(x, 1/2)"
94
+ assert mcode(Heaviside(x, y)) == "heaviside(x, y)"
95
+ assert mcode(binomial(x, y)) == "bincoeff(x, y)"
96
+ assert mcode(Mod(x, y)) == "mod(x, y)"
97
+
98
+
99
+ def test_minmax():
100
+ assert mcode(Max(x, y) + Min(x, y)) == "max(x, y) + min(x, y)"
101
+ assert mcode(Max(x, y, z)) == "max(x, max(y, z))"
102
+ assert mcode(Min(x, y, z)) == "min(x, min(y, z))"
103
+
104
+
105
+ def test_Pow():
106
+ assert mcode(x**3) == "x.^3"
107
+ assert mcode(x**(y**3)) == "x.^(y.^3)"
108
+ assert mcode(x**Rational(2, 3)) == 'x.^(2/3)'
109
+ g = implemented_function('g', Lambda(x, 2*x))
110
+ assert mcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
111
+ "(3.5*2*x).^(-x + y.^x)./(x.^2 + y)"
112
+ # For issue 14160
113
+ assert mcode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
114
+ evaluate=False)) == '-2*x./(y.*y)'
115
+
116
+
117
+ def test_basic_ops():
118
+ assert mcode(x*y) == "x.*y"
119
+ assert mcode(x + y) == "x + y"
120
+ assert mcode(x - y) == "x - y"
121
+ assert mcode(-x) == "-x"
122
+
123
+
124
+ def test_1_over_x_and_sqrt():
125
+ # 1.0 and 0.5 would do something different in regular StrPrinter,
126
+ # but these are exact in IEEE floating point so no different here.
127
+ assert mcode(1/x) == '1./x'
128
+ assert mcode(x**-1) == mcode(x**-1.0) == '1./x'
129
+ assert mcode(1/sqrt(x)) == '1./sqrt(x)'
130
+ assert mcode(x**-S.Half) == mcode(x**-0.5) == '1./sqrt(x)'
131
+ assert mcode(sqrt(x)) == 'sqrt(x)'
132
+ assert mcode(x**S.Half) == mcode(x**0.5) == 'sqrt(x)'
133
+ assert mcode(1/pi) == '1/pi'
134
+ assert mcode(pi**-1) == mcode(pi**-1.0) == '1/pi'
135
+ assert mcode(pi**-0.5) == '1/sqrt(pi)'
136
+
137
+
138
+ def test_mix_number_mult_symbols():
139
+ assert mcode(3*x) == "3*x"
140
+ assert mcode(pi*x) == "pi*x"
141
+ assert mcode(3/x) == "3./x"
142
+ assert mcode(pi/x) == "pi./x"
143
+ assert mcode(x/3) == "x/3"
144
+ assert mcode(x/pi) == "x/pi"
145
+ assert mcode(x*y) == "x.*y"
146
+ assert mcode(3*x*y) == "3*x.*y"
147
+ assert mcode(3*pi*x*y) == "3*pi*x.*y"
148
+ assert mcode(x/y) == "x./y"
149
+ assert mcode(3*x/y) == "3*x./y"
150
+ assert mcode(x*y/z) == "x.*y./z"
151
+ assert mcode(x/y*z) == "x.*z./y"
152
+ assert mcode(1/x/y) == "1./(x.*y)"
153
+ assert mcode(2*pi*x/y/z) == "2*pi*x./(y.*z)"
154
+ assert mcode(3*pi/x) == "3*pi./x"
155
+ assert mcode(S(3)/5) == "3/5"
156
+ assert mcode(S(3)/5*x) == "3*x/5"
157
+ assert mcode(x/y/z) == "x./(y.*z)"
158
+ assert mcode((x+y)/z) == "(x + y)./z"
159
+ assert mcode((x+y)/(z+x)) == "(x + y)./(x + z)"
160
+ assert mcode((x+y)/EulerGamma) == "(x + y)/%s" % EulerGamma.evalf(17)
161
+ assert mcode(x/3/pi) == "x/(3*pi)"
162
+ assert mcode(S(3)/5*x*y/pi) == "3*x.*y/(5*pi)"
163
+
164
+
165
+ def test_mix_number_pow_symbols():
166
+ assert mcode(pi**3) == 'pi^3'
167
+ assert mcode(x**2) == 'x.^2'
168
+ assert mcode(x**(pi**3)) == 'x.^(pi^3)'
169
+ assert mcode(x**y) == 'x.^y'
170
+ assert mcode(x**(y**z)) == 'x.^(y.^z)'
171
+ assert mcode((x**y)**z) == '(x.^y).^z'
172
+
173
+
174
+ def test_imag():
175
+ I = S('I')
176
+ assert mcode(I) == "1i"
177
+ assert mcode(5*I) == "5i"
178
+ assert mcode((S(3)/2)*I) == "3*1i/2"
179
+ assert mcode(3+4*I) == "3 + 4i"
180
+ assert mcode(sqrt(3)*I) == "sqrt(3)*1i"
181
+
182
+
183
+ def test_constants():
184
+ assert mcode(pi) == "pi"
185
+ assert mcode(oo) == "inf"
186
+ assert mcode(-oo) == "-inf"
187
+ assert mcode(S.NegativeInfinity) == "-inf"
188
+ assert mcode(S.NaN) == "NaN"
189
+ assert mcode(S.Exp1) == "exp(1)"
190
+ assert mcode(exp(1)) == "exp(1)"
191
+
192
+
193
+ def test_constants_other():
194
+ assert mcode(2*GoldenRatio) == "2*(1+sqrt(5))/2"
195
+ assert mcode(2*Catalan) == "2*%s" % Catalan.evalf(17)
196
+ assert mcode(2*EulerGamma) == "2*%s" % EulerGamma.evalf(17)
197
+
198
+
199
+ def test_boolean():
200
+ assert mcode(x & y) == "x & y"
201
+ assert mcode(x | y) == "x | y"
202
+ assert mcode(~x) == "~x"
203
+ assert mcode(x & y & z) == "x & y & z"
204
+ assert mcode(x | y | z) == "x | y | z"
205
+ assert mcode((x & y) | z) == "z | x & y"
206
+ assert mcode((x | y) & z) == "z & (x | y)"
207
+
208
+
209
+ def test_KroneckerDelta():
210
+ from sympy.functions import KroneckerDelta
211
+ assert mcode(KroneckerDelta(x, y)) == "double(x == y)"
212
+ assert mcode(KroneckerDelta(x, y + 1)) == "double(x == (y + 1))"
213
+ assert mcode(KroneckerDelta(2**x, y)) == "double((2.^x) == y)"
214
+
215
+
216
+ def test_Matrices():
217
+ assert mcode(Matrix(1, 1, [10])) == "10"
218
+ A = Matrix([[1, sin(x/2), abs(x)],
219
+ [0, 1, pi],
220
+ [0, exp(1), ceiling(x)]])
221
+ expected = "[1 sin(x/2) abs(x); 0 1 pi; 0 exp(1) ceil(x)]"
222
+ assert mcode(A) == expected
223
+ # row and columns
224
+ assert mcode(A[:,0]) == "[1; 0; 0]"
225
+ assert mcode(A[0,:]) == "[1 sin(x/2) abs(x)]"
226
+ # empty matrices
227
+ assert mcode(Matrix(0, 0, [])) == '[]'
228
+ assert mcode(Matrix(0, 3, [])) == 'zeros(0, 3)'
229
+ # annoying to read but correct
230
+ assert mcode(Matrix([[x, x - y, -y]])) == "[x x - y -y]"
231
+
232
+
233
+ def test_vector_entries_hadamard():
234
+ # For a row or column, user might to use the other dimension
235
+ A = Matrix([[1, sin(2/x), 3*pi/x/5]])
236
+ assert mcode(A) == "[1 sin(2./x) 3*pi./(5*x)]"
237
+ assert mcode(A.T) == "[1; sin(2./x); 3*pi./(5*x)]"
238
+
239
+
240
+ @XFAIL
241
+ def test_Matrices_entries_not_hadamard():
242
+ # For Matrix with col >= 2, row >= 2, they need to be scalars
243
+ # FIXME: is it worth worrying about this? Its not wrong, just
244
+ # leave it user's responsibility to put scalar data for x.
245
+ A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]])
246
+ expected = ("[1 sin(2/x) 3*pi/(5*x);\n"
247
+ "1 2 x*y]") # <- we give x.*y
248
+ assert mcode(A) == expected
249
+
250
+
251
+ def test_MatrixSymbol():
252
+ n = Symbol('n', integer=True)
253
+ A = MatrixSymbol('A', n, n)
254
+ B = MatrixSymbol('B', n, n)
255
+ assert mcode(A*B) == "A*B"
256
+ assert mcode(B*A) == "B*A"
257
+ assert mcode(2*A*B) == "2*A*B"
258
+ assert mcode(B*2*A) == "2*B*A"
259
+ assert mcode(A*(B + 3*Identity(n))) == "A*(3*eye(n) + B)"
260
+ assert mcode(A**(x**2)) == "A^(x.^2)"
261
+ assert mcode(A**3) == "A^3"
262
+ assert mcode(A**S.Half) == "A^(1/2)"
263
+
264
+
265
+ def test_MatrixSolve():
266
+ n = Symbol('n', integer=True)
267
+ A = MatrixSymbol('A', n, n)
268
+ x = MatrixSymbol('x', n, 1)
269
+ assert mcode(MatrixSolve(A, x)) == "A \\ x"
270
+
271
+ def test_special_matrices():
272
+ assert mcode(6*Identity(3)) == "6*eye(3)"
273
+
274
+
275
+ def test_containers():
276
+ assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
277
+ "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}"
278
+ assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}"
279
+ assert mcode([1]) == "{1}"
280
+ assert mcode((1,)) == "{1}"
281
+ assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}"
282
+ assert mcode((1, x*y, (3, x**2))) == "{1, x.*y, {3, x.^2}}"
283
+ # scalar, matrix, empty matrix and empty list
284
+ assert mcode((1, eye(3), Matrix(0, 0, []), [])) == "{1, [1 0 0; 0 1 0; 0 0 1], [], {}}"
285
+
286
+
287
+ def test_octave_noninline():
288
+ source = mcode((x+y)/Catalan, assign_to='me', inline=False)
289
+ expected = (
290
+ "Catalan = %s;\n"
291
+ "me = (x + y)/Catalan;"
292
+ ) % Catalan.evalf(17)
293
+ assert source == expected
294
+
295
+
296
+ def test_octave_piecewise():
297
+ expr = Piecewise((x, x < 1), (x**2, True))
298
+ assert mcode(expr) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))"
299
+ assert mcode(expr, assign_to="r") == (
300
+ "r = ((x < 1).*(x) + (~(x < 1)).*(x.^2));")
301
+ assert mcode(expr, assign_to="r", inline=False) == (
302
+ "if (x < 1)\n"
303
+ " r = x;\n"
304
+ "else\n"
305
+ " r = x.^2;\n"
306
+ "end")
307
+ expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True))
308
+ expected = ("((x < 1).*(x.^2) + (~(x < 1)).*( ...\n"
309
+ "(x < 2).*(x.^3) + (~(x < 2)).*( ...\n"
310
+ "(x < 3).*(x.^4) + (~(x < 3)).*(x.^5))))")
311
+ assert mcode(expr) == expected
312
+ assert mcode(expr, assign_to="r") == "r = " + expected + ";"
313
+ assert mcode(expr, assign_to="r", inline=False) == (
314
+ "if (x < 1)\n"
315
+ " r = x.^2;\n"
316
+ "elseif (x < 2)\n"
317
+ " r = x.^3;\n"
318
+ "elseif (x < 3)\n"
319
+ " r = x.^4;\n"
320
+ "else\n"
321
+ " r = x.^5;\n"
322
+ "end")
323
+ # Check that Piecewise without a True (default) condition error
324
+ expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
325
+ raises(ValueError, lambda: mcode(expr))
326
+
327
+
328
+ def test_octave_piecewise_times_const():
329
+ pw = Piecewise((x, x < 1), (x**2, True))
330
+ assert mcode(2*pw) == "2*((x < 1).*(x) + (~(x < 1)).*(x.^2))"
331
+ assert mcode(pw/x) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./x"
332
+ assert mcode(pw/(x*y)) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./(x.*y)"
333
+ assert mcode(pw/3) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))/3"
334
+
335
+
336
+ def test_octave_matrix_assign_to():
337
+ A = Matrix([[1, 2, 3]])
338
+ assert mcode(A, assign_to='a') == "a = [1 2 3];"
339
+ A = Matrix([[1, 2], [3, 4]])
340
+ assert mcode(A, assign_to='A') == "A = [1 2; 3 4];"
341
+
342
+
343
+ def test_octave_matrix_assign_to_more():
344
+ # assigning to Symbol or MatrixSymbol requires lhs/rhs match
345
+ A = Matrix([[1, 2, 3]])
346
+ B = MatrixSymbol('B', 1, 3)
347
+ C = MatrixSymbol('C', 2, 3)
348
+ assert mcode(A, assign_to=B) == "B = [1 2 3];"
349
+ raises(ValueError, lambda: mcode(A, assign_to=x))
350
+ raises(ValueError, lambda: mcode(A, assign_to=C))
351
+
352
+
353
+ def test_octave_matrix_1x1():
354
+ A = Matrix([[3]])
355
+ B = MatrixSymbol('B', 1, 1)
356
+ C = MatrixSymbol('C', 1, 2)
357
+ assert mcode(A, assign_to=B) == "B = 3;"
358
+ # FIXME?
359
+ #assert mcode(A, assign_to=x) == "x = 3;"
360
+ raises(ValueError, lambda: mcode(A, assign_to=C))
361
+
362
+
363
+ def test_octave_matrix_elements():
364
+ A = Matrix([[x, 2, x*y]])
365
+ assert mcode(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x.^2 + x.*y + 2"
366
+ A = MatrixSymbol('AA', 1, 3)
367
+ assert mcode(A) == "AA"
368
+ assert mcode(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \
369
+ "sin(AA(1, 2)) + AA(1, 1).^2 + AA(1, 3)"
370
+ assert mcode(sum(A)) == "AA(1, 1) + AA(1, 2) + AA(1, 3)"
371
+
372
+
373
+ def test_octave_boolean():
374
+ assert mcode(True) == "true"
375
+ assert mcode(S.true) == "true"
376
+ assert mcode(False) == "false"
377
+ assert mcode(S.false) == "false"
378
+
379
+
380
+ def test_octave_not_supported():
381
+ with raises(NotImplementedError):
382
+ mcode(S.ComplexInfinity)
383
+ f = Function('f')
384
+ assert mcode(f(x).diff(x), strict=False) == (
385
+ "% Not supported in Octave:\n"
386
+ "% Derivative\n"
387
+ "Derivative(f(x), x)"
388
+ )
389
+
390
+
391
+ def test_octave_not_supported_not_on_whitelist():
392
+ from sympy.functions.special.polynomials import assoc_laguerre
393
+ with raises(NotImplementedError):
394
+ mcode(assoc_laguerre(x, y, z))
395
+
396
+
397
+ def test_octave_expint():
398
+ assert mcode(expint(1, x)) == "expint(x)"
399
+ with raises(NotImplementedError):
400
+ mcode(expint(2, x))
401
+ assert mcode(expint(y, x), strict=False) == (
402
+ "% Not supported in Octave:\n"
403
+ "% expint\n"
404
+ "expint(y, x)"
405
+ )
406
+
407
+
408
+ def test_trick_indent_with_end_else_words():
409
+ # words starting with "end" or "else" do not confuse the indenter
410
+ t1 = S('endless')
411
+ t2 = S('elsewhere')
412
+ pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True))
413
+ assert mcode(pw, inline=False) == (
414
+ "if (x < 0)\n"
415
+ " endless\n"
416
+ "elseif (x <= 1)\n"
417
+ " elsewhere\n"
418
+ "else\n"
419
+ " 1\n"
420
+ "end")
421
+
422
+
423
+ def test_hadamard():
424
+ A = MatrixSymbol('A', 3, 3)
425
+ B = MatrixSymbol('B', 3, 3)
426
+ v = MatrixSymbol('v', 3, 1)
427
+ h = MatrixSymbol('h', 1, 3)
428
+ C = HadamardProduct(A, B)
429
+ n = Symbol('n')
430
+ assert mcode(C) == "A.*B"
431
+ assert mcode(C*v) == "(A.*B)*v"
432
+ assert mcode(h*C*v) == "h*(A.*B)*v"
433
+ assert mcode(C*A) == "(A.*B)*A"
434
+ # mixing Hadamard and scalar strange b/c we vectorize scalars
435
+ assert mcode(C*x*y) == "(x.*y)*(A.*B)"
436
+
437
+ # Testing HadamardPower:
438
+ assert mcode(HadamardPower(A, n)) == "A.**n"
439
+ assert mcode(HadamardPower(A, 1+n)) == "A.**(n + 1)"
440
+ assert mcode(HadamardPower(A*B.T, 1+n)) == "(A*B.T).**(n + 1)"
441
+
442
+
443
+ def test_sparse():
444
+ M = SparseMatrix(5, 6, {})
445
+ M[2, 2] = 10
446
+ M[1, 2] = 20
447
+ M[1, 3] = 22
448
+ M[0, 3] = 30
449
+ M[3, 0] = x*y
450
+ assert mcode(M) == (
451
+ "sparse([4 2 3 1 2], [1 3 3 4 4], [x.*y 20 10 30 22], 5, 6)"
452
+ )
453
+
454
+
455
+ def test_sinc():
456
+ assert mcode(sinc(x)) == 'sinc(x/pi)'
457
+ assert mcode(sinc(x + 3)) == 'sinc((x + 3)/pi)'
458
+ assert mcode(sinc(pi*(x + 3))) == 'sinc(x + 3)'
459
+
460
+
461
+ def test_trigfun():
462
+ for f in (sin, cos, tan, cot, sec, csc, asin, acos, acot, atan, asec, acsc,
463
+ sinh, cosh, tanh, coth, csch, sech, asinh, acosh, atanh, acoth,
464
+ asech, acsch):
465
+ assert octave_code(f(x) == f.__name__ + '(x)')
466
+
467
+
468
+ def test_specfun():
469
+ n = Symbol('n')
470
+ for f in [besselj, bessely, besseli, besselk]:
471
+ assert octave_code(f(n, x)) == f.__name__ + '(n, x)'
472
+ for f in (erfc, erfi, erf, erfinv, erfcinv, fresnelc, fresnels, gamma):
473
+ assert octave_code(f(x)) == f.__name__ + '(x)'
474
+ assert octave_code(hankel1(n, x)) == 'besselh(n, 1, x)'
475
+ assert octave_code(hankel2(n, x)) == 'besselh(n, 2, x)'
476
+ assert octave_code(airyai(x)) == 'airy(0, x)'
477
+ assert octave_code(airyaiprime(x)) == 'airy(1, x)'
478
+ assert octave_code(airybi(x)) == 'airy(2, x)'
479
+ assert octave_code(airybiprime(x)) == 'airy(3, x)'
480
+ assert octave_code(uppergamma(n, x)) == '(gammainc(x, n, \'upper\').*gamma(n))'
481
+ assert octave_code(lowergamma(n, x)) == '(gammainc(x, n).*gamma(n))'
482
+ assert octave_code(z**lowergamma(n, x)) == 'z.^(gammainc(x, n).*gamma(n))'
483
+ assert octave_code(jn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*besselj(n + 1/2, x)/2'
484
+ assert octave_code(yn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*bessely(n + 1/2, x)/2'
485
+ assert octave_code(LambertW(x)) == 'lambertw(x)'
486
+ assert octave_code(LambertW(x, n)) == 'lambertw(n, x)'
487
+
488
+ # Automatic rewrite
489
+ assert octave_code(Ei(x)) == '(logint(exp(x)))'
490
+ assert octave_code(dirichlet_eta(x)) == '(((x == 1).*(log(2)) + (~(x == 1)).*((1 - 2.^(1 - x)).*zeta(x))))'
491
+ assert octave_code(riemann_xi(x)) == '(pi.^(-x/2).*x.*(x - 1).*gamma(x/2).*zeta(x)/2)'
492
+
493
+
494
+ def test_MatrixElement_printing():
495
+ # test cases for issue #11821
496
+ A = MatrixSymbol("A", 1, 3)
497
+ B = MatrixSymbol("B", 1, 3)
498
+ C = MatrixSymbol("C", 1, 3)
499
+
500
+ assert mcode(A[0, 0]) == "A(1, 1)"
501
+ assert mcode(3 * A[0, 0]) == "3*A(1, 1)"
502
+
503
+ F = C[0, 0].subs(C, A - B)
504
+ assert mcode(F) == "(A - B)(1, 1)"
505
+
506
+
507
+ def test_zeta_printing_issue_14820():
508
+ assert octave_code(zeta(x)) == 'zeta(x)'
509
+ with raises(NotImplementedError):
510
+ octave_code(zeta(x, y))
511
+
512
+
513
+ def test_automatic_rewrite():
514
+ assert octave_code(Li(x)) == '(logint(x) - logint(2))'
515
+ assert octave_code(erf2(x, y)) == '(-erf(x) + erf(y))'
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_precedence.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.concrete.products import Product
2
+ from sympy.concrete.summations import Sum
3
+ from sympy.core.function import Derivative, Function
4
+ from sympy.core.numbers import Integer, Rational, Float, oo
5
+ from sympy.core.relational import Rel
6
+ from sympy.core.symbol import symbols
7
+ from sympy.functions import sin
8
+ from sympy.integrals.integrals import Integral
9
+ from sympy.series.order import Order
10
+
11
+ from sympy.printing.precedence import precedence, PRECEDENCE
12
+
13
+ x, y = symbols("x,y")
14
+
15
+
16
+ def test_Add():
17
+ assert precedence(x + y) == PRECEDENCE["Add"]
18
+ assert precedence(x*y + 1) == PRECEDENCE["Add"]
19
+
20
+
21
+ def test_Function():
22
+ assert precedence(sin(x)) == PRECEDENCE["Func"]
23
+
24
+ def test_Derivative():
25
+ assert precedence(Derivative(x, y)) == PRECEDENCE["Atom"]
26
+
27
+ def test_Integral():
28
+ assert precedence(Integral(x, y)) == PRECEDENCE["Atom"]
29
+
30
+
31
+ def test_Mul():
32
+ assert precedence(x*y) == PRECEDENCE["Mul"]
33
+ assert precedence(-x*y) == PRECEDENCE["Add"]
34
+
35
+
36
+ def test_Number():
37
+ assert precedence(Integer(0)) == PRECEDENCE["Atom"]
38
+ assert precedence(Integer(1)) == PRECEDENCE["Atom"]
39
+ assert precedence(Integer(-1)) == PRECEDENCE["Add"]
40
+ assert precedence(Integer(10)) == PRECEDENCE["Atom"]
41
+ assert precedence(Rational(5, 2)) == PRECEDENCE["Mul"]
42
+ assert precedence(Rational(-5, 2)) == PRECEDENCE["Add"]
43
+ assert precedence(Float(5)) == PRECEDENCE["Atom"]
44
+ assert precedence(Float(-5)) == PRECEDENCE["Add"]
45
+ assert precedence(oo) == PRECEDENCE["Atom"]
46
+ assert precedence(-oo) == PRECEDENCE["Add"]
47
+
48
+
49
+ def test_Order():
50
+ assert precedence(Order(x)) == PRECEDENCE["Atom"]
51
+
52
+
53
+ def test_Pow():
54
+ assert precedence(x**y) == PRECEDENCE["Pow"]
55
+ assert precedence(-x**y) == PRECEDENCE["Add"]
56
+ assert precedence(x**-y) == PRECEDENCE["Pow"]
57
+
58
+
59
+ def test_Product():
60
+ assert precedence(Product(x, (x, y, y + 1))) == PRECEDENCE["Atom"]
61
+
62
+
63
+ def test_Relational():
64
+ assert precedence(Rel(x + y, y, "<")) == PRECEDENCE["Relational"]
65
+
66
+
67
+ def test_Sum():
68
+ assert precedence(Sum(x, (x, y, y + 1))) == PRECEDENCE["Atom"]
69
+
70
+
71
+ def test_Symbol():
72
+ assert precedence(x) == PRECEDENCE["Atom"]
73
+
74
+
75
+ def test_And_Or():
76
+ # precedence relations between logical operators, ...
77
+ assert precedence(x & y) > precedence(x | y)
78
+ assert precedence(~y) > precedence(x & y)
79
+ # ... and with other operators (cfr. other programming languages)
80
+ assert precedence(x + y) > precedence(x | y)
81
+ assert precedence(x + y) > precedence(x & y)
82
+ assert precedence(x*y) > precedence(x | y)
83
+ assert precedence(x*y) > precedence(x & y)
84
+ assert precedence(~y) > precedence(x*y)
85
+ assert precedence(~y) > precedence(x - y)
86
+ # double checks
87
+ assert precedence(x & y) == PRECEDENCE["And"]
88
+ assert precedence(x | y) == PRECEDENCE["Or"]
89
+ assert precedence(~y) == PRECEDENCE["Not"]
90
+
91
+
92
+ def test_custom_function_precedence_comparison():
93
+ """
94
+ Test cases for custom functions with different precedence values,
95
+ specifically handling:
96
+ 1. Functions with precedence < PRECEDENCE["Mul"] (50)
97
+ 2. Functions with precedence = Func (70)
98
+
99
+ Key distinction:
100
+ 1. Lower precedence functions (45) need parentheses: -2*(x F y)
101
+ 2. Higher precedence functions (70) don't: -2*x F y
102
+ """
103
+ class LowPrecedenceF(Function):
104
+ precedence = PRECEDENCE["Mul"] - 5
105
+ def _sympystr(self, printer):
106
+ return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}"
107
+
108
+ class HighPrecedenceF(Function):
109
+ precedence = PRECEDENCE["Func"]
110
+ def _sympystr(self, printer):
111
+ return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}"
112
+
113
+ def test_low_precedence():
114
+ expr1 = 2 * LowPrecedenceF(x, y)
115
+ assert str(expr1) == "2*(x F y)"
116
+
117
+ expr2 = -2 * LowPrecedenceF(x, y)
118
+ assert str(expr2) == "-2*(x F y)"
119
+
120
+ def test_high_precedence():
121
+ expr1 = 2 * HighPrecedenceF(x, y)
122
+ assert str(expr1) == "2*x F y"
123
+
124
+ expr2 = -2 * HighPrecedenceF(x, y)
125
+ assert str(expr2) == "-2*x F y"
126
+
127
+ test_low_precedence()
128
+ test_high_precedence()
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_preview.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from sympy.core.relational import Eq
4
+ from sympy.core.symbol import Symbol
5
+ from sympy.functions.elementary.piecewise import Piecewise
6
+ from sympy.printing.preview import preview
7
+
8
+ from io import BytesIO
9
+
10
+
11
+ def test_preview():
12
+ x = Symbol('x')
13
+ obj = BytesIO()
14
+ try:
15
+ preview(x, output='png', viewer='BytesIO', outputbuffer=obj)
16
+ except RuntimeError:
17
+ pass # latex not installed on CI server
18
+
19
+
20
+ def test_preview_unicode_symbol():
21
+ # issue 9107
22
+ a = Symbol('α')
23
+ obj = BytesIO()
24
+ try:
25
+ preview(a, output='png', viewer='BytesIO', outputbuffer=obj)
26
+ except RuntimeError:
27
+ pass # latex not installed on CI server
28
+
29
+
30
+ def test_preview_latex_construct_in_expr():
31
+ # see PR 9801
32
+ x = Symbol('x')
33
+ pw = Piecewise((1, Eq(x, 0)), (0, True))
34
+ obj = BytesIO()
35
+ try:
36
+ preview(pw, output='png', viewer='BytesIO', outputbuffer=obj)
37
+ except RuntimeError:
38
+ pass # latex not installed on CI server
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_pycode.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy import Not
2
+ from sympy.codegen import Assignment
3
+ from sympy.codegen.ast import none
4
+ from sympy.codegen.cfunctions import expm1, log1p
5
+ from sympy.codegen.scipy_nodes import cosm1
6
+ from sympy.codegen.matrix_nodes import MatrixSolve
7
+ from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow
8
+ from sympy.core.function import Derivative
9
+ from sympy.core.numbers import pi
10
+ from sympy.core.singleton import S
11
+ from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth, sec, log, sin, cos, tan, asin, atan, sinh, cosh, tanh, asinh, acosh, atanh
12
+ from sympy.functions.elementary.trigonometric import atan2
13
+ from sympy.logic import And, Or
14
+ from sympy.matrices import SparseMatrix, MatrixSymbol, Identity
15
+ from sympy.printing.codeprinter import PrintMethodNotImplementedError
16
+ from sympy.printing.pycode import (
17
+ MpmathPrinter, CmathPrinter, PythonCodePrinter, pycode, SymPyPrinter
18
+ )
19
+ from sympy.printing.tensorflow import TensorflowPrinter
20
+ from sympy.printing.numpy import NumPyPrinter, SciPyPrinter
21
+ from sympy.testing.pytest import raises, skip
22
+ from sympy.tensor import IndexedBase, Idx
23
+ from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray
24
+ from sympy.external import import_module
25
+ from sympy.functions.special.gamma_functions import loggamma
26
+
27
+
28
+
29
+ x, y, z = symbols('x y z')
30
+ p = IndexedBase("p")
31
+
32
+
33
+ def test_PythonCodePrinter():
34
+ prntr = PythonCodePrinter()
35
+
36
+ assert not prntr.module_imports
37
+
38
+ assert prntr.doprint(x**y) == 'x**y'
39
+ assert prntr.doprint(Mod(x, 2)) == 'x % 2'
40
+ assert prntr.doprint(-Mod(x, y)) == '-(x % y)'
41
+ assert prntr.doprint(Mod(-x, y)) == '(-x) % y'
42
+ assert prntr.doprint(And(x, y)) == 'x and y'
43
+ assert prntr.doprint(Or(x, y)) == 'x or y'
44
+ assert prntr.doprint(1/(x+y)) == '1/(x + y)'
45
+ assert prntr.doprint(Not(x)) == 'not x'
46
+ assert not prntr.module_imports
47
+
48
+ assert prntr.doprint(pi) == 'math.pi'
49
+ assert prntr.module_imports == {'math': {'pi'}}
50
+
51
+ assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)'
52
+ assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)'
53
+ assert prntr.module_imports == {'math': {'pi', 'sqrt'}}
54
+
55
+ assert prntr.doprint(acos(x)) == 'math.acos(x)'
56
+ assert prntr.doprint(cot(x)) == '(1/math.tan(x))'
57
+ assert prntr.doprint(coth(x)) == '((math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x)))'
58
+ assert prntr.doprint(asec(x)) == '(math.acos(1/x))'
59
+ assert prntr.doprint(acsch(x)) == '(math.log(math.sqrt(1 + x**(-2)) + 1/x))'
60
+
61
+ assert prntr.doprint(Assignment(x, 2)) == 'x = 2'
62
+ assert prntr.doprint(Piecewise((1, Eq(x, 0)),
63
+ (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)'
64
+ assert prntr.doprint(Piecewise((2, Le(x, 0)),
65
+ (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\
66
+ ' (3) if (x > 0) else None)'
67
+ assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))'
68
+ assert prntr.doprint(p[0, 1]) == 'p[0, 1]'
69
+ assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)'
70
+
71
+ assert prntr.doprint((2,3)) == "(2, 3)"
72
+ assert prntr.doprint([2,3]) == "[2, 3]"
73
+
74
+ assert prntr.doprint(Min(x, y)) == "min(x, y)"
75
+ assert prntr.doprint(Max(x, y)) == "max(x, y)"
76
+
77
+
78
+ def test_PythonCodePrinter_standard():
79
+ prntr = PythonCodePrinter()
80
+
81
+ assert prntr.standard == 'python3'
82
+
83
+ raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'}))
84
+
85
+
86
+ def test_CmathPrinter():
87
+ p = CmathPrinter()
88
+
89
+ assert p.doprint(sqrt(x)) == 'cmath.sqrt(x)'
90
+ assert p.doprint(log(x)) == 'cmath.log(x)'
91
+
92
+ assert p.doprint(sin(x)) == 'cmath.sin(x)'
93
+ assert p.doprint(cos(x)) == 'cmath.cos(x)'
94
+ assert p.doprint(tan(x)) == 'cmath.tan(x)'
95
+
96
+ assert p.doprint(asin(x)) == 'cmath.asin(x)'
97
+ assert p.doprint(acos(x)) == 'cmath.acos(x)'
98
+ assert p.doprint(atan(x)) == 'cmath.atan(x)'
99
+
100
+ assert p.doprint(sinh(x)) == 'cmath.sinh(x)'
101
+ assert p.doprint(cosh(x)) == 'cmath.cosh(x)'
102
+ assert p.doprint(tanh(x)) == 'cmath.tanh(x)'
103
+
104
+ assert p.doprint(asinh(x)) == 'cmath.asinh(x)'
105
+ assert p.doprint(acosh(x)) == 'cmath.acosh(x)'
106
+ assert p.doprint(atanh(x)) == 'cmath.atanh(x)'
107
+
108
+
109
+ def test_MpmathPrinter():
110
+ p = MpmathPrinter()
111
+ assert p.doprint(sign(x)) == 'mpmath.sign(x)'
112
+ assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)'
113
+
114
+ assert p.doprint(S.Exp1) == 'mpmath.e'
115
+ assert p.doprint(S.Pi) == 'mpmath.pi'
116
+ assert p.doprint(S.GoldenRatio) == 'mpmath.phi'
117
+ assert p.doprint(S.EulerGamma) == 'mpmath.euler'
118
+ assert p.doprint(S.NaN) == 'mpmath.nan'
119
+ assert p.doprint(S.Infinity) == 'mpmath.inf'
120
+ assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf'
121
+ assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)'
122
+
123
+
124
+ def test_NumPyPrinter():
125
+ from sympy.core.function import Lambda
126
+ from sympy.matrices.expressions.adjoint import Adjoint
127
+ from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf)
128
+ from sympy.matrices.expressions.funcmatrix import FunctionMatrix
129
+ from sympy.matrices.expressions.hadamard import HadamardProduct
130
+ from sympy.matrices.expressions.kronecker import KroneckerProduct
131
+ from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix)
132
+ from sympy.abc import a, b
133
+ p = NumPyPrinter()
134
+ assert p.doprint(sign(x)) == 'numpy.sign(x)'
135
+ A = MatrixSymbol("A", 2, 2)
136
+ B = MatrixSymbol("B", 2, 2)
137
+ C = MatrixSymbol("C", 1, 5)
138
+ D = MatrixSymbol("D", 3, 4)
139
+ assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)"
140
+ assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)"
141
+ assert p.doprint(Identity(3)) == "numpy.eye(3)"
142
+
143
+ u = MatrixSymbol('x', 2, 1)
144
+ v = MatrixSymbol('y', 2, 1)
145
+ assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)'
146
+ assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y'
147
+
148
+ assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))"
149
+ assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))"
150
+ assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \
151
+ "numpy.fromfunction(lambda a, b: a + b, (4, 5))"
152
+ assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)"
153
+ assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)"
154
+ assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))"
155
+ assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))"
156
+ assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)"
157
+ assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))"
158
+
159
+ # Workaround for numpy negative integer power errors
160
+ assert p.doprint(x**-1) == 'x**(-1.0)'
161
+ assert p.doprint(x**-2) == 'x**(-2.0)'
162
+
163
+ expr = Pow(2, -1, evaluate=False)
164
+ assert p.doprint(expr) == "2**(-1.0)"
165
+
166
+ assert p.doprint(S.Exp1) == 'numpy.e'
167
+ assert p.doprint(S.Pi) == 'numpy.pi'
168
+ assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma'
169
+ assert p.doprint(S.NaN) == 'numpy.nan'
170
+ assert p.doprint(S.Infinity) == 'numpy.inf'
171
+ assert p.doprint(S.NegativeInfinity) == '-numpy.inf'
172
+
173
+ # Function rewriting operator precedence fix
174
+ assert p.doprint(sec(x)**2) == '(numpy.cos(x)**(-1.0))**2'
175
+
176
+
177
+ def test_issue_18770():
178
+ numpy = import_module('numpy')
179
+ if not numpy:
180
+ skip("numpy not installed.")
181
+
182
+ from sympy.functions.elementary.miscellaneous import (Max, Min)
183
+ from sympy.utilities.lambdify import lambdify
184
+
185
+ expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1)
186
+ func = lambdify(x, expr1, "numpy")
187
+ assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all()
188
+ assert func(4) == 3
189
+
190
+ expr1 = Max(x**2, x**3)
191
+ func = lambdify(x,expr1, "numpy")
192
+ assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all()
193
+ assert func(4) == 64
194
+
195
+
196
+ def test_SciPyPrinter():
197
+ p = SciPyPrinter()
198
+ expr = acos(x)
199
+ assert 'numpy' not in p.module_imports
200
+ assert p.doprint(expr) == 'numpy.arccos(x)'
201
+ assert 'numpy' in p.module_imports
202
+ assert not any(m.startswith('scipy') for m in p.module_imports)
203
+ smat = SparseMatrix(2, 5, {(0, 1): 3})
204
+ assert p.doprint(smat) == \
205
+ 'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))'
206
+ assert 'scipy.sparse' in p.module_imports
207
+
208
+ assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio'
209
+ assert p.doprint(S.Pi) == 'scipy.constants.pi'
210
+ assert p.doprint(S.Exp1) == 'numpy.e'
211
+
212
+
213
+ def test_pycode_reserved_words():
214
+ s1, s2 = symbols('if else')
215
+ raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True))
216
+ py_str = pycode(s1 + s2)
217
+ assert py_str in ('else_ + if_', 'if_ + else_')
218
+
219
+
220
+ def test_issue_20762():
221
+ # Make sure pycode removes curly braces from subscripted variables
222
+ a_b, b, a_11 = symbols('a_{b} b a_{11}')
223
+ expr = a_b*b
224
+ assert pycode(expr) == 'a_b*b'
225
+ expr = a_11*b
226
+ assert pycode(expr) == 'a_11*b'
227
+
228
+
229
+ def test_sqrt():
230
+ prntr = PythonCodePrinter()
231
+ assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)'
232
+ assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)'
233
+
234
+ prntr = PythonCodePrinter({'standard' : 'python3'})
235
+ assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
236
+ assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)'
237
+
238
+ prntr = MpmathPrinter()
239
+ assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)'
240
+ assert prntr._print_Pow(sqrt(x), rational=True) == \
241
+ "x**(mpmath.mpf(1)/mpmath.mpf(2))"
242
+
243
+ prntr = NumPyPrinter()
244
+ assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
245
+ assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
246
+
247
+ prntr = SciPyPrinter()
248
+ assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
249
+ assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
250
+
251
+ prntr = SymPyPrinter()
252
+ assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)'
253
+ assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
254
+
255
+
256
+ def test_frac():
257
+ from sympy.functions.elementary.integers import frac
258
+
259
+ expr = frac(x)
260
+ prntr = NumPyPrinter()
261
+ assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
262
+
263
+ prntr = SciPyPrinter()
264
+ assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
265
+
266
+ prntr = PythonCodePrinter()
267
+ assert prntr.doprint(expr) == 'x % 1'
268
+
269
+ prntr = MpmathPrinter()
270
+ assert prntr.doprint(expr) == 'mpmath.frac(x)'
271
+
272
+ prntr = SymPyPrinter()
273
+ assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)'
274
+
275
+
276
+ class CustomPrintedObject(Expr):
277
+ def _numpycode(self, printer):
278
+ return 'numpy'
279
+
280
+ def _mpmathcode(self, printer):
281
+ return 'mpmath'
282
+
283
+
284
+ def test_printmethod():
285
+ obj = CustomPrintedObject()
286
+ assert NumPyPrinter().doprint(obj) == 'numpy'
287
+ assert MpmathPrinter().doprint(obj) == 'mpmath'
288
+
289
+
290
+ def test_codegen_ast_nodes():
291
+ assert pycode(none) == 'None'
292
+
293
+
294
+ def test_issue_14283():
295
+ prntr = PythonCodePrinter()
296
+
297
+ assert prntr.doprint(zoo) == "math.nan"
298
+ assert prntr.doprint(-oo) == "float('-inf')"
299
+
300
+
301
+ def test_NumPyPrinter_print_seq():
302
+ n = NumPyPrinter()
303
+
304
+ assert n._print_seq(range(2)) == '(0, 1,)'
305
+
306
+
307
+ def test_issue_16535_16536():
308
+ from sympy.functions.special.gamma_functions import (lowergamma, uppergamma)
309
+
310
+ a = symbols('a')
311
+ expr1 = lowergamma(a, x)
312
+ expr2 = uppergamma(a, x)
313
+
314
+ prntr = SciPyPrinter()
315
+ assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)'
316
+ assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)'
317
+
318
+ p_numpy = NumPyPrinter()
319
+ p_pycode = PythonCodePrinter({'strict': False})
320
+
321
+ for expr in [expr1, expr2]:
322
+ with raises(NotImplementedError):
323
+ p_numpy.doprint(expr1)
324
+ assert "Not supported" in p_pycode.doprint(expr)
325
+
326
+
327
+ def test_Integral():
328
+ from sympy.functions.elementary.exponential import exp
329
+ from sympy.integrals.integrals import Integral
330
+
331
+ single = Integral(exp(-x), (x, 0, oo))
332
+ double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z))
333
+ indefinite = Integral(x**2, x)
334
+ evaluateat = Integral(x**2, (x, 1))
335
+
336
+ prntr = SciPyPrinter()
337
+ assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.inf)[0]'
338
+ assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]'
339
+ raises(NotImplementedError, lambda: prntr.doprint(indefinite))
340
+ raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
341
+
342
+ prntr = MpmathPrinter()
343
+ assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))'
344
+ assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))'
345
+ raises(NotImplementedError, lambda: prntr.doprint(indefinite))
346
+ raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
347
+
348
+
349
+ def test_fresnel_integrals():
350
+ from sympy.functions.special.error_functions import (fresnelc, fresnels)
351
+
352
+ expr1 = fresnelc(x)
353
+ expr2 = fresnels(x)
354
+
355
+ prntr = SciPyPrinter()
356
+ assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]'
357
+ assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]'
358
+
359
+ p_numpy = NumPyPrinter()
360
+ p_pycode = PythonCodePrinter()
361
+ p_mpmath = MpmathPrinter()
362
+ for expr in [expr1, expr2]:
363
+ with raises(NotImplementedError):
364
+ p_numpy.doprint(expr)
365
+ with raises(NotImplementedError):
366
+ p_pycode.doprint(expr)
367
+
368
+ assert p_mpmath.doprint(expr1) == 'mpmath.fresnelc(x)'
369
+ assert p_mpmath.doprint(expr2) == 'mpmath.fresnels(x)'
370
+
371
+
372
+ def test_beta():
373
+ from sympy.functions.special.beta_functions import beta
374
+
375
+ expr = beta(x, y)
376
+
377
+ prntr = SciPyPrinter()
378
+ assert prntr.doprint(expr) == 'scipy.special.beta(x, y)'
379
+
380
+ prntr = NumPyPrinter()
381
+ assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
382
+
383
+ prntr = PythonCodePrinter()
384
+ assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
385
+
386
+ prntr = PythonCodePrinter({'allow_unknown_functions': True})
387
+ assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
388
+
389
+ prntr = MpmathPrinter()
390
+ assert prntr.doprint(expr) == 'mpmath.beta(x, y)'
391
+
392
+ def test_airy():
393
+ from sympy.functions.special.bessel import (airyai, airybi)
394
+
395
+ expr1 = airyai(x)
396
+ expr2 = airybi(x)
397
+
398
+ prntr = SciPyPrinter()
399
+ assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]'
400
+ assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]'
401
+
402
+ prntr = NumPyPrinter({'strict': False})
403
+ assert "Not supported" in prntr.doprint(expr1)
404
+ assert "Not supported" in prntr.doprint(expr2)
405
+
406
+ prntr = PythonCodePrinter({'strict': False})
407
+ assert "Not supported" in prntr.doprint(expr1)
408
+ assert "Not supported" in prntr.doprint(expr2)
409
+
410
+ def test_airy_prime():
411
+ from sympy.functions.special.bessel import (airyaiprime, airybiprime)
412
+
413
+ expr1 = airyaiprime(x)
414
+ expr2 = airybiprime(x)
415
+
416
+ prntr = SciPyPrinter()
417
+ assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]'
418
+ assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]'
419
+
420
+ prntr = NumPyPrinter({'strict': False})
421
+ assert "Not supported" in prntr.doprint(expr1)
422
+ assert "Not supported" in prntr.doprint(expr2)
423
+
424
+ prntr = PythonCodePrinter({'strict': False})
425
+ assert "Not supported" in prntr.doprint(expr1)
426
+ assert "Not supported" in prntr.doprint(expr2)
427
+
428
+
429
+ def test_numerical_accuracy_functions():
430
+ prntr = SciPyPrinter()
431
+ assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)'
432
+ assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)'
433
+ assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)'
434
+
435
+ def test_array_printer():
436
+ A = ArraySymbol('A', (4,4,6,6,6))
437
+ I = IndexedBase('I')
438
+ i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5))
439
+
440
+ prntr = NumPyPrinter()
441
+ assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))'
442
+ assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))'
443
+ assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)'
444
+ assert prntr.doprint(I) == 'I'
445
+ assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)'
446
+ assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)'
447
+ assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)'
448
+ assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
449
+
450
+ prntr = TensorflowPrinter()
451
+ assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))'
452
+ assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))'
453
+ assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)'
454
+ assert prntr.doprint(I) == 'I'
455
+ assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)'
456
+ assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)'
457
+ assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)'
458
+ assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
459
+
460
+
461
+ def test_custom_Derivative_methods():
462
+ class MyPrinter(SciPyPrinter):
463
+ def _print_Derivative_cosm1(self, args, seq_orders):
464
+ arg, = args
465
+ order, = seq_orders
466
+ return 'my_custom_cosm1(%s, deriv_order=%d)' % (self._print(arg), order)
467
+
468
+ def _print_Derivative_atan2(self, args, seq_orders):
469
+ arg1, arg2 = args
470
+ ord1, ord2 = seq_orders
471
+ return 'my_custom_atan2(%s, %s, deriv1=%d, deriv2=%d)' % (
472
+ self._print(arg1), self._print(arg2), ord1, ord2
473
+ )
474
+
475
+ p = MyPrinter()
476
+ cosm1_1 = cosm1(x).diff(x, evaluate=False)
477
+ assert p.doprint(cosm1_1) == 'my_custom_cosm1(x, deriv_order=1)'
478
+ atan2_2_3 = atan2(x, y).diff(x, 2, y, 3, evaluate=False)
479
+ assert p.doprint(atan2_2_3) == 'my_custom_atan2(x, y, deriv1=2, deriv2=3)'
480
+
481
+ try:
482
+ p.doprint(expm1(x).diff(x, evaluate=False))
483
+ except PrintMethodNotImplementedError as e:
484
+ assert '_print_Derivative_expm1' in repr(e)
485
+ else:
486
+ assert False # should have thrown
487
+
488
+ try:
489
+ p.doprint(Derivative(cosm1(x**2),x))
490
+ except ValueError as e:
491
+ assert '_print_Derivative(' in repr(e)
492
+ else:
493
+ assert False # should have thrown
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_python.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.function import (Derivative, Function)
2
+ from sympy.core.numbers import (I, Rational, oo, pi)
3
+ from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
4
+ from sympy.core.symbol import (Symbol, symbols)
5
+ from sympy.functions.elementary.complexes import (Abs, conjugate)
6
+ from sympy.functions.elementary.exponential import (exp, log)
7
+ from sympy.functions.elementary.miscellaneous import sqrt
8
+ from sympy.functions.elementary.trigonometric import sin
9
+ from sympy.integrals.integrals import Integral
10
+ from sympy.matrices.dense import Matrix
11
+ from sympy.series.limits import limit
12
+
13
+ from sympy.printing.python import python
14
+
15
+ from sympy.testing.pytest import raises, XFAIL
16
+
17
+ x, y = symbols('x,y')
18
+ th = Symbol('theta')
19
+ ph = Symbol('phi')
20
+
21
+
22
+ def test_python_basic():
23
+ # Simple numbers/symbols
24
+ assert python(-Rational(1)/2) == "e = Rational(-1, 2)"
25
+ assert python(-Rational(13)/22) == "e = Rational(-13, 22)"
26
+ assert python(oo) == "e = oo"
27
+
28
+ # Powers
29
+ assert python(x**2) == "x = Symbol(\'x\')\ne = x**2"
30
+ assert python(1/x) == "x = Symbol('x')\ne = 1/x"
31
+ assert python(y*x**-2) == "y = Symbol('y')\nx = Symbol('x')\ne = y/x**2"
32
+ assert python(
33
+ x**Rational(-5, 2)) == "x = Symbol('x')\ne = x**Rational(-5, 2)"
34
+
35
+ # Sums of terms
36
+ assert python(x**2 + x + 1) in [
37
+ "x = Symbol('x')\ne = 1 + x + x**2",
38
+ "x = Symbol('x')\ne = x + x**2 + 1",
39
+ "x = Symbol('x')\ne = x**2 + x + 1", ]
40
+ assert python(1 - x) in [
41
+ "x = Symbol('x')\ne = 1 - x",
42
+ "x = Symbol('x')\ne = -x + 1"]
43
+ assert python(1 - 2*x) in [
44
+ "x = Symbol('x')\ne = 1 - 2*x",
45
+ "x = Symbol('x')\ne = -2*x + 1"]
46
+ assert python(1 - Rational(3, 2)*y/x) in [
47
+ "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3/2*y/x",
48
+ "y = Symbol('y')\nx = Symbol('x')\ne = -3/2*y/x + 1",
49
+ "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3*y/(2*x)"]
50
+
51
+ # Multiplication
52
+ assert python(x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = x/y"
53
+ assert python(-x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = -x/y"
54
+ assert python((x + 2)/y) in [
55
+ "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(2 + x)",
56
+ "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(x + 2)",
57
+ "x = Symbol('x')\ny = Symbol('y')\ne = 1/y*(2 + x)",
58
+ "x = Symbol('x')\ny = Symbol('y')\ne = (2 + x)/y",
59
+ "x = Symbol('x')\ny = Symbol('y')\ne = (x + 2)/y"]
60
+ assert python((1 + x)*y) in [
61
+ "y = Symbol('y')\nx = Symbol('x')\ne = y*(1 + x)",
62
+ "y = Symbol('y')\nx = Symbol('x')\ne = y*(x + 1)", ]
63
+
64
+ # Check for proper placement of negative sign
65
+ assert python(-5*x/(x + 10)) == "x = Symbol('x')\ne = -5*x/(x + 10)"
66
+ assert python(1 - Rational(3, 2)*(x + 1)) in [
67
+ "x = Symbol('x')\ne = Rational(-3, 2)*x + Rational(-1, 2)",
68
+ "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)",
69
+ "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)"
70
+ ]
71
+
72
+
73
+ def test_python_keyword_symbol_name_escaping():
74
+ # Check for escaping of keywords
75
+ assert python(
76
+ 5*Symbol("lambda")) == "lambda_ = Symbol('lambda')\ne = 5*lambda_"
77
+ assert (python(5*Symbol("lambda") + 7*Symbol("lambda_")) ==
78
+ "lambda__ = Symbol('lambda')\nlambda_ = Symbol('lambda_')\ne = 7*lambda_ + 5*lambda__")
79
+ assert (python(5*Symbol("for") + Function("for_")(8)) ==
80
+ "for__ = Symbol('for')\nfor_ = Function('for_')\ne = 5*for__ + for_(8)")
81
+
82
+
83
+ def test_python_keyword_function_name_escaping():
84
+ assert python(
85
+ 5*Function("for")(8)) == "for_ = Function('for')\ne = 5*for_(8)"
86
+
87
+
88
+ def test_python_relational():
89
+ assert python(Eq(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = Eq(x, y)"
90
+ assert python(Ge(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x >= y"
91
+ assert python(Le(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x <= y"
92
+ assert python(Gt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x > y"
93
+ assert python(Lt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x < y"
94
+ assert python(Ne(x/(y + 1), y**2)) in [
95
+ "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(1 + y), y**2)",
96
+ "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(y + 1), y**2)"]
97
+
98
+
99
+ def test_python_functions():
100
+ # Simple
101
+ assert python(2*x + exp(x)) in "x = Symbol('x')\ne = 2*x + exp(x)"
102
+ assert python(sqrt(2)) == 'e = sqrt(2)'
103
+ assert python(2**Rational(1, 3)) == 'e = 2**Rational(1, 3)'
104
+ assert python(sqrt(2 + pi)) == 'e = sqrt(2 + pi)'
105
+ assert python((2 + pi)**Rational(1, 3)) == 'e = (2 + pi)**Rational(1, 3)'
106
+ assert python(2**Rational(1, 4)) == 'e = 2**Rational(1, 4)'
107
+ assert python(Abs(x)) == "x = Symbol('x')\ne = Abs(x)"
108
+ assert python(
109
+ Abs(x/(x**2 + 1))) in ["x = Symbol('x')\ne = Abs(x/(1 + x**2))",
110
+ "x = Symbol('x')\ne = Abs(x/(x**2 + 1))"]
111
+
112
+ # Univariate/Multivariate functions
113
+ f = Function('f')
114
+ assert python(f(x)) == "x = Symbol('x')\nf = Function('f')\ne = f(x)"
115
+ assert python(f(x, y)) == "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x, y)"
116
+ assert python(f(x/(y + 1), y)) in [
117
+ "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(1 + y), y)",
118
+ "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(y + 1), y)"]
119
+
120
+ # Nesting of square roots
121
+ assert python(sqrt((sqrt(x + 1)) + 1)) in [
122
+ "x = Symbol('x')\ne = sqrt(1 + sqrt(1 + x))",
123
+ "x = Symbol('x')\ne = sqrt(sqrt(x + 1) + 1)"]
124
+
125
+ # Nesting of powers
126
+ assert python((((x + 1)**Rational(1, 3)) + 1)**Rational(1, 3)) in [
127
+ "x = Symbol('x')\ne = (1 + (1 + x)**Rational(1, 3))**Rational(1, 3)",
128
+ "x = Symbol('x')\ne = ((x + 1)**Rational(1, 3) + 1)**Rational(1, 3)"]
129
+
130
+ # Function powers
131
+ assert python(sin(x)**2) == "x = Symbol('x')\ne = sin(x)**2"
132
+
133
+
134
+ @XFAIL
135
+ def test_python_functions_conjugates():
136
+ a, b = map(Symbol, 'ab')
137
+ assert python( conjugate(a + b*I) ) == '_ _\na - I*b'
138
+ assert python( conjugate(exp(a + b*I)) ) == ' _ _\n a - I*b\ne '
139
+
140
+
141
+ def test_python_derivatives():
142
+ # Simple
143
+ f_1 = Derivative(log(x), x, evaluate=False)
144
+ assert python(f_1) == "x = Symbol('x')\ne = Derivative(log(x), x)"
145
+
146
+ f_2 = Derivative(log(x), x, evaluate=False) + x
147
+ assert python(f_2) == "x = Symbol('x')\ne = x + Derivative(log(x), x)"
148
+
149
+ # Multiple symbols
150
+ f_3 = Derivative(log(x) + x**2, x, y, evaluate=False)
151
+ assert python(f_3) == \
152
+ "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(x**2 + log(x), x, y)"
153
+
154
+ f_4 = Derivative(2*x*y, y, x, evaluate=False) + x**2
155
+ assert python(f_4) in [
156
+ "x = Symbol('x')\ny = Symbol('y')\ne = x**2 + Derivative(2*x*y, y, x)",
157
+ "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(2*x*y, y, x) + x**2"]
158
+
159
+
160
+ def test_python_integrals():
161
+ # Simple
162
+ f_1 = Integral(log(x), x)
163
+ assert python(f_1) == "x = Symbol('x')\ne = Integral(log(x), x)"
164
+
165
+ f_2 = Integral(x**2, x)
166
+ assert python(f_2) == "x = Symbol('x')\ne = Integral(x**2, x)"
167
+
168
+ # Double nesting of pow
169
+ f_3 = Integral(x**(2**x), x)
170
+ assert python(f_3) == "x = Symbol('x')\ne = Integral(x**(2**x), x)"
171
+
172
+ # Definite integrals
173
+ f_4 = Integral(x**2, (x, 1, 2))
174
+ assert python(f_4) == "x = Symbol('x')\ne = Integral(x**2, (x, 1, 2))"
175
+
176
+ f_5 = Integral(x**2, (x, Rational(1, 2), 10))
177
+ assert python(
178
+ f_5) == "x = Symbol('x')\ne = Integral(x**2, (x, Rational(1, 2), 10))"
179
+
180
+ # Nested integrals
181
+ f_6 = Integral(x**2*y**2, x, y)
182
+ assert python(f_6) == "x = Symbol('x')\ny = Symbol('y')\ne = Integral(x**2*y**2, x, y)"
183
+
184
+
185
+ def test_python_matrix():
186
+ p = python(Matrix([[x**2+1, 1], [y, x+y]]))
187
+ s = "x = Symbol('x')\ny = Symbol('y')\ne = MutableDenseMatrix([[x**2 + 1, 1], [y, x + y]])"
188
+ assert p == s
189
+
190
+ def test_python_limits():
191
+ assert python(limit(x, x, oo)) == 'e = oo'
192
+ assert python(limit(x**2, x, 0)) == 'e = 0'
193
+
194
+ def test_issue_20762():
195
+ # Make sure Python removes curly braces from subscripted variables
196
+ a_b = Symbol('a_{b}')
197
+ b = Symbol('b')
198
+ expr = a_b*b
199
+ assert python(expr) == "a_b = Symbol('a_{b}')\nb = Symbol('b')\ne = a_b*b"
200
+
201
+
202
+ def test_settings():
203
+ raises(TypeError, lambda: python(x, method="garbage"))
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_repr.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any
3
+
4
+ from sympy.external.gmpy import GROUND_TYPES
5
+ from sympy.testing.pytest import raises, warns_deprecated_sympy
6
+ from sympy.assumptions.ask import Q
7
+ from sympy.core.function import (Function, WildFunction)
8
+ from sympy.core.numbers import (AlgebraicNumber, Float, Integer, Rational)
9
+ from sympy.core.singleton import S
10
+ from sympy.core.symbol import (Dummy, Symbol, Wild, symbols)
11
+ from sympy.core.sympify import sympify
12
+ from sympy.functions.elementary.complexes import Abs
13
+ from sympy.functions.elementary.miscellaneous import (root, sqrt)
14
+ from sympy.functions.elementary.trigonometric import sin
15
+ from sympy.functions.special.delta_functions import Heaviside
16
+ from sympy.logic.boolalg import (false, true)
17
+ from sympy.matrices.dense import (Matrix, ones)
18
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
19
+ from sympy.matrices.immutable import ImmutableDenseMatrix
20
+ from sympy.combinatorics import Cycle, Permutation
21
+ from sympy.core.symbol import Str
22
+ from sympy.geometry import Point, Ellipse
23
+ from sympy.printing import srepr
24
+ from sympy.polys import ring, field, ZZ, QQ, lex, grlex, Poly
25
+ from sympy.polys.polyclasses import DMP
26
+ from sympy.polys.agca.extensions import FiniteExtension
27
+
28
+ x, y = symbols('x,y')
29
+
30
+ # eval(srepr(expr)) == expr has to succeed in the right environment. The right
31
+ # environment is the scope of "from sympy import *" for most cases.
32
+ ENV: dict[str, Any] = {"Str": Str}
33
+ exec("from sympy import *", ENV)
34
+
35
+
36
+ def sT(expr, string, import_stmt=None, **kwargs):
37
+ """
38
+ sT := sreprTest
39
+
40
+ Tests that srepr delivers the expected string and that
41
+ the condition eval(srepr(expr))==expr holds.
42
+ """
43
+ if import_stmt is None:
44
+ ENV2 = ENV
45
+ else:
46
+ ENV2 = ENV.copy()
47
+ exec(import_stmt, ENV2)
48
+
49
+ assert srepr(expr, **kwargs) == string
50
+ assert eval(string, ENV2) == expr
51
+
52
+
53
+ def test_printmethod():
54
+ class R(Abs):
55
+ def _sympyrepr(self, printer):
56
+ return "foo(%s)" % printer._print(self.args[0])
57
+ assert srepr(R(x)) == "foo(Symbol('x'))"
58
+
59
+
60
+ def test_Add():
61
+ sT(x + y, "Add(Symbol('x'), Symbol('y'))")
62
+ assert srepr(x**2 + 1, order='lex') == "Add(Pow(Symbol('x'), Integer(2)), Integer(1))"
63
+ assert srepr(x**2 + 1, order='old') == "Add(Integer(1), Pow(Symbol('x'), Integer(2)))"
64
+ assert srepr(sympify('x + 3 - 2', evaluate=False), order='none') == "Add(Symbol('x'), Integer(3), Mul(Integer(-1), Integer(2)))"
65
+
66
+
67
+ def test_more_than_255_args_issue_10259():
68
+ from sympy.core.add import Add
69
+ from sympy.core.mul import Mul
70
+ for op in (Add, Mul):
71
+ expr = op(*symbols('x:256'))
72
+ assert eval(srepr(expr)) == expr
73
+
74
+
75
+ def test_Function():
76
+ sT(Function("f")(x), "Function('f')(Symbol('x'))")
77
+ # test unapplied Function
78
+ sT(Function('f'), "Function('f')")
79
+
80
+ sT(sin(x), "sin(Symbol('x'))")
81
+ sT(sin, "sin")
82
+
83
+
84
+ def test_Heaviside():
85
+ sT(Heaviside(x), "Heaviside(Symbol('x'))")
86
+ sT(Heaviside(x, 1), "Heaviside(Symbol('x'), Integer(1))")
87
+
88
+
89
+ def test_Geometry():
90
+ sT(Point(0, 0), "Point2D(Integer(0), Integer(0))")
91
+ sT(Ellipse(Point(0, 0), 5, 1),
92
+ "Ellipse(Point2D(Integer(0), Integer(0)), Integer(5), Integer(1))")
93
+ # TODO more tests
94
+
95
+
96
+ def test_Singletons():
97
+ sT(S.Catalan, 'Catalan')
98
+ sT(S.ComplexInfinity, 'zoo')
99
+ sT(S.EulerGamma, 'EulerGamma')
100
+ sT(S.Exp1, 'E')
101
+ sT(S.GoldenRatio, 'GoldenRatio')
102
+ sT(S.TribonacciConstant, 'TribonacciConstant')
103
+ sT(S.Half, 'Rational(1, 2)')
104
+ sT(S.ImaginaryUnit, 'I')
105
+ sT(S.Infinity, 'oo')
106
+ sT(S.NaN, 'nan')
107
+ sT(S.NegativeInfinity, '-oo')
108
+ sT(S.NegativeOne, 'Integer(-1)')
109
+ sT(S.One, 'Integer(1)')
110
+ sT(S.Pi, 'pi')
111
+ sT(S.Zero, 'Integer(0)')
112
+ sT(S.Complexes, 'Complexes')
113
+ sT(S.EmptySequence, 'EmptySequence')
114
+ sT(S.EmptySet, 'EmptySet')
115
+ # sT(S.IdentityFunction, 'Lambda(_x, _x)')
116
+ sT(S.Naturals, 'Naturals')
117
+ sT(S.Naturals0, 'Naturals0')
118
+ sT(S.Rationals, 'Rationals')
119
+ sT(S.Reals, 'Reals')
120
+ sT(S.UniversalSet, 'UniversalSet')
121
+
122
+
123
+ def test_Integer():
124
+ sT(Integer(4), "Integer(4)")
125
+
126
+
127
+ def test_list():
128
+ sT([x, Integer(4)], "[Symbol('x'), Integer(4)]")
129
+
130
+
131
+ def test_Matrix():
132
+ for cls, name in [(Matrix, "MutableDenseMatrix"), (ImmutableDenseMatrix, "ImmutableDenseMatrix")]:
133
+ sT(cls([[x**+1, 1], [y, x + y]]),
134
+ "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name)
135
+
136
+ sT(cls(), "%s([])" % name)
137
+
138
+ sT(cls([[x**+1, 1], [y, x + y]]), "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name)
139
+
140
+
141
+ def test_empty_Matrix():
142
+ sT(ones(0, 3), "MutableDenseMatrix(0, 3, [])")
143
+ sT(ones(4, 0), "MutableDenseMatrix(4, 0, [])")
144
+ sT(ones(0, 0), "MutableDenseMatrix([])")
145
+
146
+
147
+ def test_Rational():
148
+ sT(Rational(1, 3), "Rational(1, 3)")
149
+ sT(Rational(-1, 3), "Rational(-1, 3)")
150
+
151
+
152
+ def test_Float():
153
+ sT(Float('1.23', dps=3), "Float('1.22998', precision=13)")
154
+ sT(Float('1.23456789', dps=9), "Float('1.23456788994', precision=33)")
155
+ sT(Float('1.234567890123456789', dps=19),
156
+ "Float('1.234567890123456789013', precision=66)")
157
+ sT(Float('0.60038617995049726', dps=15),
158
+ "Float('0.60038617995049726', precision=53)")
159
+
160
+ sT(Float('1.23', precision=13), "Float('1.22998', precision=13)")
161
+ sT(Float('1.23456789', precision=33),
162
+ "Float('1.23456788994', precision=33)")
163
+ sT(Float('1.234567890123456789', precision=66),
164
+ "Float('1.234567890123456789013', precision=66)")
165
+ sT(Float('0.60038617995049726', precision=53),
166
+ "Float('0.60038617995049726', precision=53)")
167
+
168
+ sT(Float('0.60038617995049726', 15),
169
+ "Float('0.60038617995049726', precision=53)")
170
+
171
+
172
+ def test_Symbol():
173
+ sT(x, "Symbol('x')")
174
+ sT(y, "Symbol('y')")
175
+ sT(Symbol('x', negative=True), "Symbol('x', negative=True)")
176
+
177
+
178
+ def test_Symbol_two_assumptions():
179
+ x = Symbol('x', negative=0, integer=1)
180
+ # order could vary
181
+ s1 = "Symbol('x', integer=True, negative=False)"
182
+ s2 = "Symbol('x', negative=False, integer=True)"
183
+ assert srepr(x) in (s1, s2)
184
+ assert eval(srepr(x), ENV) == x
185
+
186
+
187
+ def test_Symbol_no_special_commutative_treatment():
188
+ sT(Symbol('x'), "Symbol('x')")
189
+ sT(Symbol('x', commutative=False), "Symbol('x', commutative=False)")
190
+ sT(Symbol('x', commutative=0), "Symbol('x', commutative=False)")
191
+ sT(Symbol('x', commutative=True), "Symbol('x', commutative=True)")
192
+ sT(Symbol('x', commutative=1), "Symbol('x', commutative=True)")
193
+
194
+
195
+ def test_Wild():
196
+ sT(Wild('x', even=True), "Wild('x', even=True)")
197
+
198
+
199
+ def test_Dummy():
200
+ d = Dummy('d')
201
+ sT(d, "Dummy('d', dummy_index=%s)" % str(d.dummy_index))
202
+
203
+
204
+ def test_Dummy_assumption():
205
+ d = Dummy('d', nonzero=True)
206
+ assert d == eval(srepr(d))
207
+ s1 = "Dummy('d', dummy_index=%s, nonzero=True)" % str(d.dummy_index)
208
+ s2 = "Dummy('d', nonzero=True, dummy_index=%s)" % str(d.dummy_index)
209
+ assert srepr(d) in (s1, s2)
210
+
211
+
212
+ def test_Dummy_from_Symbol():
213
+ # should not get the full dictionary of assumptions
214
+ n = Symbol('n', integer=True)
215
+ d = n.as_dummy()
216
+ assert srepr(d
217
+ ) == "Dummy('n', dummy_index=%s)" % str(d.dummy_index)
218
+
219
+
220
+ def test_tuple():
221
+ sT((x,), "(Symbol('x'),)")
222
+ sT((x, y), "(Symbol('x'), Symbol('y'))")
223
+
224
+
225
+ def test_WildFunction():
226
+ sT(WildFunction('w'), "WildFunction('w')")
227
+
228
+
229
+ def test_settins():
230
+ raises(TypeError, lambda: srepr(x, method="garbage"))
231
+
232
+
233
+ def test_Mul():
234
+ sT(3*x**3*y, "Mul(Integer(3), Pow(Symbol('x'), Integer(3)), Symbol('y'))")
235
+ assert srepr(3*x**3*y, order='old') == "Mul(Integer(3), Symbol('y'), Pow(Symbol('x'), Integer(3)))"
236
+ assert srepr(sympify('(x+4)*2*x*7', evaluate=False), order='none') == "Mul(Add(Symbol('x'), Integer(4)), Integer(2), Symbol('x'), Integer(7))"
237
+
238
+
239
+ def test_AlgebraicNumber():
240
+ a = AlgebraicNumber(sqrt(2))
241
+ sT(a, "AlgebraicNumber(Pow(Integer(2), Rational(1, 2)), [Integer(1), Integer(0)])")
242
+ a = AlgebraicNumber(root(-2, 3))
243
+ sT(a, "AlgebraicNumber(Pow(Integer(-2), Rational(1, 3)), [Integer(1), Integer(0)])")
244
+
245
+
246
+ def test_PolyRing():
247
+ assert srepr(ring("x", ZZ, lex)[0]) == "PolyRing((Symbol('x'),), ZZ, lex)"
248
+ assert srepr(ring("x,y", QQ, grlex)[0]) == "PolyRing((Symbol('x'), Symbol('y')), QQ, grlex)"
249
+ assert srepr(ring("x,y,z", ZZ["t"], lex)[0]) == "PolyRing((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)"
250
+
251
+
252
+ def test_FracField():
253
+ assert srepr(field("x", ZZ, lex)[0]) == "FracField((Symbol('x'),), ZZ, lex)"
254
+ assert srepr(field("x,y", QQ, grlex)[0]) == "FracField((Symbol('x'), Symbol('y')), QQ, grlex)"
255
+ assert srepr(field("x,y,z", ZZ["t"], lex)[0]) == "FracField((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)"
256
+
257
+
258
+ def test_PolyElement():
259
+ R, x, y = ring("x,y", ZZ)
260
+ assert srepr(3*x**2*y + 1) == "PolyElement(PolyRing((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)])"
261
+
262
+
263
+ def test_FracElement():
264
+ F, x, y = field("x,y", ZZ)
265
+ assert srepr((3*x**2*y + 1)/(x - y**2)) == "FracElement(FracField((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)], [((1, 0), 1), ((0, 2), -1)])"
266
+
267
+
268
+ def test_FractionField():
269
+ assert srepr(QQ.frac_field(x)) == \
270
+ "FractionField(FracField((Symbol('x'),), QQ, lex))"
271
+ assert srepr(QQ.frac_field(x, y, order=grlex)) == \
272
+ "FractionField(FracField((Symbol('x'), Symbol('y')), QQ, grlex))"
273
+
274
+
275
+ def test_PolynomialRingBase():
276
+ assert srepr(ZZ.old_poly_ring(x)) == \
277
+ "GlobalPolynomialRing(ZZ, Symbol('x'))"
278
+ assert srepr(ZZ[x].old_poly_ring(y)) == \
279
+ "GlobalPolynomialRing(ZZ[x], Symbol('y'))"
280
+ assert srepr(QQ.frac_field(x).old_poly_ring(y)) == \
281
+ "GlobalPolynomialRing(FractionField(FracField((Symbol('x'),), QQ, lex)), Symbol('y'))"
282
+
283
+
284
+ def test_DMP():
285
+ p1 = DMP([1, 2], ZZ)
286
+ p2 = ZZ.old_poly_ring(x)([1, 2])
287
+ if GROUND_TYPES != 'flint':
288
+ assert srepr(p1) == "DMP_Python([1, 2], ZZ)"
289
+ assert srepr(p2) == "DMP_Python([1, 2], ZZ)"
290
+ else:
291
+ assert srepr(p1) == "DUP_Flint([1, 2], ZZ)"
292
+ assert srepr(p2) == "DUP_Flint([1, 2], ZZ)"
293
+
294
+
295
+ def test_FiniteExtension():
296
+ assert srepr(FiniteExtension(Poly(x**2 + 1, x))) == \
297
+ "FiniteExtension(Poly(x**2 + 1, x, domain='ZZ'))"
298
+
299
+
300
+ def test_ExtensionElement():
301
+ A = FiniteExtension(Poly(x**2 + 1, x))
302
+ if GROUND_TYPES != 'flint':
303
+ ans = "ExtElem(DMP_Python([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))"
304
+ else:
305
+ ans = "ExtElem(DUP_Flint([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))"
306
+ assert srepr(A.generator) == ans
307
+
308
+ def test_BooleanAtom():
309
+ assert srepr(true) == "true"
310
+ assert srepr(false) == "false"
311
+
312
+
313
+ def test_Integers():
314
+ sT(S.Integers, "Integers")
315
+
316
+
317
+ def test_Naturals():
318
+ sT(S.Naturals, "Naturals")
319
+
320
+
321
+ def test_Naturals0():
322
+ sT(S.Naturals0, "Naturals0")
323
+
324
+
325
+ def test_Reals():
326
+ sT(S.Reals, "Reals")
327
+
328
+
329
+ def test_matrix_expressions():
330
+ n = symbols('n', integer=True)
331
+ A = MatrixSymbol("A", n, n)
332
+ B = MatrixSymbol("B", n, n)
333
+ sT(A, "MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True))")
334
+ sT(A*B, "MatMul(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))")
335
+ sT(A + B, "MatAdd(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))")
336
+
337
+
338
+ def test_Cycle():
339
+ # FIXME: sT fails because Cycle is not immutable and calling srepr(Cycle(1, 2))
340
+ # adds keys to the Cycle dict (GH-17661)
341
+ #import_stmt = "from sympy.combinatorics import Cycle"
342
+ #sT(Cycle(1, 2), "Cycle(1, 2)", import_stmt)
343
+ assert srepr(Cycle(1, 2)) == "Cycle(1, 2)"
344
+
345
+
346
+ def test_Permutation():
347
+ import_stmt = "from sympy.combinatorics import Permutation"
348
+ sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt, perm_cyclic=False)
349
+ sT(Permutation(1, 2)(3, 4), "Permutation(1, 2)(3, 4)", import_stmt, perm_cyclic=True)
350
+
351
+ with warns_deprecated_sympy():
352
+ old_print_cyclic = Permutation.print_cyclic
353
+ Permutation.print_cyclic = False
354
+ sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt)
355
+ Permutation.print_cyclic = old_print_cyclic
356
+
357
+ def test_dict():
358
+ from sympy.abc import x, y, z
359
+ d = {}
360
+ assert srepr(d) == "{}"
361
+ d = {x: y}
362
+ assert srepr(d) == "{Symbol('x'): Symbol('y')}"
363
+ d = {x: y, y: z}
364
+ assert srepr(d) in (
365
+ "{Symbol('x'): Symbol('y'), Symbol('y'): Symbol('z')}",
366
+ "{Symbol('y'): Symbol('z'), Symbol('x'): Symbol('y')}",
367
+ )
368
+ d = {x: {y: z}}
369
+ assert srepr(d) == "{Symbol('x'): {Symbol('y'): Symbol('z')}}"
370
+
371
+ def test_set():
372
+ from sympy.abc import x, y
373
+ s = set()
374
+ assert srepr(s) == "set()"
375
+ s = {x, y}
376
+ assert srepr(s) in ("{Symbol('x'), Symbol('y')}", "{Symbol('y'), Symbol('x')}")
377
+
378
+ def test_Predicate():
379
+ sT(Q.even, "Q.even")
380
+
381
+ def test_AppliedPredicate():
382
+ sT(Q.even(Symbol('z')), "AppliedPredicate(Q.even, Symbol('z'))")
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_rust.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (S, pi, oo, symbols, Rational, Integer,
2
+ GoldenRatio, EulerGamma, Catalan, Lambda, Dummy,
3
+ Eq, Ne, Le, Lt, Gt, Ge, Mod)
4
+ from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
5
+ sign, floor)
6
+ from sympy.logic import ITE
7
+ from sympy.testing.pytest import raises
8
+ from sympy.utilities.lambdify import implemented_function
9
+ from sympy.tensor import IndexedBase, Idx
10
+ from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix
11
+
12
+ from sympy.printing.codeprinter import rust_code
13
+
14
+ x, y, z = symbols('x,y,z', integer=False, real=True)
15
+ k, m, n = symbols('k,m,n', integer=True)
16
+
17
+
18
+ def test_Integer():
19
+ assert rust_code(Integer(42)) == "42"
20
+ assert rust_code(Integer(-56)) == "-56"
21
+
22
+
23
+ def test_Relational():
24
+ assert rust_code(Eq(x, y)) == "x == y"
25
+ assert rust_code(Ne(x, y)) == "x != y"
26
+ assert rust_code(Le(x, y)) == "x <= y"
27
+ assert rust_code(Lt(x, y)) == "x < y"
28
+ assert rust_code(Gt(x, y)) == "x > y"
29
+ assert rust_code(Ge(x, y)) == "x >= y"
30
+
31
+
32
+ def test_Rational():
33
+ assert rust_code(Rational(3, 7)) == "3_f64/7.0"
34
+ assert rust_code(Rational(18, 9)) == "2"
35
+ assert rust_code(Rational(3, -7)) == "-3_f64/7.0"
36
+ assert rust_code(Rational(-3, -7)) == "3_f64/7.0"
37
+ assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0"
38
+ assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x"
39
+
40
+
41
+ def test_basic_ops():
42
+ assert rust_code(x + y) == "x + y"
43
+ assert rust_code(x - y) == "x - y"
44
+ assert rust_code(x * y) == "x*y"
45
+ assert rust_code(x / y) == "x*y.recip()"
46
+ assert rust_code(-x) == "-x"
47
+ assert rust_code(2 * x) == "2.0*x"
48
+ assert rust_code(y + 2) == "y + 2.0"
49
+ assert rust_code(x + n) == "n as f64 + x"
50
+
51
+ def test_printmethod():
52
+ class fabs(Abs):
53
+ def _rust_code(self, printer):
54
+ return "%s.fabs()" % printer._print(self.args[0])
55
+ assert rust_code(fabs(x)) == "x.fabs()"
56
+ a = MatrixSymbol("a", 1, 3)
57
+ assert rust_code(a[0,0]) == 'a[0]'
58
+
59
+
60
+ def test_Functions():
61
+ assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())"
62
+ assert rust_code(abs(x)) == "x.abs()"
63
+ assert rust_code(ceiling(x)) == "x.ceil()"
64
+ assert rust_code(floor(x)) == "x.floor()"
65
+
66
+ # Automatic rewrite
67
+ assert rust_code(Mod(x, 3)) == 'x - 3.0*((1_f64/3.0)*x).floor()'
68
+
69
+
70
+ def test_Pow():
71
+ assert rust_code(1/x) == "x.recip()"
72
+ assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()"
73
+ assert rust_code(sqrt(x)) == "x.sqrt()"
74
+ assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()"
75
+
76
+ assert rust_code(1/sqrt(x)) == "x.sqrt().recip()"
77
+ assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()"
78
+
79
+ assert rust_code(1/pi) == "PI.recip()"
80
+ assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()"
81
+ assert rust_code(pi**-0.5) == "PI.sqrt().recip()"
82
+
83
+ assert rust_code(x**Rational(1, 3)) == "x.cbrt()"
84
+ assert rust_code(2**x) == "x.exp2()"
85
+ assert rust_code(exp(x)) == "x.exp()"
86
+ assert rust_code(x**3) == "x.powi(3)"
87
+ assert rust_code(x**(y**3)) == "x.powf(y.powi(3))"
88
+ assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)"
89
+
90
+ g = implemented_function('g', Lambda(x, 2*x))
91
+ assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
92
+ "(3.5*2.0*x).powf(-x + y.powf(x))/(x.powi(2) + y)"
93
+ _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1),
94
+ (lambda base, exp: not exp.is_integer, "pow", 1)]
95
+ assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)'
96
+ assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)'
97
+
98
+
99
+ def test_constants():
100
+ assert rust_code(pi) == "PI"
101
+ assert rust_code(oo) == "INFINITY"
102
+ assert rust_code(S.Infinity) == "INFINITY"
103
+ assert rust_code(-oo) == "NEG_INFINITY"
104
+ assert rust_code(S.NegativeInfinity) == "NEG_INFINITY"
105
+ assert rust_code(S.NaN) == "NAN"
106
+ assert rust_code(exp(1)) == "E"
107
+ assert rust_code(S.Exp1) == "E"
108
+
109
+
110
+ def test_constants_other():
111
+ assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2.0*GoldenRatio" % GoldenRatio.evalf(17)
112
+ assert rust_code(
113
+ 2*Catalan) == "const Catalan: f64 = %s;\n2.0*Catalan" % Catalan.evalf(17)
114
+ assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2.0*EulerGamma" % EulerGamma.evalf(17)
115
+
116
+
117
+ def test_boolean():
118
+ assert rust_code(True) == "true"
119
+ assert rust_code(S.true) == "true"
120
+ assert rust_code(False) == "false"
121
+ assert rust_code(S.false) == "false"
122
+ assert rust_code(k & m) == "k && m"
123
+ assert rust_code(k | m) == "k || m"
124
+ assert rust_code(~k) == "!k"
125
+ assert rust_code(k & m & n) == "k && m && n"
126
+ assert rust_code(k | m | n) == "k || m || n"
127
+ assert rust_code((k & m) | n) == "n || k && m"
128
+ assert rust_code((k | m) & n) == "n && (k || m)"
129
+
130
+
131
+ def test_Piecewise():
132
+ expr = Piecewise((x, x < 1), (x + 2, True))
133
+ assert rust_code(expr) == (
134
+ "if (x < 1.0) {\n"
135
+ " x\n"
136
+ "} else {\n"
137
+ " x + 2.0\n"
138
+ "}")
139
+ assert rust_code(expr, assign_to="r") == (
140
+ "r = if (x < 1.0) {\n"
141
+ " x\n"
142
+ "} else {\n"
143
+ " x + 2.0\n"
144
+ "};")
145
+ assert rust_code(expr, assign_to="r", inline=True) == (
146
+ "r = if (x < 1.0) { x } else { x + 2.0 };")
147
+ expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
148
+ assert rust_code(expr, inline=True) == (
149
+ "if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }")
150
+ assert rust_code(expr, assign_to="r", inline=True) == (
151
+ "r = if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 };")
152
+ assert rust_code(expr, assign_to="r") == (
153
+ "r = if (x < 1.0) {\n"
154
+ " x\n"
155
+ "} else if (x < 5.0) {\n"
156
+ " x + 1.0\n"
157
+ "} else {\n"
158
+ " x + 2.0\n"
159
+ "};")
160
+ expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
161
+ assert rust_code(expr, inline=True) == (
162
+ "2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }")
163
+ expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42
164
+ assert rust_code(expr, inline=True) == (
165
+ "2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 } - 42.0")
166
+ # Check that Piecewise without a True (default) condition error
167
+ expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
168
+ raises(ValueError, lambda: rust_code(expr))
169
+
170
+
171
+ def test_dereference_printing():
172
+ expr = x + y + sin(z) + z
173
+ assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()"
174
+
175
+
176
+ def test_sign():
177
+ expr = sign(x) * y
178
+ assert rust_code(expr) == "y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64"
179
+ assert rust_code(expr, assign_to='r') == "r = y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64;"
180
+
181
+ expr = sign(x + y) + 42
182
+ assert rust_code(expr) == "(if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42"
183
+ assert rust_code(expr, assign_to='r') == "r = (if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42;"
184
+
185
+ expr = sign(cos(x))
186
+ assert rust_code(expr) == "(if (x.cos() == 0.0) { 0.0 } else { (x.cos()).signum() })"
187
+
188
+
189
+ def test_reserved_words():
190
+
191
+ x, y = symbols("x if")
192
+
193
+ expr = sin(y)
194
+ assert rust_code(expr) == "if_.sin()"
195
+ assert rust_code(expr, dereference=[y]) == "(*if_).sin()"
196
+ assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()"
197
+
198
+ with raises(ValueError):
199
+ rust_code(expr, error_on_reserved=True)
200
+
201
+
202
+ def test_ITE():
203
+ ekpr = ITE(k < 1, m, n)
204
+ assert rust_code(ekpr) == (
205
+ "if (k < 1) {\n"
206
+ " m\n"
207
+ "} else {\n"
208
+ " n\n"
209
+ "}")
210
+
211
+
212
+ def test_Indexed():
213
+ n, m, o = symbols('n m o', integer=True)
214
+ i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
215
+
216
+ x = IndexedBase('x')[j]
217
+ assert rust_code(x) == "x[j]"
218
+
219
+ A = IndexedBase('A')[i, j]
220
+ assert rust_code(A) == "A[m*i + j]"
221
+
222
+ B = IndexedBase('B')[i, j, k]
223
+ assert rust_code(B) == "B[m*o*i + o*j + k]"
224
+
225
+
226
+ def test_dummy_loops():
227
+ i, m = symbols('i m', integer=True, cls=Dummy)
228
+ x = IndexedBase('x')
229
+ y = IndexedBase('y')
230
+ i = Idx(i, m)
231
+
232
+ assert rust_code(x[i], assign_to=y[i]) == (
233
+ "for i in 0..m {\n"
234
+ " y[i] = x[i];\n"
235
+ "}")
236
+
237
+
238
+ def test_loops():
239
+ m, n = symbols('m n', integer=True)
240
+ A = IndexedBase('A')
241
+ x = IndexedBase('x')
242
+ y = IndexedBase('y')
243
+ z = IndexedBase('z')
244
+ i = Idx('i', m)
245
+ j = Idx('j', n)
246
+
247
+ assert rust_code(A[i, j]*x[j], assign_to=y[i]) == (
248
+ "for i in 0..m {\n"
249
+ " y[i] = 0;\n"
250
+ "}\n"
251
+ "for i in 0..m {\n"
252
+ " for j in 0..n {\n"
253
+ " y[i] = A[n*i + j]*x[j] + y[i];\n"
254
+ " }\n"
255
+ "}")
256
+
257
+ assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == (
258
+ "for i in 0..m {\n"
259
+ " y[i] = x[i] + z[i];\n"
260
+ "}\n"
261
+ "for i in 0..m {\n"
262
+ " for j in 0..n {\n"
263
+ " y[i] = A[n*i + j]*x[j] + y[i];\n"
264
+ " }\n"
265
+ "}")
266
+
267
+
268
+ def test_loops_multiple_contractions():
269
+ n, m, o, p = symbols('n m o p', integer=True)
270
+ a = IndexedBase('a')
271
+ b = IndexedBase('b')
272
+ y = IndexedBase('y')
273
+ i = Idx('i', m)
274
+ j = Idx('j', n)
275
+ k = Idx('k', o)
276
+ l = Idx('l', p)
277
+
278
+ assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == (
279
+ "for i in 0..m {\n"
280
+ " y[i] = 0;\n"
281
+ "}\n"
282
+ "for i in 0..m {\n"
283
+ " for j in 0..n {\n"
284
+ " for k in 0..o {\n"
285
+ " for l in 0..p {\n"
286
+ " y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
287
+ " }\n"
288
+ " }\n"
289
+ " }\n"
290
+ "}")
291
+
292
+
293
+ def test_loops_addfactor():
294
+ m, n, o, p = symbols('m n o p', integer=True)
295
+ a = IndexedBase('a')
296
+ b = IndexedBase('b')
297
+ c = IndexedBase('c')
298
+ y = IndexedBase('y')
299
+ i = Idx('i', m)
300
+ j = Idx('j', n)
301
+ k = Idx('k', o)
302
+ l = Idx('l', p)
303
+
304
+ code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
305
+ assert code == (
306
+ "for i in 0..m {\n"
307
+ " y[i] = 0;\n"
308
+ "}\n"
309
+ "for i in 0..m {\n"
310
+ " for j in 0..n {\n"
311
+ " for k in 0..o {\n"
312
+ " for l in 0..p {\n"
313
+ " y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
314
+ " }\n"
315
+ " }\n"
316
+ " }\n"
317
+ "}")
318
+
319
+
320
+ def test_settings():
321
+ raises(TypeError, lambda: rust_code(sin(x), method="garbage"))
322
+
323
+
324
+ def test_inline_function():
325
+ x = symbols('x')
326
+ g = implemented_function('g', Lambda(x, 2*x))
327
+ assert rust_code(g(x)) == "2*x"
328
+
329
+ g = implemented_function('g', Lambda(x, 2*x/Catalan))
330
+ assert rust_code(g(x)) == (
331
+ "const Catalan: f64 = %s;\n2.0*x/Catalan" % Catalan.evalf(17))
332
+
333
+ A = IndexedBase('A')
334
+ i = Idx('i', symbols('n', integer=True))
335
+ g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
336
+ assert rust_code(g(A[i]), assign_to=A[i]) == (
337
+ "for i in 0..n {\n"
338
+ " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
339
+ "}")
340
+
341
+
342
+ def test_user_functions():
343
+ x = symbols('x', integer=False)
344
+ n = symbols('n', integer=True)
345
+ custom_functions = {
346
+ "ceiling": "ceil",
347
+ "Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)],
348
+ }
349
+ assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()"
350
+ assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)"
351
+ assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)"
352
+
353
+
354
+ def test_matrix():
355
+ assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]'
356
+ with raises(ValueError):
357
+ rust_code(Matrix([[1, 2, 3]]))
358
+
359
+
360
+ def test_sparse_matrix():
361
+ # gh-15791
362
+ with raises(NotImplementedError):
363
+ rust_code(SparseMatrix([[1, 2, 3]]))
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_smtlib.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import itertools
3
+ import re
4
+ import typing
5
+ from enum import Enum
6
+ from typing import Callable
7
+
8
+ import sympy
9
+ from sympy import Add, Implies, sqrt
10
+ from sympy.core import Mul, Pow
11
+ from sympy.core import (S, pi, symbols, Function, Rational, Integer,
12
+ Symbol, Eq, Ne, Le, Lt, Gt, Ge)
13
+ from sympy.functions import Piecewise, exp, sin, cos
14
+ from sympy.assumptions.ask import Q
15
+ from sympy.printing.smtlib import smtlib_code
16
+ from sympy.testing.pytest import raises, Failed
17
+
18
+ x, y, z = symbols('x,y,z')
19
+
20
+
21
+ class _W(Enum):
22
+ DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.IGNORECASE)
23
+ WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.IGNORECASE)
24
+ WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.IGNORECASE)
25
+
26
+
27
+ @contextlib.contextmanager
28
+ def _check_warns(expected: typing.Iterable[_W]):
29
+ warns: typing.List[str] = []
30
+ log_warn = warns.append
31
+ yield log_warn
32
+
33
+ errors = []
34
+ for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)):
35
+ if not e:
36
+ errors += [f"[{i}] Received unexpected warning `{w}`."]
37
+ elif not w:
38
+ errors += [f"[{i}] Did not receive expected warning `{e.name}`."]
39
+ elif not e.value.match(w):
40
+ errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."]
41
+
42
+ if errors: raise Failed('\n'.join(errors))
43
+
44
+
45
+ def test_Integer():
46
+ with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w:
47
+ assert smtlib_code(Integer(67), log_warn=w) == "67"
48
+ assert smtlib_code(Integer(-1), log_warn=w) == "-1"
49
+ with _check_warns([]) as w:
50
+ assert smtlib_code(Integer(67)) == "67"
51
+ assert smtlib_code(Integer(-1)) == "-1"
52
+
53
+
54
+ def test_Rational():
55
+ with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w:
56
+ assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)"
57
+ assert smtlib_code(Rational(18, 9), log_warn=w) == "2"
58
+ assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)"
59
+ assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)"
60
+
61
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w:
62
+ assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)"
63
+ assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \
64
+ "(* (/ 3 7) x)"
65
+
66
+
67
+ def test_Relational():
68
+ with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
69
+ assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
70
+ assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
71
+ assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
72
+ assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
73
+ assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
74
+ assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
75
+
76
+
77
+ def test_AppliedBinaryRelation():
78
+ with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
79
+ assert smtlib_code(Q.eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
80
+ assert smtlib_code(Q.ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
81
+ assert smtlib_code(Q.lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
82
+ assert smtlib_code(Q.le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
83
+ assert smtlib_code(Q.gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
84
+ assert smtlib_code(Q.ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
85
+
86
+ raises(ValueError, lambda: smtlib_code(Q.complex(x), log_warn=w))
87
+
88
+
89
+ def test_AppliedPredicate():
90
+ with _check_warns([_W.DEFAULTING_TO_FLOAT] * 6) as w:
91
+ assert smtlib_code(Q.positive(x), auto_declare=False, log_warn=w) == "(assert (> x 0))"
92
+ assert smtlib_code(Q.negative(x), auto_declare=False, log_warn=w) == "(assert (< x 0))"
93
+ assert smtlib_code(Q.zero(x), auto_declare=False, log_warn=w) == "(assert (= x 0))"
94
+ assert smtlib_code(Q.nonpositive(x), auto_declare=False, log_warn=w) == "(assert (<= x 0))"
95
+ assert smtlib_code(Q.nonnegative(x), auto_declare=False, log_warn=w) == "(assert (>= x 0))"
96
+ assert smtlib_code(Q.nonzero(x), auto_declare=False, log_warn=w) == "(assert (not (= x 0)))"
97
+
98
+ def test_Function():
99
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
100
+ assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))"
101
+
102
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
103
+ assert smtlib_code(
104
+ abs(x),
105
+ symbol_table={x: int, y: bool},
106
+ known_types={int: "INTEGER_TYPE"},
107
+ known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"},
108
+ log_warn=w
109
+ ) == "(declare-const x INTEGER_TYPE)\n" \
110
+ "(ABSOLUTE_VALUE_OF x)"
111
+
112
+ my_fun1 = Function('f1')
113
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
114
+ assert smtlib_code(
115
+ my_fun1(x),
116
+ symbol_table={my_fun1: Callable[[bool], float]},
117
+ log_warn=w
118
+ ) == "(declare-const x Bool)\n" \
119
+ "(declare-fun f1 (Bool) Real)\n" \
120
+ "(f1 x)"
121
+
122
+ with _check_warns([]) as w:
123
+ assert smtlib_code(
124
+ my_fun1(x),
125
+ symbol_table={my_fun1: Callable[[bool], bool]},
126
+ log_warn=w
127
+ ) == "(declare-const x Bool)\n" \
128
+ "(declare-fun f1 (Bool) Bool)\n" \
129
+ "(assert (f1 x))"
130
+
131
+ assert smtlib_code(
132
+ Eq(my_fun1(x, z), y),
133
+ symbol_table={my_fun1: Callable[[int, bool], bool]},
134
+ log_warn=w
135
+ ) == "(declare-const x Int)\n" \
136
+ "(declare-const y Bool)\n" \
137
+ "(declare-const z Bool)\n" \
138
+ "(declare-fun f1 (Int Bool) Bool)\n" \
139
+ "(assert (= (f1 x z) y))"
140
+
141
+ assert smtlib_code(
142
+ Eq(my_fun1(x, z), y),
143
+ symbol_table={my_fun1: Callable[[int, bool], bool]},
144
+ known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
145
+ log_warn=w
146
+ ) == "(declare-const x Int)\n" \
147
+ "(declare-const y Bool)\n" \
148
+ "(declare-const z Bool)\n" \
149
+ "(assert (== (MY_KNOWN_FUN x z) y))"
150
+
151
+ with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w:
152
+ assert smtlib_code(
153
+ Eq(my_fun1(x, z), y),
154
+ known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
155
+ log_warn=w
156
+ ) == "(declare-const x Real)\n" \
157
+ "(declare-const y Real)\n" \
158
+ "(declare-const z Real)\n" \
159
+ "(assert (== (MY_KNOWN_FUN x z) y))"
160
+
161
+
162
+ def test_Pow():
163
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
164
+ assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)"
165
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
166
+ assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))"
167
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
168
+ assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))'
169
+
170
+ a = Symbol('a', integer=True)
171
+ b = Symbol('b', real=True)
172
+ c = Symbol('c')
173
+
174
+ def g(x): return 2 * x
175
+
176
+ # if x=1, y=2, then expr=2.333...
177
+ expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b)
178
+
179
+ with _check_warns([]) as w:
180
+ assert smtlib_code(
181
+ [
182
+ Eq(a < 2, c),
183
+ Eq(b > a, c),
184
+ c & True,
185
+ Eq(expr, 2 + Rational(1, 3))
186
+ ],
187
+ log_warn=w
188
+ ) == '(declare-const a Int)\n' \
189
+ '(declare-const b Real)\n' \
190
+ '(declare-const c Bool)\n' \
191
+ '(assert (= (< a 2) c))\n' \
192
+ '(assert (= (> b a) c))\n' \
193
+ '(assert c)\n' \
194
+ '(assert (= ' \
195
+ '(* (pow (* 7.0 a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \
196
+ '(/ 7 3)' \
197
+ '))'
198
+
199
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
200
+ assert smtlib_code(
201
+ Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False),
202
+ log_warn=w
203
+ ) == '(declare-const b Real)\n' \
204
+ '(declare-const c Real)\n' \
205
+ '(* -2 c (pow (* b b) -1))'
206
+
207
+
208
+ def test_basic_ops():
209
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
210
+ assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)"
211
+
212
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
213
+ assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)"
214
+
215
+ # with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w:
216
+ # todo: implement re-write, currently does '(+ x (* -1 y))' instead
217
+ # assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)"
218
+
219
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
220
+ assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)"
221
+
222
+
223
+ def test_quantifier_extensions():
224
+ from sympy.logic.boolalg import Boolean
225
+ from sympy import Interval, Tuple, sympify
226
+
227
+ # start For-all quantifier class example
228
+ class ForAll(Boolean):
229
+ def _smtlib(self, printer):
230
+ bound_symbol_declarations = [
231
+ printer._s_expr(sym.name, [
232
+ printer._known_types[printer.symbol_table[sym]],
233
+ Interval(start, end)
234
+ ]) for sym, start, end in self.limits
235
+ ]
236
+ return printer._s_expr('forall', [
237
+ printer._s_expr('', bound_symbol_declarations),
238
+ self.function
239
+ ])
240
+
241
+ @property
242
+ def bound_symbols(self):
243
+ return {s for s, _, _ in self.limits}
244
+
245
+ @property
246
+ def free_symbols(self):
247
+ bound_symbol_names = {s.name for s in self.bound_symbols}
248
+ return {
249
+ s for s in self.function.free_symbols
250
+ if s.name not in bound_symbol_names
251
+ }
252
+
253
+ def __new__(cls, *args):
254
+ limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))]
255
+ function = [sympify(a) for a in args if isinstance(a, Boolean)]
256
+ assert len(limits) + len(function) == len(args)
257
+ assert len(function) == 1
258
+ function = function[0]
259
+
260
+ if isinstance(function, ForAll): return ForAll.__new__(
261
+ ForAll, *(limits + function.limits), function.function
262
+ )
263
+ inst = Boolean.__new__(cls)
264
+ inst._args = tuple(limits + [function])
265
+ inst.limits = limits
266
+ inst.function = function
267
+ return inst
268
+
269
+ # end For-All Quantifier class example
270
+
271
+ f = Function('f')
272
+ with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
273
+ assert smtlib_code(
274
+ ForAll((x, -42, +21), Eq(f(x), f(x))),
275
+ symbol_table={f: Callable[[float], float]},
276
+ log_warn=w
277
+ ) == '(assert (forall ( (x Real [-42, 21])) true))'
278
+
279
+ with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w:
280
+ assert smtlib_code(
281
+ ForAll(
282
+ (x, -42, +21), (y, -100, 3),
283
+ Implies(Eq(x, y), Eq(f(x), f(y)))
284
+ ),
285
+ symbol_table={f: Callable[[float], float]},
286
+ log_warn=w
287
+ ) == '(declare-fun f (Real) Real)\n' \
288
+ '(assert (' \
289
+ 'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \
290
+ '(=> (= x y) (= (f x) (f y)))' \
291
+ '))'
292
+
293
+ a = Symbol('a', integer=True)
294
+ b = Symbol('b', real=True)
295
+ c = Symbol('c')
296
+
297
+ with _check_warns([]) as w:
298
+ assert smtlib_code(
299
+ ForAll(
300
+ (a, 2, 100), ForAll(
301
+ (b, 2, 100),
302
+ Implies(a < b, sqrt(a) < b) | c
303
+ )),
304
+ log_warn=w
305
+ ) == '(declare-const c Bool)\n' \
306
+ '(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \
307
+ '(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \
308
+ '))'
309
+
310
+
311
+ def test_mix_number_mult_symbols():
312
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
313
+ assert smtlib_code(
314
+ 1 / pi,
315
+ known_constants={pi: "MY_PI"},
316
+ log_warn=w
317
+ ) == '(pow MY_PI -1)'
318
+
319
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
320
+ assert smtlib_code(
321
+ [
322
+ Eq(pi, 3.14, evaluate=False),
323
+ 1 / pi,
324
+ ],
325
+ known_constants={pi: "MY_PI"},
326
+ log_warn=w
327
+ ) == '(assert (= MY_PI 3.14))\n' \
328
+ '(pow MY_PI -1)'
329
+
330
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
331
+ assert smtlib_code(
332
+ Add(S.Zero, S.One, S.NegativeOne, S.Half,
333
+ S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
334
+ known_constants={
335
+ S.Pi: 'p', S.GoldenRatio: 'g',
336
+ S.Exp1: 'e'
337
+ },
338
+ known_functions={
339
+ Add: 'plus',
340
+ exp: 'exp'
341
+ },
342
+ precision=3,
343
+ log_warn=w
344
+ ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)'
345
+
346
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
347
+ assert smtlib_code(
348
+ Add(S.Zero, S.One, S.NegativeOne, S.Half,
349
+ S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
350
+ known_constants={
351
+ S.Pi: 'p'
352
+ },
353
+ known_functions={
354
+ Add: 'plus',
355
+ exp: 'exp'
356
+ },
357
+ precision=3,
358
+ log_warn=w
359
+ ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)'
360
+
361
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
362
+ assert smtlib_code(
363
+ Add(S.Zero, S.One, S.NegativeOne, S.Half,
364
+ S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
365
+ known_functions={Add: 'plus'},
366
+ precision=3,
367
+ log_warn=w
368
+ ) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)'
369
+
370
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
371
+ assert smtlib_code(
372
+ Add(S.Zero, S.One, S.NegativeOne, S.Half,
373
+ S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
374
+ known_constants={S.Exp1: 'e'},
375
+ known_functions={Add: 'plus'},
376
+ precision=3,
377
+ log_warn=w
378
+ ) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)'
379
+
380
+
381
+ def test_boolean():
382
+ with _check_warns([]) as w:
383
+ assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \
384
+ '(declare-const y Bool)\n' \
385
+ '(assert (and x y))'
386
+ assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \
387
+ '(declare-const y Bool)\n' \
388
+ '(assert (or x y))'
389
+ assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \
390
+ '(assert (not x))'
391
+ assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \
392
+ '(declare-const y Bool)\n' \
393
+ '(declare-const z Bool)\n' \
394
+ '(assert (and x y z))'
395
+
396
+ with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
397
+ assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \
398
+ '(declare-const y Bool)\n' \
399
+ '(declare-const z Real)\n' \
400
+ '(assert (or (> z 3) (and x (not y))))'
401
+
402
+ f = Function('f')
403
+ g = Function('g')
404
+ h = Function('h')
405
+ with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
406
+ assert smtlib_code(
407
+ [Gt(f(x), y),
408
+ Lt(y, g(z))],
409
+ symbol_table={
410
+ f: Callable[[bool], int], g: Callable[[bool], int],
411
+ }, log_warn=w
412
+ ) == '(declare-const x Bool)\n' \
413
+ '(declare-const y Real)\n' \
414
+ '(declare-const z Bool)\n' \
415
+ '(declare-fun f (Bool) Int)\n' \
416
+ '(declare-fun g (Bool) Int)\n' \
417
+ '(assert (> (f x) y))\n' \
418
+ '(assert (< y (g z)))'
419
+
420
+ with _check_warns([]) as w:
421
+ assert smtlib_code(
422
+ [Eq(f(x), y),
423
+ Lt(y, g(z))],
424
+ symbol_table={
425
+ f: Callable[[bool], int], g: Callable[[bool], int],
426
+ }, log_warn=w
427
+ ) == '(declare-const x Bool)\n' \
428
+ '(declare-const y Int)\n' \
429
+ '(declare-const z Bool)\n' \
430
+ '(declare-fun f (Bool) Int)\n' \
431
+ '(declare-fun g (Bool) Int)\n' \
432
+ '(assert (= (f x) y))\n' \
433
+ '(assert (< y (g z)))'
434
+
435
+ with _check_warns([]) as w:
436
+ assert smtlib_code(
437
+ [Eq(f(x), y),
438
+ Eq(g(f(x)), z),
439
+ Eq(h(g(f(x))), x)],
440
+ symbol_table={
441
+ f: Callable[[float], int],
442
+ g: Callable[[int], bool],
443
+ h: Callable[[bool], float]
444
+ },
445
+ log_warn=w
446
+ ) == '(declare-const x Real)\n' \
447
+ '(declare-const y Int)\n' \
448
+ '(declare-const z Bool)\n' \
449
+ '(declare-fun f (Real) Int)\n' \
450
+ '(declare-fun g (Int) Bool)\n' \
451
+ '(declare-fun h (Bool) Real)\n' \
452
+ '(assert (= (f x) y))\n' \
453
+ '(assert (= (g (f x)) z))\n' \
454
+ '(assert (= (h (g (f x))) x))'
455
+
456
+
457
+ # todo: make smtlib_code support arrays
458
+ # def test_containers():
459
+ # assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
460
+ # "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
461
+ # assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
462
+ # assert julia_code([1]) == "Any[1]"
463
+ # assert julia_code((1,)) == "(1,)"
464
+ # assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
465
+ # assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))"
466
+ # # scalar, matrix, empty matrix and empty list
467
+ # assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
468
+
469
+ def test_smtlib_piecewise():
470
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
471
+ assert smtlib_code(
472
+ Piecewise((x, x < 1),
473
+ (x ** 2, True)),
474
+ auto_declare=False,
475
+ log_warn=w
476
+ ) == '(ite (< x 1) x (pow x 2))'
477
+
478
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
479
+ assert smtlib_code(
480
+ Piecewise((x ** 2, x < 1),
481
+ (x ** 3, x < 2),
482
+ (x ** 4, x < 3),
483
+ (x ** 5, True)),
484
+ auto_declare=False,
485
+ log_warn=w
486
+ ) == '(ite (< x 1) (pow x 2) ' \
487
+ '(ite (< x 2) (pow x 3) ' \
488
+ '(ite (< x 3) (pow x 4) ' \
489
+ '(pow x 5))))'
490
+
491
+ # Check that Piecewise without a True (default) condition error
492
+ expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
493
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
494
+ raises(AssertionError, lambda: smtlib_code(expr, log_warn=w))
495
+
496
+
497
+ def test_smtlib_piecewise_times_const():
498
+ pw = Piecewise((x, x < 1), (x ** 2, True))
499
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
500
+ assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))'
501
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
502
+ assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))'
503
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
504
+ assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))'
505
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
506
+ assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))'
507
+
508
+
509
+ # todo: make smtlib_code support arrays / matrices ?
510
+ # def test_smtlib_matrix_assign_to():
511
+ # A = Matrix([[1, 2, 3]])
512
+ # assert smtlib_code(A, assign_to='a') == "a = [1 2 3]"
513
+ # A = Matrix([[1, 2], [3, 4]])
514
+ # assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]"
515
+
516
+ # def test_julia_matrix_1x1():
517
+ # A = Matrix([[3]])
518
+ # B = MatrixSymbol('B', 1, 1)
519
+ # C = MatrixSymbol('C', 1, 2)
520
+ # assert julia_code(A, assign_to=B) == "B = [3]"
521
+ # raises(ValueError, lambda: julia_code(A, assign_to=C))
522
+
523
+ # def test_julia_matrix_elements():
524
+ # A = Matrix([[x, 2, x * y]])
525
+ # assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
526
+ # A = MatrixSymbol('AA', 1, 3)
527
+ # assert julia_code(A) == "AA"
528
+ # assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \
529
+ # "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
530
+ # assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
531
+
532
+ def test_smtlib_boolean():
533
+ with _check_warns([]) as w:
534
+ assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true'
535
+ assert smtlib_code(True, log_warn=w) == '(assert true)'
536
+ assert smtlib_code(S.true, log_warn=w) == '(assert true)'
537
+ assert smtlib_code(S.false, log_warn=w) == '(assert false)'
538
+ assert smtlib_code(False, log_warn=w) == '(assert false)'
539
+ assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false'
540
+
541
+
542
+ def test_not_supported():
543
+ f = Function('f')
544
+ with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
545
+ raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w))
546
+ with _check_warns([_W.WILL_NOT_ASSERT]) as w:
547
+ raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w))
548
+
549
+
550
+ def test_Float():
551
+ assert smtlib_code(0.0) == "0.0"
552
+ assert smtlib_code(0.000000000000000003) == '(* 3.0 (pow 10 -18))'
553
+ assert smtlib_code(5.3) == "5.3"
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_str.py ADDED
@@ -0,0 +1,1206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy import MatAdd
2
+ from sympy.algebras.quaternion import Quaternion
3
+ from sympy.assumptions.ask import Q
4
+ from sympy.calculus.accumulationbounds import AccumBounds
5
+ from sympy.combinatorics.partitions import Partition
6
+ from sympy.concrete.summations import (Sum, summation)
7
+ from sympy.core.add import Add
8
+ from sympy.core.containers import (Dict, Tuple)
9
+ from sympy.core.expr import UnevaluatedExpr, Expr
10
+ from sympy.core.function import (Derivative, Function, Lambda, Subs, WildFunction)
11
+ from sympy.core.mul import Mul
12
+ from sympy.core import (Catalan, EulerGamma, GoldenRatio, TribonacciConstant)
13
+ from sympy.core.numbers import (E, Float, I, Integer, Rational, nan, oo, pi, zoo)
14
+ from sympy.core.parameters import _exp_is_pow
15
+ from sympy.core.power import Pow
16
+ from sympy.core.relational import (Eq, Rel, Ne)
17
+ from sympy.core.singleton import S
18
+ from sympy.core.symbol import (Dummy, Symbol, Wild, symbols)
19
+ from sympy.functions.combinatorial.factorials import (factorial, factorial2, subfactorial)
20
+ from sympy.functions.elementary.complexes import Abs
21
+ from sympy.functions.elementary.exponential import exp
22
+ from sympy.functions.elementary.miscellaneous import sqrt
23
+ from sympy.functions.elementary.trigonometric import (cos, sin)
24
+ from sympy.functions.special.delta_functions import Heaviside
25
+ from sympy.functions.special.zeta_functions import zeta
26
+ from sympy.integrals.integrals import Integral
27
+ from sympy.logic.boolalg import (Equivalent, false, true, Xor)
28
+ from sympy.matrices.dense import Matrix
29
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
30
+ from sympy.matrices.expressions import Identity
31
+ from sympy.matrices.expressions.slice import MatrixSlice
32
+ from sympy.matrices import SparseMatrix
33
+ from sympy.polys.polytools import factor
34
+ from sympy.series.limits import Limit
35
+ from sympy.series.order import O
36
+ from sympy.sets.sets import (Complement, FiniteSet, Interval, SymmetricDifference)
37
+ from sympy.stats import (Covariance, Expectation, Probability, Variance)
38
+ from sympy.stats.rv import RandomSymbol
39
+ from sympy.external import import_module
40
+ from sympy.physics.control.lti import TransferFunction, Series, Parallel, \
41
+ Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback
42
+ from sympy.physics.units import second, joule
43
+ from sympy.polys import (Poly, rootof, RootSum, groebner, ring, field, ZZ, QQ,
44
+ ZZ_I, QQ_I, lex, grlex)
45
+ from sympy.geometry import Point, Circle, Polygon, Ellipse, Triangle
46
+ from sympy.tensor import NDimArray
47
+ from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement
48
+
49
+ from sympy.testing.pytest import raises, warns_deprecated_sympy
50
+
51
+ from sympy.printing import sstr, sstrrepr, StrPrinter
52
+ from sympy.physics.quantum.trace import Tr
53
+
54
+ x, y, z, w, t = symbols('x,y,z,w,t')
55
+ d = Dummy('d')
56
+
57
+
58
+ def test_printmethod():
59
+ class R(Abs):
60
+ def _sympystr(self, printer):
61
+ return "foo(%s)" % printer._print(self.args[0])
62
+ assert sstr(R(x)) == "foo(x)"
63
+
64
+ class R(Abs):
65
+ def _sympystr(self, printer):
66
+ return "foo"
67
+ assert sstr(R(x)) == "foo"
68
+
69
+
70
+ def test_Abs():
71
+ assert str(Abs(x)) == "Abs(x)"
72
+ assert str(Abs(Rational(1, 6))) == "1/6"
73
+ assert str(Abs(Rational(-1, 6))) == "1/6"
74
+
75
+
76
+ def test_Add():
77
+ assert str(x + y) == "x + y"
78
+ assert str(x + 1) == "x + 1"
79
+ assert str(x + x**2) == "x**2 + x"
80
+ assert str(Add(0, 1, evaluate=False)) == "0 + 1"
81
+ assert str(Add(0, 0, 1, evaluate=False)) == "0 + 0 + 1"
82
+ assert str(1.0*x) == "1.0*x"
83
+ assert str(5 + x + y + x*y + x**2 + y**2) == "x**2 + x*y + x + y**2 + y + 5"
84
+ assert str(1 + x + x**2/2 + x**3/3) == "x**3/3 + x**2/2 + x + 1"
85
+ assert str(2*x - 7*x**2 + 2 + 3*y) == "-7*x**2 + 2*x + 3*y + 2"
86
+ assert str(x - y) == "x - y"
87
+ assert str(2 - x) == "2 - x"
88
+ assert str(x - 2) == "x - 2"
89
+ assert str(x - y - z - w) == "-w + x - y - z"
90
+ assert str(x - z*y**2*z*w) == "-w*y**2*z**2 + x"
91
+ assert str(x - 1*y*x*y) == "-x*y**2 + x"
92
+ assert str(sin(x).series(x, 0, 15)) == "x - x**3/6 + x**5/120 - x**7/5040 + x**9/362880 - x**11/39916800 + x**13/6227020800 + O(x**15)"
93
+ assert str(Add(Add(-w, x, evaluate=False), Add(-y, z, evaluate=False), evaluate=False)) == "(-w + x) + (-y + z)"
94
+ assert str(Add(Add(-x, -y, evaluate=False), -z, evaluate=False)) == "-z + (-x - y)"
95
+ assert str(Add(Add(Add(-x, -y, evaluate=False), -z, evaluate=False), -t, evaluate=False)) == "-t + (-z + (-x - y))"
96
+
97
+
98
+ def test_Catalan():
99
+ assert str(Catalan) == "Catalan"
100
+
101
+
102
+ def test_ComplexInfinity():
103
+ assert str(zoo) == "zoo"
104
+
105
+
106
+ def test_Derivative():
107
+ assert str(Derivative(x, y)) == "Derivative(x, y)"
108
+ assert str(Derivative(x**2, x, evaluate=False)) == "Derivative(x**2, x)"
109
+ assert str(Derivative(
110
+ x**2/y, x, y, evaluate=False)) == "Derivative(x**2/y, x, y)"
111
+
112
+
113
+ def test_dict():
114
+ assert str({1: 1 + x}) == sstr({1: 1 + x}) == "{1: x + 1}"
115
+ assert str({1: x**2, 2: y*x}) in ("{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}")
116
+ assert sstr({1: x**2, 2: y*x}) == "{1: x**2, 2: x*y}"
117
+
118
+
119
+ def test_Dict():
120
+ assert str(Dict({1: 1 + x})) == sstr({1: 1 + x}) == "{1: x + 1}"
121
+ assert str(Dict({1: x**2, 2: y*x})) in (
122
+ "{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}")
123
+ assert sstr(Dict({1: x**2, 2: y*x})) == "{1: x**2, 2: x*y}"
124
+
125
+
126
+ def test_Dummy():
127
+ assert str(d) == "_d"
128
+ assert str(d + x) == "_d + x"
129
+
130
+
131
+ def test_EulerGamma():
132
+ assert str(EulerGamma) == "EulerGamma"
133
+
134
+
135
+ def test_Exp():
136
+ assert str(E) == "E"
137
+ with _exp_is_pow(True):
138
+ assert str(exp(x)) == "E**x"
139
+
140
+
141
+ def test_factorial():
142
+ n = Symbol('n', integer=True)
143
+ assert str(factorial(-2)) == "zoo"
144
+ assert str(factorial(0)) == "1"
145
+ assert str(factorial(7)) == "5040"
146
+ assert str(factorial(n)) == "factorial(n)"
147
+ assert str(factorial(2*n)) == "factorial(2*n)"
148
+ assert str(factorial(factorial(n))) == 'factorial(factorial(n))'
149
+ assert str(factorial(factorial2(n))) == 'factorial(factorial2(n))'
150
+ assert str(factorial2(factorial(n))) == 'factorial2(factorial(n))'
151
+ assert str(factorial2(factorial2(n))) == 'factorial2(factorial2(n))'
152
+ assert str(subfactorial(3)) == "2"
153
+ assert str(subfactorial(n)) == "subfactorial(n)"
154
+ assert str(subfactorial(2*n)) == "subfactorial(2*n)"
155
+
156
+
157
+ def test_Function():
158
+ f = Function('f')
159
+ fx = f(x)
160
+ w = WildFunction('w')
161
+ assert str(f) == "f"
162
+ assert str(fx) == "f(x)"
163
+ assert str(w) == "w_"
164
+
165
+
166
+ def test_Geometry():
167
+ assert sstr(Point(0, 0)) == 'Point2D(0, 0)'
168
+ assert sstr(Circle(Point(0, 0), 3)) == 'Circle(Point2D(0, 0), 3)'
169
+ assert sstr(Ellipse(Point(1, 2), 3, 4)) == 'Ellipse(Point2D(1, 2), 3, 4)'
170
+ assert sstr(Triangle(Point(1, 1), Point(7, 8), Point(0, -1))) == \
171
+ 'Triangle(Point2D(1, 1), Point2D(7, 8), Point2D(0, -1))'
172
+ assert sstr(Polygon(Point(5, 6), Point(-2, -3), Point(0, 0), Point(4, 7))) == \
173
+ 'Polygon(Point2D(5, 6), Point2D(-2, -3), Point2D(0, 0), Point2D(4, 7))'
174
+ assert sstr(Triangle(Point(0, 0), Point(1, 0), Point(0, 1)), sympy_integers=True) == \
175
+ 'Triangle(Point2D(S(0), S(0)), Point2D(S(1), S(0)), Point2D(S(0), S(1)))'
176
+ assert sstr(Ellipse(Point(1, 2), 3, 4), sympy_integers=True) == \
177
+ 'Ellipse(Point2D(S(1), S(2)), S(3), S(4))'
178
+
179
+
180
+ def test_GoldenRatio():
181
+ assert str(GoldenRatio) == "GoldenRatio"
182
+
183
+
184
+ def test_Heaviside():
185
+ assert str(Heaviside(x)) == str(Heaviside(x, S.Half)) == "Heaviside(x)"
186
+ assert str(Heaviside(x, 1)) == "Heaviside(x, 1)"
187
+
188
+
189
+ def test_TribonacciConstant():
190
+ assert str(TribonacciConstant) == "TribonacciConstant"
191
+
192
+
193
+ def test_ImaginaryUnit():
194
+ assert str(I) == "I"
195
+
196
+
197
+ def test_Infinity():
198
+ assert str(oo) == "oo"
199
+ assert str(oo*I) == "oo*I"
200
+
201
+
202
+ def test_Integer():
203
+ assert str(Integer(-1)) == "-1"
204
+ assert str(Integer(1)) == "1"
205
+ assert str(Integer(-3)) == "-3"
206
+ assert str(Integer(0)) == "0"
207
+ assert str(Integer(25)) == "25"
208
+
209
+
210
+ def test_Integral():
211
+ assert str(Integral(sin(x), y)) == "Integral(sin(x), y)"
212
+ assert str(Integral(sin(x), (y, 0, 1))) == "Integral(sin(x), (y, 0, 1))"
213
+
214
+
215
+ def test_Interval():
216
+ n = (S.NegativeInfinity, 1, 2, S.Infinity)
217
+ for i in range(len(n)):
218
+ for j in range(i + 1, len(n)):
219
+ for l in (True, False):
220
+ for r in (True, False):
221
+ ival = Interval(n[i], n[j], l, r)
222
+ assert S(str(ival)) == ival
223
+
224
+
225
+ def test_AccumBounds():
226
+ a = Symbol('a', real=True)
227
+ assert str(AccumBounds(0, a)) == "AccumBounds(0, a)"
228
+ assert str(AccumBounds(0, 1)) == "AccumBounds(0, 1)"
229
+
230
+
231
+ def test_Lambda():
232
+ assert str(Lambda(d, d**2)) == "Lambda(_d, _d**2)"
233
+ # issue 2908
234
+ assert str(Lambda((), 1)) == "Lambda((), 1)"
235
+ assert str(Lambda((), x)) == "Lambda((), x)"
236
+ assert str(Lambda((x, y), x+y)) == "Lambda((x, y), x + y)"
237
+ assert str(Lambda(((x, y),), x+y)) == "Lambda(((x, y),), x + y)"
238
+
239
+
240
+ def test_Limit():
241
+ assert str(Limit(sin(x)/x, x, y)) == "Limit(sin(x)/x, x, y, dir='+')"
242
+ assert str(Limit(1/x, x, 0)) == "Limit(1/x, x, 0, dir='+')"
243
+ assert str(
244
+ Limit(sin(x)/x, x, y, dir="-")) == "Limit(sin(x)/x, x, y, dir='-')"
245
+
246
+
247
+ def test_list():
248
+ assert str([x]) == sstr([x]) == "[x]"
249
+ assert str([x**2, x*y + 1]) == sstr([x**2, x*y + 1]) == "[x**2, x*y + 1]"
250
+ assert str([x**2, [y + x]]) == sstr([x**2, [y + x]]) == "[x**2, [x + y]]"
251
+
252
+
253
+ def test_Matrix_str():
254
+ M = Matrix([[x**+1, 1], [y, x + y]])
255
+ assert str(M) == "Matrix([[x, 1], [y, x + y]])"
256
+ assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])"
257
+ M = Matrix([[1]])
258
+ assert str(M) == sstr(M) == "Matrix([[1]])"
259
+ M = Matrix([[1, 2]])
260
+ assert str(M) == sstr(M) == "Matrix([[1, 2]])"
261
+ M = Matrix()
262
+ assert str(M) == sstr(M) == "Matrix(0, 0, [])"
263
+ M = Matrix(0, 1, lambda i, j: 0)
264
+ assert str(M) == sstr(M) == "Matrix(0, 1, [])"
265
+
266
+
267
+ def test_Mul():
268
+ assert str(x/y) == "x/y"
269
+ assert str(y/x) == "y/x"
270
+ assert str(x/y/z) == "x/(y*z)"
271
+ assert str((x + 1)/(y + 2)) == "(x + 1)/(y + 2)"
272
+ assert str(2*x/3) == '2*x/3'
273
+ assert str(-2*x/3) == '-2*x/3'
274
+ assert str(-1.0*x) == '-1.0*x'
275
+ assert str(1.0*x) == '1.0*x'
276
+ assert str(Mul(0, 1, evaluate=False)) == '0*1'
277
+ assert str(Mul(1, 0, evaluate=False)) == '1*0'
278
+ assert str(Mul(1, 1, evaluate=False)) == '1*1'
279
+ assert str(Mul(1, 1, 1, evaluate=False)) == '1*1*1'
280
+ assert str(Mul(1, 2, evaluate=False)) == '1*2'
281
+ assert str(Mul(1, S.Half, evaluate=False)) == '1*(1/2)'
282
+ assert str(Mul(1, 1, S.Half, evaluate=False)) == '1*1*(1/2)'
283
+ assert str(Mul(1, 1, 2, 3, x, evaluate=False)) == '1*1*2*3*x'
284
+ assert str(Mul(1, -1, evaluate=False)) == '1*(-1)'
285
+ assert str(Mul(-1, 1, evaluate=False)) == '-1*1'
286
+ assert str(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == '4*3*2*1*0*y*x'
287
+ assert str(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == '4*3*2*(z + 1)*0*y*x'
288
+ assert str(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == '(2/3)*(5/7)'
289
+ # For issue 14160
290
+ assert str(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
291
+ evaluate=False)) == '-2*x/(y*y)'
292
+ # issue 21537
293
+ assert str(Mul(x, Pow(1/y, -1, evaluate=False), evaluate=False)) == 'x/(1/y)'
294
+
295
+ # Issue 24108
296
+ from sympy.core.parameters import evaluate
297
+ with evaluate(False):
298
+ assert str(Mul(Pow(Integer(2), Integer(-1)), Add(Integer(-1), Mul(Integer(-1), Integer(1))))) == "(-1 - 1*1)/2"
299
+
300
+ class CustomClass1(Expr):
301
+ is_commutative = True
302
+
303
+ class CustomClass2(Expr):
304
+ is_commutative = True
305
+ cc1 = CustomClass1()
306
+ cc2 = CustomClass2()
307
+ assert str(Rational(2)*cc1) == '2*CustomClass1()'
308
+ assert str(cc1*Rational(2)) == '2*CustomClass1()'
309
+ assert str(cc1*Float("1.5")) == '1.5*CustomClass1()'
310
+ assert str(cc2*Rational(2)) == '2*CustomClass2()'
311
+ assert str(cc2*Rational(2)*cc1) == '2*CustomClass1()*CustomClass2()'
312
+ assert str(cc1*Rational(2)*cc2) == '2*CustomClass1()*CustomClass2()'
313
+
314
+
315
+ def test_NaN():
316
+ assert str(nan) == "nan"
317
+
318
+
319
+ def test_NegativeInfinity():
320
+ assert str(-oo) == "-oo"
321
+
322
+ def test_Order():
323
+ assert str(O(x)) == "O(x)"
324
+ assert str(O(x**2)) == "O(x**2)"
325
+ assert str(O(x*y)) == "O(x*y, x, y)"
326
+ assert str(O(x, x)) == "O(x)"
327
+ assert str(O(x, (x, 0))) == "O(x)"
328
+ assert str(O(x, (x, oo))) == "O(x, (x, oo))"
329
+ assert str(O(x, x, y)) == "O(x, x, y)"
330
+ assert str(O(x, x, y)) == "O(x, x, y)"
331
+ assert str(O(x, (x, oo), (y, oo))) == "O(x, (x, oo), (y, oo))"
332
+
333
+
334
+ def test_Permutation_Cycle():
335
+ from sympy.combinatorics import Permutation, Cycle
336
+
337
+ # general principle: economically, canonically show all moved elements
338
+ # and the size of the permutation.
339
+
340
+ for p, s in [
341
+ (Cycle(),
342
+ '()'),
343
+ (Cycle(2),
344
+ '(2)'),
345
+ (Cycle(2, 1),
346
+ '(1 2)'),
347
+ (Cycle(1, 2)(5)(6, 7)(10),
348
+ '(1 2)(6 7)(10)'),
349
+ (Cycle(3, 4)(1, 2)(3, 4),
350
+ '(1 2)(4)'),
351
+ ]:
352
+ assert sstr(p) == s
353
+
354
+ for p, s in [
355
+ (Permutation([]),
356
+ 'Permutation([])'),
357
+ (Permutation([], size=1),
358
+ 'Permutation([0])'),
359
+ (Permutation([], size=2),
360
+ 'Permutation([0, 1])'),
361
+ (Permutation([], size=10),
362
+ 'Permutation([], size=10)'),
363
+ (Permutation([1, 0, 2]),
364
+ 'Permutation([1, 0, 2])'),
365
+ (Permutation([1, 0, 2, 3, 4, 5]),
366
+ 'Permutation([1, 0], size=6)'),
367
+ (Permutation([1, 0, 2, 3, 4, 5], size=10),
368
+ 'Permutation([1, 0], size=10)'),
369
+ ]:
370
+ assert sstr(p, perm_cyclic=False) == s
371
+
372
+ for p, s in [
373
+ (Permutation([]),
374
+ '()'),
375
+ (Permutation([], size=1),
376
+ '(0)'),
377
+ (Permutation([], size=2),
378
+ '(1)'),
379
+ (Permutation([], size=10),
380
+ '(9)'),
381
+ (Permutation([1, 0, 2]),
382
+ '(2)(0 1)'),
383
+ (Permutation([1, 0, 2, 3, 4, 5]),
384
+ '(5)(0 1)'),
385
+ (Permutation([1, 0, 2, 3, 4, 5], size=10),
386
+ '(9)(0 1)'),
387
+ (Permutation([0, 1, 3, 2, 4, 5], size=10),
388
+ '(9)(2 3)'),
389
+ ]:
390
+ assert sstr(p) == s
391
+
392
+
393
+ with warns_deprecated_sympy():
394
+ old_print_cyclic = Permutation.print_cyclic
395
+ Permutation.print_cyclic = False
396
+ assert sstr(Permutation([1, 0, 2])) == 'Permutation([1, 0, 2])'
397
+ Permutation.print_cyclic = old_print_cyclic
398
+
399
+ def test_Pi():
400
+ assert str(pi) == "pi"
401
+
402
+
403
+ def test_Poly():
404
+ assert str(Poly(0, x)) == "Poly(0, x, domain='ZZ')"
405
+ assert str(Poly(1, x)) == "Poly(1, x, domain='ZZ')"
406
+ assert str(Poly(x, x)) == "Poly(x, x, domain='ZZ')"
407
+
408
+ assert str(Poly(2*x + 1, x)) == "Poly(2*x + 1, x, domain='ZZ')"
409
+ assert str(Poly(2*x - 1, x)) == "Poly(2*x - 1, x, domain='ZZ')"
410
+
411
+ assert str(Poly(-1, x)) == "Poly(-1, x, domain='ZZ')"
412
+ assert str(Poly(-x, x)) == "Poly(-x, x, domain='ZZ')"
413
+
414
+ assert str(Poly(-2*x + 1, x)) == "Poly(-2*x + 1, x, domain='ZZ')"
415
+ assert str(Poly(-2*x - 1, x)) == "Poly(-2*x - 1, x, domain='ZZ')"
416
+
417
+ assert str(Poly(x - 1, x)) == "Poly(x - 1, x, domain='ZZ')"
418
+ assert str(Poly(2*x + x**5, x)) == "Poly(x**5 + 2*x, x, domain='ZZ')"
419
+
420
+ assert str(Poly(3**(2*x), 3**x)) == "Poly((3**x)**2, 3**x, domain='ZZ')"
421
+ assert str(Poly((x**2)**x)) == "Poly(((x**2)**x), (x**2)**x, domain='ZZ')"
422
+
423
+ assert str(Poly((x + y)**3, (x + y), expand=False)
424
+ ) == "Poly((x + y)**3, x + y, domain='ZZ')"
425
+ assert str(Poly((x - 1)**2, (x - 1), expand=False)
426
+ ) == "Poly((x - 1)**2, x - 1, domain='ZZ')"
427
+
428
+ assert str(
429
+ Poly(x**2 + 1 + y, x)) == "Poly(x**2 + y + 1, x, domain='ZZ[y]')"
430
+ assert str(
431
+ Poly(x**2 - 1 + y, x)) == "Poly(x**2 + y - 1, x, domain='ZZ[y]')"
432
+
433
+ assert str(Poly(x**2 + I*x, x)) == "Poly(x**2 + I*x, x, domain='ZZ_I')"
434
+ assert str(Poly(x**2 - I*x, x)) == "Poly(x**2 - I*x, x, domain='ZZ_I')"
435
+
436
+ assert str(Poly(-x*y*z + x*y - 1, x, y, z)
437
+ ) == "Poly(-x*y*z + x*y - 1, x, y, z, domain='ZZ')"
438
+ assert str(Poly(-w*x**21*y**7*z + (1 + w)*z**3 - 2*x*z + 1, x, y, z)) == \
439
+ "Poly(-w*x**21*y**7*z - 2*x*z + (w + 1)*z**3 + 1, x, y, z, domain='ZZ[w]')"
440
+
441
+ assert str(Poly(x**2 + 1, x, modulus=2)) == "Poly(x**2 + 1, x, modulus=2)"
442
+ assert str(Poly(2*x**2 + 3*x + 4, x, modulus=17)) == "Poly(2*x**2 + 3*x + 4, x, modulus=17)"
443
+
444
+
445
+ def test_PolyRing():
446
+ assert str(ring("x", ZZ, lex)[0]) == "Polynomial ring in x over ZZ with lex order"
447
+ assert str(ring("x,y", QQ, grlex)[0]) == "Polynomial ring in x, y over QQ with grlex order"
448
+ assert str(ring("x,y,z", ZZ["t"], lex)[0]) == "Polynomial ring in x, y, z over ZZ[t] with lex order"
449
+
450
+
451
+ def test_FracField():
452
+ assert str(field("x", ZZ, lex)[0]) == "Rational function field in x over ZZ with lex order"
453
+ assert str(field("x,y", QQ, grlex)[0]) == "Rational function field in x, y over QQ with grlex order"
454
+ assert str(field("x,y,z", ZZ["t"], lex)[0]) == "Rational function field in x, y, z over ZZ[t] with lex order"
455
+
456
+
457
+ def test_PolyElement():
458
+ Ruv, u,v = ring("u,v", ZZ)
459
+ Rxyz, x,y,z = ring("x,y,z", Ruv)
460
+ Rx_zzi, xz = ring("x", ZZ_I)
461
+
462
+ assert str(x - x) == "0"
463
+ assert str(x - 1) == "x - 1"
464
+ assert str(x + 1) == "x + 1"
465
+ assert str(x**2) == "x**2"
466
+
467
+ assert str((u**2 + 3*u*v + 1)*x**2*y + u + 1) == "(u**2 + 3*u*v + 1)*x**2*y + u + 1"
468
+ assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x"
469
+ assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1"
470
+ assert str((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == "-(u**2 - 3*u*v + 1)*x**2*y - (u + 1)*x - 1"
471
+
472
+ assert str(-(v**2 + v + 1)*x + 3*u*v + 1) == "-(v**2 + v + 1)*x + 3*u*v + 1"
473
+ assert str(-(v**2 + v + 1)*x - 3*u*v + 1) == "-(v**2 + v + 1)*x - 3*u*v + 1"
474
+
475
+ assert str((1+I)*xz + 2) == "(1 + 1*I)*x + (2 + 0*I)"
476
+
477
+
478
+ def test_FracElement():
479
+ Fuv, u,v = field("u,v", ZZ)
480
+ Fxyzt, x,y,z,t = field("x,y,z,t", Fuv)
481
+ Rx_zzi, xz = field("x", QQ_I)
482
+ i = QQ_I(0, 1)
483
+
484
+ assert str(x - x) == "0"
485
+ assert str(x - 1) == "x - 1"
486
+ assert str(x + 1) == "x + 1"
487
+
488
+ assert str(x/3) == "x/3"
489
+ assert str(x/z) == "x/z"
490
+ assert str(x*y/z) == "x*y/z"
491
+ assert str(x/(z*t)) == "x/(z*t)"
492
+ assert str(x*y/(z*t)) == "x*y/(z*t)"
493
+
494
+ assert str((x - 1)/y) == "(x - 1)/y"
495
+ assert str((x + 1)/y) == "(x + 1)/y"
496
+ assert str((-x - 1)/y) == "(-x - 1)/y"
497
+ assert str((x + 1)/(y*z)) == "(x + 1)/(y*z)"
498
+ assert str(-y/(x + 1)) == "-y/(x + 1)"
499
+ assert str(y*z/(x + 1)) == "y*z/(x + 1)"
500
+
501
+ assert str(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - 1)"
502
+ assert str(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - u*v*t - 1)"
503
+
504
+ assert str((1+i)/xz) == "(1 + 1*I)/x"
505
+ assert str(((1+i)*xz - i)/xz) == "((1 + 1*I)*x + (0 + -1*I))/x"
506
+
507
+
508
+ def test_GaussianInteger():
509
+ assert str(ZZ_I(1, 0)) == "1"
510
+ assert str(ZZ_I(-1, 0)) == "-1"
511
+ assert str(ZZ_I(0, 1)) == "I"
512
+ assert str(ZZ_I(0, -1)) == "-I"
513
+ assert str(ZZ_I(0, 2)) == "2*I"
514
+ assert str(ZZ_I(0, -2)) == "-2*I"
515
+ assert str(ZZ_I(1, 1)) == "1 + I"
516
+ assert str(ZZ_I(-1, -1)) == "-1 - I"
517
+ assert str(ZZ_I(-1, -2)) == "-1 - 2*I"
518
+
519
+
520
+ def test_GaussianRational():
521
+ assert str(QQ_I(1, 0)) == "1"
522
+ assert str(QQ_I(QQ(2, 3), 0)) == "2/3"
523
+ assert str(QQ_I(0, QQ(2, 3))) == "2*I/3"
524
+ assert str(QQ_I(QQ(1, 2), QQ(-2, 3))) == "1/2 - 2*I/3"
525
+
526
+
527
+ def test_Pow():
528
+ assert str(x**-1) == "1/x"
529
+ assert str(x**-2) == "x**(-2)"
530
+ assert str(x**2) == "x**2"
531
+ assert str((x + y)**-1) == "1/(x + y)"
532
+ assert str((x + y)**-2) == "(x + y)**(-2)"
533
+ assert str((x + y)**2) == "(x + y)**2"
534
+ assert str((x + y)**(1 + x)) == "(x + y)**(x + 1)"
535
+ assert str(x**Rational(1, 3)) == "x**(1/3)"
536
+ assert str(1/x**Rational(1, 3)) == "x**(-1/3)"
537
+ assert str(sqrt(sqrt(x))) == "x**(1/4)"
538
+ # not the same as x**-1
539
+ assert str(x**-1.0) == 'x**(-1.0)'
540
+ # see issue #2860
541
+ assert str(Pow(S(2), -1.0, evaluate=False)) == '2**(-1.0)'
542
+
543
+
544
+ def test_sqrt():
545
+ assert str(sqrt(x)) == "sqrt(x)"
546
+ assert str(sqrt(x**2)) == "sqrt(x**2)"
547
+ assert str(1/sqrt(x)) == "1/sqrt(x)"
548
+ assert str(1/sqrt(x**2)) == "1/sqrt(x**2)"
549
+ assert str(y/sqrt(x)) == "y/sqrt(x)"
550
+ assert str(x**0.5) == "x**0.5"
551
+ assert str(1/x**0.5) == "x**(-0.5)"
552
+
553
+
554
+ def test_Rational():
555
+ n1 = Rational(1, 4)
556
+ n2 = Rational(1, 3)
557
+ n3 = Rational(2, 4)
558
+ n4 = Rational(2, -4)
559
+ n5 = Rational(0)
560
+ n7 = Rational(3)
561
+ n8 = Rational(-3)
562
+ assert str(n1*n2) == "1/12"
563
+ assert str(n1*n2) == "1/12"
564
+ assert str(n3) == "1/2"
565
+ assert str(n1*n3) == "1/8"
566
+ assert str(n1 + n3) == "3/4"
567
+ assert str(n1 + n2) == "7/12"
568
+ assert str(n1 + n4) == "-1/4"
569
+ assert str(n4*n4) == "1/4"
570
+ assert str(n4 + n2) == "-1/6"
571
+ assert str(n4 + n5) == "-1/2"
572
+ assert str(n4*n5) == "0"
573
+ assert str(n3 + n4) == "0"
574
+ assert str(n1**n7) == "1/64"
575
+ assert str(n2**n7) == "1/27"
576
+ assert str(n2**n8) == "27"
577
+ assert str(n7**n8) == "1/27"
578
+ assert str(Rational("-25")) == "-25"
579
+ assert str(Rational("1.25")) == "5/4"
580
+ assert str(Rational("-2.6e-2")) == "-13/500"
581
+ assert str(S("25/7")) == "25/7"
582
+ assert str(S("-123/569")) == "-123/569"
583
+ assert str(S("0.1[23]", rational=1)) == "61/495"
584
+ assert str(S("5.1[666]", rational=1)) == "31/6"
585
+ assert str(S("-5.1[666]", rational=1)) == "-31/6"
586
+ assert str(S("0.[9]", rational=1)) == "1"
587
+ assert str(S("-0.[9]", rational=1)) == "-1"
588
+
589
+ assert str(sqrt(Rational(1, 4))) == "1/2"
590
+ assert str(sqrt(Rational(1, 36))) == "1/6"
591
+
592
+ assert str((123**25) ** Rational(1, 25)) == "123"
593
+ assert str((123**25 + 1)**Rational(1, 25)) != "123"
594
+ assert str((123**25 - 1)**Rational(1, 25)) != "123"
595
+ assert str((123**25 - 1)**Rational(1, 25)) != "122"
596
+
597
+ assert str(sqrt(Rational(81, 36))**3) == "27/8"
598
+ assert str(1/sqrt(Rational(81, 36))**3) == "8/27"
599
+
600
+ assert str(sqrt(-4)) == str(2*I)
601
+ assert str(2**Rational(1, 10**10)) == "2**(1/10000000000)"
602
+
603
+ assert sstr(Rational(2, 3), sympy_integers=True) == "S(2)/3"
604
+ x = Symbol("x")
605
+ assert sstr(x**Rational(2, 3), sympy_integers=True) == "x**(S(2)/3)"
606
+ assert sstr(Eq(x, Rational(2, 3)), sympy_integers=True) == "Eq(x, S(2)/3)"
607
+ assert sstr(Limit(x, x, Rational(7, 2)), sympy_integers=True) == \
608
+ "Limit(x, x, S(7)/2, dir='+')"
609
+
610
+
611
+ def test_Float():
612
+ # NOTE dps is the whole number of decimal digits
613
+ assert str(Float('1.23', dps=1 + 2)) == '1.23'
614
+ assert str(Float('1.23456789', dps=1 + 8)) == '1.23456789'
615
+ assert str(
616
+ Float('1.234567890123456789', dps=1 + 18)) == '1.234567890123456789'
617
+ assert str(pi.evalf(1 + 2)) == '3.14'
618
+ assert str(pi.evalf(1 + 14)) == '3.14159265358979'
619
+ assert str(pi.evalf(1 + 64)) == ('3.141592653589793238462643383279'
620
+ '5028841971693993751058209749445923')
621
+ assert str(pi.round(-1)) == '0.0'
622
+ assert str((pi**400 - (pi**400).round(1)).n(2)) == '-0.e+88'
623
+ assert sstr(Float("100"), full_prec=False, min=-2, max=2) == '1.0e+2'
624
+ assert sstr(Float("100"), full_prec=False, min=-2, max=3) == '100.0'
625
+ assert sstr(Float("0.1"), full_prec=False, min=-2, max=3) == '0.1'
626
+ assert sstr(Float("0.099"), min=-2, max=3) == '9.90000000000000e-2'
627
+
628
+
629
+ def test_Relational():
630
+ assert str(Rel(x, y, "<")) == "x < y"
631
+ assert str(Rel(x + y, y, "==")) == "Eq(x + y, y)"
632
+ assert str(Rel(x, y, "!=")) == "Ne(x, y)"
633
+ assert str(Eq(x, 1) | Eq(x, 2)) == "Eq(x, 1) | Eq(x, 2)"
634
+ assert str(Ne(x, 1) & Ne(x, 2)) == "Ne(x, 1) & Ne(x, 2)"
635
+
636
+
637
+ def test_AppliedBinaryRelation():
638
+ assert str(Q.eq(x, y)) == "Q.eq(x, y)"
639
+ assert str(Q.ne(x, y)) == "Q.ne(x, y)"
640
+
641
+
642
+ def test_CRootOf():
643
+ assert str(rootof(x**5 + 2*x - 1, 0)) == "CRootOf(x**5 + 2*x - 1, 0)"
644
+
645
+
646
+ def test_RootSum():
647
+ f = x**5 + 2*x - 1
648
+
649
+ assert str(
650
+ RootSum(f, Lambda(z, z), auto=False)) == "RootSum(x**5 + 2*x - 1)"
651
+ assert str(RootSum(f, Lambda(
652
+ z, z**2), auto=False)) == "RootSum(x**5 + 2*x - 1, Lambda(z, z**2))"
653
+
654
+
655
+ def test_GroebnerBasis():
656
+ assert str(groebner(
657
+ [], x, y)) == "GroebnerBasis([], x, y, domain='ZZ', order='lex')"
658
+
659
+ F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1]
660
+
661
+ assert str(groebner(F, order='grlex')) == \
662
+ "GroebnerBasis([x**2 - x - 3*y + 1, y**2 - 2*x + y - 1], x, y, domain='ZZ', order='grlex')"
663
+ assert str(groebner(F, order='lex')) == \
664
+ "GroebnerBasis([2*x - y**2 - y + 1, y**4 + 2*y**3 - 3*y**2 - 16*y + 7], x, y, domain='ZZ', order='lex')"
665
+
666
+ def test_set():
667
+ assert sstr(set()) == 'set()'
668
+ assert sstr(frozenset()) == 'frozenset()'
669
+
670
+ assert sstr({1}) == '{1}'
671
+ assert sstr(frozenset([1])) == 'frozenset({1})'
672
+ assert sstr({1, 2, 3}) == '{1, 2, 3}'
673
+ assert sstr(frozenset([1, 2, 3])) == 'frozenset({1, 2, 3})'
674
+
675
+ assert sstr(
676
+ {1, x, x**2, x**3, x**4}) == '{1, x, x**2, x**3, x**4}'
677
+ assert sstr(
678
+ frozenset([1, x, x**2, x**3, x**4])) == 'frozenset({1, x, x**2, x**3, x**4})'
679
+
680
+
681
+ def test_SparseMatrix():
682
+ M = SparseMatrix([[x**+1, 1], [y, x + y]])
683
+ assert str(M) == "Matrix([[x, 1], [y, x + y]])"
684
+ assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])"
685
+
686
+
687
+ def test_Sum():
688
+ assert str(summation(cos(3*z), (z, x, y))) == "Sum(cos(3*z), (z, x, y))"
689
+ assert str(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \
690
+ "Sum(x*y**2, (x, -2, 2), (y, -5, 5))"
691
+
692
+
693
+ def test_Symbol():
694
+ assert str(y) == "y"
695
+ assert str(x) == "x"
696
+ e = x
697
+ assert str(e) == "x"
698
+
699
+
700
+ def test_tuple():
701
+ assert str((x,)) == sstr((x,)) == "(x,)"
702
+ assert str((x + y, 1 + x)) == sstr((x + y, 1 + x)) == "(x + y, x + 1)"
703
+ assert str((x + y, (
704
+ 1 + x, x**2))) == sstr((x + y, (1 + x, x**2))) == "(x + y, (x + 1, x**2))"
705
+
706
+
707
+ def test_Series_str():
708
+ tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y)
709
+ tf2 = TransferFunction(x - y, x + y, y)
710
+ tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y)
711
+ assert str(Series(tf1, tf2)) == \
712
+ "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))"
713
+ assert str(Series(tf1, tf2, tf3)) == \
714
+ "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))"
715
+ assert str(Series(-tf2, tf1)) == \
716
+ "Series(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))"
717
+
718
+
719
+ def test_MIMOSeries_str():
720
+ tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y)
721
+ tf2 = TransferFunction(x - y, x + y, y)
722
+ tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]])
723
+ tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]])
724
+ assert str(MIMOSeries(tfm_1, tfm_2)) == \
725
+ "MIMOSeries(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\
726
+ "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\
727
+ "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\
728
+ "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))"
729
+
730
+
731
+ def test_TransferFunction_str():
732
+ tf1 = TransferFunction(x - 1, x + 1, x)
733
+ assert str(tf1) == "TransferFunction(x - 1, x + 1, x)"
734
+ tf2 = TransferFunction(x + 1, 2 - y, x)
735
+ assert str(tf2) == "TransferFunction(x + 1, 2 - y, x)"
736
+ tf3 = TransferFunction(y, y**2 + 2*y + 3, y)
737
+ assert str(tf3) == "TransferFunction(y, y**2 + 2*y + 3, y)"
738
+
739
+
740
+ def test_Parallel_str():
741
+ tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y)
742
+ tf2 = TransferFunction(x - y, x + y, y)
743
+ tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y)
744
+ assert str(Parallel(tf1, tf2)) == \
745
+ "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))"
746
+ assert str(Parallel(tf1, tf2, tf3)) == \
747
+ "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))"
748
+ assert str(Parallel(-tf2, tf1)) == \
749
+ "Parallel(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))"
750
+
751
+
752
+ def test_MIMOParallel_str():
753
+ tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y)
754
+ tf2 = TransferFunction(x - y, x + y, y)
755
+ tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]])
756
+ tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]])
757
+ assert str(MIMOParallel(tfm_1, tfm_2)) == \
758
+ "MIMOParallel(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\
759
+ "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\
760
+ "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\
761
+ "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))"
762
+
763
+
764
+ def test_Feedback_str():
765
+ tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y)
766
+ tf2 = TransferFunction(x - y, x + y, y)
767
+ tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y)
768
+ assert str(Feedback(tf1*tf2, tf3)) == \
769
+ "Feedback(Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), " \
770
+ "TransferFunction(t*x**2 - t**w*x + w, t - y, y), -1)"
771
+ assert str(Feedback(tf1, TransferFunction(1, 1, y), 1)) == \
772
+ "Feedback(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(1, 1, y), 1)"
773
+
774
+
775
+ def test_MIMOFeedback_str():
776
+ tf1 = TransferFunction(x**2 - y**3, y - z, x)
777
+ tf2 = TransferFunction(y - x, z + y, x)
778
+ tfm_1 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]])
779
+ tfm_2 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]])
780
+ assert (str(MIMOFeedback(tfm_1, tfm_2)) \
781
+ == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x))," \
782
+ " (TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \
783
+ "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), " \
784
+ "TransferFunction(-x + y, y + z, x)), (TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), -1)")
785
+ assert (str(MIMOFeedback(tfm_1, tfm_2, 1)) \
786
+ == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)), " \
787
+ "(TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \
788
+ "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)), "\
789
+ "(TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), 1)")
790
+
791
+
792
+ def test_TransferFunctionMatrix_str():
793
+ tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y)
794
+ tf2 = TransferFunction(x - y, x + y, y)
795
+ tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y)
796
+ assert str(TransferFunctionMatrix([[tf1], [tf2]])) == \
797
+ "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y),), (TransferFunction(x - y, x + y, y),)))"
798
+ assert str(TransferFunctionMatrix([[tf1, tf2], [tf3, tf2]])) == \
799
+ "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), (TransferFunction(t*x**2 - t**w*x + w, t - y, y), TransferFunction(x - y, x + y, y))))"
800
+
801
+
802
+ def test_Quaternion_str_printer():
803
+ q = Quaternion(x, y, z, t)
804
+ assert str(q) == "x + y*i + z*j + t*k"
805
+ q = Quaternion(x,y,z,x*t)
806
+ assert str(q) == "x + y*i + z*j + t*x*k"
807
+ q = Quaternion(x,y,z,x+t)
808
+ assert str(q) == "x + y*i + z*j + (t + x)*k"
809
+
810
+
811
+ def test_Quantity_str():
812
+ assert sstr(second, abbrev=True) == "s"
813
+ assert sstr(joule, abbrev=True) == "J"
814
+ assert str(second) == "second"
815
+ assert str(joule) == "joule"
816
+
817
+
818
+ def test_wild_str():
819
+ # Check expressions containing Wild not causing infinite recursion
820
+ w = Wild('x')
821
+ assert str(w + 1) == 'x_ + 1'
822
+ assert str(exp(2**w) + 5) == 'exp(2**x_) + 5'
823
+ assert str(3*w + 1) == '3*x_ + 1'
824
+ assert str(1/w + 1) == '1 + 1/x_'
825
+ assert str(w**2 + 1) == 'x_**2 + 1'
826
+ assert str(1/(1 - w)) == '1/(1 - x_)'
827
+
828
+
829
+ def test_wild_matchpy():
830
+ from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar
831
+
832
+ matchpy = import_module("matchpy")
833
+
834
+ if matchpy is None:
835
+ return
836
+
837
+ wd = WildDot('w_')
838
+ wp = WildPlus('w__')
839
+ ws = WildStar('w___')
840
+
841
+ assert str(wd) == 'w_'
842
+ assert str(wp) == 'w__'
843
+ assert str(ws) == 'w___'
844
+
845
+ assert str(wp/ws + 2**wd) == '2**w_ + w__/w___'
846
+ assert str(sin(wd)*cos(wp)*sqrt(ws)) == 'sqrt(w___)*sin(w_)*cos(w__)'
847
+
848
+
849
+ def test_zeta():
850
+ assert str(zeta(3)) == "zeta(3)"
851
+
852
+
853
+ def test_issue_3101():
854
+ e = x - y
855
+ a = str(e)
856
+ b = str(e)
857
+ assert a == b
858
+
859
+
860
+ def test_issue_3103():
861
+ e = -2*sqrt(x) - y/sqrt(x)/2
862
+ assert str(e) not in ["(-2)*x**1/2(-1/2)*x**(-1/2)*y",
863
+ "-2*x**1/2(-1/2)*x**(-1/2)*y", "-2*x**1/2-1/2*x**-1/2*w"]
864
+ assert str(e) == "-2*sqrt(x) - y/(2*sqrt(x))"
865
+
866
+
867
+ def test_issue_4021():
868
+ e = Integral(x, x) + 1
869
+ assert str(e) == 'Integral(x, x) + 1'
870
+
871
+
872
+ def test_sstrrepr():
873
+ assert sstr('abc') == 'abc'
874
+ assert sstrrepr('abc') == "'abc'"
875
+
876
+ e = ['a', 'b', 'c', x]
877
+ assert sstr(e) == "[a, b, c, x]"
878
+ assert sstrrepr(e) == "['a', 'b', 'c', x]"
879
+
880
+
881
+ def test_infinity():
882
+ assert sstr(oo*I) == "oo*I"
883
+
884
+
885
+ def test_full_prec():
886
+ assert sstr(S("0.3"), full_prec=True) == "0.300000000000000"
887
+ assert sstr(S("0.3"), full_prec="auto") == "0.300000000000000"
888
+ assert sstr(S("0.3"), full_prec=False) == "0.3"
889
+ assert sstr(S("0.3")*x, full_prec=True) in [
890
+ "0.300000000000000*x",
891
+ "x*0.300000000000000"
892
+ ]
893
+ assert sstr(S("0.3")*x, full_prec="auto") in [
894
+ "0.3*x",
895
+ "x*0.3"
896
+ ]
897
+ assert sstr(S("0.3")*x, full_prec=False) in [
898
+ "0.3*x",
899
+ "x*0.3"
900
+ ]
901
+
902
+
903
+ def test_noncommutative():
904
+ A, B, C = symbols('A,B,C', commutative=False)
905
+
906
+ assert sstr(A*B*C**-1) == "A*B*C**(-1)"
907
+ assert sstr(C**-1*A*B) == "C**(-1)*A*B"
908
+ assert sstr(A*C**-1*B) == "A*C**(-1)*B"
909
+ assert sstr(sqrt(A)) == "sqrt(A)"
910
+ assert sstr(1/sqrt(A)) == "A**(-1/2)"
911
+
912
+
913
+ def test_empty_printer():
914
+ str_printer = StrPrinter()
915
+ assert str_printer.emptyPrinter("foo") == "foo"
916
+ assert str_printer.emptyPrinter(x*y) == "x*y"
917
+ assert str_printer.emptyPrinter(32) == "32"
918
+
919
+ def test_decimal_printer():
920
+ dec_printer = StrPrinter(settings={"dps":3})
921
+ f = Function('f')
922
+ assert dec_printer.doprint(f(1.329294)) == "f(1.33)"
923
+
924
+
925
+ def test_settings():
926
+ raises(TypeError, lambda: sstr(S(4), method="garbage"))
927
+
928
+
929
+ def test_RandomDomain():
930
+ from sympy.stats import Normal, Die, Exponential, pspace, where
931
+ X = Normal('x1', 0, 1)
932
+ assert str(where(X > 0)) == "Domain: (0 < x1) & (x1 < oo)"
933
+
934
+ D = Die('d1', 6)
935
+ assert str(where(D > 4)) == "Domain: Eq(d1, 5) | Eq(d1, 6)"
936
+
937
+ A = Exponential('a', 1)
938
+ B = Exponential('b', 1)
939
+ assert str(pspace(Tuple(A, B)).domain) == "Domain: (0 <= a) & (0 <= b) & (a < oo) & (b < oo)"
940
+
941
+
942
+ def test_FiniteSet():
943
+ assert str(FiniteSet(*range(1, 51))) == (
944
+ '{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,'
945
+ ' 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,'
946
+ ' 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50}'
947
+ )
948
+ assert str(FiniteSet(*range(1, 6))) == '{1, 2, 3, 4, 5}'
949
+ assert str(FiniteSet(*[x*y, x**2])) == '{x**2, x*y}'
950
+ assert str(FiniteSet(FiniteSet(FiniteSet(x, y), 5), FiniteSet(x,y), 5)
951
+ ) == 'FiniteSet(5, FiniteSet(5, {x, y}), {x, y})'
952
+
953
+
954
+ def test_Partition():
955
+ assert str(Partition(FiniteSet(x, y), {z})) == 'Partition({z}, {x, y})'
956
+
957
+ def test_UniversalSet():
958
+ assert str(S.UniversalSet) == 'UniversalSet'
959
+
960
+
961
+ def test_PrettyPoly():
962
+ F = QQ.frac_field(x, y)
963
+ R = QQ[x, y]
964
+ assert sstr(F.convert(x/(x + y))) == sstr(x/(x + y))
965
+ assert sstr(R.convert(x + y)) == sstr(x + y)
966
+
967
+
968
+ def test_categories():
969
+ from sympy.categories import (Object, NamedMorphism,
970
+ IdentityMorphism, Category)
971
+
972
+ A = Object("A")
973
+ B = Object("B")
974
+
975
+ f = NamedMorphism(A, B, "f")
976
+ id_A = IdentityMorphism(A)
977
+
978
+ K = Category("K")
979
+
980
+ assert str(A) == 'Object("A")'
981
+ assert str(f) == 'NamedMorphism(Object("A"), Object("B"), "f")'
982
+ assert str(id_A) == 'IdentityMorphism(Object("A"))'
983
+
984
+ assert str(K) == 'Category("K")'
985
+
986
+
987
+ def test_Tr():
988
+ A, B = symbols('A B', commutative=False)
989
+ t = Tr(A*B)
990
+ assert str(t) == 'Tr(A*B)'
991
+
992
+
993
+ def test_issue_6387():
994
+ assert str(factor(-3.0*z + 3)) == '-3.0*(1.0*z - 1.0)'
995
+
996
+
997
+ def test_MatMul_MatAdd():
998
+ X, Y = MatrixSymbol("X", 2, 2), MatrixSymbol("Y", 2, 2)
999
+ assert str(2*(X + Y)) == "2*X + 2*Y"
1000
+
1001
+ assert str(I*X) == "I*X"
1002
+ assert str(-I*X) == "-I*X"
1003
+ assert str((1 + I)*X) == '(1 + I)*X'
1004
+ assert str(-(1 + I)*X) == '(-1 - I)*X'
1005
+ assert str(MatAdd(MatAdd(X, Y), MatAdd(X, Y))) == '(X + Y) + (X + Y)'
1006
+
1007
+
1008
+ def test_MatrixSlice():
1009
+ n = Symbol('n', integer=True)
1010
+ X = MatrixSymbol('X', n, n)
1011
+ Y = MatrixSymbol('Y', 10, 10)
1012
+ Z = MatrixSymbol('Z', 10, 10)
1013
+
1014
+ assert str(MatrixSlice(X, (None, None, None), (None, None, None))) == 'X[:, :]'
1015
+ assert str(X[x:x + 1, y:y + 1]) == 'X[x:x + 1, y:y + 1]'
1016
+ assert str(X[x:x + 1:2, y:y + 1:2]) == 'X[x:x + 1:2, y:y + 1:2]'
1017
+ assert str(X[:x, y:]) == 'X[:x, y:]'
1018
+ assert str(X[:x, y:]) == 'X[:x, y:]'
1019
+ assert str(X[x:, :y]) == 'X[x:, :y]'
1020
+ assert str(X[x:y, z:w]) == 'X[x:y, z:w]'
1021
+ assert str(X[x:y:t, w:t:x]) == 'X[x:y:t, w:t:x]'
1022
+ assert str(X[x::y, t::w]) == 'X[x::y, t::w]'
1023
+ assert str(X[:x:y, :t:w]) == 'X[:x:y, :t:w]'
1024
+ assert str(X[::x, ::y]) == 'X[::x, ::y]'
1025
+ assert str(MatrixSlice(X, (0, None, None), (0, None, None))) == 'X[:, :]'
1026
+ assert str(MatrixSlice(X, (None, n, None), (None, n, None))) == 'X[:, :]'
1027
+ assert str(MatrixSlice(X, (0, n, None), (0, n, None))) == 'X[:, :]'
1028
+ assert str(MatrixSlice(X, (0, n, 2), (0, n, 2))) == 'X[::2, ::2]'
1029
+ assert str(X[1:2:3, 4:5:6]) == 'X[1:2:3, 4:5:6]'
1030
+ assert str(X[1:3:5, 4:6:8]) == 'X[1:3:5, 4:6:8]'
1031
+ assert str(X[1:10:2]) == 'X[1:10:2, :]'
1032
+ assert str(Y[:5, 1:9:2]) == 'Y[:5, 1:9:2]'
1033
+ assert str(Y[:5, 1:10:2]) == 'Y[:5, 1::2]'
1034
+ assert str(Y[5, :5:2]) == 'Y[5:6, :5:2]'
1035
+ assert str(X[0:1, 0:1]) == 'X[:1, :1]'
1036
+ assert str(X[0:1:2, 0:1:2]) == 'X[:1:2, :1:2]'
1037
+ assert str((Y + Z)[2:, 2:]) == '(Y + Z)[2:, 2:]'
1038
+
1039
+ def test_true_false():
1040
+ assert str(true) == repr(true) == sstr(true) == "True"
1041
+ assert str(false) == repr(false) == sstr(false) == "False"
1042
+
1043
+ def test_Equivalent():
1044
+ assert str(Equivalent(y, x)) == "Equivalent(x, y)"
1045
+
1046
+ def test_Xor():
1047
+ assert str(Xor(y, x, evaluate=False)) == "x ^ y"
1048
+
1049
+ def test_Complement():
1050
+ assert str(Complement(S.Reals, S.Naturals)) == 'Complement(Reals, Naturals)'
1051
+
1052
+ def test_SymmetricDifference():
1053
+ assert str(SymmetricDifference(Interval(2, 3), Interval(3, 4),evaluate=False)) == \
1054
+ 'SymmetricDifference(Interval(2, 3), Interval(3, 4))'
1055
+
1056
+
1057
+ def test_UnevaluatedExpr():
1058
+ a, b = symbols("a b")
1059
+ expr1 = 2*UnevaluatedExpr(a+b)
1060
+ assert str(expr1) == "2*(a + b)"
1061
+
1062
+
1063
+ def test_MatrixElement_printing():
1064
+ # test cases for issue #11821
1065
+ A = MatrixSymbol("A", 1, 3)
1066
+ B = MatrixSymbol("B", 1, 3)
1067
+ C = MatrixSymbol("C", 1, 3)
1068
+
1069
+ assert(str(A[0, 0]) == "A[0, 0]")
1070
+ assert(str(3 * A[0, 0]) == "3*A[0, 0]")
1071
+
1072
+ F = C[0, 0].subs(C, A - B)
1073
+ assert str(F) == "(A - B)[0, 0]"
1074
+
1075
+
1076
+ def test_MatrixSymbol_printing():
1077
+ A = MatrixSymbol("A", 3, 3)
1078
+ B = MatrixSymbol("B", 3, 3)
1079
+
1080
+ assert str(A - A*B - B) == "A - A*B - B"
1081
+ assert str(A*B - (A+B)) == "-A + A*B - B"
1082
+ assert str(A**(-1)) == "A**(-1)"
1083
+ assert str(A**3) == "A**3"
1084
+
1085
+
1086
+ def test_MatrixExpressions():
1087
+ n = Symbol('n', integer=True)
1088
+ X = MatrixSymbol('X', n, n)
1089
+
1090
+ assert str(X) == "X"
1091
+
1092
+ # Apply function elementwise (`ElementwiseApplyFunc`):
1093
+
1094
+ expr = (X.T*X).applyfunc(sin)
1095
+ assert str(expr) == 'Lambda(_d, sin(_d)).(X.T*X)'
1096
+
1097
+ lamda = Lambda(x, 1/x)
1098
+ expr = (n*X).applyfunc(lamda)
1099
+ assert str(expr) == 'Lambda(x, 1/x).(n*X)'
1100
+
1101
+
1102
+ def test_Subs_printing():
1103
+ assert str(Subs(x, (x,), (1,))) == 'Subs(x, x, 1)'
1104
+ assert str(Subs(x + y, (x, y), (1, 2))) == 'Subs(x + y, (x, y), (1, 2))'
1105
+
1106
+
1107
+ def test_issue_15716():
1108
+ e = Integral(factorial(x), (x, -oo, oo))
1109
+ assert e.as_terms() == ([(e, ((1.0, 0.0), (1,), ()))], [e])
1110
+
1111
+
1112
+ def test_str_special_matrices():
1113
+ from sympy.matrices import Identity, ZeroMatrix, OneMatrix
1114
+ assert str(Identity(4)) == 'I'
1115
+ assert str(ZeroMatrix(2, 2)) == '0'
1116
+ assert str(OneMatrix(2, 2)) == '1'
1117
+
1118
+
1119
+ def test_issue_14567():
1120
+ assert factorial(Sum(-1, (x, 0, 0))) + y # doesn't raise an error
1121
+
1122
+
1123
+ def test_issue_21823():
1124
+ assert str(Partition([1, 2])) == 'Partition({1, 2})'
1125
+ assert str(Partition({1, 2})) == 'Partition({1, 2})'
1126
+
1127
+
1128
+ def test_issue_22689():
1129
+ assert str(Mul(Pow(x,-2, evaluate=False), Pow(3,-1,evaluate=False), evaluate=False)) == "1/(x**2*3)"
1130
+
1131
+
1132
+ def test_issue_21119_21460():
1133
+ ss = lambda x: str(S(x, evaluate=False))
1134
+ assert ss('4/2') == '4/2'
1135
+ assert ss('4/-2') == '4/(-2)'
1136
+ assert ss('-4/2') == '-4/2'
1137
+ assert ss('-4/-2') == '-4/(-2)'
1138
+ assert ss('-2*3/-1') == '-2*3/(-1)'
1139
+ assert ss('-2*3/-1/2') == '-2*3/(-1*2)'
1140
+ assert ss('4/2/1') == '4/(2*1)'
1141
+ assert ss('-2/-1/2') == '-2/(-1*2)'
1142
+ assert ss('2*3*4**(-2*3)') == '2*3/4**(2*3)'
1143
+ assert ss('2*3*1*4**(-2*3)') == '2*3*1/4**(2*3)'
1144
+
1145
+
1146
+ def test_Str():
1147
+ from sympy.core.symbol import Str
1148
+ assert str(Str('x')) == 'x'
1149
+ assert sstrrepr(Str('x')) == "Str('x')"
1150
+
1151
+
1152
+ def test_diffgeom():
1153
+ from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField
1154
+ x,y = symbols('x y', real=True)
1155
+ m = Manifold('M', 2)
1156
+ assert str(m) == "M"
1157
+ p = Patch('P', m)
1158
+ assert str(p) == "P"
1159
+ rect = CoordSystem('rect', p, [x, y])
1160
+ assert str(rect) == "rect"
1161
+ b = BaseScalarField(rect, 0)
1162
+ assert str(b) == "x"
1163
+
1164
+ def test_NDimArray():
1165
+ assert sstr(NDimArray(1.0), full_prec=True) == '1.00000000000000'
1166
+ assert sstr(NDimArray(1.0), full_prec=False) == '1.0'
1167
+ assert sstr(NDimArray([1.0, 2.0]), full_prec=True) == '[1.00000000000000, 2.00000000000000]'
1168
+ assert sstr(NDimArray([1.0, 2.0]), full_prec=False) == '[1.0, 2.0]'
1169
+ assert sstr(NDimArray([], (0,))) == 'ImmutableDenseNDimArray([], (0,))'
1170
+ assert sstr(NDimArray([], (0, 0))) == 'ImmutableDenseNDimArray([], (0, 0))'
1171
+ assert sstr(NDimArray([], (0, 1))) == 'ImmutableDenseNDimArray([], (0, 1))'
1172
+ assert sstr(NDimArray([], (1, 0))) == 'ImmutableDenseNDimArray([], (1, 0))'
1173
+
1174
+ def test_Predicate():
1175
+ assert sstr(Q.even) == 'Q.even'
1176
+
1177
+ def test_AppliedPredicate():
1178
+ assert sstr(Q.even(x)) == 'Q.even(x)'
1179
+
1180
+ def test_printing_str_array_expressions():
1181
+ assert sstr(ArraySymbol("A", (2, 3, 4))) == "A"
1182
+ assert sstr(ArrayElement("A", (2, 1/(1-x), 0))) == "A[2, 1/(1 - x), 0]"
1183
+ M = MatrixSymbol("M", 3, 3)
1184
+ N = MatrixSymbol("N", 3, 3)
1185
+ assert sstr(ArrayElement(M*N, [x, 0])) == "(M*N)[x, 0]"
1186
+
1187
+ def test_printing_stats():
1188
+ # issue 24132
1189
+ x = RandomSymbol("x")
1190
+ y = RandomSymbol("y")
1191
+ z1 = Probability(x > 0)*Identity(2)
1192
+ z2 = Expectation(x)*Identity(2)
1193
+ z3 = Variance(x)*Identity(2)
1194
+ z4 = Covariance(x, y) * Identity(2)
1195
+
1196
+ assert str(z1) == "Probability(x > 0)*I"
1197
+ assert str(z2) == "Expectation(x)*I"
1198
+ assert str(z3) == "Variance(x)*I"
1199
+ assert str(z4) == "Covariance(x, y)*I"
1200
+ assert z1.is_commutative == False
1201
+ assert z2.is_commutative == False
1202
+ assert z3.is_commutative == False
1203
+ assert z4.is_commutative == False
1204
+ assert z2._eval_is_commutative() == False
1205
+ assert z3._eval_is_commutative() == False
1206
+ assert z4._eval_is_commutative() == False
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tableform.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.singleton import S
2
+ from sympy.printing.tableform import TableForm
3
+ from sympy.printing.latex import latex
4
+ from sympy.abc import x
5
+ from sympy.functions.elementary.miscellaneous import sqrt
6
+ from sympy.functions.elementary.trigonometric import sin
7
+ from sympy.testing.pytest import raises
8
+
9
+ from textwrap import dedent
10
+
11
+
12
+ def test_TableForm():
13
+ s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]],
14
+ headings="automatic"))
15
+ assert s == (
16
+ ' | 1 2\n'
17
+ '-------\n'
18
+ '1 | a b\n'
19
+ '2 | c d\n'
20
+ '3 | e '
21
+ )
22
+ s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]],
23
+ headings="automatic", wipe_zeros=False))
24
+ assert s == dedent('''\
25
+ | 1 2
26
+ -------
27
+ 1 | a b
28
+ 2 | c d
29
+ 3 | e 0''')
30
+ s = str(TableForm([[x**2, "b"], ["c", x**2], ["e", "f"]],
31
+ headings=("automatic", None)))
32
+ assert s == (
33
+ '1 | x**2 b \n'
34
+ '2 | c x**2\n'
35
+ '3 | e f '
36
+ )
37
+ s = str(TableForm([["a", "b"], ["c", "d"], ["e", "f"]],
38
+ headings=(None, "automatic")))
39
+ assert s == dedent('''\
40
+ 1 2
41
+ ---
42
+ a b
43
+ c d
44
+ e f''')
45
+ s = str(TableForm([[5, 7], [4, 2], [10, 3]],
46
+ headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]]))
47
+ assert s == (
48
+ ' | y1 y2\n'
49
+ '---------------\n'
50
+ 'Group A | 5 7 \n'
51
+ 'Group B | 4 2 \n'
52
+ 'Group C | 10 3 '
53
+ )
54
+ raises(
55
+ ValueError,
56
+ lambda:
57
+ TableForm(
58
+ [[5, 7], [4, 2], [10, 3]],
59
+ headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]],
60
+ alignments="middle")
61
+ )
62
+ s = str(TableForm([[5, 7], [4, 2], [10, 3]],
63
+ headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]],
64
+ alignments="right"))
65
+ assert s == dedent('''\
66
+ | y1 y2
67
+ ---------------
68
+ Group A | 5 7
69
+ Group B | 4 2
70
+ Group C | 10 3''')
71
+
72
+ # other alignment permutations
73
+ d = [[1, 100], [100, 1]]
74
+ s = TableForm(d, headings=(('xxx', 'x'), None), alignments='l')
75
+ assert str(s) == (
76
+ 'xxx | 1 100\n'
77
+ ' x | 100 1 '
78
+ )
79
+ s = TableForm(d, headings=(('xxx', 'x'), None), alignments='lr')
80
+ assert str(s) == dedent('''\
81
+ xxx | 1 100
82
+ x | 100 1''')
83
+ s = TableForm(d, headings=(('xxx', 'x'), None), alignments='clr')
84
+ assert str(s) == dedent('''\
85
+ xxx | 1 100
86
+ x | 100 1''')
87
+
88
+ s = TableForm(d, headings=(('xxx', 'x'), None))
89
+ assert str(s) == (
90
+ 'xxx | 1 100\n'
91
+ ' x | 100 1 '
92
+ )
93
+
94
+ raises(ValueError, lambda: TableForm(d, alignments='clr'))
95
+
96
+ #pad
97
+ s = str(TableForm([[None, "-", 2], [1]], pad='?'))
98
+ assert s == dedent('''\
99
+ ? - 2
100
+ 1 ? ?''')
101
+
102
+
103
+ def test_TableForm_latex():
104
+ s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
105
+ wipe_zeros=True, headings=("automatic", "automatic")))
106
+ assert s == (
107
+ '\\begin{tabular}{r l l}\n'
108
+ ' & 1 & 2 \\\\\n'
109
+ '\\hline\n'
110
+ '1 & & $x^{3}$ \\\\\n'
111
+ '2 & $c$ & $\\frac{1}{4}$ \\\\\n'
112
+ '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
113
+ '\\end{tabular}'
114
+ )
115
+ s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
116
+ wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'))
117
+ assert s == (
118
+ '\\begin{tabular}{r l l}\n'
119
+ ' & 1 & 2 \\\\\n'
120
+ '\\hline\n'
121
+ '1 & & $x^{3}$ \\\\\n'
122
+ '2 & $c$ & $\\frac{1}{4}$ \\\\\n'
123
+ '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
124
+ '\\end{tabular}'
125
+ )
126
+ s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
127
+ wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'*3))
128
+ assert s == (
129
+ '\\begin{tabular}{l l l}\n'
130
+ ' & 1 & 2 \\\\\n'
131
+ '\\hline\n'
132
+ '1 & & $x^{3}$ \\\\\n'
133
+ '2 & $c$ & $\\frac{1}{4}$ \\\\\n'
134
+ '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
135
+ '\\end{tabular}'
136
+ )
137
+ s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
138
+ headings=("automatic", "automatic")))
139
+ assert s == (
140
+ '\\begin{tabular}{r l l}\n'
141
+ ' & 1 & 2 \\\\\n'
142
+ '\\hline\n'
143
+ '1 & $a$ & $x^{3}$ \\\\\n'
144
+ '2 & $c$ & $\\frac{1}{4}$ \\\\\n'
145
+ '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
146
+ '\\end{tabular}'
147
+ )
148
+ s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]],
149
+ formats=['(%s)', None], headings=("automatic", "automatic")))
150
+ assert s == (
151
+ '\\begin{tabular}{r l l}\n'
152
+ ' & 1 & 2 \\\\\n'
153
+ '\\hline\n'
154
+ '1 & (a) & $x^{3}$ \\\\\n'
155
+ '2 & (c) & $\\frac{1}{4}$ \\\\\n'
156
+ '3 & (sqrt(x)) & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
157
+ '\\end{tabular}'
158
+ )
159
+
160
+ def neg_in_paren(x, i, j):
161
+ if i % 2:
162
+ return ('(%s)' if x < 0 else '%s') % x
163
+ else:
164
+ pass # use default print
165
+ s = latex(TableForm([[-1, 2], [-3, 4]],
166
+ formats=[neg_in_paren]*2, headings=("automatic", "automatic")))
167
+ assert s == (
168
+ '\\begin{tabular}{r l l}\n'
169
+ ' & 1 & 2 \\\\\n'
170
+ '\\hline\n'
171
+ '1 & -1 & 2 \\\\\n'
172
+ '2 & (-3) & 4 \\\\\n'
173
+ '\\end{tabular}'
174
+ )
175
+ s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]]))
176
+ assert s == (
177
+ '\\begin{tabular}{l l}\n'
178
+ '$a$ & $x^{3}$ \\\\\n'
179
+ '$c$ & $\\frac{1}{4}$ \\\\\n'
180
+ '$\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n'
181
+ '\\end{tabular}'
182
+ )
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tensorflow.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from sympy.core.function import Derivative
3
+ from sympy.core.symbol import symbols
4
+ from sympy import Piecewise
5
+ from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \
6
+ PermuteDims, ArrayDiagonal
7
+ from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt
8
+ from sympy.external import import_module
9
+ from sympy.functions import \
10
+ Abs, ceiling, exp, floor, sign, sin, asin, sqrt, cos, \
11
+ acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \
12
+ re, im, arg, erf, loggamma, log
13
+ from sympy.codegen.cfunctions import isnan, isinf
14
+ from sympy.matrices import Matrix, MatrixBase, eye, randMatrix
15
+ from sympy.matrices.expressions import \
16
+ Determinant, HadamardProduct, Inverse, MatrixSymbol, Trace
17
+ from sympy.printing.tensorflow import tensorflow_code
18
+ from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
19
+ from sympy.utilities.lambdify import lambdify
20
+ from sympy.testing.pytest import skip
21
+ from sympy.testing.pytest import XFAIL
22
+
23
+
24
+ tf = tensorflow = import_module("tensorflow")
25
+
26
+ if tensorflow:
27
+ # Hide Tensorflow warnings
28
+ import os
29
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
30
+
31
+
32
+ M = MatrixSymbol("M", 3, 3)
33
+ N = MatrixSymbol("N", 3, 3)
34
+ P = MatrixSymbol("P", 3, 3)
35
+ Q = MatrixSymbol("Q", 3, 3)
36
+
37
+ x, y, z, t = symbols("x y z t")
38
+
39
+ if tf is not None:
40
+ llo = [list(range(i, i+3)) for i in range(0, 9, 3)]
41
+ m3x3 = tf.constant(llo)
42
+ m3x3sympy = Matrix(llo)
43
+
44
+
45
+ def _compare_tensorflow_matrix(variables, expr, use_float=False):
46
+ f = lambdify(variables, expr, 'tensorflow')
47
+ if not use_float:
48
+ random_matrices = [randMatrix(v.rows, v.cols) for v in variables]
49
+ else:
50
+ random_matrices = [randMatrix(v.rows, v.cols)/100. for v in variables]
51
+
52
+ graph = tf.Graph()
53
+ r = None
54
+ with graph.as_default():
55
+ random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
56
+ session = tf.compat.v1.Session(graph=graph)
57
+ r = session.run(f(*random_variables))
58
+
59
+ e = expr.subs(dict(zip(variables, random_matrices)))
60
+ e = e.doit()
61
+ if e.is_Matrix:
62
+ if not isinstance(e, MatrixBase):
63
+ e = e.as_explicit()
64
+ e = e.tolist()
65
+
66
+ if not use_float:
67
+ assert (r == e).all()
68
+ else:
69
+ r = [i for row in r for i in row]
70
+ e = [i for row in e for i in row]
71
+ assert all(
72
+ abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e))
73
+
74
+
75
+ # Creating a custom inverse test.
76
+ # See https://github.com/sympy/sympy/issues/18469
77
+ def _compare_tensorflow_matrix_inverse(variables, expr, use_float=False):
78
+ f = lambdify(variables, expr, 'tensorflow')
79
+ if not use_float:
80
+ random_matrices = [eye(v.rows, v.cols)*4 for v in variables]
81
+ else:
82
+ random_matrices = [eye(v.rows, v.cols)*3.14 for v in variables]
83
+
84
+ graph = tf.Graph()
85
+ r = None
86
+ with graph.as_default():
87
+ random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
88
+ session = tf.compat.v1.Session(graph=graph)
89
+ r = session.run(f(*random_variables))
90
+
91
+ e = expr.subs(dict(zip(variables, random_matrices)))
92
+ e = e.doit()
93
+ if e.is_Matrix:
94
+ if not isinstance(e, MatrixBase):
95
+ e = e.as_explicit()
96
+ e = e.tolist()
97
+
98
+ if not use_float:
99
+ assert (r == e).all()
100
+ else:
101
+ r = [i for row in r for i in row]
102
+ e = [i for row in e for i in row]
103
+ assert all(
104
+ abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e))
105
+
106
+
107
+ def _compare_tensorflow_matrix_scalar(variables, expr):
108
+ f = lambdify(variables, expr, 'tensorflow')
109
+ random_matrices = [
110
+ randMatrix(v.rows, v.cols).evalf() / 100 for v in variables]
111
+
112
+ graph = tf.Graph()
113
+ r = None
114
+ with graph.as_default():
115
+ random_variables = [eval(tensorflow_code(i)) for i in random_matrices]
116
+ session = tf.compat.v1.Session(graph=graph)
117
+ r = session.run(f(*random_variables))
118
+
119
+ e = expr.subs(dict(zip(variables, random_matrices)))
120
+ e = e.doit()
121
+ assert abs(r-e) < 10**-6
122
+
123
+
124
+ def _compare_tensorflow_scalar(
125
+ variables, expr, rng=lambda: random.randint(0, 10)):
126
+ f = lambdify(variables, expr, 'tensorflow')
127
+ rvs = [rng() for v in variables]
128
+
129
+ graph = tf.Graph()
130
+ r = None
131
+ with graph.as_default():
132
+ tf_rvs = [eval(tensorflow_code(i)) for i in rvs]
133
+ session = tf.compat.v1.Session(graph=graph)
134
+ r = session.run(f(*tf_rvs))
135
+
136
+ e = expr.subs(dict(zip(variables, rvs))).evalf().doit()
137
+ assert abs(r-e) < 10**-6
138
+
139
+
140
+ def _compare_tensorflow_relational(
141
+ variables, expr, rng=lambda: random.randint(0, 10)):
142
+ f = lambdify(variables, expr, 'tensorflow')
143
+ rvs = [rng() for v in variables]
144
+
145
+ graph = tf.Graph()
146
+ r = None
147
+ with graph.as_default():
148
+ tf_rvs = [eval(tensorflow_code(i)) for i in rvs]
149
+ session = tf.compat.v1.Session(graph=graph)
150
+ r = session.run(f(*tf_rvs))
151
+
152
+ e = expr.subs(dict(zip(variables, rvs))).doit()
153
+ assert r == e
154
+
155
+
156
+ def test_tensorflow_printing():
157
+ assert tensorflow_code(eye(3)) == \
158
+ "tensorflow.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])"
159
+
160
+ expr = Matrix([[x, sin(y)], [exp(z), -t]])
161
+ assert tensorflow_code(expr) == \
162
+ "tensorflow.Variable(" \
163
+ "[[x, tensorflow.math.sin(y)]," \
164
+ " [tensorflow.math.exp(z), -t]])"
165
+
166
+
167
+ # This (random) test is XFAIL because it fails occasionally
168
+ # See https://github.com/sympy/sympy/issues/18469
169
+ @XFAIL
170
+ def test_tensorflow_math():
171
+ if not tf:
172
+ skip("TensorFlow not installed")
173
+
174
+ expr = Abs(x)
175
+ assert tensorflow_code(expr) == "tensorflow.math.abs(x)"
176
+ _compare_tensorflow_scalar((x,), expr)
177
+
178
+ expr = sign(x)
179
+ assert tensorflow_code(expr) == "tensorflow.math.sign(x)"
180
+ _compare_tensorflow_scalar((x,), expr)
181
+
182
+ expr = ceiling(x)
183
+ assert tensorflow_code(expr) == "tensorflow.math.ceil(x)"
184
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
185
+
186
+ expr = floor(x)
187
+ assert tensorflow_code(expr) == "tensorflow.math.floor(x)"
188
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
189
+
190
+ expr = exp(x)
191
+ assert tensorflow_code(expr) == "tensorflow.math.exp(x)"
192
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
193
+
194
+ expr = sqrt(x)
195
+ assert tensorflow_code(expr) == "tensorflow.math.sqrt(x)"
196
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
197
+
198
+ expr = x ** 4
199
+ assert tensorflow_code(expr) == "tensorflow.math.pow(x, 4)"
200
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
201
+
202
+ expr = cos(x)
203
+ assert tensorflow_code(expr) == "tensorflow.math.cos(x)"
204
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
205
+
206
+ expr = acos(x)
207
+ assert tensorflow_code(expr) == "tensorflow.math.acos(x)"
208
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(0, 0.95))
209
+
210
+ expr = sin(x)
211
+ assert tensorflow_code(expr) == "tensorflow.math.sin(x)"
212
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
213
+
214
+ expr = asin(x)
215
+ assert tensorflow_code(expr) == "tensorflow.math.asin(x)"
216
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
217
+
218
+ expr = tan(x)
219
+ assert tensorflow_code(expr) == "tensorflow.math.tan(x)"
220
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
221
+
222
+ expr = atan(x)
223
+ assert tensorflow_code(expr) == "tensorflow.math.atan(x)"
224
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
225
+
226
+ expr = atan2(y, x)
227
+ assert tensorflow_code(expr) == "tensorflow.math.atan2(y, x)"
228
+ _compare_tensorflow_scalar((y, x), expr, rng=lambda: random.random())
229
+
230
+ expr = cosh(x)
231
+ assert tensorflow_code(expr) == "tensorflow.math.cosh(x)"
232
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random())
233
+
234
+ expr = acosh(x)
235
+ assert tensorflow_code(expr) == "tensorflow.math.acosh(x)"
236
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
237
+
238
+ expr = sinh(x)
239
+ assert tensorflow_code(expr) == "tensorflow.math.sinh(x)"
240
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
241
+
242
+ expr = asinh(x)
243
+ assert tensorflow_code(expr) == "tensorflow.math.asinh(x)"
244
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
245
+
246
+ expr = tanh(x)
247
+ assert tensorflow_code(expr) == "tensorflow.math.tanh(x)"
248
+ _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2))
249
+
250
+ expr = atanh(x)
251
+ assert tensorflow_code(expr) == "tensorflow.math.atanh(x)"
252
+ _compare_tensorflow_scalar(
253
+ (x,), expr, rng=lambda: random.uniform(-.5, .5))
254
+
255
+ expr = erf(x)
256
+ assert tensorflow_code(expr) == "tensorflow.math.erf(x)"
257
+ _compare_tensorflow_scalar(
258
+ (x,), expr, rng=lambda: random.random())
259
+
260
+ expr = loggamma(x)
261
+ assert tensorflow_code(expr) == "tensorflow.math.lgamma(x)"
262
+ _compare_tensorflow_scalar(
263
+ (x,), expr, rng=lambda: random.random())
264
+
265
+
266
+ def test_tensorflow_complexes():
267
+ assert tensorflow_code(re(x)) == "tensorflow.math.real(x)"
268
+ assert tensorflow_code(im(x)) == "tensorflow.math.imag(x)"
269
+ assert tensorflow_code(arg(x)) == "tensorflow.math.angle(x)"
270
+
271
+
272
+ def test_tensorflow_relational():
273
+ if not tf:
274
+ skip("TensorFlow not installed")
275
+
276
+ expr = Eq(x, y)
277
+ assert tensorflow_code(expr) == "tensorflow.math.equal(x, y)"
278
+ _compare_tensorflow_relational((x, y), expr)
279
+
280
+ expr = Ne(x, y)
281
+ assert tensorflow_code(expr) == "tensorflow.math.not_equal(x, y)"
282
+ _compare_tensorflow_relational((x, y), expr)
283
+
284
+ expr = Ge(x, y)
285
+ assert tensorflow_code(expr) == "tensorflow.math.greater_equal(x, y)"
286
+ _compare_tensorflow_relational((x, y), expr)
287
+
288
+ expr = Gt(x, y)
289
+ assert tensorflow_code(expr) == "tensorflow.math.greater(x, y)"
290
+ _compare_tensorflow_relational((x, y), expr)
291
+
292
+ expr = Le(x, y)
293
+ assert tensorflow_code(expr) == "tensorflow.math.less_equal(x, y)"
294
+ _compare_tensorflow_relational((x, y), expr)
295
+
296
+ expr = Lt(x, y)
297
+ assert tensorflow_code(expr) == "tensorflow.math.less(x, y)"
298
+ _compare_tensorflow_relational((x, y), expr)
299
+
300
+
301
+ # This (random) test is XFAIL because it fails occasionally
302
+ # See https://github.com/sympy/sympy/issues/18469
303
+ @XFAIL
304
+ def test_tensorflow_matrices():
305
+ if not tf:
306
+ skip("TensorFlow not installed")
307
+
308
+ expr = M
309
+ assert tensorflow_code(expr) == "M"
310
+ _compare_tensorflow_matrix((M,), expr)
311
+
312
+ expr = M + N
313
+ assert tensorflow_code(expr) == "tensorflow.math.add(M, N)"
314
+ _compare_tensorflow_matrix((M, N), expr)
315
+
316
+ expr = M * N
317
+ assert tensorflow_code(expr) == "tensorflow.linalg.matmul(M, N)"
318
+ _compare_tensorflow_matrix((M, N), expr)
319
+
320
+ expr = HadamardProduct(M, N)
321
+ assert tensorflow_code(expr) == "tensorflow.math.multiply(M, N)"
322
+ _compare_tensorflow_matrix((M, N), expr)
323
+
324
+ expr = M*N*P*Q
325
+ assert tensorflow_code(expr) == \
326
+ "tensorflow.linalg.matmul(" \
327
+ "tensorflow.linalg.matmul(" \
328
+ "tensorflow.linalg.matmul(M, N), P), Q)"
329
+ _compare_tensorflow_matrix((M, N, P, Q), expr)
330
+
331
+ expr = M**3
332
+ assert tensorflow_code(expr) == \
333
+ "tensorflow.linalg.matmul(tensorflow.linalg.matmul(M, M), M)"
334
+ _compare_tensorflow_matrix((M,), expr)
335
+
336
+ expr = Trace(M)
337
+ assert tensorflow_code(expr) == "tensorflow.linalg.trace(M)"
338
+ _compare_tensorflow_matrix((M,), expr)
339
+
340
+ expr = Determinant(M)
341
+ assert tensorflow_code(expr) == "tensorflow.linalg.det(M)"
342
+ _compare_tensorflow_matrix_scalar((M,), expr)
343
+
344
+ expr = Inverse(M)
345
+ assert tensorflow_code(expr) == "tensorflow.linalg.inv(M)"
346
+ _compare_tensorflow_matrix_inverse((M,), expr, use_float=True)
347
+
348
+ expr = M.T
349
+ assert tensorflow_code(expr, tensorflow_version='1.14') == \
350
+ "tensorflow.linalg.matrix_transpose(M)"
351
+ assert tensorflow_code(expr, tensorflow_version='1.13') == \
352
+ "tensorflow.matrix_transpose(M)"
353
+
354
+ _compare_tensorflow_matrix((M,), expr)
355
+
356
+
357
+ def test_codegen_einsum():
358
+ if not tf:
359
+ skip("TensorFlow not installed")
360
+
361
+ graph = tf.Graph()
362
+ with graph.as_default():
363
+ session = tf.compat.v1.Session(graph=graph)
364
+
365
+ M = MatrixSymbol("M", 2, 2)
366
+ N = MatrixSymbol("N", 2, 2)
367
+
368
+ cg = convert_matrix_to_array(M * N)
369
+ f = lambdify((M, N), cg, 'tensorflow')
370
+
371
+ ma = tf.constant([[1, 2], [3, 4]])
372
+ mb = tf.constant([[1,-2], [-1, 3]])
373
+ y = session.run(f(ma, mb))
374
+ c = session.run(tf.matmul(ma, mb))
375
+ assert (y == c).all()
376
+
377
+
378
+ def test_codegen_extra():
379
+ if not tf:
380
+ skip("TensorFlow not installed")
381
+
382
+ graph = tf.Graph()
383
+ with graph.as_default():
384
+ session = tf.compat.v1.Session()
385
+
386
+ M = MatrixSymbol("M", 2, 2)
387
+ N = MatrixSymbol("N", 2, 2)
388
+ P = MatrixSymbol("P", 2, 2)
389
+ Q = MatrixSymbol("Q", 2, 2)
390
+ ma = tf.constant([[1, 2], [3, 4]])
391
+ mb = tf.constant([[1,-2], [-1, 3]])
392
+ mc = tf.constant([[2, 0], [1, 2]])
393
+ md = tf.constant([[1,-1], [4, 7]])
394
+
395
+ cg = ArrayTensorProduct(M, N)
396
+ assert tensorflow_code(cg) == \
397
+ 'tensorflow.linalg.einsum("ab,cd", M, N)'
398
+ f = lambdify((M, N), cg, 'tensorflow')
399
+ y = session.run(f(ma, mb))
400
+ c = session.run(tf.einsum("ij,kl", ma, mb))
401
+ assert (y == c).all()
402
+
403
+ cg = ArrayAdd(M, N)
404
+ assert tensorflow_code(cg) == 'tensorflow.math.add(M, N)'
405
+ f = lambdify((M, N), cg, 'tensorflow')
406
+ y = session.run(f(ma, mb))
407
+ c = session.run(ma + mb)
408
+ assert (y == c).all()
409
+
410
+ cg = ArrayAdd(M, N, P)
411
+ assert tensorflow_code(cg) == \
412
+ 'tensorflow.math.add(tensorflow.math.add(M, N), P)'
413
+ f = lambdify((M, N, P), cg, 'tensorflow')
414
+ y = session.run(f(ma, mb, mc))
415
+ c = session.run(ma + mb + mc)
416
+ assert (y == c).all()
417
+
418
+ cg = ArrayAdd(M, N, P, Q)
419
+ assert tensorflow_code(cg) == \
420
+ 'tensorflow.math.add(' \
421
+ 'tensorflow.math.add(tensorflow.math.add(M, N), P), Q)'
422
+ f = lambdify((M, N, P, Q), cg, 'tensorflow')
423
+ y = session.run(f(ma, mb, mc, md))
424
+ c = session.run(ma + mb + mc + md)
425
+ assert (y == c).all()
426
+
427
+ cg = PermuteDims(M, [1, 0])
428
+ assert tensorflow_code(cg) == 'tensorflow.transpose(M, [1, 0])'
429
+ f = lambdify((M,), cg, 'tensorflow')
430
+ y = session.run(f(ma))
431
+ c = session.run(tf.transpose(ma))
432
+ assert (y == c).all()
433
+
434
+ cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
435
+ assert tensorflow_code(cg) == \
436
+ 'tensorflow.transpose(' \
437
+ 'tensorflow.linalg.einsum("ab,cd", M, N), [1, 2, 3, 0])'
438
+ f = lambdify((M, N), cg, 'tensorflow')
439
+ y = session.run(f(ma, mb))
440
+ c = session.run(tf.transpose(tf.einsum("ab,cd", ma, mb), [1, 2, 3, 0]))
441
+ assert (y == c).all()
442
+
443
+ cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
444
+ assert tensorflow_code(cg) == \
445
+ 'tensorflow.linalg.einsum("ab,bc->acb", M, N)'
446
+ f = lambdify((M, N), cg, 'tensorflow')
447
+ y = session.run(f(ma, mb))
448
+ c = session.run(tf.einsum("ab,bc->acb", ma, mb))
449
+ assert (y == c).all()
450
+
451
+
452
+ def test_MatrixElement_printing():
453
+ A = MatrixSymbol("A", 1, 3)
454
+ B = MatrixSymbol("B", 1, 3)
455
+ C = MatrixSymbol("C", 1, 3)
456
+
457
+ assert tensorflow_code(A[0, 0]) == "A[0, 0]"
458
+ assert tensorflow_code(3 * A[0, 0]) == "3*A[0, 0]"
459
+
460
+ F = C[0, 0].subs(C, A - B)
461
+ assert tensorflow_code(F) == "(tensorflow.math.add((-1)*B, A))[0, 0]"
462
+
463
+
464
+ def test_tensorflow_Derivative():
465
+ expr = Derivative(sin(x), x)
466
+ assert tensorflow_code(expr) == \
467
+ "tensorflow.gradients(tensorflow.math.sin(x), x)[0]"
468
+
469
+ def test_tensorflow_isnan_isinf():
470
+ if not tf:
471
+ skip("TensorFlow not installed")
472
+
473
+ # Test for isnan
474
+ x = symbols("x")
475
+ # Return 0 if x is of nan value, and 1 otherwise
476
+ expression = Piecewise((0.0, isnan(x)), (1.0, True))
477
+ printed_code = tensorflow_code(expression)
478
+ expected_printed_code = "tensorflow.where(tensorflow.math.is_nan(x), 0.0, 1.0)"
479
+ assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}"
480
+ for _input, _expected in [(float('nan'), 0.0), (float('inf'), 1.0), (float('-inf'), 1.0), (1.0, 1.0)]:
481
+ _output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input]))
482
+ assert (_output == _expected).numpy().all()
483
+
484
+ # Test for isinf
485
+ x = symbols("x")
486
+ # Return 0 if x is of nan value, and 1 otherwise
487
+ expression = Piecewise((0.0, isinf(x)), (1.0, True))
488
+ printed_code = tensorflow_code(expression)
489
+ expected_printed_code = "tensorflow.where(tensorflow.math.is_inf(x), 0.0, 1.0)"
490
+ assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}"
491
+ for _input, _expected in [(float('inf'), 0.0), (float('-inf'), 0.0), (float('nan'), 1.0), (1.0, 1.0)]:
492
+ _output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input]))
493
+ assert (_output == _expected).numpy().all()
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_theanocode.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Important note on tests in this module - the Theano printing functions use a
3
+ global cache by default, which means that tests using it will modify global
4
+ state and thus not be independent from each other. Instead of using the "cache"
5
+ keyword argument each time, this module uses the theano_code_ and
6
+ theano_function_ functions defined below which default to using a new, empty
7
+ cache instead.
8
+ """
9
+
10
+ import logging
11
+
12
+ from sympy.external import import_module
13
+ from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy
14
+
15
+ theanologger = logging.getLogger('theano.configdefaults')
16
+ theanologger.setLevel(logging.CRITICAL)
17
+ theano = import_module('theano')
18
+ theanologger.setLevel(logging.WARNING)
19
+
20
+
21
+ if theano:
22
+ import numpy as np
23
+ ts = theano.scalar
24
+ tt = theano.tensor
25
+ xt, yt, zt = [tt.scalar(name, 'floatX') for name in 'xyz']
26
+ Xt, Yt, Zt = [tt.tensor('floatX', (False, False), name=n) for n in 'XYZ']
27
+ else:
28
+ #bin/test will not execute any tests now
29
+ disabled = True
30
+
31
+ import sympy as sy
32
+ from sympy.core.singleton import S
33
+ from sympy.abc import x, y, z, t
34
+ from sympy.printing.theanocode import (theano_code, dim_handling,
35
+ theano_function)
36
+
37
+
38
+ # Default set of matrix symbols for testing - make square so we can both
39
+ # multiply and perform elementwise operations between them.
40
+ X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ']
41
+
42
+ # For testing AppliedUndef
43
+ f_t = sy.Function('f')(t)
44
+
45
+
46
+ def theano_code_(expr, **kwargs):
47
+ """ Wrapper for theano_code that uses a new, empty cache by default. """
48
+ kwargs.setdefault('cache', {})
49
+ with warns_deprecated_sympy():
50
+ return theano_code(expr, **kwargs)
51
+
52
+ def theano_function_(inputs, outputs, **kwargs):
53
+ """ Wrapper for theano_function that uses a new, empty cache by default. """
54
+ kwargs.setdefault('cache', {})
55
+ with warns_deprecated_sympy():
56
+ return theano_function(inputs, outputs, **kwargs)
57
+
58
+
59
+ def fgraph_of(*exprs):
60
+ """ Transform SymPy expressions into Theano Computation.
61
+
62
+ Parameters
63
+ ==========
64
+ exprs
65
+ SymPy expressions
66
+
67
+ Returns
68
+ =======
69
+ theano.gof.FunctionGraph
70
+ """
71
+ outs = list(map(theano_code_, exprs))
72
+ ins = theano.gof.graph.inputs(outs)
73
+ ins, outs = theano.gof.graph.clone(ins, outs)
74
+ return theano.gof.FunctionGraph(ins, outs)
75
+
76
+
77
+ def theano_simplify(fgraph):
78
+ """ Simplify a Theano Computation.
79
+
80
+ Parameters
81
+ ==========
82
+ fgraph : theano.gof.FunctionGraph
83
+
84
+ Returns
85
+ =======
86
+ theano.gof.FunctionGraph
87
+ """
88
+ mode = theano.compile.get_default_mode().excluding("fusion")
89
+ fgraph = fgraph.clone()
90
+ mode.optimizer.optimize(fgraph)
91
+ return fgraph
92
+
93
+
94
+ def theq(a, b):
95
+ """ Test two Theano objects for equality.
96
+
97
+ Also accepts numeric types and lists/tuples of supported types.
98
+
99
+ Note - debugprint() has a bug where it will accept numeric types but does
100
+ not respect the "file" argument and in this case and instead prints the number
101
+ to stdout and returns an empty string. This can lead to tests passing where
102
+ they should fail because any two numbers will always compare as equal. To
103
+ prevent this we treat numbers as a separate case.
104
+ """
105
+ numeric_types = (int, float, np.number)
106
+ a_is_num = isinstance(a, numeric_types)
107
+ b_is_num = isinstance(b, numeric_types)
108
+
109
+ # Compare numeric types using regular equality
110
+ if a_is_num or b_is_num:
111
+ if not (a_is_num and b_is_num):
112
+ return False
113
+
114
+ return a == b
115
+
116
+ # Compare sequences element-wise
117
+ a_is_seq = isinstance(a, (tuple, list))
118
+ b_is_seq = isinstance(b, (tuple, list))
119
+
120
+ if a_is_seq or b_is_seq:
121
+ if not (a_is_seq and b_is_seq) or type(a) != type(b):
122
+ return False
123
+
124
+ return list(map(theq, a)) == list(map(theq, b))
125
+
126
+ # Otherwise, assume debugprint() can handle it
127
+ astr = theano.printing.debugprint(a, file='str')
128
+ bstr = theano.printing.debugprint(b, file='str')
129
+
130
+ # Check for bug mentioned above
131
+ for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]:
132
+ if argstr == '':
133
+ raise TypeError(
134
+ 'theano.printing.debugprint(%s) returned empty string '
135
+ '(%s is instance of %r)'
136
+ % (argname, argname, type(argval))
137
+ )
138
+
139
+ return astr == bstr
140
+
141
+
142
+ def test_example_symbols():
143
+ """
144
+ Check that the example symbols in this module print to their Theano
145
+ equivalents, as many of the other tests depend on this.
146
+ """
147
+ assert theq(xt, theano_code_(x))
148
+ assert theq(yt, theano_code_(y))
149
+ assert theq(zt, theano_code_(z))
150
+ assert theq(Xt, theano_code_(X))
151
+ assert theq(Yt, theano_code_(Y))
152
+ assert theq(Zt, theano_code_(Z))
153
+
154
+
155
+ def test_Symbol():
156
+ """ Test printing a Symbol to a theano variable. """
157
+ xx = theano_code_(x)
158
+ assert isinstance(xx, (tt.TensorVariable, ts.ScalarVariable))
159
+ assert xx.broadcastable == ()
160
+ assert xx.name == x.name
161
+
162
+ xx2 = theano_code_(x, broadcastables={x: (False,)})
163
+ assert xx2.broadcastable == (False,)
164
+ assert xx2.name == x.name
165
+
166
+ def test_MatrixSymbol():
167
+ """ Test printing a MatrixSymbol to a theano variable. """
168
+ XX = theano_code_(X)
169
+ assert isinstance(XX, tt.TensorVariable)
170
+ assert XX.broadcastable == (False, False)
171
+
172
+ @SKIP # TODO - this is currently not checked but should be implemented
173
+ def test_MatrixSymbol_wrong_dims():
174
+ """ Test MatrixSymbol with invalid broadcastable. """
175
+ bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)]
176
+ for bc in bcs:
177
+ with raises(ValueError):
178
+ theano_code_(X, broadcastables={X: bc})
179
+
180
+ def test_AppliedUndef():
181
+ """ Test printing AppliedUndef instance, which works similarly to Symbol. """
182
+ ftt = theano_code_(f_t)
183
+ assert isinstance(ftt, tt.TensorVariable)
184
+ assert ftt.broadcastable == ()
185
+ assert ftt.name == 'f_t'
186
+
187
+
188
+ def test_add():
189
+ expr = x + y
190
+ comp = theano_code_(expr)
191
+ assert comp.owner.op == theano.tensor.add
192
+
193
+ def test_trig():
194
+ assert theq(theano_code_(sy.sin(x)), tt.sin(xt))
195
+ assert theq(theano_code_(sy.tan(x)), tt.tan(xt))
196
+
197
+ def test_many():
198
+ """ Test printing a complex expression with multiple symbols. """
199
+ expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z)
200
+ comp = theano_code_(expr)
201
+ expected = tt.exp(xt**2 + tt.cos(yt)) * tt.log(2*zt)
202
+ assert theq(comp, expected)
203
+
204
+
205
+ def test_dtype():
206
+ """ Test specifying specific data types through the dtype argument. """
207
+ for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']:
208
+ assert theano_code_(x, dtypes={x: dtype}).type.dtype == dtype
209
+
210
+ # "floatX" type
211
+ assert theano_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64')
212
+
213
+ # Type promotion
214
+ assert theano_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32'
215
+ assert theano_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64'
216
+
217
+
218
+ def test_broadcastables():
219
+ """ Test the "broadcastables" argument when printing symbol-like objects. """
220
+
221
+ # No restrictions on shape
222
+ for s in [x, f_t]:
223
+ for bc in [(), (False,), (True,), (False, False), (True, False)]:
224
+ assert theano_code_(s, broadcastables={s: bc}).broadcastable == bc
225
+
226
+ # TODO - matrix broadcasting?
227
+
228
+ def test_broadcasting():
229
+ """ Test "broadcastable" attribute after applying element-wise binary op. """
230
+
231
+ expr = x + y
232
+
233
+ cases = [
234
+ [(), (), ()],
235
+ [(False,), (False,), (False,)],
236
+ [(True,), (False,), (False,)],
237
+ [(False, True), (False, False), (False, False)],
238
+ [(True, False), (False, False), (False, False)],
239
+ ]
240
+
241
+ for bc1, bc2, bc3 in cases:
242
+ comp = theano_code_(expr, broadcastables={x: bc1, y: bc2})
243
+ assert comp.broadcastable == bc3
244
+
245
+
246
+ def test_MatMul():
247
+ expr = X*Y*Z
248
+ expr_t = theano_code_(expr)
249
+ assert isinstance(expr_t.owner.op, tt.Dot)
250
+ assert theq(expr_t, Xt.dot(Yt).dot(Zt))
251
+
252
+ def test_Transpose():
253
+ assert isinstance(theano_code_(X.T).owner.op, tt.DimShuffle)
254
+
255
+ def test_MatAdd():
256
+ expr = X+Y+Z
257
+ assert isinstance(theano_code_(expr).owner.op, tt.Elemwise)
258
+
259
+
260
+ def test_Rationals():
261
+ assert theq(theano_code_(sy.Integer(2) / 3), tt.true_div(2, 3))
262
+ assert theq(theano_code_(S.Half), tt.true_div(1, 2))
263
+
264
+ def test_Integers():
265
+ assert theano_code_(sy.Integer(3)) == 3
266
+
267
+ def test_factorial():
268
+ n = sy.Symbol('n')
269
+ assert theano_code_(sy.factorial(n))
270
+
271
+ def test_Derivative():
272
+ simp = lambda expr: theano_simplify(fgraph_of(expr))
273
+ assert theq(simp(theano_code_(sy.Derivative(sy.sin(x), x, evaluate=False))),
274
+ simp(theano.grad(tt.sin(xt), xt)))
275
+
276
+
277
+ def test_theano_function_simple():
278
+ """ Test theano_function() with single output. """
279
+ f = theano_function_([x, y], [x+y])
280
+ assert f(2, 3) == 5
281
+
282
+ def test_theano_function_multi():
283
+ """ Test theano_function() with multiple outputs. """
284
+ f = theano_function_([x, y], [x+y, x-y])
285
+ o1, o2 = f(2, 3)
286
+ assert o1 == 5
287
+ assert o2 == -1
288
+
289
+ def test_theano_function_numpy():
290
+ """ Test theano_function() vs Numpy implementation. """
291
+ f = theano_function_([x, y], [x+y], dim=1,
292
+ dtypes={x: 'float64', y: 'float64'})
293
+ assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9
294
+
295
+ f = theano_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'},
296
+ dim=1)
297
+ xx = np.arange(3).astype('float64')
298
+ yy = 2*np.arange(3).astype('float64')
299
+ assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9
300
+
301
+
302
+ def test_theano_function_matrix():
303
+ m = sy.Matrix([[x, y], [z, x + y + z]])
304
+ expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]])
305
+ f = theano_function_([x, y, z], [m])
306
+ np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
307
+ f = theano_function_([x, y, z], [m], scalar=True)
308
+ np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected)
309
+ f = theano_function_([x, y, z], [m, m])
310
+ assert isinstance(f(1.0, 2.0, 3.0), type([]))
311
+ np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected)
312
+ np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected)
313
+
314
+ def test_dim_handling():
315
+ assert dim_handling([x], dim=2) == {x: (False, False)}
316
+ assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True),
317
+ y: (False, False)}
318
+ assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)}
319
+
320
+ def test_theano_function_kwargs():
321
+ """
322
+ Test passing additional kwargs from theano_function() to theano.function().
323
+ """
324
+ import numpy as np
325
+ f = theano_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore',
326
+ dtypes={x: 'float64', y: 'float64', z: 'float64'})
327
+ assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9
328
+
329
+ f = theano_function_([x, y, z], [x+y],
330
+ dtypes={x: 'float64', y: 'float64', z: 'float64'},
331
+ dim=1, on_unused_input='ignore')
332
+ xx = np.arange(3).astype('float64')
333
+ yy = 2*np.arange(3).astype('float64')
334
+ zz = 2*np.arange(3).astype('float64')
335
+ assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9
336
+
337
+ def test_theano_function_scalar():
338
+ """ Test the "scalar" argument to theano_function(). """
339
+
340
+ args = [
341
+ ([x, y], [x + y], None, [0]), # Single 0d output
342
+ ([X, Y], [X + Y], None, [2]), # Single 2d output
343
+ ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output
344
+ ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs
345
+ ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d
346
+ ]
347
+
348
+ # Create and test functions with and without the scalar setting
349
+ for inputs, outputs, in_dims, out_dims in args:
350
+ for scalar in [False, True]:
351
+
352
+ f = theano_function_(inputs, outputs, dims=in_dims, scalar=scalar)
353
+
354
+ # Check the theano_function attribute is set whether wrapped or not
355
+ assert isinstance(f.theano_function, theano.compile.function_module.Function)
356
+
357
+ # Feed in inputs of the appropriate size and get outputs
358
+ in_values = [
359
+ np.ones([1 if bc else 5 for bc in i.type.broadcastable])
360
+ for i in f.theano_function.input_storage
361
+ ]
362
+ out_values = f(*in_values)
363
+ if not isinstance(out_values, list):
364
+ out_values = [out_values]
365
+
366
+ # Check output types and shapes
367
+ assert len(out_dims) == len(out_values)
368
+ for d, value in zip(out_dims, out_values):
369
+
370
+ if scalar and d == 0:
371
+ # Should have been converted to a scalar value
372
+ assert isinstance(value, np.number)
373
+
374
+ else:
375
+ # Otherwise should be an array
376
+ assert isinstance(value, np.ndarray)
377
+ assert value.ndim == d
378
+
379
+ def test_theano_function_bad_kwarg():
380
+ """
381
+ Passing an unknown keyword argument to theano_function() should raise an
382
+ exception.
383
+ """
384
+ raises(Exception, lambda : theano_function_([x], [x+1], foobar=3))
385
+
386
+
387
+ def test_slice():
388
+ assert theano_code_(slice(1, 2, 3)) == slice(1, 2, 3)
389
+
390
+ def theq_slice(s1, s2):
391
+ for attr in ['start', 'stop', 'step']:
392
+ a1 = getattr(s1, attr)
393
+ a2 = getattr(s2, attr)
394
+ if a1 is None or a2 is None:
395
+ if not (a1 is None or a2 is None):
396
+ return False
397
+ elif not theq(a1, a2):
398
+ return False
399
+ return True
400
+
401
+ dtypes = {x: 'int32', y: 'int32'}
402
+ assert theq_slice(theano_code_(slice(x, y), dtypes=dtypes), slice(xt, yt))
403
+ assert theq_slice(theano_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3))
404
+
405
+ def test_MatrixSlice():
406
+ from theano import Constant
407
+
408
+ cache = {}
409
+
410
+ n = sy.Symbol('n', integer=True)
411
+ X = sy.MatrixSymbol('X', n, n)
412
+
413
+ Y = X[1:2:3, 4:5:6]
414
+ Yt = theano_code_(Y, cache=cache)
415
+
416
+ s = ts.Scalar('int64')
417
+ assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s))
418
+ assert Yt.owner.inputs[0] == theano_code_(X, cache=cache)
419
+ # == doesn't work in theano like it does in SymPy. You have to use
420
+ # equals.
421
+ assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7))
422
+
423
+ k = sy.Symbol('k')
424
+ theano_code_(k, dtypes={k: 'int32'})
425
+ start, stop, step = 4, k, 2
426
+ Y = X[start:stop:step]
427
+ Yt = theano_code_(Y, dtypes={n: 'int32', k: 'int32'})
428
+ # assert Yt.owner.op.idx_list[0].stop == kt
429
+
430
+ def test_BlockMatrix():
431
+ n = sy.Symbol('n', integer=True)
432
+ A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD']
433
+ At, Bt, Ct, Dt = map(theano_code_, (A, B, C, D))
434
+ Block = sy.BlockMatrix([[A, B], [C, D]])
435
+ Blockt = theano_code_(Block)
436
+ solutions = [tt.join(0, tt.join(1, At, Bt), tt.join(1, Ct, Dt)),
437
+ tt.join(1, tt.join(0, At, Ct), tt.join(0, Bt, Dt))]
438
+ assert any(theq(Blockt, solution) for solution in solutions)
439
+
440
+ @SKIP
441
+ def test_BlockMatrix_Inverse_execution():
442
+ k, n = 2, 4
443
+ dtype = 'float32'
444
+ A = sy.MatrixSymbol('A', n, k)
445
+ B = sy.MatrixSymbol('B', n, n)
446
+ inputs = A, B
447
+ output = B.I*A
448
+
449
+ cutsizes = {A: [(n//2, n//2), (k//2, k//2)],
450
+ B: [(n//2, n//2), (n//2, n//2)]}
451
+ cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs]
452
+ cutoutput = output.subs(dict(zip(inputs, cutinputs)))
453
+
454
+ dtypes = dict(zip(inputs, [dtype]*len(inputs)))
455
+ f = theano_function_(inputs, [output], dtypes=dtypes, cache={})
456
+ fblocked = theano_function_(inputs, [sy.block_collapse(cutoutput)],
457
+ dtypes=dtypes, cache={})
458
+
459
+ ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs]
460
+ ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype),
461
+ np.eye(n).astype(dtype)]
462
+ ninputs[1] += np.ones(B.shape)*1e-5
463
+
464
+ assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5)
465
+
466
+ def test_DenseMatrix():
467
+ t = sy.Symbol('theta')
468
+ for MatrixType in [sy.Matrix, sy.ImmutableMatrix]:
469
+ X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]])
470
+ tX = theano_code_(X)
471
+ assert isinstance(tX, tt.TensorVariable)
472
+ assert tX.owner.op == tt.join_
473
+
474
+
475
+ def test_cache_basic():
476
+ """ Test single symbol-like objects are cached when printed by themselves. """
477
+
478
+ # Pairs of objects which should be considered equivalent with respect to caching
479
+ pairs = [
480
+ (x, sy.Symbol('x')),
481
+ (X, sy.MatrixSymbol('X', *X.shape)),
482
+ (f_t, sy.Function('f')(sy.Symbol('t'))),
483
+ ]
484
+
485
+ for s1, s2 in pairs:
486
+ cache = {}
487
+ st = theano_code_(s1, cache=cache)
488
+
489
+ # Test hit with same instance
490
+ assert theano_code_(s1, cache=cache) is st
491
+
492
+ # Test miss with same instance but new cache
493
+ assert theano_code_(s1, cache={}) is not st
494
+
495
+ # Test hit with different but equivalent instance
496
+ assert theano_code_(s2, cache=cache) is st
497
+
498
+ def test_global_cache():
499
+ """ Test use of the global cache. """
500
+ from sympy.printing.theanocode import global_cache
501
+
502
+ backup = dict(global_cache)
503
+ try:
504
+ # Temporarily empty global cache
505
+ global_cache.clear()
506
+
507
+ for s in [x, X, f_t]:
508
+ with warns_deprecated_sympy():
509
+ st = theano_code(s)
510
+ assert theano_code(s) is st
511
+
512
+ finally:
513
+ # Restore global cache
514
+ global_cache.update(backup)
515
+
516
+ def test_cache_types_distinct():
517
+ """
518
+ Test that symbol-like objects of different types (Symbol, MatrixSymbol,
519
+ AppliedUndef) are distinguished by the cache even if they have the same
520
+ name.
521
+ """
522
+ symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t]
523
+
524
+ cache = {} # Single shared cache
525
+ printed = {}
526
+
527
+ for s in symbols:
528
+ st = theano_code_(s, cache=cache)
529
+ assert st not in printed.values()
530
+ printed[s] = st
531
+
532
+ # Check all printed objects are distinct
533
+ assert len(set(map(id, printed.values()))) == len(symbols)
534
+
535
+ # Check retrieving
536
+ for s, st in printed.items():
537
+ with warns_deprecated_sympy():
538
+ assert theano_code(s, cache=cache) is st
539
+
540
+ def test_symbols_are_created_once():
541
+ """
542
+ Test that a symbol is cached and reused when it appears in an expression
543
+ more than once.
544
+ """
545
+ expr = sy.Add(x, x, evaluate=False)
546
+ comp = theano_code_(expr)
547
+
548
+ assert theq(comp, xt + xt)
549
+ assert not theq(comp, xt + theano_code_(x))
550
+
551
+ def test_cache_complex():
552
+ """
553
+ Test caching on a complicated expression with multiple symbols appearing
554
+ multiple times.
555
+ """
556
+ expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y)
557
+ symbol_names = {s.name for s in expr.free_symbols}
558
+ expr_t = theano_code_(expr)
559
+
560
+ # Iterate through variables in the Theano computational graph that the
561
+ # printed expression depends on
562
+ seen = set()
563
+ for v in theano.gof.graph.ancestors([expr_t]):
564
+ # Owner-less, non-constant variables should be our symbols
565
+ if v.owner is None and not isinstance(v, theano.gof.graph.Constant):
566
+ # Check it corresponds to a symbol and appears only once
567
+ assert v.name in symbol_names
568
+ assert v.name not in seen
569
+ seen.add(v.name)
570
+
571
+ # Check all were present
572
+ assert seen == symbol_names
573
+
574
+
575
+ def test_Piecewise():
576
+ # A piecewise linear
577
+ expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III
578
+ result = theano_code_(expr)
579
+ assert result.owner.op == tt.switch
580
+
581
+ expected = tt.switch(xt<0, 0, tt.switch(xt<2, xt, 1))
582
+ assert theq(result, expected)
583
+
584
+ expr = sy.Piecewise((x, x < 0))
585
+ result = theano_code_(expr)
586
+ expected = tt.switch(xt < 0, xt, np.nan)
587
+ assert theq(result, expected)
588
+
589
+ expr = sy.Piecewise((0, sy.And(x>0, x<2)), \
590
+ (x, sy.Or(x>2, x<0)))
591
+ result = theano_code_(expr)
592
+ expected = tt.switch(tt.and_(xt>0,xt<2), 0, \
593
+ tt.switch(tt.or_(xt>2, xt<0), xt, np.nan))
594
+ assert theq(result, expected)
595
+
596
+
597
+ def test_Relationals():
598
+ assert theq(theano_code_(sy.Eq(x, y)), tt.eq(xt, yt))
599
+ # assert theq(theano_code_(sy.Ne(x, y)), tt.neq(xt, yt)) # TODO - implement
600
+ assert theq(theano_code_(x > y), xt > yt)
601
+ assert theq(theano_code_(x < y), xt < yt)
602
+ assert theq(theano_code_(x >= y), xt >= yt)
603
+ assert theq(theano_code_(x <= y), xt <= yt)
604
+
605
+
606
+ def test_complexfunctions():
607
+ with warns_deprecated_sympy():
608
+ xt, yt = theano_code_(x, dtypes={x:'complex128'}), theano_code_(y, dtypes={y: 'complex128'})
609
+ from sympy.functions.elementary.complexes import conjugate
610
+ from theano.tensor import as_tensor_variable as atv
611
+ from theano.tensor import complex as cplx
612
+ with warns_deprecated_sympy():
613
+ assert theq(theano_code_(y*conjugate(x)), yt*(xt.conj()))
614
+ assert theq(theano_code_((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1)))
615
+
616
+
617
+ def test_constantfunctions():
618
+ with warns_deprecated_sympy():
619
+ tf = theano_function_([],[1+1j])
620
+ assert(tf()==1+1j)
621
+
622
+
623
+ def test_Exp1():
624
+ """
625
+ Test that exp(1) prints without error and evaluates close to SymPy's E
626
+ """
627
+ # sy.exp(1) should yield same instance of E as sy.E (singleton), but extra
628
+ # check added for sanity
629
+ e_a = sy.exp(1)
630
+ e_b = sy.E
631
+
632
+ np.testing.assert_allclose(float(e_a), np.e)
633
+ np.testing.assert_allclose(float(e_b), np.e)
634
+
635
+ e = theano_code_(e_a)
636
+ np.testing.assert_allclose(float(e_a), e.eval())
637
+
638
+ e = theano_code_(e_b)
639
+ np.testing.assert_allclose(float(e_b), e.eval())
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_torch.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+
4
+ from sympy import symbols, Derivative
5
+ from sympy.printing.pytorch import torch_code
6
+ from sympy import (eye, MatrixSymbol, Matrix)
7
+ from sympy.tensor.array import NDimArray
8
+ from sympy.tensor.array.expressions.array_expressions import (
9
+ ArrayTensorProduct, ArrayAdd,
10
+ PermuteDims, ArrayDiagonal, _CodegenArrayAbstract)
11
+ from sympy.utilities.lambdify import lambdify
12
+ from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt
13
+ from sympy.functions import \
14
+ Abs, ceiling, exp, floor, sign, sin, asin, cos, \
15
+ acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \
16
+ re, im, arg, erf, loggamma, sqrt
17
+ from sympy.testing.pytest import skip
18
+ from sympy.external import import_module
19
+ from sympy.matrices.expressions import \
20
+ Determinant, HadamardProduct, Inverse, Trace
21
+ from sympy.matrices import randMatrix
22
+ from sympy.matrices import Identity, ZeroMatrix, OneMatrix
23
+ from sympy import conjugate, I
24
+ from sympy import Heaviside, gamma, polygamma
25
+
26
+
27
+
28
+ torch = import_module("torch")
29
+
30
+ M = MatrixSymbol("M", 3, 3)
31
+ N = MatrixSymbol("N", 3, 3)
32
+ P = MatrixSymbol("P", 3, 3)
33
+ Q = MatrixSymbol("Q", 3, 3)
34
+
35
+ x, y, z, t = symbols("x y z t")
36
+
37
+ if torch is not None:
38
+ llo = [list(range(i, i + 3)) for i in range(0, 9, 3)]
39
+ m3x3 = torch.tensor(llo, dtype=torch.float64)
40
+ m3x3sympy = Matrix(llo)
41
+
42
+
43
+ def _compare_torch_matrix(variables, expr):
44
+ f = lambdify(variables, expr, 'torch')
45
+
46
+ random_matrices = [randMatrix(i.shape[0], i.shape[1]) for i in variables]
47
+ random_variables = [torch.tensor(i.tolist(), dtype=torch.float64) for i in random_matrices]
48
+ r = f(*random_variables)
49
+ e = expr.subs(dict(zip(variables, random_matrices))).doit()
50
+
51
+ if isinstance(e, _CodegenArrayAbstract):
52
+ e = e.doit()
53
+
54
+ if hasattr(e, 'is_number') and e.is_number:
55
+ if isinstance(r, torch.Tensor) and r.dim() == 0:
56
+ r = r.item()
57
+ e = float(e)
58
+ assert abs(r - e) < 1e-6
59
+ return
60
+
61
+ if e.is_Matrix or isinstance(e, NDimArray):
62
+ e = torch.tensor(e.tolist(), dtype=torch.float64)
63
+ assert torch.allclose(r, e, atol=1e-6)
64
+ else:
65
+ raise TypeError(f"Cannot compare {type(r)} with {type(e)}")
66
+
67
+
68
+ def _compare_torch_scalar(variables, expr, rng=lambda: random.uniform(-5, 5)):
69
+ f = lambdify(variables, expr, 'torch')
70
+ rvs = [rng() for v in variables]
71
+ t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs]
72
+ r = f(*t_rvs)
73
+ if isinstance(r, torch.Tensor):
74
+ r = r.item()
75
+ e = expr.subs(dict(zip(variables, rvs))).doit()
76
+ assert abs(r - e) < 1e-6
77
+
78
+
79
+ def _compare_torch_relational(variables, expr, rng=lambda: random.randint(0, 10)):
80
+ f = lambdify(variables, expr, 'torch')
81
+ rvs = [rng() for v in variables]
82
+ t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs]
83
+ r = f(*t_rvs)
84
+ e = bool(expr.subs(dict(zip(variables, rvs))).doit())
85
+ assert r.item() == e
86
+
87
+
88
+ def test_torch_math():
89
+ if not torch:
90
+ skip("PyTorch not installed")
91
+
92
+ expr = Abs(x)
93
+ assert torch_code(expr) == "torch.abs(x)"
94
+ f = lambdify(x, expr, 'torch')
95
+ ma = torch.tensor([[-1, 2, -3, -4]], dtype=torch.float64)
96
+ y_abs = f(ma)
97
+ c = torch.abs(ma)
98
+ assert torch.all(y_abs == c)
99
+
100
+ expr = sign(x)
101
+ assert torch_code(expr) == "torch.sign(x)"
102
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-10, 10))
103
+
104
+ expr = ceiling(x)
105
+ assert torch_code(expr) == "torch.ceil(x)"
106
+ _compare_torch_scalar((x,), expr, rng=lambda: random.random())
107
+
108
+ expr = floor(x)
109
+ assert torch_code(expr) == "torch.floor(x)"
110
+ _compare_torch_scalar((x,), expr, rng=lambda: random.random())
111
+
112
+ expr = exp(x)
113
+ assert torch_code(expr) == "torch.exp(x)"
114
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
115
+
116
+ expr = sqrt(x)
117
+ assert torch_code(expr) == "torch.sqrt(x)"
118
+ _compare_torch_scalar((x,), expr, rng=lambda: random.random())
119
+
120
+ expr = x ** 4
121
+ assert torch_code(expr) == "torch.pow(x, 4)"
122
+ _compare_torch_scalar((x,), expr, rng=lambda: random.random())
123
+
124
+ expr = cos(x)
125
+ assert torch_code(expr) == "torch.cos(x)"
126
+ _compare_torch_scalar((x,), expr, rng=lambda: random.random())
127
+
128
+ expr = acos(x)
129
+ assert torch_code(expr) == "torch.acos(x)"
130
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99))
131
+
132
+ expr = sin(x)
133
+ assert torch_code(expr) == "torch.sin(x)"
134
+ _compare_torch_scalar((x,), expr, rng=lambda: random.random())
135
+
136
+ expr = asin(x)
137
+ assert torch_code(expr) == "torch.asin(x)"
138
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99))
139
+
140
+ expr = tan(x)
141
+ assert torch_code(expr) == "torch.tan(x)"
142
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-1.5, 1.5))
143
+
144
+ expr = atan(x)
145
+ assert torch_code(expr) == "torch.atan(x)"
146
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5))
147
+
148
+ expr = atan2(y, x)
149
+ assert torch_code(expr) == "torch.atan2(y, x)"
150
+ _compare_torch_scalar((y, x), expr, rng=lambda: random.uniform(-5, 5))
151
+
152
+ expr = cosh(x)
153
+ assert torch_code(expr) == "torch.cosh(x)"
154
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
155
+
156
+ expr = acosh(x)
157
+ assert torch_code(expr) == "torch.acosh(x)"
158
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(1.1, 5))
159
+
160
+ expr = sinh(x)
161
+ assert torch_code(expr) == "torch.sinh(x)"
162
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
163
+
164
+ expr = asinh(x)
165
+ assert torch_code(expr) == "torch.asinh(x)"
166
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5))
167
+
168
+ expr = tanh(x)
169
+ assert torch_code(expr) == "torch.tanh(x)"
170
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
171
+
172
+ expr = atanh(x)
173
+ assert torch_code(expr) == "torch.atanh(x)"
174
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.9, 0.9))
175
+
176
+ expr = erf(x)
177
+ assert torch_code(expr) == "torch.erf(x)"
178
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2))
179
+
180
+ expr = loggamma(x)
181
+ assert torch_code(expr) == "torch.lgamma(x)"
182
+ _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(0.5, 5))
183
+
184
+
185
+ def test_torch_complexes():
186
+ assert torch_code(re(x)) == "torch.real(x)"
187
+ assert torch_code(im(x)) == "torch.imag(x)"
188
+ assert torch_code(arg(x)) == "torch.angle(x)"
189
+
190
+
191
+ def test_torch_relational():
192
+ if not torch:
193
+ skip("PyTorch not installed")
194
+
195
+ expr = Eq(x, y)
196
+ assert torch_code(expr) == "torch.eq(x, y)"
197
+ _compare_torch_relational((x, y), expr)
198
+
199
+ expr = Ne(x, y)
200
+ assert torch_code(expr) == "torch.ne(x, y)"
201
+ _compare_torch_relational((x, y), expr)
202
+
203
+ expr = Ge(x, y)
204
+ assert torch_code(expr) == "torch.ge(x, y)"
205
+ _compare_torch_relational((x, y), expr)
206
+
207
+ expr = Gt(x, y)
208
+ assert torch_code(expr) == "torch.gt(x, y)"
209
+ _compare_torch_relational((x, y), expr)
210
+
211
+ expr = Le(x, y)
212
+ assert torch_code(expr) == "torch.le(x, y)"
213
+ _compare_torch_relational((x, y), expr)
214
+
215
+ expr = Lt(x, y)
216
+ assert torch_code(expr) == "torch.lt(x, y)"
217
+ _compare_torch_relational((x, y), expr)
218
+
219
+
220
+ def test_torch_matrix():
221
+ if torch is None:
222
+ skip("PyTorch not installed")
223
+
224
+ expr = M
225
+ assert torch_code(expr) == "M"
226
+ f = lambdify((M,), expr, "torch")
227
+ eye_mat = eye(3)
228
+ eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64)
229
+ assert torch.allclose(f(eye_tensor), eye_tensor)
230
+
231
+ expr = M * N
232
+ assert torch_code(expr) == "torch.matmul(M, N)"
233
+ _compare_torch_matrix((M, N), expr)
234
+
235
+ expr = M ** 3
236
+ assert torch_code(expr) == "torch.mm(torch.mm(M, M), M)"
237
+ _compare_torch_matrix((M,), expr)
238
+
239
+ expr = M * N * P * Q
240
+ assert torch_code(expr) == "torch.matmul(torch.matmul(torch.matmul(M, N), P), Q)"
241
+ _compare_torch_matrix((M, N, P, Q), expr)
242
+
243
+ expr = Trace(M)
244
+ assert torch_code(expr) == "torch.trace(M)"
245
+ _compare_torch_matrix((M,), expr)
246
+
247
+ expr = Determinant(M)
248
+ assert torch_code(expr) == "torch.det(M)"
249
+ _compare_torch_matrix((M,), expr)
250
+
251
+ expr = HadamardProduct(M, N)
252
+ assert torch_code(expr) == "torch.mul(M, N)"
253
+ _compare_torch_matrix((M, N), expr)
254
+
255
+ expr = Inverse(M)
256
+ assert torch_code(expr) == "torch.linalg.inv(M)"
257
+
258
+ # For inverse, use a matrix that's guaranteed to be invertible
259
+ eye_mat = eye(3)
260
+ eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64)
261
+ f = lambdify((M,), expr, "torch")
262
+ result = f(eye_tensor)
263
+ expected = torch.linalg.inv(eye_tensor)
264
+ assert torch.allclose(result, expected)
265
+
266
+
267
+ def test_torch_array_operations():
268
+ if not torch:
269
+ skip("PyTorch not installed")
270
+
271
+ M = MatrixSymbol("M", 2, 2)
272
+ N = MatrixSymbol("N", 2, 2)
273
+ P = MatrixSymbol("P", 2, 2)
274
+ Q = MatrixSymbol("Q", 2, 2)
275
+
276
+ ma = torch.tensor([[1., 2.], [3., 4.]], dtype=torch.float64)
277
+ mb = torch.tensor([[1., -2.], [-1., 3.]], dtype=torch.float64)
278
+ mc = torch.tensor([[2., 0.], [1., 2.]], dtype=torch.float64)
279
+ md = torch.tensor([[1., -1.], [4., 7.]], dtype=torch.float64)
280
+
281
+ cg = ArrayTensorProduct(M, N)
282
+ assert torch_code(cg) == 'torch.einsum("ab,cd", M, N)'
283
+ f = lambdify((M, N), cg, 'torch')
284
+ y = f(ma, mb)
285
+ c = torch.einsum("ij,kl", ma, mb)
286
+ assert torch.allclose(y, c)
287
+
288
+ cg = ArrayAdd(M, N)
289
+ assert torch_code(cg) == 'torch.add(M, N)'
290
+ f = lambdify((M, N), cg, 'torch')
291
+ y = f(ma, mb)
292
+ c = ma + mb
293
+ assert torch.allclose(y, c)
294
+
295
+ cg = ArrayAdd(M, N, P)
296
+ assert torch_code(cg) == 'torch.add(torch.add(M, N), P)'
297
+ f = lambdify((M, N, P), cg, 'torch')
298
+ y = f(ma, mb, mc)
299
+ c = ma + mb + mc
300
+ assert torch.allclose(y, c)
301
+
302
+ cg = ArrayAdd(M, N, P, Q)
303
+ assert torch_code(cg) == 'torch.add(torch.add(torch.add(M, N), P), Q)'
304
+ f = lambdify((M, N, P, Q), cg, 'torch')
305
+ y = f(ma, mb, mc, md)
306
+ c = ma + mb + mc + md
307
+ assert torch.allclose(y, c)
308
+
309
+ cg = PermuteDims(M, [1, 0])
310
+ assert torch_code(cg) == 'M.permute(1, 0)'
311
+ f = lambdify((M,), cg, 'torch')
312
+ y = f(ma)
313
+ c = ma.T
314
+ assert torch.allclose(y, c)
315
+
316
+ cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0])
317
+ assert torch_code(cg) == 'torch.einsum("ab,cd", M, N).permute(1, 2, 3, 0)'
318
+ f = lambdify((M, N), cg, 'torch')
319
+ y = f(ma, mb)
320
+ c = torch.einsum("ab,cd", ma, mb).permute(1, 2, 3, 0)
321
+ assert torch.allclose(y, c)
322
+
323
+ cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2))
324
+ assert torch_code(cg) == 'torch.einsum("ab,bc->acb", M, N)'
325
+ f = lambdify((M, N), cg, 'torch')
326
+ y = f(ma, mb)
327
+ c = torch.einsum("ab,bc->acb", ma, mb)
328
+ assert torch.allclose(y, c)
329
+
330
+
331
+ def test_torch_derivative():
332
+ """Test derivative handling."""
333
+ expr = Derivative(sin(x), x)
334
+ assert torch_code(expr) == 'torch.autograd.grad(torch.sin(x), x)[0]'
335
+
336
+
337
+ def test_torch_printing_dtype():
338
+ if not torch:
339
+ skip("PyTorch not installed")
340
+
341
+ # matrix printing with default dtype
342
+ expr = Matrix([[x, sin(y)], [exp(z), -t]])
343
+ assert "dtype=torch.float64" in torch_code(expr)
344
+
345
+ # explicit dtype
346
+ assert "dtype=torch.float32" in torch_code(expr, dtype="torch.float32")
347
+
348
+ # with requires_grad
349
+ result = torch_code(expr, requires_grad=True)
350
+ assert "requires_grad=True" in result
351
+ assert "dtype=torch.float64" in result
352
+
353
+ # both
354
+ result = torch_code(expr, requires_grad=True, dtype="torch.float32")
355
+ assert "requires_grad=True" in result
356
+ assert "dtype=torch.float32" in result
357
+
358
+
359
+ def test_requires_grad():
360
+ if not torch:
361
+ skip("PyTorch not installed")
362
+
363
+ expr = sin(x) + cos(y)
364
+ f = lambdify([x, y], expr, 'torch')
365
+
366
+ # make sure the gradients flow
367
+ x_val = torch.tensor(1.0, requires_grad=True)
368
+ y_val = torch.tensor(2.0, requires_grad=True)
369
+ result = f(x_val, y_val)
370
+ assert result.requires_grad
371
+ result.backward()
372
+
373
+ # x_val.grad should be cos(x_val) which is close to cos(1.0)
374
+ assert abs(x_val.grad.item() - float(cos(1.0).evalf())) < 1e-6
375
+
376
+ # y_val.grad should be -sin(y_val) which is close to -sin(2.0)
377
+ assert abs(y_val.grad.item() - float(-sin(2.0).evalf())) < 1e-6
378
+
379
+
380
+ def test_torch_multi_variable_derivatives():
381
+ if not torch:
382
+ skip("PyTorch not installed")
383
+
384
+ x, y, z = symbols("x y z")
385
+
386
+ expr = Derivative(sin(x), x)
387
+ assert torch_code(expr) == "torch.autograd.grad(torch.sin(x), x)[0]"
388
+
389
+ expr = Derivative(sin(x), (x, 2))
390
+ assert torch_code(
391
+ expr) == "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]"
392
+
393
+ expr = Derivative(sin(x * y), x, y)
394
+ result = torch_code(expr)
395
+ expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x*y), x, create_graph=True)[0], y, create_graph=True)[0]"
396
+ normalized_result = result.replace(" ", "")
397
+ normalized_expected = expected.replace(" ", "")
398
+ assert normalized_result == normalized_expected
399
+
400
+ expr = Derivative(sin(x), x, x)
401
+ result = torch_code(expr)
402
+ expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]"
403
+ assert result == expected
404
+
405
+ expr = Derivative(sin(x * y * z), x, (y, 2), z)
406
+ result = torch_code(expr)
407
+ expected = "torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.sin(x*y*z), x, create_graph=True)[0], y, create_graph=True)[0], y, create_graph=True)[0], z, create_graph=True)[0]"
408
+ normalized_result = result.replace(" ", "")
409
+ normalized_expected = expected.replace(" ", "")
410
+ assert normalized_result == normalized_expected
411
+
412
+
413
+ def test_torch_derivative_lambdify():
414
+ if not torch:
415
+ skip("PyTorch not installed")
416
+
417
+ x = symbols("x")
418
+ y = symbols("y")
419
+
420
+ expr = Derivative(x ** 2, x)
421
+ f = lambdify(x, expr, 'torch')
422
+ x_val = torch.tensor(2.0, requires_grad=True)
423
+ result = f(x_val)
424
+ assert torch.isclose(result, torch.tensor(4.0))
425
+
426
+ expr = Derivative(sin(x), (x, 2))
427
+ f = lambdify(x, expr, 'torch')
428
+ # Second derivative of sin(x) at x=0 is 0, not -1
429
+ x_val = torch.tensor(0.0, requires_grad=True)
430
+ result = f(x_val)
431
+ assert torch.isclose(result, torch.tensor(0.0), atol=1e-5)
432
+
433
+ x_val = torch.tensor(math.pi / 2, requires_grad=True)
434
+ result = f(x_val)
435
+ assert torch.isclose(result, torch.tensor(-1.0), atol=1e-5)
436
+
437
+ expr = Derivative(x * y ** 2, x, y)
438
+ f = lambdify((x, y), expr, 'torch')
439
+ x_val = torch.tensor(2.0, requires_grad=True)
440
+ y_val = torch.tensor(3.0, requires_grad=True)
441
+ result = f(x_val, y_val)
442
+ assert torch.isclose(result, torch.tensor(6.0))
443
+
444
+
445
+ def test_torch_special_matrices():
446
+ if not torch:
447
+ skip("PyTorch not installed")
448
+
449
+ expr = Identity(3)
450
+ assert torch_code(expr) == "torch.eye(3)"
451
+
452
+ n = symbols("n")
453
+ expr = Identity(n)
454
+ assert torch_code(expr) == "torch.eye(n, n)"
455
+
456
+ expr = ZeroMatrix(2, 3)
457
+ assert torch_code(expr) == "torch.zeros((2, 3))"
458
+
459
+ m, n = symbols("m n")
460
+ expr = ZeroMatrix(m, n)
461
+ assert torch_code(expr) == "torch.zeros((m, n))"
462
+
463
+ expr = OneMatrix(2, 3)
464
+ assert torch_code(expr) == "torch.ones((2, 3))"
465
+
466
+ expr = OneMatrix(m, n)
467
+ assert torch_code(expr) == "torch.ones((m, n))"
468
+
469
+
470
+ def test_torch_special_matrices_lambdify():
471
+ if not torch:
472
+ skip("PyTorch not installed")
473
+
474
+ expr = Identity(3)
475
+ f = lambdify([], expr, 'torch')
476
+ result = f()
477
+ expected = torch.eye(3)
478
+ assert torch.allclose(result, expected)
479
+
480
+ expr = ZeroMatrix(2, 3)
481
+ f = lambdify([], expr, 'torch')
482
+ result = f()
483
+ expected = torch.zeros((2, 3))
484
+ assert torch.allclose(result, expected)
485
+
486
+ expr = OneMatrix(2, 3)
487
+ f = lambdify([], expr, 'torch')
488
+ result = f()
489
+ expected = torch.ones((2, 3))
490
+ assert torch.allclose(result, expected)
491
+
492
+
493
+ def test_torch_complex_operations():
494
+ if not torch:
495
+ skip("PyTorch not installed")
496
+
497
+ expr = conjugate(x)
498
+ assert torch_code(expr) == "torch.conj(x)"
499
+
500
+ # SymPy distributes conjugate over addition and applies specific rules for each term
501
+ expr = conjugate(sin(x) + I * cos(y))
502
+ assert torch_code(expr) == "torch.sin(torch.conj(x)) - 1j*torch.cos(torch.conj(y))"
503
+
504
+ expr = I
505
+ assert torch_code(expr) == "1j"
506
+
507
+ expr = 2 * I + x
508
+ assert torch_code(expr) == "x + 2*1j"
509
+
510
+ expr = exp(I * x)
511
+ assert torch_code(expr) == "torch.exp(1j*x)"
512
+
513
+
514
+ def test_torch_special_functions():
515
+ if not torch:
516
+ skip("PyTorch not installed")
517
+
518
+ expr = Heaviside(x)
519
+ assert torch_code(expr) == "torch.heaviside(x, 1/2)"
520
+
521
+ expr = Heaviside(x, 0)
522
+ assert torch_code(expr) == "torch.heaviside(x, 0)"
523
+
524
+ expr = gamma(x)
525
+ assert torch_code(expr) == "torch.special.gamma(x)"
526
+
527
+ expr = polygamma(0, x) # Use polygamma instead of digamma because sympy will default to that anyway
528
+ assert torch_code(expr) == "torch.special.digamma(x)"
529
+
530
+ expr = gamma(sin(x))
531
+ assert torch_code(expr) == "torch.special.gamma(torch.sin(x))"
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tree.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.printing.tree import tree
2
+ from sympy.testing.pytest import XFAIL
3
+
4
+
5
+ # Remove this flag after making _assumptions cache deterministic.
6
+ @XFAIL
7
+ def test_print_tree_MatAdd():
8
+ from sympy.matrices.expressions import MatrixSymbol
9
+ A = MatrixSymbol('A', 3, 3)
10
+ B = MatrixSymbol('B', 3, 3)
11
+
12
+ test_str = [
13
+ 'MatAdd: A + B\n',
14
+ 'algebraic: False\n',
15
+ 'commutative: False\n',
16
+ 'complex: False\n',
17
+ 'composite: False\n',
18
+ 'even: False\n',
19
+ 'extended_negative: False\n',
20
+ 'extended_nonnegative: False\n',
21
+ 'extended_nonpositive: False\n',
22
+ 'extended_nonzero: False\n',
23
+ 'extended_positive: False\n',
24
+ 'extended_real: False\n',
25
+ 'imaginary: False\n',
26
+ 'integer: False\n',
27
+ 'irrational: False\n',
28
+ 'negative: False\n',
29
+ 'noninteger: False\n',
30
+ 'nonnegative: False\n',
31
+ 'nonpositive: False\n',
32
+ 'nonzero: False\n',
33
+ 'odd: False\n',
34
+ 'positive: False\n',
35
+ 'prime: False\n',
36
+ 'rational: False\n',
37
+ 'real: False\n',
38
+ 'transcendental: False\n',
39
+ 'zero: False\n',
40
+ '+-MatrixSymbol: A\n',
41
+ '| algebraic: False\n',
42
+ '| commutative: False\n',
43
+ '| complex: False\n',
44
+ '| composite: False\n',
45
+ '| even: False\n',
46
+ '| extended_negative: False\n',
47
+ '| extended_nonnegative: False\n',
48
+ '| extended_nonpositive: False\n',
49
+ '| extended_nonzero: False\n',
50
+ '| extended_positive: False\n',
51
+ '| extended_real: False\n',
52
+ '| imaginary: False\n',
53
+ '| integer: False\n',
54
+ '| irrational: False\n',
55
+ '| negative: False\n',
56
+ '| noninteger: False\n',
57
+ '| nonnegative: False\n',
58
+ '| nonpositive: False\n',
59
+ '| nonzero: False\n',
60
+ '| odd: False\n',
61
+ '| positive: False\n',
62
+ '| prime: False\n',
63
+ '| rational: False\n',
64
+ '| real: False\n',
65
+ '| transcendental: False\n',
66
+ '| zero: False\n',
67
+ '| +-Symbol: A\n',
68
+ '| | commutative: True\n',
69
+ '| +-Integer: 3\n',
70
+ '| | algebraic: True\n',
71
+ '| | commutative: True\n',
72
+ '| | complex: True\n',
73
+ '| | extended_negative: False\n',
74
+ '| | extended_nonnegative: True\n',
75
+ '| | extended_real: True\n',
76
+ '| | finite: True\n',
77
+ '| | hermitian: True\n',
78
+ '| | imaginary: False\n',
79
+ '| | infinite: False\n',
80
+ '| | integer: True\n',
81
+ '| | irrational: False\n',
82
+ '| | negative: False\n',
83
+ '| | noninteger: False\n',
84
+ '| | nonnegative: True\n',
85
+ '| | rational: True\n',
86
+ '| | real: True\n',
87
+ '| | transcendental: False\n',
88
+ '| +-Integer: 3\n',
89
+ '| algebraic: True\n',
90
+ '| commutative: True\n',
91
+ '| complex: True\n',
92
+ '| extended_negative: False\n',
93
+ '| extended_nonnegative: True\n',
94
+ '| extended_real: True\n',
95
+ '| finite: True\n',
96
+ '| hermitian: True\n',
97
+ '| imaginary: False\n',
98
+ '| infinite: False\n',
99
+ '| integer: True\n',
100
+ '| irrational: False\n',
101
+ '| negative: False\n',
102
+ '| noninteger: False\n',
103
+ '| nonnegative: True\n',
104
+ '| rational: True\n',
105
+ '| real: True\n',
106
+ '| transcendental: False\n',
107
+ '+-MatrixSymbol: B\n',
108
+ ' algebraic: False\n',
109
+ ' commutative: False\n',
110
+ ' complex: False\n',
111
+ ' composite: False\n',
112
+ ' even: False\n',
113
+ ' extended_negative: False\n',
114
+ ' extended_nonnegative: False\n',
115
+ ' extended_nonpositive: False\n',
116
+ ' extended_nonzero: False\n',
117
+ ' extended_positive: False\n',
118
+ ' extended_real: False\n',
119
+ ' imaginary: False\n',
120
+ ' integer: False\n',
121
+ ' irrational: False\n',
122
+ ' negative: False\n',
123
+ ' noninteger: False\n',
124
+ ' nonnegative: False\n',
125
+ ' nonpositive: False\n',
126
+ ' nonzero: False\n',
127
+ ' odd: False\n',
128
+ ' positive: False\n',
129
+ ' prime: False\n',
130
+ ' rational: False\n',
131
+ ' real: False\n',
132
+ ' transcendental: False\n',
133
+ ' zero: False\n',
134
+ ' +-Symbol: B\n',
135
+ ' | commutative: True\n',
136
+ ' +-Integer: 3\n',
137
+ ' | algebraic: True\n',
138
+ ' | commutative: True\n',
139
+ ' | complex: True\n',
140
+ ' | extended_negative: False\n',
141
+ ' | extended_nonnegative: True\n',
142
+ ' | extended_real: True\n',
143
+ ' | finite: True\n',
144
+ ' | hermitian: True\n',
145
+ ' | imaginary: False\n',
146
+ ' | infinite: False\n',
147
+ ' | integer: True\n',
148
+ ' | irrational: False\n',
149
+ ' | negative: False\n',
150
+ ' | noninteger: False\n',
151
+ ' | nonnegative: True\n',
152
+ ' | rational: True\n',
153
+ ' | real: True\n',
154
+ ' | transcendental: False\n',
155
+ ' +-Integer: 3\n',
156
+ ' algebraic: True\n',
157
+ ' commutative: True\n',
158
+ ' complex: True\n',
159
+ ' extended_negative: False\n',
160
+ ' extended_nonnegative: True\n',
161
+ ' extended_real: True\n',
162
+ ' finite: True\n',
163
+ ' hermitian: True\n',
164
+ ' imaginary: False\n',
165
+ ' infinite: False\n',
166
+ ' integer: True\n',
167
+ ' irrational: False\n',
168
+ ' negative: False\n',
169
+ ' noninteger: False\n',
170
+ ' nonnegative: True\n',
171
+ ' rational: True\n',
172
+ ' real: True\n',
173
+ ' transcendental: False\n'
174
+ ]
175
+
176
+ assert tree(A + B) == "".join(test_str)
177
+
178
+
179
+ def test_print_tree_MatAdd_noassumptions():
180
+ from sympy.matrices.expressions import MatrixSymbol
181
+ A = MatrixSymbol('A', 3, 3)
182
+ B = MatrixSymbol('B', 3, 3)
183
+
184
+ test_str = \
185
+ """MatAdd: A + B
186
+ +-MatrixSymbol: A
187
+ | +-Str: A
188
+ | +-Integer: 3
189
+ | +-Integer: 3
190
+ +-MatrixSymbol: B
191
+ +-Str: B
192
+ +-Integer: 3
193
+ +-Integer: 3
194
+ """
195
+
196
+ assert tree(A + B, assumptions=False) == test_str
.venv/lib/python3.13/site-packages/sympy/solvers/benchmarks/__init__.py ADDED
File without changes