MTerryJack commited on
Commit
4274159
·
verified ·
1 Parent(s): 67448fa

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. .venv/lib/python3.13/site-packages/sympy/discrete/tests/__init__.py +0 -0
  2. .venv/lib/python3.13/site-packages/sympy/discrete/tests/test_convolutions.py +392 -0
  3. .venv/lib/python3.13/site-packages/sympy/discrete/tests/test_recurrences.py +59 -0
  4. .venv/lib/python3.13/site-packages/sympy/discrete/tests/test_transforms.py +154 -0
  5. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/__init__.py +0 -0
  6. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_cartan_type.py +12 -0
  7. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_dynkin_diagram.py +9 -0
  8. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_root_system.py +18 -0
  9. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_A.py +17 -0
  10. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_D.py +19 -0
  11. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_E.py +22 -0
  12. .venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_weyl_group.py +35 -0
  13. .venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/test_conflict.py +62 -0
  14. .venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/test_core.py +213 -0
  15. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/__init__.py +0 -0
  16. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_bbp_pi.py +134 -0
  17. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_continued_fraction.py +77 -0
  18. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_digits.py +55 -0
  19. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_ecm.py +63 -0
  20. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_egyptian_fraction.py +49 -0
  21. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_elliptic_curve.py +20 -0
  22. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_factor_.py +702 -0
  23. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_generate.py +285 -0
  24. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_hypothesis.py +24 -0
  25. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_modular.py +34 -0
  26. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_multinomial.py +48 -0
  27. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_partitions.py +28 -0
  28. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_primetest.py +235 -0
  29. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_qs.py +110 -0
  30. .venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_residue.py +349 -0
  31. .venv/lib/python3.13/site-packages/sympy/printing/pretty/__init__.py +12 -0
  32. .venv/lib/python3.13/site-packages/sympy/printing/pretty/pretty.py +0 -0
  33. .venv/lib/python3.13/site-packages/sympy/printing/pretty/pretty_symbology.py +731 -0
  34. .venv/lib/python3.13/site-packages/sympy/printing/pretty/stringpict.py +537 -0
  35. .venv/lib/python3.13/site-packages/sympy/printing/pretty/tests/__init__.py +0 -0
  36. .venv/lib/python3.13/site-packages/sympy/printing/pretty/tests/test_pretty.py +0 -0
  37. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_fortran.py +854 -0
  38. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_llvmjit.py +224 -0
  39. .venv/lib/python3.13/site-packages/sympy/printing/tests/test_rcode.py +476 -0
  40. .venv/lib/python3.13/site-packages/sympy/strategies/branch/__init__.py +14 -0
  41. .venv/lib/python3.13/site-packages/sympy/strategies/branch/core.py +116 -0
  42. .venv/lib/python3.13/site-packages/sympy/strategies/branch/tests/test_traverse.py +53 -0
  43. .venv/lib/python3.13/site-packages/sympy/strategies/branch/tools.py +12 -0
  44. .venv/lib/python3.13/site-packages/sympy/strategies/branch/traverse.py +25 -0
  45. .venv/lib/python3.13/site-packages/sympy/strategies/tests/__init__.py +0 -0
  46. .venv/lib/python3.13/site-packages/sympy/strategies/tests/test_core.py +118 -0
  47. .venv/lib/python3.13/site-packages/sympy/strategies/tests/test_rl.py +78 -0
  48. .venv/lib/python3.13/site-packages/sympy/strategies/tests/test_tools.py +32 -0
  49. .venv/lib/python3.13/site-packages/sympy/strategies/tests/test_traverse.py +84 -0
  50. .venv/lib/python3.13/site-packages/sympy/strategies/tests/test_tree.py +92 -0
.venv/lib/python3.13/site-packages/sympy/discrete/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/discrete/tests/test_convolutions.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.numbers import (E, Rational, pi)
2
+ from sympy.functions.elementary.exponential import exp
3
+ from sympy.functions.elementary.miscellaneous import sqrt
4
+ from sympy.core import S, symbols, I
5
+ from sympy.discrete.convolutions import (
6
+ convolution, convolution_fft, convolution_ntt, convolution_fwht,
7
+ convolution_subset, covering_product, intersecting_product,
8
+ convolution_int)
9
+ from sympy.testing.pytest import raises
10
+ from sympy.abc import x, y
11
+
12
+ def test_convolution():
13
+ # fft
14
+ a = [1, Rational(5, 3), sqrt(3), Rational(7, 5)]
15
+ b = [9, 5, 5, 4, 3, 2]
16
+ c = [3, 5, 3, 7, 8]
17
+ d = [1422, 6572, 3213, 5552]
18
+ e = [-1, Rational(5, 3), Rational(7, 5)]
19
+
20
+ assert convolution(a, b) == convolution_fft(a, b)
21
+ assert convolution(a, b, dps=9) == convolution_fft(a, b, dps=9)
22
+ assert convolution(a, d, dps=7) == convolution_fft(d, a, dps=7)
23
+ assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)
24
+
25
+ # prime moduli of the form (m*2**k + 1), sequence length
26
+ # should be a divisor of 2**k
27
+ p = 7*17*2**23 + 1
28
+ q = 19*2**10 + 1
29
+
30
+ # ntt
31
+ assert convolution(d, b, prime=q) == convolution_ntt(b, d, prime=q)
32
+ assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
33
+ assert convolution(d, c, prime=p) == convolution_ntt(c, d, prime=p)
34
+ raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
35
+ raises(TypeError, lambda: convolution(b, d, dps=6, prime=q))
36
+
37
+ # fwht
38
+ assert convolution(a, b, dyadic=True) == convolution_fwht(a, b)
39
+ assert convolution(a, b, dyadic=False) == convolution(a, b)
40
+ raises(TypeError, lambda: convolution(b, d, dps=2, dyadic=True))
41
+ raises(TypeError, lambda: convolution(b, d, prime=p, dyadic=True))
42
+ raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True))
43
+ raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))
44
+
45
+ # subset
46
+ assert convolution(a, b, subset=True) == convolution_subset(a, b) == \
47
+ convolution(a, b, subset=True, dyadic=False) == \
48
+ convolution(a, b, subset=True)
49
+ assert convolution(a, b, subset=False) == convolution(a, b)
50
+ raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True))
51
+ raises(TypeError, lambda: convolution(c, d, subset=True, dps=6))
52
+ raises(TypeError, lambda: convolution(a, c, subset=True, prime=q))
53
+
54
+ # integer
55
+ assert convolution([0], [0]) == convolution_int([0], [0])
56
+ assert convolution(b, c) == convolution_int(b, c)
57
+
58
+ # rational
59
+ assert convolution([Rational(1,2)], [Rational(1,2)]) == [Rational(1, 4)]
60
+ assert convolution(b, e) == [-9, 10, Rational(239, 15), Rational(34, 3),
61
+ Rational(32, 3), Rational(43, 5), Rational(113, 15),
62
+ Rational(14, 5)]
63
+
64
+
65
+ def test_cyclic_convolution():
66
+ # fft
67
+ a = [1, Rational(5, 3), sqrt(3), Rational(7, 5)]
68
+ b = [9, 5, 5, 4, 3, 2]
69
+
70
+ assert convolution([1, 2, 3], [4, 5, 6], cycle=0) == \
71
+ convolution([1, 2, 3], [4, 5, 6], cycle=5) == \
72
+ convolution([1, 2, 3], [4, 5, 6])
73
+
74
+ assert convolution([1, 2, 3], [4, 5, 6], cycle=3) == [31, 31, 28]
75
+
76
+ a = [Rational(1, 3), Rational(7, 3), Rational(5, 9), Rational(2, 7), Rational(5, 8)]
77
+ b = [Rational(3, 5), Rational(4, 7), Rational(7, 8), Rational(8, 9)]
78
+
79
+ assert convolution(a, b, cycle=0) == \
80
+ convolution(a, b, cycle=len(a) + len(b) - 1)
81
+
82
+ assert convolution(a, b, cycle=4) == [Rational(87277, 26460), Rational(30521, 11340),
83
+ Rational(11125, 4032), Rational(3653, 1080)]
84
+
85
+ assert convolution(a, b, cycle=6) == [Rational(20177, 20160), Rational(676, 315), Rational(47, 24),
86
+ Rational(3053, 1080), Rational(16397, 5292), Rational(2497, 2268)]
87
+
88
+ assert convolution(a, b, cycle=9) == \
89
+ convolution(a, b, cycle=0) + [S.Zero]
90
+
91
+ # ntt
92
+ a = [2313, 5323532, S(3232), 42142, 42242421]
93
+ b = [S(33456), 56757, 45754, 432423]
94
+
95
+ assert convolution(a, b, prime=19*2**10 + 1, cycle=0) == \
96
+ convolution(a, b, prime=19*2**10 + 1, cycle=8) == \
97
+ convolution(a, b, prime=19*2**10 + 1)
98
+
99
+ assert convolution(a, b, prime=19*2**10 + 1, cycle=5) == [96, 17146, 2664,
100
+ 15534, 3517]
101
+
102
+ assert convolution(a, b, prime=19*2**10 + 1, cycle=7) == [4643, 3458, 1260,
103
+ 15534, 3517, 16314, 13688]
104
+
105
+ assert convolution(a, b, prime=19*2**10 + 1, cycle=9) == \
106
+ convolution(a, b, prime=19*2**10 + 1) + [0]
107
+
108
+ # fwht
109
+ u, v, w, x, y = symbols('u v w x y')
110
+ p, q, r, s, t = symbols('p q r s t')
111
+ c = [u, v, w, x, y]
112
+ d = [p, q, r, s, t]
113
+
114
+ assert convolution(a, b, dyadic=True, cycle=3) == \
115
+ [2499522285783, 19861417974796, 4702176579021]
116
+
117
+ assert convolution(a, b, dyadic=True, cycle=5) == [2718149225143,
118
+ 2114320852171, 20571217906407, 246166418903, 1413262436976]
119
+
120
+ assert convolution(c, d, dyadic=True, cycle=4) == \
121
+ [p*u + p*y + q*v + r*w + s*x + t*u + t*y,
122
+ p*v + q*u + q*y + r*x + s*w + t*v,
123
+ p*w + q*x + r*u + r*y + s*v + t*w,
124
+ p*x + q*w + r*v + s*u + s*y + t*x]
125
+
126
+ assert convolution(c, d, dyadic=True, cycle=6) == \
127
+ [p*u + q*v + r*w + r*y + s*x + t*w + t*y,
128
+ p*v + q*u + r*x + s*w + s*y + t*x,
129
+ p*w + q*x + r*u + s*v,
130
+ p*x + q*w + r*v + s*u,
131
+ p*y + t*u,
132
+ q*y + t*v]
133
+
134
+ # subset
135
+ assert convolution(a, b, subset=True, cycle=7) == [18266671799811,
136
+ 178235365533, 213958794, 246166418903, 1413262436976,
137
+ 2397553088697, 1932759730434]
138
+
139
+ assert convolution(a[1:], b, subset=True, cycle=4) == \
140
+ [178104086592, 302255835516, 244982785880, 3717819845434]
141
+
142
+ assert convolution(a, b[:-1], subset=True, cycle=6) == [1932837114162,
143
+ 178235365533, 213958794, 245166224504, 1413262436976, 2397553088697]
144
+
145
+ assert convolution(c, d, subset=True, cycle=3) == \
146
+ [p*u + p*x + q*w + r*v + r*y + s*u + t*w,
147
+ p*v + p*y + q*u + s*y + t*u + t*x,
148
+ p*w + q*y + r*u + t*v]
149
+
150
+ assert convolution(c, d, subset=True, cycle=5) == \
151
+ [p*u + q*y + t*v,
152
+ p*v + q*u + r*y + t*w,
153
+ p*w + r*u + s*y + t*x,
154
+ p*x + q*w + r*v + s*u,
155
+ p*y + t*u]
156
+
157
+ raises(ValueError, lambda: convolution([1, 2, 3], [4, 5, 6], cycle=-1))
158
+
159
+
160
+ def test_convolution_fft():
161
+ assert all(convolution_fft([], x, dps=y) == [] for x in ([], [1]) for y in (None, 3))
162
+ assert convolution_fft([1, 2, 3], [4, 5, 6]) == [4, 13, 28, 27, 18]
163
+ assert convolution_fft([1], [5, 6, 7]) == [5, 6, 7]
164
+ assert convolution_fft([1, 3], [5, 6, 7]) == [5, 21, 25, 21]
165
+
166
+ assert convolution_fft([1 + 2*I], [2 + 3*I]) == [-4 + 7*I]
167
+ assert convolution_fft([1 + 2*I, 3 + 4*I, 5 + 3*I/5], [Rational(2, 5) + 4*I/7]) == \
168
+ [Rational(-26, 35) + I*48/35, Rational(-38, 35) + I*116/35, Rational(58, 35) + I*542/175]
169
+
170
+ assert convolution_fft([Rational(3, 4), Rational(5, 6)], [Rational(7, 8), Rational(1, 3), Rational(2, 5)]) == \
171
+ [Rational(21, 32), Rational(47, 48), Rational(26, 45), Rational(1, 3)]
172
+
173
+ assert convolution_fft([Rational(1, 9), Rational(2, 3), Rational(3, 5)], [Rational(2, 5), Rational(3, 7), Rational(4, 9)]) == \
174
+ [Rational(2, 45), Rational(11, 35), Rational(8152, 14175), Rational(523, 945), Rational(4, 15)]
175
+
176
+ assert convolution_fft([pi, E, sqrt(2)], [sqrt(3), 1/pi, 1/E]) == \
177
+ [sqrt(3)*pi, 1 + sqrt(3)*E, E/pi + pi*exp(-1) + sqrt(6),
178
+ sqrt(2)/pi + 1, sqrt(2)*exp(-1)]
179
+
180
+ assert convolution_fft([2321, 33123], [5321, 6321, 71323]) == \
181
+ [12350041, 190918524, 374911166, 2362431729]
182
+
183
+ assert convolution_fft([312313, 31278232], [32139631, 319631]) == \
184
+ [10037624576503, 1005370659728895, 9997492572392]
185
+
186
+ raises(TypeError, lambda: convolution_fft(x, y))
187
+ raises(ValueError, lambda: convolution_fft([x, y], [y, x]))
188
+
189
+
190
+ def test_convolution_ntt():
191
+ # prime moduli of the form (m*2**k + 1), sequence length
192
+ # should be a divisor of 2**k
193
+ p = 7*17*2**23 + 1
194
+ q = 19*2**10 + 1
195
+ r = 2*500000003 + 1 # only for sequences of length 1 or 2
196
+ # s = 2*3*5*7 # composite modulus
197
+
198
+ assert all(convolution_ntt([], x, prime=y) == [] for x in ([], [1]) for y in (p, q, r))
199
+ assert convolution_ntt([2], [3], r) == [6]
200
+ assert convolution_ntt([2, 3], [4], r) == [8, 12]
201
+
202
+ assert convolution_ntt([32121, 42144, 4214, 4241], [32132, 3232, 87242], p) == [33867619,
203
+ 459741727, 79180879, 831885249, 381344700, 369993322]
204
+ assert convolution_ntt([121913, 3171831, 31888131, 12], [17882, 21292, 29921, 312], q) == \
205
+ [8158, 3065, 3682, 7090, 1239, 2232, 3744]
206
+
207
+ assert convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], p) == \
208
+ convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], q)
209
+ assert convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], p) == \
210
+ convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], q)
211
+
212
+ raises(ValueError, lambda: convolution_ntt([2, 3], [4, 5], r))
213
+ raises(ValueError, lambda: convolution_ntt([x, y], [y, x], q))
214
+ raises(TypeError, lambda: convolution_ntt(x, y, p))
215
+
216
+
217
+ def test_convolution_fwht():
218
+ assert convolution_fwht([], []) == []
219
+ assert convolution_fwht([], [1]) == []
220
+ assert convolution_fwht([1, 2, 3], [4, 5, 6]) == [32, 13, 18, 27]
221
+
222
+ assert convolution_fwht([Rational(5, 7), Rational(6, 8), Rational(7, 3)], [2, 4, Rational(6, 7)]) == \
223
+ [Rational(45, 7), Rational(61, 14), Rational(776, 147), Rational(419, 42)]
224
+
225
+ a = [1, Rational(5, 3), sqrt(3), Rational(7, 5), 4 + 5*I]
226
+ b = [94, 51, 53, 45, 31, 27, 13]
227
+ c = [3 + 4*I, 5 + 7*I, 3, Rational(7, 6), 8]
228
+
229
+ assert convolution_fwht(a, b) == [53*sqrt(3) + 366 + 155*I,
230
+ 45*sqrt(3) + Rational(5848, 15) + 135*I,
231
+ 94*sqrt(3) + Rational(1257, 5) + 65*I,
232
+ 51*sqrt(3) + Rational(3974, 15),
233
+ 13*sqrt(3) + 452 + 470*I,
234
+ Rational(4513, 15) + 255*I,
235
+ 31*sqrt(3) + Rational(1314, 5) + 265*I,
236
+ 27*sqrt(3) + Rational(3676, 15) + 225*I]
237
+
238
+ assert convolution_fwht(b, c) == [Rational(1993, 2) + 733*I, Rational(6215, 6) + 862*I,
239
+ Rational(1659, 2) + 527*I, Rational(1988, 3) + 551*I, 1019 + 313*I, Rational(3955, 6) + 325*I,
240
+ Rational(1175, 2) + 52*I, Rational(3253, 6) + 91*I]
241
+
242
+ assert convolution_fwht(a[3:], c) == [Rational(-54, 5) + I*293/5, -1 + I*204/5,
243
+ Rational(133, 15) + I*35/6, Rational(409, 30) + 15*I, Rational(56, 5), 32 + 40*I, 0, 0]
244
+
245
+ u, v, w, x, y, z = symbols('u v w x y z')
246
+
247
+ assert convolution_fwht([u, v], [x, y]) == [u*x + v*y, u*y + v*x]
248
+
249
+ assert convolution_fwht([u, v, w], [x, y]) == \
250
+ [u*x + v*y, u*y + v*x, w*x, w*y]
251
+
252
+ assert convolution_fwht([u, v, w], [x, y, z]) == \
253
+ [u*x + v*y + w*z, u*y + v*x, u*z + w*x, v*z + w*y]
254
+
255
+ raises(TypeError, lambda: convolution_fwht(x, y))
256
+ raises(TypeError, lambda: convolution_fwht(x*y, u + v))
257
+
258
+
259
+ def test_convolution_subset():
260
+ assert convolution_subset([], []) == []
261
+ assert convolution_subset([], [Rational(1, 3)]) == []
262
+ assert convolution_subset([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
263
+
264
+ a = [1, Rational(5, 3), sqrt(3), 4 + 5*I]
265
+ b = [64, 71, 55, 47, 33, 29, 15]
266
+ c = [3 + I*2/3, 5 + 7*I, 7, Rational(7, 5), 9]
267
+
268
+ assert convolution_subset(a, b) == [64, Rational(533, 3), 55 + 64*sqrt(3),
269
+ 71*sqrt(3) + Rational(1184, 3) + 320*I, 33, 84,
270
+ 15 + 33*sqrt(3), 29*sqrt(3) + 157 + 165*I]
271
+
272
+ assert convolution_subset(b, c) == [192 + I*128/3, 533 + I*1486/3,
273
+ 613 + I*110/3, Rational(5013, 5) + I*1249/3,
274
+ 675 + 22*I, 891 + I*751/3,
275
+ 771 + 10*I, Rational(3736, 5) + 105*I]
276
+
277
+ assert convolution_subset(a, c) == convolution_subset(c, a)
278
+ assert convolution_subset(a[:2], b) == \
279
+ [64, Rational(533, 3), 55, Rational(416, 3), 33, 84, 15, 25]
280
+
281
+ assert convolution_subset(a[:2], c) == \
282
+ [3 + I*2/3, 10 + I*73/9, 7, Rational(196, 15), 9, 15, 0, 0]
283
+
284
+ u, v, w, x, y, z = symbols('u v w x y z')
285
+
286
+ assert convolution_subset([u, v, w], [x, y]) == [u*x, u*y + v*x, w*x, w*y]
287
+ assert convolution_subset([u, v, w, x], [y, z]) == \
288
+ [u*y, u*z + v*y, w*y, w*z + x*y]
289
+
290
+ assert convolution_subset([u, v], [x, y, z]) == \
291
+ convolution_subset([x, y, z], [u, v])
292
+
293
+ raises(TypeError, lambda: convolution_subset(x, z))
294
+ raises(TypeError, lambda: convolution_subset(Rational(7, 3), u))
295
+
296
+
297
+ def test_covering_product():
298
+ assert covering_product([], []) == []
299
+ assert covering_product([], [Rational(1, 3)]) == []
300
+ assert covering_product([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
301
+
302
+ a = [1, Rational(5, 8), sqrt(7), 4 + 9*I]
303
+ b = [66, 81, 95, 49, 37, 89, 17]
304
+ c = [3 + I*2/3, 51 + 72*I, 7, Rational(7, 15), 91]
305
+
306
+ assert covering_product(a, b) == [66, Rational(1383, 8), 95 + 161*sqrt(7),
307
+ 130*sqrt(7) + 1303 + 2619*I, 37,
308
+ Rational(671, 4), 17 + 54*sqrt(7),
309
+ 89*sqrt(7) + Rational(4661, 8) + 1287*I]
310
+
311
+ assert covering_product(b, c) == [198 + 44*I, 7740 + 10638*I,
312
+ 1412 + I*190/3, Rational(42684, 5) + I*31202/3,
313
+ 9484 + I*74/3, 22163 + I*27394/3,
314
+ 10621 + I*34/3, Rational(90236, 15) + 1224*I]
315
+
316
+ assert covering_product(a, c) == covering_product(c, a)
317
+ assert covering_product(b, c[:-1]) == [198 + 44*I, 7740 + 10638*I,
318
+ 1412 + I*190/3, Rational(42684, 5) + I*31202/3,
319
+ 111 + I*74/3, 6693 + I*27394/3,
320
+ 429 + I*34/3, Rational(23351, 15) + 1224*I]
321
+
322
+ assert covering_product(a, c[:-1]) == [3 + I*2/3,
323
+ Rational(339, 4) + I*1409/12, 7 + 10*sqrt(7) + 2*sqrt(7)*I/3,
324
+ -403 + 772*sqrt(7)/15 + 72*sqrt(7)*I + I*12658/15]
325
+
326
+ u, v, w, x, y, z = symbols('u v w x y z')
327
+
328
+ assert covering_product([u, v, w], [x, y]) == \
329
+ [u*x, u*y + v*x + v*y, w*x, w*y]
330
+
331
+ assert covering_product([u, v, w, x], [y, z]) == \
332
+ [u*y, u*z + v*y + v*z, w*y, w*z + x*y + x*z]
333
+
334
+ assert covering_product([u, v], [x, y, z]) == \
335
+ covering_product([x, y, z], [u, v])
336
+
337
+ raises(TypeError, lambda: covering_product(x, z))
338
+ raises(TypeError, lambda: covering_product(Rational(7, 3), u))
339
+
340
+
341
+ def test_intersecting_product():
342
+ assert intersecting_product([], []) == []
343
+ assert intersecting_product([], [Rational(1, 3)]) == []
344
+ assert intersecting_product([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
345
+
346
+ a = [1, sqrt(5), Rational(3, 8) + 5*I, 4 + 7*I]
347
+ b = [67, 51, 65, 48, 36, 79, 27]
348
+ c = [3 + I*2/5, 5 + 9*I, 7, Rational(7, 19), 13]
349
+
350
+ assert intersecting_product(a, b) == [195*sqrt(5) + Rational(6979, 8) + 1886*I,
351
+ 178*sqrt(5) + 520 + 910*I, Rational(841, 2) + 1344*I,
352
+ 192 + 336*I, 0, 0, 0, 0]
353
+
354
+ assert intersecting_product(b, c) == [Rational(128553, 19) + I*9521/5,
355
+ Rational(17820, 19) + 1602*I, Rational(19264, 19), Rational(336, 19), 1846, 0, 0, 0]
356
+
357
+ assert intersecting_product(a, c) == intersecting_product(c, a)
358
+ assert intersecting_product(b[1:], c[:-1]) == [Rational(64788, 19) + I*8622/5,
359
+ Rational(12804, 19) + 1152*I, Rational(11508, 19), Rational(252, 19), 0, 0, 0, 0]
360
+
361
+ assert intersecting_product(a, c[:-2]) == \
362
+ [Rational(-99, 5) + 10*sqrt(5) + 2*sqrt(5)*I/5 + I*3021/40,
363
+ -43 + 5*sqrt(5) + 9*sqrt(5)*I + 71*I, Rational(245, 8) + 84*I, 0]
364
+
365
+ u, v, w, x, y, z = symbols('u v w x y z')
366
+
367
+ assert intersecting_product([u, v, w], [x, y]) == \
368
+ [u*x + u*y + v*x + w*x + w*y, v*y, 0, 0]
369
+
370
+ assert intersecting_product([u, v, w, x], [y, z]) == \
371
+ [u*y + u*z + v*y + w*y + w*z + x*y, v*z + x*z, 0, 0]
372
+
373
+ assert intersecting_product([u, v], [x, y, z]) == \
374
+ intersecting_product([x, y, z], [u, v])
375
+
376
+ raises(TypeError, lambda: intersecting_product(x, z))
377
+ raises(TypeError, lambda: intersecting_product(u, Rational(8, 3)))
378
+
379
+
380
+ def test_convolution_int():
381
+ assert convolution_int([1], [1]) == [1]
382
+ assert convolution_int([1, 1], [0]) == [0]
383
+ assert convolution_int([1, 2, 3], [4, 5, 6]) == [4, 13, 28, 27, 18]
384
+ assert convolution_int([1], [5, 6, 7]) == [5, 6, 7]
385
+ assert convolution_int([1, 3], [5, 6, 7]) == [5, 21, 25, 21]
386
+ assert convolution_int([10, -5, 1, 3], [-5, 6, 7]) == [-50, 85, 35, -44, 25, 21]
387
+ assert convolution_int([0, 1, 0, -1], [1, 0, -1, 0]) == [0, 1, 0, -2, 0, 1]
388
+ assert convolution_int(
389
+ [-341, -5, 1, 3, -71, -99, 43, 87],
390
+ [5, 6, 7, 12, 345, 21, -78, -7, -89]
391
+ ) == [-1705, -2071, -2412, -4106, -118035, -9774, 25998, 2981, 5509,
392
+ -34317, 19228, 38870, 5485, 1724, -4436, -7743]
.venv/lib/python3.13/site-packages/sympy/discrete/tests/test_recurrences.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.numbers import Rational
2
+ from sympy.functions.combinatorial.numbers import fibonacci
3
+ from sympy.core import S, symbols
4
+ from sympy.testing.pytest import raises
5
+ from sympy.discrete.recurrences import linrec
6
+
7
+ def test_linrec():
8
+ assert linrec(coeffs=[1, 1], init=[1, 1], n=20) == 10946
9
+ assert linrec(coeffs=[1, 2, 3, 4, 5], init=[1, 1, 0, 2], n=10) == 1040
10
+ assert linrec(coeffs=[0, 0, 11, 13], init=[23, 27], n=25) == 59628567384
11
+ assert linrec(coeffs=[0, 0, 1, 1, 2], init=[1, 5, 3], n=15) == 165
12
+ assert linrec(coeffs=[11, 13, 15, 17], init=[1, 2, 3, 4], n=70) == \
13
+ 56889923441670659718376223533331214868804815612050381493741233489928913241
14
+ assert linrec(coeffs=[0]*55 + [1, 1, 2, 3], init=[0]*50 + [1, 2, 3], n=4000) == \
15
+ 702633573874937994980598979769135096432444135301118916539
16
+
17
+ assert linrec(coeffs=[11, 13, 15, 17], init=[1, 2, 3, 4], n=10**4)
18
+ assert linrec(coeffs=[11, 13, 15, 17], init=[1, 2, 3, 4], n=10**5)
19
+
20
+ assert all(linrec(coeffs=[1, 1], init=[0, 1], n=n) == fibonacci(n)
21
+ for n in range(95, 115))
22
+
23
+ assert all(linrec(coeffs=[1, 1], init=[1, 1], n=n) == fibonacci(n + 1)
24
+ for n in range(595, 615))
25
+
26
+ a = [S.Half, Rational(3, 4), Rational(5, 6), 7, Rational(8, 9), Rational(3, 5)]
27
+ b = [1, 2, 8, Rational(5, 7), Rational(3, 7), Rational(2, 9), 6]
28
+ x, y, z = symbols('x y z')
29
+
30
+ assert linrec(coeffs=a[:5], init=b[:4], n=80) == \
31
+ Rational(1726244235456268979436592226626304376013002142588105090705187189,
32
+ 1960143456748895967474334873705475211264)
33
+
34
+ assert linrec(coeffs=a[:4], init=b[:4], n=50) == \
35
+ Rational(368949940033050147080268092104304441, 504857282956046106624)
36
+
37
+ assert linrec(coeffs=a[3:], init=b[:3], n=35) == \
38
+ Rational(97409272177295731943657945116791049305244422833125109,
39
+ 814315512679031689453125)
40
+
41
+ assert linrec(coeffs=[0]*60 + [Rational(2, 3), Rational(4, 5)], init=b, n=3000) == \
42
+ Rational(26777668739896791448594650497024, 48084516708184142230517578125)
43
+
44
+ raises(TypeError, lambda: linrec(coeffs=[11, 13, 15, 17], init=[1, 2, 3, 4, 5], n=1))
45
+ raises(TypeError, lambda: linrec(coeffs=a[:4], init=b[:5], n=10000))
46
+ raises(ValueError, lambda: linrec(coeffs=a[:4], init=b[:4], n=-10000))
47
+ raises(TypeError, lambda: linrec(x, b, n=10000))
48
+ raises(TypeError, lambda: linrec(a, y, n=10000))
49
+
50
+ assert linrec(coeffs=[x, y, z], init=[1, 1, 1], n=4) == \
51
+ x**2 + x*y + x*z + y + z
52
+ assert linrec(coeffs=[1, 2, 1], init=[x, y, z], n=20) == \
53
+ 269542*x + 664575*y + 578949*z
54
+ assert linrec(coeffs=[0, 3, 1, 2], init=[x, y], n=30) == \
55
+ 58516436*x + 56372788*y
56
+ assert linrec(coeffs=[0]*50 + [1, 2, 3], init=[x, y, z], n=1000) == \
57
+ 11477135884896*x + 25999077948732*y + 41975630244216*z
58
+ assert linrec(coeffs=[], init=[1, 1], n=20) == 0
59
+ assert linrec(coeffs=[x, y, z], init=[1, 2, 3], n=2) == 3
.venv/lib/python3.13/site-packages/sympy/discrete/tests/test_transforms.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.functions.elementary.miscellaneous import sqrt
2
+ from sympy.core import S, Symbol, symbols, I, Rational
3
+ from sympy.discrete import (fft, ifft, ntt, intt, fwht, ifwht,
4
+ mobius_transform, inverse_mobius_transform)
5
+ from sympy.testing.pytest import raises
6
+
7
+
8
+ def test_fft_ifft():
9
+ assert all(tf(ls) == ls for tf in (fft, ifft)
10
+ for ls in ([], [Rational(5, 3)]))
11
+
12
+ ls = list(range(6))
13
+ fls = [15, -7*sqrt(2)/2 - 4 - sqrt(2)*I/2 + 2*I, 2 + 3*I,
14
+ -4 + 7*sqrt(2)/2 - 2*I - sqrt(2)*I/2, -3,
15
+ -4 + 7*sqrt(2)/2 + sqrt(2)*I/2 + 2*I,
16
+ 2 - 3*I, -7*sqrt(2)/2 - 4 - 2*I + sqrt(2)*I/2]
17
+
18
+ assert fft(ls) == fls
19
+ assert ifft(fls) == ls + [S.Zero]*2
20
+
21
+ ls = [1 + 2*I, 3 + 4*I, 5 + 6*I]
22
+ ifls = [Rational(9, 4) + 3*I, I*Rational(-7, 4), Rational(3, 4) + I, -2 - I/4]
23
+
24
+ assert ifft(ls) == ifls
25
+ assert fft(ifls) == ls + [S.Zero]
26
+
27
+ x = Symbol('x', real=True)
28
+ raises(TypeError, lambda: fft(x))
29
+ raises(ValueError, lambda: ifft([x, 2*x, 3*x**2, 4*x**3]))
30
+
31
+
32
+ def test_ntt_intt():
33
+ # prime moduli of the form (m*2**k + 1), sequence length
34
+ # should be a divisor of 2**k
35
+ p = 7*17*2**23 + 1
36
+ q = 2*500000003 + 1 # only for sequences of length 1 or 2
37
+ r = 2*3*5*7 # composite modulus
38
+
39
+ assert all(tf(ls, p) == ls for tf in (ntt, intt)
40
+ for ls in ([], [5]))
41
+
42
+ ls = list(range(6))
43
+ nls = [15, 801133602, 738493201, 334102277, 998244350, 849020224,
44
+ 259751156, 12232587]
45
+
46
+ assert ntt(ls, p) == nls
47
+ assert intt(nls, p) == ls + [0]*2
48
+
49
+ ls = [1 + 2*I, 3 + 4*I, 5 + 6*I]
50
+ x = Symbol('x', integer=True)
51
+
52
+ raises(TypeError, lambda: ntt(x, p))
53
+ raises(ValueError, lambda: intt([x, 2*x, 3*x**2, 4*x**3], p))
54
+ raises(ValueError, lambda: intt(ls, p))
55
+ raises(ValueError, lambda: ntt([1.2, 2.1, 3.5], p))
56
+ raises(ValueError, lambda: ntt([3, 5, 6], q))
57
+ raises(ValueError, lambda: ntt([4, 5, 7], r))
58
+ raises(ValueError, lambda: ntt([1.0, 2.0, 3.0], p))
59
+
60
+
61
+ def test_fwht_ifwht():
62
+ assert all(tf(ls) == ls for tf in (fwht, ifwht) \
63
+ for ls in ([], [Rational(7, 4)]))
64
+
65
+ ls = [213, 321, 43235, 5325, 312, 53]
66
+ fls = [49459, 38061, -47661, -37759, 48729, 37543, -48391, -38277]
67
+
68
+ assert fwht(ls) == fls
69
+ assert ifwht(fls) == ls + [S.Zero]*2
70
+
71
+ ls = [S.Half + 2*I, Rational(3, 7) + 4*I, Rational(5, 6) + 6*I, Rational(7, 3), Rational(9, 4)]
72
+ ifls = [Rational(533, 672) + I*3/2, Rational(23, 224) + I/2, Rational(1, 672), Rational(107, 224) - I,
73
+ Rational(155, 672) + I*3/2, Rational(-103, 224) + I/2, Rational(-377, 672), Rational(-19, 224) - I]
74
+
75
+ assert ifwht(ls) == ifls
76
+ assert fwht(ifls) == ls + [S.Zero]*3
77
+
78
+ x, y = symbols('x y')
79
+
80
+ raises(TypeError, lambda: fwht(x))
81
+
82
+ ls = [x, 2*x, 3*x**2, 4*x**3]
83
+ ifls = [x**3 + 3*x**2/4 + x*Rational(3, 4),
84
+ -x**3 + 3*x**2/4 - x/4,
85
+ -x**3 - 3*x**2/4 + x*Rational(3, 4),
86
+ x**3 - 3*x**2/4 - x/4]
87
+
88
+ assert ifwht(ls) == ifls
89
+ assert fwht(ifls) == ls
90
+
91
+ ls = [x, y, x**2, y**2, x*y]
92
+ fls = [x**2 + x*y + x + y**2 + y,
93
+ x**2 + x*y + x - y**2 - y,
94
+ -x**2 + x*y + x - y**2 + y,
95
+ -x**2 + x*y + x + y**2 - y,
96
+ x**2 - x*y + x + y**2 + y,
97
+ x**2 - x*y + x - y**2 - y,
98
+ -x**2 - x*y + x - y**2 + y,
99
+ -x**2 - x*y + x + y**2 - y]
100
+
101
+ assert fwht(ls) == fls
102
+ assert ifwht(fls) == ls + [S.Zero]*3
103
+
104
+ ls = list(range(6))
105
+
106
+ assert fwht(ls) == [x*8 for x in ifwht(ls)]
107
+
108
+
109
+ def test_mobius_transform():
110
+ assert all(tf(ls, subset=subset) == ls
111
+ for ls in ([], [Rational(7, 4)]) for subset in (True, False)
112
+ for tf in (mobius_transform, inverse_mobius_transform))
113
+
114
+ w, x, y, z = symbols('w x y z')
115
+
116
+ assert mobius_transform([x, y]) == [x, x + y]
117
+ assert inverse_mobius_transform([x, x + y]) == [x, y]
118
+ assert mobius_transform([x, y], subset=False) == [x + y, y]
119
+ assert inverse_mobius_transform([x + y, y], subset=False) == [x, y]
120
+
121
+ assert mobius_transform([w, x, y, z]) == [w, w + x, w + y, w + x + y + z]
122
+ assert inverse_mobius_transform([w, w + x, w + y, w + x + y + z]) == \
123
+ [w, x, y, z]
124
+ assert mobius_transform([w, x, y, z], subset=False) == \
125
+ [w + x + y + z, x + z, y + z, z]
126
+ assert inverse_mobius_transform([w + x + y + z, x + z, y + z, z], subset=False) == \
127
+ [w, x, y, z]
128
+
129
+ ls = [Rational(2, 3), Rational(6, 7), Rational(5, 8), 9, Rational(5, 3) + 7*I]
130
+ mls = [Rational(2, 3), Rational(32, 21), Rational(31, 24), Rational(1873, 168),
131
+ Rational(7, 3) + 7*I, Rational(67, 21) + 7*I, Rational(71, 24) + 7*I,
132
+ Rational(2153, 168) + 7*I]
133
+
134
+ assert mobius_transform(ls) == mls
135
+ assert inverse_mobius_transform(mls) == ls + [S.Zero]*3
136
+
137
+ mls = [Rational(2153, 168) + 7*I, Rational(69, 7), Rational(77, 8), 9, Rational(5, 3) + 7*I, 0, 0, 0]
138
+
139
+ assert mobius_transform(ls, subset=False) == mls
140
+ assert inverse_mobius_transform(mls, subset=False) == ls + [S.Zero]*3
141
+
142
+ ls = ls[:-1]
143
+ mls = [Rational(2, 3), Rational(32, 21), Rational(31, 24), Rational(1873, 168)]
144
+
145
+ assert mobius_transform(ls) == mls
146
+ assert inverse_mobius_transform(mls) == ls
147
+
148
+ mls = [Rational(1873, 168), Rational(69, 7), Rational(77, 8), 9]
149
+
150
+ assert mobius_transform(ls, subset=False) == mls
151
+ assert inverse_mobius_transform(mls, subset=False) == ls
152
+
153
+ raises(TypeError, lambda: mobius_transform(x, subset=True))
154
+ raises(TypeError, lambda: inverse_mobius_transform(y, subset=False))
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_cartan_type.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.cartan_type import CartanType, Standard_Cartan
2
+
3
+ def test_Standard_Cartan():
4
+ c = CartanType("A4")
5
+ assert c.rank() == 4
6
+ assert c.series == "A"
7
+ m = Standard_Cartan("A", 2)
8
+ assert m.rank() == 2
9
+ assert m.series == "A"
10
+ b = CartanType("B12")
11
+ assert b.rank() == 12
12
+ assert b.series == "B"
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_dynkin_diagram.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.dynkin_diagram import DynkinDiagram
2
+
3
+ def test_DynkinDiagram():
4
+ c = DynkinDiagram("A3")
5
+ diag = "0---0---0\n1 2 3"
6
+ assert c == diag
7
+ ct = DynkinDiagram(["B", 3])
8
+ diag2 = "0---0=>=0\n1 2 3"
9
+ assert ct == diag2
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_root_system.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.root_system import RootSystem
2
+ from sympy.liealgebras.type_a import TypeA
3
+ from sympy.matrices import Matrix
4
+
5
+ def test_root_system():
6
+ c = RootSystem("A3")
7
+ assert c.cartan_type == TypeA(3)
8
+ assert c.simple_roots() == {1: [1, -1, 0, 0], 2: [0, 1, -1, 0], 3: [0, 0, 1, -1]}
9
+ assert c.root_space() == "alpha[1] + alpha[2] + alpha[3]"
10
+ assert c.cartan_matrix() == Matrix([[ 2, -1, 0], [-1, 2, -1], [ 0, -1, 2]])
11
+ assert c.dynkin_diagram() == "0---0---0\n1 2 3"
12
+ assert c.add_simple_roots(1, 2) == [1, 0, -1, 0]
13
+ assert c.all_roots() == {1: [1, -1, 0, 0], 2: [1, 0, -1, 0],
14
+ 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], 5: [0, 1, 0, -1],
15
+ 6: [0, 0, 1, -1], 7: [-1, 1, 0, 0], 8: [-1, 0, 1, 0],
16
+ 9: [-1, 0, 0, 1], 10: [0, -1, 1, 0],
17
+ 11: [0, -1, 0, 1], 12: [0, 0, -1, 1]}
18
+ assert c.add_as_roots([1, 0, -1, 0], [0, 0, 1, -1]) == [1, 0, 0, -1]
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_A.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.cartan_type import CartanType
2
+ from sympy.matrices import Matrix
3
+
4
+ def test_type_A():
5
+ c = CartanType("A3")
6
+ m = Matrix(3, 3, [2, -1, 0, -1, 2, -1, 0, -1, 2])
7
+ assert m == c.cartan_matrix()
8
+ assert c.basis() == 8
9
+ assert c.roots() == 12
10
+ assert c.dimension() == 4
11
+ assert c.simple_root(1) == [1, -1, 0, 0]
12
+ assert c.highest_root() == [1, 0, 0, -1]
13
+ assert c.lie_algebra() == "su(4)"
14
+ diag = "0---0---0\n1 2 3"
15
+ assert c.dynkin_diagram() == diag
16
+ assert c.positive_roots() == {1: [1, -1, 0, 0], 2: [1, 0, -1, 0],
17
+ 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]}
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_D.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.cartan_type import CartanType
2
+ from sympy.matrices import Matrix
3
+
4
+
5
+
6
+ def test_type_D():
7
+ c = CartanType("D4")
8
+ m = Matrix(4, 4, [2, -1, 0, 0, -1, 2, -1, -1, 0, -1, 2, 0, 0, -1, 0, 2])
9
+ assert c.cartan_matrix() == m
10
+ assert c.basis() == 6
11
+ assert c.lie_algebra() == "so(8)"
12
+ assert c.roots() == 24
13
+ assert c.simple_root(3) == [0, 0, 1, -1]
14
+ diag = " 3\n 0\n |\n |\n0---0---0\n1 2 4"
15
+ assert diag == c.dynkin_diagram()
16
+ assert c.positive_roots() == {1: [1, -1, 0, 0], 2: [1, 1, 0, 0],
17
+ 3: [1, 0, -1, 0], 4: [1, 0, 1, 0], 5: [1, 0, 0, -1], 6: [1, 0, 0, 1],
18
+ 7: [0, 1, -1, 0], 8: [0, 1, 1, 0], 9: [0, 1, 0, -1], 10: [0, 1, 0, 1],
19
+ 11: [0, 0, 1, -1], 12: [0, 0, 1, 1]}
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_type_E.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.cartan_type import CartanType
2
+ from sympy.matrices import Matrix
3
+ from sympy.core.backend import Rational
4
+
5
+ def test_type_E():
6
+ c = CartanType("E6")
7
+ m = Matrix(6, 6, [2, 0, -1, 0, 0, 0, 0, 2, 0, -1, 0, 0,
8
+ -1, 0, 2, -1, 0, 0, 0, -1, -1, 2, -1, 0, 0, 0, 0,
9
+ -1, 2, -1, 0, 0, 0, 0, -1, 2])
10
+ assert c.cartan_matrix() == m
11
+ assert c.dimension() == 8
12
+ assert c.simple_root(6) == [0, 0, 0, -1, 1, 0, 0, 0]
13
+ assert c.roots() == 72
14
+ assert c.basis() == 78
15
+ diag = " "*8 + "2\n" + " "*8 + "0\n" + " "*8 + "|\n" + " "*8 + "|\n"
16
+ diag += "---".join("0" for i in range(1, 6))+"\n"
17
+ diag += "1 " + " ".join(str(i) for i in range(3, 7))
18
+ assert c.dynkin_diagram() == diag
19
+ posroots = c.positive_roots()
20
+ assert posroots[8] == [1, 0, 0, 0, 1, 0, 0, 0]
21
+ assert posroots[21] == [Rational(1,2),Rational(1,2),Rational(1,2),Rational(1,2),
22
+ Rational(1,2),Rational(-1,2),Rational(-1,2),Rational(1,2)]
.venv/lib/python3.13/site-packages/sympy/liealgebras/tests/test_weyl_group.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.liealgebras.weyl_group import WeylGroup
2
+ from sympy.matrices import Matrix
3
+
4
+ def test_weyl_group():
5
+ c = WeylGroup("A3")
6
+ assert c.matrix_form('r1*r2') == Matrix([[0, 0, 1, 0], [1, 0, 0, 0],
7
+ [0, 1, 0, 0], [0, 0, 0, 1]])
8
+ assert c.generators() == ['r1', 'r2', 'r3']
9
+ assert c.group_order() == 24.0
10
+ assert c.group_name() == "S4: the symmetric group acting on 4 elements."
11
+ assert c.coxeter_diagram() == "0---0---0\n1 2 3"
12
+ assert c.element_order('r1*r2*r3') == 4
13
+ assert c.element_order('r1*r3*r2*r3') == 3
14
+ d = WeylGroup("B5")
15
+ assert d.group_order() == 3840
16
+ assert d.element_order('r1*r2*r4*r5') == 12
17
+ assert d.matrix_form('r2*r3') == Matrix([[0, 0, 1, 0, 0], [1, 0, 0, 0, 0],
18
+ [0, 1, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]])
19
+ assert d.element_order('r1*r2*r1*r3*r5') == 6
20
+ e = WeylGroup("D5")
21
+ assert e.element_order('r2*r3*r5') == 4
22
+ assert e.matrix_form('r2*r3*r5') == Matrix([[1, 0, 0, 0, 0], [0, 0, 0, 0, -1],
23
+ [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, -1, 0]])
24
+ f = WeylGroup("G2")
25
+ assert f.element_order('r1*r2*r1*r2') == 3
26
+ assert f.element_order('r2*r1*r1*r2') == 1
27
+
28
+ assert f.matrix_form('r1*r2*r1*r2') == Matrix([[0, 1, 0], [0, 0, 1], [1, 0, 0]])
29
+ g = WeylGroup("F4")
30
+ assert g.matrix_form('r2*r3') == Matrix([[1, 0, 0, 0], [0, 1, 0, 0],
31
+ [0, 0, 0, -1], [0, 0, 1, 0]])
32
+
33
+ assert g.element_order('r2*r3') == 4
34
+ h = WeylGroup("E6")
35
+ assert h.group_order() == 51840
.venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/test_conflict.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.multipledispatch.conflict import (supercedes, ordering, ambiguities,
2
+ ambiguous, super_signature, consistent)
3
+
4
+
5
+ class A: pass
6
+ class B(A): pass
7
+ class C: pass
8
+
9
+
10
+ def test_supercedes():
11
+ assert supercedes([B], [A])
12
+ assert supercedes([B, A], [A, A])
13
+ assert not supercedes([B, A], [A, B])
14
+ assert not supercedes([A], [B])
15
+
16
+
17
+ def test_consistent():
18
+ assert consistent([A], [A])
19
+ assert consistent([B], [B])
20
+ assert not consistent([A], [C])
21
+ assert consistent([A, B], [A, B])
22
+ assert consistent([B, A], [A, B])
23
+ assert not consistent([B, A], [B])
24
+ assert not consistent([B, A], [B, C])
25
+
26
+
27
+ def test_super_signature():
28
+ assert super_signature([[A]]) == [A]
29
+ assert super_signature([[A], [B]]) == [B]
30
+ assert super_signature([[A, B], [B, A]]) == [B, B]
31
+ assert super_signature([[A, A, B], [A, B, A], [B, A, A]]) == [B, B, B]
32
+
33
+
34
+ def test_ambiguous():
35
+ assert not ambiguous([A], [A])
36
+ assert not ambiguous([A], [B])
37
+ assert not ambiguous([B], [B])
38
+ assert not ambiguous([A, B], [B, B])
39
+ assert ambiguous([A, B], [B, A])
40
+
41
+
42
+ def test_ambiguities():
43
+ signatures = [[A], [B], [A, B], [B, A], [A, C]]
44
+ expected = {((A, B), (B, A))}
45
+ result = ambiguities(signatures)
46
+ assert set(map(frozenset, expected)) == set(map(frozenset, result))
47
+
48
+ signatures = [[A], [B], [A, B], [B, A], [A, C], [B, B]]
49
+ expected = set()
50
+ result = ambiguities(signatures)
51
+ assert set(map(frozenset, expected)) == set(map(frozenset, result))
52
+
53
+
54
+ def test_ordering():
55
+ signatures = [[A, A], [A, B], [B, A], [B, B], [A, C]]
56
+ ord = ordering(signatures)
57
+ assert ord[0] == (B, B) or ord[0] == (A, C)
58
+ assert ord[-1] == (A, A) or ord[-1] == (A, C)
59
+
60
+
61
+ def test_type_mro():
62
+ assert super_signature([[object], [type]]) == [type]
.venv/lib/python3.13/site-packages/sympy/multipledispatch/tests/test_core.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any
3
+
4
+ from sympy.multipledispatch import dispatch
5
+ from sympy.multipledispatch.conflict import AmbiguityWarning
6
+ from sympy.testing.pytest import raises, warns
7
+ from functools import partial
8
+
9
+ test_namespace: dict[str, Any] = {}
10
+
11
+ orig_dispatch = dispatch
12
+ dispatch = partial(dispatch, namespace=test_namespace)
13
+
14
+
15
+ def test_singledispatch():
16
+ @dispatch(int)
17
+ def f(x): # noqa:F811
18
+ return x + 1
19
+
20
+ @dispatch(int)
21
+ def g(x): # noqa:F811
22
+ return x + 2
23
+
24
+ @dispatch(float) # noqa:F811
25
+ def f(x): # noqa:F811
26
+ return x - 1
27
+
28
+ assert f(1) == 2
29
+ assert g(1) == 3
30
+ assert f(1.0) == 0
31
+
32
+ assert raises(NotImplementedError, lambda: f('hello'))
33
+
34
+
35
+ def test_multipledispatch():
36
+ @dispatch(int, int)
37
+ def f(x, y): # noqa:F811
38
+ return x + y
39
+
40
+ @dispatch(float, float) # noqa:F811
41
+ def f(x, y): # noqa:F811
42
+ return x - y
43
+
44
+ assert f(1, 2) == 3
45
+ assert f(1.0, 2.0) == -1.0
46
+
47
+
48
+ class A: pass
49
+ class B: pass
50
+ class C(A): pass
51
+ class D(C): pass
52
+ class E(C): pass
53
+
54
+
55
+ def test_inheritance():
56
+ @dispatch(A)
57
+ def f(x): # noqa:F811
58
+ return 'a'
59
+
60
+ @dispatch(B) # noqa:F811
61
+ def f(x): # noqa:F811
62
+ return 'b'
63
+
64
+ assert f(A()) == 'a'
65
+ assert f(B()) == 'b'
66
+ assert f(C()) == 'a'
67
+
68
+
69
+ def test_inheritance_and_multiple_dispatch():
70
+ @dispatch(A, A)
71
+ def f(x, y): # noqa:F811
72
+ return type(x), type(y)
73
+
74
+ @dispatch(A, B) # noqa:F811
75
+ def f(x, y): # noqa:F811
76
+ return 0
77
+
78
+ assert f(A(), A()) == (A, A)
79
+ assert f(A(), C()) == (A, C)
80
+ assert f(A(), B()) == 0
81
+ assert f(C(), B()) == 0
82
+ assert raises(NotImplementedError, lambda: f(B(), B()))
83
+
84
+
85
+ def test_competing_solutions():
86
+ @dispatch(A)
87
+ def h(x): # noqa:F811
88
+ return 1
89
+
90
+ @dispatch(C) # noqa:F811
91
+ def h(x): # noqa:F811
92
+ return 2
93
+
94
+ assert h(D()) == 2
95
+
96
+
97
+ def test_competing_multiple():
98
+ @dispatch(A, B)
99
+ def h(x, y): # noqa:F811
100
+ return 1
101
+
102
+ @dispatch(C, B) # noqa:F811
103
+ def h(x, y): # noqa:F811
104
+ return 2
105
+
106
+ assert h(D(), B()) == 2
107
+
108
+
109
+ def test_competing_ambiguous():
110
+ test_namespace = {}
111
+ dispatch = partial(orig_dispatch, namespace=test_namespace)
112
+
113
+ @dispatch(A, C)
114
+ def f(x, y): # noqa:F811
115
+ return 2
116
+
117
+ with warns(AmbiguityWarning, test_stacklevel=False):
118
+ @dispatch(C, A) # noqa:F811
119
+ def f(x, y): # noqa:F811
120
+ return 2
121
+
122
+ assert f(A(), C()) == f(C(), A()) == 2
123
+ # assert raises(Warning, lambda : f(C(), C()))
124
+
125
+
126
+ def test_caching_correct_behavior():
127
+ @dispatch(A)
128
+ def f(x): # noqa:F811
129
+ return 1
130
+
131
+ assert f(C()) == 1
132
+
133
+ @dispatch(C)
134
+ def f(x): # noqa:F811
135
+ return 2
136
+
137
+ assert f(C()) == 2
138
+
139
+
140
+ def test_union_types():
141
+ @dispatch((A, C))
142
+ def f(x): # noqa:F811
143
+ return 1
144
+
145
+ assert f(A()) == 1
146
+ assert f(C()) == 1
147
+
148
+
149
+ def test_namespaces():
150
+ ns1 = {}
151
+ ns2 = {}
152
+
153
+ def foo(x):
154
+ return 1
155
+ foo1 = orig_dispatch(int, namespace=ns1)(foo)
156
+
157
+ def foo(x):
158
+ return 2
159
+ foo2 = orig_dispatch(int, namespace=ns2)(foo)
160
+
161
+ assert foo1(0) == 1
162
+ assert foo2(0) == 2
163
+
164
+
165
+ """
166
+ Fails
167
+ def test_dispatch_on_dispatch():
168
+ @dispatch(A)
169
+ @dispatch(C)
170
+ def q(x): # noqa:F811
171
+ return 1
172
+
173
+ assert q(A()) == 1
174
+ assert q(C()) == 1
175
+ """
176
+
177
+
178
+ def test_methods():
179
+ class Foo:
180
+ @dispatch(float)
181
+ def f(self, x): # noqa:F811
182
+ return x - 1
183
+
184
+ @dispatch(int) # noqa:F811
185
+ def f(self, x): # noqa:F811
186
+ return x + 1
187
+
188
+ @dispatch(int)
189
+ def g(self, x): # noqa:F811
190
+ return x + 3
191
+
192
+
193
+ foo = Foo()
194
+ assert foo.f(1) == 2
195
+ assert foo.f(1.0) == 0.0
196
+ assert foo.g(1) == 4
197
+
198
+
199
+ def test_methods_multiple_dispatch():
200
+ class Foo:
201
+ @dispatch(A, A)
202
+ def f(x, y): # noqa:F811
203
+ return 1
204
+
205
+ @dispatch(A, C) # noqa:F811
206
+ def f(x, y): # noqa:F811
207
+ return 2
208
+
209
+
210
+ foo = Foo()
211
+ assert foo.f(A(), A()) == 1
212
+ assert foo.f(A(), C()) == 2
213
+ assert foo.f(C(), C()) == 2
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_bbp_pi.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.random import randint
2
+
3
+ from sympy.ntheory.bbp_pi import pi_hex_digits
4
+ from sympy.testing.pytest import raises
5
+
6
+
7
+ # http://www.herongyang.com/Cryptography/Blowfish-First-8366-Hex-Digits-of-PI.html
8
+ # There are actually 8336 listed there; with the prepended 3 there are 8337
9
+ # below
10
+ dig=''.join('''
11
+ 3243f6a8885a308d313198a2e03707344a4093822299f31d0082efa98ec4e6c89452821e638d013
12
+ 77be5466cf34e90c6cc0ac29b7c97c50dd3f84d5b5b54709179216d5d98979fb1bd1310ba698dfb5
13
+ ac2ffd72dbd01adfb7b8e1afed6a267e96ba7c9045f12c7f9924a19947b3916cf70801f2e2858efc
14
+ 16636920d871574e69a458fea3f4933d7e0d95748f728eb658718bcd5882154aee7b54a41dc25a59
15
+ b59c30d5392af26013c5d1b023286085f0ca417918b8db38ef8e79dcb0603a180e6c9e0e8bb01e8a
16
+ 3ed71577c1bd314b2778af2fda55605c60e65525f3aa55ab945748986263e8144055ca396a2aab10
17
+ b6b4cc5c341141e8cea15486af7c72e993b3ee1411636fbc2a2ba9c55d741831f6ce5c3e169b8793
18
+ 1eafd6ba336c24cf5c7a325381289586773b8f48986b4bb9afc4bfe81b6628219361d809ccfb21a9
19
+ 91487cac605dec8032ef845d5de98575b1dc262302eb651b8823893e81d396acc50f6d6ff383f442
20
+ 392e0b4482a484200469c8f04a9e1f9b5e21c66842f6e96c9a670c9c61abd388f06a51a0d2d8542f
21
+ 68960fa728ab5133a36eef0b6c137a3be4ba3bf0507efb2a98a1f1651d39af017666ca593e82430e
22
+ 888cee8619456f9fb47d84a5c33b8b5ebee06f75d885c12073401a449f56c16aa64ed3aa62363f77
23
+ 061bfedf72429b023d37d0d724d00a1248db0fead349f1c09b075372c980991b7b25d479d8f6e8de
24
+ f7e3fe501ab6794c3b976ce0bd04c006bac1a94fb6409f60c45e5c9ec2196a246368fb6faf3e6c53
25
+ b51339b2eb3b52ec6f6dfc511f9b30952ccc814544af5ebd09bee3d004de334afd660f2807192e4b
26
+ b3c0cba85745c8740fd20b5f39b9d3fbdb5579c0bd1a60320ad6a100c6402c7279679f25fefb1fa3
27
+ cc8ea5e9f8db3222f83c7516dffd616b152f501ec8ad0552ab323db5fafd23876053317b483e00df
28
+ 829e5c57bbca6f8ca01a87562edf1769dbd542a8f6287effc3ac6732c68c4f5573695b27b0bbca58
29
+ c8e1ffa35db8f011a010fa3d98fd2183b84afcb56c2dd1d35b9a53e479b6f84565d28e49bc4bfb97
30
+ 90e1ddf2daa4cb7e3362fb1341cee4c6e8ef20cada36774c01d07e9efe2bf11fb495dbda4dae9091
31
+ 98eaad8e716b93d5a0d08ed1d0afc725e08e3c5b2f8e7594b78ff6e2fbf2122b648888b812900df0
32
+ 1c4fad5ea0688fc31cd1cff191b3a8c1ad2f2f2218be0e1777ea752dfe8b021fa1e5a0cc0fb56f74
33
+ e818acf3d6ce89e299b4a84fe0fd13e0b77cc43b81d2ada8d9165fa2668095770593cc7314211a14
34
+ 77e6ad206577b5fa86c75442f5fb9d35cfebcdaf0c7b3e89a0d6411bd3ae1e7e4900250e2d2071b3
35
+ 5e226800bb57b8e0af2464369bf009b91e5563911d59dfa6aa78c14389d95a537f207d5ba202e5b9
36
+ c5832603766295cfa911c819684e734a41b3472dca7b14a94a1b5100529a532915d60f573fbc9bc6
37
+ e42b60a47681e6740008ba6fb5571be91ff296ec6b2a0dd915b6636521e7b9f9b6ff34052ec58556
38
+ 6453b02d5da99f8fa108ba47996e85076a4b7a70e9b5b32944db75092ec4192623ad6ea6b049a7df
39
+ 7d9cee60b88fedb266ecaa8c71699a17ff5664526cc2b19ee1193602a575094c29a0591340e4183a
40
+ 3e3f54989a5b429d656b8fe4d699f73fd6a1d29c07efe830f54d2d38e6f0255dc14cdd20868470eb
41
+ 266382e9c6021ecc5e09686b3f3ebaefc93c9718146b6a70a1687f358452a0e286b79c5305aa5007
42
+ 373e07841c7fdeae5c8e7d44ec5716f2b8b03ada37f0500c0df01c1f040200b3ffae0cf51a3cb574
43
+ b225837a58dc0921bdd19113f97ca92ff69432477322f547013ae5e58137c2dadcc8b576349af3dd
44
+ a7a94461460fd0030eecc8c73ea4751e41e238cd993bea0e2f3280bba1183eb3314e548b384f6db9
45
+ 086f420d03f60a04bf2cb8129024977c795679b072bcaf89afde9a771fd9930810b38bae12dccf3f
46
+ 2e5512721f2e6b7124501adde69f84cd877a5847187408da17bc9f9abce94b7d8cec7aec3adb851d
47
+ fa63094366c464c3d2ef1c18473215d908dd433b3724c2ba1612a14d432a65c45150940002133ae4
48
+ dd71dff89e10314e5581ac77d65f11199b043556f1d7a3c76b3c11183b5924a509f28fe6ed97f1fb
49
+ fa9ebabf2c1e153c6e86e34570eae96fb1860e5e0a5a3e2ab3771fe71c4e3d06fa2965dcb999e71d
50
+ 0f803e89d65266c8252e4cc9789c10b36ac6150eba94e2ea78a5fc3c531e0a2df4f2f74ea7361d2b
51
+ 3d1939260f19c279605223a708f71312b6ebadfe6eeac31f66e3bc4595a67bc883b17f37d1018cff
52
+ 28c332ddefbe6c5aa56558218568ab9802eecea50fdb2f953b2aef7dad5b6e2f841521b628290761
53
+ 70ecdd4775619f151013cca830eb61bd960334fe1eaa0363cfb5735c904c70a239d59e9e0bcbaade
54
+ 14eecc86bc60622ca79cab5cabb2f3846e648b1eaf19bdf0caa02369b9655abb5040685a323c2ab4
55
+ b3319ee9d5c021b8f79b540b19875fa09995f7997e623d7da8f837889a97e32d7711ed935f166812
56
+ 810e358829c7e61fd696dedfa17858ba9957f584a51b2272639b83c3ff1ac24696cdb30aeb532e30
57
+ 548fd948e46dbc312858ebf2ef34c6ffeafe28ed61ee7c3c735d4a14d9e864b7e342105d14203e13
58
+ e045eee2b6a3aaabeadb6c4f15facb4fd0c742f442ef6abbb5654f3b1d41cd2105d81e799e86854d
59
+ c7e44b476a3d816250cf62a1f25b8d2646fc8883a0c1c7b6a37f1524c369cb749247848a0b5692b2
60
+ 85095bbf00ad19489d1462b17423820e0058428d2a0c55f5ea1dadf43e233f70613372f0928d937e
61
+ 41d65fecf16c223bdb7cde3759cbee74604085f2a7ce77326ea607808419f8509ee8efd85561d997
62
+ 35a969a7aac50c06c25a04abfc800bcadc9e447a2ec3453484fdd567050e1e9ec9db73dbd3105588
63
+ cd675fda79e3674340c5c43465713e38d83d28f89ef16dff20153e21e78fb03d4ae6e39f2bdb83ad
64
+ f7e93d5a68948140f7f64c261c94692934411520f77602d4f7bcf46b2ed4a20068d40824713320f4
65
+ 6a43b7d4b7500061af1e39f62e9724454614214f74bf8b88404d95fc1d96b591af70f4ddd366a02f
66
+ 45bfbc09ec03bd97857fac6dd031cb850496eb27b355fd3941da2547e6abca0a9a28507825530429
67
+ f40a2c86dae9b66dfb68dc1462d7486900680ec0a427a18dee4f3ffea2e887ad8cb58ce0067af4d6
68
+ b6aace1e7cd3375fecce78a399406b2a4220fe9e35d9f385b9ee39d7ab3b124e8b1dc9faf74b6d18
69
+ 5626a36631eae397b23a6efa74dd5b43326841e7f7ca7820fbfb0af54ed8feb397454056acba4895
70
+ 2755533a3a20838d87fe6ba9b7d096954b55a867bca1159a58cca9296399e1db33a62a4a563f3125
71
+ f95ef47e1c9029317cfdf8e80204272f7080bb155c05282ce395c11548e4c66d2248c1133fc70f86
72
+ dc07f9c9ee41041f0f404779a45d886e17325f51ebd59bc0d1f2bcc18f41113564257b7834602a9c
73
+ 60dff8e8a31f636c1b0e12b4c202e1329eaf664fd1cad181156b2395e0333e92e13b240b62eebeb9
74
+ 2285b2a20ee6ba0d99de720c8c2da2f728d012784595b794fd647d0862e7ccf5f05449a36f877d48
75
+ fac39dfd27f33e8d1e0a476341992eff743a6f6eabf4f8fd37a812dc60a1ebddf8991be14cdb6e6b
76
+ 0dc67b55106d672c372765d43bdcd0e804f1290dc7cc00ffa3b5390f92690fed0b667b9ffbcedb7d
77
+ 9ca091cf0bd9155ea3bb132f88515bad247b9479bf763bd6eb37392eb3cc1159798026e297f42e31
78
+ 2d6842ada7c66a2b3b12754ccc782ef11c6a124237b79251e706a1bbe64bfb63501a6b101811caed
79
+ fa3d25bdd8e2e1c3c9444216590a121386d90cec6ed5abea2a64af674eda86a85fbebfe98864e4c3
80
+ fe9dbc8057f0f7c08660787bf86003604dd1fd8346f6381fb07745ae04d736fccc83426b33f01eab
81
+ 71b08041873c005e5f77a057bebde8ae2455464299bf582e614e58f48ff2ddfda2f474ef388789bd
82
+ c25366f9c3c8b38e74b475f25546fcd9b97aeb26618b1ddf84846a0e79915f95e2466e598e20b457
83
+ 708cd55591c902de4cb90bace1bb8205d011a862487574a99eb77f19b6e0a9dc09662d09a1c43246
84
+ 33e85a1f0209f0be8c4a99a0251d6efe101ab93d1d0ba5a4dfa186f20f2868f169dcb7da83573906
85
+ fea1e2ce9b4fcd7f5250115e01a70683faa002b5c40de6d0279af88c27773f8641c3604c0661a806
86
+ b5f0177a28c0f586e0006058aa30dc7d6211e69ed72338ea6353c2dd94c2c21634bbcbee5690bcb6
87
+ deebfc7da1ce591d766f05e4094b7c018839720a3d7c927c2486e3725f724d9db91ac15bb4d39eb8
88
+ fced54557808fca5b5d83d7cd34dad0fc41e50ef5eb161e6f8a28514d96c51133c6fd5c7e756e14e
89
+ c4362abfceddc6c837d79a323492638212670efa8e406000e03a39ce37d3faf5cfabc277375ac52d
90
+ 1b5cb0679e4fa33742d382274099bc9bbed5118e9dbf0f7315d62d1c7ec700c47bb78c1b6b21a190
91
+ 45b26eb1be6a366eb45748ab2fbc946e79c6a376d26549c2c8530ff8ee468dde7dd5730a1d4cd04d
92
+ c62939bbdba9ba4650ac9526e8be5ee304a1fad5f06a2d519a63ef8ce29a86ee22c089c2b843242e
93
+ f6a51e03aa9cf2d0a483c061ba9be96a4d8fe51550ba645bd62826a2f9a73a3ae14ba99586ef5562
94
+ e9c72fefd3f752f7da3f046f6977fa0a5980e4a91587b086019b09e6ad3b3ee593e990fd5a9e34d7
95
+ 972cf0b7d9022b8b5196d5ac3a017da67dd1cf3ed67c7d2d281f9f25cfadf2b89b5ad6b4725a88f5
96
+ 4ce029ac71e019a5e647b0acfded93fa9be8d3c48d283b57ccf8d5662979132e28785f0191ed7560
97
+ 55f7960e44e3d35e8c15056dd488f46dba03a161250564f0bdc3eb9e153c9057a297271aeca93a07
98
+ 2a1b3f6d9b1e6321f5f59c66fb26dcf3197533d928b155fdf5035634828aba3cbb28517711c20ad9
99
+ f8abcc5167ccad925f4de817513830dc8e379d58629320f991ea7a90c2fb3e7bce5121ce64774fbe
100
+ 32a8b6e37ec3293d4648de53696413e680a2ae0810dd6db22469852dfd09072166b39a460a6445c0
101
+ dd586cdecf1c20c8ae5bbef7dd1b588d40ccd2017f6bb4e3bbdda26a7e3a59ff453e350a44bcb4cd
102
+ d572eacea8fa6484bb8d6612aebf3c6f47d29be463542f5d9eaec2771bf64e6370740e0d8de75b13
103
+ 57f8721671af537d5d4040cb084eb4e2cc34d2466a0115af84e1b0042895983a1d06b89fb4ce6ea0
104
+ 486f3f3b823520ab82011a1d4b277227f8611560b1e7933fdcbb3a792b344525bda08839e151ce79
105
+ 4b2f32c9b7a01fbac9e01cc87ebcc7d1f6cf0111c3a1e8aac71a908749d44fbd9ad0dadecbd50ada
106
+ 380339c32ac69136678df9317ce0b12b4ff79e59b743f5bb3af2d519ff27d9459cbf97222c15e6fc
107
+ 2a0f91fc719b941525fae59361ceb69cebc2a8645912baa8d1b6c1075ee3056a0c10d25065cb03a4
108
+ 42e0ec6e0e1698db3b4c98a0be3278e9649f1f9532e0d392dfd3a0342b8971f21e1b0a74414ba334
109
+ 8cc5be7120c37632d8df359f8d9b992f2ee60b6f470fe3f11de54cda541edad891ce6279cfcd3e7e
110
+ 6f1618b166fd2c1d05848fd2c5f6fb2299f523f357a632762393a8353156cccd02acf081625a75eb
111
+ b56e16369788d273ccde96629281b949d04c50901b71c65614e6c6c7bd327a140a45e1d006c3f27b
112
+ 9ac9aa53fd62a80f00bb25bfe235bdd2f671126905b2040222b6cbcf7ccd769c2b53113ec01640e3
113
+ d338abbd602547adf0ba38209cf746ce7677afa1c52075606085cbfe4e8ae88dd87aaaf9b04cf9aa
114
+ 7e1948c25c02fb8a8c01c36ae4d6ebe1f990d4f869a65cdea03f09252dc208e69fb74e6132ce77e2
115
+ 5b578fdfe33ac372e6'''.split())
116
+
117
+
118
+ def test_hex_pi_nth_digits():
119
+ assert pi_hex_digits(0) == '3243f6a8885a30'
120
+ assert pi_hex_digits(1) == '243f6a8885a308'
121
+ assert pi_hex_digits(10000) == '68ac8fcfb8016c'
122
+ assert pi_hex_digits(13) == '08d313198a2e03'
123
+ assert pi_hex_digits(0, 3) == '324'
124
+ assert pi_hex_digits(0, 0) == ''
125
+ raises(ValueError, lambda: pi_hex_digits(-1))
126
+ raises(ValueError, lambda: pi_hex_digits(0, -1))
127
+ raises(ValueError, lambda: pi_hex_digits(3.14))
128
+
129
+ # this will pick a random segment to compute every time
130
+ # it is run. If it ever fails, there is an error in the
131
+ # computation.
132
+ n = randint(0, len(dig))
133
+ prec = randint(0, len(dig) - n)
134
+ assert pi_hex_digits(n, prec) == dig[n: n + prec]
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_continued_fraction.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from sympy.core import GoldenRatio as phi
3
+ from sympy.core.numbers import (Rational, pi)
4
+ from sympy.core.singleton import S
5
+ from sympy.functions.elementary.miscellaneous import sqrt
6
+ from sympy.ntheory.continued_fraction import \
7
+ (continued_fraction_periodic as cf_p,
8
+ continued_fraction_iterator as cf_i,
9
+ continued_fraction_convergents as cf_c,
10
+ continued_fraction_reduce as cf_r,
11
+ continued_fraction as cf)
12
+ from sympy.testing.pytest import raises
13
+
14
+
15
+ def test_continued_fraction():
16
+ assert cf_p(1, 1, 10, 0) == cf_p(1, 1, 0, 1)
17
+ assert cf_p(1, -1, 10, 1) == cf_p(-1, 1, 10, -1)
18
+ t = sqrt(2)
19
+ assert cf((1 + t)*(1 - t)) == cf(-1)
20
+ for n in [0, 2, Rational(2, 3), sqrt(2), 3*sqrt(2), 1 + 2*sqrt(3)/5,
21
+ (2 - 3*sqrt(5))/7, 1 + sqrt(2), (-5 + sqrt(17))/4]:
22
+ assert (cf_r(cf(n)) - n).expand() == 0
23
+ assert (cf_r(cf(-n)) + n).expand() == 0
24
+ raises(ValueError, lambda: cf(sqrt(2 + sqrt(3))))
25
+ raises(ValueError, lambda: cf(sqrt(2) + sqrt(3)))
26
+ raises(ValueError, lambda: cf(pi))
27
+ raises(ValueError, lambda: cf(.1))
28
+
29
+ raises(ValueError, lambda: cf_p(1, 0, 0))
30
+ raises(ValueError, lambda: cf_p(1, 1, -1))
31
+ assert cf_p(4, 3, 0) == [1, 3]
32
+ assert cf_p(0, 3, 5) == [0, 1, [2, 1, 12, 1, 2, 2]]
33
+ assert cf_p(1, 1, 0) == [1]
34
+ assert cf_p(3, 4, 0) == [0, 1, 3]
35
+ assert cf_p(4, 5, 0) == [0, 1, 4]
36
+ assert cf_p(5, 6, 0) == [0, 1, 5]
37
+ assert cf_p(11, 13, 0) == [0, 1, 5, 2]
38
+ assert cf_p(16, 19, 0) == [0, 1, 5, 3]
39
+ assert cf_p(27, 32, 0) == [0, 1, 5, 2, 2]
40
+ assert cf_p(1, 2, 5) == [[1]]
41
+ assert cf_p(0, 1, 2) == [1, [2]]
42
+ assert cf_p(6, 7, 49) == [1, 1, 6]
43
+ assert cf_p(3796, 1387, 0) == [2, 1, 2, 1, 4]
44
+ assert cf_p(3245, 10000) == [0, 3, 12, 4, 13]
45
+ assert cf_p(1932, 2568) == [0, 1, 3, 26, 2]
46
+ assert cf_p(6589, 2569) == [2, 1, 1, 3, 2, 1, 3, 1, 23]
47
+
48
+ def take(iterator, n=7):
49
+ return list(itertools.islice(iterator, n))
50
+
51
+ assert take(cf_i(phi)) == [1, 1, 1, 1, 1, 1, 1]
52
+ assert take(cf_i(pi)) == [3, 7, 15, 1, 292, 1, 1]
53
+
54
+ assert list(cf_i(Rational(17, 12))) == [1, 2, 2, 2]
55
+ assert list(cf_i(Rational(-17, 12))) == [-2, 1, 1, 2, 2]
56
+
57
+ assert list(cf_c([1, 6, 1, 8])) == [S.One, Rational(7, 6), Rational(8, 7), Rational(71, 62)]
58
+ assert list(cf_c([2])) == [S(2)]
59
+ assert list(cf_c([1, 1, 1, 1, 1, 1, 1])) == [S.One, S(2), Rational(3, 2), Rational(5, 3),
60
+ Rational(8, 5), Rational(13, 8), Rational(21, 13)]
61
+ assert list(cf_c([1, 6, Rational(-1, 2), 4])) == [S.One, Rational(7, 6), Rational(5, 4), Rational(3, 2)]
62
+ assert take(cf_c([[1]])) == [S.One, S(2), Rational(3, 2), Rational(5, 3), Rational(8, 5),
63
+ Rational(13, 8), Rational(21, 13)]
64
+ assert take(cf_c([1, [1, 2]])) == [S.One, S(2), Rational(5, 3), Rational(7, 4), Rational(19, 11),
65
+ Rational(26, 15), Rational(71, 41)]
66
+
67
+ cf_iter_e = (2 if i == 1 else i // 3 * 2 if i % 3 == 0 else 1 for i in itertools.count(1))
68
+ assert take(cf_c(cf_iter_e)) == [S(2), S(3), Rational(8, 3), Rational(11, 4), Rational(19, 7),
69
+ Rational(87, 32), Rational(106, 39)]
70
+
71
+ assert cf_r([1, 6, 1, 8]) == Rational(71, 62)
72
+ assert cf_r([3]) == S(3)
73
+ assert cf_r([-1, 5, 1, 4]) == Rational(-24, 29)
74
+ assert (cf_r([0, 1, 1, 7, [24, 8]]) - (sqrt(3) + 2)/7).expand() == 0
75
+ assert cf_r([1, 5, 9]) == Rational(55, 46)
76
+ assert (cf_r([[1]]) - (sqrt(5) + 1)/2).expand() == 0
77
+ assert cf_r([-3, 1, 1, [2]]) == -1 - sqrt(2)
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_digits.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.ntheory import count_digits, digits, is_palindromic
2
+ from sympy.core.intfunc import num_digits
3
+
4
+ from sympy.testing.pytest import raises
5
+
6
+
7
+ def test_num_digits():
8
+ # depending on whether one rounds up or down or uses log or log10,
9
+ # one or more of these will fail if you don't check for the off-by
10
+ # one condition
11
+ assert num_digits(2, 2) == 2
12
+ assert num_digits(2**48 - 1, 2) == 48
13
+ assert num_digits(1000, 10) == 4
14
+ assert num_digits(125, 5) == 4
15
+ assert num_digits(100, 16) == 2
16
+ assert num_digits(-1000, 10) == 4
17
+ # if changes are made to the function, this structured test over
18
+ # this range will expose problems
19
+ for base in range(2, 100):
20
+ for e in range(1, 100):
21
+ n = base**e
22
+ assert num_digits(n, base) == e + 1
23
+ assert num_digits(n + 1, base) == e + 1
24
+ assert num_digits(n - 1, base) == e
25
+
26
+
27
+ def test_digits():
28
+ assert all(digits(n, 2)[1:] == [int(d) for d in format(n, 'b')]
29
+ for n in range(20))
30
+ assert all(digits(n, 8)[1:] == [int(d) for d in format(n, 'o')]
31
+ for n in range(20))
32
+ assert all(digits(n, 16)[1:] == [int(d, 16) for d in format(n, 'x')]
33
+ for n in range(20))
34
+ assert digits(2345, 34) == [34, 2, 0, 33]
35
+ assert digits(384753, 71) == [71, 1, 5, 23, 4]
36
+ assert digits(93409, 10) == [10, 9, 3, 4, 0, 9]
37
+ assert digits(-92838, 11) == [-11, 6, 3, 8, 2, 9]
38
+ assert digits(35, 10) == [10, 3, 5]
39
+ assert digits(35, 10, 3) == [10, 0, 3, 5]
40
+ assert digits(-35, 10, 4) == [-10, 0, 0, 3, 5]
41
+ raises(ValueError, lambda: digits(2, 2, 1))
42
+
43
+
44
+ def test_count_digits():
45
+ assert count_digits(55, 2) == {1: 5, 0: 1}
46
+ assert count_digits(55, 10) == {5: 2}
47
+ n = count_digits(123)
48
+ assert n[4] == 0 and type(n[4]) is int
49
+
50
+
51
+ def test_is_palindromic():
52
+ assert is_palindromic(-11)
53
+ assert is_palindromic(11)
54
+ assert is_palindromic(0o121, 8)
55
+ assert not is_palindromic(123)
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_ecm.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.external.gmpy import invert
2
+ from sympy.ntheory.ecm import ecm, Point
3
+ from sympy.testing.pytest import slow
4
+
5
+ @slow
6
+ def test_ecm():
7
+ assert ecm(3146531246531241245132451321) == {3, 100327907731, 10454157497791297}
8
+ assert ecm(46167045131415113) == {43, 2634823, 407485517}
9
+ assert ecm(631211032315670776841) == {9312934919, 67777885039}
10
+ assert ecm(398883434337287) == {99476569, 4009823}
11
+ assert ecm(64211816600515193) == {281719, 359641, 633767}
12
+ assert ecm(4269021180054189416198169786894227) == {184039, 241603, 333331, 477973, 618619, 974123}
13
+ assert ecm(4516511326451341281684513) == {3, 39869, 131743543, 95542348571}
14
+ assert ecm(4132846513818654136451) == {47, 160343, 2802377, 195692803}
15
+ assert ecm(168541512131094651323) == {79, 113, 11011069, 1714635721}
16
+ #This takes ~10secs while factorint is not able to factorize this even in ~10mins
17
+ assert ecm(7060005655815754299976961394452809, B1=100000, B2=1000000) == {6988699669998001, 1010203040506070809}
18
+
19
+
20
+ def test_Point():
21
+ #The curve is of the form y**2 = x**3 + a*x**2 + x
22
+ mod = 101
23
+ a = 10
24
+ a_24 = (a + 2)*invert(4, mod)
25
+ p1 = Point(10, 17, a_24, mod)
26
+ p2 = p1.double()
27
+ assert p2 == Point(68, 56, a_24, mod)
28
+ p4 = p2.double()
29
+ assert p4 == Point(22, 64, a_24, mod)
30
+ p8 = p4.double()
31
+ assert p8 == Point(71, 95, a_24, mod)
32
+ p16 = p8.double()
33
+ assert p16 == Point(5, 16, a_24, mod)
34
+ p32 = p16.double()
35
+ assert p32 == Point(33, 96, a_24, mod)
36
+
37
+ # p3 = p2 + p1
38
+ p3 = p2.add(p1, p1)
39
+ assert p3 == Point(1, 61, a_24, mod)
40
+ # p5 = p3 + p2 or p4 + p1
41
+ p5 = p3.add(p2, p1)
42
+ assert p5 == Point(49, 90, a_24, mod)
43
+ assert p5 == p4.add(p1, p3)
44
+ # p6 = 2*p3
45
+ p6 = p3.double()
46
+ assert p6 == Point(87, 43, a_24, mod)
47
+ assert p6 == p4.add(p2, p2)
48
+ # p7 = p5 + p2
49
+ p7 = p5.add(p2, p3)
50
+ assert p7 == Point(69, 23, a_24, mod)
51
+ assert p7 == p4.add(p3, p1)
52
+ assert p7 == p6.add(p1, p5)
53
+ # p9 = p5 + p4
54
+ p9 = p5.add(p4, p1)
55
+ assert p9 == Point(56, 99, a_24, mod)
56
+ assert p9 == p6.add(p3, p3)
57
+ assert p9 == p7.add(p2, p5)
58
+ assert p9 == p8.add(p1, p7)
59
+
60
+ assert p5 == p1.mont_ladder(5)
61
+ assert p9 == p1.mont_ladder(9)
62
+ assert p16 == p1.mont_ladder(16)
63
+ assert p9 == p3.mont_ladder(3)
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_egyptian_fraction.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.numbers import Rational
2
+ from sympy.ntheory.egyptian_fraction import egyptian_fraction
3
+ from sympy.core.add import Add
4
+ from sympy.testing.pytest import raises
5
+ from sympy.core.random import random_complex_number
6
+
7
+
8
+ def test_egyptian_fraction():
9
+ def test_equality(r, alg="Greedy"):
10
+ return r == Add(*[Rational(1, i) for i in egyptian_fraction(r, alg)])
11
+
12
+ r = random_complex_number(a=0, c=1, b=0, d=0, rational=True)
13
+ assert test_equality(r)
14
+
15
+ assert egyptian_fraction(Rational(4, 17)) == [5, 29, 1233, 3039345]
16
+ assert egyptian_fraction(Rational(7, 13), "Greedy") == [2, 26]
17
+ assert egyptian_fraction(Rational(23, 101), "Greedy") == \
18
+ [5, 37, 1438, 2985448, 40108045937720]
19
+ assert egyptian_fraction(Rational(18, 23), "Takenouchi") == \
20
+ [2, 6, 12, 35, 276, 2415]
21
+ assert egyptian_fraction(Rational(5, 6), "Graham Jewett") == \
22
+ [6, 7, 8, 9, 10, 42, 43, 44, 45, 56, 57, 58, 72, 73, 90, 1806, 1807,
23
+ 1808, 1892, 1893, 1980, 3192, 3193, 3306, 5256, 3263442, 3263443,
24
+ 3267056, 3581556, 10192056, 10650056950806]
25
+ assert egyptian_fraction(Rational(5, 6), "Golomb") == [2, 6, 12, 20, 30]
26
+ assert egyptian_fraction(Rational(5, 121), "Golomb") == [25, 1225, 3577, 7081, 11737]
27
+ raises(ValueError, lambda: egyptian_fraction(Rational(-4, 9)))
28
+ assert egyptian_fraction(Rational(8, 3), "Golomb") == [1, 2, 3, 4, 5, 6, 7,
29
+ 14, 574, 2788, 6460,
30
+ 11590, 33062, 113820]
31
+ assert egyptian_fraction(Rational(355, 113)) == [1, 2, 3, 4, 5, 6, 7, 8, 9,
32
+ 10, 11, 12, 27, 744, 893588,
33
+ 1251493536607,
34
+ 20361068938197002344405230]
35
+
36
+
37
+ def test_input():
38
+ r = (2,3), Rational(2, 3), (Rational(2), Rational(3))
39
+ for m in ["Greedy", "Graham Jewett", "Takenouchi", "Golomb"]:
40
+ for i in r:
41
+ d = egyptian_fraction(i, m)
42
+ assert all(i.is_Integer for i in d)
43
+ if m == "Graham Jewett":
44
+ assert d == [3, 4, 12]
45
+ else:
46
+ assert d == [2, 6]
47
+ # check prefix
48
+ d = egyptian_fraction(Rational(5, 3))
49
+ assert d == [1, 2, 6] and all(i.is_Integer for i in d)
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_elliptic_curve.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.ntheory.elliptic_curve import EllipticCurve
2
+
3
+
4
+ def test_elliptic_curve():
5
+ # Point addition and multiplication
6
+ e3 = EllipticCurve(-1, 9)
7
+ p = e3(0, 3)
8
+ q = e3(-1, 3)
9
+ r = p + q
10
+ assert r.x == 1 and r.y == -3
11
+ r = 2*p + q
12
+ assert r.x == 35 and r.y == 207
13
+ r = -p + q
14
+ assert r.x == 37 and r.y == 225
15
+ # Verify result in http://www.lmfdb.org/EllipticCurve/Q
16
+ # Discriminant
17
+ assert EllipticCurve(-1, 9).discriminant == -34928
18
+ assert EllipticCurve(-2731, -55146, 1, 0, 1).discriminant == 25088
19
+ # Torsion points
20
+ assert len(EllipticCurve(0, 1).torsion_points()) == 6
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_factor_.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.containers import Dict
2
+ from sympy.core.mul import Mul
3
+ from sympy.core.power import Pow
4
+ from sympy.core.singleton import S
5
+ from sympy.functions.combinatorial.factorials import factorial as fac
6
+ from sympy.core.numbers import Integer, Rational
7
+ from sympy.external.gmpy import gcd
8
+
9
+ from sympy.ntheory import (totient,
10
+ factorint, primefactors, divisors, nextprime,
11
+ pollard_rho, perfect_power, multiplicity, multiplicity_in_factorial,
12
+ divisor_count, primorial, pollard_pm1, divisor_sigma,
13
+ factorrat, reduced_totient)
14
+ from sympy.ntheory.factor_ import (smoothness, smoothness_p, proper_divisors,
15
+ antidivisors, antidivisor_count, _divisor_sigma, core, udivisors, udivisor_sigma,
16
+ udivisor_count, proper_divisor_count, primenu, primeomega,
17
+ mersenne_prime_exponent, is_perfect, is_abundant,
18
+ is_deficient, is_amicable, is_carmichael, find_carmichael_numbers_in_range,
19
+ find_first_n_carmichaels, dra, drm, _perfect_power, factor_cache)
20
+
21
+ from sympy.testing.pytest import raises, slow
22
+
23
+ from sympy.utilities.iterables import capture
24
+
25
+
26
+ def fac_multiplicity(n, p):
27
+ """Return the power of the prime number p in the
28
+ factorization of n!"""
29
+ if p > n:
30
+ return 0
31
+ if p > n//2:
32
+ return 1
33
+ q, m = n, 0
34
+ while q >= p:
35
+ q //= p
36
+ m += q
37
+ return m
38
+
39
+
40
+ def multiproduct(seq=(), start=1):
41
+ """
42
+ Return the product of a sequence of factors with multiplicities,
43
+ times the value of the parameter ``start``. The input may be a
44
+ sequence of (factor, exponent) pairs or a dict of such pairs.
45
+
46
+ >>> multiproduct({3:7, 2:5}, 4) # = 3**7 * 2**5 * 4
47
+ 279936
48
+
49
+ """
50
+ if not seq:
51
+ return start
52
+ if isinstance(seq, dict):
53
+ seq = iter(seq.items())
54
+ units = start
55
+ multi = []
56
+ for base, exp in seq:
57
+ if not exp:
58
+ continue
59
+ elif exp == 1:
60
+ units *= base
61
+ else:
62
+ if exp % 2:
63
+ units *= base
64
+ multi.append((base, exp//2))
65
+ return units * multiproduct(multi)**2
66
+
67
+
68
+ def test_multiplicity():
69
+ for b in range(2, 20):
70
+ for i in range(100):
71
+ assert multiplicity(b, b**i) == i
72
+ assert multiplicity(b, (b**i) * 23) == i
73
+ assert multiplicity(b, (b**i) * 1000249) == i
74
+ # Should be fast
75
+ assert multiplicity(10, 10**10023) == 10023
76
+ # Should exit quickly
77
+ assert multiplicity(10**10, 10**10) == 1
78
+ # Should raise errors for bad input
79
+ raises(ValueError, lambda: multiplicity(1, 1))
80
+ raises(ValueError, lambda: multiplicity(1, 2))
81
+ raises(ValueError, lambda: multiplicity(1.3, 2))
82
+ raises(ValueError, lambda: multiplicity(2, 0))
83
+ raises(ValueError, lambda: multiplicity(1.3, 0))
84
+
85
+ # handles Rationals
86
+ assert multiplicity(10, Rational(30, 7)) == 1
87
+ assert multiplicity(Rational(2, 7), Rational(4, 7)) == 1
88
+ assert multiplicity(Rational(1, 7), Rational(3, 49)) == 2
89
+ assert multiplicity(Rational(2, 7), Rational(7, 2)) == -1
90
+ assert multiplicity(3, Rational(1, 9)) == -2
91
+
92
+
93
+ def test_multiplicity_in_factorial():
94
+ n = fac(1000)
95
+ for i in (2, 4, 6, 12, 30, 36, 48, 60, 72, 96):
96
+ assert multiplicity(i, n) == multiplicity_in_factorial(i, 1000)
97
+
98
+
99
+ def test_private_perfect_power():
100
+ assert _perfect_power(0) is False
101
+ assert _perfect_power(1) is False
102
+ assert _perfect_power(2) is False
103
+ assert _perfect_power(3) is False
104
+ for x in [2, 3, 5, 6, 7, 12, 15, 105, 100003]:
105
+ for y in range(2, 100):
106
+ assert _perfect_power(x**y) == (x, y)
107
+ if x & 1:
108
+ assert _perfect_power(x**y, next_p=3) == (x, y)
109
+ if x == 100003:
110
+ assert _perfect_power(x**y, next_p=100003) == (x, y)
111
+ assert _perfect_power(101*x**y) == False
112
+ # Catalan's conjecture
113
+ if x**y not in [8, 9]:
114
+ assert _perfect_power(x**y + 1) == False
115
+ assert _perfect_power(x**y - 1) == False
116
+ for x in range(1, 10):
117
+ for y in range(1, 10):
118
+ g = gcd(x, y)
119
+ if g == 1:
120
+ assert _perfect_power(5**x * 101**y) == False
121
+ else:
122
+ assert _perfect_power(5**x * 101**y) == (5**(x//g) * 101**(y//g), g)
123
+
124
+
125
+ def test_perfect_power():
126
+ raises(ValueError, lambda: perfect_power(0.1))
127
+ assert perfect_power(0) is False
128
+ assert perfect_power(1) is False
129
+ assert perfect_power(2) is False
130
+ assert perfect_power(3) is False
131
+ assert perfect_power(4) == (2, 2)
132
+ assert perfect_power(14) is False
133
+ assert perfect_power(25) == (5, 2)
134
+ assert perfect_power(22) is False
135
+ assert perfect_power(22, [2]) is False
136
+ assert perfect_power(137**(3*5*13)) == (137, 3*5*13)
137
+ assert perfect_power(137**(3*5*13) + 1) is False
138
+ assert perfect_power(137**(3*5*13) - 1) is False
139
+ assert perfect_power(103005006004**7) == (103005006004, 7)
140
+ assert perfect_power(103005006004**7 + 1) is False
141
+ assert perfect_power(103005006004**7 - 1) is False
142
+ assert perfect_power(103005006004**12) == (103005006004, 12)
143
+ assert perfect_power(103005006004**12 + 1) is False
144
+ assert perfect_power(103005006004**12 - 1) is False
145
+ assert perfect_power(2**10007) == (2, 10007)
146
+ assert perfect_power(2**10007 + 1) is False
147
+ assert perfect_power(2**10007 - 1) is False
148
+ assert perfect_power((9**99 + 1)**60) == (9**99 + 1, 60)
149
+ assert perfect_power((9**99 + 1)**60 + 1) is False
150
+ assert perfect_power((9**99 + 1)**60 - 1) is False
151
+ assert perfect_power((10**40000)**2, big=False) == (10**40000, 2)
152
+ assert perfect_power(10**100000) == (10, 100000)
153
+ assert perfect_power(10**100001) == (10, 100001)
154
+ assert perfect_power(13**4, [3, 5]) is False
155
+ assert perfect_power(3**4, [3, 10], factor=0) is False
156
+ assert perfect_power(3**3*5**3) == (15, 3)
157
+ assert perfect_power(2**3*5**5) is False
158
+ assert perfect_power(2*13**4) is False
159
+ assert perfect_power(2**5*3**3) is False
160
+ t = 2**24
161
+ for d in divisors(24):
162
+ m = perfect_power(t*3**d)
163
+ assert m and m[1] == d or d == 1
164
+ m = perfect_power(t*3**d, big=False)
165
+ assert m and m[1] == 2 or d == 1 or d == 3, (d, m)
166
+
167
+ # negatives and non-integer rationals
168
+ assert perfect_power(-4) is False
169
+ assert perfect_power(-8) == (-2, 3)
170
+ assert perfect_power(-S(1)/8) == (-S(1)/2, 3)
171
+ assert perfect_power(S(1)/3) == False
172
+ assert perfect_power(-5**15) == (-5, 15)
173
+ assert perfect_power(-5**15, big=False) == (-3125, 3)
174
+ assert perfect_power(-5**15, [15]) == (-5, 15)
175
+
176
+ n = -3 ** 60
177
+ assert perfect_power(n) == (-81, 15)
178
+ assert perfect_power(n, big=False) == (-3486784401, 3)
179
+ assert perfect_power(n, [3, 5], big=True) == (-531441, 5)
180
+ assert perfect_power(n, [3, 5], big=False) == (-3486784401, 3)
181
+ assert perfect_power(n, [2]) == False
182
+ assert perfect_power(n, [2, 15]) == (-81, 15)
183
+ assert perfect_power(n, [2, 13]) == False
184
+ assert perfect_power(n, [17]) == False
185
+ assert perfect_power(n, [3]) == (-3486784401, 3)
186
+ assert perfect_power(n + 1) == False
187
+
188
+ r = S(2) ** (2 * 5 * 7) / S(3) ** (2 * 7)
189
+ assert perfect_power(r) == (S(32) / 3, 14)
190
+ assert perfect_power(-r) == (-S(1024) / 9, 7)
191
+ assert perfect_power(r, big=False) == (S(34359738368) / 2187, 2)
192
+ assert perfect_power(r, [2, 5]) == (S(34359738368) / 2187, 2)
193
+ assert perfect_power(r, [5, 7]) == (S(1024) / 9, 7)
194
+ assert perfect_power(r, [5, 7], big=False) == (S(1024) / 9, 7)
195
+ assert perfect_power(r, [2, 5, 7], big=False) == (S(34359738368) / 2187, 2)
196
+ assert perfect_power(-r, [5, 7], big=False) == (-S(1024) / 9, 7)
197
+
198
+ assert perfect_power(-S(1) / 8) == (-S(1) / 2, 3)
199
+
200
+ assert perfect_power((-3)**60) == (3, 60)
201
+ assert perfect_power((-3)**61) == (-3, 61)
202
+
203
+ assert perfect_power(S(2 ** 9) / 3 ** 12) == (S(8)/81, 3)
204
+ assert perfect_power(Rational(1, 2)**3) == (S.Half, 3)
205
+ assert perfect_power(Rational(-3, 2)**3) == (-3*S.Half, 3)
206
+
207
+
208
+ def test_factor_cache():
209
+ factor_cache.cache_clear()
210
+ raises(ValueError, lambda: factor_cache.__setitem__(1, 5))
211
+ raises(ValueError, lambda: factor_cache.__setitem__(10, 1))
212
+ raises(ValueError, lambda: factor_cache.__setitem__(10, 10))
213
+ raises(ValueError, lambda: factor_cache.__setitem__(10, 3))
214
+ raises(ValueError, lambda: factor_cache.__setitem__(20, 4))
215
+ factor_cache.maxsize = 3
216
+ for i in range(2, 10):
217
+ factor_cache[5*i] = 5
218
+ assert len(factor_cache) == 3
219
+ factor_cache.maxsize = 5
220
+ for i in range(2, 10):
221
+ factor_cache[5*i] = 5
222
+ assert len(factor_cache) == 5
223
+ factor_cache.maxsize = 2
224
+ assert len(factor_cache) == 2
225
+ factor_cache.maxsize =1000
226
+
227
+ factor_cache.cache_clear()
228
+ factor_cache[40] = 5
229
+ assert factor_cache.get(40) == 5
230
+ assert factor_cache.get(20) is None
231
+ assert factor_cache[40] == 5
232
+ raises(KeyError, lambda: factor_cache[10])
233
+ del factor_cache[40]
234
+ assert len(factor_cache) == 0
235
+ raises(KeyError, lambda: factor_cache.__delitem__(40))
236
+ factor_cache.add(100, [5, 2])
237
+ assert len(factor_cache) == 2
238
+ assert factor_cache[100] == 5
239
+
240
+ for n in [1000000007, 10000019*20000003]:
241
+ factorint(n)
242
+ assert n in factor_cache
243
+
244
+ # Restore the initial state
245
+ factor_cache.cache_clear()
246
+ factor_cache.maxsize = 1000
247
+
248
+
249
+ @slow
250
+ def test_factorint():
251
+ assert primefactors(123456) == [2, 3, 643]
252
+ assert factorint(0) == {0: 1}
253
+ assert factorint(1) == {}
254
+ assert factorint(-1) == {-1: 1}
255
+ assert factorint(-2) == {-1: 1, 2: 1}
256
+ assert factorint(-16) == {-1: 1, 2: 4}
257
+ assert factorint(2) == {2: 1}
258
+ assert factorint(126) == {2: 1, 3: 2, 7: 1}
259
+ assert factorint(123456) == {2: 6, 3: 1, 643: 1}
260
+ assert factorint(5951757) == {3: 1, 7: 1, 29: 2, 337: 1}
261
+ assert factorint(64015937) == {7993: 1, 8009: 1}
262
+ assert factorint(2**(2**6) + 1) == {274177: 1, 67280421310721: 1}
263
+ #issue 19683
264
+ assert factorint(10**38 - 1) == {3: 2, 11: 1, 909090909090909091: 1, 1111111111111111111: 1}
265
+ #issue 17676
266
+ assert factorint(28300421052393658575) == {3: 1, 5: 2, 11: 2, 43: 1, 2063: 2, 4127: 1, 4129: 1}
267
+ assert factorint(2063**2 * 4127**1 * 4129**1) == {2063: 2, 4127: 1, 4129: 1}
268
+ assert factorint(2347**2 * 7039**1 * 7043**1) == {2347: 2, 7039: 1, 7043: 1}
269
+
270
+ assert factorint(0, multiple=True) == [0]
271
+ assert factorint(1, multiple=True) == []
272
+ assert factorint(-1, multiple=True) == [-1]
273
+ assert factorint(-2, multiple=True) == [-1, 2]
274
+ assert factorint(-16, multiple=True) == [-1, 2, 2, 2, 2]
275
+ assert factorint(2, multiple=True) == [2]
276
+ assert factorint(24, multiple=True) == [2, 2, 2, 3]
277
+ assert factorint(126, multiple=True) == [2, 3, 3, 7]
278
+ assert factorint(123456, multiple=True) == [2, 2, 2, 2, 2, 2, 3, 643]
279
+ assert factorint(5951757, multiple=True) == [3, 7, 29, 29, 337]
280
+ assert factorint(64015937, multiple=True) == [7993, 8009]
281
+ assert factorint(2**(2**6) + 1, multiple=True) == [274177, 67280421310721]
282
+
283
+ assert factorint(fac(1, evaluate=False)) == {}
284
+ assert factorint(fac(7, evaluate=False)) == {2: 4, 3: 2, 5: 1, 7: 1}
285
+ assert factorint(fac(15, evaluate=False)) == \
286
+ {2: 11, 3: 6, 5: 3, 7: 2, 11: 1, 13: 1}
287
+ assert factorint(fac(20, evaluate=False)) == \
288
+ {2: 18, 3: 8, 5: 4, 7: 2, 11: 1, 13: 1, 17: 1, 19: 1}
289
+ assert factorint(fac(23, evaluate=False)) == \
290
+ {2: 19, 3: 9, 5: 4, 7: 3, 11: 2, 13: 1, 17: 1, 19: 1, 23: 1}
291
+
292
+ assert multiproduct(factorint(fac(200))) == fac(200)
293
+ assert multiproduct(factorint(fac(200, evaluate=False))) == fac(200)
294
+ for b, e in factorint(fac(150)).items():
295
+ assert e == fac_multiplicity(150, b)
296
+ for b, e in factorint(fac(150, evaluate=False)).items():
297
+ assert e == fac_multiplicity(150, b)
298
+ assert factorint(103005006059**7) == {103005006059: 7}
299
+ assert factorint(31337**191) == {31337: 191}
300
+ assert factorint(2**1000 * 3**500 * 257**127 * 383**60) == \
301
+ {2: 1000, 3: 500, 257: 127, 383: 60}
302
+ assert len(factorint(fac(10000))) == 1229
303
+ assert len(factorint(fac(10000, evaluate=False))) == 1229
304
+ assert factorint(12932983746293756928584532764589230) == \
305
+ {2: 1, 5: 1, 73: 1, 727719592270351: 1, 63564265087747: 1, 383: 1}
306
+ assert factorint(727719592270351) == {727719592270351: 1}
307
+ assert factorint(2**64 + 1, use_trial=False) == factorint(2**64 + 1)
308
+ for n in range(60000):
309
+ assert multiproduct(factorint(n)) == n
310
+ assert pollard_rho(2**64 + 1, seed=1) == 274177
311
+ assert pollard_rho(19, seed=1) is None
312
+ assert factorint(3, limit=2) == {3: 1}
313
+ assert factorint(12345) == {3: 1, 5: 1, 823: 1}
314
+ assert factorint(
315
+ 12345, limit=3) == {4115: 1, 3: 1} # the 5 is greater than the limit
316
+ assert factorint(1, limit=1) == {}
317
+ assert factorint(0, 3) == {0: 1}
318
+ assert factorint(12, limit=1) == {12: 1}
319
+ assert factorint(30, limit=2) == {2: 1, 15: 1}
320
+ assert factorint(16, limit=2) == {2: 4}
321
+ assert factorint(124, limit=3) == {2: 2, 31: 1}
322
+ assert factorint(4*31**2, limit=3) == {2: 2, 31: 2}
323
+ p1 = nextprime(2**32)
324
+ p2 = nextprime(2**16)
325
+ p3 = nextprime(p2)
326
+ assert factorint(p1*p2*p3) == {p1: 1, p2: 1, p3: 1}
327
+ assert factorint(13*17*19, limit=15) == {13: 1, 17*19: 1}
328
+ assert factorint(1951*15013*15053, limit=2000) == {225990689: 1, 1951: 1}
329
+ assert factorint(primorial(17) + 1, use_pm1=0) == \
330
+ {int(19026377261): 1, 3467: 1, 277: 1, 105229: 1}
331
+ # when prime b is closer than approx sqrt(8*p) to prime p then they are
332
+ # "close" and have a trivial factorization
333
+ a = nextprime(2**2**8) # 78 digits
334
+ b = nextprime(a + 2**2**4)
335
+ assert 'Fermat' in capture(lambda: factorint(a*b, verbose=1))
336
+
337
+ raises(ValueError, lambda: pollard_rho(4))
338
+ raises(ValueError, lambda: pollard_pm1(3))
339
+ raises(ValueError, lambda: pollard_pm1(10, B=2))
340
+ # verbose coverage
341
+ n = nextprime(2**16)*nextprime(2**17)*nextprime(1901)
342
+ assert 'with primes' in capture(lambda: factorint(n, verbose=1))
343
+ capture(lambda: factorint(nextprime(2**16)*1012, verbose=1))
344
+
345
+ n = nextprime(2**17)
346
+ capture(lambda: factorint(n**3, verbose=1)) # perfect power termination
347
+ capture(lambda: factorint(2*n, verbose=1)) # factoring complete msg
348
+
349
+ # exceed 1st
350
+ n = nextprime(2**17)
351
+ n *= nextprime(n)
352
+ assert '1000' in capture(lambda: factorint(n, limit=1000, verbose=1))
353
+ n *= nextprime(n)
354
+ assert len(factorint(n)) == 3
355
+ assert len(factorint(n, limit=p1)) == 3
356
+ n *= nextprime(2*n)
357
+ # exceed 2nd
358
+ assert '2001' in capture(lambda: factorint(n, limit=2000, verbose=1))
359
+ assert capture(
360
+ lambda: factorint(n, limit=4000, verbose=1)).count('Pollard') == 2
361
+ # non-prime pm1 result
362
+ n = nextprime(8069)
363
+ n *= nextprime(2*n)*nextprime(2*n, 2)
364
+ capture(lambda: factorint(n, verbose=1)) # non-prime pm1 result
365
+ # factor fermat composite
366
+ p1 = nextprime(2**17)
367
+ p2 = nextprime(2*p1)
368
+ assert factorint((p1*p2**2)**3) == {p1: 3, p2: 6}
369
+ # Test for non integer input
370
+ raises(ValueError, lambda: factorint(4.5))
371
+ # test dict/Dict input
372
+ sans = '2**10*3**3'
373
+ n = {4: 2, 12: 3}
374
+ assert str(factorint(n)) == sans
375
+ assert str(factorint(Dict(n))) == sans
376
+
377
+
378
+ def test_divisors_and_divisor_count():
379
+ assert divisors(-1) == [1]
380
+ assert divisors(0) == []
381
+ assert divisors(1) == [1]
382
+ assert divisors(2) == [1, 2]
383
+ assert divisors(3) == [1, 3]
384
+ assert divisors(17) == [1, 17]
385
+ assert divisors(10) == [1, 2, 5, 10]
386
+ assert divisors(100) == [1, 2, 4, 5, 10, 20, 25, 50, 100]
387
+ assert divisors(101) == [1, 101]
388
+ assert type(divisors(2, generator=True)) is not list
389
+
390
+ assert divisor_count(0) == 0
391
+ assert divisor_count(-1) == 1
392
+ assert divisor_count(1) == 1
393
+ assert divisor_count(6) == 4
394
+ assert divisor_count(12) == 6
395
+
396
+ assert divisor_count(180, 3) == divisor_count(180//3)
397
+ assert divisor_count(2*3*5, 7) == 0
398
+
399
+
400
+ def test_proper_divisors_and_proper_divisor_count():
401
+ assert proper_divisors(-1) == []
402
+ assert proper_divisors(0) == []
403
+ assert proper_divisors(1) == []
404
+ assert proper_divisors(2) == [1]
405
+ assert proper_divisors(3) == [1]
406
+ assert proper_divisors(17) == [1]
407
+ assert proper_divisors(10) == [1, 2, 5]
408
+ assert proper_divisors(100) == [1, 2, 4, 5, 10, 20, 25, 50]
409
+ assert proper_divisors(1000000007) == [1]
410
+ assert type(proper_divisors(2, generator=True)) is not list
411
+
412
+ assert proper_divisor_count(0) == 0
413
+ assert proper_divisor_count(-1) == 0
414
+ assert proper_divisor_count(1) == 0
415
+ assert proper_divisor_count(36) == 8
416
+ assert proper_divisor_count(2*3*5) == 7
417
+
418
+
419
+ def test_udivisors_and_udivisor_count():
420
+ assert udivisors(-1) == [1]
421
+ assert udivisors(0) == []
422
+ assert udivisors(1) == [1]
423
+ assert udivisors(2) == [1, 2]
424
+ assert udivisors(3) == [1, 3]
425
+ assert udivisors(17) == [1, 17]
426
+ assert udivisors(10) == [1, 2, 5, 10]
427
+ assert udivisors(100) == [1, 4, 25, 100]
428
+ assert udivisors(101) == [1, 101]
429
+ assert udivisors(1000) == [1, 8, 125, 1000]
430
+ assert type(udivisors(2, generator=True)) is not list
431
+
432
+ assert udivisor_count(0) == 0
433
+ assert udivisor_count(-1) == 1
434
+ assert udivisor_count(1) == 1
435
+ assert udivisor_count(6) == 4
436
+ assert udivisor_count(12) == 4
437
+
438
+ assert udivisor_count(180) == 8
439
+ assert udivisor_count(2*3*5*7) == 16
440
+
441
+
442
+ def test_issue_6981():
443
+ S = set(divisors(4)).union(set(divisors(Integer(2))))
444
+ assert S == {1,2,4}
445
+
446
+
447
+ def test_issue_4356():
448
+ assert factorint(1030903) == {53: 2, 367: 1}
449
+
450
+
451
+ def test_divisors():
452
+ assert divisors(28) == [1, 2, 4, 7, 14, 28]
453
+ assert list(divisors(3*5*7, 1)) == [1, 3, 5, 15, 7, 21, 35, 105]
454
+ assert divisors(0) == []
455
+
456
+
457
+ def test_divisor_count():
458
+ assert divisor_count(0) == 0
459
+ assert divisor_count(6) == 4
460
+
461
+
462
+ def test_proper_divisors():
463
+ assert proper_divisors(-1) == []
464
+ assert proper_divisors(28) == [1, 2, 4, 7, 14]
465
+ assert list(proper_divisors(3*5*7, True)) == [1, 3, 5, 15, 7, 21, 35]
466
+
467
+
468
+ def test_proper_divisor_count():
469
+ assert proper_divisor_count(6) == 3
470
+ assert proper_divisor_count(108) == 11
471
+
472
+
473
+ def test_antidivisors():
474
+ assert antidivisors(-1) == []
475
+ assert antidivisors(-3) == [2]
476
+ assert antidivisors(14) == [3, 4, 9]
477
+ assert antidivisors(237) == [2, 5, 6, 11, 19, 25, 43, 95, 158]
478
+ assert antidivisors(12345) == [2, 6, 7, 10, 30, 1646, 3527, 4938, 8230]
479
+ assert antidivisors(393216) == [262144]
480
+ assert sorted(x for x in antidivisors(3*5*7, 1)) == \
481
+ [2, 6, 10, 11, 14, 19, 30, 42, 70]
482
+ assert antidivisors(1) == []
483
+ assert type(antidivisors(2, generator=True)) is not list
484
+
485
+ def test_antidivisor_count():
486
+ assert antidivisor_count(0) == 0
487
+ assert antidivisor_count(-1) == 0
488
+ assert antidivisor_count(-4) == 1
489
+ assert antidivisor_count(20) == 3
490
+ assert antidivisor_count(25) == 5
491
+ assert antidivisor_count(38) == 7
492
+ assert antidivisor_count(180) == 6
493
+ assert antidivisor_count(2*3*5) == 3
494
+
495
+
496
+ def test_smoothness_and_smoothness_p():
497
+ assert smoothness(1) == (1, 1)
498
+ assert smoothness(2**4*3**2) == (3, 16)
499
+
500
+ assert smoothness_p(10431, m=1) == \
501
+ (1, [(3, (2, 2, 4)), (19, (1, 5, 5)), (61, (1, 31, 31))])
502
+ assert smoothness_p(10431) == \
503
+ (-1, [(3, (2, 2, 2)), (19, (1, 3, 9)), (61, (1, 5, 5))])
504
+ assert smoothness_p(10431, power=1) == \
505
+ (-1, [(3, (2, 2, 2)), (61, (1, 5, 5)), (19, (1, 3, 9))])
506
+ assert smoothness_p(21477639576571, visual=1) == \
507
+ 'p**i=4410317**1 has p-1 B=1787, B-pow=1787\n' + \
508
+ 'p**i=4869863**1 has p-1 B=2434931, B-pow=2434931'
509
+
510
+
511
+ def test_visual_factorint():
512
+ assert factorint(1, visual=1) == 1
513
+ forty2 = factorint(42, visual=True)
514
+ assert type(forty2) == Mul
515
+ assert str(forty2) == '2**1*3**1*7**1'
516
+ assert factorint(1, visual=True) is S.One
517
+ no = {"evaluate": False}
518
+ assert factorint(42**2, visual=True) == Mul(Pow(2, 2, **no),
519
+ Pow(3, 2, **no),
520
+ Pow(7, 2, **no), **no)
521
+ assert -1 in factorint(-42, visual=True).args
522
+
523
+
524
+ def test_factorrat():
525
+ assert str(factorrat(S(12)/1, visual=True)) == '2**2*3**1'
526
+ assert str(factorrat(Rational(1, 1), visual=True)) == '1'
527
+ assert str(factorrat(S(25)/14, visual=True)) == '5**2/(2*7)'
528
+ assert str(factorrat(Rational(25, 14), visual=True)) == '5**2/(2*7)'
529
+ assert str(factorrat(S(-25)/14/9, visual=True)) == '-1*5**2/(2*3**2*7)'
530
+
531
+ assert factorrat(S(12)/1, multiple=True) == [2, 2, 3]
532
+ assert factorrat(Rational(1, 1), multiple=True) == []
533
+ assert factorrat(S(25)/14, multiple=True) == [Rational(1, 7), S.Half, 5, 5]
534
+ assert factorrat(Rational(25, 14), multiple=True) == [Rational(1, 7), S.Half, 5, 5]
535
+ assert factorrat(Rational(12, 1), multiple=True) == [2, 2, 3]
536
+ assert factorrat(S(-25)/14/9, multiple=True) == \
537
+ [-1, Rational(1, 7), Rational(1, 3), Rational(1, 3), S.Half, 5, 5]
538
+
539
+
540
+ def test_visual_io():
541
+ sm = smoothness_p
542
+ fi = factorint
543
+ # with smoothness_p
544
+ n = 124
545
+ d = fi(n)
546
+ m = fi(d, visual=True)
547
+ t = sm(n)
548
+ s = sm(t)
549
+ for th in [d, s, t, n, m]:
550
+ assert sm(th, visual=True) == s
551
+ assert sm(th, visual=1) == s
552
+ for th in [d, s, t, n, m]:
553
+ assert sm(th, visual=False) == t
554
+ assert [sm(th, visual=None) for th in [d, s, t, n, m]] == [s, d, s, t, t]
555
+ assert [sm(th, visual=2) for th in [d, s, t, n, m]] == [s, d, s, t, t]
556
+
557
+ # with factorint
558
+ for th in [d, m, n]:
559
+ assert fi(th, visual=True) == m
560
+ assert fi(th, visual=1) == m
561
+ for th in [d, m, n]:
562
+ assert fi(th, visual=False) == d
563
+ assert [fi(th, visual=None) for th in [d, m, n]] == [m, d, d]
564
+ assert [fi(th, visual=0) for th in [d, m, n]] == [m, d, d]
565
+
566
+ # test reevaluation
567
+ no = {"evaluate": False}
568
+ assert sm({4: 2}, visual=False) == sm(16)
569
+ assert sm(Mul(*[Pow(k, v, **no) for k, v in {4: 2, 2: 6}.items()], **no),
570
+ visual=False) == sm(2**10)
571
+
572
+ assert fi({4: 2}, visual=False) == fi(16)
573
+ assert fi(Mul(*[Pow(k, v, **no) for k, v in {4: 2, 2: 6}.items()], **no),
574
+ visual=False) == fi(2**10)
575
+
576
+
577
+ def test_core():
578
+ assert core(35**13, 10) == 42875
579
+ assert core(210**2) == 1
580
+ assert core(7776, 3) == 36
581
+ assert core(10**27, 22) == 10**5
582
+ assert core(537824) == 14
583
+ assert core(1, 6) == 1
584
+
585
+
586
+ def test__divisor_sigma():
587
+ assert _divisor_sigma(23450) == 50592
588
+ assert _divisor_sigma(23450, 0) == 24
589
+ assert _divisor_sigma(23450, 1) == 50592
590
+ assert _divisor_sigma(23450, 2) == 730747500
591
+ assert _divisor_sigma(23450, 3) == 14666785333344
592
+ A000005 = [1, 2, 2, 3, 2, 4, 2, 4, 3, 4, 2, 6, 2, 4, 4, 5, 2, 6, 2, 6, 4,
593
+ 4, 2, 8, 3, 4, 4, 6, 2, 8, 2, 6, 4, 4, 4, 9, 2, 4, 4, 8, 2, 8]
594
+ for n, val in enumerate(A000005, 1):
595
+ assert _divisor_sigma(n, 0) == val
596
+ A000203 = [1, 3, 4, 7, 6, 12, 8, 15, 13, 18, 12, 28, 14, 24, 24, 31, 18,
597
+ 39, 20, 42, 32, 36, 24, 60, 31, 42, 40, 56, 30, 72, 32, 63, 48]
598
+ for n, val in enumerate(A000203, 1):
599
+ assert _divisor_sigma(n, 1) == val
600
+ A001157 = [1, 5, 10, 21, 26, 50, 50, 85, 91, 130, 122, 210, 170, 250, 260,
601
+ 341, 290, 455, 362, 546, 500, 610, 530, 850, 651, 850, 820, 1050]
602
+ for n, val in enumerate(A001157, 1):
603
+ assert _divisor_sigma(n, 2) == val
604
+
605
+
606
+ def test_mersenne_prime_exponent():
607
+ assert mersenne_prime_exponent(1) == 2
608
+ assert mersenne_prime_exponent(4) == 7
609
+ assert mersenne_prime_exponent(10) == 89
610
+ assert mersenne_prime_exponent(25) == 21701
611
+ raises(ValueError, lambda: mersenne_prime_exponent(52))
612
+ raises(ValueError, lambda: mersenne_prime_exponent(0))
613
+
614
+
615
+ def test_is_perfect():
616
+ assert is_perfect(-6) is False
617
+ assert is_perfect(6) is True
618
+ assert is_perfect(15) is False
619
+ assert is_perfect(28) is True
620
+ assert is_perfect(400) is False
621
+ assert is_perfect(496) is True
622
+ assert is_perfect(8128) is True
623
+ assert is_perfect(10000) is False
624
+
625
+
626
+ def test_is_abundant():
627
+ assert is_abundant(10) is False
628
+ assert is_abundant(12) is True
629
+ assert is_abundant(18) is True
630
+ assert is_abundant(21) is False
631
+ assert is_abundant(945) is True
632
+
633
+
634
+ def test_is_deficient():
635
+ assert is_deficient(10) is True
636
+ assert is_deficient(22) is True
637
+ assert is_deficient(56) is False
638
+ assert is_deficient(20) is False
639
+ assert is_deficient(36) is False
640
+
641
+
642
+ def test_is_amicable():
643
+ assert is_amicable(173, 129) is False
644
+ assert is_amicable(220, 284) is True
645
+ assert is_amicable(8756, 8756) is False
646
+
647
+
648
+ def test_is_carmichael():
649
+ A002997 = [561, 1105, 1729, 2465, 2821, 6601, 8911, 10585, 15841,
650
+ 29341, 41041, 46657, 52633, 62745, 63973, 75361, 101101]
651
+ for n in range(1, 5000):
652
+ assert is_carmichael(n) == (n in A002997)
653
+ for n in A002997:
654
+ assert is_carmichael(n)
655
+
656
+
657
+ def test_find_carmichael_numbers_in_range():
658
+ assert find_carmichael_numbers_in_range(0, 561) == []
659
+ assert find_carmichael_numbers_in_range(561, 562) == [561]
660
+ assert find_carmichael_numbers_in_range(561, 1105) == find_carmichael_numbers_in_range(561, 562)
661
+ raises(ValueError, lambda: find_carmichael_numbers_in_range(-2, 2))
662
+ raises(ValueError, lambda: find_carmichael_numbers_in_range(22, 2))
663
+
664
+
665
+ def test_find_first_n_carmichaels():
666
+ assert find_first_n_carmichaels(0) == []
667
+ assert find_first_n_carmichaels(1) == [561]
668
+ assert find_first_n_carmichaels(2) == [561, 1105]
669
+
670
+
671
+ def test_dra():
672
+ assert dra(19, 12) == 8
673
+ assert dra(2718, 10) == 9
674
+ assert dra(0, 22) == 0
675
+ assert dra(23456789, 10) == 8
676
+ raises(ValueError, lambda: dra(24, -2))
677
+ raises(ValueError, lambda: dra(24.2, 5))
678
+
679
+ def test_drm():
680
+ assert drm(19, 12) == 7
681
+ assert drm(2718, 10) == 2
682
+ assert drm(0, 15) == 0
683
+ assert drm(234161, 10) == 6
684
+ raises(ValueError, lambda: drm(24, -2))
685
+ raises(ValueError, lambda: drm(11.6, 9))
686
+
687
+
688
+ def test_deprecated_ntheory_symbolic_functions():
689
+ from sympy.testing.pytest import warns_deprecated_sympy
690
+
691
+ with warns_deprecated_sympy():
692
+ assert primenu(3) == 1
693
+ with warns_deprecated_sympy():
694
+ assert primeomega(3) == 1
695
+ with warns_deprecated_sympy():
696
+ assert totient(3) == 2
697
+ with warns_deprecated_sympy():
698
+ assert reduced_totient(3) == 2
699
+ with warns_deprecated_sympy():
700
+ assert divisor_sigma(3) == 4
701
+ with warns_deprecated_sympy():
702
+ assert udivisor_sigma(3) == 4
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_generate.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bisect import bisect, bisect_left
2
+
3
+ from sympy.functions.combinatorial.numbers import mobius, totient
4
+ from sympy.ntheory.generate import (sieve, Sieve)
5
+
6
+ from sympy.ntheory import isprime, randprime, nextprime, prevprime, \
7
+ primerange, primepi, prime, primorial, composite, compositepi
8
+ from sympy.ntheory.generate import cycle_length, _primepi
9
+ from sympy.ntheory.primetest import mr
10
+ from sympy.testing.pytest import raises
11
+
12
+ def test_prime():
13
+ assert prime(1) == 2
14
+ assert prime(2) == 3
15
+ assert prime(5) == 11
16
+ assert prime(11) == 31
17
+ assert prime(57) == 269
18
+ assert prime(296) == 1949
19
+ assert prime(559) == 4051
20
+ assert prime(3000) == 27449
21
+ assert prime(4096) == 38873
22
+ assert prime(9096) == 94321
23
+ assert prime(25023) == 287341
24
+ assert prime(10000000) == 179424673 # issue #20951
25
+ assert prime(99999999) == 2038074739
26
+ raises(ValueError, lambda: prime(0))
27
+ sieve.extend(3000)
28
+ assert prime(401) == 2749
29
+ raises(ValueError, lambda: prime(-1))
30
+
31
+
32
+ def test__primepi():
33
+ assert _primepi(-1) == 0
34
+ assert _primepi(1) == 0
35
+ assert _primepi(2) == 1
36
+ assert _primepi(5) == 3
37
+ assert _primepi(11) == 5
38
+ assert _primepi(57) == 16
39
+ assert _primepi(296) == 62
40
+ assert _primepi(559) == 102
41
+ assert _primepi(3000) == 430
42
+ assert _primepi(4096) == 564
43
+ assert _primepi(9096) == 1128
44
+ assert _primepi(25023) == 2763
45
+ assert _primepi(10**8) == 5761455
46
+ assert _primepi(253425253) == 13856396
47
+ assert _primepi(8769575643) == 401464322
48
+ sieve.extend(3000)
49
+ assert _primepi(2000) == 303
50
+
51
+
52
+ def test_composite():
53
+ from sympy.ntheory.generate import sieve
54
+ sieve._reset()
55
+ assert composite(1) == 4
56
+ assert composite(2) == 6
57
+ assert composite(5) == 10
58
+ assert composite(11) == 20
59
+ assert composite(41) == 58
60
+ assert composite(57) == 80
61
+ assert composite(296) == 370
62
+ assert composite(559) == 684
63
+ assert composite(3000) == 3488
64
+ assert composite(4096) == 4736
65
+ assert composite(9096) == 10368
66
+ assert composite(25023) == 28088
67
+ sieve.extend(3000)
68
+ assert composite(1957) == 2300
69
+ assert composite(2568) == 2998
70
+ raises(ValueError, lambda: composite(0))
71
+
72
+
73
+ def test_compositepi():
74
+ assert compositepi(1) == 0
75
+ assert compositepi(2) == 0
76
+ assert compositepi(5) == 1
77
+ assert compositepi(11) == 5
78
+ assert compositepi(57) == 40
79
+ assert compositepi(296) == 233
80
+ assert compositepi(559) == 456
81
+ assert compositepi(3000) == 2569
82
+ assert compositepi(4096) == 3531
83
+ assert compositepi(9096) == 7967
84
+ assert compositepi(25023) == 22259
85
+ assert compositepi(10**8) == 94238544
86
+ assert compositepi(253425253) == 239568856
87
+ assert compositepi(8769575643) == 8368111320
88
+ sieve.extend(3000)
89
+ assert compositepi(2321) == 1976
90
+
91
+
92
+ def test_generate():
93
+ from sympy.ntheory.generate import sieve
94
+ sieve._reset()
95
+ assert nextprime(-4) == 2
96
+ assert nextprime(2) == 3
97
+ assert nextprime(5) == 7
98
+ assert nextprime(12) == 13
99
+ assert prevprime(3) == 2
100
+ assert prevprime(7) == 5
101
+ assert prevprime(13) == 11
102
+ assert prevprime(19) == 17
103
+ assert prevprime(20) == 19
104
+
105
+ sieve.extend_to_no(9)
106
+ assert sieve._list[-1] == 23
107
+
108
+ assert sieve._list[-1] < 31
109
+ assert 31 in sieve
110
+
111
+ assert nextprime(90) == 97
112
+ assert nextprime(10**40) == (10**40 + 121)
113
+ primelist = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31,
114
+ 37, 41, 43, 47, 53, 59, 61, 67, 71, 73,
115
+ 79, 83, 89, 97, 101, 103, 107, 109, 113,
116
+ 127, 131, 137, 139, 149, 151, 157, 163,
117
+ 167, 173, 179, 181, 191, 193, 197, 199,
118
+ 211, 223, 227, 229, 233, 239, 241, 251,
119
+ 257, 263, 269, 271, 277, 281, 283, 293]
120
+ for i in range(len(primelist) - 2):
121
+ for j in range(2, len(primelist) - i):
122
+ assert nextprime(primelist[i], j) == primelist[i + j]
123
+ if 3 < i:
124
+ assert nextprime(primelist[i] - 1, j) == primelist[i + j - 1]
125
+ raises(ValueError, lambda: nextprime(2, 0))
126
+ raises(ValueError, lambda: nextprime(2, -1))
127
+ assert prevprime(97) == 89
128
+ assert prevprime(10**40) == (10**40 - 17)
129
+
130
+ raises(ValueError, lambda: Sieve(0))
131
+ raises(ValueError, lambda: Sieve(-1))
132
+ for sieve_interval in [1, 10, 11, 1_000_000]:
133
+ s = Sieve(sieve_interval=sieve_interval)
134
+ for head in range(s._list[-1] + 1, (s._list[-1] + 1)**2, 2):
135
+ for tail in range(head + 1, (s._list[-1] + 1)**2):
136
+ A = list(s._primerange(head, tail))
137
+ B = primelist[bisect(primelist, head):bisect_left(primelist, tail)]
138
+ assert A == B
139
+ for k in range(s._list[-1], primelist[-1] - 1, 2):
140
+ s = Sieve(sieve_interval=sieve_interval)
141
+ s.extend(k)
142
+ assert list(s._list) == primelist[:bisect(primelist, k)]
143
+ s.extend(primelist[-1])
144
+ assert list(s._list) == primelist
145
+
146
+ assert list(sieve.primerange(10, 1)) == []
147
+ assert list(sieve.primerange(5, 9)) == [5, 7]
148
+ sieve._reset(prime=True)
149
+ assert list(sieve.primerange(2, 13)) == [2, 3, 5, 7, 11]
150
+ assert list(sieve.primerange(13)) == [2, 3, 5, 7, 11]
151
+ assert list(sieve.primerange(8)) == [2, 3, 5, 7]
152
+ assert list(sieve.primerange(-2)) == []
153
+ assert list(sieve.primerange(29)) == [2, 3, 5, 7, 11, 13, 17, 19, 23]
154
+ assert list(sieve.primerange(34)) == [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]
155
+
156
+ assert list(sieve.totientrange(5, 15)) == [4, 2, 6, 4, 6, 4, 10, 4, 12, 6]
157
+ sieve._reset(totient=True)
158
+ assert list(sieve.totientrange(3, 13)) == [2, 2, 4, 2, 6, 4, 6, 4, 10, 4]
159
+ assert list(sieve.totientrange(900, 1000)) == [totient(x) for x in range(900, 1000)]
160
+ assert list(sieve.totientrange(0, 1)) == []
161
+ assert list(sieve.totientrange(1, 2)) == [1]
162
+
163
+ assert list(sieve.mobiusrange(5, 15)) == [-1, 1, -1, 0, 0, 1, -1, 0, -1, 1]
164
+ sieve._reset(mobius=True)
165
+ assert list(sieve.mobiusrange(3, 13)) == [-1, 0, -1, 1, -1, 0, 0, 1, -1, 0]
166
+ assert list(sieve.mobiusrange(1050, 1100)) == [mobius(x) for x in range(1050, 1100)]
167
+ assert list(sieve.mobiusrange(0, 1)) == []
168
+ assert list(sieve.mobiusrange(1, 2)) == [1]
169
+
170
+ assert list(primerange(10, 1)) == []
171
+ assert list(primerange(2, 7)) == [2, 3, 5]
172
+ assert list(primerange(2, 10)) == [2, 3, 5, 7]
173
+ assert list(primerange(1050, 1100)) == [1051, 1061,
174
+ 1063, 1069, 1087, 1091, 1093, 1097]
175
+ s = Sieve()
176
+ for i in range(30, 2350, 376):
177
+ for j in range(2, 5096, 1139):
178
+ A = list(s.primerange(i, i + j))
179
+ B = list(primerange(i, i + j))
180
+ assert A == B
181
+ s = Sieve()
182
+ sieve._reset(prime=True)
183
+ sieve.extend(13)
184
+ for i in range(200):
185
+ for j in range(i, 200):
186
+ A = list(s.primerange(i, j))
187
+ B = list(primerange(i, j))
188
+ assert A == B
189
+ sieve.extend(1000)
190
+ for a, b in [(901, 1103), # a < 1000 < b < 1000**2
191
+ (806, 1002007), # a < 1000 < 1000**2 < b
192
+ (2000, 30001), # 1000 < a < b < 1000**2
193
+ (100005, 1010001), # 1000 < a < 1000**2 < b
194
+ (1003003, 1005000), # 1000**2 < a < b
195
+ ]:
196
+ assert list(primerange(a, b)) == list(s.primerange(a, b))
197
+ sieve._reset(prime=True)
198
+ sieve.extend(100000)
199
+ assert len(sieve._list) == len(set(sieve._list))
200
+ s = Sieve()
201
+ assert s[10] == 29
202
+
203
+ assert nextprime(2, 2) == 5
204
+
205
+ raises(ValueError, lambda: totient(0))
206
+
207
+ raises(ValueError, lambda: primorial(0))
208
+
209
+ assert mr(1, [2]) is False
210
+
211
+ func = lambda i: (i**2 + 1) % 51
212
+ assert next(cycle_length(func, 4)) == (6, 3)
213
+ assert list(cycle_length(func, 4, values=True)) == \
214
+ [4, 17, 35, 2, 5, 26, 14, 44, 50, 2, 5, 26, 14]
215
+ assert next(cycle_length(func, 4, nmax=5)) == (5, None)
216
+ assert list(cycle_length(func, 4, nmax=5, values=True)) == \
217
+ [4, 17, 35, 2, 5]
218
+ sieve.extend(3000)
219
+ assert nextprime(2968) == 2969
220
+ assert prevprime(2930) == 2927
221
+ raises(ValueError, lambda: prevprime(1))
222
+ raises(ValueError, lambda: prevprime(-4))
223
+
224
+
225
+ def test_randprime():
226
+ assert randprime(10, 1) is None
227
+ assert randprime(3, -3) is None
228
+ assert randprime(2, 3) == 2
229
+ assert randprime(1, 3) == 2
230
+ assert randprime(3, 5) == 3
231
+ raises(ValueError, lambda: randprime(-12, -2))
232
+ raises(ValueError, lambda: randprime(-10, 0))
233
+ raises(ValueError, lambda: randprime(20, 22))
234
+ raises(ValueError, lambda: randprime(0, 2))
235
+ raises(ValueError, lambda: randprime(1, 2))
236
+ for a in [100, 300, 500, 250000]:
237
+ for b in [100, 300, 500, 250000]:
238
+ p = randprime(a, a + b)
239
+ assert a <= p < (a + b) and isprime(p)
240
+
241
+
242
+ def test_primorial():
243
+ assert primorial(1) == 2
244
+ assert primorial(1, nth=0) == 1
245
+ assert primorial(2) == 6
246
+ assert primorial(2, nth=0) == 2
247
+ assert primorial(4, nth=0) == 6
248
+
249
+
250
+ def test_search():
251
+ assert 2 in sieve
252
+ assert 2.1 not in sieve
253
+ assert 1 not in sieve
254
+ assert 2**1000 not in sieve
255
+ raises(ValueError, lambda: sieve.search(1))
256
+
257
+
258
+ def test_sieve_slice():
259
+ assert sieve[5] == 11
260
+ assert list(sieve[5:10]) == [sieve[x] for x in range(5, 10)]
261
+ assert list(sieve[5:10:2]) == [sieve[x] for x in range(5, 10, 2)]
262
+ assert list(sieve[1:5]) == [2, 3, 5, 7]
263
+ raises(IndexError, lambda: sieve[:5])
264
+ raises(IndexError, lambda: sieve[0])
265
+ raises(IndexError, lambda: sieve[0:5])
266
+
267
+ def test_sieve_iter():
268
+ values = []
269
+ for value in sieve:
270
+ if value > 7:
271
+ break
272
+ values.append(value)
273
+ assert values == list(sieve[1:5])
274
+
275
+
276
+ def test_sieve_repr():
277
+ assert "sieve" in repr(sieve)
278
+ assert "prime" in repr(sieve)
279
+
280
+
281
+ def test_deprecated_ntheory_symbolic_functions():
282
+ from sympy.testing.pytest import warns_deprecated_sympy
283
+
284
+ with warns_deprecated_sympy():
285
+ assert primepi(0) == 0
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_hypothesis.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hypothesis import given
2
+ from hypothesis import strategies as st
3
+ from sympy import divisors
4
+ from sympy.functions.combinatorial.numbers import divisor_sigma, totient
5
+ from sympy.ntheory.primetest import is_square
6
+
7
+
8
+ @given(n=st.integers(1, 10**10))
9
+ def test_tau_hypothesis(n):
10
+ div = divisors(n)
11
+ tau_n = len(div)
12
+ assert is_square(n) == (tau_n % 2 == 1)
13
+ sigmas = [divisor_sigma(i) for i in div]
14
+ totients = [totient(n // i) for i in div]
15
+ mul = [a * b for a, b in zip(sigmas, totients)]
16
+ assert n * tau_n == sum(mul)
17
+
18
+
19
+ @given(n=st.integers(1, 10**10))
20
+ def test_totient_hypothesis(n):
21
+ assert totient(n) <= n
22
+ div = divisors(n)
23
+ totients = [totient(i) for i in div]
24
+ assert n == sum(totients)
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_modular.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.ntheory.modular import crt, crt1, crt2, solve_congruence
2
+ from sympy.testing.pytest import raises
3
+
4
+
5
+ def test_crt():
6
+ def mcrt(m, v, r, symmetric=False):
7
+ assert crt(m, v, symmetric)[0] == r
8
+ mm, e, s = crt1(m)
9
+ assert crt2(m, v, mm, e, s, symmetric) == (r, mm)
10
+
11
+ mcrt([2, 3, 5], [0, 0, 0], 0)
12
+ mcrt([2, 3, 5], [1, 1, 1], 1)
13
+
14
+ mcrt([2, 3, 5], [-1, -1, -1], -1, True)
15
+ mcrt([2, 3, 5], [-1, -1, -1], 2*3*5 - 1, False)
16
+
17
+ assert crt([656, 350], [811, 133], symmetric=True) == (-56917, 114800)
18
+
19
+
20
+ def test_modular():
21
+ assert solve_congruence(*list(zip([3, 4, 2], [12, 35, 17]))) == (1719, 7140)
22
+ assert solve_congruence(*list(zip([3, 4, 2], [12, 6, 17]))) is None
23
+ assert solve_congruence(*list(zip([3, 4, 2], [13, 7, 17]))) == (172, 1547)
24
+ assert solve_congruence(*list(zip([-10, -3, -15], [13, 7, 17]))) == (172, 1547)
25
+ assert solve_congruence(*list(zip([-10, -3, 1, -15], [13, 7, 7, 17]))) is None
26
+ assert solve_congruence(
27
+ *list(zip([-10, -5, 2, -15], [13, 7, 7, 17]))) == (835, 1547)
28
+ assert solve_congruence(
29
+ *list(zip([-10, -5, 2, -15], [13, 7, 14, 17]))) == (2382, 3094)
30
+ assert solve_congruence(
31
+ *list(zip([-10, 2, 2, -15], [13, 7, 14, 17]))) == (2382, 3094)
32
+ assert solve_congruence(*list(zip((1, 1, 2), (3, 2, 4)))) is None
33
+ raises(
34
+ ValueError, lambda: solve_congruence(*list(zip([3, 4, 2], [12.1, 35, 17]))))
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_multinomial.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.ntheory.multinomial import (binomial_coefficients, binomial_coefficients_list, multinomial_coefficients)
2
+ from sympy.ntheory.multinomial import multinomial_coefficients_iterator
3
+
4
+
5
+ def test_binomial_coefficients_list():
6
+ assert binomial_coefficients_list(0) == [1]
7
+ assert binomial_coefficients_list(1) == [1, 1]
8
+ assert binomial_coefficients_list(2) == [1, 2, 1]
9
+ assert binomial_coefficients_list(3) == [1, 3, 3, 1]
10
+ assert binomial_coefficients_list(4) == [1, 4, 6, 4, 1]
11
+ assert binomial_coefficients_list(5) == [1, 5, 10, 10, 5, 1]
12
+ assert binomial_coefficients_list(6) == [1, 6, 15, 20, 15, 6, 1]
13
+
14
+
15
+ def test_binomial_coefficients():
16
+ for n in range(15):
17
+ c = binomial_coefficients(n)
18
+ l = [c[k] for k in sorted(c)]
19
+ assert l == binomial_coefficients_list(n)
20
+
21
+
22
+ def test_multinomial_coefficients():
23
+ assert multinomial_coefficients(1, 1) == {(1,): 1}
24
+ assert multinomial_coefficients(1, 2) == {(2,): 1}
25
+ assert multinomial_coefficients(1, 3) == {(3,): 1}
26
+ assert multinomial_coefficients(2, 0) == {(0, 0): 1}
27
+ assert multinomial_coefficients(2, 1) == {(0, 1): 1, (1, 0): 1}
28
+ assert multinomial_coefficients(2, 2) == {(2, 0): 1, (0, 2): 1, (1, 1): 2}
29
+ assert multinomial_coefficients(2, 3) == {(3, 0): 1, (1, 2): 3, (0, 3): 1,
30
+ (2, 1): 3}
31
+ assert multinomial_coefficients(3, 1) == {(1, 0, 0): 1, (0, 1, 0): 1,
32
+ (0, 0, 1): 1}
33
+ assert multinomial_coefficients(3, 2) == {(0, 1, 1): 2, (0, 0, 2): 1,
34
+ (1, 1, 0): 2, (0, 2, 0): 1, (1, 0, 1): 2, (2, 0, 0): 1}
35
+ mc = multinomial_coefficients(3, 3)
36
+ assert mc == {(2, 1, 0): 3, (0, 3, 0): 1,
37
+ (1, 0, 2): 3, (0, 2, 1): 3, (0, 1, 2): 3, (3, 0, 0): 1,
38
+ (2, 0, 1): 3, (1, 2, 0): 3, (1, 1, 1): 6, (0, 0, 3): 1}
39
+ assert dict(multinomial_coefficients_iterator(2, 0)) == {(0, 0): 1}
40
+ assert dict(
41
+ multinomial_coefficients_iterator(2, 1)) == {(0, 1): 1, (1, 0): 1}
42
+ assert dict(multinomial_coefficients_iterator(2, 2)) == \
43
+ {(2, 0): 1, (0, 2): 1, (1, 1): 2}
44
+ assert dict(multinomial_coefficients_iterator(3, 3)) == mc
45
+ it = multinomial_coefficients_iterator(7, 2)
46
+ assert [next(it) for i in range(4)] == \
47
+ [((2, 0, 0, 0, 0, 0, 0), 1), ((1, 1, 0, 0, 0, 0, 0), 2),
48
+ ((0, 2, 0, 0, 0, 0, 0), 1), ((1, 0, 1, 0, 0, 0, 0), 2)]
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_partitions.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.ntheory.partitions_ import npartitions, _partition_rec, _partition
2
+
3
+
4
+ def test__partition_rec():
5
+ A000041 = [1, 1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77, 101, 135,
6
+ 176, 231, 297, 385, 490, 627, 792, 1002, 1255, 1575]
7
+ for n, val in enumerate(A000041):
8
+ assert _partition_rec(n) == val
9
+
10
+
11
+ def test__partition():
12
+ assert [_partition(k) for k in range(13)] == \
13
+ [1, 1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77]
14
+ assert _partition(100) == 190569292
15
+ assert _partition(200) == 3972999029388
16
+ assert _partition(1000) == 24061467864032622473692149727991
17
+ assert _partition(1001) == 25032297938763929621013218349796
18
+ assert _partition(2000) == 4720819175619413888601432406799959512200344166
19
+ assert _partition(10000) % 10**10 == 6916435144
20
+ assert _partition(100000) % 10**10 == 9421098519
21
+ assert _partition(10000000) % 10**10 == 7677288980
22
+
23
+
24
+ def test_deprecated_ntheory_symbolic_functions():
25
+ from sympy.testing.pytest import warns_deprecated_sympy
26
+
27
+ with warns_deprecated_sympy():
28
+ assert npartitions(0) == 1
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_primetest.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import gcd
2
+
3
+ from sympy.ntheory.generate import Sieve, sieve
4
+ from sympy.ntheory.primetest import (mr, _lucas_extrastrong_params, is_lucas_prp, is_square,
5
+ is_strong_lucas_prp, is_extra_strong_lucas_prp,
6
+ proth_test, isprime, is_euler_pseudoprime,
7
+ is_gaussian_prime, is_fermat_pseudoprime, is_euler_jacobi_pseudoprime,
8
+ MERSENNE_PRIME_EXPONENTS, _lucas_lehmer_primality_test,
9
+ is_mersenne_prime)
10
+
11
+ from sympy.testing.pytest import slow, raises
12
+ from sympy.core.numbers import I, Float
13
+
14
+
15
+ def test_is_fermat_pseudoprime():
16
+ assert is_fermat_pseudoprime(5, 1)
17
+ assert is_fermat_pseudoprime(9, 1)
18
+
19
+
20
+ def test_euler_pseudoprimes():
21
+ assert is_euler_pseudoprime(13, 1)
22
+ assert is_euler_pseudoprime(15, 1)
23
+ assert is_euler_pseudoprime(17, 6)
24
+ assert is_euler_pseudoprime(101, 7)
25
+ assert is_euler_pseudoprime(1009, 10)
26
+ assert is_euler_pseudoprime(11287, 41)
27
+
28
+ raises(ValueError, lambda: is_euler_pseudoprime(0, 4))
29
+ raises(ValueError, lambda: is_euler_pseudoprime(3, 0))
30
+ raises(ValueError, lambda: is_euler_pseudoprime(15, 6))
31
+
32
+ # A006970
33
+ euler_prp = [341, 561, 1105, 1729, 1905, 2047, 2465, 3277,
34
+ 4033, 4681, 5461, 6601, 8321, 8481, 10261, 10585]
35
+ for p in euler_prp:
36
+ assert is_euler_pseudoprime(p, 2)
37
+
38
+ # A048950
39
+ euler_prp = [121, 703, 1729, 1891, 2821, 3281, 7381, 8401, 8911, 10585,
40
+ 12403, 15457, 15841, 16531, 18721, 19345, 23521, 24661, 28009]
41
+ for p in euler_prp:
42
+ assert is_euler_pseudoprime(p, 3)
43
+
44
+ # A033181
45
+ absolute_euler_prp = [1729, 2465, 15841, 41041, 46657, 75361,
46
+ 162401, 172081, 399001, 449065, 488881]
47
+ for p in absolute_euler_prp:
48
+ for a in range(2, p):
49
+ if gcd(a, p) != 1:
50
+ continue
51
+ assert is_euler_pseudoprime(p, a)
52
+
53
+
54
+ def test_is_euler_jacobi_pseudoprime():
55
+ assert is_euler_jacobi_pseudoprime(11, 1)
56
+ assert is_euler_jacobi_pseudoprime(15, 1)
57
+
58
+
59
+ def test_lucas_extrastrong_params():
60
+ assert _lucas_extrastrong_params(3) == (5, 3, 1)
61
+ assert _lucas_extrastrong_params(5) == (12, 4, 1)
62
+ assert _lucas_extrastrong_params(7) == (5, 3, 1)
63
+ assert _lucas_extrastrong_params(9) == (0, 0, 0)
64
+ assert _lucas_extrastrong_params(11) == (21, 5, 1)
65
+ assert _lucas_extrastrong_params(59) == (32, 6, 1)
66
+ assert _lucas_extrastrong_params(479) == (117, 11, 1)
67
+
68
+
69
+ def test_is_extra_strong_lucas_prp():
70
+ assert is_extra_strong_lucas_prp(4) == False
71
+ assert is_extra_strong_lucas_prp(989) == True
72
+ assert is_extra_strong_lucas_prp(10877) == True
73
+ assert is_extra_strong_lucas_prp(9) == False
74
+ assert is_extra_strong_lucas_prp(16) == False
75
+ assert is_extra_strong_lucas_prp(169) == False
76
+
77
+ @slow
78
+ def test_prps():
79
+ oddcomposites = [n for n in range(1, 10**5) if
80
+ n % 2 and not isprime(n)]
81
+ # A checksum would be better.
82
+ assert sum(oddcomposites) == 2045603465
83
+ assert [n for n in oddcomposites if mr(n, [2])] == [
84
+ 2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141,
85
+ 52633, 65281, 74665, 80581, 85489, 88357, 90751]
86
+ assert [n for n in oddcomposites if mr(n, [3])] == [
87
+ 121, 703, 1891, 3281, 8401, 8911, 10585, 12403, 16531,
88
+ 18721, 19345, 23521, 31621, 44287, 47197, 55969, 63139,
89
+ 74593, 79003, 82513, 87913, 88573, 97567]
90
+ assert [n for n in oddcomposites if mr(n, [325])] == [
91
+ 9, 25, 27, 49, 65, 81, 325, 341, 343, 697, 1141, 2059,
92
+ 2149, 3097, 3537, 4033, 4681, 4941, 5833, 6517, 7987, 8911,
93
+ 12403, 12913, 15043, 16021, 20017, 22261, 23221, 24649,
94
+ 24929, 31841, 35371, 38503, 43213, 44173, 47197, 50041,
95
+ 55909, 56033, 58969, 59089, 61337, 65441, 68823, 72641,
96
+ 76793, 78409, 85879]
97
+ assert not any(mr(n, [9345883071009581737]) for n in oddcomposites)
98
+ assert [n for n in oddcomposites if is_lucas_prp(n)] == [
99
+ 323, 377, 1159, 1829, 3827, 5459, 5777, 9071, 9179, 10877,
100
+ 11419, 11663, 13919, 14839, 16109, 16211, 18407, 18971,
101
+ 19043, 22499, 23407, 24569, 25199, 25877, 26069, 27323,
102
+ 32759, 34943, 35207, 39059, 39203, 39689, 40309, 44099,
103
+ 46979, 47879, 50183, 51983, 53663, 56279, 58519, 60377,
104
+ 63881, 69509, 72389, 73919, 75077, 77219, 79547, 79799,
105
+ 82983, 84419, 86063, 90287, 94667, 97019, 97439]
106
+ assert [n for n in oddcomposites if is_strong_lucas_prp(n)] == [
107
+ 5459, 5777, 10877, 16109, 18971, 22499, 24569, 25199, 40309,
108
+ 58519, 75077, 97439]
109
+ assert [n for n in oddcomposites if is_extra_strong_lucas_prp(n)
110
+ ] == [
111
+ 989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059,
112
+ 72389, 73919, 75077]
113
+
114
+
115
+ def test_proth_test():
116
+ # Proth number
117
+ A080075 = [3, 5, 9, 13, 17, 25, 33, 41, 49, 57, 65,
118
+ 81, 97, 113, 129, 145, 161, 177, 193]
119
+ # Proth prime
120
+ A080076 = [3, 5, 13, 17, 41, 97, 113, 193]
121
+
122
+ for n in range(200):
123
+ if n in A080075:
124
+ assert proth_test(n) == (n in A080076)
125
+ else:
126
+ raises(ValueError, lambda: proth_test(n))
127
+
128
+
129
+ def test_lucas_lehmer_primality_test():
130
+ for p in sieve.primerange(3, 100):
131
+ assert _lucas_lehmer_primality_test(p) == (p in MERSENNE_PRIME_EXPONENTS)
132
+
133
+
134
+ def test_is_mersenne_prime():
135
+ assert is_mersenne_prime(-3) is False
136
+ assert is_mersenne_prime(3) is True
137
+ assert is_mersenne_prime(10) is False
138
+ assert is_mersenne_prime(127) is True
139
+ assert is_mersenne_prime(511) is False
140
+ assert is_mersenne_prime(131071) is True
141
+ assert is_mersenne_prime(2147483647) is True
142
+
143
+
144
+ def test_isprime():
145
+ s = Sieve()
146
+ s.extend(100000)
147
+ ps = set(s.primerange(2, 100001))
148
+ for n in range(100001):
149
+ # if (n in ps) != isprime(n): print n
150
+ assert (n in ps) == isprime(n)
151
+ assert isprime(179424673)
152
+ assert isprime(20678048681)
153
+ assert isprime(1968188556461)
154
+ assert isprime(2614941710599)
155
+ assert isprime(65635624165761929287)
156
+ assert isprime(1162566711635022452267983)
157
+ assert isprime(77123077103005189615466924501)
158
+ assert isprime(3991617775553178702574451996736229)
159
+ assert isprime(273952953553395851092382714516720001799)
160
+ assert isprime(int('''
161
+ 531137992816767098689588206552468627329593117727031923199444138200403\
162
+ 559860852242739162502265229285668889329486246501015346579337652707239\
163
+ 409519978766587351943831270835393219031728127'''))
164
+
165
+ # Some Mersenne primes
166
+ assert isprime(2**61 - 1)
167
+ assert isprime(2**89 - 1)
168
+ assert isprime(2**607 - 1)
169
+ # (but not all Mersenne's are primes
170
+ assert not isprime(2**601 - 1)
171
+
172
+ # pseudoprimes
173
+ #-------------
174
+ # to some small bases
175
+ assert not isprime(2152302898747)
176
+ assert not isprime(3474749660383)
177
+ assert not isprime(341550071728321)
178
+ assert not isprime(3825123056546413051)
179
+ # passes the base set [2, 3, 7, 61, 24251]
180
+ assert not isprime(9188353522314541)
181
+ # large examples
182
+ assert not isprime(877777777777777777777777)
183
+ # conjectured psi_12 given at http://mathworld.wolfram.com/StrongPseudoprime.html
184
+ assert not isprime(318665857834031151167461)
185
+ # conjectured psi_17 given at http://mathworld.wolfram.com/StrongPseudoprime.html
186
+ assert not isprime(564132928021909221014087501701)
187
+ # Arnault's 1993 number; a factor of it is
188
+ # 400958216639499605418306452084546853005188166041132508774506\
189
+ # 204738003217070119624271622319159721973358216316508535816696\
190
+ # 9145233813917169287527980445796800452592031836601
191
+ assert not isprime(int('''
192
+ 803837457453639491257079614341942108138837688287558145837488917522297\
193
+ 427376533365218650233616396004545791504202360320876656996676098728404\
194
+ 396540823292873879185086916685732826776177102938969773947016708230428\
195
+ 687109997439976544144845341155872450633409279022275296229414984230688\
196
+ 1685404326457534018329786111298960644845216191652872597534901'''))
197
+ # Arnault's 1995 number; can be factored as
198
+ # p1*(313*(p1 - 1) + 1)*(353*(p1 - 1) + 1) where p1 is
199
+ # 296744956686855105501541746429053327307719917998530433509950\
200
+ # 755312768387531717701995942385964281211880336647542183455624\
201
+ # 93168782883
202
+ assert not isprime(int('''
203
+ 288714823805077121267142959713039399197760945927972270092651602419743\
204
+ 230379915273311632898314463922594197780311092934965557841894944174093\
205
+ 380561511397999942154241693397290542371100275104208013496673175515285\
206
+ 922696291677532547504444585610194940420003990443211677661994962953925\
207
+ 045269871932907037356403227370127845389912612030924484149472897688540\
208
+ 6024976768122077071687938121709811322297802059565867'''))
209
+ sieve.extend(3000)
210
+ assert isprime(2819)
211
+ assert not isprime(2931)
212
+ raises(ValueError, lambda: isprime(2.0))
213
+ raises(ValueError, lambda: isprime(Float(2)))
214
+
215
+
216
+ def test_is_square():
217
+ assert [i for i in range(25) if is_square(i)] == [0, 1, 4, 9, 16]
218
+
219
+ # issue #17044
220
+ assert not is_square(60 ** 3)
221
+ assert not is_square(60 ** 5)
222
+ assert not is_square(84 ** 7)
223
+ assert not is_square(105 ** 9)
224
+ assert not is_square(120 ** 3)
225
+
226
+ def test_is_gaussianprime():
227
+ assert is_gaussian_prime(7*I)
228
+ assert is_gaussian_prime(7)
229
+ assert is_gaussian_prime(2 + 3*I)
230
+ assert not is_gaussian_prime(2 + 2*I)
231
+
232
+
233
+ def test_issue_27145():
234
+ #https://github.com/sympy/sympy/issues/27145
235
+ assert [mr(i,[2,3,5,7]) for i in (1, 2, 6)] == [False, True, False]
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_qs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from sympy.core.random import _randint
5
+ from sympy.ntheory import qs, qs_factor
6
+ from sympy.ntheory.qs import SievePolynomial, _generate_factor_base, \
7
+ _generate_polynomial, \
8
+ _gen_sieve_array, _check_smoothness, _trial_division_stage, _find_factor
9
+ from sympy.testing.pytest import slow
10
+
11
+
12
+ @slow
13
+ def test_qs_1():
14
+ assert qs(10009202107, 100, 10000) == {100043, 100049}
15
+ assert qs(211107295182713951054568361, 1000, 10000) == \
16
+ {13791315212531, 15307263442931}
17
+ assert qs(980835832582657*990377764891511, 2000, 10000) == \
18
+ {980835832582657, 990377764891511}
19
+ assert qs(18640889198609*20991129234731, 1000, 50000) == \
20
+ {18640889198609, 20991129234731}
21
+
22
+
23
+ def test_qs_2() -> None:
24
+ n = 10009202107
25
+ M = 50
26
+ sieve_poly = SievePolynomial(10, 80, n)
27
+ assert sieve_poly.eval_v(10) == sieve_poly.eval_u(10)**2 - n == -10009169707
28
+ assert sieve_poly.eval_v(5) == sieve_poly.eval_u(5)**2 - n == -10009185207
29
+
30
+ idx_1000, idx_5000, factor_base = _generate_factor_base(2000, n)
31
+ assert idx_1000 == 82
32
+ assert [factor_base[i].prime for i in range(15)] == \
33
+ [2, 3, 7, 11, 17, 19, 29, 31, 43, 59, 61, 67, 71, 73, 79]
34
+ assert [factor_base[i].tmem_p for i in range(15)] == \
35
+ [1, 1, 3, 5, 3, 6, 6, 14, 1, 16, 24, 22, 18, 22, 15]
36
+ assert [factor_base[i].log_p for i in range(5)] == \
37
+ [710, 1125, 1993, 2455, 2901]
38
+
39
+ it = _generate_polynomial(
40
+ n, M, factor_base, idx_1000, idx_5000, _randint(0))
41
+ g = next(it)
42
+ assert g.a == 1133107
43
+ assert g.b == 682543
44
+ assert [factor_base[i].soln1 for i in range(15)] == \
45
+ [0, 0, 3, 7, 13, 0, 8, 19, 9, 43, 27, 25, 63, 29, 19]
46
+ assert [factor_base[i].soln2 for i in range(15)] == \
47
+ [0, 1, 1, 3, 12, 16, 15, 6, 15, 1, 56, 55, 61, 58, 16]
48
+ assert [factor_base[i].b_ainv for i in range(5)] == \
49
+ [[0, 0], [0, 2], [3, 0], [3, 9], [13, 13]]
50
+
51
+ g_1 = next(it)
52
+ assert g_1.a == 1133107
53
+ assert g_1.b == 136765
54
+
55
+ sieve_array = _gen_sieve_array(M, factor_base)
56
+ assert sieve_array[0:5] == [8424, 13603, 1835, 5335, 710]
57
+
58
+ assert _check_smoothness(9645, factor_base) == (36028797018963972, 5)
59
+ assert _check_smoothness(210313, factor_base) == (20992, 1)
60
+
61
+ partial_relations: dict[int, tuple[int, int]] = {}
62
+ smooth_relation, proper_factor = _trial_division_stage(
63
+ n, M, factor_base, sieve_array, sieve_poly, partial_relations,
64
+ ERROR_TERM=25*2**10)
65
+
66
+ assert partial_relations == {
67
+ 8699: (440, -10009008507, 75557863761098695507973),
68
+ 166741: (490, -10008962007, 524341),
69
+ 131449: (530, -10008921207, 664613997892457936451903530140172325),
70
+ 6653: (550, -10008899607, 19342813113834066795307021)
71
+ }
72
+ assert [smooth_relation[i][0] for i in range(5)] == [
73
+ -250, 1064469, 72819, 231957, 44167]
74
+ assert [smooth_relation[i][1] for i in range(5)] == [
75
+ -10009139607, 1133094251961, 5302606761, 53804049849, 1950723889]
76
+ assert smooth_relation[0][2] == 89213869829863962596973701078031812362502145
77
+ assert proper_factor == set()
78
+
79
+
80
+ def test_qs_3():
81
+ N = 1817
82
+ smooth_relations = [
83
+ (2455024, 637, 8),
84
+ (-27993000, 81536, 10),
85
+ (11461840, 12544, 0),
86
+ (149, 20384, 10),
87
+ (-31138074, 19208, 2)
88
+ ]
89
+ assert next(_find_factor(N, smooth_relations, 4)) == 23
90
+
91
+
92
+ def test_qs_4():
93
+ N = 10007**2 * 10009 * 10037**3 * 10039
94
+ for factor in qs(N, 1000, 2000):
95
+ assert N % factor == 0
96
+ N //= factor
97
+
98
+
99
+ def test_qs_factor():
100
+ assert qs_factor(1009 * 100003, 2000, 10000) == {1009: 1, 100003: 1}
101
+ n = 1009**2 * 2003**2*30011*400009
102
+ factors = qs_factor(n, 2000, 10000)
103
+ assert len(factors) > 1
104
+ assert math.prod(p**e for p, e in factors.items()) == n
105
+
106
+
107
+ def test_issue_27616():
108
+ #https://github.com/sympy/sympy/issues/27616
109
+ N = 9804659461513846513 + 1
110
+ assert qs(N, 5000, 20000) is not None
.venv/lib/python3.13/site-packages/sympy/ntheory/tests/test_residue.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from sympy.core.containers import Tuple
3
+ from sympy.core.singleton import S
4
+ from sympy.core.symbol import (Dummy, Symbol)
5
+ from sympy.functions.combinatorial.numbers import totient
6
+ from sympy.ntheory import n_order, is_primitive_root, is_quad_residue, \
7
+ legendre_symbol, jacobi_symbol, primerange, sqrt_mod, \
8
+ primitive_root, quadratic_residues, is_nthpow_residue, nthroot_mod, \
9
+ sqrt_mod_iter, mobius, discrete_log, quadratic_congruence, \
10
+ polynomial_congruence, sieve
11
+ from sympy.ntheory.residue_ntheory import _primitive_root_prime_iter, \
12
+ _primitive_root_prime_power_iter, _primitive_root_prime_power2_iter, \
13
+ _nthroot_mod_prime_power, _discrete_log_trial_mul, _discrete_log_shanks_steps, \
14
+ _discrete_log_pollard_rho, _discrete_log_index_calculus, _discrete_log_pohlig_hellman, \
15
+ _binomial_mod_prime_power, binomial_mod
16
+ from sympy.polys.domains import ZZ
17
+ from sympy.testing.pytest import raises
18
+ from sympy.core.random import randint, choice
19
+
20
+
21
+ def test_residue():
22
+ assert n_order(2, 13) == 12
23
+ assert [n_order(a, 7) for a in range(1, 7)] == \
24
+ [1, 3, 6, 3, 6, 2]
25
+ assert n_order(5, 17) == 16
26
+ assert n_order(17, 11) == n_order(6, 11)
27
+ assert n_order(101, 119) == 6
28
+ assert n_order(11, (10**50 + 151)**2) == 10000000000000000000000000000000000000000000000030100000000000000000000000000000000000000000000022650
29
+ raises(ValueError, lambda: n_order(6, 9))
30
+
31
+ assert is_primitive_root(2, 7) is False
32
+ assert is_primitive_root(3, 8) is False
33
+ assert is_primitive_root(11, 14) is False
34
+ assert is_primitive_root(12, 17) == is_primitive_root(29, 17)
35
+ raises(ValueError, lambda: is_primitive_root(3, 6))
36
+
37
+ for p in primerange(3, 100):
38
+ li = list(_primitive_root_prime_iter(p))
39
+ assert li[0] == min(li)
40
+ for g in li:
41
+ assert n_order(g, p) == p - 1
42
+ assert len(li) == totient(totient(p))
43
+ for e in range(1, 4):
44
+ li_power = list(_primitive_root_prime_power_iter(p, e))
45
+ li_power2 = list(_primitive_root_prime_power2_iter(p, e))
46
+ assert len(li_power) == len(li_power2) == totient(totient(p**e))
47
+ assert primitive_root(97) == 5
48
+ assert n_order(primitive_root(97, False), 97) == totient(97)
49
+ assert primitive_root(97**2) == 5
50
+ assert n_order(primitive_root(97**2, False), 97**2) == totient(97**2)
51
+ assert primitive_root(40487) == 5
52
+ assert n_order(primitive_root(40487, False), 40487) == totient(40487)
53
+ # note that primitive_root(40487) + 40487 = 40492 is a primitive root
54
+ # of 40487**2, but it is not the smallest
55
+ assert primitive_root(40487**2) == 10
56
+ assert n_order(primitive_root(40487**2, False), 40487**2) == totient(40487**2)
57
+ assert primitive_root(82) == 7
58
+ assert n_order(primitive_root(82, False), 82) == totient(82)
59
+ p = 10**50 + 151
60
+ assert primitive_root(p) == 11
61
+ assert n_order(primitive_root(p, False), p) == totient(p)
62
+ assert primitive_root(2*p) == 11
63
+ assert n_order(primitive_root(2*p, False), 2*p) == totient(2*p)
64
+ assert primitive_root(p**2) == 11
65
+ assert n_order(primitive_root(p**2, False), p**2) == totient(p**2)
66
+ assert primitive_root(4 * 11) is None and primitive_root(4 * 11, False) is None
67
+ assert primitive_root(15) is None and primitive_root(15, False) is None
68
+ raises(ValueError, lambda: primitive_root(-3))
69
+
70
+ assert is_quad_residue(3, 7) is False
71
+ assert is_quad_residue(10, 13) is True
72
+ assert is_quad_residue(12364, 139) == is_quad_residue(12364 % 139, 139)
73
+ assert is_quad_residue(207, 251) is True
74
+ assert is_quad_residue(0, 1) is True
75
+ assert is_quad_residue(1, 1) is True
76
+ assert is_quad_residue(0, 2) == is_quad_residue(1, 2) is True
77
+ assert is_quad_residue(1, 4) is True
78
+ assert is_quad_residue(2, 27) is False
79
+ assert is_quad_residue(13122380800, 13604889600) is True
80
+ assert [j for j in range(14) if is_quad_residue(j, 14)] == \
81
+ [0, 1, 2, 4, 7, 8, 9, 11]
82
+ raises(ValueError, lambda: is_quad_residue(1.1, 2))
83
+ raises(ValueError, lambda: is_quad_residue(2, 0))
84
+
85
+ assert quadratic_residues(S.One) == [0]
86
+ assert quadratic_residues(1) == [0]
87
+ assert quadratic_residues(12) == [0, 1, 4, 9]
88
+ assert quadratic_residues(13) == [0, 1, 3, 4, 9, 10, 12]
89
+ assert [len(quadratic_residues(i)) for i in range(1, 20)] == \
90
+ [1, 2, 2, 2, 3, 4, 4, 3, 4, 6, 6, 4, 7, 8, 6, 4, 9, 8, 10]
91
+
92
+ assert list(sqrt_mod_iter(6, 2)) == [0]
93
+ assert sqrt_mod(3, 13) == 4
94
+ assert sqrt_mod(3, -13) == 4
95
+ assert sqrt_mod(6, 23) == 11
96
+ assert sqrt_mod(345, 690) == 345
97
+ assert sqrt_mod(67, 101) == None
98
+ assert sqrt_mod(1020, 104729) == None
99
+
100
+ for p in range(3, 100):
101
+ d = defaultdict(list)
102
+ for i in range(p):
103
+ d[pow(i, 2, p)].append(i)
104
+ for i in range(1, p):
105
+ it = sqrt_mod_iter(i, p)
106
+ v = sqrt_mod(i, p, True)
107
+ if v:
108
+ v = sorted(v)
109
+ assert d[i] == v
110
+ else:
111
+ assert not d[i]
112
+
113
+ assert sqrt_mod(9, 27, True) == [3, 6, 12, 15, 21, 24]
114
+ assert sqrt_mod(9, 81, True) == [3, 24, 30, 51, 57, 78]
115
+ assert sqrt_mod(9, 3**5, True) == [3, 78, 84, 159, 165, 240]
116
+ assert sqrt_mod(81, 3**4, True) == [0, 9, 18, 27, 36, 45, 54, 63, 72]
117
+ assert sqrt_mod(81, 3**5, True) == [9, 18, 36, 45, 63, 72, 90, 99, 117,\
118
+ 126, 144, 153, 171, 180, 198, 207, 225, 234]
119
+ assert sqrt_mod(81, 3**6, True) == [9, 72, 90, 153, 171, 234, 252, 315,\
120
+ 333, 396, 414, 477, 495, 558, 576, 639, 657, 720]
121
+ assert sqrt_mod(81, 3**7, True) == [9, 234, 252, 477, 495, 720, 738, 963,\
122
+ 981, 1206, 1224, 1449, 1467, 1692, 1710, 1935, 1953, 2178]
123
+
124
+ for a, p in [(26214400, 32768000000), (26214400, 16384000000),
125
+ (262144, 1048576), (87169610025, 163443018796875),
126
+ (22315420166400, 167365651248000000)]:
127
+ assert pow(sqrt_mod(a, p), 2, p) == a
128
+
129
+ n = 70
130
+ a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+2)
131
+ it = sqrt_mod_iter(a, p)
132
+ for i in range(10):
133
+ assert pow(next(it), 2, p) == a
134
+ a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+3)
135
+ it = sqrt_mod_iter(a, p)
136
+ for i in range(2):
137
+ assert pow(next(it), 2, p) == a
138
+ n = 100
139
+ a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+1)
140
+ it = sqrt_mod_iter(a, p)
141
+ for i in range(2):
142
+ assert pow(next(it), 2, p) == a
143
+
144
+ assert type(next(sqrt_mod_iter(9, 27))) is int
145
+ assert type(next(sqrt_mod_iter(9, 27, ZZ))) is type(ZZ(1))
146
+ assert type(next(sqrt_mod_iter(1, 7, ZZ))) is type(ZZ(1))
147
+
148
+ assert is_nthpow_residue(2, 1, 5)
149
+
150
+ #issue 10816
151
+ assert is_nthpow_residue(1, 0, 1) is False
152
+ assert is_nthpow_residue(1, 0, 2) is True
153
+ assert is_nthpow_residue(3, 0, 2) is True
154
+ assert is_nthpow_residue(0, 1, 8) is True
155
+ assert is_nthpow_residue(2, 3, 2) is True
156
+ assert is_nthpow_residue(2, 3, 9) is False
157
+ assert is_nthpow_residue(3, 5, 30) is True
158
+ assert is_nthpow_residue(21, 11, 20) is True
159
+ assert is_nthpow_residue(7, 10, 20) is False
160
+ assert is_nthpow_residue(5, 10, 20) is True
161
+ assert is_nthpow_residue(3, 10, 48) is False
162
+ assert is_nthpow_residue(1, 10, 40) is True
163
+ assert is_nthpow_residue(3, 10, 24) is False
164
+ assert is_nthpow_residue(1, 10, 24) is True
165
+ assert is_nthpow_residue(3, 10, 24) is False
166
+ assert is_nthpow_residue(2, 10, 48) is False
167
+ assert is_nthpow_residue(81, 3, 972) is False
168
+ assert is_nthpow_residue(243, 5, 5103) is True
169
+ assert is_nthpow_residue(243, 3, 1240029) is False
170
+ assert is_nthpow_residue(36010, 8, 87382) is True
171
+ assert is_nthpow_residue(28552, 6, 2218) is True
172
+ assert is_nthpow_residue(92712, 9, 50026) is True
173
+ x = {pow(i, 56, 1024) for i in range(1024)}
174
+ assert {a for a in range(1024) if is_nthpow_residue(a, 56, 1024)} == x
175
+ x = { pow(i, 256, 2048) for i in range(2048)}
176
+ assert {a for a in range(2048) if is_nthpow_residue(a, 256, 2048)} == x
177
+ x = { pow(i, 11, 324000) for i in range(1000)}
178
+ assert [ is_nthpow_residue(a, 11, 324000) for a in x]
179
+ x = { pow(i, 17, 22217575536) for i in range(1000)}
180
+ assert [ is_nthpow_residue(a, 17, 22217575536) for a in x]
181
+ assert is_nthpow_residue(676, 3, 5364)
182
+ assert is_nthpow_residue(9, 12, 36)
183
+ assert is_nthpow_residue(32, 10, 41)
184
+ assert is_nthpow_residue(4, 2, 64)
185
+ assert is_nthpow_residue(31, 4, 41)
186
+ assert not is_nthpow_residue(2, 2, 5)
187
+ assert is_nthpow_residue(8547, 12, 10007)
188
+ assert is_nthpow_residue(Dummy(even=True) + 3, 3, 2) == True
189
+ # _nthroot_mod_prime_power
190
+ for p in primerange(2, 10):
191
+ for a in range(3):
192
+ for n in range(3, 5):
193
+ ans = _nthroot_mod_prime_power(a, n, p, 1)
194
+ assert isinstance(ans, list)
195
+ if len(ans) == 0:
196
+ for b in range(p):
197
+ assert pow(b, n, p) != a % p
198
+ for k in range(2, 10):
199
+ assert _nthroot_mod_prime_power(a, n, p, k) == []
200
+ else:
201
+ for b in range(p):
202
+ pred = pow(b, n, p) == a % p
203
+ assert not(pred ^ (b in ans))
204
+ for k in range(2, 10):
205
+ ans = _nthroot_mod_prime_power(a, n, p, k)
206
+ if not ans:
207
+ break
208
+ for b in ans:
209
+ assert pow(b, n , p**k) == a
210
+
211
+ assert nthroot_mod(Dummy(odd=True), 3, 2) == 1
212
+ assert nthroot_mod(29, 31, 74) == 45
213
+ assert nthroot_mod(1801, 11, 2663) == 44
214
+ for a, q, p in [(51922, 2, 203017), (43, 3, 109), (1801, 11, 2663),
215
+ (26118163, 1303, 33333347), (1499, 7, 2663), (595, 6, 2663),
216
+ (1714, 12, 2663), (28477, 9, 33343)]:
217
+ r = nthroot_mod(a, q, p)
218
+ assert pow(r, q, p) == a
219
+ assert nthroot_mod(11, 3, 109) is None
220
+ assert nthroot_mod(16, 5, 36, True) == [4, 22]
221
+ assert nthroot_mod(9, 16, 36, True) == [3, 9, 15, 21, 27, 33]
222
+ assert nthroot_mod(4, 3, 3249000) is None
223
+ assert nthroot_mod(36010, 8, 87382, True) == [40208, 47174]
224
+ assert nthroot_mod(0, 12, 37, True) == [0]
225
+ assert nthroot_mod(0, 7, 100, True) == [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
226
+ assert nthroot_mod(4, 4, 27, True) == [5, 22]
227
+ assert nthroot_mod(4, 4, 121, True) == [19, 102]
228
+ assert nthroot_mod(2, 3, 7, True) == []
229
+ for p in range(1, 20):
230
+ for a in range(p):
231
+ for n in range(1, p):
232
+ ans = nthroot_mod(a, n, p, True)
233
+ assert isinstance(ans, list)
234
+ for b in range(p):
235
+ pred = pow(b, n, p) == a
236
+ assert not(pred ^ (b in ans))
237
+ ans2 = nthroot_mod(a, n, p, False)
238
+ if ans2 is None:
239
+ assert ans == []
240
+ else:
241
+ assert ans2 in ans
242
+
243
+ x = Symbol('x', positive=True)
244
+ i = Symbol('i', integer=True)
245
+ assert _discrete_log_trial_mul(587, 2**7, 2) == 7
246
+ assert _discrete_log_trial_mul(941, 7**18, 7) == 18
247
+ assert _discrete_log_trial_mul(389, 3**81, 3) == 81
248
+ assert _discrete_log_trial_mul(191, 19**123, 19) == 123
249
+ assert _discrete_log_shanks_steps(442879, 7**2, 7) == 2
250
+ assert _discrete_log_shanks_steps(874323, 5**19, 5) == 19
251
+ assert _discrete_log_shanks_steps(6876342, 7**71, 7) == 71
252
+ assert _discrete_log_shanks_steps(2456747, 3**321, 3) == 321
253
+ assert _discrete_log_pollard_rho(6013199, 2**6, 2, rseed=0) == 6
254
+ assert _discrete_log_pollard_rho(6138719, 2**19, 2, rseed=0) == 19
255
+ assert _discrete_log_pollard_rho(36721943, 2**40, 2, rseed=0) == 40
256
+ assert _discrete_log_pollard_rho(24567899, 3**333, 3, rseed=0) == 333
257
+ raises(ValueError, lambda: _discrete_log_pollard_rho(11, 7, 31, rseed=0))
258
+ raises(ValueError, lambda: _discrete_log_pollard_rho(227, 3**7, 5, rseed=0))
259
+ assert _discrete_log_index_calculus(983, 948, 2, 491) == 183
260
+ assert _discrete_log_index_calculus(633383, 21794, 2, 316691) == 68048
261
+ assert _discrete_log_index_calculus(941762639, 68822582, 2, 470881319) == 338029275
262
+ assert _discrete_log_index_calculus(999231337607, 888188918786, 2, 499615668803) == 142811376514
263
+ assert _discrete_log_index_calculus(47747730623, 19410045286, 43425105668, 645239603) == 590504662
264
+ assert _discrete_log_pohlig_hellman(98376431, 11**9, 11) == 9
265
+ assert _discrete_log_pohlig_hellman(78723213, 11**31, 11) == 31
266
+ assert _discrete_log_pohlig_hellman(32942478, 11**98, 11) == 98
267
+ assert _discrete_log_pohlig_hellman(14789363, 11**444, 11) == 444
268
+ assert discrete_log(1, 0, 2) == 0
269
+ raises(ValueError, lambda: discrete_log(-4, 1, 3))
270
+ raises(ValueError, lambda: discrete_log(10, 3, 2))
271
+ assert discrete_log(587, 2**9, 2) == 9
272
+ assert discrete_log(2456747, 3**51, 3) == 51
273
+ assert discrete_log(32942478, 11**127, 11) == 127
274
+ assert discrete_log(432751500361, 7**324, 7) == 324
275
+ assert discrete_log(265390227570863,184500076053622, 2) == 17835221372061
276
+ assert discrete_log(22708823198678103974314518195029102158525052496759285596453269189798311427475159776411276642277139650833937,
277
+ 17463946429475485293747680247507700244427944625055089103624311227422110546803452417458985046168310373075327,
278
+ 123456) == 2068031853682195777930683306640554533145512201725884603914601918777510185469769997054750835368413389728895
279
+ args = 5779, 3528, 6215
280
+ assert discrete_log(*args) == 687
281
+ assert discrete_log(*Tuple(*args)) == 687
282
+ assert quadratic_congruence(400, 85, 125, 1600) == [295, 615, 935, 1255, 1575]
283
+ assert quadratic_congruence(3, 6, 5, 25) == [3, 20]
284
+ assert quadratic_congruence(120, 80, 175, 500) == []
285
+ assert quadratic_congruence(15, 14, 7, 2) == [1]
286
+ assert quadratic_congruence(8, 15, 7, 29) == [10, 28]
287
+ assert quadratic_congruence(160, 200, 300, 461) == [144, 431]
288
+ assert quadratic_congruence(100000, 123456, 7415263, 48112959837082048697) == [30417843635344493501, 36001135160550533083]
289
+ assert quadratic_congruence(65, 121, 72, 277) == [249, 252]
290
+ assert quadratic_congruence(5, 10, 14, 2) == [0]
291
+ assert quadratic_congruence(10, 17, 19, 2) == [1]
292
+ assert quadratic_congruence(10, 14, 20, 2) == [0, 1]
293
+ assert quadratic_congruence(2**48-7, 2**48-1, 4, 2**48) == [8249717183797, 31960993774868]
294
+ assert polynomial_congruence(6*x**5 + 10*x**4 + 5*x**3 + x**2 + x + 1,
295
+ 972000) == [220999, 242999, 463999, 485999, 706999, 728999, 949999, 971999]
296
+
297
+ assert polynomial_congruence(x**3 - 10*x**2 + 12*x - 82, 33075) == [30287]
298
+ assert polynomial_congruence(x**2 + x + 47, 2401) == [785, 1615]
299
+ assert polynomial_congruence(10*x**2 + 14*x + 20, 2) == [0, 1]
300
+ assert polynomial_congruence(x**3 + 3, 16) == [5]
301
+ assert polynomial_congruence(65*x**2 + 121*x + 72, 277) == [249, 252]
302
+ assert polynomial_congruence(x**4 - 4, 27) == [5, 22]
303
+ assert polynomial_congruence(35*x**3 - 6*x**2 - 567*x + 2308, 148225) == [86957,
304
+ 111157, 122531, 146731]
305
+ assert polynomial_congruence(x**16 - 9, 36) == [3, 9, 15, 21, 27, 33]
306
+ assert polynomial_congruence(x**6 - 2*x**5 - 35, 6125) == [3257]
307
+ raises(ValueError, lambda: polynomial_congruence(x**x, 6125))
308
+ raises(ValueError, lambda: polynomial_congruence(x**i, 6125))
309
+ raises(ValueError, lambda: polynomial_congruence(0.1*x**2 + 6, 100))
310
+
311
+ assert binomial_mod(-1, 1, 10) == 0
312
+ assert binomial_mod(1, -1, 10) == 0
313
+ raises(ValueError, lambda: binomial_mod(2, 1, -1))
314
+ assert binomial_mod(51, 10, 10) == 0
315
+ assert binomial_mod(10**3, 500, 3**6) == 567
316
+ assert binomial_mod(10**18 - 1, 123456789, 4) == 0
317
+ assert binomial_mod(10**18, 10**12, (10**5 + 3)**2) == 3744312326
318
+
319
+
320
+ def test_binomial_p_pow():
321
+ n, binomials, binomial = 1000, [1], 1
322
+ for i in range(1, n + 1):
323
+ binomial *= n - i + 1
324
+ binomial //= i
325
+ binomials.append(binomial)
326
+
327
+ # Test powers of two, which the algorithm treats slightly differently
328
+ trials_2 = 100
329
+ for _ in range(trials_2):
330
+ m, power = randint(0, n), randint(1, 20)
331
+ assert _binomial_mod_prime_power(n, m, 2, power) == binomials[m] % 2**power
332
+
333
+ # Test against other prime powers
334
+ primes = list(sieve.primerange(2*n))
335
+ trials = 1000
336
+ for _ in range(trials):
337
+ m, prime, power = randint(0, n), choice(primes), randint(1, 10)
338
+ assert _binomial_mod_prime_power(n, m, prime, power) == binomials[m] % prime**power
339
+
340
+
341
+ def test_deprecated_ntheory_symbolic_functions():
342
+ from sympy.testing.pytest import warns_deprecated_sympy
343
+
344
+ with warns_deprecated_sympy():
345
+ assert mobius(3) == -1
346
+ with warns_deprecated_sympy():
347
+ assert legendre_symbol(2, 3) == -1
348
+ with warns_deprecated_sympy():
349
+ assert jacobi_symbol(2, 3) == -1
.venv/lib/python3.13/site-packages/sympy/printing/pretty/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ASCII-ART 2D pretty-printer"""
2
+
3
+ from .pretty import (pretty, pretty_print, pprint, pprint_use_unicode,
4
+ pprint_try_use_unicode, pager_print)
5
+
6
+ # if unicode output is available -- let's use it
7
+ pprint_try_use_unicode()
8
+
9
+ __all__ = [
10
+ 'pretty', 'pretty_print', 'pprint', 'pprint_use_unicode',
11
+ 'pprint_try_use_unicode', 'pager_print',
12
+ ]
.venv/lib/python3.13/site-packages/sympy/printing/pretty/pretty.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/printing/pretty/pretty_symbology.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Symbolic primitives + unicode/ASCII abstraction for pretty.py"""
2
+
3
+ import sys
4
+ import warnings
5
+ from string import ascii_lowercase, ascii_uppercase
6
+ import unicodedata
7
+
8
+ unicode_warnings = ''
9
+
10
+ def U(name):
11
+ """
12
+ Get a unicode character by name or, None if not found.
13
+
14
+ This exists because older versions of Python use older unicode databases.
15
+ """
16
+ try:
17
+ return unicodedata.lookup(name)
18
+ except KeyError:
19
+ global unicode_warnings
20
+ unicode_warnings += 'No \'%s\' in unicodedata\n' % name
21
+ return None
22
+
23
+ from sympy.printing.conventions import split_super_sub
24
+ from sympy.core.alphabets import greeks
25
+ from sympy.utilities.exceptions import sympy_deprecation_warning
26
+
27
+ # prefix conventions when constructing tables
28
+ # L - LATIN i
29
+ # G - GREEK beta
30
+ # D - DIGIT 0
31
+ # S - SYMBOL +
32
+
33
+
34
+ __all__ = ['greek_unicode', 'sub', 'sup', 'xsym', 'vobj', 'hobj', 'pretty_symbol',
35
+ 'annotated', 'center_pad', 'center']
36
+
37
+
38
+ _use_unicode = False
39
+
40
+
41
+ def pretty_use_unicode(flag=None):
42
+ """Set whether pretty-printer should use unicode by default"""
43
+ global _use_unicode, unicode_warnings
44
+ if flag is None:
45
+ return _use_unicode
46
+
47
+ if flag and unicode_warnings:
48
+ # print warnings (if any) on first unicode usage
49
+ warnings.warn(unicode_warnings)
50
+ unicode_warnings = ''
51
+
52
+ use_unicode_prev = _use_unicode
53
+ _use_unicode = flag
54
+ return use_unicode_prev
55
+
56
+
57
+ def pretty_try_use_unicode():
58
+ """See if unicode output is available and leverage it if possible"""
59
+
60
+ encoding = getattr(sys.stdout, 'encoding', None)
61
+
62
+ # this happens when e.g. stdout is redirected through a pipe, or is
63
+ # e.g. a cStringIO.StringO
64
+ if encoding is None:
65
+ return # sys.stdout has no encoding
66
+
67
+ symbols = []
68
+
69
+ # see if we can represent greek alphabet
70
+ symbols += greek_unicode.values()
71
+
72
+ # and atoms
73
+ symbols += atoms_table.values()
74
+
75
+ for s in symbols:
76
+ if s is None:
77
+ return # common symbols not present!
78
+
79
+ try:
80
+ s.encode(encoding)
81
+ except UnicodeEncodeError:
82
+ return
83
+
84
+ # all the characters were present and encodable
85
+ pretty_use_unicode(True)
86
+
87
+
88
+ def xstr(*args):
89
+ sympy_deprecation_warning(
90
+ """
91
+ The sympy.printing.pretty.pretty_symbology.xstr() function is
92
+ deprecated. Use str() instead.
93
+ """,
94
+ deprecated_since_version="1.7",
95
+ active_deprecations_target="deprecated-pretty-printing-functions"
96
+ )
97
+ return str(*args)
98
+
99
+ # GREEK
100
+ g = lambda l: U('GREEK SMALL LETTER %s' % l.upper())
101
+ G = lambda l: U('GREEK CAPITAL LETTER %s' % l.upper())
102
+
103
+ greek_letters = list(greeks) # make a copy
104
+ # deal with Unicode's funny spelling of lambda
105
+ greek_letters[greek_letters.index('lambda')] = 'lamda'
106
+
107
+ # {} greek letter -> (g,G)
108
+ greek_unicode = {L: g(L) for L in greek_letters}
109
+ greek_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_letters)
110
+
111
+ # aliases
112
+ greek_unicode['lambda'] = greek_unicode['lamda']
113
+ greek_unicode['Lambda'] = greek_unicode['Lamda']
114
+ greek_unicode['varsigma'] = '\N{GREEK SMALL LETTER FINAL SIGMA}'
115
+
116
+ # BOLD
117
+ b = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper())
118
+ B = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper())
119
+
120
+ bold_unicode = {l: b(l) for l in ascii_lowercase}
121
+ bold_unicode.update((L, B(L)) for L in ascii_uppercase)
122
+
123
+ # GREEK BOLD
124
+ gb = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper())
125
+ GB = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper())
126
+
127
+ greek_bold_letters = list(greeks) # make a copy, not strictly required here
128
+ # deal with Unicode's funny spelling of lambda
129
+ greek_bold_letters[greek_bold_letters.index('lambda')] = 'lamda'
130
+
131
+ # {} greek letter -> (g,G)
132
+ greek_bold_unicode = {L: g(L) for L in greek_bold_letters}
133
+ greek_bold_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_bold_letters)
134
+ greek_bold_unicode['lambda'] = greek_unicode['lamda']
135
+ greek_bold_unicode['Lambda'] = greek_unicode['Lamda']
136
+ greek_bold_unicode['varsigma'] = '\N{MATHEMATICAL BOLD SMALL FINAL SIGMA}'
137
+
138
+ digit_2txt = {
139
+ '0': 'ZERO',
140
+ '1': 'ONE',
141
+ '2': 'TWO',
142
+ '3': 'THREE',
143
+ '4': 'FOUR',
144
+ '5': 'FIVE',
145
+ '6': 'SIX',
146
+ '7': 'SEVEN',
147
+ '8': 'EIGHT',
148
+ '9': 'NINE',
149
+ }
150
+
151
+ symb_2txt = {
152
+ '+': 'PLUS SIGN',
153
+ '-': 'MINUS',
154
+ '=': 'EQUALS SIGN',
155
+ '(': 'LEFT PARENTHESIS',
156
+ ')': 'RIGHT PARENTHESIS',
157
+ '[': 'LEFT SQUARE BRACKET',
158
+ ']': 'RIGHT SQUARE BRACKET',
159
+ '{': 'LEFT CURLY BRACKET',
160
+ '}': 'RIGHT CURLY BRACKET',
161
+
162
+ # non-std
163
+ '{}': 'CURLY BRACKET',
164
+ 'sum': 'SUMMATION',
165
+ 'int': 'INTEGRAL',
166
+ }
167
+
168
+ # SUBSCRIPT & SUPERSCRIPT
169
+ LSUB = lambda letter: U('LATIN SUBSCRIPT SMALL LETTER %s' % letter.upper())
170
+ GSUB = lambda letter: U('GREEK SUBSCRIPT SMALL LETTER %s' % letter.upper())
171
+ DSUB = lambda digit: U('SUBSCRIPT %s' % digit_2txt[digit])
172
+ SSUB = lambda symb: U('SUBSCRIPT %s' % symb_2txt[symb])
173
+
174
+ LSUP = lambda letter: U('SUPERSCRIPT LATIN SMALL LETTER %s' % letter.upper())
175
+ DSUP = lambda digit: U('SUPERSCRIPT %s' % digit_2txt[digit])
176
+ SSUP = lambda symb: U('SUPERSCRIPT %s' % symb_2txt[symb])
177
+
178
+ sub = {} # symb -> subscript symbol
179
+ sup = {} # symb -> superscript symbol
180
+
181
+ # latin subscripts
182
+ for l in 'aeioruvxhklmnpst':
183
+ sub[l] = LSUB(l)
184
+
185
+ for l in 'in':
186
+ sup[l] = LSUP(l)
187
+
188
+ for gl in ['beta', 'gamma', 'rho', 'phi', 'chi']:
189
+ sub[gl] = GSUB(gl)
190
+
191
+ for d in [str(i) for i in range(10)]:
192
+ sub[d] = DSUB(d)
193
+ sup[d] = DSUP(d)
194
+
195
+ for s in '+-=()':
196
+ sub[s] = SSUB(s)
197
+ sup[s] = SSUP(s)
198
+
199
+ # Variable modifiers
200
+ # TODO: Make brackets adjust to height of contents
201
+ modifier_dict = {
202
+ # Accents
203
+ 'mathring': lambda s: center_accent(s, '\N{COMBINING RING ABOVE}'),
204
+ 'ddddot': lambda s: center_accent(s, '\N{COMBINING FOUR DOTS ABOVE}'),
205
+ 'dddot': lambda s: center_accent(s, '\N{COMBINING THREE DOTS ABOVE}'),
206
+ 'ddot': lambda s: center_accent(s, '\N{COMBINING DIAERESIS}'),
207
+ 'dot': lambda s: center_accent(s, '\N{COMBINING DOT ABOVE}'),
208
+ 'check': lambda s: center_accent(s, '\N{COMBINING CARON}'),
209
+ 'breve': lambda s: center_accent(s, '\N{COMBINING BREVE}'),
210
+ 'acute': lambda s: center_accent(s, '\N{COMBINING ACUTE ACCENT}'),
211
+ 'grave': lambda s: center_accent(s, '\N{COMBINING GRAVE ACCENT}'),
212
+ 'tilde': lambda s: center_accent(s, '\N{COMBINING TILDE}'),
213
+ 'hat': lambda s: center_accent(s, '\N{COMBINING CIRCUMFLEX ACCENT}'),
214
+ 'bar': lambda s: center_accent(s, '\N{COMBINING OVERLINE}'),
215
+ 'vec': lambda s: center_accent(s, '\N{COMBINING RIGHT ARROW ABOVE}'),
216
+ 'prime': lambda s: s+'\N{PRIME}',
217
+ 'prm': lambda s: s+'\N{PRIME}',
218
+ # # Faces -- these are here for some compatibility with latex printing
219
+ # 'bold': lambda s: s,
220
+ # 'bm': lambda s: s,
221
+ # 'cal': lambda s: s,
222
+ # 'scr': lambda s: s,
223
+ # 'frak': lambda s: s,
224
+ # Brackets
225
+ 'norm': lambda s: '\N{DOUBLE VERTICAL LINE}'+s+'\N{DOUBLE VERTICAL LINE}',
226
+ 'avg': lambda s: '\N{MATHEMATICAL LEFT ANGLE BRACKET}'+s+'\N{MATHEMATICAL RIGHT ANGLE BRACKET}',
227
+ 'abs': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}',
228
+ 'mag': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}',
229
+ }
230
+
231
+ # VERTICAL OBJECTS
232
+ HUP = lambda symb: U('%s UPPER HOOK' % symb_2txt[symb])
233
+ CUP = lambda symb: U('%s UPPER CORNER' % symb_2txt[symb])
234
+ MID = lambda symb: U('%s MIDDLE PIECE' % symb_2txt[symb])
235
+ EXT = lambda symb: U('%s EXTENSION' % symb_2txt[symb])
236
+ HLO = lambda symb: U('%s LOWER HOOK' % symb_2txt[symb])
237
+ CLO = lambda symb: U('%s LOWER CORNER' % symb_2txt[symb])
238
+ TOP = lambda symb: U('%s TOP' % symb_2txt[symb])
239
+ BOT = lambda symb: U('%s BOTTOM' % symb_2txt[symb])
240
+
241
+ # {} '(' -> (extension, start, end, middle) 1-character
242
+ _xobj_unicode = {
243
+
244
+ # vertical symbols
245
+ # (( ext, top, bot, mid ), c1)
246
+ '(': (( EXT('('), HUP('('), HLO('(') ), '('),
247
+ ')': (( EXT(')'), HUP(')'), HLO(')') ), ')'),
248
+ '[': (( EXT('['), CUP('['), CLO('[') ), '['),
249
+ ']': (( EXT(']'), CUP(']'), CLO(']') ), ']'),
250
+ '{': (( EXT('{}'), HUP('{'), HLO('{'), MID('{') ), '{'),
251
+ '}': (( EXT('{}'), HUP('}'), HLO('}'), MID('}') ), '}'),
252
+ '|': U('BOX DRAWINGS LIGHT VERTICAL'),
253
+ 'Tee': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'),
254
+ 'UpTack': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL'),
255
+ 'corner_up_centre'
256
+ '(_ext': U('LEFT PARENTHESIS EXTENSION'),
257
+ ')_ext': U('RIGHT PARENTHESIS EXTENSION'),
258
+ '(_lower_hook': U('LEFT PARENTHESIS LOWER HOOK'),
259
+ ')_lower_hook': U('RIGHT PARENTHESIS LOWER HOOK'),
260
+ '(_upper_hook': U('LEFT PARENTHESIS UPPER HOOK'),
261
+ ')_upper_hook': U('RIGHT PARENTHESIS UPPER HOOK'),
262
+ '<': ((U('BOX DRAWINGS LIGHT VERTICAL'),
263
+ U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'),
264
+ U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT')), '<'),
265
+
266
+ '>': ((U('BOX DRAWINGS LIGHT VERTICAL'),
267
+ U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'),
268
+ U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), '>'),
269
+
270
+ 'lfloor': (( EXT('['), EXT('['), CLO('[') ), U('LEFT FLOOR')),
271
+ 'rfloor': (( EXT(']'), EXT(']'), CLO(']') ), U('RIGHT FLOOR')),
272
+ 'lceil': (( EXT('['), CUP('['), EXT('[') ), U('LEFT CEILING')),
273
+ 'rceil': (( EXT(']'), CUP(']'), EXT(']') ), U('RIGHT CEILING')),
274
+
275
+ 'int': (( EXT('int'), U('TOP HALF INTEGRAL'), U('BOTTOM HALF INTEGRAL') ), U('INTEGRAL')),
276
+ 'sum': (( U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), '_', U('OVERLINE'), U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), U('N-ARY SUMMATION')),
277
+
278
+ # horizontal objects
279
+ #'-': '-',
280
+ '-': U('BOX DRAWINGS LIGHT HORIZONTAL'),
281
+ '_': U('LOW LINE'),
282
+ # We used to use this, but LOW LINE looks better for roots, as it's a
283
+ # little lower (i.e., it lines up with the / perfectly. But perhaps this
284
+ # one would still be wanted for some cases?
285
+ # '_': U('HORIZONTAL SCAN LINE-9'),
286
+
287
+ # diagonal objects '\' & '/' ?
288
+ '/': U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'),
289
+ '\\': U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'),
290
+ }
291
+
292
+ _xobj_ascii = {
293
+ # vertical symbols
294
+ # (( ext, top, bot, mid ), c1)
295
+ '(': (( '|', '/', '\\' ), '('),
296
+ ')': (( '|', '\\', '/' ), ')'),
297
+
298
+ # XXX this looks ugly
299
+ # '[': (( '|', '-', '-' ), '['),
300
+ # ']': (( '|', '-', '-' ), ']'),
301
+ # XXX not so ugly :(
302
+ '[': (( '[', '[', '[' ), '['),
303
+ ']': (( ']', ']', ']' ), ']'),
304
+
305
+ '{': (( '|', '/', '\\', '<' ), '{'),
306
+ '}': (( '|', '\\', '/', '>' ), '}'),
307
+ '|': '|',
308
+
309
+ '<': (( '|', '/', '\\' ), '<'),
310
+ '>': (( '|', '\\', '/' ), '>'),
311
+
312
+ 'int': ( ' | ', ' /', '/ ' ),
313
+
314
+ # horizontal objects
315
+ '-': '-',
316
+ '_': '_',
317
+
318
+ # diagonal objects '\' & '/' ?
319
+ '/': '/',
320
+ '\\': '\\',
321
+ }
322
+
323
+
324
+ def xobj(symb, length):
325
+ """Construct spatial object of given length.
326
+
327
+ return: [] of equal-length strings
328
+ """
329
+
330
+ if length <= 0:
331
+ raise ValueError("Length should be greater than 0")
332
+
333
+ # TODO robustify when no unicodedat available
334
+ if _use_unicode:
335
+ _xobj = _xobj_unicode
336
+ else:
337
+ _xobj = _xobj_ascii
338
+
339
+ vinfo = _xobj[symb]
340
+
341
+ c1 = top = bot = mid = None
342
+
343
+ if not isinstance(vinfo, tuple): # 1 entry
344
+ ext = vinfo
345
+ else:
346
+ if isinstance(vinfo[0], tuple): # (vlong), c1
347
+ vlong = vinfo[0]
348
+ c1 = vinfo[1]
349
+ else: # (vlong), c1
350
+ vlong = vinfo
351
+
352
+ ext = vlong[0]
353
+
354
+ try:
355
+ top = vlong[1]
356
+ bot = vlong[2]
357
+ mid = vlong[3]
358
+ except IndexError:
359
+ pass
360
+
361
+ if c1 is None:
362
+ c1 = ext
363
+ if top is None:
364
+ top = ext
365
+ if bot is None:
366
+ bot = ext
367
+ if mid is not None:
368
+ if (length % 2) == 0:
369
+ # even height, but we have to print it somehow anyway...
370
+ # XXX is it ok?
371
+ length += 1
372
+
373
+ else:
374
+ mid = ext
375
+
376
+ if length == 1:
377
+ return c1
378
+
379
+ res = []
380
+ next = (length - 2)//2
381
+ nmid = (length - 2) - next*2
382
+
383
+ res += [top]
384
+ res += [ext]*next
385
+ res += [mid]*nmid
386
+ res += [ext]*next
387
+ res += [bot]
388
+
389
+ return res
390
+
391
+
392
+ def vobj(symb, height):
393
+ """Construct vertical object of a given height
394
+
395
+ see: xobj
396
+ """
397
+ return '\n'.join( xobj(symb, height) )
398
+
399
+
400
+ def hobj(symb, width):
401
+ """Construct horizontal object of a given width
402
+
403
+ see: xobj
404
+ """
405
+ return ''.join( xobj(symb, width) )
406
+
407
+ # RADICAL
408
+ # n -> symbol
409
+ root = {
410
+ 2: U('SQUARE ROOT'), # U('RADICAL SYMBOL BOTTOM')
411
+ 3: U('CUBE ROOT'),
412
+ 4: U('FOURTH ROOT'),
413
+ }
414
+
415
+
416
+ # RATIONAL
417
+ VF = lambda txt: U('VULGAR FRACTION %s' % txt)
418
+
419
+ # (p,q) -> symbol
420
+ frac = {
421
+ (1, 2): VF('ONE HALF'),
422
+ (1, 3): VF('ONE THIRD'),
423
+ (2, 3): VF('TWO THIRDS'),
424
+ (1, 4): VF('ONE QUARTER'),
425
+ (3, 4): VF('THREE QUARTERS'),
426
+ (1, 5): VF('ONE FIFTH'),
427
+ (2, 5): VF('TWO FIFTHS'),
428
+ (3, 5): VF('THREE FIFTHS'),
429
+ (4, 5): VF('FOUR FIFTHS'),
430
+ (1, 6): VF('ONE SIXTH'),
431
+ (5, 6): VF('FIVE SIXTHS'),
432
+ (1, 8): VF('ONE EIGHTH'),
433
+ (3, 8): VF('THREE EIGHTHS'),
434
+ (5, 8): VF('FIVE EIGHTHS'),
435
+ (7, 8): VF('SEVEN EIGHTHS'),
436
+ }
437
+
438
+
439
+ # atom symbols
440
+ _xsym = {
441
+ '==': ('=', '='),
442
+ '<': ('<', '<'),
443
+ '>': ('>', '>'),
444
+ '<=': ('<=', U('LESS-THAN OR EQUAL TO')),
445
+ '>=': ('>=', U('GREATER-THAN OR EQUAL TO')),
446
+ '!=': ('!=', U('NOT EQUAL TO')),
447
+ ':=': (':=', ':='),
448
+ '+=': ('+=', '+='),
449
+ '-=': ('-=', '-='),
450
+ '*=': ('*=', '*='),
451
+ '/=': ('/=', '/='),
452
+ '%=': ('%=', '%='),
453
+ '*': ('*', U('DOT OPERATOR')),
454
+ '-->': ('-->', U('EM DASH') + U('EM DASH') +
455
+ U('BLACK RIGHT-POINTING TRIANGLE') if U('EM DASH')
456
+ and U('BLACK RIGHT-POINTING TRIANGLE') else None),
457
+ '==>': ('==>', U('BOX DRAWINGS DOUBLE HORIZONTAL') +
458
+ U('BOX DRAWINGS DOUBLE HORIZONTAL') +
459
+ U('BLACK RIGHT-POINTING TRIANGLE') if
460
+ U('BOX DRAWINGS DOUBLE HORIZONTAL') and
461
+ U('BOX DRAWINGS DOUBLE HORIZONTAL') and
462
+ U('BLACK RIGHT-POINTING TRIANGLE') else None),
463
+ '.': ('*', U('RING OPERATOR')),
464
+ }
465
+
466
+
467
+ def xsym(sym):
468
+ """get symbology for a 'character'"""
469
+ op = _xsym[sym]
470
+
471
+ if _use_unicode:
472
+ return op[1]
473
+ else:
474
+ return op[0]
475
+
476
+
477
+ # SYMBOLS
478
+
479
+ atoms_table = {
480
+ # class how-to-display
481
+ 'Exp1': U('SCRIPT SMALL E'),
482
+ 'Pi': U('GREEK SMALL LETTER PI'),
483
+ 'Infinity': U('INFINITY'),
484
+ 'NegativeInfinity': U('INFINITY') and ('-' + U('INFINITY')), # XXX what to do here
485
+ #'ImaginaryUnit': U('GREEK SMALL LETTER IOTA'),
486
+ #'ImaginaryUnit': U('MATHEMATICAL ITALIC SMALL I'),
487
+ 'ImaginaryUnit': U('DOUBLE-STRUCK ITALIC SMALL I'),
488
+ 'EmptySet': U('EMPTY SET'),
489
+ 'Naturals': U('DOUBLE-STRUCK CAPITAL N'),
490
+ 'Naturals0': (U('DOUBLE-STRUCK CAPITAL N') and
491
+ (U('DOUBLE-STRUCK CAPITAL N') +
492
+ U('SUBSCRIPT ZERO'))),
493
+ 'Integers': U('DOUBLE-STRUCK CAPITAL Z'),
494
+ 'Rationals': U('DOUBLE-STRUCK CAPITAL Q'),
495
+ 'Reals': U('DOUBLE-STRUCK CAPITAL R'),
496
+ 'Complexes': U('DOUBLE-STRUCK CAPITAL C'),
497
+ 'Universe': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL U'),
498
+ 'IdentityMatrix': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL I'),
499
+ 'ZeroMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ZERO'),
500
+ 'OneMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ONE'),
501
+ 'Differential': U('DOUBLE-STRUCK ITALIC SMALL D'),
502
+ 'Union': U('UNION'),
503
+ 'ElementOf': U('ELEMENT OF'),
504
+ 'SmallElementOf': U('SMALL ELEMENT OF'),
505
+ 'SymmetricDifference': U('INCREMENT'),
506
+ 'Intersection': U('INTERSECTION'),
507
+ 'Ring': U('RING OPERATOR'),
508
+ 'Multiplication': U('MULTIPLICATION SIGN'),
509
+ 'TensorProduct': U('N-ARY CIRCLED TIMES OPERATOR'),
510
+ 'Dots': U('HORIZONTAL ELLIPSIS'),
511
+ 'Modifier Letter Low Ring':U('Modifier Letter Low Ring'),
512
+ 'EmptySequence': 'EmptySequence',
513
+ 'SuperscriptPlus': U('SUPERSCRIPT PLUS SIGN'),
514
+ 'SuperscriptMinus': U('SUPERSCRIPT MINUS'),
515
+ 'Dagger': U('DAGGER'),
516
+ 'Degree': U('DEGREE SIGN'),
517
+ #Logic Symbols
518
+ 'And': U('LOGICAL AND'),
519
+ 'Or': U('LOGICAL OR'),
520
+ 'Not': U('NOT SIGN'),
521
+ 'Nor': U('NOR'),
522
+ 'Nand': U('NAND'),
523
+ 'Xor': U('XOR'),
524
+ 'Equiv': U('LEFT RIGHT DOUBLE ARROW'),
525
+ 'NotEquiv': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'),
526
+ 'Implies': U('LEFT RIGHT DOUBLE ARROW'),
527
+ 'NotImplies': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'),
528
+ 'Arrow': U('RIGHTWARDS ARROW'),
529
+ 'ArrowFromBar': U('RIGHTWARDS ARROW FROM BAR'),
530
+ 'NotArrow': U('RIGHTWARDS ARROW WITH STROKE'),
531
+ 'Tautology': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'),
532
+ 'Contradiction': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL')
533
+ }
534
+
535
+
536
+ def pretty_atom(atom_name, default=None, printer=None):
537
+ """return pretty representation of an atom"""
538
+ if _use_unicode:
539
+ if printer is not None and atom_name == 'ImaginaryUnit' and printer._settings['imaginary_unit'] == 'j':
540
+ return U('DOUBLE-STRUCK ITALIC SMALL J')
541
+ else:
542
+ return atoms_table[atom_name]
543
+ else:
544
+ if default is not None:
545
+ return default
546
+
547
+ raise KeyError('only unicode') # send it default printer
548
+
549
+
550
+ def pretty_symbol(symb_name, bold_name=False):
551
+ """return pretty representation of a symbol"""
552
+ # let's split symb_name into symbol + index
553
+ # UC: beta1
554
+ # UC: f_beta
555
+
556
+ if not _use_unicode:
557
+ return symb_name
558
+
559
+ name, sups, subs = split_super_sub(symb_name)
560
+
561
+ def translate(s, bold_name) :
562
+ if bold_name:
563
+ gG = greek_bold_unicode.get(s)
564
+ else:
565
+ gG = greek_unicode.get(s)
566
+ if gG is not None:
567
+ return gG
568
+ for key in sorted(modifier_dict.keys(), key=lambda k:len(k), reverse=True) :
569
+ if s.lower().endswith(key) and len(s)>len(key):
570
+ return modifier_dict[key](translate(s[:-len(key)], bold_name))
571
+ if bold_name:
572
+ return ''.join([bold_unicode[c] for c in s])
573
+ return s
574
+
575
+ name = translate(name, bold_name)
576
+
577
+ # Let's prettify sups/subs. If it fails at one of them, pretty sups/subs are
578
+ # not used at all.
579
+ def pretty_list(l, mapping):
580
+ result = []
581
+ for s in l:
582
+ pretty = mapping.get(s)
583
+ if pretty is None:
584
+ try: # match by separate characters
585
+ pretty = ''.join([mapping[c] for c in s])
586
+ except (TypeError, KeyError):
587
+ return None
588
+ result.append(pretty)
589
+ return result
590
+
591
+ pretty_sups = pretty_list(sups, sup)
592
+ if pretty_sups is not None:
593
+ pretty_subs = pretty_list(subs, sub)
594
+ else:
595
+ pretty_subs = None
596
+
597
+ # glue the results into one string
598
+ if pretty_subs is None: # nice formatting of sups/subs did not work
599
+ if subs:
600
+ name += '_'+'_'.join([translate(s, bold_name) for s in subs])
601
+ if sups:
602
+ name += '__'+'__'.join([translate(s, bold_name) for s in sups])
603
+ return name
604
+ else:
605
+ sups_result = ' '.join(pretty_sups)
606
+ subs_result = ' '.join(pretty_subs)
607
+
608
+ return ''.join([name, sups_result, subs_result])
609
+
610
+
611
+ def annotated(letter):
612
+ """
613
+ Return a stylised drawing of the letter ``letter``, together with
614
+ information on how to put annotations (super- and subscripts to the
615
+ left and to the right) on it.
616
+
617
+ See pretty.py functions _print_meijerg, _print_hyper on how to use this
618
+ information.
619
+ """
620
+ ucode_pics = {
621
+ 'F': (2, 0, 2, 0, '\N{BOX DRAWINGS LIGHT DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n'
622
+ '\N{BOX DRAWINGS LIGHT VERTICAL AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n'
623
+ '\N{BOX DRAWINGS LIGHT UP}'),
624
+ 'G': (3, 0, 3, 1, '\N{BOX DRAWINGS LIGHT ARC DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC DOWN AND LEFT}\n'
625
+ '\N{BOX DRAWINGS LIGHT VERTICAL}\N{BOX DRAWINGS LIGHT RIGHT}\N{BOX DRAWINGS LIGHT DOWN AND LEFT}\n'
626
+ '\N{BOX DRAWINGS LIGHT ARC UP AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC UP AND LEFT}')
627
+ }
628
+ ascii_pics = {
629
+ 'F': (3, 0, 3, 0, ' _\n|_\n|\n'),
630
+ 'G': (3, 0, 3, 1, ' __\n/__\n\\_|')
631
+ }
632
+
633
+ if _use_unicode:
634
+ return ucode_pics[letter]
635
+ else:
636
+ return ascii_pics[letter]
637
+
638
+ _remove_combining = dict.fromkeys(list(range(ord('\N{COMBINING GRAVE ACCENT}'), ord('\N{COMBINING LATIN SMALL LETTER X}')))
639
+ + list(range(ord('\N{COMBINING LEFT HARPOON ABOVE}'), ord('\N{COMBINING ASTERISK ABOVE}'))))
640
+
641
+ def is_combining(sym):
642
+ """Check whether symbol is a unicode modifier. """
643
+
644
+ return ord(sym) in _remove_combining
645
+
646
+
647
+ def center_accent(string, accent):
648
+ """
649
+ Returns a string with accent inserted on the middle character. Useful to
650
+ put combining accents on symbol names, including multi-character names.
651
+
652
+ Parameters
653
+ ==========
654
+
655
+ string : string
656
+ The string to place the accent in.
657
+ accent : string
658
+ The combining accent to insert
659
+
660
+ References
661
+ ==========
662
+
663
+ .. [1] https://en.wikipedia.org/wiki/Combining_character
664
+ .. [2] https://en.wikipedia.org/wiki/Combining_Diacritical_Marks
665
+
666
+ """
667
+
668
+ # Accent is placed on the previous character, although it may not always look
669
+ # like that depending on console
670
+ midpoint = len(string) // 2 + 1
671
+ firstpart = string[:midpoint]
672
+ secondpart = string[midpoint:]
673
+ return firstpart + accent + secondpart
674
+
675
+
676
+ def line_width(line):
677
+ """Unicode combining symbols (modifiers) are not ever displayed as
678
+ separate symbols and thus should not be counted
679
+ """
680
+ return len(line.translate(_remove_combining))
681
+
682
+
683
+ def is_subscriptable_in_unicode(subscript):
684
+ """
685
+ Checks whether a string is subscriptable in unicode or not.
686
+
687
+ Parameters
688
+ ==========
689
+
690
+ subscript: the string which needs to be checked
691
+
692
+ Examples
693
+ ========
694
+
695
+ >>> from sympy.printing.pretty.pretty_symbology import is_subscriptable_in_unicode
696
+ >>> is_subscriptable_in_unicode('abc')
697
+ False
698
+ >>> is_subscriptable_in_unicode('123')
699
+ True
700
+
701
+ """
702
+ return all(character in sub for character in subscript)
703
+
704
+
705
+ def center_pad(wstring, wtarget, fillchar=' '):
706
+ """
707
+ Return the padding strings necessary to center a string of
708
+ wstring characters wide in a wtarget wide space.
709
+
710
+ The line_width wstring should always be less or equal to wtarget
711
+ or else a ValueError will be raised.
712
+ """
713
+ if wstring > wtarget:
714
+ raise ValueError('not enough space for string')
715
+ wdelta = wtarget - wstring
716
+
717
+ wleft = wdelta // 2 # favor left '1 '
718
+ wright = wdelta - wleft
719
+
720
+ left = fillchar * wleft
721
+ right = fillchar * wright
722
+
723
+ return left, right
724
+
725
+
726
+ def center(string, width, fillchar=' '):
727
+ """Return a centered string of length determined by `line_width`
728
+ that uses `fillchar` for padding.
729
+ """
730
+ left, right = center_pad(line_width(string), width, fillchar)
731
+ return ''.join([left, string, right])
.venv/lib/python3.13/site-packages/sympy/printing/pretty/stringpict.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prettyprinter by Jurjen Bos.
2
+ (I hate spammers: mail me at pietjepuk314 at the reverse of ku.oc.oohay).
3
+ All objects have a method that create a "stringPict",
4
+ that can be used in the str method for pretty printing.
5
+
6
+ Updates by Jason Gedge (email <my last name> at cs mun ca)
7
+ - terminal_string() method
8
+ - minor fixes and changes (mostly to prettyForm)
9
+
10
+ TODO:
11
+ - Allow left/center/right alignment options for above/below and
12
+ top/center/bottom alignment options for left/right
13
+ """
14
+
15
+ import shutil
16
+
17
+ from .pretty_symbology import hobj, vobj, xsym, xobj, pretty_use_unicode, line_width, center
18
+ from sympy.utilities.exceptions import sympy_deprecation_warning
19
+
20
+ _GLOBAL_WRAP_LINE = None
21
+
22
+ class stringPict:
23
+ """An ASCII picture.
24
+ The pictures are represented as a list of equal length strings.
25
+ """
26
+ #special value for stringPict.below
27
+ LINE = 'line'
28
+
29
+ def __init__(self, s, baseline=0):
30
+ """Initialize from string.
31
+ Multiline strings are centered.
32
+ """
33
+ self.s = s
34
+ #picture is a string that just can be printed
35
+ self.picture = stringPict.equalLengths(s.splitlines())
36
+ #baseline is the line number of the "base line"
37
+ self.baseline = baseline
38
+ self.binding = None
39
+
40
+ @staticmethod
41
+ def equalLengths(lines):
42
+ # empty lines
43
+ if not lines:
44
+ return ['']
45
+
46
+ width = max(line_width(line) for line in lines)
47
+ return [center(line, width) for line in lines]
48
+
49
+ def height(self):
50
+ """The height of the picture in characters."""
51
+ return len(self.picture)
52
+
53
+ def width(self):
54
+ """The width of the picture in characters."""
55
+ return line_width(self.picture[0])
56
+
57
+ @staticmethod
58
+ def next(*args):
59
+ """Put a string of stringPicts next to each other.
60
+ Returns string, baseline arguments for stringPict.
61
+ """
62
+ #convert everything to stringPicts
63
+ objects = []
64
+ for arg in args:
65
+ if isinstance(arg, str):
66
+ arg = stringPict(arg)
67
+ objects.append(arg)
68
+
69
+ #make a list of pictures, with equal height and baseline
70
+ newBaseline = max(obj.baseline for obj in objects)
71
+ newHeightBelowBaseline = max(
72
+ obj.height() - obj.baseline
73
+ for obj in objects)
74
+ newHeight = newBaseline + newHeightBelowBaseline
75
+
76
+ pictures = []
77
+ for obj in objects:
78
+ oneEmptyLine = [' '*obj.width()]
79
+ basePadding = newBaseline - obj.baseline
80
+ totalPadding = newHeight - obj.height()
81
+ pictures.append(
82
+ oneEmptyLine * basePadding +
83
+ obj.picture +
84
+ oneEmptyLine * (totalPadding - basePadding))
85
+
86
+ result = [''.join(lines) for lines in zip(*pictures)]
87
+ return '\n'.join(result), newBaseline
88
+
89
+ def right(self, *args):
90
+ r"""Put pictures next to this one.
91
+ Returns string, baseline arguments for stringPict.
92
+ (Multiline) strings are allowed, and are given a baseline of 0.
93
+
94
+ Examples
95
+ ========
96
+
97
+ >>> from sympy.printing.pretty.stringpict import stringPict
98
+ >>> print(stringPict("10").right(" + ",stringPict("1\r-\r2",1))[0])
99
+ 1
100
+ 10 + -
101
+ 2
102
+
103
+ """
104
+ return stringPict.next(self, *args)
105
+
106
+ def left(self, *args):
107
+ """Put pictures (left to right) at left.
108
+ Returns string, baseline arguments for stringPict.
109
+ """
110
+ return stringPict.next(*(args + (self,)))
111
+
112
+ @staticmethod
113
+ def stack(*args):
114
+ """Put pictures on top of each other,
115
+ from top to bottom.
116
+ Returns string, baseline arguments for stringPict.
117
+ The baseline is the baseline of the second picture.
118
+ Everything is centered.
119
+ Baseline is the baseline of the second picture.
120
+ Strings are allowed.
121
+ The special value stringPict.LINE is a row of '-' extended to the width.
122
+ """
123
+ #convert everything to stringPicts; keep LINE
124
+ objects = []
125
+ for arg in args:
126
+ if arg is not stringPict.LINE and isinstance(arg, str):
127
+ arg = stringPict(arg)
128
+ objects.append(arg)
129
+
130
+ #compute new width
131
+ newWidth = max(
132
+ obj.width()
133
+ for obj in objects
134
+ if obj is not stringPict.LINE)
135
+
136
+ lineObj = stringPict(hobj('-', newWidth))
137
+
138
+ #replace LINE with proper lines
139
+ for i, obj in enumerate(objects):
140
+ if obj is stringPict.LINE:
141
+ objects[i] = lineObj
142
+
143
+ #stack the pictures, and center the result
144
+ newPicture = [center(line, newWidth) for obj in objects for line in obj.picture]
145
+ newBaseline = objects[0].height() + objects[1].baseline
146
+ return '\n'.join(newPicture), newBaseline
147
+
148
+ def below(self, *args):
149
+ """Put pictures under this picture.
150
+ Returns string, baseline arguments for stringPict.
151
+ Baseline is baseline of top picture
152
+
153
+ Examples
154
+ ========
155
+
156
+ >>> from sympy.printing.pretty.stringpict import stringPict
157
+ >>> print(stringPict("x+3").below(
158
+ ... stringPict.LINE, '3')[0]) #doctest: +NORMALIZE_WHITESPACE
159
+ x+3
160
+ ---
161
+ 3
162
+
163
+ """
164
+ s, baseline = stringPict.stack(self, *args)
165
+ return s, self.baseline
166
+
167
+ def above(self, *args):
168
+ """Put pictures above this picture.
169
+ Returns string, baseline arguments for stringPict.
170
+ Baseline is baseline of bottom picture.
171
+ """
172
+ string, baseline = stringPict.stack(*(args + (self,)))
173
+ baseline = len(string.splitlines()) - self.height() + self.baseline
174
+ return string, baseline
175
+
176
+ def parens(self, left='(', right=')', ifascii_nougly=False):
177
+ """Put parentheses around self.
178
+ Returns string, baseline arguments for stringPict.
179
+
180
+ left or right can be None or empty string which means 'no paren from
181
+ that side'
182
+ """
183
+ h = self.height()
184
+ b = self.baseline
185
+
186
+ # XXX this is a hack -- ascii parens are ugly!
187
+ if ifascii_nougly and not pretty_use_unicode():
188
+ h = 1
189
+ b = 0
190
+
191
+ res = self
192
+
193
+ if left:
194
+ lparen = stringPict(vobj(left, h), baseline=b)
195
+ res = stringPict(*lparen.right(self))
196
+ if right:
197
+ rparen = stringPict(vobj(right, h), baseline=b)
198
+ res = stringPict(*res.right(rparen))
199
+
200
+ return ('\n'.join(res.picture), res.baseline)
201
+
202
+ def leftslash(self):
203
+ """Precede object by a slash of the proper size.
204
+ """
205
+ # XXX not used anywhere ?
206
+ height = max(
207
+ self.baseline,
208
+ self.height() - 1 - self.baseline)*2 + 1
209
+ slash = '\n'.join(
210
+ ' '*(height - i - 1) + xobj('/', 1) + ' '*i
211
+ for i in range(height)
212
+ )
213
+ return self.left(stringPict(slash, height//2))
214
+
215
+ def root(self, n=None):
216
+ """Produce a nice root symbol.
217
+ Produces ugly results for big n inserts.
218
+ """
219
+ # XXX not used anywhere
220
+ # XXX duplicate of root drawing in pretty.py
221
+ #put line over expression
222
+ result = self.above('_'*self.width())
223
+ #construct right half of root symbol
224
+ height = self.height()
225
+ slash = '\n'.join(
226
+ ' ' * (height - i - 1) + '/' + ' ' * i
227
+ for i in range(height)
228
+ )
229
+ slash = stringPict(slash, height - 1)
230
+ #left half of root symbol
231
+ if height > 2:
232
+ downline = stringPict('\\ \n \\', 1)
233
+ else:
234
+ downline = stringPict('\\')
235
+ #put n on top, as low as possible
236
+ if n is not None and n.width() > downline.width():
237
+ downline = downline.left(' '*(n.width() - downline.width()))
238
+ downline = downline.above(n)
239
+ #build root symbol
240
+ root = downline.right(slash)
241
+ #glue it on at the proper height
242
+ #normally, the root symbel is as high as self
243
+ #which is one less than result
244
+ #this moves the root symbol one down
245
+ #if the root became higher, the baseline has to grow too
246
+ root.baseline = result.baseline - result.height() + root.height()
247
+ return result.left(root)
248
+
249
+ def render(self, * args, **kwargs):
250
+ """Return the string form of self.
251
+
252
+ Unless the argument line_break is set to False, it will
253
+ break the expression in a form that can be printed
254
+ on the terminal without being broken up.
255
+ """
256
+ if _GLOBAL_WRAP_LINE is not None:
257
+ kwargs["wrap_line"] = _GLOBAL_WRAP_LINE
258
+
259
+ if kwargs["wrap_line"] is False:
260
+ return "\n".join(self.picture)
261
+
262
+ if kwargs["num_columns"] is not None:
263
+ # Read the argument num_columns if it is not None
264
+ ncols = kwargs["num_columns"]
265
+ else:
266
+ # Attempt to get a terminal width
267
+ ncols = self.terminal_width()
268
+
269
+ if ncols <= 0:
270
+ ncols = 80
271
+
272
+ # If smaller than the terminal width, no need to correct
273
+ if self.width() <= ncols:
274
+ return type(self.picture[0])(self)
275
+
276
+ """
277
+ Break long-lines in a visually pleasing format.
278
+ without overflow indicators | with overflow indicators
279
+ | 2 2 3 | | 2 2 3 ↪|
280
+ |6*x *y + 4*x*y + | |6*x *y + 4*x*y + ↪|
281
+ | | | |
282
+ | 3 4 4 | |↪ 3 4 4 |
283
+ |4*y*x + x + y | |↪ 4*y*x + x + y |
284
+ |a*c*e + a*c*f + a*d | |a*c*e + a*c*f + a*d ↪|
285
+ |*e + a*d*f + b*c*e | | |
286
+ |+ b*c*f + b*d*e + b | |↪ *e + a*d*f + b*c* ↪|
287
+ |*d*f | | |
288
+ | | |↪ e + b*c*f + b*d*e ↪|
289
+ | | | |
290
+ | | |↪ + b*d*f |
291
+ """
292
+
293
+ overflow_first = ""
294
+ if kwargs["use_unicode"] or pretty_use_unicode():
295
+ overflow_start = "\N{RIGHTWARDS ARROW WITH HOOK} "
296
+ overflow_end = " \N{RIGHTWARDS ARROW WITH HOOK}"
297
+ else:
298
+ overflow_start = "> "
299
+ overflow_end = " >"
300
+
301
+ def chunks(line):
302
+ """Yields consecutive chunks of line_width ncols"""
303
+ prefix = overflow_first
304
+ width, start = line_width(prefix + overflow_end), 0
305
+ for i, x in enumerate(line):
306
+ wx = line_width(x)
307
+ # Only flush the screen when the current character overflows.
308
+ # This way, combining marks can be appended even when width == ncols.
309
+ if width + wx > ncols:
310
+ yield prefix + line[start:i] + overflow_end
311
+ prefix = overflow_start
312
+ width, start = line_width(prefix + overflow_end), i
313
+ width += wx
314
+ yield prefix + line[start:]
315
+
316
+ # Concurrently assemble chunks of all lines into individual screens
317
+ pictures = zip(*map(chunks, self.picture))
318
+
319
+ # Join lines of each screen into sub-pictures
320
+ pictures = ["\n".join(picture) for picture in pictures]
321
+
322
+ # Add spacers between sub-pictures
323
+ return "\n\n".join(pictures)
324
+
325
+ def terminal_width(self):
326
+ """Return the terminal width if possible, otherwise return 0.
327
+ """
328
+ size = shutil.get_terminal_size(fallback=(0, 0))
329
+ return size.columns
330
+
331
+ def __eq__(self, o):
332
+ if isinstance(o, str):
333
+ return '\n'.join(self.picture) == o
334
+ elif isinstance(o, stringPict):
335
+ return o.picture == self.picture
336
+ return False
337
+
338
+ def __hash__(self):
339
+ return super().__hash__()
340
+
341
+ def __str__(self):
342
+ return '\n'.join(self.picture)
343
+
344
+ def __repr__(self):
345
+ return "stringPict(%r,%d)" % ('\n'.join(self.picture), self.baseline)
346
+
347
+ def __getitem__(self, index):
348
+ return self.picture[index]
349
+
350
+ def __len__(self):
351
+ return len(self.s)
352
+
353
+
354
+ class prettyForm(stringPict):
355
+ """
356
+ Extension of the stringPict class that knows about basic math applications,
357
+ optimizing double minus signs.
358
+
359
+ "Binding" is interpreted as follows::
360
+
361
+ ATOM this is an atom: never needs to be parenthesized
362
+ FUNC this is a function application: parenthesize if added (?)
363
+ DIV this is a division: make wider division if divided
364
+ POW this is a power: only parenthesize if exponent
365
+ MUL this is a multiplication: parenthesize if powered
366
+ ADD this is an addition: parenthesize if multiplied or powered
367
+ NEG this is a negative number: optimize if added, parenthesize if
368
+ multiplied or powered
369
+ OPEN this is an open object: parenthesize if added, multiplied, or
370
+ powered (example: Piecewise)
371
+ """
372
+ ATOM, FUNC, DIV, POW, MUL, ADD, NEG, OPEN = range(8)
373
+
374
+ def __init__(self, s, baseline=0, binding=0, unicode=None):
375
+ """Initialize from stringPict and binding power."""
376
+ stringPict.__init__(self, s, baseline)
377
+ self.binding = binding
378
+ if unicode is not None:
379
+ sympy_deprecation_warning(
380
+ """
381
+ The unicode argument to prettyForm is deprecated. Only the s
382
+ argument (the first positional argument) should be passed.
383
+ """,
384
+ deprecated_since_version="1.7",
385
+ active_deprecations_target="deprecated-pretty-printing-functions")
386
+ self._unicode = unicode or s
387
+
388
+ @property
389
+ def unicode(self):
390
+ sympy_deprecation_warning(
391
+ """
392
+ The prettyForm.unicode attribute is deprecated. Use the
393
+ prettyForm.s attribute instead.
394
+ """,
395
+ deprecated_since_version="1.7",
396
+ active_deprecations_target="deprecated-pretty-printing-functions")
397
+ return self._unicode
398
+
399
+ # Note: code to handle subtraction is in _print_Add
400
+
401
+ def __add__(self, *others):
402
+ """Make a pretty addition.
403
+ Addition of negative numbers is simplified.
404
+ """
405
+ arg = self
406
+ if arg.binding > prettyForm.NEG:
407
+ arg = stringPict(*arg.parens())
408
+ result = [arg]
409
+ for arg in others:
410
+ #add parentheses for weak binders
411
+ if arg.binding > prettyForm.NEG:
412
+ arg = stringPict(*arg.parens())
413
+ #use existing minus sign if available
414
+ if arg.binding != prettyForm.NEG:
415
+ result.append(' + ')
416
+ result.append(arg)
417
+ return prettyForm(binding=prettyForm.ADD, *stringPict.next(*result))
418
+
419
+ def __truediv__(self, den, slashed=False):
420
+ """Make a pretty division; stacked or slashed.
421
+ """
422
+ if slashed:
423
+ raise NotImplementedError("Can't do slashed fraction yet")
424
+ num = self
425
+ if num.binding == prettyForm.DIV:
426
+ num = stringPict(*num.parens())
427
+ if den.binding == prettyForm.DIV:
428
+ den = stringPict(*den.parens())
429
+
430
+ if num.binding==prettyForm.NEG:
431
+ num = num.right(" ")[0]
432
+
433
+ return prettyForm(binding=prettyForm.DIV, *stringPict.stack(
434
+ num,
435
+ stringPict.LINE,
436
+ den))
437
+
438
+ def __mul__(self, *others):
439
+ """Make a pretty multiplication.
440
+ Parentheses are needed around +, - and neg.
441
+ """
442
+ quantity = {
443
+ 'degree': "\N{DEGREE SIGN}"
444
+ }
445
+
446
+ if len(others) == 0:
447
+ return self # We aren't actually multiplying... So nothing to do here.
448
+
449
+ # add parens on args that need them
450
+ arg = self
451
+ if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG:
452
+ arg = stringPict(*arg.parens())
453
+ result = [arg]
454
+ for arg in others:
455
+ if arg.picture[0] not in quantity.values():
456
+ result.append(xsym('*'))
457
+ #add parentheses for weak binders
458
+ if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG:
459
+ arg = stringPict(*arg.parens())
460
+ result.append(arg)
461
+
462
+ len_res = len(result)
463
+ for i in range(len_res):
464
+ if i < len_res - 1 and result[i] == '-1' and result[i + 1] == xsym('*'):
465
+ # substitute -1 by -, like in -1*x -> -x
466
+ result.pop(i)
467
+ result.pop(i)
468
+ result.insert(i, '-')
469
+ if result[0][0] == '-':
470
+ # if there is a - sign in front of all
471
+ # This test was failing to catch a prettyForm.__mul__(prettyForm("-1", 0, 6)) being negative
472
+ bin = prettyForm.NEG
473
+ if result[0] == '-':
474
+ right = result[1]
475
+ if right.picture[right.baseline][0] == '-':
476
+ result[0] = '- '
477
+ else:
478
+ bin = prettyForm.MUL
479
+ return prettyForm(binding=bin, *stringPict.next(*result))
480
+
481
+ def __repr__(self):
482
+ return "prettyForm(%r,%d,%d)" % (
483
+ '\n'.join(self.picture),
484
+ self.baseline,
485
+ self.binding)
486
+
487
+ def __pow__(self, b):
488
+ """Make a pretty power.
489
+ """
490
+ a = self
491
+ use_inline_func_form = False
492
+ if b.binding == prettyForm.POW:
493
+ b = stringPict(*b.parens())
494
+ if a.binding > prettyForm.FUNC:
495
+ a = stringPict(*a.parens())
496
+ elif a.binding == prettyForm.FUNC:
497
+ # heuristic for when to use inline power
498
+ if b.height() > 1:
499
+ a = stringPict(*a.parens())
500
+ else:
501
+ use_inline_func_form = True
502
+
503
+ if use_inline_func_form:
504
+ # 2
505
+ # sin + + (x)
506
+ b.baseline = a.prettyFunc.baseline + b.height()
507
+ func = stringPict(*a.prettyFunc.right(b))
508
+ return prettyForm(*func.right(a.prettyArgs))
509
+ else:
510
+ # 2 <-- top
511
+ # (x+y) <-- bot
512
+ top = stringPict(*b.left(' '*a.width()))
513
+ bot = stringPict(*a.right(' '*b.width()))
514
+
515
+ return prettyForm(binding=prettyForm.POW, *bot.above(top))
516
+
517
+ simpleFunctions = ["sin", "cos", "tan"]
518
+
519
+ @staticmethod
520
+ def apply(function, *args):
521
+ """Functions of one or more variables.
522
+ """
523
+ if function in prettyForm.simpleFunctions:
524
+ #simple function: use only space if possible
525
+ assert len(
526
+ args) == 1, "Simple function %s must have 1 argument" % function
527
+ arg = args[0].__pretty__()
528
+ if arg.binding <= prettyForm.DIV:
529
+ #optimization: no parentheses necessary
530
+ return prettyForm(binding=prettyForm.FUNC, *arg.left(function + ' '))
531
+ argumentList = []
532
+ for arg in args:
533
+ argumentList.append(',')
534
+ argumentList.append(arg.__pretty__())
535
+ argumentList = stringPict(*stringPict.next(*argumentList[1:]))
536
+ argumentList = stringPict(*argumentList.parens())
537
+ return prettyForm(binding=prettyForm.ATOM, *argumentList.left(function))
.venv/lib/python3.13/site-packages/sympy/printing/pretty/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/printing/pretty/tests/test_pretty.py ADDED
The diff for this file is too large to render. See raw diff
 
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_fortran.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.add import Add
2
+ from sympy.core.expr import Expr
3
+ from sympy.core.function import (Function, Lambda, diff)
4
+ from sympy.core.mod import Mod
5
+ from sympy.core import (Catalan, EulerGamma, GoldenRatio)
6
+ from sympy.core.numbers import (E, Float, I, Integer, Rational, pi)
7
+ from sympy.core.relational import Eq
8
+ from sympy.core.singleton import S
9
+ from sympy.core.symbol import (Dummy, symbols)
10
+ from sympy.functions.combinatorial.factorials import factorial
11
+ from sympy.functions.elementary.complexes import (conjugate, sign)
12
+ from sympy.functions.elementary.exponential import (exp, log)
13
+ from sympy.functions.elementary.miscellaneous import sqrt
14
+ from sympy.functions.elementary.piecewise import Piecewise
15
+ from sympy.functions.elementary.trigonometric import (atan2, cos, sin)
16
+ from sympy.functions.special.gamma_functions import gamma
17
+ from sympy.integrals.integrals import Integral
18
+ from sympy.sets.fancysets import Range
19
+
20
+ from sympy.codegen import For, Assignment, aug_assign
21
+ from sympy.codegen.ast import Declaration, Variable, float32, float64, \
22
+ value_const, real, bool_, While, FunctionPrototype, FunctionDefinition, \
23
+ integer, Return, Element
24
+ from sympy.core.expr import UnevaluatedExpr
25
+ from sympy.core.relational import Relational
26
+ from sympy.logic.boolalg import And, Or, Not, Equivalent, Xor
27
+ from sympy.matrices import Matrix, MatrixSymbol
28
+ from sympy.printing.fortran import fcode, FCodePrinter
29
+ from sympy.tensor import IndexedBase, Idx
30
+ from sympy.tensor.array.expressions import ArraySymbol, ArrayElement
31
+ from sympy.utilities.lambdify import implemented_function
32
+ from sympy.testing.pytest import raises
33
+
34
+
35
+ def test_UnevaluatedExpr():
36
+ p, q, r = symbols("p q r", real=True)
37
+ q_r = UnevaluatedExpr(q + r)
38
+ expr = abs(exp(p+q_r))
39
+ assert fcode(expr, source_format="free") == "exp(p + (q + r))"
40
+ x, y, z = symbols("x y z")
41
+ y_z = UnevaluatedExpr(y + z)
42
+ expr2 = abs(exp(x+y_z))
43
+ assert fcode(expr2, human=False)[2].lstrip() == "exp(re(x) + re(y + z))"
44
+ assert fcode(expr2, user_functions={"re": "realpart"}).lstrip() == "exp(realpart(x) + realpart(y + z))"
45
+
46
+
47
+ def test_printmethod():
48
+ x = symbols('x')
49
+
50
+ class nint(Function):
51
+ def _fcode(self, printer):
52
+ return "nint(%s)" % printer._print(self.args[0])
53
+ assert fcode(nint(x)) == " nint(x)"
54
+
55
+
56
+ def test_fcode_sign(): #issue 12267
57
+ x=symbols('x')
58
+ y=symbols('y', integer=True)
59
+ z=symbols('z', complex=True)
60
+ assert fcode(sign(x), standard=95, source_format='free') == "merge(0d0, dsign(1d0, x), x == 0d0)"
61
+ assert fcode(sign(y), standard=95, source_format='free') == "merge(0, isign(1, y), y == 0)"
62
+ assert fcode(sign(z), standard=95, source_format='free') == "merge(cmplx(0d0, 0d0), z/abs(z), abs(z) == 0d0)"
63
+ raises(NotImplementedError, lambda: fcode(sign(x)))
64
+
65
+
66
+ def test_fcode_Pow():
67
+ x, y = symbols('x,y')
68
+ n = symbols('n', integer=True)
69
+
70
+ assert fcode(x**3) == " x**3"
71
+ assert fcode(x**(y**3)) == " x**(y**3)"
72
+ assert fcode(1/(sin(x)*3.5)**(x - y**x)/(x**2 + y)) == \
73
+ " (3.5d0*sin(x))**(-x + y**x)/(x**2 + y)"
74
+ assert fcode(sqrt(x)) == ' sqrt(x)'
75
+ assert fcode(sqrt(n)) == ' sqrt(dble(n))'
76
+ assert fcode(x**0.5) == ' sqrt(x)'
77
+ assert fcode(sqrt(x)) == ' sqrt(x)'
78
+ assert fcode(sqrt(10)) == ' sqrt(10.0d0)'
79
+ assert fcode(x**-1.0) == ' 1d0/x'
80
+ assert fcode(x**-2.0, 'y', source_format='free') == 'y = x**(-2.0d0)' # 2823
81
+ assert fcode(x**Rational(3, 7)) == ' x**(3.0d0/7.0d0)'
82
+
83
+
84
+ def test_fcode_Rational():
85
+ x = symbols('x')
86
+ assert fcode(Rational(3, 7)) == " 3.0d0/7.0d0"
87
+ assert fcode(Rational(18, 9)) == " 2"
88
+ assert fcode(Rational(3, -7)) == " -3.0d0/7.0d0"
89
+ assert fcode(Rational(-3, -7)) == " 3.0d0/7.0d0"
90
+ assert fcode(x + Rational(3, 7)) == " x + 3.0d0/7.0d0"
91
+ assert fcode(Rational(3, 7)*x) == " (3.0d0/7.0d0)*x"
92
+
93
+
94
+ def test_fcode_Integer():
95
+ assert fcode(Integer(67)) == " 67"
96
+ assert fcode(Integer(-1)) == " -1"
97
+
98
+
99
+ def test_fcode_Float():
100
+ assert fcode(Float(42.0)) == " 42.0000000000000d0"
101
+ assert fcode(Float(-1e20)) == " -1.00000000000000d+20"
102
+
103
+
104
+ def test_fcode_functions():
105
+ x, y = symbols('x,y')
106
+ assert fcode(sin(x) ** cos(y)) == " sin(x)**cos(y)"
107
+ raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=66))
108
+ raises(NotImplementedError, lambda: fcode(x % y, standard=66))
109
+ raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=77))
110
+ raises(NotImplementedError, lambda: fcode(x % y, standard=77))
111
+ for standard in [90, 95, 2003, 2008]:
112
+ assert fcode(Mod(x, y), standard=standard) == " modulo(x, y)"
113
+ assert fcode(x % y, standard=standard) == " modulo(x, y)"
114
+
115
+
116
+ def test_case():
117
+ ob = FCodePrinter()
118
+ x,x_,x__,y,X,X_,Y = symbols('x,x_,x__,y,X,X_,Y')
119
+ assert fcode(exp(x_) + sin(x*y) + cos(X*Y)) == \
120
+ ' exp(x_) + sin(x*y) + cos(X__*Y_)'
121
+ assert fcode(exp(x__) + 2*x*Y*X_**Rational(7, 2)) == \
122
+ ' 2*X_**(7.0d0/2.0d0)*Y*x + exp(x__)'
123
+ assert fcode(exp(x_) + sin(x*y) + cos(X*Y), name_mangling=False) == \
124
+ ' exp(x_) + sin(x*y) + cos(X*Y)'
125
+ assert fcode(x - cos(X), name_mangling=False) == ' x - cos(X)'
126
+ assert ob.doprint(X*sin(x) + x_, assign_to='me') == ' me = X*sin(x_) + x__'
127
+ assert ob.doprint(X*sin(x), assign_to='mu') == ' mu = X*sin(x_)'
128
+ assert ob.doprint(x_, assign_to='ad') == ' ad = x__'
129
+ n, m = symbols('n,m', integer=True)
130
+ A = IndexedBase('A')
131
+ x = IndexedBase('x')
132
+ y = IndexedBase('y')
133
+ i = Idx('i', m)
134
+ I = Idx('I', n)
135
+ assert fcode(A[i, I]*x[I], assign_to=y[i], source_format='free') == (
136
+ "do i = 1, m\n"
137
+ " y(i) = 0\n"
138
+ "end do\n"
139
+ "do i = 1, m\n"
140
+ " do I_ = 1, n\n"
141
+ " y(i) = A(i, I_)*x(I_) + y(i)\n"
142
+ " end do\n"
143
+ "end do" )
144
+
145
+
146
+ #issue 6814
147
+ def test_fcode_functions_with_integers():
148
+ x= symbols('x')
149
+ log10_17 = log(10).evalf(17)
150
+ loglog10_17 = '0.8340324452479558d0'
151
+ assert fcode(x * log(10)) == " x*%sd0" % log10_17
152
+ assert fcode(x * log(10)) == " x*%sd0" % log10_17
153
+ assert fcode(x * log(S(10))) == " x*%sd0" % log10_17
154
+ assert fcode(log(S(10))) == " %sd0" % log10_17
155
+ assert fcode(exp(10)) == " %sd0" % exp(10).evalf(17)
156
+ assert fcode(x * log(log(10))) == " x*%s" % loglog10_17
157
+ assert fcode(x * log(log(S(10)))) == " x*%s" % loglog10_17
158
+
159
+
160
+ def test_fcode_NumberSymbol():
161
+ prec = 17
162
+ p = FCodePrinter()
163
+ assert fcode(Catalan) == ' parameter (Catalan = %sd0)\n Catalan' % Catalan.evalf(prec)
164
+ assert fcode(EulerGamma) == ' parameter (EulerGamma = %sd0)\n EulerGamma' % EulerGamma.evalf(prec)
165
+ assert fcode(E) == ' parameter (E = %sd0)\n E' % E.evalf(prec)
166
+ assert fcode(GoldenRatio) == ' parameter (GoldenRatio = %sd0)\n GoldenRatio' % GoldenRatio.evalf(prec)
167
+ assert fcode(pi) == ' parameter (pi = %sd0)\n pi' % pi.evalf(prec)
168
+ assert fcode(
169
+ pi, precision=5) == ' parameter (pi = %sd0)\n pi' % pi.evalf(5)
170
+ assert fcode(Catalan, human=False) == ({
171
+ (Catalan, p._print(Catalan.evalf(prec)))}, set(), ' Catalan')
172
+ assert fcode(EulerGamma, human=False) == ({(EulerGamma, p._print(
173
+ EulerGamma.evalf(prec)))}, set(), ' EulerGamma')
174
+ assert fcode(E, human=False) == (
175
+ {(E, p._print(E.evalf(prec)))}, set(), ' E')
176
+ assert fcode(GoldenRatio, human=False) == ({(GoldenRatio, p._print(
177
+ GoldenRatio.evalf(prec)))}, set(), ' GoldenRatio')
178
+ assert fcode(pi, human=False) == (
179
+ {(pi, p._print(pi.evalf(prec)))}, set(), ' pi')
180
+ assert fcode(pi, precision=5, human=False) == (
181
+ {(pi, p._print(pi.evalf(5)))}, set(), ' pi')
182
+
183
+
184
+ def test_fcode_complex():
185
+ assert fcode(I) == " cmplx(0,1)"
186
+ x = symbols('x')
187
+ assert fcode(4*I) == " cmplx(0,4)"
188
+ assert fcode(3 + 4*I) == " cmplx(3,4)"
189
+ assert fcode(3 + 4*I + x) == " cmplx(3,4) + x"
190
+ assert fcode(I*x) == " cmplx(0,1)*x"
191
+ assert fcode(3 + 4*I - x) == " cmplx(3,4) - x"
192
+ x = symbols('x', imaginary=True)
193
+ assert fcode(5*x) == " 5*x"
194
+ assert fcode(I*x) == " cmplx(0,1)*x"
195
+ assert fcode(3 + x) == " x + 3"
196
+
197
+
198
+ def test_implicit():
199
+ x, y = symbols('x,y')
200
+ assert fcode(sin(x)) == " sin(x)"
201
+ assert fcode(atan2(x, y)) == " atan2(x, y)"
202
+ assert fcode(conjugate(x)) == " conjg(x)"
203
+
204
+
205
+ def test_not_fortran():
206
+ x = symbols('x')
207
+ g = Function('g')
208
+ with raises(NotImplementedError):
209
+ fcode(gamma(x))
210
+ assert fcode(Integral(sin(x)), strict=False) == "C Not supported in Fortran:\nC Integral\n Integral(sin(x), x)"
211
+ with raises(NotImplementedError):
212
+ fcode(g(x))
213
+
214
+
215
+ def test_user_functions():
216
+ x = symbols('x')
217
+ assert fcode(sin(x), user_functions={"sin": "zsin"}) == " zsin(x)"
218
+ x = symbols('x')
219
+ assert fcode(
220
+ gamma(x), user_functions={"gamma": "mygamma"}) == " mygamma(x)"
221
+ g = Function('g')
222
+ assert fcode(g(x), user_functions={"g": "great"}) == " great(x)"
223
+ n = symbols('n', integer=True)
224
+ assert fcode(
225
+ factorial(n), user_functions={"factorial": "fct"}) == " fct(n)"
226
+
227
+
228
+ def test_inline_function():
229
+ x = symbols('x')
230
+ g = implemented_function('g', Lambda(x, 2*x))
231
+ assert fcode(g(x)) == " 2*x"
232
+ g = implemented_function('g', Lambda(x, 2*pi/x))
233
+ assert fcode(g(x)) == (
234
+ " parameter (pi = %sd0)\n"
235
+ " 2*pi/x"
236
+ ) % pi.evalf(17)
237
+ A = IndexedBase('A')
238
+ i = Idx('i', symbols('n', integer=True))
239
+ g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
240
+ assert fcode(g(A[i]), assign_to=A[i]) == (
241
+ " do i = 1, n\n"
242
+ " A(i) = (A(i) + 1)*(A(i) + 2)*A(i)\n"
243
+ " end do"
244
+ )
245
+
246
+
247
+ def test_assign_to():
248
+ x = symbols('x')
249
+ assert fcode(sin(x), assign_to="s") == " s = sin(x)"
250
+
251
+
252
+ def test_line_wrapping():
253
+ x, y = symbols('x,y')
254
+ assert fcode(((x + y)**10).expand(), assign_to="var") == (
255
+ " var = x**10 + 10*x**9*y + 45*x**8*y**2 + 120*x**7*y**3 + 210*x**6*\n"
256
+ " @ y**4 + 252*x**5*y**5 + 210*x**4*y**6 + 120*x**3*y**7 + 45*x**2*y\n"
257
+ " @ **8 + 10*x*y**9 + y**10"
258
+ )
259
+ e = [x**i for i in range(11)]
260
+ assert fcode(Add(*e)) == (
261
+ " x**10 + x**9 + x**8 + x**7 + x**6 + x**5 + x**4 + x**3 + x**2 + x\n"
262
+ " @ + 1"
263
+ )
264
+
265
+
266
+ def test_fcode_precedence():
267
+ x, y = symbols("x y")
268
+ assert fcode(And(x < y, y < x + 1), source_format="free") == \
269
+ "x < y .and. y < x + 1"
270
+ assert fcode(Or(x < y, y < x + 1), source_format="free") == \
271
+ "x < y .or. y < x + 1"
272
+ assert fcode(Xor(x < y, y < x + 1, evaluate=False),
273
+ source_format="free") == "x < y .neqv. y < x + 1"
274
+ assert fcode(Equivalent(x < y, y < x + 1), source_format="free") == \
275
+ "x < y .eqv. y < x + 1"
276
+
277
+
278
+ def test_fcode_Logical():
279
+ x, y, z = symbols("x y z")
280
+ # unary Not
281
+ assert fcode(Not(x), source_format="free") == ".not. x"
282
+ # binary And
283
+ assert fcode(And(x, y), source_format="free") == "x .and. y"
284
+ assert fcode(And(x, Not(y)), source_format="free") == "x .and. .not. y"
285
+ assert fcode(And(Not(x), y), source_format="free") == "y .and. .not. x"
286
+ assert fcode(And(Not(x), Not(y)), source_format="free") == \
287
+ ".not. x .and. .not. y"
288
+ assert fcode(Not(And(x, y), evaluate=False), source_format="free") == \
289
+ ".not. (x .and. y)"
290
+ # binary Or
291
+ assert fcode(Or(x, y), source_format="free") == "x .or. y"
292
+ assert fcode(Or(x, Not(y)), source_format="free") == "x .or. .not. y"
293
+ assert fcode(Or(Not(x), y), source_format="free") == "y .or. .not. x"
294
+ assert fcode(Or(Not(x), Not(y)), source_format="free") == \
295
+ ".not. x .or. .not. y"
296
+ assert fcode(Not(Or(x, y), evaluate=False), source_format="free") == \
297
+ ".not. (x .or. y)"
298
+ # mixed And/Or
299
+ assert fcode(And(Or(y, z), x), source_format="free") == "x .and. (y .or. z)"
300
+ assert fcode(And(Or(z, x), y), source_format="free") == "y .and. (x .or. z)"
301
+ assert fcode(And(Or(x, y), z), source_format="free") == "z .and. (x .or. y)"
302
+ assert fcode(Or(And(y, z), x), source_format="free") == "x .or. y .and. z"
303
+ assert fcode(Or(And(z, x), y), source_format="free") == "y .or. x .and. z"
304
+ assert fcode(Or(And(x, y), z), source_format="free") == "z .or. x .and. y"
305
+ # trinary And
306
+ assert fcode(And(x, y, z), source_format="free") == "x .and. y .and. z"
307
+ assert fcode(And(x, y, Not(z)), source_format="free") == \
308
+ "x .and. y .and. .not. z"
309
+ assert fcode(And(x, Not(y), z), source_format="free") == \
310
+ "x .and. z .and. .not. y"
311
+ assert fcode(And(Not(x), y, z), source_format="free") == \
312
+ "y .and. z .and. .not. x"
313
+ assert fcode(Not(And(x, y, z), evaluate=False), source_format="free") == \
314
+ ".not. (x .and. y .and. z)"
315
+ # trinary Or
316
+ assert fcode(Or(x, y, z), source_format="free") == "x .or. y .or. z"
317
+ assert fcode(Or(x, y, Not(z)), source_format="free") == \
318
+ "x .or. y .or. .not. z"
319
+ assert fcode(Or(x, Not(y), z), source_format="free") == \
320
+ "x .or. z .or. .not. y"
321
+ assert fcode(Or(Not(x), y, z), source_format="free") == \
322
+ "y .or. z .or. .not. x"
323
+ assert fcode(Not(Or(x, y, z), evaluate=False), source_format="free") == \
324
+ ".not. (x .or. y .or. z)"
325
+
326
+
327
+ def test_fcode_Xlogical():
328
+ x, y, z = symbols("x y z")
329
+ # binary Xor
330
+ assert fcode(Xor(x, y, evaluate=False), source_format="free") == \
331
+ "x .neqv. y"
332
+ assert fcode(Xor(x, Not(y), evaluate=False), source_format="free") == \
333
+ "x .neqv. .not. y"
334
+ assert fcode(Xor(Not(x), y, evaluate=False), source_format="free") == \
335
+ "y .neqv. .not. x"
336
+ assert fcode(Xor(Not(x), Not(y), evaluate=False),
337
+ source_format="free") == ".not. x .neqv. .not. y"
338
+ assert fcode(Not(Xor(x, y, evaluate=False), evaluate=False),
339
+ source_format="free") == ".not. (x .neqv. y)"
340
+ # binary Equivalent
341
+ assert fcode(Equivalent(x, y), source_format="free") == "x .eqv. y"
342
+ assert fcode(Equivalent(x, Not(y)), source_format="free") == \
343
+ "x .eqv. .not. y"
344
+ assert fcode(Equivalent(Not(x), y), source_format="free") == \
345
+ "y .eqv. .not. x"
346
+ assert fcode(Equivalent(Not(x), Not(y)), source_format="free") == \
347
+ ".not. x .eqv. .not. y"
348
+ assert fcode(Not(Equivalent(x, y), evaluate=False),
349
+ source_format="free") == ".not. (x .eqv. y)"
350
+ # mixed And/Equivalent
351
+ assert fcode(Equivalent(And(y, z), x), source_format="free") == \
352
+ "x .eqv. y .and. z"
353
+ assert fcode(Equivalent(And(z, x), y), source_format="free") == \
354
+ "y .eqv. x .and. z"
355
+ assert fcode(Equivalent(And(x, y), z), source_format="free") == \
356
+ "z .eqv. x .and. y"
357
+ assert fcode(And(Equivalent(y, z), x), source_format="free") == \
358
+ "x .and. (y .eqv. z)"
359
+ assert fcode(And(Equivalent(z, x), y), source_format="free") == \
360
+ "y .and. (x .eqv. z)"
361
+ assert fcode(And(Equivalent(x, y), z), source_format="free") == \
362
+ "z .and. (x .eqv. y)"
363
+ # mixed Or/Equivalent
364
+ assert fcode(Equivalent(Or(y, z), x), source_format="free") == \
365
+ "x .eqv. y .or. z"
366
+ assert fcode(Equivalent(Or(z, x), y), source_format="free") == \
367
+ "y .eqv. x .or. z"
368
+ assert fcode(Equivalent(Or(x, y), z), source_format="free") == \
369
+ "z .eqv. x .or. y"
370
+ assert fcode(Or(Equivalent(y, z), x), source_format="free") == \
371
+ "x .or. (y .eqv. z)"
372
+ assert fcode(Or(Equivalent(z, x), y), source_format="free") == \
373
+ "y .or. (x .eqv. z)"
374
+ assert fcode(Or(Equivalent(x, y), z), source_format="free") == \
375
+ "z .or. (x .eqv. y)"
376
+ # mixed Xor/Equivalent
377
+ assert fcode(Equivalent(Xor(y, z, evaluate=False), x),
378
+ source_format="free") == "x .eqv. (y .neqv. z)"
379
+ assert fcode(Equivalent(Xor(z, x, evaluate=False), y),
380
+ source_format="free") == "y .eqv. (x .neqv. z)"
381
+ assert fcode(Equivalent(Xor(x, y, evaluate=False), z),
382
+ source_format="free") == "z .eqv. (x .neqv. y)"
383
+ assert fcode(Xor(Equivalent(y, z), x, evaluate=False),
384
+ source_format="free") == "x .neqv. (y .eqv. z)"
385
+ assert fcode(Xor(Equivalent(z, x), y, evaluate=False),
386
+ source_format="free") == "y .neqv. (x .eqv. z)"
387
+ assert fcode(Xor(Equivalent(x, y), z, evaluate=False),
388
+ source_format="free") == "z .neqv. (x .eqv. y)"
389
+ # mixed And/Xor
390
+ assert fcode(Xor(And(y, z), x, evaluate=False), source_format="free") == \
391
+ "x .neqv. y .and. z"
392
+ assert fcode(Xor(And(z, x), y, evaluate=False), source_format="free") == \
393
+ "y .neqv. x .and. z"
394
+ assert fcode(Xor(And(x, y), z, evaluate=False), source_format="free") == \
395
+ "z .neqv. x .and. y"
396
+ assert fcode(And(Xor(y, z, evaluate=False), x), source_format="free") == \
397
+ "x .and. (y .neqv. z)"
398
+ assert fcode(And(Xor(z, x, evaluate=False), y), source_format="free") == \
399
+ "y .and. (x .neqv. z)"
400
+ assert fcode(And(Xor(x, y, evaluate=False), z), source_format="free") == \
401
+ "z .and. (x .neqv. y)"
402
+ # mixed Or/Xor
403
+ assert fcode(Xor(Or(y, z), x, evaluate=False), source_format="free") == \
404
+ "x .neqv. y .or. z"
405
+ assert fcode(Xor(Or(z, x), y, evaluate=False), source_format="free") == \
406
+ "y .neqv. x .or. z"
407
+ assert fcode(Xor(Or(x, y), z, evaluate=False), source_format="free") == \
408
+ "z .neqv. x .or. y"
409
+ assert fcode(Or(Xor(y, z, evaluate=False), x), source_format="free") == \
410
+ "x .or. (y .neqv. z)"
411
+ assert fcode(Or(Xor(z, x, evaluate=False), y), source_format="free") == \
412
+ "y .or. (x .neqv. z)"
413
+ assert fcode(Or(Xor(x, y, evaluate=False), z), source_format="free") == \
414
+ "z .or. (x .neqv. y)"
415
+ # trinary Xor
416
+ assert fcode(Xor(x, y, z, evaluate=False), source_format="free") == \
417
+ "x .neqv. y .neqv. z"
418
+ assert fcode(Xor(x, y, Not(z), evaluate=False), source_format="free") == \
419
+ "x .neqv. y .neqv. .not. z"
420
+ assert fcode(Xor(x, Not(y), z, evaluate=False), source_format="free") == \
421
+ "x .neqv. z .neqv. .not. y"
422
+ assert fcode(Xor(Not(x), y, z, evaluate=False), source_format="free") == \
423
+ "y .neqv. z .neqv. .not. x"
424
+
425
+
426
+ def test_fcode_Relational():
427
+ x, y = symbols("x y")
428
+ assert fcode(Relational(x, y, "=="), source_format="free") == "x == y"
429
+ assert fcode(Relational(x, y, "!="), source_format="free") == "x /= y"
430
+ assert fcode(Relational(x, y, ">="), source_format="free") == "x >= y"
431
+ assert fcode(Relational(x, y, "<="), source_format="free") == "x <= y"
432
+ assert fcode(Relational(x, y, ">"), source_format="free") == "x > y"
433
+ assert fcode(Relational(x, y, "<"), source_format="free") == "x < y"
434
+
435
+
436
+ def test_fcode_Piecewise():
437
+ x = symbols('x')
438
+ expr = Piecewise((x, x < 1), (x**2, True))
439
+ # Check that inline conditional (merge) fails if standard isn't 95+
440
+ raises(NotImplementedError, lambda: fcode(expr))
441
+ code = fcode(expr, standard=95)
442
+ expected = " merge(x, x**2, x < 1)"
443
+ assert code == expected
444
+ assert fcode(Piecewise((x, x < 1), (x**2, True)), assign_to="var") == (
445
+ " if (x < 1) then\n"
446
+ " var = x\n"
447
+ " else\n"
448
+ " var = x**2\n"
449
+ " end if"
450
+ )
451
+ a = cos(x)/x
452
+ b = sin(x)/x
453
+ for i in range(10):
454
+ a = diff(a, x)
455
+ b = diff(b, x)
456
+ expected = (
457
+ " if (x < 0) then\n"
458
+ " weird_name = -cos(x)/x + 10*sin(x)/x**2 + 90*cos(x)/x**3 - 720*\n"
459
+ " @ sin(x)/x**4 - 5040*cos(x)/x**5 + 30240*sin(x)/x**6 + 151200*cos(x\n"
460
+ " @ )/x**7 - 604800*sin(x)/x**8 - 1814400*cos(x)/x**9 + 3628800*sin(x\n"
461
+ " @ )/x**10 + 3628800*cos(x)/x**11\n"
462
+ " else\n"
463
+ " weird_name = -sin(x)/x - 10*cos(x)/x**2 + 90*sin(x)/x**3 + 720*\n"
464
+ " @ cos(x)/x**4 - 5040*sin(x)/x**5 - 30240*cos(x)/x**6 + 151200*sin(x\n"
465
+ " @ )/x**7 + 604800*cos(x)/x**8 - 1814400*sin(x)/x**9 - 3628800*cos(x\n"
466
+ " @ )/x**10 + 3628800*sin(x)/x**11\n"
467
+ " end if"
468
+ )
469
+ code = fcode(Piecewise((a, x < 0), (b, True)), assign_to="weird_name")
470
+ assert code == expected
471
+ code = fcode(Piecewise((x, x < 1), (x**2, x > 1), (sin(x), True)), standard=95)
472
+ expected = " merge(x, merge(x**2, sin(x), x > 1), x < 1)"
473
+ assert code == expected
474
+ # Check that Piecewise without a True (default) condition error
475
+ expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
476
+ raises(ValueError, lambda: fcode(expr))
477
+
478
+
479
+ def test_wrap_fortran():
480
+ # "########################################################################"
481
+ printer = FCodePrinter()
482
+ lines = [
483
+ "C This is a long comment on a single line that must be wrapped properly to produce nice output",
484
+ " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
485
+ " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
486
+ " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly",
487
+ " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
488
+ " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
489
+ " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
490
+ " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly",
491
+ " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
492
+ " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
493
+ " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
494
+ " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
495
+ " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly",
496
+ " this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly",
497
+ " this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly",
498
+ ]
499
+ wrapped_lines = printer._wrap_fortran(lines)
500
+ expected_lines = [
501
+ "C This is a long comment on a single line that must be wrapped",
502
+ "C properly to produce nice output",
503
+ " this = is + a + long + and + nasty + fortran + statement + that *",
504
+ " @ must + be + wrapped + properly",
505
+ " this = is + a + long + and + nasty + fortran + statement + that *",
506
+ " @ must + be + wrapped + properly",
507
+ " this = is + a + long + and + nasty + fortran + statement + that",
508
+ " @ * must + be + wrapped + properly",
509
+ " this = is + a + long + and + nasty + fortran + statement + that*",
510
+ " @ must + be + wrapped + properly",
511
+ " this = is + a + long + and + nasty + fortran + statement + that*",
512
+ " @ must + be + wrapped + properly",
513
+ " this = is + a + long + and + nasty + fortran + statement + that",
514
+ " @ *must + be + wrapped + properly",
515
+ " this = is + a + long + and + nasty + fortran + statement +",
516
+ " @ that*must + be + wrapped + properly",
517
+ " this = is + a + long + and + nasty + fortran + statement + that**",
518
+ " @ must + be + wrapped + properly",
519
+ " this = is + a + long + and + nasty + fortran + statement + that**",
520
+ " @ must + be + wrapped + properly",
521
+ " this = is + a + long + and + nasty + fortran + statement + that",
522
+ " @ **must + be + wrapped + properly",
523
+ " this = is + a + long + and + nasty + fortran + statement + that",
524
+ " @ **must + be + wrapped + properly",
525
+ " this = is + a + long + and + nasty + fortran + statement +",
526
+ " @ that**must + be + wrapped + properly",
527
+ " this = is + a + long + and + nasty + fortran + statement(that)/",
528
+ " @ must + be + wrapped + properly",
529
+ " this = is + a + long + and + nasty + fortran + statement(that)",
530
+ " @ /must + be + wrapped + properly",
531
+ ]
532
+ for line in wrapped_lines:
533
+ assert len(line) <= 72
534
+ for w, e in zip(wrapped_lines, expected_lines):
535
+ assert w == e
536
+ assert len(wrapped_lines) == len(expected_lines)
537
+
538
+
539
+ def test_wrap_fortran_keep_d0():
540
+ printer = FCodePrinter()
541
+ lines = [
542
+ ' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0',
543
+ ' this_variable_is_very_long_because_we_try_to_test_line_break =1.0d0',
544
+ ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
545
+ ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
546
+ ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0',
547
+ ' this_variable_is_very_long_because_we_try_to_test_line_break = 10.0d0'
548
+ ]
549
+ expected = [
550
+ ' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0',
551
+ ' this_variable_is_very_long_because_we_try_to_test_line_break =',
552
+ ' @ 1.0d0',
553
+ ' this_variable_is_very_long_because_we_try_to_test_line_break =',
554
+ ' @ 1.0d0',
555
+ ' this_variable_is_very_long_because_we_try_to_test_line_break =',
556
+ ' @ 1.0d0',
557
+ ' this_variable_is_very_long_because_we_try_to_test_line_break =',
558
+ ' @ 1.0d0',
559
+ ' this_variable_is_very_long_because_we_try_to_test_line_break =',
560
+ ' @ 10.0d0'
561
+ ]
562
+ assert printer._wrap_fortran(lines) == expected
563
+
564
+
565
+ def test_settings():
566
+ raises(TypeError, lambda: fcode(S(4), method="garbage"))
567
+
568
+
569
+ def test_free_form_code_line():
570
+ x, y = symbols('x,y')
571
+ assert fcode(cos(x) + sin(y), source_format='free') == "sin(y) + cos(x)"
572
+
573
+
574
+ def test_free_form_continuation_line():
575
+ x, y = symbols('x,y')
576
+ result = fcode(((cos(x) + sin(y))**(7)).expand(), source_format='free')
577
+ expected = (
578
+ 'sin(y)**7 + 7*sin(y)**6*cos(x) + 21*sin(y)**5*cos(x)**2 + 35*sin(y)**4* &\n'
579
+ ' cos(x)**3 + 35*sin(y)**3*cos(x)**4 + 21*sin(y)**2*cos(x)**5 + 7* &\n'
580
+ ' sin(y)*cos(x)**6 + cos(x)**7'
581
+ )
582
+ assert result == expected
583
+
584
+
585
+ def test_free_form_comment_line():
586
+ printer = FCodePrinter({'source_format': 'free'})
587
+ lines = [ "! This is a long comment on a single line that must be wrapped properly to produce nice output"]
588
+ expected = [
589
+ '! This is a long comment on a single line that must be wrapped properly',
590
+ '! to produce nice output']
591
+ assert printer._wrap_fortran(lines) == expected
592
+
593
+
594
+ def test_loops():
595
+ n, m = symbols('n,m', integer=True)
596
+ A = IndexedBase('A')
597
+ x = IndexedBase('x')
598
+ y = IndexedBase('y')
599
+ i = Idx('i', m)
600
+ j = Idx('j', n)
601
+
602
+ expected = (
603
+ 'do i = 1, m\n'
604
+ ' y(i) = 0\n'
605
+ 'end do\n'
606
+ 'do i = 1, m\n'
607
+ ' do j = 1, n\n'
608
+ ' y(i) = %(rhs)s\n'
609
+ ' end do\n'
610
+ 'end do'
611
+ )
612
+
613
+ code = fcode(A[i, j]*x[j], assign_to=y[i], source_format='free')
614
+ assert (code == expected % {'rhs': 'y(i) + A(i, j)*x(j)'} or
615
+ code == expected % {'rhs': 'y(i) + x(j)*A(i, j)'} or
616
+ code == expected % {'rhs': 'x(j)*A(i, j) + y(i)'} or
617
+ code == expected % {'rhs': 'A(i, j)*x(j) + y(i)'})
618
+
619
+
620
+ def test_dummy_loops():
621
+ i, m = symbols('i m', integer=True, cls=Dummy)
622
+ x = IndexedBase('x')
623
+ y = IndexedBase('y')
624
+ i = Idx(i, m)
625
+
626
+ expected = (
627
+ 'do i_%(icount)i = 1, m_%(mcount)i\n'
628
+ ' y(i_%(icount)i) = x(i_%(icount)i)\n'
629
+ 'end do'
630
+ ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
631
+ code = fcode(x[i], assign_to=y[i], source_format='free')
632
+ assert code == expected
633
+
634
+
635
+ def test_fcode_Indexed_without_looking_for_contraction():
636
+ len_y = 5
637
+ y = IndexedBase('y', shape=(len_y,))
638
+ x = IndexedBase('x', shape=(len_y,))
639
+ Dy = IndexedBase('Dy', shape=(len_y-1,))
640
+ i = Idx('i', len_y-1)
641
+ e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
642
+ code0 = fcode(e.rhs, assign_to=e.lhs, contract=False)
643
+ assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
644
+
645
+
646
+ def test_element_like_objects():
647
+ len_y = 5
648
+ y = ArraySymbol('y', shape=(len_y,))
649
+ x = ArraySymbol('x', shape=(len_y,))
650
+ Dy = ArraySymbol('Dy', shape=(len_y-1,))
651
+ i = Idx('i', len_y-1)
652
+ e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
653
+ code0 = fcode(Assignment(e.lhs, e.rhs))
654
+ assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
655
+
656
+ class ElementExpr(Element, Expr):
657
+ pass
658
+
659
+ e = e.subs((a, ElementExpr(a.name, a.indices)) for a in e.atoms(ArrayElement) )
660
+ e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
661
+ code0 = fcode(Assignment(e.lhs, e.rhs))
662
+ assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))')
663
+
664
+
665
+ def test_derived_classes():
666
+ class MyFancyFCodePrinter(FCodePrinter):
667
+ _default_settings = FCodePrinter._default_settings.copy()
668
+
669
+ printer = MyFancyFCodePrinter()
670
+ x = symbols('x')
671
+ assert printer.doprint(sin(x), "bork") == " bork = sin(x)"
672
+
673
+
674
+ def test_indent():
675
+ codelines = (
676
+ 'subroutine test(a)\n'
677
+ 'integer :: a, i, j\n'
678
+ '\n'
679
+ 'do\n'
680
+ 'do \n'
681
+ 'do j = 1, 5\n'
682
+ 'if (a>b) then\n'
683
+ 'if(b>0) then\n'
684
+ 'a = 3\n'
685
+ 'donot_indent_me = 2\n'
686
+ 'do_not_indent_me_either = 2\n'
687
+ 'ifIam_indented_something_went_wrong = 2\n'
688
+ 'if_I_am_indented_something_went_wrong = 2\n'
689
+ 'end should not be unindented here\n'
690
+ 'end if\n'
691
+ 'endif\n'
692
+ 'end do\n'
693
+ 'end do\n'
694
+ 'enddo\n'
695
+ 'end subroutine\n'
696
+ '\n'
697
+ 'subroutine test2(a)\n'
698
+ 'integer :: a\n'
699
+ 'do\n'
700
+ 'a = a + 1\n'
701
+ 'end do \n'
702
+ 'end subroutine\n'
703
+ )
704
+ expected = (
705
+ 'subroutine test(a)\n'
706
+ 'integer :: a, i, j\n'
707
+ '\n'
708
+ 'do\n'
709
+ ' do \n'
710
+ ' do j = 1, 5\n'
711
+ ' if (a>b) then\n'
712
+ ' if(b>0) then\n'
713
+ ' a = 3\n'
714
+ ' donot_indent_me = 2\n'
715
+ ' do_not_indent_me_either = 2\n'
716
+ ' ifIam_indented_something_went_wrong = 2\n'
717
+ ' if_I_am_indented_something_went_wrong = 2\n'
718
+ ' end should not be unindented here\n'
719
+ ' end if\n'
720
+ ' endif\n'
721
+ ' end do\n'
722
+ ' end do\n'
723
+ 'enddo\n'
724
+ 'end subroutine\n'
725
+ '\n'
726
+ 'subroutine test2(a)\n'
727
+ 'integer :: a\n'
728
+ 'do\n'
729
+ ' a = a + 1\n'
730
+ 'end do \n'
731
+ 'end subroutine\n'
732
+ )
733
+ p = FCodePrinter({'source_format': 'free'})
734
+ result = p.indent_code(codelines)
735
+ assert result == expected
736
+
737
+ def test_Matrix_printing():
738
+ x, y, z = symbols('x,y,z')
739
+ # Test returning a Matrix
740
+ mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
741
+ A = MatrixSymbol('A', 3, 1)
742
+ assert fcode(mat, A) == (
743
+ " A(1, 1) = x*y\n"
744
+ " if (y > 0) then\n"
745
+ " A(2, 1) = x + 2\n"
746
+ " else\n"
747
+ " A(2, 1) = y\n"
748
+ " end if\n"
749
+ " A(3, 1) = sin(z)")
750
+ # Test using MatrixElements in expressions
751
+ expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
752
+ assert fcode(expr, standard=95) == (
753
+ " merge(2*A(3, 1), A(3, 1), x > 0) + sin(A(2, 1)) + A(1, 1)")
754
+ # Test using MatrixElements in a Matrix
755
+ q = MatrixSymbol('q', 5, 1)
756
+ M = MatrixSymbol('M', 3, 3)
757
+ m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
758
+ [q[1,0] + q[2,0], q[3, 0], 5],
759
+ [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
760
+ assert fcode(m, M) == (
761
+ " M(1, 1) = sin(q(2, 1))\n"
762
+ " M(2, 1) = q(2, 1) + q(3, 1)\n"
763
+ " M(3, 1) = 2*q(5, 1)/q(2, 1)\n"
764
+ " M(1, 2) = 0\n"
765
+ " M(2, 2) = q(4, 1)\n"
766
+ " M(3, 2) = sqrt(q(1, 1)) + 4\n"
767
+ " M(1, 3) = cos(q(3, 1))\n"
768
+ " M(2, 3) = 5\n"
769
+ " M(3, 3) = 0")
770
+
771
+
772
+ def test_fcode_For():
773
+ x, y = symbols('x y')
774
+
775
+ f = For(x, Range(0, 10, 2), [Assignment(y, x * y)])
776
+ sol = fcode(f)
777
+ assert sol == (" do x = 0, 9, 2\n"
778
+ " y = x*y\n"
779
+ " end do")
780
+
781
+
782
+ def test_fcode_Declaration():
783
+ def check(expr, ref, **kwargs):
784
+ assert fcode(expr, standard=95, source_format='free', **kwargs) == ref
785
+
786
+ i = symbols('i', integer=True)
787
+ var1 = Variable.deduced(i)
788
+ dcl1 = Declaration(var1)
789
+ check(dcl1, "integer*4 :: i")
790
+
791
+
792
+ x, y = symbols('x y')
793
+ var2 = Variable(x, float32, value=42, attrs={value_const})
794
+ dcl2b = Declaration(var2)
795
+ check(dcl2b, 'real*4, parameter :: x = 42')
796
+
797
+ var3 = Variable(y, type=bool_)
798
+ dcl3 = Declaration(var3)
799
+ check(dcl3, 'logical :: y')
800
+
801
+ check(float32, "real*4")
802
+ check(float64, "real*8")
803
+ check(real, "real*4", type_aliases={real: float32})
804
+ check(real, "real*8", type_aliases={real: float64})
805
+
806
+
807
+ def test_MatrixElement_printing():
808
+ # test cases for issue #11821
809
+ A = MatrixSymbol("A", 1, 3)
810
+ B = MatrixSymbol("B", 1, 3)
811
+ C = MatrixSymbol("C", 1, 3)
812
+
813
+ assert(fcode(A[0, 0]) == " A(1, 1)")
814
+ assert(fcode(3 * A[0, 0]) == " 3*A(1, 1)")
815
+
816
+ F = C[0, 0].subs(C, A - B)
817
+ assert(fcode(F) == " (A - B)(1, 1)")
818
+
819
+
820
+ def test_aug_assign():
821
+ x = symbols('x')
822
+ assert fcode(aug_assign(x, '+', 1), source_format='free') == 'x = x + 1'
823
+
824
+
825
+ def test_While():
826
+ x = symbols('x')
827
+ assert fcode(While(abs(x) > 1, [aug_assign(x, '-', 1)]), source_format='free') == (
828
+ 'do while (abs(x) > 1)\n'
829
+ ' x = x - 1\n'
830
+ 'end do'
831
+ )
832
+
833
+
834
+ def test_FunctionPrototype_print():
835
+ x = symbols('x')
836
+ n = symbols('n', integer=True)
837
+ vx = Variable(x, type=real)
838
+ vn = Variable(n, type=integer)
839
+ fp1 = FunctionPrototype(real, 'power', [vx, vn])
840
+ # Should be changed to proper test once multi-line generation is working
841
+ # see https://github.com/sympy/sympy/issues/15824
842
+ raises(NotImplementedError, lambda: fcode(fp1))
843
+
844
+
845
+ def test_FunctionDefinition_print():
846
+ x = symbols('x')
847
+ n = symbols('n', integer=True)
848
+ vx = Variable(x, type=real)
849
+ vn = Variable(n, type=integer)
850
+ body = [Assignment(x, x**n), Return(x)]
851
+ fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
852
+ # Should be changed to proper test once multi-line generation is working
853
+ # see https://github.com/sympy/sympy/issues/15824
854
+ raises(NotImplementedError, lambda: fcode(fd1))
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_llvmjit.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.external import import_module
2
+ from sympy.testing.pytest import raises
3
+ import ctypes
4
+
5
+
6
+ if import_module('llvmlite'):
7
+ import sympy.printing.llvmjitcode as g
8
+ else:
9
+ disabled = True
10
+
11
+ import sympy
12
+ from sympy.abc import a, b, n
13
+
14
+
15
+ # copied from numpy.isclose documentation
16
+ def isclose(a, b):
17
+ rtol = 1e-5
18
+ atol = 1e-8
19
+ return abs(a-b) <= atol + rtol*abs(b)
20
+
21
+
22
+ def test_simple_expr():
23
+ e = a + 1.0
24
+ f = g.llvm_callable([a], e)
25
+ res = float(e.subs({a: 4.0}).evalf())
26
+ jit_res = f(4.0)
27
+
28
+ assert isclose(jit_res, res)
29
+
30
+
31
+ def test_two_arg():
32
+ e = 4.0*a + b + 3.0
33
+ f = g.llvm_callable([a, b], e)
34
+ res = float(e.subs({a: 4.0, b: 3.0}).evalf())
35
+ jit_res = f(4.0, 3.0)
36
+
37
+ assert isclose(jit_res, res)
38
+
39
+
40
+ def test_func():
41
+ e = 4.0*sympy.exp(-a)
42
+ f = g.llvm_callable([a], e)
43
+ res = float(e.subs({a: 1.5}).evalf())
44
+ jit_res = f(1.5)
45
+
46
+ assert isclose(jit_res, res)
47
+
48
+
49
+ def test_two_func():
50
+ e = 4.0*sympy.exp(-a) + sympy.exp(b)
51
+ f = g.llvm_callable([a, b], e)
52
+ res = float(e.subs({a: 1.5, b: 2.0}).evalf())
53
+ jit_res = f(1.5, 2.0)
54
+
55
+ assert isclose(jit_res, res)
56
+
57
+
58
+ def test_two_sqrt():
59
+ e = 4.0*sympy.sqrt(a) + sympy.sqrt(b)
60
+ f = g.llvm_callable([a, b], e)
61
+ res = float(e.subs({a: 1.5, b: 2.0}).evalf())
62
+ jit_res = f(1.5, 2.0)
63
+
64
+ assert isclose(jit_res, res)
65
+
66
+
67
+ def test_two_pow():
68
+ e = a**1.5 + b**7
69
+ f = g.llvm_callable([a, b], e)
70
+ res = float(e.subs({a: 1.5, b: 2.0}).evalf())
71
+ jit_res = f(1.5, 2.0)
72
+
73
+ assert isclose(jit_res, res)
74
+
75
+
76
+ def test_callback():
77
+ e = a + 1.2
78
+ f = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
79
+ m = ctypes.c_int(1)
80
+ array_type = ctypes.c_double * 1
81
+ inp = {a: 2.2}
82
+ array = array_type(inp[a])
83
+ jit_res = f(m, array)
84
+
85
+ res = float(e.subs(inp).evalf())
86
+
87
+ assert isclose(jit_res, res)
88
+
89
+
90
+ def test_callback_cubature():
91
+ e = a + 1.2
92
+ f = g.llvm_callable([a], e, callback_type='cubature')
93
+ m = ctypes.c_int(1)
94
+ array_type = ctypes.c_double * 1
95
+ inp = {a: 2.2}
96
+ array = array_type(inp[a])
97
+ out_array = array_type(0.0)
98
+ jit_ret = f(m, array, None, m, out_array)
99
+
100
+ assert jit_ret == 0
101
+
102
+ res = float(e.subs(inp).evalf())
103
+
104
+ assert isclose(out_array[0], res)
105
+
106
+
107
+ def test_callback_two():
108
+ e = 3*a*b
109
+ f = g.llvm_callable([a, b], e, callback_type='scipy.integrate.test')
110
+ m = ctypes.c_int(2)
111
+ array_type = ctypes.c_double * 2
112
+ inp = {a: 0.2, b: 1.7}
113
+ array = array_type(inp[a], inp[b])
114
+ jit_res = f(m, array)
115
+
116
+ res = float(e.subs(inp).evalf())
117
+
118
+ assert isclose(jit_res, res)
119
+
120
+
121
+ def test_callback_alt_two():
122
+ d = sympy.IndexedBase('d')
123
+ e = 3*d[0]*d[1]
124
+ f = g.llvm_callable([n, d], e, callback_type='scipy.integrate.test')
125
+ m = ctypes.c_int(2)
126
+ array_type = ctypes.c_double * 2
127
+ inp = {d[0]: 0.2, d[1]: 1.7}
128
+ array = array_type(inp[d[0]], inp[d[1]])
129
+ jit_res = f(m, array)
130
+
131
+ res = float(e.subs(inp).evalf())
132
+
133
+ assert isclose(jit_res, res)
134
+
135
+
136
+ def test_multiple_statements():
137
+ # Match return from CSE
138
+ e = [[(b, 4.0*a)], [b + 5]]
139
+ f = g.llvm_callable([a], e)
140
+ b_val = e[0][0][1].subs({a: 1.5})
141
+ res = float(e[1][0].subs({b: b_val}).evalf())
142
+ jit_res = f(1.5)
143
+ assert isclose(jit_res, res)
144
+
145
+ f_callback = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
146
+ m = ctypes.c_int(1)
147
+ array_type = ctypes.c_double * 1
148
+ array = array_type(1.5)
149
+ jit_callback_res = f_callback(m, array)
150
+ assert isclose(jit_callback_res, res)
151
+
152
+
153
+ def test_cse():
154
+ e = a*a + b*b + sympy.exp(-a*a - b*b)
155
+ e2 = sympy.cse(e)
156
+ f = g.llvm_callable([a, b], e2)
157
+ res = float(e.subs({a: 2.3, b: 0.1}).evalf())
158
+ jit_res = f(2.3, 0.1)
159
+
160
+ assert isclose(jit_res, res)
161
+
162
+
163
+ def eval_cse(e, sub_dict):
164
+ tmp_dict = {}
165
+ for tmp_name, tmp_expr in e[0]:
166
+ e2 = tmp_expr.subs(sub_dict)
167
+ e3 = e2.subs(tmp_dict)
168
+ tmp_dict[tmp_name] = e3
169
+ return [e.subs(sub_dict).subs(tmp_dict) for e in e[1]]
170
+
171
+
172
+ def test_cse_multiple():
173
+ e1 = a*a
174
+ e2 = a*a + b*b
175
+ e3 = sympy.cse([e1, e2])
176
+
177
+ raises(NotImplementedError,
178
+ lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate'))
179
+
180
+ f = g.llvm_callable([a, b], e3)
181
+ jit_res = f(0.1, 1.5)
182
+ assert len(jit_res) == 2
183
+ res = eval_cse(e3, {a: 0.1, b: 1.5})
184
+ assert isclose(res[0], jit_res[0])
185
+ assert isclose(res[1], jit_res[1])
186
+
187
+
188
+ def test_callback_cubature_multiple():
189
+ e1 = a*a
190
+ e2 = a*a + b*b
191
+ e3 = sympy.cse([e1, e2, 4*e2])
192
+ f = g.llvm_callable([a, b], e3, callback_type='cubature')
193
+
194
+ # Number of input variables
195
+ ndim = 2
196
+ # Number of output expression values
197
+ outdim = 3
198
+
199
+ m = ctypes.c_int(ndim)
200
+ fdim = ctypes.c_int(outdim)
201
+ array_type = ctypes.c_double * ndim
202
+ out_array_type = ctypes.c_double * outdim
203
+ inp = {a: 0.2, b: 1.5}
204
+ array = array_type(inp[a], inp[b])
205
+ out_array = out_array_type()
206
+ jit_ret = f(m, array, None, fdim, out_array)
207
+
208
+ assert jit_ret == 0
209
+
210
+ res = eval_cse(e3, inp)
211
+
212
+ assert isclose(out_array[0], res[0])
213
+ assert isclose(out_array[1], res[1])
214
+ assert isclose(out_array[2], res[2])
215
+
216
+
217
+ def test_symbol_not_found():
218
+ e = a*a + b
219
+ raises(LookupError, lambda: g.llvm_callable([a], e))
220
+
221
+
222
+ def test_bad_callback():
223
+ e = a
224
+ raises(ValueError, lambda: g.llvm_callable([a], e, callback_type='bad_callback'))
.venv/lib/python3.13/site-packages/sympy/printing/tests/test_rcode.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import (S, pi, oo, Symbol, symbols, Rational, Integer,
2
+ GoldenRatio, EulerGamma, Catalan, Lambda, Dummy)
3
+ from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
4
+ gamma, sign, Max, Min, factorial, beta)
5
+ from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
6
+ from sympy.sets import Range
7
+ from sympy.logic import ITE
8
+ from sympy.codegen import For, aug_assign, Assignment
9
+ from sympy.testing.pytest import raises
10
+ from sympy.printing.rcode import RCodePrinter
11
+ from sympy.utilities.lambdify import implemented_function
12
+ from sympy.tensor import IndexedBase, Idx
13
+ from sympy.matrices import Matrix, MatrixSymbol
14
+
15
+ from sympy.printing.rcode import rcode
16
+
17
+ x, y, z = symbols('x,y,z')
18
+
19
+
20
+ def test_printmethod():
21
+ class fabs(Abs):
22
+ def _rcode(self, printer):
23
+ return "abs(%s)" % printer._print(self.args[0])
24
+
25
+ assert rcode(fabs(x)) == "abs(x)"
26
+
27
+
28
+ def test_rcode_sqrt():
29
+ assert rcode(sqrt(x)) == "sqrt(x)"
30
+ assert rcode(x**0.5) == "sqrt(x)"
31
+ assert rcode(sqrt(x)) == "sqrt(x)"
32
+
33
+
34
+ def test_rcode_Pow():
35
+ assert rcode(x**3) == "x^3"
36
+ assert rcode(x**(y**3)) == "x^(y^3)"
37
+ g = implemented_function('g', Lambda(x, 2*x))
38
+ assert rcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
39
+ "(3.5*2*x)^(-x + y^x)/(x^2 + y)"
40
+ assert rcode(x**-1.0) == '1.0/x'
41
+ assert rcode(x**Rational(2, 3)) == 'x^(2.0/3.0)'
42
+ _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
43
+ (lambda base, exp: not exp.is_integer, "pow")]
44
+ assert rcode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
45
+ assert rcode(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 3.2)'
46
+
47
+
48
+ def test_rcode_Max():
49
+ # Test for gh-11926
50
+ assert rcode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
51
+
52
+
53
+ def test_rcode_constants_mathh():
54
+ assert rcode(exp(1)) == "exp(1)"
55
+ assert rcode(pi) == "pi"
56
+ assert rcode(oo) == "Inf"
57
+ assert rcode(-oo) == "-Inf"
58
+
59
+
60
+ def test_rcode_constants_other():
61
+ assert rcode(2*GoldenRatio) == "GoldenRatio = 1.61803398874989;\n2*GoldenRatio"
62
+ assert rcode(
63
+ 2*Catalan) == "Catalan = 0.915965594177219;\n2*Catalan"
64
+ assert rcode(2*EulerGamma) == "EulerGamma = 0.577215664901533;\n2*EulerGamma"
65
+
66
+
67
+ def test_rcode_Rational():
68
+ assert rcode(Rational(3, 7)) == "3.0/7.0"
69
+ assert rcode(Rational(18, 9)) == "2"
70
+ assert rcode(Rational(3, -7)) == "-3.0/7.0"
71
+ assert rcode(Rational(-3, -7)) == "3.0/7.0"
72
+ assert rcode(x + Rational(3, 7)) == "x + 3.0/7.0"
73
+ assert rcode(Rational(3, 7)*x) == "(3.0/7.0)*x"
74
+
75
+
76
+ def test_rcode_Integer():
77
+ assert rcode(Integer(67)) == "67"
78
+ assert rcode(Integer(-1)) == "-1"
79
+
80
+
81
+ def test_rcode_functions():
82
+ assert rcode(sin(x) ** cos(x)) == "sin(x)^cos(x)"
83
+ assert rcode(factorial(x) + gamma(y)) == "factorial(x) + gamma(y)"
84
+ assert rcode(beta(Min(x, y), Max(x, y))) == "beta(min(x, y), max(x, y))"
85
+
86
+
87
+ def test_rcode_inline_function():
88
+ x = symbols('x')
89
+ g = implemented_function('g', Lambda(x, 2*x))
90
+ assert rcode(g(x)) == "2*x"
91
+ g = implemented_function('g', Lambda(x, 2*x/Catalan))
92
+ assert rcode(
93
+ g(x)) == "Catalan = %s;\n2*x/Catalan" % Catalan.n()
94
+ A = IndexedBase('A')
95
+ i = Idx('i', symbols('n', integer=True))
96
+ g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
97
+ res=rcode(g(A[i]), assign_to=A[i])
98
+ ref=(
99
+ "for (i in 1:n){\n"
100
+ " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
101
+ "}"
102
+ )
103
+ assert res == ref
104
+
105
+
106
+ def test_rcode_exceptions():
107
+ assert rcode(ceiling(x)) == "ceiling(x)"
108
+ assert rcode(Abs(x)) == "abs(x)"
109
+ assert rcode(gamma(x)) == "gamma(x)"
110
+
111
+
112
+ def test_rcode_user_functions():
113
+ x = symbols('x', integer=False)
114
+ n = symbols('n', integer=True)
115
+ custom_functions = {
116
+ "ceiling": "myceil",
117
+ "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
118
+ }
119
+ assert rcode(ceiling(x), user_functions=custom_functions) == "myceil(x)"
120
+ assert rcode(Abs(x), user_functions=custom_functions) == "fabs(x)"
121
+ assert rcode(Abs(n), user_functions=custom_functions) == "abs(n)"
122
+
123
+
124
+ def test_rcode_boolean():
125
+ assert rcode(True) == "True"
126
+ assert rcode(S.true) == "True"
127
+ assert rcode(False) == "False"
128
+ assert rcode(S.false) == "False"
129
+ assert rcode(x & y) == "x & y"
130
+ assert rcode(x | y) == "x | y"
131
+ assert rcode(~x) == "!x"
132
+ assert rcode(x & y & z) == "x & y & z"
133
+ assert rcode(x | y | z) == "x | y | z"
134
+ assert rcode((x & y) | z) == "z | x & y"
135
+ assert rcode((x | y) & z) == "z & (x | y)"
136
+
137
+ def test_rcode_Relational():
138
+ assert rcode(Eq(x, y)) == "x == y"
139
+ assert rcode(Ne(x, y)) == "x != y"
140
+ assert rcode(Le(x, y)) == "x <= y"
141
+ assert rcode(Lt(x, y)) == "x < y"
142
+ assert rcode(Gt(x, y)) == "x > y"
143
+ assert rcode(Ge(x, y)) == "x >= y"
144
+
145
+
146
+ def test_rcode_Piecewise():
147
+ expr = Piecewise((x, x < 1), (x**2, True))
148
+ res=rcode(expr)
149
+ ref="ifelse(x < 1,x,x^2)"
150
+ assert res == ref
151
+ tau=Symbol("tau")
152
+ res=rcode(expr,tau)
153
+ ref="tau = ifelse(x < 1,x,x^2);"
154
+ assert res == ref
155
+
156
+ expr = 2*Piecewise((x, x < 1), (x**2, x<2), (x**3,True))
157
+ assert rcode(expr) == "2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3))"
158
+ res = rcode(expr, assign_to='c')
159
+ assert res == "c = 2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3));"
160
+
161
+ # Check that Piecewise without a True (default) condition error
162
+ #expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
163
+ #raises(ValueError, lambda: rcode(expr))
164
+ expr = 2*Piecewise((x, x < 1), (x**2, x<2))
165
+ assert(rcode(expr))== "2*ifelse(x < 1,x,ifelse(x < 2,x^2,NA))"
166
+
167
+
168
+ def test_rcode_sinc():
169
+ from sympy.functions.elementary.trigonometric import sinc
170
+ expr = sinc(x)
171
+ res = rcode(expr)
172
+ ref = "(ifelse(x != 0,sin(x)/x,1))"
173
+ assert res == ref
174
+
175
+
176
+ def test_rcode_Piecewise_deep():
177
+ p = rcode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
178
+ assert p == "2*ifelse(x < 1,x,ifelse(x < 2,x + 1,x^2))"
179
+ expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
180
+ p = rcode(expr)
181
+ ref="x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1"
182
+ assert p == ref
183
+
184
+ ref="c = x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1;"
185
+ p = rcode(expr, assign_to='c')
186
+ assert p == ref
187
+
188
+
189
+ def test_rcode_ITE():
190
+ expr = ITE(x < 1, y, z)
191
+ p = rcode(expr)
192
+ ref="ifelse(x < 1,y,z)"
193
+ assert p == ref
194
+
195
+
196
+ def test_rcode_settings():
197
+ raises(TypeError, lambda: rcode(sin(x), method="garbage"))
198
+
199
+
200
+ def test_rcode_Indexed():
201
+ n, m, o = symbols('n m o', integer=True)
202
+ i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
203
+ p = RCodePrinter()
204
+ p._not_r = set()
205
+
206
+ x = IndexedBase('x')[j]
207
+ assert p._print_Indexed(x) == 'x[j]'
208
+ A = IndexedBase('A')[i, j]
209
+ assert p._print_Indexed(A) == 'A[i, j]'
210
+ B = IndexedBase('B')[i, j, k]
211
+ assert p._print_Indexed(B) == 'B[i, j, k]'
212
+
213
+ assert p._not_r == set()
214
+
215
+ def test_rcode_Indexed_without_looking_for_contraction():
216
+ len_y = 5
217
+ y = IndexedBase('y', shape=(len_y,))
218
+ x = IndexedBase('x', shape=(len_y,))
219
+ Dy = IndexedBase('Dy', shape=(len_y-1,))
220
+ i = Idx('i', len_y-1)
221
+ e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
222
+ code0 = rcode(e.rhs, assign_to=e.lhs, contract=False)
223
+ assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
224
+
225
+
226
+ def test_rcode_loops_matrix_vector():
227
+ n, m = symbols('n m', integer=True)
228
+ A = IndexedBase('A')
229
+ x = IndexedBase('x')
230
+ y = IndexedBase('y')
231
+ i = Idx('i', m)
232
+ j = Idx('j', n)
233
+
234
+ s = (
235
+ 'for (i in 1:m){\n'
236
+ ' y[i] = 0;\n'
237
+ '}\n'
238
+ 'for (i in 1:m){\n'
239
+ ' for (j in 1:n){\n'
240
+ ' y[i] = A[i, j]*x[j] + y[i];\n'
241
+ ' }\n'
242
+ '}'
243
+ )
244
+ c = rcode(A[i, j]*x[j], assign_to=y[i])
245
+ assert c == s
246
+
247
+
248
+ def test_dummy_loops():
249
+ # the following line could also be
250
+ # [Dummy(s, integer=True) for s in 'im']
251
+ # or [Dummy(integer=True) for s in 'im']
252
+ i, m = symbols('i m', integer=True, cls=Dummy)
253
+ x = IndexedBase('x')
254
+ y = IndexedBase('y')
255
+ i = Idx(i, m)
256
+
257
+ expected = (
258
+ 'for (i_%(icount)i in 1:m_%(mcount)i){\n'
259
+ ' y[i_%(icount)i] = x[i_%(icount)i];\n'
260
+ '}'
261
+ ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
262
+ code = rcode(x[i], assign_to=y[i])
263
+ assert code == expected
264
+
265
+
266
+ def test_rcode_loops_add():
267
+ n, m = symbols('n m', integer=True)
268
+ A = IndexedBase('A')
269
+ x = IndexedBase('x')
270
+ y = IndexedBase('y')
271
+ z = IndexedBase('z')
272
+ i = Idx('i', m)
273
+ j = Idx('j', n)
274
+
275
+ s = (
276
+ 'for (i in 1:m){\n'
277
+ ' y[i] = x[i] + z[i];\n'
278
+ '}\n'
279
+ 'for (i in 1:m){\n'
280
+ ' for (j in 1:n){\n'
281
+ ' y[i] = A[i, j]*x[j] + y[i];\n'
282
+ ' }\n'
283
+ '}'
284
+ )
285
+ c = rcode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
286
+ assert c == s
287
+
288
+
289
+ def test_rcode_loops_multiple_contractions():
290
+ n, m, o, p = symbols('n m o p', integer=True)
291
+ a = IndexedBase('a')
292
+ b = IndexedBase('b')
293
+ y = IndexedBase('y')
294
+ i = Idx('i', m)
295
+ j = Idx('j', n)
296
+ k = Idx('k', o)
297
+ l = Idx('l', p)
298
+
299
+ s = (
300
+ 'for (i in 1:m){\n'
301
+ ' y[i] = 0;\n'
302
+ '}\n'
303
+ 'for (i in 1:m){\n'
304
+ ' for (j in 1:n){\n'
305
+ ' for (k in 1:o){\n'
306
+ ' for (l in 1:p){\n'
307
+ ' y[i] = a[i, j, k, l]*b[j, k, l] + y[i];\n'
308
+ ' }\n'
309
+ ' }\n'
310
+ ' }\n'
311
+ '}'
312
+ )
313
+ c = rcode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
314
+ assert c == s
315
+
316
+
317
+ def test_rcode_loops_addfactor():
318
+ n, m, o, p = symbols('n m o p', integer=True)
319
+ a = IndexedBase('a')
320
+ b = IndexedBase('b')
321
+ c = IndexedBase('c')
322
+ y = IndexedBase('y')
323
+ i = Idx('i', m)
324
+ j = Idx('j', n)
325
+ k = Idx('k', o)
326
+ l = Idx('l', p)
327
+
328
+ s = (
329
+ 'for (i in 1:m){\n'
330
+ ' y[i] = 0;\n'
331
+ '}\n'
332
+ 'for (i in 1:m){\n'
333
+ ' for (j in 1:n){\n'
334
+ ' for (k in 1:o){\n'
335
+ ' for (l in 1:p){\n'
336
+ ' y[i] = (a[i, j, k, l] + b[i, j, k, l])*c[j, k, l] + y[i];\n'
337
+ ' }\n'
338
+ ' }\n'
339
+ ' }\n'
340
+ '}'
341
+ )
342
+ c = rcode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
343
+ assert c == s
344
+
345
+
346
+ def test_rcode_loops_multiple_terms():
347
+ n, m, o, p = symbols('n m o p', integer=True)
348
+ a = IndexedBase('a')
349
+ b = IndexedBase('b')
350
+ c = IndexedBase('c')
351
+ y = IndexedBase('y')
352
+ i = Idx('i', m)
353
+ j = Idx('j', n)
354
+ k = Idx('k', o)
355
+
356
+ s0 = (
357
+ 'for (i in 1:m){\n'
358
+ ' y[i] = 0;\n'
359
+ '}\n'
360
+ )
361
+ s1 = (
362
+ 'for (i in 1:m){\n'
363
+ ' for (j in 1:n){\n'
364
+ ' for (k in 1:o){\n'
365
+ ' y[i] = b[j]*b[k]*c[i, j, k] + y[i];\n'
366
+ ' }\n'
367
+ ' }\n'
368
+ '}\n'
369
+ )
370
+ s2 = (
371
+ 'for (i in 1:m){\n'
372
+ ' for (k in 1:o){\n'
373
+ ' y[i] = a[i, k]*b[k] + y[i];\n'
374
+ ' }\n'
375
+ '}\n'
376
+ )
377
+ s3 = (
378
+ 'for (i in 1:m){\n'
379
+ ' for (j in 1:n){\n'
380
+ ' y[i] = a[i, j]*b[j] + y[i];\n'
381
+ ' }\n'
382
+ '}\n'
383
+ )
384
+ c = rcode(
385
+ b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
386
+
387
+ ref={}
388
+ ref[0] = s0 + s1 + s2 + s3[:-1]
389
+ ref[1] = s0 + s1 + s3 + s2[:-1]
390
+ ref[2] = s0 + s2 + s1 + s3[:-1]
391
+ ref[3] = s0 + s2 + s3 + s1[:-1]
392
+ ref[4] = s0 + s3 + s1 + s2[:-1]
393
+ ref[5] = s0 + s3 + s2 + s1[:-1]
394
+
395
+ assert (c == ref[0] or
396
+ c == ref[1] or
397
+ c == ref[2] or
398
+ c == ref[3] or
399
+ c == ref[4] or
400
+ c == ref[5])
401
+
402
+
403
+ def test_dereference_printing():
404
+ expr = x + y + sin(z) + z
405
+ assert rcode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
406
+
407
+
408
+ def test_Matrix_printing():
409
+ # Test returning a Matrix
410
+ mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
411
+ A = MatrixSymbol('A', 3, 1)
412
+ p = rcode(mat, A)
413
+ assert p == (
414
+ "A[0] = x*y;\n"
415
+ "A[1] = ifelse(y > 0,x + 2,y);\n"
416
+ "A[2] = sin(z);")
417
+ # Test using MatrixElements in expressions
418
+ expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
419
+ p = rcode(expr)
420
+ assert p == ("ifelse(x > 0,2*A[2],A[2]) + sin(A[1]) + A[0]")
421
+ # Test using MatrixElements in a Matrix
422
+ q = MatrixSymbol('q', 5, 1)
423
+ M = MatrixSymbol('M', 3, 3)
424
+ m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
425
+ [q[1,0] + q[2,0], q[3, 0], 5],
426
+ [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
427
+ assert rcode(m, M) == (
428
+ "M[0] = sin(q[1]);\n"
429
+ "M[1] = 0;\n"
430
+ "M[2] = cos(q[2]);\n"
431
+ "M[3] = q[1] + q[2];\n"
432
+ "M[4] = q[3];\n"
433
+ "M[5] = 5;\n"
434
+ "M[6] = 2*q[4]/q[1];\n"
435
+ "M[7] = sqrt(q[0]) + 4;\n"
436
+ "M[8] = 0;")
437
+
438
+
439
+ def test_rcode_sgn():
440
+
441
+ expr = sign(x) * y
442
+ assert rcode(expr) == 'y*sign(x)'
443
+ p = rcode(expr, 'z')
444
+ assert p == 'z = y*sign(x);'
445
+
446
+ p = rcode(sign(2 * x + x**2) * x + x**2)
447
+ assert p == "x^2 + x*sign(x^2 + 2*x)"
448
+
449
+ expr = sign(cos(x))
450
+ p = rcode(expr)
451
+ assert p == 'sign(cos(x))'
452
+
453
+ def test_rcode_Assignment():
454
+ assert rcode(Assignment(x, y + z)) == 'x = y + z;'
455
+ assert rcode(aug_assign(x, '+', y + z)) == 'x += y + z;'
456
+
457
+
458
+ def test_rcode_For():
459
+ f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
460
+ sol = rcode(f)
461
+ assert sol == ("for(x in seq(from=0, to=9, by=2){\n"
462
+ " y *= x;\n"
463
+ "}")
464
+
465
+
466
+ def test_MatrixElement_printing():
467
+ # test cases for issue #11821
468
+ A = MatrixSymbol("A", 1, 3)
469
+ B = MatrixSymbol("B", 1, 3)
470
+ C = MatrixSymbol("C", 1, 3)
471
+
472
+ assert(rcode(A[0, 0]) == "A[0]")
473
+ assert(rcode(3 * A[0, 0]) == "3*A[0]")
474
+
475
+ F = C[0, 0].subs(C, A - B)
476
+ assert(rcode(F) == "(A - B)[0]")
.venv/lib/python3.13/site-packages/sympy/strategies/branch/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import traverse
2
+ from .core import (
3
+ condition, debug, multiplex, exhaust, notempty,
4
+ chain, onaction, sfilter, yieldify, do_one, identity)
5
+ from .tools import canon
6
+
7
+ __all__ = [
8
+ 'traverse',
9
+
10
+ 'condition', 'debug', 'multiplex', 'exhaust', 'notempty', 'chain',
11
+ 'onaction', 'sfilter', 'yieldify', 'do_one', 'identity',
12
+
13
+ 'canon',
14
+ ]
.venv/lib/python3.13/site-packages/sympy/strategies/branch/core.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Generic SymPy-Independent Strategies """
2
+
3
+
4
+ def identity(x):
5
+ yield x
6
+
7
+
8
+ def exhaust(brule):
9
+ """ Apply a branching rule repeatedly until it has no effect """
10
+ def exhaust_brl(expr):
11
+ seen = {expr}
12
+ for nexpr in brule(expr):
13
+ if nexpr not in seen:
14
+ seen.add(nexpr)
15
+ yield from exhaust_brl(nexpr)
16
+ if seen == {expr}:
17
+ yield expr
18
+ return exhaust_brl
19
+
20
+
21
+ def onaction(brule, fn):
22
+ def onaction_brl(expr):
23
+ for result in brule(expr):
24
+ if result != expr:
25
+ fn(brule, expr, result)
26
+ yield result
27
+ return onaction_brl
28
+
29
+
30
+ def debug(brule, file=None):
31
+ """ Print the input and output expressions at each rule application """
32
+ if not file:
33
+ from sys import stdout
34
+ file = stdout
35
+
36
+ def write(brl, expr, result):
37
+ file.write("Rule: %s\n" % brl.__name__)
38
+ file.write("In: %s\nOut: %s\n\n" % (expr, result))
39
+
40
+ return onaction(brule, write)
41
+
42
+
43
+ def multiplex(*brules):
44
+ """ Multiplex many branching rules into one """
45
+ def multiplex_brl(expr):
46
+ seen = set()
47
+ for brl in brules:
48
+ for nexpr in brl(expr):
49
+ if nexpr not in seen:
50
+ seen.add(nexpr)
51
+ yield nexpr
52
+ return multiplex_brl
53
+
54
+
55
+ def condition(cond, brule):
56
+ """ Only apply branching rule if condition is true """
57
+ def conditioned_brl(expr):
58
+ if cond(expr):
59
+ yield from brule(expr)
60
+ else:
61
+ pass
62
+ return conditioned_brl
63
+
64
+
65
+ def sfilter(pred, brule):
66
+ """ Yield only those results which satisfy the predicate """
67
+ def filtered_brl(expr):
68
+ yield from filter(pred, brule(expr))
69
+ return filtered_brl
70
+
71
+
72
+ def notempty(brule):
73
+ def notempty_brl(expr):
74
+ yielded = False
75
+ for nexpr in brule(expr):
76
+ yielded = True
77
+ yield nexpr
78
+ if not yielded:
79
+ yield expr
80
+ return notempty_brl
81
+
82
+
83
+ def do_one(*brules):
84
+ """ Execute one of the branching rules """
85
+ def do_one_brl(expr):
86
+ yielded = False
87
+ for brl in brules:
88
+ for nexpr in brl(expr):
89
+ yielded = True
90
+ yield nexpr
91
+ if yielded:
92
+ return
93
+ return do_one_brl
94
+
95
+
96
+ def chain(*brules):
97
+ """
98
+ Compose a sequence of brules so that they apply to the expr sequentially
99
+ """
100
+ def chain_brl(expr):
101
+ if not brules:
102
+ yield expr
103
+ return
104
+
105
+ head, tail = brules[0], brules[1:]
106
+ for nexpr in head(expr):
107
+ yield from chain(*tail)(nexpr)
108
+
109
+ return chain_brl
110
+
111
+
112
+ def yieldify(rl):
113
+ """ Turn a rule into a branching rule """
114
+ def brl(expr):
115
+ yield rl(expr)
116
+ return brl
.venv/lib/python3.13/site-packages/sympy/strategies/branch/tests/test_traverse.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.basic import Basic
2
+ from sympy.core.numbers import Integer
3
+ from sympy.core.singleton import S
4
+ from sympy.strategies.branch.traverse import top_down, sall
5
+ from sympy.strategies.branch.core import do_one, identity
6
+
7
+
8
+ def inc(x):
9
+ if isinstance(x, Integer):
10
+ yield x + 1
11
+
12
+
13
+ def test_top_down_easy():
14
+ expr = Basic(S(1), S(2))
15
+ expected = Basic(S(2), S(3))
16
+ brl = top_down(inc)
17
+
18
+ assert set(brl(expr)) == {expected}
19
+
20
+
21
+ def test_top_down_big_tree():
22
+ expr = Basic(S(1), Basic(S(2)), Basic(S(3), Basic(S(4)), S(5)))
23
+ expected = Basic(S(2), Basic(S(3)), Basic(S(4), Basic(S(5)), S(6)))
24
+ brl = top_down(inc)
25
+
26
+ assert set(brl(expr)) == {expected}
27
+
28
+
29
+ def test_top_down_harder_function():
30
+ def split5(x):
31
+ if x == 5:
32
+ yield x - 1
33
+ yield x + 1
34
+
35
+ expr = Basic(Basic(S(5), S(6)), S(1))
36
+ expected = {Basic(Basic(S(4), S(6)), S(1)), Basic(Basic(S(6), S(6)), S(1))}
37
+ brl = top_down(split5)
38
+
39
+ assert set(brl(expr)) == expected
40
+
41
+
42
+ def test_sall():
43
+ expr = Basic(S(1), S(2))
44
+ expected = Basic(S(2), S(3))
45
+ brl = sall(inc)
46
+
47
+ assert list(brl(expr)) == [expected]
48
+
49
+ expr = Basic(S(1), S(2), Basic(S(3), S(4)))
50
+ expected = Basic(S(2), S(3), Basic(S(3), S(4)))
51
+ brl = sall(do_one(inc, identity))
52
+
53
+ assert list(brl(expr)) == [expected]
.venv/lib/python3.13/site-packages/sympy/strategies/branch/tools.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .core import exhaust, multiplex
2
+ from .traverse import top_down
3
+
4
+
5
+ def canon(*rules):
6
+ """ Strategy for canonicalization
7
+
8
+ Apply each branching rule in a top-down fashion through the tree.
9
+ Multiplex through all branching rule traversals
10
+ Keep doing this until there is no change.
11
+ """
12
+ return exhaust(multiplex(*map(top_down, rules)))
.venv/lib/python3.13/site-packages/sympy/strategies/branch/traverse.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Branching Strategies to Traverse a Tree """
2
+ from itertools import product
3
+ from sympy.strategies.util import basic_fns
4
+ from .core import chain, identity, do_one
5
+
6
+
7
+ def top_down(brule, fns=basic_fns):
8
+ """ Apply a rule down a tree running it on the top nodes first """
9
+ return chain(do_one(brule, identity),
10
+ lambda expr: sall(top_down(brule, fns), fns)(expr))
11
+
12
+
13
+ def sall(brule, fns=basic_fns):
14
+ """ Strategic all - apply rule to args """
15
+ op, new, children, leaf = map(fns.get, ('op', 'new', 'children', 'leaf'))
16
+
17
+ def all_rl(expr):
18
+ if leaf(expr):
19
+ yield expr
20
+ else:
21
+ myop = op(expr)
22
+ argss = product(*map(brule, children(expr)))
23
+ for args in argss:
24
+ yield new(myop, *args)
25
+ return all_rl
.venv/lib/python3.13/site-packages/sympy/strategies/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/strategies/tests/test_core.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from sympy.core.singleton import S
3
+ from sympy.core.basic import Basic
4
+ from sympy.strategies.core import (
5
+ null_safe, exhaust, memoize, condition,
6
+ chain, tryit, do_one, debug, switch, minimize)
7
+ from io import StringIO
8
+
9
+
10
+ def posdec(x: int) -> int:
11
+ if x > 0:
12
+ return x - 1
13
+ return x
14
+
15
+
16
+ def inc(x: int) -> int:
17
+ return x + 1
18
+
19
+
20
+ def dec(x: int) -> int:
21
+ return x - 1
22
+
23
+
24
+ def test_null_safe():
25
+ def rl(expr: int) -> int | None:
26
+ if expr == 1:
27
+ return 2
28
+ return None
29
+
30
+ safe_rl = null_safe(rl)
31
+ assert rl(1) == safe_rl(1)
32
+ assert rl(3) is None
33
+ assert safe_rl(3) == 3
34
+
35
+
36
+ def test_exhaust():
37
+ sink = exhaust(posdec)
38
+ assert sink(5) == 0
39
+ assert sink(10) == 0
40
+
41
+
42
+ def test_memoize():
43
+ rl = memoize(posdec)
44
+ assert rl(5) == posdec(5)
45
+ assert rl(5) == posdec(5)
46
+ assert rl(-2) == posdec(-2)
47
+
48
+
49
+ def test_condition():
50
+ rl = condition(lambda x: x % 2 == 0, posdec)
51
+ assert rl(5) == 5
52
+ assert rl(4) == 3
53
+
54
+
55
+ def test_chain():
56
+ rl = chain(posdec, posdec)
57
+ assert rl(5) == 3
58
+ assert rl(1) == 0
59
+
60
+
61
+ def test_tryit():
62
+ def rl(expr: Basic) -> Basic:
63
+ assert False
64
+
65
+ safe_rl = tryit(rl, AssertionError)
66
+ assert safe_rl(S(1)) == S(1)
67
+
68
+
69
+ def test_do_one():
70
+ rl = do_one(posdec, posdec)
71
+ assert rl(5) == 4
72
+
73
+ def rl1(x: int) -> int:
74
+ if x == 1:
75
+ return 2
76
+ return x
77
+
78
+ def rl2(x: int) -> int:
79
+ if x == 2:
80
+ return 3
81
+ return x
82
+
83
+ rule = do_one(rl1, rl2)
84
+ assert rule(1) == 2
85
+ assert rule(rule(1)) == 3
86
+
87
+
88
+ def test_debug():
89
+ file = StringIO()
90
+ rl = debug(posdec, file)
91
+ rl(5)
92
+ log = file.getvalue()
93
+ file.close()
94
+
95
+ assert posdec.__name__ in log
96
+ assert '5' in log
97
+ assert '4' in log
98
+
99
+
100
+ def test_switch():
101
+ def key(x: int) -> int:
102
+ return x % 3
103
+
104
+ rl = switch(key, {0: inc, 1: dec})
105
+ assert rl(3) == 4
106
+ assert rl(4) == 3
107
+ assert rl(5) == 5
108
+
109
+
110
+ def test_minimize():
111
+ def key(x: int) -> int:
112
+ return -x
113
+
114
+ rl = minimize(inc, dec)
115
+ assert rl(4) == 3
116
+
117
+ rl = minimize(inc, dec, objective=key)
118
+ assert rl(4) == 5
.venv/lib/python3.13/site-packages/sympy/strategies/tests/test_rl.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.singleton import S
2
+ from sympy.strategies.rl import (
3
+ rm_id, glom, flatten, unpack, sort, distribute, subs, rebuild)
4
+ from sympy.core.basic import Basic
5
+ from sympy.core.add import Add
6
+ from sympy.core.mul import Mul
7
+ from sympy.core.symbol import symbols
8
+ from sympy.abc import x
9
+
10
+
11
+ def test_rm_id():
12
+ rmzeros = rm_id(lambda x: x == 0)
13
+ assert rmzeros(Basic(S(0), S(1))) == Basic(S(1))
14
+ assert rmzeros(Basic(S(0), S(0))) == Basic(S(0))
15
+ assert rmzeros(Basic(S(2), S(1))) == Basic(S(2), S(1))
16
+
17
+
18
+ def test_glom():
19
+ def key(x):
20
+ return x.as_coeff_Mul()[1]
21
+
22
+ def count(x):
23
+ return x.as_coeff_Mul()[0]
24
+
25
+ def newargs(cnt, arg):
26
+ return cnt * arg
27
+
28
+ rl = glom(key, count, newargs)
29
+
30
+ result = rl(Add(x, -x, 3 * x, 2, 3, evaluate=False))
31
+ expected = Add(3 * x, 5)
32
+ assert set(result.args) == set(expected.args)
33
+
34
+
35
+ def test_flatten():
36
+ assert flatten(Basic(S(1), S(2), Basic(S(3), S(4)))) == \
37
+ Basic(S(1), S(2), S(3), S(4))
38
+
39
+
40
+ def test_unpack():
41
+ assert unpack(Basic(S(2))) == 2
42
+ assert unpack(Basic(S(2), S(3))) == Basic(S(2), S(3))
43
+
44
+
45
+ def test_sort():
46
+ assert sort(str)(Basic(S(3), S(1), S(2))) == Basic(S(1), S(2), S(3))
47
+
48
+
49
+ def test_distribute():
50
+ class T1(Basic):
51
+ pass
52
+
53
+ class T2(Basic):
54
+ pass
55
+
56
+ distribute_t12 = distribute(T1, T2)
57
+ assert distribute_t12(T1(S(1), S(2), T2(S(3), S(4)), S(5))) == \
58
+ T2(T1(S(1), S(2), S(3), S(5)), T1(S(1), S(2), S(4), S(5)))
59
+ assert distribute_t12(T1(S(1), S(2), S(3))) == T1(S(1), S(2), S(3))
60
+
61
+
62
+ def test_distribute_add_mul():
63
+ x, y = symbols('x, y')
64
+ expr = Mul(2, Add(x, y), evaluate=False)
65
+ expected = Add(Mul(2, x), Mul(2, y))
66
+ distribute_mul = distribute(Mul, Add)
67
+ assert distribute_mul(expr) == expected
68
+
69
+
70
+ def test_subs():
71
+ rl = subs(1, 2)
72
+ assert rl(1) == 2
73
+ assert rl(3) == 3
74
+
75
+
76
+ def test_rebuild():
77
+ expr = Basic.__new__(Add, S(1), S(2))
78
+ assert rebuild(expr) == 3
.venv/lib/python3.13/site-packages/sympy/strategies/tests/test_tools.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.strategies.tools import subs, typed
2
+ from sympy.strategies.rl import rm_id
3
+ from sympy.core.basic import Basic
4
+ from sympy.core.singleton import S
5
+
6
+
7
+ def test_subs():
8
+ from sympy.core.symbol import symbols
9
+ a, b, c, d, e, f = symbols('a,b,c,d,e,f')
10
+ mapping = {a: d, d: a, Basic(e): Basic(f)}
11
+ expr = Basic(a, Basic(b, c), Basic(d, Basic(e)))
12
+ result = Basic(d, Basic(b, c), Basic(a, Basic(f)))
13
+ assert subs(mapping)(expr) == result
14
+
15
+
16
+ def test_subs_empty():
17
+ assert subs({})(Basic(S(1), S(2))) == Basic(S(1), S(2))
18
+
19
+
20
+ def test_typed():
21
+ class A(Basic):
22
+ pass
23
+
24
+ class B(Basic):
25
+ pass
26
+
27
+ rmzeros = rm_id(lambda x: x == S(0))
28
+ rmones = rm_id(lambda x: x == S(1))
29
+ remove_something = typed({A: rmzeros, B: rmones})
30
+
31
+ assert remove_something(A(S(0), S(1))) == A(S(1))
32
+ assert remove_something(B(S(0), S(1))) == B(S(0))
.venv/lib/python3.13/site-packages/sympy/strategies/tests/test_traverse.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.strategies.traverse import (
2
+ top_down, bottom_up, sall, top_down_once, bottom_up_once, basic_fns)
3
+ from sympy.strategies.rl import rebuild
4
+ from sympy.strategies.util import expr_fns
5
+ from sympy.core.add import Add
6
+ from sympy.core.basic import Basic
7
+ from sympy.core.numbers import Integer
8
+ from sympy.core.singleton import S
9
+ from sympy.core.symbol import Str, Symbol
10
+ from sympy.abc import x, y, z
11
+
12
+
13
+ def zero_symbols(expression):
14
+ return S.Zero if isinstance(expression, Symbol) else expression
15
+
16
+
17
+ def test_sall():
18
+ zero_onelevel = sall(zero_symbols)
19
+
20
+ assert zero_onelevel(Basic(x, y, Basic(x, z))) == \
21
+ Basic(S(0), S(0), Basic(x, z))
22
+
23
+
24
+ def test_bottom_up():
25
+ _test_global_traversal(bottom_up)
26
+ _test_stop_on_non_basics(bottom_up)
27
+
28
+
29
+ def test_top_down():
30
+ _test_global_traversal(top_down)
31
+ _test_stop_on_non_basics(top_down)
32
+
33
+
34
+ def _test_global_traversal(trav):
35
+ zero_all_symbols = trav(zero_symbols)
36
+
37
+ assert zero_all_symbols(Basic(x, y, Basic(x, z))) == \
38
+ Basic(S(0), S(0), Basic(S(0), S(0)))
39
+
40
+
41
+ def _test_stop_on_non_basics(trav):
42
+ def add_one_if_can(expr):
43
+ try:
44
+ return expr + 1
45
+ except TypeError:
46
+ return expr
47
+
48
+ expr = Basic(S(1), Str('a'), Basic(S(2), Str('b')))
49
+ expected = Basic(S(2), Str('a'), Basic(S(3), Str('b')))
50
+ rl = trav(add_one_if_can)
51
+
52
+ assert rl(expr) == expected
53
+
54
+
55
+ class Basic2(Basic):
56
+ pass
57
+
58
+
59
+ def rl(x):
60
+ if x.args and not isinstance(x.args[0], Integer):
61
+ return Basic2(*x.args)
62
+ return x
63
+
64
+
65
+ def test_top_down_once():
66
+ top_rl = top_down_once(rl)
67
+
68
+ assert top_rl(Basic(S(1.0), S(2.0), Basic(S(3), S(4)))) == \
69
+ Basic2(S(1.0), S(2.0), Basic(S(3), S(4)))
70
+
71
+
72
+ def test_bottom_up_once():
73
+ bottom_rl = bottom_up_once(rl)
74
+
75
+ assert bottom_rl(Basic(S(1), S(2), Basic(S(3.0), S(4.0)))) == \
76
+ Basic(S(1), S(2), Basic2(S(3.0), S(4.0)))
77
+
78
+
79
+ def test_expr_fns():
80
+ expr = x + y**3
81
+ e = bottom_up(lambda v: v + 1, expr_fns)(expr)
82
+ b = bottom_up(lambda v: Basic.__new__(Add, v, S(1)), basic_fns)(expr)
83
+
84
+ assert rebuild(b) == e
.venv/lib/python3.13/site-packages/sympy/strategies/tests/test_tree.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.strategies.tree import treeapply, greedy, allresults, brute
2
+ from functools import partial, reduce
3
+
4
+
5
+ def inc(x):
6
+ return x + 1
7
+
8
+
9
+ def dec(x):
10
+ return x - 1
11
+
12
+
13
+ def double(x):
14
+ return 2 * x
15
+
16
+
17
+ def square(x):
18
+ return x**2
19
+
20
+
21
+ def add(*args):
22
+ return sum(args)
23
+
24
+
25
+ def mul(*args):
26
+ return reduce(lambda a, b: a * b, args, 1)
27
+
28
+
29
+ def test_treeapply():
30
+ tree = ([3, 3], [4, 1], 2)
31
+ assert treeapply(tree, {list: min, tuple: max}) == 3
32
+ assert treeapply(tree, {list: add, tuple: mul}) == 60
33
+
34
+
35
+ def test_treeapply_leaf():
36
+ assert treeapply(3, {}, leaf=lambda x: x**2) == 9
37
+ tree = ([3, 3], [4, 1], 2)
38
+ treep1 = ([4, 4], [5, 2], 3)
39
+ assert treeapply(tree, {list: min, tuple: max}, leaf=lambda x: x + 1) == \
40
+ treeapply(treep1, {list: min, tuple: max})
41
+
42
+
43
+ def test_treeapply_strategies():
44
+ from sympy.strategies import chain, minimize
45
+ join = {list: chain, tuple: minimize}
46
+
47
+ assert treeapply(inc, join) == inc
48
+ assert treeapply((inc, dec), join)(5) == minimize(inc, dec)(5)
49
+ assert treeapply([inc, dec], join)(5) == chain(inc, dec)(5)
50
+ tree = (inc, [dec, double]) # either inc or dec-then-double
51
+ assert treeapply(tree, join)(5) == 6
52
+ assert treeapply(tree, join)(1) == 0
53
+
54
+ maximize = partial(minimize, objective=lambda x: -x)
55
+ join = {list: chain, tuple: maximize}
56
+ fn = treeapply(tree, join)
57
+ assert fn(4) == 6 # highest value comes from the dec then double
58
+ assert fn(1) == 2 # highest value comes from the inc
59
+
60
+
61
+ def test_greedy():
62
+ tree = [inc, (dec, double)] # either inc or dec-then-double
63
+
64
+ fn = greedy(tree, objective=lambda x: -x)
65
+ assert fn(4) == 6 # highest value comes from the dec then double
66
+ assert fn(1) == 2 # highest value comes from the inc
67
+
68
+ tree = [inc, dec, [inc, dec, [(inc, inc), (dec, dec)]]]
69
+ lowest = greedy(tree)
70
+ assert lowest(10) == 8
71
+
72
+ highest = greedy(tree, objective=lambda x: -x)
73
+ assert highest(10) == 12
74
+
75
+
76
+ def test_allresults():
77
+ # square = lambda x: x**2
78
+
79
+ assert set(allresults(inc)(3)) == {inc(3)}
80
+ assert set(allresults([inc, dec])(3)) == {2, 4}
81
+ assert set(allresults((inc, dec))(3)) == {3}
82
+ assert set(allresults([inc, (dec, double)])(4)) == {5, 6}
83
+
84
+
85
+ def test_brute():
86
+ tree = ([inc, dec], square)
87
+ fn = brute(tree, lambda x: -x)
88
+
89
+ assert fn(2) == (2 + 1)**2
90
+ assert fn(-2) == (-2 - 1)**2
91
+
92
+ assert brute(inc)(1) == 2