diff --git a/.venv/lib/python3.13/site-packages/sympy/algebras/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/algebras/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/algebras/tests/test_quaternion.py b/.venv/lib/python3.13/site-packages/sympy/algebras/tests/test_quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..a4331cd6afa05c96e8e11d59df4b7520e4810930 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/algebras/tests/test_quaternion.py @@ -0,0 +1,437 @@ +from sympy.testing.pytest import slow +from sympy.core.function import diff +from sympy.core.function import expand +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (Abs, conjugate, im, re, sign) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, cos, sin, atan2, atan) +from sympy.integrals.integrals import integrate +from sympy.matrices.dense import Matrix +from sympy.simplify import simplify +from sympy.simplify.trigsimp import trigsimp +from sympy.algebras.quaternion import Quaternion +from sympy.testing.pytest import raises +import math +from itertools import permutations, product + +w, x, y, z = symbols('w:z') +phi = symbols('phi') + +def test_quaternion_construction(): + q = Quaternion(w, x, y, z) + assert q + q == Quaternion(2*w, 2*x, 2*y, 2*z) + + q2 = Quaternion.from_axis_angle((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3), + pi*Rational(2, 3)) + assert q2 == Quaternion(S.Half, S.Half, + S.Half, S.Half) + + M = Matrix([[cos(phi), -sin(phi), 0], [sin(phi), cos(phi), 0], [0, 0, 1]]) + q3 = trigsimp(Quaternion.from_rotation_matrix(M)) + assert q3 == Quaternion( + sqrt(2)*sqrt(cos(phi) + 1)/2, 0, 0, sqrt(2 - 2*cos(phi))*sign(sin(phi))/2) + + nc = Symbol('nc', commutative=False) + raises(ValueError, lambda: Quaternion(w, x, nc, z)) + + +def test_quaternion_construction_norm(): + q1 = Quaternion(*symbols('a:d')) + + q2 = Quaternion(w, x, y, z) + assert expand((q1*q2).norm()**2 - (q1.norm()**2 * q2.norm()**2)) == 0 + + q3 = Quaternion(w, x, y, z, norm=1) + assert (q1 * q3).norm() == q1.norm() + + +def test_issue_25254(): + # calculating the inverse cached the norm which caused problems + # when multiplying + p = Quaternion(1, 0, 0, 0) + q = Quaternion.from_axis_angle((1, 1, 1), 3 * math.pi/4) + qi = q.inverse() # this operation cached the norm + test = q * p * qi + assert ((test - p).norm() < 1E-10) + + +def test_to_and_from_Matrix(): + q = Quaternion(w, x, y, z) + q_full = Quaternion.from_Matrix(q.to_Matrix()) + q_vect = Quaternion.from_Matrix(q.to_Matrix(True)) + assert (q - q_full).is_zero_quaternion() + assert (q.vector_part() - q_vect).is_zero_quaternion() + + +def test_product_matrices(): + q1 = Quaternion(w, x, y, z) + q2 = Quaternion(*(symbols("a:d"))) + assert (q1 * q2).to_Matrix() == q1.product_matrix_left * q2.to_Matrix() + assert (q1 * q2).to_Matrix() == q2.product_matrix_right * q1.to_Matrix() + + R1 = (q1.product_matrix_left * q1.product_matrix_right.T)[1:, 1:] + R2 = simplify(q1.to_rotation_matrix()*q1.norm()**2) + assert R1 == R2 + + +def test_quaternion_axis_angle(): + + test_data = [ # axis, angle, expected_quaternion + ((1, 0, 0), 0, (1, 0, 0, 0)), + ((1, 0, 0), pi/2, (sqrt(2)/2, sqrt(2)/2, 0, 0)), + ((0, 1, 0), pi/2, (sqrt(2)/2, 0, sqrt(2)/2, 0)), + ((0, 0, 1), pi/2, (sqrt(2)/2, 0, 0, sqrt(2)/2)), + ((1, 0, 0), pi, (0, 1, 0, 0)), + ((0, 1, 0), pi, (0, 0, 1, 0)), + ((0, 0, 1), pi, (0, 0, 0, 1)), + ((1, 1, 1), pi, (0, 1/sqrt(3),1/sqrt(3),1/sqrt(3))), + ((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3), pi*2/3, (S.Half, S.Half, S.Half, S.Half)) + ] + + for axis, angle, expected in test_data: + assert Quaternion.from_axis_angle(axis, angle) == Quaternion(*expected) + + +def test_quaternion_axis_angle_simplification(): + result = Quaternion.from_axis_angle((1, 2, 3), asin(4)) + assert result.a == cos(asin(4)/2) + assert result.b == sqrt(14)*sin(asin(4)/2)/14 + assert result.c == sqrt(14)*sin(asin(4)/2)/7 + assert result.d == 3*sqrt(14)*sin(asin(4)/2)/14 + +def test_quaternion_complex_real_addition(): + a = symbols("a", complex=True) + b = symbols("b", real=True) + # This symbol is not complex: + c = symbols("c", commutative=False) + + q = Quaternion(w, x, y, z) + assert a + q == Quaternion(w + re(a), x + im(a), y, z) + assert 1 + q == Quaternion(1 + w, x, y, z) + assert I + q == Quaternion(w, 1 + x, y, z) + assert b + q == Quaternion(w + b, x, y, z) + raises(ValueError, lambda: c + q) + raises(ValueError, lambda: q * c) + raises(ValueError, lambda: c * q) + + assert -q == Quaternion(-w, -x, -y, -z) + + q1 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False) + q2 = Quaternion(1, 4, 7, 8) + + assert q1 + (2 + 3*I) == Quaternion(5 + 7*I, 2 + 5*I, 0, 7 + 8*I) + assert q2 + (2 + 3*I) == Quaternion(3, 7, 7, 8) + assert q1 * (2 + 3*I) == \ + Quaternion((2 + 3*I)*(3 + 4*I), (2 + 3*I)*(2 + 5*I), 0, (2 + 3*I)*(7 + 8*I)) + assert q2 * (2 + 3*I) == Quaternion(-10, 11, 38, -5) + + q1 = Quaternion(1, 2, 3, 4) + q0 = Quaternion(0, 0, 0, 0) + assert q1 + q0 == q1 + assert q1 - q0 == q1 + assert q1 - q1 == q0 + + +def test_quaternion_subs(): + q = Quaternion.from_axis_angle((0, 0, 1), phi) + assert q.subs(phi, 0) == Quaternion(1, 0, 0, 0) + + +def test_quaternion_evalf(): + assert (Quaternion(sqrt(2), 0, 0, sqrt(3)).evalf() == + Quaternion(sqrt(2).evalf(), 0, 0, sqrt(3).evalf())) + assert (Quaternion(1/sqrt(2), 0, 0, 1/sqrt(2)).evalf() == + Quaternion((1/sqrt(2)).evalf(), 0, 0, (1/sqrt(2)).evalf())) + + +def test_quaternion_functions(): + q = Quaternion(w, x, y, z) + q1 = Quaternion(1, 2, 3, 4) + q0 = Quaternion(0, 0, 0, 0) + + assert conjugate(q) == Quaternion(w, -x, -y, -z) + assert q.norm() == sqrt(w**2 + x**2 + y**2 + z**2) + assert q.normalize() == Quaternion(w, x, y, z) / sqrt(w**2 + x**2 + y**2 + z**2) + assert q.inverse() == Quaternion(w, -x, -y, -z) / (w**2 + x**2 + y**2 + z**2) + assert q.inverse() == q.pow(-1) + raises(ValueError, lambda: q0.inverse()) + assert q.pow(2) == Quaternion(w**2 - x**2 - y**2 - z**2, 2*w*x, 2*w*y, 2*w*z) + assert q**(2) == Quaternion(w**2 - x**2 - y**2 - z**2, 2*w*x, 2*w*y, 2*w*z) + assert q1.pow(-2) == Quaternion( + Rational(-7, 225), Rational(-1, 225), Rational(-1, 150), Rational(-2, 225)) + assert q1**(-2) == Quaternion( + Rational(-7, 225), Rational(-1, 225), Rational(-1, 150), Rational(-2, 225)) + assert q1.pow(-0.5) == NotImplemented + raises(TypeError, lambda: q1**(-0.5)) + + assert q1.exp() == \ + Quaternion(E * cos(sqrt(29)), + 2 * sqrt(29) * E * sin(sqrt(29)) / 29, + 3 * sqrt(29) * E * sin(sqrt(29)) / 29, + 4 * sqrt(29) * E * sin(sqrt(29)) / 29) + assert q1.log() == \ + Quaternion(log(sqrt(30)), + 2 * sqrt(29) * acos(sqrt(30)/30) / 29, + 3 * sqrt(29) * acos(sqrt(30)/30) / 29, + 4 * sqrt(29) * acos(sqrt(30)/30) / 29) + + assert q1.pow_cos_sin(2) == \ + Quaternion(30 * cos(2 * acos(sqrt(30)/30)), + 60 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29, + 90 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29, + 120 * sqrt(29) * sin(2 * acos(sqrt(30)/30)) / 29) + + assert diff(Quaternion(x, x, x, x), x) == Quaternion(1, 1, 1, 1) + + assert integrate(Quaternion(x, x, x, x), x) == \ + Quaternion(x**2 / 2, x**2 / 2, x**2 / 2, x**2 / 2) + + assert Quaternion(1, x, x**2, x**3).integrate(x) == \ + Quaternion(x, x**2/2, x**3/3, x**4/4) + + assert Quaternion(sin(x), cos(x), sin(2*x), cos(2*x)).integrate(x) == \ + Quaternion(-cos(x), sin(x), -cos(2*x)/2, sin(2*x)/2) + + assert Quaternion(x**2, y**2, z**2, x*y*z).integrate(x, y) == \ + Quaternion(x**3*y/3, x*y**3/3, x*y*z**2, x**2*y**2*z/4) + + assert Quaternion.rotate_point((1, 1, 1), q1) == (S.One / 5, 1, S(7) / 5) + n = Symbol('n') + raises(TypeError, lambda: q1**n) + n = Symbol('n', integer=True) + raises(TypeError, lambda: q1**n) + + assert Quaternion(22, 23, 55, 8).scalar_part() == 22 + assert Quaternion(w, x, y, z).scalar_part() == w + + assert Quaternion(22, 23, 55, 8).vector_part() == Quaternion(0, 23, 55, 8) + assert Quaternion(w, x, y, z).vector_part() == Quaternion(0, x, y, z) + + assert q1.axis() == Quaternion(0, 2*sqrt(29)/29, 3*sqrt(29)/29, 4*sqrt(29)/29) + assert q1.axis().pow(2) == Quaternion(-1, 0, 0, 0) + assert q0.axis().scalar_part() == 0 + assert (q.axis() == Quaternion(0, + x/sqrt(x**2 + y**2 + z**2), + y/sqrt(x**2 + y**2 + z**2), + z/sqrt(x**2 + y**2 + z**2))) + + assert q0.is_pure() is True + assert q1.is_pure() is False + assert Quaternion(0, 0, 0, 3).is_pure() is True + assert Quaternion(0, 2, 10, 3).is_pure() is True + assert Quaternion(w, 2, 10, 3).is_pure() is None + + assert q1.angle() == 2*atan(sqrt(29)) + assert q.angle() == 2*atan2(sqrt(x**2 + y**2 + z**2), w) + + assert Quaternion.arc_coplanar(q1, Quaternion(2, 4, 6, 8)) is True + assert Quaternion.arc_coplanar(q1, Quaternion(1, -2, -3, -4)) is True + assert Quaternion.arc_coplanar(q1, Quaternion(1, 8, 12, 16)) is True + assert Quaternion.arc_coplanar(q1, Quaternion(1, 2, 3, 4)) is True + assert Quaternion.arc_coplanar(q1, Quaternion(w, 4, 6, 8)) is True + assert Quaternion.arc_coplanar(q1, Quaternion(2, 7, 4, 1)) is False + assert Quaternion.arc_coplanar(q1, Quaternion(w, x, y, z)) is None + raises(ValueError, lambda: Quaternion.arc_coplanar(q1, q0)) + + assert Quaternion.vector_coplanar( + Quaternion(0, 8, 12, 16), + Quaternion(0, 4, 6, 8), + Quaternion(0, 2, 3, 4)) is True + assert Quaternion.vector_coplanar( + Quaternion(0, 0, 0, 0), Quaternion(0, 4, 6, 8), Quaternion(0, 2, 3, 4)) is True + assert Quaternion.vector_coplanar( + Quaternion(0, 8, 2, 6), Quaternion(0, 1, 6, 6), Quaternion(0, 0, 3, 4)) is False + assert Quaternion.vector_coplanar( + Quaternion(0, 1, 3, 4), + Quaternion(0, 4, w, 6), + Quaternion(0, 6, 8, 1)) is None + raises(ValueError, lambda: + Quaternion.vector_coplanar(q0, Quaternion(0, 4, 6, 8), q1)) + + assert Quaternion(0, 1, 2, 3).parallel(Quaternion(0, 2, 4, 6)) is True + assert Quaternion(0, 1, 2, 3).parallel(Quaternion(0, 2, 2, 6)) is False + assert Quaternion(0, 1, 2, 3).parallel(Quaternion(w, x, y, 6)) is None + raises(ValueError, lambda: q0.parallel(q1)) + + assert Quaternion(0, 1, 2, 3).orthogonal(Quaternion(0, -2, 1, 0)) is True + assert Quaternion(0, 2, 4, 7).orthogonal(Quaternion(0, 2, 2, 6)) is False + assert Quaternion(0, 2, 4, 7).orthogonal(Quaternion(w, x, y, 6)) is None + raises(ValueError, lambda: q0.orthogonal(q1)) + + assert q1.index_vector() == Quaternion( + 0, 2*sqrt(870)/29, + 3*sqrt(870)/29, + 4*sqrt(870)/29) + assert Quaternion(0, 3, 9, 4).index_vector() == Quaternion(0, 3, 9, 4) + + assert Quaternion(4, 3, 9, 4).mensor() == log(sqrt(122)) + assert Quaternion(3, 3, 0, 2).mensor() == log(sqrt(22)) + + assert q0.is_zero_quaternion() is True + assert q1.is_zero_quaternion() is False + assert Quaternion(w, 0, 0, 0).is_zero_quaternion() is None + +def test_quaternion_conversions(): + q1 = Quaternion(1, 2, 3, 4) + + assert q1.to_axis_angle() == ((2 * sqrt(29)/29, + 3 * sqrt(29)/29, + 4 * sqrt(29)/29), + 2 * acos(sqrt(30)/30)) + + assert (q1.to_rotation_matrix() == + Matrix([[Rational(-2, 3), Rational(2, 15), Rational(11, 15)], + [Rational(2, 3), Rational(-1, 3), Rational(2, 3)], + [Rational(1, 3), Rational(14, 15), Rational(2, 15)]])) + + assert (q1.to_rotation_matrix((1, 1, 1)) == + Matrix([ + [Rational(-2, 3), Rational(2, 15), Rational(11, 15), Rational(4, 5)], + [Rational(2, 3), Rational(-1, 3), Rational(2, 3), S.Zero], + [Rational(1, 3), Rational(14, 15), Rational(2, 15), Rational(-2, 5)], + [S.Zero, S.Zero, S.Zero, S.One]])) + + theta = symbols("theta", real=True) + q2 = Quaternion(cos(theta/2), 0, 0, sin(theta/2)) + + assert trigsimp(q2.to_rotation_matrix()) == Matrix([ + [cos(theta), -sin(theta), 0], + [sin(theta), cos(theta), 0], + [0, 0, 1]]) + + assert q2.to_axis_angle() == ((0, 0, sin(theta/2)/Abs(sin(theta/2))), + 2*acos(cos(theta/2))) + + assert trigsimp(q2.to_rotation_matrix((1, 1, 1))) == Matrix([ + [cos(theta), -sin(theta), 0, sin(theta) - cos(theta) + 1], + [sin(theta), cos(theta), 0, -sin(theta) - cos(theta) + 1], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + +def test_rotation_matrix_homogeneous(): + q = Quaternion(w, x, y, z) + R1 = q.to_rotation_matrix(homogeneous=True) * q.norm()**2 + R2 = simplify(q.to_rotation_matrix(homogeneous=False) * q.norm()**2) + assert R1 == R2 + + +def test_quaternion_rotation_iss1593(): + """ + There was a sign mistake in the definition, + of the rotation matrix. This tests that particular sign mistake. + See issue 1593 for reference. + See wikipedia + https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation#Quaternion-derived_rotation_matrix + for the correct definition + """ + q = Quaternion(cos(phi/2), sin(phi/2), 0, 0) + assert(trigsimp(q.to_rotation_matrix()) == Matrix([ + [1, 0, 0], + [0, cos(phi), -sin(phi)], + [0, sin(phi), cos(phi)]])) + + +def test_quaternion_multiplication(): + q1 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False) + q2 = Quaternion(1, 2, 3, 5) + q3 = Quaternion(1, 1, 1, y) + + assert Quaternion._generic_mul(S(4), S.One) == 4 + assert (Quaternion._generic_mul(S(4), q1) == + Quaternion(12 + 16*I, 8 + 20*I, 0, 28 + 32*I)) + assert q2.mul(2) == Quaternion(2, 4, 6, 10) + assert q2.mul(q3) == Quaternion(-5*y - 4, 3*y - 2, 9 - 2*y, y + 4) + assert q2.mul(q3) == q2*q3 + + z = symbols('z', complex=True) + z_quat = Quaternion(re(z), im(z), 0, 0) + q = Quaternion(*symbols('q:4', real=True)) + + assert z * q == z_quat * q + assert q * z == q * z_quat + + +def test_issue_16318(): + #for rtruediv + q0 = Quaternion(0, 0, 0, 0) + raises(ValueError, lambda: 1/q0) + #for rotate_point + q = Quaternion(1, 2, 3, 4) + (axis, angle) = q.to_axis_angle() + assert Quaternion.rotate_point((1, 1, 1), (axis, angle)) == (S.One / 5, 1, S(7) / 5) + #test for to_axis_angle + q = Quaternion(-1, 1, 1, 1) + axis = (-sqrt(3)/3, -sqrt(3)/3, -sqrt(3)/3) + angle = 2*pi/3 + assert (axis, angle) == q.to_axis_angle() + + +@slow +def test_to_euler(): + q = Quaternion(w, x, y, z) + q_normalized = q.normalize() + + seqs = ['zxy', 'zyx', 'zyz', 'zxz'] + seqs += [seq.upper() for seq in seqs] + + for seq in seqs: + euler_from_q = q.to_euler(seq) + q_back = simplify(Quaternion.from_euler(euler_from_q, seq)) + assert q_back == q_normalized + + +def test_to_euler_iss24504(): + """ + There was a mistake in the degenerate case testing + See issue 24504 for reference. + """ + q = Quaternion.from_euler((phi, 0, 0), 'zyz') + assert trigsimp(q.to_euler('zyz'), inverse=True) == (phi, 0, 0) + + +def test_to_euler_numerical_singilarities(): + + def test_one_case(angles, seq): + q = Quaternion.from_euler(angles, seq) + assert q.to_euler(seq) == angles + + # symmetric + test_one_case((pi/2, 0, 0), 'zyz') + test_one_case((pi/2, 0, 0), 'ZYZ') + test_one_case((pi/2, pi, 0), 'zyz') + test_one_case((pi/2, pi, 0), 'ZYZ') + + # asymmetric + test_one_case((pi/2, pi/2, 0), 'zyx') + test_one_case((pi/2, -pi/2, 0), 'zyx') + test_one_case((pi/2, pi/2, 0), 'ZYX') + test_one_case((pi/2, -pi/2, 0), 'ZYX') + + +@slow +def test_to_euler_options(): + def test_one_case(q): + angles1 = Matrix(q.to_euler(seq, True, True)) + angles2 = Matrix(q.to_euler(seq, False, False)) + angle_errors = simplify(angles1-angles2).evalf() + for angle_error in angle_errors: + # forcing angles to set {-pi, pi} + angle_error = (angle_error + pi) % (2 * pi) - pi + assert angle_error < 10e-7 + + for xyz in ('xyz', 'XYZ'): + for seq_tuple in permutations(xyz): + for symmetric in (True, False): + if symmetric: + seq = ''.join([seq_tuple[0], seq_tuple[1], seq_tuple[0]]) + else: + seq = ''.join(seq_tuple) + + for elements in product([-1, 0, 1], repeat=4): + q = Quaternion(*elements) + if not q.is_zero_quaternion(): + test_one_case(q) diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/__init__.py b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll.py b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll.py new file mode 100644 index 0000000000000000000000000000000000000000..40e6802f7626c982a9a6cd7146baea3ac6b8b6e0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll.py @@ -0,0 +1,308 @@ +"""Implementation of DPLL algorithm + +Further improvements: eliminate calls to pl_true, implement branching rules, +efficient unit propagation. + +References: + - https://en.wikipedia.org/wiki/DPLL_algorithm + - https://www.researchgate.net/publication/242384772_Implementations_of_the_DPLL_Algorithm +""" + +from sympy.core.sorting import default_sort_key +from sympy.logic.boolalg import Or, Not, conjuncts, disjuncts, to_cnf, \ + to_int_repr, _find_predicates +from sympy.assumptions.cnf import CNF +from sympy.logic.inference import pl_true, literal_symbol + + +def dpll_satisfiable(expr): + """ + Check satisfiability of a propositional sentence. + It returns a model rather than True when it succeeds + + >>> from sympy.abc import A, B + >>> from sympy.logic.algorithms.dpll import dpll_satisfiable + >>> dpll_satisfiable(A & ~B) + {A: True, B: False} + >>> dpll_satisfiable(A & ~A) + False + + """ + if not isinstance(expr, CNF): + clauses = conjuncts(to_cnf(expr)) + else: + clauses = expr.clauses + if False in clauses: + return False + symbols = sorted(_find_predicates(expr), key=default_sort_key) + symbols_int_repr = set(range(1, len(symbols) + 1)) + clauses_int_repr = to_int_repr(clauses, symbols) + result = dpll_int_repr(clauses_int_repr, symbols_int_repr, {}) + if not result: + return result + output = {} + for key in result: + output.update({symbols[key - 1]: result[key]}) + return output + + +def dpll(clauses, symbols, model): + """ + Compute satisfiability in a partial model. + Clauses is an array of conjuncts. + + >>> from sympy.abc import A, B, D + >>> from sympy.logic.algorithms.dpll import dpll + >>> dpll([A, B, D], [A, B], {D: False}) + False + + """ + # compute DP kernel + P, value = find_unit_clause(clauses, model) + while P: + model.update({P: value}) + symbols.remove(P) + if not value: + P = ~P + clauses = unit_propagate(clauses, P) + P, value = find_unit_clause(clauses, model) + P, value = find_pure_symbol(symbols, clauses) + while P: + model.update({P: value}) + symbols.remove(P) + if not value: + P = ~P + clauses = unit_propagate(clauses, P) + P, value = find_pure_symbol(symbols, clauses) + # end DP kernel + unknown_clauses = [] + for c in clauses: + val = pl_true(c, model) + if val is False: + return False + if val is not True: + unknown_clauses.append(c) + if not unknown_clauses: + return model + if not clauses: + return model + P = symbols.pop() + model_copy = model.copy() + model.update({P: True}) + model_copy.update({P: False}) + symbols_copy = symbols[:] + return (dpll(unit_propagate(unknown_clauses, P), symbols, model) or + dpll(unit_propagate(unknown_clauses, Not(P)), symbols_copy, model_copy)) + + +def dpll_int_repr(clauses, symbols, model): + """ + Compute satisfiability in a partial model. + Arguments are expected to be in integer representation + + >>> from sympy.logic.algorithms.dpll import dpll_int_repr + >>> dpll_int_repr([{1}, {2}, {3}], {1, 2}, {3: False}) + False + + """ + # compute DP kernel + P, value = find_unit_clause_int_repr(clauses, model) + while P: + model.update({P: value}) + symbols.remove(P) + if not value: + P = -P + clauses = unit_propagate_int_repr(clauses, P) + P, value = find_unit_clause_int_repr(clauses, model) + P, value = find_pure_symbol_int_repr(symbols, clauses) + while P: + model.update({P: value}) + symbols.remove(P) + if not value: + P = -P + clauses = unit_propagate_int_repr(clauses, P) + P, value = find_pure_symbol_int_repr(symbols, clauses) + # end DP kernel + unknown_clauses = [] + for c in clauses: + val = pl_true_int_repr(c, model) + if val is False: + return False + if val is not True: + unknown_clauses.append(c) + if not unknown_clauses: + return model + P = symbols.pop() + model_copy = model.copy() + model.update({P: True}) + model_copy.update({P: False}) + symbols_copy = symbols.copy() + return (dpll_int_repr(unit_propagate_int_repr(unknown_clauses, P), symbols, model) or + dpll_int_repr(unit_propagate_int_repr(unknown_clauses, -P), symbols_copy, model_copy)) + +### helper methods for DPLL + + +def pl_true_int_repr(clause, model={}): + """ + Lightweight version of pl_true. + Argument clause represents the set of args of an Or clause. This is used + inside dpll_int_repr, it is not meant to be used directly. + + >>> from sympy.logic.algorithms.dpll import pl_true_int_repr + >>> pl_true_int_repr({1, 2}, {1: False}) + >>> pl_true_int_repr({1, 2}, {1: False, 2: False}) + False + + """ + result = False + for lit in clause: + if lit < 0: + p = model.get(-lit) + if p is not None: + p = not p + else: + p = model.get(lit) + if p is True: + return True + elif p is None: + result = None + return result + + +def unit_propagate(clauses, symbol): + """ + Returns an equivalent set of clauses + If a set of clauses contains the unit clause l, the other clauses are + simplified by the application of the two following rules: + + 1. every clause containing l is removed + 2. in every clause that contains ~l this literal is deleted + + Arguments are expected to be in CNF. + + >>> from sympy.abc import A, B, D + >>> from sympy.logic.algorithms.dpll import unit_propagate + >>> unit_propagate([A | B, D | ~B, B], B) + [D, B] + + """ + output = [] + for c in clauses: + if c.func != Or: + output.append(c) + continue + for arg in c.args: + if arg == ~symbol: + output.append(Or(*[x for x in c.args if x != ~symbol])) + break + if arg == symbol: + break + else: + output.append(c) + return output + + +def unit_propagate_int_repr(clauses, s): + """ + Same as unit_propagate, but arguments are expected to be in integer + representation + + >>> from sympy.logic.algorithms.dpll import unit_propagate_int_repr + >>> unit_propagate_int_repr([{1, 2}, {3, -2}, {2}], 2) + [{3}] + + """ + negated = {-s} + return [clause - negated for clause in clauses if s not in clause] + + +def find_pure_symbol(symbols, unknown_clauses): + """ + Find a symbol and its value if it appears only as a positive literal + (or only as a negative) in clauses. + + >>> from sympy.abc import A, B, D + >>> from sympy.logic.algorithms.dpll import find_pure_symbol + >>> find_pure_symbol([A, B, D], [A|~B,~B|~D,D|A]) + (A, True) + + """ + for sym in symbols: + found_pos, found_neg = False, False + for c in unknown_clauses: + if not found_pos and sym in disjuncts(c): + found_pos = True + if not found_neg and Not(sym) in disjuncts(c): + found_neg = True + if found_pos != found_neg: + return sym, found_pos + return None, None + + +def find_pure_symbol_int_repr(symbols, unknown_clauses): + """ + Same as find_pure_symbol, but arguments are expected + to be in integer representation + + >>> from sympy.logic.algorithms.dpll import find_pure_symbol_int_repr + >>> find_pure_symbol_int_repr({1,2,3}, + ... [{1, -2}, {-2, -3}, {3, 1}]) + (1, True) + + """ + all_symbols = set().union(*unknown_clauses) + found_pos = all_symbols.intersection(symbols) + found_neg = all_symbols.intersection([-s for s in symbols]) + for p in found_pos: + if -p not in found_neg: + return p, True + for p in found_neg: + if -p not in found_pos: + return -p, False + return None, None + + +def find_unit_clause(clauses, model): + """ + A unit clause has only 1 variable that is not bound in the model. + + >>> from sympy.abc import A, B, D + >>> from sympy.logic.algorithms.dpll import find_unit_clause + >>> find_unit_clause([A | B | D, B | ~D, A | ~B], {A:True}) + (B, False) + + """ + for clause in clauses: + num_not_in_model = 0 + for literal in disjuncts(clause): + sym = literal_symbol(literal) + if sym not in model: + num_not_in_model += 1 + P, value = sym, not isinstance(literal, Not) + if num_not_in_model == 1: + return P, value + return None, None + + +def find_unit_clause_int_repr(clauses, model): + """ + Same as find_unit_clause, but arguments are expected to be in + integer representation. + + >>> from sympy.logic.algorithms.dpll import find_unit_clause_int_repr + >>> find_unit_clause_int_repr([{1, 2, 3}, + ... {2, -3}, {1, -2}], {1: True}) + (2, False) + + """ + bound = set(model) | {-sym for sym in model} + for clause in clauses: + unbound = clause - bound + if len(unbound) == 1: + p = unbound.pop() + if p < 0: + return -p, False + else: + return p, True + return None, None diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll2.py b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll2.py new file mode 100644 index 0000000000000000000000000000000000000000..4f18c81189d6be565dc9b7caa3f0bf48e978bb56 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/dpll2.py @@ -0,0 +1,688 @@ +"""Implementation of DPLL algorithm + +Features: + - Clause learning + - Watch literal scheme + - VSIDS heuristic + +References: + - https://en.wikipedia.org/wiki/DPLL_algorithm +""" + +from collections import defaultdict +from heapq import heappush, heappop + +from sympy.core.sorting import ordered +from sympy.assumptions.cnf import EncodedCNF + +from sympy.logic.algorithms.lra_theory import LRASolver + + +def dpll_satisfiable(expr, all_models=False, use_lra_theory=False): + """ + Check satisfiability of a propositional sentence. + It returns a model rather than True when it succeeds. + Returns a generator of all models if all_models is True. + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.algorithms.dpll2 import dpll_satisfiable + >>> dpll_satisfiable(A & ~B) + {A: True, B: False} + >>> dpll_satisfiable(A & ~A) + False + + """ + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + if use_lra_theory: + lra, immediate_conflicts = LRASolver.from_encoded_cnf(expr) + else: + lra = None + immediate_conflicts = [] + solver = SATSolver(expr.data + immediate_conflicts, expr.variables, set(), expr.symbols, lra_theory=lra) + models = solver._find_model() + + if all_models: + return _all_models(models) + + try: + return next(models) + except StopIteration: + return False + + # Uncomment to confirm the solution is valid (hitting set for the clauses) + #else: + #for cls in clauses_int_repr: + #assert solver.var_settings.intersection(cls) + + +def _all_models(models): + satisfiable = False + try: + while True: + yield next(models) + satisfiable = True + except StopIteration: + if not satisfiable: + yield False + + +class SATSolver: + """ + Class for representing a SAT solver capable of + finding a model to a boolean theory in conjunctive + normal form. + """ + + def __init__(self, clauses, variables, var_settings, symbols=None, + heuristic='vsids', clause_learning='none', INTERVAL=500, + lra_theory = None): + + self.var_settings = var_settings + self.heuristic = heuristic + self.is_unsatisfied = False + self._unit_prop_queue = [] + self.update_functions = [] + self.INTERVAL = INTERVAL + + if symbols is None: + self.symbols = list(ordered(variables)) + else: + self.symbols = symbols + + self._initialize_variables(variables) + self._initialize_clauses(clauses) + + if 'vsids' == heuristic: + self._vsids_init() + self.heur_calculate = self._vsids_calculate + self.heur_lit_assigned = self._vsids_lit_assigned + self.heur_lit_unset = self._vsids_lit_unset + self.heur_clause_added = self._vsids_clause_added + + # Note: Uncomment this if/when clause learning is enabled + #self.update_functions.append(self._vsids_decay) + + else: + raise NotImplementedError + + if 'simple' == clause_learning: + self.add_learned_clause = self._simple_add_learned_clause + self.compute_conflict = self._simple_compute_conflict + self.update_functions.append(self._simple_clean_clauses) + elif 'none' == clause_learning: + self.add_learned_clause = lambda x: None + self.compute_conflict = lambda: None + else: + raise NotImplementedError + + # Create the base level + self.levels = [Level(0)] + self._current_level.varsettings = var_settings + + # Keep stats + self.num_decisions = 0 + self.num_learned_clauses = 0 + self.original_num_clauses = len(self.clauses) + + self.lra = lra_theory + + def _initialize_variables(self, variables): + """Set up the variable data structures needed.""" + self.sentinels = defaultdict(set) + self.occurrence_count = defaultdict(int) + self.variable_set = [False] * (len(variables) + 1) + + def _initialize_clauses(self, clauses): + """Set up the clause data structures needed. + + For each clause, the following changes are made: + - Unit clauses are queued for propagation right away. + - Non-unit clauses have their first and last literals set as sentinels. + - The number of clauses a literal appears in is computed. + """ + self.clauses = [list(clause) for clause in clauses] + + for i, clause in enumerate(self.clauses): + + # Handle the unit clauses + if 1 == len(clause): + self._unit_prop_queue.append(clause[0]) + continue + + self.sentinels[clause[0]].add(i) + self.sentinels[clause[-1]].add(i) + + for lit in clause: + self.occurrence_count[lit] += 1 + + def _find_model(self): + """ + Main DPLL loop. Returns a generator of models. + + Variables are chosen successively, and assigned to be either + True or False. If a solution is not found with this setting, + the opposite is chosen and the search continues. The solver + halts when every variable has a setting. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> list(l._find_model()) + [{1: True, 2: False, 3: False}, {1: True, 2: True, 3: True}] + + >>> from sympy.abc import A, B, C + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set(), [A, B, C]) + >>> list(l._find_model()) + [{A: True, B: False, C: False}, {A: True, B: True, C: True}] + + """ + + # We use this variable to keep track of if we should flip a + # variable setting in successive rounds + flip_var = False + + # Check if unit prop says the theory is unsat right off the bat + self._simplify() + if self.is_unsatisfied: + return + + # While the theory still has clauses remaining + while True: + # Perform cleanup / fixup at regular intervals + if self.num_decisions % self.INTERVAL == 0: + for func in self.update_functions: + func() + + if flip_var: + # We have just backtracked and we are trying to opposite literal + flip_var = False + lit = self._current_level.decision + + else: + # Pick a literal to set + lit = self.heur_calculate() + self.num_decisions += 1 + + # Stopping condition for a satisfying theory + if 0 == lit: + + # check if assignment satisfies lra theory + if self.lra: + for enc_var in self.var_settings: + res = self.lra.assert_lit(enc_var) + if res is not None: + break + res = self.lra.check() + self.lra.reset_bounds() + else: + res = None + if res is None or res[0]: + yield {self.symbols[abs(lit) - 1]: + lit > 0 for lit in self.var_settings} + else: + self._simple_add_learned_clause(res[1]) + + # backtrack until we unassign one of the literals causing the conflict + while not any(-lit in res[1] for lit in self._current_level.var_settings): + self._undo() + + while self._current_level.flipped: + self._undo() + if len(self.levels) == 1: + return + flip_lit = -self._current_level.decision + self._undo() + self.levels.append(Level(flip_lit, flipped=True)) + flip_var = True + continue + + # Start the new decision level + self.levels.append(Level(lit)) + + # Assign the literal, updating the clauses it satisfies + self._assign_literal(lit) + + # _simplify the theory + self._simplify() + + # Check if we've made the theory unsat + if self.is_unsatisfied: + + self.is_unsatisfied = False + + # We unroll all of the decisions until we can flip a literal + while self._current_level.flipped: + self._undo() + + # If we've unrolled all the way, the theory is unsat + if 1 == len(self.levels): + return + + # Detect and add a learned clause + self.add_learned_clause(self.compute_conflict()) + + # Try the opposite setting of the most recent decision + flip_lit = -self._current_level.decision + self._undo() + self.levels.append(Level(flip_lit, flipped=True)) + flip_var = True + + ######################## + # Helper Methods # + ######################## + @property + def _current_level(self): + """The current decision level data structure + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{1}, {2}], {1, 2}, set()) + >>> next(l._find_model()) + {1: True, 2: True} + >>> l._current_level.decision + 0 + >>> l._current_level.flipped + False + >>> l._current_level.var_settings + {1, 2} + + """ + return self.levels[-1] + + def _clause_sat(self, cls): + """Check if a clause is satisfied by the current variable setting. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{1}, {-1}], {1}, set()) + >>> try: + ... next(l._find_model()) + ... except StopIteration: + ... pass + >>> l._clause_sat(0) + False + >>> l._clause_sat(1) + True + + """ + for lit in self.clauses[cls]: + if lit in self.var_settings: + return True + return False + + def _is_sentinel(self, lit, cls): + """Check if a literal is a sentinel of a given clause. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l._is_sentinel(2, 3) + True + >>> l._is_sentinel(-3, 1) + False + + """ + return cls in self.sentinels[lit] + + def _assign_literal(self, lit): + """Make a literal assignment. + + The literal assignment must be recorded as part of the current + decision level. Additionally, if the literal is marked as a + sentinel of any clause, then a new sentinel must be chosen. If + this is not possible, then unit propagation is triggered and + another literal is added to the queue to be set in the future. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l.var_settings + {-3, -2, 1} + + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l._assign_literal(-1) + >>> try: + ... next(l._find_model()) + ... except StopIteration: + ... pass + >>> l.var_settings + {-1} + + """ + self.var_settings.add(lit) + self._current_level.var_settings.add(lit) + self.variable_set[abs(lit)] = True + self.heur_lit_assigned(lit) + + sentinel_list = list(self.sentinels[-lit]) + + for cls in sentinel_list: + if not self._clause_sat(cls): + other_sentinel = None + for newlit in self.clauses[cls]: + if newlit != -lit: + if self._is_sentinel(newlit, cls): + other_sentinel = newlit + elif not self.variable_set[abs(newlit)]: + self.sentinels[-lit].remove(cls) + self.sentinels[newlit].add(cls) + other_sentinel = None + break + + # Check if no sentinel update exists + if other_sentinel: + self._unit_prop_queue.append(other_sentinel) + + def _undo(self): + """ + _undo the changes of the most recent decision level. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> level = l._current_level + >>> level.decision, level.var_settings, level.flipped + (-3, {-3, -2}, False) + >>> l._undo() + >>> level = l._current_level + >>> level.decision, level.var_settings, level.flipped + (0, {1}, False) + + """ + # Undo the variable settings + for lit in self._current_level.var_settings: + self.var_settings.remove(lit) + self.heur_lit_unset(lit) + self.variable_set[abs(lit)] = False + + # Pop the level off the stack + self.levels.pop() + + ######################### + # Propagation # + ######################### + """ + Propagation methods should attempt to soundly simplify the boolean + theory, and return True if any simplification occurred and False + otherwise. + """ + def _simplify(self): + """Iterate over the various forms of propagation to simplify the theory. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l.variable_set + [False, False, False, False] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}} + + >>> l._simplify() + + >>> l.variable_set + [False, True, False, False] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, -1: set(), 2: {0, 3}, + ...3: {2, 4}} + + """ + changed = True + while changed: + changed = False + changed |= self._unit_prop() + changed |= self._pure_literal() + + def _unit_prop(self): + """Perform unit propagation on the current theory.""" + result = len(self._unit_prop_queue) > 0 + while self._unit_prop_queue: + next_lit = self._unit_prop_queue.pop() + if -next_lit in self.var_settings: + self.is_unsatisfied = True + self._unit_prop_queue = [] + return False + else: + self._assign_literal(next_lit) + + return result + + def _pure_literal(self): + """Look for pure literals and assign them when found.""" + return False + + ######################### + # Heuristics # + ######################### + def _vsids_init(self): + """Initialize the data structures needed for the VSIDS heuristic.""" + self.lit_heap = [] + self.lit_scores = {} + + for var in range(1, len(self.variable_set)): + self.lit_scores[var] = float(-self.occurrence_count[var]) + self.lit_scores[-var] = float(-self.occurrence_count[-var]) + heappush(self.lit_heap, (self.lit_scores[var], var)) + heappush(self.lit_heap, (self.lit_scores[-var], -var)) + + def _vsids_decay(self): + """Decay the VSIDS scores for every literal. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.lit_scores + {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0} + + >>> l._vsids_decay() + + >>> l.lit_scores + {-3: -1.0, -2: -1.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -1.0} + + """ + # We divide every literal score by 2 for a decay factor + # Note: This doesn't change the heap property + for lit in self.lit_scores.keys(): + self.lit_scores[lit] /= 2.0 + + def _vsids_calculate(self): + """ + VSIDS Heuristic Calculation + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.lit_heap + [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)] + + >>> l._vsids_calculate() + -3 + + >>> l.lit_heap + [(-2.0, -2), (-2.0, 2), (0.0, -1), (0.0, 1), (-2.0, 3)] + + """ + if len(self.lit_heap) == 0: + return 0 + + # Clean out the front of the heap as long the variables are set + while self.variable_set[abs(self.lit_heap[0][1])]: + heappop(self.lit_heap) + if len(self.lit_heap) == 0: + return 0 + + return heappop(self.lit_heap)[1] + + def _vsids_lit_assigned(self, lit): + """Handle the assignment of a literal for the VSIDS heuristic.""" + pass + + def _vsids_lit_unset(self, lit): + """Handle the unsetting of a literal for the VSIDS heuristic. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l.lit_heap + [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)] + + >>> l._vsids_lit_unset(2) + + >>> l.lit_heap + [(-2.0, -3), (-2.0, -2), (-2.0, -2), (-2.0, 2), (-2.0, 3), (0.0, -1), + ...(-2.0, 2), (0.0, 1)] + + """ + var = abs(lit) + heappush(self.lit_heap, (self.lit_scores[var], var)) + heappush(self.lit_heap, (self.lit_scores[-var], -var)) + + def _vsids_clause_added(self, cls): + """Handle the addition of a new clause for the VSIDS heuristic. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.num_learned_clauses + 0 + >>> l.lit_scores + {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0} + + >>> l._vsids_clause_added({2, -3}) + + >>> l.num_learned_clauses + 1 + >>> l.lit_scores + {-3: -1.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -2.0} + + """ + self.num_learned_clauses += 1 + for lit in cls: + self.lit_scores[lit] += 1 + + ######################## + # Clause Learning # + ######################## + def _simple_add_learned_clause(self, cls): + """Add a new clause to the theory. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.num_learned_clauses + 0 + >>> l.clauses + [[2, -3], [1], [3, -3], [2, -2], [3, -2]] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}} + + >>> l._simple_add_learned_clause([3]) + + >>> l.clauses + [[2, -3], [1], [3, -3], [2, -2], [3, -2], [3]] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4, 5}} + + """ + cls_num = len(self.clauses) + self.clauses.append(cls) + + for lit in cls: + self.occurrence_count[lit] += 1 + + self.sentinels[cls[0]].add(cls_num) + self.sentinels[cls[-1]].add(cls_num) + + self.heur_clause_added(cls) + + def _simple_compute_conflict(self): + """ Build a clause representing the fact that at least one decision made + so far is wrong. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l._simple_compute_conflict() + [3] + + """ + return [-(level.decision) for level in self.levels[1:]] + + def _simple_clean_clauses(self): + """Clean up learned clauses.""" + pass + + +class Level: + """ + Represents a single level in the DPLL algorithm, and contains + enough information for a sound backtracking procedure. + """ + + def __init__(self, decision, flipped=False): + self.decision = decision + self.var_settings = set() + self.flipped = flipped diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/lra_theory.py b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/lra_theory.py new file mode 100644 index 0000000000000000000000000000000000000000..1690760d36003aed6866f593120c05a5b8f92c83 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/lra_theory.py @@ -0,0 +1,912 @@ +"""Implements "A Fast Linear-Arithmetic Solver for DPLL(T)" + +The LRASolver class defined in this file can be used +in conjunction with a SAT solver to check the +satisfiability of formulas involving inequalities. + +Here's an example of how that would work: + + Suppose you want to check the satisfiability of + the following formula: + + >>> from sympy.core.relational import Eq + >>> from sympy.abc import x, y + >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & (~Eq(y, 1) | Eq(1, 2)) + + First a preprocessing step should be done on f. During preprocessing, + f should be checked for any predicates such as `Q.prime` that can't be + handled. Also unequality like `~Eq(y, 1)` should be split. + + I should mention that the paper says to split both equalities and + unequality, but this implementation only requires that unequality + be split. + + >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & ((y < 1) | (y > 1) | Eq(1, 2)) + + Then an LRASolver instance needs to be initialized with this formula. + + >>> from sympy.assumptions.cnf import CNF, EncodedCNF + >>> from sympy.assumptions.ask import Q + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> cnf = CNF.from_prop(f) + >>> enc = EncodedCNF() + >>> enc.add_from_cnf(cnf) + >>> lra, conflicts = LRASolver.from_encoded_cnf(enc) + + Any immediate one-lital conflicts clauses will be detected here. + In this example, `~Eq(1, 2)` is one such conflict clause. We'll + want to add it to `f` so that the SAT solver is forced to + assign Eq(1, 2) to False. + + >>> f = f & ~Eq(1, 2) + + Now that the one-literal conflict clauses have been added + and an lra object has been initialized, we can pass `f` + to a SAT solver. The SAT solver will give us a satisfying + assignment such as: + + (1 = 2): False + (y = 1): True + (y < 1): True + (y > 1): True + (x = 0): True + (x < 0): True + (x > 0): True + + Next you would pass this assignment to the LRASolver + which will be able to determine that this particular + assignment is satisfiable or not. + + Note that since EncodedCNF is inherently non-deterministic, + the int each predicate is encoded as is not consistent. As a + result, the code below likely does not reflect the assignment + given above. + + >>> lra.assert_lit(-1) #doctest: +SKIP + >>> lra.assert_lit(2) #doctest: +SKIP + >>> lra.assert_lit(3) #doctest: +SKIP + >>> lra.assert_lit(4) #doctest: +SKIP + >>> lra.assert_lit(5) #doctest: +SKIP + >>> lra.assert_lit(6) #doctest: +SKIP + >>> lra.assert_lit(7) #doctest: +SKIP + >>> is_sat, conflict_or_assignment = lra.check() + + As the particular assignment suggested is not satisfiable, + the LRASolver will return unsat and a conflict clause when + given that assignment. The conflict clause will always be + minimal, but there can be multiple minimal conflict clauses. + One possible conflict clause could be `~(x < 0) | ~(x > 0)`. + + We would then add whatever conflict clause is given to + `f` to prevent the SAT solver from coming up with an + assignment with the same conflicting literals. In this case, + the conflict clause `~(x < 0) | ~(x > 0)` would prevent + any assignment where both (x < 0) and (x > 0) were both + true. + + The SAT solver would then find another assignment + and we would check that assignment with the LRASolver + and so on. Eventually either a satisfying assignment + that the SAT solver and LRASolver agreed on would be found + or enough conflict clauses would be added so that the + boolean formula was unsatisfiable. + + +This implementation is based on [1]_, which includes a +detailed explanation of the algorithm and pseudocode +for the most important functions. + +[1]_ also explains how backtracking and theory propagation +could be implemented to speed up the current implementation, +but these are not currently implemented. + +TODO: + - Handle non-rational real numbers + - Handle positive and negative infinity + - Implement backtracking and theory proposition + - Simplify matrix by removing unused variables using Gaussian elimination + +References +========== + +.. [1] Dutertre, B., de Moura, L.: + A Fast Linear-Arithmetic Solver for DPLL(T) + https://link.springer.com/chapter/10.1007/11817963_11 +""" +from sympy.solvers.solveset import linear_eq_to_matrix +from sympy.matrices.dense import eye +from sympy.assumptions import Predicate +from sympy.assumptions.assume import AppliedPredicate +from sympy.assumptions.ask import Q +from sympy.core import Dummy +from sympy.core.mul import Mul +from sympy.core.add import Add +from sympy.core.relational import Eq, Ne +from sympy.core.sympify import sympify +from sympy.core.singleton import S +from sympy.core.numbers import Rational, oo +from sympy.matrices.dense import Matrix + +class UnhandledInput(Exception): + """ + Raised while creating an LRASolver if non-linearity + or non-rational numbers are present. + """ + +# predicates that LRASolver understands and makes use of +ALLOWED_PRED = {Q.eq, Q.gt, Q.lt, Q.le, Q.ge} + +# if true ~Q.gt(x, y) implies Q.le(x, y) +HANDLE_NEGATION = True + +class LRASolver(): + """ + Linear Arithmetic Solver for DPLL(T) implemented with an algorithm based on + the Dual Simplex method. Uses Bland's pivoting rule to avoid cycling. + + References + ========== + + .. [1] Dutertre, B., de Moura, L.: + A Fast Linear-Arithmetic Solver for DPLL(T) + https://link.springer.com/chapter/10.1007/11817963_11 + """ + + def __init__(self, A, slack_variables, nonslack_variables, enc_to_boundary, s_subs, testing_mode): + """ + Use the "from_encoded_cnf" method to create a new LRASolver. + """ + self.run_checks = testing_mode + self.s_subs = s_subs # used only for test_lra_theory.test_random_problems + + if any(not isinstance(a, Rational) for a in A): + raise UnhandledInput("Non-rational numbers are not handled") + if any(not isinstance(b.bound, Rational) for b in enc_to_boundary.values()): + raise UnhandledInput("Non-rational numbers are not handled") + m, n = len(slack_variables), len(slack_variables)+len(nonslack_variables) + if m != 0: + assert A.shape == (m, n) + if self.run_checks: + assert A[:, n-m:] == -eye(m) + + self.enc_to_boundary = enc_to_boundary # mapping of int to Boundary objects + self.boundary_to_enc = {value: key for key, value in enc_to_boundary.items()} + self.A = A + self.slack = slack_variables + self.nonslack = nonslack_variables + self.all_var = nonslack_variables + slack_variables + + self.slack_set = set(slack_variables) + + self.is_sat = True # While True, all constraints asserted so far are satisfiable + self.result = None # always one of: (True, assignment), (False, conflict clause), None + + @staticmethod + def from_encoded_cnf(encoded_cnf, testing_mode=False): + """ + Creates an LRASolver from an EncodedCNF object + and a list of conflict clauses for propositions + that can be simplified to True or False. + + Parameters + ========== + + encoded_cnf : EncodedCNF + + testing_mode : bool + Setting testing_mode to True enables some slow assert statements + and sorting to reduce nonterministic behavior. + + Returns + ======= + + (lra, conflicts) + + lra : LRASolver + + conflicts : list + Contains a one-literal conflict clause for each proposition + that can be simplified to True or False. + + Example + ======= + + >>> from sympy.core.relational import Eq + >>> from sympy.assumptions.cnf import CNF, EncodedCNF + >>> from sympy.assumptions.ask import Q + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> from sympy.abc import x, y, z + >>> phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6)) + >>> phi = phi & (Eq(x + y, 2) | (x + 2 * y - z > 4)) + >>> phi = phi & Q.gt(2, 1) + >>> cnf = CNF.from_prop(phi) + >>> enc = EncodedCNF() + >>> enc.from_cnf(cnf) + >>> lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + >>> lra #doctest: +SKIP + + >>> conflicts #doctest: +SKIP + [[4]] + """ + # This function has three main jobs: + # - raise errors if the input formula is not handled + # - preprocesses the formula into a matrix and single variable constraints + # - create one-literal conflict clauses from predicates that are always True + # or always False such as Q.gt(3, 2) + # + # See the preprocessing section of "A Fast Linear-Arithmetic Solver for DPLL(T)" + # for an explanation of how the formula is converted into a matrix + # and a set of single variable constraints. + + encoding = {} # maps int to boundary + A = [] + + basic = [] + s_count = 0 + s_subs = {} + nonbasic = [] + + if testing_mode: + # sort to reduce nondeterminism + encoded_cnf_items = sorted(encoded_cnf.encoding.items(), key=lambda x: str(x)) + else: + encoded_cnf_items = encoded_cnf.encoding.items() + + empty_var = Dummy() + var_to_lra_var = {} + conflicts = [] + + for prop, enc in encoded_cnf_items: + if isinstance(prop, Predicate): + prop = prop(empty_var) + if not isinstance(prop, AppliedPredicate): + if prop == True: + conflicts.append([enc]) + continue + if prop == False: + conflicts.append([-enc]) + continue + + raise ValueError(f"Unhandled Predicate: {prop}") + + assert prop.function in ALLOWED_PRED + if prop.lhs == S.NaN or prop.rhs == S.NaN: + raise ValueError(f"{prop} contains nan") + if prop.lhs.is_imaginary or prop.rhs.is_imaginary: + raise UnhandledInput(f"{prop} contains an imaginary component") + if prop.lhs == oo or prop.rhs == oo: + raise UnhandledInput(f"{prop} contains infinity") + + prop = _eval_binrel(prop) # simplify variable-less quantities to True / False if possible + if prop == True: + conflicts.append([enc]) + continue + elif prop == False: + conflicts.append([-enc]) + continue + elif prop is None: + raise UnhandledInput(f"{prop} could not be simplified") + + expr = prop.lhs - prop.rhs + if prop.function in [Q.ge, Q.gt]: + expr = -expr + + # expr should be less than (or equal to) 0 + # otherwise prop is False + if prop.function in [Q.le, Q.ge]: + bool = (expr <= 0) + elif prop.function in [Q.lt, Q.gt]: + bool = (expr < 0) + else: + assert prop.function == Q.eq + bool = Eq(expr, 0) + + if bool == True: + conflicts.append([enc]) + continue + elif bool == False: + conflicts.append([-enc]) + continue + + + vars, const = _sep_const_terms(expr) # example: (2x + 3y + 2) --> (2x + 3y), (2) + vars, var_coeff = _sep_const_coeff(vars) # examples: (2x) --> (x, 2); (2x + 3y) --> (2x + 3y), (1) + const = const / var_coeff + + terms = _list_terms(vars) # example: (2x + 3y) --> [2x, 3y] + for term in terms: + term, _ = _sep_const_coeff(term) + assert len(term.free_symbols) > 0 + if term not in var_to_lra_var: + var_to_lra_var[term] = LRAVariable(term) + nonbasic.append(term) + + if len(terms) > 1: + if vars not in s_subs: + s_count += 1 + d = Dummy(f"s{s_count}") + var_to_lra_var[d] = LRAVariable(d) + basic.append(d) + s_subs[vars] = d + A.append(vars - d) + var = s_subs[vars] + else: + var = terms[0] + + assert var_coeff != 0 + + equality = prop.function == Q.eq + upper = var_coeff > 0 if not equality else None + strict = prop.function in [Q.gt, Q.lt] + b = Boundary(var_to_lra_var[var], -const, upper, equality, strict) + encoding[enc] = b + + fs = [v.free_symbols for v in nonbasic + basic] + assert all(len(syms) > 0 for syms in fs) + fs_count = sum(len(syms) for syms in fs) + if len(fs) > 0 and len(set.union(*fs)) < fs_count: + raise UnhandledInput("Nonlinearity is not handled") + + A, _ = linear_eq_to_matrix(A, nonbasic + basic) + nonbasic = [var_to_lra_var[nb] for nb in nonbasic] + basic = [var_to_lra_var[b] for b in basic] + for idx, var in enumerate(nonbasic + basic): + var.col_idx = idx + + return LRASolver(A, basic, nonbasic, encoding, s_subs, testing_mode), conflicts + + def reset_bounds(self): + """ + Resets the state of the LRASolver to before + anything was asserted. + """ + self.result = None + for var in self.all_var: + var.lower = LRARational(-float("inf"), 0) + var.lower_from_eq = False + var.lower_from_neg = False + var.upper = LRARational(float("inf"), 0) + var.upper_from_eq= False + var.lower_from_neg = False + var.assign = LRARational(0, 0) + + def assert_lit(self, enc_constraint): + """ + Assert a literal representing a constraint + and update the internal state accordingly. + + Note that due to peculiarities of this implementation + asserting ~(x > 0) will assert (x <= 0) but asserting + ~Eq(x, 0) will not do anything. + + Parameters + ========== + + enc_constraint : int + A mapping of encodings to constraints + can be found in `self.enc_to_boundary`. + + Returns + ======= + + None or (False, explanation) + + explanation : set of ints + A conflict clause that "explains" why + the literals asserted so far are unsatisfiable. + """ + if abs(enc_constraint) not in self.enc_to_boundary: + return None + + if not HANDLE_NEGATION and enc_constraint < 0: + return None + + boundary = self.enc_to_boundary[abs(enc_constraint)] + sym, c, negated = boundary.var, boundary.bound, enc_constraint < 0 + + if boundary.equality and negated: + return None # negated equality is not handled and should only appear in conflict clauses + + upper = boundary.upper != negated + if boundary.strict != negated: + delta = -1 if upper else 1 + c = LRARational(c, delta) + else: + c = LRARational(c, 0) + + if boundary.equality: + res1 = self._assert_lower(sym, c, from_equality=True, from_neg=negated) + if res1 and res1[0] == False: + res = res1 + else: + res2 = self._assert_upper(sym, c, from_equality=True, from_neg=negated) + res = res2 + elif upper: + res = self._assert_upper(sym, c, from_neg=negated) + else: + res = self._assert_lower(sym, c, from_neg=negated) + + if self.is_sat and sym not in self.slack_set: + self.is_sat = res is None + else: + self.is_sat = False + + return res + + def _assert_upper(self, xi, ci, from_equality=False, from_neg=False): + """ + Adjusts the upper bound on variable xi if the new upper bound is + more limiting. The assignment of variable xi is adjusted to be + within the new bound if needed. + + Also calls `self._update` to update the assignment for slack variables + to keep all equalities satisfied. + """ + if self.result: + assert self.result[0] != False + self.result = None + if ci >= xi.upper: + return None + if ci < xi.lower: + assert (xi.lower[1] >= 0) is True + assert (ci[1] <= 0) is True + + lit1, neg1 = Boundary.from_lower(xi) + + lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=True, equality=from_equality) + if from_neg: + lit2 = lit2.get_negated() + neg2 = -1 if from_neg else 1 + + conflict = [-neg1*self.boundary_to_enc[lit1], -neg2*self.boundary_to_enc[lit2]] + self.result = False, conflict + return self.result + xi.upper = ci + xi.upper_from_eq = from_equality + xi.upper_from_neg = from_neg + if xi in self.nonslack and xi.assign > ci: + self._update(xi, ci) + + if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + M = self.A + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10 ** (-10) for val in M * X) + + return None + + def _assert_lower(self, xi, ci, from_equality=False, from_neg=False): + """ + Adjusts the lower bound on variable xi if the new lower bound is + more limiting. The assignment of variable xi is adjusted to be + within the new bound if needed. + + Also calls `self._update` to update the assignment for slack variables + to keep all equalities satisfied. + """ + if self.result: + assert self.result[0] != False + self.result = None + if ci <= xi.lower: + return None + if ci > xi.upper: + assert (xi.upper[1] <= 0) is True + assert (ci[1] >= 0) is True + + lit1, neg1 = Boundary.from_upper(xi) + + lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=False, equality=from_equality) + if from_neg: + lit2 = lit2.get_negated() + neg2 = -1 if from_neg else 1 + + conflict = [-neg1*self.boundary_to_enc[lit1],-neg2*self.boundary_to_enc[lit2]] + self.result = False, conflict + return self.result + xi.lower = ci + xi.lower_from_eq = from_equality + xi.lower_from_neg = from_neg + if xi in self.nonslack and xi.assign < ci: + self._update(xi, ci) + + if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + M = self.A + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10 ** (-10) for val in M * X) + + return None + + def _update(self, xi, v): + """ + Updates all slack variables that have equations that contain + variable xi so that they stay satisfied given xi is equal to v. + """ + i = xi.col_idx + for j, b in enumerate(self.slack): + aji = self.A[j, i] + b.assign = b.assign + (v - xi.assign)*aji + xi.assign = v + + def check(self): + """ + Searches for an assignment that satisfies all constraints + or determines that no such assignment exists and gives + a minimal conflict clause that "explains" why the + constraints are unsatisfiable. + + Returns + ======= + + (True, assignment) or (False, explanation) + + assignment : dict of LRAVariables to values + Assigned values are tuples that represent a rational number + plus some infinatesimal delta. + + explanation : set of ints + """ + if self.is_sat: + return True, {var: var.assign for var in self.all_var} + if self.result: + return self.result + + from sympy.matrices.dense import Matrix + M = self.A.copy() + basic = {s: i for i, s in enumerate(self.slack)} # contains the row index associated with each basic variable + nonbasic = set(self.nonslack) + while True: + if self.run_checks: + # nonbasic variables must always be within bounds + assert all(((nb.assign >= nb.lower) == True) and ((nb.assign <= nb.upper) == True) for nb in nonbasic) + + # assignments for x must always satisfy Ax = 0 + # probably have to turn this off when dealing with strict ineq + if all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10**(-10) for val in M*X) + + # check upper and lower match this format: + # x <= rat + delta iff x < rat + # x >= rat - delta iff x > rat + # this wouldn't make sense: + # x <= rat - delta + # x >= rat + delta + assert all(x.upper[1] <= 0 for x in self.all_var) + assert all(x.lower[1] >= 0 for x in self.all_var) + + cand = [b for b in basic if b.assign < b.lower or b.assign > b.upper] + + if len(cand) == 0: + return True, {var: var.assign for var in self.all_var} + + xi = min(cand, key=lambda v: v.col_idx) # Bland's rule + i = basic[xi] + + if xi.assign < xi.lower: + cand = [nb for nb in nonbasic + if (M[i, nb.col_idx] > 0 and nb.assign < nb.upper) + or (M[i, nb.col_idx] < 0 and nb.assign > nb.lower)] + if len(cand) == 0: + N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0] + N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0] + + conflict = [] + conflict += [Boundary.from_upper(nb) for nb in N_plus] + conflict += [Boundary.from_lower(nb) for nb in N_minus] + conflict.append(Boundary.from_lower(xi)) + conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict] + return False, conflict + xj = min(cand, key=str) + M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.lower) + + if xi.assign > xi.upper: + cand = [nb for nb in nonbasic + if (M[i, nb.col_idx] < 0 and nb.assign < nb.upper) + or (M[i, nb.col_idx] > 0 and nb.assign > nb.lower)] + + if len(cand) == 0: + N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0] + N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0] + + conflict = [] + conflict += [Boundary.from_upper(nb) for nb in N_minus] + conflict += [Boundary.from_lower(nb) for nb in N_plus] + conflict.append(Boundary.from_upper(xi)) + + conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict] + return False, conflict + xj = min(cand, key=lambda v: v.col_idx) + M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.upper) + + def _pivot_and_update(self, M, basic, nonbasic, xi, xj, v): + """ + Pivots basic variable xi with nonbasic variable xj, + and sets value of xi to v and adjusts the values of all basic variables + to keep equations satisfied. + """ + i, j = basic[xi], xj.col_idx + assert M[i, j] != 0 + theta = (v - xi.assign)*(1/M[i, j]) + xi.assign = v + xj.assign = xj.assign + theta + for xk in basic: + if xk != xi: + k = basic[xk] + akj = M[k, j] + xk.assign = xk.assign + theta*akj + # pivot + basic[xj] = basic[xi] + del basic[xi] + nonbasic.add(xi) + nonbasic.remove(xj) + return self._pivot(M, i, j) + + @staticmethod + def _pivot(M, i, j): + """ + Performs a pivot operation about entry i, j of M by performing + a series of row operations on a copy of M and returning the result. + The original M is left unmodified. + + Conceptually, M represents a system of equations and pivoting + can be thought of as rearranging equation i to be in terms of + variable j and then substituting in the rest of the equations + to get rid of other occurances of variable j. + + Example + ======= + + >>> from sympy.matrices.dense import Matrix + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> from sympy import var + >>> Matrix(3, 3, var('a:i')) + Matrix([ + [a, b, c], + [d, e, f], + [g, h, i]]) + + This matrix is equivalent to: + 0 = a*x + b*y + c*z + 0 = d*x + e*y + f*z + 0 = g*x + h*y + i*z + + >>> LRASolver._pivot(_, 1, 0) + Matrix([ + [ 0, -a*e/d + b, -a*f/d + c], + [-1, -e/d, -f/d], + [ 0, h - e*g/d, i - f*g/d]]) + + We rearrange equation 1 in terms of variable 0 (x) + and substitute to remove x from the other equations. + + 0 = 0 + (-a*e/d + b)*y + (-a*f/d + c)*z + 0 = -x + (-e/d)*y + (-f/d)*z + 0 = 0 + (h - e*g/d)*y + (i - f*g/d)*z + """ + _, _, Mij = M[i, :], M[:, j], M[i, j] + if Mij == 0: + raise ZeroDivisionError("Tried to pivot about zero-valued entry.") + A = M.copy() + A[i, :] = -A[i, :]/Mij + for row in range(M.shape[0]): + if row != i: + A[row, :] = A[row, :] + A[row, j] * A[i, :] + + return A + + +def _sep_const_coeff(expr): + """ + Example + ======= + + >>> from sympy.logic.algorithms.lra_theory import _sep_const_coeff + >>> from sympy.abc import x, y + >>> _sep_const_coeff(2*x) + (x, 2) + >>> _sep_const_coeff(2*x + 3*y) + (2*x + 3*y, 1) + """ + if isinstance(expr, Add): + return expr, sympify(1) + + if isinstance(expr, Mul): + coeffs = expr.args + else: + coeffs = [expr] + + var, const = [], [] + for c in coeffs: + c = sympify(c) + if len(c.free_symbols)==0: + const.append(c) + else: + var.append(c) + return Mul(*var), Mul(*const) + + +def _list_terms(expr): + if not isinstance(expr, Add): + return [expr] + + return expr.args + + +def _sep_const_terms(expr): + """ + Example + ======= + + >>> from sympy.logic.algorithms.lra_theory import _sep_const_terms + >>> from sympy.abc import x, y + >>> _sep_const_terms(2*x + 3*y + 2) + (2*x + 3*y, 2) + """ + if isinstance(expr, Add): + terms = expr.args + else: + terms = [expr] + + var, const = [], [] + for t in terms: + if len(t.free_symbols) == 0: + const.append(t) + else: + var.append(t) + return sum(var), sum(const) + + +def _eval_binrel(binrel): + """ + Simplify binary relation to True / False if possible. + """ + if not (len(binrel.lhs.free_symbols) == 0 and len(binrel.rhs.free_symbols) == 0): + return binrel + if binrel.function == Q.lt: + res = binrel.lhs < binrel.rhs + elif binrel.function == Q.gt: + res = binrel.lhs > binrel.rhs + elif binrel.function == Q.le: + res = binrel.lhs <= binrel.rhs + elif binrel.function == Q.ge: + res = binrel.lhs >= binrel.rhs + elif binrel.function == Q.eq: + res = Eq(binrel.lhs, binrel.rhs) + elif binrel.function == Q.ne: + res = Ne(binrel.lhs, binrel.rhs) + + if res == True or res == False: + return res + else: + return None + + +class Boundary: + """ + Represents an upper or lower bound or an equality between a symbol + and some constant. + """ + def __init__(self, var, const, upper, equality, strict=None): + if not equality in [True, False]: + assert equality in [True, False] + + + self.var = var + if isinstance(const, tuple): + s = const[1] != 0 + if strict: + assert s == strict + self.bound = const[0] + self.strict = s + else: + self.bound = const + self.strict = strict + self.upper = upper if not equality else None + self.equality = equality + self.strict = strict + assert self.strict is not None + + @staticmethod + def from_upper(var): + neg = -1 if var.upper_from_neg else 1 + b = Boundary(var, var.upper[0], True, var.upper_from_eq, var.upper[1] != 0) + if neg < 0: + b = b.get_negated() + return b, neg + + @staticmethod + def from_lower(var): + neg = -1 if var.lower_from_neg else 1 + b = Boundary(var, var.lower[0], False, var.lower_from_eq, var.lower[1] != 0) + if neg < 0: + b = b.get_negated() + return b, neg + + def get_negated(self): + return Boundary(self.var, self.bound, not self.upper, self.equality, not self.strict) + + def get_inequality(self): + if self.equality: + return Eq(self.var.var, self.bound) + elif self.upper and self.strict: + return self.var.var < self.bound + elif not self.upper and self.strict: + return self.var.var > self.bound + elif self.upper: + return self.var.var <= self.bound + else: + return self.var.var >= self.bound + + def __repr__(self): + return repr("Boundary(" + repr(self.get_inequality()) + ")") + + def __eq__(self, other): + other = (other.var, other.bound, other.strict, other.upper, other.equality) + return (self.var, self.bound, self.strict, self.upper, self.equality) == other + + def __hash__(self): + return hash((self.var, self.bound, self.strict, self.upper, self.equality)) + + +class LRARational(): + """ + Represents a rational plus or minus some amount + of arbitrary small deltas. + """ + def __init__(self, rational, delta): + self.value = (rational, delta) + + def __lt__(self, other): + return self.value < other.value + + def __le__(self, other): + return self.value <= other.value + + def __eq__(self, other): + return self.value == other.value + + def __add__(self, other): + return LRARational(self.value[0] + other.value[0], self.value[1] + other.value[1]) + + def __sub__(self, other): + return LRARational(self.value[0] - other.value[0], self.value[1] - other.value[1]) + + def __mul__(self, other): + assert not isinstance(other, LRARational) + return LRARational(self.value[0] * other, self.value[1] * other) + + def __getitem__(self, index): + return self.value[index] + + def __repr__(self): + return repr(self.value) + + +class LRAVariable(): + """ + Object to keep track of upper and lower bounds + on `self.var`. + """ + def __init__(self, var): + self.upper = LRARational(float("inf"), 0) + self.upper_from_eq = False + self.upper_from_neg = False + self.lower = LRARational(-float("inf"), 0) + self.lower_from_eq = False + self.lower_from_neg = False + self.assign = LRARational(0,0) + self.var = var + self.col_idx = None + + def __repr__(self): + return repr(self.var) + + def __eq__(self, other): + if not isinstance(other, LRAVariable): + return False + return other.var == self.var + + def __hash__(self): + return hash(self.var) diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/minisat22_wrapper.py b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/minisat22_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5c1f8f14f04309f7cb8197cc05d01a3c108545 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/minisat22_wrapper.py @@ -0,0 +1,46 @@ +from sympy.assumptions.cnf import EncodedCNF + +def minisat22_satisfiable(expr, all_models=False, minimal=False): + + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + from pysat.solvers import Minisat22 + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + r = Minisat22(expr.data) + + if minimal: + r.set_phases([-(i+1) for i in range(r.nof_vars())]) + + if not r.solve(): + return False + + if not all_models: + return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r.get_model()} + + else: + # Make solutions SymPy compatible by creating a generator + def _gen(results): + satisfiable = False + while results.solve(): + sol = results.get_model() + yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol} + if minimal: + results.add_clause([-i for i in sol if i>0]) + else: + results.add_clause([-i for i in sol]) + satisfiable = True + if not satisfiable: + yield False + raise StopIteration + + + return _gen(r) diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/pycosat_wrapper.py b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/pycosat_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5ff498b7e3f6b73d95e9b949598ef32df4ecf226 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/pycosat_wrapper.py @@ -0,0 +1,41 @@ +from sympy.assumptions.cnf import EncodedCNF + + +def pycosat_satisfiable(expr, all_models=False): + import pycosat + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + if not all_models: + r = pycosat.solve(expr.data) + result = (r != "UNSAT") + if not result: + return result + return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r} + else: + r = pycosat.itersolve(expr.data) + result = (r != "UNSAT") + if not result: + return result + + # Make solutions SymPy compatible by creating a generator + def _gen(results): + satisfiable = False + try: + while True: + sol = next(results) + yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol} + satisfiable = True + except StopIteration: + if not satisfiable: + yield False + + return _gen(r) diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/z3_wrapper.py b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/z3_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..fe44f713a2edfd5286c0f81b737212146766b11b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/algorithms/z3_wrapper.py @@ -0,0 +1,115 @@ +from sympy.printing.smtlib import smtlib_code +from sympy.assumptions.assume import AppliedPredicate +from sympy.assumptions.cnf import EncodedCNF +from sympy.assumptions.ask import Q + +from sympy.core import Add, Mul +from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import Pow +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import And, Or, Xor, Implies +from sympy.logic.boolalg import Not, ITE +from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate +from sympy.external import import_module + +def z3_satisfiable(expr, all_models=False): + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + z3 = import_module("z3") + if z3 is None: + raise ImportError("z3 is not installed") + + s = encoded_cnf_to_z3_solver(expr, z3) + + res = str(s.check()) + if res == "unsat": + return False + elif res == "sat": + return z3_model_to_sympy_model(s.model(), expr) + else: + return None + + +def z3_model_to_sympy_model(z3_model, enc_cnf): + rev_enc = {value : key for key, value in enc_cnf.encoding.items()} + return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model} + + +def clause_to_assertion(clause): + clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause] + return "(assert (or " + " ".join(clause_strings) + "))" + + +def encoded_cnf_to_z3_solver(enc_cnf, z3): + def dummify_bool(pred): + return False + assert isinstance(pred, AppliedPredicate) + + if pred.function in [Q.positive, Q.negative, Q.zero]: + return pred + else: + return False + + s = z3.Solver() + + declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables] + assertions = [clause_to_assertion(clause) for clause in enc_cnf.data] + + symbols = set() + for pred, enc in enc_cnf.encoding.items(): + if not isinstance(pred, AppliedPredicate): + continue + if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive): + continue + + pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions) + + symbols |= pred.free_symbols + pred = pred_str + clause = f"(implies d{enc} {pred})" + assertion = "(assert " + clause + ")" + assertions.append(assertion) + + for sym in symbols: + declarations.append(f"(declare-const {sym} Real)") + + declarations = "\n".join(declarations) + assertions = "\n".join(assertions) + s.from_string(declarations) + s.from_string(assertions) + + return s + + +known_functions = { + Add: '+', + Mul: '*', + + Equality: '=', + LessThan: '<=', + GreaterThan: '>=', + StrictLessThan: '<', + StrictGreaterThan: '>', + + EqualityPredicate(): '=', + LessThanPredicate(): '<=', + GreaterThanPredicate(): '>=', + StrictLessThanPredicate(): '<', + StrictGreaterThanPredicate(): '>', + + Abs: 'abs', + Min: 'min', + Max: 'max', + Pow: '^', + + And: 'and', + Or: 'or', + Xor: 'xor', + Not: 'not', + ITE: 'ite', + Implies: '=>', + } diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/logic/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_boolalg.py b/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_boolalg.py new file mode 100644 index 0000000000000000000000000000000000000000..88cdd647fdcc723faee328f71df96030841a3edb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_boolalg.py @@ -0,0 +1,1367 @@ +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine +from sympy.core.numbers import oo +from sympy.core.relational import Equality, Eq, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions import Piecewise +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.sets.sets import Interval, Union +from sympy.sets.contains import Contains +from sympy.simplify.simplify import simplify +from sympy.logic.boolalg import ( + And, Boolean, Equivalent, ITE, Implies, Nand, Nor, Not, Or, + POSform, SOPform, Xor, Xnor, conjuncts, disjuncts, + distribute_or_over_and, distribute_and_over_or, + eliminate_implications, is_nnf, is_cnf, is_dnf, simplify_logic, + to_nnf, to_cnf, to_dnf, to_int_repr, bool_map, true, false, + BooleanAtom, is_literal, term_to_integer, + truth_table, as_Boolean, to_anf, is_anf, distribute_xor_over_and, + anf_coeffs, ANFform, bool_minterm, bool_maxterm, bool_monomial, + _check_pair, _convert_to_varsSOP, _convert_to_varsPOS, Exclusive, + gateinputcount) +from sympy.assumptions.cnf import CNF + +from sympy.testing.pytest import raises, XFAIL, slow + +from itertools import combinations, permutations, product + +A, B, C, D = symbols('A:D') +a, b, c, d, e, w, x, y, z = symbols('a:e w:z') + + +def test_overloading(): + """Test that |, & are overloaded as expected""" + + assert A & B == And(A, B) + assert A | B == Or(A, B) + assert (A & B) | C == Or(And(A, B), C) + assert A >> B == Implies(A, B) + assert A << B == Implies(B, A) + assert ~A == Not(A) + assert A ^ B == Xor(A, B) + + +def test_And(): + assert And() is true + assert And(A) == A + assert And(True) is true + assert And(False) is false + assert And(True, True) is true + assert And(True, False) is false + assert And(False, False) is false + assert And(True, A) == A + assert And(False, A) is false + assert And(True, True, True) is true + assert And(True, True, A) == A + assert And(True, False, A) is false + assert And(1, A) == A + raises(TypeError, lambda: And(2, A)) + assert And(A < 1, A >= 1) is false + e = A > 1 + assert And(e, e.canonical) == e.canonical + g, l, ge, le = A > B, B < A, A >= B, B <= A + assert And(g, l, ge, le) == And(ge, g) + assert {And(*i) for i in permutations((l, g, le, ge))} == {And(ge, g)} + assert And(And(Eq(a, 0), Eq(b, 0)), And(Ne(a, 0), Eq(c, 0))) is false + + +def test_Or(): + assert Or() is false + assert Or(A) == A + assert Or(True) is true + assert Or(False) is false + assert Or(True, True) is true + assert Or(True, False) is true + assert Or(False, False) is false + assert Or(True, A) is true + assert Or(False, A) == A + assert Or(True, False, False) is true + assert Or(True, False, A) is true + assert Or(False, False, A) == A + assert Or(1, A) is true + raises(TypeError, lambda: Or(2, A)) + assert Or(A < 1, A >= 1) is true + e = A > 1 + assert Or(e, e.canonical) == e + g, l, ge, le = A > B, B < A, A >= B, B <= A + assert Or(g, l, ge, le) == Or(g, ge) + + +def test_Xor(): + assert Xor() is false + assert Xor(A) == A + assert Xor(A, A) is false + assert Xor(True, A, A) is true + assert Xor(A, A, A, A, A) == A + assert Xor(True, False, False, A, B) == ~Xor(A, B) + assert Xor(True) is true + assert Xor(False) is false + assert Xor(True, True) is false + assert Xor(True, False) is true + assert Xor(False, False) is false + assert Xor(True, A) == ~A + assert Xor(False, A) == A + assert Xor(True, False, False) is true + assert Xor(True, False, A) == ~A + assert Xor(False, False, A) == A + assert isinstance(Xor(A, B), Xor) + assert Xor(A, B, Xor(C, D)) == Xor(A, B, C, D) + assert Xor(A, B, Xor(B, C)) == Xor(A, C) + assert Xor(A < 1, A >= 1, B) == Xor(0, 1, B) == Xor(1, 0, B) + e = A > 1 + assert Xor(e, e.canonical) == Xor(0, 0) == Xor(1, 1) + + +def test_rewrite_as_And(): + expr = x ^ y + assert expr.rewrite(And) == (x | y) & (~x | ~y) + + +def test_rewrite_as_Or(): + expr = x ^ y + assert expr.rewrite(Or) == (x & ~y) | (y & ~x) + + +def test_rewrite_as_Nand(): + expr = (y & z) | (z & ~w) + assert expr.rewrite(Nand) == ~(~(y & z) & ~(z & ~w)) + + +def test_rewrite_as_Nor(): + expr = z & (y | ~w) + assert expr.rewrite(Nor) == ~(~z | ~(y | ~w)) + + +def test_Not(): + raises(TypeError, lambda: Not(True, False)) + assert Not(True) is false + assert Not(False) is true + assert Not(0) is true + assert Not(1) is false + assert Not(2) is false + + +def test_Nand(): + assert Nand() is false + assert Nand(A) == ~A + assert Nand(True) is false + assert Nand(False) is true + assert Nand(True, True) is false + assert Nand(True, False) is true + assert Nand(False, False) is true + assert Nand(True, A) == ~A + assert Nand(False, A) is true + assert Nand(True, True, True) is false + assert Nand(True, True, A) == ~A + assert Nand(True, False, A) is true + + +def test_Nor(): + assert Nor() is true + assert Nor(A) == ~A + assert Nor(True) is false + assert Nor(False) is true + assert Nor(True, True) is false + assert Nor(True, False) is false + assert Nor(False, False) is true + assert Nor(True, A) is false + assert Nor(False, A) == ~A + assert Nor(True, True, True) is false + assert Nor(True, True, A) is false + assert Nor(True, False, A) is false + + +def test_Xnor(): + assert Xnor() is true + assert Xnor(A) == ~A + assert Xnor(A, A) is true + assert Xnor(True, A, A) is false + assert Xnor(A, A, A, A, A) == ~A + assert Xnor(True) is false + assert Xnor(False) is true + assert Xnor(True, True) is true + assert Xnor(True, False) is false + assert Xnor(False, False) is true + assert Xnor(True, A) == A + assert Xnor(False, A) == ~A + assert Xnor(True, False, False) is false + assert Xnor(True, False, A) == A + assert Xnor(False, False, A) == ~A + + +def test_Implies(): + raises(ValueError, lambda: Implies(A, B, C)) + assert Implies(True, True) is true + assert Implies(True, False) is false + assert Implies(False, True) is true + assert Implies(False, False) is true + assert Implies(0, A) is true + assert Implies(1, 1) is true + assert Implies(1, 0) is false + assert A >> B == B << A + assert (A < 1) >> (A >= 1) == (A >= 1) + assert (A < 1) >> (S.One > A) is true + assert A >> A is true + + +def test_Equivalent(): + assert Equivalent(A, B) == Equivalent(B, A) == Equivalent(A, B, A) + assert Equivalent() is true + assert Equivalent(A, A) == Equivalent(A) is true + assert Equivalent(True, True) == Equivalent(False, False) is true + assert Equivalent(True, False) == Equivalent(False, True) is false + assert Equivalent(A, True) == A + assert Equivalent(A, False) == Not(A) + assert Equivalent(A, B, True) == A & B + assert Equivalent(A, B, False) == ~A & ~B + assert Equivalent(1, A) == A + assert Equivalent(0, A) == Not(A) + assert Equivalent(A, Equivalent(B, C)) != Equivalent(Equivalent(A, B), C) + assert Equivalent(A < 1, A >= 1) is false + assert Equivalent(A < 1, A >= 1, 0) is false + assert Equivalent(A < 1, A >= 1, 1) is false + assert Equivalent(A < 1, S.One > A) == Equivalent(1, 1) == Equivalent(0, 0) + assert Equivalent(Equality(A, B), Equality(B, A)) is true + + +def test_Exclusive(): + assert Exclusive(False, False, False) is true + assert Exclusive(True, False, False) is true + assert Exclusive(True, True, False) is false + assert Exclusive(True, True, True) is false + + +def test_equals(): + assert Not(Or(A, B)).equals(And(Not(A), Not(B))) is True + assert Equivalent(A, B).equals((A >> B) & (B >> A)) is True + assert ((A | ~B) & (~A | B)).equals((~A & ~B) | (A & B)) is True + assert (A >> B).equals(~A >> ~B) is False + assert (A >> (B >> A)).equals(A >> (C >> A)) is False + raises(NotImplementedError, lambda: (A & B).equals(A > B)) + + +def test_simplification_boolalg(): + """ + Test working of simplification methods. + """ + set1 = [[0, 0, 1], [0, 1, 1], [1, 0, 0], [1, 1, 0]] + set2 = [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1]] + assert SOPform([x, y, z], set1) == Or(And(Not(x), z), And(Not(z), x)) + assert Not(SOPform([x, y, z], set2)) == \ + Not(Or(And(Not(x), Not(z)), And(x, z))) + assert POSform([x, y, z], set1 + set2) is true + assert SOPform([x, y, z], set1 + set2) is true + assert SOPform([Dummy(), Dummy(), Dummy()], set1 + set2) is true + + minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], + [1, 1, 1, 1]] + dontcares = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, 3, 7, 11, 15] + dontcares = [0, 2, 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, [0, 0, 1, 1], 7, [1, 0, 1, 1], + [1, 1, 1, 1]] + dontcares = [0, [0, 0, 1, 0], 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, {y: 1, z: 1}] + dontcares = [0, [0, 0, 1, 0], 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [{y: 1, z: 1}, 1] + dontcares = [[0, 0, 0, 0]] + + minterms = [[0, 0, 0]] + raises(ValueError, lambda: SOPform([w, x, y, z], minterms)) + raises(ValueError, lambda: POSform([w, x, y, z], minterms)) + + raises(TypeError, lambda: POSform([w, x, y, z], ["abcdefg"])) + + # test simplification + ans = And(A, Or(B, C)) + assert simplify_logic(A & (B | C)) == ans + assert simplify_logic((A & B) | (A & C)) == ans + assert simplify_logic(Implies(A, B)) == Or(Not(A), B) + assert simplify_logic(Equivalent(A, B)) == \ + Or(And(A, B), And(Not(A), Not(B))) + assert simplify_logic(And(Equality(A, 2), C)) == And(Equality(A, 2), C) + assert simplify_logic(And(Equality(A, 2), A)) == And(Equality(A, 2), A) + assert simplify_logic(And(Equality(A, B), C)) == And(Equality(A, B), C) + assert simplify_logic(Or(And(Equality(A, 3), B), And(Equality(A, 3), C))) \ + == And(Equality(A, 3), Or(B, C)) + b = (~x & ~y & ~z) | (~x & ~y & z) + e = And(A, b) + assert simplify_logic(e) == A & ~x & ~y + raises(ValueError, lambda: simplify_logic(A & (B | C), form='blabla')) + assert simplify(Or(x <= y, And(x < y, z))) == (x <= y) + assert simplify(Or(x <= y, And(y > x, z))) == (x <= y) + assert simplify(Or(x >= y, And(y < x, z))) == (x >= y) + + # Check that expressions with nine variables or more are not simplified + # (without the force-flag) + a, b, c, d, e, f, g, h, j = symbols('a b c d e f g h j') + expr = a & b & c & d & e & f & g & h & j | \ + a & b & c & d & e & f & g & h & ~j + # This expression can be simplified to get rid of the j variables + assert simplify_logic(expr) == expr + + # Test dontcare + assert simplify_logic((a & b) | c | d, dontcare=(a & b)) == c | d + + # check input + ans = SOPform([x, y], [[1, 0]]) + assert SOPform([x, y], [[1, 0]]) == ans + assert POSform([x, y], [[1, 0]]) == ans + + raises(ValueError, lambda: SOPform([x], [[1]], [[1]])) + assert SOPform([x], [[1]], [[0]]) is true + assert SOPform([x], [[0]], [[1]]) is true + assert SOPform([x], [], []) is false + + raises(ValueError, lambda: POSform([x], [[1]], [[1]])) + assert POSform([x], [[1]], [[0]]) is true + assert POSform([x], [[0]], [[1]]) is true + assert POSform([x], [], []) is false + + # check working of simplify + assert simplify((A & B) | (A & C)) == And(A, Or(B, C)) + assert simplify(And(x, Not(x))) == False + assert simplify(Or(x, Not(x))) == True + assert simplify(And(Eq(x, 0), Eq(x, y))) == And(Eq(x, 0), Eq(y, 0)) + assert And(Eq(x - 1, 0), Eq(x, y)).simplify() == And(Eq(x, 1), Eq(y, 1)) + assert And(Ne(x - 1, 0), Ne(x, y)).simplify() == And(Ne(x, 1), Ne(x, y)) + assert And(Eq(x - 1, 0), Ne(x, y)).simplify() == And(Eq(x, 1), Ne(y, 1)) + assert And(Eq(x - 1, 0), Eq(x, z + y), Eq(y + x, 0)).simplify( + ) == And(Eq(x, 1), Eq(y, -1), Eq(z, 2)) + assert And(Eq(x - 1, 0), Eq(x + 2, 3)).simplify() == Eq(x, 1) + assert And(Ne(x - 1, 0), Ne(x + 2, 3)).simplify() == Ne(x, 1) + assert And(Eq(x - 1, 0), Eq(x + 2, 2)).simplify() == False + assert And(Ne(x - 1, 0), Ne(x + 2, 2)).simplify( + ) == And(Ne(x, 1), Ne(x, 0)) + assert simplify(Xor(x, ~x)) == True + + +def test_bool_map(): + """ + Test working of bool_map function. + """ + + minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], + [1, 1, 1, 1]] + assert bool_map(Not(Not(a)), a) == (a, {a: a}) + assert bool_map(SOPform([w, x, y, z], minterms), + POSform([w, x, y, z], minterms)) == \ + (And(Or(Not(w), y), Or(Not(x), y), z), {x: x, w: w, z: z, y: y}) + assert bool_map(SOPform([x, z, y], [[1, 0, 1]]), + SOPform([a, b, c], [[1, 0, 1]])) != False + function1 = SOPform([x, z, y], [[1, 0, 1], [0, 0, 1]]) + function2 = SOPform([a, b, c], [[1, 0, 1], [1, 0, 0]]) + assert bool_map(function1, function2) == \ + (function1, {y: a, z: b}) + assert bool_map(Xor(x, y), ~Xor(x, y)) == False + assert bool_map(And(x, y), Or(x, y)) is None + assert bool_map(And(x, y), And(x, y, z)) is None + # issue 16179 + assert bool_map(Xor(x, y, z), ~Xor(x, y, z)) == False + assert bool_map(Xor(a, x, y, z), ~Xor(a, x, y, z)) == False + + +def test_bool_symbol(): + """Test that mixing symbols with boolean values + works as expected""" + + assert And(A, True) == A + assert And(A, True, True) == A + assert And(A, False) is false + assert And(A, True, False) is false + assert Or(A, True) is true + assert Or(A, False) == A + + +def test_is_boolean(): + assert isinstance(True, Boolean) is False + assert isinstance(true, Boolean) is True + assert 1 == True + assert 1 != true + assert (1 == true) is False + assert 0 == False + assert 0 != false + assert (0 == false) is False + assert true.is_Boolean is True + assert (A & B).is_Boolean + assert (A | B).is_Boolean + assert (~A).is_Boolean + assert (A ^ B).is_Boolean + assert A.is_Boolean != isinstance(A, Boolean) + assert isinstance(A, Boolean) + + +def test_subs(): + assert (A & B).subs(A, True) == B + assert (A & B).subs(A, False) is false + assert (A & B).subs(B, True) == A + assert (A & B).subs(B, False) is false + assert (A & B).subs({A: True, B: True}) is true + assert (A | B).subs(A, True) is true + assert (A | B).subs(A, False) == B + assert (A | B).subs(B, True) is true + assert (A | B).subs(B, False) == A + assert (A | B).subs({A: True, B: True}) is true + + +""" +we test for axioms of boolean algebra +see https://en.wikipedia.org/wiki/Boolean_algebra_(structure) +""" + + +def test_commutative(): + """Test for commutativity of And and Or""" + A, B = map(Boolean, symbols('A,B')) + + assert A & B == B & A + assert A | B == B | A + + +def test_and_associativity(): + """Test for associativity of And""" + + assert (A & B) & C == A & (B & C) + + +def test_or_assicativity(): + assert ((A | B) | C) == (A | (B | C)) + + +def test_double_negation(): + a = Boolean() + assert ~(~a) == a + + +# test methods + +def test_eliminate_implications(): + assert eliminate_implications(Implies(A, B, evaluate=False)) == (~A) | B + assert eliminate_implications( + A >> (C >> Not(B))) == Or(Or(Not(B), Not(C)), Not(A)) + assert eliminate_implications(Equivalent(A, B, C, D)) == \ + (~A | B) & (~B | C) & (~C | D) & (~D | A) + + +def test_conjuncts(): + assert conjuncts(A & B & C) == {A, B, C} + assert conjuncts((A | B) & C) == {A | B, C} + assert conjuncts(A) == {A} + assert conjuncts(True) == {True} + assert conjuncts(False) == {False} + + +def test_disjuncts(): + assert disjuncts(A | B | C) == {A, B, C} + assert disjuncts((A | B) & C) == {(A | B) & C} + assert disjuncts(A) == {A} + assert disjuncts(True) == {True} + assert disjuncts(False) == {False} + + +def test_distribute(): + assert distribute_and_over_or(Or(And(A, B), C)) == And(Or(A, C), Or(B, C)) + assert distribute_or_over_and(And(A, Or(B, C))) == Or(And(A, B), And(A, C)) + assert distribute_xor_over_and(And(A, Xor(B, C))) == Xor(And(A, B), And(A, C)) + + +def test_to_anf(): + x, y, z = symbols('x,y,z') + assert to_anf(And(x, y)) == And(x, y) + assert to_anf(Or(x, y)) == Xor(x, y, And(x, y)) + assert to_anf(Or(Implies(x, y), And(x, y), y)) == \ + Xor(x, True, x & y, remove_true=False) + assert to_anf(Or(Nand(x, y), Nor(x, y), Xnor(x, y), Implies(x, y))) == True + assert to_anf(Or(x, Not(y), Nor(x, z), And(x, y), Nand(y, z))) == \ + Xor(True, And(y, z), And(x, y, z), remove_true=False) + assert to_anf(Xor(x, y)) == Xor(x, y) + assert to_anf(Not(x)) == Xor(x, True, remove_true=False) + assert to_anf(Nand(x, y)) == Xor(True, And(x, y), remove_true=False) + assert to_anf(Nor(x, y)) == Xor(x, y, True, And(x, y), remove_true=False) + assert to_anf(Implies(x, y)) == Xor(x, True, And(x, y), remove_true=False) + assert to_anf(Equivalent(x, y)) == Xor(x, y, True, remove_true=False) + assert to_anf(Nand(x | y, x >> y), deep=False) == \ + Xor(True, And(Or(x, y), Implies(x, y)), remove_true=False) + assert to_anf(Nor(x ^ y, x & y), deep=False) == \ + Xor(True, Or(Xor(x, y), And(x, y)), remove_true=False) + # issue 25218 + assert to_anf(x ^ ~(x ^ y ^ ~y)) == False + + +def test_to_nnf(): + assert to_nnf(true) is true + assert to_nnf(false) is false + assert to_nnf(A) == A + assert to_nnf(A | ~A | B) is true + assert to_nnf(A & ~A & B) is false + assert to_nnf(A >> B) == ~A | B + assert to_nnf(Equivalent(A, B, C)) == (~A | B) & (~B | C) & (~C | A) + assert to_nnf(A ^ B ^ C) == \ + (A | B | C) & (~A | ~B | C) & (A | ~B | ~C) & (~A | B | ~C) + assert to_nnf(ITE(A, B, C)) == (~A | B) & (A | C) + assert to_nnf(Not(A | B | C)) == ~A & ~B & ~C + assert to_nnf(Not(A & B & C)) == ~A | ~B | ~C + assert to_nnf(Not(A >> B)) == A & ~B + assert to_nnf(Not(Equivalent(A, B, C))) == And(Or(A, B, C), Or(~A, ~B, ~C)) + assert to_nnf(Not(A ^ B ^ C)) == \ + (~A | B | C) & (A | ~B | C) & (A | B | ~C) & (~A | ~B | ~C) + assert to_nnf(Not(ITE(A, B, C))) == (~A | ~B) & (A | ~C) + assert to_nnf((A >> B) ^ (B >> A)) == (A & ~B) | (~A & B) + assert to_nnf((A >> B) ^ (B >> A), False) == \ + (~A | ~B | A | B) & ((A & ~B) | (~A & B)) + assert ITE(A, 1, 0).to_nnf() == A + assert ITE(A, 0, 1).to_nnf() == ~A + # although ITE can hold non-Boolean, it will complain if + # an attempt is made to convert the ITE to Boolean nnf + raises(TypeError, lambda: ITE(A < 1, [1], B).to_nnf()) + + +def test_to_cnf(): + assert to_cnf(~(B | C)) == And(Not(B), Not(C)) + assert to_cnf((A & B) | C) == And(Or(A, C), Or(B, C)) + assert to_cnf(A >> B) == (~A) | B + assert to_cnf(A >> (B & C)) == (~A | B) & (~A | C) + assert to_cnf(A & (B | C) | ~A & (B | C), True) == B | C + assert to_cnf(A & B) == And(A, B) + + assert to_cnf(Equivalent(A, B)) == And(Or(A, Not(B)), Or(B, Not(A))) + assert to_cnf(Equivalent(A, B & C)) == \ + (~A | B) & (~A | C) & (~B | ~C | A) + assert to_cnf(Equivalent(A, B | C), True) == \ + And(Or(Not(B), A), Or(Not(C), A), Or(B, C, Not(A))) + assert to_cnf(A + 1) == A + 1 + + +def test_issue_18904(): + x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 = symbols('x1:16') + eq = ((x1 & x2 & x3 & x4 & x5 & x6 & x7 & x8 & x9) | + (x1 & x2 & x3 & x4 & x5 & x6 & x7 & x10 & x9) | + (x1 & x11 & x3 & x12 & x5 & x13 & x14 & x15 & x9)) + assert is_cnf(to_cnf(eq)) + raises(ValueError, lambda: to_cnf(eq, simplify=True)) + for f, t in zip((And, Or), (to_cnf, to_dnf)): + eq = f(x1, x2, x3, x4, x5, x6, x7, x8, x9) + raises(ValueError, lambda: to_cnf(eq, simplify=True)) + assert t(eq, simplify=True, force=True) == eq + + +def test_issue_9949(): + assert is_cnf(to_cnf((b > -5) | (a > 2) & (a < 4))) + + +def test_to_CNF(): + assert CNF.CNF_to_cnf(CNF.to_CNF(~(B | C))) == to_cnf(~(B | C)) + assert CNF.CNF_to_cnf(CNF.to_CNF((A & B) | C)) == to_cnf((A & B) | C) + assert CNF.CNF_to_cnf(CNF.to_CNF(A >> B)) == to_cnf(A >> B) + assert CNF.CNF_to_cnf(CNF.to_CNF(A >> (B & C))) == to_cnf(A >> (B & C)) + assert CNF.CNF_to_cnf(CNF.to_CNF(A & (B | C) | ~A & (B | C))) == to_cnf(A & (B | C) | ~A & (B | C)) + assert CNF.CNF_to_cnf(CNF.to_CNF(A & B)) == to_cnf(A & B) + + +def test_to_dnf(): + assert to_dnf(~(B | C)) == And(Not(B), Not(C)) + assert to_dnf(A & (B | C)) == Or(And(A, B), And(A, C)) + assert to_dnf(A >> B) == (~A) | B + assert to_dnf(A >> (B & C)) == (~A) | (B & C) + assert to_dnf(A | B) == A | B + + assert to_dnf(Equivalent(A, B), True) == \ + Or(And(A, B), And(Not(A), Not(B))) + assert to_dnf(Equivalent(A, B & C), True) == \ + Or(And(A, B, C), And(Not(A), Not(B)), And(Not(A), Not(C))) + assert to_dnf(A + 1) == A + 1 + + +def test_to_int_repr(): + x, y, z = map(Boolean, symbols('x,y,z')) + + def sorted_recursive(arg): + try: + return sorted(sorted_recursive(x) for x in arg) + except TypeError: # arg is not a sequence + return arg + + assert sorted_recursive(to_int_repr([x | y, z | x], [x, y, z])) == \ + sorted_recursive([[1, 2], [1, 3]]) + assert sorted_recursive(to_int_repr([x | y, z | ~x], [x, y, z])) == \ + sorted_recursive([[1, 2], [3, -1]]) + + +def test_is_anf(): + x, y = symbols('x,y') + assert is_anf(true) is True + assert is_anf(false) is True + assert is_anf(x) is True + assert is_anf(And(x, y)) is True + assert is_anf(Xor(x, y, And(x, y))) is True + assert is_anf(Xor(x, y, Or(x, y))) is False + assert is_anf(Xor(Not(x), y)) is False + + +def test_is_nnf(): + assert is_nnf(true) is True + assert is_nnf(A) is True + assert is_nnf(~A) is True + assert is_nnf(A & B) is True + assert is_nnf((A & B) | (~A & A) | (~B & B) | (~A & ~B), False) is True + assert is_nnf((A | B) & (~A | ~B)) is True + assert is_nnf(Not(Or(A, B))) is False + assert is_nnf(A ^ B) is False + assert is_nnf((A & B) | (~A & A) | (~B & B) | (~A & ~B), True) is False + + +def test_is_cnf(): + assert is_cnf(x) is True + assert is_cnf(x | y | z) is True + assert is_cnf(x & y & z) is True + assert is_cnf((x | y) & z) is True + assert is_cnf((x & y) | z) is False + assert is_cnf(~(x & y) | z) is False + + +def test_is_dnf(): + assert is_dnf(x) is True + assert is_dnf(x | y | z) is True + assert is_dnf(x & y & z) is True + assert is_dnf((x & y) | z) is True + assert is_dnf((x | y) & z) is False + assert is_dnf(~(x | y) & z) is False + + +def test_ITE(): + A, B, C = symbols('A:C') + assert ITE(True, False, True) is false + assert ITE(True, True, False) is true + assert ITE(False, True, False) is false + assert ITE(False, False, True) is true + assert isinstance(ITE(A, B, C), ITE) + + A = True + assert ITE(A, B, C) == B + A = False + assert ITE(A, B, C) == C + B = True + assert ITE(And(A, B), B, C) == C + assert ITE(Or(A, False), And(B, True), False) is false + assert ITE(x, A, B) == Not(x) + assert ITE(x, B, A) == x + assert ITE(1, x, y) == x + assert ITE(0, x, y) == y + raises(TypeError, lambda: ITE(2, x, y)) + raises(TypeError, lambda: ITE(1, [], y)) + raises(TypeError, lambda: ITE(1, (), y)) + raises(TypeError, lambda: ITE(1, y, [])) + assert ITE(1, 1, 1) is S.true + assert isinstance(ITE(1, 1, 1, evaluate=False), ITE) + + assert ITE(Eq(x, True), y, x) == ITE(x, y, x) + assert ITE(Eq(x, False), y, x) == ITE(~x, y, x) + assert ITE(Ne(x, True), y, x) == ITE(~x, y, x) + assert ITE(Ne(x, False), y, x) == ITE(x, y, x) + assert ITE(Eq(S.true, x), y, x) == ITE(x, y, x) + assert ITE(Eq(S.false, x), y, x) == ITE(~x, y, x) + assert ITE(Ne(S.true, x), y, x) == ITE(~x, y, x) + assert ITE(Ne(S.false, x), y, x) == ITE(x, y, x) + # 0 and 1 in the context are not treated as True/False + # so the equality must always be False since dissimilar + # objects cannot be equal + assert ITE(Eq(x, 0), y, x) == x + assert ITE(Eq(x, 1), y, x) == x + assert ITE(Ne(x, 0), y, x) == y + assert ITE(Ne(x, 1), y, x) == y + assert ITE(Eq(x, 0), y, z).subs(x, 0) == y + assert ITE(Eq(x, 0), y, z).subs(x, 1) == z + raises(ValueError, lambda: ITE(x > 1, y, x, z)) + + +def test_is_literal(): + assert is_literal(True) is True + assert is_literal(False) is True + assert is_literal(A) is True + assert is_literal(~A) is True + assert is_literal(Or(A, B)) is False + assert is_literal(Q.zero(A)) is True + assert is_literal(Not(Q.zero(A))) is True + assert is_literal(Or(A, B)) is False + assert is_literal(And(Q.zero(A), Q.zero(B))) is False + assert is_literal(x < 3) + assert not is_literal(x + y < 3) + + +def test_operators(): + # Mostly test __and__, __rand__, and so on + assert True & A == A & True == A + assert False & A == A & False == False + assert A & B == And(A, B) + assert True | A == A | True == True + assert False | A == A | False == A + assert A | B == Or(A, B) + assert ~A == Not(A) + assert True >> A == A << True == A + assert False >> A == A << False == True + assert A >> True == True << A == True + assert A >> False == False << A == ~A + assert A >> B == B << A == Implies(A, B) + assert True ^ A == A ^ True == ~A + assert False ^ A == A ^ False == A + assert A ^ B == Xor(A, B) + + +def test_true_false(): + assert true is S.true + assert false is S.false + assert true is not True + assert false is not False + assert true + assert not false + assert true == True + assert false == False + assert not (true == False) + assert not (false == True) + assert not (true == false) + + assert hash(true) == hash(True) + assert hash(false) == hash(False) + assert len({true, True}) == len({false, False}) == 1 + + assert isinstance(true, BooleanAtom) + assert isinstance(false, BooleanAtom) + # We don't want to subclass from bool, because bool subclasses from + # int. But operators like &, |, ^, <<, >>, and ~ act differently on 0 and + # 1 then we want them to on true and false. See the docstrings of the + # various And, Or, etc. functions for examples. + assert not isinstance(true, bool) + assert not isinstance(false, bool) + + # Note: using 'is' comparison is important here. We want these to return + # true and false, not True and False + + assert Not(true) is false + assert Not(True) is false + assert Not(false) is true + assert Not(False) is true + assert ~true is false + assert ~false is true + + for T, F in product((True, true), (False, false)): + assert And(T, F) is false + assert And(F, T) is false + assert And(F, F) is false + assert And(T, T) is true + assert And(T, x) == x + assert And(F, x) is false + if not (T is True and F is False): + assert T & F is false + assert F & T is false + if F is not False: + assert F & F is false + if T is not True: + assert T & T is true + + assert Or(T, F) is true + assert Or(F, T) is true + assert Or(F, F) is false + assert Or(T, T) is true + assert Or(T, x) is true + assert Or(F, x) == x + if not (T is True and F is False): + assert T | F is true + assert F | T is true + if F is not False: + assert F | F is false + if T is not True: + assert T | T is true + + assert Xor(T, F) is true + assert Xor(F, T) is true + assert Xor(F, F) is false + assert Xor(T, T) is false + assert Xor(T, x) == ~x + assert Xor(F, x) == x + if not (T is True and F is False): + assert T ^ F is true + assert F ^ T is true + if F is not False: + assert F ^ F is false + if T is not True: + assert T ^ T is false + + assert Nand(T, F) is true + assert Nand(F, T) is true + assert Nand(F, F) is true + assert Nand(T, T) is false + assert Nand(T, x) == ~x + assert Nand(F, x) is true + + assert Nor(T, F) is false + assert Nor(F, T) is false + assert Nor(F, F) is true + assert Nor(T, T) is false + assert Nor(T, x) is false + assert Nor(F, x) == ~x + + assert Implies(T, F) is false + assert Implies(F, T) is true + assert Implies(F, F) is true + assert Implies(T, T) is true + assert Implies(T, x) == x + assert Implies(F, x) is true + assert Implies(x, T) is true + assert Implies(x, F) == ~x + if not (T is True and F is False): + assert T >> F is false + assert F << T is false + assert F >> T is true + assert T << F is true + if F is not False: + assert F >> F is true + assert F << F is true + if T is not True: + assert T >> T is true + assert T << T is true + + assert Equivalent(T, F) is false + assert Equivalent(F, T) is false + assert Equivalent(F, F) is true + assert Equivalent(T, T) is true + assert Equivalent(T, x) == x + assert Equivalent(F, x) == ~x + assert Equivalent(x, T) == x + assert Equivalent(x, F) == ~x + + assert ITE(T, T, T) is true + assert ITE(T, T, F) is true + assert ITE(T, F, T) is false + assert ITE(T, F, F) is false + assert ITE(F, T, T) is true + assert ITE(F, T, F) is false + assert ITE(F, F, T) is true + assert ITE(F, F, F) is false + + assert all(i.simplify(1, 2) is i for i in (S.true, S.false)) + + +def test_bool_as_set(): + assert ITE(y <= 0, False, y >= 1).as_set() == Interval(1, oo) + assert And(x <= 2, x >= -2).as_set() == Interval(-2, 2) + assert Or(x >= 2, x <= -2).as_set() == Interval(-oo, -2) + Interval(2, oo) + assert Not(x > 2).as_set() == Interval(-oo, 2) + # issue 10240 + assert Not(And(x > 2, x < 3)).as_set() == \ + Union(Interval(-oo, 2), Interval(3, oo)) + assert true.as_set() == S.UniversalSet + assert false.as_set() is S.EmptySet + assert x.as_set() == S.UniversalSet + assert And(Or(x < 1, x > 3), x < 2).as_set() == Interval.open(-oo, 1) + assert And(x < 1, sin(x) < 3).as_set() == (x < 1).as_set() + raises(NotImplementedError, lambda: (sin(x) < 1).as_set()) + # watch for object morph in as_set + assert Eq(-1, cos(2 * x) ** 2 / sin(2 * x) ** 2).as_set() is S.EmptySet + + +@XFAIL +def test_multivariate_bool_as_set(): + x, y = symbols('x,y') + + assert And(x >= 0, y >= 0).as_set() == Interval(0, oo) * Interval(0, oo) + assert Or(x >= 0, y >= 0).as_set() == S.Reals * S.Reals - \ + Interval(-oo, 0, True, True) * Interval(-oo, 0, True, True) + + +def test_all_or_nothing(): + x = symbols('x', extended_real=True) + args = x >= -oo, x <= oo + v = And(*args) + if v.func is And: + assert len(v.args) == len(args) - args.count(S.true) + else: + assert v == True + v = Or(*args) + if v.func is Or: + assert len(v.args) == 2 + else: + assert v == True + + +def test_canonical_atoms(): + assert true.canonical == true + assert false.canonical == false + + +def test_negated_atoms(): + assert true.negated == false + assert false.negated == true + + +def test_issue_8777(): + assert And(x > 2, x < oo).as_set() == Interval(2, oo, left_open=True) + assert And(x >= 1, x < oo).as_set() == Interval(1, oo) + assert (x < oo).as_set() == Interval(-oo, oo) + assert (x > -oo).as_set() == Interval(-oo, oo) + + +def test_issue_8975(): + assert Or(And(-oo < x, x <= -2), And(2 <= x, x < oo)).as_set() == \ + Interval(-oo, -2) + Interval(2, oo) + + +def test_term_to_integer(): + assert term_to_integer([1, 0, 1, 0, 0, 1, 0]) == 82 + assert term_to_integer('0010101000111001') == 10809 + + +def test_issue_21971(): + a, b, c, d = symbols('a b c d') + f = a & b & c | a & c + assert f.subs(a & c, d) == b & d | d + assert f.subs(a & b & c, d) == a & c | d + + f = (a | b | c) & (a | c) + assert f.subs(a | c, d) == (b | d) & d + assert f.subs(a | b | c, d) == (a | c) & d + + f = (a ^ b ^ c) & (a ^ c) + assert f.subs(a ^ c, d) == (b ^ d) & d + assert f.subs(a ^ b ^ c, d) == (a ^ c) & d + + +def test_truth_table(): + assert list(truth_table(And(x, y), [x, y], input=False)) == \ + [False, False, False, True] + assert list(truth_table(x | y, [x, y], input=False)) == \ + [False, True, True, True] + assert list(truth_table(x >> y, [x, y], input=False)) == \ + [True, True, False, True] + assert list(truth_table(And(x, y), [x, y])) == \ + [([0, 0], False), ([0, 1], False), ([1, 0], False), ([1, 1], True)] + + +def test_issue_8571(): + for t in (S.true, S.false): + raises(TypeError, lambda: +t) + raises(TypeError, lambda: -t) + raises(TypeError, lambda: abs(t)) + # use int(bool(t)) to get 0 or 1 + raises(TypeError, lambda: int(t)) + + for o in [S.Zero, S.One, x]: + for _ in range(2): + raises(TypeError, lambda: o + t) + raises(TypeError, lambda: o - t) + raises(TypeError, lambda: o % t) + raises(TypeError, lambda: o * t) + raises(TypeError, lambda: o / t) + raises(TypeError, lambda: o ** t) + o, t = t, o # do again in reversed order + + +def test_expand_relational(): + n = symbols('n', negative=True) + p, q = symbols('p q', positive=True) + r = ((n + q * (-n / q + 1)) / (q * (-n / q + 1)) < 0) + assert r is not S.false + assert r.expand() is S.false + assert (q > 0).expand() is S.true + + +def test_issue_12717(): + assert S.true.is_Atom == True + assert S.false.is_Atom == True + + +def test_as_Boolean(): + nz = symbols('nz', nonzero=True) + assert all(as_Boolean(i) is S.true for i in (True, S.true, 1, nz)) + z = symbols('z', zero=True) + assert all(as_Boolean(i) is S.false for i in (False, S.false, 0, z)) + assert all(as_Boolean(i) == i for i in (x, x < 0)) + for i in (2, S(2), x + 1, []): + raises(TypeError, lambda: as_Boolean(i)) + + +def test_binary_symbols(): + assert ITE(x < 1, y, z).binary_symbols == {y, z} + for f in (Eq, Ne): + assert f(x, 1).binary_symbols == set() + assert f(x, True).binary_symbols == {x} + assert f(x, False).binary_symbols == {x} + assert S.true.binary_symbols == set() + assert S.false.binary_symbols == set() + assert x.binary_symbols == {x} + assert And(x, Eq(y, False), Eq(z, 1)).binary_symbols == {x, y} + assert Q.prime(x).binary_symbols == set() + assert Q.lt(x, 1).binary_symbols == set() + assert Q.is_true(x).binary_symbols == {x} + assert Q.eq(x, True).binary_symbols == {x} + assert Q.prime(x).binary_symbols == set() + + +def test_BooleanFunction_diff(): + assert And(x, y).diff(x) == Piecewise((0, Eq(y, False)), (1, True)) + + +def test_issue_14700(): + A, B, C, D, E, F, G, H = symbols('A B C D E F G H') + q = ((B & D & H & ~F) | (B & H & ~C & ~D) | (B & H & ~C & ~F) | + (B & H & ~D & ~G) | (B & H & ~F & ~G) | (C & G & ~B & ~D) | + (C & G & ~D & ~H) | (C & G & ~F & ~H) | (D & F & H & ~B) | + (D & F & ~G & ~H) | (B & D & F & ~C & ~H) | (D & E & F & ~B & ~C) | + (D & F & ~A & ~B & ~C) | (D & F & ~A & ~C & ~H) | + (A & B & D & F & ~E & ~H)) + soldnf = ((B & D & H & ~F) | (D & F & H & ~B) | (B & H & ~C & ~D) | + (B & H & ~D & ~G) | (C & G & ~B & ~D) | (C & G & ~D & ~H) | + (C & G & ~F & ~H) | (D & F & ~G & ~H) | (D & E & F & ~C & ~H) | + (D & F & ~A & ~C & ~H) | (A & B & D & F & ~E & ~H)) + solcnf = ((B | C | D) & (B | D | G) & (C | D | H) & (C | F | H) & + (D | G | H) & (F | G | H) & (B | F | ~D | ~H) & + (~B | ~D | ~F | ~H) & (D | ~B | ~C | ~G | ~H) & + (A | H | ~C | ~D | ~F | ~G) & (H | ~C | ~D | ~E | ~F | ~G) & + (B | E | H | ~A | ~D | ~F | ~G)) + assert simplify_logic(q, "dnf") == soldnf + assert simplify_logic(q, "cnf") == solcnf + + minterms = [[0, 1, 0, 0], [0, 1, 0, 1], [0, 1, 1, 0], [0, 1, 1, 1], + [0, 0, 1, 1], [1, 0, 1, 1]] + dontcares = [[1, 0, 0, 0], [1, 0, 0, 1], [1, 1, 0, 0], [1, 1, 0, 1]] + assert SOPform([w, x, y, z], minterms) == (x & ~w) | (y & z & ~x) + # Should not be more complicated with don't cares + assert SOPform([w, x, y, z], minterms, dontcares) == \ + (x & ~w) | (y & z & ~x) + + +def test_issue_25115(): + cond = Contains(x, S.Integers) + # Previously this raised an exception: + assert simplify_logic(cond) == cond + + +def test_relational_simplification(): + w, x, y, z = symbols('w x y z', real=True) + d, e = symbols('d e', real=False) + # Test all combinations or sign and order + assert Or(x >= y, x < y).simplify() == S.true + assert Or(x >= y, y > x).simplify() == S.true + assert Or(x >= y, -x > -y).simplify() == S.true + assert Or(x >= y, -y < -x).simplify() == S.true + assert Or(-x <= -y, x < y).simplify() == S.true + assert Or(-x <= -y, -x > -y).simplify() == S.true + assert Or(-x <= -y, y > x).simplify() == S.true + assert Or(-x <= -y, -y < -x).simplify() == S.true + assert Or(y <= x, x < y).simplify() == S.true + assert Or(y <= x, y > x).simplify() == S.true + assert Or(y <= x, -x > -y).simplify() == S.true + assert Or(y <= x, -y < -x).simplify() == S.true + assert Or(-y >= -x, x < y).simplify() == S.true + assert Or(-y >= -x, y > x).simplify() == S.true + assert Or(-y >= -x, -x > -y).simplify() == S.true + assert Or(-y >= -x, -y < -x).simplify() == S.true + + assert Or(x < y, x >= y).simplify() == S.true + assert Or(y > x, x >= y).simplify() == S.true + assert Or(-x > -y, x >= y).simplify() == S.true + assert Or(-y < -x, x >= y).simplify() == S.true + assert Or(x < y, -x <= -y).simplify() == S.true + assert Or(-x > -y, -x <= -y).simplify() == S.true + assert Or(y > x, -x <= -y).simplify() == S.true + assert Or(-y < -x, -x <= -y).simplify() == S.true + assert Or(x < y, y <= x).simplify() == S.true + assert Or(y > x, y <= x).simplify() == S.true + assert Or(-x > -y, y <= x).simplify() == S.true + assert Or(-y < -x, y <= x).simplify() == S.true + assert Or(x < y, -y >= -x).simplify() == S.true + assert Or(y > x, -y >= -x).simplify() == S.true + assert Or(-x > -y, -y >= -x).simplify() == S.true + assert Or(-y < -x, -y >= -x).simplify() == S.true + + # Some other tests + assert Or(x >= y, w < z, x <= y).simplify() == S.true + assert And(x >= y, x < y).simplify() == S.false + assert Or(x >= y, Eq(y, x)).simplify() == (x >= y) + assert And(x >= y, Eq(y, x)).simplify() == Eq(x, y) + assert And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify() == \ + (Eq(x, y) & (x >= 1) & (y >= 5) & (y > z)) + assert Or(Eq(x, y), x >= y, w < y, z < y).simplify() == \ + (x >= y) | (y > z) | (w < y) + assert And(Eq(x, y), x >= y, w < y, y >= z, z < y).simplify() == \ + Eq(x, y) & (y > z) & (w < y) + # assert And(Eq(x, y), x >= y, w < y, y >= z, z < y).simplify(relational_minmax=True) == \ + # And(Eq(x, y), y > Max(w, z)) + # assert Or(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify(relational_minmax=True) == \ + # (Eq(x, y) | (x >= 1) | (y > Min(2, z))) + assert And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify() == \ + (Eq(x, y) & (x >= 1) & (y >= 5) & (y > z)) + assert (Eq(x, y) & Eq(d, e) & (x >= y) & (d >= e)).simplify() == \ + (Eq(x, y) & Eq(d, e) & (d >= e)) + assert And(Eq(x, y), Eq(x, -y)).simplify() == And(Eq(x, 0), Eq(y, 0)) + assert Xor(x >= y, x <= y).simplify() == Ne(x, y) + assert And(x > 1, x < -1, Eq(x, y)).simplify() == S.false + # From #16690 + assert And(x >= y, Eq(y, 0)).simplify() == And(x >= 0, Eq(y, 0)) + assert Or(Ne(x, 1), Ne(x, 2)).simplify() == S.true + assert And(Eq(x, 1), Ne(2, x)).simplify() == Eq(x, 1) + assert Or(Eq(x, 1), Ne(2, x)).simplify() == Ne(x, 2) + + +def test_issue_8373(): + x = symbols('x', real=True) + assert Or(x < 1, x > -1).simplify() == S.true + assert Or(x < 1, x >= 1).simplify() == S.true + assert And(x < 1, x >= 1).simplify() == S.false + assert Or(x <= 1, x >= 1).simplify() == S.true + + +def test_issue_7950(): + x = symbols('x', real=True) + assert And(Eq(x, 1), Eq(x, 2)).simplify() == S.false + + +@slow +def test_relational_simplification_numerically(): + def test_simplification_numerically_function(original, simplified): + symb = original.free_symbols + n = len(symb) + valuelist = list(set(combinations(list(range(-(n - 1), n)) * n, n))) + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.subs(sublist) + simplifiedvalue = simplified.subs(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for {}" \ + "".format(original, simplified, sublist) + + w, x, y, z = symbols('w x y z', real=True) + d, e = symbols('d e', real=False) + + expressions = (And(Eq(x, y), x >= y, w < y, y >= z, z < y), + And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y), + Or(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y), + And(x >= y, Eq(y, x)), + Or(And(Eq(x, y), x >= y, w < y, Or(y >= z, z < y)), + And(Eq(x, y), x >= 1, 2 < y, y >= -1, z < y)), + (Eq(x, y) & Eq(d, e) & (x >= y) & (d >= e)), + ) + + for expression in expressions: + test_simplification_numerically_function(expression, + expression.simplify()) + + +def test_relational_simplification_patterns_numerically(): + from sympy.core import Wild + from sympy.logic.boolalg import _simplify_patterns_and, \ + _simplify_patterns_or, _simplify_patterns_xor + a = Wild('a') + b = Wild('b') + c = Wild('c') + symb = [a, b, c] + patternlists = [[And, _simplify_patterns_and()], + [Or, _simplify_patterns_or()], + [Xor, _simplify_patterns_xor()]] + valuelist = list(set(combinations(list(range(-2, 3)) * 3, 3))) + # Skip combinations of +/-2 and 0, except for all 0 + valuelist = [v for v in valuelist if any(w % 2 for w in v) or not any(v)] + for func, patternlist in patternlists: + for pattern in patternlist: + original = func(*pattern[0].args) + simplified = pattern[1] + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.xreplace(sublist) + simplifiedvalue = simplified.xreplace(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for" \ + "{}".format(pattern[0], simplified, sublist) + + +def test_issue_16803(): + n = symbols('n') + # No simplification done, but should not raise an exception + assert ((n > 3) | (n < 0) | ((n > 0) & (n < 3))).simplify() == \ + (n > 3) | (n < 0) | ((n > 0) & (n < 3)) + + +def test_issue_17530(): + r = {x: oo, y: oo} + assert Or(x + y > 0, x - y < 0).subs(r) + assert not And(x + y < 0, x - y < 0).subs(r) + raises(TypeError, lambda: Or(x + y < 0, x - y < 0).subs(r)) + raises(TypeError, lambda: And(x + y > 0, x - y < 0).subs(r)) + raises(TypeError, lambda: And(x + y > 0, x - y < 0).subs(r)) + + +def test_anf_coeffs(): + assert anf_coeffs([1, 0]) == [1, 1] + assert anf_coeffs([0, 0, 0, 1]) == [0, 0, 0, 1] + assert anf_coeffs([0, 1, 1, 1]) == [0, 1, 1, 1] + assert anf_coeffs([1, 1, 1, 0]) == [1, 0, 0, 1] + assert anf_coeffs([1, 0, 0, 0]) == [1, 1, 1, 1] + assert anf_coeffs([1, 0, 0, 1]) == [1, 1, 1, 0] + assert anf_coeffs([1, 1, 0, 1]) == [1, 0, 1, 1] + + +def test_ANFform(): + x, y = symbols('x,y') + assert ANFform([x], [1, 1]) == True + assert ANFform([x], [0, 0]) == False + assert ANFform([x], [1, 0]) == Xor(x, True, remove_true=False) + assert ANFform([x, y], [1, 1, 1, 0]) == \ + Xor(True, And(x, y), remove_true=False) + + +def test_bool_minterm(): + x, y = symbols('x,y') + assert bool_minterm(3, [x, y]) == And(x, y) + assert bool_minterm([1, 0], [x, y]) == And(Not(y), x) + + +def test_bool_maxterm(): + x, y = symbols('x,y') + assert bool_maxterm(2, [x, y]) == Or(Not(x), y) + assert bool_maxterm([0, 1], [x, y]) == Or(Not(y), x) + + +def test_bool_monomial(): + x, y = symbols('x,y') + assert bool_monomial(1, [x, y]) == y + assert bool_monomial([1, 1], [x, y]) == And(x, y) + + +def test_check_pair(): + assert _check_pair([0, 1, 0], [0, 1, 1]) == 2 + assert _check_pair([0, 1, 0], [1, 1, 1]) == -1 + + +def test_issue_19114(): + expr = (B & C) | (A & ~C) | (~A & ~B) + # Expression is minimal, but there are multiple minimal forms possible + res1 = (A & B) | (C & ~A) | (~B & ~C) + result = to_dnf(expr, simplify=True) + assert result in (expr, res1) + + +def test_issue_20870(): + result = SOPform([a, b, c, d], [1, 2, 3, 4, 5, 6, 8, 9, 11, 12, 14, 15]) + expected = ((d & ~b) | (a & b & c) | (a & ~c & ~d) | + (b & ~a & ~c) | (c & ~a & ~d)) + assert result == expected + + +def test_convert_to_varsSOP(): + assert _convert_to_varsSOP([0, 1, 0], [x, y, z]) == And(Not(x), y, Not(z)) + assert _convert_to_varsSOP([3, 1, 0], [x, y, z]) == And(y, Not(z)) + + +def test_convert_to_varsPOS(): + assert _convert_to_varsPOS([0, 1, 0], [x, y, z]) == Or(x, Not(y), z) + assert _convert_to_varsPOS([3, 1, 0], [x, y, z]) == Or(Not(y), z) + + +def test_gateinputcount(): + a, b, c, d, e = symbols('a:e') + assert gateinputcount(And(a, b)) == 2 + assert gateinputcount(a | b & c & d ^ (e | a)) == 9 + assert gateinputcount(And(a, True)) == 0 + raises(TypeError, lambda: gateinputcount(a * b)) + + +def test_refine(): + # relational + assert not refine(x < 0, ~(x < 0)) + assert refine(x < 0, (x < 0)) + assert refine(x < 0, (0 > x)) is S.true + assert refine(x < 0, (y < 0)) == (x < 0) + assert not refine(x <= 0, ~(x <= 0)) + assert refine(x <= 0, (x <= 0)) + assert refine(x <= 0, (0 >= x)) is S.true + assert refine(x <= 0, (y <= 0)) == (x <= 0) + assert not refine(x > 0, ~(x > 0)) + assert refine(x > 0, (x > 0)) + assert refine(x > 0, (0 < x)) is S.true + assert refine(x > 0, (y > 0)) == (x > 0) + assert not refine(x >= 0, ~(x >= 0)) + assert refine(x >= 0, (x >= 0)) + assert refine(x >= 0, (0 <= x)) is S.true + assert refine(x >= 0, (y >= 0)) == (x >= 0) + assert not refine(Eq(x, 0), ~(Eq(x, 0))) + assert refine(Eq(x, 0), (Eq(x, 0))) + assert refine(Eq(x, 0), (Eq(0, x))) is S.true + assert refine(Eq(x, 0), (Eq(y, 0))) == Eq(x, 0) + assert not refine(Ne(x, 0), ~(Ne(x, 0))) + assert refine(Ne(x, 0), (Ne(0, x))) is S.true + assert refine(Ne(x, 0), (Ne(x, 0))) + assert refine(Ne(x, 0), (Ne(y, 0))) == (Ne(x, 0)) + + # boolean functions + assert refine(And(x > 0, y > 0), (x > 0)) == (y > 0) + assert refine(And(x > 0, y > 0), (x > 0) & (y > 0)) is S.true + + # predicates + assert refine(Q.positive(x), Q.positive(x)) is S.true + assert refine(Q.positive(x), Q.negative(x)) is S.false + assert refine(Q.positive(x), Q.real(x)) == Q.positive(x) + + +def test_relational_threeterm_simplification_patterns_numerically(): + from sympy.core import Wild + from sympy.logic.boolalg import _simplify_patterns_and3 + a = Wild('a') + b = Wild('b') + c = Wild('c') + symb = [a, b, c] + patternlists = [[And, _simplify_patterns_and3()]] + valuelist = list(set(combinations(list(range(-2, 3)) * 3, 3))) + # Skip combinations of +/-2 and 0, except for all 0 + valuelist = [v for v in valuelist if any(w % 2 for w in v) or not any(v)] + for func, patternlist in patternlists: + for pattern in patternlist: + original = func(*pattern[0].args) + simplified = pattern[1] + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.xreplace(sublist) + simplifiedvalue = simplified.xreplace(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for" \ + "{}".format(pattern[0], simplified, sublist) + + +def test_issue_25451(): + x = Or(And(a, c), Eq(a, b)) + assert isinstance(x, Or) + assert set(x.args) == {And(a, c), Eq(a, b)} + + +def test_issue_26985(): + a, b, c, d = symbols('a b c d') + + # Expression before applying to_anf + x = Xor(c, And(a, b), And(a, c)) + y = Xor(a, b, And(a, c)) + + # Applying to_anf + result = Xor(Xor(d, And(x, y)), And(x, y)) + result_anf = to_anf(Xor(to_anf(Xor(d, And(x, y))), And(x, y))) + + assert result_anf == d + assert result == d diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_dimacs.py b/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_dimacs.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9a51a39d33fb807688614cb5809b621ce21a2c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_dimacs.py @@ -0,0 +1,234 @@ +"""Various tests on satisfiability using dimacs cnf file syntax +You can find lots of cnf files in +ftp://dimacs.rutgers.edu/pub/challenge/satisfiability/benchmarks/cnf/ +""" + +from sympy.logic.utilities.dimacs import load +from sympy.logic.algorithms.dpll import dpll_satisfiable + + +def test_f1(): + assert bool(dpll_satisfiable(load(f1))) + + +def test_f2(): + assert bool(dpll_satisfiable(load(f2))) + + +def test_f3(): + assert bool(dpll_satisfiable(load(f3))) + + +def test_f4(): + assert not bool(dpll_satisfiable(load(f4))) + + +def test_f5(): + assert bool(dpll_satisfiable(load(f5))) + +f1 = """c simple example +c Resolution: SATISFIABLE +c +p cnf 3 2 +1 -3 0 +2 3 -1 0 +""" + + +f2 = """c an example from Quinn's text, 16 variables and 18 clauses. +c Resolution: SATISFIABLE +c +p cnf 16 18 + 1 2 0 + -2 -4 0 + 3 4 0 + -4 -5 0 + 5 -6 0 + 6 -7 0 + 6 7 0 + 7 -16 0 + 8 -9 0 + -8 -14 0 + 9 10 0 + 9 -10 0 +-10 -11 0 + 10 12 0 + 11 12 0 + 13 14 0 + 14 -15 0 + 15 16 0 +""" + +f3 = """c +p cnf 6 9 +-1 0 +-3 0 +2 -1 0 +2 -4 0 +5 -4 0 +-1 -3 0 +-4 -6 0 +1 3 -2 0 +4 6 -2 -5 0 +""" + +f4 = """c +c file: hole6.cnf [http://people.sc.fsu.edu/~jburkardt/data/cnf/hole6.cnf] +c +c SOURCE: John Hooker (jh38+@andrew.cmu.edu) +c +c DESCRIPTION: Pigeon hole problem of placing n (for file 'holen.cnf') pigeons +c in n+1 holes without placing 2 pigeons in the same hole +c +c NOTE: Part of the collection at the Forschungsinstitut fuer +c anwendungsorientierte Wissensverarbeitung in Ulm Germany. +c +c NOTE: Not satisfiable +c +p cnf 42 133 +-1 -7 0 +-1 -13 0 +-1 -19 0 +-1 -25 0 +-1 -31 0 +-1 -37 0 +-7 -13 0 +-7 -19 0 +-7 -25 0 +-7 -31 0 +-7 -37 0 +-13 -19 0 +-13 -25 0 +-13 -31 0 +-13 -37 0 +-19 -25 0 +-19 -31 0 +-19 -37 0 +-25 -31 0 +-25 -37 0 +-31 -37 0 +-2 -8 0 +-2 -14 0 +-2 -20 0 +-2 -26 0 +-2 -32 0 +-2 -38 0 +-8 -14 0 +-8 -20 0 +-8 -26 0 +-8 -32 0 +-8 -38 0 +-14 -20 0 +-14 -26 0 +-14 -32 0 +-14 -38 0 +-20 -26 0 +-20 -32 0 +-20 -38 0 +-26 -32 0 +-26 -38 0 +-32 -38 0 +-3 -9 0 +-3 -15 0 +-3 -21 0 +-3 -27 0 +-3 -33 0 +-3 -39 0 +-9 -15 0 +-9 -21 0 +-9 -27 0 +-9 -33 0 +-9 -39 0 +-15 -21 0 +-15 -27 0 +-15 -33 0 +-15 -39 0 +-21 -27 0 +-21 -33 0 +-21 -39 0 +-27 -33 0 +-27 -39 0 +-33 -39 0 +-4 -10 0 +-4 -16 0 +-4 -22 0 +-4 -28 0 +-4 -34 0 +-4 -40 0 +-10 -16 0 +-10 -22 0 +-10 -28 0 +-10 -34 0 +-10 -40 0 +-16 -22 0 +-16 -28 0 +-16 -34 0 +-16 -40 0 +-22 -28 0 +-22 -34 0 +-22 -40 0 +-28 -34 0 +-28 -40 0 +-34 -40 0 +-5 -11 0 +-5 -17 0 +-5 -23 0 +-5 -29 0 +-5 -35 0 +-5 -41 0 +-11 -17 0 +-11 -23 0 +-11 -29 0 +-11 -35 0 +-11 -41 0 +-17 -23 0 +-17 -29 0 +-17 -35 0 +-17 -41 0 +-23 -29 0 +-23 -35 0 +-23 -41 0 +-29 -35 0 +-29 -41 0 +-35 -41 0 +-6 -12 0 +-6 -18 0 +-6 -24 0 +-6 -30 0 +-6 -36 0 +-6 -42 0 +-12 -18 0 +-12 -24 0 +-12 -30 0 +-12 -36 0 +-12 -42 0 +-18 -24 0 +-18 -30 0 +-18 -36 0 +-18 -42 0 +-24 -30 0 +-24 -36 0 +-24 -42 0 +-30 -36 0 +-30 -42 0 +-36 -42 0 + 6 5 4 3 2 1 0 + 12 11 10 9 8 7 0 + 18 17 16 15 14 13 0 + 24 23 22 21 20 19 0 + 30 29 28 27 26 25 0 + 36 35 34 33 32 31 0 + 42 41 40 39 38 37 0 +""" + +f5 = """c simple example requiring variable selection +c +c NOTE: Satisfiable +c +p cnf 5 5 +1 2 3 0 +1 -2 3 0 +4 5 -3 0 +1 -4 -3 0 +-1 -5 0 +""" diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_inference.py b/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ff37b1b104f6f106ec5df7809fd34959bce35917 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_inference.py @@ -0,0 +1,396 @@ +"""For more tests on satisfiability, see test_dimacs""" + +from sympy.assumptions.ask import Q +from sympy.core.symbol import symbols +from sympy.core.relational import Unequality +from sympy.logic.boolalg import And, Or, Implies, Equivalent, true, false +from sympy.logic.inference import literal_symbol, \ + pl_true, satisfiable, valid, entails, PropKB +from sympy.logic.algorithms.dpll import dpll, dpll_satisfiable, \ + find_pure_symbol, find_unit_clause, unit_propagate, \ + find_pure_symbol_int_repr, find_unit_clause_int_repr, \ + unit_propagate_int_repr +from sympy.logic.algorithms.dpll2 import dpll_satisfiable as dpll2_satisfiable + +from sympy.logic.algorithms.z3_wrapper import z3_satisfiable +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.logic.tests.test_lra_theory import make_random_problem +from sympy.core.random import randint + +from sympy.testing.pytest import raises, skip +from sympy.external import import_module + + +def test_literal(): + A, B = symbols('A,B') + assert literal_symbol(True) is True + assert literal_symbol(False) is False + assert literal_symbol(A) is A + assert literal_symbol(~A) is A + + +def test_find_pure_symbol(): + A, B, C = symbols('A,B,C') + assert find_pure_symbol([A], [A]) == (A, True) + assert find_pure_symbol([A, B], [~A | B, ~B | A]) == (None, None) + assert find_pure_symbol([A, B, C], [ A | ~B, ~B | ~C, C | A]) == (A, True) + assert find_pure_symbol([A, B, C], [~A | B, B | ~C, C | A]) == (B, True) + assert find_pure_symbol([A, B, C], [~A | ~B, ~B | ~C, C | A]) == (B, False) + assert find_pure_symbol( + [A, B, C], [~A | B, ~B | ~C, C | A]) == (None, None) + + +def test_find_pure_symbol_int_repr(): + assert find_pure_symbol_int_repr([1], [{1}]) == (1, True) + assert find_pure_symbol_int_repr([1, 2], + [{-1, 2}, {-2, 1}]) == (None, None) + assert find_pure_symbol_int_repr([1, 2, 3], + [{1, -2}, {-2, -3}, {3, 1}]) == (1, True) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, 2}, {2, -3}, {3, 1}]) == (2, True) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, -2}, {-2, -3}, {3, 1}]) == (2, False) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, 2}, {-2, -3}, {3, 1}]) == (None, None) + + +def test_unit_clause(): + A, B, C = symbols('A,B,C') + assert find_unit_clause([A], {}) == (A, True) + assert find_unit_clause([A, ~A], {}) == (A, True) # Wrong ?? + assert find_unit_clause([A | B], {A: True}) == (B, True) + assert find_unit_clause([A | B], {B: True}) == (A, True) + assert find_unit_clause( + [A | B | C, B | ~C, A | ~B], {A: True}) == (B, False) + assert find_unit_clause([A | B | C, B | ~C, A | B], {A: True}) == (B, True) + assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True) + + +def test_unit_clause_int_repr(): + assert find_unit_clause_int_repr(map(set, [[1]]), {}) == (1, True) + assert find_unit_clause_int_repr(map(set, [[1], [-1]]), {}) == (1, True) + assert find_unit_clause_int_repr([{1, 2}], {1: True}) == (2, True) + assert find_unit_clause_int_repr([{1, 2}], {2: True}) == (1, True) + assert find_unit_clause_int_repr(map(set, + [[1, 2, 3], [2, -3], [1, -2]]), {1: True}) == (2, False) + assert find_unit_clause_int_repr(map(set, + [[1, 2, 3], [3, -3], [1, 2]]), {1: True}) == (2, True) + + A, B, C = symbols('A,B,C') + assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True) + + +def test_unit_propagate(): + A, B, C = symbols('A,B,C') + assert unit_propagate([A | B], A) == [] + assert unit_propagate([A | B, ~A | C, ~C | B, A], A) == [C, ~C | B, A] + + +def test_unit_propagate_int_repr(): + assert unit_propagate_int_repr([{1, 2}], 1) == [] + assert unit_propagate_int_repr(map(set, + [[1, 2], [-1, 3], [-3, 2], [1]]), 1) == [{3}, {-3, 2}] + + +def test_dpll(): + """This is also tested in test_dimacs""" + A, B, C = symbols('A,B,C') + assert dpll([A | B], [A, B], {A: True, B: True}) == {A: True, B: True} + + +def test_dpll_satisfiable(): + A, B, C = symbols('A,B,C') + assert dpll_satisfiable( A & ~A ) is False + assert dpll_satisfiable( A & ~B ) == {A: True, B: False} + assert dpll_satisfiable( + A | B ) in ({A: True}, {B: True}, {A: True, B: True}) + assert dpll_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert dpll_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False}, + {A: True, C: True}, {B: True, C: True}) + assert dpll_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert dpll_satisfiable( (A | B) & (A >> B) ) == {B: True} + assert dpll_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert dpll_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + + +def test_dpll2_satisfiable(): + A, B, C = symbols('A,B,C') + assert dpll2_satisfiable( A & ~A ) is False + assert dpll2_satisfiable( A & ~B ) == {A: True, B: False} + assert dpll2_satisfiable( + A | B ) in ({A: True}, {B: True}, {A: True, B: True}) + assert dpll2_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert dpll2_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}) + assert dpll2_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert dpll2_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert dpll2_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert dpll2_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + + +def test_minisat22_satisfiable(): + A, B, C = symbols('A,B,C') + minisat22_satisfiable = lambda expr: satisfiable(expr, algorithm="minisat22") + assert minisat22_satisfiable( A & ~A ) is False + assert minisat22_satisfiable( A & ~B ) == {A: True, B: False} + assert minisat22_satisfiable( + A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False}) + assert minisat22_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False}) + assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + +def test_minisat22_minimal_satisfiable(): + A, B, C = symbols('A,B,C') + minisat22_satisfiable = lambda expr, minimal=True: satisfiable(expr, algorithm="minisat22", minimal=True) + assert minisat22_satisfiable( A & ~A ) is False + assert minisat22_satisfiable( A & ~B ) == {A: True, B: False} + assert minisat22_satisfiable( + A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False}) + assert minisat22_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False}) + assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + g = satisfiable((A | B | C),algorithm="minisat22",minimal=True,all_models=True) + sol = next(g) + first_solution = {key for key, value in sol.items() if value} + sol=next(g) + second_solution = {key for key, value in sol.items() if value} + sol=next(g) + third_solution = {key for key, value in sol.items() if value} + assert not first_solution <= second_solution + assert not second_solution <= third_solution + assert not first_solution <= third_solution + +def test_satisfiable(): + A, B, C = symbols('A,B,C') + assert satisfiable(A & (A >> B) & ~B) is False + + +def test_valid(): + A, B, C = symbols('A,B,C') + assert valid(A >> (B >> A)) is True + assert valid((A >> (B >> C)) >> ((A >> B) >> (A >> C))) is True + assert valid((~B >> ~A) >> (A >> B)) is True + assert valid(A | B | C) is False + assert valid(A >> B) is False + + +def test_pl_true(): + A, B, C = symbols('A,B,C') + assert pl_true(True) is True + assert pl_true( A & B, {A: True, B: True}) is True + assert pl_true( A | B, {A: True}) is True + assert pl_true( A | B, {B: True}) is True + assert pl_true( A | B, {A: None, B: True}) is True + assert pl_true( A >> B, {A: False}) is True + assert pl_true( A | B | ~C, {A: False, B: True, C: True}) is True + assert pl_true(Equivalent(A, B), {A: False, B: False}) is True + + # test for false + assert pl_true(False) is False + assert pl_true( A & B, {A: False, B: False}) is False + assert pl_true( A & B, {A: False}) is False + assert pl_true( A & B, {B: False}) is False + assert pl_true( A | B, {A: False, B: False}) is False + + #test for None + assert pl_true(B, {B: None}) is None + assert pl_true( A & B, {A: True, B: None}) is None + assert pl_true( A >> B, {A: True, B: None}) is None + assert pl_true(Equivalent(A, B), {A: None}) is None + assert pl_true(Equivalent(A, B), {A: True, B: None}) is None + + # Test for deep + assert pl_true(A | B, {A: False}, deep=True) is None + assert pl_true(~A & ~B, {A: False}, deep=True) is None + assert pl_true(A | B, {A: False, B: False}, deep=True) is False + assert pl_true(A & B & (~A | ~B), {A: True}, deep=True) is False + assert pl_true((C >> A) >> (B >> A), {C: True}, deep=True) is True + + +def test_pl_true_wrong_input(): + from sympy.core.numbers import pi + raises(ValueError, lambda: pl_true('John Cleese')) + raises(ValueError, lambda: pl_true(42 + pi + pi ** 2)) + raises(ValueError, lambda: pl_true(42)) + + +def test_entails(): + A, B, C = symbols('A, B, C') + assert entails(A, [A >> B, ~B]) is False + assert entails(B, [Equivalent(A, B), A]) is True + assert entails((A >> B) >> (~A >> ~B)) is False + assert entails((A >> B) >> (~B >> ~A)) is True + + +def test_PropKB(): + A, B, C = symbols('A,B,C') + kb = PropKB() + assert kb.ask(A >> B) is False + assert kb.ask(A >> (B >> A)) is True + kb.tell(A >> B) + kb.tell(B >> C) + assert kb.ask(A) is False + assert kb.ask(B) is False + assert kb.ask(C) is False + assert kb.ask(~A) is False + assert kb.ask(~B) is False + assert kb.ask(~C) is False + assert kb.ask(A >> C) is True + kb.tell(A) + assert kb.ask(A) is True + assert kb.ask(B) is True + assert kb.ask(C) is True + assert kb.ask(~C) is False + kb.retract(A) + assert kb.ask(C) is False + + +def test_propKB_tolerant(): + """"tolerant to bad input""" + kb = PropKB() + A, B, C = symbols('A,B,C') + assert kb.ask(B) is False + +def test_satisfiable_non_symbols(): + x, y = symbols('x y') + assumptions = Q.zero(x*y) + facts = Implies(Q.zero(x*y), Q.zero(x) | Q.zero(y)) + query = ~Q.zero(x) & ~Q.zero(y) + refutations = [ + {Q.zero(x): True, Q.zero(x*y): True}, + {Q.zero(y): True, Q.zero(x*y): True}, + {Q.zero(x): True, Q.zero(y): True, Q.zero(x*y): True}, + {Q.zero(x): True, Q.zero(y): False, Q.zero(x*y): True}, + {Q.zero(x): False, Q.zero(y): True, Q.zero(x*y): True}] + assert not satisfiable(And(assumptions, facts, query), algorithm='dpll') + assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll') in refutations + assert not satisfiable(And(assumptions, facts, query), algorithm='dpll2') + assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll2') in refutations + +def test_satisfiable_bool(): + from sympy.core.singleton import S + assert satisfiable(true) == {true: true} + assert satisfiable(S.true) == {true: true} + assert satisfiable(false) is False + assert satisfiable(S.false) is False + + +def test_satisfiable_all_models(): + from sympy.abc import A, B + assert next(satisfiable(False, all_models=True)) is False + assert list(satisfiable((A >> ~A) & A, all_models=True)) == [False] + assert list(satisfiable(True, all_models=True)) == [{true: true}] + + models = [{A: True, B: False}, {A: False, B: True}] + result = satisfiable(A ^ B, all_models=True) + models.remove(next(result)) + models.remove(next(result)) + raises(StopIteration, lambda: next(result)) + assert not models + + assert list(satisfiable(Equivalent(A, B), all_models=True)) == \ + [{A: False, B: False}, {A: True, B: True}] + + models = [{A: False, B: False}, {A: False, B: True}, {A: True, B: True}] + for model in satisfiable(A >> B, all_models=True): + models.remove(model) + assert not models + + # This is a santiy test to check that only the required number + # of solutions are generated. The expr below has 2**100 - 1 models + # which would time out the test if all are generated at once. + from sympy.utilities.iterables import numbered_symbols + from sympy.logic.boolalg import Or + sym = numbered_symbols() + X = [next(sym) for i in range(100)] + result = satisfiable(Or(*X), all_models=True) + for i in range(10): + assert next(result) + + +def test_z3(): + z3 = import_module("z3") + + if not z3: + skip("z3 not installed.") + A, B, C = symbols('A,B,C') + x, y, z = symbols('x,y,z') + assert z3_satisfiable((x >= 2) & (x < 1)) is False + assert z3_satisfiable( A & ~A ) is False + + model = z3_satisfiable(A & (~A | B | C)) + assert bool(model) is True + assert model[A] is True + + # test nonlinear function + assert z3_satisfiable((x ** 2 >= 2) & (x < 1) & (x > -1)) is False + + +def test_z3_vs_lra_dpll2(): + z3 = import_module("z3") + if z3 is None: + skip("z3 not installed.") + + def boolean_formula_to_encoded_cnf(bf): + cnf = CNF.from_prop(bf) + enc = EncodedCNF() + enc.from_cnf(cnf) + return enc + + def make_random_cnf(num_clauses=5, num_constraints=10, num_var=2): + assert num_clauses <= num_constraints + constraints = make_random_problem(num_variables=num_var, num_constraints=num_constraints, rational=False) + clauses = [[cons] for cons in constraints[:num_clauses]] + for cons in constraints[num_clauses:]: + if isinstance(cons, Unequality): + cons = ~cons + i = randint(0, num_clauses-1) + clauses[i].append(cons) + + clauses = [Or(*clause) for clause in clauses] + cnf = And(*clauses) + return boolean_formula_to_encoded_cnf(cnf) + + lra_dpll2_satisfiable = lambda x: dpll2_satisfiable(x, use_lra_theory=True) + + for _ in range(50): + cnf = make_random_cnf(num_clauses=10, num_constraints=15, num_var=2) + + try: + z3_sat = z3_satisfiable(cnf) + except z3.z3types.Z3Exception: + continue + + lra_dpll2_sat = lra_dpll2_satisfiable(cnf) is not False + + assert z3_sat == lra_dpll2_sat + +def test_issue_27733(): + x, y = symbols('x,y') + clauses = [[1, -3, -2], [5, 7, -8, -6, -4], [-10, -9, 10, 11, -4], [-12, 13, 14], [-10, 9, -6, 11, -4], + [16, -15, 18, -19, -17], [11, -6, 10, -9], [9, 11, -10, -9], [2, -3, -1], [-13, 12], [-15, 3, -17], + [-16, -15, 19, -17], [-6, -9, 10, 11, -4], [20, -1, -2], [-23, -22, -21], [10, 11, -10, -9], + [9, 11, -4, -10], [24, -6, -4], [-14, 12], [-10, -9, 9, -6, 11], [25, -27, -26], [-15, 19, -18, -17], + [5, 8, -7, -6, -4], [-30, -29, 28], [12], [14]] + + encoding = {Q.gt(y, i): i for i in range(1, 31) if i != 11 and i != 12} + encoding[Q.gt(x, 0)] = 11 + encoding[Q.lt(x, 0)] = 12 + + cnf = EncodedCNF(clauses, encoding) + assert satisfiable(cnf, use_lra_theory=True) is False diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_lra_theory.py b/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_lra_theory.py new file mode 100644 index 0000000000000000000000000000000000000000..207a3c5ba2c1b16ee5323382deee0863a5dfb595 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/tests/test_lra_theory.py @@ -0,0 +1,440 @@ +from sympy.core.numbers import Rational, I, oo +from sympy.core.relational import Eq +from sympy.core.symbol import symbols +from sympy.core.singleton import S +from sympy.matrices.dense import Matrix +from sympy.matrices.dense import randMatrix +from sympy.assumptions.ask import Q +from sympy.logic.boolalg import And +from sympy.abc import x, y, z +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.functions.elementary.trigonometric import cos +from sympy.external import import_module + +from sympy.logic.algorithms.lra_theory import LRASolver, UnhandledInput, LRARational, HANDLE_NEGATION +from sympy.core.random import random, choice, randint +from sympy.core.sympify import sympify +from sympy.ntheory.generate import randprime +from sympy.core.relational import StrictLessThan, StrictGreaterThan +import itertools + +from sympy.testing.pytest import raises, XFAIL, skip + +def make_random_problem(num_variables=2, num_constraints=2, sparsity=.1, rational=True, + disable_strict = False, disable_nonstrict=False, disable_equality=False): + def rand(sparsity=sparsity): + if random() < sparsity: + return sympify(0) + if rational: + int1, int2 = [randprime(0, 50) for _ in range(2)] + return Rational(int1, int2) * choice([-1, 1]) + else: + return randint(1, 10) * choice([-1, 1]) + + variables = symbols('x1:%s' % (num_variables + 1)) + constraints = [] + for _ in range(num_constraints): + lhs, rhs = sum(rand() * x for x in variables), rand(sparsity=0) # sparsity=0 bc of bug with smtlib_code + options = [] + if not disable_equality: + options += [Eq(lhs, rhs)] + if not disable_nonstrict: + options += [lhs <= rhs, lhs >= rhs] + if not disable_strict: + options += [lhs < rhs, lhs > rhs] + + constraints.append(choice(options)) + + return constraints + +def check_if_satisfiable_with_z3(constraints): + from sympy.external.importtools import import_module + from sympy.printing.smtlib import smtlib_code + from sympy.logic.boolalg import And + boolean_formula = And(*constraints) + z3 = import_module("z3") + if z3: + smtlib_string = smtlib_code(boolean_formula) + s = z3.Solver() + s.from_string(smtlib_string) + res = str(s.check()) + if res == 'sat': + return True + elif res == 'unsat': + return False + else: + raise ValueError(f"z3 was not able to check the satisfiability of {boolean_formula}") + +def find_rational_assignment(constr, assignment, iter=20): + eps = sympify(1) + + for _ in range(iter): + assign = {key: val[0] + val[1]*eps for key, val in assignment.items()} + try: + for cons in constr: + assert cons.subs(assign) == True + return assign + except AssertionError: + eps = eps/2 + + return None + +def boolean_formula_to_encoded_cnf(bf): + cnf = CNF.from_prop(bf) + enc = EncodedCNF() + enc.from_cnf(cnf) + return enc + + +def test_from_encoded_cnf(): + s1, s2 = symbols("s1 s2") + + # Test preprocessing + # Example is from section 3 of paper. + phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6)) & (Eq(x + y, 2) | (x + 2 * y - z > 4)) + enc = boolean_formula_to_encoded_cnf(phi) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert lra.A.shape == (2, 5) + assert str(lra.slack) == '[_s1, _s2]' + assert str(lra.nonslack) == '[x, y, z]' + assert lra.A == Matrix([[ 1, 1, 0, -1, 0], + [-1, -2, 1, 0, -1]]) + assert {(str(b.var), b.bound, b.upper, b.equality, b.strict) for b in lra.enc_to_boundary.values()} == {('_s1', 2, None, True, False), + ('_s1', 2, True, False, False), + ('_s2', -4, True, False, True), + ('_s2', -6, True, False, False), + ('x', 0, False, False, False)} + + +def test_problem(): + from sympy.logic.algorithms.lra_theory import LRASolver + from sympy.assumptions.cnf import CNF, EncodedCNF + cons = [-2 * x - 2 * y >= 7, -9 * y >= 7, -6 * y >= 5] + cnf = CNF().from_prop(And(*cons)) + enc = EncodedCNF() + enc.from_cnf(cnf) + lra, _ = LRASolver.from_encoded_cnf(enc) + lra.assert_lit(1) + lra.assert_lit(2) + lra.assert_lit(3) + is_sat, assignment = lra.check() + assert is_sat is True + + +def test_random_problems(): + z3 = import_module("z3") + if z3 is None: + skip("z3 is not installed") + + special_cases = []; x1, x2, x3 = symbols("x1 x2 x3") + special_cases.append([x1 - 3 * x2 <= -5, 6 * x1 + 4 * x2 <= 0, -7 * x1 + 3 * x2 <= 3]) + special_cases.append([-3 * x1 >= 3, Eq(4 * x1, -1)]) + special_cases.append([-4 * x1 < 4, 6 * x1 <= -6]) + special_cases.append([-3 * x2 >= 7, 6 * x1 <= -5, -3 * x2 <= -4]) + special_cases.append([x + y >= 2, x + y <= 1]) + special_cases.append([x >= 0, x + y <= 2, x + 2 * y - z >= 6]) # from paper example + special_cases.append([-2 * x1 - 2 * x2 >= 7, -9 * x1 >= 7, -6 * x1 >= 5]) + special_cases.append([2 * x1 > -3, -9 * x1 < -6, 9 * x1 <= 6]) + special_cases.append([-2*x1 < -4, 9*x1 > -9]) + special_cases.append([-6*x1 >= -1, -8*x1 + x2 >= 5, -8*x1 + 7*x2 < 4, x1 > 7]) + special_cases.append([Eq(x1, 2), Eq(5*x1, -2), Eq(-7*x2, -6), Eq(9*x1 + 10*x2, 9)]) + special_cases.append([Eq(3*x1, 6), Eq(x1 - 8*x2, -9), Eq(-7*x1 + 5*x2, 3), Eq(3*x2, 7)]) + special_cases.append([-4*x1 < 4, 6*x1 <= -6]) + special_cases.append([-3*x1 + 8*x2 >= -8, -10*x2 > 9, 8*x1 - 4*x2 < 8, 10*x1 - 9*x2 >= -9]) + special_cases.append([x1 + 5*x2 >= -6, 9*x1 - 3*x2 >= -9, 6*x1 + 6*x2 < -10, -3*x1 + 3*x2 < -7]) + special_cases.append([-9*x1 < 7, -5*x1 - 7*x2 < -1, 3*x1 + 7*x2 > 1, -6*x1 - 6*x2 > 9]) + special_cases.append([9*x1 - 6*x2 >= -7, 9*x1 + 4*x2 < -8, -7*x2 <= 1, 10*x2 <= -7]) + + feasible_count = 0 + for i in range(50): + if i % 8 == 0: + constraints = make_random_problem(num_variables=1, num_constraints=2, rational=False) + elif i % 8 == 1: + constraints = make_random_problem(num_variables=2, num_constraints=4, rational=False, disable_equality=True, + disable_nonstrict=True) + elif i % 8 == 2: + constraints = make_random_problem(num_variables=2, num_constraints=4, rational=False, disable_strict=True) + elif i % 8 == 3: + constraints = make_random_problem(num_variables=3, num_constraints=12, rational=False) + else: + constraints = make_random_problem(num_variables=3, num_constraints=6, rational=False) + + if i < len(special_cases): + constraints = special_cases[i] + + if False in constraints or True in constraints: + continue + + phi = And(*constraints) + if phi == False: + continue + cnf = CNF.from_prop(phi); enc = EncodedCNF() + enc.from_cnf(cnf) + assert all(0 not in clause for clause in enc.data) + + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + s_subs = lra.s_subs + + lra.run_checks = True + s_subs_rev = {value: key for key, value in s_subs.items()} + lits = {lit for clause in enc.data for lit in clause} + + bounds = [(lra.enc_to_boundary[l], l) for l in lits if l in lra.enc_to_boundary] + bounds = sorted(bounds, key=lambda x: (str(x[0].var), x[0].bound, str(x[0].upper))) # to remove nondeterminism + + for b, l in bounds: + if lra.result and lra.result[0] == False: + break + lra.assert_lit(l) + + feasible = lra.check() + + if feasible[0] == True: + feasible_count += 1 + assert check_if_satisfiable_with_z3(constraints) is True + cons_funcs = [cons.func for cons in constraints] + assignment = feasible[1] + assignment = {key.var : value for key, value in assignment.items()} + if not (StrictLessThan in cons_funcs or StrictGreaterThan in cons_funcs): + assignment = {key: value[0] for key, value in assignment.items()} + for cons in constraints: + assert cons.subs(assignment) == True + + else: + rat_assignment = find_rational_assignment(constraints, assignment) + assert rat_assignment is not None + else: + assert check_if_satisfiable_with_z3(constraints) is False + + conflict = feasible[1] + assert len(conflict) >= 2 + conflict = {lra.enc_to_boundary[-l].get_inequality() for l in conflict} + conflict = {clause.subs(s_subs_rev) for clause in conflict} + assert check_if_satisfiable_with_z3(conflict) is False + + # check that conflict clause is probably minimal + for subset in itertools.combinations(conflict, len(conflict)-1): + assert check_if_satisfiable_with_z3(subset) is True + + +@XFAIL +def test_pos_neg_zero(): + bf = Q.positive(x) & Q.negative(x) & Q.zero(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.lt(x, -1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.zero(x) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.zero(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == True + + +@XFAIL +def test_pos_neg_infinite(): + bf = Q.positive_infinite(x) & Q.lt(x, 10000000) & Q.positive_infinite(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + + bf = Q.positive_infinite(x) & Q.gt(x, 10000000) & Q.positive_infinite(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == True + + bf = Q.positive_infinite(x) & Q.negative_infinite(x) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + +def test_binrel_evaluation(): + bf = Q.gt(3, 2) + enc = boolean_formula_to_encoded_cnf(bf) + lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert len(lra.enc_to_boundary) == 0 + assert conflicts == [[1]] + + bf = Q.lt(3, 2) + enc = boolean_formula_to_encoded_cnf(bf) + lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert len(lra.enc_to_boundary) == 0 + assert conflicts == [[-1]] + + +def test_negation(): + assert HANDLE_NEGATION is True + bf = Q.gt(x, 1) & ~Q.gt(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + assert sorted(lra.check()[1]) in [[-1, 2], [-2, 1]] + + bf = ~Q.gt(x, 1) & ~Q.lt(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == True + + bf = ~Q.gt(x, 0) & ~Q.lt(x, 1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = ~Q.gt(x, 0) & ~Q.le(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = ~Q.le(x+y, 2) & ~Q.ge(x-y, 2) & ~Q.ge(y, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + assert len(lra.check()[1]) == 3 + assert all(i > 0 for i in lra.check()[1]) + + +def test_unhandled_input(): + nan = S.NaN + bf = Q.gt(3, nan) & Q.gt(x, nan) + enc = boolean_formula_to_encoded_cnf(bf) + raises(ValueError, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, I) & Q.gt(x, I) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, float("inf")) & Q.gt(x, float("inf")) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, oo) & Q.gt(x, oo) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + # test non-linearity + bf = Q.gt(x**2 + x, 2) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(cos(x) + x, 2) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + +@XFAIL +def test_infinite_strict_inequalities(): + # Extensive testing of the interaction between strict inequalities + # and constraints containing infinity is needed because + # the paper's rule for strict inequalities don't work when + # infinite numbers are allowed. Using the paper's rules you + # can end up with situations where oo + delta > oo is considered + # True when oo + delta should be equal to oo. + # See https://math.stackexchange.com/questions/4757069/can-this-method-of-converting-strict-inequalities-to-equisatisfiable-nonstrict-i + bf = (-x - y >= -float("inf")) & (x > 0) & (y >= float("inf")) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in sorted(enc.encoding.values()): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == True + + +def test_pivot(): + for _ in range(10): + m = randMatrix(5) + rref = m.rref() + for _ in range(5): + i, j = randint(0, 4), randint(0, 4) + if m[i, j] != 0: + assert LRASolver._pivot(m, i, j).rref() == rref + + +def test_reset_bounds(): + bf = Q.ge(x, 1) & Q.lt(x, 1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + lra.reset_bounds() + assert lra.check()[0] == True + for var in lra.all_var: + assert var.upper == LRARational(float("inf"), 0) + assert var.upper_from_eq == False + assert var.upper_from_neg == False + assert var.lower == LRARational(-float("inf"), 0) + assert var.lower_from_eq == False + assert var.lower_from_neg == False + assert var.assign == LRARational(0, 0) + assert var.var is not None + assert var.col_idx is not None + + +def test_empty_cnf(): + cnf = CNF() + enc = EncodedCNF() + enc.from_cnf(cnf) + lra, conflict = LRASolver.from_encoded_cnf(enc) + assert len(conflict) == 0 + assert lra.check() == (True, {}) diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/utilities/__init__.py b/.venv/lib/python3.13/site-packages/sympy/logic/utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3526c3e53d624bc70afe2df05f123c835781364c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/utilities/__init__.py @@ -0,0 +1,3 @@ +from .dimacs import load_file + +__all__ = ['load_file'] diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/utilities/dimacs.py b/.venv/lib/python3.13/site-packages/sympy/logic/utilities/dimacs.py new file mode 100644 index 0000000000000000000000000000000000000000..51302d8052c8ed8443239c1e21a2f063cf34e4ab --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/utilities/dimacs.py @@ -0,0 +1,69 @@ +"""For reading in DIMACS file format + +www.cs.ubc.ca/~hoos/SATLIB/Benchmarks/SAT/satformat.ps + +""" + +from sympy.core import Symbol +from sympy.logic.boolalg import And, Or +import re +from pathlib import Path + + +def load(s): + """Loads a boolean expression from a string. + + Examples + ======== + + >>> from sympy.logic.utilities.dimacs import load + >>> load('1') + cnf_1 + >>> load('1 2') + cnf_1 | cnf_2 + >>> load('1 \\n 2') + cnf_1 & cnf_2 + >>> load('1 2 \\n 3') + cnf_3 & (cnf_1 | cnf_2) + """ + clauses = [] + + lines = s.split('\n') + + pComment = re.compile(r'c.*') + pStats = re.compile(r'p\s*cnf\s*(\d*)\s*(\d*)') + + while len(lines) > 0: + line = lines.pop(0) + + # Only deal with lines that aren't comments + if not pComment.match(line): + m = pStats.match(line) + + if not m: + nums = line.rstrip('\n').split(' ') + list = [] + for lit in nums: + if lit != '': + if int(lit) == 0: + continue + num = abs(int(lit)) + sign = True + if int(lit) < 0: + sign = False + + if sign: + list.append(Symbol("cnf_%s" % num)) + else: + list.append(~Symbol("cnf_%s" % num)) + + if len(list) > 0: + clauses.append(Or(*list)) + + return And(*clauses) + + +def load_file(location): + """Loads a boolean expression from a file.""" + s = Path(location).read_text() + return load(s) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_aesaracode.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_aesaracode.py new file mode 100644 index 0000000000000000000000000000000000000000..13308af65b382e77de33302bcd75344d2b00adbf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_aesaracode.py @@ -0,0 +1,633 @@ +""" +Important note on tests in this module - the Aesara printing functions use a +global cache by default, which means that tests using it will modify global +state and thus not be independent from each other. Instead of using the "cache" +keyword argument each time, this module uses the aesara_code_ and +aesara_function_ functions defined below which default to using a new, empty +cache instead. +""" + +import logging + +from sympy.external import import_module +from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy + +from sympy.utilities.exceptions import ignore_warnings + + +aesaralogger = logging.getLogger('aesara.configdefaults') +aesaralogger.setLevel(logging.CRITICAL) +aesara = import_module('aesara') +aesaralogger.setLevel(logging.WARNING) + + +if aesara: + import numpy as np + aet = aesara.tensor + from aesara.scalar.basic import ScalarType + from aesara.graph.basic import Variable + from aesara.tensor.var import TensorVariable + from aesara.tensor.elemwise import Elemwise, DimShuffle + from aesara.tensor.math import Dot + + from sympy.printing.aesaracode import true_divide + + xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz'] + Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ'] +else: + #bin/test will not execute any tests now + disabled = True + +import sympy as sy +from sympy.core.singleton import S +from sympy.abc import x, y, z, t +from sympy.printing.aesaracode import (aesara_code, dim_handling, + aesara_function) + + +# Default set of matrix symbols for testing - make square so we can both +# multiply and perform elementwise operations between them. +X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] + +# For testing AppliedUndef +f_t = sy.Function('f')(t) + + +def aesara_code_(expr, **kwargs): + """ Wrapper for aesara_code that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return aesara_code(expr, **kwargs) + +def aesara_function_(inputs, outputs, **kwargs): + """ Wrapper for aesara_function that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return aesara_function(inputs, outputs, **kwargs) + + +def fgraph_of(*exprs): + """ Transform SymPy expressions into Aesara Computation. + + Parameters + ========== + exprs + SymPy expressions + + Returns + ======= + aesara.graph.fg.FunctionGraph + """ + outs = list(map(aesara_code_, exprs)) + ins = list(aesara.graph.basic.graph_inputs(outs)) + ins, outs = aesara.graph.basic.clone(ins, outs) + return aesara.graph.fg.FunctionGraph(ins, outs) + + +def aesara_simplify(fgraph): + """ Simplify a Aesara Computation. + + Parameters + ========== + fgraph : aesara.graph.fg.FunctionGraph + + Returns + ======= + aesara.graph.fg.FunctionGraph + """ + mode = aesara.compile.get_default_mode().excluding("fusion") + fgraph = fgraph.clone() + mode.optimizer.rewrite(fgraph) + return fgraph + + +def theq(a, b): + """ Test two Aesara objects for equality. + + Also accepts numeric types and lists/tuples of supported types. + + Note - debugprint() has a bug where it will accept numeric types but does + not respect the "file" argument and in this case and instead prints the number + to stdout and returns an empty string. This can lead to tests passing where + they should fail because any two numbers will always compare as equal. To + prevent this we treat numbers as a separate case. + """ + numeric_types = (int, float, np.number) + a_is_num = isinstance(a, numeric_types) + b_is_num = isinstance(b, numeric_types) + + # Compare numeric types using regular equality + if a_is_num or b_is_num: + if not (a_is_num and b_is_num): + return False + + return a == b + + # Compare sequences element-wise + a_is_seq = isinstance(a, (tuple, list)) + b_is_seq = isinstance(b, (tuple, list)) + + if a_is_seq or b_is_seq: + if not (a_is_seq and b_is_seq) or type(a) != type(b): + return False + + return list(map(theq, a)) == list(map(theq, b)) + + # Otherwise, assume debugprint() can handle it + astr = aesara.printing.debugprint(a, file='str') + bstr = aesara.printing.debugprint(b, file='str') + + # Check for bug mentioned above + for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: + if argstr == '': + raise TypeError( + 'aesara.printing.debugprint(%s) returned empty string ' + '(%s is instance of %r)' + % (argname, argname, type(argval)) + ) + + return astr == bstr + + +def test_example_symbols(): + """ + Check that the example symbols in this module print to their Aesara + equivalents, as many of the other tests depend on this. + """ + assert theq(xt, aesara_code_(x)) + assert theq(yt, aesara_code_(y)) + assert theq(zt, aesara_code_(z)) + assert theq(Xt, aesara_code_(X)) + assert theq(Yt, aesara_code_(Y)) + assert theq(Zt, aesara_code_(Z)) + + +def test_Symbol(): + """ Test printing a Symbol to a aesara variable. """ + xx = aesara_code_(x) + assert isinstance(xx, Variable) + assert xx.broadcastable == () + assert xx.name == x.name + + xx2 = aesara_code_(x, broadcastables={x: (False,)}) + assert xx2.broadcastable == (False,) + assert xx2.name == x.name + +def test_MatrixSymbol(): + """ Test printing a MatrixSymbol to a aesara variable. """ + XX = aesara_code_(X) + assert isinstance(XX, TensorVariable) + assert XX.broadcastable == (False, False) + +@SKIP # TODO - this is currently not checked but should be implemented +def test_MatrixSymbol_wrong_dims(): + """ Test MatrixSymbol with invalid broadcastable. """ + bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] + for bc in bcs: + with raises(ValueError): + aesara_code_(X, broadcastables={X: bc}) + +def test_AppliedUndef(): + """ Test printing AppliedUndef instance, which works similarly to Symbol. """ + ftt = aesara_code_(f_t) + assert isinstance(ftt, TensorVariable) + assert ftt.broadcastable == () + assert ftt.name == 'f_t' + + +def test_add(): + expr = x + y + comp = aesara_code_(expr) + assert comp.owner.op == aesara.tensor.add + +def test_trig(): + assert theq(aesara_code_(sy.sin(x)), aet.sin(xt)) + assert theq(aesara_code_(sy.tan(x)), aet.tan(xt)) + +def test_many(): + """ Test printing a complex expression with multiple symbols. """ + expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) + comp = aesara_code_(expr) + expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt) + assert theq(comp, expected) + + +def test_dtype(): + """ Test specifying specific data types through the dtype argument. """ + for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: + assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype + + # "floatX" type + assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') + + # Type promotion + assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' + assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' + + +def test_broadcastables(): + """ Test the "broadcastables" argument when printing symbol-like objects. """ + + # No restrictions on shape + for s in [x, f_t]: + for bc in [(), (False,), (True,), (False, False), (True, False)]: + assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc + + # TODO - matrix broadcasting? + +def test_broadcasting(): + """ Test "broadcastable" attribute after applying element-wise binary op. """ + + expr = x + y + + cases = [ + [(), (), ()], + [(False,), (False,), (False,)], + [(True,), (False,), (False,)], + [(False, True), (False, False), (False, False)], + [(True, False), (False, False), (False, False)], + ] + + for bc1, bc2, bc3 in cases: + comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2}) + assert comp.broadcastable == bc3 + + +def test_MatMul(): + expr = X*Y*Z + expr_t = aesara_code_(expr) + assert isinstance(expr_t.owner.op, Dot) + assert theq(expr_t, Xt.dot(Yt).dot(Zt)) + +def test_Transpose(): + assert isinstance(aesara_code_(X.T).owner.op, DimShuffle) + +def test_MatAdd(): + expr = X+Y+Z + assert isinstance(aesara_code_(expr).owner.op, Elemwise) + + +def test_Rationals(): + assert theq(aesara_code_(sy.Integer(2) / 3), true_divide(2, 3)) + assert theq(aesara_code_(S.Half), true_divide(1, 2)) + +def test_Integers(): + assert aesara_code_(sy.Integer(3)) == 3 + +def test_factorial(): + n = sy.Symbol('n') + assert aesara_code_(sy.factorial(n)) + +def test_Derivative(): + with ignore_warnings(UserWarning): + simp = lambda expr: aesara_simplify(fgraph_of(expr)) + assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), + simp(aesara.grad(aet.sin(xt), xt))) + + +def test_aesara_function_simple(): + """ Test aesara_function() with single output. """ + f = aesara_function_([x, y], [x+y]) + assert f(2, 3) == 5 + +def test_aesara_function_multi(): + """ Test aesara_function() with multiple outputs. """ + f = aesara_function_([x, y], [x+y, x-y]) + o1, o2 = f(2, 3) + assert o1 == 5 + assert o2 == -1 + +def test_aesara_function_numpy(): + """ Test aesara_function() vs Numpy implementation. """ + f = aesara_function_([x, y], [x+y], dim=1, + dtypes={x: 'float64', y: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 + + f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, + dim=1) + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 + + +def test_aesara_function_matrix(): + m = sy.Matrix([[x, y], [z, x + y + z]]) + expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) + f = aesara_function_([x, y, z], [m]) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = aesara_function_([x, y, z], [m], scalar=True) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = aesara_function_([x, y, z], [m, m]) + assert isinstance(f(1.0, 2.0, 3.0), type([])) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) + +def test_dim_handling(): + assert dim_handling([x], dim=2) == {x: (False, False)} + assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), + y: (False, False)} + assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} + +def test_aesara_function_kwargs(): + """ + Test passing additional kwargs from aesara_function() to aesara.function(). + """ + import numpy as np + f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', + dtypes={x: 'float64', y: 'float64', z: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 + + f = aesara_function_([x, y, z], [x+y], + dtypes={x: 'float64', y: 'float64', z: 'float64'}, + dim=1, on_unused_input='ignore') + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + zz = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 + +def test_aesara_function_scalar(): + """ Test the "scalar" argument to aesara_function(). """ + from aesara.compile.function.types import Function + + args = [ + ([x, y], [x + y], None, [0]), # Single 0d output + ([X, Y], [X + Y], None, [2]), # Single 2d output + ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output + ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs + ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d + ] + + # Create and test functions with and without the scalar setting + for inputs, outputs, in_dims, out_dims in args: + for scalar in [False, True]: + + f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar) + + # Check the aesara_function attribute is set whether wrapped or not + assert isinstance(f.aesara_function, Function) + + # Feed in inputs of the appropriate size and get outputs + in_values = [ + np.ones([1 if bc else 5 for bc in i.type.broadcastable]) + for i in f.aesara_function.input_storage + ] + out_values = f(*in_values) + if not isinstance(out_values, list): + out_values = [out_values] + + # Check output types and shapes + assert len(out_dims) == len(out_values) + for d, value in zip(out_dims, out_values): + + if scalar and d == 0: + # Should have been converted to a scalar value + assert isinstance(value, np.number) + + else: + # Otherwise should be an array + assert isinstance(value, np.ndarray) + assert value.ndim == d + +def test_aesara_function_bad_kwarg(): + """ + Passing an unknown keyword argument to aesara_function() should raise an + exception. + """ + raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3)) + + +def test_slice(): + assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3) + + def theq_slice(s1, s2): + for attr in ['start', 'stop', 'step']: + a1 = getattr(s1, attr) + a2 = getattr(s2, attr) + if a1 is None or a2 is None: + if not (a1 is None or a2 is None): + return False + elif not theq(a1, a2): + return False + return True + + dtypes = {x: 'int32', y: 'int32'} + assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) + assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) + +def test_MatrixSlice(): + cache = {} + + n = sy.Symbol('n', integer=True) + X = sy.MatrixSymbol('X', n, n) + + Y = X[1:2:3, 4:5:6] + Yt = aesara_code_(Y, cache=cache) + + s = ScalarType('int64') + assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) + assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache) + # == doesn't work in Aesara like it does in SymPy. You have to use + # equals. + assert all(Yt.owner.inputs[i].data == i for i in range(1, 7)) + + k = sy.Symbol('k') + aesara_code_(k, dtypes={k: 'int32'}) + start, stop, step = 4, k, 2 + Y = X[start:stop:step] + Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'}) + # assert Yt.owner.op.idx_list[0].stop == kt + +def test_BlockMatrix(): + n = sy.Symbol('n', integer=True) + A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] + At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D)) + Block = sy.BlockMatrix([[A, B], [C, D]]) + Blockt = aesara_code_(Block) + solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)), + aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))] + assert any(theq(Blockt, solution) for solution in solutions) + +@SKIP +def test_BlockMatrix_Inverse_execution(): + k, n = 2, 4 + dtype = 'float32' + A = sy.MatrixSymbol('A', n, k) + B = sy.MatrixSymbol('B', n, n) + inputs = A, B + output = B.I*A + + cutsizes = {A: [(n//2, n//2), (k//2, k//2)], + B: [(n//2, n//2), (n//2, n//2)]} + cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] + cutoutput = output.subs(dict(zip(inputs, cutinputs))) + + dtypes = dict(zip(inputs, [dtype]*len(inputs))) + f = aesara_function_(inputs, [output], dtypes=dtypes, cache={}) + fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)], + dtypes=dtypes, cache={}) + + ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] + ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), + np.eye(n).astype(dtype)] + ninputs[1] += np.ones(B.shape)*1e-5 + + assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) + +def test_DenseMatrix(): + from aesara.tensor.basic import Join + + t = sy.Symbol('theta') + for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: + X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) + tX = aesara_code_(X) + assert isinstance(tX, TensorVariable) + assert isinstance(tX.owner.op, Join) + + +def test_cache_basic(): + """ Test single symbol-like objects are cached when printed by themselves. """ + + # Pairs of objects which should be considered equivalent with respect to caching + pairs = [ + (x, sy.Symbol('x')), + (X, sy.MatrixSymbol('X', *X.shape)), + (f_t, sy.Function('f')(sy.Symbol('t'))), + ] + + for s1, s2 in pairs: + cache = {} + st = aesara_code_(s1, cache=cache) + + # Test hit with same instance + assert aesara_code_(s1, cache=cache) is st + + # Test miss with same instance but new cache + assert aesara_code_(s1, cache={}) is not st + + # Test hit with different but equivalent instance + assert aesara_code_(s2, cache=cache) is st + +def test_global_cache(): + """ Test use of the global cache. """ + from sympy.printing.aesaracode import global_cache + + backup = dict(global_cache) + try: + # Temporarily empty global cache + global_cache.clear() + + for s in [x, X, f_t]: + with warns_deprecated_sympy(): + st = aesara_code(s) + assert aesara_code(s) is st + + finally: + # Restore global cache + global_cache.update(backup) + +def test_cache_types_distinct(): + """ + Test that symbol-like objects of different types (Symbol, MatrixSymbol, + AppliedUndef) are distinguished by the cache even if they have the same + name. + """ + symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] + + cache = {} # Single shared cache + printed = {} + + for s in symbols: + st = aesara_code_(s, cache=cache) + assert st not in printed.values() + printed[s] = st + + # Check all printed objects are distinct + assert len(set(map(id, printed.values()))) == len(symbols) + + # Check retrieving + for s, st in printed.items(): + with warns_deprecated_sympy(): + assert aesara_code(s, cache=cache) is st + +def test_symbols_are_created_once(): + """ + Test that a symbol is cached and reused when it appears in an expression + more than once. + """ + expr = sy.Add(x, x, evaluate=False) + comp = aesara_code_(expr) + + assert theq(comp, xt + xt) + assert not theq(comp, xt + aesara_code_(x)) + +def test_cache_complex(): + """ + Test caching on a complicated expression with multiple symbols appearing + multiple times. + """ + expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) + symbol_names = {s.name for s in expr.free_symbols} + expr_t = aesara_code_(expr) + + # Iterate through variables in the Aesara computational graph that the + # printed expression depends on + seen = set() + for v in aesara.graph.basic.ancestors([expr_t]): + # Owner-less, non-constant variables should be our symbols + if v.owner is None and not isinstance(v, aesara.graph.basic.Constant): + # Check it corresponds to a symbol and appears only once + assert v.name in symbol_names + assert v.name not in seen + seen.add(v.name) + + # Check all were present + assert seen == symbol_names + + +def test_Piecewise(): + # A piecewise linear + expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III + result = aesara_code_(expr) + assert result.owner.op == aet.switch + + expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1)) + assert theq(result, expected) + + expr = sy.Piecewise((x, x < 0)) + result = aesara_code_(expr) + expected = aet.switch(xt < 0, xt, np.nan) + assert theq(result, expected) + + expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ + (x, sy.Or(x>2, x<0))) + result = aesara_code_(expr) + expected = aet.switch(aet.and_(xt>0,xt<2), 0, \ + aet.switch(aet.or_(xt>2, xt<0), xt, np.nan)) + assert theq(result, expected) + + +def test_Relationals(): + assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt)) + # assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt)) # TODO - implement + assert theq(aesara_code_(x > y), xt > yt) + assert theq(aesara_code_(x < y), xt < yt) + assert theq(aesara_code_(x >= y), xt >= yt) + assert theq(aesara_code_(x <= y), xt <= yt) + + +def test_complexfunctions(): + dtypes = {x:'complex128', y:'complex128'} + with warns_deprecated_sympy(): + xt, yt = aesara_code(x, dtypes=dtypes), aesara_code(y, dtypes=dtypes) + from sympy.functions.elementary.complexes import conjugate + from aesara.tensor import as_tensor_variable as atv + from aesara.tensor import complex as cplx + with warns_deprecated_sympy(): + assert theq(aesara_code(y*conjugate(x), dtypes=dtypes), yt*(xt.conj())) + assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) + + +def test_constantfunctions(): + with warns_deprecated_sympy(): + tf = aesara_function([],[1+1j]) + assert(tf()==1+1j) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_c.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_c.py new file mode 100644 index 0000000000000000000000000000000000000000..626e7b6f244ea3227b886cd897d327f5d7bf66ec --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_c.py @@ -0,0 +1,888 @@ +from sympy.core import ( + S, pi, oo, Symbol, symbols, Rational, Integer, Float, Function, Mod, GoldenRatio, EulerGamma, Catalan, + Lambda, Dummy, nan, Mul, Pow, UnevaluatedExpr +) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.functions import ( + Abs, acos, acosh, asin, asinh, atan, atanh, atan2, ceiling, cos, cosh, erf, + erfc, exp, floor, gamma, log, loggamma, Max, Min, Piecewise, sign, sin, sinh, + sqrt, tan, tanh, fibonacci, lucas +) +from sympy.sets import Range +from sympy.logic import ITE, Implies, Equivalent +from sympy.codegen import For, aug_assign, Assignment +from sympy.testing.pytest import raises, XFAIL +from sympy.printing.codeprinter import PrintMethodNotImplementedError +from sympy.printing.c import C89CodePrinter, C99CodePrinter, get_math_macros +from sympy.codegen.ast import ( + AddAugmentedAssignment, Element, Type, FloatType, Declaration, Pointer, Variable, value_const, pointer_const, + While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall, Return, + real, float32, float64, float80, float128, intc, Comment, CodeBlock, stderr, QuotedString +) +from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, fma, log10, Cbrt, hypot, Sqrt, isnan, isinf +from sympy.codegen.cnodes import restrict +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix + +from sympy.printing.codeprinter import ccode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + class fabs(Abs): + def _ccode(self, printer): + return "fabs(%s)" % printer._print(self.args[0]) + + assert ccode(fabs(x)) == "fabs(x)" + + +def test_ccode_sqrt(): + assert ccode(sqrt(x)) == "sqrt(x)" + assert ccode(x**0.5) == "sqrt(x)" + assert ccode(sqrt(x)) == "sqrt(x)" + + +def test_ccode_Pow(): + assert ccode(x**3) == "pow(x, 3)" + assert ccode(x**(y**3)) == "pow(x, pow(y, 3))" + g = implemented_function('g', Lambda(x, 2*x)) + assert ccode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2) + y)" + assert ccode(x**-1.0) == '1.0/x' + assert ccode(x**Rational(2, 3)) == 'pow(x, 2.0/3.0)' + assert ccode(x**Rational(2, 3), type_aliases={real: float80}) == 'powl(x, 2.0L/3.0L)' + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"), + (lambda base, exp: not exp.is_integer, "pow")] + assert ccode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)' + assert ccode(x**0.5, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 0.5)' + assert ccode(x**Rational(16, 5), user_functions={'Pow': _cond_cfunc}) == 'pow(x, 16.0/5.0)' + _cond_cfunc2 = [(lambda base, exp: base == 2, lambda base, exp: 'exp2(%s)' % exp), + (lambda base, exp: base != 2, 'pow')] + # Related to gh-11353 + assert ccode(2**x, user_functions={'Pow': _cond_cfunc2}) == 'exp2(x)' + assert ccode(x**2, user_functions={'Pow': _cond_cfunc2}) == 'pow(x, 2)' + # For issue 14160 + assert ccode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + + +def test_ccode_Max(): + # Test for gh-11926 + assert ccode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))' + + +def test_ccode_Min_performance(): + #Shouldn't take more than a few seconds + big_min = Min(*symbols('a[0:50]')) + for curr_standard in ('c89', 'c99', 'c11'): + output = ccode(big_min, standard=curr_standard) + assert output.count('(') == output.count(')') + + +def test_ccode_constants_mathh(): + assert ccode(exp(1)) == "M_E" + assert ccode(pi) == "M_PI" + assert ccode(oo, standard='c89') == "HUGE_VAL" + assert ccode(-oo, standard='c89') == "-HUGE_VAL" + assert ccode(oo) == "INFINITY" + assert ccode(-oo, standard='c99') == "-INFINITY" + assert ccode(pi, type_aliases={real: float80}) == "M_PIl" + + +def test_ccode_constants_other(): + assert ccode(2*GoldenRatio) == "const double GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert ccode( + 2*Catalan) == "const double Catalan = %s;\n2*Catalan" % Catalan.evalf(17) + assert ccode(2*EulerGamma) == "const double EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_ccode_Rational(): + assert ccode(Rational(3, 7)) == "3.0/7.0" + assert ccode(Rational(3, 7), type_aliases={real: float80}) == "3.0L/7.0L" + assert ccode(Rational(18, 9)) == "2" + assert ccode(Rational(3, -7)) == "-3.0/7.0" + assert ccode(Rational(3, -7), type_aliases={real: float80}) == "-3.0L/7.0L" + assert ccode(Rational(-3, -7)) == "3.0/7.0" + assert ccode(Rational(-3, -7), type_aliases={real: float80}) == "3.0L/7.0L" + assert ccode(x + Rational(3, 7)) == "x + 3.0/7.0" + assert ccode(x + Rational(3, 7), type_aliases={real: float80}) == "x + 3.0L/7.0L" + assert ccode(Rational(3, 7)*x) == "(3.0/7.0)*x" + assert ccode(Rational(3, 7)*x, type_aliases={real: float80}) == "(3.0L/7.0L)*x" + + +def test_ccode_Integer(): + assert ccode(Integer(67)) == "67" + assert ccode(Integer(-1)) == "-1" + + +def test_ccode_functions(): + assert ccode(sin(x) ** cos(x)) == "pow(sin(x), cos(x))" + + +def test_ccode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert ccode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert ccode( + g(x)) == "const double Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert ccode(g(A[i]), assign_to=A[i]) == ( + "for (int i=0; i y" + assert ccode(Ge(x, y)) == "x >= y" + + +def test_ccode_Piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " x\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + "))") + assert ccode(expr, assign_to="c") == ( + "if (x < 1) {\n" + " c = x;\n" + "}\n" + "else {\n" + " c = pow(x, 2);\n" + "}") + expr = Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " x\n" + ")\n" + ": ((x < 2) ? (\n" + " x + 1\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + ")))") + assert ccode(expr, assign_to='c') == ( + "if (x < 1) {\n" + " c = x;\n" + "}\n" + "else if (x < 2) {\n" + " c = x + 1;\n" + "}\n" + "else {\n" + " c = pow(x, 2);\n" + "}") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: ccode(expr)) + + +def test_ccode_sinc(): + from sympy.functions.elementary.trigonometric import sinc + expr = sinc(x) + assert ccode(expr) == ( + "(((x != 0) ? (\n" + " sin(x)/x\n" + ")\n" + ": (\n" + " 1\n" + ")))") + + +def test_ccode_Piecewise_deep(): + p = ccode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))) + assert p == ( + "2*((x < 1) ? (\n" + " x\n" + ")\n" + ": ((x < 2) ? (\n" + " x + 1\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + ")))") + expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1 + assert ccode(expr) == ( + "pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n" + " 0\n" + ")\n" + ": (\n" + " 1\n" + ")) + cos(z) - 1") + assert ccode(expr, assign_to='c') == ( + "c = pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n" + " 0\n" + ")\n" + ": (\n" + " 1\n" + ")) + cos(z) - 1;") + + +def test_ccode_ITE(): + expr = ITE(x < 1, y, z) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " y\n" + ")\n" + ": (\n" + " z\n" + "))") + + +def test_ccode_settings(): + raises(TypeError, lambda: ccode(sin(x), method="garbage")) + + +def test_ccode_Indexed(): + s, n, m, o = symbols('s n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + + x = IndexedBase('x')[j] + A = IndexedBase('A')[i, j] + B = IndexedBase('B')[i, j, k] + + p = C99CodePrinter() + + assert p._print_Indexed(x) == 'x[j]' + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + A = IndexedBase('A', shape=(5,3))[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (3*i + j) + + A = IndexedBase('A', shape=(5,3), strides='F')[i, j] + assert ccode(A) == 'A[%s]' % (i + 5*j) + + A = IndexedBase('A', shape=(29,29), strides=(1, s), offset=o)[i, j] + assert ccode(A) == 'A[o + s*j + i]' + + Abase = IndexedBase('A', strides=(s, m, n), offset=o) + assert ccode(Abase[i, j, k]) == 'A[m*j + n*k + o + s*i]' + assert ccode(Abase[2, 3, k]) == 'A[3*m + n*k + o + 2*s]' + + +def test_Element(): + assert ccode(Element('x', 'ij')) == 'x[i][j]' + assert ccode(Element('x', 'ij', strides='kl', offset='o')) == 'x[i*k + j*l + o]' + assert ccode(Element('x', (3,))) == 'x[3]' + assert ccode(Element('x', (3,4,5))) == 'x[3][4][5]' + + +def test_ccode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e = Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = ccode(e.rhs, assign_to=e.lhs, contract=False) + assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1) + + +def test_ccode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (int i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert ccode(mat, A) == ( + "A[0] = x*y;\n" + "if (y > 0) {\n" + " A[1] = x + 2;\n" + "}\n" + "else {\n" + " A[1] = y;\n" + "}\n" + "A[2] = sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert ccode(expr) == ( + "((x > 0) ? (\n" + " 2*A[2]\n" + ")\n" + ": (\n" + " A[2]\n" + ")) + sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert ccode(m, M) == ( + "M[0] = sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_sparse_matrix(): + # gh-15791 + with raises(PrintMethodNotImplementedError): + ccode(SparseMatrix([[1, 2, 3]])) + + assert 'Not supported in C' in C89CodePrinter({'strict': False}).doprint(SparseMatrix([[1, 2, 3]])) + + + +def test_ccode_reserved_words(): + x, y = symbols('x, if') + with raises(ValueError): + ccode(y**2, error_on_reserved=True, standard='C99') + assert ccode(y**2) == 'pow(if_, 2)' + assert ccode(x * y**2, dereference=[y]) == 'pow((*if_), 2)*x' + assert ccode(y**2, reserved_word_suffix='_unreserved') == 'pow(if_unreserved, 2)' + + +def test_ccode_sign(): + expr1, ref1 = sign(x) * y, 'y*(((x) > 0) - ((x) < 0))' + expr2, ref2 = sign(cos(x)), '(((cos(x)) > 0) - ((cos(x)) < 0))' + expr3, ref3 = sign(2 * x + x**2) * x + x**2, 'pow(x, 2) + x*(((pow(x, 2) + 2*x) > 0) - ((pow(x, 2) + 2*x) < 0))' + assert ccode(expr1) == ref1 + assert ccode(expr1, 'z') == 'z = %s;' % ref1 + assert ccode(expr2) == ref2 + assert ccode(expr3) == ref3 + +def test_ccode_Assignment(): + assert ccode(Assignment(x, y + z)) == 'x = y + z;' + assert ccode(aug_assign(x, '+', y + z)) == 'x += y + z;' + + +def test_ccode_For(): + f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)]) + assert ccode(f) == ("for (x = 0; x < 10; x += 2) {\n" + " y *= x;\n" + "}") + +def test_ccode_Max_Min(): + assert ccode(Max(x, 0), standard='C89') == '((0 > x) ? 0 : x)' + assert ccode(Max(x, 0), standard='C99') == 'fmax(0, x)' + assert ccode(Min(x, 0, sqrt(x)), standard='c89') == ( + '((0 < ((x < sqrt(x)) ? x : sqrt(x))) ? 0 : ((x < sqrt(x)) ? x : sqrt(x)))' + ) + +def test_ccode_standard(): + assert ccode(expm1(x), standard='c99') == 'expm1(x)' + assert ccode(nan, standard='c99') == 'NAN' + assert ccode(float('nan'), standard='c99') == 'NAN' + + +def test_C89CodePrinter(): + c89printer = C89CodePrinter() + assert c89printer.language == 'C' + assert c89printer.standard == 'C89' + assert 'void' in c89printer.reserved_words + assert 'template' not in c89printer.reserved_words + assert c89printer.doprint(log10(x)) == 'log10(x)' + + +def test_C99CodePrinter(): + assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)' + assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)' + assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)' + assert C99CodePrinter().doprint(log2(x)) == 'log2(x)' + assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)' + assert C99CodePrinter().doprint(log10(x)) == 'log10(x)' + assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken. + assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)' + assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)' + assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))' + assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)' + c99printer = C99CodePrinter() + assert c99printer.language == 'C' + assert c99printer.standard == 'C99' + assert 'restrict' in c99printer.reserved_words + assert 'using' not in c99printer.reserved_words + + +@XFAIL +def test_C99CodePrinter__precision_f80(): + f80_printer = C99CodePrinter({"type_aliases": {real: float80}}) + assert f80_printer.doprint(sin(x + Float('2.1'))) == 'sinl(x + 2.1L)' + + +def test_C99CodePrinter__precision(): + n = symbols('n', integer=True) + p = symbols('p', integer=True, positive=True) + f32_printer = C99CodePrinter({"type_aliases": {real: float32}}) + f64_printer = C99CodePrinter({"type_aliases": {real: float64}}) + f80_printer = C99CodePrinter({"type_aliases": {real: float80}}) + assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)' + assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)' + assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)' + + for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']): + def check(expr, ref): + assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper()) + check(Abs(n), 'abs(n)') + check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})') + check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))') + check(exp(x*8.0), 'exp{s}(8.0{S}*x)') + check(exp2(x), 'exp2{s}(x)') + check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)') + check(Mod(p, 2), 'p % 2') + check(Mod(2*p + 3, 3*p + 5, evaluate=False), '(2*p + 3) % (3*p + 5)') + check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})') + check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})') + check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)') + check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)') + check(log2(x*8.0), 'log2{s}(8.0{S}*x)') + check(log1p(x), 'log1p{s}(x)') + check(2**x, 'pow{s}(2, x)') + check(2.0**x, 'pow{s}(2.0{S}, x)') + check(x**3, 'pow{s}(x, 3)') + check(x**4.0, 'pow{s}(x, 4.0{S})') + check(sqrt(3+x), 'sqrt{s}(x + 3)') + check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})') + check(hypot(x, y), 'hypot{s}(x, y)') + check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})') + check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})') + check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})') + check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})') + check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})') + check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})') + check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)') + + check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})') + check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})') + check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})') + check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})') + check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})') + check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})') + check(erf(42.*x), 'erf{s}(42.0{S}*x)') + check(erfc(42.*x), 'erfc{s}(42.0{S}*x)') + check(gamma(x), 'tgamma{s}(x)') + check(loggamma(x), 'lgamma{s}(x)') + + check(ceiling(x + 2.), "ceil{s}(x) + 2") + check(floor(x + 2.), "floor{s}(x) + 2") + check(fma(x, y, -z), 'fma{s}(x, y, -z)') + check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))') + check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)') + + +def test_get_math_macros(): + macros = get_math_macros() + assert macros[exp(1)] == 'M_E' + assert macros[1/Sqrt(2)] == 'M_SQRT1_2' + + +def test_ccode_Declaration(): + i = symbols('i', integer=True) + var1 = Variable(i, type=Type.from_expr(i)) + dcl1 = Declaration(var1) + assert ccode(dcl1) == 'int i' + + var2 = Variable(x, type=float32, attrs={value_const}) + dcl2a = Declaration(var2) + assert ccode(dcl2a) == 'const float x' + dcl2b = var2.as_Declaration(value=pi) + assert ccode(dcl2b) == 'const float x = M_PI' + + var3 = Variable(y, type=Type('bool')) + dcl3 = Declaration(var3) + printer = C89CodePrinter() + assert 'stdbool.h' not in printer.headers + assert printer.doprint(dcl3) == 'bool y' + assert 'stdbool.h' in printer.headers + + u = symbols('u', real=True) + ptr4 = Pointer.deduced(u, attrs={pointer_const, restrict}) + dcl4 = Declaration(ptr4) + assert ccode(dcl4) == 'double * const restrict u' + + var5 = Variable(x, Type('__float128'), attrs={value_const}) + dcl5a = Declaration(var5) + assert ccode(dcl5a) == 'const __float128 x' + var5b = Variable(var5.symbol, var5.type, pi, attrs=var5.attrs) + dcl5b = Declaration(var5b) + assert ccode(dcl5b) == 'const __float128 x = M_PI' + + +def test_C99CodePrinter_custom_type(): + # We will look at __float128 (new in glibc 2.26) + f128 = FloatType('_Float128', float128.nbits, float128.nmant, float128.nexp) + p128 = C99CodePrinter({ + "type_aliases": {real: f128}, + "type_literal_suffixes": {f128: 'Q'}, + "type_func_suffixes": {f128: 'f128'}, + "type_math_macro_suffixes": { + real: 'f128', + f128: 'f128' + }, + "type_macros": { + f128: ('__STDC_WANT_IEC_60559_TYPES_EXT__',) + } + }) + assert p128.doprint(x) == 'x' + assert not p128.headers + assert not p128.libraries + assert not p128.macros + assert p128.doprint(2.0) == '2.0Q' + assert not p128.headers + assert not p128.libraries + assert p128.macros == {'__STDC_WANT_IEC_60559_TYPES_EXT__'} + + assert p128.doprint(Rational(1, 2)) == '1.0Q/2.0Q' + assert p128.doprint(sin(x)) == 'sinf128(x)' + assert p128.doprint(cos(2., evaluate=False)) == 'cosf128(2.0Q)' + assert p128.doprint(x**-1.0) == '1.0Q/x' + + var5 = Variable(x, f128, attrs={value_const}) + + dcl5a = Declaration(var5) + assert ccode(dcl5a) == 'const _Float128 x' + var5b = Variable(x, f128, pi, attrs={value_const}) + dcl5b = Declaration(var5b) + assert p128.doprint(dcl5b) == 'const _Float128 x = M_PIf128' + var5b = Variable(x, f128, value=Catalan.evalf(38), attrs={value_const}) + dcl5c = Declaration(var5b) + assert p128.doprint(dcl5c) == 'const _Float128 x = %sQ' % Catalan.evalf(f128.decimal_dig) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(ccode(A[0, 0]) == "A[0]") + assert(ccode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(ccode(F) == "(A - B)[0]") + +def test_ccode_math_macros(): + assert ccode(z + exp(1)) == 'z + M_E' + assert ccode(z + log2(exp(1))) == 'z + M_LOG2E' + assert ccode(z + 1/log(2)) == 'z + M_LOG2E' + assert ccode(z + log(2)) == 'z + M_LN2' + assert ccode(z + log(10)) == 'z + M_LN10' + assert ccode(z + pi) == 'z + M_PI' + assert ccode(z + pi/2) == 'z + M_PI_2' + assert ccode(z + pi/4) == 'z + M_PI_4' + assert ccode(z + 1/pi) == 'z + M_1_PI' + assert ccode(z + 2/pi) == 'z + M_2_PI' + assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI' + assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI' + assert ccode(z + sqrt(2)) == 'z + M_SQRT2' + assert ccode(z + Sqrt(2)) == 'z + M_SQRT2' + assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2' + assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2' + + +def test_ccode_Type(): + assert ccode(Type('float')) == 'float' + assert ccode(intc) == 'int' + + +def test_ccode_codegen_ast(): + # Note that C only allows comments of the form /* ... */, double forward + # slash is not standard C, and some C compilers will grind to a halt upon + # encountering them. + assert ccode(Comment("this is a comment")) == "/* this is a comment */" # not // + assert ccode(While(abs(x) > 1, [aug_assign(x, '-', 1)])) == ( + 'while (fabs(x) > 1) {\n' + ' x -= 1;\n' + '}' + ) + assert ccode(Scope([AddAugmentedAssignment(x, 1)])) == ( + '{\n' + ' x += 1;\n' + '}' + ) + inp_x = Declaration(Variable(x, type=real)) + assert ccode(FunctionPrototype(real, 'pwer', [inp_x])) == 'double pwer(double x)' + assert ccode(FunctionDefinition(real, 'pwer', [inp_x], [Assignment(x, x**2)])) == ( + 'double pwer(double x){\n' + ' x = pow(x, 2);\n' + '}' + ) + + # Elements of CodeBlock are formatted as statements: + block = CodeBlock( + x, + Print([x, y], "%d %d"), + Print([QuotedString('hello'), y], "%s %d", file=stderr), + FunctionCall('pwer', [x]), + Return(x), + ) + assert ccode(block) == '\n'.join([ + 'x;', + 'printf("%d %d", x, y);', + 'fprintf(stderr, "%s %d", "hello", y);', + 'pwer(x);', + 'return x;', + ]) + +def test_ccode_UnevaluatedExpr(): + assert ccode(UnevaluatedExpr(y * x) + z) == "z + x*y" + assert ccode(UnevaluatedExpr(y + x) + z) == "z + (x + y)" # gh-21955 + w = symbols('w') + assert ccode(UnevaluatedExpr(y + x) + UnevaluatedExpr(z + w)) == "(w + z) + (x + y)" + + p, q, r = symbols("p q r", real=True) + q_r = UnevaluatedExpr(q + r) + expr = abs(exp(p+q_r)) + assert ccode(expr) == "exp(p + (q + r))" + + +def test_ccode_array_like_containers(): + assert ccode([2,3,4]) == "{2, 3, 4}" + assert ccode((2,3,4)) == "{2, 3, 4}" + +def test_ccode__isinf_isnan(): + assert ccode(isinf(x)) == 'isinf(x)' + assert ccode(isnan(x)) == 'isnan(x)' diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_codeprinter.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_codeprinter.py new file mode 100644 index 0000000000000000000000000000000000000000..4b077037eb84e218fcfd4a05fc03e40b211e45b9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_codeprinter.py @@ -0,0 +1,77 @@ +from sympy.printing.codeprinter import CodePrinter, PrintMethodNotImplementedError +from sympy.core import symbols +from sympy.core.symbol import Dummy +from sympy.testing.pytest import raises +from sympy import cos +from sympy.utilities.lambdify import lambdify +from math import cos as math_cos +from sympy.printing.lambdarepr import LambdaPrinter + + +def setup_test_printer(**kwargs): + p = CodePrinter(settings=kwargs) + p._not_supported = set() + p._number_symbols = set() + return p + + +def test_print_Dummy(): + d = Dummy('d') + p = setup_test_printer() + assert p._print_Dummy(d) == "d_%i" % d.dummy_index + +def test_print_Symbol(): + + x, y = symbols('x, if') + + p = setup_test_printer() + assert p._print(x) == 'x' + assert p._print(y) == 'if' + + p.reserved_words.update(['if']) + assert p._print(y) == 'if_' + + p = setup_test_printer(error_on_reserved=True) + p.reserved_words.update(['if']) + with raises(ValueError): + p._print(y) + + p = setup_test_printer(reserved_word_suffix='_He_Man') + p.reserved_words.update(['if']) + assert p._print(y) == 'if_He_Man' + + +def test_lambdify_LaTeX_symbols_issue_23374(): + # Create symbols with Latex style names + x1, x2 = symbols("x_{1} x_2") + + # Lambdify the function + f1 = lambdify([x1, x2], cos(x1 ** 2 + x2 ** 2)) + + # Test that the function works correctly (numerically) + assert f1(1, 2) == math_cos(1 ** 2 + 2 ** 2) + + # Explicitly generate a custom printer to verify the naming convention + p = LambdaPrinter() + expr_str = p.doprint(cos(x1 ** 2 + x2 ** 2)) + assert 'x_1' in expr_str + assert 'x_2' in expr_str + + +def test_issue_15791(): + class CrashingCodePrinter(CodePrinter): + def emptyPrinter(self, obj): + raise NotImplementedError + + from sympy.matrices import ( + MutableSparseMatrix, + ImmutableSparseMatrix, + ) + + c = CrashingCodePrinter() + + # these should not silently succeed + with raises(PrintMethodNotImplementedError): + c.doprint(ImmutableSparseMatrix(2, 2, {})) + with raises(PrintMethodNotImplementedError): + c.doprint(MutableSparseMatrix(2, 2, {})) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_conventions.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_conventions.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f1fa8532f96130828b89d1ba5ba11fd5bed7a4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_conventions.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import cos +from sympy.integrals.integrals import Integral +from sympy.functions.special.bessel import besselj +from sympy.functions.special.polynomials import legendre +from sympy.functions.combinatorial.numbers import bell +from sympy.printing.conventions import split_super_sub, requires_partial +from sympy.testing.pytest import XFAIL + +def test_super_sub(): + assert split_super_sub("beta_13_2") == ("beta", [], ["13", "2"]) + assert split_super_sub("beta_132_20") == ("beta", [], ["132", "20"]) + assert split_super_sub("beta_13") == ("beta", [], ["13"]) + assert split_super_sub("x_a_b") == ("x", [], ["a", "b"]) + assert split_super_sub("x_1_2_3") == ("x", [], ["1", "2", "3"]) + assert split_super_sub("x_a_b1") == ("x", [], ["a", "b1"]) + assert split_super_sub("x_a_1") == ("x", [], ["a", "1"]) + assert split_super_sub("x_1_a") == ("x", [], ["1", "a"]) + assert split_super_sub("x_1^aa") == ("x", ["aa"], ["1"]) + assert split_super_sub("x_1__aa") == ("x", ["aa"], ["1"]) + assert split_super_sub("x_11^a") == ("x", ["a"], ["11"]) + assert split_super_sub("x_11__a") == ("x", ["a"], ["11"]) + assert split_super_sub("x_a_b_c_d") == ("x", [], ["a", "b", "c", "d"]) + assert split_super_sub("x_a_b^c^d") == ("x", ["c", "d"], ["a", "b"]) + assert split_super_sub("x_a_b__c__d") == ("x", ["c", "d"], ["a", "b"]) + assert split_super_sub("x_a^b_c^d") == ("x", ["b", "d"], ["a", "c"]) + assert split_super_sub("x_a__b_c__d") == ("x", ["b", "d"], ["a", "c"]) + assert split_super_sub("x^a^b_c_d") == ("x", ["a", "b"], ["c", "d"]) + assert split_super_sub("x__a__b_c_d") == ("x", ["a", "b"], ["c", "d"]) + assert split_super_sub("x^a^b^c^d") == ("x", ["a", "b", "c", "d"], []) + assert split_super_sub("x__a__b__c__d") == ("x", ["a", "b", "c", "d"], []) + assert split_super_sub("alpha_11") == ("alpha", [], ["11"]) + assert split_super_sub("alpha_11_11") == ("alpha", [], ["11", "11"]) + assert split_super_sub("w1") == ("w", [], ["1"]) + assert split_super_sub("w𝟙") == ("w", [], ["𝟙"]) + assert split_super_sub("w11") == ("w", [], ["11"]) + assert split_super_sub("w𝟙𝟙") == ("w", [], ["𝟙𝟙"]) + assert split_super_sub("w𝟙2𝟙") == ("w", [], ["𝟙2𝟙"]) + assert split_super_sub("w1^a") == ("w", ["a"], ["1"]) + assert split_super_sub("ω1") == ("ω", [], ["1"]) + assert split_super_sub("ω11") == ("ω", [], ["11"]) + assert split_super_sub("ω1^a") == ("ω", ["a"], ["1"]) + assert split_super_sub("ω𝟙^α") == ("ω", ["α"], ["𝟙"]) + assert split_super_sub("ω𝟙2^3α") == ("ω", ["3α"], ["𝟙2"]) + assert split_super_sub("") == ("", [], []) + + +def test_requires_partial(): + x, y, z, t, nu = symbols('x y z t nu') + n = symbols('n', integer=True) + + f = x * y + assert requires_partial(Derivative(f, x)) is True + assert requires_partial(Derivative(f, y)) is True + + ## integrating out one of the variables + assert requires_partial(Derivative(Integral(exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False + + ## bessel function with smooth parameter + f = besselj(nu, x) + assert requires_partial(Derivative(f, x)) is True + assert requires_partial(Derivative(f, nu)) is True + + ## bessel function with integer parameter + f = besselj(n, x) + assert requires_partial(Derivative(f, x)) is False + # this is not really valid (differentiating with respect to an integer) + # but there's no reason to use the partial derivative symbol there. make + # sure we don't throw an exception here, though + assert requires_partial(Derivative(f, n)) is False + + ## bell polynomial + f = bell(n, x) + assert requires_partial(Derivative(f, x)) is False + # again, invalid + assert requires_partial(Derivative(f, n)) is False + + ## legendre polynomial + f = legendre(0, x) + assert requires_partial(Derivative(f, x)) is False + + f = legendre(n, x) + assert requires_partial(Derivative(f, x)) is False + # again, invalid + assert requires_partial(Derivative(f, n)) is False + + f = x ** n + assert requires_partial(Derivative(f, x)) is False + + assert requires_partial(Derivative(Integral((x*y) ** n * exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False + + # parametric equation + f = (exp(t), cos(t)) + g = sum(f) + assert requires_partial(Derivative(g, t)) is False + + f = symbols('f', cls=Function) + assert requires_partial(Derivative(f(x), x)) is False + assert requires_partial(Derivative(f(x), y)) is False + assert requires_partial(Derivative(f(x, y), x)) is True + assert requires_partial(Derivative(f(x, y), y)) is True + assert requires_partial(Derivative(f(x, y), z)) is True + assert requires_partial(Derivative(f(x, y), x, y)) is True + +@XFAIL +def test_requires_partial_unspecified_variables(): + x, y = symbols('x y') + # function of unspecified variables + f = symbols('f', cls=Function) + assert requires_partial(Derivative(f, x)) is False + assert requires_partial(Derivative(f, x, y)) is True diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_cupy.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_cupy.py new file mode 100644 index 0000000000000000000000000000000000000000..cf111ec1623390a3dbbf489235d2ed387624a36c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_cupy.py @@ -0,0 +1,56 @@ +from sympy.concrete.summations import Sum +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.utilities.lambdify import lambdify +from sympy.abc import x, i, a, b +from sympy.codegen.numpy_nodes import logaddexp +from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +cp = import_module('cupy') + +def test_cupy_print(): + prntr = CuPyPrinter() + assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)' + assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)' + assert prntr.doprint(log(x)) == 'cupy.log(x)' + assert prntr.doprint("acos(x)") == 'cupy.arccos(x)' + assert prntr.doprint("exp(x)") == 'cupy.exp(x)' + assert prntr.doprint("Abs(x)") == 'abs(x)' + +def test_not_cupy_print(): + prntr = CuPyPrinter() + with raises(NotImplementedError): + prntr.doprint("abcd(x)") + +def test_cupy_sum(): + if not cp: + skip("CuPy not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'cupy') + + a_, b_ = 0, 10 + x_ = cp.linspace(-1, +1, 10) + assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = cp.linspace(-1, +1, 10) + assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + +def test_cupy_known_funcs_consts(): + assert _cupy_known_constants['NaN'] == 'cupy.nan' + assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma' + + assert _cupy_known_functions['acos'] == 'cupy.arccos' + assert _cupy_known_functions['log'] == 'cupy.log' + +def test_cupy_print_methods(): + prntr = CuPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_cxx.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_cxx.py new file mode 100644 index 0000000000000000000000000000000000000000..d84ec75cbf0eeb60a1176b9cb3b401a3384454e7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_cxx.py @@ -0,0 +1,86 @@ +from sympy.core.numbers import Float, Integer, Rational +from sympy.core.symbol import symbols +from sympy.functions import beta, Ei, zeta, Max, Min, sqrt, riemann_xi, frac +from sympy.printing.cxx import CXX98CodePrinter, CXX11CodePrinter, CXX17CodePrinter, cxxcode +from sympy.codegen.cfunctions import log1p + + +x, y, u, v = symbols('x y u v') + + +def test_CXX98CodePrinter(): + assert CXX98CodePrinter().doprint(Max(x, 3)) in ('std::max(x, 3)', 'std::max(3, x)') + assert CXX98CodePrinter().doprint(Min(x, 3, sqrt(x))) == 'std::min(3, std::min(x, std::sqrt(x)))' + cxx98printer = CXX98CodePrinter() + assert cxx98printer.language == 'C++' + assert cxx98printer.standard == 'C++98' + assert 'template' in cxx98printer.reserved_words + assert 'alignas' not in cxx98printer.reserved_words + + +def test_CXX11CodePrinter(): + assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)' + + cxx11printer = CXX11CodePrinter() + assert cxx11printer.language == 'C++' + assert cxx11printer.standard == 'C++11' + assert 'operator' in cxx11printer.reserved_words + assert 'noexcept' in cxx11printer.reserved_words + assert 'concept' not in cxx11printer.reserved_words + + +def test_subclass_print_method(): + class MyPrinter(CXX11CodePrinter): + def _print_log1p(self, expr): + return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args)) + + assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)' + + +def test_subclass_print_method__ns(): + class MyPrinter(CXX11CodePrinter): + _ns = 'my_library::' + + p = CXX11CodePrinter() + myp = MyPrinter() + + assert p.doprint(log1p(x)) == 'std::log1p(x)' + assert myp.doprint(log1p(x)) == 'my_library::log1p(x)' + + +def test_CXX17CodePrinter(): + assert CXX17CodePrinter().doprint(beta(x, y)) == 'std::beta(x, y)' + assert CXX17CodePrinter().doprint(Ei(x)) == 'std::expint(x)' + assert CXX17CodePrinter().doprint(zeta(x)) == 'std::riemann_zeta(x)' + + # Automatic rewrite + assert CXX17CodePrinter().doprint(frac(x)) == '(x - std::floor(x))' + assert CXX17CodePrinter().doprint(riemann_xi(x)) == '((1.0/2.0)*std::pow(M_PI, -1.0/2.0*x)*x*(x - 1)*std::tgamma((1.0/2.0)*x)*std::riemann_zeta(x))' + + +def test_cxxcode(): + assert sorted(cxxcode(sqrt(x)*.5).split('*')) == sorted(['0.5', 'std::sqrt(x)']) + +def test_cxxcode_nested_minmax(): + assert cxxcode(Max(Min(x, y), Min(u, v))) \ + == 'std::max(std::min(u, v), std::min(x, y))' + assert cxxcode(Min(Max(x, y), Max(u, v))) \ + == 'std::min(std::max(u, v), std::max(x, y))' + +def test_subclass_Integer_Float(): + class MyPrinter(CXX17CodePrinter): + def _print_Integer(self, arg): + return 'bigInt("%s")' % super()._print_Integer(arg) + + def _print_Float(self, arg): + rat = Rational(arg) + return 'bigFloat(%s, %s)' % ( + self._print(Integer(rat.p)), + self._print(Integer(rat.q)) + ) + + p = MyPrinter() + for i in range(13): + assert p.doprint(i) == 'bigInt("%d")' % i + assert p.doprint(Float(0.5)) == 'bigFloat(bigInt("1"), bigInt("2"))' + assert p.doprint(x**-1.0) == 'bigFloat(bigInt("1"), bigInt("1"))/x' diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_dot.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_dot.py new file mode 100644 index 0000000000000000000000000000000000000000..6213e237fb7aac6460a956b4c9fc1f7c8710fec6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_dot.py @@ -0,0 +1,134 @@ +from sympy.printing.dot import (purestr, styleof, attrprint, dotnode, + dotedges, dotprint) +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.numbers import (Float, Integer) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.printing.repr import srepr +from sympy.abc import x + + +def test_purestr(): + assert purestr(Symbol('x')) == "Symbol('x')" + assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))" + assert purestr(Float(2)) == "Float('2.0', precision=53)" + + assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ()) + assert purestr(Basic(S(1), S(2)), with_args=True) == \ + ('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)')) + assert purestr(Float(2), with_args=True) == \ + ("Float('2.0', precision=53)", ()) + + +def test_styleof(): + styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'})] + assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'} + + assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'} + + +def test_attrprint(): + assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \ + '"color"="blue", "shape"="ellipse"' + +def test_dotnode(): + + assert dotnode(x, repeat=False) == \ + '"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];' + assert dotnode(x+2, repeat=False) == \ + '"Add(Integer(2), Symbol(\'x\'))" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];', \ + dotnode(x+2,repeat=0) + + assert dotnode(x + x**2, repeat=False) == \ + '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];' + assert dotnode(x + x**2, repeat=True) == \ + '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];' + +def test_dotedges(): + assert sorted(dotedges(x+2, repeat=False)) == [ + '"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";', + '"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";' + ] + assert sorted(dotedges(x + 2, repeat=True)) == [ + '"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";', + '"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";' + ] + +def test_dotprint(): + text = dotprint(x+2, repeat=False) + assert all(e in text for e in dotedges(x+2, repeat=False)) + assert all( + n in text for n in [dotnode(expr, repeat=False) + for expr in (x, Integer(2), x+2)]) + assert 'digraph' in text + + text = dotprint(x+x**2, repeat=False) + assert all(e in text for e in dotedges(x+x**2, repeat=False)) + assert all( + n in text for n in [dotnode(expr, repeat=False) + for expr in (x, Integer(2), x**2)]) + assert 'digraph' in text + + text = dotprint(x+x**2, repeat=True) + assert all(e in text for e in dotedges(x+x**2, repeat=True)) + assert all( + n in text for n in [dotnode(expr, pos=()) + for expr in [x + x**2]]) + + text = dotprint(x**x, repeat=True) + assert all(e in text for e in dotedges(x**x, repeat=True)) + assert all( + n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))]) + assert 'digraph' in text + +def test_dotprint_depth(): + text = dotprint(3*x+2, depth=1) + assert dotnode(3*x+2) in text + assert dotnode(x) not in text + text = dotprint(3*x+2) + assert "depth" not in text + +def test_Matrix_and_non_basics(): + from sympy.matrices.expressions.matexpr import MatrixSymbol + n = Symbol('n') + assert dotprint(MatrixSymbol('X', n, n)) == \ +"""digraph{ + +# Graph style +"ordering"="out" +"rankdir"="TD" + +######### +# Nodes # +######### + +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"]; +"Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"]; +"Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"]; +"Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"]; + +######### +# Edges # +######### + +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)"; +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)"; +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)"; +}""" + + +def test_labelfunc(): + text = dotprint(x + 2, labelfunc=srepr) + assert "Symbol('x')" in text + assert "Integer(2)" in text + + +def test_commutative(): + x, y = symbols('x y', commutative=False) + assert dotprint(x + y) == dotprint(y + x) + assert dotprint(x*y) != dotprint(y*x) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_glsl.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_glsl.py new file mode 100644 index 0000000000000000000000000000000000000000..86ec1dfe4a37d141e8435c369cb692d3a9a3b7bc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_glsl.py @@ -0,0 +1,998 @@ +from sympy.core import (pi, symbols, Rational, Integer, GoldenRatio, EulerGamma, + Catalan, Lambda, Dummy, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.functions import Piecewise, sin, cos, Abs, exp, ceiling, sqrt +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.printing.glsl import GLSLPrinter +from sympy.printing.str import StrPrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol +from sympy.core import Tuple +from sympy.printing.glsl import glsl_code +import textwrap + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + assert glsl_code(Abs(x)) == "abs(x)" + +def test_print_without_operators(): + assert glsl_code(x*y,use_operators = False) == 'mul(x, y)' + assert glsl_code(x**y+z,use_operators = False) == 'add(pow(x, y), z)' + assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))' + assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))' + assert glsl_code(x*(y+z**y**0.5),use_operators = False) == 'mul(x, add(y, pow(z, sqrt(y))))' + assert glsl_code(-x-y, use_operators=False, zero='zero()') == 'sub(zero(), add(x, y))' + assert glsl_code(-x-y, use_operators=False) == 'sub(0.0, add(x, y))' + +def test_glsl_code_sqrt(): + assert glsl_code(sqrt(x)) == "sqrt(x)" + assert glsl_code(x**0.5) == "sqrt(x)" + assert glsl_code(sqrt(x)) == "sqrt(x)" + + +def test_glsl_code_Pow(): + g = implemented_function('g', Lambda(x, 2*x)) + assert glsl_code(x**3) == "pow(x, 3.0)" + assert glsl_code(x**(y**3)) == "pow(x, pow(y, 3.0))" + assert glsl_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2.0) + y)" + assert glsl_code(x**-1.0) == '1.0/x' + + +def test_glsl_code_Relational(): + assert glsl_code(Eq(x, y)) == "x == y" + assert glsl_code(Ne(x, y)) == "x != y" + assert glsl_code(Le(x, y)) == "x <= y" + assert glsl_code(Lt(x, y)) == "x < y" + assert glsl_code(Gt(x, y)) == "x > y" + assert glsl_code(Ge(x, y)) == "x >= y" + + +def test_glsl_code_constants_mathh(): + assert glsl_code(exp(1)) == "float E = 2.71828183;\nE" + assert glsl_code(pi) == "float pi = 3.14159265;\npi" + # assert glsl_code(oo) == "Number.POSITIVE_INFINITY" + # assert glsl_code(-oo) == "Number.NEGATIVE_INFINITY" + + +def test_glsl_code_constants_other(): + assert glsl_code(2*GoldenRatio) == "float GoldenRatio = 1.61803399;\n2*GoldenRatio" + assert glsl_code(2*Catalan) == "float Catalan = 0.915965594;\n2*Catalan" + assert glsl_code(2*EulerGamma) == "float EulerGamma = 0.577215665;\n2*EulerGamma" + + +def test_glsl_code_Rational(): + assert glsl_code(Rational(3, 7)) == "3.0/7.0" + assert glsl_code(Rational(18, 9)) == "2" + assert glsl_code(Rational(3, -7)) == "-3.0/7.0" + assert glsl_code(Rational(-3, -7)) == "3.0/7.0" + + +def test_glsl_code_Integer(): + assert glsl_code(Integer(67)) == "67" + assert glsl_code(Integer(-1)) == "-1" + + +def test_glsl_code_functions(): + assert glsl_code(sin(x) ** cos(x)) == "pow(sin(x), cos(x))" + + +def test_glsl_code_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert glsl_code(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert glsl_code(g(x)) == "float Catalan = 0.915965594;\n2*x/Catalan" + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert glsl_code(g(A[i]), assign_to=A[i]) == ( + "for (int i=0; i 1), (sin(x), x > 0)) + raises(ValueError, lambda: glsl_code(expr)) + + +def test_glsl_code_Piecewise_deep(): + p = glsl_code(2*Piecewise((x, x < 1), (x**2, True))) + s = \ +"""\ +2*((x < 1) ? ( + x +) +: ( + pow(x, 2.0) +))\ +""" + assert p == s + + +def test_glsl_code_settings(): + raises(TypeError, lambda: glsl_code(sin(x), method="garbage")) + + +def test_glsl_code_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = GLSLPrinter() + p._not_c = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + assert p._not_c == set() + +def test_glsl_code_list_tuple_Tuple(): + assert glsl_code([1,2,3,4]) == 'vec4(1, 2, 3, 4)' + assert glsl_code([1,2,3],glsl_types=False) == 'float[3](1, 2, 3)' + assert glsl_code([1,2,3]) == glsl_code((1,2,3)) + assert glsl_code([1,2,3]) == glsl_code(Tuple(1,2,3)) + + m = MatrixSymbol('A',3,4) + assert glsl_code([m[0],m[1]]) + +def test_glsl_code_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (int i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert glsl_code(mat, assign_to=A) == ( +'''A[0][0] = x*y; +if (y > 0) { + A[1][0] = x + 2; +} +else { + A[1][0] = y; +} +A[2][0] = sin(z);''' ) + assert glsl_code(Matrix([A[0],A[1]])) + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert glsl_code(expr) == ( +'''((x > 0) ? ( + 2*A[2][0] +) +: ( + A[2][0] +)) + sin(A[1][0]) + A[0][0]''' ) + + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert glsl_code(m,M) == ( +'''M[0][0] = sin(q[1]); +M[0][1] = 0; +M[0][2] = cos(q[2]); +M[1][0] = q[1] + q[2]; +M[1][1] = q[3]; +M[1][2] = 5; +M[2][0] = 2*q[4]/q[1]; +M[2][1] = sqrt(q[0]) + 4; +M[2][2] = 0;''' + ) + +def test_Matrices_1x7(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assert gl(A) == 'float[7](1, 2, 3, 4, 5, 6, 7)' + assert gl(A.transpose()) == 'float[7](1, 2, 3, 4, 5, 6, 7)' + +def test_Matrices_1x7_array_type_int(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assert gl(A, array_type='int') == 'int[7](1, 2, 3, 4, 5, 6, 7)' + +def test_Tuple_array_type_custom(): + gl = glsl_code + A = symbols('a b c') + assert gl(A, array_type='AbcType', glsl_types=False) == 'AbcType[3](a, b, c)' + +def test_Matrices_1x7_spread_assign_to_symbols(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assign_to = symbols('x.a x.b x.c x.d x.e x.f x.g') + assert gl(A, assign_to=assign_to) == textwrap.dedent('''\ + x.a = 1; + x.b = 2; + x.c = 3; + x.d = 4; + x.e = 5; + x.f = 6; + x.g = 7;''' + ) + +def test_spread_assign_to_nested_symbols(): + gl = glsl_code + expr = ((1,2,3), (1,2,3)) + assign_to = (symbols('a b c'), symbols('x y z')) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + x = 1; + y = 2; + z = 3;''' + ) + +def test_spread_assign_to_deeply_nested_symbols(): + gl = glsl_code + a, b, c, x, y, z = symbols('a b c x y z') + expr = (((1,2),3), ((1,2),3)) + assign_to = (((a, b), c), ((x, y), z)) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + x = 1; + y = 2; + z = 3;''' + ) + +def test_matrix_of_tuples_spread_assign_to_symbols(): + gl = glsl_code + with warns_deprecated_sympy(): + expr = Matrix([[(1,2),(3,4)],[(5,6),(7,8)]]) + assign_to = (symbols('a b'), symbols('c d'), symbols('e f'), symbols('g h')) + assert gl(expr, assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + d = 4; + e = 5; + f = 6; + g = 7; + h = 8;''' + ) + +def test_cannot_assign_to_cause_mismatched_length(): + expr = (1, 2) + assign_to = symbols('x y z') + raises(ValueError, lambda: glsl_code(expr, assign_to)) + +def test_matrix_4x4_assign(): + gl = glsl_code + expr = MatrixSymbol('A',4,4) * MatrixSymbol('B',4,4) + MatrixSymbol('C',4,4) + assign_to = MatrixSymbol('X',4,4) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + X[0][0] = A[0][0]*B[0][0] + A[0][1]*B[1][0] + A[0][2]*B[2][0] + A[0][3]*B[3][0] + C[0][0]; + X[0][1] = A[0][0]*B[0][1] + A[0][1]*B[1][1] + A[0][2]*B[2][1] + A[0][3]*B[3][1] + C[0][1]; + X[0][2] = A[0][0]*B[0][2] + A[0][1]*B[1][2] + A[0][2]*B[2][2] + A[0][3]*B[3][2] + C[0][2]; + X[0][3] = A[0][0]*B[0][3] + A[0][1]*B[1][3] + A[0][2]*B[2][3] + A[0][3]*B[3][3] + C[0][3]; + X[1][0] = A[1][0]*B[0][0] + A[1][1]*B[1][0] + A[1][2]*B[2][0] + A[1][3]*B[3][0] + C[1][0]; + X[1][1] = A[1][0]*B[0][1] + A[1][1]*B[1][1] + A[1][2]*B[2][1] + A[1][3]*B[3][1] + C[1][1]; + X[1][2] = A[1][0]*B[0][2] + A[1][1]*B[1][2] + A[1][2]*B[2][2] + A[1][3]*B[3][2] + C[1][2]; + X[1][3] = A[1][0]*B[0][3] + A[1][1]*B[1][3] + A[1][2]*B[2][3] + A[1][3]*B[3][3] + C[1][3]; + X[2][0] = A[2][0]*B[0][0] + A[2][1]*B[1][0] + A[2][2]*B[2][0] + A[2][3]*B[3][0] + C[2][0]; + X[2][1] = A[2][0]*B[0][1] + A[2][1]*B[1][1] + A[2][2]*B[2][1] + A[2][3]*B[3][1] + C[2][1]; + X[2][2] = A[2][0]*B[0][2] + A[2][1]*B[1][2] + A[2][2]*B[2][2] + A[2][3]*B[3][2] + C[2][2]; + X[2][3] = A[2][0]*B[0][3] + A[2][1]*B[1][3] + A[2][2]*B[2][3] + A[2][3]*B[3][3] + C[2][3]; + X[3][0] = A[3][0]*B[0][0] + A[3][1]*B[1][0] + A[3][2]*B[2][0] + A[3][3]*B[3][0] + C[3][0]; + X[3][1] = A[3][0]*B[0][1] + A[3][1]*B[1][1] + A[3][2]*B[2][1] + A[3][3]*B[3][1] + C[3][1]; + X[3][2] = A[3][0]*B[0][2] + A[3][1]*B[1][2] + A[3][2]*B[2][2] + A[3][3]*B[3][2] + C[3][2]; + X[3][3] = A[3][0]*B[0][3] + A[3][1]*B[1][3] + A[3][2]*B[2][3] + A[3][3]*B[3][3] + C[3][3];''' + ) + +def test_1xN_vecs(): + gl = glsl_code + for i in range(1,10): + A = Matrix(range(i)) + assert gl(A.transpose()) == gl(A) + assert gl(A,mat_transpose=True) == gl(A) + if i > 1: + if i <= 4: + assert gl(A) == 'vec%s(%s)' % (i,', '.join(str(s) for s in range(i))) + else: + assert gl(A) == 'float[%s](%s)' % (i,', '.join(str(s) for s in range(i))) + +def test_MxN_mats(): + generatedAssertions='def test_misc_mats():\n' + for i in range(1,6): + for j in range(1,6): + A = Matrix([[x + y*j for x in range(j)] for y in range(i)]) + gl = glsl_code(A) + glTransposed = glsl_code(A,mat_transpose=True) + generatedAssertions+=' mat = '+StrPrinter()._print(A)+'\n\n' + generatedAssertions+=' gl = \'\'\''+gl+'\'\'\'\n' + generatedAssertions+=' glTransposed = \'\'\''+glTransposed+'\'\'\'\n\n' + generatedAssertions+=' assert glsl_code(mat) == gl\n' + generatedAssertions+=' assert glsl_code(mat,mat_transpose=True) == glTransposed\n' + if i == 1 and j == 1: + assert gl == '0' + elif i <= 4 and j <= 4 and i>1 and j>1: + assert gl.startswith('mat%s' % j) + assert glTransposed.startswith('mat%s' % i) + elif i == 1 and j <= 4: + assert gl.startswith('vec') + elif j == 1 and i <= 4: + assert gl.startswith('vec') + elif i == 1: + assert gl.startswith('float[%s]('% j*i) + assert glTransposed.startswith('float[%s]('% j*i) + elif j == 1: + assert gl.startswith('float[%s]('% i*j) + assert glTransposed.startswith('float[%s]('% i*j) + else: + assert gl.startswith('float[%s](' % (i*j)) + assert glTransposed.startswith('float[%s](' % (i*j)) + glNested = glsl_code(A,mat_nested=True) + glNestedTransposed = glsl_code(A,mat_transpose=True,mat_nested=True) + assert glNested.startswith('float[%s][%s]' % (i,j)) + assert glNestedTransposed.startswith('float[%s][%s]' % (j,i)) + generatedAssertions+=' glNested = \'\'\''+glNested+'\'\'\'\n' + generatedAssertions+=' glNestedTransposed = \'\'\''+glNestedTransposed+'\'\'\'\n\n' + generatedAssertions+=' assert glsl_code(mat,mat_nested=True) == glNested\n' + generatedAssertions+=' assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed\n\n' + generateAssertions = False # set this to true to write bake these generated tests to a file + if generateAssertions: + gen = open('test_glsl_generated_matrices.py','w') + gen.write(generatedAssertions) + gen.close() + + +# these assertions were generated from the previous function +# glsl has complicated rules and this makes it easier to look over all the cases +def test_misc_mats(): + + mat = Matrix([[0]]) + + gl = '''0''' + glTransposed = '''0''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1]]) + + gl = '''vec2(0, 1)''' + glTransposed = '''vec2(0, 1)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2]]) + + gl = '''vec3(0, 1, 2)''' + glTransposed = '''vec3(0, 1, 2)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2, 3]]) + + gl = '''vec4(0, 1, 2, 3)''' + glTransposed = '''vec4(0, 1, 2, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2, 3, 4]]) + + gl = '''float[5](0, 1, 2, 3, 4)''' + glTransposed = '''float[5](0, 1, 2, 3, 4)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0], +[1]]) + + gl = '''vec2(0, 1)''' + glTransposed = '''vec2(0, 1)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3]]) + + gl = '''mat2(0, 1, 2, 3)''' + glTransposed = '''mat2(0, 2, 1, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5]]) + + gl = '''mat3x2(0, 1, 2, 3, 4, 5)''' + glTransposed = '''mat2x3(0, 3, 1, 4, 2, 5)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7]]) + + gl = '''mat4x2(0, 1, 2, 3, 4, 5, 6, 7)''' + glTransposed = '''mat2x4(0, 4, 1, 5, 2, 6, 3, 7)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3, 4], +[5, 6, 7, 8, 9]]) + + gl = '''float[10]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9 +) /* a 2x5 matrix */''' + glTransposed = '''float[10]( + 0, 5, + 1, 6, + 2, 7, + 3, 8, + 4, 9 +) /* a 5x2 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[2][5]( + float[](0, 1, 2, 3, 4), + float[](5, 6, 7, 8, 9) +)''' + glNestedTransposed = '''float[5][2]( + float[](0, 5), + float[](1, 6), + float[](2, 7), + float[](3, 8), + float[](4, 9) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2]]) + + gl = '''vec3(0, 1, 2)''' + glTransposed = '''vec3(0, 1, 2)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5]]) + + gl = '''mat2x3(0, 1, 2, 3, 4, 5)''' + glTransposed = '''mat3x2(0, 2, 4, 1, 3, 5)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5], +[6, 7, 8]]) + + gl = '''mat3(0, 1, 2, 3, 4, 5, 6, 7, 8)''' + glTransposed = '''mat3(0, 3, 6, 1, 4, 7, 2, 5, 8)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7], +[8, 9, 10, 11]]) + + gl = '''mat4x3(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)''' + glTransposed = '''mat3x4(0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14]]) + + gl = '''float[15]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14 +) /* a 3x5 matrix */''' + glTransposed = '''float[15]( + 0, 5, 10, + 1, 6, 11, + 2, 7, 12, + 3, 8, 13, + 4, 9, 14 +) /* a 5x3 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[3][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14) +)''' + glNestedTransposed = '''float[5][3]( + float[](0, 5, 10), + float[](1, 6, 11), + float[](2, 7, 12), + float[](3, 8, 13), + float[](4, 9, 14) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2], +[3]]) + + gl = '''vec4(0, 1, 2, 3)''' + glTransposed = '''vec4(0, 1, 2, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5], +[6, 7]]) + + gl = '''mat2x4(0, 1, 2, 3, 4, 5, 6, 7)''' + glTransposed = '''mat4x2(0, 2, 4, 6, 1, 3, 5, 7)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5], +[6, 7, 8], +[9, 10, 11]]) + + gl = '''mat3x4(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)''' + glTransposed = '''mat4x3(0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3], +[ 4, 5, 6, 7], +[ 8, 9, 10, 11], +[12, 13, 14, 15]]) + + gl = '''mat4( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)''' + glTransposed = '''mat4(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14], +[15, 16, 17, 18, 19]]) + + gl = '''float[20]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19 +) /* a 4x5 matrix */''' + glTransposed = '''float[20]( + 0, 5, 10, 15, + 1, 6, 11, 16, + 2, 7, 12, 17, + 3, 8, 13, 18, + 4, 9, 14, 19 +) /* a 5x4 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[4][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14), + float[](15, 16, 17, 18, 19) +)''' + glNestedTransposed = '''float[5][4]( + float[](0, 5, 10, 15), + float[](1, 6, 11, 16), + float[](2, 7, 12, 17), + float[](3, 8, 13, 18), + float[](4, 9, 14, 19) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2], +[3], +[4]]) + + gl = '''float[5](0, 1, 2, 3, 4)''' + glTransposed = '''float[5](0, 1, 2, 3, 4)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5], +[6, 7], +[8, 9]]) + + gl = '''float[10]( + 0, 1, + 2, 3, + 4, 5, + 6, 7, + 8, 9 +) /* a 5x2 matrix */''' + glTransposed = '''float[10]( + 0, 2, 4, 6, 8, + 1, 3, 5, 7, 9 +) /* a 2x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][2]( + float[](0, 1), + float[](2, 3), + float[](4, 5), + float[](6, 7), + float[](8, 9) +)''' + glNestedTransposed = '''float[2][5]( + float[](0, 2, 4, 6, 8), + float[](1, 3, 5, 7, 9) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2], +[ 3, 4, 5], +[ 6, 7, 8], +[ 9, 10, 11], +[12, 13, 14]]) + + gl = '''float[15]( + 0, 1, 2, + 3, 4, 5, + 6, 7, 8, + 9, 10, 11, + 12, 13, 14 +) /* a 5x3 matrix */''' + glTransposed = '''float[15]( + 0, 3, 6, 9, 12, + 1, 4, 7, 10, 13, + 2, 5, 8, 11, 14 +) /* a 3x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][3]( + float[]( 0, 1, 2), + float[]( 3, 4, 5), + float[]( 6, 7, 8), + float[]( 9, 10, 11), + float[](12, 13, 14) +)''' + glNestedTransposed = '''float[3][5]( + float[](0, 3, 6, 9, 12), + float[](1, 4, 7, 10, 13), + float[](2, 5, 8, 11, 14) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2, 3], +[ 4, 5, 6, 7], +[ 8, 9, 10, 11], +[12, 13, 14, 15], +[16, 17, 18, 19]]) + + gl = '''float[20]( + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19 +) /* a 5x4 matrix */''' + glTransposed = '''float[20]( + 0, 4, 8, 12, 16, + 1, 5, 9, 13, 17, + 2, 6, 10, 14, 18, + 3, 7, 11, 15, 19 +) /* a 4x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][4]( + float[]( 0, 1, 2, 3), + float[]( 4, 5, 6, 7), + float[]( 8, 9, 10, 11), + float[](12, 13, 14, 15), + float[](16, 17, 18, 19) +)''' + glNestedTransposed = '''float[4][5]( + float[](0, 4, 8, 12, 16), + float[](1, 5, 9, 13, 17), + float[](2, 6, 10, 14, 18), + float[](3, 7, 11, 15, 19) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14], +[15, 16, 17, 18, 19], +[20, 21, 22, 23, 24]]) + + gl = '''float[25]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24 +) /* a 5x5 matrix */''' + glTransposed = '''float[25]( + 0, 5, 10, 15, 20, + 1, 6, 11, 16, 21, + 2, 7, 12, 17, 22, + 3, 8, 13, 18, 23, + 4, 9, 14, 19, 24 +) /* a 5x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14), + float[](15, 16, 17, 18, 19), + float[](20, 21, 22, 23, 24) +)''' + glNestedTransposed = '''float[5][5]( + float[](0, 5, 10, 15, 20), + float[](1, 6, 11, 16, 21), + float[](2, 7, 12, 17, 22), + float[](3, 8, 13, 18, 23), + float[](4, 9, 14, 19, 24) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_gtk.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_gtk.py new file mode 100644 index 0000000000000000000000000000000000000000..5a595ab04d3a29d23e06ec12207bf917392aebce --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_gtk.py @@ -0,0 +1,18 @@ +from sympy.functions.elementary.trigonometric import sin +from sympy.printing.gtk import print_gtk +from sympy.testing.pytest import XFAIL, raises + +# this test fails if python-lxml isn't installed. We don't want to depend on +# anything with SymPy + + +@XFAIL +def test_1(): + from sympy.abc import x + print_gtk(x**2, start_viewer=False) + print_gtk(x**2 + sin(x)/4, start_viewer=False) + + +def test_settings(): + from sympy.abc import x + raises(TypeError, lambda: print_gtk(x, method="garbage")) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_jax.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_jax.py new file mode 100644 index 0000000000000000000000000000000000000000..365d87c5b91fdd49a8e46cfde9c2b5792c23a03c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_jax.py @@ -0,0 +1,370 @@ +from sympy.concrete.summations import Sum +from sympy.core.mod import Mod +from sympy.core.relational import (Equality, Unequality) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.utilities.lambdify import lambdify + +from sympy.abc import x, i, j, a, b, c, d +from sympy.core import Function, Pow, Symbol +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt +from sympy.tensor.array import Array +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.printing.numpy import JaxPrinter, _jax_known_constants, _jax_known_functions +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +# Unlike NumPy which will aggressively promote operands to double precision, +# jax always uses single precision. Double precision in jax can be +# configured before the call to `import jax`, however this must be explicitly +# configured and is not fully supported. Thus, the tests here have been modified +# from the tests in test_numpy.py, only in the fact that they assert lambdify +# function accuracy to only single precision accuracy. +# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + +jax = import_module('jax') + +if jax: + deafult_float_info = jax.numpy.finfo(jax.numpy.array([]).dtype) + JAX_DEFAULT_EPSILON = deafult_float_info.eps + + +def test_jax_piecewise_regression(): + """ + NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid + breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+. + See gh-9747 and gh-9749 for details. + """ + printer = JaxPrinter() + p = Piecewise((1, x < 0), (0, True)) + assert printer.doprint(p) == \ + 'jax.numpy.select([jax.numpy.less(x, 0),True], [1,0], default=jax.numpy.nan)' + assert printer.module_imports == {'jax.numpy': {'select', 'less', 'nan'}} + + +def test_jax_logaddexp(): + lae = logaddexp(a, b) + assert JaxPrinter().doprint(lae) == 'jax.numpy.logaddexp(a, b)' + lae2 = logaddexp2(a, b) + assert JaxPrinter().doprint(lae2) == 'jax.numpy.logaddexp2(a, b)' + + +def test_jax_sum(): + if not jax: + skip("JAX not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'jax') + + a_, b_ = 0, 10 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'jax') + + a_, b_ = 0, 10 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + + +def test_jax_multiple_sums(): + if not jax: + skip("JAX not installed") + + s = Sum((x + j) * i, (i, a, b), (j, c, d)) + f = lambdify((a, b, c, d, x), s, 'jax') + + a_, b_ = 0, 10 + c_, d_ = 11, 21 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, c_, d_, x_), + sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1))) + + +def test_jax_codegen_einsum(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'jax') + + ma = jax.numpy.array([[1, 2], [3, 4]]) + mb = jax.numpy.array([[1,-2], [-1, 3]]) + assert (f(ma, mb) == jax.numpy.matmul(ma, mb)).all() + + +def test_jax_codegen_extra(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = jax.numpy.array([[1, 2], [3, 4]]) + mb = jax.numpy.array([[1,-2], [-1, 3]]) + mc = jax.numpy.array([[2, 0], [1, 2]]) + md = jax.numpy.array([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.einsum(ma, [0, 1], mb, [2, 3])).all() + + cg = ArrayAdd(M, N) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == ma+mb).all() + + cg = ArrayAdd(M, N, P) + f = lambdify((M, N, P), cg, 'jax') + assert (f(ma, mb, mc) == ma+mb+mc).all() + + cg = ArrayAdd(M, N, P, Q) + f = lambdify((M, N, P, Q), cg, 'jax') + assert (f(ma, mb, mc, md) == ma+mb+mc+md).all() + + cg = PermuteDims(M, [1, 0]) + f = lambdify((M,), cg, 'jax') + assert (f(ma) == ma.T).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.transpose(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.diagonal(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all() + + +def test_jax_relational(): + if not jax: + skip("JAX not installed") + + e = Equality(x, 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, False]) + + e = Unequality(x, 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, False, True]) + + e = (x < 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, False, False]) + + e = (x <= 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, True, False]) + + e = (x > 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, False, True]) + + e = (x >= 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, True]) + + # Multi-condition expressions + e = (x >= 1) & (x < 2) + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, False]) + + e = (x >= 1) | (x < 2) + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, True, True]) + +def test_jax_mod(): + if not jax: + skip("JAX not installed") + + e = Mod(a, b) + f = lambdify((a, b), e, 'jax') + + a_ = jax.numpy.array([0, 1, 2, 3]) + b_ = 2 + assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = jax.numpy.array([0, 1, 2, 3]) + b_ = jax.numpy.array([2, 2, 2, 2]) + assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = jax.numpy.array([2, 3, 4, 5]) + b_ = jax.numpy.array([2, 3, 4, 5]) + assert jax.numpy.array_equal(f(a_, b_), [0, 0, 0, 0]) + + +def test_jax_pow(): + if not jax: + skip('JAX not installed') + + expr = Pow(2, -1, evaluate=False) + f = lambdify([], expr, 'jax') + assert f() == 0.5 + + +def test_jax_expm1(): + if not jax: + skip("JAX not installed") + + f = lambdify((a,), expm1(a), 'jax') + assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * JAX_DEFAULT_EPSILON + + +def test_jax_log1p(): + if not jax: + skip("JAX not installed") + + f = lambdify((a,), log1p(a), 'jax') + assert abs(f(1e-99) - 1e-99) <= 1e-99 * JAX_DEFAULT_EPSILON + +def test_jax_hypot(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a, b), hypot(a, b), 'jax')(3, 4) - 5) <= JAX_DEFAULT_EPSILON + +def test_jax_log10(): + if not jax: + skip("JAX not installed") + + assert abs(lambdify((a,), log10(a), 'jax')(100) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_exp2(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), exp2(a), 'jax')(5) - 32) <= JAX_DEFAULT_EPSILON + + +def test_jax_log2(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), log2(a), 'jax')(256) - 8) <= JAX_DEFAULT_EPSILON + + +def test_jax_Sqrt(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), Sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_sqrt(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_matsolve(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 3, 3) + x = MatrixSymbol("x", 3, 1) + + expr = M**(-1) * x + x + matsolve_expr = MatrixSolve(M, x) + x + + f = lambdify((M, x), expr, 'jax') + f_matsolve = lambdify((M, x), matsolve_expr, 'jax') + + m0 = jax.numpy.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]]) + assert jax.numpy.linalg.matrix_rank(m0) == 3 + + x0 = jax.numpy.array([3, 4, 5]) + + assert jax.numpy.allclose(f_matsolve(m0, x0), f(m0, x0)) + + +def test_16857(): + if not jax: + skip("JAX not installed") + + a_1 = MatrixSymbol('a_1', 10, 3) + a_2 = MatrixSymbol('a_2', 10, 3) + a_3 = MatrixSymbol('a_3', 10, 3) + a_4 = MatrixSymbol('a_4', 10, 3) + A = BlockMatrix([[a_1, a_2], [a_3, a_4]]) + assert A.shape == (20, 6) + + printer = JaxPrinter() + assert printer.doprint(A) == 'jax.numpy.block([[a_1, a_2], [a_3, a_4]])' + + +def test_issue_17006(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + + f = lambdify(M, M + Identity(2), 'jax') + ma = jax.numpy.array([[1, 2], [3, 4]]) + mr = jax.numpy.array([[2, 2], [3, 5]]) + + assert (f(ma) == mr).all() + + from sympy.core.symbol import symbols + n = symbols('n', integer=True) + N = MatrixSymbol("M", n, n) + raises(NotImplementedError, lambda: lambdify(N, N + Identity(n), 'jax')) + + +def test_jax_array(): + assert JaxPrinter().doprint(Array(((1, 2), (3, 5)))) == 'jax.numpy.array([[1, 2], [3, 5]])' + assert JaxPrinter().doprint(Array((1, 2))) == 'jax.numpy.array([1, 2])' + + +def test_jax_known_funcs_consts(): + assert _jax_known_constants['NaN'] == 'jax.numpy.nan' + assert _jax_known_constants['EulerGamma'] == 'jax.numpy.euler_gamma' + + assert _jax_known_functions['acos'] == 'jax.numpy.arccos' + assert _jax_known_functions['log'] == 'jax.numpy.log' + + +def test_jax_print_methods(): + prntr = JaxPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + + +def test_jax_printmethod(): + printer = JaxPrinter() + assert hasattr(printer, 'printmethod') + assert printer.printmethod == '_jaxcode' + + +def test_jax_custom_print_method(): + + class expm1(Function): + + def _jaxcode(self, printer): + x, = self.args + function = f'expm1({printer._print(x)})' + return printer._module_format(printer._module + '.' + function) + + printer = JaxPrinter() + assert printer.doprint(expm1(Symbol('x'))) == 'jax.numpy.expm1(x)' diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_jscode.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_jscode.py new file mode 100644 index 0000000000000000000000000000000000000000..9199a8e0d62e87f2e964cb1712726a21c894fd20 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_jscode.py @@ -0,0 +1,396 @@ +from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio, + EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le, + Lt, Gt, Ge, Mod) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + sinh, cosh, tanh, asin, acos, acosh, Max, Min) +from sympy.testing.pytest import raises +from sympy.printing.jscode import JavascriptCodePrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol + +from sympy.printing.jscode import jscode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + assert jscode(Abs(x)) == "Math.abs(x)" + + +def test_jscode_sqrt(): + assert jscode(sqrt(x)) == "Math.sqrt(x)" + assert jscode(x**0.5) == "Math.sqrt(x)" + assert jscode(x**(S.One/3)) == "Math.cbrt(x)" + + +def test_jscode_Pow(): + g = implemented_function('g', Lambda(x, 2*x)) + assert jscode(x**3) == "Math.pow(x, 3)" + assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))" + assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)" + assert jscode(x**-1.0) == '1/x' + + +def test_jscode_constants_mathh(): + assert jscode(exp(1)) == "Math.E" + assert jscode(pi) == "Math.PI" + assert jscode(oo) == "Number.POSITIVE_INFINITY" + assert jscode(-oo) == "Number.NEGATIVE_INFINITY" + + +def test_jscode_constants_other(): + assert jscode( + 2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17) + assert jscode( + 2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_jscode_Rational(): + assert jscode(Rational(3, 7)) == "3/7" + assert jscode(Rational(18, 9)) == "2" + assert jscode(Rational(3, -7)) == "-3/7" + assert jscode(Rational(-3, -7)) == "3/7" + + +def test_Relational(): + assert jscode(Eq(x, y)) == "x == y" + assert jscode(Ne(x, y)) == "x != y" + assert jscode(Le(x, y)) == "x <= y" + assert jscode(Lt(x, y)) == "x < y" + assert jscode(Gt(x, y)) == "x > y" + assert jscode(Ge(x, y)) == "x >= y" + + +def test_Mod(): + assert jscode(Mod(x, y)) == '((x % y) + y) % y' + assert jscode(Mod(x, x + y)) == '((x % (x + y)) + (x + y)) % (x + y)' + p1, p2 = symbols('p1 p2', positive=True) + assert jscode(Mod(p1, p2)) == 'p1 % p2' + assert jscode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)' + assert jscode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)' + assert jscode(-Mod(p1, p2)) == '-(p1 % p2)' + assert jscode(x*Mod(p1, p2)) == 'x*(p1 % p2)' + + +def test_jscode_Integer(): + assert jscode(Integer(67)) == "67" + assert jscode(Integer(-1)) == "-1" + + +def test_jscode_functions(): + assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))" + assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)" + assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)" + assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)" + assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)" + + +def test_jscode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert jscode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert jscode(g(A[i]), assign_to=A[i]) == ( + "for (var i=0; i 1), (sin(x), x > 0)) + raises(ValueError, lambda: jscode(expr)) + + +def test_jscode_Piecewise_deep(): + p = jscode(2*Piecewise((x, x < 1), (x**2, True))) + s = \ +"""\ +2*((x < 1) ? ( + x +) +: ( + Math.pow(x, 2) +))\ +""" + assert p == s + + +def test_jscode_settings(): + raises(TypeError, lambda: jscode(sin(x), method="garbage")) + + +def test_jscode_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = JavascriptCodePrinter() + p._not_c = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + assert p._not_c == set() + + +def test_jscode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (var i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert jscode(mat, A) == ( + "A[0] = x*y;\n" + "if (y > 0) {\n" + " A[1] = x + 2;\n" + "}\n" + "else {\n" + " A[1] = y;\n" + "}\n" + "A[2] = Math.sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert jscode(expr) == ( + "((x > 0) ? (\n" + " 2*A[2]\n" + ")\n" + ": (\n" + " A[2]\n" + ")) + Math.sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert jscode(m, M) == ( + "M[0] = Math.sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = Math.cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = Math.sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(jscode(A[0, 0]) == "A[0]") + assert(jscode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(jscode(F) == "(A - B)[0]") diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_julia.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_julia.py new file mode 100644 index 0000000000000000000000000000000000000000..b19c7b4fd4f21d8402ca2f577605322b3ec10f5b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_julia.py @@ -0,0 +1,390 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow +from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix) +from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli, + besselk, hankel1, hankel2, airyai, + airybi, airyaiprime, airybiprime) +from sympy.testing.pytest import XFAIL + +from sympy.printing.julia import julia_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert julia_code(Integer(67)) == "67" + assert julia_code(Integer(-1)) == "-1" + + +def test_Rational(): + assert julia_code(Rational(3, 7)) == "3 // 7" + assert julia_code(Rational(18, 9)) == "2" + assert julia_code(Rational(3, -7)) == "-3 // 7" + assert julia_code(Rational(-3, -7)) == "3 // 7" + assert julia_code(x + Rational(3, 7)) == "x + 3 // 7" + assert julia_code(Rational(3, 7)*x) == "(3 // 7) * x" + + +def test_Relational(): + assert julia_code(Eq(x, y)) == "x == y" + assert julia_code(Ne(x, y)) == "x != y" + assert julia_code(Le(x, y)) == "x <= y" + assert julia_code(Lt(x, y)) == "x < y" + assert julia_code(Gt(x, y)) == "x > y" + assert julia_code(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert julia_code(sin(x) ** cos(x)) == "sin(x) .^ cos(x)" + assert julia_code(abs(x)) == "abs(x)" + assert julia_code(ceiling(x)) == "ceil(x)" + + +def test_Pow(): + assert julia_code(x**3) == "x .^ 3" + assert julia_code(x**(y**3)) == "x .^ (y .^ 3)" + assert julia_code(x**Rational(2, 3)) == 'x .^ (2 // 3)' + g = implemented_function('g', Lambda(x, 2*x)) + assert julia_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5 * 2 * x) .^ (-x + y .^ x) ./ (x .^ 2 + y)" + # For issue 14160 + assert julia_code(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2 * x ./ (y .* y)' + + +def test_basic_ops(): + assert julia_code(x*y) == "x .* y" + assert julia_code(x + y) == "x + y" + assert julia_code(x - y) == "x - y" + assert julia_code(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert julia_code(1/x) == '1 ./ x' + assert julia_code(x**-1) == julia_code(x**-1.0) == '1 ./ x' + assert julia_code(1/sqrt(x)) == '1 ./ sqrt(x)' + assert julia_code(x**-S.Half) == julia_code(x**-0.5) == '1 ./ sqrt(x)' + assert julia_code(sqrt(x)) == 'sqrt(x)' + assert julia_code(x**S.Half) == julia_code(x**0.5) == 'sqrt(x)' + assert julia_code(1/pi) == '1 / pi' + assert julia_code(pi**-1) == julia_code(pi**-1.0) == '1 / pi' + assert julia_code(pi**-0.5) == '1 / sqrt(pi)' + + +def test_mix_number_mult_symbols(): + assert julia_code(3*x) == "3 * x" + assert julia_code(pi*x) == "pi * x" + assert julia_code(3/x) == "3 ./ x" + assert julia_code(pi/x) == "pi ./ x" + assert julia_code(x/3) == "x / 3" + assert julia_code(x/pi) == "x / pi" + assert julia_code(x*y) == "x .* y" + assert julia_code(3*x*y) == "3 * x .* y" + assert julia_code(3*pi*x*y) == "3 * pi * x .* y" + assert julia_code(x/y) == "x ./ y" + assert julia_code(3*x/y) == "3 * x ./ y" + assert julia_code(x*y/z) == "x .* y ./ z" + assert julia_code(x/y*z) == "x .* z ./ y" + assert julia_code(1/x/y) == "1 ./ (x .* y)" + assert julia_code(2*pi*x/y/z) == "2 * pi * x ./ (y .* z)" + assert julia_code(3*pi/x) == "3 * pi ./ x" + assert julia_code(S(3)/5) == "3 // 5" + assert julia_code(S(3)/5*x) == "(3 // 5) * x" + assert julia_code(x/y/z) == "x ./ (y .* z)" + assert julia_code((x+y)/z) == "(x + y) ./ z" + assert julia_code((x+y)/(z+x)) == "(x + y) ./ (x + z)" + assert julia_code((x+y)/EulerGamma) == "(x + y) / eulergamma" + assert julia_code(x/3/pi) == "x / (3 * pi)" + assert julia_code(S(3)/5*x*y/pi) == "(3 // 5) * x .* y / pi" + + +def test_mix_number_pow_symbols(): + assert julia_code(pi**3) == 'pi ^ 3' + assert julia_code(x**2) == 'x .^ 2' + assert julia_code(x**(pi**3)) == 'x .^ (pi ^ 3)' + assert julia_code(x**y) == 'x .^ y' + assert julia_code(x**(y**z)) == 'x .^ (y .^ z)' + assert julia_code((x**y)**z) == '(x .^ y) .^ z' + + +def test_imag(): + I = S('I') + assert julia_code(I) == "im" + assert julia_code(5*I) == "5im" + assert julia_code((S(3)/2)*I) == "(3 // 2) * im" + assert julia_code(3+4*I) == "3 + 4im" + + +def test_constants(): + assert julia_code(pi) == "pi" + assert julia_code(oo) == "Inf" + assert julia_code(-oo) == "-Inf" + assert julia_code(S.NegativeInfinity) == "-Inf" + assert julia_code(S.NaN) == "NaN" + assert julia_code(S.Exp1) == "e" + assert julia_code(exp(1)) == "e" + + +def test_constants_other(): + assert julia_code(2*GoldenRatio) == "2 * golden" + assert julia_code(2*Catalan) == "2 * catalan" + assert julia_code(2*EulerGamma) == "2 * eulergamma" + + +def test_boolean(): + assert julia_code(x & y) == "x && y" + assert julia_code(x | y) == "x || y" + assert julia_code(~x) == "!x" + assert julia_code(x & y & z) == "x && y && z" + assert julia_code(x | y | z) == "x || y || z" + assert julia_code((x & y) | z) == "z || x && y" + assert julia_code((x | y) & z) == "z && (x || y)" + +def test_sinc(): + assert julia_code(sinc(x)) == 'sinc(x / pi)' + assert julia_code(sinc(x + 3)) == 'sinc((x + 3) / pi)' + assert julia_code(sinc(pi * (x + 3))) == 'sinc(x + 3)' + +def test_Matrices(): + assert julia_code(Matrix(1, 1, [10])) == "[10]" + A = Matrix([[1, sin(x/2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = ("[1 sin(x / 2) abs(x);\n" + "0 1 pi;\n" + "0 e ceil(x)]") + assert julia_code(A) == expected + # row and columns + assert julia_code(A[:,0]) == "[1, 0, 0]" + assert julia_code(A[0,:]) == "[1 sin(x / 2) abs(x)]" + # empty matrices + assert julia_code(Matrix(0, 0, [])) == 'zeros(0, 0)' + assert julia_code(Matrix(0, 3, [])) == 'zeros(0, 3)' + # annoying to read but correct + assert julia_code(Matrix([[x, x - y, -y]])) == "[x x - y -y]" + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2/x), 3*pi/x/5]]) + assert julia_code(A) == "[1 sin(2 ./ x) (3 // 5) * pi ./ x]" + assert julia_code(A.T) == "[1, sin(2 ./ x), (3 // 5) * pi ./ x]" + + +@XFAIL +def test_Matrices_entries_not_hadamard(): + # For Matrix with col >= 2, row >= 2, they need to be scalars + # FIXME: is it worth worrying about this? Its not wrong, just + # leave it user's responsibility to put scalar data for x. + A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]]) + expected = ("[1 sin(2/x) 3*pi/(5*x);\n" + "1 2 x*y]") # <- we give x.*y + assert julia_code(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert julia_code(A*B) == "A * B" + assert julia_code(B*A) == "B * A" + assert julia_code(2*A*B) == "2 * A * B" + assert julia_code(B*2*A) == "2 * B * A" + assert julia_code(A*(B + 3*Identity(n))) == "A * (3 * eye(n) + B)" + assert julia_code(A**(x**2)) == "A ^ (x .^ 2)" + assert julia_code(A**3) == "A ^ 3" + assert julia_code(A**S.Half) == "A ^ (1 // 2)" + + +def test_special_matrices(): + assert julia_code(6*Identity(3)) == "6 * eye(3)" + + +def test_containers(): + assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]" + assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))" + assert julia_code([1]) == "Any[1]" + assert julia_code((1,)) == "(1,)" + assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)" + assert julia_code((1, x*y, (3, x**2))) == "(1, x .* y, (3, x .^ 2))" + # scalar, matrix, empty matrix and empty list + assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])" + + +def test_julia_noninline(): + source = julia_code((x+y)/Catalan, assign_to='me', inline=False) + expected = ( + "const Catalan = %s\n" + "me = (x + y) / Catalan" + ) % Catalan.evalf(17) + assert source == expected + + +def test_julia_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert julia_code(expr) == "((x < 1) ? (x) : (x .^ 2))" + assert julia_code(expr, assign_to="r") == ( + "r = ((x < 1) ? (x) : (x .^ 2))") + assert julia_code(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x\n" + "else\n" + " r = x .^ 2\n" + "end") + expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True)) + expected = ("((x < 1) ? (x .^ 2) :\n" + "(x < 2) ? (x .^ 3) :\n" + "(x < 3) ? (x .^ 4) : (x .^ 5))") + assert julia_code(expr) == expected + assert julia_code(expr, assign_to="r") == "r = " + expected + assert julia_code(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x .^ 2\n" + "elseif (x < 2)\n" + " r = x .^ 3\n" + "elseif (x < 3)\n" + " r = x .^ 4\n" + "else\n" + " r = x .^ 5\n" + "end") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: julia_code(expr)) + + +def test_julia_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x**2, True)) + assert julia_code(2*pw) == "2 * ((x < 1) ? (x) : (x .^ 2))" + assert julia_code(pw/x) == "((x < 1) ? (x) : (x .^ 2)) ./ x" + assert julia_code(pw/(x*y)) == "((x < 1) ? (x) : (x .^ 2)) ./ (x .* y)" + assert julia_code(pw/3) == "((x < 1) ? (x) : (x .^ 2)) / 3" + + +def test_julia_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert julia_code(A, assign_to='a') == "a = [1 2 3]" + A = Matrix([[1, 2], [3, 4]]) + assert julia_code(A, assign_to='A') == "A = [1 2;\n3 4]" + + +def test_julia_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert julia_code(A, assign_to=B) == "B = [1 2 3]" + raises(ValueError, lambda: julia_code(A, assign_to=x)) + raises(ValueError, lambda: julia_code(A, assign_to=C)) + + +def test_julia_matrix_1x1(): + A = Matrix([[3]]) + B = MatrixSymbol('B', 1, 1) + C = MatrixSymbol('C', 1, 2) + assert julia_code(A, assign_to=B) == "B = [3]" + # FIXME? + #assert julia_code(A, assign_to=x) == "x = [3]" + raises(ValueError, lambda: julia_code(A, assign_to=C)) + + +def test_julia_matrix_elements(): + A = Matrix([[x, 2, x*y]]) + assert julia_code(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2" + A = MatrixSymbol('AA', 1, 3) + assert julia_code(A) == "AA" + assert julia_code(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \ + "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]" + assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]" + + +def test_julia_boolean(): + assert julia_code(True) == "true" + assert julia_code(S.true) == "true" + assert julia_code(False) == "false" + assert julia_code(S.false) == "false" + + +def test_julia_not_supported(): + with raises(NotImplementedError): + julia_code(S.ComplexInfinity) + + f = Function('f') + assert julia_code(f(x).diff(x), strict=False) == ( + "# Not supported in Julia:\n" + "# Derivative\n" + "Derivative(f(x), x)" + ) + + +def test_trick_indent_with_end_else_words(): + # words starting with "end" or "else" do not confuse the indenter + t1 = S('endless') + t2 = S('elsewhere') + pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True)) + assert julia_code(pw, inline=False) == ( + "if (x < 0)\n" + " endless\n" + "elseif (x <= 1)\n" + " elsewhere\n" + "else\n" + " 1\n" + "end") + + +def test_haramard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + assert julia_code(C) == "A .* B" + assert julia_code(C*v) == "(A .* B) * v" + assert julia_code(h*C*v) == "h * (A .* B) * v" + assert julia_code(C*A) == "(A .* B) * A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + assert julia_code(C*x*y) == "(x .* y) * (A .* B)" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x*y + assert julia_code(M) == ( + "sparse([4, 2, 3, 1, 2], [1, 3, 3, 4, 4], [x .* y, 20, 10, 30, 22], 5, 6)" + ) + + +def test_specfun(): + n = Symbol('n') + for f in [besselj, bessely, besseli, besselk]: + assert julia_code(f(n, x)) == f.__name__ + '(n, x)' + for f in [airyai, airyaiprime, airybi, airybiprime]: + assert julia_code(f(x)) == f.__name__ + '(x)' + assert julia_code(hankel1(n, x)) == 'hankelh1(n, x)' + assert julia_code(hankel2(n, x)) == 'hankelh2(n, x)' + assert julia_code(jn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* besselj(n + 1 // 2, x) / 2' + assert julia_code(yn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* bessely(n + 1 // 2, x) / 2' + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(julia_code(A[0, 0]) == "A[1,1]") + assert(julia_code(3 * A[0, 0]) == "3 * A[1,1]") + + F = C[0, 0].subs(C, A - B) + assert(julia_code(F) == "(A - B)[1,1]") diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_lambdarepr.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_lambdarepr.py new file mode 100644 index 0000000000000000000000000000000000000000..94e09ada7a9ce7d01667edd8fc6ec35ebfbb9639 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_lambdarepr.py @@ -0,0 +1,246 @@ +from sympy.concrete.summations import Sum +from sympy.core.expr import Expr +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.sets.sets import Interval +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import raises + +from sympy.printing.tensorflow import TensorflowPrinter +from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter + + +x, y, z = symbols("x,y,z") +i, a, b = symbols("i,a,b") +j, c, d = symbols("j,c,d") + + +def test_basic(): + assert lambdarepr(x*y) == "x*y" + assert lambdarepr(x + y) in ["y + x", "x + y"] + assert lambdarepr(x**y) == "x**y" + + +def test_matrix(): + # Test printing a Matrix that has an element that is printed differently + # with the LambdaPrinter than with the StrPrinter. + e = x % 2 + assert lambdarepr(e) != str(e) + assert lambdarepr(Matrix([e])) == 'ImmutableDenseMatrix([[x % 2]])' + + +def test_piecewise(): + # In each case, test eval() the lambdarepr() to make sure there are a + # correct number of parentheses. It will give a SyntaxError if there aren't. + + h = "lambda x: " + + p = Piecewise((x, x < 0)) + l = lambdarepr(p) + eval(h + l) + assert l == "((x) if (x < 0) else None)" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + (0, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else None)" + + p = Piecewise( + (x, x < 1), + (x**2, Interval(3, 4, True, False).contains(x)), + (0, True), + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))" + + p = Piecewise( + (x**2, x < 0), + (x, x < 1), + (2 - x, x >= 1), + (0, True), evaluate=False + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ + " else (2 - x) if (x >= 1) else (0))" + + p = Piecewise( + (x**2, x < 0), + (x, x < 1), + (2 - x, x >= 1), evaluate=False + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ + " else (2 - x) if (x >= 1) else None)" + + p = Piecewise( + (1, x >= 1), + (2, x >= 2), + (3, x >= 3), + (4, x >= 4), + (5, x >= 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\ + " else (4) if (x >= 4) else (5) if (x >= 5) else (6))" + + p = Piecewise( + (1, x <= 1), + (2, x <= 2), + (3, x <= 3), + (4, x <= 4), + (5, x <= 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\ + " else (4) if (x <= 4) else (5) if (x <= 5) else (6))" + + p = Piecewise( + (1, x > 1), + (2, x > 2), + (3, x > 3), + (4, x > 4), + (5, x > 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\ + " else (4) if (x > 4) else (5) if (x > 5) else (6))" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + (3, x < 3), + (4, x < 4), + (5, x < 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\ + " else (4) if (x < 4) else (5) if (x < 5) else (6))" + + p = Piecewise( + (Piecewise( + (1, x > 0), + (2, True) + ), y > 0), + (3, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))" + + +def test_sum__1(): + # In each case, test eval() the lambdarepr() to make sure that + # it evaluates to the same results as the symbolic expression + s = Sum(x ** i, (i, a, b)) + l = lambdarepr(s) + assert l == "(builtins.sum(x**i for i in range(a, b+1)))" + + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() + +def test_sum__2(): + s = Sum(i * x, (i, a, b)) + l = lambdarepr(s) + assert l == "(builtins.sum(i*x for i in range(a, b+1)))" + + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() + + +def test_multiple_sums(): + s = Sum(i * x + j, (i, a, b), (j, c, d)) + + l = lambdarepr(s) + assert l == "(builtins.sum(i*x + j for j in range(c, d+1) for i in range(a, b+1)))" + + args = x, a, b, c, d + f = lambdify(args, s) + vals = 2, 3, 4, 5, 6 + f_ref = s.subs(zip(args, vals)).doit() + f_res = f(*vals) + assert f_res == f_ref + + +def test_sqrt(): + prntr = LambdaPrinter({'standard' : 'python3'}) + assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + +def test_settings(): + raises(TypeError, lambda: lambdarepr(sin(x), method="garbage")) + + +def test_numexpr(): + # test ITE rewrite as Piecewise + from sympy.logic.boolalg import ITE + expr = ITE(x > 0, True, False, evaluate=False) + assert NumExprPrinter().doprint(expr) == \ + "numexpr.evaluate('where((x > 0), True, False)', truediv=True)" + + from sympy.codegen.ast import Return, FunctionDefinition, Variable, Assignment + func_def = FunctionDefinition(None, 'foo', [Variable(x)], [Assignment(y,x), Return(y**2)]) + expected = "def foo(x):\n"\ + " y = numexpr.evaluate('x', truediv=True)\n"\ + " return numexpr.evaluate('y**2', truediv=True)" + assert NumExprPrinter().doprint(func_def) == expected + + +class CustomPrintedObject(Expr): + def _lambdacode(self, printer): + return 'lambda' + + def _tensorflowcode(self, printer): + return 'tensorflow' + + def _numpycode(self, printer): + return 'numpy' + + def _numexprcode(self, printer): + return 'numexpr' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + # In each case, printmethod is called to test + # its working + + obj = CustomPrintedObject() + assert LambdaPrinter().doprint(obj) == 'lambda' + assert TensorflowPrinter().doprint(obj) == 'tensorflow' + assert NumExprPrinter().doprint(obj) == "numexpr.evaluate('numexpr', truediv=True)" + + assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \ + "numexpr.evaluate('where((x >= 0), y, z)', truediv=True)" diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_latex.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..063611d09a923881cd94bd693f3f3f721535fd0c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_latex.py @@ -0,0 +1,3164 @@ +from sympy import MatAdd, MatMul, Array +from sympy.algebras.quaternion import Quaternion +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.combinatorics.permutations import Cycle, Permutation, AppliedPermutation +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple, Dict +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import (Derivative, Function, Lambda, Subs, diff) +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core.numbers import (AlgebraicNumber, Float, I, Integer, Rational, oo, pi) +from sympy.core.parameters import evaluate +from sympy.core.power import Pow +from sympy.core.relational import Eq, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial, factorial2, subfactorial) +from sympy.functions.combinatorial.numbers import (bernoulli, bell, catalan, euler, genocchi, + lucas, fibonacci, tribonacci, divisor_sigma, udivisor_sigma, + mobius, primenu, primeomega, + totient, reduced_totient) +from sympy.functions.elementary.complexes import (Abs, arg, conjugate, im, polar_lift, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (asinh, coth) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (Max, Min, root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acsc, asin, cos, cot, sin, tan) +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.functions.special.elliptic_integrals import (elliptic_e, elliptic_f, elliptic_k, elliptic_pi) +from sympy.functions.special.error_functions import (Chi, Ci, Ei, Shi, Si, expint) +from sympy.functions.special.gamma_functions import (gamma, uppergamma) +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.functions.special.mathieu_functions import (mathieuc, mathieucprime, mathieus, mathieusprime) +from sympy.functions.special.polynomials import (assoc_laguerre, assoc_legendre, chebyshevt, chebyshevu, gegenbauer, hermite, jacobi, laguerre, legendre) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.spherical_harmonics import (Ynm, Znm) +from sympy.functions.special.tensor_functions import (KroneckerDelta, LeviCivita) +from sympy.functions.special.zeta_functions import (dirichlet_eta, lerchphi, polylog, stieltjes, zeta) +from sympy.integrals.integrals import Integral +from sympy.integrals.transforms import (CosineTransform, FourierTransform, InverseCosineTransform, InverseFourierTransform, InverseLaplaceTransform, InverseMellinTransform, InverseSineTransform, LaplaceTransform, MellinTransform, SineTransform) +from sympy.logic import Implies +from sympy.logic.boolalg import (And, Or, Xor, Equivalent, false, Not, true) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.kronecker import KroneckerProduct +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.permutation import PermutationMatrix +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions.dotproduct import DotProduct +from sympy.physics.control.lti import TransferFunction, Series, Parallel, Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback +from sympy.physics.quantum import Commutator, Operator +from sympy.physics.quantum.trace import Tr +from sympy.physics.units import meter, gibibyte, gram, microgram, second, milli, micro +from sympy.polys.domains.integerring import ZZ +from sympy.polys.fields import field +from sympy.polys.polytools import Poly +from sympy.polys.rings import ring +from sympy.polys.rootoftools import (RootSum, rootof) +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.series.limits import Limit +from sympy.series.order import Order +from sympy.series.sequences import (SeqAdd, SeqFormula, SeqMul, SeqPer) +from sympy.sets.conditionset import ConditionSet +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ComplexRegion, ImageSet, Range) +from sympy.sets.ordinals import Ordinal, OrdinalOmega, OmegaPower +from sympy.sets.powerset import PowerSet +from sympy.sets.sets import (FiniteSet, Interval, Union, Intersection, Complement, SymmetricDifference, ProductSet) +from sympy.sets.setexpr import SetExpr +from sympy.stats.crv_types import Normal +from sympy.stats.symbolic_probability import (Covariance, Expectation, + Probability, Variance) +from sympy.tensor.array import (ImmutableDenseNDimArray, + ImmutableSparseNDimArray, + MutableSparseNDimArray, + MutableDenseNDimArray, + tensorproduct) +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement +from sympy.tensor.indexed import (Idx, Indexed, IndexedBase) +from sympy.tensor.toperators import PartialDerivative +from sympy.vector import CoordSys3D, Cross, Curl, Dot, Divergence, Gradient, Laplacian + + +from sympy.testing.pytest import (XFAIL, raises, _both_exp_pow, + warns_deprecated_sympy) +from sympy.printing.latex import (latex, translate, greek_letters_set, + tex_greek_dictionary, multiline_latex, + latex_escape, LatexPrinter) + +import sympy as sym + +from sympy.abc import mu, tau + + +class lowergamma(sym.lowergamma): + pass # testing notation inheritance by a subclass with same name + + +x, y, z, t, w, a, b, c, s, p = symbols('x y z t w a b c s p') +k, m, n = symbols('k m n', integer=True) + + +def test_printmethod(): + class R(Abs): + def _latex(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert latex(R(x)) == r"foo(x)" + + class R(Abs): + def _latex(self, printer): + return "foo" + assert latex(R(x)) == r"foo" + + +def test_latex_basic(): + assert latex(1 + x) == r"x + 1" + assert latex(x**2) == r"x^{2}" + assert latex(x**(1 + x)) == r"x^{x + 1}" + assert latex(x**3 + x + 1 + x**2) == r"x^{3} + x^{2} + x + 1" + + assert latex(2*x*y) == r"2 x y" + assert latex(2*x*y, mul_symbol='dot') == r"2 \cdot x \cdot y" + assert latex(3*x**2*y, mul_symbol='\\,') == r"3\,x^{2}\,y" + assert latex(1.5*3**x, mul_symbol='\\,') == r"1.5 \cdot 3^{x}" + + assert latex(x**S.Half**5) == r"\sqrt[32]{x}" + assert latex(Mul(S.Half, x**2, -5, evaluate=False)) == r"\frac{1}{2} x^{2} \left(-5\right)" + assert latex(Mul(S.Half, x**2, 5, evaluate=False)) == r"\frac{1}{2} x^{2} \cdot 5" + assert latex(Mul(-5, -5, evaluate=False)) == r"\left(-5\right) \left(-5\right)" + assert latex(Mul(5, -5, evaluate=False)) == r"5 \left(-5\right)" + assert latex(Mul(S.Half, -5, S.Half, evaluate=False)) == r"\frac{1}{2} \left(-5\right) \frac{1}{2}" + assert latex(Mul(5, I, 5, evaluate=False)) == r"5 i 5" + assert latex(Mul(5, I, -5, evaluate=False)) == r"5 i \left(-5\right)" + assert latex(Mul(Pow(x, 2), S.Half*x + 1)) == r"x^{2} \left(\frac{x}{2} + 1\right)" + assert latex(Mul(Pow(x, 3), Rational(2, 3)*x + 1)) == r"x^{3} \left(\frac{2 x}{3} + 1\right)" + assert latex(Mul(Pow(x, 11), 2*x + 1)) == r"x^{11} \left(2 x + 1\right)" + + assert latex(Mul(0, 1, evaluate=False)) == r'0 \cdot 1' + assert latex(Mul(1, 0, evaluate=False)) == r'1 \cdot 0' + assert latex(Mul(1, 1, evaluate=False)) == r'1 \cdot 1' + assert latex(Mul(-1, 1, evaluate=False)) == r'\left(-1\right) 1' + assert latex(Mul(1, 1, 1, evaluate=False)) == r'1 \cdot 1 \cdot 1' + assert latex(Mul(1, 2, evaluate=False)) == r'1 \cdot 2' + assert latex(Mul(1, S.Half, evaluate=False)) == r'1 \cdot \frac{1}{2}' + assert latex(Mul(1, 1, S.Half, evaluate=False)) == \ + r'1 \cdot 1 \cdot \frac{1}{2}' + assert latex(Mul(1, 1, 2, 3, x, evaluate=False)) == \ + r'1 \cdot 1 \cdot 2 \cdot 3 x' + assert latex(Mul(1, -1, evaluate=False)) == r'1 \left(-1\right)' + assert latex(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == \ + r'4 \cdot 3 \cdot 2 \cdot 1 \cdot 0 y x' + assert latex(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == \ + r'4 \cdot 3 \cdot 2 \left(z + 1\right) 0 y x' + assert latex(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == \ + r'\frac{2}{3} \cdot \frac{5}{7}' + + assert latex(1/x) == r"\frac{1}{x}" + assert latex(1/x, fold_short_frac=True) == r"1 / x" + assert latex(-S(3)/2) == r"- \frac{3}{2}" + assert latex(-S(3)/2, fold_short_frac=True) == r"- 3 / 2" + assert latex(1/x**2) == r"\frac{1}{x^{2}}" + assert latex(1/(x + y)/2) == r"\frac{1}{2 \left(x + y\right)}" + assert latex(x/2) == r"\frac{x}{2}" + assert latex(x/2, fold_short_frac=True) == r"x / 2" + assert latex((x + y)/(2*x)) == r"\frac{x + y}{2 x}" + assert latex((x + y)/(2*x), fold_short_frac=True) == \ + r"\left(x + y\right) / 2 x" + assert latex((x + y)/(2*x), long_frac_ratio=0) == \ + r"\frac{1}{2 x} \left(x + y\right)" + assert latex((x + y)/x) == r"\frac{x + y}{x}" + assert latex((x + y)/x, long_frac_ratio=3) == r"\frac{x + y}{x}" + assert latex((2*sqrt(2)*x)/3) == r"\frac{2 \sqrt{2} x}{3}" + assert latex((2*sqrt(2)*x)/3, long_frac_ratio=2) == \ + r"\frac{2 x}{3} \sqrt{2}" + assert latex(binomial(x, y)) == r"{\binom{x}{y}}" + + x_star = Symbol('x^*') + f = Function('f') + assert latex(x_star**2) == r"\left(x^{*}\right)^{2}" + assert latex(x_star**2, parenthesize_super=False) == r"{x^{*}}^{2}" + assert latex(Derivative(f(x_star), x_star,2)) == r"\frac{d^{2}}{d \left(x^{*}\right)^{2}} f{\left(x^{*} \right)}" + assert latex(Derivative(f(x_star), x_star,2), parenthesize_super=False) == r"\frac{d^{2}}{d {x^{*}}^{2}} f{\left(x^{*} \right)}" + + assert latex(2*Integral(x, x)/3) == r"\frac{2 \int x\, dx}{3}" + assert latex(2*Integral(x, x)/3, fold_short_frac=True) == \ + r"\left(2 \int x\, dx\right) / 3" + + assert latex(sqrt(x)) == r"\sqrt{x}" + assert latex(x**Rational(1, 3)) == r"\sqrt[3]{x}" + assert latex(x**Rational(1, 3), root_notation=False) == r"x^{\frac{1}{3}}" + assert latex(sqrt(x)**3) == r"x^{\frac{3}{2}}" + assert latex(sqrt(x), itex=True) == r"\sqrt{x}" + assert latex(x**Rational(1, 3), itex=True) == r"\root{3}{x}" + assert latex(sqrt(x)**3, itex=True) == r"x^{\frac{3}{2}}" + assert latex(x**Rational(3, 4)) == r"x^{\frac{3}{4}}" + assert latex(x**Rational(3, 4), fold_frac_powers=True) == r"x^{3/4}" + assert latex((x + 1)**Rational(3, 4)) == \ + r"\left(x + 1\right)^{\frac{3}{4}}" + assert latex((x + 1)**Rational(3, 4), fold_frac_powers=True) == \ + r"\left(x + 1\right)^{3/4}" + assert latex(AlgebraicNumber(sqrt(2))) == r"\sqrt{2}" + assert latex(AlgebraicNumber(sqrt(2), [3, -7])) == r"-7 + 3 \sqrt{2}" + assert latex(AlgebraicNumber(sqrt(2), alias='alpha')) == r"\alpha" + assert latex(AlgebraicNumber(sqrt(2), [3, -7], alias='alpha')) == \ + r"3 \alpha - 7" + assert latex(AlgebraicNumber(2**(S(1)/3), [1, 3, -7], alias='beta')) == \ + r"\beta^{2} + 3 \beta - 7" + + k = ZZ.cyclotomic_field(5) + assert latex(k.ext.field_element([1, 2, 3, 4])) == \ + r"\zeta^{3} + 2 \zeta^{2} + 3 \zeta + 4" + assert latex(k.ext.field_element([1, 2, 3, 4]), order='old') == \ + r"4 + 3 \zeta + 2 \zeta^{2} + \zeta^{3}" + assert latex(k.primes_above(19)[0]) == \ + r"\left(19, \zeta^{2} + 5 \zeta + 1\right)" + assert latex(k.primes_above(19)[0], order='old') == \ + r"\left(19, 1 + 5 \zeta + \zeta^{2}\right)" + assert latex(k.primes_above(7)[0]) == r"\left(7\right)" + + assert latex(1.5e20*x) == r"1.5 \cdot 10^{20} x" + assert latex(1.5e20*x, mul_symbol='dot') == r"1.5 \cdot 10^{20} \cdot x" + assert latex(1.5e20*x, mul_symbol='times') == \ + r"1.5 \times 10^{20} \times x" + + assert latex(1/sin(x)) == r"\frac{1}{\sin{\left(x \right)}}" + assert latex(sin(x)**-1) == r"\frac{1}{\sin{\left(x \right)}}" + assert latex(sin(x)**Rational(3, 2)) == \ + r"\sin^{\frac{3}{2}}{\left(x \right)}" + assert latex(sin(x)**Rational(3, 2), fold_frac_powers=True) == \ + r"\sin^{3/2}{\left(x \right)}" + + assert latex(~x) == r"\neg x" + assert latex(x & y) == r"x \wedge y" + assert latex(x & y & z) == r"x \wedge y \wedge z" + assert latex(x | y) == r"x \vee y" + assert latex(x | y | z) == r"x \vee y \vee z" + assert latex((x & y) | z) == r"z \vee \left(x \wedge y\right)" + assert latex(Implies(x, y)) == r"x \Rightarrow y" + assert latex(~(x >> ~y)) == r"x \not\Rightarrow \neg y" + assert latex(Implies(Or(x,y), z)) == r"\left(x \vee y\right) \Rightarrow z" + assert latex(Implies(z, Or(x,y))) == r"z \Rightarrow \left(x \vee y\right)" + assert latex(~(x & y)) == r"\neg \left(x \wedge y\right)" + + assert latex(~x, symbol_names={x: "x_i"}) == r"\neg x_i" + assert latex(x & y, symbol_names={x: "x_i", y: "y_i"}) == \ + r"x_i \wedge y_i" + assert latex(x & y & z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"x_i \wedge y_i \wedge z_i" + assert latex(x | y, symbol_names={x: "x_i", y: "y_i"}) == r"x_i \vee y_i" + assert latex(x | y | z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"x_i \vee y_i \vee z_i" + assert latex((x & y) | z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"z_i \vee \left(x_i \wedge y_i\right)" + assert latex(Implies(x, y), symbol_names={x: "x_i", y: "y_i"}) == \ + r"x_i \Rightarrow y_i" + assert latex(Pow(Rational(1, 3), -1, evaluate=False)) == r"\frac{1}{\frac{1}{3}}" + assert latex(Pow(Rational(1, 3), -2, evaluate=False)) == r"\frac{1}{(\frac{1}{3})^{2}}" + assert latex(Pow(Integer(1)/100, -1, evaluate=False)) == r"\frac{1}{\frac{1}{100}}" + + p = Symbol('p', positive=True) + assert latex(exp(-p)*log(p)) == r"e^{- p} \log{\left(p \right)}" + + assert latex(Pow(Rational(2, 3), -1, evaluate=False)) == r'\frac{1}{\frac{2}{3}}' + assert latex(Pow(Rational(4, 3), -1, evaluate=False)) == r'\frac{1}{\frac{4}{3}}' + assert latex(Pow(Rational(-3, 4), -1, evaluate=False)) == r'\frac{1}{- \frac{3}{4}}' + assert latex(Pow(Rational(-4, 4), -1, evaluate=False)) == r'\frac{1}{-1}' + assert latex(Pow(Rational(1, 3), -1, evaluate=False)) == r'\frac{1}{\frac{1}{3}}' + assert latex(Pow(Rational(-1, 3), -1, evaluate=False)) == r'\frac{1}{- \frac{1}{3}}' + + +def test_latex_builtins(): + assert latex(True) == r"\text{True}" + assert latex(False) == r"\text{False}" + assert latex(None) == r"\text{None}" + assert latex(true) == r"\text{True}" + assert latex(false) == r'\text{False}' + + +def test_latex_SingularityFunction(): + assert latex(SingularityFunction(x, 4, 5)) == \ + r"{\left\langle x - 4 \right\rangle}^{5}" + assert latex(SingularityFunction(x, -3, 4)) == \ + r"{\left\langle x + 3 \right\rangle}^{4}" + assert latex(SingularityFunction(x, 0, 4)) == \ + r"{\left\langle x \right\rangle}^{4}" + assert latex(SingularityFunction(x, a, n)) == \ + r"{\left\langle - a + x \right\rangle}^{n}" + assert latex(SingularityFunction(x, 4, -2)) == \ + r"{\left\langle x - 4 \right\rangle}^{-2}" + assert latex(SingularityFunction(x, 4, -1)) == \ + r"{\left\langle x - 4 \right\rangle}^{-1}" + + assert latex(SingularityFunction(x, 4, 5)**3) == \ + r"{\left({\langle x - 4 \rangle}^{5}\right)}^{3}" + assert latex(SingularityFunction(x, -3, 4)**3) == \ + r"{\left({\langle x + 3 \rangle}^{4}\right)}^{3}" + assert latex(SingularityFunction(x, 0, 4)**3) == \ + r"{\left({\langle x \rangle}^{4}\right)}^{3}" + assert latex(SingularityFunction(x, a, n)**3) == \ + r"{\left({\langle - a + x \rangle}^{n}\right)}^{3}" + assert latex(SingularityFunction(x, 4, -2)**3) == \ + r"{\left({\langle x - 4 \rangle}^{-2}\right)}^{3}" + assert latex((SingularityFunction(x, 4, -1)**3)**3) == \ + r"{\left({\langle x - 4 \rangle}^{-1}\right)}^{9}" + + +def test_latex_cycle(): + assert latex(Cycle(1, 2, 4)) == r"\left( 1\; 2\; 4\right)" + assert latex(Cycle(1, 2)(4, 5, 6)) == \ + r"\left( 1\; 2\right)\left( 4\; 5\; 6\right)" + assert latex(Cycle()) == r"\left( \right)" + + +def test_latex_permutation(): + assert latex(Permutation(1, 2, 4)) == r"\left( 1\; 2\; 4\right)" + assert latex(Permutation(1, 2)(4, 5, 6)) == \ + r"\left( 1\; 2\right)\left( 4\; 5\; 6\right)" + assert latex(Permutation()) == r"\left( \right)" + assert latex(Permutation(2, 4)*Permutation(5)) == \ + r"\left( 2\; 4\right)\left( 5\right)" + assert latex(Permutation(5)) == r"\left( 5\right)" + + assert latex(Permutation(0, 1), perm_cyclic=False) == \ + r"\begin{pmatrix} 0 & 1 \\ 1 & 0 \end{pmatrix}" + assert latex(Permutation(0, 1)(2, 3), perm_cyclic=False) == \ + r"\begin{pmatrix} 0 & 1 & 2 & 3 \\ 1 & 0 & 3 & 2 \end{pmatrix}" + assert latex(Permutation(), perm_cyclic=False) == \ + r"\left( \right)" + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert latex(Permutation(0, 1)(2, 3)) == \ + r"\begin{pmatrix} 0 & 1 & 2 & 3 \\ 1 & 0 & 3 & 2 \end{pmatrix}" + Permutation.print_cyclic = old_print_cyclic + +def test_latex_Float(): + assert latex(Float(1.0e100)) == r"1.0 \cdot 10^{100}" + assert latex(Float(1.0e-100)) == r"1.0 \cdot 10^{-100}" + assert latex(Float(1.0e-100), mul_symbol="times") == \ + r"1.0 \times 10^{-100}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=2) == \ + r"1.0 \cdot 10^{4}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=4) == \ + r"1.0 \cdot 10^{4}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=5) == \ + r"10000.0" + assert latex(Float('0.099999'), full_prec=True, min=-2, max=5) == \ + r"9.99990000000000 \cdot 10^{-2}" + + +def test_latex_vector_expressions(): + A = CoordSys3D('A') + + assert latex(Cross(A.i, A.j*A.x*3+A.k)) == \ + r"\mathbf{\hat{i}_{A}} \times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}} + \mathbf{\hat{k}_{A}}\right)" + assert latex(Cross(A.i, A.j)) == \ + r"\mathbf{\hat{i}_{A}} \times \mathbf{\hat{j}_{A}}" + assert latex(x*Cross(A.i, A.j)) == \ + r"x \left(\mathbf{\hat{i}_{A}} \times \mathbf{\hat{j}_{A}}\right)" + assert latex(Cross(x*A.i, A.j)) == \ + r'- \mathbf{\hat{j}_{A}} \times \left(\left(x\right)\mathbf{\hat{i}_{A}}\right)' + + assert latex(Curl(3*A.x*A.j)) == \ + r"\nabla\times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Curl(3*A.x*A.j+A.i)) == \ + r"\nabla\times \left(\mathbf{\hat{i}_{A}} + \left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Curl(3*x*A.x*A.j)) == \ + r"\nabla\times \left(\left(3 \mathbf{{x}_{A}} x\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(x*Curl(3*A.x*A.j)) == \ + r"x \left(\nabla\times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)\right)" + + assert latex(Divergence(3*A.x*A.j+A.i)) == \ + r"\nabla\cdot \left(\mathbf{\hat{i}_{A}} + \left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Divergence(3*A.x*A.j)) == \ + r"\nabla\cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(x*Divergence(3*A.x*A.j)) == \ + r"x \left(\nabla\cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)\right)" + + assert latex(Dot(A.i, A.j*A.x*3+A.k)) == \ + r"\mathbf{\hat{i}_{A}} \cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}} + \mathbf{\hat{k}_{A}}\right)" + assert latex(Dot(A.i, A.j)) == \ + r"\mathbf{\hat{i}_{A}} \cdot \mathbf{\hat{j}_{A}}" + assert latex(Dot(x*A.i, A.j)) == \ + r"\mathbf{\hat{j}_{A}} \cdot \left(\left(x\right)\mathbf{\hat{i}_{A}}\right)" + assert latex(x*Dot(A.i, A.j)) == \ + r"x \left(\mathbf{\hat{i}_{A}} \cdot \mathbf{\hat{j}_{A}}\right)" + + assert latex(Gradient(A.x)) == r"\nabla \mathbf{{x}_{A}}" + assert latex(Gradient(A.x + 3*A.y)) == \ + r"\nabla \left(\mathbf{{x}_{A}} + 3 \mathbf{{y}_{A}}\right)" + assert latex(x*Gradient(A.x)) == r"x \left(\nabla \mathbf{{x}_{A}}\right)" + assert latex(Gradient(x*A.x)) == r"\nabla \left(\mathbf{{x}_{A}} x\right)" + + assert latex(Laplacian(A.x)) == r"\Delta \mathbf{{x}_{A}}" + assert latex(Laplacian(A.x + 3*A.y)) == \ + r"\Delta \left(\mathbf{{x}_{A}} + 3 \mathbf{{y}_{A}}\right)" + assert latex(x*Laplacian(A.x)) == r"x \left(\Delta \mathbf{{x}_{A}}\right)" + assert latex(Laplacian(x*A.x)) == r"\Delta \left(\mathbf{{x}_{A}} x\right)" + +def test_latex_symbols(): + Gamma, lmbda, rho = symbols('Gamma, lambda, rho') + tau, Tau, TAU, taU = symbols('tau, Tau, TAU, taU') + assert latex(tau) == r"\tau" + assert latex(Tau) == r"\mathrm{T}" + assert latex(TAU) == r"\tau" + assert latex(taU) == r"\tau" + # Check that all capitalized greek letters are handled explicitly + capitalized_letters = {l.capitalize() for l in greek_letters_set} + assert len(capitalized_letters - set(tex_greek_dictionary.keys())) == 0 + assert latex(Gamma + lmbda) == r"\Gamma + \lambda" + assert latex(Gamma * lmbda) == r"\Gamma \lambda" + assert latex(Symbol('q1')) == r"q_{1}" + assert latex(Symbol('q21')) == r"q_{21}" + assert latex(Symbol('epsilon0')) == r"\epsilon_{0}" + assert latex(Symbol('omega1')) == r"\omega_{1}" + assert latex(Symbol('91')) == r"91" + assert latex(Symbol('alpha_new')) == r"\alpha_{new}" + assert latex(Symbol('C^orig')) == r"C^{orig}" + assert latex(Symbol('x^alpha')) == r"x^{\alpha}" + assert latex(Symbol('beta^alpha')) == r"\beta^{\alpha}" + assert latex(Symbol('e^Alpha')) == r"e^{\mathrm{A}}" + assert latex(Symbol('omega_alpha^beta')) == r"\omega^{\beta}_{\alpha}" + assert latex(Symbol('omega') ** Symbol('beta')) == r"\omega^{\beta}" + + +@XFAIL +def test_latex_symbols_failing(): + rho, mass, volume = symbols('rho, mass, volume') + assert latex( + volume * rho == mass) == r"\rho \mathrm{volume} = \mathrm{mass}" + assert latex(volume / mass * rho == 1) == \ + r"\rho \mathrm{volume} {\mathrm{mass}}^{(-1)} = 1" + assert latex(mass**3 * volume**3) == \ + r"{\mathrm{mass}}^{3} \cdot {\mathrm{volume}}^{3}" + + +@_both_exp_pow +def test_latex_functions(): + assert latex(exp(x)) == r"e^{x}" + assert latex(exp(1) + exp(2)) == r"e + e^{2}" + + f = Function('f') + assert latex(f(x)) == r'f{\left(x \right)}' + assert latex(f) == r'f' + + g = Function('g') + assert latex(g(x, y)) == r'g{\left(x,y \right)}' + assert latex(g) == r'g' + + h = Function('h') + assert latex(h(x, y, z)) == r'h{\left(x,y,z \right)}' + assert latex(h) == r'h' + + Li = Function('Li') + assert latex(Li) == r'\operatorname{Li}' + assert latex(Li(x)) == r'\operatorname{Li}{\left(x \right)}' + + mybeta = Function('beta') + # not to be confused with the beta function + assert latex(mybeta(x, y, z)) == r"\beta{\left(x,y,z \right)}" + assert latex(beta(x, y)) == r'\operatorname{B}\left(x, y\right)' + assert latex(beta(x, evaluate=False)) == r'\operatorname{B}\left(x, x\right)' + assert latex(beta(x, y)**2) == r'\operatorname{B}^{2}\left(x, y\right)' + assert latex(mybeta(x)) == r"\beta{\left(x \right)}" + assert latex(mybeta) == r"\beta" + + g = Function('gamma') + # not to be confused with the gamma function + assert latex(g(x, y, z)) == r"\gamma{\left(x,y,z \right)}" + assert latex(g(x)) == r"\gamma{\left(x \right)}" + assert latex(g) == r"\gamma" + + a_1 = Function('a_1') + assert latex(a_1) == r"a_{1}" + assert latex(a_1(x)) == r"a_{1}{\left(x \right)}" + assert latex(Function('a_1')) == r"a_{1}" + + # Issue #16925 + # multi letter function names + # > simple + assert latex(Function('ab')) == r"\operatorname{ab}" + assert latex(Function('ab1')) == r"\operatorname{ab}_{1}" + assert latex(Function('ab12')) == r"\operatorname{ab}_{12}" + assert latex(Function('ab_1')) == r"\operatorname{ab}_{1}" + assert latex(Function('ab_12')) == r"\operatorname{ab}_{12}" + assert latex(Function('ab_c')) == r"\operatorname{ab}_{c}" + assert latex(Function('ab_cd')) == r"\operatorname{ab}_{cd}" + # > with argument + assert latex(Function('ab')(Symbol('x'))) == r"\operatorname{ab}{\left(x \right)}" + assert latex(Function('ab1')(Symbol('x'))) == r"\operatorname{ab}_{1}{\left(x \right)}" + assert latex(Function('ab12')(Symbol('x'))) == r"\operatorname{ab}_{12}{\left(x \right)}" + assert latex(Function('ab_1')(Symbol('x'))) == r"\operatorname{ab}_{1}{\left(x \right)}" + assert latex(Function('ab_c')(Symbol('x'))) == r"\operatorname{ab}_{c}{\left(x \right)}" + assert latex(Function('ab_cd')(Symbol('x'))) == r"\operatorname{ab}_{cd}{\left(x \right)}" + + # > with power + # does not work on functions without brackets + + # > with argument and power combined + assert latex(Function('ab')()**2) == r"\operatorname{ab}^{2}{\left( \right)}" + assert latex(Function('ab1')()**2) == r"\operatorname{ab}_{1}^{2}{\left( \right)}" + assert latex(Function('ab12')()**2) == r"\operatorname{ab}_{12}^{2}{\left( \right)}" + assert latex(Function('ab_1')()**2) == r"\operatorname{ab}_{1}^{2}{\left( \right)}" + assert latex(Function('ab_12')()**2) == r"\operatorname{ab}_{12}^{2}{\left( \right)}" + assert latex(Function('ab')(Symbol('x'))**2) == r"\operatorname{ab}^{2}{\left(x \right)}" + assert latex(Function('ab1')(Symbol('x'))**2) == r"\operatorname{ab}_{1}^{2}{\left(x \right)}" + assert latex(Function('ab12')(Symbol('x'))**2) == r"\operatorname{ab}_{12}^{2}{\left(x \right)}" + assert latex(Function('ab_1')(Symbol('x'))**2) == r"\operatorname{ab}_{1}^{2}{\left(x \right)}" + assert latex(Function('ab_12')(Symbol('x'))**2) == \ + r"\operatorname{ab}_{12}^{2}{\left(x \right)}" + + # single letter function names + # > simple + assert latex(Function('a')) == r"a" + assert latex(Function('a1')) == r"a_{1}" + assert latex(Function('a12')) == r"a_{12}" + assert latex(Function('a_1')) == r"a_{1}" + assert latex(Function('a_12')) == r"a_{12}" + + # > with argument + assert latex(Function('a')()) == r"a{\left( \right)}" + assert latex(Function('a1')()) == r"a_{1}{\left( \right)}" + assert latex(Function('a12')()) == r"a_{12}{\left( \right)}" + assert latex(Function('a_1')()) == r"a_{1}{\left( \right)}" + assert latex(Function('a_12')()) == r"a_{12}{\left( \right)}" + + # > with power + # does not work on functions without brackets + + # > with argument and power combined + assert latex(Function('a')()**2) == r"a^{2}{\left( \right)}" + assert latex(Function('a1')()**2) == r"a_{1}^{2}{\left( \right)}" + assert latex(Function('a12')()**2) == r"a_{12}^{2}{\left( \right)}" + assert latex(Function('a_1')()**2) == r"a_{1}^{2}{\left( \right)}" + assert latex(Function('a_12')()**2) == r"a_{12}^{2}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**2) == r"a^{2}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**2) == r"a_{1}^{2}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**2) == r"a_{12}^{2}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**2) == r"a_{1}^{2}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**2) == r"a_{12}^{2}{\left(x \right)}" + + assert latex(Function('a')()**32) == r"a^{32}{\left( \right)}" + assert latex(Function('a1')()**32) == r"a_{1}^{32}{\left( \right)}" + assert latex(Function('a12')()**32) == r"a_{12}^{32}{\left( \right)}" + assert latex(Function('a_1')()**32) == r"a_{1}^{32}{\left( \right)}" + assert latex(Function('a_12')()**32) == r"a_{12}^{32}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**32) == r"a^{32}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**32) == r"a_{1}^{32}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**32) == r"a_{12}^{32}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**32) == r"a_{1}^{32}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**32) == r"a_{12}^{32}{\left(x \right)}" + + assert latex(Function('a')()**a) == r"a^{a}{\left( \right)}" + assert latex(Function('a1')()**a) == r"a_{1}^{a}{\left( \right)}" + assert latex(Function('a12')()**a) == r"a_{12}^{a}{\left( \right)}" + assert latex(Function('a_1')()**a) == r"a_{1}^{a}{\left( \right)}" + assert latex(Function('a_12')()**a) == r"a_{12}^{a}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**a) == r"a^{a}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**a) == r"a_{1}^{a}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**a) == r"a_{12}^{a}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**a) == r"a_{1}^{a}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**a) == r"a_{12}^{a}{\left(x \right)}" + + ab = Symbol('ab') + assert latex(Function('a')()**ab) == r"a^{ab}{\left( \right)}" + assert latex(Function('a1')()**ab) == r"a_{1}^{ab}{\left( \right)}" + assert latex(Function('a12')()**ab) == r"a_{12}^{ab}{\left( \right)}" + assert latex(Function('a_1')()**ab) == r"a_{1}^{ab}{\left( \right)}" + assert latex(Function('a_12')()**ab) == r"a_{12}^{ab}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**ab) == r"a^{ab}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**ab) == r"a_{1}^{ab}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**ab) == r"a_{12}^{ab}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**ab) == r"a_{1}^{ab}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**ab) == r"a_{12}^{ab}{\left(x \right)}" + + assert latex(Function('a^12')(x)) == R"a^{12}{\left(x \right)}" + assert latex(Function('a^12')(x) ** ab) == R"\left(a^{12}\right)^{ab}{\left(x \right)}" + assert latex(Function('a__12')(x)) == R"a^{12}{\left(x \right)}" + assert latex(Function('a__12')(x) ** ab) == R"\left(a^{12}\right)^{ab}{\left(x \right)}" + assert latex(Function('a_1__1_2')(x)) == R"a^{1}_{1 2}{\left(x \right)}" + + # issue 5868 + omega1 = Function('omega1') + assert latex(omega1) == r"\omega_{1}" + assert latex(omega1(x)) == r"\omega_{1}{\left(x \right)}" + + assert latex(sin(x)) == r"\sin{\left(x \right)}" + assert latex(sin(x), fold_func_brackets=True) == r"\sin {x}" + assert latex(sin(2*x**2), fold_func_brackets=True) == \ + r"\sin {2 x^{2}}" + assert latex(sin(x**2), fold_func_brackets=True) == \ + r"\sin {x^{2}}" + + assert latex(asin(x)**2) == r"\operatorname{asin}^{2}{\left(x \right)}" + assert latex(asin(x)**2, inv_trig_style="full") == \ + r"\arcsin^{2}{\left(x \right)}" + assert latex(asin(x)**2, inv_trig_style="power") == \ + r"\sin^{-1}{\left(x \right)}^{2}" + assert latex(asin(x**2), inv_trig_style="power", + fold_func_brackets=True) == \ + r"\sin^{-1} {x^{2}}" + assert latex(acsc(x), inv_trig_style="full") == \ + r"\operatorname{arccsc}{\left(x \right)}" + assert latex(asinh(x), inv_trig_style="full") == \ + r"\operatorname{arsinh}{\left(x \right)}" + + assert latex(factorial(k)) == r"k!" + assert latex(factorial(-k)) == r"\left(- k\right)!" + assert latex(factorial(k)**2) == r"k!^{2}" + + assert latex(subfactorial(k)) == r"!k" + assert latex(subfactorial(-k)) == r"!\left(- k\right)" + assert latex(subfactorial(k)**2) == r"\left(!k\right)^{2}" + + assert latex(factorial2(k)) == r"k!!" + assert latex(factorial2(-k)) == r"\left(- k\right)!!" + assert latex(factorial2(k)**2) == r"k!!^{2}" + + assert latex(binomial(2, k)) == r"{\binom{2}{k}}" + assert latex(binomial(2, k)**2) == r"{\binom{2}{k}}^{2}" + + assert latex(FallingFactorial(3, k)) == r"{\left(3\right)}_{k}" + assert latex(RisingFactorial(3, k)) == r"{3}^{\left(k\right)}" + + assert latex(floor(x)) == r"\left\lfloor{x}\right\rfloor" + assert latex(ceiling(x)) == r"\left\lceil{x}\right\rceil" + assert latex(frac(x)) == r"\operatorname{frac}{\left(x\right)}" + assert latex(floor(x)**2) == r"\left\lfloor{x}\right\rfloor^{2}" + assert latex(ceiling(x)**2) == r"\left\lceil{x}\right\rceil^{2}" + assert latex(frac(x)**2) == r"\operatorname{frac}{\left(x\right)}^{2}" + + assert latex(Min(x, 2, x**3)) == r"\min\left(2, x, x^{3}\right)" + assert latex(Min(x, y)**2) == r"\min\left(x, y\right)^{2}" + assert latex(Max(x, 2, x**3)) == r"\max\left(2, x, x^{3}\right)" + assert latex(Max(x, y)**2) == r"\max\left(x, y\right)^{2}" + assert latex(Abs(x)) == r"\left|{x}\right|" + assert latex(Abs(x)**2) == r"\left|{x}\right|^{2}" + assert latex(re(x)) == r"\operatorname{re}{\left(x\right)}" + assert latex(re(x + y)) == \ + r"\operatorname{re}{\left(x\right)} + \operatorname{re}{\left(y\right)}" + assert latex(im(x)) == r"\operatorname{im}{\left(x\right)}" + assert latex(conjugate(x)) == r"\overline{x}" + assert latex(conjugate(x)**2) == r"\overline{x}^{2}" + assert latex(conjugate(x**2)) == r"\overline{x}^{2}" + assert latex(gamma(x)) == r"\Gamma\left(x\right)" + w = Wild('w') + assert latex(gamma(w)) == r"\Gamma\left(w\right)" + assert latex(Order(x)) == r"O\left(x\right)" + assert latex(Order(x, x)) == r"O\left(x\right)" + assert latex(Order(x, (x, 0))) == r"O\left(x\right)" + assert latex(Order(x, (x, oo))) == r"O\left(x; x\rightarrow \infty\right)" + assert latex(Order(x - y, (x, y))) == \ + r"O\left(x - y; x\rightarrow y\right)" + assert latex(Order(x, x, y)) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( 0, \ 0\right)\right)" + assert latex(Order(x, x, y)) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( 0, \ 0\right)\right)" + assert latex(Order(x, (x, oo), (y, oo))) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( \infty, \ \infty\right)\right)" + assert latex(lowergamma(x, y)) == r'\gamma\left(x, y\right)' + assert latex(lowergamma(x, y)**2) == r'\gamma^{2}\left(x, y\right)' + assert latex(uppergamma(x, y)) == r'\Gamma\left(x, y\right)' + assert latex(uppergamma(x, y)**2) == r'\Gamma^{2}\left(x, y\right)' + + assert latex(cot(x)) == r'\cot{\left(x \right)}' + assert latex(coth(x)) == r'\coth{\left(x \right)}' + assert latex(re(x)) == r'\operatorname{re}{\left(x\right)}' + assert latex(im(x)) == r'\operatorname{im}{\left(x\right)}' + assert latex(root(x, y)) == r'x^{\frac{1}{y}}' + assert latex(arg(x)) == r'\arg{\left(x \right)}' + + assert latex(zeta(x)) == r"\zeta\left(x\right)" + assert latex(zeta(x)**2) == r"\zeta^{2}\left(x\right)" + assert latex(zeta(x, y)) == r"\zeta\left(x, y\right)" + assert latex(zeta(x, y)**2) == r"\zeta^{2}\left(x, y\right)" + assert latex(dirichlet_eta(x)) == r"\eta\left(x\right)" + assert latex(dirichlet_eta(x)**2) == r"\eta^{2}\left(x\right)" + assert latex(polylog(x, y)) == r"\operatorname{Li}_{x}\left(y\right)" + assert latex( + polylog(x, y)**2) == r"\operatorname{Li}_{x}^{2}\left(y\right)" + assert latex(lerchphi(x, y, n)) == r"\Phi\left(x, y, n\right)" + assert latex(lerchphi(x, y, n)**2) == r"\Phi^{2}\left(x, y, n\right)" + assert latex(stieltjes(x)) == r"\gamma_{x}" + assert latex(stieltjes(x)**2) == r"\gamma_{x}^{2}" + assert latex(stieltjes(x, y)) == r"\gamma_{x}\left(y\right)" + assert latex(stieltjes(x, y)**2) == r"\gamma_{x}\left(y\right)^{2}" + + assert latex(elliptic_k(z)) == r"K\left(z\right)" + assert latex(elliptic_k(z)**2) == r"K^{2}\left(z\right)" + assert latex(elliptic_f(x, y)) == r"F\left(x\middle| y\right)" + assert latex(elliptic_f(x, y)**2) == r"F^{2}\left(x\middle| y\right)" + assert latex(elliptic_e(x, y)) == r"E\left(x\middle| y\right)" + assert latex(elliptic_e(x, y)**2) == r"E^{2}\left(x\middle| y\right)" + assert latex(elliptic_e(z)) == r"E\left(z\right)" + assert latex(elliptic_e(z)**2) == r"E^{2}\left(z\right)" + assert latex(elliptic_pi(x, y, z)) == r"\Pi\left(x; y\middle| z\right)" + assert latex(elliptic_pi(x, y, z)**2) == \ + r"\Pi^{2}\left(x; y\middle| z\right)" + assert latex(elliptic_pi(x, y)) == r"\Pi\left(x\middle| y\right)" + assert latex(elliptic_pi(x, y)**2) == r"\Pi^{2}\left(x\middle| y\right)" + + assert latex(Ei(x)) == r'\operatorname{Ei}{\left(x \right)}' + assert latex(Ei(x)**2) == r'\operatorname{Ei}^{2}{\left(x \right)}' + assert latex(expint(x, y)) == r'\operatorname{E}_{x}\left(y\right)' + assert latex(expint(x, y)**2) == r'\operatorname{E}_{x}^{2}\left(y\right)' + assert latex(Shi(x)**2) == r'\operatorname{Shi}^{2}{\left(x \right)}' + assert latex(Si(x)**2) == r'\operatorname{Si}^{2}{\left(x \right)}' + assert latex(Ci(x)**2) == r'\operatorname{Ci}^{2}{\left(x \right)}' + assert latex(Chi(x)**2) == r'\operatorname{Chi}^{2}\left(x\right)' + assert latex(Chi(x)) == r'\operatorname{Chi}\left(x\right)' + assert latex(jacobi(n, a, b, x)) == \ + r'P_{n}^{\left(a,b\right)}\left(x\right)' + assert latex(jacobi(n, a, b, x)**2) == \ + r'\left(P_{n}^{\left(a,b\right)}\left(x\right)\right)^{2}' + assert latex(gegenbauer(n, a, x)) == \ + r'C_{n}^{\left(a\right)}\left(x\right)' + assert latex(gegenbauer(n, a, x)**2) == \ + r'\left(C_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(chebyshevt(n, x)) == r'T_{n}\left(x\right)' + assert latex(chebyshevt(n, x)**2) == \ + r'\left(T_{n}\left(x\right)\right)^{2}' + assert latex(chebyshevu(n, x)) == r'U_{n}\left(x\right)' + assert latex(chebyshevu(n, x)**2) == \ + r'\left(U_{n}\left(x\right)\right)^{2}' + assert latex(legendre(n, x)) == r'P_{n}\left(x\right)' + assert latex(legendre(n, x)**2) == r'\left(P_{n}\left(x\right)\right)^{2}' + assert latex(assoc_legendre(n, a, x)) == \ + r'P_{n}^{\left(a\right)}\left(x\right)' + assert latex(assoc_legendre(n, a, x)**2) == \ + r'\left(P_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(laguerre(n, x)) == r'L_{n}\left(x\right)' + assert latex(laguerre(n, x)**2) == r'\left(L_{n}\left(x\right)\right)^{2}' + assert latex(assoc_laguerre(n, a, x)) == \ + r'L_{n}^{\left(a\right)}\left(x\right)' + assert latex(assoc_laguerre(n, a, x)**2) == \ + r'\left(L_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(hermite(n, x)) == r'H_{n}\left(x\right)' + assert latex(hermite(n, x)**2) == r'\left(H_{n}\left(x\right)\right)^{2}' + + theta = Symbol("theta", real=True) + phi = Symbol("phi", real=True) + assert latex(Ynm(n, m, theta, phi)) == r'Y_{n}^{m}\left(\theta,\phi\right)' + assert latex(Ynm(n, m, theta, phi)**3) == \ + r'\left(Y_{n}^{m}\left(\theta,\phi\right)\right)^{3}' + assert latex(Znm(n, m, theta, phi)) == r'Z_{n}^{m}\left(\theta,\phi\right)' + assert latex(Znm(n, m, theta, phi)**3) == \ + r'\left(Z_{n}^{m}\left(\theta,\phi\right)\right)^{3}' + + # Test latex printing of function names with "_" + assert latex(polar_lift(0)) == \ + r"\operatorname{polar\_lift}{\left(0 \right)}" + assert latex(polar_lift(0)**3) == \ + r"\operatorname{polar\_lift}^{3}{\left(0 \right)}" + + assert latex(totient(n)) == r'\phi\left(n\right)' + assert latex(totient(n) ** 2) == r'\left(\phi\left(n\right)\right)^{2}' + + assert latex(reduced_totient(n)) == r'\lambda\left(n\right)' + assert latex(reduced_totient(n) ** 2) == \ + r'\left(\lambda\left(n\right)\right)^{2}' + + assert latex(divisor_sigma(x)) == r"\sigma\left(x\right)" + assert latex(divisor_sigma(x)**2) == r"\sigma^{2}\left(x\right)" + assert latex(divisor_sigma(x, y)) == r"\sigma_y\left(x\right)" + assert latex(divisor_sigma(x, y)**2) == r"\sigma^{2}_y\left(x\right)" + + assert latex(udivisor_sigma(x)) == r"\sigma^*\left(x\right)" + assert latex(udivisor_sigma(x)**2) == r"\sigma^*^{2}\left(x\right)" + assert latex(udivisor_sigma(x, y)) == r"\sigma^*_y\left(x\right)" + assert latex(udivisor_sigma(x, y)**2) == r"\sigma^*^{2}_y\left(x\right)" + + assert latex(primenu(n)) == r'\nu\left(n\right)' + assert latex(primenu(n) ** 2) == r'\left(\nu\left(n\right)\right)^{2}' + + assert latex(primeomega(n)) == r'\Omega\left(n\right)' + assert latex(primeomega(n) ** 2) == \ + r'\left(\Omega\left(n\right)\right)^{2}' + + assert latex(LambertW(n)) == r'W\left(n\right)' + assert latex(LambertW(n, -1)) == r'W_{-1}\left(n\right)' + assert latex(LambertW(n, k)) == r'W_{k}\left(n\right)' + assert latex(LambertW(n) * LambertW(n)) == r"W^{2}\left(n\right)" + assert latex(Pow(LambertW(n), 2)) == r"W^{2}\left(n\right)" + assert latex(LambertW(n)**k) == r"W^{k}\left(n\right)" + assert latex(LambertW(n, k)**p) == r"W^{p}_{k}\left(n\right)" + + assert latex(Mod(x, 7)) == r'x \bmod 7' + assert latex(Mod(x + 1, 7)) == r'\left(x + 1\right) \bmod 7' + assert latex(Mod(7, x + 1)) == r'7 \bmod \left(x + 1\right)' + assert latex(Mod(2 * x, 7)) == r'2 x \bmod 7' + assert latex(Mod(7, 2 * x)) == r'7 \bmod 2 x' + assert latex(Mod(x, 7) + 1) == r'\left(x \bmod 7\right) + 1' + assert latex(2 * Mod(x, 7)) == r'2 \left(x \bmod 7\right)' + assert latex(Mod(7, 2 * x)**n) == r'\left(7 \bmod 2 x\right)^{n}' + + # some unknown function name should get rendered with \operatorname + fjlkd = Function('fjlkd') + assert latex(fjlkd(x)) == r'\operatorname{fjlkd}{\left(x \right)}' + # even when it is referred to without an argument + assert latex(fjlkd) == r'\operatorname{fjlkd}' + + +# test that notation passes to subclasses of the same name only +def test_function_subclass_different_name(): + class mygamma(gamma): + pass + assert latex(mygamma) == r"\operatorname{mygamma}" + assert latex(mygamma(x)) == r"\operatorname{mygamma}{\left(x \right)}" + + +def test_hyper_printing(): + from sympy.abc import x, z + + assert latex(meijerg(Tuple(pi, pi, x), Tuple(1), + (0, 1), Tuple(1, 2, 3/pi), z)) == \ + r'{G_{4, 5}^{2, 3}\left(\begin{matrix} \pi, \pi, x & 1 \\0, 1 & 1, 2, '\ + r'\frac{3}{\pi} \end{matrix} \middle| {z} \right)}' + assert latex(meijerg(Tuple(), Tuple(1), (0,), Tuple(), z)) == \ + r'{G_{1, 1}^{1, 0}\left(\begin{matrix} & 1 \\0 & \end{matrix} \middle| {z} \right)}' + assert latex(hyper((x, 2), (3,), z)) == \ + r'{{}_{2}F_{1}\left(\begin{matrix} 2, x ' \ + r'\\ 3 \end{matrix}\middle| {z} \right)}' + assert latex(hyper(Tuple(), Tuple(1), z)) == \ + r'{{}_{0}F_{1}\left(\begin{matrix} ' \ + r'\\ 1 \end{matrix}\middle| {z} \right)}' + + +def test_latex_bessel(): + from sympy.functions.special.bessel import (besselj, bessely, besseli, + besselk, hankel1, hankel2, + jn, yn, hn1, hn2) + from sympy.abc import z + assert latex(besselj(n, z**2)**k) == r'J^{k}_{n}\left(z^{2}\right)' + assert latex(bessely(n, z)) == r'Y_{n}\left(z\right)' + assert latex(besseli(n, z)) == r'I_{n}\left(z\right)' + assert latex(besselk(n, z)) == r'K_{n}\left(z\right)' + assert latex(hankel1(n, z**2)**2) == \ + r'\left(H^{(1)}_{n}\left(z^{2}\right)\right)^{2}' + assert latex(hankel2(n, z)) == r'H^{(2)}_{n}\left(z\right)' + assert latex(jn(n, z)) == r'j_{n}\left(z\right)' + assert latex(yn(n, z)) == r'y_{n}\left(z\right)' + assert latex(hn1(n, z)) == r'h^{(1)}_{n}\left(z\right)' + assert latex(hn2(n, z)) == r'h^{(2)}_{n}\left(z\right)' + + +def test_latex_fresnel(): + from sympy.functions.special.error_functions import (fresnels, fresnelc) + from sympy.abc import z + assert latex(fresnels(z)) == r'S\left(z\right)' + assert latex(fresnelc(z)) == r'C\left(z\right)' + assert latex(fresnels(z)**2) == r'S^{2}\left(z\right)' + assert latex(fresnelc(z)**2) == r'C^{2}\left(z\right)' + + +def test_latex_brackets(): + assert latex((-1)**x) == r"\left(-1\right)^{x}" + + +def test_latex_indexed(): + Psi_symbol = Symbol('Psi_0', complex=True, real=False) + Psi_indexed = IndexedBase(Symbol('Psi', complex=True, real=False)) + symbol_latex = latex(Psi_symbol * conjugate(Psi_symbol)) + indexed_latex = latex(Psi_indexed[0] * conjugate(Psi_indexed[0])) + # \\overline{{\\Psi}_{0}} {\\Psi}_{0} vs. \\Psi_{0} \\overline{\\Psi_{0}} + assert symbol_latex == r'\Psi_{0} \overline{\Psi_{0}}' + assert indexed_latex == r'\overline{{\Psi}_{0}} {\Psi}_{0}' + + # Symbol('gamma') gives r'\gamma' + interval = '\\mathrel{..}\\nobreak ' + assert latex(Indexed('x1', Symbol('i'))) == r'{x_{1}}_{i}' + assert latex(Indexed('x2', Idx('i'))) == r'{x_{2}}_{i}' + assert latex(Indexed('x3', Idx('i', Symbol('N')))) == r'{x_{3}}_{{i}_{0'+interval+'N - 1}}' + assert latex(Indexed('x3', Idx('i', Symbol('N')+1))) == r'{x_{3}}_{{i}_{0'+interval+'N}}' + assert latex(Indexed('x4', Idx('i', (Symbol('a'),Symbol('b'))))) == r'{x_{4}}_{{i}_{a'+interval+'b}}' + assert latex(IndexedBase('gamma')) == r'\gamma' + assert latex(IndexedBase('a b')) == r'a b' + assert latex(IndexedBase('a_b')) == r'a_{b}' + + +def test_latex_derivatives(): + # regular "d" for ordinary derivatives + assert latex(diff(x**3, x, evaluate=False)) == \ + r"\frac{d}{d x} x^{3}" + assert latex(diff(sin(x) + x**2, x, evaluate=False)) == \ + r"\frac{d}{d x} \left(x^{2} + \sin{\left(x \right)}\right)" + assert latex(diff(diff(sin(x) + x**2, x, evaluate=False), evaluate=False))\ + == \ + r"\frac{d^{2}}{d x^{2}} \left(x^{2} + \sin{\left(x \right)}\right)" + assert latex(diff(diff(diff(sin(x) + x**2, x, evaluate=False), evaluate=False), evaluate=False)) == \ + r"\frac{d^{3}}{d x^{3}} \left(x^{2} + \sin{\left(x \right)}\right)" + + # \partial for partial derivatives + assert latex(diff(sin(x * y), x, evaluate=False)) == \ + r"\frac{\partial}{\partial x} \sin{\left(x y \right)}" + assert latex(diff(sin(x * y) + x**2, x, evaluate=False)) == \ + r"\frac{\partial}{\partial x} \left(x^{2} + \sin{\left(x y \right)}\right)" + assert latex(diff(diff(sin(x*y) + x**2, x, evaluate=False), x, evaluate=False)) == \ + r"\frac{\partial^{2}}{\partial x^{2}} \left(x^{2} + \sin{\left(x y \right)}\right)" + assert latex(diff(diff(diff(sin(x*y) + x**2, x, evaluate=False), x, evaluate=False), x, evaluate=False)) == \ + r"\frac{\partial^{3}}{\partial x^{3}} \left(x^{2} + \sin{\left(x y \right)}\right)" + + # mixed partial derivatives + f = Function("f") + assert latex(diff(diff(f(x, y), x, evaluate=False), y, evaluate=False)) == \ + r"\frac{\partial^{2}}{\partial y\partial x} " + latex(f(x, y)) + + assert latex(diff(diff(diff(f(x, y), x, evaluate=False), x, evaluate=False), y, evaluate=False)) == \ + r"\frac{\partial^{3}}{\partial y\partial x^{2}} " + latex(f(x, y)) + + # for negative nested Derivative + assert latex(diff(-diff(y**2,x,evaluate=False),x,evaluate=False)) == r'\frac{d}{d x} \left(- \frac{d}{d x} y^{2}\right)' + assert latex(diff(diff(-diff(diff(y,x,evaluate=False),x,evaluate=False),x,evaluate=False),x,evaluate=False)) == \ + r'\frac{d^{2}}{d x^{2}} \left(- \frac{d^{2}}{d x^{2}} y\right)' + + # use ordinary d when one of the variables has been integrated out + assert latex(diff(Integral(exp(-x*y), (x, 0, oo)), y, evaluate=False)) == \ + r"\frac{d}{d y} \int\limits_{0}^{\infty} e^{- x y}\, dx" + + # Derivative wrapped in power: + assert latex(diff(x, x, evaluate=False)**2) == \ + r"\left(\frac{d}{d x} x\right)^{2}" + + assert latex(diff(f(x), x)**2) == \ + r"\left(\frac{d}{d x} f{\left(x \right)}\right)^{2}" + + assert latex(diff(f(x), (x, n))) == \ + r"\frac{d^{n}}{d x^{n}} f{\left(x \right)}" + + x1 = Symbol('x1') + x2 = Symbol('x2') + assert latex(diff(f(x1, x2), x1)) == r'\frac{\partial}{\partial x_{1}} f{\left(x_{1},x_{2} \right)}' + + n1 = Symbol('n1') + assert latex(diff(f(x), (x, n1))) == r'\frac{d^{n_{1}}}{d x^{n_{1}}} f{\left(x \right)}' + + n2 = Symbol('n2') + assert latex(diff(f(x), (x, Max(n1, n2)))) == \ + r'\frac{d^{\max\left(n_{1}, n_{2}\right)}}{d x^{\max\left(n_{1}, n_{2}\right)}} f{\left(x \right)}' + + # set diff operator + assert latex(diff(f(x), x), diff_operator="rd") == r'\frac{\mathrm{d}}{\mathrm{d} x} f{\left(x \right)}' + + +def test_latex_subs(): + assert latex(Subs(x*y, (x, y), (1, 2))) == r'\left. x y \right|_{\substack{ x=1\\ y=2 }}' + + +def test_latex_integrals(): + assert latex(Integral(log(x), x)) == r"\int \log{\left(x \right)}\, dx" + assert latex(Integral(x**2, (x, 0, 1))) == \ + r"\int\limits_{0}^{1} x^{2}\, dx" + assert latex(Integral(x**2, (x, 10, 20))) == \ + r"\int\limits_{10}^{20} x^{2}\, dx" + assert latex(Integral(y*x**2, (x, 0, 1), y)) == \ + r"\int\int\limits_{0}^{1} x^{2} y\, dx\, dy" + assert latex(Integral(y*x**2, (x, 0, 1), y), mode='equation*') == \ + r"\begin{equation*}\int\int\limits_{0}^{1} x^{2} y\, dx\, dy\end{equation*}" + assert latex(Integral(y*x**2, (x, 0, 1), y), mode='equation*', itex=True) \ + == r"$$\int\int_{0}^{1} x^{2} y\, dx\, dy$$" + assert latex(Integral(x, (x, 0))) == r"\int\limits^{0} x\, dx" + assert latex(Integral(x*y, x, y)) == r"\iint x y\, dx\, dy" + assert latex(Integral(x*y*z, x, y, z)) == r"\iiint x y z\, dx\, dy\, dz" + assert latex(Integral(x*y*z*t, x, y, z, t)) == \ + r"\iiiint t x y z\, dx\, dy\, dz\, dt" + assert latex(Integral(x, x, x, x, x, x, x)) == \ + r"\int\int\int\int\int\int x\, dx\, dx\, dx\, dx\, dx\, dx" + assert latex(Integral(x, x, y, (z, 0, 1))) == \ + r"\int\limits_{0}^{1}\int\int x\, dx\, dy\, dz" + + # for negative nested Integral + assert latex(Integral(-Integral(y**2,x),x)) == \ + r'\int \left(- \int y^{2}\, dx\right)\, dx' + assert latex(Integral(-Integral(-Integral(y,x),x),x)) == \ + r'\int \left(- \int \left(- \int y\, dx\right)\, dx\right)\, dx' + + # fix issue #10806 + assert latex(Integral(z, z)**2) == r"\left(\int z\, dz\right)^{2}" + assert latex(Integral(x + z, z)) == r"\int \left(x + z\right)\, dz" + assert latex(Integral(x+z/2, z)) == \ + r"\int \left(x + \frac{z}{2}\right)\, dz" + assert latex(Integral(x**y, z)) == r"\int x^{y}\, dz" + + # set diff operator + assert latex(Integral(x, x), diff_operator="rd") == r'\int x\, \mathrm{d}x' + assert latex(Integral(x, (x, 0, 1)), diff_operator="rd") == r'\int\limits_{0}^{1} x\, \mathrm{d}x' + + +def test_latex_sets(): + for s in (frozenset, set): + assert latex(s([x*y, x**2])) == r"\left\{x^{2}, x y\right\}" + assert latex(s(range(1, 6))) == r"\left\{1, 2, 3, 4, 5\right\}" + assert latex(s(range(1, 13))) == \ + r"\left\{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12\right\}" + + s = FiniteSet + assert latex(s(*[x*y, x**2])) == r"\left\{x^{2}, x y\right\}" + assert latex(s(*range(1, 6))) == r"\left\{1, 2, 3, 4, 5\right\}" + assert latex(s(*range(1, 13))) == \ + r"\left\{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12\right\}" + + +def test_latex_SetExpr(): + iv = Interval(1, 3) + se = SetExpr(iv) + assert latex(se) == r"SetExpr\left(\left[1, 3\right]\right)" + + +def test_latex_Range(): + assert latex(Range(1, 51)) == r'\left\{1, 2, \ldots, 50\right\}' + assert latex(Range(1, 4)) == r'\left\{1, 2, 3\right\}' + assert latex(Range(0, 3, 1)) == r'\left\{0, 1, 2\right\}' + assert latex(Range(0, 30, 1)) == r'\left\{0, 1, \ldots, 29\right\}' + assert latex(Range(30, 1, -1)) == r'\left\{30, 29, \ldots, 2\right\}' + assert latex(Range(0, oo, 2)) == r'\left\{0, 2, \ldots\right\}' + assert latex(Range(oo, -2, -2)) == r'\left\{\ldots, 2, 0\right\}' + assert latex(Range(-2, -oo, -1)) == r'\left\{-2, -3, \ldots\right\}' + assert latex(Range(-oo, oo)) == r'\left\{\ldots, -1, 0, 1, \ldots\right\}' + assert latex(Range(oo, -oo, -1)) == r'\left\{\ldots, 1, 0, -1, \ldots\right\}' + + a, b, c = symbols('a:c') + assert latex(Range(a, b, c)) == r'\text{Range}\left(a, b, c\right)' + assert latex(Range(a, 10, 1)) == r'\text{Range}\left(a, 10\right)' + assert latex(Range(0, b, 1)) == r'\text{Range}\left(b\right)' + assert latex(Range(0, 10, c)) == r'\text{Range}\left(0, 10, c\right)' + + i = Symbol('i', integer=True) + n = Symbol('n', negative=True, integer=True) + p = Symbol('p', positive=True, integer=True) + + assert latex(Range(i, i + 3)) == r'\left\{i, i + 1, i + 2\right\}' + assert latex(Range(-oo, n, 2)) == r'\left\{\ldots, n - 4, n - 2\right\}' + assert latex(Range(p, oo)) == r'\left\{p, p + 1, \ldots\right\}' + # The following will work if __iter__ is improved + # assert latex(Range(-3, p + 7)) == r'\left\{-3, -2, \ldots, p + 6\right\}' + # Must have integer assumptions + assert latex(Range(a, a + 3)) == r'\text{Range}\left(a, a + 3\right)' + + +def test_latex_sequences(): + s1 = SeqFormula(a**2, (0, oo)) + s2 = SeqPer((1, 2)) + + latex_str = r'\left[0, 1, 4, 9, \ldots\right]' + assert latex(s1) == latex_str + + latex_str = r'\left[1, 2, 1, 2, \ldots\right]' + assert latex(s2) == latex_str + + s3 = SeqFormula(a**2, (0, 2)) + s4 = SeqPer((1, 2), (0, 2)) + + latex_str = r'\left[0, 1, 4\right]' + assert latex(s3) == latex_str + + latex_str = r'\left[1, 2, 1\right]' + assert latex(s4) == latex_str + + s5 = SeqFormula(a**2, (-oo, 0)) + s6 = SeqPer((1, 2), (-oo, 0)) + + latex_str = r'\left[\ldots, 9, 4, 1, 0\right]' + assert latex(s5) == latex_str + + latex_str = r'\left[\ldots, 2, 1, 2, 1\right]' + assert latex(s6) == latex_str + + latex_str = r'\left[1, 3, 5, 11, \ldots\right]' + assert latex(SeqAdd(s1, s2)) == latex_str + + latex_str = r'\left[1, 3, 5\right]' + assert latex(SeqAdd(s3, s4)) == latex_str + + latex_str = r'\left[\ldots, 11, 5, 3, 1\right]' + assert latex(SeqAdd(s5, s6)) == latex_str + + latex_str = r'\left[0, 2, 4, 18, \ldots\right]' + assert latex(SeqMul(s1, s2)) == latex_str + + latex_str = r'\left[0, 2, 4\right]' + assert latex(SeqMul(s3, s4)) == latex_str + + latex_str = r'\left[\ldots, 18, 4, 2, 0\right]' + assert latex(SeqMul(s5, s6)) == latex_str + + # Sequences with symbolic limits, issue 12629 + s7 = SeqFormula(a**2, (a, 0, x)) + latex_str = r'\left\{a^{2}\right\}_{a=0}^{x}' + assert latex(s7) == latex_str + + b = Symbol('b') + s8 = SeqFormula(b*a**2, (a, 0, 2)) + latex_str = r'\left[0, b, 4 b\right]' + assert latex(s8) == latex_str + + +def test_latex_FourierSeries(): + latex_str = \ + r'2 \sin{\left(x \right)} - \sin{\left(2 x \right)} + \frac{2 \sin{\left(3 x \right)}}{3} + \ldots' + assert latex(fourier_series(x, (x, -pi, pi))) == latex_str + + +def test_latex_FormalPowerSeries(): + latex_str = r'\sum_{k=1}^{\infty} - \frac{\left(-1\right)^{- k} x^{k}}{k}' + assert latex(fps(log(1 + x))) == latex_str + + +def test_latex_intervals(): + a = Symbol('a', real=True) + assert latex(Interval(0, 0)) == r"\left\{0\right\}" + assert latex(Interval(0, a)) == r"\left[0, a\right]" + assert latex(Interval(0, a, False, False)) == r"\left[0, a\right]" + assert latex(Interval(0, a, True, False)) == r"\left(0, a\right]" + assert latex(Interval(0, a, False, True)) == r"\left[0, a\right)" + assert latex(Interval(0, a, True, True)) == r"\left(0, a\right)" + + +def test_latex_AccumuBounds(): + a = Symbol('a', real=True) + assert latex(AccumBounds(0, 1)) == r"\left\langle 0, 1\right\rangle" + assert latex(AccumBounds(0, a)) == r"\left\langle 0, a\right\rangle" + assert latex(AccumBounds(a + 1, a + 2)) == \ + r"\left\langle a + 1, a + 2\right\rangle" + + +def test_latex_emptyset(): + assert latex(S.EmptySet) == r"\emptyset" + + +def test_latex_universalset(): + assert latex(S.UniversalSet) == r"\mathbb{U}" + + +def test_latex_commutator(): + A = Operator('A') + B = Operator('B') + comm = Commutator(B, A) + assert latex(comm.doit()) == r"- (A B - B A)" + + +def test_latex_union(): + assert latex(Union(Interval(0, 1), Interval(2, 3))) == \ + r"\left[0, 1\right] \cup \left[2, 3\right]" + assert latex(Union(Interval(1, 1), Interval(2, 2), Interval(3, 4))) == \ + r"\left\{1, 2\right\} \cup \left[3, 4\right]" + + +def test_latex_intersection(): + assert latex(Intersection(Interval(0, 1), Interval(x, y))) == \ + r"\left[0, 1\right] \cap \left[x, y\right]" + + +def test_latex_symmetric_difference(): + assert latex(SymmetricDifference(Interval(2, 5), Interval(4, 7), + evaluate=False)) == \ + r'\left[2, 5\right] \triangle \left[4, 7\right]' + + +def test_latex_Complement(): + assert latex(Complement(S.Reals, S.Naturals)) == \ + r"\mathbb{R} \setminus \mathbb{N}" + + +def test_latex_productset(): + line = Interval(0, 1) + bigline = Interval(0, 10) + fset = FiniteSet(1, 2, 3) + assert latex(line**2) == r"%s^{2}" % latex(line) + assert latex(line**10) == r"%s^{10}" % latex(line) + assert latex((line * bigline * fset).flatten()) == r"%s \times %s \times %s" % ( + latex(line), latex(bigline), latex(fset)) + + +def test_latex_powerset(): + fset = FiniteSet(1, 2, 3) + assert latex(PowerSet(fset)) == r'\mathcal{P}\left(\left\{1, 2, 3\right\}\right)' + + +def test_latex_ordinals(): + w = OrdinalOmega() + assert latex(w) == r"\omega" + wp = OmegaPower(2, 3) + assert latex(wp) == r'3 \omega^{2}' + assert latex(Ordinal(wp, OmegaPower(1, 1))) == r'3 \omega^{2} + \omega' + assert latex(Ordinal(OmegaPower(2, 1), OmegaPower(1, 2))) == r'\omega^{2} + 2 \omega' + + +def test_set_operators_parenthesis(): + a, b, c, d = symbols('a:d') + A = FiniteSet(a) + B = FiniteSet(b) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(A, B, evaluate=False) + U2 = Union(C, D, evaluate=False) + I1 = Intersection(A, B, evaluate=False) + I2 = Intersection(C, D, evaluate=False) + C1 = Complement(A, B, evaluate=False) + C2 = Complement(C, D, evaluate=False) + D1 = SymmetricDifference(A, B, evaluate=False) + D2 = SymmetricDifference(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(A, B) + P2 = ProductSet(C, D) + + assert latex(Intersection(A, U2, evaluate=False)) == \ + r'\left\{a\right\} \cap ' \ + r'\left(\left\{c\right\} \cup \left\{d\right\}\right)' + assert latex(Intersection(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\cap \left(\left\{c\right\} \cup \left\{d\right\}\right)' + assert latex(Intersection(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \cap \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Intersection(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \cap \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(Intersection(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\cap \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + assert latex(Union(A, I2, evaluate=False)) == \ + r'\left\{a\right\} \cup ' \ + r'\left(\left\{c\right\} \cap \left\{d\right\}\right)' + assert latex(Union(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\cup \left(\left\{c\right\} \cap \left\{d\right\}\right)' + assert latex(Union(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \cup \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Union(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \cup \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(Union(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\cup \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + assert latex(Complement(A, C2, evaluate=False)) == \ + r'\left\{a\right\} \setminus \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Complement(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\setminus \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(Complement(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\setminus \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(Complement(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \setminus ' \ + r'\left(\left\{c\right\} \triangle \left\{d\right\}\right)' + assert latex(Complement(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) '\ + r'\setminus \left(\left\{c\right\} \times '\ + r'\left\{d\right\}\right)' + + assert latex(SymmetricDifference(A, D2, evaluate=False)) == \ + r'\left\{a\right\} \triangle \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(SymmetricDifference(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(SymmetricDifference(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(SymmetricDifference(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \triangle ' \ + r'\left(\left\{c\right\} \setminus \left\{d\right\}\right)' + assert latex(SymmetricDifference(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + # XXX This can be incorrect since cartesian product is not associative + assert latex(ProductSet(A, P2).flatten()) == \ + r'\left\{a\right\} \times \left\{c\right\} \times ' \ + r'\left\{d\right\}' + assert latex(ProductSet(U1, U2)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\times \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(ProductSet(I1, I2)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\times \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(ProductSet(C1, C2)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \times \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(ProductSet(D1, D2)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \times \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + + +def test_latex_Complexes(): + assert latex(S.Complexes) == r"\mathbb{C}" + + +def test_latex_Naturals(): + assert latex(S.Naturals) == r"\mathbb{N}" + + +def test_latex_Naturals0(): + assert latex(S.Naturals0) == r"\mathbb{N}_0" + + +def test_latex_Integers(): + assert latex(S.Integers) == r"\mathbb{Z}" + + +def test_latex_ImageSet(): + x = Symbol('x') + assert latex(ImageSet(Lambda(x, x**2), S.Naturals)) == \ + r"\left\{x^{2}\; \middle|\; x \in \mathbb{N}\right\}" + + y = Symbol('y') + imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4}) + assert latex(imgset) == \ + r"\left\{x + y\; \middle|\; x \in \left\{1, 2, 3\right\}, y \in \left\{3, 4\right\}\right\}" + + imgset = ImageSet(Lambda(((x, y),), x + y), ProductSet({1, 2, 3}, {3, 4})) + assert latex(imgset) == \ + r"\left\{x + y\; \middle|\; \left( x, \ y\right) \in \left\{1, 2, 3\right\} \times \left\{3, 4\right\}\right\}" + + +def test_latex_ConditionSet(): + x = Symbol('x') + assert latex(ConditionSet(x, Eq(x**2, 1), S.Reals)) == \ + r"\left\{x\; \middle|\; x \in \mathbb{R} \wedge x^{2} = 1 \right\}" + assert latex(ConditionSet(x, Eq(x**2, 1), S.UniversalSet)) == \ + r"\left\{x\; \middle|\; x^{2} = 1 \right\}" + + +def test_latex_ComplexRegion(): + assert latex(ComplexRegion(Interval(3, 5)*Interval(4, 6))) == \ + r"\left\{x + y i\; \middle|\; x, y \in \left[3, 5\right] \times \left[4, 6\right] \right\}" + assert latex(ComplexRegion(Interval(0, 1)*Interval(0, 2*pi), polar=True)) == \ + r"\left\{r \left(i \sin{\left(\theta \right)} + \cos{\left(\theta "\ + r"\right)}\right)\; \middle|\; r, \theta \in \left[0, 1\right] \times \left[0, 2 \pi\right) \right\}" + + +def test_latex_Contains(): + x = Symbol('x') + assert latex(Contains(x, S.Naturals)) == r"x \in \mathbb{N}" + + +def test_latex_sum(): + assert latex(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + r"\sum_{\substack{-2 \leq x \leq 2\\-5 \leq y \leq 5}} x y^{2}" + assert latex(Sum(x**2, (x, -2, 2))) == \ + r"\sum_{x=-2}^{2} x^{2}" + assert latex(Sum(x**2 + y, (x, -2, 2))) == \ + r"\sum_{x=-2}^{2} \left(x^{2} + y\right)" + assert latex(Sum(x**2 + y, (x, -2, 2))**2) == \ + r"\left(\sum_{x=-2}^{2} \left(x^{2} + y\right)\right)^{2}" + + +def test_latex_product(): + assert latex(Product(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + r"\prod_{\substack{-2 \leq x \leq 2\\-5 \leq y \leq 5}} x y^{2}" + assert latex(Product(x**2, (x, -2, 2))) == \ + r"\prod_{x=-2}^{2} x^{2}" + assert latex(Product(x**2 + y, (x, -2, 2))) == \ + r"\prod_{x=-2}^{2} \left(x^{2} + y\right)" + + assert latex(Product(x, (x, -2, 2))**2) == \ + r"\left(\prod_{x=-2}^{2} x\right)^{2}" + + +def test_latex_limits(): + assert latex(Limit(x, x, oo)) == r"\lim_{x \to \infty} x" + + # issue 8175 + f = Function('f') + assert latex(Limit(f(x), x, 0)) == r"\lim_{x \to 0^+} f{\left(x \right)}" + assert latex(Limit(f(x), x, 0, "-")) == \ + r"\lim_{x \to 0^-} f{\left(x \right)}" + + # issue #10806 + assert latex(Limit(f(x), x, 0)**2) == \ + r"\left(\lim_{x \to 0^+} f{\left(x \right)}\right)^{2}" + # bi-directional limit + assert latex(Limit(f(x), x, 0, dir='+-')) == \ + r"\lim_{x \to 0} f{\left(x \right)}" + + +def test_latex_log(): + assert latex(log(x)) == r"\log{\left(x \right)}" + assert latex(log(x), ln_notation=True) == r"\ln{\left(x \right)}" + assert latex(log(x) + log(y)) == \ + r"\log{\left(x \right)} + \log{\left(y \right)}" + assert latex(log(x) + log(y), ln_notation=True) == \ + r"\ln{\left(x \right)} + \ln{\left(y \right)}" + assert latex(pow(log(x), x)) == r"\log{\left(x \right)}^{x}" + assert latex(pow(log(x), x), ln_notation=True) == \ + r"\ln{\left(x \right)}^{x}" + + +def test_issue_3568(): + beta = Symbol(r'\beta') + y = beta + x + assert latex(y) in [r'\beta + x', r'x + \beta'] + + beta = Symbol(r'beta') + y = beta + x + assert latex(y) in [r'\beta + x', r'x + \beta'] + + +def test_latex(): + assert latex((2*tau)**Rational(7, 2)) == r"8 \sqrt{2} \tau^{\frac{7}{2}}" + assert latex((2*mu)**Rational(7, 2), mode='equation*') == \ + r"\begin{equation*}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation*}" + assert latex((2*mu)**Rational(7, 2), mode='equation', itex=True) == \ + r"$$8 \sqrt{2} \mu^{\frac{7}{2}}$$" + assert latex([2/x, y]) == r"\left[ \frac{2}{x}, \ y\right]" + + +def test_latex_dict(): + d = {Rational(1): 1, x**2: 2, x: 3, x**3: 4} + assert latex(d) == \ + r'\left\{ 1 : 1, \ x : 3, \ x^{2} : 2, \ x^{3} : 4\right\}' + D = Dict(d) + assert latex(D) == \ + r'\left\{ 1 : 1, \ x : 3, \ x^{2} : 2, \ x^{3} : 4\right\}' + + +def test_latex_list(): + ll = [Symbol('omega1'), Symbol('a'), Symbol('alpha')] + assert latex(ll) == r'\left[ \omega_{1}, \ a, \ \alpha\right]' + + +def test_latex_NumberSymbols(): + assert latex(S.Catalan) == "G" + assert latex(S.EulerGamma) == r"\gamma" + assert latex(S.Exp1) == "e" + assert latex(S.GoldenRatio) == r"\phi" + assert latex(S.Pi) == r"\pi" + assert latex(S.TribonacciConstant) == r"\text{TribonacciConstant}" + + +def test_latex_rational(): + # tests issue 3973 + assert latex(-Rational(1, 2)) == r"- \frac{1}{2}" + assert latex(Rational(-1, 2)) == r"- \frac{1}{2}" + assert latex(Rational(1, -2)) == r"- \frac{1}{2}" + assert latex(-Rational(-1, 2)) == r"\frac{1}{2}" + assert latex(-Rational(1, 2)*x) == r"- \frac{x}{2}" + assert latex(-Rational(1, 2)*x + Rational(-2, 3)*y) == \ + r"- \frac{x}{2} - \frac{2 y}{3}" + + +def test_latex_inverse(): + # tests issue 4129 + assert latex(1/x) == r"\frac{1}{x}" + assert latex(1/(x + y)) == r"\frac{1}{x + y}" + + +def test_latex_DiracDelta(): + assert latex(DiracDelta(x)) == r"\delta\left(x\right)" + assert latex(DiracDelta(x)**2) == r"\left(\delta\left(x\right)\right)^{2}" + assert latex(DiracDelta(x, 0)) == r"\delta\left(x\right)" + assert latex(DiracDelta(x, 5)) == \ + r"\delta^{\left( 5 \right)}\left( x \right)" + assert latex(DiracDelta(x, 5)**2) == \ + r"\left(\delta^{\left( 5 \right)}\left( x \right)\right)^{2}" + + +def test_latex_Heaviside(): + assert latex(Heaviside(x)) == r"\theta\left(x\right)" + assert latex(Heaviside(x)**2) == r"\left(\theta\left(x\right)\right)^{2}" + + +def test_latex_KroneckerDelta(): + assert latex(KroneckerDelta(x, y)) == r"\delta_{x y}" + assert latex(KroneckerDelta(x, y + 1)) == r"\delta_{x, y + 1}" + # issue 6578 + assert latex(KroneckerDelta(x + 1, y)) == r"\delta_{y, x + 1}" + assert latex(Pow(KroneckerDelta(x, y), 2, evaluate=False)) == \ + r"\left(\delta_{x y}\right)^{2}" + + +def test_latex_LeviCivita(): + assert latex(LeviCivita(x, y, z)) == r"\varepsilon_{x y z}" + assert latex(LeviCivita(x, y, z)**2) == \ + r"\left(\varepsilon_{x y z}\right)^{2}" + assert latex(LeviCivita(x, y, z + 1)) == r"\varepsilon_{x, y, z + 1}" + assert latex(LeviCivita(x, y + 1, z)) == r"\varepsilon_{x, y + 1, z}" + assert latex(LeviCivita(x + 1, y, z)) == r"\varepsilon_{x + 1, y, z}" + + +def test_mode(): + expr = x + y + assert latex(expr) == r'x + y' + assert latex(expr, mode='plain') == r'x + y' + assert latex(expr, mode='inline') == r'$x + y$' + assert latex( + expr, mode='equation*') == r'\begin{equation*}x + y\end{equation*}' + assert latex( + expr, mode='equation') == r'\begin{equation}x + y\end{equation}' + raises(ValueError, lambda: latex(expr, mode='foo')) + + +def test_latex_mathieu(): + assert latex(mathieuc(x, y, z)) == r"C\left(x, y, z\right)" + assert latex(mathieus(x, y, z)) == r"S\left(x, y, z\right)" + assert latex(mathieuc(x, y, z)**2) == r"C\left(x, y, z\right)^{2}" + assert latex(mathieus(x, y, z)**2) == r"S\left(x, y, z\right)^{2}" + assert latex(mathieucprime(x, y, z)) == r"C^{\prime}\left(x, y, z\right)" + assert latex(mathieusprime(x, y, z)) == r"S^{\prime}\left(x, y, z\right)" + assert latex(mathieucprime(x, y, z)**2) == r"C^{\prime}\left(x, y, z\right)^{2}" + assert latex(mathieusprime(x, y, z)**2) == r"S^{\prime}\left(x, y, z\right)^{2}" + +def test_latex_Piecewise(): + p = Piecewise((x, x < 1), (x**2, True)) + assert latex(p) == r"\begin{cases} x & \text{for}\: x < 1 \\x^{2} &" \ + r" \text{otherwise} \end{cases}" + assert latex(p, itex=True) == \ + r"\begin{cases} x & \text{for}\: x \lt 1 \\x^{2} &" \ + r" \text{otherwise} \end{cases}" + p = Piecewise((x, x < 0), (0, x >= 0)) + assert latex(p) == r'\begin{cases} x & \text{for}\: x < 0 \\0 &' \ + r' \text{otherwise} \end{cases}' + A, B = symbols("A B", commutative=False) + p = Piecewise((A**2, Eq(A, B)), (A*B, True)) + s = r"\begin{cases} A^{2} & \text{for}\: A = B \\A B & \text{otherwise} \end{cases}" + assert latex(p) == s + assert latex(A*p) == r"A \left(%s\right)" % s + assert latex(p*A) == r"\left(%s\right) A" % s + assert latex(Piecewise((x, x < 1), (x**2, x < 2))) == \ + r'\begin{cases} x & ' \ + r'\text{for}\: x < 1 \\x^{2} & \text{for}\: x < 2 \end{cases}' + + +def test_latex_Matrix(): + M = Matrix([[1 + x, y], [y, x - 1]]) + assert latex(M) == \ + r'\left[\begin{matrix}x + 1 & y\\y & x - 1\end{matrix}\right]' + assert latex(M, mode='inline') == \ + r'$\left[\begin{smallmatrix}x + 1 & y\\' \ + r'y & x - 1\end{smallmatrix}\right]$' + assert latex(M, mat_str='array') == \ + r'\left[\begin{array}{cc}x + 1 & y\\y & x - 1\end{array}\right]' + assert latex(M, mat_str='bmatrix') == \ + r'\left[\begin{bmatrix}x + 1 & y\\y & x - 1\end{bmatrix}\right]' + assert latex(M, mat_delim=None, mat_str='bmatrix') == \ + r'\begin{bmatrix}x + 1 & y\\y & x - 1\end{bmatrix}' + + M2 = Matrix(1, 11, range(11)) + assert latex(M2) == \ + r'\left[\begin{array}{ccccccccccc}' \ + r'0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}\right]' + + +def test_latex_matrix_with_functions(): + t = symbols('t') + theta1 = symbols('theta1', cls=Function) + + M = Matrix([[sin(theta1(t)), cos(theta1(t))], + [cos(theta1(t).diff(t)), sin(theta1(t).diff(t))]]) + + expected = (r'\left[\begin{matrix}\sin{\left(' + r'\theta_{1}{\left(t \right)} \right)} & ' + r'\cos{\left(\theta_{1}{\left(t \right)} \right)' + r'}\\\cos{\left(\frac{d}{d t} \theta_{1}{\left(t ' + r'\right)} \right)} & \sin{\left(\frac{d}{d t} ' + r'\theta_{1}{\left(t \right)} \right' + r')}\end{matrix}\right]') + + assert latex(M) == expected + + +def test_latex_NDimArray(): + x, y, z, w = symbols("x y z w") + + for ArrayType in (ImmutableDenseNDimArray, ImmutableSparseNDimArray, + MutableDenseNDimArray, MutableSparseNDimArray): + # Basic: scalar array + M = ArrayType(x) + + assert latex(M) == r"x" + + M = ArrayType([[1 / x, y], [z, w]]) + M1 = ArrayType([1 / x, y, z]) + + M2 = tensorproduct(M1, M) + M3 = tensorproduct(M, M) + + assert latex(M) == \ + r'\left[\begin{matrix}\frac{1}{x} & y\\z & w\end{matrix}\right]' + assert latex(M1) == \ + r"\left[\begin{matrix}\frac{1}{x} & y & z\end{matrix}\right]" + assert latex(M2) == \ + r"\left[\begin{matrix}" \ + r"\left[\begin{matrix}\frac{1}{x^{2}} & \frac{y}{x}\\\frac{z}{x} & \frac{w}{x}\end{matrix}\right] & " \ + r"\left[\begin{matrix}\frac{y}{x} & y^{2}\\y z & w y\end{matrix}\right] & " \ + r"\left[\begin{matrix}\frac{z}{x} & y z\\z^{2} & w z\end{matrix}\right]" \ + r"\end{matrix}\right]" + assert latex(M3) == \ + r"""\left[\begin{matrix}"""\ + r"""\left[\begin{matrix}\frac{1}{x^{2}} & \frac{y}{x}\\\frac{z}{x} & \frac{w}{x}\end{matrix}\right] & """\ + r"""\left[\begin{matrix}\frac{y}{x} & y^{2}\\y z & w y\end{matrix}\right]\\"""\ + r"""\left[\begin{matrix}\frac{z}{x} & y z\\z^{2} & w z\end{matrix}\right] & """\ + r"""\left[\begin{matrix}\frac{w}{x} & w y\\w z & w^{2}\end{matrix}\right]"""\ + r"""\end{matrix}\right]""" + + Mrow = ArrayType([[x, y, 1/z]]) + Mcolumn = ArrayType([[x], [y], [1/z]]) + Mcol2 = ArrayType([Mcolumn.tolist()]) + + assert latex(Mrow) == \ + r"\left[\left[\begin{matrix}x & y & \frac{1}{z}\end{matrix}\right]\right]" + assert latex(Mcolumn) == \ + r"\left[\begin{matrix}x\\y\\\frac{1}{z}\end{matrix}\right]" + assert latex(Mcol2) == \ + r'\left[\begin{matrix}\left[\begin{matrix}x\\y\\\frac{1}{z}\end{matrix}\right]\end{matrix}\right]' + + +def test_latex_mul_symbol(): + assert latex(4*4**x, mul_symbol='times') == r"4 \times 4^{x}" + assert latex(4*4**x, mul_symbol='dot') == r"4 \cdot 4^{x}" + assert latex(4*4**x, mul_symbol='ldot') == r"4 \,.\, 4^{x}" + + assert latex(4*x, mul_symbol='times') == r"4 \times x" + assert latex(4*x, mul_symbol='dot') == r"4 \cdot x" + assert latex(4*x, mul_symbol='ldot') == r"4 \,.\, x" + + +def test_latex_issue_4381(): + y = 4*4**log(2) + assert latex(y) == r'4 \cdot 4^{\log{\left(2 \right)}}' + assert latex(1/y) == r'\frac{1}{4 \cdot 4^{\log{\left(2 \right)}}}' + + +def test_latex_issue_4576(): + assert latex(Symbol("beta_13_2")) == r"\beta_{13 2}" + assert latex(Symbol("beta_132_20")) == r"\beta_{132 20}" + assert latex(Symbol("beta_13")) == r"\beta_{13}" + assert latex(Symbol("x_a_b")) == r"x_{a b}" + assert latex(Symbol("x_1_2_3")) == r"x_{1 2 3}" + assert latex(Symbol("x_a_b1")) == r"x_{a b1}" + assert latex(Symbol("x_a_1")) == r"x_{a 1}" + assert latex(Symbol("x_1_a")) == r"x_{1 a}" + assert latex(Symbol("x_1^aa")) == r"x^{aa}_{1}" + assert latex(Symbol("x_1__aa")) == r"x^{aa}_{1}" + assert latex(Symbol("x_11^a")) == r"x^{a}_{11}" + assert latex(Symbol("x_11__a")) == r"x^{a}_{11}" + assert latex(Symbol("x_a_a_a_a")) == r"x_{a a a a}" + assert latex(Symbol("x_a_a^a^a")) == r"x^{a a}_{a a}" + assert latex(Symbol("x_a_a__a__a")) == r"x^{a a}_{a a}" + assert latex(Symbol("alpha_11")) == r"\alpha_{11}" + assert latex(Symbol("alpha_11_11")) == r"\alpha_{11 11}" + assert latex(Symbol("alpha_alpha")) == r"\alpha_{\alpha}" + assert latex(Symbol("alpha^aleph")) == r"\alpha^{\aleph}" + assert latex(Symbol("alpha__aleph")) == r"\alpha^{\aleph}" + + +def test_latex_pow_fraction(): + x = Symbol('x') + # Testing exp + assert r'e^{-x}' in latex(exp(-x)/2).replace(' ', '') # Remove Whitespace + + # Testing e^{-x} in case future changes alter behavior of muls or fracs + # In particular current output is \frac{1}{2}e^{- x} but perhaps this will + # change to \frac{e^{-x}}{2} + + # Testing general, non-exp, power + assert r'3^{-x}' in latex(3**-x/2).replace(' ', '') + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + assert latex(A*B*C**-1) == r"A B C^{-1}" + assert latex(C**-1*A*B) == r"C^{-1} A B" + assert latex(A*C**-1*B) == r"A C^{-1} B" + + +def test_latex_order(): + expr = x**3 + x**2*y + y**4 + 3*x*y**3 + + assert latex(expr, order='lex') == r"x^{3} + x^{2} y + 3 x y^{3} + y^{4}" + assert latex( + expr, order='rev-lex') == r"y^{4} + 3 x y^{3} + x^{2} y + x^{3}" + assert latex(expr, order='none') == r"x^{3} + y^{4} + y x^{2} + 3 x y^{3}" + + +def test_latex_Lambda(): + assert latex(Lambda(x, x + 1)) == r"\left( x \mapsto x + 1 \right)" + assert latex(Lambda((x, y), x + 1)) == r"\left( \left( x, \ y\right) \mapsto x + 1 \right)" + assert latex(Lambda(x, x)) == r"\left( x \mapsto x \right)" + +def test_latex_PolyElement(): + Ruv, u, v = ring("u,v", ZZ) + Rxyz, x, y, z = ring("x,y,z", Ruv) + + assert latex(x - x) == r"0" + assert latex(x - 1) == r"x - 1" + assert latex(x + 1) == r"x + 1" + + assert latex((u**2 + 3*u*v + 1)*x**2*y + u + 1) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + u + 1" + assert latex((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + \left(u + 1\right) x" + assert latex((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + \left(u + 1\right) x + 1" + assert latex((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == \ + r"-\left({u}^{2} - 3 u v + 1\right) {x}^{2} y - \left(u + 1\right) x - 1" + + assert latex(-(v**2 + v + 1)*x + 3*u*v + 1) == \ + r"-\left({v}^{2} + v + 1\right) x + 3 u v + 1" + assert latex(-(v**2 + v + 1)*x - 3*u*v + 1) == \ + r"-\left({v}^{2} + v + 1\right) x - 3 u v + 1" + + +def test_latex_FracElement(): + Fuv, u, v = field("u,v", ZZ) + Fxyzt, x, y, z, t = field("x,y,z,t", Fuv) + + assert latex(x - x) == r"0" + assert latex(x - 1) == r"x - 1" + assert latex(x + 1) == r"x + 1" + + assert latex(x/3) == r"\frac{x}{3}" + assert latex(x/z) == r"\frac{x}{z}" + assert latex(x*y/z) == r"\frac{x y}{z}" + assert latex(x/(z*t)) == r"\frac{x}{z t}" + assert latex(x*y/(z*t)) == r"\frac{x y}{z t}" + + assert latex((x - 1)/y) == r"\frac{x - 1}{y}" + assert latex((x + 1)/y) == r"\frac{x + 1}{y}" + assert latex((-x - 1)/y) == r"\frac{-x - 1}{y}" + assert latex((x + 1)/(y*z)) == r"\frac{x + 1}{y z}" + assert latex(-y/(x + 1)) == r"\frac{-y}{x + 1}" + assert latex(y*z/(x + 1)) == r"\frac{y z}{x + 1}" + + assert latex(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == \ + r"\frac{\left(u + 1\right) x y + 1}{\left(v - 1\right) z - 1}" + assert latex(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == \ + r"\frac{\left(u + 1\right) x y + 1}{\left(v - 1\right) z - u v t - 1}" + + +def test_latex_Poly(): + assert latex(Poly(x**2 + 2 * x, x)) == \ + r"\operatorname{Poly}{\left( x^{2} + 2 x, x, domain=\mathbb{Z} \right)}" + assert latex(Poly(x/y, x)) == \ + r"\operatorname{Poly}{\left( \frac{1}{y} x, x, domain=\mathbb{Z}\left(y\right) \right)}" + assert latex(Poly(2.0*x + y)) == \ + r"\operatorname{Poly}{\left( 2.0 x + 1.0 y, x, y, domain=\mathbb{R} \right)}" + + +def test_latex_Poly_order(): + assert latex(Poly([a, 1, b, 2, c, 3], x)) == \ + r'\operatorname{Poly}{\left( a x^{5} + x^{4} + b x^{3} + 2 x^{2} + c'\ + r' x + 3, x, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + assert latex(Poly([a, 1, b+c, 2, 3], x)) == \ + r'\operatorname{Poly}{\left( a x^{4} + x^{3} + \left(b + c\right) '\ + r'x^{2} + 2 x + 3, x, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + assert latex(Poly(a*x**3 + x**2*y - x*y - c*y**3 - b*x*y**2 + y - a*x + b, + (x, y))) == \ + r'\operatorname{Poly}{\left( a x^{3} + x^{2}y - b xy^{2} - xy - '\ + r'a x - c y^{3} + y + b, x, y, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + + +def test_latex_ComplexRootOf(): + assert latex(rootof(x**5 + x + 3, 0)) == \ + r"\operatorname{CRootOf} {\left(x^{5} + x + 3, 0\right)}" + + +def test_latex_RootSum(): + assert latex(RootSum(x**5 + x + 3, sin)) == \ + r"\operatorname{RootSum} {\left(x^{5} + x + 3, \left( x \mapsto \sin{\left(x \right)} \right)\right)}" + + +def test_settings(): + raises(TypeError, lambda: latex(x*y, method="garbage")) + + +def test_latex_numbers(): + assert latex(catalan(n)) == r"C_{n}" + assert latex(catalan(n)**2) == r"C_{n}^{2}" + assert latex(bernoulli(n)) == r"B_{n}" + assert latex(bernoulli(n, x)) == r"B_{n}\left(x\right)" + assert latex(bernoulli(n)**2) == r"B_{n}^{2}" + assert latex(bernoulli(n, x)**2) == r"B_{n}^{2}\left(x\right)" + assert latex(genocchi(n)) == r"G_{n}" + assert latex(genocchi(n, x)) == r"G_{n}\left(x\right)" + assert latex(genocchi(n)**2) == r"G_{n}^{2}" + assert latex(genocchi(n, x)**2) == r"G_{n}^{2}\left(x\right)" + assert latex(bell(n)) == r"B_{n}" + assert latex(bell(n, x)) == r"B_{n}\left(x\right)" + assert latex(bell(n, m, (x, y))) == r"B_{n, m}\left(x, y\right)" + assert latex(bell(n)**2) == r"B_{n}^{2}" + assert latex(bell(n, x)**2) == r"B_{n}^{2}\left(x\right)" + assert latex(bell(n, m, (x, y))**2) == r"B_{n, m}^{2}\left(x, y\right)" + assert latex(fibonacci(n)) == r"F_{n}" + assert latex(fibonacci(n, x)) == r"F_{n}\left(x\right)" + assert latex(fibonacci(n)**2) == r"F_{n}^{2}" + assert latex(fibonacci(n, x)**2) == r"F_{n}^{2}\left(x\right)" + assert latex(lucas(n)) == r"L_{n}" + assert latex(lucas(n)**2) == r"L_{n}^{2}" + assert latex(tribonacci(n)) == r"T_{n}" + assert latex(tribonacci(n, x)) == r"T_{n}\left(x\right)" + assert latex(tribonacci(n)**2) == r"T_{n}^{2}" + assert latex(tribonacci(n, x)**2) == r"T_{n}^{2}\left(x\right)" + assert latex(mobius(n)) == r"\mu\left(n\right)" + assert latex(mobius(n)**2) == r"\mu^{2}\left(n\right)" + + +def test_latex_euler(): + assert latex(euler(n)) == r"E_{n}" + assert latex(euler(n, x)) == r"E_{n}\left(x\right)" + assert latex(euler(n, x)**2) == r"E_{n}^{2}\left(x\right)" + + +def test_lamda(): + assert latex(Symbol('lamda')) == r"\lambda" + assert latex(Symbol('Lamda')) == r"\Lambda" + + +def test_custom_symbol_names(): + x = Symbol('x') + y = Symbol('y') + assert latex(x) == r"x" + assert latex(x, symbol_names={x: "x_i"}) == r"x_i" + assert latex(x + y, symbol_names={x: "x_i"}) == r"x_i + y" + assert latex(x**2, symbol_names={x: "x_i"}) == r"x_i^{2}" + assert latex(x + y, symbol_names={x: "x_i", y: "y_j"}) == r"x_i + y_j" + + +def test_matAdd(): + C = MatrixSymbol('C', 5, 5) + B = MatrixSymbol('B', 5, 5) + + n = symbols("n") + h = MatrixSymbol("h", 1, 1) + + assert latex(C - 2*B) in [r'- 2 B + C', r'C -2 B'] + assert latex(C + 2*B) in [r'2 B + C', r'C + 2 B'] + assert latex(B - 2*C) in [r'B - 2 C', r'- 2 C + B'] + assert latex(B + 2*C) in [r'B + 2 C', r'2 C + B'] + + assert latex(n * h - (-h + h.T) * (h + h.T)) == 'n h - \\left(- h + h^{T}\\right) \\left(h + h^{T}\\right)' + assert latex(MatAdd(MatAdd(h, h), MatAdd(h, h))) == '\\left(h + h\\right) + \\left(h + h\\right)' + assert latex(MatMul(MatMul(h, h), MatMul(h, h))) == '\\left(h h\\right) \\left(h h\\right)' + + +def test_matMul(): + A = MatrixSymbol('A', 5, 5) + B = MatrixSymbol('B', 5, 5) + x = Symbol('x') + assert latex(2*A) == r'2 A' + assert latex(2*x*A) == r'2 x A' + assert latex(-2*A) == r'- 2 A' + assert latex(1.5*A) == r'1.5 A' + assert latex(sqrt(2)*A) == r'\sqrt{2} A' + assert latex(-sqrt(2)*A) == r'- \sqrt{2} A' + assert latex(2*sqrt(2)*x*A) == r'2 \sqrt{2} x A' + assert latex(-2*A*(A + 2*B)) in [r'- 2 A \left(A + 2 B\right)', + r'- 2 A \left(2 B + A\right)'] + + +def test_latex_MatrixSlice(): + n = Symbol('n', integer=True) + x, y, z, w, t, = symbols('x y z w t') + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + assert latex(MatrixSlice(X, (None, None, None), (None, None, None))) == r'X\left[:, :\right]' + assert latex(X[x:x + 1, y:y + 1]) == r'X\left[x:x + 1, y:y + 1\right]' + assert latex(X[x:x + 1:2, y:y + 1:2]) == r'X\left[x:x + 1:2, y:y + 1:2\right]' + assert latex(X[:x, y:]) == r'X\left[:x, y:\right]' + assert latex(X[:x, y:]) == r'X\left[:x, y:\right]' + assert latex(X[x:, :y]) == r'X\left[x:, :y\right]' + assert latex(X[x:y, z:w]) == r'X\left[x:y, z:w\right]' + assert latex(X[x:y:t, w:t:x]) == r'X\left[x:y:t, w:t:x\right]' + assert latex(X[x::y, t::w]) == r'X\left[x::y, t::w\right]' + assert latex(X[:x:y, :t:w]) == r'X\left[:x:y, :t:w\right]' + assert latex(X[::x, ::y]) == r'X\left[::x, ::y\right]' + assert latex(MatrixSlice(X, (0, None, None), (0, None, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (None, n, None), (None, n, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (0, n, None), (0, n, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (0, n, 2), (0, n, 2))) == r'X\left[::2, ::2\right]' + assert latex(X[1:2:3, 4:5:6]) == r'X\left[1:2:3, 4:5:6\right]' + assert latex(X[1:3:5, 4:6:8]) == r'X\left[1:3:5, 4:6:8\right]' + assert latex(X[1:10:2]) == r'X\left[1:10:2, :\right]' + assert latex(Y[:5, 1:9:2]) == r'Y\left[:5, 1:9:2\right]' + assert latex(Y[:5, 1:10:2]) == r'Y\left[:5, 1::2\right]' + assert latex(Y[5, :5:2]) == r'Y\left[5:6, :5:2\right]' + assert latex(X[0:1, 0:1]) == r'X\left[:1, :1\right]' + assert latex(X[0:1:2, 0:1:2]) == r'X\left[:1:2, :1:2\right]' + assert latex((Y + Z)[2:, 2:]) == r'\left(Y + Z\right)\left[2:, 2:\right]' + + +def test_latex_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + from sympy.stats.rv import RandomDomain + + X = Normal('x1', 0, 1) + assert latex(where(X > 0)) == r"\text{Domain: }0 < x_{1} \wedge x_{1} < \infty" + + D = Die('d1', 6) + assert latex(where(D > 4)) == r"\text{Domain: }d_{1} = 5 \vee d_{1} = 6" + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert latex( + pspace(Tuple(A, B)).domain) == \ + r"\text{Domain: }0 \leq a \wedge 0 \leq b \wedge a < \infty \wedge b < \infty" + + assert latex(RandomDomain(FiniteSet(x), FiniteSet(1, 2))) == \ + r'\text{Domain: }\left\{x\right\} \in \left\{1, 2\right\}' + +def test_PrettyPoly(): + from sympy.polys.domains import QQ + F = QQ.frac_field(x, y) + R = QQ[x, y] + + assert latex(F.convert(x/(x + y))) == latex(x/(x + y)) + assert latex(R.convert(x + y)) == latex(x + y) + + +def test_integral_transforms(): + x = Symbol("x") + k = Symbol("k") + f = Function("f") + a = Symbol("a") + b = Symbol("b") + + assert latex(MellinTransform(f(x), x, k)) == \ + r"\mathcal{M}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseMellinTransform(f(k), k, x, a, b)) == \ + r"\mathcal{M}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(LaplaceTransform(f(x), x, k)) == \ + r"\mathcal{L}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseLaplaceTransform(f(k), k, x, (a, b))) == \ + r"\mathcal{L}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(FourierTransform(f(x), x, k)) == \ + r"\mathcal{F}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseFourierTransform(f(k), k, x)) == \ + r"\mathcal{F}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(CosineTransform(f(x), x, k)) == \ + r"\mathcal{COS}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseCosineTransform(f(k), k, x)) == \ + r"\mathcal{COS}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(SineTransform(f(x), x, k)) == \ + r"\mathcal{SIN}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseSineTransform(f(k), k, x)) == \ + r"\mathcal{SIN}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + +def test_PolynomialRingBase(): + from sympy.polys.domains import QQ + assert latex(QQ.old_poly_ring(x, y)) == r"\mathbb{Q}\left[x, y\right]" + assert latex(QQ.old_poly_ring(x, y, order="ilex")) == \ + r"S_<^{-1}\mathbb{Q}\left[x, y\right]" + + +def test_categories(): + from sympy.categories import (Object, IdentityMorphism, + NamedMorphism, Category, Diagram, + DiagramGrid) + + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A2, A3, "f2") + id_A1 = IdentityMorphism(A1) + + K1 = Category("K1") + + assert latex(A1) == r"A_{1}" + assert latex(f1) == r"f_{1}:A_{1}\rightarrow A_{2}" + assert latex(id_A1) == r"id:A_{1}\rightarrow A_{1}" + assert latex(f2*f1) == r"f_{2}\circ f_{1}:A_{1}\rightarrow A_{3}" + + assert latex(K1) == r"\mathbf{K_{1}}" + + d = Diagram() + assert latex(d) == r"\emptyset" + + d = Diagram({f1: "unique", f2: S.EmptySet}) + assert latex(d) == r"\left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \emptyset, \ id:A_{1}\rightarrow " \ + r"A_{1} : \emptyset, \ id:A_{2}\rightarrow A_{2} : " \ + r"\emptyset, \ id:A_{3}\rightarrow A_{3} : \emptyset, " \ + r"\ f_{1}:A_{1}\rightarrow A_{2} : \left\{unique\right\}, " \ + r"\ f_{2}:A_{2}\rightarrow A_{3} : \emptyset\right\}" + + d = Diagram({f1: "unique", f2: S.EmptySet}, {f2 * f1: "unique"}) + assert latex(d) == r"\left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \emptyset, \ id:A_{1}\rightarrow " \ + r"A_{1} : \emptyset, \ id:A_{2}\rightarrow A_{2} : " \ + r"\emptyset, \ id:A_{3}\rightarrow A_{3} : \emptyset, " \ + r"\ f_{1}:A_{1}\rightarrow A_{2} : \left\{unique\right\}," \ + r" \ f_{2}:A_{2}\rightarrow A_{3} : \emptyset\right\}" \ + r"\Longrightarrow \left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \left\{unique\right\}\right\}" + + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert latex(grid) == r"\begin{array}{cc}" + "\n" \ + r"A & B \\" + "\n" \ + r" & C " + "\n" \ + r"\end{array}" + "\n" + + +def test_Modules(): + from sympy.polys.domains import QQ + from sympy.polys.agca import homomorphism + + R = QQ.old_poly_ring(x, y) + F = R.free_module(2) + M = F.submodule([x, y], [1, x**2]) + + assert latex(F) == r"{\mathbb{Q}\left[x, y\right]}^{2}" + assert latex(M) == \ + r"\left\langle {\left[ {x},{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle" + + I = R.ideal(x**2, y) + assert latex(I) == r"\left\langle {x^{2}},{y} \right\rangle" + + Q = F / M + assert latex(Q) == \ + r"\frac{{\mathbb{Q}\left[x, y\right]}^{2}}{\left\langle {\left[ {x},"\ + r"{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle}" + assert latex(Q.submodule([1, x**3/2], [2, y])) == \ + r"\left\langle {{\left[ {1},{\frac{x^{3}}{2}} \right]} + {\left"\ + r"\langle {\left[ {x},{y} \right]},{\left[ {1},{x^{2}} \right]} "\ + r"\right\rangle}},{{\left[ {2},{y} \right]} + {\left\langle {\left[ "\ + r"{x},{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle}} \right\rangle" + + h = homomorphism(QQ.old_poly_ring(x).free_module(2), + QQ.old_poly_ring(x).free_module(2), [0, 0]) + + assert latex(h) == \ + r"{\left[\begin{matrix}0 & 0\\0 & 0\end{matrix}\right]} : "\ + r"{{\mathbb{Q}\left[x\right]}^{2}} \to {{\mathbb{Q}\left[x\right]}^{2}}" + + +def test_QuotientRing(): + from sympy.polys.domains import QQ + R = QQ.old_poly_ring(x)/[x**2 + 1] + + assert latex(R) == \ + r"\frac{\mathbb{Q}\left[x\right]}{\left\langle {x^{2} + 1} \right\rangle}" + assert latex(R.one) == r"{1} + {\left\langle {x^{2} + 1} \right\rangle}" + + +def test_Tr(): + #TODO: Handle indices + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert latex(t) == r'\operatorname{tr}\left(A B\right)' + + +def test_Determinant(): + from sympy.matrices import Determinant, Inverse, BlockMatrix, OneMatrix, ZeroMatrix + m = Matrix(((1, 2), (3, 4))) + assert latex(Determinant(m)) == '\\left|{\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}}\\right|' + assert latex(Determinant(Inverse(m))) == \ + '\\left|{\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{-1}}\\right|' + X = MatrixSymbol('X', 2, 2) + assert latex(Determinant(X)) == '\\left|{X}\\right|' + assert latex(Determinant(X + m)) == \ + '\\left|{\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X}\\right|' + assert latex(Determinant(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left|{\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}}\\right|' + + +def test_Adjoint(): + from sympy.matrices import Adjoint, Inverse, Transpose + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(Adjoint(X)) == r'X^{\dagger}' + assert latex(Adjoint(X + Y)) == r'\left(X + Y\right)^{\dagger}' + assert latex(Adjoint(X) + Adjoint(Y)) == r'X^{\dagger} + Y^{\dagger}' + assert latex(Adjoint(X*Y)) == r'\left(X Y\right)^{\dagger}' + assert latex(Adjoint(Y)*Adjoint(X)) == r'Y^{\dagger} X^{\dagger}' + assert latex(Adjoint(X**2)) == r'\left(X^{2}\right)^{\dagger}' + assert latex(Adjoint(X)**2) == r'\left(X^{\dagger}\right)^{2}' + assert latex(Adjoint(Inverse(X))) == r'\left(X^{-1}\right)^{\dagger}' + assert latex(Inverse(Adjoint(X))) == r'\left(X^{\dagger}\right)^{-1}' + assert latex(Adjoint(Transpose(X))) == r'\left(X^{T}\right)^{\dagger}' + assert latex(Transpose(Adjoint(X))) == r'\left(X^{\dagger}\right)^{T}' + assert latex(Transpose(Adjoint(X) + Y)) == r'\left(X^{\dagger} + Y\right)^{T}' + m = Matrix(((1, 2), (3, 4))) + assert latex(Adjoint(m)) == '\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{\\dagger}' + assert latex(Adjoint(m+X)) == \ + '\\left(\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X\\right)^{\\dagger}' + from sympy.matrices import BlockMatrix, OneMatrix, ZeroMatrix + assert latex(Adjoint(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left[\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}\\right]^{\\dagger}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(Adjoint(Mx)) == r'\left(M^{x}\right)^{\dagger}' + + # adjoint style + assert latex(Adjoint(X), adjoint_style="star") == r'X^{\ast}' + assert latex(Adjoint(X + Y), adjoint_style="hermitian") == r'\left(X + Y\right)^{\mathsf{H}}' + assert latex(Adjoint(X) + Adjoint(Y), adjoint_style="dagger") == r'X^{\dagger} + Y^{\dagger}' + assert latex(Adjoint(Y)*Adjoint(X)) == r'Y^{\dagger} X^{\dagger}' + assert latex(Adjoint(X**2), adjoint_style="star") == r'\left(X^{2}\right)^{\ast}' + assert latex(Adjoint(X)**2, adjoint_style="hermitian") == r'\left(X^{\mathsf{H}}\right)^{2}' + +def test_Transpose(): + from sympy.matrices import Transpose, MatPow, HadamardPower + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(Transpose(X)) == r'X^{T}' + assert latex(Transpose(X + Y)) == r'\left(X + Y\right)^{T}' + + assert latex(Transpose(HadamardPower(X, 2))) == r'\left(X^{\circ {2}}\right)^{T}' + assert latex(HadamardPower(Transpose(X), 2)) == r'\left(X^{T}\right)^{\circ {2}}' + assert latex(Transpose(MatPow(X, 2))) == r'\left(X^{2}\right)^{T}' + assert latex(MatPow(Transpose(X), 2)) == r'\left(X^{T}\right)^{2}' + m = Matrix(((1, 2), (3, 4))) + assert latex(Transpose(m)) == '\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{T}' + assert latex(Transpose(m+X)) == \ + '\\left(\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X\\right)^{T}' + from sympy.matrices import BlockMatrix, OneMatrix, ZeroMatrix + assert latex(Transpose(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left[\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}\\right]^{T}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(Transpose(Mx)) == r'\left(M^{x}\right)^{T}' + + +def test_Hadamard(): + from sympy.matrices import HadamardProduct, HadamardPower + from sympy.matrices.expressions import MatAdd, MatMul, MatPow + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(HadamardProduct(X, Y*Y)) == r'X \circ Y^{2}' + assert latex(HadamardProduct(X, Y)*Y) == r'\left(X \circ Y\right) Y' + + assert latex(HadamardPower(X, 2)) == r'X^{\circ {2}}' + assert latex(HadamardPower(X, -1)) == r'X^{\circ \left({-1}\right)}' + assert latex(HadamardPower(MatAdd(X, Y), 2)) == \ + r'\left(X + Y\right)^{\circ {2}}' + assert latex(HadamardPower(MatMul(X, Y), 2)) == \ + r'\left(X Y\right)^{\circ {2}}' + + assert latex(HadamardPower(MatPow(X, -1), -1)) == \ + r'\left(X^{-1}\right)^{\circ \left({-1}\right)}' + assert latex(MatPow(HadamardPower(X, -1), -1)) == \ + r'\left(X^{\circ \left({-1}\right)}\right)^{-1}' + + assert latex(HadamardPower(X, n+1)) == \ + r'X^{\circ \left({n + 1}\right)}' + + +def test_MatPow(): + from sympy.matrices.expressions import MatPow + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(MatPow(X, 2)) == 'X^{2}' + assert latex(MatPow(X*X, 2)) == '\\left(X^{2}\\right)^{2}' + assert latex(MatPow(X*Y, 2)) == '\\left(X Y\\right)^{2}' + assert latex(MatPow(X + Y, 2)) == '\\left(X + Y\\right)^{2}' + assert latex(MatPow(X + X, 2)) == '\\left(2 X\\right)^{2}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(MatPow(Mx, 2)) == r'\left(M^{x}\right)^{2}' + + +def test_ElementwiseApplyFunction(): + X = MatrixSymbol('X', 2, 2) + expr = (X.T*X).applyfunc(sin) + assert latex(expr) == r"{\left( d \mapsto \sin{\left(d \right)} \right)}_{\circ}\left({X^{T} X}\right)" + expr = X.applyfunc(Lambda(x, 1/x)) + assert latex(expr) == r'{\left( x \mapsto \frac{1}{x} \right)}_{\circ}\left({X}\right)' + + +def test_ZeroMatrix(): + from sympy.matrices.expressions.special import ZeroMatrix + assert latex(ZeroMatrix(1, 1), mat_symbol_style='plain') == r"0" + assert latex(ZeroMatrix(1, 1), mat_symbol_style='bold') == r"\mathbf{0}" + + +def test_OneMatrix(): + from sympy.matrices.expressions.special import OneMatrix + assert latex(OneMatrix(3, 4), mat_symbol_style='plain') == r"1" + assert latex(OneMatrix(3, 4), mat_symbol_style='bold') == r"\mathbf{1}" + + +def test_Identity(): + from sympy.matrices.expressions.special import Identity + assert latex(Identity(1), mat_symbol_style='plain') == r"\mathbb{I}" + assert latex(Identity(1), mat_symbol_style='bold') == r"\mathbf{I}" + + +def test_latex_DFT_IDFT(): + from sympy.matrices.expressions.fourier import DFT, IDFT + assert latex(DFT(13)) == r"\text{DFT}_{13}" + assert latex(IDFT(x)) == r"\text{IDFT}_{x}" + + +def test_boolean_args_order(): + syms = symbols('a:f') + + expr = And(*syms) + assert latex(expr) == r'a \wedge b \wedge c \wedge d \wedge e \wedge f' + + expr = Or(*syms) + assert latex(expr) == r'a \vee b \vee c \vee d \vee e \vee f' + + expr = Equivalent(*syms) + assert latex(expr) == \ + r'a \Leftrightarrow b \Leftrightarrow c \Leftrightarrow d \Leftrightarrow e \Leftrightarrow f' + + expr = Xor(*syms) + assert latex(expr) == \ + r'a \veebar b \veebar c \veebar d \veebar e \veebar f' + + +def test_imaginary(): + i = sqrt(-1) + assert latex(i) == r'i' + + +def test_builtins_without_args(): + assert latex(sin) == r'\sin' + assert latex(cos) == r'\cos' + assert latex(tan) == r'\tan' + assert latex(log) == r'\log' + assert latex(Ei) == r'\operatorname{Ei}' + assert latex(zeta) == r'\zeta' + + +def test_latex_greek_functions(): + # bug because capital greeks that have roman equivalents should not use + # \Alpha, \Beta, \Eta, etc. + s = Function('Alpha') + assert latex(s) == r'\mathrm{A}' + assert latex(s(x)) == r'\mathrm{A}{\left(x \right)}' + s = Function('Beta') + assert latex(s) == r'\mathrm{B}' + s = Function('Eta') + assert latex(s) == r'\mathrm{H}' + assert latex(s(x)) == r'\mathrm{H}{\left(x \right)}' + + # bug because sympy.core.numbers.Pi is special + p = Function('Pi') + # assert latex(p(x)) == r'\Pi{\left(x \right)}' + assert latex(p) == r'\Pi' + + # bug because not all greeks are included + c = Function('chi') + assert latex(c(x)) == r'\chi{\left(x \right)}' + assert latex(c) == r'\chi' + + +def test_translate(): + s = 'Alpha' + assert translate(s) == r'\mathrm{A}' + s = 'Beta' + assert translate(s) == r'\mathrm{B}' + s = 'Eta' + assert translate(s) == r'\mathrm{H}' + s = 'omicron' + assert translate(s) == r'o' + s = 'Pi' + assert translate(s) == r'\Pi' + s = 'pi' + assert translate(s) == r'\pi' + s = 'LamdaHatDOT' + assert translate(s) == r'\dot{\hat{\Lambda}}' + + +def test_other_symbols(): + from sympy.printing.latex import other_symbols + for s in other_symbols: + assert latex(symbols(s)) == r"" "\\" + s + + +def test_modifiers(): + # Test each modifier individually in the simplest case + # (with funny capitalizations) + assert latex(symbols("xMathring")) == r"\mathring{x}" + assert latex(symbols("xCheck")) == r"\check{x}" + assert latex(symbols("xBreve")) == r"\breve{x}" + assert latex(symbols("xAcute")) == r"\acute{x}" + assert latex(symbols("xGrave")) == r"\grave{x}" + assert latex(symbols("xTilde")) == r"\tilde{x}" + assert latex(symbols("xPrime")) == r"{x}'" + assert latex(symbols("xddDDot")) == r"\ddddot{x}" + assert latex(symbols("xDdDot")) == r"\dddot{x}" + assert latex(symbols("xDDot")) == r"\ddot{x}" + assert latex(symbols("xBold")) == r"\boldsymbol{x}" + assert latex(symbols("xnOrM")) == r"\left\|{x}\right\|" + assert latex(symbols("xAVG")) == r"\left\langle{x}\right\rangle" + assert latex(symbols("xHat")) == r"\hat{x}" + assert latex(symbols("xDot")) == r"\dot{x}" + assert latex(symbols("xBar")) == r"\bar{x}" + assert latex(symbols("xVec")) == r"\vec{x}" + assert latex(symbols("xAbs")) == r"\left|{x}\right|" + assert latex(symbols("xMag")) == r"\left|{x}\right|" + assert latex(symbols("xPrM")) == r"{x}'" + assert latex(symbols("xBM")) == r"\boldsymbol{x}" + # Test strings that are *only* the names of modifiers + assert latex(symbols("Mathring")) == r"Mathring" + assert latex(symbols("Check")) == r"Check" + assert latex(symbols("Breve")) == r"Breve" + assert latex(symbols("Acute")) == r"Acute" + assert latex(symbols("Grave")) == r"Grave" + assert latex(symbols("Tilde")) == r"Tilde" + assert latex(symbols("Prime")) == r"Prime" + assert latex(symbols("DDot")) == r"\dot{D}" + assert latex(symbols("Bold")) == r"Bold" + assert latex(symbols("NORm")) == r"NORm" + assert latex(symbols("AVG")) == r"AVG" + assert latex(symbols("Hat")) == r"Hat" + assert latex(symbols("Dot")) == r"Dot" + assert latex(symbols("Bar")) == r"Bar" + assert latex(symbols("Vec")) == r"Vec" + assert latex(symbols("Abs")) == r"Abs" + assert latex(symbols("Mag")) == r"Mag" + assert latex(symbols("PrM")) == r"PrM" + assert latex(symbols("BM")) == r"BM" + assert latex(symbols("hbar")) == r"\hbar" + # Check a few combinations + assert latex(symbols("xvecdot")) == r"\dot{\vec{x}}" + assert latex(symbols("xDotVec")) == r"\vec{\dot{x}}" + assert latex(symbols("xHATNorm")) == r"\left\|{\hat{x}}\right\|" + # Check a couple big, ugly combinations + assert latex(symbols('xMathringBm_yCheckPRM__zbreveAbs')) == \ + r"\boldsymbol{\mathring{x}}^{\left|{\breve{z}}\right|}_{{\check{y}}'}" + assert latex(symbols('alphadothat_nVECDOT__tTildePrime')) == \ + r"\hat{\dot{\alpha}}^{{\tilde{t}}'}_{\dot{\vec{n}}}" + + +def test_greek_symbols(): + assert latex(Symbol('alpha')) == r'\alpha' + assert latex(Symbol('beta')) == r'\beta' + assert latex(Symbol('gamma')) == r'\gamma' + assert latex(Symbol('delta')) == r'\delta' + assert latex(Symbol('epsilon')) == r'\epsilon' + assert latex(Symbol('zeta')) == r'\zeta' + assert latex(Symbol('eta')) == r'\eta' + assert latex(Symbol('theta')) == r'\theta' + assert latex(Symbol('iota')) == r'\iota' + assert latex(Symbol('kappa')) == r'\kappa' + assert latex(Symbol('lambda')) == r'\lambda' + assert latex(Symbol('mu')) == r'\mu' + assert latex(Symbol('nu')) == r'\nu' + assert latex(Symbol('xi')) == r'\xi' + assert latex(Symbol('omicron')) == r'o' + assert latex(Symbol('pi')) == r'\pi' + assert latex(Symbol('rho')) == r'\rho' + assert latex(Symbol('sigma')) == r'\sigma' + assert latex(Symbol('tau')) == r'\tau' + assert latex(Symbol('upsilon')) == r'\upsilon' + assert latex(Symbol('phi')) == r'\phi' + assert latex(Symbol('chi')) == r'\chi' + assert latex(Symbol('psi')) == r'\psi' + assert latex(Symbol('omega')) == r'\omega' + + assert latex(Symbol('Alpha')) == r'\mathrm{A}' + assert latex(Symbol('Beta')) == r'\mathrm{B}' + assert latex(Symbol('Gamma')) == r'\Gamma' + assert latex(Symbol('Delta')) == r'\Delta' + assert latex(Symbol('Epsilon')) == r'\mathrm{E}' + assert latex(Symbol('Zeta')) == r'\mathrm{Z}' + assert latex(Symbol('Eta')) == r'\mathrm{H}' + assert latex(Symbol('Theta')) == r'\Theta' + assert latex(Symbol('Iota')) == r'\mathrm{I}' + assert latex(Symbol('Kappa')) == r'\mathrm{K}' + assert latex(Symbol('Lambda')) == r'\Lambda' + assert latex(Symbol('Mu')) == r'\mathrm{M}' + assert latex(Symbol('Nu')) == r'\mathrm{N}' + assert latex(Symbol('Xi')) == r'\Xi' + assert latex(Symbol('Omicron')) == r'\mathrm{O}' + assert latex(Symbol('Pi')) == r'\Pi' + assert latex(Symbol('Rho')) == r'\mathrm{P}' + assert latex(Symbol('Sigma')) == r'\Sigma' + assert latex(Symbol('Tau')) == r'\mathrm{T}' + assert latex(Symbol('Upsilon')) == r'\Upsilon' + assert latex(Symbol('Phi')) == r'\Phi' + assert latex(Symbol('Chi')) == r'\mathrm{X}' + assert latex(Symbol('Psi')) == r'\Psi' + assert latex(Symbol('Omega')) == r'\Omega' + + assert latex(Symbol('varepsilon')) == r'\varepsilon' + assert latex(Symbol('varkappa')) == r'\varkappa' + assert latex(Symbol('varphi')) == r'\varphi' + assert latex(Symbol('varpi')) == r'\varpi' + assert latex(Symbol('varrho')) == r'\varrho' + assert latex(Symbol('varsigma')) == r'\varsigma' + assert latex(Symbol('vartheta')) == r'\vartheta' + + +def test_fancyset_symbols(): + assert latex(S.Rationals) == r'\mathbb{Q}' + assert latex(S.Naturals) == r'\mathbb{N}' + assert latex(S.Naturals0) == r'\mathbb{N}_0' + assert latex(S.Integers) == r'\mathbb{Z}' + assert latex(S.Reals) == r'\mathbb{R}' + assert latex(S.Complexes) == r'\mathbb{C}' + + +@XFAIL +def test_builtin_without_args_mismatched_names(): + assert latex(CosineTransform) == r'\mathcal{COS}' + + +def test_builtin_no_args(): + assert latex(Chi) == r'\operatorname{Chi}' + assert latex(beta) == r'\operatorname{B}' + assert latex(gamma) == r'\Gamma' + assert latex(KroneckerDelta) == r'\delta' + assert latex(DiracDelta) == r'\delta' + assert latex(lowergamma) == r'\gamma' + + +def test_issue_6853(): + p = Function('Pi') + assert latex(p(x)) == r"\Pi{\left(x \right)}" + + +def test_Mul(): + e = Mul(-2, x + 1, evaluate=False) + assert latex(e) == r'- 2 \left(x + 1\right)' + e = Mul(2, x + 1, evaluate=False) + assert latex(e) == r'2 \left(x + 1\right)' + e = Mul(S.Half, x + 1, evaluate=False) + assert latex(e) == r'\frac{x + 1}{2}' + e = Mul(y, x + 1, evaluate=False) + assert latex(e) == r'y \left(x + 1\right)' + e = Mul(-y, x + 1, evaluate=False) + assert latex(e) == r'- y \left(x + 1\right)' + e = Mul(-2, x + 1) + assert latex(e) == r'- 2 x - 2' + e = Mul(2, x + 1) + assert latex(e) == r'2 x + 2' + + +def test_Pow(): + e = Pow(2, 2, evaluate=False) + assert latex(e) == r'2^{2}' + assert latex(x**(Rational(-1, 3))) == r'\frac{1}{\sqrt[3]{x}}' + x2 = Symbol(r'x^2') + assert latex(x2**2) == r'\left(x^{2}\right)^{2}' + # Issue 11011 + assert latex(S('1.453e4500')**x) == r'{1.453 \cdot 10^{4500}}^{x}' + + +def test_issue_7180(): + assert latex(Equivalent(x, y)) == r"x \Leftrightarrow y" + assert latex(Not(Equivalent(x, y))) == r"x \not\Leftrightarrow y" + + +def test_issue_8409(): + assert latex(S.Half**n) == r"\left(\frac{1}{2}\right)^{n}" + + +def test_issue_8470(): + from sympy.parsing.sympy_parser import parse_expr + e = parse_expr("-B*A", evaluate=False) + assert latex(e) == r"A \left(- B\right)" + + +def test_issue_15439(): + x = MatrixSymbol('x', 2, 2) + y = MatrixSymbol('y', 2, 2) + assert latex((x * y).subs(y, -y)) == r"x \left(- y\right)" + assert latex((x * y).subs(y, -2*y)) == r"x \left(- 2 y\right)" + assert latex((x * y).subs(x, -x)) == r"\left(- x\right) y" + + +def test_issue_2934(): + assert latex(Symbol(r'\frac{a_1}{b_1}')) == r'\frac{a_1}{b_1}' + + +def test_issue_10489(): + latexSymbolWithBrace = r'C_{x_{0}}' + s = Symbol(latexSymbolWithBrace) + assert latex(s) == latexSymbolWithBrace + assert latex(cos(s)) == r'\cos{\left(C_{x_{0}} \right)}' + + +def test_issue_12886(): + m__1, l__1 = symbols('m__1, l__1') + assert latex(m__1**2 + l__1**2) == \ + r'\left(l^{1}\right)^{2} + \left(m^{1}\right)^{2}' + + +def test_issue_13559(): + from sympy.parsing.sympy_parser import parse_expr + expr = parse_expr('5/1', evaluate=False) + assert latex(expr) == r"\frac{5}{1}" + + +def test_issue_13651(): + expr = c + Mul(-1, a + b, evaluate=False) + assert latex(expr) == r"c - \left(a + b\right)" + + +def test_latex_UnevaluatedExpr(): + x = symbols("x") + he = UnevaluatedExpr(1/x) + assert latex(he) == latex(1/x) == r"\frac{1}{x}" + assert latex(he**2) == r"\left(\frac{1}{x}\right)^{2}" + assert latex(he + 1) == r"1 + \frac{1}{x}" + assert latex(x*he) == r"x \frac{1}{x}" + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert latex(A[0, 0]) == r"{A}_{0,0}" + assert latex(3 * A[0, 0]) == r"3 {A}_{0,0}" + + F = C[0, 0].subs(C, A - B) + assert latex(F) == r"{\left(A - B\right)}_{0,0}" + + i, j, k = symbols("i j k") + M = MatrixSymbol("M", k, k) + N = MatrixSymbol("N", k, k) + assert latex((M*N)[i, j]) == \ + r'\sum_{i_{1}=0}^{k - 1} {M}_{i,i_{1}} {N}_{i_{1},j}' + + X_a = MatrixSymbol('X_a', 3, 3) + assert latex(X_a[0, 0]) == r"{X_{a}}_{0,0}" + + +def test_MatrixSymbol_printing(): + # test cases for issue #14237 + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert latex(-A) == r"- A" + assert latex(A - A*B - B) == r"A - A B - B" + assert latex(-A*B - A*B*C - B) == r"- A B - A B C - B" + + +def test_DotProduct_printing(): + X = MatrixSymbol('X', 3, 1) + Y = MatrixSymbol('Y', 3, 1) + a = Symbol('a') + assert latex(DotProduct(X, Y)) == r"X \cdot Y" + assert latex(DotProduct(a * X, Y)) == r"a X \cdot Y" + assert latex(a * DotProduct(X, Y)) == r"a \left(X \cdot Y\right)" + + +def test_KroneckerProduct_printing(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 2, 2) + assert latex(KroneckerProduct(A, B)) == r'A \otimes B' + + +def test_Series_printing(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert latex(Series(tf1, tf2)) == \ + r'\left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right) \left(\frac{x - y}{x + y}\right)' + assert latex(Series(tf1, tf2, tf3)) == \ + r'\left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right) \left(\frac{x - y}{x + y}\right) \left(\frac{t x^{2} - t^{w} x + w}{t - y}\right)' + assert latex(Series(-tf2, tf1)) == \ + r'\left(\frac{- x + y}{x + y}\right) \left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right)' + + M_1 = Matrix([[5/s], [5/(2*s)]]) + T_1 = TransferFunctionMatrix.from_Matrix(M_1, s) + M_2 = Matrix([[5, 6*s**3]]) + T_2 = TransferFunctionMatrix.from_Matrix(M_2, s) + # Brackets + assert latex(T_1*(T_2 + T_2)) == \ + r'\left[\begin{matrix}\frac{5}{s}\\\frac{5}{2 s}\end{matrix}\right]_\tau\cdot\left(\left[\begin{matrix}\frac{5}{1} &' \ + r' \frac{6 s^{3}}{1}\end{matrix}\right]_\tau + \left[\begin{matrix}\frac{5}{1} & \frac{6 s^{3}}{1}\end{matrix}\right]_\tau\right)' \ + == latex(MIMOSeries(MIMOParallel(T_2, T_2), T_1)) + # No Brackets + M_3 = Matrix([[5, 6], [6, 5/s]]) + T_3 = TransferFunctionMatrix.from_Matrix(M_3, s) + assert latex(T_1*T_2 + T_3) == r'\left[\begin{matrix}\frac{5}{s}\\\frac{5}{2 s}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}' \ + r'\frac{5}{1} & \frac{6 s^{3}}{1}\end{matrix}\right]_\tau + \left[\begin{matrix}\frac{5}{1} & \frac{6}{1}\\\frac{6}{1} & ' \ + r'\frac{5}{s}\end{matrix}\right]_\tau' == latex(MIMOParallel(MIMOSeries(T_2, T_1), T_3)) + + +def test_TransferFunction_printing(): + tf1 = TransferFunction(x - 1, x + 1, x) + assert latex(tf1) == r"\frac{x - 1}{x + 1}" + tf2 = TransferFunction(x + 1, 2 - y, x) + assert latex(tf2) == r"\frac{x + 1}{2 - y}" + tf3 = TransferFunction(y, y**2 + 2*y + 3, y) + assert latex(tf3) == r"\frac{y}{y^{2} + 2 y + 3}" + + +def test_Parallel_printing(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + assert latex(Parallel(tf1, tf2)) == \ + r'\frac{x y^{2} - z}{- t^{3} + y^{3}} + \frac{x - y}{x + y}' + assert latex(Parallel(-tf2, tf1)) == \ + r'\frac{- x + y}{x + y} + \frac{x y^{2} - z}{- t^{3} + y^{3}}' + + M_1 = Matrix([[5, 6], [6, 5/s]]) + T_1 = TransferFunctionMatrix.from_Matrix(M_1, s) + M_2 = Matrix([[5/s, 6], [6, 5/(s - 1)]]) + T_2 = TransferFunctionMatrix.from_Matrix(M_2, s) + M_3 = Matrix([[6, 5/(s*(s - 1))], [5, 6]]) + T_3 = TransferFunctionMatrix.from_Matrix(M_3, s) + assert latex(T_1 + T_2 + T_3) == r'\left[\begin{matrix}\frac{5}{1} & \frac{6}{1}\\\frac{6}{1} & \frac{5}{s}\end{matrix}\right]' \ + r'_\tau + \left[\begin{matrix}\frac{5}{s} & \frac{6}{1}\\\frac{6}{1} & \frac{5}{s - 1}\end{matrix}\right]_\tau + \left[\begin{matrix}' \ + r'\frac{6}{1} & \frac{5}{s \left(s - 1\right)}\\\frac{5}{1} & \frac{6}{1}\end{matrix}\right]_\tau' \ + == latex(MIMOParallel(T_1, T_2, T_3)) == latex(MIMOParallel(T_1, MIMOParallel(T_2, T_3))) == latex(MIMOParallel(MIMOParallel(T_1, T_2), T_3)) + + +def test_TransferFunctionMatrix_printing(): + tf1 = TransferFunction(p, p + x, p) + tf2 = TransferFunction(-s + p, p + s, p) + tf3 = TransferFunction(p, y**2 + 2*y + 3, p) + assert latex(TransferFunctionMatrix([[tf1], [tf2]])) == \ + r'\left[\begin{matrix}\frac{p}{p + x}\\\frac{p - s}{p + s}\end{matrix}\right]_\tau' + assert latex(TransferFunctionMatrix([[tf1, tf2], [tf3, -tf1]])) == \ + r'\left[\begin{matrix}\frac{p}{p + x} & \frac{p - s}{p + s}\\\frac{p}{y^{2} + 2 y + 3} & \frac{\left(-1\right) p}{p + x}\end{matrix}\right]_\tau' + + +def test_Feedback_printing(): + tf1 = TransferFunction(p, p + x, p) + tf2 = TransferFunction(-s + p, p + s, p) + # Negative Feedback (Default) + assert latex(Feedback(tf1, tf2)) == \ + r'\frac{\frac{p}{p + x}}{\frac{1}{1} + \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + assert latex(Feedback(tf1*tf2, TransferFunction(1, 1, p))) == \ + r'\frac{\left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}{\frac{1}{1} + \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + # Positive Feedback + assert latex(Feedback(tf1, tf2, 1)) == \ + r'\frac{\frac{p}{p + x}}{\frac{1}{1} - \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + assert latex(Feedback(tf1*tf2, sign=1)) == \ + r'\frac{\left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}{\frac{1}{1} - \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + + +def test_MIMOFeedback_printing(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s**2 - 1, s) + tf3 = TransferFunction(s, s - 1, s) + tf4 = TransferFunction(s**2, s**2 - 1, s) + + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm_2 = TransferFunctionMatrix([[tf4, tf3], [tf2, tf1]]) + + # Negative Feedback (Default) + assert latex(MIMOFeedback(tfm_1, tfm_2)) == \ + r'\left(I_{\tau} + \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left[' \ + r'\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1} & \frac{1}{s}\end{matrix}\right]_\tau\right)^{-1} \cdot \left[\begin{matrix}' \ + r'\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau' + + # Positive Feedback + assert latex(MIMOFeedback(tfm_1*tfm_2, tfm_1, 1)) == \ + r'\left(I_{\tau} - \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left' \ + r'[\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1} & \frac{1}{s}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}' \ + r'\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\right)^{-1} \cdot \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}' \ + r'\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1}' \ + r' & \frac{1}{s}\end{matrix}\right]_\tau' + + +def test_Quaternion_latex_printing(): + q = Quaternion(x, y, z, t) + assert latex(q) == r"x + y i + z j + t k" + q = Quaternion(x, y, z, x*t) + assert latex(q) == r"x + y i + z j + t x k" + q = Quaternion(x, y, z, x + t) + assert latex(q) == r"x + y i + z j + \left(t + x\right) k" + + +def test_TensorProduct_printing(): + from sympy.tensor.functions import TensorProduct + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + assert latex(TensorProduct(A, B)) == r"A \otimes B" + + +def test_WedgeProduct_printing(): + from sympy.diffgeom.rn import R2 + from sympy.diffgeom import WedgeProduct + wp = WedgeProduct(R2.dx, R2.dy) + assert latex(wp) == r"\operatorname{d}x \wedge \operatorname{d}y" + + +def test_issue_9216(): + expr_1 = Pow(1, -1, evaluate=False) + assert latex(expr_1) == r"1^{-1}" + + expr_2 = Pow(1, Pow(1, -1, evaluate=False), evaluate=False) + assert latex(expr_2) == r"1^{1^{-1}}" + + expr_3 = Pow(3, -2, evaluate=False) + assert latex(expr_3) == r"\frac{1}{9}" + + expr_4 = Pow(1, -2, evaluate=False) + assert latex(expr_4) == r"1^{-2}" + + +def test_latex_printer_tensor(): + from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, tensor_heads + L = TensorIndexType("L") + i, j, k, l = tensor_indices("i j k l", L) + i0 = tensor_indices("i_0", L) + A, B, C, D = tensor_heads("A B C D", [L]) + H = TensorHead("H", [L, L]) + K = TensorHead("K", [L, L, L, L]) + + assert latex(i) == r"{}^{i}" + assert latex(-i) == r"{}_{i}" + + expr = A(i) + assert latex(expr) == r"A{}^{i}" + + expr = A(i0) + assert latex(expr) == r"A{}^{i_{0}}" + + expr = A(-i) + assert latex(expr) == r"A{}_{i}" + + expr = -3*A(i) + assert latex(expr) == r"-3A{}^{i}" + + expr = K(i, j, -k, -i0) + assert latex(expr) == r"K{}^{ij}{}_{ki_{0}}" + + expr = K(i, -j, -k, i0) + assert latex(expr) == r"K{}^{i}{}_{jk}{}^{i_{0}}" + + expr = K(i, -j, k, -i0) + assert latex(expr) == r"K{}^{i}{}_{j}{}^{k}{}_{i_{0}}" + + expr = H(i, -j) + assert latex(expr) == r"H{}^{i}{}_{j}" + + expr = H(i, j) + assert latex(expr) == r"H{}^{ij}" + + expr = H(-i, -j) + assert latex(expr) == r"H{}_{ij}" + + expr = (1+x)*A(i) + assert latex(expr) == r"\left(x + 1\right)A{}^{i}" + + expr = H(i, -i) + assert latex(expr) == r"H{}^{L_{0}}{}_{L_{0}}" + + expr = H(i, -j)*A(j)*B(k) + assert latex(expr) == r"H{}^{i}{}_{L_{0}}A{}^{L_{0}}B{}^{k}" + + expr = A(i) + 3*B(i) + assert latex(expr) == r"3B{}^{i} + A{}^{i}" + + # Test ``TensorElement``: + from sympy.tensor.tensor import TensorElement + + expr = TensorElement(K(i, j, k, l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3,j,k=2,l}' + + expr = TensorElement(K(i, j, k, l), {i: 3}) + assert latex(expr) == r'K{}^{i=3,jkl}' + + expr = TensorElement(K(i, -j, k, l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3}{}_{j}{}^{k=2,l}' + + expr = TensorElement(K(i, -j, k, -l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3}{}_{j}{}^{k=2}{}_{l}' + + expr = TensorElement(K(i, j, -k, -l), {i: 3, -k: 2}) + assert latex(expr) == r'K{}^{i=3,j}{}_{k=2,l}' + + expr = TensorElement(K(i, j, -k, -l), {i: 3}) + assert latex(expr) == r'K{}^{i=3,j}{}_{kl}' + + expr = PartialDerivative(A(i), A(i)) + assert latex(expr) == r"\frac{\partial}{\partial {A{}^{L_{0}}}}{A{}^{L_{0}}}" + + expr = PartialDerivative(A(-i), A(-j)) + assert latex(expr) == r"\frac{\partial}{\partial {A{}_{j}}}{A{}_{i}}" + + expr = PartialDerivative(K(i, j, -k, -l), A(m), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}^{m}} \partial {A{}_{n}}}{K{}^{ij}{}_{kl}}" + + expr = PartialDerivative(B(-i) + A(-i), A(-j), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}_{j}} \partial {A{}_{n}}}{\left(A{}_{i} + B{}_{i}\right)}" + + expr = PartialDerivative(3*A(-i), A(-j), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}_{j}} \partial {A{}_{n}}}{\left(3A{}_{i}\right)}" + + +def test_multiline_latex(): + a, b, c, d, e, f = symbols('a b c d e f') + expr = -a + 2*b -3*c +4*d -5*e + expected = r"\begin{eqnarray}" + "\n"\ + r"f & = &- a \nonumber\\" + "\n"\ + r"& & + 2 b \nonumber\\" + "\n"\ + r"& & - 3 c \nonumber\\" + "\n"\ + r"& & + 4 d \nonumber\\" + "\n"\ + r"& & - 5 e " + "\n"\ + r"\end{eqnarray}" + assert multiline_latex(f, expr, environment="eqnarray") == expected + + expected2 = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b \nonumber\\' + '\n'\ + r'& & - 3 c + 4 d \nonumber\\' + '\n'\ + r'& & - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 2, environment="eqnarray") == expected2 + + expected3 = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b - 3 c \nonumber\\'+ '\n'\ + r'& & + 4 d - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 3, environment="eqnarray") == expected3 + + expected3dots = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b - 3 c \dots\nonumber\\'+ '\n'\ + r'& & + 4 d - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 3, environment="eqnarray", use_dots=True) == expected3dots + + expected3align = r'\begin{align*}' + '\n'\ + r'f = &- a + 2 b - 3 c \\'+ '\n'\ + r'& + 4 d - 5 e ' + '\n'\ + r'\end{align*}' + + assert multiline_latex(f, expr, 3) == expected3align + assert multiline_latex(f, expr, 3, environment='align*') == expected3align + + expected2ieee = r'\begin{IEEEeqnarray}{rCl}' + '\n'\ + r'f & = &- a + 2 b \nonumber\\' + '\n'\ + r'& & - 3 c + 4 d \nonumber\\' + '\n'\ + r'& & - 5 e ' + '\n'\ + r'\end{IEEEeqnarray}' + + assert multiline_latex(f, expr, 2, environment="IEEEeqnarray") == expected2ieee + + raises(ValueError, lambda: multiline_latex(f, expr, environment="foo")) + +def test_issue_15353(): + a, x = symbols('a x') + # Obtained from nonlinsolve([(sin(a*x)),cos(a*x)],[x,a]) + sol = ConditionSet( + Tuple(x, a), Eq(sin(a*x), 0) & Eq(cos(a*x), 0), S.Complexes**2) + assert latex(sol) == \ + r'\left\{\left( x, \ a\right)\; \middle|\; \left( x, \ a\right) \in ' \ + r'\mathbb{C}^{2} \wedge \sin{\left(a x \right)} = 0 \wedge ' \ + r'\cos{\left(a x \right)} = 0 \right\}' + + +def test_latex_symbolic_probability(): + mu = symbols("mu") + sigma = symbols("sigma", positive=True) + X = Normal("X", mu, sigma) + assert latex(Expectation(X)) == r'\operatorname{E}\left[X\right]' + assert latex(Variance(X)) == r'\operatorname{Var}\left(X\right)' + assert latex(Probability(X > 0)) == r'\operatorname{P}\left(X > 0\right)' + Y = Normal("Y", mu, sigma) + assert latex(Covariance(X, Y)) == r'\operatorname{Cov}\left(X, Y\right)' + + +def test_trace(): + # Issue 15303 + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert latex(trace(A)) == r"\operatorname{tr}\left(A \right)" + assert latex(trace(A**2)) == r"\operatorname{tr}\left(A^{2} \right)" + + +def test_print_basic(): + # Issue 15303 + from sympy.core.basic import Basic + from sympy.core.expr import Expr + + # dummy class for testing printing where the function is not + # implemented in latex.py + class UnimplementedExpr(Expr): + def __new__(cls, e): + return Basic.__new__(cls, e) + + # dummy function for testing + def unimplemented_expr(expr): + return UnimplementedExpr(expr).doit() + + # override class name to use superscript / subscript + def unimplemented_expr_sup_sub(expr): + result = UnimplementedExpr(expr) + result.__class__.__name__ = 'UnimplementedExpr_x^1' + return result + + assert latex(unimplemented_expr(x)) == r'\operatorname{UnimplementedExpr}\left(x\right)' + assert latex(unimplemented_expr(x**2)) == \ + r'\operatorname{UnimplementedExpr}\left(x^{2}\right)' + assert latex(unimplemented_expr_sup_sub(x)) == \ + r'\operatorname{UnimplementedExpr^{1}_{x}}\left(x\right)' + + +def test_MatrixSymbol_bold(): + # Issue #15871 + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert latex(trace(A), mat_symbol_style='bold') == \ + r"\operatorname{tr}\left(\mathbf{A} \right)" + assert latex(trace(A), mat_symbol_style='plain') == \ + r"\operatorname{tr}\left(A \right)" + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert latex(-A, mat_symbol_style='bold') == r"- \mathbf{A}" + assert latex(A - A*B - B, mat_symbol_style='bold') == \ + r"\mathbf{A} - \mathbf{A} \mathbf{B} - \mathbf{B}" + assert latex(-A*B - A*B*C - B, mat_symbol_style='bold') == \ + r"- \mathbf{A} \mathbf{B} - \mathbf{A} \mathbf{B} \mathbf{C} - \mathbf{B}" + + A_k = MatrixSymbol("A_k", 3, 3) + assert latex(A_k, mat_symbol_style='bold') == r"\mathbf{A}_{k}" + + A = MatrixSymbol(r"\nabla_k", 3, 3) + assert latex(A, mat_symbol_style='bold') == r"\mathbf{\nabla}_{k}" + +def test_AppliedPermutation(): + p = Permutation(0, 1, 2) + x = Symbol('x') + assert latex(AppliedPermutation(p, x)) == \ + r'\sigma_{\left( 0\; 1\; 2\right)}(x)' + + +def test_PermutationMatrix(): + p = Permutation(0, 1, 2) + assert latex(PermutationMatrix(p)) == r'P_{\left( 0\; 1\; 2\right)}' + p = Permutation(0, 3)(1, 2) + assert latex(PermutationMatrix(p)) == \ + r'P_{\left( 0\; 3\right)\left( 1\; 2\right)}' + + +def test_issue_21758(): + from sympy.functions.elementary.piecewise import piecewise_fold + from sympy.series.fourier import FourierSeries + x = Symbol('x') + k, n = symbols('k n') + fo = FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), SeqFormula( + Piecewise((-2*pi*cos(n*pi)/n + 2*sin(n*pi)/n**2, (n > -oo) & (n < oo) & Ne(n, 0)), + (0, True))*sin(n*x)/pi, (n, 1, oo)))) + assert latex(piecewise_fold(fo)) == '\\begin{cases} 2 \\sin{\\left(x \\right)}' \ + ' - \\sin{\\left(2 x \\right)} + \\frac{2 \\sin{\\left(3 x \\right)}}{3} +' \ + ' \\ldots & \\text{for}\\: n > -\\infty \\wedge n < \\infty \\wedge ' \ + 'n \\neq 0 \\\\0 & \\text{otherwise} \\end{cases}' + assert latex(FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), + SeqFormula(0, (n, 1, oo))))) == '0' + + +def test_imaginary_unit(): + assert latex(1 + I) == r'1 + i' + assert latex(1 + I, imaginary_unit='i') == r'1 + i' + assert latex(1 + I, imaginary_unit='j') == r'1 + j' + assert latex(1 + I, imaginary_unit='foo') == r'1 + foo' + assert latex(I, imaginary_unit="ti") == r'\text{i}' + assert latex(I, imaginary_unit="tj") == r'\text{j}' + + +def test_text_re_im(): + assert latex(im(x), gothic_re_im=True) == r'\Im{\left(x\right)}' + assert latex(im(x), gothic_re_im=False) == r'\operatorname{im}{\left(x\right)}' + assert latex(re(x), gothic_re_im=True) == r'\Re{\left(x\right)}' + assert latex(re(x), gothic_re_im=False) == r'\operatorname{re}{\left(x\right)}' + + +def test_latex_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField, Differential + from sympy.diffgeom.rn import R2 + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert latex(m) == r'\text{M}' + p = Patch('P', m) + assert latex(p) == r'\text{P}_{\text{M}}' + rect = CoordSystem('rect', p, [x, y]) + assert latex(rect) == r'\text{rect}^{\text{P}}_{\text{M}}' + b = BaseScalarField(rect, 0) + assert latex(b) == r'\mathbf{x}' + + g = Function('g') + s_field = g(R2.x, R2.y) + assert latex(Differential(s_field)) == \ + r'\operatorname{d}\left(g{\left(\mathbf{x},\mathbf{y} \right)}\right)' + + +def test_unit_printing(): + assert latex(5*meter) == r'5 \text{m}' + assert latex(3*gibibyte) == r'3 \text{gibibyte}' + assert latex(4*microgram/second) == r'\frac{4 \mu\text{g}}{\text{s}}' + assert latex(4*micro*gram/second) == r'\frac{4 \mu \text{g}}{\text{s}}' + assert latex(5*milli*meter) == r'5 \text{m} \text{m}' + assert latex(milli) == r'\text{m}' + + +def test_issue_17092(): + x_star = Symbol('x^*') + assert latex(Derivative(x_star, x_star,2)) == r'\frac{d^{2}}{d \left(x^{*}\right)^{2}} x^{*}' + + +def test_latex_decimal_separator(): + + x, y, z, t = symbols('x y z t') + k, m, n = symbols('k m n', integer=True) + f, g, h = symbols('f g h', cls=Function) + + # comma decimal_separator + assert(latex([1, 2.3, 4.5], decimal_separator='comma') == r'\left[ 1; \ 2{,}3; \ 4{,}5\right]') + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='comma') == r'\left\{1; 2{,}3; 4{,}5\right\}') + assert(latex((1, 2.3, 4.6), decimal_separator = 'comma') == r'\left( 1; \ 2{,}3; \ 4{,}6\right)') + assert(latex((1,), decimal_separator='comma') == r'\left( 1;\right)') + + # period decimal_separator + assert(latex([1, 2.3, 4.5], decimal_separator='period') == r'\left[ 1, \ 2.3, \ 4.5\right]' ) + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}') + assert(latex((1, 2.3, 4.6), decimal_separator = 'period') == r'\left( 1, \ 2.3, \ 4.6\right)') + assert(latex((1,), decimal_separator='period') == r'\left( 1,\right)') + + # default decimal_separator + assert(latex([1, 2.3, 4.5]) == r'\left[ 1, \ 2.3, \ 4.5\right]') + assert(latex(FiniteSet(1, 2.3, 4.5)) == r'\left\{1, 2.3, 4.5\right\}') + assert(latex((1, 2.3, 4.6)) == r'\left( 1, \ 2.3, \ 4.6\right)') + assert(latex((1,)) == r'\left( 1,\right)') + + assert(latex(Mul(3.4,5.3), decimal_separator = 'comma') == r'18{,}02') + assert(latex(3.4*5.3, decimal_separator = 'comma') == r'18{,}02') + x = symbols('x') + y = symbols('y') + z = symbols('z') + assert(latex(x*5.3 + 2**y**3.4 + 4.5 + z, decimal_separator = 'comma') == r'2^{y^{3{,}4}} + 5{,}3 x + z + 4{,}5') + + assert(latex(0.987, decimal_separator='comma') == r'0{,}987') + assert(latex(S(0.987), decimal_separator='comma') == r'0{,}987') + assert(latex(.3, decimal_separator='comma') == r'0{,}3') + assert(latex(S(.3), decimal_separator='comma') == r'0{,}3') + + + assert(latex(5.8*10**(-7), decimal_separator='comma') == r'5{,}8 \cdot 10^{-7}') + assert(latex(S(5.7)*10**(-7), decimal_separator='comma') == r'5{,}7 \cdot 10^{-7}') + assert(latex(S(5.7*10**(-7)), decimal_separator='comma') == r'5{,}7 \cdot 10^{-7}') + + x = symbols('x') + assert(latex(1.2*x+3.4, decimal_separator='comma') == r'1{,}2 x + 3{,}4') + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}') + + # Error Handling tests + raises(ValueError, lambda: latex([1,2.3,4.5], decimal_separator='non_existing_decimal_separator_in_list')) + raises(ValueError, lambda: latex(FiniteSet(1,2.3,4.5), decimal_separator='non_existing_decimal_separator_in_set')) + raises(ValueError, lambda: latex((1,2.3,4.5), decimal_separator='non_existing_decimal_separator_in_tuple')) + +def test_Str(): + from sympy.core.symbol import Str + assert str(Str('x')) == r'x' + +def test_latex_escape(): + assert latex_escape(r"~^\&%$#_{}") == "".join([ + r'\textasciitilde', + r'\textasciicircum', + r'\textbackslash', + r'\&', + r'\%', + r'\$', + r'\#', + r'\_', + r'\{', + r'\}', + ]) + +def test_emptyPrinter(): + class MyObject: + def __repr__(self): + return "" + + # unknown objects are monospaced + assert latex(MyObject()) == r"\mathtt{\text{}}" + + # even if they are nested within other objects + assert latex((MyObject(),)) == r"\left( \mathtt{\text{}},\right)" + +def test_global_settings(): + import inspect + + # settings should be visible in the signature of `latex` + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'i' + assert latex(I) == r'i' + try: + # but changing the defaults... + LatexPrinter.set_global_settings(imaginary_unit='j') + # ... should change the signature + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'j' + assert latex(I) == r'j' + finally: + # there's no public API to undo this, but we need to make sure we do + # so as not to impact other tests + del LatexPrinter._global_settings['imaginary_unit'] + + # check we really did undo it + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'i' + assert latex(I) == r'i' + +def test_pickleable(): + # this tests that the _PrintFunction instance is pickleable + import pickle + assert pickle.loads(pickle.dumps(latex)) is latex + +def test_printing_latex_array_expressions(): + assert latex(ArraySymbol("A", (2, 3, 4))) == "A" + assert latex(ArrayElement("A", (2, 1/(1-x), 0))) == "{{A}_{2, \\frac{1}{1 - x}, 0}}" + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + assert latex(ArrayElement(M*N, [x, 0])) == "{{\\left(M N\\right)}_{x, 0}}" + +def test_Array(): + arr = Array(range(10)) + assert latex(arr) == r'\left[\begin{matrix}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9\end{matrix}\right]' + + arr = Array(range(11)) + # fill the empty argument with a bunch of 'c' to avoid latex errors + assert latex(arr) == r'\left[\begin{array}{ccccccccccc}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}\right]' + +def test_latex_with_unevaluated(): + with evaluate(False): + assert latex(a * a) == r"a a" + + +def test_latex_disable_split_super_sub(): + assert latex(Symbol('u^a_b')) == 'u^{a}_{b}' + assert latex(Symbol('u^a_b'), disable_split_super_sub=False) == 'u^{a}_{b}' + assert latex(Symbol('u^a_b'), disable_split_super_sub=True) == 'u\\^a\\_b' diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_maple.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_maple.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb4c512ad3203bd64ae56b350e15734b3a6afb0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_maple.py @@ -0,0 +1,381 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow +from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc, lucas +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix) +from sympy.functions.special.bessel import besseli + +from sympy.printing.maple import maple_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert maple_code(Integer(67)) == "67" + assert maple_code(Integer(-1)) == "-1" + + +def test_Rational(): + assert maple_code(Rational(3, 7)) == "3/7" + assert maple_code(Rational(18, 9)) == "2" + assert maple_code(Rational(3, -7)) == "-3/7" + assert maple_code(Rational(-3, -7)) == "3/7" + assert maple_code(x + Rational(3, 7)) == "x + 3/7" + assert maple_code(Rational(3, 7) * x) == '(3/7)*x' + + +def test_Relational(): + assert maple_code(Eq(x, y)) == "x = y" + assert maple_code(Ne(x, y)) == "x <> y" + assert maple_code(Le(x, y)) == "x <= y" + assert maple_code(Lt(x, y)) == "x < y" + assert maple_code(Gt(x, y)) == "x > y" + assert maple_code(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert maple_code(sin(x) ** cos(x)) == "sin(x)^cos(x)" + assert maple_code(abs(x)) == "abs(x)" + assert maple_code(ceiling(x)) == "ceil(x)" + + +def test_Pow(): + assert maple_code(x ** 3) == "x^3" + assert maple_code(x ** (y ** 3)) == "x^(y^3)" + + assert maple_code((x ** 3) ** y) == "(x^3)^y" + assert maple_code(x ** Rational(2, 3)) == 'x^(2/3)' + + g = implemented_function('g', Lambda(x, 2 * x)) + assert maple_code(1 / (g(x) * 3.5) ** (x - y ** x) / (x ** 2 + y)) == \ + "(3.5*2*x)^(-x + y^x)/(x^2 + y)" + # For issue 14160 + assert maple_code(Mul(-2, x, Pow(Mul(y, y, evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + + +def test_basic_ops(): + assert maple_code(x * y) == "x*y" + assert maple_code(x + y) == "x + y" + assert maple_code(x - y) == "x - y" + assert maple_code(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert maple_code(1 / x) == '1/x' + assert maple_code(x ** -1) == maple_code(x ** -1.0) == '1/x' + assert maple_code(1 / sqrt(x)) == '1/sqrt(x)' + assert maple_code(x ** -S.Half) == maple_code(x ** -0.5) == '1/sqrt(x)' + assert maple_code(sqrt(x)) == 'sqrt(x)' + assert maple_code(x ** S.Half) == maple_code(x ** 0.5) == 'sqrt(x)' + assert maple_code(1 / pi) == '1/Pi' + assert maple_code(pi ** -1) == maple_code(pi ** -1.0) == '1/Pi' + assert maple_code(pi ** -0.5) == '1/sqrt(Pi)' + + +def test_mix_number_mult_symbols(): + assert maple_code(3 * x) == "3*x" + assert maple_code(pi * x) == "Pi*x" + assert maple_code(3 / x) == "3/x" + assert maple_code(pi / x) == "Pi/x" + assert maple_code(x / 3) == '(1/3)*x' + assert maple_code(x / pi) == "x/Pi" + assert maple_code(x * y) == "x*y" + assert maple_code(3 * x * y) == "3*x*y" + assert maple_code(3 * pi * x * y) == "3*Pi*x*y" + assert maple_code(x / y) == "x/y" + assert maple_code(3 * x / y) == "3*x/y" + assert maple_code(x * y / z) == "x*y/z" + assert maple_code(x / y * z) == "x*z/y" + assert maple_code(1 / x / y) == "1/(x*y)" + assert maple_code(2 * pi * x / y / z) == "2*Pi*x/(y*z)" + assert maple_code(3 * pi / x) == "3*Pi/x" + assert maple_code(S(3) / 5) == "3/5" + assert maple_code(S(3) / 5 * x) == '(3/5)*x' + assert maple_code(x / y / z) == "x/(y*z)" + assert maple_code((x + y) / z) == "(x + y)/z" + assert maple_code((x + y) / (z + x)) == "(x + y)/(x + z)" + assert maple_code((x + y) / EulerGamma) == '(x + y)/gamma' + assert maple_code(x / 3 / pi) == '(1/3)*x/Pi' + assert maple_code(S(3) / 5 * x * y / pi) == '(3/5)*x*y/Pi' + + +def test_mix_number_pow_symbols(): + assert maple_code(pi ** 3) == 'Pi^3' + assert maple_code(x ** 2) == 'x^2' + + assert maple_code(x ** (pi ** 3)) == 'x^(Pi^3)' + assert maple_code(x ** y) == 'x^y' + + assert maple_code(x ** (y ** z)) == 'x^(y^z)' + assert maple_code((x ** y) ** z) == '(x^y)^z' + + +def test_imag(): + I = S('I') + assert maple_code(I) == "I" + assert maple_code(5 * I) == "5*I" + + assert maple_code((S(3) / 2) * I) == "(3/2)*I" + assert maple_code(3 + 4 * I) == "3 + 4*I" + + +def test_constants(): + assert maple_code(pi) == "Pi" + assert maple_code(oo) == "infinity" + assert maple_code(-oo) == "-infinity" + assert maple_code(S.NegativeInfinity) == "-infinity" + assert maple_code(S.NaN) == "undefined" + assert maple_code(S.Exp1) == "exp(1)" + assert maple_code(exp(1)) == "exp(1)" + + +def test_constants_other(): + assert maple_code(2 * GoldenRatio) == '2*(1/2 + (1/2)*sqrt(5))' + assert maple_code(2 * Catalan) == '2*Catalan' + assert maple_code(2 * EulerGamma) == "2*gamma" + + +def test_boolean(): + assert maple_code(x & y) == "x and y" + assert maple_code(x | y) == "x or y" + assert maple_code(~x) == "not x" + assert maple_code(x & y & z) == "x and y and z" + assert maple_code(x | y | z) == "x or y or z" + assert maple_code((x & y) | z) == "z or x and y" + assert maple_code((x | y) & z) == "z and (x or y)" + + +def test_Matrices(): + assert maple_code(Matrix(1, 1, [10])) == \ + 'Matrix([[10]], storage = rectangular)' + + A = Matrix([[1, sin(x / 2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = \ + 'Matrix(' \ + '[[1, sin((1/2)*x), abs(x)],' \ + ' [0, 1, Pi],' \ + ' [0, exp(1), ceil(x)]], ' \ + 'storage = rectangular)' + assert maple_code(A) == expected + + # row and columns + assert maple_code(A[:, 0]) == \ + 'Matrix([[1], [0], [0]], storage = rectangular)' + assert maple_code(A[0, :]) == \ + 'Matrix([[1, sin((1/2)*x), abs(x)]], storage = rectangular)' + assert maple_code(Matrix([[x, x - y, -y]])) == \ + 'Matrix([[x, x - y, -y]], storage = rectangular)' + + # empty matrices + assert maple_code(Matrix(0, 0, [])) == \ + 'Matrix([], storage = rectangular)' + assert maple_code(Matrix(0, 3, [])) == \ + 'Matrix([], storage = rectangular)' + +def test_SparseMatrices(): + assert maple_code(SparseMatrix(Identity(2))) == 'Matrix([[1, 0], [0, 1]], storage = sparse)' + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2 / x), 3 * pi / x / 5]]) + assert maple_code(A) == \ + 'Matrix([[1, sin(2/x), (3/5)*Pi/x]], storage = rectangular)' + assert maple_code(A.T) == \ + 'Matrix([[1], [sin(2/x)], [(3/5)*Pi/x]], storage = rectangular)' + + +def test_Matrices_entries_not_hadamard(): + A = Matrix([[1, sin(2 / x), 3 * pi / x / 5], [1, 2, x * y]]) + expected = \ + 'Matrix([[1, sin(2/x), (3/5)*Pi/x], [1, 2, x*y]], ' \ + 'storage = rectangular)' + assert maple_code(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert maple_code(A * B) == "A.B" + assert maple_code(B * A) == "B.A" + assert maple_code(2 * A * B) == "2*A.B" + assert maple_code(B * 2 * A) == "2*B.A" + + assert maple_code( + A * (B + 3 * Identity(n))) == "A.(3*Matrix(n, shape = identity) + B)" + + assert maple_code(A ** (x ** 2)) == "MatrixPower(A, x^2)" + assert maple_code(A ** 3) == "MatrixPower(A, 3)" + assert maple_code(A ** (S.Half)) == "MatrixPower(A, 1/2)" + + +def test_special_matrices(): + assert maple_code(6 * Identity(3)) == "6*Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = sparse)" + assert maple_code(Identity(x)) == 'Matrix(x, shape = identity)' + + +def test_containers(): + assert maple_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "[1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]" + + assert maple_code((1, 2, (3, 4))) == "[1, 2, [3, 4]]" + assert maple_code([1]) == "[1]" + assert maple_code((1,)) == "[1]" + assert maple_code(Tuple(*[1, 2, 3])) == "[1, 2, 3]" + assert maple_code((1, x * y, (3, x ** 2))) == "[1, x*y, [3, x^2]]" + # scalar, matrix, empty matrix and empty list + + assert maple_code((1, eye(3), Matrix(0, 0, []), [])) == \ + "[1, Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = rectangular), Matrix([], storage = rectangular), []]" + + +def test_maple_noninline(): + source = maple_code((x + y)/Catalan, assign_to='me', inline=False) + expected = "me := (x + y)/Catalan" + + assert source == expected + + +def test_maple_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert maple_code(A, assign_to='a') == "a := Matrix([[1, 2, 3]], storage = rectangular)" + A = Matrix([[1, 2], [3, 4]]) + assert maple_code(A, assign_to='A') == "A := Matrix([[1, 2], [3, 4]], storage = rectangular)" + + +def test_maple_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert maple_code(A, assign_to=B) == "B := Matrix([[1, 2, 3]], storage = rectangular)" + raises(ValueError, lambda: maple_code(A, assign_to=x)) + raises(ValueError, lambda: maple_code(A, assign_to=C)) + + +def test_maple_matrix_1x1(): + A = Matrix([[3]]) + assert maple_code(A, assign_to='B') == "B := Matrix([[3]], storage = rectangular)" + + +def test_maple_matrix_elements(): + A = Matrix([[x, 2, x * y]]) + + assert maple_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x^2 + x*y + 2" + AA = MatrixSymbol('AA', 1, 3) + assert maple_code(AA) == "AA" + + assert maple_code(AA[0, 0] ** 2 + sin(AA[0, 1]) + AA[0, 2]) == \ + "sin(AA[1, 2]) + AA[1, 1]^2 + AA[1, 3]" + assert maple_code(sum(AA)) == "AA[1, 1] + AA[1, 2] + AA[1, 3]" + + +def test_maple_boolean(): + assert maple_code(True) == "true" + assert maple_code(S.true) == "true" + assert maple_code(False) == "false" + assert maple_code(S.false) == "false" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x * y + assert maple_code(M) == \ + 'Matrix([[0, 0, 0, 30, 0, 0],' \ + ' [0, 0, 20, 22, 0, 0],' \ + ' [0, 0, 10, 0, 0, 0],' \ + ' [x*y, 0, 0, 0, 0, 0],' \ + ' [0, 0, 0, 0, 0, 0]], ' \ + 'storage = sparse)' + +# Not an important point. +def test_maple_not_supported(): + with raises(NotImplementedError): + maple_code(S.ComplexInfinity) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + + assert (maple_code(A[0, 0]) == "A[1, 1]") + assert (maple_code(3 * A[0, 0]) == "3*A[1, 1]") + + F = A-B + + assert (maple_code(F[0,0]) == "A[1, 1] - B[1, 1]") + + +def test_hadamard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + assert maple_code(C) == "A*B" + + assert maple_code(C * v) == "(A*B).v" + # HadamardProduct is higher than dot product. + + assert maple_code(h * C * v) == "h.(A*B).v" + + assert maple_code(C * A) == "(A*B).A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + + assert maple_code(C * x * y) == "x*y*(A*B)" + + +def test_maple_piecewise(): + expr = Piecewise((x, x < 1), (x ** 2, True)) + + assert maple_code(expr) == "piecewise(x < 1, x, x^2)" + assert maple_code(expr, assign_to="r") == ( + "r := piecewise(x < 1, x, x^2)") + + expr = Piecewise((x ** 2, x < 1), (x ** 3, x < 2), (x ** 4, x < 3), (x ** 5, True)) + expected = "piecewise(x < 1, x^2, x < 2, x^3, x < 3, x^4, x^5)" + assert maple_code(expr) == expected + assert maple_code(expr, assign_to="r") == "r := " + expected + + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: maple_code(expr)) + + +def test_maple_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x ** 2, True)) + + assert maple_code(2 * pw) == "2*piecewise(x < 1, x, x^2)" + assert maple_code(pw / x) == "piecewise(x < 1, x, x^2)/x" + assert maple_code(pw / (x * y)) == "piecewise(x < 1, x, x^2)/(x*y)" + assert maple_code(pw / 3) == "(1/3)*piecewise(x < 1, x, x^2)" + + +def test_maple_derivatives(): + f = Function('f') + assert maple_code(f(x).diff(x)) == 'diff(f(x), x)' + assert maple_code(f(x).diff(x, 2)) == 'diff(f(x), x$2)' + + +def test_automatic_rewrites(): + assert maple_code(lucas(x)) == '(2^(-x)*((1 - sqrt(5))^x + (1 + sqrt(5))^x))' + assert maple_code(sinc(x)) == '(piecewise(x <> 0, sin(x)/x, 1))' + + +def test_specfun(): + assert maple_code('asin(x)') == 'arcsin(x)' + assert maple_code(besseli(x, y)) == 'BesselI(x, y)' diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathematica.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf6b537677442ae59a4f1bbd2b5774d6646f4e2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathematica.py @@ -0,0 +1,287 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, Tuple, + Derivative, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.integrals import Integral +from sympy.concrete import Sum +from sympy.functions import (exp, sin, cos, fresnelc, fresnels, conjugate, Max, + Min, gamma, polygamma, loggamma, erf, erfi, erfc, + erf2, expint, erfinv, erfcinv, Ei, Si, Ci, li, + Shi, Chi, uppergamma, beta, subfactorial, erf2inv, + factorial, factorial2, catalan, RisingFactorial, + FallingFactorial, harmonic, atan2, sec, acsc, + hermite, laguerre, assoc_laguerre, jacobi, + gegenbauer, chebyshevt, chebyshevu, legendre, + assoc_legendre, Li, LambertW) + +from sympy.printing.mathematica import mathematica_code as mcode + +x, y, z, w = symbols('x,y,z,w') +f = Function('f') + + +def test_Integer(): + assert mcode(Integer(67)) == "67" + assert mcode(Integer(-1)) == "-1" + + +def test_Rational(): + assert mcode(Rational(3, 7)) == "3/7" + assert mcode(Rational(18, 9)) == "2" + assert mcode(Rational(3, -7)) == "-3/7" + assert mcode(Rational(-3, -7)) == "3/7" + assert mcode(x + Rational(3, 7)) == "x + 3/7" + assert mcode(Rational(3, 7)*x) == "(3/7)*x" + + +def test_Relational(): + assert mcode(Eq(x, y)) == "x == y" + assert mcode(Ne(x, y)) == "x != y" + assert mcode(Le(x, y)) == "x <= y" + assert mcode(Lt(x, y)) == "x < y" + assert mcode(Gt(x, y)) == "x > y" + assert mcode(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert mcode(f(x, y, z)) == "f[x, y, z]" + assert mcode(sin(x) ** cos(x)) == "Sin[x]^Cos[x]" + assert mcode(sec(x) * acsc(x)) == "ArcCsc[x]*Sec[x]" + assert mcode(atan2(y, x)) == "ArcTan[x, y]" + assert mcode(conjugate(x)) == "Conjugate[x]" + assert mcode(Max(x, y, z)*Min(y, z)) == "Max[x, y, z]*Min[y, z]" + assert mcode(fresnelc(x)) == "FresnelC[x]" + assert mcode(fresnels(x)) == "FresnelS[x]" + assert mcode(gamma(x)) == "Gamma[x]" + assert mcode(uppergamma(x, y)) == "Gamma[x, y]" + assert mcode(polygamma(x, y)) == "PolyGamma[x, y]" + assert mcode(loggamma(x)) == "LogGamma[x]" + assert mcode(erf(x)) == "Erf[x]" + assert mcode(erfc(x)) == "Erfc[x]" + assert mcode(erfi(x)) == "Erfi[x]" + assert mcode(erf2(x, y)) == "Erf[x, y]" + assert mcode(expint(x, y)) == "ExpIntegralE[x, y]" + assert mcode(erfcinv(x)) == "InverseErfc[x]" + assert mcode(erfinv(x)) == "InverseErf[x]" + assert mcode(erf2inv(x, y)) == "InverseErf[x, y]" + assert mcode(Ei(x)) == "ExpIntegralEi[x]" + assert mcode(Ci(x)) == "CosIntegral[x]" + assert mcode(li(x)) == "LogIntegral[x]" + assert mcode(Si(x)) == "SinIntegral[x]" + assert mcode(Shi(x)) == "SinhIntegral[x]" + assert mcode(Chi(x)) == "CoshIntegral[x]" + assert mcode(beta(x, y)) == "Beta[x, y]" + assert mcode(factorial(x)) == "Factorial[x]" + assert mcode(factorial2(x)) == "Factorial2[x]" + assert mcode(subfactorial(x)) == "Subfactorial[x]" + assert mcode(FallingFactorial(x, y)) == "FactorialPower[x, y]" + assert mcode(RisingFactorial(x, y)) == "Pochhammer[x, y]" + assert mcode(catalan(x)) == "CatalanNumber[x]" + assert mcode(harmonic(x)) == "HarmonicNumber[x]" + assert mcode(harmonic(x, y)) == "HarmonicNumber[x, y]" + assert mcode(Li(x)) == "LogIntegral[x] - LogIntegral[2]" + assert mcode(LambertW(x)) == "ProductLog[x]" + assert mcode(LambertW(x, -1)) == "ProductLog[-1, x]" + assert mcode(LambertW(x, y)) == "ProductLog[y, x]" + + +def test_special_polynomials(): + assert mcode(hermite(x, y)) == "HermiteH[x, y]" + assert mcode(laguerre(x, y)) == "LaguerreL[x, y]" + assert mcode(assoc_laguerre(x, y, z)) == "LaguerreL[x, y, z]" + assert mcode(jacobi(x, y, z, w)) == "JacobiP[x, y, z, w]" + assert mcode(gegenbauer(x, y, z)) == "GegenbauerC[x, y, z]" + assert mcode(chebyshevt(x, y)) == "ChebyshevT[x, y]" + assert mcode(chebyshevu(x, y)) == "ChebyshevU[x, y]" + assert mcode(legendre(x, y)) == "LegendreP[x, y]" + assert mcode(assoc_legendre(x, y, z)) == "LegendreP[x, y, z]" + + +def test_Pow(): + assert mcode(x**3) == "x^3" + assert mcode(x**(y**3)) == "x^(y^3)" + assert mcode(1/(f(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*f[x])^(-x + y^x)/(x^2 + y)" + assert mcode(x**-1.0) == 'x^(-1.0)' + assert mcode(x**Rational(2, 3)) == 'x^(2/3)' + + +def test_Mul(): + A, B, C, D = symbols('A B C D', commutative=False) + assert mcode(x*y*z) == "x*y*z" + assert mcode(x*y*A) == "x*y*A" + assert mcode(x*y*A*B) == "x*y*A**B" + assert mcode(x*y*A*B*C) == "x*y*A**B**C" + assert mcode(x*A*B*(C + D)*A*y) == "x*y*A**B**(C + D)**A" + + +def test_constants(): + assert mcode(S.Zero) == "0" + assert mcode(S.One) == "1" + assert mcode(S.NegativeOne) == "-1" + assert mcode(S.Half) == "1/2" + assert mcode(S.ImaginaryUnit) == "I" + + assert mcode(oo) == "Infinity" + assert mcode(S.NegativeInfinity) == "-Infinity" + assert mcode(S.ComplexInfinity) == "ComplexInfinity" + assert mcode(S.NaN) == "Indeterminate" + + assert mcode(S.Exp1) == "E" + assert mcode(pi) == "Pi" + assert mcode(S.GoldenRatio) == "GoldenRatio" + assert mcode(S.TribonacciConstant) == \ + "(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \ + "(1/3)*(3*33^(1/2) + 19)^(1/3))" + assert mcode(2*S.TribonacciConstant) == \ + "2*(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \ + "(1/3)*(3*33^(1/2) + 19)^(1/3))" + assert mcode(S.EulerGamma) == "EulerGamma" + assert mcode(S.Catalan) == "Catalan" + + +def test_containers(): + assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}" + assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}" + assert mcode([1]) == "{1}" + assert mcode((1,)) == "{1}" + assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}" + + +def test_matrices(): + from sympy.matrices import MutableDenseMatrix, MutableSparseMatrix, \ + ImmutableDenseMatrix, ImmutableSparseMatrix + A = MutableDenseMatrix( + [[1, -1, 0, 0], + [0, 1, -1, 0], + [0, 0, 1, -1], + [0, 0, 0, 1]] + ) + B = MutableSparseMatrix(A) + C = ImmutableDenseMatrix(A) + D = ImmutableSparseMatrix(A) + + assert mcode(C) == mcode(A) == \ + "{{1, -1, 0, 0}, " \ + "{0, 1, -1, 0}, " \ + "{0, 0, 1, -1}, " \ + "{0, 0, 0, 1}}" + + assert mcode(D) == mcode(B) == \ + "SparseArray[{" \ + "{1, 1} -> 1, {1, 2} -> -1, {2, 2} -> 1, {2, 3} -> -1, " \ + "{3, 3} -> 1, {3, 4} -> -1, {4, 4} -> 1" \ + "}, {4, 4}]" + + # Trivial cases of matrices + assert mcode(MutableDenseMatrix(0, 0, [])) == '{}' + assert mcode(MutableSparseMatrix(0, 0, [])) == 'SparseArray[{}, {0, 0}]' + assert mcode(MutableDenseMatrix(0, 3, [])) == '{}' + assert mcode(MutableSparseMatrix(0, 3, [])) == 'SparseArray[{}, {0, 3}]' + assert mcode(MutableDenseMatrix(3, 0, [])) == '{{}, {}, {}}' + assert mcode(MutableSparseMatrix(3, 0, [])) == 'SparseArray[{}, {3, 0}]' + +def test_NDArray(): + from sympy.tensor.array import ( + MutableDenseNDimArray, ImmutableDenseNDimArray, + MutableSparseNDimArray, ImmutableSparseNDimArray) + + example = MutableDenseNDimArray( + [[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], + [[13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24]]] + ) + + assert mcode(example) == \ + "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \ + "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}" + + example = ImmutableDenseNDimArray(example) + + assert mcode(example) == \ + "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \ + "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}" + + example = MutableSparseNDimArray(example) + + assert mcode(example) == \ + "SparseArray[{" \ + "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \ + "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \ + "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \ + "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \ + "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \ + "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \ + "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \ + "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \ + "}, {2, 3, 4}]" + + example = ImmutableSparseNDimArray(example) + + assert mcode(example) == \ + "SparseArray[{" \ + "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \ + "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \ + "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \ + "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \ + "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \ + "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \ + "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \ + "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \ + "}, {2, 3, 4}]" + + +def test_Integral(): + assert mcode(Integral(sin(sin(x)), x)) == "Hold[Integrate[Sin[Sin[x]], x]]" + assert mcode(Integral(exp(-x**2 - y**2), + (x, -oo, oo), + (y, -oo, oo))) == \ + "Hold[Integrate[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \ + "{y, -Infinity, Infinity}]]" + + +def test_Derivative(): + assert mcode(Derivative(sin(x), x)) == "Hold[D[Sin[x], x]]" + assert mcode(Derivative(x, x)) == "Hold[D[x, x]]" + assert mcode(Derivative(sin(x)*y**4, x, 2)) == "Hold[D[y^4*Sin[x], {x, 2}]]" + assert mcode(Derivative(sin(x)*y**4, x, y, x)) == "Hold[D[y^4*Sin[x], x, y, x]]" + assert mcode(Derivative(sin(x)*y**4, x, y, 3, x)) == "Hold[D[y^4*Sin[x], x, {y, 3}, x]]" + + +def test_Sum(): + assert mcode(Sum(sin(x), (x, 0, 10))) == "Hold[Sum[Sin[x], {x, 0, 10}]]" + assert mcode(Sum(exp(-x**2 - y**2), + (x, -oo, oo), + (y, -oo, oo))) == \ + "Hold[Sum[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \ + "{y, -Infinity, Infinity}]]" + + +def test_comment(): + from sympy.printing.mathematica import MCodePrinter + assert MCodePrinter()._get_comment("Hello World") == \ + "(* Hello World *)" + + +def test_userfuncs(): + # Dictionary mutation test + some_function = symbols("some_function", cls=Function) + my_user_functions = {"some_function": "SomeFunction"} + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeFunction[z]' + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeFunction[z]' + + # List argument test + my_user_functions = \ + {"some_function": [(lambda x: True, "SomeOtherFunction")]} + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeOtherFunction[z]' diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathml.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathml.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7c2253c98fb1a4e99375774ad158df9b80b439 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_mathml.py @@ -0,0 +1,2048 @@ +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.function import Derivative, Lambda, diff, Function +from sympy.core.numbers import (zoo, Float, Integer, I, oo, pi, E, + Rational) +from sympy.core.relational import Lt, Ge, Ne, Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols, Symbol +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (factorial2, + binomial, factorial) +from sympy.functions.combinatorial.numbers import (lucas, bell, + catalan, euler, tribonacci, fibonacci, bernoulli, primenu, primeomega, + totient, reduced_totient) +from sympy.functions.elementary.complexes import re, im, conjugate, Abs +from sympy.functions.elementary.exponential import exp, LambertW, log +from sympy.functions.elementary.hyperbolic import (tanh, acoth, atanh, + coth, asinh, acsch, asech, acosh, csch, sinh, cosh, sech) +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import Max, Min +from sympy.functions.elementary.trigonometric import (csc, sec, tan, + atan, sin, asec, cot, cos, acot, acsc, asin, acos) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.elliptic_integrals import (elliptic_pi, + elliptic_f, elliptic_k, elliptic_e) +from sympy.functions.special.error_functions import (fresnelc, + fresnels, Ei, expint) +from sympy.functions.special.gamma_functions import (gamma, uppergamma, + lowergamma) +from sympy.functions.special.mathieu_functions import (mathieusprime, + mathieus, mathieucprime, mathieuc) +from sympy.functions.special.polynomials import (jacobi, chebyshevu, + chebyshevt, hermite, assoc_legendre, gegenbauer, assoc_laguerre, + legendre, laguerre) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.zeta_functions import (polylog, stieltjes, + lerchphi, dirichlet_eta, zeta) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (Xor, Or, false, true, And, Equivalent, + Implies, Not) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.determinant import Determinant +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.physics.quantum import (ComplexSpace, FockSpace, hbar, + HilbertSpace, Dagger) +from sympy.printing.mathml import (MathMLPresentationPrinter, + MathMLPrinter, MathMLContentPrinter, mathml) +from sympy.series.limits import Limit +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Interval, Union, SymmetricDifference, + Complement, FiniteSet, Intersection, ProductSet) +from sympy.stats.rv import RandomSymbol +from sympy.tensor.indexed import IndexedBase +from sympy.vector import (Divergence, CoordSys3D, Cross, Curl, Dot, + Laplacian, Gradient) +from sympy.testing.pytest import raises, XFAIL + +x, y, z, a, b, c, d, e, n = symbols('x:z a:e n') +mp = MathMLContentPrinter() +mpp = MathMLPresentationPrinter() + + +def test_mathml_printer(): + m = MathMLPrinter() + assert m.doprint(1+x) == mp.doprint(1+x) + + +def test_content_printmethod(): + assert mp.doprint(1 + x) == 'x1' + + +def test_content_mathml_core(): + mml_1 = mp._print(1 + x) + assert mml_1.nodeName == 'apply' + nodes = mml_1.childNodes + assert len(nodes) == 3 + assert nodes[0].nodeName == 'plus' + assert nodes[0].hasChildNodes() is False + assert nodes[0].nodeValue is None + assert nodes[1].nodeName in ['cn', 'ci'] + if nodes[1].nodeName == 'cn': + assert nodes[1].childNodes[0].nodeValue == '1' + assert nodes[2].childNodes[0].nodeValue == 'x' + else: + assert nodes[1].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mp._print(x**2) + assert mml_2.nodeName == 'apply' + nodes = mml_2.childNodes + assert nodes[1].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '2' + + mml_3 = mp._print(2*x) + assert mml_3.nodeName == 'apply' + nodes = mml_3.childNodes + assert nodes[0].nodeName == 'times' + assert nodes[1].childNodes[0].nodeValue == '2' + assert nodes[2].childNodes[0].nodeValue == 'x' + + mml = mp._print(Float(1.0, 2)*x) + assert mml.nodeName == 'apply' + nodes = mml.childNodes + assert nodes[0].nodeName == 'times' + assert nodes[1].childNodes[0].nodeValue == '1.0' + assert nodes[2].childNodes[0].nodeValue == 'x' + + +def test_content_mathml_functions(): + mml_1 = mp._print(sin(x)) + assert mml_1.nodeName == 'apply' + assert mml_1.childNodes[0].nodeName == 'sin' + assert mml_1.childNodes[1].nodeName == 'ci' + + mml_2 = mp._print(diff(sin(x), x, evaluate=False)) + assert mml_2.nodeName == 'apply' + assert mml_2.childNodes[0].nodeName == 'diff' + assert mml_2.childNodes[1].nodeName == 'bvar' + assert mml_2.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + + mml_3 = mp._print(diff(cos(x*y), x, evaluate=False)) + assert mml_3.nodeName == 'apply' + assert mml_3.childNodes[0].nodeName == 'partialdiff' + assert mml_3.childNodes[1].nodeName == 'bvar' + assert mml_3.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + + mml_4 = mp._print(Lambda((x, y), x * y)) + assert mml_4.nodeName == 'lambda' + assert mml_4.childNodes[0].nodeName == 'bvar' + assert mml_4.childNodes[0].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + assert mml_4.childNodes[1].nodeName == 'bvar' + assert mml_4.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's y/ci> + assert mml_4.childNodes[2].nodeName == 'apply' + + +def test_content_mathml_limits(): + # XXX No unevaluated limits + lim_fun = sin(x)/x + mml_1 = mp._print(Limit(lim_fun, x, 0)) + assert mml_1.childNodes[0].nodeName == 'limit' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].toxml() == mp._print(lim_fun).toxml() + + +def test_content_mathml_integrals(): + integrand = x + mml_1 = mp._print(Integral(integrand, (x, 0, 1))) + assert mml_1.childNodes[0].nodeName == 'int' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].nodeName == 'uplimit' + assert mml_1.childNodes[4].toxml() == mp._print(integrand).toxml() + + +def test_content_mathml_matrices(): + A = Matrix([1, 2, 3]) + B = Matrix([[0, 5, 4], [2, 3, 1], [9, 7, 9]]) + mll_1 = mp._print(A) + assert mll_1.childNodes[0].nodeName == 'matrixrow' + assert mll_1.childNodes[0].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[0].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_1.childNodes[1].nodeName == 'matrixrow' + assert mll_1.childNodes[1].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_1.childNodes[2].nodeName == 'matrixrow' + assert mll_1.childNodes[2].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[2].childNodes[0].childNodes[0].nodeValue == '3' + mll_2 = mp._print(B) + assert mll_2.childNodes[0].nodeName == 'matrixrow' + assert mll_2.childNodes[0].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[0].childNodes[0].nodeValue == '0' + assert mll_2.childNodes[0].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[1].childNodes[0].nodeValue == '5' + assert mll_2.childNodes[0].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[2].childNodes[0].nodeValue == '4' + assert mll_2.childNodes[1].nodeName == 'matrixrow' + assert mll_2.childNodes[1].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_2.childNodes[1].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[1].childNodes[0].nodeValue == '3' + assert mll_2.childNodes[1].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[2].childNodes[0].nodeValue == '1' + assert mll_2.childNodes[2].nodeName == 'matrixrow' + assert mll_2.childNodes[2].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[0].childNodes[0].nodeValue == '9' + assert mll_2.childNodes[2].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[1].childNodes[0].nodeValue == '7' + assert mll_2.childNodes[2].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[2].childNodes[0].nodeValue == '9' + + +def test_content_mathml_sums(): + summand = x + mml_1 = mp._print(Sum(summand, (x, 1, 10))) + assert mml_1.childNodes[0].nodeName == 'sum' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].nodeName == 'uplimit' + assert mml_1.childNodes[4].toxml() == mp._print(summand).toxml() + + +def test_content_mathml_tuples(): + mml_1 = mp._print([2]) + assert mml_1.nodeName == 'list' + assert mml_1.childNodes[0].nodeName == 'cn' + assert len(mml_1.childNodes) == 1 + + mml_2 = mp._print([2, Integer(1)]) + assert mml_2.nodeName == 'list' + assert mml_2.childNodes[0].nodeName == 'cn' + assert mml_2.childNodes[1].nodeName == 'cn' + assert len(mml_2.childNodes) == 2 + + +def test_content_mathml_add(): + mml = mp._print(x**5 - x**4 + x) + assert mml.childNodes[0].nodeName == 'plus' + assert mml.childNodes[1].childNodes[0].nodeName == 'minus' + assert mml.childNodes[1].childNodes[1].nodeName == 'apply' + + +def test_content_mathml_Rational(): + mml_1 = mp._print(Rational(1, 1)) + """should just return a number""" + assert mml_1.nodeName == 'cn' + + mml_2 = mp._print(Rational(2, 5)) + assert mml_2.childNodes[0].nodeName == 'divide' + + +def test_content_mathml_constants(): + mml = mp._print(I) + assert mml.nodeName == 'imaginaryi' + + mml = mp._print(E) + assert mml.nodeName == 'exponentiale' + + mml = mp._print(oo) + assert mml.nodeName == 'infinity' + + mml = mp._print(pi) + assert mml.nodeName == 'pi' + + assert mathml(hbar) == '' + assert mathml(S.TribonacciConstant) == '' + assert mathml(S.GoldenRatio) == 'φ' + mml = mathml(S.EulerGamma) + assert mml == '' + + mml = mathml(S.EmptySet) + assert mml == '' + + mml = mathml(S.true) + assert mml == '' + + mml = mathml(S.false) + assert mml == '' + + mml = mathml(S.NaN) + assert mml == '' + + +def test_content_mathml_trig(): + mml = mp._print(sin(x)) + assert mml.childNodes[0].nodeName == 'sin' + + mml = mp._print(cos(x)) + assert mml.childNodes[0].nodeName == 'cos' + + mml = mp._print(tan(x)) + assert mml.childNodes[0].nodeName == 'tan' + + mml = mp._print(cot(x)) + assert mml.childNodes[0].nodeName == 'cot' + + mml = mp._print(csc(x)) + assert mml.childNodes[0].nodeName == 'csc' + + mml = mp._print(sec(x)) + assert mml.childNodes[0].nodeName == 'sec' + + mml = mp._print(asin(x)) + assert mml.childNodes[0].nodeName == 'arcsin' + + mml = mp._print(acos(x)) + assert mml.childNodes[0].nodeName == 'arccos' + + mml = mp._print(atan(x)) + assert mml.childNodes[0].nodeName == 'arctan' + + mml = mp._print(acot(x)) + assert mml.childNodes[0].nodeName == 'arccot' + + mml = mp._print(acsc(x)) + assert mml.childNodes[0].nodeName == 'arccsc' + + mml = mp._print(asec(x)) + assert mml.childNodes[0].nodeName == 'arcsec' + + mml = mp._print(sinh(x)) + assert mml.childNodes[0].nodeName == 'sinh' + + mml = mp._print(cosh(x)) + assert mml.childNodes[0].nodeName == 'cosh' + + mml = mp._print(tanh(x)) + assert mml.childNodes[0].nodeName == 'tanh' + + mml = mp._print(coth(x)) + assert mml.childNodes[0].nodeName == 'coth' + + mml = mp._print(csch(x)) + assert mml.childNodes[0].nodeName == 'csch' + + mml = mp._print(sech(x)) + assert mml.childNodes[0].nodeName == 'sech' + + mml = mp._print(asinh(x)) + assert mml.childNodes[0].nodeName == 'arcsinh' + + mml = mp._print(atanh(x)) + assert mml.childNodes[0].nodeName == 'arctanh' + + mml = mp._print(acosh(x)) + assert mml.childNodes[0].nodeName == 'arccosh' + + mml = mp._print(acoth(x)) + assert mml.childNodes[0].nodeName == 'arccoth' + + mml = mp._print(acsch(x)) + assert mml.childNodes[0].nodeName == 'arccsch' + + mml = mp._print(asech(x)) + assert mml.childNodes[0].nodeName == 'arcsech' + + +def test_content_mathml_relational(): + mml_1 = mp._print(Eq(x, 1)) + assert mml_1.nodeName == 'apply' + assert mml_1.childNodes[0].nodeName == 'eq' + assert mml_1.childNodes[1].nodeName == 'ci' + assert mml_1.childNodes[1].childNodes[0].nodeValue == 'x' + assert mml_1.childNodes[2].nodeName == 'cn' + assert mml_1.childNodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mp._print(Ne(1, x)) + assert mml_2.nodeName == 'apply' + assert mml_2.childNodes[0].nodeName == 'neq' + assert mml_2.childNodes[1].nodeName == 'cn' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_2.childNodes[2].nodeName == 'ci' + assert mml_2.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_3 = mp._print(Ge(1, x)) + assert mml_3.nodeName == 'apply' + assert mml_3.childNodes[0].nodeName == 'geq' + assert mml_3.childNodes[1].nodeName == 'cn' + assert mml_3.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_3.childNodes[2].nodeName == 'ci' + assert mml_3.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_4 = mp._print(Lt(1, x)) + assert mml_4.nodeName == 'apply' + assert mml_4.childNodes[0].nodeName == 'lt' + assert mml_4.childNodes[1].nodeName == 'cn' + assert mml_4.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_4.childNodes[2].nodeName == 'ci' + assert mml_4.childNodes[2].childNodes[0].nodeValue == 'x' + + +def test_content_symbol(): + mml = mp._print(x) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeValue == 'x' + del mml + + mml = mp._print(Symbol("x^2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x__2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msub' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x^3_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msubsup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[0].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mp._print(Symbol("x__3_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msubsup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[0].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mp._print(Symbol("x_2_a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msub' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + mml = mp._print(Symbol("x^2^a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + mml = mp._print(Symbol("x__2__a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + +def test_content_mathml_greek(): + mml = mp._print(Symbol('alpha')) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeValue == '\N{GREEK SMALL LETTER ALPHA}' + + assert mp.doprint(Symbol('alpha')) == 'α' + assert mp.doprint(Symbol('beta')) == 'β' + assert mp.doprint(Symbol('gamma')) == 'γ' + assert mp.doprint(Symbol('delta')) == 'δ' + assert mp.doprint(Symbol('epsilon')) == 'ε' + assert mp.doprint(Symbol('zeta')) == 'ζ' + assert mp.doprint(Symbol('eta')) == 'η' + assert mp.doprint(Symbol('theta')) == 'θ' + assert mp.doprint(Symbol('iota')) == 'ι' + assert mp.doprint(Symbol('kappa')) == 'κ' + assert mp.doprint(Symbol('lambda')) == 'λ' + assert mp.doprint(Symbol('mu')) == 'μ' + assert mp.doprint(Symbol('nu')) == 'ν' + assert mp.doprint(Symbol('xi')) == 'ξ' + assert mp.doprint(Symbol('omicron')) == 'ο' + assert mp.doprint(Symbol('pi')) == 'π' + assert mp.doprint(Symbol('rho')) == 'ρ' + assert mp.doprint(Symbol('varsigma')) == 'ς' + assert mp.doprint(Symbol('sigma')) == 'σ' + assert mp.doprint(Symbol('tau')) == 'τ' + assert mp.doprint(Symbol('upsilon')) == 'υ' + assert mp.doprint(Symbol('phi')) == 'φ' + assert mp.doprint(Symbol('chi')) == 'χ' + assert mp.doprint(Symbol('psi')) == 'ψ' + assert mp.doprint(Symbol('omega')) == 'ω' + + assert mp.doprint(Symbol('Alpha')) == 'Α' + assert mp.doprint(Symbol('Beta')) == 'Β' + assert mp.doprint(Symbol('Gamma')) == 'Γ' + assert mp.doprint(Symbol('Delta')) == 'Δ' + assert mp.doprint(Symbol('Epsilon')) == 'Ε' + assert mp.doprint(Symbol('Zeta')) == 'Ζ' + assert mp.doprint(Symbol('Eta')) == 'Η' + assert mp.doprint(Symbol('Theta')) == 'Θ' + assert mp.doprint(Symbol('Iota')) == 'Ι' + assert mp.doprint(Symbol('Kappa')) == 'Κ' + assert mp.doprint(Symbol('Lambda')) == 'Λ' + assert mp.doprint(Symbol('Mu')) == 'Μ' + assert mp.doprint(Symbol('Nu')) == 'Ν' + assert mp.doprint(Symbol('Xi')) == 'Ξ' + assert mp.doprint(Symbol('Omicron')) == 'Ο' + assert mp.doprint(Symbol('Pi')) == 'Π' + assert mp.doprint(Symbol('Rho')) == 'Ρ' + assert mp.doprint(Symbol('Sigma')) == 'Σ' + assert mp.doprint(Symbol('Tau')) == 'Τ' + assert mp.doprint(Symbol('Upsilon')) == 'Υ' + assert mp.doprint(Symbol('Phi')) == 'Φ' + assert mp.doprint(Symbol('Chi')) == 'Χ' + assert mp.doprint(Symbol('Psi')) == 'Ψ' + assert mp.doprint(Symbol('Omega')) == 'Ω' + + +def test_content_mathml_order(): + expr = x**3 + x**2*y + 3*x*y**3 + y**4 + + mp = MathMLContentPrinter({'order': 'lex'}) + mml = mp._print(expr) + + assert mml.childNodes[1].childNodes[0].nodeName == 'power' + assert mml.childNodes[1].childNodes[1].childNodes[0].data == 'x' + assert mml.childNodes[1].childNodes[2].childNodes[0].data == '3' + + assert mml.childNodes[4].childNodes[0].nodeName == 'power' + assert mml.childNodes[4].childNodes[1].childNodes[0].data == 'y' + assert mml.childNodes[4].childNodes[2].childNodes[0].data == '4' + + mp = MathMLContentPrinter({'order': 'rev-lex'}) + mml = mp._print(expr) + + assert mml.childNodes[1].childNodes[0].nodeName == 'power' + assert mml.childNodes[1].childNodes[1].childNodes[0].data == 'y' + assert mml.childNodes[1].childNodes[2].childNodes[0].data == '4' + + assert mml.childNodes[4].childNodes[0].nodeName == 'power' + assert mml.childNodes[4].childNodes[1].childNodes[0].data == 'x' + assert mml.childNodes[4].childNodes[2].childNodes[0].data == '3' + + +def test_content_settings(): + raises(TypeError, lambda: mathml(x, method="garbage")) + + +def test_content_mathml_logic(): + assert mathml(And(x, y)) == 'xy' + assert mathml(Or(x, y)) == 'xy' + assert mathml(Xor(x, y)) == 'xy' + assert mathml(Implies(x, y)) == 'xy' + assert mathml(Not(x)) == 'x' + + +def test_content_finite_sets(): + assert mathml(FiniteSet(a)) == 'a' + assert mathml(FiniteSet(a, b)) == 'ab' + assert mathml(FiniteSet(FiniteSet(a, b), c)) == \ + 'cab' + + A = FiniteSet(a) + B = FiniteSet(b) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(A, B, evaluate=False) + U2 = Union(C, D, evaluate=False) + I1 = Intersection(A, B, evaluate=False) + I2 = Intersection(C, D, evaluate=False) + C1 = Complement(A, B, evaluate=False) + C2 = Complement(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(A, B) + P2 = ProductSet(C, D) + + assert mathml(U1) == \ + 'ab' + assert mathml(I1) == \ + 'ab' \ + '' + assert mathml(C1) == \ + 'ab' + assert mathml(P1) == \ + 'ab' \ + '' + + assert mathml(Intersection(A, U2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Intersection(U1, U2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + + # XXX Does the parenthesis appear correctly for these examples in mathjax? + assert mathml(Intersection(C1, C2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Intersection(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(Union(A, I2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Union(I1, I2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Union(C1, C2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Union(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(Complement(A, C2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Complement(U1, U2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Complement(I1, I2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Complement(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(ProductSet(A, P2)) == \ + 'a' \ + 'c' \ + 'd' + assert mathml(ProductSet(U1, U2)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(ProductSet(I1, I2)) == \ + 'a' \ + 'b' \ + 'cd' + assert mathml(ProductSet(C1, C2)) == \ + 'a' \ + 'b' \ + 'cd' + + +def test_presentation_printmethod(): + assert mpp.doprint(1 + x) == 'x+1' + assert mpp.doprint(x**2) == 'x2' + assert mpp.doprint(x**-1) == '1x' + assert mpp.doprint(x**-2) == \ + '1x2' + assert mpp.doprint(2*x) == \ + '2x' + + +def test_presentation_mathml_core(): + mml_1 = mpp._print(1 + x) + assert mml_1.nodeName == 'mrow' + nodes = mml_1.childNodes + assert len(nodes) == 3 + assert nodes[0].nodeName in ['mi', 'mn'] + assert nodes[1].nodeName == 'mo' + if nodes[0].nodeName == 'mn': + assert nodes[0].childNodes[0].nodeValue == '1' + assert nodes[2].childNodes[0].nodeValue == 'x' + else: + assert nodes[0].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mpp._print(x**2) + assert mml_2.nodeName == 'msup' + nodes = mml_2.childNodes + assert nodes[0].childNodes[0].nodeValue == 'x' + assert nodes[1].childNodes[0].nodeValue == '2' + + mml_3 = mpp._print(2*x) + assert mml_3.nodeName == 'mrow' + nodes = mml_3.childNodes + assert nodes[0].childNodes[0].nodeValue == '2' + assert nodes[1].childNodes[0].nodeValue == '⁢' + assert nodes[2].childNodes[0].nodeValue == 'x' + + mml = mpp._print(Float(1.0, 2)*x) + assert mml.nodeName == 'mrow' + nodes = mml.childNodes + assert nodes[0].childNodes[0].nodeValue == '1.0' + assert nodes[1].childNodes[0].nodeValue == '⁢' + assert nodes[2].childNodes[0].nodeValue == 'x' + + +def test_presentation_mathml_functions(): + mml_1 = mpp._print(sin(x)) + assert mml_1.childNodes[0].childNodes[0 + ].nodeValue == 'sin' + assert mml_1.childNodes[1].childNodes[1 + ].childNodes[0].nodeValue == 'x' + + mml_2 = mpp._print(diff(sin(x), x, evaluate=False)) + assert mml_2.nodeName == 'mrow' + assert mml_2.childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == 'ⅆ' + assert mml_2.childNodes[1].childNodes[1 + ].nodeName == 'mrow' + assert mml_2.childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == 'ⅆ' + + mml_3 = mpp._print(diff(cos(x*y), x, evaluate=False)) + assert mml_3.childNodes[0].nodeName == 'mfrac' + assert mml_3.childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '∂' + assert mml_3.childNodes[1].childNodes[0 + ].childNodes[0].nodeValue == 'cos' + + +def test_print_derivative(): + f = Function('f') + d = Derivative(f(x, y, z), x, z, x, z, z, y) + assert mathml(d) == \ + 'yz2xzxxyz' + assert mathml(d, printer='presentation') == \ + '6y2zxzxf(x,y,z)' + + +def test_presentation_mathml_limits(): + lim_fun = sin(x)/x + mml_1 = mpp._print(Limit(lim_fun, x, 0)) + assert mml_1.childNodes[0].nodeName == 'munder' + assert mml_1.childNodes[0].childNodes[0 + ].childNodes[0].nodeValue == 'lim' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[1].childNodes[0 + ].nodeValue == '→' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[2].childNodes[0 + ].nodeValue == '0' + + +def test_presentation_mathml_integrals(): + assert mpp.doprint(Integral(x, (x, 0, 1))) == \ + '01'\ + 'xx' + assert mpp.doprint(Integral(log(x), x)) == \ + 'log(x' \ + ')x' + assert mpp.doprint(Integral(x*y, x, y)) == \ + 'x'\ + 'yyx' + z, w = symbols('z w') + assert mpp.doprint(Integral(x*y*z, x, y, z)) == \ + 'x'\ + 'yz'\ + 'zyx' + assert mpp.doprint(Integral(x*y*z*w, x, y, z, w)) == \ + ''\ + 'w'\ + 'xy'\ + 'zw'\ + 'zyx' + assert mpp.doprint(Integral(x, x, y, (z, 0, 1))) == \ + '01'\ + 'xz'\ + 'yx' + assert mpp.doprint(Integral(x, (x, 0))) == \ + '0x'\ + 'x' + + +def test_presentation_mathml_matrices(): + A = Matrix([1, 2, 3]) + B = Matrix([[0, 5, 4], [2, 3, 1], [9, 7, 9]]) + mll_1 = mpp._print(A) + assert mll_1.childNodes[1].nodeName == 'mtable' + assert mll_1.childNodes[1].childNodes[0].nodeName == 'mtr' + assert len(mll_1.childNodes[1].childNodes) == 3 + assert mll_1.childNodes[1].childNodes[0].childNodes[0].nodeName == 'mtd' + assert len(mll_1.childNodes[1].childNodes[0].childNodes) == 1 + assert mll_1.childNodes[1].childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_1.childNodes[1].childNodes[1].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_1.childNodes[1].childNodes[2].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '3' + mll_2 = mpp._print(B) + assert mll_2.childNodes[1].nodeName == 'mtable' + assert mll_2.childNodes[1].childNodes[0].nodeName == 'mtr' + assert len(mll_2.childNodes[1].childNodes) == 3 + assert mll_2.childNodes[1].childNodes[0].childNodes[0].nodeName == 'mtd' + assert len(mll_2.childNodes[1].childNodes[0].childNodes) == 3 + assert mll_2.childNodes[1].childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '0' + assert mll_2.childNodes[1].childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '5' + assert mll_2.childNodes[1].childNodes[0].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '4' + assert mll_2.childNodes[1].childNodes[1].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_2.childNodes[1].childNodes[1].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '3' + assert mll_2.childNodes[1].childNodes[1].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_2.childNodes[1].childNodes[2].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '9' + assert mll_2.childNodes[1].childNodes[2].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '7' + assert mll_2.childNodes[1].childNodes[2].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '9' + + +def test_presentation_mathml_sums(): + mml_1 = mpp._print(Sum(x, (x, 1, 10))) + assert mml_1.childNodes[0].nodeName == 'munderover' + assert len(mml_1.childNodes[0].childNodes) == 3 + assert mml_1.childNodes[0].childNodes[0].childNodes[0 + ].nodeValue == '∑' + assert len(mml_1.childNodes[0].childNodes[1].childNodes) == 3 + assert mml_1.childNodes[0].childNodes[2].childNodes[0 + ].nodeValue == '10' + assert mml_1.childNodes[1].childNodes[0].nodeValue == 'x' + + assert mpp.doprint(Sum(x, (x, 1, 10))) == \ + 'x=110x' + assert mpp.doprint(Sum(x + y, (x, 1, 10))) == \ + 'x=110(x+y)' + + +def test_presentation_mathml_add(): + mml = mpp._print(x**5 - x**4 + x) + assert len(mml.childNodes) == 5 + assert mml.childNodes[0].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].childNodes[0 + ].nodeValue == '5' + assert mml.childNodes[1].childNodes[0].nodeValue == '-' + assert mml.childNodes[2].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml.childNodes[2].childNodes[1].childNodes[0 + ].nodeValue == '4' + assert mml.childNodes[3].childNodes[0].nodeValue == '+' + assert mml.childNodes[4].childNodes[0].nodeValue == 'x' + + +def test_presentation_mathml_Rational(): + mml_1 = mpp._print(Rational(1, 1)) + assert mml_1.nodeName == 'mn' + + mml_2 = mpp._print(Rational(2, 5)) + assert mml_2.nodeName == 'mfrac' + assert mml_2.childNodes[0].childNodes[0].nodeValue == '2' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '5' + + +def test_presentation_mathml_constants(): + mml = mpp._print(I) + assert mml.childNodes[0].nodeValue == 'ⅈ' + + mml = mpp._print(E) + assert mml.childNodes[0].nodeValue == 'ⅇ' + + mml = mpp._print(oo) + assert mml.childNodes[0].nodeValue == '∞' + + mml = mpp._print(pi) + assert mml.childNodes[0].nodeValue == 'π' + + assert mathml(hbar, printer='presentation') == '' + assert mathml(S.TribonacciConstant, printer='presentation' + ) == 'TribonacciConstant' + assert mathml(S.EulerGamma, printer='presentation' + ) == 'γ' + assert mathml(S.GoldenRatio, printer='presentation' + ) == 'Φ' + + assert mathml(zoo, printer='presentation') == \ + '~' + + assert mathml(S.NaN, printer='presentation') == 'NaN' + +def test_presentation_mathml_trig(): + mml = mpp._print(sin(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'sin' + + mml = mpp._print(cos(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'cos' + + mml = mpp._print(tan(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'tan' + + mml = mpp._print(asin(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arcsin' + + mml = mpp._print(acos(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arccos' + + mml = mpp._print(atan(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arctan' + + mml = mpp._print(sinh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'sinh' + + mml = mpp._print(cosh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'cosh' + + mml = mpp._print(tanh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'tanh' + + mml = mpp._print(asinh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arcsinh' + + mml = mpp._print(atanh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arctanh' + + mml = mpp._print(acosh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arccosh' + + +def test_presentation_mathml_relational(): + mml_1 = mpp._print(Eq(x, 1)) + assert len(mml_1.childNodes) == 3 + assert mml_1.childNodes[0].nodeName == 'mi' + assert mml_1.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml_1.childNodes[1].nodeName == 'mo' + assert mml_1.childNodes[1].childNodes[0].nodeValue == '=' + assert mml_1.childNodes[2].nodeName == 'mn' + assert mml_1.childNodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mpp._print(Ne(1, x)) + assert len(mml_2.childNodes) == 3 + assert mml_2.childNodes[0].nodeName == 'mn' + assert mml_2.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_2.childNodes[1].nodeName == 'mo' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '≠' + assert mml_2.childNodes[2].nodeName == 'mi' + assert mml_2.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_3 = mpp._print(Ge(1, x)) + assert len(mml_3.childNodes) == 3 + assert mml_3.childNodes[0].nodeName == 'mn' + assert mml_3.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_3.childNodes[1].nodeName == 'mo' + assert mml_3.childNodes[1].childNodes[0].nodeValue == '≥' + assert mml_3.childNodes[2].nodeName == 'mi' + assert mml_3.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_4 = mpp._print(Lt(1, x)) + assert len(mml_4.childNodes) == 3 + assert mml_4.childNodes[0].nodeName == 'mn' + assert mml_4.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_4.childNodes[1].nodeName == 'mo' + assert mml_4.childNodes[1].childNodes[0].nodeValue == '<' + assert mml_4.childNodes[2].nodeName == 'mi' + assert mml_4.childNodes[2].childNodes[0].nodeValue == 'x' + + +def test_presentation_symbol(): + mml = mpp._print(x) + assert mml.nodeName == 'mi' + assert mml.childNodes[0].nodeValue == 'x' + del mml + + mml = mpp._print(Symbol("x^2")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x__2")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x_2")) + assert mml.nodeName == 'msub' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x^3_2")) + assert mml.nodeName == 'msubsup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[2].nodeName == 'mi' + assert mml.childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mpp._print(Symbol("x__3_2")) + assert mml.nodeName == 'msubsup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[2].nodeName == 'mi' + assert mml.childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mpp._print(Symbol("x_2_a")) + assert mml.nodeName == 'msub' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + mml = mpp._print(Symbol("x^2^a")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + mml = mpp._print(Symbol("x__2__a")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + +def test_presentation_mathml_greek(): + mml = mpp._print(Symbol('alpha')) + assert mml.nodeName == 'mi' + assert mml.childNodes[0].nodeValue == '\N{GREEK SMALL LETTER ALPHA}' + + assert mpp.doprint(Symbol('alpha')) == 'α' + assert mpp.doprint(Symbol('beta')) == 'β' + assert mpp.doprint(Symbol('gamma')) == 'γ' + assert mpp.doprint(Symbol('delta')) == 'δ' + assert mpp.doprint(Symbol('epsilon')) == 'ε' + assert mpp.doprint(Symbol('zeta')) == 'ζ' + assert mpp.doprint(Symbol('eta')) == 'η' + assert mpp.doprint(Symbol('theta')) == 'θ' + assert mpp.doprint(Symbol('iota')) == 'ι' + assert mpp.doprint(Symbol('kappa')) == 'κ' + assert mpp.doprint(Symbol('lambda')) == 'λ' + assert mpp.doprint(Symbol('mu')) == 'μ' + assert mpp.doprint(Symbol('nu')) == 'ν' + assert mpp.doprint(Symbol('xi')) == 'ξ' + assert mpp.doprint(Symbol('omicron')) == 'ο' + assert mpp.doprint(Symbol('pi')) == 'π' + assert mpp.doprint(Symbol('rho')) == 'ρ' + assert mpp.doprint(Symbol('varsigma')) == 'ς' + assert mpp.doprint(Symbol('sigma')) == 'σ' + assert mpp.doprint(Symbol('tau')) == 'τ' + assert mpp.doprint(Symbol('upsilon')) == 'υ' + assert mpp.doprint(Symbol('phi')) == 'φ' + assert mpp.doprint(Symbol('chi')) == 'χ' + assert mpp.doprint(Symbol('psi')) == 'ψ' + assert mpp.doprint(Symbol('omega')) == 'ω' + + assert mpp.doprint(Symbol('Alpha')) == 'Α' + assert mpp.doprint(Symbol('Beta')) == 'Β' + assert mpp.doprint(Symbol('Gamma')) == 'Γ' + assert mpp.doprint(Symbol('Delta')) == 'Δ' + assert mpp.doprint(Symbol('Epsilon')) == 'Ε' + assert mpp.doprint(Symbol('Zeta')) == 'Ζ' + assert mpp.doprint(Symbol('Eta')) == 'Η' + assert mpp.doprint(Symbol('Theta')) == 'Θ' + assert mpp.doprint(Symbol('Iota')) == 'Ι' + assert mpp.doprint(Symbol('Kappa')) == 'Κ' + assert mpp.doprint(Symbol('Lambda')) == 'Λ' + assert mpp.doprint(Symbol('Mu')) == 'Μ' + assert mpp.doprint(Symbol('Nu')) == 'Ν' + assert mpp.doprint(Symbol('Xi')) == 'Ξ' + assert mpp.doprint(Symbol('Omicron')) == 'Ο' + assert mpp.doprint(Symbol('Pi')) == 'Π' + assert mpp.doprint(Symbol('Rho')) == 'Ρ' + assert mpp.doprint(Symbol('Sigma')) == 'Σ' + assert mpp.doprint(Symbol('Tau')) == 'Τ' + assert mpp.doprint(Symbol('Upsilon')) == 'Υ' + assert mpp.doprint(Symbol('Phi')) == 'Φ' + assert mpp.doprint(Symbol('Chi')) == 'Χ' + assert mpp.doprint(Symbol('Psi')) == 'Ψ' + assert mpp.doprint(Symbol('Omega')) == 'Ω' + + +def test_presentation_mathml_order(): + expr = x**3 + x**2*y + 3*x*y**3 + y**4 + + mp = MathMLPresentationPrinter({'order': 'lex'}) + mml = mp._print(expr) + assert mml.childNodes[0].nodeName == 'msup' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '3' + + assert mml.childNodes[6].nodeName == 'msup' + assert mml.childNodes[6].childNodes[0].childNodes[0].nodeValue == 'y' + assert mml.childNodes[6].childNodes[1].childNodes[0].nodeValue == '4' + + mp = MathMLPresentationPrinter({'order': 'rev-lex'}) + mml = mp._print(expr) + + assert mml.childNodes[0].nodeName == 'msup' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'y' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '4' + + assert mml.childNodes[6].nodeName == 'msup' + assert mml.childNodes[6].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[6].childNodes[1].childNodes[0].nodeValue == '3' + + +def test_print_intervals(): + a = Symbol('a', real=True) + assert mpp.doprint(Interval(0, a)) == \ + '[0,a]' + assert mpp.doprint(Interval(0, a, False, False)) == \ + '[0,a]' + assert mpp.doprint(Interval(0, a, True, False)) == \ + '(0,a]' + assert mpp.doprint(Interval(0, a, False, True)) == \ + '[0,a)' + assert mpp.doprint(Interval(0, a, True, True)) == \ + '(0,a)' + + +def test_print_tuples(): + assert mpp.doprint(Tuple(0,)) == \ + '(0)' + assert mpp.doprint(Tuple(0, a)) == \ + '(0,a)' + assert mpp.doprint(Tuple(0, a, a)) == \ + '(0,a,a)' + assert mpp.doprint(Tuple(0, 1, 2, 3, 4)) == \ + '(0,1,2,3,4)' + assert mpp.doprint(Tuple(0, 1, Tuple(2, 3, 4))) == \ + '(0,1,(2,3'\ + ',4))' + + +def test_print_re_im(): + assert mpp.doprint(re(x)) == \ + '(x)' + assert mpp.doprint(im(x)) == \ + '(x)' + assert mpp.doprint(re(x + 1, evaluate=False)) == \ + '(x+1)' + assert mpp.doprint(im(x + 1, evaluate=False)) == \ + '(x+1)' + + +def test_print_Abs(): + assert mpp.doprint(Abs(x)) == \ + '|x|' + assert mpp.doprint(Abs(x + 1)) == \ + '|x+1|' + + +def test_print_Determinant(): + assert mpp.doprint(Determinant(Matrix([[1, 2], [3, 4]]))) == \ + '|[1234]|' + + +def test_presentation_settings(): + raises(TypeError, lambda: mathml(x, printer='presentation', + method="garbage")) + + +def test_print_domains(): + from sympy.sets import Integers, Naturals, Naturals0, Reals, Complexes + + assert mpp.doprint(Complexes) == '' + assert mpp.doprint(Integers) == '' + assert mpp.doprint(Naturals) == '' + assert mpp.doprint(Naturals0) == \ + '0' + assert mpp.doprint(Reals) == '' + + +def test_print_expression_with_minus(): + assert mpp.doprint(-x) == '-x' + assert mpp.doprint(-x/y) == \ + '-xy' + assert mpp.doprint(-Rational(1, 2)) == \ + '-12' + + +def test_print_AssocOp(): + from sympy.core.operations import AssocOp + + class TestAssocOp(AssocOp): + identity = 0 + + expr = TestAssocOp(1, 2) + assert mpp.doprint(expr) == \ + 'testassocop12' + + +def test_print_basic(): + expr = Basic(S(1), S(2)) + assert mpp.doprint(expr) == \ + 'basic(1,2)' + assert mp.doprint(expr) == '12' + + +def test_mat_delim_print(): + expr = Matrix([[1, 2], [3, 4]]) + assert mathml(expr, printer='presentation', mat_delim='[') == \ + '[1'\ + '234'\ + ']' + assert mathml(expr, printer='presentation', mat_delim='(') == \ + '(12'\ + '34)' + assert mathml(expr, printer='presentation', mat_delim='') == \ + '12'\ + '34' + + +def test_ln_notation_print(): + expr = log(x) + assert mathml(expr, printer='presentation') == \ + 'log(x)' + assert mathml(expr, printer='presentation', ln_notation=False) == \ + 'log(x)' + assert mathml(expr, printer='presentation', ln_notation=True) == \ + 'ln(x)' + + +def test_mul_symbol_print(): + expr = x * y + assert mathml(expr, printer='presentation') == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol=None) == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol='dot') == \ + 'x·y' + assert mathml(expr, printer='presentation', mul_symbol='ldot') == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol='times') == \ + 'x×y' + + +def test_print_lerchphi(): + assert mpp.doprint(lerchphi(1, 2, 3)) == \ + 'Φ(1,2,3)' + + +def test_print_polylog(): + assert mp.doprint(polylog(x, y)) == \ + 'xy' + assert mpp.doprint(polylog(x, y)) == \ + 'Lix(y)' + + +def test_print_set_frozenset(): + f = frozenset({1, 5, 3}) + assert mpp.doprint(f) == \ + '{1,3,5}' + s = set({1, 2, 3}) + assert mpp.doprint(s) == \ + '{1,2,3}' + + +def test_print_FiniteSet(): + f1 = FiniteSet(x, 1, 3) + assert mpp.doprint(f1) == \ + '{1,3,x}' + + +def test_print_LambertW(): + assert mpp.doprint(LambertW(x)) == 'W(x)' + assert mpp.doprint(LambertW(x, y)) == 'W(x,y)' + + +def test_print_EmptySet(): + assert mpp.doprint(S.EmptySet) == '' + + +def test_print_UniversalSet(): + assert mpp.doprint(S.UniversalSet) == '𝕌' + + +def test_print_spaces(): + assert mpp.doprint(HilbertSpace()) == '' + assert mpp.doprint(ComplexSpace(2)) == '𝒞2' + assert mpp.doprint(FockSpace()) == '' + + +def test_print_constants(): + assert mpp.doprint(hbar) == '' + assert mpp.doprint(S.TribonacciConstant) == 'TribonacciConstant' + assert mpp.doprint(S.GoldenRatio) == 'Φ' + assert mpp.doprint(S.EulerGamma) == 'γ' + + +def test_print_Contains(): + assert mpp.doprint(Contains(x, S.Naturals)) == \ + 'x' + + +def test_print_Dagger(): + x = symbols('x', commutative=False) + assert mpp.doprint(Dagger(x)) == 'x' + + +def test_print_SetOp(): + f1 = FiniteSet(x, 1, 3) + f2 = FiniteSet(y, 2, 4) + + prntr = lambda x: mathml(x, printer='presentation') + + assert prntr(Union(f1, f2, evaluate=False)) == \ + '{1,3,x'\ + '}{2,'\ + '4,y}' + assert prntr(Intersection(f1, f2, evaluate=False)) == \ + '{1,3,x'\ + '}{2'\ + ',4,y}' + assert prntr(Complement(f1, f2, evaluate=False)) == \ + '{1,3,x'\ + '}{2'\ + ',4,y}' + assert prntr(SymmetricDifference(f1, f2, evaluate=False)) == \ + '{1,3,x'\ + '}{2'\ + ',4,y}' + + A = FiniteSet(a) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(C, D, evaluate=False) + I1 = Intersection(C, D, evaluate=False) + C1 = Complement(C, D, evaluate=False) + D1 = SymmetricDifference(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(C, D) + + assert prntr(Union(A, I1, evaluate=False)) == \ + '{a}' \ + '({' \ + 'c}{' \ + 'd})' + assert prntr(Intersection(A, C1, evaluate=False)) == \ + '{a}' \ + '({' \ + 'c}{' \ + 'd})' + assert prntr(Complement(A, D1, evaluate=False)) == \ + '{a}' \ + '({' \ + 'c}{' \ + 'd})' + assert prntr(SymmetricDifference(A, P1, evaluate=False)) == \ + '{a}' \ + '({' \ + 'c}×{' \ + 'd})' + assert prntr(ProductSet(A, U1)) == \ + '{a}' \ + '×({' \ + 'c}{' \ + 'd})' + + +def test_print_logic(): + assert mpp.doprint(And(x, y)) == \ + 'xy' + assert mpp.doprint(Or(x, y)) == \ + 'xy' + assert mpp.doprint(Xor(x, y)) == \ + 'xy' + assert mpp.doprint(Implies(x, y)) == \ + 'xy' + assert mpp.doprint(Equivalent(x, y)) == \ + 'xy' + + assert mpp.doprint(And(Eq(x, y), x > 4)) == \ + 'x=y'\ + 'x>4' + assert mpp.doprint(And(Eq(x, 3), y < 3, x > y + 1)) == \ + 'x=3'\ + 'x>y+1'\ + 'y<3' + assert mpp.doprint(Or(Eq(x, y), x > 4)) == \ + 'x=y'\ + 'x>4' + assert mpp.doprint(And(Eq(x, 3), Or(y < 3, x > y + 1))) == \ + 'x=3'\ + '(x>'\ + 'y+1'\ + 'y<3)' + + assert mpp.doprint(Not(x)) == '¬x' + assert mpp.doprint(Not(And(x, y))) == \ + '¬(xy)' + + +def test_root_notation_print(): + assert mathml(x**(S.One/3), printer='presentation') == \ + 'x3' + assert mathml(x**(S.One/3), printer='presentation', root_notation=False) ==\ + 'x13' + assert mathml(x**(S.One/3), printer='content') == \ + '3x' + assert mathml(x**(S.One/3), printer='content', root_notation=False) == \ + 'x13' + assert mathml(x**(Rational(-1, 3)), printer='presentation') == \ + '1x3' + assert mathml(x**(Rational(-1, 3)), printer='presentation', root_notation=False) \ + == '1x13' + + +def test_fold_frac_powers_print(): + expr = x ** Rational(5, 2) + assert mathml(expr, printer='presentation') == \ + 'x52' + assert mathml(expr, printer='presentation', fold_frac_powers=True) == \ + 'x52' + assert mathml(expr, printer='presentation', fold_frac_powers=False) == \ + 'x52' + + +def test_fold_short_frac_print(): + expr = Rational(2, 5) + assert mathml(expr, printer='presentation') == \ + '25' + assert mathml(expr, printer='presentation', fold_short_frac=True) == \ + '25' + assert mathml(expr, printer='presentation', fold_short_frac=False) == \ + '25' + + +def test_print_factorials(): + assert mpp.doprint(factorial(x)) == 'x!' + assert mpp.doprint(factorial(x + 1)) == \ + '(x+1)!' + assert mpp.doprint(factorial2(x)) == 'x!!' + assert mpp.doprint(factorial2(x + 1)) == \ + '(x+1)!!' + assert mpp.doprint(binomial(x, y)) == \ + '(xy)' + assert mpp.doprint(binomial(4, x + y)) == \ + '(4x'\ + '+y)' + + +def test_print_floor(): + expr = floor(x) + assert mathml(expr, printer='presentation') == \ + 'x' + + +def test_print_ceiling(): + expr = ceiling(x) + assert mathml(expr, printer='presentation') == \ + 'x' + + +def test_print_Lambda(): + expr = Lambda(x, x+1) + assert mathml(expr, printer='presentation') == \ + '(xx+1)' + expr = Lambda((x, y), x + y) + assert mathml(expr, printer='presentation') == \ + '((x,y)x+y)' + + +def test_print_conjugate(): + assert mpp.doprint(conjugate(x)) == \ + 'x' + assert mpp.doprint(conjugate(x + 1)) == \ + 'x+1' + + +def test_print_AccumBounds(): + a = Symbol('a', real=True) + assert mpp.doprint(AccumBounds(0, 1)) == '0,1' + assert mpp.doprint(AccumBounds(0, a)) == '0,a' + assert mpp.doprint(AccumBounds(a + 1, a + 2)) == 'a+1,a+2' + + +def test_print_Float(): + assert mpp.doprint(Float(1e100)) == '1.0·10100' + assert mpp.doprint(Float(1e-100)) == '1.0·10-100' + assert mpp.doprint(Float(-1e100)) == '-1.0·10100' + assert mpp.doprint(Float(1.0*oo)) == '' + assert mpp.doprint(Float(-1.0*oo)) == '-' + + +def test_print_different_functions(): + assert mpp.doprint(gamma(x)) == 'Γ(x)' + assert mpp.doprint(lowergamma(x, y)) == 'γ(x,y)' + assert mpp.doprint(uppergamma(x, y)) == 'Γ(x,y)' + assert mpp.doprint(zeta(x)) == 'ζ(x)' + assert mpp.doprint(zeta(x, y)) == 'ζ(x,y)' + assert mpp.doprint(dirichlet_eta(x)) == 'η(x)' + assert mpp.doprint(elliptic_k(x)) == 'Κ(x)' + assert mpp.doprint(totient(x)) == 'ϕ(x)' + assert mpp.doprint(reduced_totient(x)) == 'λ(x)' + assert mpp.doprint(primenu(x)) == 'ν(x)' + assert mpp.doprint(primeomega(x)) == 'Ω(x)' + assert mpp.doprint(fresnels(x)) == 'S(x)' + assert mpp.doprint(fresnelc(x)) == 'C(x)' + assert mpp.doprint(Heaviside(x)) == 'Θ(x,12)' + + +def test_mathml_builtins(): + assert mpp.doprint(None) == 'None' + assert mpp.doprint(true) == 'True' + assert mpp.doprint(false) == 'False' + + +def test_mathml_Range(): + assert mpp.doprint(Range(1, 51)) == \ + '{1,2,,50}' + assert mpp.doprint(Range(1, 4)) == \ + '{1,2,3}' + assert mpp.doprint(Range(0, 3, 1)) == \ + '{0,1,2}' + assert mpp.doprint(Range(0, 30, 1)) == \ + '{0,1,,29}' + assert mpp.doprint(Range(30, 1, -1)) == \ + '{30,29,,2}' + assert mpp.doprint(Range(0, oo, 2)) == \ + '{0,2,}' + assert mpp.doprint(Range(oo, -2, -2)) == \ + '{,2,0}' + assert mpp.doprint(Range(-2, -oo, -1)) == \ + '{-2,-3,}' + + +def test_print_exp(): + assert mpp.doprint(exp(x)) == \ + 'x' + assert mpp.doprint(exp(1) + exp(2)) == \ + '+2' + + +def test_print_MinMax(): + assert mpp.doprint(Min(x, y)) == \ + 'min(x,y)' + assert mpp.doprint(Min(x, 2, x**3)) == \ + 'min(2,x,x3)' + assert mpp.doprint(Max(x, y)) == \ + 'max(x,y)' + assert mpp.doprint(Max(x, 2, x**3)) == \ + 'max(2,x,x3)' + + +def test_mathml_presentation_numbers(): + n = Symbol('n') + assert mathml(catalan(n), printer='presentation') == \ + 'Cn' + assert mathml(bernoulli(n), printer='presentation') == \ + 'Bn' + assert mathml(bell(n), printer='presentation') == \ + 'Bn' + assert mathml(euler(n), printer='presentation') == \ + 'En' + assert mathml(fibonacci(n), printer='presentation') == \ + 'Fn' + assert mathml(lucas(n), printer='presentation') == \ + 'Ln' + assert mathml(tribonacci(n), printer='presentation') == \ + 'Tn' + assert mathml(bernoulli(n, x), printer='presentation') == \ + mathml(bell(n, x), printer='presentation') == \ + 'Bn(x)' + assert mathml(euler(n, x), printer='presentation') == \ + 'En(x)' + assert mathml(fibonacci(n, x), printer='presentation') == \ + 'Fn(x)' + assert mathml(tribonacci(n, x), printer='presentation') == \ + 'Tn(x)' + + +def test_mathml_presentation_mathieu(): + assert mathml(mathieuc(x, y, z), printer='presentation') == \ + 'C(x,y,z)' + assert mathml(mathieus(x, y, z), printer='presentation') == \ + 'S(x,y,z)' + assert mathml(mathieucprime(x, y, z), printer='presentation') == \ + 'C′(x,y,z)' + assert mathml(mathieusprime(x, y, z), printer='presentation') == \ + 'S′(x,y,z)' + + +def test_mathml_presentation_stieltjes(): + assert mathml(stieltjes(n), printer='presentation') == \ + 'γn' + assert mathml(stieltjes(n, x), printer='presentation') == \ + 'γn(x)' + + +def test_print_matrix_symbol(): + A = MatrixSymbol('A', 1, 2) + assert mpp.doprint(A) == 'A' + assert mp.doprint(A) == 'A' + assert mathml(A, printer='presentation', mat_symbol_style="bold") == \ + 'A' + # No effect in content printer + assert mathml(A, mat_symbol_style="bold") == 'A' + + +def test_print_hadamard(): + from sympy.matrices.expressions import HadamardProduct + from sympy.matrices.expressions import Transpose + + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + + assert mathml(HadamardProduct(X, Y*Y), printer="presentation") == \ + '' \ + 'X' \ + '' \ + 'Y2' \ + '' + + assert mathml(HadamardProduct(X, Y)*Y, printer="presentation") == \ + '' \ + '(' \ + 'XY' \ + ')' \ + 'Y' \ + '' + + assert mathml(HadamardProduct(X, Y, Y), printer="presentation") == \ + '' \ + 'X' \ + 'Y' \ + 'Y' \ + '' + + assert mathml( + Transpose(HadamardProduct(X, Y)), printer="presentation") == \ + '' \ + '(' \ + 'XY' \ + ')' \ + 'T' \ + '' + + +def test_print_random_symbol(): + R = RandomSymbol(Symbol('R')) + assert mpp.doprint(R) == 'R' + assert mp.doprint(R) == 'R' + + +def test_print_IndexedBase(): + assert mathml(IndexedBase(a)[b], printer='presentation') == \ + 'ab' + assert mathml(IndexedBase(a)[b, c, d], printer='presentation') == \ + 'a(b,c,d)' + assert mathml(IndexedBase(a)[b]*IndexedBase(c)[d]*IndexedBase(e), + printer='presentation') == \ + 'ab⁢'\ + 'cde' + + +def test_print_Indexed(): + assert mathml(IndexedBase(a), printer='presentation') == 'a' + assert mathml(IndexedBase(a/b), printer='presentation') == \ + 'ab' + assert mathml(IndexedBase((a, b)), printer='presentation') == \ + '(a,b)' + +def test_print_MatrixElement(): + i, j = symbols('i j') + A = MatrixSymbol('A', i, j) + assert mathml(A[0,0],printer = 'presentation') == \ + 'A0,0' + assert mathml(A[i,j], printer = 'presentation') == \ + 'Ai,j' + assert mathml(A[i*j,0], printer = 'presentation') == \ + 'Aij,0' + + +def test_print_Vector(): + ACS = CoordSys3D('A') + assert mathml(Cross(ACS.i, ACS.j*ACS.x*3 + ACS.k), printer='presentation') == \ + 'i^'\ + 'A×('\ + '(3'\ + 'xA'\ + ')'\ + 'j^'\ + 'A+'\ + 'k^'\ + 'A)' + assert mathml(Cross(ACS.i, ACS.j), printer='presentation') == \ + 'i^'\ + 'A×'\ + 'j^'\ + 'A' + assert mathml(x*Cross(ACS.i, ACS.j), printer='presentation') == \ + 'x('\ + 'i^'\ + 'A×'\ + 'j^'\ + 'A)' + assert mathml(Cross(x*ACS.i, ACS.j), printer='presentation') == \ + '-j'\ + '^A'\ + '×((x)'\ + 'i'\ + '^A'\ + ')' + assert mathml(Curl(3*ACS.x*ACS.j), printer='presentation') == \ + '×(('\ + '3x'\ + 'A)'\ + 'j^'\ + 'A)' + assert mathml(Curl(3*x*ACS.x*ACS.j), printer='presentation') == \ + '×(('\ + '3x'\ + 'A'\ + 'x)'\ + 'j^'\ + 'A)' + assert mathml(x*Curl(3*ACS.x*ACS.j), printer='presentation') == \ + 'x('\ + '×((3'\ + 'x'\ + 'A)'\ + 'j'\ + '^A)'\ + ')' + assert mathml(Curl(3*x*ACS.x*ACS.j + ACS.i), printer='presentation') == \ + '×('\ + 'i^'\ + 'A+('\ + '3x'\ + 'A'\ + 'x)'\ + 'j^'\ + 'A)' + assert mathml(Divergence(3*ACS.x*ACS.j), printer='presentation') == \ + '·(('\ + '3x'\ + 'A)'\ + 'j'\ + '^A)' + assert mathml(x*Divergence(3*ACS.x*ACS.j), printer='presentation') == \ + 'x('\ + '·((3'\ + 'x'\ + 'A)'\ + 'j'\ + '^A'\ + '))' + assert mathml(Divergence(3*x*ACS.x*ACS.j + ACS.i), printer='presentation') == \ + '·('\ + 'i^'\ + 'A+('\ + '3'\ + 'xA'\ + 'x)'\ + 'j'\ + '^A'\ + ')' + assert mathml(Dot(ACS.i, ACS.j*ACS.x*3+ACS.k), printer='presentation') == \ + 'i^'\ + 'A·('\ + '(3'\ + 'xA'\ + ')'\ + 'j^'\ + 'A+'\ + 'k^'\ + 'A)' + assert mathml(Dot(ACS.i, ACS.j), printer='presentation') == \ + 'i^'\ + 'A·'\ + 'j^'\ + 'A' + assert mathml(Dot(x*ACS.i, ACS.j), printer='presentation') == \ + 'j^'\ + 'A·('\ + '(x)'\ + 'i^'\ + 'A)' + assert mathml(x*Dot(ACS.i, ACS.j), printer='presentation') == \ + 'x('\ + 'i^'\ + 'A·'\ + 'j^'\ + 'A)' + assert mathml(Gradient(ACS.x), printer='presentation') == \ + 'x'\ + 'A' + assert mathml(Gradient(ACS.x + 3*ACS.y), printer='presentation') == \ + '('\ + 'xA+3'\ + 'y'\ + 'A)' + assert mathml(x*Gradient(ACS.x), printer='presentation') == \ + 'x('\ + 'xA'\ + ')' + assert mathml(Gradient(x*ACS.x), printer='presentation') == \ + '('\ + 'xA'\ + 'x)' + assert mathml(Cross(ACS.z, ACS.x), printer='presentation') == \ + '-x'\ + 'A×'\ + 'zA' + assert mathml(Laplacian(ACS.x), printer='presentation') == \ + 'x'\ + 'A' + assert mathml(Laplacian(ACS.x + 3*ACS.y), printer='presentation') == \ + '('\ + 'xA+3'\ + 'y'\ + 'A)' + assert mathml(x*Laplacian(ACS.x), printer='presentation') == \ + 'x('\ + 'xA'\ + ')' + assert mathml(Laplacian(x*ACS.x), printer='presentation') == \ + '('\ + 'xA'\ + 'x)' + +@XFAIL +def test_vector_cross_xfail(): + ACS = CoordSys3D('A') + assert mathml(Cross(ACS.x, ACS.z) + Cross(ACS.z, ACS.x), printer='presentation') == \ + '0^' + +def test_print_elliptic_f(): + assert mathml(elliptic_f(x, y), printer = 'presentation') == \ + '𝖥(x|y)' + assert mathml(elliptic_f(x/y, y), printer = 'presentation') == \ + '𝖥(xy|y)' + +def test_print_elliptic_e(): + assert mathml(elliptic_e(x), printer = 'presentation') == \ + '𝖤(x)' + assert mathml(elliptic_e(x, y), printer = 'presentation') == \ + '𝖤(x|y)' + +def test_print_elliptic_pi(): + assert mathml(elliptic_pi(x, y), printer = 'presentation') == \ + '𝛱(x|y)' + assert mathml(elliptic_pi(x, y, z), printer = 'presentation') == \ + '𝛱(x;y|z)' + +def test_print_Ei(): + assert mathml(Ei(x), printer = 'presentation') == \ + 'Ei(x)' + assert mathml(Ei(x**y), printer = 'presentation') == \ + 'Ei(xy)' + +def test_print_expint(): + assert mathml(expint(x, y), printer = 'presentation') == \ + 'Ex(y)' + assert mathml(expint(IndexedBase(x)[1], IndexedBase(x)[2]), printer = 'presentation') == \ + 'Ex1(x2)' + +def test_print_jacobi(): + assert mathml(jacobi(n, a, b, x), printer = 'presentation') == \ + 'Pn(a,b)(x)' + +def test_print_gegenbauer(): + assert mathml(gegenbauer(n, a, x), printer = 'presentation') == \ + 'Cn(a)(x)' + +def test_print_chebyshevt(): + assert mathml(chebyshevt(n, x), printer = 'presentation') == \ + 'Tn(x)' + +def test_print_chebyshevu(): + assert mathml(chebyshevu(n, x), printer = 'presentation') == \ + 'Un(x)' + +def test_print_legendre(): + assert mathml(legendre(n, x), printer = 'presentation') == \ + 'Pn(x)' + +def test_print_assoc_legendre(): + assert mathml(assoc_legendre(n, a, x), printer = 'presentation') == \ + 'Pn(a)(x)' + +def test_print_laguerre(): + assert mathml(laguerre(n, x), printer = 'presentation') == \ + 'Ln(x)' + +def test_print_assoc_laguerre(): + assert mathml(assoc_laguerre(n, a, x), printer = 'presentation') == \ + 'Ln(a)(x)' + +def test_print_hermite(): + assert mathml(hermite(n, x), printer = 'presentation') == \ + 'Hn(x)' + +def test_mathml_SingularityFunction(): + assert mathml(SingularityFunction(x, 4, 5), printer='presentation') == \ + 'x-45' + assert mathml(SingularityFunction(x, -3, 4), printer='presentation') == \ + 'x+34' + assert mathml(SingularityFunction(x, 0, 4), printer='presentation') == \ + 'x4' + assert mathml(SingularityFunction(x, a, n), printer='presentation') == \ + '-a+xn' + assert mathml(SingularityFunction(x, 4, -2), printer='presentation') == \ + 'x-4-2' + assert mathml(SingularityFunction(x, 4, -1), printer='presentation') == \ + 'x-4-1' + + +def test_mathml_matrix_functions(): + from sympy.matrices import Adjoint, Inverse, Transpose + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert mathml(Adjoint(X), printer='presentation') == \ + 'X' + assert mathml(Adjoint(X + Y), printer='presentation') == \ + '(X+Y)' + assert mathml(Adjoint(X) + Adjoint(Y), printer='presentation') == \ + 'X+' \ + 'Y' + assert mathml(Adjoint(X*Y), printer='presentation') == \ + '(X' \ + 'Y)' + assert mathml(Adjoint(Y)*Adjoint(X), printer='presentation') == \ + 'Y⁢' \ + 'X' + assert mathml(Adjoint(X**2), printer='presentation') == \ + '(X2)' + assert mathml(Adjoint(X)**2, printer='presentation') == \ + '(X)2' + assert mathml(Adjoint(Inverse(X)), printer='presentation') == \ + '(X-1)' + assert mathml(Inverse(Adjoint(X)), printer='presentation') == \ + '(X)-1' + assert mathml(Adjoint(Transpose(X)), printer='presentation') == \ + '(XT)' + assert mathml(Transpose(Adjoint(X)), printer='presentation') == \ + '(X)T' + assert mathml(Transpose(Adjoint(X) + Y), printer='presentation') == \ + '(X' \ + '+Y)T' + assert mathml(Transpose(X), printer='presentation') == \ + 'XT' + assert mathml(Transpose(X + Y), printer='presentation') == \ + '(X+Y)T' + + +def test_mathml_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert mathml(Identity(4), printer='presentation') == '𝕀' + assert mathml(ZeroMatrix(2, 2), printer='presentation') == '𝟘' + assert mathml(OneMatrix(2, 2), printer='presentation') == '𝟙' + +def test_mathml_piecewise(): + from sympy.functions.elementary.piecewise import Piecewise + # Content MathML + assert mathml(Piecewise((x, x <= 1), (x**2, True))) == \ + 'xx1x2' + + raises(ValueError, lambda: mathml(Piecewise((x, x <= 1)))) + + +def test_issue_17857(): + assert mathml(Range(-oo, oo), printer='presentation') == \ + '{,-1,0,1,}' + assert mathml(Range(oo, -oo, -1), printer='presentation') == \ + '{,1,0,-1,}' + + +def test_float_roundtrip(): + x = sympify(0.8975979010256552) + y = float(mp.doprint(x).strip('')) + assert x == y + + +def test_content_mathml_disable_split_super_sub(): + mp = MathMLContentPrinter() + assert mp.doprint(Symbol('u_b')) == 'ub' + mp = MathMLContentPrinter({'disable_split_super_sub': False}) + assert mp.doprint(Symbol('u_b')) == 'ub' + mp = MathMLContentPrinter({'disable_split_super_sub': True}) + assert mp.doprint(Symbol('u_b')) == 'u_b' + +def test_presentation_mathml_disable_split_super_sub(): + mpp = MathMLPresentationPrinter() + assert mpp.doprint(Symbol('u_b')) == 'ub' + mpp = MathMLPresentationPrinter({'disable_split_super_sub': False}) + assert mpp.doprint(Symbol('u_b')) == 'ub' + mpp = MathMLPresentationPrinter({'disable_split_super_sub': True}) + assert mpp.doprint(Symbol('u_b')) == 'u_b' diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_numpy.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..fee1c6bd95e54790a048220f37b8e5de79017d2f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_numpy.py @@ -0,0 +1,381 @@ +from sympy.concrete.summations import Sum +from sympy.core.mod import Mod +from sympy.core.relational import (Equality, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.gamma_functions import polygamma +from sympy.functions.special.error_functions import (Si, Ci) +from sympy.matrices import Matrix +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.utilities.lambdify import lambdify +from sympy import symbols, Min, Max + +from sympy.abc import x, i, j, a, b, c, d +from sympy.core import Pow +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt +from sympy.tensor.array import Array +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.printing.numpy import NumPyPrinter, SciPyPrinter, _numpy_known_constants, \ + _numpy_known_functions, _scipy_known_constants, _scipy_known_functions +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +np = import_module('numpy') +jax = import_module('jax') + +if np: + deafult_float_info = np.finfo(np.array([]).dtype) + NUMPY_DEFAULT_EPSILON = deafult_float_info.eps + +def test_numpy_piecewise_regression(): + """ + NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid + breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+. + See gh-9747 and gh-9749 for details. + """ + printer = NumPyPrinter() + p = Piecewise((1, x < 0), (0, True)) + assert printer.doprint(p) == \ + 'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)' + assert printer.module_imports == {'numpy': {'select', 'less', 'nan'}} + +def test_numpy_logaddexp(): + lae = logaddexp(a, b) + assert NumPyPrinter().doprint(lae) == 'numpy.logaddexp(a, b)' + lae2 = logaddexp2(a, b) + assert NumPyPrinter().doprint(lae2) == 'numpy.logaddexp2(a, b)' + + +def test_sum(): + if not np: + skip("NumPy not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + + +def test_multiple_sums(): + if not np: + skip("NumPy not installed") + + s = Sum((x + j) * i, (i, a, b), (j, c, d)) + f = lambdify((a, b, c, d, x), s, 'numpy') + + a_, b_ = 0, 10 + c_, d_ = 11, 21 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, c_, d_, x_), + sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1))) + + +def test_codegen_einsum(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'numpy') + + ma = np.array([[1, 2], [3, 4]]) + mb = np.array([[1,-2], [-1, 3]]) + assert (f(ma, mb) == np.matmul(ma, mb)).all() + + +def test_codegen_extra(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = np.array([[1, 2], [3, 4]]) + mb = np.array([[1,-2], [-1, 3]]) + mc = np.array([[2, 0], [1, 2]]) + md = np.array([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.einsum(ma, [0, 1], mb, [2, 3])).all() + + cg = ArrayAdd(M, N) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == ma+mb).all() + + cg = ArrayAdd(M, N, P) + f = lambdify((M, N, P), cg, 'numpy') + assert (f(ma, mb, mc) == ma+mb+mc).all() + + cg = ArrayAdd(M, N, P, Q) + f = lambdify((M, N, P, Q), cg, 'numpy') + assert (f(ma, mb, mc, md) == ma+mb+mc+md).all() + + cg = PermuteDims(M, [1, 0]) + f = lambdify((M,), cg, 'numpy') + assert (f(ma) == ma.T).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.transpose(np.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.diagonal(np.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all() + + +def test_relational(): + if not np: + skip("NumPy not installed") + + e = Equality(x, 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, True, False]) + + e = Unequality(x, 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, False, True]) + + e = (x < 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, False, False]) + + e = (x <= 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, True, False]) + + e = (x > 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, False, True]) + + e = (x >= 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, True, True]) + + +def test_mod(): + if not np: + skip("NumPy not installed") + + e = Mod(a, b) + f = lambdify((a, b), e) + + a_ = np.array([0, 1, 2, 3]) + b_ = 2 + assert np.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = np.array([0, 1, 2, 3]) + b_ = np.array([2, 2, 2, 2]) + assert np.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = np.array([2, 3, 4, 5]) + b_ = np.array([2, 3, 4, 5]) + assert np.array_equal(f(a_, b_), [0, 0, 0, 0]) + + +def test_pow(): + if not np: + skip('NumPy not installed') + + expr = Pow(2, -1, evaluate=False) + f = lambdify([], expr, 'numpy') + assert f() == 0.5 + + +def test_expm1(): + if not np: + skip("NumPy not installed") + + f = lambdify((a,), expm1(a), 'numpy') + assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * NUMPY_DEFAULT_EPSILON + + +def test_log1p(): + if not np: + skip("NumPy not installed") + + f = lambdify((a,), log1p(a), 'numpy') + assert abs(f(1e-99) - 1e-99) <= 1e-99 * NUMPY_DEFAULT_EPSILON + +def test_hypot(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a, b), hypot(a, b), 'numpy')(3, 4) - 5) <= NUMPY_DEFAULT_EPSILON + +def test_log10(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), log10(a), 'numpy')(100) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_exp2(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), exp2(a), 'numpy')(5) - 32) <= NUMPY_DEFAULT_EPSILON + + +def test_log2(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) <= NUMPY_DEFAULT_EPSILON + + +def test_Sqrt(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), Sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_sqrt(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_matsolve(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 3, 3) + x = MatrixSymbol("x", 3, 1) + + expr = M**(-1) * x + x + matsolve_expr = MatrixSolve(M, x) + x + + f = lambdify((M, x), expr) + f_matsolve = lambdify((M, x), matsolve_expr) + + m0 = np.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]]) + assert np.linalg.matrix_rank(m0) == 3 + + x0 = np.array([3, 4, 5]) + + assert np.allclose(f_matsolve(m0, x0), f(m0, x0)) + + +def test_16857(): + if not np: + skip("NumPy not installed") + + a_1 = MatrixSymbol('a_1', 10, 3) + a_2 = MatrixSymbol('a_2', 10, 3) + a_3 = MatrixSymbol('a_3', 10, 3) + a_4 = MatrixSymbol('a_4', 10, 3) + A = BlockMatrix([[a_1, a_2], [a_3, a_4]]) + assert A.shape == (20, 6) + + printer = NumPyPrinter() + assert printer.doprint(A) == 'numpy.block([[a_1, a_2], [a_3, a_4]])' + + +def test_issue_17006(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + + f = lambdify(M, M + Identity(2)) + ma = np.array([[1, 2], [3, 4]]) + mr = np.array([[2, 2], [3, 5]]) + + assert (f(ma) == mr).all() + + from sympy.core.symbol import symbols + n = symbols('n', integer=True) + N = MatrixSymbol("M", n, n) + raises(NotImplementedError, lambda: lambdify(N, N + Identity(n))) + +def test_jax_tuple_compatibility(): + if not jax: + skip("Jax not installed") + + x, y, z = symbols('x y z') + expr = Max(x, y, z) + Min(x, y, z) + func = lambdify((x, y, z), expr, 'jax') + input_tuple1, input_tuple2 = (1, 2, 3), (4, 5, 6) + input_array1, input_array2 = jax.numpy.asarray(input_tuple1), jax.numpy.asarray(input_tuple2) + assert np.allclose(func(*input_tuple1), func(*input_array1)) + assert np.allclose(func(*input_tuple2), func(*input_array2)) + +def test_numpy_array(): + p = NumPyPrinter() + assert p.doprint(Array([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])' + assert p.doprint(Array([1, 2])) == 'numpy.array([1, 2])' + assert p.doprint(Array([[[1, 2, 3]]])) == 'numpy.array([[[1, 2, 3]]])' + assert p.doprint(Array([], (0,))) == 'numpy.zeros((0,))' + assert p.doprint(Array([], (0, 0))) == 'numpy.zeros((0, 0))' + assert p.doprint(Array([], (0, 1))) == 'numpy.zeros((0, 1))' + assert p.doprint(Array([], (1, 0))) == 'numpy.zeros((1, 0))' + assert p.doprint(Array([1], ())) == 'numpy.array(1)' + +def test_numpy_matrix(): + p = NumPyPrinter() + assert p.doprint(Matrix([[1, 2], [3, 5]])) == 'numpy.array([[1, 2], [3, 5]])' + assert p.doprint(Matrix([1, 2])) == 'numpy.array([[1], [2]])' + assert p.doprint(Matrix(0, 0, [])) == 'numpy.zeros((0, 0))' + assert p.doprint(Matrix(0, 1, [])) == 'numpy.zeros((0, 1))' + assert p.doprint(Matrix(1, 0, [])) == 'numpy.zeros((1, 0))' + +def test_numpy_known_funcs_consts(): + assert _numpy_known_constants['NaN'] == 'numpy.nan' + assert _numpy_known_constants['EulerGamma'] == 'numpy.euler_gamma' + + assert _numpy_known_functions['acos'] == 'numpy.arccos' + assert _numpy_known_functions['log'] == 'numpy.log' + +def test_scipy_known_funcs_consts(): + assert _scipy_known_constants['GoldenRatio'] == 'scipy.constants.golden_ratio' + assert _scipy_known_constants['Pi'] == 'scipy.constants.pi' + + assert _scipy_known_functions['erf'] == 'scipy.special.erf' + assert _scipy_known_functions['factorial'] == 'scipy.special.factorial' + +def test_numpy_print_methods(): + prntr = NumPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + +def test_scipy_print_methods(): + prntr = SciPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + assert hasattr(prntr, '_print_erf') + assert hasattr(prntr, '_print_factorial') + assert hasattr(prntr, '_print_chebyshevt') + k = Symbol('k', integer=True, nonnegative=True) + x = Symbol('x', real=True) + assert prntr.doprint(polygamma(k, x)) == "scipy.special.polygamma(k, x)" + assert prntr.doprint(Si(x)) == "scipy.special.sici(x)[0]" + assert prntr.doprint(Ci(x)) == "scipy.special.sici(x)[1]" diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_octave.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_octave.py new file mode 100644 index 0000000000000000000000000000000000000000..1aba318f873c48ec702f1b4e3a6cc047f75d647d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_octave.py @@ -0,0 +1,515 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, EulerGamma, GoldenRatio, Catalan, + Lambda, Mul, Pow, Mod, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.functions import (arg, atan2, bernoulli, beta, ceiling, chebyshevu, + chebyshevt, conjugate, DiracDelta, exp, expint, + factorial, floor, harmonic, Heaviside, im, + laguerre, LambertW, log, Max, Min, Piecewise, + polylog, re, RisingFactorial, sign, sinc, sqrt, + zeta, binomial, legendre, dirichlet_eta, + riemann_xi) +from sympy.functions import (sin, cos, tan, cot, sec, csc, asin, acos, acot, + atan, asec, acsc, sinh, cosh, tanh, coth, csch, + sech, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.testing.pytest import raises, XFAIL +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix, HadamardPower) +from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli, + besselk, hankel1, hankel2, airyai, + airybi, airyaiprime, airybiprime) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, + uppergamma, loggamma, + polygamma) +from sympy.functions.special.error_functions import (Chi, Ci, erf, erfc, erfi, + erfcinv, erfinv, fresnelc, + fresnels, li, Shi, Si, Li, + erf2, Ei) +from sympy.printing.octave import octave_code, octave_code as mcode + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert mcode(Integer(67)) == "67" + assert mcode(Integer(-1)) == "-1" + + +def test_Rational(): + assert mcode(Rational(3, 7)) == "3/7" + assert mcode(Rational(18, 9)) == "2" + assert mcode(Rational(3, -7)) == "-3/7" + assert mcode(Rational(-3, -7)) == "3/7" + assert mcode(x + Rational(3, 7)) == "x + 3/7" + assert mcode(Rational(3, 7)*x) == "3*x/7" + + +def test_Relational(): + assert mcode(Eq(x, y)) == "x == y" + assert mcode(Ne(x, y)) == "x != y" + assert mcode(Le(x, y)) == "x <= y" + assert mcode(Lt(x, y)) == "x < y" + assert mcode(Gt(x, y)) == "x > y" + assert mcode(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert mcode(sin(x) ** cos(x)) == "sin(x).^cos(x)" + assert mcode(sign(x)) == "sign(x)" + assert mcode(exp(x)) == "exp(x)" + assert mcode(log(x)) == "log(x)" + assert mcode(factorial(x)) == "factorial(x)" + assert mcode(floor(x)) == "floor(x)" + assert mcode(atan2(y, x)) == "atan2(y, x)" + assert mcode(beta(x, y)) == 'beta(x, y)' + assert mcode(polylog(x, y)) == 'polylog(x, y)' + assert mcode(harmonic(x)) == 'harmonic(x)' + assert mcode(bernoulli(x)) == "bernoulli(x)" + assert mcode(bernoulli(x, y)) == "bernoulli(x, y)" + assert mcode(legendre(x, y)) == "legendre(x, y)" + + +def test_Function_change_name(): + assert mcode(abs(x)) == "abs(x)" + assert mcode(ceiling(x)) == "ceil(x)" + assert mcode(arg(x)) == "angle(x)" + assert mcode(im(x)) == "imag(x)" + assert mcode(re(x)) == "real(x)" + assert mcode(conjugate(x)) == "conj(x)" + assert mcode(chebyshevt(y, x)) == "chebyshevT(y, x)" + assert mcode(chebyshevu(y, x)) == "chebyshevU(y, x)" + assert mcode(laguerre(x, y)) == "laguerreL(x, y)" + assert mcode(Chi(x)) == "coshint(x)" + assert mcode(Shi(x)) == "sinhint(x)" + assert mcode(Ci(x)) == "cosint(x)" + assert mcode(Si(x)) == "sinint(x)" + assert mcode(li(x)) == "logint(x)" + assert mcode(loggamma(x)) == "gammaln(x)" + assert mcode(polygamma(x, y)) == "psi(x, y)" + assert mcode(RisingFactorial(x, y)) == "pochhammer(x, y)" + assert mcode(DiracDelta(x)) == "dirac(x)" + assert mcode(DiracDelta(x, 3)) == "dirac(3, x)" + assert mcode(Heaviside(x)) == "heaviside(x, 1/2)" + assert mcode(Heaviside(x, y)) == "heaviside(x, y)" + assert mcode(binomial(x, y)) == "bincoeff(x, y)" + assert mcode(Mod(x, y)) == "mod(x, y)" + + +def test_minmax(): + assert mcode(Max(x, y) + Min(x, y)) == "max(x, y) + min(x, y)" + assert mcode(Max(x, y, z)) == "max(x, max(y, z))" + assert mcode(Min(x, y, z)) == "min(x, min(y, z))" + + +def test_Pow(): + assert mcode(x**3) == "x.^3" + assert mcode(x**(y**3)) == "x.^(y.^3)" + assert mcode(x**Rational(2, 3)) == 'x.^(2/3)' + g = implemented_function('g', Lambda(x, 2*x)) + assert mcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2*x).^(-x + y.^x)./(x.^2 + y)" + # For issue 14160 + assert mcode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x./(y.*y)' + + +def test_basic_ops(): + assert mcode(x*y) == "x.*y" + assert mcode(x + y) == "x + y" + assert mcode(x - y) == "x - y" + assert mcode(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert mcode(1/x) == '1./x' + assert mcode(x**-1) == mcode(x**-1.0) == '1./x' + assert mcode(1/sqrt(x)) == '1./sqrt(x)' + assert mcode(x**-S.Half) == mcode(x**-0.5) == '1./sqrt(x)' + assert mcode(sqrt(x)) == 'sqrt(x)' + assert mcode(x**S.Half) == mcode(x**0.5) == 'sqrt(x)' + assert mcode(1/pi) == '1/pi' + assert mcode(pi**-1) == mcode(pi**-1.0) == '1/pi' + assert mcode(pi**-0.5) == '1/sqrt(pi)' + + +def test_mix_number_mult_symbols(): + assert mcode(3*x) == "3*x" + assert mcode(pi*x) == "pi*x" + assert mcode(3/x) == "3./x" + assert mcode(pi/x) == "pi./x" + assert mcode(x/3) == "x/3" + assert mcode(x/pi) == "x/pi" + assert mcode(x*y) == "x.*y" + assert mcode(3*x*y) == "3*x.*y" + assert mcode(3*pi*x*y) == "3*pi*x.*y" + assert mcode(x/y) == "x./y" + assert mcode(3*x/y) == "3*x./y" + assert mcode(x*y/z) == "x.*y./z" + assert mcode(x/y*z) == "x.*z./y" + assert mcode(1/x/y) == "1./(x.*y)" + assert mcode(2*pi*x/y/z) == "2*pi*x./(y.*z)" + assert mcode(3*pi/x) == "3*pi./x" + assert mcode(S(3)/5) == "3/5" + assert mcode(S(3)/5*x) == "3*x/5" + assert mcode(x/y/z) == "x./(y.*z)" + assert mcode((x+y)/z) == "(x + y)./z" + assert mcode((x+y)/(z+x)) == "(x + y)./(x + z)" + assert mcode((x+y)/EulerGamma) == "(x + y)/%s" % EulerGamma.evalf(17) + assert mcode(x/3/pi) == "x/(3*pi)" + assert mcode(S(3)/5*x*y/pi) == "3*x.*y/(5*pi)" + + +def test_mix_number_pow_symbols(): + assert mcode(pi**3) == 'pi^3' + assert mcode(x**2) == 'x.^2' + assert mcode(x**(pi**3)) == 'x.^(pi^3)' + assert mcode(x**y) == 'x.^y' + assert mcode(x**(y**z)) == 'x.^(y.^z)' + assert mcode((x**y)**z) == '(x.^y).^z' + + +def test_imag(): + I = S('I') + assert mcode(I) == "1i" + assert mcode(5*I) == "5i" + assert mcode((S(3)/2)*I) == "3*1i/2" + assert mcode(3+4*I) == "3 + 4i" + assert mcode(sqrt(3)*I) == "sqrt(3)*1i" + + +def test_constants(): + assert mcode(pi) == "pi" + assert mcode(oo) == "inf" + assert mcode(-oo) == "-inf" + assert mcode(S.NegativeInfinity) == "-inf" + assert mcode(S.NaN) == "NaN" + assert mcode(S.Exp1) == "exp(1)" + assert mcode(exp(1)) == "exp(1)" + + +def test_constants_other(): + assert mcode(2*GoldenRatio) == "2*(1+sqrt(5))/2" + assert mcode(2*Catalan) == "2*%s" % Catalan.evalf(17) + assert mcode(2*EulerGamma) == "2*%s" % EulerGamma.evalf(17) + + +def test_boolean(): + assert mcode(x & y) == "x & y" + assert mcode(x | y) == "x | y" + assert mcode(~x) == "~x" + assert mcode(x & y & z) == "x & y & z" + assert mcode(x | y | z) == "x | y | z" + assert mcode((x & y) | z) == "z | x & y" + assert mcode((x | y) & z) == "z & (x | y)" + + +def test_KroneckerDelta(): + from sympy.functions import KroneckerDelta + assert mcode(KroneckerDelta(x, y)) == "double(x == y)" + assert mcode(KroneckerDelta(x, y + 1)) == "double(x == (y + 1))" + assert mcode(KroneckerDelta(2**x, y)) == "double((2.^x) == y)" + + +def test_Matrices(): + assert mcode(Matrix(1, 1, [10])) == "10" + A = Matrix([[1, sin(x/2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = "[1 sin(x/2) abs(x); 0 1 pi; 0 exp(1) ceil(x)]" + assert mcode(A) == expected + # row and columns + assert mcode(A[:,0]) == "[1; 0; 0]" + assert mcode(A[0,:]) == "[1 sin(x/2) abs(x)]" + # empty matrices + assert mcode(Matrix(0, 0, [])) == '[]' + assert mcode(Matrix(0, 3, [])) == 'zeros(0, 3)' + # annoying to read but correct + assert mcode(Matrix([[x, x - y, -y]])) == "[x x - y -y]" + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2/x), 3*pi/x/5]]) + assert mcode(A) == "[1 sin(2./x) 3*pi./(5*x)]" + assert mcode(A.T) == "[1; sin(2./x); 3*pi./(5*x)]" + + +@XFAIL +def test_Matrices_entries_not_hadamard(): + # For Matrix with col >= 2, row >= 2, they need to be scalars + # FIXME: is it worth worrying about this? Its not wrong, just + # leave it user's responsibility to put scalar data for x. + A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]]) + expected = ("[1 sin(2/x) 3*pi/(5*x);\n" + "1 2 x*y]") # <- we give x.*y + assert mcode(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert mcode(A*B) == "A*B" + assert mcode(B*A) == "B*A" + assert mcode(2*A*B) == "2*A*B" + assert mcode(B*2*A) == "2*B*A" + assert mcode(A*(B + 3*Identity(n))) == "A*(3*eye(n) + B)" + assert mcode(A**(x**2)) == "A^(x.^2)" + assert mcode(A**3) == "A^3" + assert mcode(A**S.Half) == "A^(1/2)" + + +def test_MatrixSolve(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + x = MatrixSymbol('x', n, 1) + assert mcode(MatrixSolve(A, x)) == "A \\ x" + +def test_special_matrices(): + assert mcode(6*Identity(3)) == "6*eye(3)" + + +def test_containers(): + assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}" + assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}" + assert mcode([1]) == "{1}" + assert mcode((1,)) == "{1}" + assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}" + assert mcode((1, x*y, (3, x**2))) == "{1, x.*y, {3, x.^2}}" + # scalar, matrix, empty matrix and empty list + assert mcode((1, eye(3), Matrix(0, 0, []), [])) == "{1, [1 0 0; 0 1 0; 0 0 1], [], {}}" + + +def test_octave_noninline(): + source = mcode((x+y)/Catalan, assign_to='me', inline=False) + expected = ( + "Catalan = %s;\n" + "me = (x + y)/Catalan;" + ) % Catalan.evalf(17) + assert source == expected + + +def test_octave_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert mcode(expr) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))" + assert mcode(expr, assign_to="r") == ( + "r = ((x < 1).*(x) + (~(x < 1)).*(x.^2));") + assert mcode(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x;\n" + "else\n" + " r = x.^2;\n" + "end") + expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True)) + expected = ("((x < 1).*(x.^2) + (~(x < 1)).*( ...\n" + "(x < 2).*(x.^3) + (~(x < 2)).*( ...\n" + "(x < 3).*(x.^4) + (~(x < 3)).*(x.^5))))") + assert mcode(expr) == expected + assert mcode(expr, assign_to="r") == "r = " + expected + ";" + assert mcode(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x.^2;\n" + "elseif (x < 2)\n" + " r = x.^3;\n" + "elseif (x < 3)\n" + " r = x.^4;\n" + "else\n" + " r = x.^5;\n" + "end") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: mcode(expr)) + + +def test_octave_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x**2, True)) + assert mcode(2*pw) == "2*((x < 1).*(x) + (~(x < 1)).*(x.^2))" + assert mcode(pw/x) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./x" + assert mcode(pw/(x*y)) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./(x.*y)" + assert mcode(pw/3) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))/3" + + +def test_octave_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert mcode(A, assign_to='a') == "a = [1 2 3];" + A = Matrix([[1, 2], [3, 4]]) + assert mcode(A, assign_to='A') == "A = [1 2; 3 4];" + + +def test_octave_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert mcode(A, assign_to=B) == "B = [1 2 3];" + raises(ValueError, lambda: mcode(A, assign_to=x)) + raises(ValueError, lambda: mcode(A, assign_to=C)) + + +def test_octave_matrix_1x1(): + A = Matrix([[3]]) + B = MatrixSymbol('B', 1, 1) + C = MatrixSymbol('C', 1, 2) + assert mcode(A, assign_to=B) == "B = 3;" + # FIXME? + #assert mcode(A, assign_to=x) == "x = 3;" + raises(ValueError, lambda: mcode(A, assign_to=C)) + + +def test_octave_matrix_elements(): + A = Matrix([[x, 2, x*y]]) + assert mcode(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x.^2 + x.*y + 2" + A = MatrixSymbol('AA', 1, 3) + assert mcode(A) == "AA" + assert mcode(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \ + "sin(AA(1, 2)) + AA(1, 1).^2 + AA(1, 3)" + assert mcode(sum(A)) == "AA(1, 1) + AA(1, 2) + AA(1, 3)" + + +def test_octave_boolean(): + assert mcode(True) == "true" + assert mcode(S.true) == "true" + assert mcode(False) == "false" + assert mcode(S.false) == "false" + + +def test_octave_not_supported(): + with raises(NotImplementedError): + mcode(S.ComplexInfinity) + f = Function('f') + assert mcode(f(x).diff(x), strict=False) == ( + "% Not supported in Octave:\n" + "% Derivative\n" + "Derivative(f(x), x)" + ) + + +def test_octave_not_supported_not_on_whitelist(): + from sympy.functions.special.polynomials import assoc_laguerre + with raises(NotImplementedError): + mcode(assoc_laguerre(x, y, z)) + + +def test_octave_expint(): + assert mcode(expint(1, x)) == "expint(x)" + with raises(NotImplementedError): + mcode(expint(2, x)) + assert mcode(expint(y, x), strict=False) == ( + "% Not supported in Octave:\n" + "% expint\n" + "expint(y, x)" + ) + + +def test_trick_indent_with_end_else_words(): + # words starting with "end" or "else" do not confuse the indenter + t1 = S('endless') + t2 = S('elsewhere') + pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True)) + assert mcode(pw, inline=False) == ( + "if (x < 0)\n" + " endless\n" + "elseif (x <= 1)\n" + " elsewhere\n" + "else\n" + " 1\n" + "end") + + +def test_hadamard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + n = Symbol('n') + assert mcode(C) == "A.*B" + assert mcode(C*v) == "(A.*B)*v" + assert mcode(h*C*v) == "h*(A.*B)*v" + assert mcode(C*A) == "(A.*B)*A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + assert mcode(C*x*y) == "(x.*y)*(A.*B)" + + # Testing HadamardPower: + assert mcode(HadamardPower(A, n)) == "A.**n" + assert mcode(HadamardPower(A, 1+n)) == "A.**(n + 1)" + assert mcode(HadamardPower(A*B.T, 1+n)) == "(A*B.T).**(n + 1)" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x*y + assert mcode(M) == ( + "sparse([4 2 3 1 2], [1 3 3 4 4], [x.*y 20 10 30 22], 5, 6)" + ) + + +def test_sinc(): + assert mcode(sinc(x)) == 'sinc(x/pi)' + assert mcode(sinc(x + 3)) == 'sinc((x + 3)/pi)' + assert mcode(sinc(pi*(x + 3))) == 'sinc(x + 3)' + + +def test_trigfun(): + for f in (sin, cos, tan, cot, sec, csc, asin, acos, acot, atan, asec, acsc, + sinh, cosh, tanh, coth, csch, sech, asinh, acosh, atanh, acoth, + asech, acsch): + assert octave_code(f(x) == f.__name__ + '(x)') + + +def test_specfun(): + n = Symbol('n') + for f in [besselj, bessely, besseli, besselk]: + assert octave_code(f(n, x)) == f.__name__ + '(n, x)' + for f in (erfc, erfi, erf, erfinv, erfcinv, fresnelc, fresnels, gamma): + assert octave_code(f(x)) == f.__name__ + '(x)' + assert octave_code(hankel1(n, x)) == 'besselh(n, 1, x)' + assert octave_code(hankel2(n, x)) == 'besselh(n, 2, x)' + assert octave_code(airyai(x)) == 'airy(0, x)' + assert octave_code(airyaiprime(x)) == 'airy(1, x)' + assert octave_code(airybi(x)) == 'airy(2, x)' + assert octave_code(airybiprime(x)) == 'airy(3, x)' + assert octave_code(uppergamma(n, x)) == '(gammainc(x, n, \'upper\').*gamma(n))' + assert octave_code(lowergamma(n, x)) == '(gammainc(x, n).*gamma(n))' + assert octave_code(z**lowergamma(n, x)) == 'z.^(gammainc(x, n).*gamma(n))' + assert octave_code(jn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*besselj(n + 1/2, x)/2' + assert octave_code(yn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*bessely(n + 1/2, x)/2' + assert octave_code(LambertW(x)) == 'lambertw(x)' + assert octave_code(LambertW(x, n)) == 'lambertw(n, x)' + + # Automatic rewrite + assert octave_code(Ei(x)) == '(logint(exp(x)))' + assert octave_code(dirichlet_eta(x)) == '(((x == 1).*(log(2)) + (~(x == 1)).*((1 - 2.^(1 - x)).*zeta(x))))' + assert octave_code(riemann_xi(x)) == '(pi.^(-x/2).*x.*(x - 1).*gamma(x/2).*zeta(x)/2)' + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert mcode(A[0, 0]) == "A(1, 1)" + assert mcode(3 * A[0, 0]) == "3*A(1, 1)" + + F = C[0, 0].subs(C, A - B) + assert mcode(F) == "(A - B)(1, 1)" + + +def test_zeta_printing_issue_14820(): + assert octave_code(zeta(x)) == 'zeta(x)' + with raises(NotImplementedError): + octave_code(zeta(x, y)) + + +def test_automatic_rewrite(): + assert octave_code(Li(x)) == '(logint(x) - logint(2))' + assert octave_code(erf2(x, y)) == '(-erf(x) + erf(y))' diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_precedence.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_precedence.py new file mode 100644 index 0000000000000000000000000000000000000000..d08ea07483857e8c2ee7f930aa53d2dacdc58193 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_precedence.py @@ -0,0 +1,128 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import Derivative, Function +from sympy.core.numbers import Integer, Rational, Float, oo +from sympy.core.relational import Rel +from sympy.core.symbol import symbols +from sympy.functions import sin +from sympy.integrals.integrals import Integral +from sympy.series.order import Order + +from sympy.printing.precedence import precedence, PRECEDENCE + +x, y = symbols("x,y") + + +def test_Add(): + assert precedence(x + y) == PRECEDENCE["Add"] + assert precedence(x*y + 1) == PRECEDENCE["Add"] + + +def test_Function(): + assert precedence(sin(x)) == PRECEDENCE["Func"] + +def test_Derivative(): + assert precedence(Derivative(x, y)) == PRECEDENCE["Atom"] + +def test_Integral(): + assert precedence(Integral(x, y)) == PRECEDENCE["Atom"] + + +def test_Mul(): + assert precedence(x*y) == PRECEDENCE["Mul"] + assert precedence(-x*y) == PRECEDENCE["Add"] + + +def test_Number(): + assert precedence(Integer(0)) == PRECEDENCE["Atom"] + assert precedence(Integer(1)) == PRECEDENCE["Atom"] + assert precedence(Integer(-1)) == PRECEDENCE["Add"] + assert precedence(Integer(10)) == PRECEDENCE["Atom"] + assert precedence(Rational(5, 2)) == PRECEDENCE["Mul"] + assert precedence(Rational(-5, 2)) == PRECEDENCE["Add"] + assert precedence(Float(5)) == PRECEDENCE["Atom"] + assert precedence(Float(-5)) == PRECEDENCE["Add"] + assert precedence(oo) == PRECEDENCE["Atom"] + assert precedence(-oo) == PRECEDENCE["Add"] + + +def test_Order(): + assert precedence(Order(x)) == PRECEDENCE["Atom"] + + +def test_Pow(): + assert precedence(x**y) == PRECEDENCE["Pow"] + assert precedence(-x**y) == PRECEDENCE["Add"] + assert precedence(x**-y) == PRECEDENCE["Pow"] + + +def test_Product(): + assert precedence(Product(x, (x, y, y + 1))) == PRECEDENCE["Atom"] + + +def test_Relational(): + assert precedence(Rel(x + y, y, "<")) == PRECEDENCE["Relational"] + + +def test_Sum(): + assert precedence(Sum(x, (x, y, y + 1))) == PRECEDENCE["Atom"] + + +def test_Symbol(): + assert precedence(x) == PRECEDENCE["Atom"] + + +def test_And_Or(): + # precedence relations between logical operators, ... + assert precedence(x & y) > precedence(x | y) + assert precedence(~y) > precedence(x & y) + # ... and with other operators (cfr. other programming languages) + assert precedence(x + y) > precedence(x | y) + assert precedence(x + y) > precedence(x & y) + assert precedence(x*y) > precedence(x | y) + assert precedence(x*y) > precedence(x & y) + assert precedence(~y) > precedence(x*y) + assert precedence(~y) > precedence(x - y) + # double checks + assert precedence(x & y) == PRECEDENCE["And"] + assert precedence(x | y) == PRECEDENCE["Or"] + assert precedence(~y) == PRECEDENCE["Not"] + + +def test_custom_function_precedence_comparison(): + """ + Test cases for custom functions with different precedence values, + specifically handling: + 1. Functions with precedence < PRECEDENCE["Mul"] (50) + 2. Functions with precedence = Func (70) + + Key distinction: + 1. Lower precedence functions (45) need parentheses: -2*(x F y) + 2. Higher precedence functions (70) don't: -2*x F y + """ + class LowPrecedenceF(Function): + precedence = PRECEDENCE["Mul"] - 5 + def _sympystr(self, printer): + return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}" + + class HighPrecedenceF(Function): + precedence = PRECEDENCE["Func"] + def _sympystr(self, printer): + return f"{printer._print(self.args[0])} F {printer._print(self.args[1])}" + + def test_low_precedence(): + expr1 = 2 * LowPrecedenceF(x, y) + assert str(expr1) == "2*(x F y)" + + expr2 = -2 * LowPrecedenceF(x, y) + assert str(expr2) == "-2*(x F y)" + + def test_high_precedence(): + expr1 = 2 * HighPrecedenceF(x, y) + assert str(expr1) == "2*x F y" + + expr2 = -2 * HighPrecedenceF(x, y) + assert str(expr2) == "-2*x F y" + + test_low_precedence() + test_high_precedence() diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_preview.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_preview.py new file mode 100644 index 0000000000000000000000000000000000000000..91771ceb0466d6b0fee00570426713d02da14872 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_preview.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol +from sympy.functions.elementary.piecewise import Piecewise +from sympy.printing.preview import preview + +from io import BytesIO + + +def test_preview(): + x = Symbol('x') + obj = BytesIO() + try: + preview(x, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server + + +def test_preview_unicode_symbol(): + # issue 9107 + a = Symbol('α') + obj = BytesIO() + try: + preview(a, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server + + +def test_preview_latex_construct_in_expr(): + # see PR 9801 + x = Symbol('x') + pw = Piecewise((1, Eq(x, 0)), (0, True)) + obj = BytesIO() + try: + preview(pw, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_pycode.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_pycode.py new file mode 100644 index 0000000000000000000000000000000000000000..2c38fe81d830149cdce6b55f15e6e07513fdd146 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_pycode.py @@ -0,0 +1,493 @@ +from sympy import Not +from sympy.codegen import Assignment +from sympy.codegen.ast import none +from sympy.codegen.cfunctions import expm1, log1p +from sympy.codegen.scipy_nodes import cosm1 +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow +from sympy.core.function import Derivative +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth, sec, log, sin, cos, tan, asin, atan, sinh, cosh, tanh, asinh, acosh, atanh +from sympy.functions.elementary.trigonometric import atan2 +from sympy.logic import And, Or +from sympy.matrices import SparseMatrix, MatrixSymbol, Identity +from sympy.printing.codeprinter import PrintMethodNotImplementedError +from sympy.printing.pycode import ( + MpmathPrinter, CmathPrinter, PythonCodePrinter, pycode, SymPyPrinter +) +from sympy.printing.tensorflow import TensorflowPrinter +from sympy.printing.numpy import NumPyPrinter, SciPyPrinter +from sympy.testing.pytest import raises, skip +from sympy.tensor import IndexedBase, Idx +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray +from sympy.external import import_module +from sympy.functions.special.gamma_functions import loggamma + + + +x, y, z = symbols('x y z') +p = IndexedBase("p") + + +def test_PythonCodePrinter(): + prntr = PythonCodePrinter() + + assert not prntr.module_imports + + assert prntr.doprint(x**y) == 'x**y' + assert prntr.doprint(Mod(x, 2)) == 'x % 2' + assert prntr.doprint(-Mod(x, y)) == '-(x % y)' + assert prntr.doprint(Mod(-x, y)) == '(-x) % y' + assert prntr.doprint(And(x, y)) == 'x and y' + assert prntr.doprint(Or(x, y)) == 'x or y' + assert prntr.doprint(1/(x+y)) == '1/(x + y)' + assert prntr.doprint(Not(x)) == 'not x' + assert not prntr.module_imports + + assert prntr.doprint(pi) == 'math.pi' + assert prntr.module_imports == {'math': {'pi'}} + + assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)' + assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)' + assert prntr.module_imports == {'math': {'pi', 'sqrt'}} + + assert prntr.doprint(acos(x)) == 'math.acos(x)' + assert prntr.doprint(cot(x)) == '(1/math.tan(x))' + assert prntr.doprint(coth(x)) == '((math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x)))' + assert prntr.doprint(asec(x)) == '(math.acos(1/x))' + assert prntr.doprint(acsch(x)) == '(math.log(math.sqrt(1 + x**(-2)) + 1/x))' + + assert prntr.doprint(Assignment(x, 2)) == 'x = 2' + assert prntr.doprint(Piecewise((1, Eq(x, 0)), + (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)' + assert prntr.doprint(Piecewise((2, Le(x, 0)), + (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\ + ' (3) if (x > 0) else None)' + assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))' + assert prntr.doprint(p[0, 1]) == 'p[0, 1]' + assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)' + + assert prntr.doprint((2,3)) == "(2, 3)" + assert prntr.doprint([2,3]) == "[2, 3]" + + assert prntr.doprint(Min(x, y)) == "min(x, y)" + assert prntr.doprint(Max(x, y)) == "max(x, y)" + + +def test_PythonCodePrinter_standard(): + prntr = PythonCodePrinter() + + assert prntr.standard == 'python3' + + raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'})) + + +def test_CmathPrinter(): + p = CmathPrinter() + + assert p.doprint(sqrt(x)) == 'cmath.sqrt(x)' + assert p.doprint(log(x)) == 'cmath.log(x)' + + assert p.doprint(sin(x)) == 'cmath.sin(x)' + assert p.doprint(cos(x)) == 'cmath.cos(x)' + assert p.doprint(tan(x)) == 'cmath.tan(x)' + + assert p.doprint(asin(x)) == 'cmath.asin(x)' + assert p.doprint(acos(x)) == 'cmath.acos(x)' + assert p.doprint(atan(x)) == 'cmath.atan(x)' + + assert p.doprint(sinh(x)) == 'cmath.sinh(x)' + assert p.doprint(cosh(x)) == 'cmath.cosh(x)' + assert p.doprint(tanh(x)) == 'cmath.tanh(x)' + + assert p.doprint(asinh(x)) == 'cmath.asinh(x)' + assert p.doprint(acosh(x)) == 'cmath.acosh(x)' + assert p.doprint(atanh(x)) == 'cmath.atanh(x)' + + +def test_MpmathPrinter(): + p = MpmathPrinter() + assert p.doprint(sign(x)) == 'mpmath.sign(x)' + assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)' + + assert p.doprint(S.Exp1) == 'mpmath.e' + assert p.doprint(S.Pi) == 'mpmath.pi' + assert p.doprint(S.GoldenRatio) == 'mpmath.phi' + assert p.doprint(S.EulerGamma) == 'mpmath.euler' + assert p.doprint(S.NaN) == 'mpmath.nan' + assert p.doprint(S.Infinity) == 'mpmath.inf' + assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf' + assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)' + + +def test_NumPyPrinter(): + from sympy.core.function import Lambda + from sympy.matrices.expressions.adjoint import Adjoint + from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf) + from sympy.matrices.expressions.funcmatrix import FunctionMatrix + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.kronecker import KroneckerProduct + from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix) + from sympy.abc import a, b + p = NumPyPrinter() + assert p.doprint(sign(x)) == 'numpy.sign(x)' + A = MatrixSymbol("A", 2, 2) + B = MatrixSymbol("B", 2, 2) + C = MatrixSymbol("C", 1, 5) + D = MatrixSymbol("D", 3, 4) + assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)" + assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)" + assert p.doprint(Identity(3)) == "numpy.eye(3)" + + u = MatrixSymbol('x', 2, 1) + v = MatrixSymbol('y', 2, 1) + assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)' + assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y' + + assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))" + assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))" + assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \ + "numpy.fromfunction(lambda a, b: a + b, (4, 5))" + assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)" + assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)" + assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))" + assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))" + assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)" + assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))" + + # Workaround for numpy negative integer power errors + assert p.doprint(x**-1) == 'x**(-1.0)' + assert p.doprint(x**-2) == 'x**(-2.0)' + + expr = Pow(2, -1, evaluate=False) + assert p.doprint(expr) == "2**(-1.0)" + + assert p.doprint(S.Exp1) == 'numpy.e' + assert p.doprint(S.Pi) == 'numpy.pi' + assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma' + assert p.doprint(S.NaN) == 'numpy.nan' + assert p.doprint(S.Infinity) == 'numpy.inf' + assert p.doprint(S.NegativeInfinity) == '-numpy.inf' + + # Function rewriting operator precedence fix + assert p.doprint(sec(x)**2) == '(numpy.cos(x)**(-1.0))**2' + + +def test_issue_18770(): + numpy = import_module('numpy') + if not numpy: + skip("numpy not installed.") + + from sympy.functions.elementary.miscellaneous import (Max, Min) + from sympy.utilities.lambdify import lambdify + + expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1) + func = lambdify(x, expr1, "numpy") + assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all() + assert func(4) == 3 + + expr1 = Max(x**2, x**3) + func = lambdify(x,expr1, "numpy") + assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all() + assert func(4) == 64 + + +def test_SciPyPrinter(): + p = SciPyPrinter() + expr = acos(x) + assert 'numpy' not in p.module_imports + assert p.doprint(expr) == 'numpy.arccos(x)' + assert 'numpy' in p.module_imports + assert not any(m.startswith('scipy') for m in p.module_imports) + smat = SparseMatrix(2, 5, {(0, 1): 3}) + assert p.doprint(smat) == \ + 'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))' + assert 'scipy.sparse' in p.module_imports + + assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio' + assert p.doprint(S.Pi) == 'scipy.constants.pi' + assert p.doprint(S.Exp1) == 'numpy.e' + + +def test_pycode_reserved_words(): + s1, s2 = symbols('if else') + raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True)) + py_str = pycode(s1 + s2) + assert py_str in ('else_ + if_', 'if_ + else_') + + +def test_issue_20762(): + # Make sure pycode removes curly braces from subscripted variables + a_b, b, a_11 = symbols('a_{b} b a_{11}') + expr = a_b*b + assert pycode(expr) == 'a_b*b' + expr = a_11*b + assert pycode(expr) == 'a_11*b' + + +def test_sqrt(): + prntr = PythonCodePrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)' + assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)' + + prntr = PythonCodePrinter({'standard' : 'python3'}) + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)' + + prntr = MpmathPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == \ + "x**(mpmath.mpf(1)/mpmath.mpf(2))" + + prntr = NumPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + prntr = SciPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + prntr = SymPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + +def test_frac(): + from sympy.functions.elementary.integers import frac + + expr = frac(x) + prntr = NumPyPrinter() + assert prntr.doprint(expr) == 'numpy.mod(x, 1)' + + prntr = SciPyPrinter() + assert prntr.doprint(expr) == 'numpy.mod(x, 1)' + + prntr = PythonCodePrinter() + assert prntr.doprint(expr) == 'x % 1' + + prntr = MpmathPrinter() + assert prntr.doprint(expr) == 'mpmath.frac(x)' + + prntr = SymPyPrinter() + assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)' + + +class CustomPrintedObject(Expr): + def _numpycode(self, printer): + return 'numpy' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + obj = CustomPrintedObject() + assert NumPyPrinter().doprint(obj) == 'numpy' + assert MpmathPrinter().doprint(obj) == 'mpmath' + + +def test_codegen_ast_nodes(): + assert pycode(none) == 'None' + + +def test_issue_14283(): + prntr = PythonCodePrinter() + + assert prntr.doprint(zoo) == "math.nan" + assert prntr.doprint(-oo) == "float('-inf')" + + +def test_NumPyPrinter_print_seq(): + n = NumPyPrinter() + + assert n._print_seq(range(2)) == '(0, 1,)' + + +def test_issue_16535_16536(): + from sympy.functions.special.gamma_functions import (lowergamma, uppergamma) + + a = symbols('a') + expr1 = lowergamma(a, x) + expr2 = uppergamma(a, x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)' + assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)' + + p_numpy = NumPyPrinter() + p_pycode = PythonCodePrinter({'strict': False}) + + for expr in [expr1, expr2]: + with raises(NotImplementedError): + p_numpy.doprint(expr1) + assert "Not supported" in p_pycode.doprint(expr) + + +def test_Integral(): + from sympy.functions.elementary.exponential import exp + from sympy.integrals.integrals import Integral + + single = Integral(exp(-x), (x, 0, oo)) + double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z)) + indefinite = Integral(x**2, x) + evaluateat = Integral(x**2, (x, 1)) + + prntr = SciPyPrinter() + assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.inf)[0]' + assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]' + raises(NotImplementedError, lambda: prntr.doprint(indefinite)) + raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) + + prntr = MpmathPrinter() + assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))' + assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))' + raises(NotImplementedError, lambda: prntr.doprint(indefinite)) + raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) + + +def test_fresnel_integrals(): + from sympy.functions.special.error_functions import (fresnelc, fresnels) + + expr1 = fresnelc(x) + expr2 = fresnels(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]' + assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]' + + p_numpy = NumPyPrinter() + p_pycode = PythonCodePrinter() + p_mpmath = MpmathPrinter() + for expr in [expr1, expr2]: + with raises(NotImplementedError): + p_numpy.doprint(expr) + with raises(NotImplementedError): + p_pycode.doprint(expr) + + assert p_mpmath.doprint(expr1) == 'mpmath.fresnelc(x)' + assert p_mpmath.doprint(expr2) == 'mpmath.fresnels(x)' + + +def test_beta(): + from sympy.functions.special.beta_functions import beta + + expr = beta(x, y) + + prntr = SciPyPrinter() + assert prntr.doprint(expr) == 'scipy.special.beta(x, y)' + + prntr = NumPyPrinter() + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = PythonCodePrinter() + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = PythonCodePrinter({'allow_unknown_functions': True}) + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = MpmathPrinter() + assert prntr.doprint(expr) == 'mpmath.beta(x, y)' + +def test_airy(): + from sympy.functions.special.bessel import (airyai, airybi) + + expr1 = airyai(x) + expr2 = airybi(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]' + assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]' + + prntr = NumPyPrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + prntr = PythonCodePrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + +def test_airy_prime(): + from sympy.functions.special.bessel import (airyaiprime, airybiprime) + + expr1 = airyaiprime(x) + expr2 = airybiprime(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]' + assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]' + + prntr = NumPyPrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + prntr = PythonCodePrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + +def test_numerical_accuracy_functions(): + prntr = SciPyPrinter() + assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)' + assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)' + assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)' + +def test_array_printer(): + A = ArraySymbol('A', (4,4,6,6,6)) + I = IndexedBase('I') + i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5)) + + prntr = NumPyPrinter() + assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))' + assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))' + assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)' + assert prntr.doprint(I) == 'I' + assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)' + assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)' + assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)' + assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I' + + prntr = TensorflowPrinter() + assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))' + assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))' + assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)' + assert prntr.doprint(I) == 'I' + assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)' + assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)' + assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)' + assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I' + + +def test_custom_Derivative_methods(): + class MyPrinter(SciPyPrinter): + def _print_Derivative_cosm1(self, args, seq_orders): + arg, = args + order, = seq_orders + return 'my_custom_cosm1(%s, deriv_order=%d)' % (self._print(arg), order) + + def _print_Derivative_atan2(self, args, seq_orders): + arg1, arg2 = args + ord1, ord2 = seq_orders + return 'my_custom_atan2(%s, %s, deriv1=%d, deriv2=%d)' % ( + self._print(arg1), self._print(arg2), ord1, ord2 + ) + + p = MyPrinter() + cosm1_1 = cosm1(x).diff(x, evaluate=False) + assert p.doprint(cosm1_1) == 'my_custom_cosm1(x, deriv_order=1)' + atan2_2_3 = atan2(x, y).diff(x, 2, y, 3, evaluate=False) + assert p.doprint(atan2_2_3) == 'my_custom_atan2(x, y, deriv1=2, deriv2=3)' + + try: + p.doprint(expm1(x).diff(x, evaluate=False)) + except PrintMethodNotImplementedError as e: + assert '_print_Derivative_expm1' in repr(e) + else: + assert False # should have thrown + + try: + p.doprint(Derivative(cosm1(x**2),x)) + except ValueError as e: + assert '_print_Derivative(' in repr(e) + else: + assert False # should have thrown diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_python.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_python.py new file mode 100644 index 0000000000000000000000000000000000000000..fb94a662be90934a672d08b3de44a22e2580d8b6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_python.py @@ -0,0 +1,203 @@ +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (Abs, conjugate) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from sympy.series.limits import limit + +from sympy.printing.python import python + +from sympy.testing.pytest import raises, XFAIL + +x, y = symbols('x,y') +th = Symbol('theta') +ph = Symbol('phi') + + +def test_python_basic(): + # Simple numbers/symbols + assert python(-Rational(1)/2) == "e = Rational(-1, 2)" + assert python(-Rational(13)/22) == "e = Rational(-13, 22)" + assert python(oo) == "e = oo" + + # Powers + assert python(x**2) == "x = Symbol(\'x\')\ne = x**2" + assert python(1/x) == "x = Symbol('x')\ne = 1/x" + assert python(y*x**-2) == "y = Symbol('y')\nx = Symbol('x')\ne = y/x**2" + assert python( + x**Rational(-5, 2)) == "x = Symbol('x')\ne = x**Rational(-5, 2)" + + # Sums of terms + assert python(x**2 + x + 1) in [ + "x = Symbol('x')\ne = 1 + x + x**2", + "x = Symbol('x')\ne = x + x**2 + 1", + "x = Symbol('x')\ne = x**2 + x + 1", ] + assert python(1 - x) in [ + "x = Symbol('x')\ne = 1 - x", + "x = Symbol('x')\ne = -x + 1"] + assert python(1 - 2*x) in [ + "x = Symbol('x')\ne = 1 - 2*x", + "x = Symbol('x')\ne = -2*x + 1"] + assert python(1 - Rational(3, 2)*y/x) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3/2*y/x", + "y = Symbol('y')\nx = Symbol('x')\ne = -3/2*y/x + 1", + "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3*y/(2*x)"] + + # Multiplication + assert python(x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = x/y" + assert python(-x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = -x/y" + assert python((x + 2)/y) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(2 + x)", + "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(x + 2)", + "x = Symbol('x')\ny = Symbol('y')\ne = 1/y*(2 + x)", + "x = Symbol('x')\ny = Symbol('y')\ne = (2 + x)/y", + "x = Symbol('x')\ny = Symbol('y')\ne = (x + 2)/y"] + assert python((1 + x)*y) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = y*(1 + x)", + "y = Symbol('y')\nx = Symbol('x')\ne = y*(x + 1)", ] + + # Check for proper placement of negative sign + assert python(-5*x/(x + 10)) == "x = Symbol('x')\ne = -5*x/(x + 10)" + assert python(1 - Rational(3, 2)*(x + 1)) in [ + "x = Symbol('x')\ne = Rational(-3, 2)*x + Rational(-1, 2)", + "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)", + "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)" + ] + + +def test_python_keyword_symbol_name_escaping(): + # Check for escaping of keywords + assert python( + 5*Symbol("lambda")) == "lambda_ = Symbol('lambda')\ne = 5*lambda_" + assert (python(5*Symbol("lambda") + 7*Symbol("lambda_")) == + "lambda__ = Symbol('lambda')\nlambda_ = Symbol('lambda_')\ne = 7*lambda_ + 5*lambda__") + assert (python(5*Symbol("for") + Function("for_")(8)) == + "for__ = Symbol('for')\nfor_ = Function('for_')\ne = 5*for__ + for_(8)") + + +def test_python_keyword_function_name_escaping(): + assert python( + 5*Function("for")(8)) == "for_ = Function('for')\ne = 5*for_(8)" + + +def test_python_relational(): + assert python(Eq(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = Eq(x, y)" + assert python(Ge(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x >= y" + assert python(Le(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x <= y" + assert python(Gt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x > y" + assert python(Lt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x < y" + assert python(Ne(x/(y + 1), y**2)) in [ + "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(1 + y), y**2)", + "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(y + 1), y**2)"] + + +def test_python_functions(): + # Simple + assert python(2*x + exp(x)) in "x = Symbol('x')\ne = 2*x + exp(x)" + assert python(sqrt(2)) == 'e = sqrt(2)' + assert python(2**Rational(1, 3)) == 'e = 2**Rational(1, 3)' + assert python(sqrt(2 + pi)) == 'e = sqrt(2 + pi)' + assert python((2 + pi)**Rational(1, 3)) == 'e = (2 + pi)**Rational(1, 3)' + assert python(2**Rational(1, 4)) == 'e = 2**Rational(1, 4)' + assert python(Abs(x)) == "x = Symbol('x')\ne = Abs(x)" + assert python( + Abs(x/(x**2 + 1))) in ["x = Symbol('x')\ne = Abs(x/(1 + x**2))", + "x = Symbol('x')\ne = Abs(x/(x**2 + 1))"] + + # Univariate/Multivariate functions + f = Function('f') + assert python(f(x)) == "x = Symbol('x')\nf = Function('f')\ne = f(x)" + assert python(f(x, y)) == "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x, y)" + assert python(f(x/(y + 1), y)) in [ + "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(1 + y), y)", + "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(y + 1), y)"] + + # Nesting of square roots + assert python(sqrt((sqrt(x + 1)) + 1)) in [ + "x = Symbol('x')\ne = sqrt(1 + sqrt(1 + x))", + "x = Symbol('x')\ne = sqrt(sqrt(x + 1) + 1)"] + + # Nesting of powers + assert python((((x + 1)**Rational(1, 3)) + 1)**Rational(1, 3)) in [ + "x = Symbol('x')\ne = (1 + (1 + x)**Rational(1, 3))**Rational(1, 3)", + "x = Symbol('x')\ne = ((x + 1)**Rational(1, 3) + 1)**Rational(1, 3)"] + + # Function powers + assert python(sin(x)**2) == "x = Symbol('x')\ne = sin(x)**2" + + +@XFAIL +def test_python_functions_conjugates(): + a, b = map(Symbol, 'ab') + assert python( conjugate(a + b*I) ) == '_ _\na - I*b' + assert python( conjugate(exp(a + b*I)) ) == ' _ _\n a - I*b\ne ' + + +def test_python_derivatives(): + # Simple + f_1 = Derivative(log(x), x, evaluate=False) + assert python(f_1) == "x = Symbol('x')\ne = Derivative(log(x), x)" + + f_2 = Derivative(log(x), x, evaluate=False) + x + assert python(f_2) == "x = Symbol('x')\ne = x + Derivative(log(x), x)" + + # Multiple symbols + f_3 = Derivative(log(x) + x**2, x, y, evaluate=False) + assert python(f_3) == \ + "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(x**2 + log(x), x, y)" + + f_4 = Derivative(2*x*y, y, x, evaluate=False) + x**2 + assert python(f_4) in [ + "x = Symbol('x')\ny = Symbol('y')\ne = x**2 + Derivative(2*x*y, y, x)", + "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(2*x*y, y, x) + x**2"] + + +def test_python_integrals(): + # Simple + f_1 = Integral(log(x), x) + assert python(f_1) == "x = Symbol('x')\ne = Integral(log(x), x)" + + f_2 = Integral(x**2, x) + assert python(f_2) == "x = Symbol('x')\ne = Integral(x**2, x)" + + # Double nesting of pow + f_3 = Integral(x**(2**x), x) + assert python(f_3) == "x = Symbol('x')\ne = Integral(x**(2**x), x)" + + # Definite integrals + f_4 = Integral(x**2, (x, 1, 2)) + assert python(f_4) == "x = Symbol('x')\ne = Integral(x**2, (x, 1, 2))" + + f_5 = Integral(x**2, (x, Rational(1, 2), 10)) + assert python( + f_5) == "x = Symbol('x')\ne = Integral(x**2, (x, Rational(1, 2), 10))" + + # Nested integrals + f_6 = Integral(x**2*y**2, x, y) + assert python(f_6) == "x = Symbol('x')\ny = Symbol('y')\ne = Integral(x**2*y**2, x, y)" + + +def test_python_matrix(): + p = python(Matrix([[x**2+1, 1], [y, x+y]])) + s = "x = Symbol('x')\ny = Symbol('y')\ne = MutableDenseMatrix([[x**2 + 1, 1], [y, x + y]])" + assert p == s + +def test_python_limits(): + assert python(limit(x, x, oo)) == 'e = oo' + assert python(limit(x**2, x, 0)) == 'e = 0' + +def test_issue_20762(): + # Make sure Python removes curly braces from subscripted variables + a_b = Symbol('a_{b}') + b = Symbol('b') + expr = a_b*b + assert python(expr) == "a_b = Symbol('a_{b}')\nb = Symbol('b')\ne = a_b*b" + + +def test_settings(): + raises(TypeError, lambda: python(x, method="garbage")) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_repr.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_repr.py new file mode 100644 index 0000000000000000000000000000000000000000..da58883b4fb027ed82db842a0a1ce5f76a49a8bb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_repr.py @@ -0,0 +1,382 @@ +from __future__ import annotations +from typing import Any + +from sympy.external.gmpy import GROUND_TYPES +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.assumptions.ask import Q +from sympy.core.function import (Function, WildFunction) +from sympy.core.numbers import (AlgebraicNumber, Float, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import sin +from sympy.functions.special.delta_functions import Heaviside +from sympy.logic.boolalg import (false, true) +from sympy.matrices.dense import (Matrix, ones) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.combinatorics import Cycle, Permutation +from sympy.core.symbol import Str +from sympy.geometry import Point, Ellipse +from sympy.printing import srepr +from sympy.polys import ring, field, ZZ, QQ, lex, grlex, Poly +from sympy.polys.polyclasses import DMP +from sympy.polys.agca.extensions import FiniteExtension + +x, y = symbols('x,y') + +# eval(srepr(expr)) == expr has to succeed in the right environment. The right +# environment is the scope of "from sympy import *" for most cases. +ENV: dict[str, Any] = {"Str": Str} +exec("from sympy import *", ENV) + + +def sT(expr, string, import_stmt=None, **kwargs): + """ + sT := sreprTest + + Tests that srepr delivers the expected string and that + the condition eval(srepr(expr))==expr holds. + """ + if import_stmt is None: + ENV2 = ENV + else: + ENV2 = ENV.copy() + exec(import_stmt, ENV2) + + assert srepr(expr, **kwargs) == string + assert eval(string, ENV2) == expr + + +def test_printmethod(): + class R(Abs): + def _sympyrepr(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert srepr(R(x)) == "foo(Symbol('x'))" + + +def test_Add(): + sT(x + y, "Add(Symbol('x'), Symbol('y'))") + assert srepr(x**2 + 1, order='lex') == "Add(Pow(Symbol('x'), Integer(2)), Integer(1))" + assert srepr(x**2 + 1, order='old') == "Add(Integer(1), Pow(Symbol('x'), Integer(2)))" + assert srepr(sympify('x + 3 - 2', evaluate=False), order='none') == "Add(Symbol('x'), Integer(3), Mul(Integer(-1), Integer(2)))" + + +def test_more_than_255_args_issue_10259(): + from sympy.core.add import Add + from sympy.core.mul import Mul + for op in (Add, Mul): + expr = op(*symbols('x:256')) + assert eval(srepr(expr)) == expr + + +def test_Function(): + sT(Function("f")(x), "Function('f')(Symbol('x'))") + # test unapplied Function + sT(Function('f'), "Function('f')") + + sT(sin(x), "sin(Symbol('x'))") + sT(sin, "sin") + + +def test_Heaviside(): + sT(Heaviside(x), "Heaviside(Symbol('x'))") + sT(Heaviside(x, 1), "Heaviside(Symbol('x'), Integer(1))") + + +def test_Geometry(): + sT(Point(0, 0), "Point2D(Integer(0), Integer(0))") + sT(Ellipse(Point(0, 0), 5, 1), + "Ellipse(Point2D(Integer(0), Integer(0)), Integer(5), Integer(1))") + # TODO more tests + + +def test_Singletons(): + sT(S.Catalan, 'Catalan') + sT(S.ComplexInfinity, 'zoo') + sT(S.EulerGamma, 'EulerGamma') + sT(S.Exp1, 'E') + sT(S.GoldenRatio, 'GoldenRatio') + sT(S.TribonacciConstant, 'TribonacciConstant') + sT(S.Half, 'Rational(1, 2)') + sT(S.ImaginaryUnit, 'I') + sT(S.Infinity, 'oo') + sT(S.NaN, 'nan') + sT(S.NegativeInfinity, '-oo') + sT(S.NegativeOne, 'Integer(-1)') + sT(S.One, 'Integer(1)') + sT(S.Pi, 'pi') + sT(S.Zero, 'Integer(0)') + sT(S.Complexes, 'Complexes') + sT(S.EmptySequence, 'EmptySequence') + sT(S.EmptySet, 'EmptySet') + # sT(S.IdentityFunction, 'Lambda(_x, _x)') + sT(S.Naturals, 'Naturals') + sT(S.Naturals0, 'Naturals0') + sT(S.Rationals, 'Rationals') + sT(S.Reals, 'Reals') + sT(S.UniversalSet, 'UniversalSet') + + +def test_Integer(): + sT(Integer(4), "Integer(4)") + + +def test_list(): + sT([x, Integer(4)], "[Symbol('x'), Integer(4)]") + + +def test_Matrix(): + for cls, name in [(Matrix, "MutableDenseMatrix"), (ImmutableDenseMatrix, "ImmutableDenseMatrix")]: + sT(cls([[x**+1, 1], [y, x + y]]), + "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name) + + sT(cls(), "%s([])" % name) + + sT(cls([[x**+1, 1], [y, x + y]]), "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name) + + +def test_empty_Matrix(): + sT(ones(0, 3), "MutableDenseMatrix(0, 3, [])") + sT(ones(4, 0), "MutableDenseMatrix(4, 0, [])") + sT(ones(0, 0), "MutableDenseMatrix([])") + + +def test_Rational(): + sT(Rational(1, 3), "Rational(1, 3)") + sT(Rational(-1, 3), "Rational(-1, 3)") + + +def test_Float(): + sT(Float('1.23', dps=3), "Float('1.22998', precision=13)") + sT(Float('1.23456789', dps=9), "Float('1.23456788994', precision=33)") + sT(Float('1.234567890123456789', dps=19), + "Float('1.234567890123456789013', precision=66)") + sT(Float('0.60038617995049726', dps=15), + "Float('0.60038617995049726', precision=53)") + + sT(Float('1.23', precision=13), "Float('1.22998', precision=13)") + sT(Float('1.23456789', precision=33), + "Float('1.23456788994', precision=33)") + sT(Float('1.234567890123456789', precision=66), + "Float('1.234567890123456789013', precision=66)") + sT(Float('0.60038617995049726', precision=53), + "Float('0.60038617995049726', precision=53)") + + sT(Float('0.60038617995049726', 15), + "Float('0.60038617995049726', precision=53)") + + +def test_Symbol(): + sT(x, "Symbol('x')") + sT(y, "Symbol('y')") + sT(Symbol('x', negative=True), "Symbol('x', negative=True)") + + +def test_Symbol_two_assumptions(): + x = Symbol('x', negative=0, integer=1) + # order could vary + s1 = "Symbol('x', integer=True, negative=False)" + s2 = "Symbol('x', negative=False, integer=True)" + assert srepr(x) in (s1, s2) + assert eval(srepr(x), ENV) == x + + +def test_Symbol_no_special_commutative_treatment(): + sT(Symbol('x'), "Symbol('x')") + sT(Symbol('x', commutative=False), "Symbol('x', commutative=False)") + sT(Symbol('x', commutative=0), "Symbol('x', commutative=False)") + sT(Symbol('x', commutative=True), "Symbol('x', commutative=True)") + sT(Symbol('x', commutative=1), "Symbol('x', commutative=True)") + + +def test_Wild(): + sT(Wild('x', even=True), "Wild('x', even=True)") + + +def test_Dummy(): + d = Dummy('d') + sT(d, "Dummy('d', dummy_index=%s)" % str(d.dummy_index)) + + +def test_Dummy_assumption(): + d = Dummy('d', nonzero=True) + assert d == eval(srepr(d)) + s1 = "Dummy('d', dummy_index=%s, nonzero=True)" % str(d.dummy_index) + s2 = "Dummy('d', nonzero=True, dummy_index=%s)" % str(d.dummy_index) + assert srepr(d) in (s1, s2) + + +def test_Dummy_from_Symbol(): + # should not get the full dictionary of assumptions + n = Symbol('n', integer=True) + d = n.as_dummy() + assert srepr(d + ) == "Dummy('n', dummy_index=%s)" % str(d.dummy_index) + + +def test_tuple(): + sT((x,), "(Symbol('x'),)") + sT((x, y), "(Symbol('x'), Symbol('y'))") + + +def test_WildFunction(): + sT(WildFunction('w'), "WildFunction('w')") + + +def test_settins(): + raises(TypeError, lambda: srepr(x, method="garbage")) + + +def test_Mul(): + sT(3*x**3*y, "Mul(Integer(3), Pow(Symbol('x'), Integer(3)), Symbol('y'))") + assert srepr(3*x**3*y, order='old') == "Mul(Integer(3), Symbol('y'), Pow(Symbol('x'), Integer(3)))" + assert srepr(sympify('(x+4)*2*x*7', evaluate=False), order='none') == "Mul(Add(Symbol('x'), Integer(4)), Integer(2), Symbol('x'), Integer(7))" + + +def test_AlgebraicNumber(): + a = AlgebraicNumber(sqrt(2)) + sT(a, "AlgebraicNumber(Pow(Integer(2), Rational(1, 2)), [Integer(1), Integer(0)])") + a = AlgebraicNumber(root(-2, 3)) + sT(a, "AlgebraicNumber(Pow(Integer(-2), Rational(1, 3)), [Integer(1), Integer(0)])") + + +def test_PolyRing(): + assert srepr(ring("x", ZZ, lex)[0]) == "PolyRing((Symbol('x'),), ZZ, lex)" + assert srepr(ring("x,y", QQ, grlex)[0]) == "PolyRing((Symbol('x'), Symbol('y')), QQ, grlex)" + assert srepr(ring("x,y,z", ZZ["t"], lex)[0]) == "PolyRing((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)" + + +def test_FracField(): + assert srepr(field("x", ZZ, lex)[0]) == "FracField((Symbol('x'),), ZZ, lex)" + assert srepr(field("x,y", QQ, grlex)[0]) == "FracField((Symbol('x'), Symbol('y')), QQ, grlex)" + assert srepr(field("x,y,z", ZZ["t"], lex)[0]) == "FracField((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)" + + +def test_PolyElement(): + R, x, y = ring("x,y", ZZ) + assert srepr(3*x**2*y + 1) == "PolyElement(PolyRing((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)])" + + +def test_FracElement(): + F, x, y = field("x,y", ZZ) + assert srepr((3*x**2*y + 1)/(x - y**2)) == "FracElement(FracField((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)], [((1, 0), 1), ((0, 2), -1)])" + + +def test_FractionField(): + assert srepr(QQ.frac_field(x)) == \ + "FractionField(FracField((Symbol('x'),), QQ, lex))" + assert srepr(QQ.frac_field(x, y, order=grlex)) == \ + "FractionField(FracField((Symbol('x'), Symbol('y')), QQ, grlex))" + + +def test_PolynomialRingBase(): + assert srepr(ZZ.old_poly_ring(x)) == \ + "GlobalPolynomialRing(ZZ, Symbol('x'))" + assert srepr(ZZ[x].old_poly_ring(y)) == \ + "GlobalPolynomialRing(ZZ[x], Symbol('y'))" + assert srepr(QQ.frac_field(x).old_poly_ring(y)) == \ + "GlobalPolynomialRing(FractionField(FracField((Symbol('x'),), QQ, lex)), Symbol('y'))" + + +def test_DMP(): + p1 = DMP([1, 2], ZZ) + p2 = ZZ.old_poly_ring(x)([1, 2]) + if GROUND_TYPES != 'flint': + assert srepr(p1) == "DMP_Python([1, 2], ZZ)" + assert srepr(p2) == "DMP_Python([1, 2], ZZ)" + else: + assert srepr(p1) == "DUP_Flint([1, 2], ZZ)" + assert srepr(p2) == "DUP_Flint([1, 2], ZZ)" + + +def test_FiniteExtension(): + assert srepr(FiniteExtension(Poly(x**2 + 1, x))) == \ + "FiniteExtension(Poly(x**2 + 1, x, domain='ZZ'))" + + +def test_ExtensionElement(): + A = FiniteExtension(Poly(x**2 + 1, x)) + if GROUND_TYPES != 'flint': + ans = "ExtElem(DMP_Python([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))" + else: + ans = "ExtElem(DUP_Flint([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))" + assert srepr(A.generator) == ans + +def test_BooleanAtom(): + assert srepr(true) == "true" + assert srepr(false) == "false" + + +def test_Integers(): + sT(S.Integers, "Integers") + + +def test_Naturals(): + sT(S.Naturals, "Naturals") + + +def test_Naturals0(): + sT(S.Naturals0, "Naturals0") + + +def test_Reals(): + sT(S.Reals, "Reals") + + +def test_matrix_expressions(): + n = symbols('n', integer=True) + A = MatrixSymbol("A", n, n) + B = MatrixSymbol("B", n, n) + sT(A, "MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True))") + sT(A*B, "MatMul(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))") + sT(A + B, "MatAdd(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))") + + +def test_Cycle(): + # FIXME: sT fails because Cycle is not immutable and calling srepr(Cycle(1, 2)) + # adds keys to the Cycle dict (GH-17661) + #import_stmt = "from sympy.combinatorics import Cycle" + #sT(Cycle(1, 2), "Cycle(1, 2)", import_stmt) + assert srepr(Cycle(1, 2)) == "Cycle(1, 2)" + + +def test_Permutation(): + import_stmt = "from sympy.combinatorics import Permutation" + sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt, perm_cyclic=False) + sT(Permutation(1, 2)(3, 4), "Permutation(1, 2)(3, 4)", import_stmt, perm_cyclic=True) + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt) + Permutation.print_cyclic = old_print_cyclic + +def test_dict(): + from sympy.abc import x, y, z + d = {} + assert srepr(d) == "{}" + d = {x: y} + assert srepr(d) == "{Symbol('x'): Symbol('y')}" + d = {x: y, y: z} + assert srepr(d) in ( + "{Symbol('x'): Symbol('y'), Symbol('y'): Symbol('z')}", + "{Symbol('y'): Symbol('z'), Symbol('x'): Symbol('y')}", + ) + d = {x: {y: z}} + assert srepr(d) == "{Symbol('x'): {Symbol('y'): Symbol('z')}}" + +def test_set(): + from sympy.abc import x, y + s = set() + assert srepr(s) == "set()" + s = {x, y} + assert srepr(s) in ("{Symbol('x'), Symbol('y')}", "{Symbol('y'), Symbol('x')}") + +def test_Predicate(): + sT(Q.even, "Q.even") + +def test_AppliedPredicate(): + sT(Q.even(Symbol('z')), "AppliedPredicate(Q.even, Symbol('z'))") diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_rust.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_rust.py new file mode 100644 index 0000000000000000000000000000000000000000..c81d592faca0d4a31e5a9618a48d67cb19ca94d8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_rust.py @@ -0,0 +1,363 @@ +from sympy.core import (S, pi, oo, symbols, Rational, Integer, + GoldenRatio, EulerGamma, Catalan, Lambda, Dummy, + Eq, Ne, Le, Lt, Gt, Ge, Mod) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + sign, floor) +from sympy.logic import ITE +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix + +from sympy.printing.codeprinter import rust_code + +x, y, z = symbols('x,y,z', integer=False, real=True) +k, m, n = symbols('k,m,n', integer=True) + + +def test_Integer(): + assert rust_code(Integer(42)) == "42" + assert rust_code(Integer(-56)) == "-56" + + +def test_Relational(): + assert rust_code(Eq(x, y)) == "x == y" + assert rust_code(Ne(x, y)) == "x != y" + assert rust_code(Le(x, y)) == "x <= y" + assert rust_code(Lt(x, y)) == "x < y" + assert rust_code(Gt(x, y)) == "x > y" + assert rust_code(Ge(x, y)) == "x >= y" + + +def test_Rational(): + assert rust_code(Rational(3, 7)) == "3_f64/7.0" + assert rust_code(Rational(18, 9)) == "2" + assert rust_code(Rational(3, -7)) == "-3_f64/7.0" + assert rust_code(Rational(-3, -7)) == "3_f64/7.0" + assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0" + assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x" + + +def test_basic_ops(): + assert rust_code(x + y) == "x + y" + assert rust_code(x - y) == "x - y" + assert rust_code(x * y) == "x*y" + assert rust_code(x / y) == "x*y.recip()" + assert rust_code(-x) == "-x" + assert rust_code(2 * x) == "2.0*x" + assert rust_code(y + 2) == "y + 2.0" + assert rust_code(x + n) == "n as f64 + x" + +def test_printmethod(): + class fabs(Abs): + def _rust_code(self, printer): + return "%s.fabs()" % printer._print(self.args[0]) + assert rust_code(fabs(x)) == "x.fabs()" + a = MatrixSymbol("a", 1, 3) + assert rust_code(a[0,0]) == 'a[0]' + + +def test_Functions(): + assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())" + assert rust_code(abs(x)) == "x.abs()" + assert rust_code(ceiling(x)) == "x.ceil()" + assert rust_code(floor(x)) == "x.floor()" + + # Automatic rewrite + assert rust_code(Mod(x, 3)) == 'x - 3.0*((1_f64/3.0)*x).floor()' + + +def test_Pow(): + assert rust_code(1/x) == "x.recip()" + assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()" + assert rust_code(sqrt(x)) == "x.sqrt()" + assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()" + + assert rust_code(1/sqrt(x)) == "x.sqrt().recip()" + assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()" + + assert rust_code(1/pi) == "PI.recip()" + assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()" + assert rust_code(pi**-0.5) == "PI.sqrt().recip()" + + assert rust_code(x**Rational(1, 3)) == "x.cbrt()" + assert rust_code(2**x) == "x.exp2()" + assert rust_code(exp(x)) == "x.exp()" + assert rust_code(x**3) == "x.powi(3)" + assert rust_code(x**(y**3)) == "x.powf(y.powi(3))" + assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)" + + g = implemented_function('g', Lambda(x, 2*x)) + assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2.0*x).powf(-x + y.powf(x))/(x.powi(2) + y)" + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1), + (lambda base, exp: not exp.is_integer, "pow", 1)] + assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)' + assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)' + + +def test_constants(): + assert rust_code(pi) == "PI" + assert rust_code(oo) == "INFINITY" + assert rust_code(S.Infinity) == "INFINITY" + assert rust_code(-oo) == "NEG_INFINITY" + assert rust_code(S.NegativeInfinity) == "NEG_INFINITY" + assert rust_code(S.NaN) == "NAN" + assert rust_code(exp(1)) == "E" + assert rust_code(S.Exp1) == "E" + + +def test_constants_other(): + assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2.0*GoldenRatio" % GoldenRatio.evalf(17) + assert rust_code( + 2*Catalan) == "const Catalan: f64 = %s;\n2.0*Catalan" % Catalan.evalf(17) + assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2.0*EulerGamma" % EulerGamma.evalf(17) + + +def test_boolean(): + assert rust_code(True) == "true" + assert rust_code(S.true) == "true" + assert rust_code(False) == "false" + assert rust_code(S.false) == "false" + assert rust_code(k & m) == "k && m" + assert rust_code(k | m) == "k || m" + assert rust_code(~k) == "!k" + assert rust_code(k & m & n) == "k && m && n" + assert rust_code(k | m | n) == "k || m || n" + assert rust_code((k & m) | n) == "n || k && m" + assert rust_code((k | m) & n) == "n && (k || m)" + + +def test_Piecewise(): + expr = Piecewise((x, x < 1), (x + 2, True)) + assert rust_code(expr) == ( + "if (x < 1.0) {\n" + " x\n" + "} else {\n" + " x + 2.0\n" + "}") + assert rust_code(expr, assign_to="r") == ( + "r = if (x < 1.0) {\n" + " x\n" + "} else {\n" + " x + 2.0\n" + "};") + assert rust_code(expr, assign_to="r", inline=True) == ( + "r = if (x < 1.0) { x } else { x + 2.0 };") + expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) + assert rust_code(expr, inline=True) == ( + "if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }") + assert rust_code(expr, assign_to="r", inline=True) == ( + "r = if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 };") + assert rust_code(expr, assign_to="r") == ( + "r = if (x < 1.0) {\n" + " x\n" + "} else if (x < 5.0) {\n" + " x + 1.0\n" + "} else {\n" + " x + 2.0\n" + "};") + expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) + assert rust_code(expr, inline=True) == ( + "2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 }") + expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42 + assert rust_code(expr, inline=True) == ( + "2.0*if (x < 1.0) { x } else if (x < 5.0) { x + 1.0 } else { x + 2.0 } - 42.0") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: rust_code(expr)) + + +def test_dereference_printing(): + expr = x + y + sin(z) + z + assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()" + + +def test_sign(): + expr = sign(x) * y + assert rust_code(expr) == "y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64" + assert rust_code(expr, assign_to='r') == "r = y*(if (x == 0.0) { 0.0 } else { (x).signum() }) as f64;" + + expr = sign(x + y) + 42 + assert rust_code(expr) == "(if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42" + assert rust_code(expr, assign_to='r') == "r = (if (x + y == 0.0) { 0.0 } else { (x + y).signum() }) + 42;" + + expr = sign(cos(x)) + assert rust_code(expr) == "(if (x.cos() == 0.0) { 0.0 } else { (x.cos()).signum() })" + + +def test_reserved_words(): + + x, y = symbols("x if") + + expr = sin(y) + assert rust_code(expr) == "if_.sin()" + assert rust_code(expr, dereference=[y]) == "(*if_).sin()" + assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()" + + with raises(ValueError): + rust_code(expr, error_on_reserved=True) + + +def test_ITE(): + ekpr = ITE(k < 1, m, n) + assert rust_code(ekpr) == ( + "if (k < 1) {\n" + " m\n" + "} else {\n" + " n\n" + "}") + + +def test_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + + x = IndexedBase('x')[j] + assert rust_code(x) == "x[j]" + + A = IndexedBase('A')[i, j] + assert rust_code(A) == "A[m*i + j]" + + B = IndexedBase('B')[i, j, k] + assert rust_code(B) == "B[m*o*i + o*j + k]" + + +def test_dummy_loops(): + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + assert rust_code(x[i], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = x[i];\n" + "}") + + +def test_loops(): + m, n = symbols('m n', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i = Idx('i', m) + j = Idx('j', n) + + assert rust_code(A[i, j]*x[j], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " y[i] = A[n*i + j]*x[j] + y[i];\n" + " }\n" + "}") + + assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = x[i] + z[i];\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " y[i] = A[n*i + j]*x[j] + y[i];\n" + " }\n" + "}") + + +def test_loops_multiple_contractions(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " for k in 0..o {\n" + " for l in 0..p {\n" + " y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ + " }\n" + " }\n" + " }\n" + "}") + + +def test_loops_addfactor(): + m, n, o, p = symbols('m n o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) + assert code == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " for k in 0..o {\n" + " for l in 0..p {\n" + " y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ + " }\n" + " }\n" + " }\n" + "}") + + +def test_settings(): + raises(TypeError, lambda: rust_code(sin(x), method="garbage")) + + +def test_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert rust_code(g(x)) == "2*x" + + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert rust_code(g(x)) == ( + "const Catalan: f64 = %s;\n2.0*x/Catalan" % Catalan.evalf(17)) + + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert rust_code(g(A[i]), assign_to=A[i]) == ( + "for i in 0..n {\n" + " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" + "}") + + +def test_user_functions(): + x = symbols('x', integer=False) + n = symbols('n', integer=True) + custom_functions = { + "ceiling": "ceil", + "Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)], + } + assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()" + assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)" + assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)" + + +def test_matrix(): + assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]' + with raises(ValueError): + rust_code(Matrix([[1, 2, 3]])) + + +def test_sparse_matrix(): + # gh-15791 + with raises(NotImplementedError): + rust_code(SparseMatrix([[1, 2, 3]])) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_smtlib.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_smtlib.py new file mode 100644 index 0000000000000000000000000000000000000000..48ff3d432d9042bf178f4e52dc46c787059937a3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_smtlib.py @@ -0,0 +1,553 @@ +import contextlib +import itertools +import re +import typing +from enum import Enum +from typing import Callable + +import sympy +from sympy import Add, Implies, sqrt +from sympy.core import Mul, Pow +from sympy.core import (S, pi, symbols, Function, Rational, Integer, + Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.functions import Piecewise, exp, sin, cos +from sympy.assumptions.ask import Q +from sympy.printing.smtlib import smtlib_code +from sympy.testing.pytest import raises, Failed + +x, y, z = symbols('x,y,z') + + +class _W(Enum): + DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.IGNORECASE) + WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.IGNORECASE) + WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.IGNORECASE) + + +@contextlib.contextmanager +def _check_warns(expected: typing.Iterable[_W]): + warns: typing.List[str] = [] + log_warn = warns.append + yield log_warn + + errors = [] + for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)): + if not e: + errors += [f"[{i}] Received unexpected warning `{w}`."] + elif not w: + errors += [f"[{i}] Did not receive expected warning `{e.name}`."] + elif not e.value.match(w): + errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."] + + if errors: raise Failed('\n'.join(errors)) + + +def test_Integer(): + with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w: + assert smtlib_code(Integer(67), log_warn=w) == "67" + assert smtlib_code(Integer(-1), log_warn=w) == "-1" + with _check_warns([]) as w: + assert smtlib_code(Integer(67)) == "67" + assert smtlib_code(Integer(-1)) == "-1" + + +def test_Rational(): + with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w: + assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)" + assert smtlib_code(Rational(18, 9), log_warn=w) == "2" + assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)" + assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w: + assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)" + assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \ + "(* (/ 3 7) x)" + + +def test_Relational(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w: + assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))" + assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))" + assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))" + assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))" + assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))" + assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))" + + +def test_AppliedBinaryRelation(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w: + assert smtlib_code(Q.eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))" + assert smtlib_code(Q.ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))" + assert smtlib_code(Q.lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))" + assert smtlib_code(Q.le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))" + assert smtlib_code(Q.gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))" + assert smtlib_code(Q.ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))" + + raises(ValueError, lambda: smtlib_code(Q.complex(x), log_warn=w)) + + +def test_AppliedPredicate(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 6) as w: + assert smtlib_code(Q.positive(x), auto_declare=False, log_warn=w) == "(assert (> x 0))" + assert smtlib_code(Q.negative(x), auto_declare=False, log_warn=w) == "(assert (< x 0))" + assert smtlib_code(Q.zero(x), auto_declare=False, log_warn=w) == "(assert (= x 0))" + assert smtlib_code(Q.nonpositive(x), auto_declare=False, log_warn=w) == "(assert (<= x 0))" + assert smtlib_code(Q.nonnegative(x), auto_declare=False, log_warn=w) == "(assert (>= x 0))" + assert smtlib_code(Q.nonzero(x), auto_declare=False, log_warn=w) == "(assert (not (= x 0)))" + +def test_Function(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))" + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + abs(x), + symbol_table={x: int, y: bool}, + known_types={int: "INTEGER_TYPE"}, + known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"}, + log_warn=w + ) == "(declare-const x INTEGER_TYPE)\n" \ + "(ABSOLUTE_VALUE_OF x)" + + my_fun1 = Function('f1') + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + my_fun1(x), + symbol_table={my_fun1: Callable[[bool], float]}, + log_warn=w + ) == "(declare-const x Bool)\n" \ + "(declare-fun f1 (Bool) Real)\n" \ + "(f1 x)" + + with _check_warns([]) as w: + assert smtlib_code( + my_fun1(x), + symbol_table={my_fun1: Callable[[bool], bool]}, + log_warn=w + ) == "(declare-const x Bool)\n" \ + "(declare-fun f1 (Bool) Bool)\n" \ + "(assert (f1 x))" + + assert smtlib_code( + Eq(my_fun1(x, z), y), + symbol_table={my_fun1: Callable[[int, bool], bool]}, + log_warn=w + ) == "(declare-const x Int)\n" \ + "(declare-const y Bool)\n" \ + "(declare-const z Bool)\n" \ + "(declare-fun f1 (Int Bool) Bool)\n" \ + "(assert (= (f1 x z) y))" + + assert smtlib_code( + Eq(my_fun1(x, z), y), + symbol_table={my_fun1: Callable[[int, bool], bool]}, + known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='}, + log_warn=w + ) == "(declare-const x Int)\n" \ + "(declare-const y Bool)\n" \ + "(declare-const z Bool)\n" \ + "(assert (== (MY_KNOWN_FUN x z) y))" + + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w: + assert smtlib_code( + Eq(my_fun1(x, z), y), + known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='}, + log_warn=w + ) == "(declare-const x Real)\n" \ + "(declare-const y Real)\n" \ + "(declare-const z Real)\n" \ + "(assert (== (MY_KNOWN_FUN x z) y))" + + +def test_Pow(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)" + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))" + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))' + + a = Symbol('a', integer=True) + b = Symbol('b', real=True) + c = Symbol('c') + + def g(x): return 2 * x + + # if x=1, y=2, then expr=2.333... + expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b) + + with _check_warns([]) as w: + assert smtlib_code( + [ + Eq(a < 2, c), + Eq(b > a, c), + c & True, + Eq(expr, 2 + Rational(1, 3)) + ], + log_warn=w + ) == '(declare-const a Int)\n' \ + '(declare-const b Real)\n' \ + '(declare-const c Bool)\n' \ + '(assert (= (< a 2) c))\n' \ + '(assert (= (> b a) c))\n' \ + '(assert c)\n' \ + '(assert (= ' \ + '(* (pow (* 7.0 a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \ + '(/ 7 3)' \ + '))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False), + log_warn=w + ) == '(declare-const b Real)\n' \ + '(declare-const c Real)\n' \ + '(* -2 c (pow (* b b) -1))' + + +def test_basic_ops(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)" + + # with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w: + # todo: implement re-write, currently does '(+ x (* -1 y))' instead + # assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)" + + +def test_quantifier_extensions(): + from sympy.logic.boolalg import Boolean + from sympy import Interval, Tuple, sympify + + # start For-all quantifier class example + class ForAll(Boolean): + def _smtlib(self, printer): + bound_symbol_declarations = [ + printer._s_expr(sym.name, [ + printer._known_types[printer.symbol_table[sym]], + Interval(start, end) + ]) for sym, start, end in self.limits + ] + return printer._s_expr('forall', [ + printer._s_expr('', bound_symbol_declarations), + self.function + ]) + + @property + def bound_symbols(self): + return {s for s, _, _ in self.limits} + + @property + def free_symbols(self): + bound_symbol_names = {s.name for s in self.bound_symbols} + return { + s for s in self.function.free_symbols + if s.name not in bound_symbol_names + } + + def __new__(cls, *args): + limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))] + function = [sympify(a) for a in args if isinstance(a, Boolean)] + assert len(limits) + len(function) == len(args) + assert len(function) == 1 + function = function[0] + + if isinstance(function, ForAll): return ForAll.__new__( + ForAll, *(limits + function.limits), function.function + ) + inst = Boolean.__new__(cls) + inst._args = tuple(limits + [function]) + inst.limits = limits + inst.function = function + return inst + + # end For-All Quantifier class example + + f = Function('f') + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code( + ForAll((x, -42, +21), Eq(f(x), f(x))), + symbol_table={f: Callable[[float], float]}, + log_warn=w + ) == '(assert (forall ( (x Real [-42, 21])) true))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w: + assert smtlib_code( + ForAll( + (x, -42, +21), (y, -100, 3), + Implies(Eq(x, y), Eq(f(x), f(y))) + ), + symbol_table={f: Callable[[float], float]}, + log_warn=w + ) == '(declare-fun f (Real) Real)\n' \ + '(assert (' \ + 'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \ + '(=> (= x y) (= (f x) (f y)))' \ + '))' + + a = Symbol('a', integer=True) + b = Symbol('b', real=True) + c = Symbol('c') + + with _check_warns([]) as w: + assert smtlib_code( + ForAll( + (a, 2, 100), ForAll( + (b, 2, 100), + Implies(a < b, sqrt(a) < b) | c + )), + log_warn=w + ) == '(declare-const c Bool)\n' \ + '(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \ + '(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \ + '))' + + +def test_mix_number_mult_symbols(): + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + 1 / pi, + known_constants={pi: "MY_PI"}, + log_warn=w + ) == '(pow MY_PI -1)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + [ + Eq(pi, 3.14, evaluate=False), + 1 / pi, + ], + known_constants={pi: "MY_PI"}, + log_warn=w + ) == '(assert (= MY_PI 3.14))\n' \ + '(pow MY_PI -1)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={ + S.Pi: 'p', S.GoldenRatio: 'g', + S.Exp1: 'e' + }, + known_functions={ + Add: 'plus', + exp: 'exp' + }, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={ + S.Pi: 'p' + }, + known_functions={ + Add: 'plus', + exp: 'exp' + }, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_functions={Add: 'plus'}, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={S.Exp1: 'e'}, + known_functions={Add: 'plus'}, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)' + + +def test_boolean(): + with _check_warns([]) as w: + assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(assert (and x y))' + assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(assert (or x y))' + assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \ + '(assert (not x))' + assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(declare-const z Bool)\n' \ + '(assert (and x y z))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(declare-const z Real)\n' \ + '(assert (or (> z 3) (and x (not y))))' + + f = Function('f') + g = Function('g') + h = Function('h') + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code( + [Gt(f(x), y), + Lt(y, g(z))], + symbol_table={ + f: Callable[[bool], int], g: Callable[[bool], int], + }, log_warn=w + ) == '(declare-const x Bool)\n' \ + '(declare-const y Real)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Bool) Int)\n' \ + '(declare-fun g (Bool) Int)\n' \ + '(assert (> (f x) y))\n' \ + '(assert (< y (g z)))' + + with _check_warns([]) as w: + assert smtlib_code( + [Eq(f(x), y), + Lt(y, g(z))], + symbol_table={ + f: Callable[[bool], int], g: Callable[[bool], int], + }, log_warn=w + ) == '(declare-const x Bool)\n' \ + '(declare-const y Int)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Bool) Int)\n' \ + '(declare-fun g (Bool) Int)\n' \ + '(assert (= (f x) y))\n' \ + '(assert (< y (g z)))' + + with _check_warns([]) as w: + assert smtlib_code( + [Eq(f(x), y), + Eq(g(f(x)), z), + Eq(h(g(f(x))), x)], + symbol_table={ + f: Callable[[float], int], + g: Callable[[int], bool], + h: Callable[[bool], float] + }, + log_warn=w + ) == '(declare-const x Real)\n' \ + '(declare-const y Int)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Real) Int)\n' \ + '(declare-fun g (Int) Bool)\n' \ + '(declare-fun h (Bool) Real)\n' \ + '(assert (= (f x) y))\n' \ + '(assert (= (g (f x)) z))\n' \ + '(assert (= (h (g (f x))) x))' + + +# todo: make smtlib_code support arrays +# def test_containers(): +# assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ +# "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]" +# assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))" +# assert julia_code([1]) == "Any[1]" +# assert julia_code((1,)) == "(1,)" +# assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)" +# assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))" +# # scalar, matrix, empty matrix and empty list +# assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])" + +def test_smtlib_piecewise(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Piecewise((x, x < 1), + (x ** 2, True)), + auto_declare=False, + log_warn=w + ) == '(ite (< x 1) x (pow x 2))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Piecewise((x ** 2, x < 1), + (x ** 3, x < 2), + (x ** 4, x < 3), + (x ** 5, True)), + auto_declare=False, + log_warn=w + ) == '(ite (< x 1) (pow x 2) ' \ + '(ite (< x 2) (pow x 3) ' \ + '(ite (< x 3) (pow x 4) ' \ + '(pow x 5))))' + + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0)) + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + raises(AssertionError, lambda: smtlib_code(expr, log_warn=w)) + + +def test_smtlib_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x ** 2, True)) + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))' + + +# todo: make smtlib_code support arrays / matrices ? +# def test_smtlib_matrix_assign_to(): +# A = Matrix([[1, 2, 3]]) +# assert smtlib_code(A, assign_to='a') == "a = [1 2 3]" +# A = Matrix([[1, 2], [3, 4]]) +# assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]" + +# def test_julia_matrix_1x1(): +# A = Matrix([[3]]) +# B = MatrixSymbol('B', 1, 1) +# C = MatrixSymbol('C', 1, 2) +# assert julia_code(A, assign_to=B) == "B = [3]" +# raises(ValueError, lambda: julia_code(A, assign_to=C)) + +# def test_julia_matrix_elements(): +# A = Matrix([[x, 2, x * y]]) +# assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2" +# A = MatrixSymbol('AA', 1, 3) +# assert julia_code(A) == "AA" +# assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \ +# "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]" +# assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]" + +def test_smtlib_boolean(): + with _check_warns([]) as w: + assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true' + assert smtlib_code(True, log_warn=w) == '(assert true)' + assert smtlib_code(S.true, log_warn=w) == '(assert true)' + assert smtlib_code(S.false, log_warn=w) == '(assert false)' + assert smtlib_code(False, log_warn=w) == '(assert false)' + assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false' + + +def test_not_supported(): + f = Function('f') + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w)) + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w)) + + +def test_Float(): + assert smtlib_code(0.0) == "0.0" + assert smtlib_code(0.000000000000000003) == '(* 3.0 (pow 10 -18))' + assert smtlib_code(5.3) == "5.3" diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_str.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_str.py new file mode 100644 index 0000000000000000000000000000000000000000..675212964b03bf9a9806088225c28d7f70971ca7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_str.py @@ -0,0 +1,1206 @@ +from sympy import MatAdd +from sympy.algebras.quaternion import Quaternion +from sympy.assumptions.ask import Q +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.combinatorics.partitions import Partition +from sympy.concrete.summations import (Sum, summation) +from sympy.core.add import Add +from sympy.core.containers import (Dict, Tuple) +from sympy.core.expr import UnevaluatedExpr, Expr +from sympy.core.function import (Derivative, Function, Lambda, Subs, WildFunction) +from sympy.core.mul import Mul +from sympy.core import (Catalan, EulerGamma, GoldenRatio, TribonacciConstant) +from sympy.core.numbers import (E, Float, I, Integer, Rational, nan, oo, pi, zoo) +from sympy.core.parameters import _exp_is_pow +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Rel, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.functions.combinatorial.factorials import (factorial, factorial2, subfactorial) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.zeta_functions import zeta +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (Equivalent, false, true, Xor) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions import Identity +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices import SparseMatrix +from sympy.polys.polytools import factor +from sympy.series.limits import Limit +from sympy.series.order import O +from sympy.sets.sets import (Complement, FiniteSet, Interval, SymmetricDifference) +from sympy.stats import (Covariance, Expectation, Probability, Variance) +from sympy.stats.rv import RandomSymbol +from sympy.external import import_module +from sympy.physics.control.lti import TransferFunction, Series, Parallel, \ + Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback +from sympy.physics.units import second, joule +from sympy.polys import (Poly, rootof, RootSum, groebner, ring, field, ZZ, QQ, + ZZ_I, QQ_I, lex, grlex) +from sympy.geometry import Point, Circle, Polygon, Ellipse, Triangle +from sympy.tensor import NDimArray +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement + +from sympy.testing.pytest import raises, warns_deprecated_sympy + +from sympy.printing import sstr, sstrrepr, StrPrinter +from sympy.physics.quantum.trace import Tr + +x, y, z, w, t = symbols('x,y,z,w,t') +d = Dummy('d') + + +def test_printmethod(): + class R(Abs): + def _sympystr(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert sstr(R(x)) == "foo(x)" + + class R(Abs): + def _sympystr(self, printer): + return "foo" + assert sstr(R(x)) == "foo" + + +def test_Abs(): + assert str(Abs(x)) == "Abs(x)" + assert str(Abs(Rational(1, 6))) == "1/6" + assert str(Abs(Rational(-1, 6))) == "1/6" + + +def test_Add(): + assert str(x + y) == "x + y" + assert str(x + 1) == "x + 1" + assert str(x + x**2) == "x**2 + x" + assert str(Add(0, 1, evaluate=False)) == "0 + 1" + assert str(Add(0, 0, 1, evaluate=False)) == "0 + 0 + 1" + assert str(1.0*x) == "1.0*x" + assert str(5 + x + y + x*y + x**2 + y**2) == "x**2 + x*y + x + y**2 + y + 5" + assert str(1 + x + x**2/2 + x**3/3) == "x**3/3 + x**2/2 + x + 1" + assert str(2*x - 7*x**2 + 2 + 3*y) == "-7*x**2 + 2*x + 3*y + 2" + assert str(x - y) == "x - y" + assert str(2 - x) == "2 - x" + assert str(x - 2) == "x - 2" + assert str(x - y - z - w) == "-w + x - y - z" + assert str(x - z*y**2*z*w) == "-w*y**2*z**2 + x" + assert str(x - 1*y*x*y) == "-x*y**2 + x" + assert str(sin(x).series(x, 0, 15)) == "x - x**3/6 + x**5/120 - x**7/5040 + x**9/362880 - x**11/39916800 + x**13/6227020800 + O(x**15)" + assert str(Add(Add(-w, x, evaluate=False), Add(-y, z, evaluate=False), evaluate=False)) == "(-w + x) + (-y + z)" + assert str(Add(Add(-x, -y, evaluate=False), -z, evaluate=False)) == "-z + (-x - y)" + assert str(Add(Add(Add(-x, -y, evaluate=False), -z, evaluate=False), -t, evaluate=False)) == "-t + (-z + (-x - y))" + + +def test_Catalan(): + assert str(Catalan) == "Catalan" + + +def test_ComplexInfinity(): + assert str(zoo) == "zoo" + + +def test_Derivative(): + assert str(Derivative(x, y)) == "Derivative(x, y)" + assert str(Derivative(x**2, x, evaluate=False)) == "Derivative(x**2, x)" + assert str(Derivative( + x**2/y, x, y, evaluate=False)) == "Derivative(x**2/y, x, y)" + + +def test_dict(): + assert str({1: 1 + x}) == sstr({1: 1 + x}) == "{1: x + 1}" + assert str({1: x**2, 2: y*x}) in ("{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}") + assert sstr({1: x**2, 2: y*x}) == "{1: x**2, 2: x*y}" + + +def test_Dict(): + assert str(Dict({1: 1 + x})) == sstr({1: 1 + x}) == "{1: x + 1}" + assert str(Dict({1: x**2, 2: y*x})) in ( + "{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}") + assert sstr(Dict({1: x**2, 2: y*x})) == "{1: x**2, 2: x*y}" + + +def test_Dummy(): + assert str(d) == "_d" + assert str(d + x) == "_d + x" + + +def test_EulerGamma(): + assert str(EulerGamma) == "EulerGamma" + + +def test_Exp(): + assert str(E) == "E" + with _exp_is_pow(True): + assert str(exp(x)) == "E**x" + + +def test_factorial(): + n = Symbol('n', integer=True) + assert str(factorial(-2)) == "zoo" + assert str(factorial(0)) == "1" + assert str(factorial(7)) == "5040" + assert str(factorial(n)) == "factorial(n)" + assert str(factorial(2*n)) == "factorial(2*n)" + assert str(factorial(factorial(n))) == 'factorial(factorial(n))' + assert str(factorial(factorial2(n))) == 'factorial(factorial2(n))' + assert str(factorial2(factorial(n))) == 'factorial2(factorial(n))' + assert str(factorial2(factorial2(n))) == 'factorial2(factorial2(n))' + assert str(subfactorial(3)) == "2" + assert str(subfactorial(n)) == "subfactorial(n)" + assert str(subfactorial(2*n)) == "subfactorial(2*n)" + + +def test_Function(): + f = Function('f') + fx = f(x) + w = WildFunction('w') + assert str(f) == "f" + assert str(fx) == "f(x)" + assert str(w) == "w_" + + +def test_Geometry(): + assert sstr(Point(0, 0)) == 'Point2D(0, 0)' + assert sstr(Circle(Point(0, 0), 3)) == 'Circle(Point2D(0, 0), 3)' + assert sstr(Ellipse(Point(1, 2), 3, 4)) == 'Ellipse(Point2D(1, 2), 3, 4)' + assert sstr(Triangle(Point(1, 1), Point(7, 8), Point(0, -1))) == \ + 'Triangle(Point2D(1, 1), Point2D(7, 8), Point2D(0, -1))' + assert sstr(Polygon(Point(5, 6), Point(-2, -3), Point(0, 0), Point(4, 7))) == \ + 'Polygon(Point2D(5, 6), Point2D(-2, -3), Point2D(0, 0), Point2D(4, 7))' + assert sstr(Triangle(Point(0, 0), Point(1, 0), Point(0, 1)), sympy_integers=True) == \ + 'Triangle(Point2D(S(0), S(0)), Point2D(S(1), S(0)), Point2D(S(0), S(1)))' + assert sstr(Ellipse(Point(1, 2), 3, 4), sympy_integers=True) == \ + 'Ellipse(Point2D(S(1), S(2)), S(3), S(4))' + + +def test_GoldenRatio(): + assert str(GoldenRatio) == "GoldenRatio" + + +def test_Heaviside(): + assert str(Heaviside(x)) == str(Heaviside(x, S.Half)) == "Heaviside(x)" + assert str(Heaviside(x, 1)) == "Heaviside(x, 1)" + + +def test_TribonacciConstant(): + assert str(TribonacciConstant) == "TribonacciConstant" + + +def test_ImaginaryUnit(): + assert str(I) == "I" + + +def test_Infinity(): + assert str(oo) == "oo" + assert str(oo*I) == "oo*I" + + +def test_Integer(): + assert str(Integer(-1)) == "-1" + assert str(Integer(1)) == "1" + assert str(Integer(-3)) == "-3" + assert str(Integer(0)) == "0" + assert str(Integer(25)) == "25" + + +def test_Integral(): + assert str(Integral(sin(x), y)) == "Integral(sin(x), y)" + assert str(Integral(sin(x), (y, 0, 1))) == "Integral(sin(x), (y, 0, 1))" + + +def test_Interval(): + n = (S.NegativeInfinity, 1, 2, S.Infinity) + for i in range(len(n)): + for j in range(i + 1, len(n)): + for l in (True, False): + for r in (True, False): + ival = Interval(n[i], n[j], l, r) + assert S(str(ival)) == ival + + +def test_AccumBounds(): + a = Symbol('a', real=True) + assert str(AccumBounds(0, a)) == "AccumBounds(0, a)" + assert str(AccumBounds(0, 1)) == "AccumBounds(0, 1)" + + +def test_Lambda(): + assert str(Lambda(d, d**2)) == "Lambda(_d, _d**2)" + # issue 2908 + assert str(Lambda((), 1)) == "Lambda((), 1)" + assert str(Lambda((), x)) == "Lambda((), x)" + assert str(Lambda((x, y), x+y)) == "Lambda((x, y), x + y)" + assert str(Lambda(((x, y),), x+y)) == "Lambda(((x, y),), x + y)" + + +def test_Limit(): + assert str(Limit(sin(x)/x, x, y)) == "Limit(sin(x)/x, x, y, dir='+')" + assert str(Limit(1/x, x, 0)) == "Limit(1/x, x, 0, dir='+')" + assert str( + Limit(sin(x)/x, x, y, dir="-")) == "Limit(sin(x)/x, x, y, dir='-')" + + +def test_list(): + assert str([x]) == sstr([x]) == "[x]" + assert str([x**2, x*y + 1]) == sstr([x**2, x*y + 1]) == "[x**2, x*y + 1]" + assert str([x**2, [y + x]]) == sstr([x**2, [y + x]]) == "[x**2, [x + y]]" + + +def test_Matrix_str(): + M = Matrix([[x**+1, 1], [y, x + y]]) + assert str(M) == "Matrix([[x, 1], [y, x + y]])" + assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])" + M = Matrix([[1]]) + assert str(M) == sstr(M) == "Matrix([[1]])" + M = Matrix([[1, 2]]) + assert str(M) == sstr(M) == "Matrix([[1, 2]])" + M = Matrix() + assert str(M) == sstr(M) == "Matrix(0, 0, [])" + M = Matrix(0, 1, lambda i, j: 0) + assert str(M) == sstr(M) == "Matrix(0, 1, [])" + + +def test_Mul(): + assert str(x/y) == "x/y" + assert str(y/x) == "y/x" + assert str(x/y/z) == "x/(y*z)" + assert str((x + 1)/(y + 2)) == "(x + 1)/(y + 2)" + assert str(2*x/3) == '2*x/3' + assert str(-2*x/3) == '-2*x/3' + assert str(-1.0*x) == '-1.0*x' + assert str(1.0*x) == '1.0*x' + assert str(Mul(0, 1, evaluate=False)) == '0*1' + assert str(Mul(1, 0, evaluate=False)) == '1*0' + assert str(Mul(1, 1, evaluate=False)) == '1*1' + assert str(Mul(1, 1, 1, evaluate=False)) == '1*1*1' + assert str(Mul(1, 2, evaluate=False)) == '1*2' + assert str(Mul(1, S.Half, evaluate=False)) == '1*(1/2)' + assert str(Mul(1, 1, S.Half, evaluate=False)) == '1*1*(1/2)' + assert str(Mul(1, 1, 2, 3, x, evaluate=False)) == '1*1*2*3*x' + assert str(Mul(1, -1, evaluate=False)) == '1*(-1)' + assert str(Mul(-1, 1, evaluate=False)) == '-1*1' + assert str(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == '4*3*2*1*0*y*x' + assert str(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == '4*3*2*(z + 1)*0*y*x' + assert str(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == '(2/3)*(5/7)' + # For issue 14160 + assert str(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + # issue 21537 + assert str(Mul(x, Pow(1/y, -1, evaluate=False), evaluate=False)) == 'x/(1/y)' + + # Issue 24108 + from sympy.core.parameters import evaluate + with evaluate(False): + assert str(Mul(Pow(Integer(2), Integer(-1)), Add(Integer(-1), Mul(Integer(-1), Integer(1))))) == "(-1 - 1*1)/2" + + class CustomClass1(Expr): + is_commutative = True + + class CustomClass2(Expr): + is_commutative = True + cc1 = CustomClass1() + cc2 = CustomClass2() + assert str(Rational(2)*cc1) == '2*CustomClass1()' + assert str(cc1*Rational(2)) == '2*CustomClass1()' + assert str(cc1*Float("1.5")) == '1.5*CustomClass1()' + assert str(cc2*Rational(2)) == '2*CustomClass2()' + assert str(cc2*Rational(2)*cc1) == '2*CustomClass1()*CustomClass2()' + assert str(cc1*Rational(2)*cc2) == '2*CustomClass1()*CustomClass2()' + + +def test_NaN(): + assert str(nan) == "nan" + + +def test_NegativeInfinity(): + assert str(-oo) == "-oo" + +def test_Order(): + assert str(O(x)) == "O(x)" + assert str(O(x**2)) == "O(x**2)" + assert str(O(x*y)) == "O(x*y, x, y)" + assert str(O(x, x)) == "O(x)" + assert str(O(x, (x, 0))) == "O(x)" + assert str(O(x, (x, oo))) == "O(x, (x, oo))" + assert str(O(x, x, y)) == "O(x, x, y)" + assert str(O(x, x, y)) == "O(x, x, y)" + assert str(O(x, (x, oo), (y, oo))) == "O(x, (x, oo), (y, oo))" + + +def test_Permutation_Cycle(): + from sympy.combinatorics import Permutation, Cycle + + # general principle: economically, canonically show all moved elements + # and the size of the permutation. + + for p, s in [ + (Cycle(), + '()'), + (Cycle(2), + '(2)'), + (Cycle(2, 1), + '(1 2)'), + (Cycle(1, 2)(5)(6, 7)(10), + '(1 2)(6 7)(10)'), + (Cycle(3, 4)(1, 2)(3, 4), + '(1 2)(4)'), + ]: + assert sstr(p) == s + + for p, s in [ + (Permutation([]), + 'Permutation([])'), + (Permutation([], size=1), + 'Permutation([0])'), + (Permutation([], size=2), + 'Permutation([0, 1])'), + (Permutation([], size=10), + 'Permutation([], size=10)'), + (Permutation([1, 0, 2]), + 'Permutation([1, 0, 2])'), + (Permutation([1, 0, 2, 3, 4, 5]), + 'Permutation([1, 0], size=6)'), + (Permutation([1, 0, 2, 3, 4, 5], size=10), + 'Permutation([1, 0], size=10)'), + ]: + assert sstr(p, perm_cyclic=False) == s + + for p, s in [ + (Permutation([]), + '()'), + (Permutation([], size=1), + '(0)'), + (Permutation([], size=2), + '(1)'), + (Permutation([], size=10), + '(9)'), + (Permutation([1, 0, 2]), + '(2)(0 1)'), + (Permutation([1, 0, 2, 3, 4, 5]), + '(5)(0 1)'), + (Permutation([1, 0, 2, 3, 4, 5], size=10), + '(9)(0 1)'), + (Permutation([0, 1, 3, 2, 4, 5], size=10), + '(9)(2 3)'), + ]: + assert sstr(p) == s + + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert sstr(Permutation([1, 0, 2])) == 'Permutation([1, 0, 2])' + Permutation.print_cyclic = old_print_cyclic + +def test_Pi(): + assert str(pi) == "pi" + + +def test_Poly(): + assert str(Poly(0, x)) == "Poly(0, x, domain='ZZ')" + assert str(Poly(1, x)) == "Poly(1, x, domain='ZZ')" + assert str(Poly(x, x)) == "Poly(x, x, domain='ZZ')" + + assert str(Poly(2*x + 1, x)) == "Poly(2*x + 1, x, domain='ZZ')" + assert str(Poly(2*x - 1, x)) == "Poly(2*x - 1, x, domain='ZZ')" + + assert str(Poly(-1, x)) == "Poly(-1, x, domain='ZZ')" + assert str(Poly(-x, x)) == "Poly(-x, x, domain='ZZ')" + + assert str(Poly(-2*x + 1, x)) == "Poly(-2*x + 1, x, domain='ZZ')" + assert str(Poly(-2*x - 1, x)) == "Poly(-2*x - 1, x, domain='ZZ')" + + assert str(Poly(x - 1, x)) == "Poly(x - 1, x, domain='ZZ')" + assert str(Poly(2*x + x**5, x)) == "Poly(x**5 + 2*x, x, domain='ZZ')" + + assert str(Poly(3**(2*x), 3**x)) == "Poly((3**x)**2, 3**x, domain='ZZ')" + assert str(Poly((x**2)**x)) == "Poly(((x**2)**x), (x**2)**x, domain='ZZ')" + + assert str(Poly((x + y)**3, (x + y), expand=False) + ) == "Poly((x + y)**3, x + y, domain='ZZ')" + assert str(Poly((x - 1)**2, (x - 1), expand=False) + ) == "Poly((x - 1)**2, x - 1, domain='ZZ')" + + assert str( + Poly(x**2 + 1 + y, x)) == "Poly(x**2 + y + 1, x, domain='ZZ[y]')" + assert str( + Poly(x**2 - 1 + y, x)) == "Poly(x**2 + y - 1, x, domain='ZZ[y]')" + + assert str(Poly(x**2 + I*x, x)) == "Poly(x**2 + I*x, x, domain='ZZ_I')" + assert str(Poly(x**2 - I*x, x)) == "Poly(x**2 - I*x, x, domain='ZZ_I')" + + assert str(Poly(-x*y*z + x*y - 1, x, y, z) + ) == "Poly(-x*y*z + x*y - 1, x, y, z, domain='ZZ')" + assert str(Poly(-w*x**21*y**7*z + (1 + w)*z**3 - 2*x*z + 1, x, y, z)) == \ + "Poly(-w*x**21*y**7*z - 2*x*z + (w + 1)*z**3 + 1, x, y, z, domain='ZZ[w]')" + + assert str(Poly(x**2 + 1, x, modulus=2)) == "Poly(x**2 + 1, x, modulus=2)" + assert str(Poly(2*x**2 + 3*x + 4, x, modulus=17)) == "Poly(2*x**2 + 3*x + 4, x, modulus=17)" + + +def test_PolyRing(): + assert str(ring("x", ZZ, lex)[0]) == "Polynomial ring in x over ZZ with lex order" + assert str(ring("x,y", QQ, grlex)[0]) == "Polynomial ring in x, y over QQ with grlex order" + assert str(ring("x,y,z", ZZ["t"], lex)[0]) == "Polynomial ring in x, y, z over ZZ[t] with lex order" + + +def test_FracField(): + assert str(field("x", ZZ, lex)[0]) == "Rational function field in x over ZZ with lex order" + assert str(field("x,y", QQ, grlex)[0]) == "Rational function field in x, y over QQ with grlex order" + assert str(field("x,y,z", ZZ["t"], lex)[0]) == "Rational function field in x, y, z over ZZ[t] with lex order" + + +def test_PolyElement(): + Ruv, u,v = ring("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Ruv) + Rx_zzi, xz = ring("x", ZZ_I) + + assert str(x - x) == "0" + assert str(x - 1) == "x - 1" + assert str(x + 1) == "x + 1" + assert str(x**2) == "x**2" + + assert str((u**2 + 3*u*v + 1)*x**2*y + u + 1) == "(u**2 + 3*u*v + 1)*x**2*y + u + 1" + assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x" + assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1" + assert str((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == "-(u**2 - 3*u*v + 1)*x**2*y - (u + 1)*x - 1" + + assert str(-(v**2 + v + 1)*x + 3*u*v + 1) == "-(v**2 + v + 1)*x + 3*u*v + 1" + assert str(-(v**2 + v + 1)*x - 3*u*v + 1) == "-(v**2 + v + 1)*x - 3*u*v + 1" + + assert str((1+I)*xz + 2) == "(1 + 1*I)*x + (2 + 0*I)" + + +def test_FracElement(): + Fuv, u,v = field("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Fuv) + Rx_zzi, xz = field("x", QQ_I) + i = QQ_I(0, 1) + + assert str(x - x) == "0" + assert str(x - 1) == "x - 1" + assert str(x + 1) == "x + 1" + + assert str(x/3) == "x/3" + assert str(x/z) == "x/z" + assert str(x*y/z) == "x*y/z" + assert str(x/(z*t)) == "x/(z*t)" + assert str(x*y/(z*t)) == "x*y/(z*t)" + + assert str((x - 1)/y) == "(x - 1)/y" + assert str((x + 1)/y) == "(x + 1)/y" + assert str((-x - 1)/y) == "(-x - 1)/y" + assert str((x + 1)/(y*z)) == "(x + 1)/(y*z)" + assert str(-y/(x + 1)) == "-y/(x + 1)" + assert str(y*z/(x + 1)) == "y*z/(x + 1)" + + assert str(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - 1)" + assert str(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - u*v*t - 1)" + + assert str((1+i)/xz) == "(1 + 1*I)/x" + assert str(((1+i)*xz - i)/xz) == "((1 + 1*I)*x + (0 + -1*I))/x" + + +def test_GaussianInteger(): + assert str(ZZ_I(1, 0)) == "1" + assert str(ZZ_I(-1, 0)) == "-1" + assert str(ZZ_I(0, 1)) == "I" + assert str(ZZ_I(0, -1)) == "-I" + assert str(ZZ_I(0, 2)) == "2*I" + assert str(ZZ_I(0, -2)) == "-2*I" + assert str(ZZ_I(1, 1)) == "1 + I" + assert str(ZZ_I(-1, -1)) == "-1 - I" + assert str(ZZ_I(-1, -2)) == "-1 - 2*I" + + +def test_GaussianRational(): + assert str(QQ_I(1, 0)) == "1" + assert str(QQ_I(QQ(2, 3), 0)) == "2/3" + assert str(QQ_I(0, QQ(2, 3))) == "2*I/3" + assert str(QQ_I(QQ(1, 2), QQ(-2, 3))) == "1/2 - 2*I/3" + + +def test_Pow(): + assert str(x**-1) == "1/x" + assert str(x**-2) == "x**(-2)" + assert str(x**2) == "x**2" + assert str((x + y)**-1) == "1/(x + y)" + assert str((x + y)**-2) == "(x + y)**(-2)" + assert str((x + y)**2) == "(x + y)**2" + assert str((x + y)**(1 + x)) == "(x + y)**(x + 1)" + assert str(x**Rational(1, 3)) == "x**(1/3)" + assert str(1/x**Rational(1, 3)) == "x**(-1/3)" + assert str(sqrt(sqrt(x))) == "x**(1/4)" + # not the same as x**-1 + assert str(x**-1.0) == 'x**(-1.0)' + # see issue #2860 + assert str(Pow(S(2), -1.0, evaluate=False)) == '2**(-1.0)' + + +def test_sqrt(): + assert str(sqrt(x)) == "sqrt(x)" + assert str(sqrt(x**2)) == "sqrt(x**2)" + assert str(1/sqrt(x)) == "1/sqrt(x)" + assert str(1/sqrt(x**2)) == "1/sqrt(x**2)" + assert str(y/sqrt(x)) == "y/sqrt(x)" + assert str(x**0.5) == "x**0.5" + assert str(1/x**0.5) == "x**(-0.5)" + + +def test_Rational(): + n1 = Rational(1, 4) + n2 = Rational(1, 3) + n3 = Rational(2, 4) + n4 = Rational(2, -4) + n5 = Rational(0) + n7 = Rational(3) + n8 = Rational(-3) + assert str(n1*n2) == "1/12" + assert str(n1*n2) == "1/12" + assert str(n3) == "1/2" + assert str(n1*n3) == "1/8" + assert str(n1 + n3) == "3/4" + assert str(n1 + n2) == "7/12" + assert str(n1 + n4) == "-1/4" + assert str(n4*n4) == "1/4" + assert str(n4 + n2) == "-1/6" + assert str(n4 + n5) == "-1/2" + assert str(n4*n5) == "0" + assert str(n3 + n4) == "0" + assert str(n1**n7) == "1/64" + assert str(n2**n7) == "1/27" + assert str(n2**n8) == "27" + assert str(n7**n8) == "1/27" + assert str(Rational("-25")) == "-25" + assert str(Rational("1.25")) == "5/4" + assert str(Rational("-2.6e-2")) == "-13/500" + assert str(S("25/7")) == "25/7" + assert str(S("-123/569")) == "-123/569" + assert str(S("0.1[23]", rational=1)) == "61/495" + assert str(S("5.1[666]", rational=1)) == "31/6" + assert str(S("-5.1[666]", rational=1)) == "-31/6" + assert str(S("0.[9]", rational=1)) == "1" + assert str(S("-0.[9]", rational=1)) == "-1" + + assert str(sqrt(Rational(1, 4))) == "1/2" + assert str(sqrt(Rational(1, 36))) == "1/6" + + assert str((123**25) ** Rational(1, 25)) == "123" + assert str((123**25 + 1)**Rational(1, 25)) != "123" + assert str((123**25 - 1)**Rational(1, 25)) != "123" + assert str((123**25 - 1)**Rational(1, 25)) != "122" + + assert str(sqrt(Rational(81, 36))**3) == "27/8" + assert str(1/sqrt(Rational(81, 36))**3) == "8/27" + + assert str(sqrt(-4)) == str(2*I) + assert str(2**Rational(1, 10**10)) == "2**(1/10000000000)" + + assert sstr(Rational(2, 3), sympy_integers=True) == "S(2)/3" + x = Symbol("x") + assert sstr(x**Rational(2, 3), sympy_integers=True) == "x**(S(2)/3)" + assert sstr(Eq(x, Rational(2, 3)), sympy_integers=True) == "Eq(x, S(2)/3)" + assert sstr(Limit(x, x, Rational(7, 2)), sympy_integers=True) == \ + "Limit(x, x, S(7)/2, dir='+')" + + +def test_Float(): + # NOTE dps is the whole number of decimal digits + assert str(Float('1.23', dps=1 + 2)) == '1.23' + assert str(Float('1.23456789', dps=1 + 8)) == '1.23456789' + assert str( + Float('1.234567890123456789', dps=1 + 18)) == '1.234567890123456789' + assert str(pi.evalf(1 + 2)) == '3.14' + assert str(pi.evalf(1 + 14)) == '3.14159265358979' + assert str(pi.evalf(1 + 64)) == ('3.141592653589793238462643383279' + '5028841971693993751058209749445923') + assert str(pi.round(-1)) == '0.0' + assert str((pi**400 - (pi**400).round(1)).n(2)) == '-0.e+88' + assert sstr(Float("100"), full_prec=False, min=-2, max=2) == '1.0e+2' + assert sstr(Float("100"), full_prec=False, min=-2, max=3) == '100.0' + assert sstr(Float("0.1"), full_prec=False, min=-2, max=3) == '0.1' + assert sstr(Float("0.099"), min=-2, max=3) == '9.90000000000000e-2' + + +def test_Relational(): + assert str(Rel(x, y, "<")) == "x < y" + assert str(Rel(x + y, y, "==")) == "Eq(x + y, y)" + assert str(Rel(x, y, "!=")) == "Ne(x, y)" + assert str(Eq(x, 1) | Eq(x, 2)) == "Eq(x, 1) | Eq(x, 2)" + assert str(Ne(x, 1) & Ne(x, 2)) == "Ne(x, 1) & Ne(x, 2)" + + +def test_AppliedBinaryRelation(): + assert str(Q.eq(x, y)) == "Q.eq(x, y)" + assert str(Q.ne(x, y)) == "Q.ne(x, y)" + + +def test_CRootOf(): + assert str(rootof(x**5 + 2*x - 1, 0)) == "CRootOf(x**5 + 2*x - 1, 0)" + + +def test_RootSum(): + f = x**5 + 2*x - 1 + + assert str( + RootSum(f, Lambda(z, z), auto=False)) == "RootSum(x**5 + 2*x - 1)" + assert str(RootSum(f, Lambda( + z, z**2), auto=False)) == "RootSum(x**5 + 2*x - 1, Lambda(z, z**2))" + + +def test_GroebnerBasis(): + assert str(groebner( + [], x, y)) == "GroebnerBasis([], x, y, domain='ZZ', order='lex')" + + F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1] + + assert str(groebner(F, order='grlex')) == \ + "GroebnerBasis([x**2 - x - 3*y + 1, y**2 - 2*x + y - 1], x, y, domain='ZZ', order='grlex')" + assert str(groebner(F, order='lex')) == \ + "GroebnerBasis([2*x - y**2 - y + 1, y**4 + 2*y**3 - 3*y**2 - 16*y + 7], x, y, domain='ZZ', order='lex')" + +def test_set(): + assert sstr(set()) == 'set()' + assert sstr(frozenset()) == 'frozenset()' + + assert sstr({1}) == '{1}' + assert sstr(frozenset([1])) == 'frozenset({1})' + assert sstr({1, 2, 3}) == '{1, 2, 3}' + assert sstr(frozenset([1, 2, 3])) == 'frozenset({1, 2, 3})' + + assert sstr( + {1, x, x**2, x**3, x**4}) == '{1, x, x**2, x**3, x**4}' + assert sstr( + frozenset([1, x, x**2, x**3, x**4])) == 'frozenset({1, x, x**2, x**3, x**4})' + + +def test_SparseMatrix(): + M = SparseMatrix([[x**+1, 1], [y, x + y]]) + assert str(M) == "Matrix([[x, 1], [y, x + y]])" + assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])" + + +def test_Sum(): + assert str(summation(cos(3*z), (z, x, y))) == "Sum(cos(3*z), (z, x, y))" + assert str(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + "Sum(x*y**2, (x, -2, 2), (y, -5, 5))" + + +def test_Symbol(): + assert str(y) == "y" + assert str(x) == "x" + e = x + assert str(e) == "x" + + +def test_tuple(): + assert str((x,)) == sstr((x,)) == "(x,)" + assert str((x + y, 1 + x)) == sstr((x + y, 1 + x)) == "(x + y, x + 1)" + assert str((x + y, ( + 1 + x, x**2))) == sstr((x + y, (1 + x, x**2))) == "(x + y, (x + 1, x**2))" + + +def test_Series_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Series(tf1, tf2)) == \ + "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))" + assert str(Series(tf1, tf2, tf3)) == \ + "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))" + assert str(Series(-tf2, tf1)) == \ + "Series(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))" + + +def test_MIMOSeries_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + assert str(MIMOSeries(tfm_1, tfm_2)) == \ + "MIMOSeries(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\ + "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\ + "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\ + "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))" + + +def test_TransferFunction_str(): + tf1 = TransferFunction(x - 1, x + 1, x) + assert str(tf1) == "TransferFunction(x - 1, x + 1, x)" + tf2 = TransferFunction(x + 1, 2 - y, x) + assert str(tf2) == "TransferFunction(x + 1, 2 - y, x)" + tf3 = TransferFunction(y, y**2 + 2*y + 3, y) + assert str(tf3) == "TransferFunction(y, y**2 + 2*y + 3, y)" + + +def test_Parallel_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Parallel(tf1, tf2)) == \ + "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))" + assert str(Parallel(tf1, tf2, tf3)) == \ + "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))" + assert str(Parallel(-tf2, tf1)) == \ + "Parallel(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))" + + +def test_MIMOParallel_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + assert str(MIMOParallel(tfm_1, tfm_2)) == \ + "MIMOParallel(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\ + "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\ + "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\ + "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))" + + +def test_Feedback_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Feedback(tf1*tf2, tf3)) == \ + "Feedback(Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), " \ + "TransferFunction(t*x**2 - t**w*x + w, t - y, y), -1)" + assert str(Feedback(tf1, TransferFunction(1, 1, y), 1)) == \ + "Feedback(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(1, 1, y), 1)" + + +def test_MIMOFeedback_str(): + tf1 = TransferFunction(x**2 - y**3, y - z, x) + tf2 = TransferFunction(y - x, z + y, x) + tfm_1 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + tfm_2 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + assert (str(MIMOFeedback(tfm_1, tfm_2)) \ + == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x))," \ + " (TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \ + "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), " \ + "TransferFunction(-x + y, y + z, x)), (TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), -1)") + assert (str(MIMOFeedback(tfm_1, tfm_2, 1)) \ + == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)), " \ + "(TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \ + "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)), "\ + "(TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), 1)") + + +def test_TransferFunctionMatrix_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(TransferFunctionMatrix([[tf1], [tf2]])) == \ + "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y),), (TransferFunction(x - y, x + y, y),)))" + assert str(TransferFunctionMatrix([[tf1, tf2], [tf3, tf2]])) == \ + "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), (TransferFunction(t*x**2 - t**w*x + w, t - y, y), TransferFunction(x - y, x + y, y))))" + + +def test_Quaternion_str_printer(): + q = Quaternion(x, y, z, t) + assert str(q) == "x + y*i + z*j + t*k" + q = Quaternion(x,y,z,x*t) + assert str(q) == "x + y*i + z*j + t*x*k" + q = Quaternion(x,y,z,x+t) + assert str(q) == "x + y*i + z*j + (t + x)*k" + + +def test_Quantity_str(): + assert sstr(second, abbrev=True) == "s" + assert sstr(joule, abbrev=True) == "J" + assert str(second) == "second" + assert str(joule) == "joule" + + +def test_wild_str(): + # Check expressions containing Wild not causing infinite recursion + w = Wild('x') + assert str(w + 1) == 'x_ + 1' + assert str(exp(2**w) + 5) == 'exp(2**x_) + 5' + assert str(3*w + 1) == '3*x_ + 1' + assert str(1/w + 1) == '1 + 1/x_' + assert str(w**2 + 1) == 'x_**2 + 1' + assert str(1/(1 - w)) == '1/(1 - x_)' + + +def test_wild_matchpy(): + from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar + + matchpy = import_module("matchpy") + + if matchpy is None: + return + + wd = WildDot('w_') + wp = WildPlus('w__') + ws = WildStar('w___') + + assert str(wd) == 'w_' + assert str(wp) == 'w__' + assert str(ws) == 'w___' + + assert str(wp/ws + 2**wd) == '2**w_ + w__/w___' + assert str(sin(wd)*cos(wp)*sqrt(ws)) == 'sqrt(w___)*sin(w_)*cos(w__)' + + +def test_zeta(): + assert str(zeta(3)) == "zeta(3)" + + +def test_issue_3101(): + e = x - y + a = str(e) + b = str(e) + assert a == b + + +def test_issue_3103(): + e = -2*sqrt(x) - y/sqrt(x)/2 + assert str(e) not in ["(-2)*x**1/2(-1/2)*x**(-1/2)*y", + "-2*x**1/2(-1/2)*x**(-1/2)*y", "-2*x**1/2-1/2*x**-1/2*w"] + assert str(e) == "-2*sqrt(x) - y/(2*sqrt(x))" + + +def test_issue_4021(): + e = Integral(x, x) + 1 + assert str(e) == 'Integral(x, x) + 1' + + +def test_sstrrepr(): + assert sstr('abc') == 'abc' + assert sstrrepr('abc') == "'abc'" + + e = ['a', 'b', 'c', x] + assert sstr(e) == "[a, b, c, x]" + assert sstrrepr(e) == "['a', 'b', 'c', x]" + + +def test_infinity(): + assert sstr(oo*I) == "oo*I" + + +def test_full_prec(): + assert sstr(S("0.3"), full_prec=True) == "0.300000000000000" + assert sstr(S("0.3"), full_prec="auto") == "0.300000000000000" + assert sstr(S("0.3"), full_prec=False) == "0.3" + assert sstr(S("0.3")*x, full_prec=True) in [ + "0.300000000000000*x", + "x*0.300000000000000" + ] + assert sstr(S("0.3")*x, full_prec="auto") in [ + "0.3*x", + "x*0.3" + ] + assert sstr(S("0.3")*x, full_prec=False) in [ + "0.3*x", + "x*0.3" + ] + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + assert sstr(A*B*C**-1) == "A*B*C**(-1)" + assert sstr(C**-1*A*B) == "C**(-1)*A*B" + assert sstr(A*C**-1*B) == "A*C**(-1)*B" + assert sstr(sqrt(A)) == "sqrt(A)" + assert sstr(1/sqrt(A)) == "A**(-1/2)" + + +def test_empty_printer(): + str_printer = StrPrinter() + assert str_printer.emptyPrinter("foo") == "foo" + assert str_printer.emptyPrinter(x*y) == "x*y" + assert str_printer.emptyPrinter(32) == "32" + +def test_decimal_printer(): + dec_printer = StrPrinter(settings={"dps":3}) + f = Function('f') + assert dec_printer.doprint(f(1.329294)) == "f(1.33)" + + +def test_settings(): + raises(TypeError, lambda: sstr(S(4), method="garbage")) + + +def test_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + X = Normal('x1', 0, 1) + assert str(where(X > 0)) == "Domain: (0 < x1) & (x1 < oo)" + + D = Die('d1', 6) + assert str(where(D > 4)) == "Domain: Eq(d1, 5) | Eq(d1, 6)" + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert str(pspace(Tuple(A, B)).domain) == "Domain: (0 <= a) & (0 <= b) & (a < oo) & (b < oo)" + + +def test_FiniteSet(): + assert str(FiniteSet(*range(1, 51))) == ( + '{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,' + ' 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,' + ' 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50}' + ) + assert str(FiniteSet(*range(1, 6))) == '{1, 2, 3, 4, 5}' + assert str(FiniteSet(*[x*y, x**2])) == '{x**2, x*y}' + assert str(FiniteSet(FiniteSet(FiniteSet(x, y), 5), FiniteSet(x,y), 5) + ) == 'FiniteSet(5, FiniteSet(5, {x, y}), {x, y})' + + +def test_Partition(): + assert str(Partition(FiniteSet(x, y), {z})) == 'Partition({z}, {x, y})' + +def test_UniversalSet(): + assert str(S.UniversalSet) == 'UniversalSet' + + +def test_PrettyPoly(): + F = QQ.frac_field(x, y) + R = QQ[x, y] + assert sstr(F.convert(x/(x + y))) == sstr(x/(x + y)) + assert sstr(R.convert(x + y)) == sstr(x + y) + + +def test_categories(): + from sympy.categories import (Object, NamedMorphism, + IdentityMorphism, Category) + + A = Object("A") + B = Object("B") + + f = NamedMorphism(A, B, "f") + id_A = IdentityMorphism(A) + + K = Category("K") + + assert str(A) == 'Object("A")' + assert str(f) == 'NamedMorphism(Object("A"), Object("B"), "f")' + assert str(id_A) == 'IdentityMorphism(Object("A"))' + + assert str(K) == 'Category("K")' + + +def test_Tr(): + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert str(t) == 'Tr(A*B)' + + +def test_issue_6387(): + assert str(factor(-3.0*z + 3)) == '-3.0*(1.0*z - 1.0)' + + +def test_MatMul_MatAdd(): + X, Y = MatrixSymbol("X", 2, 2), MatrixSymbol("Y", 2, 2) + assert str(2*(X + Y)) == "2*X + 2*Y" + + assert str(I*X) == "I*X" + assert str(-I*X) == "-I*X" + assert str((1 + I)*X) == '(1 + I)*X' + assert str(-(1 + I)*X) == '(-1 - I)*X' + assert str(MatAdd(MatAdd(X, Y), MatAdd(X, Y))) == '(X + Y) + (X + Y)' + + +def test_MatrixSlice(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + assert str(MatrixSlice(X, (None, None, None), (None, None, None))) == 'X[:, :]' + assert str(X[x:x + 1, y:y + 1]) == 'X[x:x + 1, y:y + 1]' + assert str(X[x:x + 1:2, y:y + 1:2]) == 'X[x:x + 1:2, y:y + 1:2]' + assert str(X[:x, y:]) == 'X[:x, y:]' + assert str(X[:x, y:]) == 'X[:x, y:]' + assert str(X[x:, :y]) == 'X[x:, :y]' + assert str(X[x:y, z:w]) == 'X[x:y, z:w]' + assert str(X[x:y:t, w:t:x]) == 'X[x:y:t, w:t:x]' + assert str(X[x::y, t::w]) == 'X[x::y, t::w]' + assert str(X[:x:y, :t:w]) == 'X[:x:y, :t:w]' + assert str(X[::x, ::y]) == 'X[::x, ::y]' + assert str(MatrixSlice(X, (0, None, None), (0, None, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (None, n, None), (None, n, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (0, n, None), (0, n, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (0, n, 2), (0, n, 2))) == 'X[::2, ::2]' + assert str(X[1:2:3, 4:5:6]) == 'X[1:2:3, 4:5:6]' + assert str(X[1:3:5, 4:6:8]) == 'X[1:3:5, 4:6:8]' + assert str(X[1:10:2]) == 'X[1:10:2, :]' + assert str(Y[:5, 1:9:2]) == 'Y[:5, 1:9:2]' + assert str(Y[:5, 1:10:2]) == 'Y[:5, 1::2]' + assert str(Y[5, :5:2]) == 'Y[5:6, :5:2]' + assert str(X[0:1, 0:1]) == 'X[:1, :1]' + assert str(X[0:1:2, 0:1:2]) == 'X[:1:2, :1:2]' + assert str((Y + Z)[2:, 2:]) == '(Y + Z)[2:, 2:]' + +def test_true_false(): + assert str(true) == repr(true) == sstr(true) == "True" + assert str(false) == repr(false) == sstr(false) == "False" + +def test_Equivalent(): + assert str(Equivalent(y, x)) == "Equivalent(x, y)" + +def test_Xor(): + assert str(Xor(y, x, evaluate=False)) == "x ^ y" + +def test_Complement(): + assert str(Complement(S.Reals, S.Naturals)) == 'Complement(Reals, Naturals)' + +def test_SymmetricDifference(): + assert str(SymmetricDifference(Interval(2, 3), Interval(3, 4),evaluate=False)) == \ + 'SymmetricDifference(Interval(2, 3), Interval(3, 4))' + + +def test_UnevaluatedExpr(): + a, b = symbols("a b") + expr1 = 2*UnevaluatedExpr(a+b) + assert str(expr1) == "2*(a + b)" + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(str(A[0, 0]) == "A[0, 0]") + assert(str(3 * A[0, 0]) == "3*A[0, 0]") + + F = C[0, 0].subs(C, A - B) + assert str(F) == "(A - B)[0, 0]" + + +def test_MatrixSymbol_printing(): + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + + assert str(A - A*B - B) == "A - A*B - B" + assert str(A*B - (A+B)) == "-A + A*B - B" + assert str(A**(-1)) == "A**(-1)" + assert str(A**3) == "A**3" + + +def test_MatrixExpressions(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + + assert str(X) == "X" + + # Apply function elementwise (`ElementwiseApplyFunc`): + + expr = (X.T*X).applyfunc(sin) + assert str(expr) == 'Lambda(_d, sin(_d)).(X.T*X)' + + lamda = Lambda(x, 1/x) + expr = (n*X).applyfunc(lamda) + assert str(expr) == 'Lambda(x, 1/x).(n*X)' + + +def test_Subs_printing(): + assert str(Subs(x, (x,), (1,))) == 'Subs(x, x, 1)' + assert str(Subs(x + y, (x, y), (1, 2))) == 'Subs(x + y, (x, y), (1, 2))' + + +def test_issue_15716(): + e = Integral(factorial(x), (x, -oo, oo)) + assert e.as_terms() == ([(e, ((1.0, 0.0), (1,), ()))], [e]) + + +def test_str_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert str(Identity(4)) == 'I' + assert str(ZeroMatrix(2, 2)) == '0' + assert str(OneMatrix(2, 2)) == '1' + + +def test_issue_14567(): + assert factorial(Sum(-1, (x, 0, 0))) + y # doesn't raise an error + + +def test_issue_21823(): + assert str(Partition([1, 2])) == 'Partition({1, 2})' + assert str(Partition({1, 2})) == 'Partition({1, 2})' + + +def test_issue_22689(): + assert str(Mul(Pow(x,-2, evaluate=False), Pow(3,-1,evaluate=False), evaluate=False)) == "1/(x**2*3)" + + +def test_issue_21119_21460(): + ss = lambda x: str(S(x, evaluate=False)) + assert ss('4/2') == '4/2' + assert ss('4/-2') == '4/(-2)' + assert ss('-4/2') == '-4/2' + assert ss('-4/-2') == '-4/(-2)' + assert ss('-2*3/-1') == '-2*3/(-1)' + assert ss('-2*3/-1/2') == '-2*3/(-1*2)' + assert ss('4/2/1') == '4/(2*1)' + assert ss('-2/-1/2') == '-2/(-1*2)' + assert ss('2*3*4**(-2*3)') == '2*3/4**(2*3)' + assert ss('2*3*1*4**(-2*3)') == '2*3*1/4**(2*3)' + + +def test_Str(): + from sympy.core.symbol import Str + assert str(Str('x')) == 'x' + assert sstrrepr(Str('x')) == "Str('x')" + + +def test_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert str(m) == "M" + p = Patch('P', m) + assert str(p) == "P" + rect = CoordSystem('rect', p, [x, y]) + assert str(rect) == "rect" + b = BaseScalarField(rect, 0) + assert str(b) == "x" + +def test_NDimArray(): + assert sstr(NDimArray(1.0), full_prec=True) == '1.00000000000000' + assert sstr(NDimArray(1.0), full_prec=False) == '1.0' + assert sstr(NDimArray([1.0, 2.0]), full_prec=True) == '[1.00000000000000, 2.00000000000000]' + assert sstr(NDimArray([1.0, 2.0]), full_prec=False) == '[1.0, 2.0]' + assert sstr(NDimArray([], (0,))) == 'ImmutableDenseNDimArray([], (0,))' + assert sstr(NDimArray([], (0, 0))) == 'ImmutableDenseNDimArray([], (0, 0))' + assert sstr(NDimArray([], (0, 1))) == 'ImmutableDenseNDimArray([], (0, 1))' + assert sstr(NDimArray([], (1, 0))) == 'ImmutableDenseNDimArray([], (1, 0))' + +def test_Predicate(): + assert sstr(Q.even) == 'Q.even' + +def test_AppliedPredicate(): + assert sstr(Q.even(x)) == 'Q.even(x)' + +def test_printing_str_array_expressions(): + assert sstr(ArraySymbol("A", (2, 3, 4))) == "A" + assert sstr(ArrayElement("A", (2, 1/(1-x), 0))) == "A[2, 1/(1 - x), 0]" + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + assert sstr(ArrayElement(M*N, [x, 0])) == "(M*N)[x, 0]" + +def test_printing_stats(): + # issue 24132 + x = RandomSymbol("x") + y = RandomSymbol("y") + z1 = Probability(x > 0)*Identity(2) + z2 = Expectation(x)*Identity(2) + z3 = Variance(x)*Identity(2) + z4 = Covariance(x, y) * Identity(2) + + assert str(z1) == "Probability(x > 0)*I" + assert str(z2) == "Expectation(x)*I" + assert str(z3) == "Variance(x)*I" + assert str(z4) == "Covariance(x, y)*I" + assert z1.is_commutative == False + assert z2.is_commutative == False + assert z3.is_commutative == False + assert z4.is_commutative == False + assert z2._eval_is_commutative() == False + assert z3._eval_is_commutative() == False + assert z4._eval_is_commutative() == False diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tableform.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tableform.py new file mode 100644 index 0000000000000000000000000000000000000000..05802dd104a12f2f53d137167ecf31d201ff8dfc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tableform.py @@ -0,0 +1,182 @@ +from sympy.core.singleton import S +from sympy.printing.tableform import TableForm +from sympy.printing.latex import latex +from sympy.abc import x +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.testing.pytest import raises + +from textwrap import dedent + + +def test_TableForm(): + s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]], + headings="automatic")) + assert s == ( + ' | 1 2\n' + '-------\n' + '1 | a b\n' + '2 | c d\n' + '3 | e ' + ) + s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]], + headings="automatic", wipe_zeros=False)) + assert s == dedent('''\ + | 1 2 + ------- + 1 | a b + 2 | c d + 3 | e 0''') + s = str(TableForm([[x**2, "b"], ["c", x**2], ["e", "f"]], + headings=("automatic", None))) + assert s == ( + '1 | x**2 b \n' + '2 | c x**2\n' + '3 | e f ' + ) + s = str(TableForm([["a", "b"], ["c", "d"], ["e", "f"]], + headings=(None, "automatic"))) + assert s == dedent('''\ + 1 2 + --- + a b + c d + e f''') + s = str(TableForm([[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]])) + assert s == ( + ' | y1 y2\n' + '---------------\n' + 'Group A | 5 7 \n' + 'Group B | 4 2 \n' + 'Group C | 10 3 ' + ) + raises( + ValueError, + lambda: + TableForm( + [[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]], + alignments="middle") + ) + s = str(TableForm([[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]], + alignments="right")) + assert s == dedent('''\ + | y1 y2 + --------------- + Group A | 5 7 + Group B | 4 2 + Group C | 10 3''') + + # other alignment permutations + d = [[1, 100], [100, 1]] + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='l') + assert str(s) == ( + 'xxx | 1 100\n' + ' x | 100 1 ' + ) + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='lr') + assert str(s) == dedent('''\ + xxx | 1 100 + x | 100 1''') + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='clr') + assert str(s) == dedent('''\ + xxx | 1 100 + x | 100 1''') + + s = TableForm(d, headings=(('xxx', 'x'), None)) + assert str(s) == ( + 'xxx | 1 100\n' + ' x | 100 1 ' + ) + + raises(ValueError, lambda: TableForm(d, alignments='clr')) + + #pad + s = str(TableForm([[None, "-", 2], [1]], pad='?')) + assert s == dedent('''\ + ? - 2 + 1 ? ?''') + + +def test_TableForm_latex(): + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"), alignments='l')) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'*3)) + assert s == ( + '\\begin{tabular}{l l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & $a$ & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + formats=['(%s)', None], headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & (a) & $x^{3}$ \\\\\n' + '2 & (c) & $\\frac{1}{4}$ \\\\\n' + '3 & (sqrt(x)) & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + + def neg_in_paren(x, i, j): + if i % 2: + return ('(%s)' if x < 0 else '%s') % x + else: + pass # use default print + s = latex(TableForm([[-1, 2], [-3, 4]], + formats=[neg_in_paren]*2, headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & -1 & 2 \\\\\n' + '2 & (-3) & 4 \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]])) + assert s == ( + '\\begin{tabular}{l l}\n' + '$a$ & $x^{3}$ \\\\\n' + '$c$ & $\\frac{1}{4}$ \\\\\n' + '$\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tensorflow.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c92cd17b13e1148ebf83f13f66854b983491fe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tensorflow.py @@ -0,0 +1,493 @@ +import random +from sympy.core.function import Derivative +from sympy.core.symbol import symbols +from sympy import Piecewise +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt +from sympy.external import import_module +from sympy.functions import \ + Abs, ceiling, exp, floor, sign, sin, asin, sqrt, cos, \ + acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \ + re, im, arg, erf, loggamma, log +from sympy.codegen.cfunctions import isnan, isinf +from sympy.matrices import Matrix, MatrixBase, eye, randMatrix +from sympy.matrices.expressions import \ + Determinant, HadamardProduct, Inverse, MatrixSymbol, Trace +from sympy.printing.tensorflow import tensorflow_code +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import skip +from sympy.testing.pytest import XFAIL + + +tf = tensorflow = import_module("tensorflow") + +if tensorflow: + # Hide Tensorflow warnings + import os + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + + +M = MatrixSymbol("M", 3, 3) +N = MatrixSymbol("N", 3, 3) +P = MatrixSymbol("P", 3, 3) +Q = MatrixSymbol("Q", 3, 3) + +x, y, z, t = symbols("x y z t") + +if tf is not None: + llo = [list(range(i, i+3)) for i in range(0, 9, 3)] + m3x3 = tf.constant(llo) + m3x3sympy = Matrix(llo) + + +def _compare_tensorflow_matrix(variables, expr, use_float=False): + f = lambdify(variables, expr, 'tensorflow') + if not use_float: + random_matrices = [randMatrix(v.rows, v.cols) for v in variables] + else: + random_matrices = [randMatrix(v.rows, v.cols)/100. for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + if e.is_Matrix: + if not isinstance(e, MatrixBase): + e = e.as_explicit() + e = e.tolist() + + if not use_float: + assert (r == e).all() + else: + r = [i for row in r for i in row] + e = [i for row in e for i in row] + assert all( + abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e)) + + +# Creating a custom inverse test. +# See https://github.com/sympy/sympy/issues/18469 +def _compare_tensorflow_matrix_inverse(variables, expr, use_float=False): + f = lambdify(variables, expr, 'tensorflow') + if not use_float: + random_matrices = [eye(v.rows, v.cols)*4 for v in variables] + else: + random_matrices = [eye(v.rows, v.cols)*3.14 for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + if e.is_Matrix: + if not isinstance(e, MatrixBase): + e = e.as_explicit() + e = e.tolist() + + if not use_float: + assert (r == e).all() + else: + r = [i for row in r for i in row] + e = [i for row in e for i in row] + assert all( + abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e)) + + +def _compare_tensorflow_matrix_scalar(variables, expr): + f = lambdify(variables, expr, 'tensorflow') + random_matrices = [ + randMatrix(v.rows, v.cols).evalf() / 100 for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + assert abs(r-e) < 10**-6 + + +def _compare_tensorflow_scalar( + variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'tensorflow') + rvs = [rng() for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + tf_rvs = [eval(tensorflow_code(i)) for i in rvs] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*tf_rvs)) + + e = expr.subs(dict(zip(variables, rvs))).evalf().doit() + assert abs(r-e) < 10**-6 + + +def _compare_tensorflow_relational( + variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'tensorflow') + rvs = [rng() for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + tf_rvs = [eval(tensorflow_code(i)) for i in rvs] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*tf_rvs)) + + e = expr.subs(dict(zip(variables, rvs))).doit() + assert r == e + + +def test_tensorflow_printing(): + assert tensorflow_code(eye(3)) == \ + "tensorflow.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])" + + expr = Matrix([[x, sin(y)], [exp(z), -t]]) + assert tensorflow_code(expr) == \ + "tensorflow.Variable(" \ + "[[x, tensorflow.math.sin(y)]," \ + " [tensorflow.math.exp(z), -t]])" + + +# This (random) test is XFAIL because it fails occasionally +# See https://github.com/sympy/sympy/issues/18469 +@XFAIL +def test_tensorflow_math(): + if not tf: + skip("TensorFlow not installed") + + expr = Abs(x) + assert tensorflow_code(expr) == "tensorflow.math.abs(x)" + _compare_tensorflow_scalar((x,), expr) + + expr = sign(x) + assert tensorflow_code(expr) == "tensorflow.math.sign(x)" + _compare_tensorflow_scalar((x,), expr) + + expr = ceiling(x) + assert tensorflow_code(expr) == "tensorflow.math.ceil(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = floor(x) + assert tensorflow_code(expr) == "tensorflow.math.floor(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = exp(x) + assert tensorflow_code(expr) == "tensorflow.math.exp(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = sqrt(x) + assert tensorflow_code(expr) == "tensorflow.math.sqrt(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = x ** 4 + assert tensorflow_code(expr) == "tensorflow.math.pow(x, 4)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = cos(x) + assert tensorflow_code(expr) == "tensorflow.math.cos(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = acos(x) + assert tensorflow_code(expr) == "tensorflow.math.acos(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(0, 0.95)) + + expr = sin(x) + assert tensorflow_code(expr) == "tensorflow.math.sin(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = asin(x) + assert tensorflow_code(expr) == "tensorflow.math.asin(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = tan(x) + assert tensorflow_code(expr) == "tensorflow.math.tan(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = atan(x) + assert tensorflow_code(expr) == "tensorflow.math.atan(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = atan2(y, x) + assert tensorflow_code(expr) == "tensorflow.math.atan2(y, x)" + _compare_tensorflow_scalar((y, x), expr, rng=lambda: random.random()) + + expr = cosh(x) + assert tensorflow_code(expr) == "tensorflow.math.cosh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = acosh(x) + assert tensorflow_code(expr) == "tensorflow.math.acosh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = sinh(x) + assert tensorflow_code(expr) == "tensorflow.math.sinh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = asinh(x) + assert tensorflow_code(expr) == "tensorflow.math.asinh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = tanh(x) + assert tensorflow_code(expr) == "tensorflow.math.tanh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = atanh(x) + assert tensorflow_code(expr) == "tensorflow.math.atanh(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.uniform(-.5, .5)) + + expr = erf(x) + assert tensorflow_code(expr) == "tensorflow.math.erf(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.random()) + + expr = loggamma(x) + assert tensorflow_code(expr) == "tensorflow.math.lgamma(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.random()) + + +def test_tensorflow_complexes(): + assert tensorflow_code(re(x)) == "tensorflow.math.real(x)" + assert tensorflow_code(im(x)) == "tensorflow.math.imag(x)" + assert tensorflow_code(arg(x)) == "tensorflow.math.angle(x)" + + +def test_tensorflow_relational(): + if not tf: + skip("TensorFlow not installed") + + expr = Eq(x, y) + assert tensorflow_code(expr) == "tensorflow.math.equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Ne(x, y) + assert tensorflow_code(expr) == "tensorflow.math.not_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Ge(x, y) + assert tensorflow_code(expr) == "tensorflow.math.greater_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Gt(x, y) + assert tensorflow_code(expr) == "tensorflow.math.greater(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Le(x, y) + assert tensorflow_code(expr) == "tensorflow.math.less_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Lt(x, y) + assert tensorflow_code(expr) == "tensorflow.math.less(x, y)" + _compare_tensorflow_relational((x, y), expr) + + +# This (random) test is XFAIL because it fails occasionally +# See https://github.com/sympy/sympy/issues/18469 +@XFAIL +def test_tensorflow_matrices(): + if not tf: + skip("TensorFlow not installed") + + expr = M + assert tensorflow_code(expr) == "M" + _compare_tensorflow_matrix((M,), expr) + + expr = M + N + assert tensorflow_code(expr) == "tensorflow.math.add(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = M * N + assert tensorflow_code(expr) == "tensorflow.linalg.matmul(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = HadamardProduct(M, N) + assert tensorflow_code(expr) == "tensorflow.math.multiply(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = M*N*P*Q + assert tensorflow_code(expr) == \ + "tensorflow.linalg.matmul(" \ + "tensorflow.linalg.matmul(" \ + "tensorflow.linalg.matmul(M, N), P), Q)" + _compare_tensorflow_matrix((M, N, P, Q), expr) + + expr = M**3 + assert tensorflow_code(expr) == \ + "tensorflow.linalg.matmul(tensorflow.linalg.matmul(M, M), M)" + _compare_tensorflow_matrix((M,), expr) + + expr = Trace(M) + assert tensorflow_code(expr) == "tensorflow.linalg.trace(M)" + _compare_tensorflow_matrix((M,), expr) + + expr = Determinant(M) + assert tensorflow_code(expr) == "tensorflow.linalg.det(M)" + _compare_tensorflow_matrix_scalar((M,), expr) + + expr = Inverse(M) + assert tensorflow_code(expr) == "tensorflow.linalg.inv(M)" + _compare_tensorflow_matrix_inverse((M,), expr, use_float=True) + + expr = M.T + assert tensorflow_code(expr, tensorflow_version='1.14') == \ + "tensorflow.linalg.matrix_transpose(M)" + assert tensorflow_code(expr, tensorflow_version='1.13') == \ + "tensorflow.matrix_transpose(M)" + + _compare_tensorflow_matrix((M,), expr) + + +def test_codegen_einsum(): + if not tf: + skip("TensorFlow not installed") + + graph = tf.Graph() + with graph.as_default(): + session = tf.compat.v1.Session(graph=graph) + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'tensorflow') + + ma = tf.constant([[1, 2], [3, 4]]) + mb = tf.constant([[1,-2], [-1, 3]]) + y = session.run(f(ma, mb)) + c = session.run(tf.matmul(ma, mb)) + assert (y == c).all() + + +def test_codegen_extra(): + if not tf: + skip("TensorFlow not installed") + + graph = tf.Graph() + with graph.as_default(): + session = tf.compat.v1.Session() + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = tf.constant([[1, 2], [3, 4]]) + mb = tf.constant([[1,-2], [-1, 3]]) + mc = tf.constant([[2, 0], [1, 2]]) + md = tf.constant([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + assert tensorflow_code(cg) == \ + 'tensorflow.linalg.einsum("ab,cd", M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.einsum("ij,kl", ma, mb)) + assert (y == c).all() + + cg = ArrayAdd(M, N) + assert tensorflow_code(cg) == 'tensorflow.math.add(M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(ma + mb) + assert (y == c).all() + + cg = ArrayAdd(M, N, P) + assert tensorflow_code(cg) == \ + 'tensorflow.math.add(tensorflow.math.add(M, N), P)' + f = lambdify((M, N, P), cg, 'tensorflow') + y = session.run(f(ma, mb, mc)) + c = session.run(ma + mb + mc) + assert (y == c).all() + + cg = ArrayAdd(M, N, P, Q) + assert tensorflow_code(cg) == \ + 'tensorflow.math.add(' \ + 'tensorflow.math.add(tensorflow.math.add(M, N), P), Q)' + f = lambdify((M, N, P, Q), cg, 'tensorflow') + y = session.run(f(ma, mb, mc, md)) + c = session.run(ma + mb + mc + md) + assert (y == c).all() + + cg = PermuteDims(M, [1, 0]) + assert tensorflow_code(cg) == 'tensorflow.transpose(M, [1, 0])' + f = lambdify((M,), cg, 'tensorflow') + y = session.run(f(ma)) + c = session.run(tf.transpose(ma)) + assert (y == c).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + assert tensorflow_code(cg) == \ + 'tensorflow.transpose(' \ + 'tensorflow.linalg.einsum("ab,cd", M, N), [1, 2, 3, 0])' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.transpose(tf.einsum("ab,cd", ma, mb), [1, 2, 3, 0])) + assert (y == c).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + assert tensorflow_code(cg) == \ + 'tensorflow.linalg.einsum("ab,bc->acb", M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.einsum("ab,bc->acb", ma, mb)) + assert (y == c).all() + + +def test_MatrixElement_printing(): + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert tensorflow_code(A[0, 0]) == "A[0, 0]" + assert tensorflow_code(3 * A[0, 0]) == "3*A[0, 0]" + + F = C[0, 0].subs(C, A - B) + assert tensorflow_code(F) == "(tensorflow.math.add((-1)*B, A))[0, 0]" + + +def test_tensorflow_Derivative(): + expr = Derivative(sin(x), x) + assert tensorflow_code(expr) == \ + "tensorflow.gradients(tensorflow.math.sin(x), x)[0]" + +def test_tensorflow_isnan_isinf(): + if not tf: + skip("TensorFlow not installed") + + # Test for isnan + x = symbols("x") + # Return 0 if x is of nan value, and 1 otherwise + expression = Piecewise((0.0, isnan(x)), (1.0, True)) + printed_code = tensorflow_code(expression) + expected_printed_code = "tensorflow.where(tensorflow.math.is_nan(x), 0.0, 1.0)" + assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}" + for _input, _expected in [(float('nan'), 0.0), (float('inf'), 1.0), (float('-inf'), 1.0), (1.0, 1.0)]: + _output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input])) + assert (_output == _expected).numpy().all() + + # Test for isinf + x = symbols("x") + # Return 0 if x is of nan value, and 1 otherwise + expression = Piecewise((0.0, isinf(x)), (1.0, True)) + printed_code = tensorflow_code(expression) + expected_printed_code = "tensorflow.where(tensorflow.math.is_inf(x), 0.0, 1.0)" + assert tensorflow_code(expression) == expected_printed_code, f"Incorrect printed result {printed_code}, expected {expected_printed_code}" + for _input, _expected in [(float('inf'), 0.0), (float('-inf'), 0.0), (float('nan'), 1.0), (1.0, 1.0)]: + _output = lambdify((x), expression, modules="tensorflow")(x=tf.constant([_input])) + assert (_output == _expected).numpy().all() diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_theanocode.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_theanocode.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff40f78cb4de16149cb5e780756b7e32b574b71 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_theanocode.py @@ -0,0 +1,639 @@ +""" +Important note on tests in this module - the Theano printing functions use a +global cache by default, which means that tests using it will modify global +state and thus not be independent from each other. Instead of using the "cache" +keyword argument each time, this module uses the theano_code_ and +theano_function_ functions defined below which default to using a new, empty +cache instead. +""" + +import logging + +from sympy.external import import_module +from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy + +theanologger = logging.getLogger('theano.configdefaults') +theanologger.setLevel(logging.CRITICAL) +theano = import_module('theano') +theanologger.setLevel(logging.WARNING) + + +if theano: + import numpy as np + ts = theano.scalar + tt = theano.tensor + xt, yt, zt = [tt.scalar(name, 'floatX') for name in 'xyz'] + Xt, Yt, Zt = [tt.tensor('floatX', (False, False), name=n) for n in 'XYZ'] +else: + #bin/test will not execute any tests now + disabled = True + +import sympy as sy +from sympy.core.singleton import S +from sympy.abc import x, y, z, t +from sympy.printing.theanocode import (theano_code, dim_handling, + theano_function) + + +# Default set of matrix symbols for testing - make square so we can both +# multiply and perform elementwise operations between them. +X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] + +# For testing AppliedUndef +f_t = sy.Function('f')(t) + + +def theano_code_(expr, **kwargs): + """ Wrapper for theano_code that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return theano_code(expr, **kwargs) + +def theano_function_(inputs, outputs, **kwargs): + """ Wrapper for theano_function that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return theano_function(inputs, outputs, **kwargs) + + +def fgraph_of(*exprs): + """ Transform SymPy expressions into Theano Computation. + + Parameters + ========== + exprs + SymPy expressions + + Returns + ======= + theano.gof.FunctionGraph + """ + outs = list(map(theano_code_, exprs)) + ins = theano.gof.graph.inputs(outs) + ins, outs = theano.gof.graph.clone(ins, outs) + return theano.gof.FunctionGraph(ins, outs) + + +def theano_simplify(fgraph): + """ Simplify a Theano Computation. + + Parameters + ========== + fgraph : theano.gof.FunctionGraph + + Returns + ======= + theano.gof.FunctionGraph + """ + mode = theano.compile.get_default_mode().excluding("fusion") + fgraph = fgraph.clone() + mode.optimizer.optimize(fgraph) + return fgraph + + +def theq(a, b): + """ Test two Theano objects for equality. + + Also accepts numeric types and lists/tuples of supported types. + + Note - debugprint() has a bug where it will accept numeric types but does + not respect the "file" argument and in this case and instead prints the number + to stdout and returns an empty string. This can lead to tests passing where + they should fail because any two numbers will always compare as equal. To + prevent this we treat numbers as a separate case. + """ + numeric_types = (int, float, np.number) + a_is_num = isinstance(a, numeric_types) + b_is_num = isinstance(b, numeric_types) + + # Compare numeric types using regular equality + if a_is_num or b_is_num: + if not (a_is_num and b_is_num): + return False + + return a == b + + # Compare sequences element-wise + a_is_seq = isinstance(a, (tuple, list)) + b_is_seq = isinstance(b, (tuple, list)) + + if a_is_seq or b_is_seq: + if not (a_is_seq and b_is_seq) or type(a) != type(b): + return False + + return list(map(theq, a)) == list(map(theq, b)) + + # Otherwise, assume debugprint() can handle it + astr = theano.printing.debugprint(a, file='str') + bstr = theano.printing.debugprint(b, file='str') + + # Check for bug mentioned above + for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: + if argstr == '': + raise TypeError( + 'theano.printing.debugprint(%s) returned empty string ' + '(%s is instance of %r)' + % (argname, argname, type(argval)) + ) + + return astr == bstr + + +def test_example_symbols(): + """ + Check that the example symbols in this module print to their Theano + equivalents, as many of the other tests depend on this. + """ + assert theq(xt, theano_code_(x)) + assert theq(yt, theano_code_(y)) + assert theq(zt, theano_code_(z)) + assert theq(Xt, theano_code_(X)) + assert theq(Yt, theano_code_(Y)) + assert theq(Zt, theano_code_(Z)) + + +def test_Symbol(): + """ Test printing a Symbol to a theano variable. """ + xx = theano_code_(x) + assert isinstance(xx, (tt.TensorVariable, ts.ScalarVariable)) + assert xx.broadcastable == () + assert xx.name == x.name + + xx2 = theano_code_(x, broadcastables={x: (False,)}) + assert xx2.broadcastable == (False,) + assert xx2.name == x.name + +def test_MatrixSymbol(): + """ Test printing a MatrixSymbol to a theano variable. """ + XX = theano_code_(X) + assert isinstance(XX, tt.TensorVariable) + assert XX.broadcastable == (False, False) + +@SKIP # TODO - this is currently not checked but should be implemented +def test_MatrixSymbol_wrong_dims(): + """ Test MatrixSymbol with invalid broadcastable. """ + bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] + for bc in bcs: + with raises(ValueError): + theano_code_(X, broadcastables={X: bc}) + +def test_AppliedUndef(): + """ Test printing AppliedUndef instance, which works similarly to Symbol. """ + ftt = theano_code_(f_t) + assert isinstance(ftt, tt.TensorVariable) + assert ftt.broadcastable == () + assert ftt.name == 'f_t' + + +def test_add(): + expr = x + y + comp = theano_code_(expr) + assert comp.owner.op == theano.tensor.add + +def test_trig(): + assert theq(theano_code_(sy.sin(x)), tt.sin(xt)) + assert theq(theano_code_(sy.tan(x)), tt.tan(xt)) + +def test_many(): + """ Test printing a complex expression with multiple symbols. """ + expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) + comp = theano_code_(expr) + expected = tt.exp(xt**2 + tt.cos(yt)) * tt.log(2*zt) + assert theq(comp, expected) + + +def test_dtype(): + """ Test specifying specific data types through the dtype argument. """ + for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: + assert theano_code_(x, dtypes={x: dtype}).type.dtype == dtype + + # "floatX" type + assert theano_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') + + # Type promotion + assert theano_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' + assert theano_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' + + +def test_broadcastables(): + """ Test the "broadcastables" argument when printing symbol-like objects. """ + + # No restrictions on shape + for s in [x, f_t]: + for bc in [(), (False,), (True,), (False, False), (True, False)]: + assert theano_code_(s, broadcastables={s: bc}).broadcastable == bc + + # TODO - matrix broadcasting? + +def test_broadcasting(): + """ Test "broadcastable" attribute after applying element-wise binary op. """ + + expr = x + y + + cases = [ + [(), (), ()], + [(False,), (False,), (False,)], + [(True,), (False,), (False,)], + [(False, True), (False, False), (False, False)], + [(True, False), (False, False), (False, False)], + ] + + for bc1, bc2, bc3 in cases: + comp = theano_code_(expr, broadcastables={x: bc1, y: bc2}) + assert comp.broadcastable == bc3 + + +def test_MatMul(): + expr = X*Y*Z + expr_t = theano_code_(expr) + assert isinstance(expr_t.owner.op, tt.Dot) + assert theq(expr_t, Xt.dot(Yt).dot(Zt)) + +def test_Transpose(): + assert isinstance(theano_code_(X.T).owner.op, tt.DimShuffle) + +def test_MatAdd(): + expr = X+Y+Z + assert isinstance(theano_code_(expr).owner.op, tt.Elemwise) + + +def test_Rationals(): + assert theq(theano_code_(sy.Integer(2) / 3), tt.true_div(2, 3)) + assert theq(theano_code_(S.Half), tt.true_div(1, 2)) + +def test_Integers(): + assert theano_code_(sy.Integer(3)) == 3 + +def test_factorial(): + n = sy.Symbol('n') + assert theano_code_(sy.factorial(n)) + +def test_Derivative(): + simp = lambda expr: theano_simplify(fgraph_of(expr)) + assert theq(simp(theano_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), + simp(theano.grad(tt.sin(xt), xt))) + + +def test_theano_function_simple(): + """ Test theano_function() with single output. """ + f = theano_function_([x, y], [x+y]) + assert f(2, 3) == 5 + +def test_theano_function_multi(): + """ Test theano_function() with multiple outputs. """ + f = theano_function_([x, y], [x+y, x-y]) + o1, o2 = f(2, 3) + assert o1 == 5 + assert o2 == -1 + +def test_theano_function_numpy(): + """ Test theano_function() vs Numpy implementation. """ + f = theano_function_([x, y], [x+y], dim=1, + dtypes={x: 'float64', y: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 + + f = theano_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, + dim=1) + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 + + +def test_theano_function_matrix(): + m = sy.Matrix([[x, y], [z, x + y + z]]) + expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) + f = theano_function_([x, y, z], [m]) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = theano_function_([x, y, z], [m], scalar=True) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = theano_function_([x, y, z], [m, m]) + assert isinstance(f(1.0, 2.0, 3.0), type([])) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) + +def test_dim_handling(): + assert dim_handling([x], dim=2) == {x: (False, False)} + assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), + y: (False, False)} + assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} + +def test_theano_function_kwargs(): + """ + Test passing additional kwargs from theano_function() to theano.function(). + """ + import numpy as np + f = theano_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', + dtypes={x: 'float64', y: 'float64', z: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 + + f = theano_function_([x, y, z], [x+y], + dtypes={x: 'float64', y: 'float64', z: 'float64'}, + dim=1, on_unused_input='ignore') + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + zz = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 + +def test_theano_function_scalar(): + """ Test the "scalar" argument to theano_function(). """ + + args = [ + ([x, y], [x + y], None, [0]), # Single 0d output + ([X, Y], [X + Y], None, [2]), # Single 2d output + ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output + ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs + ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d + ] + + # Create and test functions with and without the scalar setting + for inputs, outputs, in_dims, out_dims in args: + for scalar in [False, True]: + + f = theano_function_(inputs, outputs, dims=in_dims, scalar=scalar) + + # Check the theano_function attribute is set whether wrapped or not + assert isinstance(f.theano_function, theano.compile.function_module.Function) + + # Feed in inputs of the appropriate size and get outputs + in_values = [ + np.ones([1 if bc else 5 for bc in i.type.broadcastable]) + for i in f.theano_function.input_storage + ] + out_values = f(*in_values) + if not isinstance(out_values, list): + out_values = [out_values] + + # Check output types and shapes + assert len(out_dims) == len(out_values) + for d, value in zip(out_dims, out_values): + + if scalar and d == 0: + # Should have been converted to a scalar value + assert isinstance(value, np.number) + + else: + # Otherwise should be an array + assert isinstance(value, np.ndarray) + assert value.ndim == d + +def test_theano_function_bad_kwarg(): + """ + Passing an unknown keyword argument to theano_function() should raise an + exception. + """ + raises(Exception, lambda : theano_function_([x], [x+1], foobar=3)) + + +def test_slice(): + assert theano_code_(slice(1, 2, 3)) == slice(1, 2, 3) + + def theq_slice(s1, s2): + for attr in ['start', 'stop', 'step']: + a1 = getattr(s1, attr) + a2 = getattr(s2, attr) + if a1 is None or a2 is None: + if not (a1 is None or a2 is None): + return False + elif not theq(a1, a2): + return False + return True + + dtypes = {x: 'int32', y: 'int32'} + assert theq_slice(theano_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) + assert theq_slice(theano_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) + +def test_MatrixSlice(): + from theano import Constant + + cache = {} + + n = sy.Symbol('n', integer=True) + X = sy.MatrixSymbol('X', n, n) + + Y = X[1:2:3, 4:5:6] + Yt = theano_code_(Y, cache=cache) + + s = ts.Scalar('int64') + assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) + assert Yt.owner.inputs[0] == theano_code_(X, cache=cache) + # == doesn't work in theano like it does in SymPy. You have to use + # equals. + assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7)) + + k = sy.Symbol('k') + theano_code_(k, dtypes={k: 'int32'}) + start, stop, step = 4, k, 2 + Y = X[start:stop:step] + Yt = theano_code_(Y, dtypes={n: 'int32', k: 'int32'}) + # assert Yt.owner.op.idx_list[0].stop == kt + +def test_BlockMatrix(): + n = sy.Symbol('n', integer=True) + A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] + At, Bt, Ct, Dt = map(theano_code_, (A, B, C, D)) + Block = sy.BlockMatrix([[A, B], [C, D]]) + Blockt = theano_code_(Block) + solutions = [tt.join(0, tt.join(1, At, Bt), tt.join(1, Ct, Dt)), + tt.join(1, tt.join(0, At, Ct), tt.join(0, Bt, Dt))] + assert any(theq(Blockt, solution) for solution in solutions) + +@SKIP +def test_BlockMatrix_Inverse_execution(): + k, n = 2, 4 + dtype = 'float32' + A = sy.MatrixSymbol('A', n, k) + B = sy.MatrixSymbol('B', n, n) + inputs = A, B + output = B.I*A + + cutsizes = {A: [(n//2, n//2), (k//2, k//2)], + B: [(n//2, n//2), (n//2, n//2)]} + cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] + cutoutput = output.subs(dict(zip(inputs, cutinputs))) + + dtypes = dict(zip(inputs, [dtype]*len(inputs))) + f = theano_function_(inputs, [output], dtypes=dtypes, cache={}) + fblocked = theano_function_(inputs, [sy.block_collapse(cutoutput)], + dtypes=dtypes, cache={}) + + ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] + ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), + np.eye(n).astype(dtype)] + ninputs[1] += np.ones(B.shape)*1e-5 + + assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) + +def test_DenseMatrix(): + t = sy.Symbol('theta') + for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: + X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) + tX = theano_code_(X) + assert isinstance(tX, tt.TensorVariable) + assert tX.owner.op == tt.join_ + + +def test_cache_basic(): + """ Test single symbol-like objects are cached when printed by themselves. """ + + # Pairs of objects which should be considered equivalent with respect to caching + pairs = [ + (x, sy.Symbol('x')), + (X, sy.MatrixSymbol('X', *X.shape)), + (f_t, sy.Function('f')(sy.Symbol('t'))), + ] + + for s1, s2 in pairs: + cache = {} + st = theano_code_(s1, cache=cache) + + # Test hit with same instance + assert theano_code_(s1, cache=cache) is st + + # Test miss with same instance but new cache + assert theano_code_(s1, cache={}) is not st + + # Test hit with different but equivalent instance + assert theano_code_(s2, cache=cache) is st + +def test_global_cache(): + """ Test use of the global cache. """ + from sympy.printing.theanocode import global_cache + + backup = dict(global_cache) + try: + # Temporarily empty global cache + global_cache.clear() + + for s in [x, X, f_t]: + with warns_deprecated_sympy(): + st = theano_code(s) + assert theano_code(s) is st + + finally: + # Restore global cache + global_cache.update(backup) + +def test_cache_types_distinct(): + """ + Test that symbol-like objects of different types (Symbol, MatrixSymbol, + AppliedUndef) are distinguished by the cache even if they have the same + name. + """ + symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] + + cache = {} # Single shared cache + printed = {} + + for s in symbols: + st = theano_code_(s, cache=cache) + assert st not in printed.values() + printed[s] = st + + # Check all printed objects are distinct + assert len(set(map(id, printed.values()))) == len(symbols) + + # Check retrieving + for s, st in printed.items(): + with warns_deprecated_sympy(): + assert theano_code(s, cache=cache) is st + +def test_symbols_are_created_once(): + """ + Test that a symbol is cached and reused when it appears in an expression + more than once. + """ + expr = sy.Add(x, x, evaluate=False) + comp = theano_code_(expr) + + assert theq(comp, xt + xt) + assert not theq(comp, xt + theano_code_(x)) + +def test_cache_complex(): + """ + Test caching on a complicated expression with multiple symbols appearing + multiple times. + """ + expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) + symbol_names = {s.name for s in expr.free_symbols} + expr_t = theano_code_(expr) + + # Iterate through variables in the Theano computational graph that the + # printed expression depends on + seen = set() + for v in theano.gof.graph.ancestors([expr_t]): + # Owner-less, non-constant variables should be our symbols + if v.owner is None and not isinstance(v, theano.gof.graph.Constant): + # Check it corresponds to a symbol and appears only once + assert v.name in symbol_names + assert v.name not in seen + seen.add(v.name) + + # Check all were present + assert seen == symbol_names + + +def test_Piecewise(): + # A piecewise linear + expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III + result = theano_code_(expr) + assert result.owner.op == tt.switch + + expected = tt.switch(xt<0, 0, tt.switch(xt<2, xt, 1)) + assert theq(result, expected) + + expr = sy.Piecewise((x, x < 0)) + result = theano_code_(expr) + expected = tt.switch(xt < 0, xt, np.nan) + assert theq(result, expected) + + expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ + (x, sy.Or(x>2, x<0))) + result = theano_code_(expr) + expected = tt.switch(tt.and_(xt>0,xt<2), 0, \ + tt.switch(tt.or_(xt>2, xt<0), xt, np.nan)) + assert theq(result, expected) + + +def test_Relationals(): + assert theq(theano_code_(sy.Eq(x, y)), tt.eq(xt, yt)) + # assert theq(theano_code_(sy.Ne(x, y)), tt.neq(xt, yt)) # TODO - implement + assert theq(theano_code_(x > y), xt > yt) + assert theq(theano_code_(x < y), xt < yt) + assert theq(theano_code_(x >= y), xt >= yt) + assert theq(theano_code_(x <= y), xt <= yt) + + +def test_complexfunctions(): + with warns_deprecated_sympy(): + xt, yt = theano_code_(x, dtypes={x:'complex128'}), theano_code_(y, dtypes={y: 'complex128'}) + from sympy.functions.elementary.complexes import conjugate + from theano.tensor import as_tensor_variable as atv + from theano.tensor import complex as cplx + with warns_deprecated_sympy(): + assert theq(theano_code_(y*conjugate(x)), yt*(xt.conj())) + assert theq(theano_code_((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) + + +def test_constantfunctions(): + with warns_deprecated_sympy(): + tf = theano_function_([],[1+1j]) + assert(tf()==1+1j) + + +def test_Exp1(): + """ + Test that exp(1) prints without error and evaluates close to SymPy's E + """ + # sy.exp(1) should yield same instance of E as sy.E (singleton), but extra + # check added for sanity + e_a = sy.exp(1) + e_b = sy.E + + np.testing.assert_allclose(float(e_a), np.e) + np.testing.assert_allclose(float(e_b), np.e) + + e = theano_code_(e_a) + np.testing.assert_allclose(float(e_a), e.eval()) + + e = theano_code_(e_b) + np.testing.assert_allclose(float(e_b), e.eval()) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_torch.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce2c6cec75e03264f93b472a79eb073742e3486 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_torch.py @@ -0,0 +1,531 @@ +import random +import math + +from sympy import symbols, Derivative +from sympy.printing.pytorch import torch_code +from sympy import (eye, MatrixSymbol, Matrix) +from sympy.tensor.array import NDimArray +from sympy.tensor.array.expressions.array_expressions import ( + ArrayTensorProduct, ArrayAdd, + PermuteDims, ArrayDiagonal, _CodegenArrayAbstract) +from sympy.utilities.lambdify import lambdify +from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt +from sympy.functions import \ + Abs, ceiling, exp, floor, sign, sin, asin, cos, \ + acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \ + re, im, arg, erf, loggamma, sqrt +from sympy.testing.pytest import skip +from sympy.external import import_module +from sympy.matrices.expressions import \ + Determinant, HadamardProduct, Inverse, Trace +from sympy.matrices import randMatrix +from sympy.matrices import Identity, ZeroMatrix, OneMatrix +from sympy import conjugate, I +from sympy import Heaviside, gamma, polygamma + + + +torch = import_module("torch") + +M = MatrixSymbol("M", 3, 3) +N = MatrixSymbol("N", 3, 3) +P = MatrixSymbol("P", 3, 3) +Q = MatrixSymbol("Q", 3, 3) + +x, y, z, t = symbols("x y z t") + +if torch is not None: + llo = [list(range(i, i + 3)) for i in range(0, 9, 3)] + m3x3 = torch.tensor(llo, dtype=torch.float64) + m3x3sympy = Matrix(llo) + + +def _compare_torch_matrix(variables, expr): + f = lambdify(variables, expr, 'torch') + + random_matrices = [randMatrix(i.shape[0], i.shape[1]) for i in variables] + random_variables = [torch.tensor(i.tolist(), dtype=torch.float64) for i in random_matrices] + r = f(*random_variables) + e = expr.subs(dict(zip(variables, random_matrices))).doit() + + if isinstance(e, _CodegenArrayAbstract): + e = e.doit() + + if hasattr(e, 'is_number') and e.is_number: + if isinstance(r, torch.Tensor) and r.dim() == 0: + r = r.item() + e = float(e) + assert abs(r - e) < 1e-6 + return + + if e.is_Matrix or isinstance(e, NDimArray): + e = torch.tensor(e.tolist(), dtype=torch.float64) + assert torch.allclose(r, e, atol=1e-6) + else: + raise TypeError(f"Cannot compare {type(r)} with {type(e)}") + + +def _compare_torch_scalar(variables, expr, rng=lambda: random.uniform(-5, 5)): + f = lambdify(variables, expr, 'torch') + rvs = [rng() for v in variables] + t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs] + r = f(*t_rvs) + if isinstance(r, torch.Tensor): + r = r.item() + e = expr.subs(dict(zip(variables, rvs))).doit() + assert abs(r - e) < 1e-6 + + +def _compare_torch_relational(variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'torch') + rvs = [rng() for v in variables] + t_rvs = [torch.tensor(i, dtype=torch.float64) for i in rvs] + r = f(*t_rvs) + e = bool(expr.subs(dict(zip(variables, rvs))).doit()) + assert r.item() == e + + +def test_torch_math(): + if not torch: + skip("PyTorch not installed") + + expr = Abs(x) + assert torch_code(expr) == "torch.abs(x)" + f = lambdify(x, expr, 'torch') + ma = torch.tensor([[-1, 2, -3, -4]], dtype=torch.float64) + y_abs = f(ma) + c = torch.abs(ma) + assert torch.all(y_abs == c) + + expr = sign(x) + assert torch_code(expr) == "torch.sign(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-10, 10)) + + expr = ceiling(x) + assert torch_code(expr) == "torch.ceil(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = floor(x) + assert torch_code(expr) == "torch.floor(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = exp(x) + assert torch_code(expr) == "torch.exp(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = sqrt(x) + assert torch_code(expr) == "torch.sqrt(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = x ** 4 + assert torch_code(expr) == "torch.pow(x, 4)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = cos(x) + assert torch_code(expr) == "torch.cos(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = acos(x) + assert torch_code(expr) == "torch.acos(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99)) + + expr = sin(x) + assert torch_code(expr) == "torch.sin(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.random()) + + expr = asin(x) + assert torch_code(expr) == "torch.asin(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.99, 0.99)) + + expr = tan(x) + assert torch_code(expr) == "torch.tan(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-1.5, 1.5)) + + expr = atan(x) + assert torch_code(expr) == "torch.atan(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5)) + + expr = atan2(y, x) + assert torch_code(expr) == "torch.atan2(y, x)" + _compare_torch_scalar((y, x), expr, rng=lambda: random.uniform(-5, 5)) + + expr = cosh(x) + assert torch_code(expr) == "torch.cosh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = acosh(x) + assert torch_code(expr) == "torch.acosh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(1.1, 5)) + + expr = sinh(x) + assert torch_code(expr) == "torch.sinh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = asinh(x) + assert torch_code(expr) == "torch.asinh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-5, 5)) + + expr = tanh(x) + assert torch_code(expr) == "torch.tanh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = atanh(x) + assert torch_code(expr) == "torch.atanh(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-0.9, 0.9)) + + expr = erf(x) + assert torch_code(expr) == "torch.erf(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(-2, 2)) + + expr = loggamma(x) + assert torch_code(expr) == "torch.lgamma(x)" + _compare_torch_scalar((x,), expr, rng=lambda: random.uniform(0.5, 5)) + + +def test_torch_complexes(): + assert torch_code(re(x)) == "torch.real(x)" + assert torch_code(im(x)) == "torch.imag(x)" + assert torch_code(arg(x)) == "torch.angle(x)" + + +def test_torch_relational(): + if not torch: + skip("PyTorch not installed") + + expr = Eq(x, y) + assert torch_code(expr) == "torch.eq(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Ne(x, y) + assert torch_code(expr) == "torch.ne(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Ge(x, y) + assert torch_code(expr) == "torch.ge(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Gt(x, y) + assert torch_code(expr) == "torch.gt(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Le(x, y) + assert torch_code(expr) == "torch.le(x, y)" + _compare_torch_relational((x, y), expr) + + expr = Lt(x, y) + assert torch_code(expr) == "torch.lt(x, y)" + _compare_torch_relational((x, y), expr) + + +def test_torch_matrix(): + if torch is None: + skip("PyTorch not installed") + + expr = M + assert torch_code(expr) == "M" + f = lambdify((M,), expr, "torch") + eye_mat = eye(3) + eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64) + assert torch.allclose(f(eye_tensor), eye_tensor) + + expr = M * N + assert torch_code(expr) == "torch.matmul(M, N)" + _compare_torch_matrix((M, N), expr) + + expr = M ** 3 + assert torch_code(expr) == "torch.mm(torch.mm(M, M), M)" + _compare_torch_matrix((M,), expr) + + expr = M * N * P * Q + assert torch_code(expr) == "torch.matmul(torch.matmul(torch.matmul(M, N), P), Q)" + _compare_torch_matrix((M, N, P, Q), expr) + + expr = Trace(M) + assert torch_code(expr) == "torch.trace(M)" + _compare_torch_matrix((M,), expr) + + expr = Determinant(M) + assert torch_code(expr) == "torch.det(M)" + _compare_torch_matrix((M,), expr) + + expr = HadamardProduct(M, N) + assert torch_code(expr) == "torch.mul(M, N)" + _compare_torch_matrix((M, N), expr) + + expr = Inverse(M) + assert torch_code(expr) == "torch.linalg.inv(M)" + + # For inverse, use a matrix that's guaranteed to be invertible + eye_mat = eye(3) + eye_tensor = torch.tensor(eye_mat.tolist(), dtype=torch.float64) + f = lambdify((M,), expr, "torch") + result = f(eye_tensor) + expected = torch.linalg.inv(eye_tensor) + assert torch.allclose(result, expected) + + +def test_torch_array_operations(): + if not torch: + skip("PyTorch not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + + ma = torch.tensor([[1., 2.], [3., 4.]], dtype=torch.float64) + mb = torch.tensor([[1., -2.], [-1., 3.]], dtype=torch.float64) + mc = torch.tensor([[2., 0.], [1., 2.]], dtype=torch.float64) + md = torch.tensor([[1., -1.], [4., 7.]], dtype=torch.float64) + + cg = ArrayTensorProduct(M, N) + assert torch_code(cg) == 'torch.einsum("ab,cd", M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ij,kl", ma, mb) + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N) + assert torch_code(cg) == 'torch.add(M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = ma + mb + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N, P) + assert torch_code(cg) == 'torch.add(torch.add(M, N), P)' + f = lambdify((M, N, P), cg, 'torch') + y = f(ma, mb, mc) + c = ma + mb + mc + assert torch.allclose(y, c) + + cg = ArrayAdd(M, N, P, Q) + assert torch_code(cg) == 'torch.add(torch.add(torch.add(M, N), P), Q)' + f = lambdify((M, N, P, Q), cg, 'torch') + y = f(ma, mb, mc, md) + c = ma + mb + mc + md + assert torch.allclose(y, c) + + cg = PermuteDims(M, [1, 0]) + assert torch_code(cg) == 'M.permute(1, 0)' + f = lambdify((M,), cg, 'torch') + y = f(ma) + c = ma.T + assert torch.allclose(y, c) + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + assert torch_code(cg) == 'torch.einsum("ab,cd", M, N).permute(1, 2, 3, 0)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ab,cd", ma, mb).permute(1, 2, 3, 0) + assert torch.allclose(y, c) + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + assert torch_code(cg) == 'torch.einsum("ab,bc->acb", M, N)' + f = lambdify((M, N), cg, 'torch') + y = f(ma, mb) + c = torch.einsum("ab,bc->acb", ma, mb) + assert torch.allclose(y, c) + + +def test_torch_derivative(): + """Test derivative handling.""" + expr = Derivative(sin(x), x) + assert torch_code(expr) == 'torch.autograd.grad(torch.sin(x), x)[0]' + + +def test_torch_printing_dtype(): + if not torch: + skip("PyTorch not installed") + + # matrix printing with default dtype + expr = Matrix([[x, sin(y)], [exp(z), -t]]) + assert "dtype=torch.float64" in torch_code(expr) + + # explicit dtype + assert "dtype=torch.float32" in torch_code(expr, dtype="torch.float32") + + # with requires_grad + result = torch_code(expr, requires_grad=True) + assert "requires_grad=True" in result + assert "dtype=torch.float64" in result + + # both + result = torch_code(expr, requires_grad=True, dtype="torch.float32") + assert "requires_grad=True" in result + assert "dtype=torch.float32" in result + + +def test_requires_grad(): + if not torch: + skip("PyTorch not installed") + + expr = sin(x) + cos(y) + f = lambdify([x, y], expr, 'torch') + + # make sure the gradients flow + x_val = torch.tensor(1.0, requires_grad=True) + y_val = torch.tensor(2.0, requires_grad=True) + result = f(x_val, y_val) + assert result.requires_grad + result.backward() + + # x_val.grad should be cos(x_val) which is close to cos(1.0) + assert abs(x_val.grad.item() - float(cos(1.0).evalf())) < 1e-6 + + # y_val.grad should be -sin(y_val) which is close to -sin(2.0) + assert abs(y_val.grad.item() - float(-sin(2.0).evalf())) < 1e-6 + + +def test_torch_multi_variable_derivatives(): + if not torch: + skip("PyTorch not installed") + + x, y, z = symbols("x y z") + + expr = Derivative(sin(x), x) + assert torch_code(expr) == "torch.autograd.grad(torch.sin(x), x)[0]" + + expr = Derivative(sin(x), (x, 2)) + assert torch_code( + expr) == "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]" + + expr = Derivative(sin(x * y), x, y) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x*y), x, create_graph=True)[0], y, create_graph=True)[0]" + normalized_result = result.replace(" ", "") + normalized_expected = expected.replace(" ", "") + assert normalized_result == normalized_expected + + expr = Derivative(sin(x), x, x) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.sin(x), x, create_graph=True)[0], x, create_graph=True)[0]" + assert result == expected + + expr = Derivative(sin(x * y * z), x, (y, 2), z) + result = torch_code(expr) + expected = "torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.autograd.grad(torch.sin(x*y*z), x, create_graph=True)[0], y, create_graph=True)[0], y, create_graph=True)[0], z, create_graph=True)[0]" + normalized_result = result.replace(" ", "") + normalized_expected = expected.replace(" ", "") + assert normalized_result == normalized_expected + + +def test_torch_derivative_lambdify(): + if not torch: + skip("PyTorch not installed") + + x = symbols("x") + y = symbols("y") + + expr = Derivative(x ** 2, x) + f = lambdify(x, expr, 'torch') + x_val = torch.tensor(2.0, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(4.0)) + + expr = Derivative(sin(x), (x, 2)) + f = lambdify(x, expr, 'torch') + # Second derivative of sin(x) at x=0 is 0, not -1 + x_val = torch.tensor(0.0, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(0.0), atol=1e-5) + + x_val = torch.tensor(math.pi / 2, requires_grad=True) + result = f(x_val) + assert torch.isclose(result, torch.tensor(-1.0), atol=1e-5) + + expr = Derivative(x * y ** 2, x, y) + f = lambdify((x, y), expr, 'torch') + x_val = torch.tensor(2.0, requires_grad=True) + y_val = torch.tensor(3.0, requires_grad=True) + result = f(x_val, y_val) + assert torch.isclose(result, torch.tensor(6.0)) + + +def test_torch_special_matrices(): + if not torch: + skip("PyTorch not installed") + + expr = Identity(3) + assert torch_code(expr) == "torch.eye(3)" + + n = symbols("n") + expr = Identity(n) + assert torch_code(expr) == "torch.eye(n, n)" + + expr = ZeroMatrix(2, 3) + assert torch_code(expr) == "torch.zeros((2, 3))" + + m, n = symbols("m n") + expr = ZeroMatrix(m, n) + assert torch_code(expr) == "torch.zeros((m, n))" + + expr = OneMatrix(2, 3) + assert torch_code(expr) == "torch.ones((2, 3))" + + expr = OneMatrix(m, n) + assert torch_code(expr) == "torch.ones((m, n))" + + +def test_torch_special_matrices_lambdify(): + if not torch: + skip("PyTorch not installed") + + expr = Identity(3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.eye(3) + assert torch.allclose(result, expected) + + expr = ZeroMatrix(2, 3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.zeros((2, 3)) + assert torch.allclose(result, expected) + + expr = OneMatrix(2, 3) + f = lambdify([], expr, 'torch') + result = f() + expected = torch.ones((2, 3)) + assert torch.allclose(result, expected) + + +def test_torch_complex_operations(): + if not torch: + skip("PyTorch not installed") + + expr = conjugate(x) + assert torch_code(expr) == "torch.conj(x)" + + # SymPy distributes conjugate over addition and applies specific rules for each term + expr = conjugate(sin(x) + I * cos(y)) + assert torch_code(expr) == "torch.sin(torch.conj(x)) - 1j*torch.cos(torch.conj(y))" + + expr = I + assert torch_code(expr) == "1j" + + expr = 2 * I + x + assert torch_code(expr) == "x + 2*1j" + + expr = exp(I * x) + assert torch_code(expr) == "torch.exp(1j*x)" + + +def test_torch_special_functions(): + if not torch: + skip("PyTorch not installed") + + expr = Heaviside(x) + assert torch_code(expr) == "torch.heaviside(x, 1/2)" + + expr = Heaviside(x, 0) + assert torch_code(expr) == "torch.heaviside(x, 0)" + + expr = gamma(x) + assert torch_code(expr) == "torch.special.gamma(x)" + + expr = polygamma(0, x) # Use polygamma instead of digamma because sympy will default to that anyway + assert torch_code(expr) == "torch.special.digamma(x)" + + expr = gamma(sin(x)) + assert torch_code(expr) == "torch.special.gamma(torch.sin(x))" diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tree.py b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..cf116d0cac5d38f225815fcd2d4ac90cd0dd96d7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tests/test_tree.py @@ -0,0 +1,196 @@ +from sympy.printing.tree import tree +from sympy.testing.pytest import XFAIL + + +# Remove this flag after making _assumptions cache deterministic. +@XFAIL +def test_print_tree_MatAdd(): + from sympy.matrices.expressions import MatrixSymbol + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + + test_str = [ + 'MatAdd: A + B\n', + 'algebraic: False\n', + 'commutative: False\n', + 'complex: False\n', + 'composite: False\n', + 'even: False\n', + 'extended_negative: False\n', + 'extended_nonnegative: False\n', + 'extended_nonpositive: False\n', + 'extended_nonzero: False\n', + 'extended_positive: False\n', + 'extended_real: False\n', + 'imaginary: False\n', + 'integer: False\n', + 'irrational: False\n', + 'negative: False\n', + 'noninteger: False\n', + 'nonnegative: False\n', + 'nonpositive: False\n', + 'nonzero: False\n', + 'odd: False\n', + 'positive: False\n', + 'prime: False\n', + 'rational: False\n', + 'real: False\n', + 'transcendental: False\n', + 'zero: False\n', + '+-MatrixSymbol: A\n', + '| algebraic: False\n', + '| commutative: False\n', + '| complex: False\n', + '| composite: False\n', + '| even: False\n', + '| extended_negative: False\n', + '| extended_nonnegative: False\n', + '| extended_nonpositive: False\n', + '| extended_nonzero: False\n', + '| extended_positive: False\n', + '| extended_real: False\n', + '| imaginary: False\n', + '| integer: False\n', + '| irrational: False\n', + '| negative: False\n', + '| noninteger: False\n', + '| nonnegative: False\n', + '| nonpositive: False\n', + '| nonzero: False\n', + '| odd: False\n', + '| positive: False\n', + '| prime: False\n', + '| rational: False\n', + '| real: False\n', + '| transcendental: False\n', + '| zero: False\n', + '| +-Symbol: A\n', + '| | commutative: True\n', + '| +-Integer: 3\n', + '| | algebraic: True\n', + '| | commutative: True\n', + '| | complex: True\n', + '| | extended_negative: False\n', + '| | extended_nonnegative: True\n', + '| | extended_real: True\n', + '| | finite: True\n', + '| | hermitian: True\n', + '| | imaginary: False\n', + '| | infinite: False\n', + '| | integer: True\n', + '| | irrational: False\n', + '| | negative: False\n', + '| | noninteger: False\n', + '| | nonnegative: True\n', + '| | rational: True\n', + '| | real: True\n', + '| | transcendental: False\n', + '| +-Integer: 3\n', + '| algebraic: True\n', + '| commutative: True\n', + '| complex: True\n', + '| extended_negative: False\n', + '| extended_nonnegative: True\n', + '| extended_real: True\n', + '| finite: True\n', + '| hermitian: True\n', + '| imaginary: False\n', + '| infinite: False\n', + '| integer: True\n', + '| irrational: False\n', + '| negative: False\n', + '| noninteger: False\n', + '| nonnegative: True\n', + '| rational: True\n', + '| real: True\n', + '| transcendental: False\n', + '+-MatrixSymbol: B\n', + ' algebraic: False\n', + ' commutative: False\n', + ' complex: False\n', + ' composite: False\n', + ' even: False\n', + ' extended_negative: False\n', + ' extended_nonnegative: False\n', + ' extended_nonpositive: False\n', + ' extended_nonzero: False\n', + ' extended_positive: False\n', + ' extended_real: False\n', + ' imaginary: False\n', + ' integer: False\n', + ' irrational: False\n', + ' negative: False\n', + ' noninteger: False\n', + ' nonnegative: False\n', + ' nonpositive: False\n', + ' nonzero: False\n', + ' odd: False\n', + ' positive: False\n', + ' prime: False\n', + ' rational: False\n', + ' real: False\n', + ' transcendental: False\n', + ' zero: False\n', + ' +-Symbol: B\n', + ' | commutative: True\n', + ' +-Integer: 3\n', + ' | algebraic: True\n', + ' | commutative: True\n', + ' | complex: True\n', + ' | extended_negative: False\n', + ' | extended_nonnegative: True\n', + ' | extended_real: True\n', + ' | finite: True\n', + ' | hermitian: True\n', + ' | imaginary: False\n', + ' | infinite: False\n', + ' | integer: True\n', + ' | irrational: False\n', + ' | negative: False\n', + ' | noninteger: False\n', + ' | nonnegative: True\n', + ' | rational: True\n', + ' | real: True\n', + ' | transcendental: False\n', + ' +-Integer: 3\n', + ' algebraic: True\n', + ' commutative: True\n', + ' complex: True\n', + ' extended_negative: False\n', + ' extended_nonnegative: True\n', + ' extended_real: True\n', + ' finite: True\n', + ' hermitian: True\n', + ' imaginary: False\n', + ' infinite: False\n', + ' integer: True\n', + ' irrational: False\n', + ' negative: False\n', + ' noninteger: False\n', + ' nonnegative: True\n', + ' rational: True\n', + ' real: True\n', + ' transcendental: False\n' + ] + + assert tree(A + B) == "".join(test_str) + + +def test_print_tree_MatAdd_noassumptions(): + from sympy.matrices.expressions import MatrixSymbol + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + + test_str = \ +"""MatAdd: A + B ++-MatrixSymbol: A +| +-Str: A +| +-Integer: 3 +| +-Integer: 3 ++-MatrixSymbol: B + +-Str: B + +-Integer: 3 + +-Integer: 3 +""" + + assert tree(A + B, assumptions=False) == test_str diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/benchmarks/__init__.py b/.venv/lib/python3.13/site-packages/sympy/solvers/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/benchmarks/bench_solvers.py b/.venv/lib/python3.13/site-packages/sympy/solvers/benchmarks/bench_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..d18102873f7efcde1d111e0e8eca12e208f94663 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/benchmarks/bench_solvers.py @@ -0,0 +1,12 @@ +from sympy.core.symbol import Symbol +from sympy.matrices.dense import (eye, zeros) +from sympy.solvers.solvers import solve_linear_system + +N = 8 +M = zeros(N, N + 1) +M[:, :N] = eye(N) +S = [Symbol('A%i' % i) for i in range(N)] + + +def timeit_linsolve_trivial(): + solve_linear_system(M, *S) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py b/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b031a1e63fa445e3dbb0b425f84cfe88888667 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py @@ -0,0 +1,1071 @@ +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.matrices.dense import Matrix +from sympy.ntheory.factor_ import factorint +from sympy.simplify.powsimp import powsimp +from sympy.core.function import _mexpand +from sympy.core.sorting import default_sort_key, ordered +from sympy.functions.elementary.trigonometric import sin +from sympy.solvers.diophantine import diophantine +from sympy.solvers.diophantine.diophantine import (diop_DN, + diop_solve, diop_ternary_quadratic_normal, + diop_general_pythagorean, diop_ternary_quadratic, diop_linear, + diop_quadratic, diop_general_sum_of_squares, diop_general_sum_of_even_powers, + descent, diop_bf_DN, divisible, equivalent, find_DN, ldescent, length, + reconstruct, partition, power_representation, + prime_as_sum_of_two_squares, square_factor, sum_of_four_squares, + sum_of_three_squares, transformation_to_DN, transformation_to_normal, + classify_diop, base_solution_linear, cornacchia, sqf_normal, gaussian_reduce, holzer, + check_param, parametrize_ternary_quadratic, sum_of_powers, sum_of_squares, + _diop_ternary_quadratic_normal, _nint_or_floor, + _odd, _even, _remove_gcd, _can_do_sum_of_squares, DiophantineSolutionSet, GeneralPythagorean, + BinaryQuadratic) + +from sympy.testing.pytest import slow, raises, XFAIL +from sympy.utilities.iterables import ( + signed_permutations) + +a, b, c, d, p, q, x, y, z, w, t, u, v, X, Y, Z = symbols( + "a, b, c, d, p, q, x, y, z, w, t, u, v, X, Y, Z", integer=True) +t_0, t_1, t_2, t_3, t_4, t_5, t_6 = symbols("t_:7", integer=True) +m1, m2, m3 = symbols('m1:4', integer=True) +n1 = symbols('n1', integer=True) + + +def diop_simplify(eq): + return _mexpand(powsimp(_mexpand(eq))) + + +def test_input_format(): + raises(TypeError, lambda: diophantine(sin(x))) + raises(TypeError, lambda: diophantine(x/pi - 3)) + + +def test_nosols(): + # diophantine should sympify eq so that these are equivalent + assert diophantine(3) == set() + assert diophantine(S(3)) == set() + + +def test_univariate(): + assert diop_solve((x - 1)*(x - 2)**2) == {(1,), (2,)} + assert diop_solve((x - 1)*(x - 2)) == {(1,), (2,)} + + +def test_classify_diop(): + raises(TypeError, lambda: classify_diop(x**2/3 - 1)) + raises(ValueError, lambda: classify_diop(1)) + raises(NotImplementedError, lambda: classify_diop(w*x*y*z - 1)) + raises(NotImplementedError, lambda: classify_diop(x**3 + y**3 + z**4 - 90)) + assert classify_diop(14*x**2 + 15*x - 42) == ( + [x], {1: -42, x: 15, x**2: 14}, 'univariate') + assert classify_diop(x*y + z) == ( + [x, y, z], {x*y: 1, z: 1}, 'inhomogeneous_ternary_quadratic') + assert classify_diop(x*y + z + w + x**2) == ( + [w, x, y, z], {x*y: 1, w: 1, x**2: 1, z: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + x*z + x**2 + 1) == ( + [x, y, z], {x*y: 1, x*z: 1, x**2: 1, 1: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + z + w + 42) == ( + [w, x, y, z], {x*y: 1, w: 1, 1: 42, z: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + z*w) == ( + [w, x, y, z], {x*y: 1, w*z: 1}, 'homogeneous_general_quadratic') + assert classify_diop(x*y**2 + 1) == ( + [x, y], {x*y**2: 1, 1: 1}, 'cubic_thue') + assert classify_diop(x**4 + y**4 + z**4 - (1 + 16 + 81)) == ( + [x, y, z], {1: -98, x**4: 1, z**4: 1, y**4: 1}, 'general_sum_of_even_powers') + assert classify_diop(x**2 + y**2 + z**2) == ( + [x, y, z], {x**2: 1, y**2: 1, z**2: 1}, 'homogeneous_ternary_quadratic_normal') + + +def test_linear(): + assert diop_solve(x) == (0,) + assert diop_solve(1*x) == (0,) + assert diop_solve(3*x) == (0,) + assert diop_solve(x + 1) == (-1,) + assert diop_solve(2*x + 1) == (None,) + assert diop_solve(2*x + 4) == (-2,) + assert diop_solve(y + x) == (t_0, -t_0) + assert diop_solve(y + x + 0) == (t_0, -t_0) + assert diop_solve(y + x - 0) == (t_0, -t_0) + assert diop_solve(0*x - y - 5) == (-5,) + assert diop_solve(3*y + 2*x - 5) == (3*t_0 - 5, -2*t_0 + 5) + assert diop_solve(2*x - 3*y - 5) == (3*t_0 - 5, 2*t_0 - 5) + assert diop_solve(-2*x - 3*y - 5) == (3*t_0 + 5, -2*t_0 - 5) + assert diop_solve(7*x + 5*y) == (5*t_0, -7*t_0) + assert diop_solve(2*x + 4*y) == (-2*t_0, t_0) + assert diop_solve(4*x + 6*y - 4) == (3*t_0 - 2, -2*t_0 + 2) + assert diop_solve(4*x + 6*y - 3) == (None, None) + assert diop_solve(0*x + 3*y - 4*z + 5) == (4*t_0 + 5, 3*t_0 + 5) + assert diop_solve(4*x + 3*y - 4*z + 5) == (t_0, 8*t_0 + 4*t_1 + 5, 7*t_0 + 3*t_1 + 5) + assert diop_solve(4*x + 3*y - 4*z + 5, None) == (0, 5, 5) + assert diop_solve(4*x + 2*y + 8*z - 5) == (None, None, None) + assert diop_solve(5*x + 7*y - 2*z - 6) == (t_0, -3*t_0 + 2*t_1 + 6, -8*t_0 + 7*t_1 + 18) + assert diop_solve(3*x - 6*y + 12*z - 9) == (2*t_0 + 3, t_0 + 2*t_1, t_1) + assert diop_solve(6*w + 9*x + 20*y - z) == (t_0, t_1, t_1 + t_2, 6*t_0 + 29*t_1 + 20*t_2) + + # to ignore constant factors, use diophantine + raises(TypeError, lambda: diop_solve(x/2)) + + +def test_quadratic_simple_hyperbolic_case(): + # Simple Hyperbolic case: A = C = 0 and B != 0 + assert diop_solve(3*x*y + 34*x - 12*y + 1) == \ + {(-133, -11), (5, -57)} + assert diop_solve(6*x*y + 2*x + 3*y + 1) == set() + assert diop_solve(-13*x*y + 2*x - 4*y - 54) == {(27, 0)} + assert diop_solve(-27*x*y - 30*x - 12*y - 54) == {(-14, -1)} + assert diop_solve(2*x*y + 5*x + 56*y + 7) == {(-161, -3), (-47, -6), (-35, -12), + (-29, -69), (-27, 64), (-21, 7), + (-9, 1), (105, -2)} + assert diop_solve(6*x*y + 9*x + 2*y + 3) == set() + assert diop_solve(x*y + x + y + 1) == {(-1, t), (t, -1)} + assert diophantine(48*x*y) + + +def test_quadratic_elliptical_case(): + # Elliptical case: B**2 - 4AC < 0 + + assert diop_solve(42*x**2 + 8*x*y + 15*y**2 + 23*x + 17*y - 4915) == {(-11, -1)} + assert diop_solve(4*x**2 + 3*y**2 + 5*x - 11*y + 12) == set() + assert diop_solve(x**2 + y**2 + 2*x + 2*y + 2) == {(-1, -1)} + assert diop_solve(15*x**2 - 9*x*y + 14*y**2 - 23*x - 14*y - 4950) == {(-15, 6)} + assert diop_solve(10*x**2 + 12*x*y + 12*y**2 - 34) == \ + {(-1, -1), (-1, 2), (1, -2), (1, 1)} + + +def test_quadratic_parabolic_case(): + # Parabolic case: B**2 - 4AC = 0 + assert check_solutions(8*x**2 - 24*x*y + 18*y**2 + 5*x + 7*y + 16) + assert check_solutions(8*x**2 - 24*x*y + 18*y**2 + 6*x + 12*y - 6) + assert check_solutions(8*x**2 + 24*x*y + 18*y**2 + 4*x + 6*y - 7) + assert check_solutions(-4*x**2 + 4*x*y - y**2 + 2*x - 3) + assert check_solutions(x**2 + 2*x*y + y**2 + 2*x + 2*y + 1) + assert check_solutions(x**2 - 2*x*y + y**2 + 2*x + 2*y + 1) + assert check_solutions(y**2 - 41*x + 40) + + +def test_quadratic_perfect_square(): + # B**2 - 4*A*C > 0 + # B**2 - 4*A*C is a perfect square + assert check_solutions(48*x*y) + assert check_solutions(4*x**2 - 5*x*y + y**2 + 2) + assert check_solutions(-2*x**2 - 3*x*y + 2*y**2 -2*x - 17*y + 25) + assert check_solutions(12*x**2 + 13*x*y + 3*y**2 - 2*x + 3*y - 12) + assert check_solutions(8*x**2 + 10*x*y + 2*y**2 - 32*x - 13*y - 23) + assert check_solutions(4*x**2 - 4*x*y - 3*y- 8*x - 3) + assert check_solutions(- 4*x*y - 4*y**2 - 3*y- 5*x - 10) + assert check_solutions(x**2 - y**2 - 2*x - 2*y) + assert check_solutions(x**2 - 9*y**2 - 2*x - 6*y) + assert check_solutions(4*x**2 - 9*y**2 - 4*x - 12*y - 3) + + +def test_quadratic_non_perfect_square(): + # B**2 - 4*A*C is not a perfect square + # Used check_solutions() since the solutions are complex expressions involving + # square roots and exponents + assert check_solutions(x**2 - 2*x - 5*y**2) + assert check_solutions(3*x**2 - 2*y**2 - 2*x - 2*y) + assert check_solutions(x**2 - x*y - y**2 - 3*y) + assert check_solutions(x**2 - 9*y**2 - 2*x - 6*y) + assert BinaryQuadratic(x**2 + y**2 + 2*x + 2*y + 2).solve() == {(-1, -1)} + + +def test_issue_9106(): + eq = -48 - 2*x*(3*x - 1) + y*(3*y - 1) + v = (x, y) + for sol in diophantine(eq): + assert not diop_simplify(eq.xreplace(dict(zip(v, sol)))) + + +def test_issue_18138(): + eq = x**2 - x - y**2 + v = (x, y) + for sol in diophantine(eq): + assert not diop_simplify(eq.xreplace(dict(zip(v, sol)))) + + +@slow +def test_quadratic_non_perfect_slow(): + assert check_solutions(8*x**2 + 10*x*y - 2*y**2 - 32*x - 13*y - 23) + # This leads to very large numbers. + # assert check_solutions(5*x**2 - 13*x*y + y**2 - 4*x - 4*y - 15) + assert check_solutions(-3*x**2 - 2*x*y + 7*y**2 - 5*x - 7) + assert check_solutions(-4 - x + 4*x**2 - y - 3*x*y - 4*y**2) + assert check_solutions(1 + 2*x + 2*x**2 + 2*y + x*y - 2*y**2) + + +def test_DN(): + # Most of the test cases were adapted from, + # Solving the generalized Pell equation x**2 - D*y**2 = N, John P. Robertson, July 31, 2004. + # https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + # others are verified using Wolfram Alpha. + + # Covers cases where D <= 0 or D > 0 and D is a square or N = 0 + # Solutions are straightforward in these cases. + assert diop_DN(3, 0) == [(0, 0)] + assert diop_DN(-17, -5) == [] + assert diop_DN(-19, 23) == [(2, 1)] + assert diop_DN(-13, 17) == [(2, 1)] + assert diop_DN(-15, 13) == [] + assert diop_DN(0, 5) == [] + assert diop_DN(0, 9) == [(3, t)] + assert diop_DN(9, 0) == [(3*t, t)] + assert diop_DN(16, 24) == [] + assert diop_DN(9, 180) == [(18, 4)] + assert diop_DN(9, -180) == [(12, 6)] + assert diop_DN(7, 0) == [(0, 0)] + + # When equation is x**2 + y**2 = N + # Solutions are interchangeable + assert diop_DN(-1, 5) == [(2, 1), (1, 2)] + assert diop_DN(-1, 169) == [(12, 5), (5, 12), (13, 0), (0, 13)] + + # D > 0 and D is not a square + + # N = 1 + assert diop_DN(13, 1) == [(649, 180)] + assert diop_DN(980, 1) == [(51841, 1656)] + assert diop_DN(981, 1) == [(158070671986249, 5046808151700)] + assert diop_DN(986, 1) == [(49299, 1570)] + assert diop_DN(991, 1) == [(379516400906811930638014896080, 12055735790331359447442538767)] + assert diop_DN(17, 1) == [(33, 8)] + assert diop_DN(19, 1) == [(170, 39)] + + # N = -1 + assert diop_DN(13, -1) == [(18, 5)] + assert diop_DN(991, -1) == [] + assert diop_DN(41, -1) == [(32, 5)] + assert diop_DN(290, -1) == [(17, 1)] + assert diop_DN(21257, -1) == [(13913102721304, 95427381109)] + assert diop_DN(32, -1) == [] + + # |N| > 1 + # Some tests were created using calculator at + # http://www.numbertheory.org/php/patz.html + + assert diop_DN(13, -4) == [(3, 1), (393, 109), (36, 10)] + # Source I referred returned (3, 1), (393, 109) and (-3, 1) as fundamental solutions + # So (-3, 1) and (393, 109) should be in the same equivalent class + assert equivalent(-3, 1, 393, 109, 13, -4) == True + + assert diop_DN(13, 27) == [(220, 61), (40, 11), (768, 213), (12, 3)] + assert set(diop_DN(157, 12)) == {(13, 1), (10663, 851), (579160, 46222), + (483790960, 38610722), (26277068347, 2097138361), + (21950079635497, 1751807067011)} + assert diop_DN(13, 25) == [(3245, 900)] + assert diop_DN(192, 18) == [] + assert diop_DN(23, 13) == [(-6, 1), (6, 1)] + assert diop_DN(167, 2) == [(13, 1)] + assert diop_DN(167, -2) == [] + + assert diop_DN(123, -2) == [(11, 1)] + # One calculator returned [(11, 1), (-11, 1)] but both of these are in + # the same equivalence class + assert equivalent(11, 1, -11, 1, 123, -2) + + assert diop_DN(123, -23) == [(-10, 1), (10, 1)] + + assert diop_DN(0, 0, t) == [(0, t)] + assert diop_DN(0, -1, t) == [] + + +def test_bf_pell(): + assert diop_bf_DN(13, -4) == [(3, 1), (-3, 1), (36, 10)] + assert diop_bf_DN(13, 27) == [(12, 3), (-12, 3), (40, 11), (-40, 11)] + assert diop_bf_DN(167, -2) == [] + assert diop_bf_DN(1729, 1) == [(44611924489705, 1072885712316)] + assert diop_bf_DN(89, -8) == [(9, 1), (-9, 1)] + assert diop_bf_DN(21257, -1) == [(13913102721304, 95427381109)] + assert diop_bf_DN(340, -4) == [(756, 41)] + assert diop_bf_DN(-1, 0, t) == [(0, 0)] + assert diop_bf_DN(0, 0, t) == [(0, t)] + assert diop_bf_DN(4, 0, t) == [(2*t, t), (-2*t, t)] + assert diop_bf_DN(3, 0, t) == [(0, 0)] + assert diop_bf_DN(1, -2, t) == [] + + +def test_length(): + assert length(2, 1, 0) == 1 + assert length(-2, 4, 5) == 3 + assert length(-5, 4, 17) == 4 + assert length(0, 4, 13) == 6 + assert length(7, 13, 11) == 23 + assert length(1, 6, 4) == 2 + + +def is_pell_transformation_ok(eq): + """ + Test whether X*Y, X, or Y terms are present in the equation + after transforming the equation using the transformation returned + by transformation_to_pell(). If they are not present we are good. + Moreover, coefficient of X**2 should be a divisor of coefficient of + Y**2 and the constant term. + """ + A, B = transformation_to_DN(eq) + u = (A*Matrix([X, Y]) + B)[0] + v = (A*Matrix([X, Y]) + B)[1] + simplified = diop_simplify(eq.subs(zip((x, y), (u, v)))) + + coeff = dict([reversed(t.as_independent(*[X, Y])) for t in simplified.args]) + + for term in [X*Y, X, Y]: + if term in coeff.keys(): + return False + + for term in [X**2, Y**2, 1]: + if term not in coeff.keys(): + coeff[term] = 0 + + if coeff[X**2] != 0: + return divisible(coeff[Y**2], coeff[X**2]) and \ + divisible(coeff[1], coeff[X**2]) + + return True + + +def test_transformation_to_pell(): + assert is_pell_transformation_ok(-13*x**2 - 7*x*y + y**2 + 2*x - 2*y - 14) + assert is_pell_transformation_ok(-17*x**2 + 19*x*y - 7*y**2 - 5*x - 13*y - 23) + assert is_pell_transformation_ok(x**2 - y**2 + 17) + assert is_pell_transformation_ok(-x**2 + 7*y**2 - 23) + assert is_pell_transformation_ok(25*x**2 - 45*x*y + 5*y**2 - 5*x - 10*y + 5) + assert is_pell_transformation_ok(190*x**2 + 30*x*y + y**2 - 3*y - 170*x - 130) + assert is_pell_transformation_ok(x**2 - 2*x*y -190*y**2 - 7*y - 23*x - 89) + assert is_pell_transformation_ok(15*x**2 - 9*x*y + 14*y**2 - 23*x - 14*y - 4950) + + +def test_find_DN(): + assert find_DN(x**2 - 2*x - y**2) == (1, 1) + assert find_DN(x**2 - 3*y**2 - 5) == (3, 5) + assert find_DN(x**2 - 2*x*y - 4*y**2 - 7) == (5, 7) + assert find_DN(4*x**2 - 8*x*y - y**2 - 9) == (20, 36) + assert find_DN(7*x**2 - 2*x*y - y**2 - 12) == (8, 84) + assert find_DN(-3*x**2 + 4*x*y -y**2) == (1, 0) + assert find_DN(-13*x**2 - 7*x*y + y**2 + 2*x - 2*y -14) == (101, -7825480) + + +def test_ldescent(): + # Equations which have solutions + u = ([(13, 23), (3, -11), (41, -113), (4, -7), (-7, 4), (91, -3), (1, 1), (1, -1), + (4, 32), (17, 13), (123689, 1), (19, -570)]) + for a, b in u: + w, x, y = ldescent(a, b) + assert a*x**2 + b*y**2 == w**2 + assert ldescent(-1, -1) is None + assert ldescent(2, 6) is None + + +def test_diop_ternary_quadratic_normal(): + assert check_solutions(234*x**2 - 65601*y**2 - z**2) + assert check_solutions(23*x**2 + 616*y**2 - z**2) + assert check_solutions(5*x**2 + 4*y**2 - z**2) + assert check_solutions(3*x**2 + 6*y**2 - 3*z**2) + assert check_solutions(x**2 + 3*y**2 - z**2) + assert check_solutions(4*x**2 + 5*y**2 - z**2) + assert check_solutions(x**2 + y**2 - z**2) + assert check_solutions(16*x**2 + y**2 - 25*z**2) + assert check_solutions(6*x**2 - y**2 + 10*z**2) + assert check_solutions(213*x**2 + 12*y**2 - 9*z**2) + assert check_solutions(34*x**2 - 3*y**2 - 301*z**2) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + + +def is_normal_transformation_ok(eq): + A = transformation_to_normal(eq) + X, Y, Z = A*Matrix([x, y, z]) + simplified = diop_simplify(eq.subs(zip((x, y, z), (X, Y, Z)))) + + coeff = dict([reversed(t.as_independent(*[X, Y, Z])) for t in simplified.args]) + for term in [X*Y, Y*Z, X*Z]: + if term in coeff.keys(): + return False + + return True + + +def test_transformation_to_normal(): + assert is_normal_transformation_ok(x**2 + 3*y**2 + z**2 - 13*x*y - 16*y*z + 12*x*z) + assert is_normal_transformation_ok(x**2 + 3*y**2 - 100*z**2) + assert is_normal_transformation_ok(x**2 + 23*y*z) + assert is_normal_transformation_ok(3*y**2 - 100*z**2 - 12*x*y) + assert is_normal_transformation_ok(x**2 + 23*x*y - 34*y*z + 12*x*z) + assert is_normal_transformation_ok(z**2 + 34*x*y - 23*y*z + x*z) + assert is_normal_transformation_ok(x**2 + y**2 + z**2 - x*y - y*z - x*z) + assert is_normal_transformation_ok(x**2 + 2*y*z + 3*z**2) + assert is_normal_transformation_ok(x*y + 2*x*z + 3*y*z) + assert is_normal_transformation_ok(2*x*z + 3*y*z) + + +def test_diop_ternary_quadratic(): + assert check_solutions(2*x**2 + z**2 + y**2 - 4*x*y) + assert check_solutions(x**2 - y**2 - z**2 - x*y - y*z) + assert check_solutions(3*x**2 - x*y - y*z - x*z) + assert check_solutions(x**2 - y*z - x*z) + assert check_solutions(5*x**2 - 3*x*y - x*z) + assert check_solutions(4*x**2 - 5*y**2 - x*z) + assert check_solutions(3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) + assert check_solutions(8*x**2 - 12*y*z) + assert check_solutions(45*x**2 - 7*y**2 - 8*x*y - z**2) + assert check_solutions(x**2 - 49*y**2 - z**2 + 13*z*y -8*x*y) + assert check_solutions(90*x**2 + 3*y**2 + 5*x*y + 2*z*y + 5*x*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - x*y - 17*y*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - x*y - 16*y*z + 12*x*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - 13*x*y - 16*y*z + 12*x*z) + assert check_solutions(x*y - 7*y*z + 13*x*z) + + assert diop_ternary_quadratic_normal(x**2 + y**2 + z**2) == (None, None, None) + assert diop_ternary_quadratic_normal(x**2 + y**2) is None + raises(ValueError, lambda: + _diop_ternary_quadratic_normal((x, y, z), + {x*y: 1, x**2: 2, y**2: 3, z**2: 0})) + eq = -2*x*y - 6*x*z + 7*y**2 - 3*y*z + 4*z**2 + assert diop_ternary_quadratic(eq) == (7, 2, 0) + assert diop_ternary_quadratic_normal(4*x**2 + 5*y**2 - z**2) == \ + (1, 0, 2) + assert diop_ternary_quadratic(x*y + 2*y*z) == \ + (-2, 0, n1) + eq = -5*x*y - 8*x*z - 3*y*z + 8*z**2 + assert parametrize_ternary_quadratic(eq) == \ + (8*p**2 - 3*p*q, -8*p*q + 8*q**2, 5*p*q) + # this cannot be tested with diophantine because it will + # factor into a product + assert diop_solve(x*y + 2*y*z) == (-2*p*q, -n1*p**2 + p**2, p*q) + + +def test_square_factor(): + assert square_factor(1) == square_factor(-1) == 1 + assert square_factor(0) == 1 + assert square_factor(5) == square_factor(-5) == 1 + assert square_factor(4) == square_factor(-4) == 2 + assert square_factor(12) == square_factor(-12) == 2 + assert square_factor(6) == 1 + assert square_factor(18) == 3 + assert square_factor(52) == 2 + assert square_factor(49) == 7 + assert square_factor(392) == 14 + assert square_factor(factorint(-12)) == 2 + + +def test_parametrize_ternary_quadratic(): + assert check_solutions(x**2 + y**2 - z**2) + assert check_solutions(x**2 + 2*x*y + z**2) + assert check_solutions(234*x**2 - 65601*y**2 - z**2) + assert check_solutions(3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) + assert check_solutions(x**2 - y**2 - z**2) + assert check_solutions(x**2 - 49*y**2 - z**2 + 13*z*y - 8*x*y) + assert check_solutions(8*x*y + z**2) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + assert check_solutions(236*x**2 - 225*y**2 - 11*x*y - 13*y*z - 17*x*z) + assert check_solutions(90*x**2 + 3*y**2 + 5*x*y + 2*z*y + 5*x*z) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + + +def test_no_square_ternary_quadratic(): + assert check_solutions(2*x*y + y*z - 3*x*z) + assert check_solutions(189*x*y - 345*y*z - 12*x*z) + assert check_solutions(23*x*y + 34*y*z) + assert check_solutions(x*y + y*z + z*x) + assert check_solutions(23*x*y + 23*y*z + 23*x*z) + + +def test_descent(): + + u = ([(13, 23), (3, -11), (41, -113), (91, -3), (1, 1), (1, -1), (17, 13), (123689, 1), (19, -570)]) + for a, b in u: + w, x, y = descent(a, b) + assert a*x**2 + b*y**2 == w**2 + # the docstring warns against bad input, so these are expected results + # - can't both be negative + raises(TypeError, lambda: descent(-1, -3)) + # A can't be zero unless B != 1 + raises(ZeroDivisionError, lambda: descent(0, 3)) + # supposed to be square-free + raises(TypeError, lambda: descent(4, 3)) + + +def test_diophantine(): + assert check_solutions((x - y)*(y - z)*(z - x)) + assert check_solutions((x - y)*(x**2 + y**2 - z**2)) + assert check_solutions((x - 3*y + 7*z)*(x**2 + y**2 - z**2)) + assert check_solutions(x**2 - 3*y**2 - 1) + assert check_solutions(y**2 + 7*x*y) + assert check_solutions(x**2 - 3*x*y + y**2) + assert check_solutions(z*(x**2 - y**2 - 15)) + assert check_solutions(x*(2*y - 2*z + 5)) + assert check_solutions((x**2 - 3*y**2 - 1)*(x**2 - y**2 - 15)) + assert check_solutions((x**2 - 3*y**2 - 1)*(y - 7*z)) + assert check_solutions((x**2 + y**2 - z**2)*(x - 7*y - 3*z + 4*w)) + # Following test case caused problems in parametric representation + # But this can be solved by factoring out y. + # No need to use methods for ternary quadratic equations. + assert check_solutions(y**2 - 7*x*y + 4*y*z) + assert check_solutions(x**2 - 2*x + 1) + + assert diophantine(x - y) == diophantine(Eq(x, y)) + # 18196 + eq = x**4 + y**4 - 97 + assert diophantine(eq, permute=True) == diophantine(-eq, permute=True) + assert diophantine(3*x*pi - 2*y*pi) == {(2*t_0, 3*t_0)} + eq = x**2 + y**2 + z**2 - 14 + base_sol = {(1, 2, 3)} + assert diophantine(eq) == base_sol + complete_soln = set(signed_permutations(base_sol.pop())) + assert diophantine(eq, permute=True) == complete_soln + + assert diophantine(x**2 + x*Rational(15, 14) - 3) == set() + # test issue 11049 + eq = 92*x**2 - 99*y**2 - z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(9, 7, 51)} + assert diophantine(eq) == {( + 891*p**2 + 9*q**2, -693*p**2 - 102*p*q + 7*q**2, + 5049*p**2 - 1386*p*q - 51*q**2)} + eq = 2*x**2 + 2*y**2 - z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(1, 1, 2)} + assert diophantine(eq) == {( + 2*p**2 - q**2, -2*p**2 + 4*p*q - q**2, + 4*p**2 - 4*p*q + 2*q**2)} + eq = 411*x**2+57*y**2-221*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(2021, 2645, 3066)} + assert diophantine(eq) == \ + {(115197*p**2 - 446641*q**2, -150765*p**2 + 1355172*p*q - + 584545*q**2, 174762*p**2 - 301530*p*q + 677586*q**2)} + eq = 573*x**2+267*y**2-984*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(49, 233, 127)} + assert diophantine(eq) == \ + {(4361*p**2 - 16072*q**2, -20737*p**2 + 83312*p*q - 76424*q**2, + 11303*p**2 - 41474*p*q + 41656*q**2)} + # this produces factors during reconstruction + eq = x**2 + 3*y**2 - 12*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(0, 2, 1)} + assert diophantine(eq) == \ + {(24*p*q, 2*p**2 - 24*q**2, p**2 + 12*q**2)} + # solvers have not been written for every type + raises(NotImplementedError, lambda: diophantine(x*y**2 + 1)) + + # rational expressions + assert diophantine(1/x) == set() + assert diophantine(1/x + 1/y - S.Half) == {(6, 3), (-2, 1), (4, 4), (1, -2), (3, 6)} + assert diophantine(x**2 + y**2 +3*x- 5, permute=True) == \ + {(-1, 1), (-4, -1), (1, -1), (1, 1), (-4, 1), (-1, -1), (4, 1), (4, -1)} + + + #test issue 18186 + assert diophantine(y**4 + x**4 - 2**4 - 3**4, syms=(x, y), permute=True) == \ + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + assert diophantine(y**4 + x**4 - 2**4 - 3**4, syms=(y, x), permute=True) == \ + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + + # issue 18122 + assert check_solutions(x**2 - y) + assert check_solutions(y**2 - x) + assert diophantine((x**2 - y), t) == {(t, t**2)} + assert diophantine((y**2 - x), t) == {(t**2, t)} + + +def test_general_pythagorean(): + from sympy.abc import a, b, c, d, e + + assert check_solutions(a**2 + b**2 + c**2 - d**2) + assert check_solutions(a**2 + 4*b**2 + 4*c**2 - d**2) + assert check_solutions(9*a**2 + 4*b**2 + 4*c**2 - d**2) + assert check_solutions(9*a**2 + 4*b**2 - 25*d**2 + 4*c**2 ) + assert check_solutions(9*a**2 - 16*d**2 + 4*b**2 + 4*c**2) + assert check_solutions(-e**2 + 9*a**2 + 4*b**2 + 4*c**2 + 25*d**2) + assert check_solutions(16*a**2 - b**2 + 9*c**2 + d**2 + 25*e**2) + + assert GeneralPythagorean(a**2 + b**2 + c**2 - d**2).solve(parameters=[x, y, z]) == \ + {(x**2 + y**2 - z**2, 2*x*z, 2*y*z, x**2 + y**2 + z**2)} + + +def test_diop_general_sum_of_squares_quick(): + for i in range(3, 10): + assert check_solutions(sum(i**2 for i in symbols(':%i' % i)) - i) + + assert diop_general_sum_of_squares(x**2 + y**2 - 2) is None + assert diop_general_sum_of_squares(x**2 + y**2 + z**2 + 2) == set() + eq = x**2 + y**2 + z**2 - (1 + 4 + 9) + assert diop_general_sum_of_squares(eq) == \ + {(1, 2, 3)} + eq = u**2 + v**2 + x**2 + y**2 + z**2 - 1313 + assert len(diop_general_sum_of_squares(eq, 3)) == 3 + # issue 11016 + var = symbols(':5') + (symbols('6', negative=True),) + eq = Add(*[i**2 for i in var]) - 112 + + base_soln = {(0, 1, 1, 5, 6, -7), (1, 1, 1, 3, 6, -8), (2, 3, 3, 4, 5, -7), (0, 1, 1, 1, 3, -10), + (0, 0, 4, 4, 4, -8), (1, 2, 3, 3, 5, -8), (0, 1, 2, 3, 7, -7), (2, 2, 4, 4, 6, -6), + (1, 1, 3, 4, 6, -7), (0, 2, 3, 3, 3, -9), (0, 0, 2, 2, 2, -10), (1, 1, 2, 3, 4, -9), + (0, 1, 1, 2, 5, -9), (0, 0, 2, 6, 6, -6), (1, 3, 4, 5, 5, -6), (0, 2, 2, 2, 6, -8), + (0, 3, 3, 3, 6, -7), (0, 2, 3, 5, 5, -7), (0, 1, 5, 5, 5, -6)} + assert diophantine(eq) == base_soln + assert len(diophantine(eq, permute=True)) == 196800 + + # handle negated squares with signsimp + assert diophantine(12 - x**2 - y**2 - z**2) == {(2, 2, 2)} + # diophantine handles simplification, so classify_diop should + # not have to look for additional patterns that are removed + # by diophantine + eq = a**2 + b**2 + c**2 + d**2 - 4 + raises(NotImplementedError, lambda: classify_diop(-eq)) + + +def test_issue_23807(): + # fixes recursion error + eq = x**2 + y**2 + z**2 - 1000000 + base_soln = {(0, 0, 1000), (0, 352, 936), (480, 600, 640), (24, 640, 768), (192, 640, 744), + (192, 480, 856), (168, 224, 960), (0, 600, 800), (280, 576, 768), (152, 480, 864), + (0, 280, 960), (352, 360, 864), (424, 480, 768), (360, 480, 800), (224, 600, 768), + (96, 360, 928), (168, 576, 800), (96, 480, 872)} + + assert diophantine(eq) == base_soln + + +def test_diop_partition(): + for n in [8, 10]: + for k in range(1, 8): + for p in partition(n, k): + assert len(p) == k + assert list(partition(3, 5)) == [] + assert [list(p) for p in partition(3, 5, 1)] == [ + [0, 0, 0, 0, 3], [0, 0, 0, 1, 2], [0, 0, 1, 1, 1]] + assert list(partition(0)) == [()] + assert list(partition(1, 0)) == [()] + assert [list(i) for i in partition(3)] == [[1, 1, 1], [1, 2], [3]] + + +def test_prime_as_sum_of_two_squares(): + for i in [5, 13, 17, 29, 37, 41, 2341, 3557, 34841, 64601]: + a, b = prime_as_sum_of_two_squares(i) + assert a**2 + b**2 == i + assert prime_as_sum_of_two_squares(7) is None + ans = prime_as_sum_of_two_squares(800029) + assert ans == (450, 773) and type(ans[0]) is int + + +def test_sum_of_three_squares(): + for i in [0, 1, 2, 34, 123, 34304595905, 34304595905394941, 343045959052344, + 800, 801, 802, 803, 804, 805, 806]: + a, b, c = sum_of_three_squares(i) + assert a**2 + b**2 + c**2 == i + assert a >= 0 + + # error + raises(ValueError, lambda: sum_of_three_squares(-1)) + + assert sum_of_three_squares(7) is None + assert sum_of_three_squares((4**5)*15) is None + # if there are two zeros, there might be a solution + # with only one zero, e.g. 25 => (0, 3, 4) or + # with no zeros, e.g. 49 => (2, 3, 6) + assert sum_of_three_squares(25) == (0, 0, 5) + assert sum_of_three_squares(4) == (0, 0, 2) + + +def test_sum_of_four_squares(): + from sympy.core.random import randint + + # this should never fail + n = randint(1, 100000000000000) + assert sum(i**2 for i in sum_of_four_squares(n)) == n + + # error + raises(ValueError, lambda: sum_of_four_squares(-1)) + + for n in range(1000): + result = sum_of_four_squares(n) + assert len(result) == 4 + assert all(r >= 0 for r in result) + assert sum(r**2 for r in result) == n + assert list(result) == sorted(result) + + +def test_power_representation(): + tests = [(1729, 3, 2), (234, 2, 4), (2, 1, 2), (3, 1, 3), (5, 2, 2), (12352, 2, 4), + (32760, 2, 3)] + + for test in tests: + n, p, k = test + f = power_representation(n, p, k) + + while True: + try: + l = next(f) + assert len(l) == k + + chk_sum = 0 + for l_i in l: + chk_sum = chk_sum + l_i**p + assert chk_sum == n + + except StopIteration: + break + + assert list(power_representation(20, 2, 4, True)) == \ + [(1, 1, 3, 3), (0, 0, 2, 4)] + raises(ValueError, lambda: list(power_representation(1.2, 2, 2))) + raises(ValueError, lambda: list(power_representation(2, 0, 2))) + raises(ValueError, lambda: list(power_representation(2, 2, 0))) + assert list(power_representation(-1, 2, 2)) == [] + assert list(power_representation(1, 1, 1)) == [(1,)] + assert list(power_representation(3, 2, 1)) == [] + assert list(power_representation(4, 2, 1)) == [(2,)] + assert list(power_representation(3**4, 4, 6, zeros=True)) == \ + [(1, 2, 2, 2, 2, 2), (0, 0, 0, 0, 0, 3)] + assert list(power_representation(3**4, 4, 5, zeros=False)) == [] + assert list(power_representation(-2, 3, 2)) == [(-1, -1)] + assert list(power_representation(-2, 4, 2)) == [] + assert list(power_representation(0, 3, 2, True)) == [(0, 0)] + assert list(power_representation(0, 3, 2, False)) == [] + # when we are dealing with squares, do feasibility checks + assert len(list(power_representation(4**10*(8*10 + 7), 2, 3))) == 0 + # there will be a recursion error if these aren't recognized + big = 2**30 + for i in [13, 10, 7, 5, 4, 2, 1]: + assert list(sum_of_powers(big, 2, big - i)) == [] + + +def test_assumptions(): + """ + Test whether diophantine respects the assumptions. + """ + #Test case taken from the below so question regarding assumptions in diophantine module + #https://stackoverflow.com/questions/23301941/how-can-i-declare-natural-symbols-with-sympy + m, n = symbols('m n', integer=True, positive=True) + diof = diophantine(n**2 + m*n - 500) + assert diof == {(5, 20), (40, 10), (95, 5), (121, 4), (248, 2), (499, 1)} + + a, b = symbols('a b', integer=True, positive=False) + diof = diophantine(a*b + 2*a + 3*b - 6) + assert diof == {(-15, -3), (-9, -4), (-7, -5), (-6, -6), (-5, -8), (-4, -14)} + + x, y = symbols('x y', integer=True) + diof = diophantine(10*x**2 + 5*x*y - 3*y) + assert diof == {(1, -5), (-3, 5), (0, 0)} + + x, y = symbols('x y', integer=True, positive=True) + diof = diophantine(10*x**2 + 5*x*y - 3*y) + assert diof == set() + + x, y = symbols('x y', integer=True, negative=False) + diof = diophantine(10*x**2 + 5*x*y - 3*y) + assert diof == {(0, 0)} + + +def check_solutions(eq): + """ + Determines whether solutions returned by diophantine() satisfy the original + equation. Hope to generalize this so we can remove functions like check_ternay_quadratic, + check_solutions_normal, check_solutions() + """ + s = diophantine(eq) + + factors = Mul.make_args(eq) + + var = list(eq.free_symbols) + var.sort(key=default_sort_key) + + while s: + solution = s.pop() + for f in factors: + if diop_simplify(f.subs(zip(var, solution))) == 0: + break + else: + return False + return True + + +def test_diopcoverage(): + eq = (2*x + y + 1)**2 + assert diop_solve(eq) == {(t_0, -2*t_0 - 1)} + eq = 2*x**2 + 6*x*y + 12*x + 4*y**2 + 18*y + 18 + assert diop_solve(eq) == {(t, -t - 3), (-2*t - 3, t)} + assert diop_quadratic(x + y**2 - 3) == {(-t**2 + 3, t)} + + assert diop_linear(x + y - 3) == (t_0, 3 - t_0) + + assert base_solution_linear(0, 1, 2, t=None) == (0, 0) + ans = (3*t - 1, -2*t + 1) + assert base_solution_linear(4, 8, 12, t) == ans + assert base_solution_linear(4, 8, 12, t=None) == tuple(_.subs(t, 0) for _ in ans) + + assert cornacchia(1, 1, 20) == set() + assert cornacchia(1, 1, 5) == {(2, 1)} + assert cornacchia(1, 2, 17) == {(3, 2)} + + raises(ValueError, lambda: reconstruct(4, 20, 1)) + + assert gaussian_reduce(4, 1, 3) == (1, 1) + eq = -w**2 - x**2 - y**2 + z**2 + + assert diop_general_pythagorean(eq) == \ + diop_general_pythagorean(-eq) == \ + (m1**2 + m2**2 - m3**2, 2*m1*m3, + 2*m2*m3, m1**2 + m2**2 + m3**2) + + assert len(check_param(S(3) + x/3, S(4) + x/2, S(2), [x])) == 0 + assert len(check_param(Rational(3, 2), S(4) + x, S(2), [x])) == 0 + assert len(check_param(S(4) + x, Rational(3, 2), S(2), [x])) == 0 + + assert _nint_or_floor(16, 10) == 2 + assert _odd(1) == (not _even(1)) == True + assert _odd(0) == (not _even(0)) == False + assert _remove_gcd(2, 4, 6) == (1, 2, 3) + raises(TypeError, lambda: _remove_gcd((2, 4, 6))) + assert sqf_normal(2*3**2*5, 2*5*11, 2*7**2*11) == \ + (11, 1, 5) + + # it's ok if these pass some day when the solvers are implemented + raises(NotImplementedError, lambda: diophantine(x**2 + y**2 + x*y + 2*y*z - 12)) + raises(NotImplementedError, lambda: diophantine(x**3 + y**2)) + assert diop_quadratic(x**2 + y**2 - 1**2 - 3**4) == \ + {(-9, -1), (-9, 1), (-1, -9), (-1, 9), (1, -9), (1, 9), (9, -1), (9, 1)} + + +def test_holzer(): + # if the input is good, don't let it diverge in holzer() + # (but see test_fail_holzer below) + assert holzer(2, 7, 13, 4, 79, 23) == (2, 7, 13) + + # None in uv condition met; solution is not Holzer reduced + # so this will hopefully change but is here for coverage + assert holzer(2, 6, 2, 1, 1, 10) == (2, 6, 2) + + raises(ValueError, lambda: holzer(2, 7, 14, 4, 79, 23)) + + +@XFAIL +def test_fail_holzer(): + eq = lambda x, y, z: a*x**2 + b*y**2 - c*z**2 + a, b, c = 4, 79, 23 + x, y, z = xyz = 26, 1, 11 + X, Y, Z = ans = 2, 7, 13 + assert eq(*xyz) == 0 + assert eq(*ans) == 0 + assert max(a*x**2, b*y**2, c*z**2) <= a*b*c + assert max(a*X**2, b*Y**2, c*Z**2) <= a*b*c + h = holzer(x, y, z, a, b, c) + assert h == ans # it would be nice to get the smaller soln + + +def test_issue_9539(): + assert diophantine(6*w + 9*y + 20*x - z) == \ + {(t_0, t_1, t_1 + t_2, 6*t_0 + 29*t_1 + 9*t_2)} + + +def test_issue_8943(): + assert diophantine( + 3*(x**2 + y**2 + z**2) - 14*(x*y + y*z + z*x)) == \ + {(0, 0, 0)} + + +def test_diop_sum_of_even_powers(): + eq = x**4 + y**4 + z**4 - 2673 + assert diop_solve(eq) == {(3, 6, 6), (2, 4, 7)} + assert diop_general_sum_of_even_powers(eq, 2) == {(3, 6, 6), (2, 4, 7)} + raises(NotImplementedError, lambda: diop_general_sum_of_even_powers(-eq, 2)) + neg = symbols('neg', negative=True) + eq = x**4 + y**4 + neg**4 - 2673 + assert diop_general_sum_of_even_powers(eq) == {(-3, 6, 6)} + assert diophantine(x**4 + y**4 + 2) == set() + assert diop_general_sum_of_even_powers(x**4 + y**4 - 2, limit=0) == set() + + +def test_sum_of_squares_powers(): + tru = {(0, 0, 1, 1, 11), (0, 0, 5, 7, 7), (0, 1, 3, 7, 8), (0, 1, 4, 5, 9), (0, 3, 4, 7, 7), (0, 3, 5, 5, 8), + (1, 1, 2, 6, 9), (1, 1, 6, 6, 7), (1, 2, 3, 3, 10), (1, 3, 4, 4, 9), (1, 5, 5, 6, 6), (2, 2, 3, 5, 9), + (2, 3, 5, 6, 7), (3, 3, 4, 5, 8)} + eq = u**2 + v**2 + x**2 + y**2 + z**2 - 123 + ans = diop_general_sum_of_squares(eq, oo) # allow oo to be used + assert len(ans) == 14 + assert ans == tru + + raises(ValueError, lambda: list(sum_of_squares(10, -1))) + assert list(sum_of_squares(1, 1)) == [(1,)] + assert list(sum_of_squares(1, 2)) == [] + assert list(sum_of_squares(1, 2, True)) == [(0, 1)] + assert list(sum_of_squares(-10, 2)) == [] + assert list(sum_of_squares(2, 3)) == [] + assert list(sum_of_squares(0, 3, True)) == [(0, 0, 0)] + assert list(sum_of_squares(0, 3)) == [] + assert list(sum_of_squares(4, 1)) == [(2,)] + assert list(sum_of_squares(5, 1)) == [] + assert list(sum_of_squares(50, 2)) == [(5, 5), (1, 7)] + assert list(sum_of_squares(11, 5, True)) == [ + (1, 1, 1, 2, 2), (0, 0, 1, 1, 3)] + assert list(sum_of_squares(8, 8)) == [(1, 1, 1, 1, 1, 1, 1, 1)] + + assert [len(list(sum_of_squares(i, 5, True))) for i in range(30)] == [ + 1, 1, 1, 1, 2, + 2, 1, 1, 2, 2, + 2, 2, 2, 3, 2, + 1, 3, 3, 3, 3, + 4, 3, 3, 2, 2, + 4, 4, 4, 4, 5] + assert [len(list(sum_of_squares(i, 5))) for i in range(30)] == [ + 0, 0, 0, 0, 0, + 1, 0, 0, 1, 0, + 0, 1, 0, 1, 1, + 0, 1, 1, 0, 1, + 2, 1, 1, 1, 1, + 1, 1, 1, 1, 3] + for i in range(30): + s1 = set(sum_of_squares(i, 5, True)) + assert not s1 or all(sum(j**2 for j in t) == i for t in s1) + s2 = set(sum_of_squares(i, 5)) + assert all(sum(j**2 for j in t) == i for t in s2) + + raises(ValueError, lambda: list(sum_of_powers(2, -1, 1))) + raises(ValueError, lambda: list(sum_of_powers(2, 1, -1))) + assert list(sum_of_powers(-2, 3, 2)) == [(-1, -1)] + assert list(sum_of_powers(-2, 4, 2)) == [] + assert list(sum_of_powers(2, 1, 1)) == [(2,)] + assert list(sum_of_powers(2, 1, 3, True)) == [(0, 0, 2), (0, 1, 1)] + assert list(sum_of_powers(5, 1, 2, True)) == [(0, 5), (1, 4), (2, 3)] + assert list(sum_of_powers(6, 2, 2)) == [] + assert list(sum_of_powers(3**5, 3, 1)) == [] + assert list(sum_of_powers(3**6, 3, 1)) == [(9,)] and (9**3 == 3**6) + assert list(sum_of_powers(2**1000, 5, 2)) == [] + + +def test__can_do_sum_of_squares(): + assert _can_do_sum_of_squares(3, -1) is False + assert _can_do_sum_of_squares(-3, 1) is False + assert _can_do_sum_of_squares(0, 1) + assert _can_do_sum_of_squares(4, 1) + assert _can_do_sum_of_squares(1, 2) + assert _can_do_sum_of_squares(2, 2) + assert _can_do_sum_of_squares(3, 2) is False + + +def test_diophantine_permute_sign(): + from sympy.abc import a, b, c, d, e + eq = a**4 + b**4 - (2**4 + 3**4) + base_sol = {(2, 3)} + assert diophantine(eq) == base_sol + complete_soln = set(signed_permutations(base_sol.pop())) + assert diophantine(eq, permute=True) == complete_soln + + eq = a**2 + b**2 + c**2 + d**2 + e**2 - 234 + assert len(diophantine(eq)) == 35 + assert len(diophantine(eq, permute=True)) == 62000 + soln = {(-1, -1), (-1, 2), (1, -2), (1, 1)} + assert diophantine(10*x**2 + 12*x*y + 12*y**2 - 34, permute=True) == soln + + +@XFAIL +def test_not_implemented(): + eq = x**2 + y**4 - 1**2 - 3**4 + assert diophantine(eq, syms=[x, y]) == {(9, 1), (1, 3)} + + +def test_issue_9538(): + eq = x - 3*y + 2 + assert diophantine(eq, syms=[y,x]) == {(t_0, 3*t_0 - 2)} + raises(TypeError, lambda: diophantine(eq, syms={y, x})) + + +def test_ternary_quadratic(): + # solution with 3 parameters + s = diophantine(2*x**2 + y**2 - 2*z**2) + p, q, r = ordered(S(s).free_symbols) + assert s == {( + p**2 - 2*q**2, + -2*p**2 + 4*p*q - 4*p*r - 4*q**2, + p**2 - 4*p*q + 2*q**2 - 4*q*r)} + # solution with Mul in solution + s = diophantine(x**2 + 2*y**2 - 2*z**2) + assert s == {(4*p*q, p**2 - 2*q**2, p**2 + 2*q**2)} + # solution with no Mul in solution + s = diophantine(2*x**2 + 2*y**2 - z**2) + assert s == {(2*p**2 - q**2, -2*p**2 + 4*p*q - q**2, + 4*p**2 - 4*p*q + 2*q**2)} + # reduced form when parametrized + s = diophantine(3*x**2 + 72*y**2 - 27*z**2) + assert s == {(24*p**2 - 9*q**2, 6*p*q, 8*p**2 + 3*q**2)} + assert parametrize_ternary_quadratic( + 3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) == ( + 2*p**2 - 2*p*q - q**2, 2*p**2 + 2*p*q - q**2, 2*p**2 - + 2*p*q + 3*q**2) + assert parametrize_ternary_quadratic( + 124*x**2 - 30*y**2 - 7729*z**2) == ( + -1410*p**2 - 363263*q**2, 2700*p**2 + 30916*p*q - + 695610*q**2, -60*p**2 + 5400*p*q + 15458*q**2) + + +def test_diophantine_solution_set(): + s1 = DiophantineSolutionSet([], []) + assert set(s1) == set() + assert s1.symbols == () + assert s1.parameters == () + raises(ValueError, lambda: s1.add((x,))) + assert list(s1.dict_iterator()) == [] + + s2 = DiophantineSolutionSet([x, y], [t, u]) + assert s2.symbols == (x, y) + assert s2.parameters == (t, u) + raises(ValueError, lambda: s2.add((1,))) + s2.add((3, 4)) + assert set(s2) == {(3, 4)} + s2.update((3, 4), (-1, u)) + assert set(s2) == {(3, 4), (-1, u)} + raises(ValueError, lambda: s1.update(s2)) + assert list(s2.dict_iterator()) == [{x: -1, y: u}, {x: 3, y: 4}] + + s3 = DiophantineSolutionSet([x, y, z], [t, u]) + assert len(s3.parameters) == 2 + s3.add((t**2 + u, t - u, 1)) + assert set(s3) == {(t**2 + u, t - u, 1)} + assert s3.subs(t, 2) == {(u + 4, 2 - u, 1)} + assert s3(2) == {(u + 4, 2 - u, 1)} + assert s3.subs({t: 7, u: 8}) == {(57, -1, 1)} + assert s3(7, 8) == {(57, -1, 1)} + assert s3.subs({t: 5}) == {(u + 25, 5 - u, 1)} + assert s3(5) == {(u + 25, 5 - u, 1)} + assert s3.subs(u, -3) == {(t**2 - 3, t + 3, 1)} + assert s3(None, -3) == {(t**2 - 3, t + 3, 1)} + assert s3.subs({t: 2, u: 8}) == {(12, -6, 1)} + assert s3(2, 8) == {(12, -6, 1)} + assert s3.subs({t: 5, u: -3}) == {(22, 8, 1)} + assert s3(5, -3) == {(22, 8, 1)} + raises(TypeError, lambda: s3.subs(x=1)) + raises(TypeError, lambda: s3.subs(1, 2, 3)) + raises(ValueError, lambda: s3.add(())) + raises(ValueError, lambda: s3.add((1, 2, 3, 4))) + raises(ValueError, lambda: s3.add((1, 2))) + raises(ValueError, lambda: s3(1, 2, 3)) + raises(TypeError, lambda: s3(t=1)) + + s4 = DiophantineSolutionSet([x, y], [t, u]) + s4.add((t, 11*t)) + s4.add((-t, 22*t)) + assert s4(0, 0) == {(0, 0)} + + +def test_quadratic_parameter_passing(): + eq = -33*x*y + 3*y**2 + solution = BinaryQuadratic(eq).solve(parameters=[t, u]) + # test that parameters are passed all the way to the final solution + assert solution == {(t, 11*t), (t, -22*t)} + assert solution(0, 0) == {(0, 0)} + +def test_issue_18628(): + eq1 = x**2 - 15*x + y**2 - 8*y + sol = diophantine(eq1) + assert sol == {(15, 0), (15, 8), (-1, 4), (0, 0), (0, 8), (16, 4)} + eq2 = 2*x**2 - 9*x + 4*y**2 - 8*y + 14 + sol = diophantine(eq2) + assert sol == {(2, 1)} diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/__init__.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b543425251dea6380a1860279cb6d636f3dd629 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/__init__.py @@ -0,0 +1,16 @@ +from .ode import (allhints, checkinfsol, classify_ode, + constantsimp, dsolve, homogeneous_order) + +from .lie_group import infinitesimals + +from .subscheck import checkodesol + +from .systems import (canonical_odes, linear_ode_to_matrix, + linodesolve) + + +__all__ = [ + 'allhints', 'checkinfsol', 'checkodesol', 'classify_ode', 'constantsimp', + 'dsolve', 'homogeneous_order', 'infinitesimals', 'canonical_odes', 'linear_ode_to_matrix', + 'linodesolve' +] diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/hypergeometric.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/hypergeometric.py new file mode 100644 index 0000000000000000000000000000000000000000..5699d2418058acaf48d1e4f87f6635a7b1f7284c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/hypergeometric.py @@ -0,0 +1,272 @@ +r''' +This module contains the implementation of the 2nd_hypergeometric hint for +dsolve. This is an incomplete implementation of the algorithm described in [1]. +The algorithm solves 2nd order linear ODEs of the form + +.. math:: y'' + A(x) y' + B(x) y = 0\text{,} + +where `A` and `B` are rational functions. The algorithm should find any +solution of the form + +.. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,} + +where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function". +Currently only the 2F1 case is implemented in SymPy but the other cases are +described in the paper and could be implemented in future (contributions +welcome!). + +References +========== + +.. [1] L. Chan, E.S. Cheb-Terrab, Non-Liouvillian solutions for second order + linear ODEs, (2004). + https://arxiv.org/abs/math-ph/0402063 +''' + +from sympy.core import S, Pow +from sympy.core.function import expand +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol, Wild +from sympy.functions import exp, sqrt, hyper +from sympy.integrals import Integral +from sympy.polys import roots, gcd +from sympy.polys.polytools import cancel, factor +from sympy.simplify import collect, simplify, logcombine # type: ignore +from sympy.simplify.powsimp import powdenest +from sympy.solvers.ode.ode import get_numbered_constants + + +def match_2nd_hypergeometric(eq, func): + x = func.args[0] + df = func.diff(x) + a3 = Wild('a3', exclude=[func, func.diff(x), func.diff(x, 2)]) + b3 = Wild('b3', exclude=[func, func.diff(x), func.diff(x, 2)]) + c3 = Wild('c3', exclude=[func, func.diff(x), func.diff(x, 2)]) + deq = a3*(func.diff(x, 2)) + b3*df + c3*func + r = collect(eq, + [func.diff(x, 2), func.diff(x), func]).match(deq) + if r: + if not all(val.is_polynomial() for val in r.values()): + n, d = eq.as_numer_denom() + eq = expand(n) + r = collect(eq, [func.diff(x, 2), func.diff(x), func]).match(deq) + + if r and r[a3]!=0: + A = cancel(r[b3]/r[a3]) + B = cancel(r[c3]/r[a3]) + return [A, B] + else: + return [] + + +def equivalence_hypergeometric(A, B, func): + # This method for finding the equivalence is only for 2F1 type. + # We can extend it for 1F1 and 0F1 type also. + x = func.args[0] + + # making given equation in normal form + I1 = factor(cancel(A.diff(x)/2 + A**2/4 - B)) + + # computing shifted invariant(J1) of the equation + J1 = factor(cancel(x**2*I1 + S(1)/4)) + num, dem = J1.as_numer_denom() + num = powdenest(expand(num)) + dem = powdenest(expand(dem)) + # this function will compute the different powers of variable(x) in J1. + # then it will help in finding value of k. k is power of x such that we can express + # J1 = x**k * J0(x**k) then all the powers in J0 become integers. + def _power_counting(num): + _pow = {0} + for val in num: + if val.has(x): + if isinstance(val, Pow) and val.as_base_exp()[0] == x: + _pow.add(val.as_base_exp()[1]) + elif val == x: + _pow.add(val.as_base_exp()[1]) + else: + _pow.update(_power_counting(val.args)) + return _pow + + pow_num = _power_counting((num, )) + pow_dem = _power_counting((dem, )) + pow_dem.update(pow_num) + + _pow = pow_dem + k = gcd(_pow) + + # computing I0 of the given equation + I0 = powdenest(simplify(factor(((J1/k**2) - S(1)/4)/((x**k)**2))), force=True) + I0 = factor(cancel(powdenest(I0.subs(x, x**(S(1)/k)), force=True))) + + # Before this point I0, J1 might be functions of e.g. sqrt(x) but replacing + # x with x**(1/k) should result in I0 being a rational function of x or + # otherwise the hypergeometric solver cannot be used. Note that k can be a + # non-integer rational such as 2/7. + if not I0.is_rational_function(x): + return None + + num, dem = I0.as_numer_denom() + + max_num_pow = max(_power_counting((num, ))) + dem_args = dem.args + sing_point = [] + dem_pow = [] + # calculating singular point of I0. + for arg in dem_args: + if arg.has(x): + if isinstance(arg, Pow): + # (x-a)**n + dem_pow.append(arg.as_base_exp()[1]) + sing_point.append(list(roots(arg.as_base_exp()[0], x).keys())[0]) + else: + # (x-a) type + dem_pow.append(arg.as_base_exp()[1]) + sing_point.append(list(roots(arg, x).keys())[0]) + + dem_pow.sort() + # checking if equivalence is exists or not. + + if equivalence(max_num_pow, dem_pow) == "2F1": + return {'I0':I0, 'k':k, 'sing_point':sing_point, 'type':"2F1"} + else: + return None + + +def match_2nd_2F1_hypergeometric(I, k, sing_point, func): + x = func.args[0] + a = Wild("a") + b = Wild("b") + c = Wild("c") + t = Wild("t") + s = Wild("s") + r = Wild("r") + alpha = Wild("alpha") + beta = Wild("beta") + gamma = Wild("gamma") + delta = Wild("delta") + # I0 of the standard 2F1 equation. + I0 = ((a-b+1)*(a-b-1)*x**2 + 2*((1-a-b)*c + 2*a*b)*x + c*(c-2))/(4*x**2*(x-1)**2) + if sing_point != [0, 1]: + # If singular point is [0, 1] then we have standard equation. + eqs = [] + sing_eqs = [-beta/alpha, -delta/gamma, (delta-beta)/(alpha-gamma)] + # making equations for the finding the mobius transformation + for i in range(3): + if i>> from sympy import Function, Eq, pprint + >>> from sympy.abc import x, y + >>> xi, eta, h = map(Function, ['xi', 'eta', 'h']) + >>> h = h(x, y) # dy/dx = h + >>> eta = eta(x, y) + >>> xi = xi(x, y) + >>> genform = Eq(eta.diff(x) + (eta.diff(y) - xi.diff(x))*h + ... - (xi.diff(y))*h**2 - xi*(h.diff(x)) - eta*(h.diff(y)), 0) + >>> pprint(genform) + /d d \ d 2 d d d + |--(eta(x, y)) - --(xi(x, y))|*h(x, y) - eta(x, y)*--(h(x, y)) - h (x, y)*--(xi(x, y)) - xi(x, y)*--(h(x, y)) + --(eta(x, y)) = 0 + \dy dx / dy dy dx dx + + Solving the above mentioned PDE is not trivial, and can be solved only by + making intelligent assumptions for `\xi` and `\eta` (heuristics). Once an + infinitesimal is found, the attempt to find more heuristics stops. This is done to + optimise the speed of solving the differential equation. If a list of all the + infinitesimals is needed, ``hint`` should be flagged as ``all``, which gives + the complete list of infinitesimals. If the infinitesimals for a particular + heuristic needs to be found, it can be passed as a flag to ``hint``. + + Examples + ======== + + >>> from sympy import Function + >>> from sympy.solvers.ode.lie_group import infinitesimals + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = f(x).diff(x) - x**2*f(x) + >>> infinitesimals(eq) + [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0}] + + References + ========== + + - Solving differential equations by Symmetry Groups, + John Starrett, pp. 1 - pp. 14 + + """ + + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + if not func: + eq, func = _preprocess(eq) + variables = func.args + if len(variables) != 1: + raise ValueError("ODE's have only one independent variable") + else: + x = variables[0] + if not order: + order = ode_order(eq, func) + if order != 1: + raise NotImplementedError("Infinitesimals for only " + "first order ODE's have been implemented") + else: + df = func.diff(x) + # Matching differential equation of the form a*df + b + a = Wild('a', exclude = [df]) + b = Wild('b', exclude = [df]) + if match: # Used by lie_group hint + h = match['h'] + y = match['y'] + else: + match = collect(expand(eq), df).match(a*df + b) + if match: + h = -simplify(match[b]/match[a]) + else: + try: + sol = solve(eq, df) + except NotImplementedError: + raise NotImplementedError("Infinitesimals for the " + "first order ODE could not be found") + else: + h = sol[0] # Find infinitesimals for one solution + y = Dummy("y") + h = h.subs(func, y) + + u = Dummy("u") + hx = h.diff(x) + hy = h.diff(y) + hinv = ((1/h).subs([(x, u), (y, x)])).subs(u, y) # Inverse ODE + match = {'h': h, 'func': func, 'hx': hx, 'hy': hy, 'y': y, 'hinv': hinv} + if hint == 'all': + xieta = [] + for heuristic in lie_heuristics: + function = globals()['lie_heuristic_' + heuristic] + inflist = function(match, comp=True) + if inflist: + xieta.extend([inf for inf in inflist if inf not in xieta]) + if xieta: + return xieta + else: + raise NotImplementedError("Infinitesimals could not be found for " + "the given ODE") + + elif hint == 'default': + for heuristic in lie_heuristics: + function = globals()['lie_heuristic_' + heuristic] + xieta = function(match, comp=False) + if xieta: + return xieta + + raise NotImplementedError("Infinitesimals could not be found for" + " the given ODE") + + elif hint not in lie_heuristics: + raise ValueError("Heuristic not recognized: " + hint) + + else: + function = globals()['lie_heuristic_' + hint] + xieta = function(match, comp=True) + if xieta: + return xieta + else: + raise ValueError("Infinitesimals could not be found using the" + " given heuristic") + + +def lie_heuristic_abaco1_simple(match, comp=False): + r""" + The first heuristic uses the following four sets of + assumptions on `\xi` and `\eta` + + .. math:: \xi = 0, \eta = f(x) + + .. math:: \xi = 0, \eta = f(y) + + .. math:: \xi = f(x), \eta = 0 + + .. math:: \xi = f(y), \eta = 0 + + The success of this heuristic is determined by algebraic factorisation. + For the first assumption `\xi = 0` and `\eta` to be a function of `x`, the PDE + + .. math:: \frac{\partial \eta}{\partial x} + (\frac{\partial \eta}{\partial y} + - \frac{\partial \xi}{\partial x})*h + - \frac{\partial \xi}{\partial y}*h^{2} + - \xi*\frac{\partial h}{\partial x} - \eta*\frac{\partial h}{\partial y} = 0 + + reduces to `f'(x) - f\frac{\partial h}{\partial y} = 0` + If `\frac{\partial h}{\partial y}` is a function of `x`, then this can usually + be integrated easily. A similar idea is applied to the other 3 assumptions as well. + + + References + ========== + + - E.S Cheb-Terrab, L.G.S Duarte and L.A,C.P da Mota, Computer Algebra + Solving of First Order ODEs Using Symmetry Methods, pp. 8 + + + """ + + xieta = [] + y = match['y'] + h = match['h'] + func = match['func'] + x = func.args[0] + hx = match['hx'] + hy = match['hy'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + hysym = hy.free_symbols + if y not in hysym: + try: + fx = exp(integrate(hy, x)) + except NotImplementedError: + pass + else: + inf = {xi: S.Zero, eta: fx} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = hy/h + facsym = factor.free_symbols + if x not in facsym: + try: + fy = exp(integrate(factor, y)) + except NotImplementedError: + pass + else: + inf = {xi: S.Zero, eta: fy.subs(y, func)} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = -hx/h + facsym = factor.free_symbols + if y not in facsym: + try: + fx = exp(integrate(factor, x)) + except NotImplementedError: + pass + else: + inf = {xi: fx, eta: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = -hx/(h**2) + facsym = factor.free_symbols + if x not in facsym: + try: + fy = exp(integrate(factor, y)) + except NotImplementedError: + pass + else: + inf = {xi: fy.subs(y, func), eta: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_abaco1_product(match, comp=False): + r""" + The second heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = 0, \xi = f(x)*g(y) + + .. math:: \eta = f(x)*g(y), \xi = 0 + + The first assumption of this heuristic holds good if + `\frac{1}{h^{2}}\frac{\partial^2}{\partial x \partial y}\log(h)` is + separable in `x` and `y`, then the separated factors containing `x` + is `f(x)`, and `g(y)` is obtained by + + .. math:: e^{\int f\frac{\partial}{\partial x}\left(\frac{1}{f*h}\right)\,dy} + + provided `f\frac{\partial}{\partial x}\left(\frac{1}{f*h}\right)` is a function + of `y` only. + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first assumption + satisfies. After obtaining `f(x)` and `g(y)`, the coordinates are again + interchanged, to get `\eta` as `f(x)*g(y)` + + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 7 - pp. 8 + + """ + + xieta = [] + y = match['y'] + h = match['h'] + hinv = match['hinv'] + func = match['func'] + x = func.args[0] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + + inf = separatevars(((log(h).diff(y)).diff(x))/h**2, dict=True, symbols=[x, y]) + if inf and inf['coeff']: + fx = inf[x] + gy = simplify(fx*((1/(fx*h)).diff(x))) + gysyms = gy.free_symbols + if x not in gysyms: + gy = exp(integrate(gy, y)) + inf = {eta: S.Zero, xi: (fx*gy).subs(y, func)} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + u1 = Dummy("u1") + inf = separatevars(((log(hinv).diff(y)).diff(x))/hinv**2, dict=True, symbols=[x, y]) + if inf and inf['coeff']: + fx = inf[x] + gy = simplify(fx*((1/(fx*hinv)).diff(x))) + gysyms = gy.free_symbols + if x not in gysyms: + gy = exp(integrate(gy, y)) + etaval = fx*gy + etaval = (etaval.subs([(x, u1), (y, x)])).subs(u1, y) + inf = {eta: etaval.subs(y, func), xi: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_bivariate(match, comp=False): + r""" + The third heuristic assumes the infinitesimals `\xi` and `\eta` + to be bi-variate polynomials in `x` and `y`. The assumption made here + for the logic below is that `h` is a rational function in `x` and `y` + though that may not be necessary for the infinitesimals to be + bivariate polynomials. The coefficients of the infinitesimals + are found out by substituting them in the PDE and grouping similar terms + that are polynomials and since they form a linear system, solve and check + for non trivial solutions. The degree of the assumed bivariates + are increased till a certain maximum value. + + References + ========== + - Lie Groups and Differential Equations + pp. 327 - pp. 329 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + if h.is_rational_function(): + # The maximum degree that the infinitesimals can take is + # calculated by this technique. + etax, etay, etad, xix, xiy, xid = symbols("etax etay etad xix xiy xid") + ipde = etax + (etay - xix)*h - xiy*h**2 - xid*hx - etad*hy + num, denom = cancel(ipde).as_numer_denom() + deg = Poly(num, x, y).total_degree() + deta = Function('deta')(x, y) + dxi = Function('dxi')(x, y) + ipde = (deta.diff(x) + (deta.diff(y) - dxi.diff(x))*h - (dxi.diff(y))*h**2 + - dxi*hx - deta*hy) + xieq = Symbol("xi0") + etaeq = Symbol("eta0") + + for i in range(deg + 1): + if i: + xieq += Add(*[ + Symbol("xi_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + etaeq += Add(*[ + Symbol("eta_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + pden, denom = (ipde.subs({dxi: xieq, deta: etaeq}).doit()).as_numer_denom() + pden = expand(pden) + + # If the individual terms are monomials, the coefficients + # are grouped + if pden.is_polynomial(x, y) and pden.is_Add: + polyy = Poly(pden, x, y).as_dict() + if polyy: + symset = xieq.free_symbols.union(etaeq.free_symbols) - {x, y} + soldict = solve(polyy.values(), *symset) + if isinstance(soldict, list): + soldict = soldict[0] + if any(soldict.values()): + xired = xieq.subs(soldict) + etared = etaeq.subs(soldict) + # Scaling is done by substituting one for the parameters + # This can be any number except zero. + dict_ = dict.fromkeys(symset, 1) + inf = {eta: etared.subs(dict_).subs(y, func), + xi: xired.subs(dict_).subs(y, func)} + return [inf] + +def lie_heuristic_chi(match, comp=False): + r""" + The aim of the fourth heuristic is to find the function `\chi(x, y)` + that satisfies the PDE `\frac{d\chi}{dx} + h\frac{d\chi}{dx} + - \frac{\partial h}{\partial y}\chi = 0`. + + This assumes `\chi` to be a bivariate polynomial in `x` and `y`. By intuition, + `h` should be a rational function in `x` and `y`. The method used here is + to substitute a general binomial for `\chi` up to a certain maximum degree + is reached. The coefficients of the polynomials, are calculated by by collecting + terms of the same order in `x` and `y`. + + After finding `\chi`, the next step is to use `\eta = \xi*h + \chi`, to + determine `\xi` and `\eta`. This can be done by dividing `\chi` by `h` + which would give `-\xi` as the quotient and `\eta` as the remainder. + + + References + ========== + - E.S Cheb-Terrab, L.G.S Duarte and L.A,C.P da Mota, Computer Algebra + Solving of First Order ODEs Using Symmetry Methods, pp. 8 + + """ + + h = match['h'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + if h.is_rational_function(): + schi, schix, schiy = symbols("schi, schix, schiy") + cpde = schix + h*schiy - hy*schi + num, denom = cancel(cpde).as_numer_denom() + deg = Poly(num, x, y).total_degree() + + chi = Function('chi')(x, y) + chix = chi.diff(x) + chiy = chi.diff(y) + cpde = chix + h*chiy - hy*chi + chieq = Symbol("chi") + for i in range(1, deg + 1): + chieq += Add(*[ + Symbol("chi_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + cnum, cden = cancel(cpde.subs({chi : chieq}).doit()).as_numer_denom() + cnum = expand(cnum) + if cnum.is_polynomial(x, y) and cnum.is_Add: + cpoly = Poly(cnum, x, y).as_dict() + if cpoly: + solsyms = chieq.free_symbols - {x, y} + soldict = solve(cpoly.values(), *solsyms) + if isinstance(soldict, list): + soldict = soldict[0] + if any(soldict.values()): + chieq = chieq.subs(soldict) + dict_ = dict.fromkeys(solsyms, 1) + chieq = chieq.subs(dict_) + # After finding chi, the main aim is to find out + # eta, xi by the equation eta = xi*h + chi + # One method to set xi, would be rearranging it to + # (eta/h) - xi = (chi/h). This would mean dividing + # chi by h would give -xi as the quotient and eta + # as the remainder. Thanks to Sean Vig for suggesting + # this method. + xic, etac = div(chieq, h) + inf = {eta: etac.subs(y, func), xi: -xic.subs(y, func)} + return [inf] + +def lie_heuristic_function_sum(match, comp=False): + r""" + This heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = 0, \xi = f(x) + g(y) + + .. math:: \eta = f(x) + g(y), \xi = 0 + + The first assumption of this heuristic holds good if + + .. math:: \frac{\partial}{\partial y}[(h\frac{\partial^{2}}{ + \partial x^{2}}(h^{-1}))^{-1}] + + is separable in `x` and `y`, + + 1. The separated factors containing `y` is `\frac{\partial g}{\partial y}`. + From this `g(y)` can be determined. + 2. The separated factors containing `x` is `f''(x)`. + 3. `h\frac{\partial^{2}}{\partial x^{2}}(h^{-1})` equals + `\frac{f''(x)}{f(x) + g(y)}`. From this `f(x)` can be determined. + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first + assumption satisfies. After obtaining `f(x)` and `g(y)`, the coordinates + are again interchanged, to get `\eta` as `f(x) + g(y)`. + + For both assumptions, the constant factors are separated among `g(y)` + and `f''(x)`, such that `f''(x)` obtained from 3] is the same as that + obtained from 2]. If not possible, then this heuristic fails. + + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 7 - pp. 8 + + """ + + xieta = [] + h = match['h'] + func = match['func'] + hinv = match['hinv'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + for odefac in [h, hinv]: + factor = odefac*((1/odefac).diff(x, 2)) + sep = separatevars((1/factor).diff(y), dict=True, symbols=[x, y]) + if sep and sep['coeff'] and sep[x].has(x) and sep[y].has(y): + k = Dummy("k") + try: + gy = k*integrate(sep[y], y) + except NotImplementedError: + pass + else: + fdd = 1/(k*sep[x]*sep['coeff']) + fx = simplify(fdd/factor - gy) + check = simplify(fx.diff(x, 2) - fdd) + if fx: + if not check: + fx = fx.subs(k, 1) + gy = (gy/k) + else: + sol = solve(check, k) + if sol: + sol = sol[0] + fx = fx.subs(k, sol) + gy = (gy/k)*sol + else: + continue + if odefac == hinv: # Inverse ODE + fx = fx.subs(x, y) + gy = gy.subs(y, x) + etaval = factor_terms(fx + gy) + if etaval.is_Mul: + etaval = Mul(*[arg for arg in etaval.args if arg.has(x, y)]) + if odefac == hinv: # Inverse ODE + inf = {eta: etaval.subs(y, func), xi : S.Zero} + else: + inf = {xi: etaval.subs(y, func), eta : S.Zero} + if not comp: + return [inf] + else: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_abaco2_similar(match, comp=False): + r""" + This heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = g(x), \xi = f(x) + + .. math:: \eta = f(y), \xi = g(y) + + For the first assumption, + + 1. First `\frac{\frac{\partial h}{\partial y}}{\frac{\partial^{2} h}{ + \partial yy}}` is calculated. Let us say this value is A + + 2. If this is constant, then `h` is matched to the form `A(x) + B(x)e^{ + \frac{y}{C}}` then, `\frac{e^{\int \frac{A(x)}{C} \,dx}}{B(x)}` gives `f(x)` + and `A(x)*f(x)` gives `g(x)` + + 3. Otherwise `\frac{\frac{\partial A}{\partial X}}{\frac{\partial A}{ + \partial Y}} = \gamma` is calculated. If + + a] `\gamma` is a function of `x` alone + + b] `\frac{\gamma\frac{\partial h}{\partial y} - \gamma'(x) - \frac{ + \partial h}{\partial x}}{h + \gamma} = G` is a function of `x` alone. + then, `e^{\int G \,dx}` gives `f(x)` and `-\gamma*f(x)` gives `g(x)` + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first assumption + satisfies. After obtaining `f(x)` and `g(x)`, the coordinates are again + interchanged, to get `\xi` as `f(x^*)` and `\eta` as `g(y^*)` + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + hinv = match['hinv'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + factor = cancel(h.diff(y)/h.diff(y, 2)) + factorx = factor.diff(x) + factory = factor.diff(y) + if not factor.has(x) and not factor.has(y): + A = Wild('A', exclude=[y]) + B = Wild('B', exclude=[y]) + C = Wild('C', exclude=[x, y]) + match = h.match(A + B*exp(y/C)) + try: + tau = exp(-integrate(match[A]/match[C]), x)/match[B] + except NotImplementedError: + pass + else: + gx = match[A]*tau + return [{xi: tau, eta: gx}] + + else: + gamma = cancel(factorx/factory) + if not gamma.has(y): + tauint = cancel((gamma*hy - gamma.diff(x) - hx)/(h + gamma)) + if not tauint.has(y): + try: + tau = exp(integrate(tauint, x)) + except NotImplementedError: + pass + else: + gx = -tau*gamma + return [{xi: tau, eta: gx}] + + factor = cancel(hinv.diff(y)/hinv.diff(y, 2)) + factorx = factor.diff(x) + factory = factor.diff(y) + if not factor.has(x) and not factor.has(y): + A = Wild('A', exclude=[y]) + B = Wild('B', exclude=[y]) + C = Wild('C', exclude=[x, y]) + match = h.match(A + B*exp(y/C)) + try: + tau = exp(-integrate(match[A]/match[C]), x)/match[B] + except NotImplementedError: + pass + else: + gx = match[A]*tau + return [{eta: tau.subs(x, func), xi: gx.subs(x, func)}] + + else: + gamma = cancel(factorx/factory) + if not gamma.has(y): + tauint = cancel((gamma*hinv.diff(y) - gamma.diff(x) - hinv.diff(x))/( + hinv + gamma)) + if not tauint.has(y): + try: + tau = exp(integrate(tauint, x)) + except NotImplementedError: + pass + else: + gx = -tau*gamma + return [{eta: tau.subs(x, func), xi: gx.subs(x, func)}] + + +def lie_heuristic_abaco2_unique_unknown(match, comp=False): + r""" + This heuristic assumes the presence of unknown functions or known functions + with non-integer powers. + + 1. A list of all functions and non-integer powers containing x and y + 2. Loop over each element `f` in the list, find `\frac{\frac{\partial f}{\partial x}}{ + \frac{\partial f}{\partial x}} = R` + + If it is separable in `x` and `y`, let `X` be the factors containing `x`. Then + + a] Check if `\xi = X` and `\eta = -\frac{X}{R}` satisfy the PDE. If yes, then return + `\xi` and `\eta` + b] Check if `\xi = \frac{-R}{X}` and `\eta = -\frac{1}{X}` satisfy the PDE. + If yes, then return `\xi` and `\eta` + + If not, then check if + + a] :math:`\xi = -R,\eta = 1` + + b] :math:`\xi = 1, \eta = -\frac{1}{R}` + + are solutions. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + funclist = [] + for atom in h.atoms(Pow): + base, exp = atom.as_base_exp() + if base.has(x) and base.has(y): + if not exp.is_Integer: + funclist.append(atom) + + for function in h.atoms(AppliedUndef): + syms = function.free_symbols + if x in syms and y in syms: + funclist.append(function) + + for f in funclist: + frac = cancel(f.diff(y)/f.diff(x)) + sep = separatevars(frac, dict=True, symbols=[x, y]) + if sep and sep['coeff']: + xitry1 = sep[x] + etatry1 = -1/(sep[y]*sep['coeff']) + pde1 = etatry1.diff(y)*h - xitry1.diff(x)*h - xitry1*hx - etatry1*hy + if not simplify(pde1): + return [{xi: xitry1, eta: etatry1.subs(y, func)}] + xitry2 = 1/etatry1 + etatry2 = 1/xitry1 + pde2 = etatry2.diff(x) - (xitry2.diff(y))*h**2 - xitry2*hx - etatry2*hy + if not simplify(expand(pde2)): + return [{xi: xitry2.subs(y, func), eta: etatry2}] + + else: + etatry = -1/frac + pde = etatry.diff(x) + etatry.diff(y)*h - hx - etatry*hy + if not simplify(pde): + return [{xi: S.One, eta: etatry.subs(y, func)}] + xitry = -frac + pde = -xitry.diff(x)*h -xitry.diff(y)*h**2 - xitry*hx -hy + if not simplify(expand(pde)): + return [{xi: xitry.subs(y, func), eta: S.One}] + + +def lie_heuristic_abaco2_unique_general(match, comp=False): + r""" + This heuristic finds if infinitesimals of the form `\eta = f(x)`, `\xi = g(y)` + without making any assumptions on `h`. + + The complete sequence of steps is given in the paper mentioned below. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + A = hx.diff(y) + B = hy.diff(y) + hy**2 + C = hx.diff(x) - hx**2 + + if not (A and B and C): + return + + Ax = A.diff(x) + Ay = A.diff(y) + Axy = Ax.diff(y) + Axx = Ax.diff(x) + Ayy = Ay.diff(y) + D = simplify(2*Axy + hx*Ay - Ax*hy + (hx*hy + 2*A)*A)*A - 3*Ax*Ay + if not D: + E1 = simplify(3*Ax**2 + ((hx**2 + 2*C)*A - 2*Axx)*A) + if E1: + E2 = simplify((2*Ayy + (2*B - hy**2)*A)*A - 3*Ay**2) + if not E2: + E3 = simplify( + E1*((28*Ax + 4*hx*A)*A**3 - E1*(hy*A + Ay)) - E1.diff(x)*8*A**4) + if not E3: + etaval = cancel((4*A**3*(Ax - hx*A) + E1*(hy*A - Ay))/(S(2)*A*E1)) + if x not in etaval: + try: + etaval = exp(integrate(etaval, y)) + except NotImplementedError: + pass + else: + xival = -4*A**3*etaval/E1 + if y not in xival: + return [{xi: xival, eta: etaval.subs(y, func)}] + + else: + E1 = simplify((2*Ayy + (2*B - hy**2)*A)*A - 3*Ay**2) + if E1: + E2 = simplify( + 4*A**3*D - D**2 + E1*((2*Axx - (hx**2 + 2*C)*A)*A - 3*Ax**2)) + if not E2: + E3 = simplify( + -(A*D)*E1.diff(y) + ((E1.diff(x) - hy*D)*A + 3*Ay*D + + (A*hx - 3*Ax)*E1)*E1) + if not E3: + etaval = cancel(((A*hx - Ax)*E1 - (Ay + A*hy)*D)/(S(2)*A*D)) + if x not in etaval: + try: + etaval = exp(integrate(etaval, y)) + except NotImplementedError: + pass + else: + xival = -E1*etaval/D + if y not in xival: + return [{xi: xival, eta: etaval.subs(y, func)}] + + +def lie_heuristic_linear(match, comp=False): + r""" + This heuristic assumes + + 1. `\xi = ax + by + c` and + 2. `\eta = fx + gy + h` + + After substituting the following assumptions in the determining PDE, it + reduces to + + .. math:: f + (g - a)h - bh^{2} - (ax + by + c)\frac{\partial h}{\partial x} + - (fx + gy + c)\frac{\partial h}{\partial y} + + Solving the reduced PDE obtained, using the method of characteristics, becomes + impractical. The method followed is grouping similar terms and solving the system + of linear equations obtained. The difference between the bivariate heuristic is that + `h` need not be a rational function in this case. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + coeffdict = {} + symbols = numbered_symbols("c", cls=Dummy) + symlist = [next(symbols) for _ in islice(symbols, 6)] + C0, C1, C2, C3, C4, C5 = symlist + pde = C3 + (C4 - C0)*h - (C0*x + C1*y + C2)*hx - (C3*x + C4*y + C5)*hy - C1*h**2 + pde, denom = pde.as_numer_denom() + pde = powsimp(expand(pde)) + if pde.is_Add: + terms = pde.args + for term in terms: + if term.is_Mul: + rem = Mul(*[m for m in term.args if not m.has(x, y)]) + xypart = term/rem + if xypart not in coeffdict: + coeffdict[xypart] = rem + else: + coeffdict[xypart] += rem + else: + if term not in coeffdict: + coeffdict[term] = S.One + else: + coeffdict[term] += S.One + + sollist = coeffdict.values() + soldict = solve(sollist, symlist) + if soldict: + if isinstance(soldict, list): + soldict = soldict[0] + subval = soldict.values() + if any(t for t in subval): + onedict = dict(zip(symlist, [1]*6)) + xival = C0*x + C1*func + C2 + etaval = C3*x + C4*func + C5 + xival = xival.subs(soldict) + etaval = etaval.subs(soldict) + xival = xival.subs(onedict) + etaval = etaval.subs(onedict) + return [{xi: xival, eta: etaval}] + + +def _lie_group_remove(coords): + r""" + This function is strictly meant for internal use by the Lie group ODE solving + method. It replaces arbitrary functions returned by pdsolve as follows: + + 1] If coords is an arbitrary function, then its argument is returned. + 2] An arbitrary function in an Add object is replaced by zero. + 3] An arbitrary function in a Mul object is replaced by one. + 4] If there is no arbitrary function coords is returned unchanged. + + Examples + ======== + + >>> from sympy.solvers.ode.lie_group import _lie_group_remove + >>> from sympy import Function + >>> from sympy.abc import x, y + >>> F = Function("F") + >>> eq = x**2*y + >>> _lie_group_remove(eq) + x**2*y + >>> eq = F(x**2*y) + >>> _lie_group_remove(eq) + x**2*y + >>> eq = x*y**2 + F(x**3) + >>> _lie_group_remove(eq) + x*y**2 + >>> eq = (F(x**3) + y)*x**4 + >>> _lie_group_remove(eq) + x**4*y + + """ + if isinstance(coords, AppliedUndef): + return coords.args[0] + elif coords.is_Add: + subfunc = coords.atoms(AppliedUndef) + if subfunc: + for func in subfunc: + coords = coords.subs(func, 0) + return coords + elif coords.is_Pow: + base, expr = coords.as_base_exp() + base = _lie_group_remove(base) + expr = _lie_group_remove(expr) + return base**expr + elif coords.is_Mul: + mulargs = [] + coordargs = coords.args + for arg in coordargs: + if not isinstance(coords, AppliedUndef): + mulargs.append(_lie_group_remove(arg)) + return Mul(*mulargs) + return coords diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/nonhomogeneous.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/nonhomogeneous.py new file mode 100644 index 0000000000000000000000000000000000000000..ae39d55664e4850168ca7d68f65cf02171979957 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/nonhomogeneous.py @@ -0,0 +1,484 @@ +r""" +This File contains helper functions for nth_linear_constant_coeff_undetermined_coefficients, +nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients, +nth_linear_constant_coeff_variation_of_parameters, +and nth_linear_euler_eq_nonhomogeneous_variation_of_parameters. + +All the functions in this file are used by more than one solvers so, instead of creating +instances in other classes for using them it is better to keep it here as separate helpers. + +""" +from collections import Counter +from sympy.core import Add, S +from sympy.core.function import diff, expand, _mexpand, expand_mul +from sympy.core.relational import Eq +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy, Wild +from sympy.functions import exp, cos, cosh, im, log, re, sin, sinh, \ + atan2, conjugate +from sympy.integrals import Integral +from sympy.polys import (Poly, RootOf, rootof, roots) +from sympy.simplify import collect, simplify, separatevars, powsimp, trigsimp # type: ignore +from sympy.utilities import numbered_symbols +from sympy.solvers.solvers import solve +from sympy.matrices import wronskian +from .subscheck import sub_func_doit +from sympy.solvers.ode.ode import get_numbered_constants + + +def _test_term(coeff, func, order): + r""" + Linear Euler ODEs have the form K*x**order*diff(y(x), x, order) = F(x), + where K is independent of x and y(x), order>= 0. + So we need to check that for each term, coeff == K*x**order from + some K. We have a few cases, since coeff may have several + different types. + """ + x = func.args[0] + f = func.func + if order < 0: + raise ValueError("order should be greater than 0") + if coeff == 0: + return True + if order == 0: + if x in coeff.free_symbols: + return False + return True + if coeff.is_Mul: + if coeff.has(f(x)): + return False + return x**order in coeff.args + elif coeff.is_Pow: + return coeff.as_base_exp() == (x, order) + elif order == 1: + return x == coeff + return False + + +def _get_euler_characteristic_eq_sols(eq, func, match_obj): + r""" + Returns the solution of homogeneous part of the linear euler ODE and + the list of roots of characteristic equation. + + The parameter ``match_obj`` is a dict of order:coeff terms, where order is the order + of the derivative on each term, and coeff is the coefficient of that derivative. + + """ + x = func.args[0] + f = func.func + + # First, set up characteristic equation. + chareq, symbol = S.Zero, Dummy('x') + + for i in match_obj: + if i >= 0: + chareq += (match_obj[i]*diff(x**symbol, x, i)*x**-symbol).expand() + + chareq = Poly(chareq, symbol) + chareqroots = [rootof(chareq, k) for k in range(chareq.degree())] + collectterms = [] + + # A generator of constants + constants = list(get_numbered_constants(eq, num=chareq.degree()*2)) + constants.reverse() + + # Create a dict root: multiplicity or charroots + charroots = Counter(chareqroots) + gsol = S.Zero + ln = log + for root, multiplicity in charroots.items(): + for i in range(multiplicity): + if isinstance(root, RootOf): + gsol += (x**root) * constants.pop() + if multiplicity != 1: + raise ValueError("Value should be 1") + collectterms = [(0, root, 0)] + collectterms + elif root.is_real: + gsol += ln(x)**i*(x**root) * constants.pop() + collectterms = [(i, root, 0)] + collectterms + else: + reroot = re(root) + imroot = im(root) + gsol += ln(x)**i * (x**reroot) * ( + constants.pop() * sin(abs(imroot)*ln(x)) + + constants.pop() * cos(imroot*ln(x))) + collectterms = [(i, reroot, imroot)] + collectterms + + gsol = Eq(f(x), gsol) + + gensols = [] + # Keep track of when to use sin or cos for nonzero imroot + for i, reroot, imroot in collectterms: + if imroot == 0: + gensols.append(ln(x)**i*x**reroot) + else: + sin_form = ln(x)**i*x**reroot*sin(abs(imroot)*ln(x)) + if sin_form in gensols: + cos_form = ln(x)**i*x**reroot*cos(imroot*ln(x)) + gensols.append(cos_form) + else: + gensols.append(sin_form) + return gsol, gensols + + +def _solve_variation_of_parameters(eq, func, roots, homogen_sol, order, match_obj, simplify_flag=True): + r""" + Helper function for the method of variation of parameters and nonhomogeneous euler eq. + + See the + :py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffVariationOfParameters` + docstring for more information on this method. + + The parameter are ``match_obj`` should be a dictionary that has the following + keys: + + ``list`` + A list of solutions to the homogeneous equation. + + ``sol`` + The general solution. + + """ + f = func.func + x = func.args[0] + r = match_obj + psol = 0 + wr = wronskian(roots, x) + + if simplify_flag: + wr = simplify(wr) # We need much better simplification for + # some ODEs. See issue 4662, for example. + # To reduce commonly occurring sin(x)**2 + cos(x)**2 to 1 + wr = trigsimp(wr, deep=True, recursive=True) + if not wr: + # The wronskian will be 0 iff the solutions are not linearly + # independent. + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply " + + "variation of parameters to " + str(eq) + " (Wronskian == 0)") + if len(roots) != order: + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply " + + "variation of parameters to " + + str(eq) + " (number of terms != order)") + negoneterm = S.NegativeOne**(order) + for i in roots: + psol += negoneterm*Integral(wronskian([sol for sol in roots if sol != i], x)*r[-1]/wr, x)*i/r[order] + negoneterm *= -1 + + if simplify_flag: + psol = simplify(psol) + psol = trigsimp(psol, deep=True) + return Eq(f(x), homogen_sol.rhs + psol) + + +def _get_const_characteristic_eq_sols(r, func, order): + r""" + Returns the roots of characteristic equation of constant coefficient + linear ODE and list of collectterms which is later on used by simplification + to use collect on solution. + + The parameter `r` is a dict of order:coeff terms, where order is the order of the + derivative on each term, and coeff is the coefficient of that derivative. + + """ + x = func.args[0] + # First, set up characteristic equation. + chareq, symbol = S.Zero, Dummy('x') + + for i in r.keys(): + if isinstance(i, str) or i < 0: + pass + else: + chareq += r[i]*symbol**i + + chareq = Poly(chareq, symbol) + # Can't just call roots because it doesn't return rootof for unsolveable + # polynomials. + chareqroots = roots(chareq, multiple=True) + if len(chareqroots) != order: + chareqroots = [rootof(chareq, k) for k in range(chareq.degree())] + + chareq_is_complex = not all(i.is_real for i in chareq.all_coeffs()) + + # Create a dict root: multiplicity or charroots + charroots = Counter(chareqroots) + # We need to keep track of terms so we can run collect() at the end. + # This is necessary for constantsimp to work properly. + collectterms = [] + gensols = [] + conjugate_roots = [] # used to prevent double-use of conjugate roots + # Loop over roots in theorder provided by roots/rootof... + for root in chareqroots: + # but don't repoeat multiple roots. + if root not in charroots: + continue + multiplicity = charroots.pop(root) + for i in range(multiplicity): + if chareq_is_complex: + gensols.append(x**i*exp(root*x)) + collectterms = [(i, root, 0)] + collectterms + continue + reroot = re(root) + imroot = im(root) + if imroot.has(atan2) and reroot.has(atan2): + # Remove this condition when re and im stop returning + # circular atan2 usages. + gensols.append(x**i*exp(root*x)) + collectterms = [(i, root, 0)] + collectterms + else: + if root in conjugate_roots: + collectterms = [(i, reroot, imroot)] + collectterms + continue + if imroot == 0: + gensols.append(x**i*exp(reroot*x)) + collectterms = [(i, reroot, 0)] + collectterms + continue + conjugate_roots.append(conjugate(root)) + gensols.append(x**i*exp(reroot*x) * sin(abs(imroot) * x)) + gensols.append(x**i*exp(reroot*x) * cos( imroot * x)) + + # This ordering is important + collectterms = [(i, reroot, imroot)] + collectterms + return gensols, collectterms + + +# Ideally these kind of simplification functions shouldn't be part of solvers. +# odesimp should be improved to handle these kind of specific simplifications. +def _get_simplified_sol(sol, func, collectterms): + r""" + Helper function which collects the solution on + collectterms. Ideally this should be handled by odesimp.It is used + only when the simplify is set to True in dsolve. + + The parameter ``collectterms`` is a list of tuple (i, reroot, imroot) where `i` is + the multiplicity of the root, reroot is real part and imroot being the imaginary part. + + """ + f = func.func + x = func.args[0] + collectterms.sort(key=default_sort_key) + collectterms.reverse() + assert len(sol) == 1 and sol[0].lhs == f(x) + sol = sol[0].rhs + sol = expand_mul(sol) + for i, reroot, imroot in collectterms: + sol = collect(sol, x**i*exp(reroot*x)*sin(abs(imroot)*x)) + sol = collect(sol, x**i*exp(reroot*x)*cos(imroot*x)) + for i, reroot, imroot in collectterms: + sol = collect(sol, x**i*exp(reroot*x)) + sol = powsimp(sol) + return Eq(f(x), sol) + + +def _undetermined_coefficients_match(expr, x, func=None, eq_homogeneous=S.Zero): + r""" + Returns a trial function match if undetermined coefficients can be applied + to ``expr``, and ``None`` otherwise. + + A trial expression can be found for an expression for use with the method + of undetermined coefficients if the expression is an + additive/multiplicative combination of constants, polynomials in `x` (the + independent variable of expr), `\sin(a x + b)`, `\cos(a x + b)`, and + `e^{a x}` terms (in other words, it has a finite number of linearly + independent derivatives). + + Note that you may still need to multiply each term returned here by + sufficient `x` to make it linearly independent with the solutions to the + homogeneous equation. + + This is intended for internal use by ``undetermined_coefficients`` hints. + + SymPy currently has no way to convert `\sin^n(x) \cos^m(y)` into a sum of + only `\sin(a x)` and `\cos(b x)` terms, so these are not implemented. So, + for example, you will need to manually convert `\sin^2(x)` into `[1 + + \cos(2 x)]/2` to properly apply the method of undetermined coefficients on + it. + + Examples + ======== + + >>> from sympy import log, exp + >>> from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match + >>> from sympy.abc import x + >>> _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x) + {'test': True, 'trialset': {x*exp(x), exp(-x), exp(x)}} + >>> _undetermined_coefficients_match(log(x), x) + {'test': False} + + """ + a = Wild('a', exclude=[x]) + b = Wild('b', exclude=[x]) + expr = powsimp(expr, combine='exp') # exp(x)*exp(2*x + 1) => exp(3*x + 1) + retdict = {} + + def _test_term(expr, x) -> bool: + r""" + Test if ``expr`` fits the proper form for undetermined coefficients. + """ + if not expr.has(x): + return True + if expr.is_Add: + return all(_test_term(i, x) for i in expr.args) + if expr.is_Mul: + if expr.has(sin, cos): + foundtrig = False + # Make sure that there is only one trig function in the args. + # See the docstring. + for i in expr.args: + if i.has(sin, cos): + if foundtrig: + return False + else: + foundtrig = True + return all(_test_term(i, x) for i in expr.args) + if expr.is_Function: + return expr.func in (sin, cos, exp, sinh, cosh) and \ + bool(expr.args[0].match(a*x + b)) + if expr.is_Pow and expr.base.is_Symbol and expr.exp.is_Integer and \ + expr.exp >= 0: + return True + if expr.is_Pow and expr.base.is_number: + return bool(expr.exp.match(a*x + b)) + return expr.is_Symbol or bool(expr.is_number) + + def _get_trial_set(expr, x, exprs=set()): + r""" + Returns a set of trial terms for undetermined coefficients. + + The idea behind undetermined coefficients is that the terms expression + repeat themselves after a finite number of derivatives, except for the + coefficients (they are linearly dependent). So if we collect these, + we should have the terms of our trial function. + """ + def _remove_coefficient(expr, x): + r""" + Returns the expression without a coefficient. + + Similar to expr.as_independent(x)[1], except it only works + multiplicatively. + """ + term = S.One + if expr.is_Mul: + for i in expr.args: + if i.has(x): + term *= i + elif expr.has(x): + term = expr + return term + + expr = expand_mul(expr) + if expr.is_Add: + for term in expr.args: + if _remove_coefficient(term, x) in exprs: + pass + else: + exprs.add(_remove_coefficient(term, x)) + exprs = exprs.union(_get_trial_set(term, x, exprs)) + else: + term = _remove_coefficient(expr, x) + tmpset = exprs.union({term}) + oldset = set() + while tmpset != oldset: + # If you get stuck in this loop, then _test_term is probably + # broken + oldset = tmpset.copy() + expr = expr.diff(x) + term = _remove_coefficient(expr, x) + if term.is_Add: + tmpset = tmpset.union(_get_trial_set(term, x, tmpset)) + else: + tmpset.add(term) + exprs = tmpset + return exprs + + def is_homogeneous_solution(term): + r""" This function checks whether the given trialset contains any root + of homogeneous equation""" + return expand(sub_func_doit(eq_homogeneous, func, term)).is_zero + + retdict['test'] = _test_term(expr, x) + if retdict['test']: + # Try to generate a list of trial solutions that will have the + # undetermined coefficients. Note that if any of these are not linearly + # independent with any of the solutions to the homogeneous equation, + # then they will need to be multiplied by sufficient x to make them so. + # This function DOES NOT do that (it doesn't even look at the + # homogeneous equation). + temp_set = set() + for i in Add.make_args(expr): + act = _get_trial_set(i, x) + if eq_homogeneous is not S.Zero: + while any(is_homogeneous_solution(ts) for ts in act): + act = {x*ts for ts in act} + temp_set = temp_set.union(act) + + retdict['trialset'] = temp_set + return retdict + + +def _solve_undetermined_coefficients(eq, func, order, match, trialset): + r""" + Helper function for the method of undetermined coefficients. + + See the + :py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffUndeterminedCoefficients` + docstring for more information on this method. + + The parameter ``trialset`` is the set of trial functions as returned by + ``_undetermined_coefficients_match()['trialset']``. + + The parameter ``match`` should be a dictionary that has the following + keys: + + ``list`` + A list of solutions to the homogeneous equation. + + ``sol`` + The general solution. + + """ + r = match + coeffs = numbered_symbols('a', cls=Dummy) + coefflist = [] + gensols = r['list'] + gsol = r['sol'] + f = func.func + x = func.args[0] + + if len(gensols) != order: + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply" + + " undetermined coefficients to " + str(eq) + + " (number of terms != order)") + + trialfunc = 0 + for i in trialset: + c = next(coeffs) + coefflist.append(c) + trialfunc += c*i + + eqs = sub_func_doit(eq, f(x), trialfunc) + + coeffsdict = dict(list(zip(trialset, [0]*(len(trialset) + 1)))) + + eqs = _mexpand(eqs) + + for i in Add.make_args(eqs): + s = separatevars(i, dict=True, symbols=[x]) + if coeffsdict.get(s[x]): + coeffsdict[s[x]] += s['coeff'] + else: + coeffsdict[s[x]] = s['coeff'] + + coeffvals = solve(list(coeffsdict.values()), coefflist) + + if not coeffvals: + raise NotImplementedError( + "Could not solve `%s` using the " + "method of undetermined coefficients " + "(unable to solve for coefficients)." % eq) + + psol = trialfunc.subs(coeffvals) + + return Eq(f(x), gsol.rhs + psol) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/ode.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/ode.py new file mode 100644 index 0000000000000000000000000000000000000000..6a28b2162b38a2dd8612c14e91fd588912f6756a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/ode.py @@ -0,0 +1,3572 @@ +r""" +This module contains :py:meth:`~sympy.solvers.ode.dsolve` and different helper +functions that it uses. + +:py:meth:`~sympy.solvers.ode.dsolve` solves ordinary differential equations. +See the docstring on the various functions for their uses. Note that partial +differential equations support is in ``pde.py``. Note that hint functions +have docstrings describing their various methods, but they are intended for +internal use. Use ``dsolve(ode, func, hint=hint)`` to solve an ODE using a +specific hint. See also the docstring on +:py:meth:`~sympy.solvers.ode.dsolve`. + +**Functions in this module** + + These are the user functions in this module: + + - :py:meth:`~sympy.solvers.ode.dsolve` - Solves ODEs. + - :py:meth:`~sympy.solvers.ode.classify_ode` - Classifies ODEs into + possible hints for :py:meth:`~sympy.solvers.ode.dsolve`. + - :py:meth:`~sympy.solvers.ode.checkodesol` - Checks if an equation is the + solution to an ODE. + - :py:meth:`~sympy.solvers.ode.homogeneous_order` - Returns the + homogeneous order of an expression. + - :py:meth:`~sympy.solvers.ode.infinitesimals` - Returns the infinitesimals + of the Lie group of point transformations of an ODE, such that it is + invariant. + - :py:meth:`~sympy.solvers.ode.checkinfsol` - Checks if the given infinitesimals + are the actual infinitesimals of a first order ODE. + + These are the non-solver helper functions that are for internal use. The + user should use the various options to + :py:meth:`~sympy.solvers.ode.dsolve` to obtain the functionality provided + by these functions: + + - :py:meth:`~sympy.solvers.ode.ode.odesimp` - Does all forms of ODE + simplification. + - :py:meth:`~sympy.solvers.ode.ode.ode_sol_simplicity` - A key function for + comparing solutions by simplicity. + - :py:meth:`~sympy.solvers.ode.constantsimp` - Simplifies arbitrary + constants. + - :py:meth:`~sympy.solvers.ode.ode.constant_renumber` - Renumber arbitrary + constants. + - :py:meth:`~sympy.solvers.ode.ode._handle_Integral` - Evaluate unevaluated + Integrals. + + See also the docstrings of these functions. + +**Currently implemented solver methods** + +The following methods are implemented for solving ordinary differential +equations. See the docstrings of the various hint functions for more +information on each (run ``help(ode)``): + + - 1st order separable differential equations. + - 1st order differential equations whose coefficients or `dx` and `dy` are + functions homogeneous of the same order. + - 1st order exact differential equations. + - 1st order linear differential equations. + - 1st order Bernoulli differential equations. + - Power series solutions for first order differential equations. + - Lie Group method of solving first order differential equations. + - 2nd order Liouville differential equations. + - Power series solutions for second order differential equations + at ordinary and regular singular points. + - `n`\th order differential equation that can be solved with algebraic + rearrangement and integration. + - `n`\th order linear homogeneous differential equation with constant + coefficients. + - `n`\th order linear inhomogeneous differential equation with constant + coefficients using the method of undetermined coefficients. + - `n`\th order linear inhomogeneous differential equation with constant + coefficients using the method of variation of parameters. + +**Philosophy behind this module** + +This module is designed to make it easy to add new ODE solving methods without +having to mess with the solving code for other methods. The idea is that +there is a :py:meth:`~sympy.solvers.ode.classify_ode` function, which takes in +an ODE and tells you what hints, if any, will solve the ODE. It does this +without attempting to solve the ODE, so it is fast. Each solving method is a +hint, and it has its own function, named ``ode_``. That function takes +in the ODE and any match expression gathered by +:py:meth:`~sympy.solvers.ode.classify_ode` and returns a solved result. If +this result has any integrals in it, the hint function will return an +unevaluated :py:class:`~sympy.integrals.integrals.Integral` class. +:py:meth:`~sympy.solvers.ode.dsolve`, which is the user wrapper function +around all of this, will then call :py:meth:`~sympy.solvers.ode.ode.odesimp` on +the result, which, among other things, will attempt to solve the equation for +the dependent variable (the function we are solving for), simplify the +arbitrary constants in the expression, and evaluate any integrals, if the hint +allows it. + +**How to add new solution methods** + +If you have an ODE that you want :py:meth:`~sympy.solvers.ode.dsolve` to be +able to solve, try to avoid adding special case code here. Instead, try +finding a general method that will solve your ODE, as well as others. This +way, the :py:mod:`~sympy.solvers.ode` module will become more robust, and +unhindered by special case hacks. WolphramAlpha and Maple's +DETools[odeadvisor] function are two resources you can use to classify a +specific ODE. It is also better for a method to work with an `n`\th order ODE +instead of only with specific orders, if possible. + +To add a new method, there are a few things that you need to do. First, you +need a hint name for your method. Try to name your hint so that it is +unambiguous with all other methods, including ones that may not be implemented +yet. If your method uses integrals, also include a ``hint_Integral`` hint. +If there is more than one way to solve ODEs with your method, include a hint +for each one, as well as a ``_best`` hint. Your ``ode__best()`` +function should choose the best using min with ``ode_sol_simplicity`` as the +key argument. See +:obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest`, for example. +The function that uses your method will be called ``ode_()``, so the +hint must only use characters that are allowed in a Python function name +(alphanumeric characters and the underscore '``_``' character). Include a +function for every hint, except for ``_Integral`` hints +(:py:meth:`~sympy.solvers.ode.dsolve` takes care of those automatically). +Hint names should be all lowercase, unless a word is commonly capitalized +(such as Integral or Bernoulli). If you have a hint that you do not want to +run with ``all_Integral`` that does not have an ``_Integral`` counterpart (such +as a best hint that would defeat the purpose of ``all_Integral``), you will +need to remove it manually in the :py:meth:`~sympy.solvers.ode.dsolve` code. +See also the :py:meth:`~sympy.solvers.ode.classify_ode` docstring for +guidelines on writing a hint name. + +Determine *in general* how the solutions returned by your method compare with +other methods that can potentially solve the same ODEs. Then, put your hints +in the :py:data:`~sympy.solvers.ode.allhints` tuple in the order that they +should be called. The ordering of this tuple determines which hints are +default. Note that exceptions are ok, because it is easy for the user to +choose individual hints with :py:meth:`~sympy.solvers.ode.dsolve`. In +general, ``_Integral`` variants should go at the end of the list, and +``_best`` variants should go before the various hints they apply to. For +example, the ``undetermined_coefficients`` hint comes before the +``variation_of_parameters`` hint because, even though variation of parameters +is more general than undetermined coefficients, undetermined coefficients +generally returns cleaner results for the ODEs that it can solve than +variation of parameters does, and it does not require integration, so it is +much faster. + +Next, you need to have a match expression or a function that matches the type +of the ODE, which you should put in :py:meth:`~sympy.solvers.ode.classify_ode` +(if the match function is more than just a few lines. It should match the +ODE without solving for it as much as possible, so that +:py:meth:`~sympy.solvers.ode.classify_ode` remains fast and is not hindered by +bugs in solving code. Be sure to consider corner cases. For example, if your +solution method involves dividing by something, make sure you exclude the case +where that division will be 0. + +In most cases, the matching of the ODE will also give you the various parts +that you need to solve it. You should put that in a dictionary (``.match()`` +will do this for you), and add that as ``matching_hints['hint'] = matchdict`` +in the relevant part of :py:meth:`~sympy.solvers.ode.classify_ode`. +:py:meth:`~sympy.solvers.ode.classify_ode` will then send this to +:py:meth:`~sympy.solvers.ode.dsolve`, which will send it to your function as +the ``match`` argument. Your function should be named ``ode_(eq, func, +order, match)`. If you need to send more information, put it in the ``match`` +dictionary. For example, if you had to substitute in a dummy variable in +:py:meth:`~sympy.solvers.ode.classify_ode` to match the ODE, you will need to +pass it to your function using the `match` dict to access it. You can access +the independent variable using ``func.args[0]``, and the dependent variable +(the function you are trying to solve for) as ``func.func``. If, while trying +to solve the ODE, you find that you cannot, raise ``NotImplementedError``. +:py:meth:`~sympy.solvers.ode.dsolve` will catch this error with the ``all`` +meta-hint, rather than causing the whole routine to fail. + +Add a docstring to your function that describes the method employed. Like +with anything else in SymPy, you will need to add a doctest to the docstring, +in addition to real tests in ``test_ode.py``. Try to maintain consistency +with the other hint functions' docstrings. Add your method to the list at the +top of this docstring. Also, add your method to ``ode.rst`` in the +``docs/src`` directory, so that the Sphinx docs will pull its docstring into +the main SymPy documentation. Be sure to make the Sphinx documentation by +running ``make html`` from within the doc directory to verify that the +docstring formats correctly. + +If your solution method involves integrating, use :py:obj:`~.Integral` instead of +:py:meth:`~sympy.core.expr.Expr.integrate`. This allows the user to bypass +hard/slow integration by using the ``_Integral`` variant of your hint. In +most cases, calling :py:meth:`sympy.core.basic.Basic.doit` will integrate your +solution. If this is not the case, you will need to write special code in +:py:meth:`~sympy.solvers.ode.ode._handle_Integral`. Arbitrary constants should be +symbols named ``C1``, ``C2``, and so on. All solution methods should return +an equality instance. If you need an arbitrary number of arbitrary constants, +you can use ``constants = numbered_symbols(prefix='C', cls=Symbol, start=1)``. +If it is possible to solve for the dependent function in a general way, do so. +Otherwise, do as best as you can, but do not call solve in your +``ode_()`` function. :py:meth:`~sympy.solvers.ode.ode.odesimp` will attempt +to solve the solution for you, so you do not need to do that. Lastly, if your +ODE has a common simplification that can be applied to your solutions, you can +add a special case in :py:meth:`~sympy.solvers.ode.ode.odesimp` for it. For +example, solutions returned from the ``1st_homogeneous_coeff`` hints often +have many :obj:`~sympy.functions.elementary.exponential.log` terms, so +:py:meth:`~sympy.solvers.ode.ode.odesimp` calls +:py:meth:`~sympy.simplify.simplify.logcombine` on them (it also helps to write +the arbitrary constant as ``log(C1)`` instead of ``C1`` in this case). Also +consider common ways that you can rearrange your solution to have +:py:meth:`~sympy.solvers.ode.constantsimp` take better advantage of it. It is +better to put simplification in :py:meth:`~sympy.solvers.ode.ode.odesimp` than in +your method, because it can then be turned off with the simplify flag in +:py:meth:`~sympy.solvers.ode.dsolve`. If you have any extraneous +simplification in your function, be sure to only run it using ``if +match.get('simplify', True):``, especially if it can be slow or if it can +reduce the domain of the solution. + +Finally, as with every contribution to SymPy, your method will need to be +tested. Add a test for each method in ``test_ode.py``. Follow the +conventions there, i.e., test the solver using ``dsolve(eq, f(x), +hint=your_hint)``, and also test the solution using +:py:meth:`~sympy.solvers.ode.checkodesol` (you can put these in a separate +tests and skip/XFAIL if it runs too slow/does not work). Be sure to call your +hint specifically in :py:meth:`~sympy.solvers.ode.dsolve`, that way the test +will not be broken simply by the introduction of another matching hint. If your +method works for higher order (>1) ODEs, you will need to run ``sol = +constant_renumber(sol, 'C', 1, order)`` for each solution, where ``order`` is +the order of the ODE. This is because ``constant_renumber`` renumbers the +arbitrary constants by printing order, which is platform dependent. Try to +test every corner case of your solver, including a range of orders if it is a +`n`\th order solver, but if your solver is slow, such as if it involves hard +integration, try to keep the test run time down. + +Feel free to refactor existing hints to avoid duplicating code or creating +inconsistencies. If you can show that your method exactly duplicates an +existing method, including in the simplicity and speed of obtaining the +solutions, then you can remove the old, less general method. The existing +code is tested extensively in ``test_ode.py``, so if anything is broken, one +of those tests will surely fail. + +""" + +from sympy.core import Add, S, Mul, Pow, oo +from sympy.core.containers import Tuple +from sympy.core.expr import AtomicExpr, Expr +from sympy.core.function import (Function, Derivative, AppliedUndef, diff, + expand, expand_mul, Subs) +from sympy.core.multidimensional import vectorize +from sympy.core.numbers import nan, zoo, Number +from sympy.core.relational import Equality, Eq +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Symbol, Wild, Dummy, symbols +from sympy.core.sympify import sympify +from sympy.core.traversal import preorder_traversal + +from sympy.logic.boolalg import (BooleanAtom, BooleanTrue, + BooleanFalse) +from sympy.functions import exp, log, sqrt +from sympy.functions.combinatorial.factorials import factorial +from sympy.integrals.integrals import Integral +from sympy.polys import (Poly, terms_gcd, PolynomialError, lcm) +from sympy.polys.polytools import cancel +from sympy.series import Order +from sympy.series.series import series +from sympy.simplify import (collect, logcombine, powsimp, # type: ignore + separatevars, simplify, cse) +from sympy.simplify.radsimp import collect_const +from sympy.solvers import checksol, solve + +from sympy.utilities import numbered_symbols +from sympy.utilities.iterables import uniq, sift, iterable +from sympy.solvers.deutils import _preprocess, ode_order, _desolve + + +#: This is a list of hints in the order that they should be preferred by +#: :py:meth:`~sympy.solvers.ode.classify_ode`. In general, hints earlier in the +#: list should produce simpler solutions than those later in the list (for +#: ODEs that fit both). For now, the order of this list is based on empirical +#: observations by the developers of SymPy. +#: +#: The hint used by :py:meth:`~sympy.solvers.ode.dsolve` for a specific ODE +#: can be overridden (see the docstring). +#: +#: In general, ``_Integral`` hints are grouped at the end of the list, unless +#: there is a method that returns an unevaluable integral most of the time +#: (which go near the end of the list anyway). ``default``, ``all``, +#: ``best``, and ``all_Integral`` meta-hints should not be included in this +#: list, but ``_best`` and ``_Integral`` hints should be included. +allhints = ( + "factorable", + "nth_algebraic", + "separable", + "1st_exact", + "1st_linear", + "Bernoulli", + "1st_rational_riccati", + "Riccati_special_minus2", + "1st_homogeneous_coeff_best", + "1st_homogeneous_coeff_subs_indep_div_dep", + "1st_homogeneous_coeff_subs_dep_div_indep", + "almost_linear", + "linear_coefficients", + "separable_reduced", + "1st_power_series", + "lie_group", + "nth_linear_constant_coeff_homogeneous", + "nth_linear_euler_eq_homogeneous", + "nth_linear_constant_coeff_undetermined_coefficients", + "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients", + "nth_linear_constant_coeff_variation_of_parameters", + "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters", + "Liouville", + "2nd_linear_airy", + "2nd_linear_bessel", + "2nd_hypergeometric", + "2nd_hypergeometric_Integral", + "nth_order_reducible", + "2nd_power_series_ordinary", + "2nd_power_series_regular", + "nth_algebraic_Integral", + "separable_Integral", + "1st_exact_Integral", + "1st_linear_Integral", + "Bernoulli_Integral", + "1st_homogeneous_coeff_subs_indep_div_dep_Integral", + "1st_homogeneous_coeff_subs_dep_div_indep_Integral", + "almost_linear_Integral", + "linear_coefficients_Integral", + "separable_reduced_Integral", + "nth_linear_constant_coeff_variation_of_parameters_Integral", + "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters_Integral", + "Liouville_Integral", + "2nd_nonlinear_autonomous_conserved", + "2nd_nonlinear_autonomous_conserved_Integral", + ) + + + +def get_numbered_constants(eq, num=1, start=1, prefix='C'): + """ + Returns a list of constants that do not occur + in eq already. + """ + + ncs = iter_numbered_constants(eq, start, prefix) + Cs = [next(ncs) for i in range(num)] + return (Cs[0] if num == 1 else tuple(Cs)) + + +def iter_numbered_constants(eq, start=1, prefix='C'): + """ + Returns an iterator of constants that do not occur + in eq already. + """ + + if isinstance(eq, (Expr, Eq)): + eq = [eq] + elif not iterable(eq): + raise ValueError("Expected Expr or iterable but got %s" % eq) + + atom_set = set().union(*[i.free_symbols for i in eq]) + func_set = set().union(*[i.atoms(Function) for i in eq]) + if func_set: + atom_set |= {Symbol(str(f.func)) for f in func_set} + return numbered_symbols(start=start, prefix=prefix, exclude=atom_set) + + +def dsolve(eq, func=None, hint="default", simplify=True, + ics= None, xi=None, eta=None, x0=0, n=6, **kwargs): + r""" + Solves any (supported) kind of ordinary differential equation and + system of ordinary differential equations. + + For single ordinary differential equation + ========================================= + + It is classified under this when number of equation in ``eq`` is one. + **Usage** + + ``dsolve(eq, f(x), hint)`` -> Solve ordinary differential equation + ``eq`` for function ``f(x)``, using method ``hint``. + + **Details** + + ``eq`` can be any supported ordinary differential equation (see the + :py:mod:`~sympy.solvers.ode` docstring for supported methods). + This can either be an :py:class:`~sympy.core.relational.Equality`, + or an expression, which is assumed to be equal to ``0``. + + ``f(x)`` is a function of one variable whose derivatives in that + variable make up the ordinary differential equation ``eq``. In + many cases it is not necessary to provide this; it will be + autodetected (and an error raised if it could not be detected). + + ``hint`` is the solving method that you want dsolve to use. Use + ``classify_ode(eq, f(x))`` to get all of the possible hints for an + ODE. The default hint, ``default``, will use whatever hint is + returned first by :py:meth:`~sympy.solvers.ode.classify_ode`. See + Hints below for more options that you can use for hint. + + ``simplify`` enables simplification by + :py:meth:`~sympy.solvers.ode.ode.odesimp`. See its docstring for more + information. Turn this off, for example, to disable solving of + solutions for ``func`` or simplification of arbitrary constants. + It will still integrate with this hint. Note that the solution may + contain more arbitrary constants than the order of the ODE with + this option enabled. + + ``xi`` and ``eta`` are the infinitesimal functions of an ordinary + differential equation. They are the infinitesimals of the Lie group + of point transformations for which the differential equation is + invariant. The user can specify values for the infinitesimals. If + nothing is specified, ``xi`` and ``eta`` are calculated using + :py:meth:`~sympy.solvers.ode.infinitesimals` with the help of various + heuristics. + + ``ics`` is the set of initial/boundary conditions for the differential equation. + It should be given in the form of ``{f(x0): x1, f(x).diff(x).subs(x, x2): + x3}`` and so on. For power series solutions, if no initial + conditions are specified ``f(0)`` is assumed to be ``C0`` and the power + series solution is calculated about 0. + + ``x0`` is the point about which the power series solution of a differential + equation is to be evaluated. + + ``n`` gives the exponent of the dependent variable up to which the power series + solution of a differential equation is to be evaluated. + + **Hints** + + Aside from the various solving methods, there are also some meta-hints + that you can pass to :py:meth:`~sympy.solvers.ode.dsolve`: + + ``default``: + This uses whatever hint is returned first by + :py:meth:`~sympy.solvers.ode.classify_ode`. This is the + default argument to :py:meth:`~sympy.solvers.ode.dsolve`. + + ``all``: + To make :py:meth:`~sympy.solvers.ode.dsolve` apply all + relevant classification hints, use ``dsolve(ODE, func, + hint="all")``. This will return a dictionary of + ``hint:solution`` terms. If a hint causes dsolve to raise the + ``NotImplementedError``, value of that hint's key will be the + exception object raised. The dictionary will also include + some special keys: + + - ``order``: The order of the ODE. See also + :py:meth:`~sympy.solvers.deutils.ode_order` in + ``deutils.py``. + - ``best``: The simplest hint; what would be returned by + ``best`` below. + - ``best_hint``: The hint that would produce the solution + given by ``best``. If more than one hint produces the best + solution, the first one in the tuple returned by + :py:meth:`~sympy.solvers.ode.classify_ode` is chosen. + - ``default``: The solution that would be returned by default. + This is the one produced by the hint that appears first in + the tuple returned by + :py:meth:`~sympy.solvers.ode.classify_ode`. + + ``all_Integral``: + This is the same as ``all``, except if a hint also has a + corresponding ``_Integral`` hint, it only returns the + ``_Integral`` hint. This is useful if ``all`` causes + :py:meth:`~sympy.solvers.ode.dsolve` to hang because of a + difficult or impossible integral. This meta-hint will also be + much faster than ``all``, because + :py:meth:`~sympy.core.expr.Expr.integrate` is an expensive + routine. + + ``best``: + To have :py:meth:`~sympy.solvers.ode.dsolve` try all methods + and return the simplest one. This takes into account whether + the solution is solvable in the function, whether it contains + any Integral classes (i.e. unevaluatable integrals), and + which one is the shortest in size. + + See also the :py:meth:`~sympy.solvers.ode.classify_ode` docstring for + more info on hints, and the :py:mod:`~sympy.solvers.ode` docstring for + a list of all supported hints. + + **Tips** + + - You can declare the derivative of an unknown function this way: + + >>> from sympy import Function, Derivative + >>> from sympy.abc import x # x is the independent variable + >>> f = Function("f")(x) # f is a function of x + >>> # f_ will be the derivative of f with respect to x + >>> f_ = Derivative(f, x) + + - See ``test_ode.py`` for many tests, which serves also as a set of + examples for how to use :py:meth:`~sympy.solvers.ode.dsolve`. + - :py:meth:`~sympy.solvers.ode.dsolve` always returns an + :py:class:`~sympy.core.relational.Equality` class (except for the + case when the hint is ``all`` or ``all_Integral``). If possible, it + solves the solution explicitly for the function being solved for. + Otherwise, it returns an implicit solution. + - Arbitrary constants are symbols named ``C1``, ``C2``, and so on. + - Because all solutions should be mathematically equivalent, some + hints may return the exact same result for an ODE. Often, though, + two different hints will return the same solution formatted + differently. The two should be equivalent. Also note that sometimes + the values of the arbitrary constants in two different solutions may + not be the same, because one constant may have "absorbed" other + constants into it. + - Do ``help(ode.ode_)`` to get help more information on a + specific hint, where ```` is the name of a hint without + ``_Integral``. + + For system of ordinary differential equations + ============================================= + + **Usage** + ``dsolve(eq, func)`` -> Solve a system of ordinary differential + equations ``eq`` for ``func`` being list of functions including + `x(t)`, `y(t)`, `z(t)` where number of functions in the list depends + upon the number of equations provided in ``eq``. + + **Details** + + ``eq`` can be any supported system of ordinary differential equations + This can either be an :py:class:`~sympy.core.relational.Equality`, + or an expression, which is assumed to be equal to ``0``. + + ``func`` holds ``x(t)`` and ``y(t)`` being functions of one variable which + together with some of their derivatives make up the system of ordinary + differential equation ``eq``. It is not necessary to provide this; it + will be autodetected (and an error raised if it could not be detected). + + **Hints** + + The hints are formed by parameters returned by classify_sysode, combining + them give hints name used later for forming method name. + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, Derivative, sin, cos, symbols + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(Derivative(f(x), x, x) + 9*f(x), f(x)) + Eq(f(x), C1*sin(3*x) + C2*cos(3*x)) + + >>> eq = sin(x)*cos(f(x)) + cos(x)*sin(f(x))*f(x).diff(x) + >>> dsolve(eq, hint='1st_exact') + [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))] + >>> dsolve(eq, hint='almost_linear') + [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))] + >>> t = symbols('t') + >>> x, y = symbols('x, y', cls=Function) + >>> eq = (Eq(Derivative(x(t),t), 12*t*x(t) + 8*y(t)), Eq(Derivative(y(t),t), 21*x(t) + 7*t*y(t))) + >>> dsolve(eq) + [Eq(x(t), C1*x0(t) + C2*x0(t)*Integral(8*exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)**2, t)), + Eq(y(t), C1*y0(t) + C2*(y0(t)*Integral(8*exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)**2, t) + + exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)))] + >>> eq = (Eq(Derivative(x(t),t),x(t)*y(t)*sin(t)), Eq(Derivative(y(t),t),y(t)**2*sin(t))) + >>> dsolve(eq) + {Eq(x(t), -exp(C1)/(C2*exp(C1) - cos(t))), Eq(y(t), -1/(C1 - cos(t)))} + """ + if iterable(eq): + from sympy.solvers.ode.systems import dsolve_system + + # This may have to be changed in future + # when we have weakly and strongly + # connected components. This have to + # changed to show the systems that haven't + # been solved. + try: + sol = dsolve_system(eq, funcs=func, ics=ics, doit=True) + return sol[0] if len(sol) == 1 else sol + except NotImplementedError: + pass + + match = classify_sysode(eq, func) + + eq = match['eq'] + order = match['order'] + func = match['func'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + + # keep highest order term coefficient positive + for i in range(len(eq)): + for func_ in func: + if isinstance(func_, list): + pass + else: + if eq[i].coeff(diff(func[i],t,ode_order(eq[i], func[i]))).is_negative: + eq[i] = -eq[i] + match['eq'] = eq + if len(set(order.values()))!=1: + raise ValueError("It solves only those systems of equations whose orders are equal") + match['order'] = list(order.values())[0] + def recur_len(l): + return sum(recur_len(item) if isinstance(item,list) else 1 for item in l) + if recur_len(func) != len(eq): + raise ValueError("dsolve() and classify_sysode() work with " + "number of functions being equal to number of equations") + if match['type_of_equation'] is None: + raise NotImplementedError + else: + if match['is_linear'] == True: + solvefunc = globals()['sysode_linear_%(no_of_equation)seq_order%(order)s' % match] + else: + solvefunc = globals()['sysode_nonlinear_%(no_of_equation)seq_order%(order)s' % match] + sols = solvefunc(match) + if ics: + constants = Tuple(*sols).free_symbols - Tuple(*eq).free_symbols + solved_constants = solve_ics(sols, func, constants, ics) + return [sol.subs(solved_constants) for sol in sols] + return sols + else: + given_hint = hint # hint given by the user + + # See the docstring of _desolve for more details. + hints = _desolve(eq, func=func, + hint=hint, simplify=True, xi=xi, eta=eta, type='ode', ics=ics, + x0=x0, n=n, **kwargs) + eq = hints.pop('eq', eq) + all_ = hints.pop('all', False) + if all_: + retdict = {} + failed_hints = {} + gethints = classify_ode(eq, dict=True, hint='all') + orderedhints = gethints['ordered_hints'] + for hint in hints: + try: + rv = _helper_simplify(eq, hint, hints[hint], simplify) + except NotImplementedError as detail: + failed_hints[hint] = detail + else: + retdict[hint] = rv + func = hints[hint]['func'] + + retdict['best'] = min(list(retdict.values()), key=lambda x: + ode_sol_simplicity(x, func, trysolving=not simplify)) + if given_hint == 'best': + return retdict['best'] + for i in orderedhints: + if retdict['best'] == retdict.get(i, None): + retdict['best_hint'] = i + break + retdict['default'] = gethints['default'] + retdict['order'] = gethints['order'] + retdict.update(failed_hints) + return retdict + + else: + # The key 'hint' stores the hint needed to be solved for. + hint = hints['hint'] + return _helper_simplify(eq, hint, hints, simplify, ics=ics) + + +def _helper_simplify(eq, hint, match, simplify=True, ics=None, **kwargs): + r""" + Helper function of dsolve that calls the respective + :py:mod:`~sympy.solvers.ode` functions to solve for the ordinary + differential equations. This minimizes the computation in calling + :py:meth:`~sympy.solvers.deutils._desolve` multiple times. + """ + r = match + func = r['func'] + order = r['order'] + match = r[hint] + + if isinstance(match, SingleODESolver): + solvefunc = match + else: + solvefunc = globals()['ode_' + hint.removesuffix('_Integral')] + + free = eq.free_symbols + cons = lambda s: s.free_symbols.difference(free) + + if simplify: + # odesimp() will attempt to integrate, if necessary, apply constantsimp(), + # attempt to solve for func, and apply any other hint specific + # simplifications + if isinstance(solvefunc, SingleODESolver): + sols = solvefunc.get_general_solution() + else: + sols = solvefunc(eq, func, order, match) + if iterable(sols): + rv = [] + for s in sols: + simp = odesimp(eq, s, func, hint) + if iterable(simp): + rv.extend(simp) + else: + rv.append(simp) + else: + rv = odesimp(eq, sols, func, hint) + else: + # We still want to integrate (you can disable it separately with the hint) + if isinstance(solvefunc, SingleODESolver): + exprs = solvefunc.get_general_solution(simplify=False) + else: + match['simplify'] = False # Some hints can take advantage of this option + exprs = solvefunc(eq, func, order, match) + if isinstance(exprs, list): + rv = [_handle_Integral(expr, func, hint) for expr in exprs] + else: + rv = _handle_Integral(exprs, func, hint) + + if isinstance(rv, list): + assert all(isinstance(i, Eq) for i in rv), rv # if not => internal error + if simplify: + rv = _remove_redundant_solutions(eq, rv, order, func.args[0]) + if len(rv) == 1: + rv = rv[0] + if ics and 'power_series' not in hint: + if isinstance(rv, (Expr, Eq)): + solved_constants = solve_ics([rv], [r['func']], cons(rv), ics) + rv = rv.subs(solved_constants) + else: + rv1 = [] + for s in rv: + try: + solved_constants = solve_ics([s], [r['func']], cons(s), ics) + except ValueError: + continue + rv1.append(s.subs(solved_constants)) + if len(rv1) == 1: + return rv1[0] + rv = rv1 + return rv + + +def solve_ics(sols, funcs, constants, ics): + """ + Solve for the constants given initial conditions + + ``sols`` is a list of solutions. + + ``funcs`` is a list of functions. + + ``constants`` is a list of constants. + + ``ics`` is the set of initial/boundary conditions for the differential + equation. It should be given in the form of ``{f(x0): x1, + f(x).diff(x).subs(x, x2): x3}`` and so on. + + Returns a dictionary mapping constants to values. + ``solution.subs(constants)`` will replace the constants in ``solution``. + + Example + ======= + >>> # From dsolve(f(x).diff(x) - f(x), f(x)) + >>> from sympy import symbols, Eq, exp, Function + >>> from sympy.solvers.ode.ode import solve_ics + >>> f = Function('f') + >>> x, C1 = symbols('x C1') + >>> sols = [Eq(f(x), C1*exp(x))] + >>> funcs = [f(x)] + >>> constants = [C1] + >>> ics = {f(0): 2} + >>> solved_constants = solve_ics(sols, funcs, constants, ics) + >>> solved_constants + {C1: 2} + >>> sols[0].subs(solved_constants) + Eq(f(x), 2*exp(x)) + + """ + # Assume ics are of the form f(x0): value or Subs(diff(f(x), x, n), (x, + # x0)): value (currently checked by classify_ode). To solve, replace x + # with x0, f(x0) with value, then solve for constants. For f^(n)(x0), + # differentiate the solution n times, so that f^(n)(x) appears. + x = funcs[0].args[0] + diff_sols = [] + subs_sols = [] + diff_variables = set() + for funcarg, value in ics.items(): + if isinstance(funcarg, AppliedUndef): + x0 = funcarg.args[0] + matching_func = [f for f in funcs if f.func == funcarg.func][0] + S = sols + elif isinstance(funcarg, (Subs, Derivative)): + if isinstance(funcarg, Subs): + # Make sure it stays a subs. Otherwise subs below will produce + # a different looking term. + funcarg = funcarg.doit() + if isinstance(funcarg, Subs): + deriv = funcarg.expr + x0 = funcarg.point[0] + variables = funcarg.expr.variables + matching_func = deriv + elif isinstance(funcarg, Derivative): + deriv = funcarg + x0 = funcarg.variables[0] + variables = (x,)*len(funcarg.variables) + matching_func = deriv.subs(x0, x) + for sol in sols: + if sol.has(deriv.expr.func): + diff_sols.append(Eq(sol.lhs.diff(*variables), sol.rhs.diff(*variables))) + diff_variables.add(variables) + S = diff_sols + else: + raise NotImplementedError("Unrecognized initial condition") + + for sol in S: + if sol.has(matching_func): + sol2 = sol + sol2 = sol2.subs(x, x0) + sol2 = sol2.subs(funcarg, value) + # This check is necessary because of issue #15724 + if not isinstance(sol2, BooleanAtom) or not subs_sols: + subs_sols = [s for s in subs_sols if not isinstance(s, BooleanAtom)] + subs_sols.append(sol2) + + # TODO: Use solveset here + try: + solved_constants = solve(subs_sols, constants, dict=True) + except NotImplementedError: + solved_constants = [] + + # XXX: We can't differentiate between the solution not existing because of + # invalid initial conditions, and not existing because solve is not smart + # enough. If we could use solveset, this might be improvable, but for now, + # we use NotImplementedError in this case. + if not solved_constants: + raise ValueError("Couldn't solve for initial conditions") + + if solved_constants == True: + raise ValueError("Initial conditions did not produce any solutions for constants. Perhaps they are degenerate.") + + if len(solved_constants) > 1: + raise NotImplementedError("Initial conditions produced too many solutions for constants") + + return solved_constants[0] + +def classify_ode(eq, func=None, dict=False, ics=None, *, prep=True, xi=None, eta=None, n=None, **kwargs): + r""" + Returns a tuple of possible :py:meth:`~sympy.solvers.ode.dsolve` + classifications for an ODE. + + The tuple is ordered so that first item is the classification that + :py:meth:`~sympy.solvers.ode.dsolve` uses to solve the ODE by default. In + general, classifications at the near the beginning of the list will + produce better solutions faster than those near the end, thought there are + always exceptions. To make :py:meth:`~sympy.solvers.ode.dsolve` use a + different classification, use ``dsolve(ODE, func, + hint=)``. See also the + :py:meth:`~sympy.solvers.ode.dsolve` docstring for different meta-hints + you can use. + + If ``dict`` is true, :py:meth:`~sympy.solvers.ode.classify_ode` will + return a dictionary of ``hint:match`` expression terms. This is intended + for internal use by :py:meth:`~sympy.solvers.ode.dsolve`. Note that + because dictionaries are ordered arbitrarily, this will most likely not be + in the same order as the tuple. + + You can get help on different hints by executing + ``help(ode.ode_hintname)``, where ``hintname`` is the name of the hint + without ``_Integral``. + + See :py:data:`~sympy.solvers.ode.allhints` or the + :py:mod:`~sympy.solvers.ode` docstring for a list of all supported hints + that can be returned from :py:meth:`~sympy.solvers.ode.classify_ode`. + + Notes + ===== + + These are remarks on hint names. + + ``_Integral`` + + If a classification has ``_Integral`` at the end, it will return the + expression with an unevaluated :py:class:`~.Integral` + class in it. Note that a hint may do this anyway if + :py:meth:`~sympy.core.expr.Expr.integrate` cannot do the integral, + though just using an ``_Integral`` will do so much faster. Indeed, an + ``_Integral`` hint will always be faster than its corresponding hint + without ``_Integral`` because + :py:meth:`~sympy.core.expr.Expr.integrate` is an expensive routine. + If :py:meth:`~sympy.solvers.ode.dsolve` hangs, it is probably because + :py:meth:`~sympy.core.expr.Expr.integrate` is hanging on a tough or + impossible integral. Try using an ``_Integral`` hint or + ``all_Integral`` to get it return something. + + Note that some hints do not have ``_Integral`` counterparts. This is + because :py:func:`~sympy.integrals.integrals.integrate` is not used in + solving the ODE for those method. For example, `n`\th order linear + homogeneous ODEs with constant coefficients do not require integration + to solve, so there is no + ``nth_linear_homogeneous_constant_coeff_Integrate`` hint. You can + easily evaluate any unevaluated + :py:class:`~sympy.integrals.integrals.Integral`\s in an expression by + doing ``expr.doit()``. + + Ordinals + + Some hints contain an ordinal such as ``1st_linear``. This is to help + differentiate them from other hints, as well as from other methods + that may not be implemented yet. If a hint has ``nth`` in it, such as + the ``nth_linear`` hints, this means that the method used to applies + to ODEs of any order. + + ``indep`` and ``dep`` + + Some hints contain the words ``indep`` or ``dep``. These reference + the independent variable and the dependent function, respectively. For + example, if an ODE is in terms of `f(x)`, then ``indep`` will refer to + `x` and ``dep`` will refer to `f`. + + ``subs`` + + If a hints has the word ``subs`` in it, it means that the ODE is solved + by substituting the expression given after the word ``subs`` for a + single dummy variable. This is usually in terms of ``indep`` and + ``dep`` as above. The substituted expression will be written only in + characters allowed for names of Python objects, meaning operators will + be spelled out. For example, ``indep``/``dep`` will be written as + ``indep_div_dep``. + + ``coeff`` + + The word ``coeff`` in a hint refers to the coefficients of something + in the ODE, usually of the derivative terms. See the docstring for + the individual methods for more info (``help(ode)``). This is + contrast to ``coefficients``, as in ``undetermined_coefficients``, + which refers to the common name of a method. + + ``_best`` + + Methods that have more than one fundamental way to solve will have a + hint for each sub-method and a ``_best`` meta-classification. This + will evaluate all hints and return the best, using the same + considerations as the normal ``best`` meta-hint. + + + Examples + ======== + + >>> from sympy import Function, classify_ode, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> classify_ode(Eq(f(x).diff(x), 0), f(x)) + ('nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + >>> classify_ode(f(x).diff(x, 2) + 3*f(x).diff(x) + 2*f(x) - 4) + ('factorable', 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + + """ + ics = sympify(ics) + + if func and len(func.args) != 1: + raise ValueError("dsolve() and classify_ode() only " + "work with functions of one variable, not %s" % func) + + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + + # Some methods want the unprocessed equation + eq_orig = eq + + if prep or func is None: + eq, func_ = _preprocess(eq, func) + if func is None: + func = func_ + x = func.args[0] + f = func.func + y = Dummy('y') + terms = 5 if n is None else n + + order = ode_order(eq, f(x)) + # hint:matchdict or hint:(tuple of matchdicts) + # Also will contain "default": and "order":order items. + matching_hints = {"order": order} + + df = f(x).diff(x) + a = Wild('a', exclude=[f(x)]) + d = Wild('d', exclude=[df, f(x).diff(x, 2)]) + e = Wild('e', exclude=[df]) + n = Wild('n', exclude=[x, f(x), df]) + c1 = Wild('c1', exclude=[x]) + a3 = Wild('a3', exclude=[f(x), df, f(x).diff(x, 2)]) + b3 = Wild('b3', exclude=[f(x), df, f(x).diff(x, 2)]) + c3 = Wild('c3', exclude=[f(x), df, f(x).diff(x, 2)]) + boundary = {} # Used to extract initial conditions + C1 = Symbol("C1") + + # Preprocessing to get the initial conditions out + if ics is not None: + for funcarg in ics: + # Separating derivatives + if isinstance(funcarg, (Subs, Derivative)): + # f(x).diff(x).subs(x, 0) is a Subs, but f(x).diff(x).subs(x, + # y) is a Derivative + if isinstance(funcarg, Subs): + deriv = funcarg.expr + old = funcarg.variables[0] + new = funcarg.point[0] + elif isinstance(funcarg, Derivative): + deriv = funcarg + # No information on this. Just assume it was x + old = x + new = funcarg.variables[0] + + if (isinstance(deriv, Derivative) and isinstance(deriv.args[0], + AppliedUndef) and deriv.args[0].func == f and + len(deriv.args[0].args) == 1 and old == x and not + new.has(x) and all(i == deriv.variables[0] for i in + deriv.variables) and x not in ics[funcarg].free_symbols): + + dorder = ode_order(deriv, x) + temp = 'f' + str(dorder) + boundary.update({temp: new, temp + 'val': ics[funcarg]}) + else: + raise ValueError("Invalid boundary conditions for Derivatives") + + + # Separating functions + elif isinstance(funcarg, AppliedUndef): + if (funcarg.func == f and len(funcarg.args) == 1 and + not funcarg.args[0].has(x) and x not in ics[funcarg].free_symbols): + boundary.update({'f0': funcarg.args[0], 'f0val': ics[funcarg]}) + else: + raise ValueError("Invalid boundary conditions for Function") + + else: + raise ValueError("Enter boundary conditions of the form ics={f(point): value, f(x).diff(x, order).subs(x, point): value}") + + ode = SingleODEProblem(eq_orig, func, x, prep=prep, xi=xi, eta=eta) + user_hint = kwargs.get('hint', 'default') + # Used when dsolve is called without an explicit hint. + # We exit early to return the first valid match + early_exit = (user_hint=='default') + user_hint = user_hint.removesuffix('_Integral') + user_map = solver_map + # An explicit hint has been given to dsolve + # Skip matching code for other hints + if user_hint not in ['default', 'all', 'all_Integral', 'best'] and user_hint in solver_map: + user_map = {user_hint: solver_map[user_hint]} + + for hint in user_map: + solver = user_map[hint](ode) + if solver.matches(): + matching_hints[hint] = solver + if user_map[hint].has_integral: + matching_hints[hint + "_Integral"] = solver + if dict and early_exit: + matching_hints["default"] = hint + return matching_hints + + eq = expand(eq) + # Precondition to try remove f(x) from highest order derivative + reduced_eq = None + if eq.is_Add: + deriv_coef = eq.coeff(f(x).diff(x, order)) + if deriv_coef not in (1, 0): + r = deriv_coef.match(a*f(x)**c1) + if r and r[c1]: + den = f(x)**r[c1] + reduced_eq = Add(*[arg/den for arg in eq.args]) + if not reduced_eq: + reduced_eq = eq + + if order == 1: + + # NON-REDUCED FORM OF EQUATION matches + r = collect(eq, df, exact=True).match(d + e * df) + if r: + r['d'] = d + r['e'] = e + r['y'] = y + r[d] = r[d].subs(f(x), y) + r[e] = r[e].subs(f(x), y) + + # FIRST ORDER POWER SERIES WHICH NEEDS INITIAL CONDITIONS + # TODO: Hint first order series should match only if d/e is analytic. + # For now, only d/e and (d/e).diff(arg) is checked for existence at + # at a given point. + # This is currently done internally in ode_1st_power_series. + point = boundary.get('f0', 0) + value = boundary.get('f0val', C1) + check = cancel(r[d]/r[e]) + check1 = check.subs({x: point, y: value}) + if not check1.has(oo) and not check1.has(zoo) and \ + not check1.has(nan) and not check1.has(-oo): + check2 = (check1.diff(x)).subs({x: point, y: value}) + if not check2.has(oo) and not check2.has(zoo) and \ + not check2.has(nan) and not check2.has(-oo): + rseries = r.copy() + rseries.update({'terms': terms, 'f0': point, 'f0val': value}) + matching_hints["1st_power_series"] = rseries + + elif order == 2: + # Homogeneous second order differential equation of the form + # a3*f(x).diff(x, 2) + b3*f(x).diff(x) + c3 + # It has a definite power series solution at point x0 if, b3/a3 and c3/a3 + # are analytic at x0. + deq = a3*(f(x).diff(x, 2)) + b3*df + c3*f(x) + r = collect(reduced_eq, + [f(x).diff(x, 2), f(x).diff(x), f(x)]).match(deq) + ordinary = False + if r: + if not all(r[key].is_polynomial() for key in r): + n, d = reduced_eq.as_numer_denom() + reduced_eq = expand(n) + r = collect(reduced_eq, + [f(x).diff(x, 2), f(x).diff(x), f(x)]).match(deq) + if r and r[a3] != 0: + p = cancel(r[b3]/r[a3]) # Used below + q = cancel(r[c3]/r[a3]) # Used below + point = kwargs.get('x0', 0) + check = p.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + check = q.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + ordinary = True + r.update({'a3': a3, 'b3': b3, 'c3': c3, 'x0': point, 'terms': terms}) + matching_hints["2nd_power_series_ordinary"] = r + + # Checking if the differential equation has a regular singular point + # at x0. It has a regular singular point at x0, if (b3/a3)*(x - x0) + # and (c3/a3)*((x - x0)**2) are analytic at x0. + if not ordinary: + p = cancel((x - point)*p) + check = p.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + q = cancel(((x - point)**2)*q) + check = q.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + coeff_dict = {'p': p, 'q': q, 'x0': point, 'terms': terms} + matching_hints["2nd_power_series_regular"] = coeff_dict + + + # Order keys based on allhints. + retlist = [i for i in allhints if i in matching_hints] + if dict: + # Dictionaries are ordered arbitrarily, so make note of which + # hint would come first for dsolve(). Use an ordered dict in Py 3. + matching_hints["default"] = retlist[0] if retlist else None + matching_hints["ordered_hints"] = tuple(retlist) + return matching_hints + else: + return tuple(retlist) + + +def classify_sysode(eq, funcs=None, **kwargs): + r""" + Returns a dictionary of parameter names and values that define the system + of ordinary differential equations in ``eq``. + The parameters are further used in + :py:meth:`~sympy.solvers.ode.dsolve` for solving that system. + + Some parameter names and values are: + + 'is_linear' (boolean), which tells whether the given system is linear. + Note that "linear" here refers to the operator: terms such as ``x*diff(x,t)`` are + nonlinear, whereas terms like ``sin(t)*diff(x,t)`` are still linear operators. + + 'func' (list) contains the :py:class:`~sympy.core.function.Function`s that + appear with a derivative in the ODE, i.e. those that we are trying to solve + the ODE for. + + 'order' (dict) with the maximum derivative for each element of the 'func' + parameter. + + 'func_coeff' (dict or Matrix) with the coefficient for each triple ``(equation number, + function, order)```. The coefficients are those subexpressions that do not + appear in 'func', and hence can be considered constant for purposes of ODE + solving. The value of this parameter can also be a Matrix if the system of ODEs are + linear first order of the form X' = AX where X is the vector of dependent variables. + Here, this function returns the coefficient matrix A. + + 'eq' (list) with the equations from ``eq``, sympified and transformed into + expressions (we are solving for these expressions to be zero). + + 'no_of_equations' (int) is the number of equations (same as ``len(eq)``). + + 'type_of_equation' (string) is an internal classification of the type of + ODE. + + 'is_constant' (boolean), which tells if the system of ODEs is constant coefficient + or not. This key is temporary addition for now and is in the match dict only when + the system of ODEs is linear first order constant coefficient homogeneous. So, this + key's value is True for now if it is available else it does not exist. + + 'is_homogeneous' (boolean), which tells if the system of ODEs is homogeneous. Like the + key 'is_constant', this key is a temporary addition and it is True since this key value + is available only when the system is linear first order constant coefficient homogeneous. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode-toc1.htm + -A. D. Polyanin and A. V. Manzhirov, Handbook of Mathematics for Engineers and Scientists + + Examples + ======== + + >>> from sympy import Function, Eq, symbols, diff + >>> from sympy.solvers.ode.ode import classify_sysode + >>> from sympy.abc import t + >>> f, x, y = symbols('f, x, y', cls=Function) + >>> k, l, m, n = symbols('k, l, m, n', Integer=True) + >>> x1 = diff(x(t), t) ; y1 = diff(y(t), t) + >>> x2 = diff(x(t), t, t) ; y2 = diff(y(t), t, t) + >>> eq = (Eq(x1, 12*x(t) - 6*y(t)), Eq(y1, 11*x(t) + 3*y(t))) + >>> classify_sysode(eq) + {'eq': [-12*x(t) + 6*y(t) + Derivative(x(t), t), -11*x(t) - 3*y(t) + Derivative(y(t), t)], 'func': [x(t), y(t)], + 'func_coeff': {(0, x(t), 0): -12, (0, x(t), 1): 1, (0, y(t), 0): 6, (0, y(t), 1): 0, (1, x(t), 0): -11, (1, x(t), 1): 0, (1, y(t), 0): -3, (1, y(t), 1): 1}, 'is_linear': True, 'no_of_equation': 2, 'order': {x(t): 1, y(t): 1}, 'type_of_equation': None} + >>> eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t) + 2), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + >>> classify_sysode(eq) + {'eq': [-t**2*y(t) - 5*t*x(t) + Derivative(x(t), t) - 2, t**2*x(t) - 5*t*y(t) + Derivative(y(t), t)], + 'func': [x(t), y(t)], 'func_coeff': {(0, x(t), 0): -5*t, (0, x(t), 1): 1, (0, y(t), 0): -t**2, (0, y(t), 1): 0, + (1, x(t), 0): t**2, (1, x(t), 1): 0, (1, y(t), 0): -5*t, (1, y(t), 1): 1}, 'is_linear': True, 'no_of_equation': 2, + 'order': {x(t): 1, y(t): 1}, 'type_of_equation': None} + + """ + + # Sympify equations and convert iterables of equations into + # a list of equations + def _sympify(eq): + return list(map(sympify, eq if iterable(eq) else [eq])) + + eq, funcs = (_sympify(w) for w in [eq, funcs]) + for i, fi in enumerate(eq): + if isinstance(fi, Equality): + eq[i] = fi.lhs - fi.rhs + + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + matching_hints = {"no_of_equation":i+1} + matching_hints['eq'] = eq + if i==0: + raise ValueError("classify_sysode() works for systems of ODEs. " + "For scalar ODEs, classify_ode should be used") + + # find all the functions if not given + order = {} + if funcs==[None]: + funcs = _extract_funcs(eq) + + funcs = list(set(funcs)) + if len(funcs) != len(eq): + raise ValueError("Number of functions given is not equal to the number of equations %s" % funcs) + + # This logic of list of lists in funcs to + # be replaced later. + func_dict = {} + for func in funcs: + if not order.get(func, False): + max_order = 0 + for i, eqs_ in enumerate(eq): + order_ = ode_order(eqs_,func) + if max_order < order_: + max_order = order_ + eq_no = i + if eq_no in func_dict: + func_dict[eq_no] = [func_dict[eq_no], func] + else: + func_dict[eq_no] = func + order[func] = max_order + + funcs = [func_dict[i] for i in range(len(func_dict))] + matching_hints['func'] = funcs + for func in funcs: + if isinstance(func, list): + for func_elem in func: + if len(func_elem.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + else: + if func and len(func.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + + # find the order of all equation in system of odes + matching_hints["order"] = order + + # find coefficients of terms f(t), diff(f(t),t) and higher derivatives + # and similarly for other functions g(t), diff(g(t),t) in all equations. + # Here j denotes the equation number, funcs[l] denotes the function about + # which we are talking about and k denotes the order of function funcs[l] + # whose coefficient we are calculating. + def linearity_check(eqs, j, func, is_linear_): + for k in range(order[func] + 1): + func_coef[j, func, k] = collect(eqs.expand(), [diff(func, t, k)]).coeff(diff(func, t, k)) + if is_linear_ == True: + if func_coef[j, func, k] == 0: + if k == 0: + coef = eqs.as_independent(func, as_Add=True)[1] + for xr in range(1, ode_order(eqs,func) + 1): + coef -= eqs.as_independent(diff(func, t, xr), as_Add=True)[1] + if coef != 0: + is_linear_ = False + else: + if eqs.as_independent(diff(func, t, k), as_Add=True)[1]: + is_linear_ = False + else: + for func_ in funcs: + if isinstance(func_, list): + for elem_func_ in func_: + dep = func_coef[j, func, k].as_independent(elem_func_, as_Add=True)[1] + if dep != 0: + is_linear_ = False + else: + dep = func_coef[j, func, k].as_independent(func_, as_Add=True)[1] + if dep != 0: + is_linear_ = False + return is_linear_ + + func_coef = {} + is_linear = True + for j, eqs in enumerate(eq): + for func in funcs: + if isinstance(func, list): + for func_elem in func: + is_linear = linearity_check(eqs, j, func_elem, is_linear) + else: + is_linear = linearity_check(eqs, j, func, is_linear) + matching_hints['func_coeff'] = func_coef + matching_hints['is_linear'] = is_linear + + + if len(set(order.values())) == 1: + order_eq = list(matching_hints['order'].values())[0] + if matching_hints['is_linear'] == True: + if matching_hints['no_of_equation'] == 2: + if order_eq == 1: + type_of_equation = check_linear_2eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + # If the equation does not match up with any of the + # general case solvers in systems.py and the number + # of equations is greater than 2, then NotImplementedError + # should be raised. + else: + type_of_equation = None + + else: + if matching_hints['no_of_equation'] == 2: + if order_eq == 1: + type_of_equation = check_nonlinear_2eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + elif matching_hints['no_of_equation'] == 3: + if order_eq == 1: + type_of_equation = check_nonlinear_3eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + else: + type_of_equation = None + else: + type_of_equation = None + + matching_hints['type_of_equation'] = type_of_equation + + return matching_hints + + +def check_linear_2eq_order1(eq, func, func_coef): + x = func[0].func + y = func[1].func + fc = func_coef + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + r = {} + # for equations Eq(a1*diff(x(t),t), b1*x(t) + c1*y(t) + d1) + # and Eq(a2*diff(y(t),t), b2*x(t) + c2*y(t) + d2) + r['a1'] = fc[0,x(t),1] ; r['a2'] = fc[1,y(t),1] + r['b1'] = -fc[0,x(t),0]/fc[0,x(t),1] ; r['b2'] = -fc[1,x(t),0]/fc[1,y(t),1] + r['c1'] = -fc[0,y(t),0]/fc[0,x(t),1] ; r['c2'] = -fc[1,y(t),0]/fc[1,y(t),1] + forcing = [S.Zero,S.Zero] + for i in range(2): + for j in Add.make_args(eq[i]): + if not j.has(x(t), y(t)): + forcing[i] += j + if not (forcing[0].has(t) or forcing[1].has(t)): + # We can handle homogeneous case and simple constant forcings + r['d1'] = forcing[0] + r['d2'] = forcing[1] + else: + # Issue #9244: nonhomogeneous linear systems are not supported + return None + + # Conditions to check for type 6 whose equations are Eq(diff(x(t),t), f(t)*x(t) + g(t)*y(t)) and + # Eq(diff(y(t),t), a*[f(t) + a*h(t)]x(t) + a*[g(t) - h(t)]*y(t)) + p = 0 + q = 0 + p1 = cancel(r['b2']/(cancel(r['b2']/r['c2']).as_numer_denom()[0])) + p2 = cancel(r['b1']/(cancel(r['b1']/r['c1']).as_numer_denom()[0])) + for n, i in enumerate([p1, p2]): + for j in Mul.make_args(collect_const(i)): + if not j.has(t): + q = j + if q and n==0: + if ((r['b2']/j - r['b1'])/(r['c1'] - r['c2']/j)) == j: + p = 1 + elif q and n==1: + if ((r['b1']/j - r['b2'])/(r['c2'] - r['c1']/j)) == j: + p = 2 + # End of condition for type 6 + + if r['d1']!=0 or r['d2']!=0: + return None + else: + if not any(r[k].has(t) for k in 'a1 a2 b1 b2 c1 c2'.split()): + return None + else: + r['b1'] = r['b1']/r['a1'] ; r['b2'] = r['b2']/r['a2'] + r['c1'] = r['c1']/r['a1'] ; r['c2'] = r['c2']/r['a2'] + if p: + return "type6" + else: + # Equations for type 7 are Eq(diff(x(t),t), f(t)*x(t) + g(t)*y(t)) and Eq(diff(y(t),t), h(t)*x(t) + p(t)*y(t)) + return "type7" +def check_nonlinear_2eq_order1(eq, func, func_coef): + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + f = Wild('f') + g = Wild('g') + u, v = symbols('u, v', cls=Dummy) + def check_type(x, y): + r1 = eq[0].match(t*diff(x(t),t) - x(t) + f) + r2 = eq[1].match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = eq[0].match(diff(x(t),t) - x(t)/t + f/t) + r2 = eq[1].match(diff(y(t),t) - y(t)/t + g/t) + if not (r1 and r2): + r1 = (-eq[0]).match(t*diff(x(t),t) - x(t) + f) + r2 = (-eq[1]).match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = (-eq[0]).match(diff(x(t),t) - x(t)/t + f/t) + r2 = (-eq[1]).match(diff(y(t),t) - y(t)/t + g/t) + if r1 and r2 and not (r1[f].subs(diff(x(t),t),u).subs(diff(y(t),t),v).has(t) \ + or r2[g].subs(diff(x(t),t),u).subs(diff(y(t),t),v).has(t)): + return 'type5' + else: + return None + for func_ in func: + if isinstance(func_, list): + x = func[0][0].func + y = func[0][1].func + eq_type = check_type(x, y) + if not eq_type: + eq_type = check_type(y, x) + return eq_type + x = func[0].func + y = func[1].func + fc = func_coef + n = Wild('n', exclude=[x(t),y(t)]) + f1 = Wild('f1', exclude=[v,t]) + f2 = Wild('f2', exclude=[v,t]) + g1 = Wild('g1', exclude=[u,t]) + g2 = Wild('g2', exclude=[u,t]) + for i in range(2): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + r = eq[0].match(diff(x(t),t) - x(t)**n*f) + if r: + g = (diff(y(t),t) - eq[1])/r[f] + if r and not (g.has(x(t)) or g.subs(y(t),v).has(t) or r[f].subs(x(t),u).subs(y(t),v).has(t)): + return 'type1' + r = eq[0].match(diff(x(t),t) - exp(n*x(t))*f) + if r: + g = (diff(y(t),t) - eq[1])/r[f] + if r and not (g.has(x(t)) or g.subs(y(t),v).has(t) or r[f].subs(x(t),u).subs(y(t),v).has(t)): + return 'type2' + g = Wild('g') + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + if r1 and r2 and not (r1[f].subs(x(t),u).subs(y(t),v).has(t) or \ + r2[g].subs(x(t),u).subs(y(t),v).has(t)): + return 'type3' + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + num, den = ( + (r1[f].subs(x(t),u).subs(y(t),v))/ + (r2[g].subs(x(t),u).subs(y(t),v))).as_numer_denom() + R1 = num.match(f1*g1) + R2 = den.match(f2*g2) + # phi = (r1[f].subs(x(t),u).subs(y(t),v))/num + if R1 and R2: + return 'type4' + return None + + +def check_nonlinear_2eq_order2(eq, func, func_coef): + return None + +def check_nonlinear_3eq_order1(eq, func, func_coef): + x = func[0].func + y = func[1].func + z = func[2].func + fc = func_coef + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + u, v, w = symbols('u, v, w', cls=Dummy) + a = Wild('a', exclude=[x(t), y(t), z(t), t]) + b = Wild('b', exclude=[x(t), y(t), z(t), t]) + c = Wild('c', exclude=[x(t), y(t), z(t), t]) + f = Wild('f') + F1 = Wild('F1') + F2 = Wild('F2') + F3 = Wild('F3') + for i in range(3): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + r1 = eq[0].match(diff(x(t),t) - a*y(t)*z(t)) + r2 = eq[1].match(diff(y(t),t) - b*z(t)*x(t)) + r3 = eq[2].match(diff(z(t),t) - c*x(t)*y(t)) + if r1 and r2 and r3: + num1, den1 = r1[a].as_numer_denom() + num2, den2 = r2[b].as_numer_denom() + num3, den3 = r3[c].as_numer_denom() + if solve([num1*u-den1*(v-w), num2*v-den2*(w-u), num3*w-den3*(u-v)],[u, v]): + return 'type1' + r = eq[0].match(diff(x(t),t) - y(t)*z(t)*f) + if r: + r1 = collect_const(r[f]).match(a*f) + r2 = ((diff(y(t),t) - eq[1])/r1[f]).match(b*z(t)*x(t)) + r3 = ((diff(z(t),t) - eq[2])/r1[f]).match(c*x(t)*y(t)) + if r1 and r2 and r3: + num1, den1 = r1[a].as_numer_denom() + num2, den2 = r2[b].as_numer_denom() + num3, den3 = r3[c].as_numer_denom() + if solve([num1*u-den1*(v-w), num2*v-den2*(w-u), num3*w-den3*(u-v)],[u, v]): + return 'type2' + r = eq[0].match(diff(x(t),t) - (F2-F3)) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = eq[1].match(diff(y(t),t) - a*r1[F3] + r1[c]*F1) + if r2: + r3 = (eq[2] == diff(z(t),t) - r1[b]*r2[F1] + r2[a]*r1[F2]) + if r1 and r2 and r3: + return 'type3' + r = eq[0].match(diff(x(t),t) - z(t)*F2 + y(t)*F3) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = (diff(y(t),t) - eq[1]).match(a*x(t)*r1[F3] - r1[c]*z(t)*F1) + if r2: + r3 = (diff(z(t),t) - eq[2] == r1[b]*y(t)*r2[F1] - r2[a]*x(t)*r1[F2]) + if r1 and r2 and r3: + return 'type4' + r = (diff(x(t),t) - eq[0]).match(x(t)*(F2 - F3)) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = (diff(y(t),t) - eq[1]).match(y(t)*(a*r1[F3] - r1[c]*F1)) + if r2: + r3 = (diff(z(t),t) - eq[2] == z(t)*(r1[b]*r2[F1] - r2[a]*r1[F2])) + if r1 and r2 and r3: + return 'type5' + return None + + +def check_nonlinear_3eq_order2(eq, func, func_coef): + return None + + +@vectorize(0) +def odesimp(ode, eq, func, hint): + r""" + Simplifies solutions of ODEs, including trying to solve for ``func`` and + running :py:meth:`~sympy.solvers.ode.constantsimp`. + + It may use knowledge of the type of solution that the hint returns to + apply additional simplifications. + + It also attempts to integrate any :py:class:`~sympy.integrals.integrals.Integral`\s + in the expression, if the hint is not an ``_Integral`` hint. + + This function should have no effect on expressions returned by + :py:meth:`~sympy.solvers.ode.dsolve`, as + :py:meth:`~sympy.solvers.ode.dsolve` already calls + :py:meth:`~sympy.solvers.ode.ode.odesimp`, but the individual hint functions + do not call :py:meth:`~sympy.solvers.ode.ode.odesimp` (because the + :py:meth:`~sympy.solvers.ode.dsolve` wrapper does). Therefore, this + function is designed for mainly internal use. + + Examples + ======== + + >>> from sympy import sin, symbols, dsolve, pprint, Function + >>> from sympy.solvers.ode.ode import odesimp + >>> x, u2, C1= symbols('x,u2,C1') + >>> f = Function('f') + + >>> eq = dsolve(x*f(x).diff(x) - f(x) - x*sin(f(x)/x), f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep_Integral', + ... simplify=False) + >>> pprint(eq, wrap_line=False) + x + ---- + f(x) + / + | + | / 1 \ + | -|u1 + -------| + | | /1 \| + | | sin|--|| + | \ \u1// + log(f(x)) = log(C1) + | ---------------- d(u1) + | 2 + | u1 + | + / + + >>> pprint(odesimp(eq, f(x), 1, {C1}, + ... hint='1st_homogeneous_coeff_subs_indep_div_dep' + ... )) #doctest: +SKIP + x + --------- = C1 + /f(x)\ + tan|----| + \2*x / + + """ + x = func.args[0] + f = func.func + C1 = get_numbered_constants(eq, num=1) + constants = eq.free_symbols - ode.free_symbols + + # First, integrate if the hint allows it. + eq = _handle_Integral(eq, func, hint) + if hint.startswith("nth_linear_euler_eq_nonhomogeneous"): + eq = simplify(eq) + if not isinstance(eq, Equality): + raise TypeError("eq should be an instance of Equality") + + # allow simplifications under assumption that symbols are nonzero + eq = eq.xreplace((_:={i: Dummy(nonzero=True) for i in constants})).xreplace({_[i]: i for i in _}) + + # Second, clean up the arbitrary constants. + # Right now, nth linear hints can put as many as 2*order constants in an + # expression. If that number grows with another hint, the third argument + # here should be raised accordingly, or constantsimp() rewritten to handle + # an arbitrary number of constants. + eq = constantsimp(eq, constants) + + # Lastly, now that we have cleaned up the expression, try solving for func. + # When CRootOf is implemented in solve(), we will want to return a CRootOf + # every time instead of an Equality. + + # Get the f(x) on the left if possible. + if eq.rhs == func and not eq.lhs.has(func): + eq = [Eq(eq.rhs, eq.lhs)] + + # make sure we are working with lists of solutions in simplified form. + if eq.lhs == func and not eq.rhs.has(func): + # The solution is already solved + eq = [eq] + + else: + # The solution is not solved, so try to solve it + try: + floats = any(i.is_Float for i in eq.atoms(Number)) + eqsol = solve(eq, func, force=True, rational=False if floats else None) + if not eqsol: + raise NotImplementedError + except (NotImplementedError, PolynomialError): + eq = [eq] + else: + def _expand(expr): + numer, denom = expr.as_numer_denom() + + if denom.is_Add: + return expr + else: + return powsimp(expr.expand(), combine='exp', deep=True) + + # XXX: the rest of odesimp() expects each ``t`` to be in a + # specific normal form: rational expression with numerator + # expanded, but with combined exponential functions (at + # least in this setup all tests pass). + eq = [Eq(f(x), _expand(t)) for t in eqsol] + + # special simplification of the lhs. + if hint.startswith("1st_homogeneous_coeff"): + for j, eqi in enumerate(eq): + newi = logcombine(eqi, force=True) + if isinstance(newi.lhs, log) and newi.rhs == 0: + newi = Eq(newi.lhs.args[0]/C1, C1) + eq[j] = newi + + # We cleaned up the constants before solving to help the solve engine with + # a simpler expression, but the solved expression could have introduced + # things like -C1, so rerun constantsimp() one last time before returning. + for i, eqi in enumerate(eq): + eq[i] = constantsimp(eqi, constants) + eq[i] = constant_renumber(eq[i], ode.free_symbols) + + # If there is only 1 solution, return it; + # otherwise return the list of solutions. + if len(eq) == 1: + eq = eq[0] + return eq + + +def ode_sol_simplicity(sol, func, trysolving=True): + r""" + Returns an extended integer representing how simple a solution to an ODE + is. + + The following things are considered, in order from most simple to least: + + - ``sol`` is solved for ``func``. + - ``sol`` is not solved for ``func``, but can be if passed to solve (e.g., + a solution returned by ``dsolve(ode, func, simplify=False``). + - If ``sol`` is not solved for ``func``, then base the result on the + length of ``sol``, as computed by ``len(str(sol))``. + - If ``sol`` has any unevaluated :py:class:`~sympy.integrals.integrals.Integral`\s, + this will automatically be considered less simple than any of the above. + + This function returns an integer such that if solution A is simpler than + solution B by above metric, then ``ode_sol_simplicity(sola, func) < + ode_sol_simplicity(solb, func)``. + + Currently, the following are the numbers returned, but if the heuristic is + ever improved, this may change. Only the ordering is guaranteed. + + +----------------------------------------------+-------------------+ + | Simplicity | Return | + +==============================================+===================+ + | ``sol`` solved for ``func`` | ``-2`` | + +----------------------------------------------+-------------------+ + | ``sol`` not solved for ``func`` but can be | ``-1`` | + +----------------------------------------------+-------------------+ + | ``sol`` is not solved nor solvable for | ``len(str(sol))`` | + | ``func`` | | + +----------------------------------------------+-------------------+ + | ``sol`` contains an | ``oo`` | + | :obj:`~sympy.integrals.integrals.Integral` | | + +----------------------------------------------+-------------------+ + + ``oo`` here means the SymPy infinity, which should compare greater than + any integer. + + If you already know :py:meth:`~sympy.solvers.solvers.solve` cannot solve + ``sol``, you can use ``trysolving=False`` to skip that step, which is the + only potentially slow step. For example, + :py:meth:`~sympy.solvers.ode.dsolve` with the ``simplify=False`` flag + should do this. + + If ``sol`` is a list of solutions, if the worst solution in the list + returns ``oo`` it returns that, otherwise it returns ``len(str(sol))``, + that is, the length of the string representation of the whole list. + + Examples + ======== + + This function is designed to be passed to ``min`` as the key argument, + such as ``min(listofsolutions, key=lambda i: ode_sol_simplicity(i, + f(x)))``. + + >>> from sympy import symbols, Function, Eq, tan, Integral + >>> from sympy.solvers.ode.ode import ode_sol_simplicity + >>> x, C1, C2 = symbols('x, C1, C2') + >>> f = Function('f') + + >>> ode_sol_simplicity(Eq(f(x), C1*x**2), f(x)) + -2 + >>> ode_sol_simplicity(Eq(x**2 + f(x), C1), f(x)) + -1 + >>> ode_sol_simplicity(Eq(f(x), C1*Integral(2*x, x)), f(x)) + oo + >>> eq1 = Eq(f(x)/tan(f(x)/(2*x)), C1) + >>> eq2 = Eq(f(x)/tan(f(x)/(2*x) + f(x)), C2) + >>> [ode_sol_simplicity(eq, f(x)) for eq in [eq1, eq2]] + [28, 35] + >>> min([eq1, eq2], key=lambda i: ode_sol_simplicity(i, f(x))) + Eq(f(x)/tan(f(x)/(2*x)), C1) + + """ + # TODO: if two solutions are solved for f(x), we still want to be + # able to get the simpler of the two + + # See the docstring for the coercion rules. We check easier (faster) + # things here first, to save time. + + if iterable(sol): + # See if there are Integrals + for i in sol: + if ode_sol_simplicity(i, func, trysolving=trysolving) == oo: + return oo + + return len(str(sol)) + + if sol.has(Integral): + return oo + + # Next, try to solve for func. This code will change slightly when CRootOf + # is implemented in solve(). Probably a CRootOf solution should fall + # somewhere between a normal solution and an unsolvable expression. + + # First, see if they are already solved + if sol.lhs == func and not sol.rhs.has(func) or \ + sol.rhs == func and not sol.lhs.has(func): + return -2 + # We are not so lucky, try solving manually + if trysolving: + try: + sols = solve(sol, func) + if not sols: + raise NotImplementedError + except NotImplementedError: + pass + else: + return -1 + + # Finally, a naive computation based on the length of the string version + # of the expression. This may favor combined fractions because they + # will not have duplicate denominators, and may slightly favor expressions + # with fewer additions and subtractions, as those are separated by spaces + # by the printer. + + # Additional ideas for simplicity heuristics are welcome, like maybe + # checking if a equation has a larger domain, or if constantsimp has + # introduced arbitrary constants numbered higher than the order of a + # given ODE that sol is a solution of. + return len(str(sol)) + + +def _extract_funcs(eqs): + funcs = [] + for eq in eqs: + derivs = [node for node in preorder_traversal(eq) if isinstance(node, Derivative)] + func = [] + for d in derivs: + func += list(d.atoms(AppliedUndef)) + for func_ in func: + funcs.append(func_) + funcs = list(uniq(funcs)) + + return funcs + + +def _get_constant_subexpressions(expr, Cs): + Cs = set(Cs) + Ces = [] + def _recursive_walk(expr): + expr_syms = expr.free_symbols + if expr_syms and expr_syms.issubset(Cs): + Ces.append(expr) + else: + if expr.func == exp: + expr = expr.expand(mul=True) + if expr.func in (Add, Mul): + d = sift(expr.args, lambda i : i.free_symbols.issubset(Cs)) + if len(d[True]) > 1: + x = expr.func(*d[True]) + if not x.is_number: + Ces.append(x) + elif isinstance(expr, Integral): + if expr.free_symbols.issubset(Cs) and \ + all(len(x) == 3 for x in expr.limits): + Ces.append(expr) + for i in expr.args: + _recursive_walk(i) + return + _recursive_walk(expr) + return Ces + +def __remove_linear_redundancies(expr, Cs): + cnts = {i: expr.count(i) for i in Cs} + Cs = [i for i in Cs if cnts[i] > 0] + + def _linear(expr): + if isinstance(expr, Add): + xs = [i for i in Cs if expr.count(i)==cnts[i] \ + and 0 == expr.diff(i, 2)] + d = {} + for x in xs: + y = expr.diff(x) + if y not in d: + d[y]=[] + d[y].append(x) + for y in d: + if len(d[y]) > 1: + d[y].sort(key=str) + for x in d[y][1:]: + expr = expr.subs(x, 0) + return expr + + def _recursive_walk(expr): + if len(expr.args) != 0: + expr = expr.func(*[_recursive_walk(i) for i in expr.args]) + expr = _linear(expr) + return expr + + if isinstance(expr, Equality): + lhs, rhs = [_recursive_walk(i) for i in expr.args] + f = lambda i: isinstance(i, Number) or i in Cs + if isinstance(lhs, Symbol) and lhs in Cs: + rhs, lhs = lhs, rhs + if lhs.func in (Add, Symbol) and rhs.func in (Add, Symbol): + dlhs = sift([lhs] if isinstance(lhs, AtomicExpr) else lhs.args, f) + drhs = sift([rhs] if isinstance(rhs, AtomicExpr) else rhs.args, f) + for i in [True, False]: + for hs in [dlhs, drhs]: + if i not in hs: + hs[i] = [0] + # this calculation can be simplified + lhs = Add(*dlhs[False]) - Add(*drhs[False]) + rhs = Add(*drhs[True]) - Add(*dlhs[True]) + elif lhs.func in (Mul, Symbol) and rhs.func in (Mul, Symbol): + dlhs = sift([lhs] if isinstance(lhs, AtomicExpr) else lhs.args, f) + if True in dlhs: + if False not in dlhs: + dlhs[False] = [1] + lhs = Mul(*dlhs[False]) + rhs = rhs/Mul(*dlhs[True]) + return Eq(lhs, rhs) + else: + return _recursive_walk(expr) + +@vectorize(0) +def constantsimp(expr, constants): + r""" + Simplifies an expression with arbitrary constants in it. + + This function is written specifically to work with + :py:meth:`~sympy.solvers.ode.dsolve`, and is not intended for general use. + + Simplification is done by "absorbing" the arbitrary constants into other + arbitrary constants, numbers, and symbols that they are not independent + of. + + The symbols must all have the same name with numbers after it, for + example, ``C1``, ``C2``, ``C3``. The ``symbolname`` here would be + '``C``', the ``startnumber`` would be 1, and the ``endnumber`` would be 3. + If the arbitrary constants are independent of the variable ``x``, then the + independent symbol would be ``x``. There is no need to specify the + dependent function, such as ``f(x)``, because it already has the + independent symbol, ``x``, in it. + + Because terms are "absorbed" into arbitrary constants and because + constants are renumbered after simplifying, the arbitrary constants in + expr are not necessarily equal to the ones of the same name in the + returned result. + + If two or more arbitrary constants are added, multiplied, or raised to the + power of each other, they are first absorbed together into a single + arbitrary constant. Then the new constant is combined into other terms if + necessary. + + Absorption of constants is done with limited assistance: + + 1. terms of :py:class:`~sympy.core.add.Add`\s are collected to try join + constants so `e^x (C_1 \cos(x) + C_2 \cos(x))` will simplify to `e^x + C_1 \cos(x)`; + + 2. powers with exponents that are :py:class:`~sympy.core.add.Add`\s are + expanded so `e^{C_1 + x}` will be simplified to `C_1 e^x`. + + Use :py:meth:`~sympy.solvers.ode.ode.constant_renumber` to renumber constants + after simplification or else arbitrary numbers on constants may appear, + e.g. `C_1 + C_3 x`. + + In rare cases, a single constant can be "simplified" into two constants. + Every differential equation solution should have as many arbitrary + constants as the order of the differential equation. The result here will + be technically correct, but it may, for example, have `C_1` and `C_2` in + an expression, when `C_1` is actually equal to `C_2`. Use your discretion + in such situations, and also take advantage of the ability to use hints in + :py:meth:`~sympy.solvers.ode.dsolve`. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.solvers.ode.ode import constantsimp + >>> C1, C2, C3, x, y = symbols('C1, C2, C3, x, y') + >>> constantsimp(2*C1*x, {C1, C2, C3}) + C1*x + >>> constantsimp(C1 + 2 + x, {C1, C2, C3}) + C1 + x + >>> constantsimp(C1*C2 + 2 + C2 + C3*x, {C1, C2, C3}) + C1 + C3*x + + """ + # This function works recursively. The idea is that, for Mul, + # Add, Pow, and Function, if the class has a constant in it, then + # we can simplify it, which we do by recursing down and + # simplifying up. Otherwise, we can skip that part of the + # expression. + + Cs = constants + + orig_expr = expr + + constant_subexprs = _get_constant_subexpressions(expr, Cs) + for xe in constant_subexprs: + xes = list(xe.free_symbols) + if not xes: + continue + if all(expr.count(c) == xe.count(c) for c in xes): + xes.sort(key=str) + expr = expr.subs(xe, xes[0]) + + # try to perform common sub-expression elimination of constant terms + try: + commons, rexpr = cse(expr) + commons.reverse() + rexpr = rexpr[0] + for s in commons: + cs = list(s[1].atoms(Symbol)) + if len(cs) == 1 and cs[0] in Cs and \ + cs[0] not in rexpr.atoms(Symbol) and \ + not any(cs[0] in ex for ex in commons if ex != s): + rexpr = rexpr.subs(s[0], cs[0]) + else: + rexpr = rexpr.subs(*s) + expr = rexpr + except IndexError: + pass + expr = __remove_linear_redundancies(expr, Cs) + + def _conditional_term_factoring(expr): + new_expr = terms_gcd(expr, clear=False, deep=True, expand=False) + + # we do not want to factor exponentials, so handle this separately + if new_expr.is_Mul: + infac = False + asfac = False + for m in new_expr.args: + if isinstance(m, exp): + asfac = True + elif m.is_Add: + infac = any(isinstance(fi, exp) for t in m.args + for fi in Mul.make_args(t)) + if asfac and infac: + new_expr = expr + break + return new_expr + + expr = _conditional_term_factoring(expr) + + # call recursively if more simplification is possible + if orig_expr != expr: + return constantsimp(expr, Cs) + return expr + + +def constant_renumber(expr, variables=None, newconstants=None): + r""" + Renumber arbitrary constants in ``expr`` to use the symbol names as given + in ``newconstants``. In the process, this reorders expression terms in a + standard way. + + If ``newconstants`` is not provided then the new constant names will be + ``C1``, ``C2`` etc. Otherwise ``newconstants`` should be an iterable + giving the new symbols to use for the constants in order. + + The ``variables`` argument is a list of non-constant symbols. All other + free symbols found in ``expr`` are assumed to be constants and will be + renumbered. If ``variables`` is not given then any numbered symbol + beginning with ``C`` (e.g. ``C1``) is assumed to be a constant. + + Symbols are renumbered based on ``.sort_key()``, so they should be + numbered roughly in the order that they appear in the final, printed + expression. Note that this ordering is based in part on hashes, so it can + produce different results on different machines. + + The structure of this function is very similar to that of + :py:meth:`~sympy.solvers.ode.constantsimp`. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.solvers.ode.ode import constant_renumber + >>> x, C1, C2, C3 = symbols('x,C1:4') + >>> expr = C3 + C2*x + C1*x**2 + >>> expr + C1*x**2 + C2*x + C3 + >>> constant_renumber(expr) + C1 + C2*x + C3*x**2 + + The ``variables`` argument specifies which are constants so that the + other symbols will not be renumbered: + + >>> constant_renumber(expr, [C1, x]) + C1*x**2 + C2 + C3*x + + The ``newconstants`` argument is used to specify what symbols to use when + replacing the constants: + + >>> constant_renumber(expr, [x], newconstants=symbols('E1:4')) + E1 + E2*x + E3*x**2 + + """ + + # System of expressions + if isinstance(expr, (set, list, tuple)): + return type(expr)(constant_renumber(Tuple(*expr), + variables=variables, newconstants=newconstants)) + + # Symbols in solution but not ODE are constants + if variables is not None: + variables = set(variables) + free_symbols = expr.free_symbols + constantsymbols = list(free_symbols - variables) + # Any Cn is a constant... + else: + variables = set() + isconstant = lambda s: s.startswith('C') and s[1:].isdigit() + constantsymbols = [sym for sym in expr.free_symbols if isconstant(sym.name)] + + # Find new constants checking that they aren't already in the ODE + if newconstants is None: + iter_constants = numbered_symbols(start=1, prefix='C', exclude=variables) + else: + iter_constants = (sym for sym in newconstants if sym not in variables) + + constants_found = [] + + # make a mapping to send all constantsymbols to S.One and use + # that to make sure that term ordering is not dependent on + # the indexed value of C + C_1 = [(ci, S.One) for ci in constantsymbols] + sort_key=lambda arg: default_sort_key(arg.subs(C_1)) + + def _constant_renumber(expr): + r""" + We need to have an internal recursive function + """ + + # For system of expressions + if isinstance(expr, Tuple): + renumbered = [_constant_renumber(e) for e in expr] + return Tuple(*renumbered) + + if isinstance(expr, Equality): + return Eq( + _constant_renumber(expr.lhs), + _constant_renumber(expr.rhs)) + + if type(expr) not in (Mul, Add, Pow) and not expr.is_Function and \ + not expr.has(*constantsymbols): + # Base case, as above. Hope there aren't constants inside + # of some other class, because they won't be renumbered. + return expr + elif expr.is_Piecewise: + return expr + elif expr in constantsymbols: + if expr not in constants_found: + constants_found.append(expr) + return expr + elif expr.is_Function or expr.is_Pow: + return expr.func( + *[_constant_renumber(x) for x in expr.args]) + else: + sortedargs = list(expr.args) + sortedargs.sort(key=sort_key) + return expr.func(*[_constant_renumber(x) for x in sortedargs]) + expr = _constant_renumber(expr) + + # Don't renumber symbols present in the ODE. + constants_found = [c for c in constants_found if c not in variables] + + # Renumbering happens here + subs_dict = dict(zip(constants_found, iter_constants)) + expr = expr.subs(subs_dict, simultaneous=True) + + return expr + + +def _handle_Integral(expr, func, hint): + r""" + Converts a solution with Integrals in it into an actual solution. + + For most hints, this simply runs ``expr.doit()``. + + """ + if hint == "nth_linear_constant_coeff_homogeneous": + sol = expr + elif not hint.endswith("_Integral"): + sol = expr.doit() + else: + sol = expr + return sol + + +# XXX: Should this function maybe go somewhere else? + + +def homogeneous_order(eq, *symbols): + r""" + Returns the order `n` if `g` is homogeneous and ``None`` if it is not + homogeneous. + + Determines if a function is homogeneous and if so of what order. A + function `f(x, y, \cdots)` is homogeneous of order `n` if `f(t x, t y, + \cdots) = t^n f(x, y, \cdots)`. + + If the function is of two variables, `F(x, y)`, then `f` being homogeneous + of any order is equivalent to being able to rewrite `F(x, y)` as `G(x/y)` + or `H(y/x)`. This fact is used to solve 1st order ordinary differential + equations whose coefficients are homogeneous of the same order (see the + docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep`). + + Symbols can be functions, but every argument of the function must be a + symbol, and the arguments of the function that appear in the expression + must match those given in the list of symbols. If a declared function + appears with different arguments than given in the list of symbols, + ``None`` is returned. + + Examples + ======== + + >>> from sympy import Function, homogeneous_order, sqrt + >>> from sympy.abc import x, y + >>> f = Function('f') + >>> homogeneous_order(f(x), f(x)) is None + True + >>> homogeneous_order(f(x,y), f(y, x), x, y) is None + True + >>> homogeneous_order(f(x), f(x), x) + 1 + >>> homogeneous_order(x**2*f(x)/sqrt(x**2+f(x)**2), x, f(x)) + 2 + >>> homogeneous_order(x**2+f(x), x, f(x)) is None + True + + """ + + if not symbols: + raise ValueError("homogeneous_order: no symbols were given.") + symset = set(symbols) + eq = sympify(eq) + + # The following are not supported + if eq.has(Order, Derivative): + return None + + # These are all constants + if (eq.is_Number or + eq.is_NumberSymbol or + eq.is_number + ): + return S.Zero + + # Replace all functions with dummy variables + dum = numbered_symbols(prefix='d', cls=Dummy) + newsyms = set() + for i in [j for j in symset if getattr(j, 'is_Function')]: + iargs = set(i.args) + if iargs.difference(symset): + return None + else: + dummyvar = next(dum) + eq = eq.subs(i, dummyvar) + symset.remove(i) + newsyms.add(dummyvar) + symset.update(newsyms) + + if not eq.free_symbols & symset: + return None + + # assuming order of a nested function can only be equal to zero + if isinstance(eq, Function): + return None if homogeneous_order( + eq.args[0], *tuple(symset)) != 0 else S.Zero + + # make the replacement of x with x*t and see if t can be factored out + t = Dummy('t', positive=True) # It is sufficient that t > 0 + eqs = separatevars(eq.subs([(i, t*i) for i in symset]), [t], dict=True)[t] + if eqs is S.One: + return S.Zero # there was no term with only t + i, d = eqs.as_independent(t, as_Add=False) + b, e = d.as_base_exp() + if b == t: + return e + + +def ode_2nd_power_series_ordinary(eq, func, order, match): + r""" + Gives a power series solution to a second order homogeneous differential + equation with polynomial coefficients at an ordinary point. A homogeneous + differential equation is of the form + + .. math :: P(x)\frac{d^2y}{dx^2} + Q(x)\frac{dy}{dx} + R(x) y(x) = 0 + + For simplicity it is assumed that `P(x)`, `Q(x)` and `R(x)` are polynomials, + it is sufficient that `\frac{Q(x)}{P(x)}` and `\frac{R(x)}{P(x)}` exists at + `x_{0}`. A recurrence relation is obtained by substituting `y` as `\sum_{n=0}^\infty a_{n}x^{n}`, + in the differential equation, and equating the nth term. Using this relation + various terms can be generated. + + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = f(x).diff(x, 2) + f(x) + >>> pprint(dsolve(eq, hint='2nd_power_series_ordinary')) + / 4 2 \ / 2\ + |x x | | x | / 6\ + f(x) = C2*|-- - -- + 1| + C1*x*|1 - --| + O\x / + \24 2 / \ 6 / + + + References + ========== + - https://tutorial.math.lamar.edu/Classes/DE/SeriesSolutions.aspx + - George E. Simmons, "Differential Equations with Applications and + Historical Notes", p.p 176 - 184 + + """ + x = func.args[0] + f = func.func + C0, C1 = get_numbered_constants(eq, num=2) + n = Dummy("n", integer=True) + s = Wild("s") + k = Wild("k", exclude=[x]) + x0 = match['x0'] + terms = match['terms'] + p = match[match['a3']] + q = match[match['b3']] + r = match[match['c3']] + seriesdict = {} + recurr = Function("r") + + # Generating the recurrence relation which works this way: + # for the second order term the summation begins at n = 2. The coefficients + # p is multiplied with an*(n - 1)*(n - 2)*x**n-2 and a substitution is made such that + # the exponent of x becomes n. + # For example, if p is x, then the second degree recurrence term is + # an*(n - 1)*(n - 2)*x**n-1, substituting (n - 1) as n, it transforms to + # an+1*n*(n - 1)*x**n. + # A similar process is done with the first order and zeroth order term. + + coefflist = [(recurr(n), r), (n*recurr(n), q), (n*(n - 1)*recurr(n), p)] + for index, coeff in enumerate(coefflist): + if coeff[1]: + f2 = powsimp(expand((coeff[1]*(x - x0)**(n - index)).subs(x, x + x0))) + if f2.is_Add: + addargs = f2.args + else: + addargs = [f2] + for arg in addargs: + powm = arg.match(s*x**k) + term = coeff[0]*powm[s] + if not powm[k].is_Symbol: + term = term.subs(n, n - powm[k].as_independent(n)[0]) + startind = powm[k].subs(n, index) + # Seeing if the startterm can be reduced further. + # If it vanishes for n lesser than startind, it is + # equal to summation from n. + if startind: + for i in reversed(range(startind)): + if not term.subs(n, i): + seriesdict[term] = i + else: + seriesdict[term] = i + 1 + break + else: + seriesdict[term] = S.Zero + + # Stripping of terms so that the sum starts with the same number. + teq = S.Zero + suminit = seriesdict.values() + rkeys = seriesdict.keys() + req = Add(*rkeys) + if any(suminit): + maxval = max(suminit) + for term in seriesdict: + val = seriesdict[term] + if val != maxval: + for i in range(val, maxval): + teq += term.subs(n, val) + + finaldict = {} + if teq: + fargs = teq.atoms(AppliedUndef) + if len(fargs) == 1: + finaldict[fargs.pop()] = 0 + else: + maxf = max(fargs, key = lambda x: x.args[0]) + sol = solve(teq, maxf) + if isinstance(sol, list): + sol = sol[0] + finaldict[maxf] = sol + + # Finding the recurrence relation in terms of the largest term. + fargs = req.atoms(AppliedUndef) + maxf = max(fargs, key = lambda x: x.args[0]) + minf = min(fargs, key = lambda x: x.args[0]) + if minf.args[0].is_Symbol: + startiter = 0 + else: + startiter = -minf.args[0].as_independent(n)[0] + lhs = maxf + rhs = solve(req, maxf) + if isinstance(rhs, list): + rhs = rhs[0] + + # Checking how many values are already present + tcounter = len([t for t in finaldict.values() if t]) + + for _ in range(tcounter, terms - 3): # Assuming c0 and c1 to be arbitrary + check = rhs.subs(n, startiter) + nlhs = lhs.subs(n, startiter) + nrhs = check.subs(finaldict) + finaldict[nlhs] = nrhs + startiter += 1 + + # Post processing + series = C0 + C1*(x - x0) + for term in finaldict: + if finaldict[term]: + fact = term.args[0] + series += (finaldict[term].subs([(recurr(0), C0), (recurr(1), C1)])*( + x - x0)**fact) + series = collect(expand_mul(series), [C0, C1]) + Order(x**terms) + return Eq(f(x), series) + + +def ode_2nd_power_series_regular(eq, func, order, match): + r""" + Gives a power series solution to a second order homogeneous differential + equation with polynomial coefficients at a regular point. A second order + homogeneous differential equation is of the form + + .. math :: P(x)\frac{d^2y}{dx^2} + Q(x)\frac{dy}{dx} + R(x) y(x) = 0 + + A point is said to regular singular at `x0` if `x - x0\frac{Q(x)}{P(x)}` + and `(x - x0)^{2}\frac{R(x)}{P(x)}` are analytic at `x0`. For simplicity + `P(x)`, `Q(x)` and `R(x)` are assumed to be polynomials. The algorithm for + finding the power series solutions is: + + 1. Try expressing `(x - x0)P(x)` and `((x - x0)^{2})Q(x)` as power series + solutions about x0. Find `p0` and `q0` which are the constants of the + power series expansions. + 2. Solve the indicial equation `f(m) = m(m - 1) + m*p0 + q0`, to obtain the + roots `m1` and `m2` of the indicial equation. + 3. If `m1 - m2` is a non integer there exists two series solutions. If + `m1 = m2`, there exists only one solution. If `m1 - m2` is an integer, + then the existence of one solution is confirmed. The other solution may + or may not exist. + + The power series solution is of the form `x^{m}\sum_{n=0}^\infty a_{n}x^{n}`. The + coefficients are determined by the following recurrence relation. + `a_{n} = -\frac{\sum_{k=0}^{n-1} q_{n-k} + (m + k)p_{n-k}}{f(m + n)}`. For the case + in which `m1 - m2` is an integer, it can be seen from the recurrence relation + that for the lower root `m`, when `n` equals the difference of both the + roots, the denominator becomes zero. So if the numerator is not equal to zero, + a second series solution exists. + + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = x*(f(x).diff(x, 2)) + 2*(f(x).diff(x)) + x*f(x) + >>> pprint(dsolve(eq, hint='2nd_power_series_regular')) + / 6 4 2 \ + | x x x | + / 4 2 \ C1*|- --- + -- - -- + 1| + |x x | \ 720 24 2 / / 6\ + f(x) = C2*|--- - -- + 1| + ------------------------ + O\x / + \120 6 / x + + + References + ========== + - George E. Simmons, "Differential Equations with Applications and + Historical Notes", p.p 176 - 184 + + """ + x = func.args[0] + f = func.func + C0, C1 = get_numbered_constants(eq, num=2) + m = Dummy("m") # for solving the indicial equation + x0 = match['x0'] + terms = match['terms'] + p = match['p'] + q = match['q'] + + # Generating the indicial equation + indicial = [] + for term in [p, q]: + if not term.has(x): + indicial.append(term) + else: + term = series(term, x=x, n=1, x0=x0) + if isinstance(term, Order): + indicial.append(S.Zero) + else: + for arg in term.args: + if not arg.has(x): + indicial.append(arg) + break + + p0, q0 = indicial + sollist = solve(m*(m - 1) + m*p0 + q0, m) + if sollist and isinstance(sollist, list) and all( + sol.is_real for sol in sollist): + serdict1 = {} + serdict2 = {} + if len(sollist) == 1: + # Only one series solution exists in this case. + m1 = m2 = sollist.pop() + if terms-m1-1 <= 0: + return Eq(f(x), Order(terms)) + serdict1 = _frobenius(terms-m1-1, m1, p0, q0, p, q, x0, x, C0) + + else: + m1 = sollist[0] + m2 = sollist[1] + if m1 < m2: + m1, m2 = m2, m1 + # Irrespective of whether m1 - m2 is an integer or not, one + # Frobenius series solution exists. + serdict1 = _frobenius(terms-m1-1, m1, p0, q0, p, q, x0, x, C0) + if not (m1 - m2).is_integer: + # Second frobenius series solution exists. + serdict2 = _frobenius(terms-m2-1, m2, p0, q0, p, q, x0, x, C1) + else: + # Check if second frobenius series solution exists. + serdict2 = _frobenius(terms-m2-1, m2, p0, q0, p, q, x0, x, C1, check=m1) + + if serdict1: + finalseries1 = C0 + for key in serdict1: + power = int(key.name[1:]) + finalseries1 += serdict1[key]*(x - x0)**power + finalseries1 = (x - x0)**m1*finalseries1 + finalseries2 = S.Zero + if serdict2: + for key in serdict2: + power = int(key.name[1:]) + finalseries2 += serdict2[key]*(x - x0)**power + finalseries2 += C1 + finalseries2 = (x - x0)**m2*finalseries2 + return Eq(f(x), collect(finalseries1 + finalseries2, + [C0, C1]) + Order(x**terms)) + + +def _frobenius(n, m, p0, q0, p, q, x0, x, c, check=None): + r""" + Returns a dict with keys as coefficients and values as their values in terms of C0 + """ + n = int(n) + # In cases where m1 - m2 is not an integer + m2 = check + + d = Dummy("d") + numsyms = numbered_symbols("C", start=0) + numsyms = [next(numsyms) for i in range(n + 1)] + serlist = [] + for ser in [p, q]: + # Order term not present + if ser.is_polynomial(x) and Poly(ser, x).degree() <= n: + if x0: + ser = ser.subs(x, x + x0) + dict_ = Poly(ser, x).as_dict() + # Order term present + else: + tseries = series(ser, x=x0, n=n+1) + # Removing order + dict_ = Poly(list(ordered(tseries.args))[: -1], x).as_dict() + # Fill in with zeros, if coefficients are zero. + for i in range(n + 1): + if (i,) not in dict_: + dict_[(i,)] = S.Zero + serlist.append(dict_) + + pseries = serlist[0] + qseries = serlist[1] + indicial = d*(d - 1) + d*p0 + q0 + frobdict = {} + for i in range(1, n + 1): + num = c*(m*pseries[(i,)] + qseries[(i,)]) + for j in range(1, i): + sym = Symbol("C" + str(j)) + num += frobdict[sym]*((m + j)*pseries[(i - j,)] + qseries[(i - j,)]) + + # Checking for cases when m1 - m2 is an integer. If num equals zero + # then a second Frobenius series solution cannot be found. If num is not zero + # then set constant as zero and proceed. + if m2 is not None and i == m2 - m: + if num: + return False + else: + frobdict[numsyms[i]] = S.Zero + else: + frobdict[numsyms[i]] = -num/(indicial.subs(d, m+i)) + + return frobdict + +def _remove_redundant_solutions(eq, solns, order, var): + r""" + Remove redundant solutions from the set of solutions. + + This function is needed because otherwise dsolve can return + redundant solutions. As an example consider: + + eq = Eq((f(x).diff(x, 2))*f(x).diff(x), 0) + + There are two ways to find solutions to eq. The first is to solve f(x).diff(x, 2) = 0 + leading to solution f(x)=C1 + C2*x. The second is to solve the equation f(x).diff(x) = 0 + leading to the solution f(x) = C1. In this particular case we then see + that the second solution is a special case of the first and we do not + want to return it. + + This does not always happen. If we have + + eq = Eq((f(x)**2-4)*(f(x).diff(x)-4), 0) + + then we get the algebraic solution f(x) = [-2, 2] and the integral solution + f(x) = x + C1 and in this case the two solutions are not equivalent wrt + initial conditions so both should be returned. + """ + def is_special_case_of(soln1, soln2): + return _is_special_case_of(soln1, soln2, eq, order, var) + + unique_solns = [] + for soln1 in solns: + for soln2 in unique_solns.copy(): + if is_special_case_of(soln1, soln2): + break + elif is_special_case_of(soln2, soln1): + unique_solns.remove(soln2) + else: + unique_solns.append(soln1) + + return unique_solns + +def _is_special_case_of(soln1, soln2, eq, order, var): + r""" + True if soln1 is found to be a special case of soln2 wrt some value of the + constants that appear in soln2. False otherwise. + """ + # The solutions returned by dsolve may be given explicitly or implicitly. + # We will equate the sol1=(soln1.rhs - soln1.lhs), sol2=(soln2.rhs - soln2.lhs) + # of the two solutions. + # + # Since this is supposed to hold for all x it also holds for derivatives. + # For an order n ode we should be able to differentiate + # each solution n times to get n+1 equations. + # + # We then try to solve those n+1 equations for the integrations constants + # in sol2. If we can find a solution that does not depend on x then it + # means that some value of the constants in sol1 is a special case of + # sol2 corresponding to a particular choice of the integration constants. + + # In case the solution is in implicit form we subtract the sides + soln1 = soln1.rhs - soln1.lhs + soln2 = soln2.rhs - soln2.lhs + + # Work for the series solution + if soln1.has(Order) and soln2.has(Order): + if soln1.getO() == soln2.getO(): + soln1 = soln1.removeO() + soln2 = soln2.removeO() + else: + return False + elif soln1.has(Order) or soln2.has(Order): + return False + + constants1 = soln1.free_symbols.difference(eq.free_symbols) + constants2 = soln2.free_symbols.difference(eq.free_symbols) + + constants1_new = get_numbered_constants(Tuple(soln1, soln2), len(constants1)) + if len(constants1) == 1: + constants1_new = {constants1_new} + for c_old, c_new in zip(constants1, constants1_new): + soln1 = soln1.subs(c_old, c_new) + + # n equations for sol1 = sol2, sol1'=sol2', ... + lhs = soln1 + rhs = soln2 + eqns = [Eq(lhs, rhs)] + for n in range(1, order): + lhs = lhs.diff(var) + rhs = rhs.diff(var) + eq = Eq(lhs, rhs) + eqns.append(eq) + + # BooleanTrue/False awkwardly show up for trivial equations + if any(isinstance(eq, BooleanFalse) for eq in eqns): + return False + eqns = [eq for eq in eqns if not isinstance(eq, BooleanTrue)] + + try: + constant_solns = solve(eqns, constants2) + except NotImplementedError: + return False + + # Sometimes returns a dict and sometimes a list of dicts + if isinstance(constant_solns, dict): + constant_solns = [constant_solns] + + # after solving the issue 17418, maybe we don't need the following checksol code. + for constant_soln in constant_solns: + for eq in eqns: + eq=eq.rhs-eq.lhs + if checksol(eq, constant_soln) is not True: + return False + + # If any solution gives all constants as expressions that don't depend on + # x then there exists constants for soln2 that give soln1 + for constant_soln in constant_solns: + if not any(c.has(var) for c in constant_soln.values()): + return True + + return False + + +def ode_1st_power_series(eq, func, order, match): + r""" + The power series solution is a method which gives the Taylor series expansion + to the solution of a differential equation. + + For a first order differential equation `\frac{dy}{dx} = h(x, y)`, a power + series solution exists at a point `x = x_{0}` if `h(x, y)` is analytic at `x_{0}`. + The solution is given by + + .. math:: y(x) = y(x_{0}) + \sum_{n = 1}^{\infty} \frac{F_{n}(x_{0},b)(x - x_{0})^n}{n!}, + + where `y(x_{0}) = b` is the value of y at the initial value of `x_{0}`. + To compute the values of the `F_{n}(x_{0},b)` the following algorithm is + followed, until the required number of terms are generated. + + 1. `F_1 = h(x_{0}, b)` + 2. `F_{n+1} = \frac{\partial F_{n}}{\partial x} + \frac{\partial F_{n}}{\partial y}F_{1}` + + Examples + ======== + + >>> from sympy import Function, pprint, exp, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = exp(x)*(f(x).diff(x)) - f(x) + >>> pprint(dsolve(eq, hint='1st_power_series')) + 3 4 5 + C1*x C1*x C1*x / 6\ + f(x) = C1 + C1*x - ----- + ----- + ----- + O\x / + 6 24 60 + + + References + ========== + + - Travis W. Walker, Analytic power series technique for solving first-order + differential equations, p.p 17, 18 + + """ + x = func.args[0] + y = match['y'] + f = func.func + h = -match[match['d']]/match[match['e']] + point = match['f0'] + value = match['f0val'] + terms = match['terms'] + + # First term + F = h + if not h: + return Eq(f(x), value) + + # Initialization + series = value + if terms > 1: + hc = h.subs({x: point, y: value}) + if hc.has(oo) or hc.has(nan) or hc.has(zoo): + # Derivative does not exist, not analytic + return Eq(f(x), oo) + elif hc: + series += hc*(x - point) + + for factcount in range(2, terms): + Fnew = F.diff(x) + F.diff(y)*h + Fnewc = Fnew.subs({x: point, y: value}) + # Same logic as above + if Fnewc.has(oo) or Fnewc.has(nan) or Fnewc.has(-oo) or Fnewc.has(zoo): + return Eq(f(x), oo) + series += Fnewc*((x - point)**factcount)/factorial(factcount) + F = Fnew + series += Order(x**terms) + return Eq(f(x), series) + + +def checkinfsol(eq, infinitesimals, func=None, order=None): + r""" + This function is used to check if the given infinitesimals are the + actual infinitesimals of the given first order differential equation. + This method is specific to the Lie Group Solver of ODEs. + + As of now, it simply checks, by substituting the infinitesimals in the + partial differential equation. + + + .. math:: \frac{\partial \eta}{\partial x} + \left(\frac{\partial \eta}{\partial y} + - \frac{\partial \xi}{\partial x}\right)*h + - \frac{\partial \xi}{\partial y}*h^{2} + - \xi\frac{\partial h}{\partial x} - \eta\frac{\partial h}{\partial y} = 0 + + + where `\eta`, and `\xi` are the infinitesimals and `h(x,y) = \frac{dy}{dx}` + + The infinitesimals should be given in the form of a list of dicts + ``[{xi(x, y): inf, eta(x, y): inf}]``, corresponding to the + output of the function infinitesimals. It returns a list + of values of the form ``[(True/False, sol)]`` where ``sol`` is the value + obtained after substituting the infinitesimals in the PDE. If it + is ``True``, then ``sol`` would be 0. + + """ + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + if not func: + eq, func = _preprocess(eq) + variables = func.args + if len(variables) != 1: + raise ValueError("ODE's have only one independent variable") + else: + x = variables[0] + if not order: + order = ode_order(eq, func) + if order != 1: + raise NotImplementedError("Lie groups solver has been implemented " + "only for first order differential equations") + else: + df = func.diff(x) + a = Wild('a', exclude = [df]) + b = Wild('b', exclude = [df]) + match = collect(expand(eq), df).match(a*df + b) + + if match: + h = -simplify(match[b]/match[a]) + else: + try: + sol = solve(eq, df) + except NotImplementedError: + raise NotImplementedError("Infinitesimals for the " + "first order ODE could not be found") + else: + h = sol[0] # Find infinitesimals for one solution + + y = Dummy('y') + h = h.subs(func, y) + xi = Function('xi')(x, y) + eta = Function('eta')(x, y) + dxi = Function('xi')(x, func) + deta = Function('eta')(x, func) + pde = (eta.diff(x) + (eta.diff(y) - xi.diff(x))*h - + (xi.diff(y))*h**2 - xi*(h.diff(x)) - eta*(h.diff(y))) + soltup = [] + for sol in infinitesimals: + tsol = {xi: S(sol[dxi]).subs(func, y), + eta: S(sol[deta]).subs(func, y)} + sol = simplify(pde.subs(tsol).doit()) + if sol: + soltup.append((False, sol.subs(y, func))) + else: + soltup.append((True, 0)) + return soltup + + +def sysode_linear_2eq_order1(match_): + x = match_['func'][0].func + y = match_['func'][1].func + func = match_['func'] + fc = match_['func_coeff'] + eq = match_['eq'] + r = {} + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + for i in range(2): + eq[i] = Add(*[terms/fc[i,func[i],1] for terms in Add.make_args(eq[i])]) + + # for equations Eq(a1*diff(x(t),t), a*x(t) + b*y(t) + k1) + # and Eq(a2*diff(x(t),t), c*x(t) + d*y(t) + k2) + r['a'] = -fc[0,x(t),0]/fc[0,x(t),1] + r['c'] = -fc[1,x(t),0]/fc[1,y(t),1] + r['b'] = -fc[0,y(t),0]/fc[0,x(t),1] + r['d'] = -fc[1,y(t),0]/fc[1,y(t),1] + forcing = [S.Zero,S.Zero] + for i in range(2): + for j in Add.make_args(eq[i]): + if not j.has(x(t), y(t)): + forcing[i] += j + if not (forcing[0].has(t) or forcing[1].has(t)): + r['k1'] = forcing[0] + r['k2'] = forcing[1] + else: + raise NotImplementedError("Only homogeneous problems are supported" + + " (and constant inhomogeneity)") + + if match_['type_of_equation'] == 'type6': + sol = _linear_2eq_order1_type6(x, y, t, r, eq) + if match_['type_of_equation'] == 'type7': + sol = _linear_2eq_order1_type7(x, y, t, r, eq) + return sol + +def _linear_2eq_order1_type6(x, y, t, r, eq): + r""" + The equations of this type of ode are . + + .. math:: x' = f(t) x + g(t) y + + .. math:: y' = a [f(t) + a h(t)] x + a [g(t) - h(t)] y + + This is solved by first multiplying the first equation by `-a` and adding + it to the second equation to obtain + + .. math:: y' - a x' = -a h(t) (y - a x) + + Setting `U = y - ax` and integrating the equation we arrive at + + .. math:: y - ax = C_1 e^{-a \int h(t) \,dt} + + and on substituting the value of y in first equation give rise to first order ODEs. After solving for + `x`, we can obtain `y` by substituting the value of `x` in second equation. + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + p = 0 + q = 0 + p1 = cancel(r['c']/cancel(r['c']/r['d']).as_numer_denom()[0]) + p2 = cancel(r['a']/cancel(r['a']/r['b']).as_numer_denom()[0]) + for n, i in enumerate([p1, p2]): + for j in Mul.make_args(collect_const(i)): + if not j.has(t): + q = j + if q!=0 and n==0: + if ((r['c']/j - r['a'])/(r['b'] - r['d']/j)) == j: + p = 1 + s = j + break + if q!=0 and n==1: + if ((r['a']/j - r['c'])/(r['d'] - r['b']/j)) == j: + p = 2 + s = j + break + + if p == 1: + equ = diff(x(t),t) - r['a']*x(t) - r['b']*(s*x(t) + C1*exp(-s*Integral(r['b'] - r['d']/s, t))) + hint1 = classify_ode(equ)[1] + sol1 = dsolve(equ, hint=hint1+'_Integral').rhs + sol2 = s*sol1 + C1*exp(-s*Integral(r['b'] - r['d']/s, t)) + elif p ==2: + equ = diff(y(t),t) - r['c']*y(t) - r['d']*s*y(t) + C1*exp(-s*Integral(r['d'] - r['b']/s, t)) + hint1 = classify_ode(equ)[1] + sol2 = dsolve(equ, hint=hint1+'_Integral').rhs + sol1 = s*sol2 + C1*exp(-s*Integral(r['d'] - r['b']/s, t)) + return [Eq(x(t), sol1), Eq(y(t), sol2)] + +def _linear_2eq_order1_type7(x, y, t, r, eq): + r""" + The equations of this type of ode are . + + .. math:: x' = f(t) x + g(t) y + + .. math:: y' = h(t) x + p(t) y + + Differentiating the first equation and substituting the value of `y` + from second equation will give a second-order linear equation + + .. math:: g x'' - (fg + gp + g') x' + (fgp - g^{2} h + f g' - f' g) x = 0 + + This above equation can be easily integrated if following conditions are satisfied. + + 1. `fgp - g^{2} h + f g' - f' g = 0` + + 2. `fgp - g^{2} h + f g' - f' g = ag, fg + gp + g' = bg` + + If first condition is satisfied then it is solved by current dsolve solver and in second case it becomes + a constant coefficient differential equation which is also solved by current solver. + + Otherwise if the above condition fails then, + a particular solution is assumed as `x = x_0(t)` and `y = y_0(t)` + Then the general solution is expressed as + + .. math:: x = C_1 x_0(t) + C_2 x_0(t) \int \frac{g(t) F(t) P(t)}{x_0^{2}(t)} \,dt + + .. math:: y = C_1 y_0(t) + C_2 [\frac{F(t) P(t)}{x_0(t)} + y_0(t) \int \frac{g(t) F(t) P(t)}{x_0^{2}(t)} \,dt] + + where C1 and C2 are arbitrary constants and + + .. math:: F(t) = e^{\int f(t) \,dt}, P(t) = e^{\int p(t) \,dt} + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + e1 = r['a']*r['b']*r['c'] - r['b']**2*r['c'] + r['a']*diff(r['b'],t) - diff(r['a'],t)*r['b'] + e2 = r['a']*r['c']*r['d'] - r['b']*r['c']**2 + diff(r['c'],t)*r['d'] - r['c']*diff(r['d'],t) + m1 = r['a']*r['b'] + r['b']*r['d'] + diff(r['b'],t) + m2 = r['a']*r['c'] + r['c']*r['d'] + diff(r['c'],t) + if e1 == 0: + sol1 = dsolve(r['b']*diff(x(t),t,t) - m1*diff(x(t),t)).rhs + sol2 = dsolve(diff(y(t),t) - r['c']*sol1 - r['d']*y(t)).rhs + elif e2 == 0: + sol2 = dsolve(r['c']*diff(y(t),t,t) - m2*diff(y(t),t)).rhs + sol1 = dsolve(diff(x(t),t) - r['a']*x(t) - r['b']*sol2).rhs + elif not (e1/r['b']).has(t) and not (m1/r['b']).has(t): + sol1 = dsolve(diff(x(t),t,t) - (m1/r['b'])*diff(x(t),t) - (e1/r['b'])*x(t)).rhs + sol2 = dsolve(diff(y(t),t) - r['c']*sol1 - r['d']*y(t)).rhs + elif not (e2/r['c']).has(t) and not (m2/r['c']).has(t): + sol2 = dsolve(diff(y(t),t,t) - (m2/r['c'])*diff(y(t),t) - (e2/r['c'])*y(t)).rhs + sol1 = dsolve(diff(x(t),t) - r['a']*x(t) - r['b']*sol2).rhs + else: + x0 = Function('x0')(t) # x0 and y0 being particular solutions + y0 = Function('y0')(t) + F = exp(Integral(r['a'],t)) + P = exp(Integral(r['d'],t)) + sol1 = C1*x0 + C2*x0*Integral(r['b']*F*P/x0**2, t) + sol2 = C1*y0 + C2*(F*P/x0 + y0*Integral(r['b']*F*P/x0**2, t)) + return [Eq(x(t), sol1), Eq(y(t), sol2)] + + +def sysode_nonlinear_2eq_order1(match_): + func = match_['func'] + eq = match_['eq'] + fc = match_['func_coeff'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + if match_['type_of_equation'] == 'type5': + sol = _nonlinear_2eq_order1_type5(func, t, eq) + return sol + x = func[0].func + y = func[1].func + for i in range(2): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + if match_['type_of_equation'] == 'type1': + sol = _nonlinear_2eq_order1_type1(x, y, t, eq) + elif match_['type_of_equation'] == 'type2': + sol = _nonlinear_2eq_order1_type2(x, y, t, eq) + elif match_['type_of_equation'] == 'type3': + sol = _nonlinear_2eq_order1_type3(x, y, t, eq) + elif match_['type_of_equation'] == 'type4': + sol = _nonlinear_2eq_order1_type4(x, y, t, eq) + return sol + + +def _nonlinear_2eq_order1_type1(x, y, t, eq): + r""" + Equations: + + .. math:: x' = x^n F(x,y) + + .. math:: y' = g(y) F(x,y) + + Solution: + + .. math:: x = \varphi(y), \int \frac{1}{g(y) F(\varphi(y),y)} \,dy = t + C_2 + + where + + if `n \neq 1` + + .. math:: \varphi = [C_1 + (1-n) \int \frac{1}{g(y)} \,dy]^{\frac{1}{1-n}} + + if `n = 1` + + .. math:: \varphi = C_1 e^{\int \frac{1}{g(y)} \,dy} + + where `C_1` and `C_2` are arbitrary constants. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + n = Wild('n', exclude=[x(t),y(t)]) + f = Wild('f') + u, v = symbols('u, v') + r = eq[0].match(diff(x(t),t) - x(t)**n*f) + g = ((diff(y(t),t) - eq[1])/r[f]).subs(y(t),v) + F = r[f].subs(x(t),u).subs(y(t),v) + n = r[n] + if n!=1: + phi = (C1 + (1-n)*Integral(1/g, v))**(1/(1-n)) + else: + phi = C1*exp(Integral(1/g, v)) + phi = phi.doit() + sol2 = solve(Integral(1/(g*F.subs(u,phi)), v).doit() - t - C2, v) + sol = [] + for sols in sol2: + sol.append(Eq(x(t),phi.subs(v, sols))) + sol.append(Eq(y(t), sols)) + return sol + +def _nonlinear_2eq_order1_type2(x, y, t, eq): + r""" + Equations: + + .. math:: x' = e^{\lambda x} F(x,y) + + .. math:: y' = g(y) F(x,y) + + Solution: + + .. math:: x = \varphi(y), \int \frac{1}{g(y) F(\varphi(y),y)} \,dy = t + C_2 + + where + + if `\lambda \neq 0` + + .. math:: \varphi = -\frac{1}{\lambda} log(C_1 - \lambda \int \frac{1}{g(y)} \,dy) + + if `\lambda = 0` + + .. math:: \varphi = C_1 + \int \frac{1}{g(y)} \,dy + + where `C_1` and `C_2` are arbitrary constants. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + n = Wild('n', exclude=[x(t),y(t)]) + f = Wild('f') + u, v = symbols('u, v') + r = eq[0].match(diff(x(t),t) - exp(n*x(t))*f) + g = ((diff(y(t),t) - eq[1])/r[f]).subs(y(t),v) + F = r[f].subs(x(t),u).subs(y(t),v) + n = r[n] + if n: + phi = -1/n*log(C1 - n*Integral(1/g, v)) + else: + phi = C1 + Integral(1/g, v) + phi = phi.doit() + sol2 = solve(Integral(1/(g*F.subs(u,phi)), v).doit() - t - C2, v) + sol = [] + for sols in sol2: + sol.append(Eq(x(t),phi.subs(v, sols))) + sol.append(Eq(y(t), sols)) + return sol + +def _nonlinear_2eq_order1_type3(x, y, t, eq): + r""" + Autonomous system of general form + + .. math:: x' = F(x,y) + + .. math:: y' = G(x,y) + + Assuming `y = y(x, C_1)` where `C_1` is an arbitrary constant is the general + solution of the first-order equation + + .. math:: F(x,y) y'_x = G(x,y) + + Then the general solution of the original system of equations has the form + + .. math:: \int \frac{1}{F(x,y(x,C_1))} \,dx = t + C_1 + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + v = Function('v') + u = Symbol('u') + f = Wild('f') + g = Wild('g') + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + F = r1[f].subs(x(t), u).subs(y(t), v(u)) + G = r2[g].subs(x(t), u).subs(y(t), v(u)) + sol2r = dsolve(Eq(diff(v(u), u), G/F)) + if isinstance(sol2r, Equality): + sol2r = [sol2r] + for sol2s in sol2r: + sol1 = solve(Integral(1/F.subs(v(u), sol2s.rhs), u).doit() - t - C2, u) + sol = [] + for sols in sol1: + sol.append(Eq(x(t), sols)) + sol.append(Eq(y(t), (sol2s.rhs).subs(u, sols))) + return sol + +def _nonlinear_2eq_order1_type4(x, y, t, eq): + r""" + Equation: + + .. math:: x' = f_1(x) g_1(y) \phi(x,y,t) + + .. math:: y' = f_2(x) g_2(y) \phi(x,y,t) + + First integral: + + .. math:: \int \frac{f_2(x)}{f_1(x)} \,dx - \int \frac{g_1(y)}{g_2(y)} \,dy = C + + where `C` is an arbitrary constant. + + On solving the first integral for `x` (resp., `y` ) and on substituting the + resulting expression into either equation of the original solution, one + arrives at a first-order equation for determining `y` (resp., `x` ). + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v = symbols('u, v') + U, V = symbols('U, V', cls=Function) + f = Wild('f') + g = Wild('g') + f1 = Wild('f1', exclude=[v,t]) + f2 = Wild('f2', exclude=[v,t]) + g1 = Wild('g1', exclude=[u,t]) + g2 = Wild('g2', exclude=[u,t]) + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + num, den = ( + (r1[f].subs(x(t),u).subs(y(t),v))/ + (r2[g].subs(x(t),u).subs(y(t),v))).as_numer_denom() + R1 = num.match(f1*g1) + R2 = den.match(f2*g2) + phi = (r1[f].subs(x(t),u).subs(y(t),v))/num + F1 = R1[f1]; F2 = R2[f2] + G1 = R1[g1]; G2 = R2[g2] + sol1r = solve(Integral(F2/F1, u).doit() - Integral(G1/G2,v).doit() - C1, u) + sol2r = solve(Integral(F2/F1, u).doit() - Integral(G1/G2,v).doit() - C1, v) + sol = [] + for sols in sol1r: + sol.append(Eq(y(t), dsolve(diff(V(t),t) - F2.subs(u,sols).subs(v,V(t))*G2.subs(v,V(t))*phi.subs(u,sols).subs(v,V(t))).rhs)) + for sols in sol2r: + sol.append(Eq(x(t), dsolve(diff(U(t),t) - F1.subs(u,U(t))*G1.subs(v,sols).subs(u,U(t))*phi.subs(v,sols).subs(u,U(t))).rhs)) + return set(sol) + +def _nonlinear_2eq_order1_type5(func, t, eq): + r""" + Clairaut system of ODEs + + .. math:: x = t x' + F(x',y') + + .. math:: y = t y' + G(x',y') + + The following are solutions of the system + + `(i)` straight lines: + + .. math:: x = C_1 t + F(C_1, C_2), y = C_2 t + G(C_1, C_2) + + where `C_1` and `C_2` are arbitrary constants; + + `(ii)` envelopes of the above lines; + + `(iii)` continuously differentiable lines made up from segments of the lines + `(i)` and `(ii)`. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + f = Wild('f') + g = Wild('g') + def check_type(x, y): + r1 = eq[0].match(t*diff(x(t),t) - x(t) + f) + r2 = eq[1].match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = eq[0].match(diff(x(t),t) - x(t)/t + f/t) + r2 = eq[1].match(diff(y(t),t) - y(t)/t + g/t) + if not (r1 and r2): + r1 = (-eq[0]).match(t*diff(x(t),t) - x(t) + f) + r2 = (-eq[1]).match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = (-eq[0]).match(diff(x(t),t) - x(t)/t + f/t) + r2 = (-eq[1]).match(diff(y(t),t) - y(t)/t + g/t) + return [r1, r2] + for func_ in func: + if isinstance(func_, list): + x = func[0][0].func + y = func[0][1].func + [r1, r2] = check_type(x, y) + if not (r1 and r2): + [r1, r2] = check_type(y, x) + x, y = y, x + x1 = diff(x(t),t); y1 = diff(y(t),t) + return {Eq(x(t), C1*t + r1[f].subs(x1,C1).subs(y1,C2)), Eq(y(t), C2*t + r2[g].subs(x1,C1).subs(y1,C2))} + +def sysode_nonlinear_3eq_order1(match_): + x = match_['func'][0].func + y = match_['func'][1].func + z = match_['func'][2].func + eq = match_['eq'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + if match_['type_of_equation'] == 'type1': + sol = _nonlinear_3eq_order1_type1(x, y, z, t, eq) + if match_['type_of_equation'] == 'type2': + sol = _nonlinear_3eq_order1_type2(x, y, z, t, eq) + if match_['type_of_equation'] == 'type3': + sol = _nonlinear_3eq_order1_type3(x, y, z, t, eq) + if match_['type_of_equation'] == 'type4': + sol = _nonlinear_3eq_order1_type4(x, y, z, t, eq) + if match_['type_of_equation'] == 'type5': + sol = _nonlinear_3eq_order1_type5(x, y, z, t, eq) + return sol + +def _nonlinear_3eq_order1_type1(x, y, z, t, eq): + r""" + Equations: + + .. math:: a x' = (b - c) y z, \enspace b y' = (c - a) z x, \enspace c z' = (a - b) x y + + First Integrals: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + .. math:: a^{2} x^{2} + b^{2} y^{2} + c^{2} z^{2} = C_2 + + where `C_1` and `C_2` are arbitrary constants. On solving the integrals for `y` and + `z` and on substituting the resulting expressions into the first equation of the + system, we arrives at a separable first-order equation on `x`. Similarly doing that + for other two equations, we will arrive at first order equation on `y` and `z` too. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0401.pdf + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + r = (diff(x(t),t) - eq[0]).match(p*y(t)*z(t)) + r.update((diff(y(t),t) - eq[1]).match(q*z(t)*x(t))) + r.update((diff(z(t),t) - eq[2]).match(s*x(t)*y(t))) + n1, d1 = r[p].as_numer_denom() + n2, d2 = r[q].as_numer_denom() + n3, d3 = r[s].as_numer_denom() + val = solve([n1*u-d1*v+d1*w, d2*u+n2*v-d2*w, d3*u-d3*v-n3*w],[u,v]) + vals = [val[v], val[u]] + c = lcm(vals[0].as_numer_denom()[1], vals[1].as_numer_denom()[1]) + b = vals[0].subs(w, c) + a = vals[1].subs(w, c) + y_x = sqrt(((c*C1-C2) - a*(c-a)*x(t)**2)/(b*(c-b))) + z_x = sqrt(((b*C1-C2) - a*(b-a)*x(t)**2)/(c*(b-c))) + z_y = sqrt(((a*C1-C2) - b*(a-b)*y(t)**2)/(c*(a-c))) + x_y = sqrt(((c*C1-C2) - b*(c-b)*y(t)**2)/(a*(c-a))) + x_z = sqrt(((b*C1-C2) - c*(b-c)*z(t)**2)/(a*(b-a))) + y_z = sqrt(((a*C1-C2) - c*(a-c)*z(t)**2)/(b*(a-b))) + sol1 = dsolve(a*diff(x(t),t) - (b-c)*y_x*z_x) + sol2 = dsolve(b*diff(y(t),t) - (c-a)*z_y*x_y) + sol3 = dsolve(c*diff(z(t),t) - (a-b)*x_z*y_z) + return [sol1, sol2, sol3] + + +def _nonlinear_3eq_order1_type2(x, y, z, t, eq): + r""" + Equations: + + .. math:: a x' = (b - c) y z f(x, y, z, t) + + .. math:: b y' = (c - a) z x f(x, y, z, t) + + .. math:: c z' = (a - b) x y f(x, y, z, t) + + First Integrals: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + .. math:: a^{2} x^{2} + b^{2} y^{2} + c^{2} z^{2} = C_2 + + where `C_1` and `C_2` are arbitrary constants. On solving the integrals for `y` and + `z` and on substituting the resulting expressions into the first equation of the + system, we arrives at a first-order differential equations on `x`. Similarly doing + that for other two equations we will arrive at first order equation on `y` and `z`. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0402.pdf + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + f = Wild('f') + r1 = (diff(x(t),t) - eq[0]).match(y(t)*z(t)*f) + r = collect_const(r1[f]).match(p*f) + r.update(((diff(y(t),t) - eq[1])/r[f]).match(q*z(t)*x(t))) + r.update(((diff(z(t),t) - eq[2])/r[f]).match(s*x(t)*y(t))) + n1, d1 = r[p].as_numer_denom() + n2, d2 = r[q].as_numer_denom() + n3, d3 = r[s].as_numer_denom() + val = solve([n1*u-d1*v+d1*w, d2*u+n2*v-d2*w, -d3*u+d3*v+n3*w],[u,v]) + vals = [val[v], val[u]] + c = lcm(vals[0].as_numer_denom()[1], vals[1].as_numer_denom()[1]) + a = vals[0].subs(w, c) + b = vals[1].subs(w, c) + y_x = sqrt(((c*C1-C2) - a*(c-a)*x(t)**2)/(b*(c-b))) + z_x = sqrt(((b*C1-C2) - a*(b-a)*x(t)**2)/(c*(b-c))) + z_y = sqrt(((a*C1-C2) - b*(a-b)*y(t)**2)/(c*(a-c))) + x_y = sqrt(((c*C1-C2) - b*(c-b)*y(t)**2)/(a*(c-a))) + x_z = sqrt(((b*C1-C2) - c*(b-c)*z(t)**2)/(a*(b-a))) + y_z = sqrt(((a*C1-C2) - c*(a-c)*z(t)**2)/(b*(a-b))) + sol1 = dsolve(a*diff(x(t),t) - (b-c)*y_x*z_x*r[f]) + sol2 = dsolve(b*diff(y(t),t) - (c-a)*z_y*x_y*r[f]) + sol3 = dsolve(c*diff(z(t),t) - (a-b)*x_z*y_z*r[f]) + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type3(x, y, z, t, eq): + r""" + Equations: + + .. math:: x' = c F_2 - b F_3, \enspace y' = a F_3 - c F_1, \enspace z' = b F_1 - a F_2 + + where `F_n = F_n(x, y, z, t)`. + + 1. First Integral: + + .. math:: a x + b y + c z = C_1, + + where C is an arbitrary constant. + + 2. If we assume function `F_n` to be independent of `t`,i.e, `F_n` = `F_n (x, y, z)` + Then, on eliminating `t` and `z` from the first two equation of the system, one + arrives at the first-order equation + + .. math:: \frac{dy}{dx} = \frac{a F_3 (x, y, z) - c F_1 (x, y, z)}{c F_2 (x, y, z) - + b F_3 (x, y, z)} + + where `z = \frac{1}{c} (C_1 - a x - b y)` + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0404.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + fu, fv, fw = symbols('u, v, w', cls=Function) + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = (diff(x(t), t) - eq[0]).match(F2-F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t), t) - eq[1]).match(p*r[F3] - r[s]*F1)) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t), u).subs(y(t),v).subs(z(t), w) + F2 = r[F2].subs(x(t), u).subs(y(t),v).subs(z(t), w) + F3 = r[F3].subs(x(t), u).subs(y(t),v).subs(z(t), w) + z_xy = (C1-a*u-b*v)/c + y_zx = (C1-a*u-c*w)/b + x_yz = (C1-b*v-c*w)/a + y_x = dsolve(diff(fv(u),u) - ((a*F3-c*F1)/(c*F2-b*F3)).subs(w,z_xy).subs(v,fv(u))).rhs + z_x = dsolve(diff(fw(u),u) - ((b*F1-a*F2)/(c*F2-b*F3)).subs(v,y_zx).subs(w,fw(u))).rhs + z_y = dsolve(diff(fw(v),v) - ((b*F1-a*F2)/(a*F3-c*F1)).subs(u,x_yz).subs(w,fw(v))).rhs + x_y = dsolve(diff(fu(v),v) - ((c*F2-b*F3)/(a*F3-c*F1)).subs(w,z_xy).subs(u,fu(v))).rhs + y_z = dsolve(diff(fv(w),w) - ((a*F3-c*F1)/(b*F1-a*F2)).subs(u,x_yz).subs(v,fv(w))).rhs + x_z = dsolve(diff(fu(w),w) - ((c*F2-b*F3)/(b*F1-a*F2)).subs(v,y_zx).subs(u,fu(w))).rhs + sol1 = dsolve(diff(fu(t),t) - (c*F2 - b*F3).subs(v,y_x).subs(w,z_x).subs(u,fu(t))).rhs + sol2 = dsolve(diff(fv(t),t) - (a*F3 - c*F1).subs(u,x_y).subs(w,z_y).subs(v,fv(t))).rhs + sol3 = dsolve(diff(fw(t),t) - (b*F1 - a*F2).subs(u,x_z).subs(v,y_z).subs(w,fw(t))).rhs + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type4(x, y, z, t, eq): + r""" + Equations: + + .. math:: x' = c z F_2 - b y F_3, \enspace y' = a x F_3 - c z F_1, \enspace z' = b y F_1 - a x F_2 + + where `F_n = F_n (x, y, z, t)` + + 1. First integral: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + where `C` is an arbitrary constant. + + 2. Assuming the function `F_n` is independent of `t`: `F_n = F_n (x, y, z)`. Then on + eliminating `t` and `z` from the first two equations of the system, one arrives at + the first-order equation + + .. math:: \frac{dy}{dx} = \frac{a x F_3 (x, y, z) - c z F_1 (x, y, z)} + {c z F_2 (x, y, z) - b y F_3 (x, y, z)} + + where `z = \pm \sqrt{\frac{1}{c} (C_1 - a x^{2} - b y^{2})}` + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0405.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = eq[0].match(diff(x(t),t) - z(t)*F2 + y(t)*F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t),t) - eq[1]).match(p*x(t)*r[F3] - r[s]*z(t)*F1)) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t),u).subs(y(t),v).subs(z(t),w) + F2 = r[F2].subs(x(t),u).subs(y(t),v).subs(z(t),w) + F3 = r[F3].subs(x(t),u).subs(y(t),v).subs(z(t),w) + x_yz = sqrt((C1 - b*v**2 - c*w**2)/a) + y_zx = sqrt((C1 - c*w**2 - a*u**2)/b) + z_xy = sqrt((C1 - a*u**2 - b*v**2)/c) + y_x = dsolve(diff(v(u),u) - ((a*u*F3-c*w*F1)/(c*w*F2-b*v*F3)).subs(w,z_xy).subs(v,v(u))).rhs + z_x = dsolve(diff(w(u),u) - ((b*v*F1-a*u*F2)/(c*w*F2-b*v*F3)).subs(v,y_zx).subs(w,w(u))).rhs + z_y = dsolve(diff(w(v),v) - ((b*v*F1-a*u*F2)/(a*u*F3-c*w*F1)).subs(u,x_yz).subs(w,w(v))).rhs + x_y = dsolve(diff(u(v),v) - ((c*w*F2-b*v*F3)/(a*u*F3-c*w*F1)).subs(w,z_xy).subs(u,u(v))).rhs + y_z = dsolve(diff(v(w),w) - ((a*u*F3-c*w*F1)/(b*v*F1-a*u*F2)).subs(u,x_yz).subs(v,v(w))).rhs + x_z = dsolve(diff(u(w),w) - ((c*w*F2-b*v*F3)/(b*v*F1-a*u*F2)).subs(v,y_zx).subs(u,u(w))).rhs + sol1 = dsolve(diff(u(t),t) - (c*w*F2 - b*v*F3).subs(v,y_x).subs(w,z_x).subs(u,u(t))).rhs + sol2 = dsolve(diff(v(t),t) - (a*u*F3 - c*w*F1).subs(u,x_y).subs(w,z_y).subs(v,v(t))).rhs + sol3 = dsolve(diff(w(t),t) - (b*v*F1 - a*u*F2).subs(u,x_z).subs(v,y_z).subs(w,w(t))).rhs + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type5(x, y, z, t, eq): + r""" + .. math:: x' = x (c F_2 - b F_3), \enspace y' = y (a F_3 - c F_1), \enspace z' = z (b F_1 - a F_2) + + where `F_n = F_n (x, y, z, t)` and are arbitrary functions. + + First Integral: + + .. math:: \left|x\right|^{a} \left|y\right|^{b} \left|z\right|^{c} = C_1 + + where `C` is an arbitrary constant. If the function `F_n` is independent of `t`, + then, by eliminating `t` and `z` from the first two equations of the system, one + arrives at a first-order equation. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0406.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + fu, fv, fw = symbols('u, v, w', cls=Function) + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = eq[0].match(diff(x(t), t) - x(t)*F2 + x(t)*F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t), t) - eq[1]).match(y(t)*(p*r[F3] - r[s]*F1))) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t), u).subs(y(t), v).subs(z(t), w) + F2 = r[F2].subs(x(t), u).subs(y(t), v).subs(z(t), w) + F3 = r[F3].subs(x(t), u).subs(y(t), v).subs(z(t), w) + x_yz = (C1*v**-b*w**-c)**-a + y_zx = (C1*w**-c*u**-a)**-b + z_xy = (C1*u**-a*v**-b)**-c + y_x = dsolve(diff(fv(u), u) - ((v*(a*F3 - c*F1))/(u*(c*F2 - b*F3))).subs(w, z_xy).subs(v, fv(u))).rhs + z_x = dsolve(diff(fw(u), u) - ((w*(b*F1 - a*F2))/(u*(c*F2 - b*F3))).subs(v, y_zx).subs(w, fw(u))).rhs + z_y = dsolve(diff(fw(v), v) - ((w*(b*F1 - a*F2))/(v*(a*F3 - c*F1))).subs(u, x_yz).subs(w, fw(v))).rhs + x_y = dsolve(diff(fu(v), v) - ((u*(c*F2 - b*F3))/(v*(a*F3 - c*F1))).subs(w, z_xy).subs(u, fu(v))).rhs + y_z = dsolve(diff(fv(w), w) - ((v*(a*F3 - c*F1))/(w*(b*F1 - a*F2))).subs(u, x_yz).subs(v, fv(w))).rhs + x_z = dsolve(diff(fu(w), w) - ((u*(c*F2 - b*F3))/(w*(b*F1 - a*F2))).subs(v, y_zx).subs(u, fu(w))).rhs + sol1 = dsolve(diff(fu(t), t) - (u*(c*F2 - b*F3)).subs(v, y_x).subs(w, z_x).subs(u, fu(t))).rhs + sol2 = dsolve(diff(fv(t), t) - (v*(a*F3 - c*F1)).subs(u, x_y).subs(w, z_y).subs(v, fv(t))).rhs + sol3 = dsolve(diff(fw(t), t) - (w*(b*F1 - a*F2)).subs(u, x_z).subs(v, y_z).subs(w, fw(t))).rhs + return [sol1, sol2, sol3] + + +#This import is written at the bottom to avoid circular imports. +from .single import SingleODEProblem, SingleODESolver, solver_map diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/riccati.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/riccati.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef66ed0896d39bee8fba1b74a0c93734742fc1f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/riccati.py @@ -0,0 +1,893 @@ +r""" +This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`, +a function which gives all rational particular solutions to first order +Riccati ODEs. A general first order Riccati ODE is given by - + +.. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2 + +where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x` +with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE +anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation +is a Bernoulli ODE. The algorithm presented below can find rational +solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution, +or prove that no rational solution exists for the equation. + +Background +========== + +A Riccati equation can be transformed to its normal form + +.. math:: y' + y^2 = a(x) + +using the transformation + +.. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2} + +where `a(x)` is given by + +.. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2 + +Thus, we can develop an algorithm to solve for the Riccati equation +in its normal form, which would in turn give us the solution for +the original Riccati equation. + +Algorithm +========= + +The algorithm implemented here is presented in the Ph.D thesis +"Rational and Algebraic Solutions of First-Order Algebraic ODEs" +by N. Thieu Vo. The entire thesis can be found here - +https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf + +We have only implemented the Rational Riccati solver (Algorithm 11, +Pg 78-82 in Thesis). Before we proceed towards the implementation +of the algorithm, a few definitions to understand are - + +1. Valuation of a Rational Function at `\infty`: + The valuation of a rational function `p(x)` at `\infty` is equal + to the difference between the degree of the denominator and the + numerator of `p(x)`. + + NOTE: A general definition of valuation of a rational function + at any value of `x` can be found in Pg 63 of the thesis, but + is not of any interest for this algorithm. + +2. Zeros and Poles of a Rational Function: + Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function + of `x`. Then - + + a. The Zeros of `a(x)` are the roots of `S(x)`. + b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty` + can also be a pole of a(x). We say that `a(x)` has a pole at + `\infty` if `a(\frac{1}{x})` has a pole at 0. + +Every pole is associated with an order that is equal to the multiplicity +of its appearance as a root of `T(x)`. A pole is called a simple pole if +it has an order 1. Similarly, a pole is called a multiple pole if it has +an order `\ge` 2. + +Necessary Conditions +==================== + +For a Riccati equation in its normal form, + +.. math:: y' + y^2 = a(x) + +we can define + +a. A pole is called a movable pole if it is a pole of `y(x)` and is not +a pole of `a(x)`. +b. Similarly, a pole is called a non-movable pole if it is a pole of both +`y(x)` and `a(x)`. + +Then, the algorithm states that a rational solution exists only if - + +a. Every pole of `a(x)` must be either a simple pole or a multiple pole +of even order. +b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2. + +This algorithm finds all possible rational solutions for the Riccati ODE. +If no rational solutions are found, it means that no rational solutions +exist. + +The algorithm works for Riccati ODEs where the coefficients are rational +functions in the independent variable `x` with rational number coefficients +i.e. in `Q(x)`. The coefficients in the rational function cannot be floats, +irrational numbers, symbols or any other kind of expression. The reasons +for this are - + +1. When using symbols, different symbols could take the same value and this +would affect the multiplicity of poles if symbols are present here. + +2. An integer degree bound is required to calculate a polynomial solution +to an auxiliary differential equation, which in turn gives the particular +solution for the original ODE. If symbols/floats/irrational numbers are +present, we cannot determine if the expression for the degree bound is an +integer or not. + +Solution +======== + +With these definitions, we can state a general form for the solution of +the equation. `y(x)` must have the form - + +.. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i + +where `x_1, x_2, \dots, x_n` are non-movable poles of `a(x)`, +`\chi_1, \chi_2, \dots, \chi_m` are movable poles of `a(x)`, and the values +of `N, n, r_1, r_2, \dots, r_n` can be determined from `a(x)`. The +coefficient vectors `(d_0, d_1, \dots, d_N)` and `(c_{i1}, c_{i2}, \dots, c_{i r_i})` +can be determined from `a(x)`. We will have 2 choices each of these vectors +and part of the procedure is figuring out which of the 2 should be used +to get the solution correctly. + +Implementation +============== + +In this implementation, we use ``Poly`` to represent a rational function +rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot +represent rational functions directly using ``Poly``, we instead represent +a rational function with 2 ``Poly`` objects - one for its numerator and +the other for its denominator. + +The code is written to match the steps given in the thesis (Pg 82) + +Step 0 : Match the equation - +Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise +an error + +Step 1 : Transform the equation to its normal form as explained in the +theory section. + +Step 2 : Initialize an empty set of solutions, ``sol``. + +Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``. + +Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}` +to ``sol``. + +Step 5 : Find the poles and their multiplicities of `a(x)`. Let +the number of poles be `n`. Also find the valuation of `a(x)` at +`\infty` using ``val_at_inf``. + +NOTE: Although the algorithm considers `\infty` as a pole, it is +not mentioned if it a part of the set of finite poles. `\infty` +is NOT a part of the set of finite poles. If a pole exists at +`\infty`, we use its multiplicity to find the laurent series of +`a(x)` about `\infty`. + +Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using +``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)`` +combinations of choosing between 2 choices for each of the `n` c-vectors +and 1 d-vector. + +NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig +mistake. The term `- d_N` must be replaced with `-N d_N`. The same +has been explained in the code as well. + +For each of these above combinations, do + +Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of +the polynomial solution we must find for the auxiliary equation. + +Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is +one part of y(x) - + +.. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i + +Step 10 : If `m` is a non-negative integer - + +Step 11: Find a polynomial solution of degree `m` for the auxiliary equation. + +There are 2 cases possible - + + a. `m` is a non-negative integer: We can solve for the coefficients + in `p(x)` using Undetermined Coefficients. + + b. `m` is not a non-negative integer: In this case, we cannot find + a polynomial solution to the auxiliary equation, and hence, we ignore + this value of `m`. + +Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}` +to ``sol``. + +Step 13 : For each solution in ``sol``, apply an inverse transformation, +so that the solutions of the original equation are found using the +solutions of the equation in its normal form. +""" + + +from itertools import product +from sympy.core import S +from sympy.core.add import Add +from sympy.core.numbers import oo, Float +from sympy.core.function import count_ops +from sympy.core.relational import Eq +from sympy.core.symbol import symbols, Symbol, Dummy +from sympy.functions import sqrt, exp +from sympy.functions.elementary.complexes import sign +from sympy.integrals.integrals import Integral +from sympy.polys.domains import ZZ +from sympy.polys.polytools import Poly +from sympy.polys.polyroots import roots +from sympy.solvers.solveset import linsolve + + +def riccati_normal(w, x, b1, b2): + """ + Given a solution `w(x)` to the equation + + .. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2 + + and rational function coefficients `b_1(x)` and + `b_2(x)`, this function transforms the solution to + give a solution `y(x)` for its corresponding normal + Riccati ODE + + .. math:: y'(x) + y(x)^2 = a(x) + + using the transformation + + .. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2 + """ + return -b2*w - b2.diff(x)/(2*b2) - b1/2 + + +def riccati_inverse_normal(y, x, b1, b2, bp=None): + """ + Inverse transforming the solution to the normal + Riccati ODE to get the solution to the Riccati ODE. + """ + # bp is the expression which is independent of the solution + # and hence, it need not be computed again + if bp is None: + bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2) + # w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x)) + return -y/b2 + bp + + +def riccati_reduced(eq, f, x): + """ + Convert a Riccati ODE into its corresponding + normal Riccati ODE. + """ + match, funcs = match_riccati(eq, f, x) + # If equation is not a Riccati ODE, exit + if not match: + return False + # Using the rational functions, find the expression for a(x) + b0, b1, b2 = funcs + a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \ + b2.diff(x, 2)/(2*b2) + # Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x) + return f(x).diff(x) + f(x)**2 - a + +def linsolve_dict(eq, syms): + """ + Get the output of linsolve as a dict + """ + # Convert tuple type return value of linsolve + # to a dictionary for ease of use + sol = linsolve(eq, syms) + if not sol: + return {} + return dict(zip(syms, list(sol)[0])) + + +def match_riccati(eq, f, x): + """ + A function that matches and returns the coefficients + if an equation is a Riccati ODE + + Parameters + ========== + + eq: Equation to be matched + f: Dependent variable + x: Independent variable + + Returns + ======= + + match: True if equation is a Riccati ODE, False otherwise + funcs: [b0, b1, b2] if match is True, [] otherwise. Here, + b0, b1 and b2 are rational functions which match the equation. + """ + # Group terms based on f(x) + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + eq = eq.expand().collect(f(x)) + cf = eq.coeff(f(x).diff(x)) + + # There must be an f(x).diff(x) term. + # eq must be an Add object since we are using the expanded + # equation and it must have atleast 2 terms (b2 != 0) + if cf != 0 and isinstance(eq, Add): + + # Divide all coefficients by the coefficient of f(x).diff(x) + # and add the terms again to get the same equation + eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x)) + + # Match the equation with the pattern + b1 = -eq.coeff(f(x)) + b2 = -eq.coeff(f(x)**2) + b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand() + funcs = [b0, b1, b2] + + # Check if coefficients are not symbols and floats + if any(len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in funcs): + return False, [] + + # If b_0(x) contains f(x), it is not a Riccati ODE + if len(b0.atoms(f)) or not all((b2 != 0, b0.is_rational_function(x), + b1.is_rational_function(x), b2.is_rational_function(x))): + return False, [] + return True, funcs + return False, [] + + +def val_at_inf(num, den, x): + # Valuation of a rational function at oo = deg(denom) - deg(numer) + return den.degree(x) - num.degree(x) + + +def check_necessary_conds(val_inf, muls): + """ + The necessary conditions for a rational solution + to exist are as follows - + + i) Every pole of a(x) must be either a simple pole + or a multiple pole of even order. + + ii) The valuation of a(x) at infinity must be even + or be greater than or equal to 2. + + Here, a simple pole is a pole with multiplicity 1 + and a multiple pole is a pole with multiplicity + greater than 1. + """ + return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \ + all(mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls) + + +def inverse_transform_poly(num, den, x): + """ + A function to make the substitution + x -> 1/x in a rational function that + is represented using Poly objects for + numerator and denominator. + """ + # Declare for reuse + one = Poly(1, x) + xpoly = Poly(x, x) + + # Check if degree of numerator is same as denominator + pwr = val_at_inf(num, den, x) + if pwr >= 0: + # Denominator has greater degree. Substituting x with + # 1/x would make the extra power go to the numerator + if num.expr != 0: + num = num.transform(one, xpoly) * x**pwr + den = den.transform(one, xpoly) + else: + # Numerator has greater degree. Substituting x with + # 1/x would make the extra power go to the denominator + num = num.transform(one, xpoly) + den = den.transform(one, xpoly) * x**(-pwr) + return num.cancel(den, include=True) + + +def limit_at_inf(num, den, x): + """ + Find the limit of a rational function + at oo + """ + # pwr = degree(num) - degree(den) + pwr = -val_at_inf(num, den, x) + # Numerator has a greater degree than denominator + # Limit at infinity would depend on the sign of the + # leading coefficients of numerator and denominator + if pwr > 0: + return oo*sign(num.LC()/den.LC()) + # Degree of numerator is equal to that of denominator + # Limit at infinity is just the ratio of leading coeffs + elif pwr == 0: + return num.LC()/den.LC() + # Degree of numerator is less than that of denominator + # Limit at infinity is just 0 + else: + return 0 + + +def construct_c_case_1(num, den, x, pole): + # Find the coefficient of 1/(x - pole)**2 in the + # Laurent series expansion of a(x) about pole. + num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True) + r = (num1.subs(x, pole))/(den1.subs(x, pole)) + + # If multiplicity is 2, the coefficient to be added + # in the c-vector is c = (1 +- sqrt(1 + 4*r))/2 + if r != -S(1)/4: + return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]] + return [[S.Half]] + + +def construct_c_case_2(num, den, x, pole, mul): + # Generate the coefficients using the recurrence + # relation mentioned in (5.14) in the thesis (Pg 80) + + # r_i = mul/2 + ri = mul//2 + + # Find the Laurent series coefficients about the pole + ser = rational_laurent_series(num, den, x, pole, mul, 6) + + # Start with an empty memo to store the coefficients + # This is for the plus case + cplus = [0 for i in range(ri)] + + # Base Case + cplus[ri-1] = sqrt(ser[2*ri]) + + # Iterate backwards to find all coefficients + s = ri - 1 + sm = 0 + for s in range(ri-1, 0, -1): + sm = 0 + for j in range(s+1, ri): + sm += cplus[j-1]*cplus[ri+s-j-1] + if s!= 1: + cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1]) + + # Memo for the minus case + cminus = [-x for x in cplus] + + # Find the 0th coefficient in the recurrence + cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1]) + cminus[0] = (ser[ri+s] - sm - ri*cminus[ri-1])/(2*cminus[ri-1]) + + # Add both the plus and minus cases' coefficients + if cplus != cminus: + return [cplus, cminus] + return cplus + + +def construct_c_case_3(): + # If multiplicity is 1, the coefficient to be added + # in the c-vector is 1 (no choice) + return [[1]] + + +def construct_c(num, den, x, poles, muls): + """ + Helper function to calculate the coefficients + in the c-vector for each pole. + """ + c = [] + for pole, mul in zip(poles, muls): + c.append([]) + + # Case 3 + if mul == 1: + # Add the coefficients from Case 3 + c[-1].extend(construct_c_case_3()) + + # Case 1 + elif mul == 2: + # Add the coefficients from Case 1 + c[-1].extend(construct_c_case_1(num, den, x, pole)) + + # Case 2 + else: + # Add the coefficients from Case 2 + c[-1].extend(construct_c_case_2(num, den, x, pole, mul)) + + return c + + +def construct_d_case_4(ser, N): + # Initialize an empty vector + dplus = [0 for i in range(N+2)] + # d_N = sqrt(a_{2*N}) + dplus[N] = sqrt(ser[2*N]) + + # Use the recurrence relations to find + # the value of d_s + for s in range(N-1, -2, -1): + sm = 0 + for j in range(s+1, N): + sm += dplus[j]*dplus[N+s-j] + if s != -1: + dplus[s] = (ser[N+s] - sm)/(2*dplus[N]) + + # Coefficients for the case of d_N = -sqrt(a_{2*N}) + dminus = [-x for x in dplus] + + # The third equation in Eq 5.15 of the thesis is WRONG! + # d_N must be replaced with N*d_N in that equation. + dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N]) + dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N]) + + if dplus != dminus: + return [dplus, dminus] + return dplus + + +def construct_d_case_5(ser): + # List to store coefficients for plus case + dplus = [0, 0] + + # d_0 = sqrt(a_0) + dplus[0] = sqrt(ser[0]) + + # d_(-1) = a_(-1)/(2*d_0) + dplus[-1] = ser[-1]/(2*dplus[0]) + + # Coefficients for the minus case are just the negative + # of the coefficients for the positive case. + dminus = [-x for x in dplus] + + if dplus != dminus: + return [dplus, dminus] + return dplus + + +def construct_d_case_6(num, den, x): + # s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to + # s_oo = lim x->oo x**2 * a(x) + s_inf = limit_at_inf(Poly(x**2, x)*num, den, x) + + # d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2 + if s_inf != -S(1)/4: + return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]] + return [[S.Half]] + + +def construct_d(num, den, x, val_inf): + """ + Helper function to calculate the coefficients + in the d-vector based on the valuation of the + function at oo. + """ + N = -val_inf//2 + # Multiplicity of oo as a pole + mul = -val_inf if val_inf < 0 else 0 + ser = rational_laurent_series(num, den, x, oo, mul, 1) + + # Case 4 + if val_inf < 0: + d = construct_d_case_4(ser, N) + + # Case 5 + elif val_inf == 0: + d = construct_d_case_5(ser) + + # Case 6 + else: + d = construct_d_case_6(num, den, x) + + return d + + +def rational_laurent_series(num, den, x, r, m, n): + r""" + The function computes the Laurent series coefficients + of a rational function. + + Parameters + ========== + + num: A Poly object that is the numerator of `f(x)`. + den: A Poly object that is the denominator of `f(x)`. + x: The variable of expansion of the series. + r: The point of expansion of the series. + m: Multiplicity of r if r is a pole of `f(x)`. Should + be zero otherwise. + n: Order of the term upto which the series is expanded. + + Returns + ======= + + series: A dictionary that has power of the term as key + and coefficient of that term as value. + + Below is a basic outline of how the Laurent series of a + rational function `f(x)` about `x_0` is being calculated - + + 1. Substitute `x + x_0` in place of `x`. If `x_0` + is a pole of `f(x)`, multiply the expression by `x^m` + where `m` is the multiplicity of `x_0`. Denote the + the resulting expression as g(x). We do this substitution + so that we can now find the Laurent series of g(x) about + `x = 0`. + + 2. We can then assume that the Laurent series of `g(x)` + takes the following form - + + .. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m + + where `a_m` denotes the Laurent series coefficients. + + 3. Multiply the denominator to the RHS of the equation + and form a recurrence relation for the coefficients `a_m`. + """ + one = Poly(1, x, extension=True) + + if r == oo: + # Series at x = oo is equal to first transforming + # the function from x -> 1/x and finding the + # series at x = 0 + num, den = inverse_transform_poly(num, den, x) + r = S(0) + + if r: + # For an expansion about a non-zero point, a + # transformation from x -> x + r must be made + num = num.transform(Poly(x + r, x, extension=True), one) + den = den.transform(Poly(x + r, x, extension=True), one) + + # Remove the pole from the denominator if the series + # expansion is about one of the poles + num, den = (num*x**m).cancel(den, include=True) + + # Equate coefficients for the first terms (base case) + maxdegree = 1 + max(num.degree(), den.degree()) + syms = symbols(f'a:{maxdegree}', cls=Dummy) + diff = num - den * Poly(syms[::-1], x) + coeff_diffs = diff.all_coeffs()[::-1][:maxdegree] + (coeffs, ) = linsolve(coeff_diffs, syms) + + # Use the recursion relation for the rest + recursion = den.all_coeffs()[::-1] + div, rec_rhs = recursion[0], recursion[1:] + series = list(coeffs) + while len(series) < n: + next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div + series.append(-next_coeff) + series = {m - i: val for i, val in enumerate(series)} + return series + +def compute_m_ybar(x, poles, choice, N): + """ + Helper function to calculate - + + 1. m - The degree bound for the polynomial + solution that must be found for the auxiliary + differential equation. + + 2. ybar - Part of the solution which can be + computed using the poles, c and d vectors. + """ + ybar = 0 + m = Poly(choice[-1][-1], x, extension=True) + + # Calculate the first (nested) summation for ybar + # as given in Step 9 of the Thesis (Pg 82) + dybar = [] + for i, polei in enumerate(poles): + for j, cij in enumerate(choice[i]): + dybar.append(cij/(x - polei)**(j + 1)) + m -=Poly(choice[i][0], x, extension=True) # can't accumulate Poly and use with Add + ybar += Add(*dybar) + + # Calculate the second summation for ybar + for i in range(N+1): + ybar += choice[-1][i]*x**i + return (m.expr, ybar) + + +def solve_aux_eq(numa, dena, numy, deny, x, m): + """ + Helper function to find a polynomial solution + of degree m for the auxiliary differential + equation. + """ + # Assume that the solution is of the type + # p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m + psyms = symbols(f'C0:{m}', cls=Dummy) + K = ZZ[psyms] + psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K) + + # Eq (5.16) in Thesis - Pg 81 + auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol + if m >= 1: + px = psol.diff(x) + auxeq += px*(2*numy*deny*dena) + if m >= 2: + auxeq += px.diff(x)*(deny**2*dena) + if m != 0: + # m is a non-zero integer. Find the constant terms using undetermined coefficients + return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True + else: + # m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation + return S.One, auxeq, auxeq == 0 + + +def remove_redundant_sols(sol1, sol2, x): + """ + Helper function to remove redundant + solutions to the differential equation. + """ + # If y1 and y2 are redundant solutions, there is + # some value of the arbitrary constant for which + # they will be equal + + syms1 = sol1.atoms(Symbol, Dummy) + syms2 = sol2.atoms(Symbol, Dummy) + num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()] + num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()] + # Cross multiply + e = num1*den2 - den1*num2 + # Check if there are any constants + syms = list(e.atoms(Symbol, Dummy)) + if len(syms): + # Find values of constants for which solutions are equal + redn = linsolve(e.all_coeffs(), syms) + if len(redn): + # Return the general solution over a particular solution + if len(syms1) > len(syms2): + return sol2 + # If both have constants, return the lesser complex solution + elif len(syms1) == len(syms2): + return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2 + else: + return sol1 + + +def get_gen_sol_from_part_sol(part_sols, a, x): + """" + Helper function which computes the general + solution for a Riccati ODE from its particular + solutions. + + There are 3 cases to find the general solution + from the particular solutions for a Riccati ODE + depending on the number of particular solution(s) + we have - 1, 2 or 3. + + For more information, see Section 6 of + "Methods of Solution of the Riccati Differential Equation" + by D. R. Haaheim and F. M. Stein + """ + + # If no particular solutions are found, a general + # solution cannot be found + if len(part_sols) == 0: + return [] + + # In case of a single particular solution, the general + # solution can be found by using the substitution + # y = y1 + 1/z and solving a Bernoulli ODE to find z. + elif len(part_sols) == 1: + y1 = part_sols[0] + i = exp(Integral(2*y1, x)) + z = i * Integral(a/i, x) + z = z.doit() + if a == 0 or z == 0: + return y1 + return y1 + 1/z + + # In case of 2 particular solutions, the general solution + # can be found by solving a separable equation. This is + # the most common case, i.e. most Riccati ODEs have 2 + # rational particular solutions. + elif len(part_sols) == 2: + y1, y2 = part_sols + # One of them already has a constant + if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0: + u = exp(Integral(y2 - y1, x)).doit() + # Introduce a constant + else: + C1 = Dummy('C1') + u = C1*exp(Integral(y2 - y1, x)).doit() + if u == 1: + return y2 + return (y2*u - y1)/(u - 1) + + # In case of 3 particular solutions, a closed form + # of the general solution can be obtained directly + else: + y1, y2, y3 = part_sols[:3] + C1 = Dummy('C1') + return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3) + + +def solve_riccati(fx, x, b0, b1, b2, gensol=False): + """ + The main function that gives particular/general + solutions to Riccati ODEs that have atleast 1 + rational particular solution. + """ + # Step 1 : Convert to Normal Form + a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \ + b2.diff(x, 2)/(2*b2) + a_t = a.together() + num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()] + num, den = num.cancel(den, include=True) + + # Step 2 + presol = [] + + # Step 3 : a(x) is 0 + if num == 0: + presol.append(1/(x + Dummy('C1'))) + + # Step 4 : a(x) is a non-zero constant + elif x not in num.free_symbols.union(den.free_symbols): + presol.extend([sqrt(a), -sqrt(a)]) + + # Step 5 : Find poles and valuation at infinity + poles = roots(den, x) + poles, muls = list(poles.keys()), list(poles.values()) + val_inf = val_at_inf(num, den, x) + + if len(poles): + # Check necessary conditions (outlined in the module docstring) + if not check_necessary_conds(val_inf, muls): + raise ValueError("Rational Solution doesn't exist") + + # Step 6 + # Construct c-vectors for each singular point + c = construct_c(num, den, x, poles, muls) + + # Construct d vectors for each singular point + d = construct_d(num, den, x, val_inf) + + # Step 7 : Iterate over all possible combinations and return solutions + # For each possible combination, generate an array of 0's and 1's + # where 0 means pick 1st choice and 1 means pick the second choice. + + # NOTE: We could exit from the loop if we find 3 particular solutions, + # but it is not implemented here as - + # a. Finding 3 particular solutions is very rare. Most of the time, + # only 2 particular solutions are found. + # b. In case we exit after finding 3 particular solutions, it might + # happen that 1 or 2 of them are redundant solutions. So, instead of + # spending some more time in computing the particular solutions, + # we will end up computing the general solution from a single + # particular solution which is usually slower than computing the + # general solution from 2 or 3 particular solutions. + c.append(d) + choices = product(*c) + for choice in choices: + m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2) + numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()] + # Step 10 : Check if a valid solution exists. If yes, also check + # if m is a non-negative integer + if m.is_nonnegative == True and m.is_integer == True: + + # Step 11 : Find polynomial solutions of degree m for the auxiliary equation + psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m) + + # Step 12 : If valid polynomial solution exists, append solution. + if exists: + # m == 0 case + if psol == 1 and coeffs == 0: + # p(x) = 1, so p'(x)/p(x) term need not be added + presol.append(ybar) + # m is a positive integer and there are valid coefficients + elif len(coeffs): + # Substitute the valid coefficients to get p(x) + psol = psol.xreplace(coeffs) + # y(x) = ybar(x) + p'(x)/p(x) + presol.append(ybar + psol.diff(x)/psol) + + # Remove redundant solutions from the list of existing solutions + remove = set() + for i in range(len(presol)): + for j in range(i+1, len(presol)): + rem = remove_redundant_sols(presol[i], presol[j], x) + if rem is not None: + remove.add(rem) + sols = [x for x in presol if x not in remove] + + # Step 15 : Inverse transform the solutions of the equation in normal form + bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2) + + # If general solution is required, compute it from the particular solutions + if gensol: + sols = [get_gen_sol_from_part_sol(sols, a, x)] + + # Inverse transform the particular solutions + presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols] + return presol diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/single.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/single.py new file mode 100644 index 0000000000000000000000000000000000000000..c4829acf41293c2f6f20af3e9c45d37457802102 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/single.py @@ -0,0 +1,2977 @@ +# +# This is the module for ODE solver classes for single ODEs. +# + +from __future__ import annotations +from typing import ClassVar, Iterator + +from .riccati import match_riccati, solve_riccati +from sympy.core import Add, S, Pow, Rational +from sympy.core.cache import cached_property +from sympy.core.exprtools import factor_terms +from sympy.core.expr import Expr +from sympy.core.function import AppliedUndef, Derivative, diff, Function, expand, Subs, _mexpand +from sympy.core.numbers import zoo +from sympy.core.relational import Equality, Eq +from sympy.core.symbol import Symbol, Dummy, Wild +from sympy.core.mul import Mul +from sympy.functions import exp, tan, log, sqrt, besselj, bessely, cbrt, airyai, airybi +from sympy.integrals import Integral +from sympy.polys import Poly +from sympy.polys.polytools import cancel, factor, degree +from sympy.simplify import collect, simplify, separatevars, logcombine, posify # type: ignore +from sympy.simplify.radsimp import fraction +from sympy.utilities import numbered_symbols +from sympy.solvers.solvers import solve +from sympy.solvers.deutils import ode_order, _preprocess +from sympy.polys.matrices.linsolve import _lin_eq2dict +from sympy.polys.solvers import PolyNonlinearError +from .hypergeometric import equivalence_hypergeometric, match_2nd_2F1_hypergeometric, \ + get_sol_2F1_hypergeometric, match_2nd_hypergeometric +from .nonhomogeneous import _get_euler_characteristic_eq_sols, _get_const_characteristic_eq_sols, \ + _solve_undetermined_coefficients, _solve_variation_of_parameters, _test_term, _undetermined_coefficients_match, \ + _get_simplified_sol +from .lie_group import _ode_lie_group + + +class ODEMatchError(NotImplementedError): + """Raised if a SingleODESolver is asked to solve an ODE it does not match""" + pass + + +class SingleODEProblem: + """Represents an ordinary differential equation (ODE) + + This class is used internally in the by dsolve and related + functions/classes so that properties of an ODE can be computed + efficiently. + + Examples + ======== + + This class is used internally by dsolve. To instantiate an instance + directly first define an ODE problem: + + >>> from sympy import Function, Symbol + >>> x = Symbol('x') + >>> f = Function('f') + >>> eq = f(x).diff(x, 2) + + Now you can create a SingleODEProblem instance and query its properties: + + >>> from sympy.solvers.ode.single import SingleODEProblem + >>> problem = SingleODEProblem(f(x).diff(x), f(x), x) + >>> problem.eq + Derivative(f(x), x) + >>> problem.func + f(x) + >>> problem.sym + x + """ + + # Instance attributes: + eq: Expr + func: AppliedUndef + sym: Symbol + _order: int + _eq_expanded: Expr + _eq_preprocessed: Expr + _eq_high_order_free = None + + def __init__(self, eq, func, sym, prep=True, **kwargs): + assert isinstance(eq, Expr) + assert isinstance(func, AppliedUndef) + assert isinstance(sym, Symbol) + assert isinstance(prep, bool) + self.eq = eq + self.func = func + self.sym = sym + self.prep = prep + self.params = kwargs + + @cached_property + def order(self) -> int: + return ode_order(self.eq, self.func) + + @cached_property + def eq_preprocessed(self) -> Expr: + return self._get_eq_preprocessed() + + @cached_property + def eq_high_order_free(self) -> Expr: + a = Wild('a', exclude=[self.func]) + c1 = Wild('c1', exclude=[self.sym]) + # Precondition to try remove f(x) from highest order derivative + reduced_eq = None + if self.eq.is_Add: + deriv_coef = self.eq.coeff(self.func.diff(self.sym, self.order)) + if deriv_coef not in (1, 0): + r = deriv_coef.match(a*self.func**c1) + if r and r[c1]: + den = self.func**r[c1] + reduced_eq = Add(*[arg/den for arg in self.eq.args]) + if reduced_eq is None: + reduced_eq = expand(self.eq) + return reduced_eq + + @cached_property + def eq_expanded(self) -> Expr: + return expand(self.eq_preprocessed) + + def _get_eq_preprocessed(self) -> Expr: + if self.prep: + process_eq, process_func = _preprocess(self.eq, self.func) + if process_func != self.func: + raise ValueError + else: + process_eq = self.eq + return process_eq + + def get_numbered_constants(self, num=1, start=1, prefix='C') -> list[Symbol]: + """ + Returns a list of constants that do not occur + in eq already. + """ + ncs = self.iter_numbered_constants(start, prefix) + Cs = [next(ncs) for i in range(num)] + return Cs + + def iter_numbered_constants(self, start=1, prefix='C') -> Iterator[Symbol]: + """ + Returns an iterator of constants that do not occur + in eq already. + """ + atom_set = self.eq.free_symbols + func_set = self.eq.atoms(Function) + if func_set: + atom_set |= {Symbol(str(f.func)) for f in func_set} + return numbered_symbols(start=start, prefix=prefix, exclude=atom_set) + + @cached_property + def is_autonomous(self): + u = Dummy('u') + x = self.sym + syms = self.eq.subs(self.func, u).free_symbols + return x not in syms + + def get_linear_coefficients(self, eq, func, order): + r""" + Matches a differential equation to the linear form: + + .. math:: a_n(x) y^{(n)} + \cdots + a_1(x)y' + a_0(x) y + B(x) = 0 + + Returns a dict of order:coeff terms, where order is the order of the + derivative on each term, and coeff is the coefficient of that derivative. + The key ``-1`` holds the function `B(x)`. Returns ``None`` if the ODE is + not linear. This function assumes that ``func`` has already been checked + to be good. + + Examples + ======== + + >>> from sympy import Function, cos, sin + >>> from sympy.abc import x + >>> from sympy.solvers.ode.single import SingleODEProblem + >>> f = Function('f') + >>> eq = f(x).diff(x, 3) + 2*f(x).diff(x) + \ + ... x*f(x).diff(x, 2) + cos(x)*f(x).diff(x) + x - f(x) - \ + ... sin(x) + >>> obj = SingleODEProblem(eq, f(x), x) + >>> obj.get_linear_coefficients(eq, f(x), 3) + {-1: x - sin(x), 0: -1, 1: cos(x) + 2, 2: x, 3: 1} + >>> eq = f(x).diff(x, 3) + 2*f(x).diff(x) + \ + ... x*f(x).diff(x, 2) + cos(x)*f(x).diff(x) + x - f(x) - \ + ... sin(f(x)) + >>> obj = SingleODEProblem(eq, f(x), x) + >>> obj.get_linear_coefficients(eq, f(x), 3) == None + True + + """ + f = func.func + x = func.args[0] + symset = {Derivative(f(x), x, i) for i in range(order+1)} + try: + rhs, lhs_terms = _lin_eq2dict(eq, symset) + except PolyNonlinearError: + return None + + if rhs.has(func) or any(c.has(func) for c in lhs_terms.values()): + return None + terms = {i: lhs_terms.get(f(x).diff(x, i), S.Zero) for i in range(order+1)} + terms[-1] = rhs + return terms + + # TODO: Add methods that can be used by many ODE solvers: + # order + # is_linear() + # get_linear_coefficients() + # eq_prepared (the ODE in prepared form) + + +class SingleODESolver: + """ + Base class for Single ODE solvers. + + Subclasses should implement the _matches and _get_general_solution + methods. This class is not intended to be instantiated directly but its + subclasses are as part of dsolve. + + Examples + ======== + + You can use a subclass of SingleODEProblem to solve a particular type of + ODE. We first define a particular ODE problem: + + >>> from sympy import Function, Symbol + >>> x = Symbol('x') + >>> f = Function('f') + >>> eq = f(x).diff(x, 2) + + Now we solve this problem using the NthAlgebraic solver which is a + subclass of SingleODESolver: + + >>> from sympy.solvers.ode.single import NthAlgebraic, SingleODEProblem + >>> problem = SingleODEProblem(eq, f(x), x) + >>> solver = NthAlgebraic(problem) + >>> solver.get_general_solution() + [Eq(f(x), _C*x + _C)] + + The normal way to solve an ODE is to use dsolve (which would use + NthAlgebraic and other solvers internally). When using dsolve a number of + other things are done such as evaluating integrals, simplifying the + solution and renumbering the constants: + + >>> from sympy import dsolve + >>> dsolve(eq, hint='nth_algebraic') + Eq(f(x), C1 + C2*x) + """ + + # Subclasses should store the hint name (the argument to dsolve) in this + # attribute + hint: ClassVar[str] + + # Subclasses should define this to indicate if they support an _Integral + # hint. + has_integral: ClassVar[bool] + + # The ODE to be solved + ode_problem: SingleODEProblem + + # Cache whether or not the equation has matched the method + _matched: bool | None = None + + # Subclasses should store in this attribute the list of order(s) of ODE + # that subclass can solve or leave it to None if not specific to any order + order: list | None = None + + def __init__(self, ode_problem): + self.ode_problem = ode_problem + + def matches(self) -> bool: + if self.order is not None and self.ode_problem.order not in self.order: + self._matched = False + return self._matched + + if self._matched is None: + self._matched = self._matches() + return self._matched + + def get_general_solution(self, *, simplify: bool = True) -> list[Equality]: + if not self.matches(): + msg = "%s solver cannot solve:\n%s" + raise ODEMatchError(msg % (self.hint, self.ode_problem.eq)) + return self._get_general_solution(simplify_flag=simplify) + + def _matches(self) -> bool: + msg = "Subclasses of SingleODESolver should implement matches." + raise NotImplementedError(msg) + + def _get_general_solution(self, *, simplify_flag: bool = True) -> list[Equality]: + msg = "Subclasses of SingleODESolver should implement get_general_solution." + raise NotImplementedError(msg) + + +class SinglePatternODESolver(SingleODESolver): + '''Superclass for ODE solvers based on pattern matching''' + + def wilds(self): + prob = self.ode_problem + f = prob.func.func + x = prob.sym + order = prob.order + return self._wilds(f, x, order) + + def wilds_match(self): + match = self._wilds_match + return [match.get(w, S.Zero) for w in self.wilds()] + + def _matches(self): + eq = self.ode_problem.eq_expanded + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + df = f(x).diff(x, order) + + if order not in [1, 2]: + return False + + pattern = self._equation(f(x), x, order) + + if not pattern.coeff(df).has(Wild): + eq = expand(eq / eq.coeff(df)) + eq = eq.collect([f(x).diff(x), f(x)], func = cancel) + + self._wilds_match = match = eq.match(pattern) + if match is not None: + return self._verify(f(x)) + return False + + def _verify(self, fx) -> bool: + return True + + def _wilds(self, f, x, order): + msg = "Subclasses of SingleODESolver should implement _wilds" + raise NotImplementedError(msg) + + def _equation(self, fx, x, order): + msg = "Subclasses of SingleODESolver should implement _equation" + raise NotImplementedError(msg) + + +class NthAlgebraic(SingleODESolver): + r""" + Solves an `n`\th order ordinary differential equation using algebra and + integrals. + + There is no general form for the kind of equation that this can solve. The + the equation is solved algebraically treating differentiation as an + invertible algebraic function. + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = Eq(f(x) * (f(x).diff(x)**2 - 1), 0) + >>> dsolve(eq, f(x), hint='nth_algebraic') + [Eq(f(x), 0), Eq(f(x), C1 - x), Eq(f(x), C1 + x)] + + Note that this solver can return algebraic solutions that do not have any + integration constants (f(x) = 0 in the above example). + """ + + hint = 'nth_algebraic' + has_integral = True # nth_algebraic_Integral hint + + def _matches(self): + r""" + Matches any differential equation that nth_algebraic can solve. Uses + `sympy.solve` but teaches it how to integrate derivatives. + + This involves calling `sympy.solve` and does most of the work of finding a + solution (apart from evaluating the integrals). + """ + eq = self.ode_problem.eq + func = self.ode_problem.func + var = self.ode_problem.sym + + # Derivative that solve can handle: + diffx = self._get_diffx(var) + + # Replace derivatives wrt the independent variable with diffx + def replace(eq, var): + def expand_diffx(*args): + differand, diffs = args[0], args[1:] + toreplace = differand + for v, n in diffs: + for _ in range(n): + if v == var: + toreplace = diffx(toreplace) + else: + toreplace = Derivative(toreplace, v) + return toreplace + return eq.replace(Derivative, expand_diffx) + + # Restore derivatives in solution afterwards + def unreplace(eq, var): + return eq.replace(diffx, lambda e: Derivative(e, var)) + + subs_eqn = replace(eq, var) + try: + # turn off simplification to protect Integrals that have + # _t instead of fx in them and would otherwise factor + # as t_*Integral(1, x) + solns = solve(subs_eqn, func, simplify=False) + except NotImplementedError: + solns = [] + + solns = [simplify(unreplace(soln, var)) for soln in solns] + solns = [Equality(func, soln) for soln in solns] + + self.solutions = solns + return len(solns) != 0 + + def _get_general_solution(self, *, simplify_flag: bool = True): + return self.solutions + + # This needs to produce an invertible function but the inverse depends + # which variable we are integrating with respect to. Since the class can + # be stored in cached results we need to ensure that we always get the + # same class back for each particular integration variable so we store these + # classes in a global dict: + _diffx_stored: dict[Symbol, type[Function]] = {} + + @staticmethod + def _get_diffx(var): + diffcls = NthAlgebraic._diffx_stored.get(var, None) + + if diffcls is None: + # A class that behaves like Derivative wrt var but is "invertible". + class diffx(Function): + def inverse(self): + # don't use integrate here because fx has been replaced by _t + # in the equation; integrals will not be correct while solve + # is at work. + return lambda expr: Integral(expr, var) + Dummy('C') + + diffcls = NthAlgebraic._diffx_stored.setdefault(var, diffx) + + return diffcls + + +class FirstExact(SinglePatternODESolver): + r""" + Solves 1st order exact ordinary differential equations. + + A 1st order differential equation is called exact if it is the total + differential of a function. That is, the differential equation + + .. math:: P(x, y) \,\partial{}x + Q(x, y) \,\partial{}y = 0 + + is exact if there is some function `F(x, y)` such that `P(x, y) = + \partial{}F/\partial{}x` and `Q(x, y) = \partial{}F/\partial{}y`. It can + be shown that a necessary and sufficient condition for a first order ODE + to be exact is that `\partial{}P/\partial{}y = \partial{}Q/\partial{}x`. + Then, the solution will be as given below:: + + >>> from sympy import Function, Eq, Integral, symbols, pprint + >>> x, y, t, x0, y0, C1= symbols('x,y,t,x0,y0,C1') + >>> P, Q, F= map(Function, ['P', 'Q', 'F']) + >>> pprint(Eq(Eq(F(x, y), Integral(P(t, y), (t, x0, x)) + + ... Integral(Q(x0, t), (t, y0, y))), C1)) + x y + / / + | | + F(x, y) = | P(t, y) dt + | Q(x0, t) dt = C1 + | | + / / + x0 y0 + + Where the first partials of `P` and `Q` exist and are continuous in a + simply connected region. + + A note: SymPy currently has no way to represent inert substitution on an + expression, so the hint ``1st_exact_Integral`` will return an integral + with `dy`. This is supposed to represent the function that you are + solving for. + + Examples + ======== + + >>> from sympy import Function, dsolve, cos, sin + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + ... f(x), hint='1st_exact') + Eq(x*cos(f(x)) + f(x)**3/3, C1) + + References + ========== + + - https://en.wikipedia.org/wiki/Exact_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 73 + + # indirect doctest + + """ + hint = "1st_exact" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x).diff(x)]) + Q = Wild('Q', exclude=[f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return P + Q*fx.diff(x) + + def _verify(self, fx) -> bool: + P, Q = self.wilds() + x = self.ode_problem.sym + y = Dummy('y') + + m, n = self.wilds_match() + + m = m.subs(fx, y) + n = n.subs(fx, y) + numerator = cancel(m.diff(y) - n.diff(x)) + + if numerator.is_zero: + # Is exact + return True + else: + # The following few conditions try to convert a non-exact + # differential equation into an exact one. + # References: + # 1. Differential equations with applications + # and historical notes - George E. Simmons + # 2. https://math.okstate.edu/people/binegar/2233-S99/2233-l12.pdf + + factor_n = cancel(numerator/n) + factor_m = cancel(-numerator/m) + if y not in factor_n.free_symbols: + # If (dP/dy - dQ/dx) / Q = f(x) + # then exp(integral(f(x))*equation becomes exact + factor = factor_n + integration_variable = x + elif x not in factor_m.free_symbols: + # If (dP/dy - dQ/dx) / -P = f(y) + # then exp(integral(f(y))*equation becomes exact + factor = factor_m + integration_variable = y + else: + # Couldn't convert to exact + return False + + factor = exp(Integral(factor, integration_variable)) + m *= factor + n *= factor + self._wilds_match[P] = m.subs(y, fx) + self._wilds_match[Q] = n.subs(y, fx) + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + m, n = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + y = Dummy('y') + + m = m.subs(fx, y) + n = n.subs(fx, y) + + gen_sol = Eq(Subs(Integral(m, x) + + Integral(n - Integral(m, x).diff(y), y), y, fx), C1) + return [gen_sol] + + +class FirstLinear(SinglePatternODESolver): + r""" + Solves 1st order linear differential equations. + + These are differential equations of the form + + .. math:: dy/dx + P(x) y = Q(x)\text{.} + + These kinds of differential equations can be solved in a general way. The + integrating factor `e^{\int P(x) \,dx}` will turn the equation into a + separable equation. The general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint, diff, sin + >>> from sympy.abc import x + >>> f, P, Q = map(Function, ['f', 'P', 'Q']) + >>> genform = Eq(f(x).diff(x) + P(x)*f(x), Q(x)) + >>> pprint(genform) + d + P(x)*f(x) + --(f(x)) = Q(x) + dx + >>> pprint(dsolve(genform, f(x), hint='1st_linear_Integral')) + / / \ + | | | + | | / | / + | | | | | + | | | P(x) dx | - | P(x) dx + | | | | | + | | / | / + f(x) = |C1 + | Q(x)*e dx|*e + | | | + \ / / + + + Examples + ======== + + >>> f = Function('f') + >>> pprint(dsolve(Eq(x*diff(f(x), x) - f(x), x**2*sin(x)), + ... f(x), '1st_linear')) + f(x) = x*(C1 - cos(x)) + + References + ========== + + - https://en.wikipedia.org/wiki/Linear_differential_equation#First-order_equation_with_variable_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 92 + + # indirect doctest + + """ + hint = '1st_linear' + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x)]) + Q = Wild('Q', exclude=[f(x), f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return fx.diff(x) + P*fx - Q + + def _get_general_solution(self, *, simplify_flag: bool = True): + P, Q = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + gensol = Eq(fx, ((C1 + Integral(Q*exp(Integral(P, x)), x)) + * exp(-Integral(P, x)))) + return [gensol] + + +class AlmostLinear(SinglePatternODESolver): + r""" + Solves an almost-linear differential equation. + + The general form of an almost linear differential equation is + + .. math:: a(x) g'(f(x)) f'(x) + b(x) g(f(x)) + c(x) + + Here `f(x)` is the function to be solved for (the dependent variable). + The substitution `g(f(x)) = u(x)` leads to a linear differential equation + for `u(x)` of the form `a(x) u' + b(x) u + c(x) = 0`. This can be solved + for `u(x)` by the `first_linear` hint and then `f(x)` is found by solving + `g(f(x)) = u(x)`. + + See Also + ======== + :obj:`sympy.solvers.ode.single.FirstLinear` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint, sin, cos + >>> from sympy.abc import x + >>> f = Function('f') + >>> d = f(x).diff(x) + >>> eq = x*d + x*f(x) + 1 + >>> dsolve(eq, f(x), hint='almost_linear') + Eq(f(x), (C1 - Ei(x))*exp(-x)) + >>> pprint(dsolve(eq, f(x), hint='almost_linear')) + -x + f(x) = (C1 - Ei(x))*e + >>> example = cos(f(x))*f(x).diff(x) + sin(f(x)) + 1 + >>> pprint(example) + d + sin(f(x)) + cos(f(x))*--(f(x)) + 1 + dx + >>> pprint(dsolve(example, f(x), hint='almost_linear')) + / -x \ / -x \ + [f(x) = pi - asin\C1*e - 1/, f(x) = asin\C1*e - 1/] + + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "almost_linear" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x).diff(x)]) + Q = Wild('Q', exclude=[f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return P*fx.diff(x) + Q + + def _verify(self, fx): + a, b = self.wilds_match() + c, b = b.as_independent(fx) if b.is_Add else (S.Zero, b) + # a, b and c are the function a(x), b(x) and c(x) respectively. + # c(x) is obtained by separating out b as terms with and without fx i.e, l(y) + # The following conditions checks if the given equation is an almost-linear differential equation using the fact that + # a(x)*(l(y))' / l(y)' is independent of l(y) + + if b.diff(fx) != 0 and not simplify(b.diff(fx)/a).has(fx): + self.ly = factor_terms(b).as_independent(fx, as_Add=False)[1] # Gives the term containing fx i.e., l(y) + self.ax = a / self.ly.diff(fx) + self.cx = -c # cx is taken as -c(x) to simplify expression in the solution integral + self.bx = factor_terms(b) / self.ly + return True + + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + gensol = Eq(self.ly, ((C1 + Integral((self.cx/self.ax)*exp(Integral(self.bx/self.ax, x)), x)) + * exp(-Integral(self.bx/self.ax, x)))) + + return [gensol] + + +class Bernoulli(SinglePatternODESolver): + r""" + Solves Bernoulli differential equations. + + These are equations of the form + + .. math:: dy/dx + P(x) y = Q(x) y^n\text{, }n \ne 1`\text{.} + + The substitution `w = 1/y^{1-n}` will transform an equation of this form + into one that is linear (see the docstring of + :obj:`~sympy.solvers.ode.single.FirstLinear`). The general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x, n + >>> f, P, Q = map(Function, ['f', 'P', 'Q']) + >>> genform = Eq(f(x).diff(x) + P(x)*f(x), Q(x)*f(x)**n) + >>> pprint(genform) + d n + P(x)*f(x) + --(f(x)) = Q(x)*f (x) + dx + >>> pprint(dsolve(genform, f(x), hint='Bernoulli_Integral'), num_columns=110) + -1 + ----- + n - 1 + // / / \ \ + || | | | | + || | / | / | / | + || | | | | | | | + || | -(n - 1)* | P(x) dx | -(n - 1)* | P(x) dx | (n - 1)* | P(x) dx| + || | | | | | | | + || | / | / | / | + f(x) = ||C1 - n* | Q(x)*e dx + | Q(x)*e dx|*e | + || | | | | + \\ / / / / + + + Note that the equation is separable when `n = 1` (see the docstring of + :obj:`~sympy.solvers.ode.single.Separable`). + + >>> pprint(dsolve(Eq(f(x).diff(x) + P(x)*f(x), Q(x)*f(x)), f(x), + ... hint='separable_Integral')) + f(x) + / + | / + | 1 | + | - dy = C1 + | (-P(x) + Q(x)) dx + | y | + | / + / + + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, pprint, log + >>> from sympy.abc import x + >>> f = Function('f') + + >>> pprint(dsolve(Eq(x*f(x).diff(x) + f(x), log(x)*f(x)**2), + ... f(x), hint='Bernoulli')) + 1 + f(x) = ----------------- + C1*x + log(x) + 1 + + References + ========== + + - https://en.wikipedia.org/wiki/Bernoulli_differential_equation + + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 95 + + # indirect doctest + + """ + hint = "Bernoulli" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x)]) + Q = Wild('Q', exclude=[f(x)]) + n = Wild('n', exclude=[x, f(x), f(x).diff(x)]) + return P, Q, n + + def _equation(self, fx, x, order): + P, Q, n = self.wilds() + return fx.diff(x) + P*fx - Q*fx**n + + def _get_general_solution(self, *, simplify_flag: bool = True): + P, Q, n = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + if n==1: + gensol = Eq(log(fx), ( + C1 + Integral((-P + Q), x) + )) + else: + gensol = Eq(fx**(1-n), ( + (C1 - (n - 1) * Integral(Q*exp(-n*Integral(P, x)) + * exp(Integral(P, x)), x) + ) * exp(-(1 - n)*Integral(P, x))) + ) + return [gensol] + + +class Factorable(SingleODESolver): + r""" + Solves equations having a solvable factor. + + This function is used to solve the equation having factors. Factors may be of type algebraic or ode. It + will try to solve each factor independently. Factors will be solved by calling dsolve. We will return the + list of solutions. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = (f(x)**2-4)*(f(x).diff(x)+f(x)) + >>> pprint(dsolve(eq, f(x))) + -x + [f(x) = 2, f(x) = -2, f(x) = C1*e ] + + + """ + hint = "factorable" + has_integral = False + + def _matches(self): + eq_orig = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + df = f(x).diff(x) + self.eqs = [] + eq = eq_orig.collect(f(x), func = cancel) + eq = fraction(factor(eq))[0] + factors = Mul.make_args(factor(eq)) + roots = [fac.as_base_exp() for fac in factors if len(fac.args)!=0] + if len(roots)>1 or roots[0][1]>1: + for base, expo in roots: + if base.has(f(x)): + self.eqs.append(base) + if len(self.eqs)>0: + return True + roots = solve(eq, df) + if len(roots)>0: + self.eqs = [(df - root) for root in roots] + # Avoid infinite recursion + matches = self.eqs != [eq_orig] + return matches + for i in factors: + if i.has(f(x)): + self.eqs.append(i) + return len(self.eqs)>0 and len(factors)>1 + + def _get_general_solution(self, *, simplify_flag: bool = True): + func = self.ode_problem.func.func + x = self.ode_problem.sym + eqns = self.eqs + sols = [] + for eq in eqns: + try: + sol = dsolve(eq, func(x)) + except NotImplementedError: + continue + else: + if isinstance(sol, list): + sols.extend(sol) + else: + sols.append(sol) + + if sols == []: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the factorable group method") + return sols + + +class RiccatiSpecial(SinglePatternODESolver): + r""" + The general Riccati equation has the form + + .. math:: dy/dx = f(x) y^2 + g(x) y + h(x)\text{.} + + While it does not have a general solution [1], the "special" form, `dy/dx + = a y^2 - b x^c`, does have solutions in many cases [2]. This routine + returns a solution for `a(dy/dx) = b y^2 + c y/x + d/x^2` that is obtained + by using a suitable change of variables to reduce it to the special form + and is valid when neither `a` nor `b` are zero and either `c` or `d` is + zero. + + >>> from sympy.abc import x, a, b, c, d + >>> from sympy import dsolve, checkodesol, pprint, Function + >>> f = Function('f') + >>> y = f(x) + >>> genform = a*y.diff(x) - (b*y**2 + c*y/x + d/x**2) + >>> sol = dsolve(genform, y, hint="Riccati_special_minus2") + >>> pprint(sol, wrap_line=False) + / / __________________ \\ + | __________________ | / 2 || + | / 2 | \/ 4*b*d - (a + c) *log(x)|| + -|a + c - \/ 4*b*d - (a + c) *tan|C1 + ----------------------------|| + \ \ 2*a // + f(x) = ------------------------------------------------------------------------ + 2*b*x + + >>> checkodesol(genform, sol, order=1)[0] + True + + References + ========== + + - https://www.maplesoft.com/support/help/Maple/view.aspx?path=odeadvisor/Riccati + - https://eqworld.ipmnet.ru/en/solutions/ode/ode0106.pdf - + https://eqworld.ipmnet.ru/en/solutions/ode/ode0123.pdf + """ + hint = "Riccati_special_minus2" + has_integral = False + order = [1] + + def _wilds(self, f, x, order): + a = Wild('a', exclude=[x, f(x), f(x).diff(x), 0]) + b = Wild('b', exclude=[x, f(x), f(x).diff(x), 0]) + c = Wild('c', exclude=[x, f(x), f(x).diff(x)]) + d = Wild('d', exclude=[x, f(x), f(x).diff(x)]) + return a, b, c, d + + def _equation(self, fx, x, order): + a, b, c, d = self.wilds() + return a*fx.diff(x) + b*fx**2 + c*fx/x + d/x**2 + + def _get_general_solution(self, *, simplify_flag: bool = True): + a, b, c, d = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + mu = sqrt(4*d*b - (a - c)**2) + + gensol = Eq(fx, (a - c - mu*tan(mu/(2*a)*log(x) + C1))/(2*b*x)) + return [gensol] + + +class RationalRiccati(SinglePatternODESolver): + r""" + Gives general solutions to the first order Riccati differential + equations that have atleast one rational particular solution. + + .. math :: y' = b_0(x) + b_1(x) y + b_2(x) y^2 + + where `b_0`, `b_1` and `b_2` are rational functions of `x` + with `b_2 \ne 0` (`b_2 = 0` would make it a Bernoulli equation). + + Examples + ======== + + >>> from sympy import Symbol, Function, dsolve, checkodesol + >>> f = Function('f') + >>> x = Symbol('x') + + >>> eq = -x**4*f(x)**2 + x**3*f(x).diff(x) + x**2*f(x) + 20 + >>> sol = dsolve(eq, hint="1st_rational_riccati") + >>> sol + Eq(f(x), (4*C1 - 5*x**9 - 4)/(x**2*(C1 + x**9 - 1))) + >>> checkodesol(eq, sol) + (True, 0) + + References + ========== + + - Riccati ODE: https://en.wikipedia.org/wiki/Riccati_equation + - N. Thieu Vo - Rational and Algebraic Solutions of First-Order Algebraic ODEs: + Algorithm 11, pp. 78 - https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf + """ + has_integral = False + hint = "1st_rational_riccati" + order = [1] + + def _wilds(self, f, x, order): + b0 = Wild('b0', exclude=[f(x), f(x).diff(x)]) + b1 = Wild('b1', exclude=[f(x), f(x).diff(x)]) + b2 = Wild('b2', exclude=[f(x), f(x).diff(x)]) + return (b0, b1, b2) + + def _equation(self, fx, x, order): + b0, b1, b2 = self.wilds() + return fx.diff(x) - b0 - b1*fx - b2*fx**2 + + def _matches(self): + eq = self.ode_problem.eq_expanded + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + + if order != 1: + return False + + match, funcs = match_riccati(eq, f, x) + if not match: + return False + _b0, _b1, _b2 = funcs + b0, b1, b2 = self.wilds() + self._wilds_match = match = {b0: _b0, b1: _b1, b2: _b2} + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + # Match the equation + b0, b1, b2 = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + return solve_riccati(fx, x, b0, b1, b2, gensol=True) + + +class SecondNonlinearAutonomousConserved(SinglePatternODESolver): + r""" + Gives solution for the autonomous second order nonlinear + differential equation of the form + + .. math :: f''(x) = g(f(x)) + + The solution for this differential equation can be computed + by multiplying by `f'(x)` and integrating on both sides, + converting it into a first order differential equation. + + Examples + ======== + + >>> from sympy import Function, symbols, dsolve + >>> f, g = symbols('f g', cls=Function) + >>> x = symbols('x') + + >>> eq = f(x).diff(x, 2) - g(f(x)) + >>> dsolve(eq, simplify=False) + [Eq(Integral(1/sqrt(C1 + 2*Integral(g(_u), _u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 + 2*Integral(g(_u), _u)), (_u, f(x))), C2 - x)] + + >>> from sympy import exp, log + >>> eq = f(x).diff(x, 2) - exp(f(x)) + log(f(x)) + >>> dsolve(eq, simplify=False) + [Eq(Integral(1/sqrt(-2*_u*log(_u) + 2*_u + C1 + 2*exp(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(-2*_u*log(_u) + 2*_u + C1 + 2*exp(_u)), (_u, f(x))), C2 - x)] + + References + ========== + + - https://eqworld.ipmnet.ru/en/solutions/ode/ode0301.pdf + """ + hint = "2nd_nonlinear_autonomous_conserved" + has_integral = True + order = [2] + + def _wilds(self, f, x, order): + fy = Wild('fy', exclude=[0, f(x).diff(x), f(x).diff(x, 2)]) + return (fy, ) + + def _equation(self, fx, x, order): + fy = self.wilds()[0] + return fx.diff(x, 2) + fy + + def _verify(self, fx): + return self.ode_problem.is_autonomous + + def _get_general_solution(self, *, simplify_flag: bool = True): + g = self.wilds_match()[0] + fx = self.ode_problem.func + x = self.ode_problem.sym + u = Dummy('u') + g = g.subs(fx, u) + C1, C2 = self.ode_problem.get_numbered_constants(num=2) + inside = -2*Integral(g, u) + C1 + lhs = Integral(1/sqrt(inside), (u, fx)) + return [Eq(lhs, C2 + x), Eq(lhs, C2 - x)] + + +class Liouville(SinglePatternODESolver): + r""" + Solves 2nd order Liouville differential equations. + + The general form of a Liouville ODE is + + .. math:: \frac{d^2 y}{dx^2} + g(y) \left(\! + \frac{dy}{dx}\!\right)^2 + h(x) + \frac{dy}{dx}\text{.} + + The general solution is: + + >>> from sympy import Function, dsolve, Eq, pprint, diff + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = Eq(diff(f(x),x,x) + g(f(x))*diff(f(x),x)**2 + + ... h(x)*diff(f(x),x), 0) + >>> pprint(genform) + 2 2 + /d \ d d + g(f(x))*|--(f(x))| + h(x)*--(f(x)) + ---(f(x)) = 0 + \dx / dx 2 + dx + >>> pprint(dsolve(genform, f(x), hint='Liouville_Integral')) + f(x) + / / + | | + | / | / + | | | | + | - | h(x) dx | | g(y) dy + | | | | + | / | / + C1 + C2* | e dx + | e dy = 0 + | | + / / + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(diff(f(x), x, x) + diff(f(x), x)**2/f(x) + + ... diff(f(x), x)/x, f(x), hint='Liouville')) + ________________ ________________ + [f(x) = -\/ C1 + C2*log(x) , f(x) = \/ C1 + C2*log(x) ] + + References + ========== + + - Goldstein and Braun, "Advanced Methods for the Solution of Differential + Equations", pp. 98 + - https://www.maplesoft.com/support/help/Maple/view.aspx?path=odeadvisor/Liouville + + # indirect doctest + + """ + hint = "Liouville" + has_integral = True + order = [2] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + k = Wild('k', exclude=[f(x).diff(x)]) + return d, e, k + + def _equation(self, fx, x, order): + # Liouville ODE in the form + # f(x).diff(x, 2) + g(f(x))*(f(x).diff(x))**2 + h(x)*f(x).diff(x) + # See Goldstein and Braun, "Advanced Methods for the Solution of + # Differential Equations", pg. 98 + d, e, k = self.wilds() + return d*fx.diff(x, 2) + e*fx.diff(x)**2 + k*fx.diff(x) + + def _verify(self, fx): + d, e, k = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.g = simplify(e/d).subs(fx, self.y) + self.h = simplify(k/d).subs(fx, self.y) + if self.y in self.h.free_symbols or x in self.g.free_symbols: + return False + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, k = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + C1, C2 = self.ode_problem.get_numbered_constants(num=2) + int = Integral(exp(Integral(self.g, self.y)), (self.y, None, fx)) + gen_sol = Eq(int + C1*Integral(exp(-Integral(self.h, x)), x) + C2, 0) + + return [gen_sol] + + +class Separable(SinglePatternODESolver): + r""" + Solves separable 1st order differential equations. + + This is any differential equation that can be written as `P(y) + \tfrac{dy}{dx} = Q(x)`. The solution can then just be found by + rearranging terms and integrating: `\int P(y) \,dy = \int Q(x) \,dx`. + This hint uses :py:meth:`sympy.simplify.simplify.separatevars` as its back + end, so if a separable equation is not caught by this solver, it is most + likely the fault of that function. + :py:meth:`~sympy.simplify.simplify.separatevars` is + smart enough to do most expansion and factoring necessary to convert a + separable equation `F(x, y)` into the proper form `P(x)\cdot{}Q(y)`. The + general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x + >>> a, b, c, d, f = map(Function, ['a', 'b', 'c', 'd', 'f']) + >>> genform = Eq(a(x)*b(f(x))*f(x).diff(x), c(x)*d(f(x))) + >>> pprint(genform) + d + a(x)*b(f(x))*--(f(x)) = c(x)*d(f(x)) + dx + >>> pprint(dsolve(genform, f(x), hint='separable_Integral')) + f(x) + / / + | | + | b(y) | c(x) + | ---- dy = C1 + | ---- dx + | d(y) | a(x) + | | + / / + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(Eq(f(x)*f(x).diff(x) + x, 3*x*f(x)**2), f(x), + ... hint='separable', simplify=False)) + / 2 \ 2 + log\3*f (x) - 1/ x + ---------------- = C1 + -- + 6 2 + + References + ========== + + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 52 + + # indirect doctest + + """ + hint = "separable" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + d, e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + d = separatevars(d.subs(fx, self.y)) + e = separatevars(e.subs(fx, self.y)) + # m1[coeff]*m1[x]*m1[y] + m2[coeff]*m2[x]*m2[y]*y' + self.m1 = separatevars(d, dict=True, symbols=(x, self.y)) + self.m2 = separatevars(e, dict=True, symbols=(x, self.y)) + return bool(self.m1 and self.m2) + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + return self.m1, self.m2, x, fx + + def _get_general_solution(self, *, simplify_flag: bool = True): + m1, m2, x, fx = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral(m2['coeff']*m2[self.y]/m1[self.y], + (self.y, None, fx)) + gen_sol = Eq(int, Integral(-m1['coeff']*m1[x]/ + m2[x], x) + C1) + return [gen_sol] + + +class SeparableReduced(Separable): + r""" + Solves a differential equation that can be reduced to the separable form. + + The general form of this equation is + + .. math:: y' + (y/x) H(x^n y) = 0\text{}. + + This can be solved by substituting `u(y) = x^n y`. The equation then + reduces to the separable form `\frac{u'}{u (\mathrm{power} - H(u))} - + \frac{1}{x} = 0`. + + The general solution is: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x, n + >>> f, g = map(Function, ['f', 'g']) + >>> genform = f(x).diff(x) + (f(x)/x)*g(x**n*f(x)) + >>> pprint(genform) + / n \ + d f(x)*g\x *f(x)/ + --(f(x)) + --------------- + dx x + >>> pprint(dsolve(genform, hint='separable_reduced')) + n + x *f(x) + / + | + | 1 + | ------------ dy = C1 + log(x) + | y*(n - g(y)) + | + / + + See Also + ======== + :obj:`sympy.solvers.ode.single.Separable` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> d = f(x).diff(x) + >>> eq = (x - x**2*f(x))*d - f(x) + >>> dsolve(eq, hint='separable_reduced') + [Eq(f(x), (1 - sqrt(C1*x**2 + 1))/x), Eq(f(x), (sqrt(C1*x**2 + 1) + 1)/x)] + >>> pprint(dsolve(eq, hint='separable_reduced')) + ___________ ___________ + / 2 / 2 + 1 - \/ C1*x + 1 \/ C1*x + 1 + 1 + [f(x) = ------------------, f(x) = ------------------] + x x + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "separable_reduced" + has_integral = True + order = [1] + + def _degree(self, expr, x): + # Made this function to calculate the degree of + # x in an expression. If expr will be of form + # x**p*y, (wheare p can be variables/rationals) then it + # will return p. + for val in expr: + if val.has(x): + if isinstance(val, Pow) and val.as_base_exp()[0] == x: + return (val.as_base_exp()[1]) + elif val == x: + return (val.as_base_exp()[1]) + else: + return self._degree(val.args, x) + return 0 + + def _powers(self, expr): + # this function will return all the different relative power of x w.r.t f(x). + # expr = x**p * f(x)**q then it will return {p/q}. + pows = set() + fx = self.ode_problem.func + x = self.ode_problem.sym + self.y = Dummy('y') + if isinstance(expr, Add): + exprs = expr.atoms(Add) + elif isinstance(expr, Mul): + exprs = expr.atoms(Mul) + elif isinstance(expr, Pow): + exprs = expr.atoms(Pow) + else: + exprs = {expr} + + for arg in exprs: + if arg.has(x): + _, u = arg.as_independent(x, fx) + pow = self._degree((u.subs(fx, self.y), ), x)/self._degree((u.subs(fx, self.y), ), self.y) + pows.add(pow) + return pows + + def _verify(self, fx): + num, den = self.wilds_match() + x = self.ode_problem.sym + factor = simplify(x/fx*num/den) + # Try representing factor in terms of x^n*y + # where n is lowest power of x in factor; + # first remove terms like sqrt(2)*3 from factor.atoms(Mul) + num, dem = factor.as_numer_denom() + num = expand(num) + dem = expand(dem) + pows = self._powers(num) + pows.update(self._powers(dem)) + pows = list(pows) + if(len(pows)==1) and pows[0]!=zoo: + self.t = Dummy('t') + self.r2 = {'t': self.t} + num = num.subs(x**pows[0]*fx, self.t) + dem = dem.subs(x**pows[0]*fx, self.t) + test = num/dem + free = test.free_symbols + if len(free) == 1 and free.pop() == self.t: + self.r2.update({'power' : pows[0], 'u' : test}) + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + u = self.r2['u'].subs(self.r2['t'], self.y) + ycoeff = 1/(self.y*(self.r2['power'] - u)) + m1 = {self.y: 1, x: -1/x, 'coeff': 1} + m2 = {self.y: ycoeff, x: 1, 'coeff': 1} + return m1, m2, x, x**self.r2['power']*fx + + +class HomogeneousCoeffSubsDepDivIndep(SinglePatternODESolver): + r""" + Solves a 1st order differential equation with homogeneous coefficients + using the substitution `u_1 = \frac{\text{}}{\text{}}`. + + This is a differential equation + + .. math:: P(x, y) + Q(x, y) dy/dx = 0 + + such that `P` and `Q` are homogeneous and of the same order. A function + `F(x, y)` is homogeneous of order `n` if `F(x t, y t) = t^n F(x, y)`. + Equivalently, `F(x, y)` can be rewritten as `G(y/x)` or `H(x/y)`. See + also the docstring of :py:meth:`~sympy.solvers.ode.homogeneous_order`. + + If the coefficients `P` and `Q` in the differential equation above are + homogeneous functions of the same order, then it can be shown that the + substitution `y = u_1 x` (i.e. `u_1 = y/x`) will turn the differential + equation into an equation separable in the variables `x` and `u`. If + `h(u_1)` is the function that results from making the substitution `u_1 = + f(x)/x` on `P(x, f(x))` and `g(u_2)` is the function that results from the + substitution on `Q(x, f(x))` in the differential equation `P(x, f(x)) + + Q(x, f(x)) f'(x) = 0`, then the general solution is:: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = g(f(x)/x) + h(f(x)/x)*f(x).diff(x) + >>> pprint(genform) + /f(x)\ /f(x)\ d + g|----| + h|----|*--(f(x)) + \ x / \ x / dx + >>> pprint(dsolve(genform, f(x), + ... hint='1st_homogeneous_coeff_subs_dep_div_indep_Integral')) + f(x) + ---- + x + / + | + | -h(u1) + log(x) = C1 + | ---------------- d(u1) + | u1*h(u1) + g(u1) + | + / + + Where `u_1 h(u_1) + g(u_1) \ne 0` and `x \ne 0`. + + See also the docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep`. + + Examples + ======== + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_subs_dep_div_indep', simplify=False)) + / 3 \ + |3*f(x) f (x)| + log|------ + -----| + | x 3 | + \ x / + log(x) = log(C1) - ------------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_subs_dep_div_indep" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.d = separatevars(self.d.subs(fx, self.y)) + self.e = separatevars(self.e.subs(fx, self.y)) + ordera = homogeneous_order(self.d, x, self.y) + orderb = homogeneous_order(self.e, x, self.y) + if ordera == orderb and ordera is not None: + self.u = Dummy('u') + if simplify((self.d + self.u*self.e).subs({x: 1, self.y: self.u})) != 0: + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + xarg = 0 + yarg = 0 + return [self.d, self.e, fx, x, self.u, self.u1, self.y, xarg, yarg] + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, fx, x, u, u1, y, xarg, yarg = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral( + (-e/(d + u1*e)).subs({x: 1, y: u1}), + (u1, None, fx/x)) + sol = logcombine(Eq(log(x), int + log(C1)), force=True) + gen_sol = sol.subs(fx, u).subs(((u, u - yarg), (x, x - xarg), (u, fx))) + return [gen_sol] + + +class HomogeneousCoeffSubsIndepDivDep(SinglePatternODESolver): + r""" + Solves a 1st order differential equation with homogeneous coefficients + using the substitution `u_2 = \frac{\text{}}{\text{}}`. + + This is a differential equation + + .. math:: P(x, y) + Q(x, y) dy/dx = 0 + + such that `P` and `Q` are homogeneous and of the same order. A function + `F(x, y)` is homogeneous of order `n` if `F(x t, y t) = t^n F(x, y)`. + Equivalently, `F(x, y)` can be rewritten as `G(y/x)` or `H(x/y)`. See + also the docstring of :py:meth:`~sympy.solvers.ode.homogeneous_order`. + + If the coefficients `P` and `Q` in the differential equation above are + homogeneous functions of the same order, then it can be shown that the + substitution `x = u_2 y` (i.e. `u_2 = x/y`) will turn the differential + equation into an equation separable in the variables `y` and `u_2`. If + `h(u_2)` is the function that results from making the substitution `u_2 = + x/f(x)` on `P(x, f(x))` and `g(u_2)` is the function that results from the + substitution on `Q(x, f(x))` in the differential equation `P(x, f(x)) + + Q(x, f(x)) f'(x) = 0`, then the general solution is: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = g(x/f(x)) + h(x/f(x))*f(x).diff(x) + >>> pprint(genform) + / x \ / x \ d + g|----| + h|----|*--(f(x)) + \f(x)/ \f(x)/ dx + >>> pprint(dsolve(genform, f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep_Integral')) + x + ---- + f(x) + / + | + | -g(u1) + | ---------------- d(u1) + | u1*g(u1) + h(u1) + | + / + + f(x) = C1*e + + Where `u_1 g(u_1) + h(u_1) \ne 0` and `f(x) \ne 0`. + + See also the docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep`. + + Examples + ======== + + >>> from sympy import Function, pprint, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep', + ... simplify=False)) + / 2 \ + |3*x | + log|----- + 1| + | 2 | + \f (x) / + log(f(x)) = log(C1) - -------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_subs_indep_div_dep" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.d = separatevars(self.d.subs(fx, self.y)) + self.e = separatevars(self.e.subs(fx, self.y)) + ordera = homogeneous_order(self.d, x, self.y) + orderb = homogeneous_order(self.e, x, self.y) + if ordera == orderb and ordera is not None: + self.u = Dummy('u') + if simplify((self.e + self.u*self.d).subs({x: self.u, self.y: 1})) != 0: + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + xarg = 0 + yarg = 0 + return [self.d, self.e, fx, x, self.u, self.u1, self.y, xarg, yarg] + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, fx, x, u, u1, y, xarg, yarg = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral(simplify((-d/(e + u1*d)).subs({x: u1, y: 1})), (u1, None, x/fx)) # type: ignore + sol = logcombine(Eq(log(fx), int + log(C1)), force=True) + gen_sol = sol.subs(fx, u).subs(((u, u - yarg), (x, x - xarg), (u, fx))) + return [gen_sol] + + +class HomogeneousCoeffBest(HomogeneousCoeffSubsIndepDivDep, HomogeneousCoeffSubsDepDivIndep): + r""" + Returns the best solution to an ODE from the two hints + ``1st_homogeneous_coeff_subs_dep_div_indep`` and + ``1st_homogeneous_coeff_subs_indep_div_dep``. + + This is as determined by :py:meth:`~sympy.solvers.ode.ode.ode_sol_simplicity`. + + See the + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep` + and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` + docstrings for more information on these hints. Note that there is no + ``ode_1st_homogeneous_coeff_best_Integral`` hint. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_best', simplify=False)) + / 2 \ + |3*x | + log|----- + 1| + | 2 | + \f (x) / + log(f(x)) = log(C1) - -------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_best" + has_integral = False + order = [1] + + def _verify(self, fx): + return HomogeneousCoeffSubsIndepDivDep._verify(self, fx) and \ + HomogeneousCoeffSubsDepDivIndep._verify(self, fx) + + def _get_general_solution(self, *, simplify_flag: bool = True): + # There are two substitutions that solve the equation, u1=y/x and u2=x/y + # # They produce different integrals, so try them both and see which + # # one is easier + sol1 = HomogeneousCoeffSubsIndepDivDep._get_general_solution(self) + sol2 = HomogeneousCoeffSubsDepDivIndep._get_general_solution(self) + fx = self.ode_problem.func + if simplify_flag: + sol1 = odesimp(self.ode_problem.eq, *sol1, fx, "1st_homogeneous_coeff_subs_indep_div_dep") + sol2 = odesimp(self.ode_problem.eq, *sol2, fx, "1st_homogeneous_coeff_subs_dep_div_indep") + # XXX: not simplify should be not simplify_flag. mypy correctly complains + return min([sol1, sol2], key=lambda x: ode_sol_simplicity(x, fx, trysolving=not simplify)) # type: ignore + + +class LinearCoefficients(HomogeneousCoeffBest): + r""" + Solves a differential equation with linear coefficients. + + The general form of a differential equation with linear coefficients is + + .. math:: y' + F\left(\!\frac{a_1 x + b_1 y + c_1}{a_2 x + b_2 y + + c_2}\!\right) = 0\text{,} + + where `a_1`, `b_1`, `c_1`, `a_2`, `b_2`, `c_2` are constants and `a_1 b_2 + - a_2 b_1 \ne 0`. + + This can be solved by substituting: + + .. math:: x = x' + \frac{b_2 c_1 - b_1 c_2}{a_2 b_1 - a_1 b_2} + + y = y' + \frac{a_1 c_2 - a_2 c_1}{a_2 b_1 - a_1 + b_2}\text{.} + + This substitution reduces the equation to a homogeneous differential + equation. + + See Also + ======== + :obj:`sympy.solvers.ode.single.HomogeneousCoeffBest` + :obj:`sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep` + :obj:`sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> df = f(x).diff(x) + >>> eq = (x + f(x) + 1)*df + (f(x) - 6*x + 1) + >>> dsolve(eq, hint='linear_coefficients') + [Eq(f(x), -x - sqrt(C1 + 7*x**2) - 1), Eq(f(x), -x + sqrt(C1 + 7*x**2) - 1)] + >>> pprint(dsolve(eq, hint='linear_coefficients')) + ___________ ___________ + / 2 / 2 + [f(x) = -x - \/ C1 + 7*x - 1, f(x) = -x + \/ C1 + 7*x - 1] + + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "linear_coefficients" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + a, b = self.wilds() + F = self.d/self.e + x = self.ode_problem.sym + params = self._linear_coeff_match(F, fx) + if params: + self.xarg, self.yarg = params + u = Dummy('u') + t = Dummy('t') + self.y = Dummy('y') + # Dummy substitution for df and f(x). + dummy_eq = self.ode_problem.eq.subs(((fx.diff(x), t), (fx, u))) + reps = ((x, x + self.xarg), (u, u + self.yarg), (t, fx.diff(x)), (u, fx)) + dummy_eq = simplify(dummy_eq.subs(reps)) + # get the re-cast values for e and d + r2 = collect(expand(dummy_eq), [fx.diff(x), fx]).match(a*fx.diff(x) + b) + if r2: + self.d, self.e = r2[b], r2[a] + orderd = homogeneous_order(self.d, x, fx) + ordere = homogeneous_order(self.e, x, fx) + if orderd == ordere and orderd is not None: + self.d = self.d.subs(fx, self.y) + self.e = self.e.subs(fx, self.y) + return True + return False + return False + + def _linear_coeff_match(self, expr, func): + r""" + Helper function to match hint ``linear_coefficients``. + + Matches the expression to the form `(a_1 x + b_1 f(x) + c_1)/(a_2 x + b_2 + f(x) + c_2)` where the following conditions hold: + + 1. `a_1`, `b_1`, `c_1`, `a_2`, `b_2`, `c_2` are Rationals; + 2. `c_1` or `c_2` are not equal to zero; + 3. `a_2 b_1 - a_1 b_2` is not equal to zero. + + Return ``xarg``, ``yarg`` where + + 1. ``xarg`` = `(b_2 c_1 - b_1 c_2)/(a_2 b_1 - a_1 b_2)` + 2. ``yarg`` = `(a_1 c_2 - a_2 c_1)/(a_2 b_1 - a_1 b_2)` + + + Examples + ======== + + >>> from sympy import Function, sin + >>> from sympy.abc import x + >>> from sympy.solvers.ode.single import LinearCoefficients + >>> f = Function('f') + >>> eq = (-25*f(x) - 8*x + 62)/(4*f(x) + 11*x - 11) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + (1/9, 22/9) + >>> eq = sin((-5*f(x) - 8*x + 6)/(4*f(x) + x - 1)) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + (19/27, 2/27) + >>> eq = sin(f(x)/x) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + + """ + f = func.func + x = func.args[0] + def abc(eq): + r''' + Internal function of _linear_coeff_match + that returns Rationals a, b, c + if eq is a*x + b*f(x) + c, else None. + ''' + eq = _mexpand(eq) + c = eq.as_independent(x, f(x), as_Add=True)[0] + if not c.is_Rational: + return + a = eq.coeff(x) + if not a.is_Rational: + return + b = eq.coeff(f(x)) + if not b.is_Rational: + return + if eq == a*x + b*f(x) + c: + return a, b, c + + def match(arg): + r''' + Internal function of _linear_coeff_match that returns Rationals a1, + b1, c1, a2, b2, c2 and a2*b1 - a1*b2 of the expression (a1*x + b1*f(x) + + c1)/(a2*x + b2*f(x) + c2) if one of c1 or c2 and a2*b1 - a1*b2 is + non-zero, else None. + ''' + n, d = arg.together().as_numer_denom() + m = abc(n) + if m is not None: + a1, b1, c1 = m + m = abc(d) + if m is not None: + a2, b2, c2 = m + d = a2*b1 - a1*b2 + if (c1 or c2) and d: + return a1, b1, c1, a2, b2, c2, d + + m = [fi.args[0] for fi in expr.atoms(Function) if fi.func != f and + len(fi.args) == 1 and not fi.args[0].is_Function] or {expr} + m1 = match(m.pop()) + if m1 and all(match(mi) == m1 for mi in m): + a1, b1, c1, a2, b2, c2, denom = m1 + return (b2*c1 - b1*c2)/denom, (a1*c2 - a2*c1)/denom + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + u = Dummy('u') + return [self.d, self.e, fx, x, u, self.u1, self.y, self.xarg, self.yarg] + + +class NthOrderReducible(SingleODESolver): + r""" + Solves ODEs that only involve derivatives of the dependent variable using + a substitution of the form `f^n(x) = g(x)`. + + For example any second order ODE of the form `f''(x) = h(f'(x), x)` can be + transformed into a pair of 1st order ODEs `g'(x) = h(g(x), x)` and + `f'(x) = g(x)`. Usually the 1st order ODE for `g` is easier to solve. If + that gives an explicit solution for `g` then `f` is found simply by + integration. + + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = Eq(x*f(x).diff(x)**2 + f(x).diff(x, 2), 0) + >>> dsolve(eq, f(x), hint='nth_order_reducible') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), C1 - sqrt(-1/C2)*log(-C2*sqrt(-1/C2) + x) + sqrt(-1/C2)*log(C2*sqrt(-1/C2) + x)) + + """ + hint = "nth_order_reducible" + has_integral = False + + def _matches(self): + # Any ODE that can be solved with a substitution and + # repeated integration e.g.: + # `d^2/dx^2(y) + x*d/dx(y) = constant + #f'(x) must be finite for this to work + eq = self.ode_problem.eq_preprocessed + func = self.ode_problem.func + x = self.ode_problem.sym + r""" + Matches any differential equation that can be rewritten with a smaller + order. Only derivatives of ``func`` alone, wrt a single variable, + are considered, and only in them should ``func`` appear. + """ + # ODE only handles functions of 1 variable so this affirms that state + assert len(func.args) == 1 + vc = [d.variable_count[0] for d in eq.atoms(Derivative) + if d.expr == func and len(d.variable_count) == 1] + ords = [c for v, c in vc if v == x] + if len(ords) < 2: + return False + self.smallest = min(ords) + # make sure func does not appear outside of derivatives + D = Dummy() + if eq.subs(func.diff(x, self.smallest), D).has(func): + return False + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + n = self.smallest + # get a unique function name for g + names = [a.name for a in eq.atoms(AppliedUndef)] + while True: + name = Dummy().name + if name not in names: + g = Function(name) + break + w = f(x).diff(x, n) + geq = eq.subs(w, g(x)) + gsol = dsolve(geq, g(x)) + + if not isinstance(gsol, list): + gsol = [gsol] + + # Might be multiple solutions to the reduced ODE: + fsol = [] + for gsoli in gsol: + fsoli = dsolve(gsoli.subs(g(x), w), f(x)) # or do integration n times + fsol.append(fsoli) + + return fsol + + +class SecondHypergeometric(SingleODESolver): + r""" + Solves 2nd order linear differential equations. + + It computes special function solutions which can be expressed using the + 2F1, 1F1 or 0F1 hypergeometric functions. + + .. math:: y'' + A(x) y' + B(x) y = 0\text{,} + + where `A` and `B` are rational functions. + + These kinds of differential equations have solution of non-Liouvillian form. + + Given linear ODE can be obtained from 2F1 given by + + .. math:: (x^2 - x) y'' + ((a + b + 1) x - c) y' + b a y = 0\text{,} + + where {a, b, c} are arbitrary constants. + + Notes + ===== + + The algorithm should find any solution of the form + + .. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,} + + where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function". + Currently only the 2F1 case is implemented in SymPy but the other cases are + described in the paper and could be implemented in future (contributions + welcome!). + + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = (x*x - x)*f(x).diff(x,2) + (5*x - 1)*f(x).diff(x) + 4*f(x) + >>> pprint(dsolve(eq, f(x), '2nd_hypergeometric')) + _ + / / 4 \\ |_ /-1, -1 | \ + |C1 + C2*|log(x) + -----||* | | | x| + \ \ x + 1// 2 1 \ 1 | / + f(x) = -------------------------------------------- + 3 + (x - 1) + + + References + ========== + + - "Non-Liouvillian solutions for second order linear ODEs" by L. Chan, E.S. Cheb-Terrab + + """ + hint = "2nd_hypergeometric" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + func = self.ode_problem.func + r = match_2nd_hypergeometric(eq, func) + self.match_object = None + if r: + A, B = r + d = equivalence_hypergeometric(A, B, func) + if d: + if d['type'] == "2F1": + self.match_object = match_2nd_2F1_hypergeometric(d['I0'], d['k'], d['sing_point'], func) + if self.match_object is not None: + self.match_object.update({'A':A, 'B':B}) + # We can extend it for 1F1 and 0F1 type also. + return self.match_object is not None + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + func = self.ode_problem.func + if self.match_object['type'] == "2F1": + sol = get_sol_2F1_hypergeometric(eq, func, self.match_object) + if sol is None: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the hypergeometric method") + + return [sol] + + +class NthLinearConstantCoeffHomogeneous(SingleODESolver): + r""" + Solves an `n`\th order linear homogeneous differential equation with + constant coefficients. + + This is an equation of the form + + .. math:: a_n f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + + a_0 f(x) = 0\text{.} + + These equations can be solved in a general manner, by taking the roots of + the characteristic equation `a_n m^n + a_{n-1} m^{n-1} + \cdots + a_1 m + + a_0 = 0`. The solution will then be the sum of `C_n x^i e^{r x}` terms, + for each where `C_n` is an arbitrary constant, `r` is a root of the + characteristic equation and `i` is one of each from 0 to the multiplicity + of the root - 1 (for example, a root 3 of multiplicity 2 would create the + terms `C_1 e^{3 x} + C_2 x e^{3 x}`). The exponential is usually expanded + for complex roots using Euler's equation `e^{I x} = \cos(x) + I \sin(x)`. + Complex roots always come in conjugate pairs in polynomials with real + coefficients, so the two roots will be represented (after simplifying the + constants) as `e^{a x} \left(C_1 \cos(b x) + C_2 \sin(b x)\right)`. + + If SymPy cannot find exact roots to the characteristic equation, a + :py:class:`~sympy.polys.rootoftools.ComplexRootOf` instance will be return + instead. + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(f(x).diff(x, 5) + 10*f(x).diff(x) - 2*f(x), f(x), + ... hint='nth_linear_constant_coeff_homogeneous') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), C5*exp(x*CRootOf(_x**5 + 10*_x - 2, 0)) + + (C1*sin(x*im(CRootOf(_x**5 + 10*_x - 2, 1))) + + C2*cos(x*im(CRootOf(_x**5 + 10*_x - 2, 1))))*exp(x*re(CRootOf(_x**5 + 10*_x - 2, 1))) + + (C3*sin(x*im(CRootOf(_x**5 + 10*_x - 2, 3))) + + C4*cos(x*im(CRootOf(_x**5 + 10*_x - 2, 3))))*exp(x*re(CRootOf(_x**5 + 10*_x - 2, 3)))) + + Note that because this method does not involve integration, there is no + ``nth_linear_constant_coeff_homogeneous_Integral`` hint. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 4) + 2*f(x).diff(x, 3) - + ... 2*f(x).diff(x, 2) - 6*f(x).diff(x) + 5*f(x), f(x), + ... hint='nth_linear_constant_coeff_homogeneous')) + x -2*x + f(x) = (C1 + C2*x)*e + (C3*sin(x) + C4*cos(x))*e + + References + ========== + + - https://en.wikipedia.org/wiki/Linear_differential_equation section: + Nonhomogeneous_equation_with_constant_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 211 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_homogeneous" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if not self.r[-1]: + return True + else: + return False + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + fx = self.ode_problem.func + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, fx, order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + gsol_rhs = Add(*[i*j for (i, j) in zip(constants, roots)]) + gsol = Eq(fx, gsol_rhs) + if simplify_flag: + gsol = _get_simplified_sol([gsol], fx, collectterms) + + return [gsol] + + +class NthLinearConstantCoeffVariationOfParameters(SingleODESolver): + r""" + Solves an `n`\th order linear differential equation with constant + coefficients using the method of variation of parameters. + + This method works on any differential equations of the form + + .. math:: f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + a_0 + f(x) = P(x)\text{.} + + This method works by assuming that the particular solution takes the form + + .. math:: \sum_{x=1}^{n} c_i(x) y_i(x)\text{,} + + where `y_i` is the `i`\th solution to the homogeneous equation. The + solution is then solved using Wronskian's and Cramer's Rule. The + particular solution is given by + + .. math:: \sum_{x=1}^n \left( \int \frac{W_i(x)}{W(x)} \,dx + \right) y_i(x) \text{,} + + where `W(x)` is the Wronskian of the fundamental system (the system of `n` + linearly independent solutions to the homogeneous equation), and `W_i(x)` + is the Wronskian of the fundamental system with the `i`\th column replaced + with `[0, 0, \cdots, 0, P(x)]`. + + This method is general enough to solve any `n`\th order inhomogeneous + linear differential equation with constant coefficients, but sometimes + SymPy cannot simplify the Wronskian well enough to integrate it. If this + method hangs, try using the + ``nth_linear_constant_coeff_variation_of_parameters_Integral`` hint and + simplifying the integrals manually. Also, prefer using + ``nth_linear_constant_coeff_undetermined_coefficients`` when it + applies, because it does not use integration, making it faster and more + reliable. + + Warning, using simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters' in + :py:meth:`~sympy.solvers.ode.dsolve` may cause it to hang, because it will + not attempt to simplify the Wronskian before integrating. It is + recommended that you only use simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters_Integral' for this + method, especially if the solution to the homogeneous equation has + trigonometric functions in it. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint, exp, log + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 3) - 3*f(x).diff(x, 2) + + ... 3*f(x).diff(x) - f(x) - exp(x)*log(x), f(x), + ... hint='nth_linear_constant_coeff_variation_of_parameters')) + / / / x*log(x) 11*x\\\ x + f(x) = |C1 + x*|C2 + x*|C3 + -------- - ----|||*e + \ \ \ 6 36 /// + + References + ========== + + - https://en.wikipedia.org/wiki/Variation_of_parameters + - https://planetmath.org/VariationOfParameters + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 233 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_variation_of_parameters" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if self.r[-1]: + return True + else: + return False + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, f(x), order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + homogen_sol_rhs = Add(*[i*j for (i, j) in zip(constants, roots)]) + homogen_sol = Eq(f(x), homogen_sol_rhs) + homogen_sol = _solve_variation_of_parameters(eq, f(x), roots, homogen_sol, order, self.r, simplify_flag) + if simplify_flag: + homogen_sol = _get_simplified_sol([homogen_sol], f(x), collectterms) + return [homogen_sol] + + +class NthLinearConstantCoeffUndeterminedCoefficients(SingleODESolver): + r""" + Solves an `n`\th order linear differential equation with constant + coefficients using the method of undetermined coefficients. + + This method works on differential equations of the form + + .. math:: a_n f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + + a_0 f(x) = P(x)\text{,} + + where `P(x)` is a function that has a finite number of linearly + independent derivatives. + + Functions that fit this requirement are finite sums functions of the form + `a x^i e^{b x} \sin(c x + d)` or `a x^i e^{b x} \cos(c x + d)`, where `i` + is a non-negative integer and `a`, `b`, `c`, and `d` are constants. For + example any polynomial in `x`, functions like `x^2 e^{2 x}`, `x \sin(x)`, + and `e^x \cos(x)` can all be used. Products of `\sin`'s and `\cos`'s have + a finite number of derivatives, because they can be expanded into `\sin(a + x)` and `\cos(b x)` terms. However, SymPy currently cannot do that + expansion, so you will need to manually rewrite the expression in terms of + the above to use this method. So, for example, you will need to manually + convert `\sin^2(x)` into `(1 + \cos(2 x))/2` to properly apply the method + of undetermined coefficients on it. + + This method works by creating a trial function from the expression and all + of its linear independent derivatives and substituting them into the + original ODE. The coefficients for each term will be a system of linear + equations, which are be solved for and substituted, giving the solution. + If any of the trial functions are linearly dependent on the solution to + the homogeneous equation, they are multiplied by sufficient `x` to make + them linearly independent. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint, exp, cos + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 2) + 2*f(x).diff(x) + f(x) - + ... 4*exp(-x)*x**2 + cos(2*x), f(x), + ... hint='nth_linear_constant_coeff_undetermined_coefficients')) + / / 3\\ + | | x || -x 4*sin(2*x) 3*cos(2*x) + f(x) = |C1 + x*|C2 + --||*e - ---------- + ---------- + \ \ 3 // 25 25 + + References + ========== + + - https://en.wikipedia.org/wiki/Method_of_undetermined_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 221 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_undetermined_coefficients" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + does_match = False + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if self.r[-1]: + eq_homogeneous = Add(eq, -self.r[-1]) + undetcoeff = _undetermined_coefficients_match(self.r[-1], x, func, eq_homogeneous) + if undetcoeff['test']: + self.trialset = undetcoeff['trialset'] + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, f(x), order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + homogen_sol_rhs = Add(*[i*j for (i, j) in zip(constants, roots)]) + homogen_sol = Eq(f(x), homogen_sol_rhs) + self.r.update({'list': roots, 'sol': homogen_sol, 'simpliy_flag': simplify_flag}) + gsol = _solve_undetermined_coefficients(eq, f(x), order, self.r, self.trialset) + if simplify_flag: + gsol = _get_simplified_sol([gsol], f(x), collectterms) + return [gsol] + + +class NthLinearEulerEqHomogeneous(SingleODESolver): + r""" + Solves an `n`\th order linear homogeneous variable-coefficient + Cauchy-Euler equidimensional ordinary differential equation. + + This is an equation with form `0 = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + These equations can be solved in a general manner, by substituting + solutions of the form `f(x) = x^r`, and deriving a characteristic equation + for `r`. When there are repeated roots, we include extra terms of the + form `C_{r k} \ln^k(x) x^r`, where `C_{r k}` is an arbitrary integration + constant, `r` is a root of the characteristic equation, and `k` ranges + over the multiplicity of `r`. In the cases where the roots are complex, + solutions of the form `C_1 x^a \sin(b \log(x)) + C_2 x^a \cos(b \log(x))` + are returned, based on expansions with Euler's formula. The general + solution is the sum of the terms found. If SymPy cannot find exact roots + to the characteristic equation, a + :py:obj:`~.ComplexRootOf` instance will be returned + instead. + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(4*x**2*f(x).diff(x, 2) + f(x), f(x), + ... hint='nth_linear_euler_eq_homogeneous') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), sqrt(x)*(C1 + C2*log(x))) + + Note that because this method does not involve integration, there is no + ``nth_linear_euler_eq_homogeneous_Integral`` hint. + + The following is for internal use: + + - ``returns = 'sol'`` returns the solution to the ODE. + - ``returns = 'list'`` returns a list of linearly independent solutions, + corresponding to the fundamental solution set, for use with non + homogeneous solution methods like variation of parameters and + undetermined coefficients. Note that, though the solutions should be + linearly independent, this function does not explicitly check that. You + can do ``assert simplify(wronskian(sollist)) != 0`` to check for linear + independence. Also, ``assert len(sollist) == order`` will need to pass. + - ``returns = 'both'``, return a dictionary ``{'sol': , + 'list': }``. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = f(x).diff(x, 2)*x**2 - 4*f(x).diff(x)*x + 6*f(x) + >>> pprint(dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_homogeneous')) + 2 + f(x) = x *(C1 + C2*x) + + References + ========== + + - https://en.wikipedia.org/wiki/Cauchy%E2%80%93Euler_equation + - C. Bender & S. Orszag, "Advanced Mathematical Methods for Scientists and + Engineers", Springer 1999, pp. 12 + + # indirect doctest + + """ + hint = "nth_linear_euler_eq_homogeneous" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if not self.r[-1]: + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + fx = self.ode_problem.func + eq = self.ode_problem.eq + homogen_sol = _get_euler_characteristic_eq_sols(eq, fx, self.r)[0] + return [homogen_sol] + + +class NthLinearEulerEqNonhomogeneousVariationOfParameters(SingleODESolver): + r""" + Solves an `n`\th order linear non homogeneous Cauchy-Euler equidimensional + ordinary differential equation using variation of parameters. + + This is an equation with form `g(x) = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + This method works by assuming that the particular solution takes the form + + .. math:: \sum_{x=1}^{n} c_i(x) y_i(x) {a_n} {x^n} \text{, } + + where `y_i` is the `i`\th solution to the homogeneous equation. The + solution is then solved using Wronskian's and Cramer's Rule. The + particular solution is given by multiplying eq given below with `a_n x^{n}` + + .. math:: \sum_{x=1}^n \left( \int \frac{W_i(x)}{W(x)} \, dx + \right) y_i(x) \text{, } + + where `W(x)` is the Wronskian of the fundamental system (the system of `n` + linearly independent solutions to the homogeneous equation), and `W_i(x)` + is the Wronskian of the fundamental system with the `i`\th column replaced + with `[0, 0, \cdots, 0, \frac{x^{- n}}{a_n} g{\left(x \right)}]`. + + This method is general enough to solve any `n`\th order inhomogeneous + linear differential equation, but sometimes SymPy cannot simplify the + Wronskian well enough to integrate it. If this method hangs, try using the + ``nth_linear_constant_coeff_variation_of_parameters_Integral`` hint and + simplifying the integrals manually. Also, prefer using + ``nth_linear_constant_coeff_undetermined_coefficients`` when it + applies, because it does not use integration, making it faster and more + reliable. + + Warning, using simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters' in + :py:meth:`~sympy.solvers.ode.dsolve` may cause it to hang, because it will + not attempt to simplify the Wronskian before integrating. It is + recommended that you only use simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters_Integral' for this + method, especially if the solution to the homogeneous equation has + trigonometric functions in it. + + Examples + ======== + + >>> from sympy import Function, dsolve, Derivative + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - x**4 + >>> dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_nonhomogeneous_variation_of_parameters').expand() + Eq(f(x), C1*x + C2*x**2 + x**4/6) + + """ + hint = "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if self.r[-1]: + does_match = True + + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + homogen_sol, roots = _get_euler_characteristic_eq_sols(eq, f(x), self.r) + self.r[-1] = self.r[-1]/self.r[order] + sol = _solve_variation_of_parameters(eq, f(x), roots, homogen_sol, order, self.r, simplify_flag) + + return [Eq(f(x), homogen_sol.rhs + (sol.rhs - homogen_sol.rhs)*self.r[order])] + + +class NthLinearEulerEqNonhomogeneousUndeterminedCoefficients(SingleODESolver): + r""" + Solves an `n`\th order linear non homogeneous Cauchy-Euler equidimensional + ordinary differential equation using undetermined coefficients. + + This is an equation with form `g(x) = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + These equations can be solved in a general manner, by substituting + solutions of the form `x = exp(t)`, and deriving a characteristic equation + of form `g(exp(t)) = b_0 f(t) + b_1 f'(t) + b_2 f''(t) \cdots` which can + be then solved by nth_linear_constant_coeff_undetermined_coefficients if + g(exp(t)) has finite number of linearly independent derivatives. + + Functions that fit this requirement are finite sums functions of the form + `a x^i e^{b x} \sin(c x + d)` or `a x^i e^{b x} \cos(c x + d)`, where `i` + is a non-negative integer and `a`, `b`, `c`, and `d` are constants. For + example any polynomial in `x`, functions like `x^2 e^{2 x}`, `x \sin(x)`, + and `e^x \cos(x)` can all be used. Products of `\sin`'s and `\cos`'s have + a finite number of derivatives, because they can be expanded into `\sin(a + x)` and `\cos(b x)` terms. However, SymPy currently cannot do that + expansion, so you will need to manually rewrite the expression in terms of + the above to use this method. So, for example, you will need to manually + convert `\sin^2(x)` into `(1 + \cos(2 x))/2` to properly apply the method + of undetermined coefficients on it. + + After replacement of x by exp(t), this method works by creating a trial function + from the expression and all of its linear independent derivatives and + substituting them into the original ODE. The coefficients for each term + will be a system of linear equations, which are be solved for and + substituted, giving the solution. If any of the trial functions are linearly + dependent on the solution to the homogeneous equation, they are multiplied + by sufficient `x` to make them linearly independent. + + Examples + ======== + + >>> from sympy import dsolve, Function, Derivative, log + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - log(x) + >>> dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients').expand() + Eq(f(x), C1*x + C2*x**2 + log(x)/2 + 3/4) + + """ + hint = "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if self.r[-1]: + e, re = posify(self.r[-1].subs(x, exp(x))) + undetcoeff = _undetermined_coefficients_match(e.subs(re), x) + if undetcoeff['test']: + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + chareq, eq, symbol = S.Zero, S.Zero, Dummy('x') + for i in self.r.keys(): + if i >= 0: + chareq += (self.r[i]*diff(x**symbol, x, i)*x**-symbol).expand() + + for i in range(1, degree(Poly(chareq, symbol))+1): + eq += chareq.coeff(symbol**i)*diff(f(x), x, i) + + if chareq.as_coeff_add(symbol)[0]: + eq += chareq.as_coeff_add(symbol)[0]*f(x) + e, re = posify(self.r[-1].subs(x, exp(x))) + eq += e.subs(re) + + self.const_undet_instance = NthLinearConstantCoeffUndeterminedCoefficients(SingleODEProblem(eq, f(x), x)) + sol = self.const_undet_instance.get_general_solution(simplify = simplify_flag)[0] + sol = sol.subs(x, log(x)) # type: ignore + sol = sol.subs(f(log(x)), f(x)).expand() # type: ignore + + return [sol] + + +class SecondLinearBessel(SingleODESolver): + r""" + Gives solution of the Bessel differential equation + + .. math :: x^2 \frac{d^2y}{dx^2} + x \frac{dy}{dx} y(x) + (x^2-n^2) y(x) + + if `n` is integer then the solution is of the form ``Eq(f(x), C0 besselj(n,x) + + C1 bessely(n,x))`` as both the solutions are linearly independent else if + `n` is a fraction then the solution is of the form ``Eq(f(x), C0 besselj(n,x) + + C1 besselj(-n,x))`` which can also transform into ``Eq(f(x), C0 besselj(n,x) + + C1 bessely(n,x))``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Symbol + >>> v = Symbol('v', positive=True) + >>> from sympy import dsolve, Function + >>> f = Function('f') + >>> y = f(x) + >>> genform = x**2*y.diff(x, 2) + x*y.diff(x) + (x**2 - v**2)*y + >>> dsolve(genform) + Eq(f(x), C1*besselj(v, x) + C2*bessely(v, x)) + + References + ========== + + https://math24.net/bessel-differential-equation.html + + """ + hint = "2nd_linear_bessel" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f.diff(x) + a = Wild('a', exclude=[f,df]) + b = Wild('b', exclude=[x, f,df]) + a4 = Wild('a4', exclude=[x,f,df]) + b4 = Wild('b4', exclude=[x,f,df]) + c4 = Wild('c4', exclude=[x,f,df]) + d4 = Wild('d4', exclude=[x,f,df]) + a3 = Wild('a3', exclude=[f, df, f.diff(x, 2)]) + b3 = Wild('b3', exclude=[f, df, f.diff(x, 2)]) + c3 = Wild('c3', exclude=[f, df, f.diff(x, 2)]) + deq = a3*(f.diff(x, 2)) + b3*df + c3*f + r = collect(eq, + [f.diff(x, 2), df, f]).match(deq) + if order == 2 and r: + if not all(r[key].is_polynomial() for key in r): + n, d = eq.as_numer_denom() + eq = expand(n) + r = collect(eq, + [f.diff(x, 2), df, f]).match(deq) + + if r and r[a3] != 0: + # leading coeff of f(x).diff(x, 2) + coeff = factor(r[a3]).match(a4*(x-b)**b4) + + if coeff: + # if coeff[b4] = 0 means constant coefficient + if coeff[b4] == 0: + return False + point = coeff[b] + else: + return False + + if point: + r[a3] = simplify(r[a3].subs(x, x+point)) + r[b3] = simplify(r[b3].subs(x, x+point)) + r[c3] = simplify(r[c3].subs(x, x+point)) + + # making a3 in the form of x**2 + r[a3] = cancel(r[a3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + r[b3] = cancel(r[b3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + r[c3] = cancel(r[c3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + # checking if b3 is of form c*(x-b) + coeff1 = factor(r[b3]).match(a4*(x)) + if coeff1 is None: + return False + # c3 maybe of very complex form so I am simply checking (a - b) form + # if yes later I will match with the standard form of bessel in a and b + # a, b are wild variable defined above. + _coeff2 = expand(r[c3]).match(a - b) + if _coeff2 is None: + return False + # matching with standard form for c3 + coeff2 = factor(_coeff2[a]).match(c4**2*(x)**(2*a4)) + if coeff2 is None: + return False + + if _coeff2[b] == 0: + coeff2[d4] = 0 + else: + coeff2[d4] = factor(_coeff2[b]).match(d4**2)[d4] + + self.rn = {'n':coeff2[d4], 'a4':coeff2[c4], 'd4':coeff2[a4]} + self.rn['c4'] = coeff1[a4] + self.rn['b4'] = point + return True + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + n = self.rn['n'] + a4 = self.rn['a4'] + c4 = self.rn['c4'] + d4 = self.rn['d4'] + b4 = self.rn['b4'] + n = sqrt(n**2 + Rational(1, 4)*(c4 - 1)**2) + (C1, C2) = self.ode_problem.get_numbered_constants(num=2) + return [Eq(f(x), ((x**(Rational(1-c4,2)))*(C1*besselj(n/d4,a4*x**d4/d4) + + C2*bessely(n/d4,a4*x**d4/d4))).subs(x, x-b4))] + + +class SecondLinearAiry(SingleODESolver): + r""" + Gives solution of the Airy differential equation + + .. math :: \frac{d^2y}{dx^2} + (a + b x) y(x) = 0 + + in terms of Airy special functions airyai and airybi. + + Examples + ======== + + >>> from sympy import dsolve, Function + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = f(x).diff(x, 2) - x*f(x) + >>> dsolve(eq) + Eq(f(x), C1*airyai(x) + C2*airybi(x)) + """ + hint = "2nd_linear_airy" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f.diff(x) + a4 = Wild('a4', exclude=[x,f,df]) + b4 = Wild('b4', exclude=[x,f,df]) + match = self.ode_problem.get_linear_coefficients(eq, f, order) + does_match = False + if order == 2 and match and match[2] != 0: + if match[1].is_zero: + self.rn = cancel(match[0]/match[2]).match(a4+b4*x) + if self.rn and self.rn[b4] != 0: + self.rn = {'b':self.rn[a4],'m':self.rn[b4]} + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + (C1, C2) = self.ode_problem.get_numbered_constants(num=2) + b = self.rn['b'] + m = self.rn['m'] + if m.is_positive: + arg = - b/cbrt(m)**2 - cbrt(m)*x + elif m.is_negative: + arg = - b/cbrt(-m)**2 + cbrt(-m)*x + else: + arg = - b/cbrt(-m)**2 + cbrt(-m)*x + + return [Eq(f(x), C1*airyai(arg) + C2*airybi(arg))] + + +class LieGroup(SingleODESolver): + r""" + This hint implements the Lie group method of solving first order differential + equations. The aim is to convert the given differential equation from the + given coordinate system into another coordinate system where it becomes + invariant under the one-parameter Lie group of translations. The converted + ODE can be easily solved by quadrature. It makes use of the + :py:meth:`sympy.solvers.ode.infinitesimals` function which returns the + infinitesimals of the transformation. + + The coordinates `r` and `s` can be found by solving the following Partial + Differential Equations. + + .. math :: \xi\frac{\partial r}{\partial x} + \eta\frac{\partial r}{\partial y} + = 0 + + .. math :: \xi\frac{\partial s}{\partial x} + \eta\frac{\partial s}{\partial y} + = 1 + + The differential equation becomes separable in the new coordinate system + + .. math :: \frac{ds}{dr} = \frac{\frac{\partial s}{\partial x} + + h(x, y)\frac{\partial s}{\partial y}}{ + \frac{\partial r}{\partial x} + h(x, y)\frac{\partial r}{\partial y}} + + After finding the solution by integration, it is then converted back to the original + coordinate system by substituting `r` and `s` in terms of `x` and `y` again. + + Examples + ======== + + >>> from sympy import Function, dsolve, exp, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x) + 2*x*f(x) - x*exp(-x**2), f(x), + ... hint='lie_group')) + / 2\ 2 + | x | -x + f(x) = |C1 + --|*e + \ 2 / + + + References + ========== + + - Solving differential equations by Symmetry Groups, + John Starrett, pp. 1 - pp. 14 + + """ + hint = "lie_group" + has_integral = False + + def _has_additional_params(self): + return 'xi' in self.ode_problem.params and 'eta' in self.ode_problem.params + + def _matches(self): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f(x).diff(x) + y = Dummy('y') + d = Wild('d', exclude=[df, f(x).diff(x, 2)]) + e = Wild('e', exclude=[df]) + does_match = False + if self._has_additional_params() and order == 1: + xi = self.ode_problem.params['xi'] + eta = self.ode_problem.params['eta'] + self.r3 = {'xi': xi, 'eta': eta} + r = collect(eq, df, exact=True).match(d + e * df) + if r: + r['d'] = d + r['e'] = e + r['y'] = y + r[d] = r[d].subs(f(x), y) + r[e] = r[e].subs(f(x), y) + self.r3.update(r) + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + x = self.ode_problem.sym + func = self.ode_problem.func + order = self.ode_problem.order + df = func.diff(x) + + try: + eqsol = solve(eq, df) + except NotImplementedError: + eqsol = [] + + desols = [] + for s in eqsol: + sol = _ode_lie_group(s, func, order, match=self.r3) + if sol: + desols.extend(sol) + + if desols == []: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the lie group method") + return desols + + +solver_map = { + 'factorable': Factorable, + 'nth_linear_constant_coeff_homogeneous': NthLinearConstantCoeffHomogeneous, + 'nth_linear_euler_eq_homogeneous': NthLinearEulerEqHomogeneous, + 'nth_linear_constant_coeff_undetermined_coefficients': NthLinearConstantCoeffUndeterminedCoefficients, + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients': NthLinearEulerEqNonhomogeneousUndeterminedCoefficients, + 'separable': Separable, + '1st_exact': FirstExact, + '1st_linear': FirstLinear, + 'Bernoulli': Bernoulli, + 'Riccati_special_minus2': RiccatiSpecial, + '1st_rational_riccati': RationalRiccati, + '1st_homogeneous_coeff_best': HomogeneousCoeffBest, + '1st_homogeneous_coeff_subs_indep_div_dep': HomogeneousCoeffSubsIndepDivDep, + '1st_homogeneous_coeff_subs_dep_div_indep': HomogeneousCoeffSubsDepDivIndep, + 'almost_linear': AlmostLinear, + 'linear_coefficients': LinearCoefficients, + 'separable_reduced': SeparableReduced, + 'nth_linear_constant_coeff_variation_of_parameters': NthLinearConstantCoeffVariationOfParameters, + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters': NthLinearEulerEqNonhomogeneousVariationOfParameters, + 'Liouville': Liouville, + '2nd_linear_airy': SecondLinearAiry, + '2nd_linear_bessel': SecondLinearBessel, + '2nd_hypergeometric': SecondHypergeometric, + 'nth_order_reducible': NthOrderReducible, + '2nd_nonlinear_autonomous_conserved': SecondNonlinearAutonomousConserved, + 'nth_algebraic': NthAlgebraic, + 'lie_group': LieGroup, + } + +# Avoid circular import: +from .ode import dsolve, ode_sol_simplicity, odesimp, homogeneous_order diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/subscheck.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/subscheck.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac7fba7d364bf599e928ccf591b5bef096576d0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/subscheck.py @@ -0,0 +1,392 @@ +from sympy.core import S, Pow +from sympy.core.function import (Derivative, AppliedUndef, diff) +from sympy.core.relational import Equality, Eq +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify + +from sympy.logic.boolalg import BooleanAtom +from sympy.functions import exp +from sympy.series import Order +from sympy.simplify.simplify import simplify, posify, besselsimp +from sympy.simplify.trigsimp import trigsimp +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.solvers import solve +from sympy.solvers.deutils import _preprocess, ode_order +from sympy.utilities.iterables import iterable, is_sequence + + +def sub_func_doit(eq, func, new): + r""" + When replacing the func with something else, we usually want the + derivative evaluated, so this function helps in making that happen. + + Examples + ======== + + >>> from sympy import Derivative, symbols, Function + >>> from sympy.solvers.ode.subscheck import sub_func_doit + >>> x, z = symbols('x, z') + >>> y = Function('y') + + >>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x) + 2 + + >>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x), + ... 1/(x*(z + 1/x))) + x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x)) + ...- 1/(x**2*(z + 1/x)**2) + """ + reps= {func: new} + for d in eq.atoms(Derivative): + if d.expr == func: + reps[d] = new.diff(*d.variable_count) + else: + reps[d] = d.xreplace({func: new}).doit(deep=False) + return eq.xreplace(reps) + + +def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True): + r""" + Substitutes ``sol`` into ``ode`` and checks that the result is ``0``. + + This works when ``func`` is one function, like `f(x)` or a list of + functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can + be a single solution or a list of solutions. Each solution may be an + :py:class:`~sympy.core.relational.Equality` that the solution satisfies, + e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an + :py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it + will not be necessary to explicitly identify the function, but if the + function cannot be inferred from the original equation it can be supplied + through the ``func`` argument. + + If a sequence of solutions is passed, the same sort of container will be + used to return the result for each solution. + + It tries the following methods, in order, until it finds zero equivalence: + + 1. Substitute the solution for `f` in the original equation. This only + works if ``ode`` is solved for `f`. It will attempt to solve it first + unless ``solve_for_func == False``. + 2. Take `n` derivatives of the solution, where `n` is the order of + ``ode``, and check to see if that is equal to the solution. This only + works on exact ODEs. + 3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time + solving for the derivative of `f` of that order (this will always be + possible because `f` is a linear operator). Then back substitute each + derivative into ``ode`` in reverse order. + + This function returns a tuple. The first item in the tuple is ``True`` if + the substitution results in ``0``, and ``False`` otherwise. The second + item in the tuple is what the substitution results in. It should always + be ``0`` if the first item is ``True``. Sometimes this function will + return ``False`` even when an expression is identically equal to ``0``. + This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not + reduce the expression to ``0``. If an expression returned by this + function vanishes identically, then ``sol`` really is a solution to + the ``ode``. + + If this function seems to hang, it is probably because of a hard + simplification. + + To use this function to test, test the first item of the tuple. + + Examples + ======== + + >>> from sympy import (Eq, Function, checkodesol, symbols, + ... Derivative, exp) + >>> x, C1, C2 = symbols('x,C1,C2') + >>> f, g = symbols('f g', cls=Function) + >>> checkodesol(f(x).diff(x), Eq(f(x), C1)) + (True, 0) + >>> assert checkodesol(f(x).diff(x), C1)[0] + >>> assert not checkodesol(f(x).diff(x), x)[0] + >>> checkodesol(f(x).diff(x, 2), x**2) + (False, 2) + + >>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))] + >>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + """ + if iterable(ode): + return checksysodesol(ode, sol, func=func) + + if not isinstance(ode, Equality): + ode = Eq(ode, 0) + if func is None: + try: + _, func = _preprocess(ode.lhs) + except ValueError: + funcs = [s.atoms(AppliedUndef) for s in ( + sol if is_sequence(sol, set) else [sol])] + funcs = set().union(*funcs) + if len(funcs) != 1: + raise ValueError( + 'must pass func arg to checkodesol for this case.') + func = funcs.pop() + if not isinstance(func, AppliedUndef) or len(func.args) != 1: + raise ValueError( + "func must be a function of one variable, not %s" % func) + if is_sequence(sol, set): + return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol]) + + if not isinstance(sol, Equality): + sol = Eq(func, sol) + elif sol.rhs == func: + sol = sol.reversed + + if order == 'auto': + order = ode_order(ode, func) + solved = sol.lhs == func and not sol.rhs.has(func) + if solve_for_func and not solved: + rhs = solve(sol, func) + if rhs: + eqs = [Eq(func, t) for t in rhs] + if len(rhs) == 1: + eqs = eqs[0] + return checkodesol(ode, eqs, order=order, + solve_for_func=False) + + x = func.args[0] + + # Handle series solutions here + if sol.has(Order): + assert sol.lhs == func + Oterm = sol.rhs.getO() + solrhs = sol.rhs.removeO() + + Oexpr = Oterm.expr + assert isinstance(Oexpr, Pow) + sorder = Oexpr.exp + assert Oterm == Order(x**sorder) + + odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand() + + neworder = Order(x**(sorder - order)) + odesubs = odesubs + neworder + assert odesubs.getO() == neworder + residual = odesubs.removeO() + + return (residual == 0, residual) + + s = True + testnum = 0 + while s: + if testnum == 0: + # First pass, try substituting a solved solution directly into the + # ODE. This has the highest chance of succeeding. + ode_diff = ode.lhs - ode.rhs + + if sol.lhs == func: + s = sub_func_doit(ode_diff, func, sol.rhs) + s = besselsimp(s) + else: + testnum += 1 + continue + ss = simplify(s.rewrite(exp)) + if ss: + # with the new numer_denom in power.py, if we do a simple + # expansion then testnum == 0 verifies all solutions. + s = ss.expand(force=True) + else: + s = 0 + testnum += 1 + elif testnum == 1: + # Second pass. If we cannot substitute f, try seeing if the nth + # derivative is equal, this will only work for odes that are exact, + # by definition. + s = simplify( + trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) - + trigsimp(ode.lhs) + trigsimp(ode.rhs)) + # s2 = simplify( + # diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \ + # ode.lhs + ode.rhs) + testnum += 1 + elif testnum == 2: + # Third pass. Try solving for df/dx and substituting that into the + # ODE. Thanks to Chris Smith for suggesting this method. Many of + # the comments below are his, too. + # The method: + # - Take each of 1..n derivatives of the solution. + # - Solve each nth derivative for d^(n)f/dx^(n) + # (the differential of that order) + # - Back substitute into the ODE in decreasing order + # (i.e., n, n-1, ...) + # - Check the result for zero equivalence + if sol.lhs == func and not sol.rhs.has(func): + diffsols = {0: sol.rhs} + elif sol.rhs == func and not sol.lhs.has(func): + diffsols = {0: sol.lhs} + else: + diffsols = {} + sol = sol.lhs - sol.rhs + for i in range(1, order + 1): + # Differentiation is a linear operator, so there should always + # be 1 solution. Nonetheless, we test just to make sure. + # We only need to solve once. After that, we automatically + # have the solution to the differential in the order we want. + if i == 1: + ds = sol.diff(x) + try: + sdf = solve(ds, func.diff(x, i)) + if not sdf: + raise NotImplementedError + except NotImplementedError: + testnum += 1 + break + else: + diffsols[i] = sdf[0] + else: + # This is what the solution says df/dx should be. + diffsols[i] = diffsols[i - 1].diff(x) + + # Make sure the above didn't fail. + if testnum > 2: + continue + else: + # Substitute it into ODE to check for self consistency. + lhs, rhs = ode.lhs, ode.rhs + for i in range(order, -1, -1): + if i == 0 and 0 not in diffsols: + # We can only substitute f(x) if the solution was + # solved for f(x). + break + lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i]) + rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i]) + ode_or_bool = Eq(lhs, rhs) + ode_or_bool = simplify(ode_or_bool) + + if isinstance(ode_or_bool, (bool, BooleanAtom)): + if ode_or_bool: + lhs = rhs = S.Zero + else: + lhs = ode_or_bool.lhs + rhs = ode_or_bool.rhs + # No sense in overworking simplify -- just prove that the + # numerator goes to zero + num = trigsimp((lhs - rhs).as_numer_denom()[0]) + # since solutions are obtained using force=True we test + # using the same level of assumptions + ## replace function with dummy so assumptions will work + _func = Dummy('func') + num = num.subs(func, _func) + ## posify the expression + num, reps = posify(num) + s = simplify(num).xreplace(reps).xreplace({_func: func}) + testnum += 1 + else: + break + + if not s: + return (True, s) + elif s is True: # The code above never was able to change s + raise NotImplementedError("Unable to test if " + str(sol) + + " is a solution to " + str(ode) + ".") + else: + return (False, s) + + +def checksysodesol(eqs, sols, func=None): + r""" + Substitutes corresponding ``sols`` for each functions into each ``eqs`` and + checks that the result of substitutions for each equation is ``0``. The + equations and solutions passed can be any iterable. + + This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`. + For each function, ``sols`` can have a single solution or a list of solutions. + In most cases it will not be necessary to explicitly identify the function, + but if the function cannot be inferred from the original equation it + can be supplied through the ``func`` argument. + + When a sequence of equations is passed, the same sequence is used to return + the result for each equation with each function substituted with corresponding + solutions. + + It tries the following method to find zero equivalence for each equation: + + Substitute the solutions for functions, like `x(t)` and `y(t)` into the + original equations containing those functions. + This function returns a tuple. The first item in the tuple is ``True`` if + the substitution results for each equation is ``0``, and ``False`` otherwise. + The second item in the tuple is what the substitution results in. Each element + of the ``list`` should always be ``0`` corresponding to each equation if the + first item is ``True``. Note that sometimes this function may return ``False``, + but with an expression that is identically equal to ``0``, instead of returning + ``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot + reduce the expression to ``0``. If an expression returned by each function + vanishes identically, then ``sols`` really is a solution to ``eqs``. + + If this function seems to hang, it is probably because of a difficult simplification. + + Examples + ======== + + >>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function + >>> from sympy.solvers.ode.subscheck import checksysodesol + >>> C1, C2 = symbols('C1:3') + >>> t = symbols('t') + >>> x, y = symbols('x, y', cls=Function) + >>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12)) + >>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3), + ... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)] + >>> checksysodesol(eq, sol) + (True, [0, 0]) + >>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3)) + >>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2), + ... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)] + >>> checksysodesol(eq, sol) + (True, [0, 0]) + + """ + def _sympify(eq): + return list(map(sympify, eq if iterable(eq) else [eq])) + eqs = _sympify(eqs) + for i in range(len(eqs)): + if isinstance(eqs[i], Equality): + eqs[i] = eqs[i].lhs - eqs[i].rhs + if func is None: + funcs = [] + for eq in eqs: + derivs = eq.atoms(Derivative) + func = set().union(*[d.atoms(AppliedUndef) for d in derivs]) + funcs.extend(func) + funcs = list(set(funcs)) + if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\ + and len({func.args for func in funcs})!=1: + raise ValueError("func must be a function of one variable, not %s" % func) + for sol in sols: + if len(sol.atoms(AppliedUndef)) != 1: + raise ValueError("solutions should have one function only") + if len(funcs) != len({sol.lhs for sol in sols}): + raise ValueError("number of solutions provided does not match the number of equations") + dictsol = {} + for sol in sols: + func = list(sol.atoms(AppliedUndef))[0] + if sol.rhs == func: + sol = sol.reversed + solved = sol.lhs == func and not sol.rhs.has(func) + if not solved: + rhs = solve(sol, func) + if not rhs: + raise NotImplementedError + else: + rhs = sol.rhs + dictsol[func] = rhs + checkeq = [] + for eq in eqs: + for func in funcs: + eq = sub_func_doit(eq, func, dictsol[func]) + ss = simplify(eq) + if ss != 0: + eq = ss.expand(force=True) + if eq != 0: + eq = sqrtdenest(eq).simplify() + else: + eq = 0 + checkeq.append(eq) + if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0: + return (True, checkeq) + else: + return (False, checkeq) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/systems.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/systems.py new file mode 100644 index 0000000000000000000000000000000000000000..2d2c9b57a969c7fb5c67c06ce952fa398e22a48d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/systems.py @@ -0,0 +1,2135 @@ +from sympy.core import Add, Mul, S +from sympy.core.containers import Tuple +from sympy.core.exprtools import factor_terms +from sympy.core.numbers import I +from sympy.core.relational import Eq, Equality +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Dummy, Symbol +from sympy.core.function import (expand_mul, expand, Derivative, + AppliedUndef, Function, Subs) +from sympy.functions import (exp, im, cos, sin, re, Piecewise, + piecewise_fold, sqrt, log) +from sympy.functions.combinatorial.factorials import factorial +from sympy.matrices import zeros, Matrix, NonSquareMatrixError, MatrixBase, eye +from sympy.polys import Poly, together +from sympy.simplify import collect, radsimp, signsimp # type: ignore +from sympy.simplify.powsimp import powdenest, powsimp +from sympy.simplify.ratsimp import ratsimp +from sympy.simplify.simplify import simplify +from sympy.sets.sets import FiniteSet +from sympy.solvers.deutils import ode_order +from sympy.solvers.solveset import NonlinearError, solveset +from sympy.utilities.iterables import (connected_components, iterable, + strongly_connected_components) +from sympy.utilities.misc import filldedent +from sympy.integrals.integrals import Integral, integrate + + +def _get_func_order(eqs, funcs): + return {func: max(ode_order(eq, func) for eq in eqs) for func in funcs} + + +class ODEOrderError(ValueError): + """Raised by linear_ode_to_matrix if the system has the wrong order""" + pass + + +class ODENonlinearError(NonlinearError): + """Raised by linear_ode_to_matrix if the system is nonlinear""" + pass + + +def _simpsol(soleq): + lhs = soleq.lhs + sol = soleq.rhs + sol = powsimp(sol) + gens = list(sol.atoms(exp)) + p = Poly(sol, *gens, expand=False) + gens = [factor_terms(g) for g in gens] + if not gens: + gens = p.gens + syms = [Symbol('C1'), Symbol('C2')] + terms = [] + for coeff, monom in zip(p.coeffs(), p.monoms()): + coeff = piecewise_fold(coeff) + if isinstance(coeff, Piecewise): + coeff = Piecewise(*((ratsimp(coef).collect(syms), cond) for coef, cond in coeff.args)) + else: + coeff = ratsimp(coeff).collect(syms) + monom = Mul(*(g ** i for g, i in zip(gens, monom))) + terms.append(coeff * monom) + return Eq(lhs, Add(*terms)) + + +def _solsimp(e, t): + no_t, has_t = powsimp(expand_mul(e)).as_independent(t) + + no_t = ratsimp(no_t) + has_t = has_t.replace(exp, lambda a: exp(factor_terms(a))) + + return no_t + has_t + + +def simpsol(sol, wrt1, wrt2, doit=True): + """Simplify solutions from dsolve_system.""" + + # The parameter sol is the solution as returned by dsolve (list of Eq). + # + # The parameters wrt1 and wrt2 are lists of symbols to be collected for + # with those in wrt1 being collected for first. This allows for collecting + # on any factors involving the independent variable before collecting on + # the integration constants or vice versa using e.g.: + # + # sol = simpsol(sol, [t], [C1, C2]) # t first, constants after + # sol = simpsol(sol, [C1, C2], [t]) # constants first, t after + # + # If doit=True (default) then simpsol will begin by evaluating any + # unevaluated integrals. Since many integrals will appear multiple times + # in the solutions this is done intelligently by computing each integral + # only once. + # + # The strategy is to first perform simple cancellation with factor_terms + # and then multiply out all brackets with expand_mul. This gives an Add + # with many terms. + # + # We split each term into two multiplicative factors dep and coeff where + # all factors that involve wrt1 are in dep and any constant factors are in + # coeff e.g. + # sqrt(2)*C1*exp(t) -> ( exp(t), sqrt(2)*C1 ) + # + # The dep factors are simplified using powsimp to combine expanded + # exponential factors e.g. + # exp(a*t)*exp(b*t) -> exp(t*(a+b)) + # + # We then collect coefficients for all terms having the same (simplified) + # dep. The coefficients are then simplified using together and ratsimp and + # lastly by recursively applying the same transformation to the + # coefficients to collect on wrt2. + # + # Finally the result is recombined into an Add and signsimp is used to + # normalise any minus signs. + + def simprhs(rhs, rep, wrt1, wrt2): + """Simplify the rhs of an ODE solution""" + if rep: + rhs = rhs.subs(rep) + rhs = factor_terms(rhs) + rhs = simp_coeff_dep(rhs, wrt1, wrt2) + rhs = signsimp(rhs) + return rhs + + def simp_coeff_dep(expr, wrt1, wrt2=None): + """Split rhs into terms, split terms into dep and coeff and collect on dep""" + add_dep_terms = lambda e: e.is_Add and e.has(*wrt1) + expandable = lambda e: e.is_Mul and any(map(add_dep_terms, e.args)) + expand_func = lambda e: expand_mul(e, deep=False) + expand_mul_mod = lambda e: e.replace(expandable, expand_func) + terms = Add.make_args(expand_mul_mod(expr)) + dc = {} + for term in terms: + coeff, dep = term.as_independent(*wrt1, as_Add=False) + # Collect together the coefficients for terms that have the same + # dependence on wrt1 (after dep is normalised using simpdep). + dep = simpdep(dep, wrt1) + + # See if the dependence on t cancels out... + if dep is not S.One: + dep2 = factor_terms(dep) + if not dep2.has(*wrt1): + coeff *= dep2 + dep = S.One + + if dep not in dc: + dc[dep] = coeff + else: + dc[dep] += coeff + # Apply the method recursively to the coefficients but this time + # collecting on wrt2 rather than wrt2. + termpairs = ((simpcoeff(c, wrt2), d) for d, c in dc.items()) + if wrt2 is not None: + termpairs = ((simp_coeff_dep(c, wrt2), d) for c, d in termpairs) + return Add(*(c * d for c, d in termpairs)) + + def simpdep(term, wrt1): + """Normalise factors involving t with powsimp and recombine exp""" + def canonicalise(a): + # Using factor_terms here isn't quite right because it leads to things + # like exp(t*(1+t)) that we don't want. We do want to cancel factors + # and pull out a common denominator but ideally the numerator would be + # expressed as a standard form polynomial in t so we expand_mul + # and collect afterwards. + a = factor_terms(a) + num, den = a.as_numer_denom() + num = expand_mul(num) + num = collect(num, wrt1) + return num / den + + term = powsimp(term) + rep = {e: exp(canonicalise(e.args[0])) for e in term.atoms(exp)} + term = term.subs(rep) + return term + + def simpcoeff(coeff, wrt2): + """Bring to a common fraction and cancel with ratsimp""" + coeff = together(coeff) + if coeff.is_polynomial(): + # Calling ratsimp can be expensive. The main reason is to simplify + # sums of terms with irrational denominators so we limit ourselves + # to the case where the expression is polynomial in any symbols. + # Maybe there's a better approach... + coeff = ratsimp(radsimp(coeff)) + # collect on secondary variables first and any remaining symbols after + if wrt2 is not None: + syms = list(wrt2) + list(ordered(coeff.free_symbols - set(wrt2))) + else: + syms = list(ordered(coeff.free_symbols)) + coeff = collect(coeff, syms) + coeff = together(coeff) + return coeff + + # There are often repeated integrals. Collect unique integrals and + # evaluate each once and then substitute into the final result to replace + # all occurrences in each of the solution equations. + if doit: + integrals = set().union(*(s.atoms(Integral) for s in sol)) + rep = {i: factor_terms(i).doit() for i in integrals} + else: + rep = {} + + sol = [Eq(s.lhs, simprhs(s.rhs, rep, wrt1, wrt2)) for s in sol] + return sol + + +def linodesolve_type(A, t, b=None): + r""" + Helper function that determines the type of the system of ODEs for solving with :obj:`sympy.solvers.ode.systems.linodesolve()` + + Explanation + =========== + + This function takes in the coefficient matrix and/or the non-homogeneous term + and returns the type of the equation that can be solved by :obj:`sympy.solvers.ode.systems.linodesolve()`. + + If the system is constant coefficient homogeneous, then "type1" is returned + + If the system is constant coefficient non-homogeneous, then "type2" is returned + + If the system is non-constant coefficient homogeneous, then "type3" is returned + + If the system is non-constant coefficient non-homogeneous, then "type4" is returned + + If the system has a non-constant coefficient matrix which can be factorized into constant + coefficient matrix, then "type5" or "type6" is returned for when the system is homogeneous or + non-homogeneous respectively. + + Note that, if the system of ODEs is of "type3" or "type4", then along with the type, + the commutative antiderivative of the coefficient matrix is also returned. + + If the system cannot be solved by :obj:`sympy.solvers.ode.systems.linodesolve()`, then + NotImplementedError is raised. + + Parameters + ========== + + A : Matrix + Coefficient matrix of the system of ODEs + b : Matrix or None + Non-homogeneous term of the system. The default value is None. + If this argument is None, then the system is assumed to be homogeneous. + + Examples + ======== + + >>> from sympy import symbols, Matrix + >>> from sympy.solvers.ode.systems import linodesolve_type + >>> t = symbols("t") + >>> A = Matrix([[1, 1], [2, 3]]) + >>> b = Matrix([t, 1]) + + >>> linodesolve_type(A, t) + {'antiderivative': None, 'type_of_equation': 'type1'} + + >>> linodesolve_type(A, t, b=b) + {'antiderivative': None, 'type_of_equation': 'type2'} + + >>> A_t = Matrix([[1, t], [-t, 1]]) + + >>> linodesolve_type(A_t, t) + {'antiderivative': Matrix([ + [ t, t**2/2], + [-t**2/2, t]]), 'type_of_equation': 'type3'} + + >>> linodesolve_type(A_t, t, b=b) + {'antiderivative': Matrix([ + [ t, t**2/2], + [-t**2/2, t]]), 'type_of_equation': 'type4'} + + >>> A_non_commutative = Matrix([[1, t], [t, -1]]) + >>> linodesolve_type(A_non_commutative, t) + Traceback (most recent call last): + ... + NotImplementedError: + The system does not have a commutative antiderivative, it cannot be + solved by linodesolve. + + Returns + ======= + + Dict + + Raises + ====== + + NotImplementedError + When the coefficient matrix does not have a commutative antiderivative + + See Also + ======== + + linodesolve: Function for which linodesolve_type gets the information + + """ + + match = {} + is_non_constant = not _matrix_is_constant(A, t) + is_non_homogeneous = not (b is None or b.is_zero_matrix) + type = "type{}".format(int("{}{}".format(int(is_non_constant), int(is_non_homogeneous)), 2) + 1) + + B = None + match.update({"type_of_equation": type, "antiderivative": B}) + + if is_non_constant: + B, is_commuting = _is_commutative_anti_derivative(A, t) + if not is_commuting: + raise NotImplementedError(filldedent(''' + The system does not have a commutative antiderivative, it cannot be solved + by linodesolve. + ''')) + + match['antiderivative'] = B + match.update(_first_order_type5_6_subs(A, t, b=b)) + + return match + + +def _first_order_type5_6_subs(A, t, b=None): + match = {} + + factor_terms = _factor_matrix(A, t) + is_homogeneous = b is None or b.is_zero_matrix + + if factor_terms is not None: + t_ = Symbol("{}_".format(t)) + F_t = integrate(factor_terms[0], t) + inverse = solveset(Eq(t_, F_t), t) + + # Note: A simple way to check if a function is invertible + # or not. + if isinstance(inverse, FiniteSet) and not inverse.has(Piecewise)\ + and len(inverse) == 1: + + A = factor_terms[1] + if not is_homogeneous: + b = b / factor_terms[0] + b = b.subs(t, list(inverse)[0]) + type = "type{}".format(5 + (not is_homogeneous)) + match.update({'func_coeff': A, 'tau': F_t, + 't_': t_, 'type_of_equation': type, 'rhs': b}) + + return match + + +def linear_ode_to_matrix(eqs, funcs, t, order): + r""" + Convert a linear system of ODEs to matrix form + + Explanation + =========== + + Express a system of linear ordinary differential equations as a single + matrix differential equation [1]. For example the system $x' = x + y + 1$ + and $y' = x - y$ can be represented as + + .. math:: A_1 X' = A_0 X + b + + where $A_1$ and $A_0$ are $2 \times 2$ matrices and $b$, $X$ and $X'$ are + $2 \times 1$ matrices with $X = [x, y]^T$. + + Higher-order systems are represented with additional matrices e.g. a + second-order system would look like + + .. math:: A_2 X'' = A_1 X' + A_0 X + b + + Examples + ======== + + >>> from sympy import Function, Symbol, Matrix, Eq + >>> from sympy.solvers.ode.systems import linear_ode_to_matrix + >>> t = Symbol('t') + >>> x = Function('x') + >>> y = Function('y') + + We can create a system of linear ODEs like + + >>> eqs = [ + ... Eq(x(t).diff(t), x(t) + y(t) + 1), + ... Eq(y(t).diff(t), x(t) - y(t)), + ... ] + >>> funcs = [x(t), y(t)] + >>> order = 1 # 1st order system + + Now ``linear_ode_to_matrix`` can represent this as a matrix + differential equation. + + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, order) + >>> A1 + Matrix([ + [1, 0], + [0, 1]]) + >>> A0 + Matrix([ + [1, 1], + [1, -1]]) + >>> b + Matrix([ + [1], + [0]]) + + The original equations can be recovered from these matrices: + + >>> eqs_mat = Matrix([eq.lhs - eq.rhs for eq in eqs]) + >>> X = Matrix(funcs) + >>> A1 * X.diff(t) - A0 * X - b == eqs_mat + True + + If the system of equations has a maximum order greater than the + order of the system specified, a ODEOrderError exception is raised. + + >>> eqs = [Eq(x(t).diff(t, 2), x(t).diff(t) + x(t)), Eq(y(t).diff(t), y(t) + x(t))] + >>> linear_ode_to_matrix(eqs, funcs, t, 1) + Traceback (most recent call last): + ... + ODEOrderError: Cannot represent system in 1-order form + + If the system of equations is nonlinear, then ODENonlinearError is + raised. + + >>> eqs = [Eq(x(t).diff(t), x(t) + y(t)), Eq(y(t).diff(t), y(t)**2 + x(t))] + >>> linear_ode_to_matrix(eqs, funcs, t, 1) + Traceback (most recent call last): + ... + ODENonlinearError: The system of ODEs is nonlinear. + + Parameters + ========== + + eqs : list of SymPy expressions or equalities + The equations as expressions (assumed equal to zero). + funcs : list of applied functions + The dependent variables of the system of ODEs. + t : symbol + The independent variable. + order : int + The order of the system of ODEs. + + Returns + ======= + + The tuple ``(As, b)`` where ``As`` is a tuple of matrices and ``b`` is the + the matrix representing the rhs of the matrix equation. + + Raises + ====== + + ODEOrderError + When the system of ODEs have an order greater than what was specified + ODENonlinearError + When the system of ODEs is nonlinear + + See Also + ======== + + linear_eq_to_matrix: for systems of linear algebraic equations. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_differential_equation + + """ + from sympy.solvers.solveset import linear_eq_to_matrix + + if any(ode_order(eq, func) > order for eq in eqs for func in funcs): + msg = "Cannot represent system in {}-order form" + raise ODEOrderError(msg.format(order)) + + As = [] + + for o in range(order, -1, -1): + # Work from the highest derivative down + syms = [func.diff(t, o) for func in funcs] + + # Ai is the matrix for X(t).diff(t, o) + # eqs is minus the remainder of the equations. + try: + Ai, b = linear_eq_to_matrix(eqs, syms) + except NonlinearError: + raise ODENonlinearError("The system of ODEs is nonlinear.") + + Ai = Ai.applyfunc(expand_mul) + + As.append(Ai if o == order else -Ai) + + if o: + eqs = [-eq for eq in b] + else: + rhs = b + + return As, rhs + + +def matrix_exp(A, t): + r""" + Matrix exponential $\exp(A*t)$ for the matrix ``A`` and scalar ``t``. + + Explanation + =========== + + This functions returns the $\exp(A*t)$ by doing a simple + matrix multiplication: + + .. math:: \exp(A*t) = P * expJ * P^{-1} + + where $expJ$ is $\exp(J*t)$. $J$ is the Jordan normal + form of $A$ and $P$ is matrix such that: + + .. math:: A = P * J * P^{-1} + + The matrix exponential $\exp(A*t)$ appears in the solution of linear + differential equations. For example if $x$ is a vector and $A$ is a matrix + then the initial value problem + + .. math:: \frac{dx(t)}{dt} = A \times x(t), x(0) = x0 + + has the unique solution + + .. math:: x(t) = \exp(A t) x0 + + Examples + ======== + + >>> from sympy import Symbol, Matrix, pprint + >>> from sympy.solvers.ode.systems import matrix_exp + >>> t = Symbol('t') + + We will consider a 2x2 matrix for comupting the exponential + + >>> A = Matrix([[2, -5], [2, -4]]) + >>> pprint(A) + [2 -5] + [ ] + [2 -4] + + Now, exp(A*t) is given as follows: + + >>> pprint(matrix_exp(A, t)) + [ -t -t -t ] + [3*e *sin(t) + e *cos(t) -5*e *sin(t) ] + [ ] + [ -t -t -t ] + [ 2*e *sin(t) - 3*e *sin(t) + e *cos(t)] + + Parameters + ========== + + A : Matrix + The matrix $A$ in the expression $\exp(A*t)$ + t : Symbol + The independent variable + + See Also + ======== + + matrix_exp_jordan_form: For exponential of Jordan normal form + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jordan_normal_form + .. [2] https://en.wikipedia.org/wiki/Matrix_exponential + + """ + P, expJ = matrix_exp_jordan_form(A, t) + return P * expJ * P.inv() + + +def matrix_exp_jordan_form(A, t): + r""" + Matrix exponential $\exp(A*t)$ for the matrix *A* and scalar *t*. + + Explanation + =========== + + Returns the Jordan form of the $\exp(A*t)$ along with the matrix $P$ such that: + + .. math:: + \exp(A*t) = P * expJ * P^{-1} + + Examples + ======== + + >>> from sympy import Matrix, Symbol + >>> from sympy.solvers.ode.systems import matrix_exp, matrix_exp_jordan_form + >>> t = Symbol('t') + + We will consider a 2x2 defective matrix. This shows that our method + works even for defective matrices. + + >>> A = Matrix([[1, 1], [0, 1]]) + + It can be observed that this function gives us the Jordan normal form + and the required invertible matrix P. + + >>> P, expJ = matrix_exp_jordan_form(A, t) + + Here, it is shown that P and expJ returned by this function is correct + as they satisfy the formula: P * expJ * P_inverse = exp(A*t). + + >>> P * expJ * P.inv() == matrix_exp(A, t) + True + + Parameters + ========== + + A : Matrix + The matrix $A$ in the expression $\exp(A*t)$ + t : Symbol + The independent variable + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Defective_matrix + .. [2] https://en.wikipedia.org/wiki/Jordan_matrix + .. [3] https://en.wikipedia.org/wiki/Jordan_normal_form + + """ + + N, M = A.shape + if N != M: + raise ValueError('Needed square matrix but got shape (%s, %s)' % (N, M)) + elif A.has(t): + raise ValueError('Matrix A should not depend on t') + + def jordan_chains(A): + '''Chains from Jordan normal form analogous to M.eigenvects(). + Returns a dict with eignevalues as keys like: + {e1: [[v111,v112,...], [v121, v122,...]], e2:...} + where vijk is the kth vector in the jth chain for eigenvalue i. + ''' + P, blocks = A.jordan_cells() + basis = [P[:,i] for i in range(P.shape[1])] + n = 0 + chains = {} + for b in blocks: + eigval = b[0, 0] + size = b.shape[0] + if eigval not in chains: + chains[eigval] = [] + chains[eigval].append(basis[n:n+size]) + n += size + return chains + + eigenchains = jordan_chains(A) + + # Needed for consistency across Python versions + eigenchains_iter = sorted(eigenchains.items(), key=default_sort_key) + isreal = not A.has(I) + + blocks = [] + vectors = [] + seen_conjugate = set() + for e, chains in eigenchains_iter: + for chain in chains: + n = len(chain) + if isreal and e != e.conjugate() and e.conjugate() in eigenchains: + if e in seen_conjugate: + continue + seen_conjugate.add(e.conjugate()) + exprt = exp(re(e) * t) + imrt = im(e) * t + imblock = Matrix([[cos(imrt), sin(imrt)], + [-sin(imrt), cos(imrt)]]) + expJblock2 = Matrix(n, n, lambda i,j: + imblock * t**(j-i) / factorial(j-i) if j >= i + else zeros(2, 2)) + expJblock = Matrix(2*n, 2*n, lambda i,j: expJblock2[i//2,j//2][i%2,j%2]) + + blocks.append(exprt * expJblock) + for i in range(n): + vectors.append(re(chain[i])) + vectors.append(im(chain[i])) + else: + vectors.extend(chain) + fun = lambda i,j: t**(j-i)/factorial(j-i) if j >= i else 0 + expJblock = Matrix(n, n, fun) + blocks.append(exp(e * t) * expJblock) + + expJ = Matrix.diag(*blocks) + P = Matrix(N, N, lambda i,j: vectors[j][i]) + + return P, expJ + + +# Note: To add a docstring example with tau +def linodesolve(A, t, b=None, B=None, type="auto", doit=False, + tau=None): + r""" + System of n equations linear first-order differential equations + + Explanation + =========== + + This solver solves the system of ODEs of the following form: + + .. math:: + X'(t) = A(t) X(t) + b(t) + + Here, $A(t)$ is the coefficient matrix, $X(t)$ is the vector of n independent variables, + $b(t)$ is the non-homogeneous term and $X'(t)$ is the derivative of $X(t)$ + + Depending on the properties of $A(t)$ and $b(t)$, this solver evaluates the solution + differently. + + When $A(t)$ is constant coefficient matrix and $b(t)$ is zero vector i.e. system is homogeneous, + the system is "type1". The solution is: + + .. math:: + X(t) = \exp(A t) C + + Here, $C$ is a vector of constants and $A$ is the constant coefficient matrix. + + When $A(t)$ is constant coefficient matrix and $b(t)$ is non-zero i.e. system is non-homogeneous, + the system is "type2". The solution is: + + .. math:: + X(t) = e^{A t} ( \int e^{- A t} b \,dt + C) + + When $A(t)$ is coefficient matrix such that its commutative with its antiderivative $B(t)$ and + $b(t)$ is a zero vector i.e. system is homogeneous, the system is "type3". The solution is: + + .. math:: + X(t) = \exp(B(t)) C + + When $A(t)$ is commutative with its antiderivative $B(t)$ and $b(t)$ is non-zero i.e. system is + non-homogeneous, the system is "type4". The solution is: + + .. math:: + X(t) = e^{B(t)} ( \int e^{-B(t)} b(t) \,dt + C) + + When $A(t)$ is a coefficient matrix such that it can be factorized into a scalar and a constant + coefficient matrix: + + .. math:: + A(t) = f(t) * A + + Where $f(t)$ is a scalar expression in the independent variable $t$ and $A$ is a constant matrix, + then we can do the following substitutions: + + .. math:: + tau = \int f(t) dt, X(t) = Y(tau), b(t) = b(f^{-1}(tau)) + + Here, the substitution for the non-homogeneous term is done only when its non-zero. + Using these substitutions, our original system becomes: + + .. math:: + Y'(tau) = A * Y(tau) + b(tau)/f(tau) + + The above system can be easily solved using the solution for "type1" or "type2" depending + on the homogeneity of the system. After we get the solution for $Y(tau)$, we substitute the + solution for $tau$ as $t$ to get back $X(t)$ + + .. math:: + X(t) = Y(tau) + + Systems of "type5" and "type6" have a commutative antiderivative but we use this solution + because its faster to compute. + + The final solution is the general solution for all the four equations since a constant coefficient + matrix is always commutative with its antidervative. + + An additional feature of this function is, if someone wants to substitute for value of the independent + variable, they can pass the substitution `tau` and the solution will have the independent variable + substituted with the passed expression(`tau`). + + Parameters + ========== + + A : Matrix + Coefficient matrix of the system of linear first order ODEs. + t : Symbol + Independent variable in the system of ODEs. + b : Matrix or None + Non-homogeneous term in the system of ODEs. If None is passed, + a homogeneous system of ODEs is assumed. + B : Matrix or None + Antiderivative of the coefficient matrix. If the antiderivative + is not passed and the solution requires the term, then the solver + would compute it internally. + type : String + Type of the system of ODEs passed. Depending on the type, the + solution is evaluated. The type values allowed and the corresponding + system it solves are: "type1" for constant coefficient homogeneous + "type2" for constant coefficient non-homogeneous, "type3" for non-constant + coefficient homogeneous, "type4" for non-constant coefficient non-homogeneous, + "type5" and "type6" for non-constant coefficient homogeneous and non-homogeneous + systems respectively where the coefficient matrix can be factorized to a constant + coefficient matrix. + The default value is "auto" which will let the solver decide the correct type of + the system passed. + doit : Boolean + Evaluate the solution if True, default value is False + tau: Expression + Used to substitute for the value of `t` after we get the solution of the system. + + Examples + ======== + + To solve the system of ODEs using this function directly, several things must be + done in the right order. Wrong inputs to the function will lead to incorrect results. + + >>> from sympy import symbols, Function, Eq + >>> from sympy.solvers.ode.systems import canonical_odes, linear_ode_to_matrix, linodesolve, linodesolve_type + >>> from sympy.solvers.ode.subscheck import checkodesol + >>> f, g = symbols("f, g", cls=Function) + >>> x, a = symbols("x, a") + >>> funcs = [f(x), g(x)] + >>> eqs = [Eq(f(x).diff(x) - f(x), a*g(x) + 1), Eq(g(x).diff(x) + g(x), a*f(x))] + + Here, it is important to note that before we derive the coefficient matrix, it is + important to get the system of ODEs into the desired form. For that we will use + :obj:`sympy.solvers.ode.systems.canonical_odes()`. + + >>> eqs = canonical_odes(eqs, funcs, x) + >>> eqs + [[Eq(Derivative(f(x), x), a*g(x) + f(x) + 1), Eq(Derivative(g(x), x), a*f(x) - g(x))]] + + Now, we will use :obj:`sympy.solvers.ode.systems.linear_ode_to_matrix()` to get the coefficient matrix and the + non-homogeneous term if it is there. + + >>> eqs = eqs[0] + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, x, 1) + >>> A = A0 + + We have the coefficient matrices and the non-homogeneous term ready. Now, we can use + :obj:`sympy.solvers.ode.systems.linodesolve_type()` to get the information for the system of ODEs + to finally pass it to the solver. + + >>> system_info = linodesolve_type(A, x, b=b) + >>> sol_vector = linodesolve(A, x, b=b, B=system_info['antiderivative'], type=system_info['type_of_equation']) + + Now, we can prove if the solution is correct or not by using :obj:`sympy.solvers.ode.checkodesol()` + + >>> sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + We can also use the doit method to evaluate the solutions passed by the function. + + >>> sol_vector_evaluated = linodesolve(A, x, b=b, type="type2", doit=True) + + Now, we will look at a system of ODEs which is non-constant. + + >>> eqs = [Eq(f(x).diff(x), f(x) + x*g(x)), Eq(g(x).diff(x), -x*f(x) + g(x))] + + The system defined above is already in the desired form, so we do not have to convert it. + + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, x, 1) + >>> A = A0 + + A user can also pass the commutative antiderivative required for type3 and type4 system of ODEs. + Passing an incorrect one will lead to incorrect results. If the coefficient matrix is not commutative + with its antiderivative, then :obj:`sympy.solvers.ode.systems.linodesolve_type()` raises a NotImplementedError. + If it does have a commutative antiderivative, then the function just returns the information about the system. + + >>> system_info = linodesolve_type(A, x, b=b) + + Now, we can pass the antiderivative as an argument to get the solution. If the system information is not + passed, then the solver will compute the required arguments internally. + + >>> sol_vector = linodesolve(A, x, b=b) + + Once again, we can verify the solution obtained. + + >>> sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + Returns + ======= + + List + + Raises + ====== + + ValueError + This error is raised when the coefficient matrix, non-homogeneous term + or the antiderivative, if passed, are not a matrix or + do not have correct dimensions + NonSquareMatrixError + When the coefficient matrix or its antiderivative, if passed is not a + square matrix + NotImplementedError + If the coefficient matrix does not have a commutative antiderivative + + See Also + ======== + + linear_ode_to_matrix: Coefficient matrix computation function + canonical_odes: System of ODEs representation change + linodesolve_type: Getting information about systems of ODEs to pass in this solver + + """ + + if not isinstance(A, MatrixBase): + raise ValueError(filldedent('''\ + The coefficients of the system of ODEs should be of type Matrix + ''')) + + if not A.is_square: + raise NonSquareMatrixError(filldedent('''\ + The coefficient matrix must be a square + ''')) + + if b is not None: + if not isinstance(b, MatrixBase): + raise ValueError(filldedent('''\ + The non-homogeneous terms of the system of ODEs should be of type Matrix + ''')) + + if A.rows != b.rows: + raise ValueError(filldedent('''\ + The system of ODEs should have the same number of non-homogeneous terms and the number of + equations + ''')) + + if B is not None: + if not isinstance(B, MatrixBase): + raise ValueError(filldedent('''\ + The antiderivative of coefficients of the system of ODEs should be of type Matrix + ''')) + + if not B.is_square: + raise NonSquareMatrixError(filldedent('''\ + The antiderivative of the coefficient matrix must be a square + ''')) + + if A.rows != B.rows: + raise ValueError(filldedent('''\ + The coefficient matrix and its antiderivative should have same dimensions + ''')) + + if not any(type == "type{}".format(i) for i in range(1, 7)) and not type == "auto": + raise ValueError(filldedent('''\ + The input type should be a valid one + ''')) + + n = A.rows + + # constants = numbered_symbols(prefix='C', cls=Dummy, start=const_idx+1) + Cvect = Matrix([Dummy() for _ in range(n)]) + + if b is None and any(type == typ for typ in ["type2", "type4", "type6"]): + b = zeros(n, 1) + + is_transformed = tau is not None + passed_type = type + + if type == "auto": + system_info = linodesolve_type(A, t, b=b) + type = system_info["type_of_equation"] + B = system_info["antiderivative"] + + if type in ("type5", "type6"): + is_transformed = True + if passed_type != "auto": + if tau is None: + system_info = _first_order_type5_6_subs(A, t, b=b) + if not system_info: + raise ValueError(filldedent(''' + The system passed isn't {}. + '''.format(type))) + + tau = system_info['tau'] + t = system_info['t_'] + A = system_info['A'] + b = system_info['b'] + + intx_wrtt = lambda x: Integral(x, t) if x else 0 + if type in ("type1", "type2", "type5", "type6"): + P, J = matrix_exp_jordan_form(A, t) + P = simplify(P) + + if type in ("type1", "type5"): + sol_vector = P * (J * Cvect) + else: + Jinv = J.subs(t, -t) + sol_vector = P * J * ((Jinv * P.inv() * b).applyfunc(intx_wrtt) + Cvect) + else: + if B is None: + B, _ = _is_commutative_anti_derivative(A, t) + + if type == "type3": + sol_vector = B.exp() * Cvect + else: + sol_vector = B.exp() * (((-B).exp() * b).applyfunc(intx_wrtt) + Cvect) + + if is_transformed: + sol_vector = sol_vector.subs(t, tau) + + gens = sol_vector.atoms(exp) + + if type != "type1": + sol_vector = [expand_mul(s) for s in sol_vector] + + sol_vector = [collect(s, ordered(gens), exact=True) for s in sol_vector] + + if doit: + sol_vector = [s.doit() for s in sol_vector] + + return sol_vector + + +def _matrix_is_constant(M, t): + """Checks if the matrix M is independent of t or not.""" + return all(coef.as_independent(t, as_Add=True)[1] == 0 for coef in M) + + +def canonical_odes(eqs, funcs, t): + r""" + Function that solves for highest order derivatives in a system + + Explanation + =========== + + This function inputs a system of ODEs and based on the system, + the dependent variables and their highest order, returns the system + in the following form: + + .. math:: + X'(t) = A(t) X(t) + b(t) + + Here, $X(t)$ is the vector of dependent variables of lower order, $A(t)$ is + the coefficient matrix, $b(t)$ is the non-homogeneous term and $X'(t)$ is the + vector of dependent variables in their respective highest order. We use the term + canonical form to imply the system of ODEs which is of the above form. + + If the system passed has a non-linear term with multiple solutions, then a list of + systems is returned in its canonical form. + + Parameters + ========== + + eqs : List + List of the ODEs + funcs : List + List of dependent variables + t : Symbol + Independent variable + + Examples + ======== + + >>> from sympy import symbols, Function, Eq, Derivative + >>> from sympy.solvers.ode.systems import canonical_odes + >>> f, g = symbols("f g", cls=Function) + >>> x, y = symbols("x y") + >>> funcs = [f(x), g(x)] + >>> eqs = [Eq(f(x).diff(x) - 7*f(x), 12*g(x)), Eq(g(x).diff(x) + g(x), 20*f(x))] + + >>> canonical_eqs = canonical_odes(eqs, funcs, x) + >>> canonical_eqs + [[Eq(Derivative(f(x), x), 7*f(x) + 12*g(x)), Eq(Derivative(g(x), x), 20*f(x) - g(x))]] + + >>> system = [Eq(Derivative(f(x), x)**2 - 2*Derivative(f(x), x) + 1, 4), Eq(-y*f(x) + Derivative(g(x), x), 0)] + + >>> canonical_system = canonical_odes(system, funcs, x) + >>> canonical_system + [[Eq(Derivative(f(x), x), -1), Eq(Derivative(g(x), x), y*f(x))], [Eq(Derivative(f(x), x), 3), Eq(Derivative(g(x), x), y*f(x))]] + + Returns + ======= + + List + + """ + from sympy.solvers.solvers import solve + + order = _get_func_order(eqs, funcs) + + canon_eqs = solve(eqs, *[func.diff(t, order[func]) for func in funcs], dict=True) + + systems = [] + for eq in canon_eqs: + system = [Eq(func.diff(t, order[func]), eq[func.diff(t, order[func])]) for func in funcs] + systems.append(system) + + return systems + + +def _is_commutative_anti_derivative(A, t): + r""" + Helper function for determining if the Matrix passed is commutative with its antiderivative + + Explanation + =========== + + This function checks if the Matrix $A$ passed is commutative with its antiderivative with respect + to the independent variable $t$. + + .. math:: + B(t) = \int A(t) dt + + The function outputs two values, first one being the antiderivative $B(t)$, second one being a + boolean value, if True, then the matrix $A(t)$ passed is commutative with $B(t)$, else the matrix + passed isn't commutative with $B(t)$. + + Parameters + ========== + + A : Matrix + The matrix which has to be checked + t : Symbol + Independent variable + + Examples + ======== + + >>> from sympy import symbols, Matrix + >>> from sympy.solvers.ode.systems import _is_commutative_anti_derivative + >>> t = symbols("t") + >>> A = Matrix([[1, t], [-t, 1]]) + + >>> B, is_commuting = _is_commutative_anti_derivative(A, t) + >>> is_commuting + True + + Returns + ======= + + Matrix, Boolean + + """ + B = integrate(A, t) + is_commuting = (B*A - A*B).applyfunc(expand).applyfunc(factor_terms).is_zero_matrix + + is_commuting = False if is_commuting is None else is_commuting + + return B, is_commuting + + +def _factor_matrix(A, t): + term = None + for element in A: + temp_term = element.as_independent(t)[1] + if temp_term.has(t): + term = temp_term + break + + if term is not None: + A_factored = (A/term).applyfunc(ratsimp) + can_factor = _matrix_is_constant(A_factored, t) + term = (term, A_factored) if can_factor else None + + return term + + +def _is_second_order_type2(A, t): + term = _factor_matrix(A, t) + is_type2 = False + + if term is not None: + term = 1/term[0] + is_type2 = term.is_polynomial() + + if is_type2: + poly = Poly(term.expand(), t) + monoms = poly.monoms() + + if monoms[0][0] in (2, 4): + cs = _get_poly_coeffs(poly, 4) + a, b, c, d, e = cs + + a1 = powdenest(sqrt(a), force=True) + c1 = powdenest(sqrt(e), force=True) + b1 = powdenest(sqrt(c - 2*a1*c1), force=True) + + is_type2 = (b == 2*a1*b1) and (d == 2*b1*c1) + term = a1*t**2 + b1*t + c1 + + else: + is_type2 = False + + return is_type2, term + + +def _get_poly_coeffs(poly, order): + cs = [0 for _ in range(order+1)] + for c, m in zip(poly.coeffs(), poly.monoms()): + cs[-1-m[0]] = c + return cs + + +def _match_second_order_type(A1, A0, t, b=None): + r""" + Works only for second order system in its canonical form. + + Type 0: Constant coefficient matrix, can be simply solved by + introducing dummy variables. + Type 1: When the substitution: $U = t*X' - X$ works for reducing + the second order system to first order system. + Type 2: When the system is of the form: $poly * X'' = A*X$ where + $poly$ is square of a quadratic polynomial with respect to + *t* and $A$ is a constant coefficient matrix. + + """ + match = {"type_of_equation": "type0"} + n = A1.shape[0] + + if _matrix_is_constant(A1, t) and _matrix_is_constant(A0, t): + return match + + if (A1 + A0*t).applyfunc(expand_mul).is_zero_matrix: + match.update({"type_of_equation": "type1", "A1": A1}) + + elif A1.is_zero_matrix and (b is None or b.is_zero_matrix): + is_type2, term = _is_second_order_type2(A0, t) + if is_type2: + a, b, c = _get_poly_coeffs(Poly(term, t), 2) + A = (A0*(term**2).expand()).applyfunc(ratsimp) + (b**2/4 - a*c)*eye(n, n) + tau = integrate(1/term, t) + t_ = Symbol("{}_".format(t)) + match.update({"type_of_equation": "type2", "A0": A, + "g(t)": sqrt(term), "tau": tau, "is_transformed": True, + "t_": t_}) + + return match + + +def _second_order_subs_type1(A, b, funcs, t): + r""" + For a linear, second order system of ODEs, a particular substitution. + + A system of the below form can be reduced to a linear first order system of + ODEs: + .. math:: + X'' = A(t) * (t*X' - X) + b(t) + + By substituting: + .. math:: U = t*X' - X + + To get the system: + .. math:: U' = t*(A(t)*U + b(t)) + + Where $U$ is the vector of dependent variables, $X$ is the vector of dependent + variables in `funcs` and $X'$ is the first order derivative of $X$ with respect to + $t$. It may or may not reduce the system into linear first order system of ODEs. + + Then a check is made to determine if the system passed can be reduced or not, if + this substitution works, then the system is reduced and its solved for the new + substitution. After we get the solution for $U$: + + .. math:: U = a(t) + + We substitute and return the reduced system: + + .. math:: + a(t) = t*X' - X + + Parameters + ========== + + A: Matrix + Coefficient matrix($A(t)*t$) of the second order system of this form. + b: Matrix + Non-homogeneous term($b(t)$) of the system of ODEs. + funcs: List + List of dependent variables + t: Symbol + Independent variable of the system of ODEs. + + Returns + ======= + + List + + """ + + U = Matrix([t*func.diff(t) - func for func in funcs]) + + sol = linodesolve(A, t, t*b) + reduced_eqs = [Eq(u, s) for s, u in zip(sol, U)] + reduced_eqs = canonical_odes(reduced_eqs, funcs, t)[0] + + return reduced_eqs + + +def _second_order_subs_type2(A, funcs, t_): + r""" + Returns a second order system based on the coefficient matrix passed. + + Explanation + =========== + + This function returns a system of second order ODE of the following form: + + .. math:: + X'' = A * X + + Here, $X$ is the vector of dependent variables, but a bit modified, $A$ is the + coefficient matrix passed. + + Along with returning the second order system, this function also returns the new + dependent variables with the new independent variable `t_` passed. + + Parameters + ========== + + A: Matrix + Coefficient matrix of the system + funcs: List + List of old dependent variables + t_: Symbol + New independent variable + + Returns + ======= + + List, List + + """ + func_names = [func.func.__name__ for func in funcs] + new_funcs = [Function(Dummy("{}_".format(name)))(t_) for name in func_names] + rhss = A * Matrix(new_funcs) + new_eqs = [Eq(func.diff(t_, 2), rhs) for func, rhs in zip(new_funcs, rhss)] + + return new_eqs, new_funcs + + +def _is_euler_system(As, t): + return all(_matrix_is_constant((A*t**i).applyfunc(ratsimp), t) for i, A in enumerate(As)) + + +def _classify_linear_system(eqs, funcs, t, is_canon=False): + r""" + Returns a dictionary with details of the eqs if the system passed is linear + and can be classified by this function else returns None + + Explanation + =========== + + This function takes the eqs, converts it into a form Ax = b where x is a vector of terms + containing dependent variables and their derivatives till their maximum order. If it is + possible to convert eqs into Ax = b, then all the equations in eqs are linear otherwise + they are non-linear. + + To check if the equations are constant coefficient, we need to check if all the terms in + A obtained above are constant or not. + + To check if the equations are homogeneous or not, we need to check if b is a zero matrix + or not. + + Parameters + ========== + + eqs: List + List of ODEs + funcs: List + List of dependent variables + t: Symbol + Independent variable of the equations in eqs + is_canon: Boolean + If True, then this function will not try to get the + system in canonical form. Default value is False + + Returns + ======= + + match = { + 'no_of_equation': len(eqs), + 'eq': eqs, + 'func': funcs, + 'order': order, + 'is_linear': is_linear, + 'is_constant': is_constant, + 'is_homogeneous': is_homogeneous, + } + + Dict or list of Dicts or None + Dict with values for keys: + 1. no_of_equation: Number of equations + 2. eq: The set of equations + 3. func: List of dependent variables + 4. order: A dictionary that gives the order of the + dependent variable in eqs + 5. is_linear: Boolean value indicating if the set of + equations are linear or not. + 6. is_constant: Boolean value indicating if the set of + equations have constant coefficients or not. + 7. is_homogeneous: Boolean value indicating if the set of + equations are homogeneous or not. + 8. commutative_antiderivative: Antiderivative of the coefficient + matrix if the coefficient matrix is non-constant + and commutative with its antiderivative. This key + may or may not exist. + 9. is_general: Boolean value indicating if the system of ODEs is + solvable using one of the general case solvers or not. + 10. rhs: rhs of the non-homogeneous system of ODEs in Matrix form. This + key may or may not exist. + 11. is_higher_order: True if the system passed has an order greater than 1. + This key may or may not exist. + 12. is_second_order: True if the system passed is a second order ODE. This + key may or may not exist. + This Dict is the answer returned if the eqs are linear and constant + coefficient. Otherwise, None is returned. + + """ + + # Error for i == 0 can be added but isn't for now + + # Check for len(funcs) == len(eqs) + if len(funcs) != len(eqs): + raise ValueError("Number of functions given is not equal to the number of equations %s" % funcs) + + # ValueError when functions have more than one arguments + for func in funcs: + if len(func.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + + # Getting the func_dict and order using the helper + # function + order = _get_func_order(eqs, funcs) + system_order = max(order[func] for func in funcs) + is_higher_order = system_order > 1 + is_second_order = system_order == 2 and all(order[func] == 2 for func in funcs) + + # Not adding the check if the len(func.args) for + # every func in funcs is 1 + + # Linearity check + try: + + canon_eqs = canonical_odes(eqs, funcs, t) if not is_canon else [eqs] + if len(canon_eqs) == 1: + As, b = linear_ode_to_matrix(canon_eqs[0], funcs, t, system_order) + else: + + match = { + 'is_implicit': True, + 'canon_eqs': canon_eqs + } + + return match + + # When the system of ODEs is non-linear, an ODENonlinearError is raised. + # This function catches the error and None is returned. + except ODENonlinearError: + return None + + is_linear = True + + # Homogeneous check + is_homogeneous = True if b.is_zero_matrix else False + + # Is general key is used to identify if the system of ODEs can be solved by + # one of the general case solvers or not. + match = { + 'no_of_equation': len(eqs), + 'eq': eqs, + 'func': funcs, + 'order': order, + 'is_linear': is_linear, + 'is_homogeneous': is_homogeneous, + 'is_general': True + } + + if not is_homogeneous: + match['rhs'] = b + + is_constant = all(_matrix_is_constant(A_, t) for A_ in As) + + # The match['is_linear'] check will be added in the future when this + # function becomes ready to deal with non-linear systems of ODEs + + if not is_higher_order: + A = As[1] + match['func_coeff'] = A + + # Constant coefficient check + is_constant = _matrix_is_constant(A, t) + match['is_constant'] = is_constant + + try: + system_info = linodesolve_type(A, t, b=b) + except NotImplementedError: + return None + + match.update(system_info) + antiderivative = match.pop("antiderivative") + + if not is_constant: + match['commutative_antiderivative'] = antiderivative + + return match + else: + match['type_of_equation'] = "type0" + + if is_second_order: + A1, A0 = As[1:] + + match_second_order = _match_second_order_type(A1, A0, t) + match.update(match_second_order) + + match['is_second_order'] = True + + # If system is constant, then no need to check if its in euler + # form or not. It will be easier and faster to directly proceed + # to solve it. + if match['type_of_equation'] == "type0" and not is_constant: + is_euler = _is_euler_system(As, t) + if is_euler: + t_ = Symbol('{}_'.format(t)) + match.update({'is_transformed': True, 'type_of_equation': 'type1', + 't_': t_}) + else: + is_jordan = lambda M: M == Matrix.jordan_block(M.shape[0], M[0, 0]) + terms = _factor_matrix(As[-1], t) + if all(A.is_zero_matrix for A in As[1:-1]) and terms is not None and not is_jordan(terms[1]): + P, J = terms[1].jordan_form() + match.update({'type_of_equation': 'type2', 'J': J, + 'f(t)': terms[0], 'P': P, 'is_transformed': True}) + + if match['type_of_equation'] != 'type0' and is_second_order: + match.pop('is_second_order', None) + + match['is_higher_order'] = is_higher_order + + return match + +def _preprocess_eqs(eqs): + processed_eqs = [] + for eq in eqs: + processed_eqs.append(eq if isinstance(eq, Equality) else Eq(eq, 0)) + + return processed_eqs + + +def _eqs2dict(eqs, funcs): + eqsorig = {} + eqsmap = {} + funcset = set(funcs) + for eq in eqs: + f1, = eq.lhs.atoms(AppliedUndef) + f2s = (eq.rhs.atoms(AppliedUndef) - {f1}) & funcset + eqsmap[f1] = f2s + eqsorig[f1] = eq + return eqsmap, eqsorig + + +def _dict2graph(d): + nodes = list(d) + edges = [(f1, f2) for f1, f2s in d.items() for f2 in f2s] + G = (nodes, edges) + return G + + +def _is_type1(scc, t): + eqs, funcs = scc + + try: + (A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, 1) + except (ODENonlinearError, ODEOrderError): + return False + + if _matrix_is_constant(A0, t) and b.is_zero_matrix: + return True + + return False + + +def _combine_type1_subsystems(subsystem, funcs, t): + indices = [i for i, sys in enumerate(zip(subsystem, funcs)) if _is_type1(sys, t)] + remove = set() + for ip, i in enumerate(indices): + for j in indices[ip+1:]: + if any(eq2.has(funcs[i]) for eq2 in subsystem[j]): + subsystem[j] = subsystem[i] + subsystem[j] + remove.add(i) + subsystem = [sys for i, sys in enumerate(subsystem) if i not in remove] + return subsystem + + +def _component_division(eqs, funcs, t): + + # Assuming that each eq in eqs is in canonical form, + # that is, [f(x).diff(x) = .., g(x).diff(x) = .., etc] + # and that the system passed is in its first order + eqsmap, eqsorig = _eqs2dict(eqs, funcs) + + subsystems = [] + for cc in connected_components(_dict2graph(eqsmap)): + eqsmap_c = {f: eqsmap[f] for f in cc} + sccs = strongly_connected_components(_dict2graph(eqsmap_c)) + subsystem = [[eqsorig[f] for f in scc] for scc in sccs] + subsystem = _combine_type1_subsystems(subsystem, sccs, t) + subsystems.append(subsystem) + + return subsystems + + +# Returns: List of equations +def _linear_ode_solver(match): + t = match['t'] + funcs = match['func'] + + rhs = match.get('rhs', None) + tau = match.get('tau', None) + t = match['t_'] if 't_' in match else t + A = match['func_coeff'] + + # Note: To make B None when the matrix has constant + # coefficient + B = match.get('commutative_antiderivative', None) + type = match['type_of_equation'] + + sol_vector = linodesolve(A, t, b=rhs, B=B, + type=type, tau=tau) + + sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + + return sol + + +def _select_equations(eqs, funcs, key=lambda x: x): + eq_dict = {e.lhs: e.rhs for e in eqs} + return [Eq(f, eq_dict[key(f)]) for f in funcs] + + +def _higher_order_ode_solver(match): + eqs = match["eq"] + funcs = match["func"] + t = match["t"] + sysorder = match['order'] + type = match.get('type_of_equation', "type0") + + is_second_order = match.get('is_second_order', False) + is_transformed = match.get('is_transformed', False) + is_euler = is_transformed and type == "type1" + is_higher_order_type2 = is_transformed and type == "type2" and 'P' in match + + if is_second_order: + new_eqs, new_funcs = _second_order_to_first_order(eqs, funcs, t, + A1=match.get("A1", None), A0=match.get("A0", None), + b=match.get("rhs", None), type=type, + t_=match.get("t_", None)) + else: + new_eqs, new_funcs = _higher_order_to_first_order(eqs, sysorder, t, funcs=funcs, + type=type, J=match.get('J', None), + f_t=match.get('f(t)', None), + P=match.get('P', None), b=match.get('rhs', None)) + + if is_transformed: + t = match.get('t_', t) + + if not is_higher_order_type2: + new_eqs = _select_equations(new_eqs, [f.diff(t) for f in new_funcs]) + + sol = None + + # NotImplementedError may be raised when the system may be actually + # solvable if it can be just divided into sub-systems + try: + if not is_higher_order_type2: + sol = _strong_component_solver(new_eqs, new_funcs, t) + except NotImplementedError: + sol = None + + # Dividing the system only when it becomes essential + if sol is None: + try: + sol = _component_solver(new_eqs, new_funcs, t) + except NotImplementedError: + sol = None + + if sol is None: + return sol + + is_second_order_type2 = is_second_order and type == "type2" + + underscores = '__' if is_transformed else '_' + + sol = _select_equations(sol, funcs, + key=lambda x: Function(Dummy('{}{}0'.format(x.func.__name__, underscores)))(t)) + + if match.get("is_transformed", False): + if is_second_order_type2: + g_t = match["g(t)"] + tau = match["tau"] + sol = [Eq(s.lhs, s.rhs.subs(t, tau) * g_t) for s in sol] + elif is_euler: + t = match['t'] + tau = match['t_'] + sol = [s.subs(tau, log(t)) for s in sol] + elif is_higher_order_type2: + P = match['P'] + sol_vector = P * Matrix([s.rhs for s in sol]) + sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + + return sol + + +# Returns: List of equations or None +# If None is returned by this solver, then the system +# of ODEs cannot be solved directly by dsolve_system. +def _strong_component_solver(eqs, funcs, t): + from sympy.solvers.ode.ode import dsolve, constant_renumber + + match = _classify_linear_system(eqs, funcs, t, is_canon=True) + sol = None + + # Assuming that we can't get an implicit system + # since we are already canonical equations from + # dsolve_system + if match: + match['t'] = t + + if match.get('is_higher_order', False): + sol = _higher_order_ode_solver(match) + + elif match.get('is_linear', False): + sol = _linear_ode_solver(match) + + # Note: For now, only linear systems are handled by this function + # hence, the match condition is added. This can be removed later. + if sol is None and len(eqs) == 1: + sol = dsolve(eqs[0], func=funcs[0]) + variables = Tuple(eqs[0]).free_symbols + new_constants = [Dummy() for _ in range(ode_order(eqs[0], funcs[0]))] + sol = constant_renumber(sol, variables=variables, newconstants=new_constants) + sol = [sol] + + # To add non-linear case here in future + + return sol + + +def _get_funcs_from_canon(eqs): + return [eq.lhs.args[0] for eq in eqs] + + +# Returns: List of Equations(a solution) +def _weak_component_solver(wcc, t): + + # We will divide the systems into sccs + # only when the wcc cannot be solved as + # a whole + eqs = [] + for scc in wcc: + eqs += scc + funcs = _get_funcs_from_canon(eqs) + + sol = _strong_component_solver(eqs, funcs, t) + if sol: + return sol + + sol = [] + + for scc in wcc: + eqs = scc + funcs = _get_funcs_from_canon(eqs) + + # Substituting solutions for the dependent + # variables solved in previous SCC, if any solved. + comp_eqs = [eq.subs({s.lhs: s.rhs for s in sol}) for eq in eqs] + scc_sol = _strong_component_solver(comp_eqs, funcs, t) + + if scc_sol is None: + raise NotImplementedError(filldedent(''' + The system of ODEs passed cannot be solved by dsolve_system. + ''')) + + # scc_sol: List of equations + # scc_sol is a solution + sol += scc_sol + + return sol + + +# Returns: List of Equations(a solution) +def _component_solver(eqs, funcs, t): + components = _component_division(eqs, funcs, t) + sol = [] + + for wcc in components: + + # wcc_sol: List of Equations + sol += _weak_component_solver(wcc, t) + + # sol: List of Equations + return sol + + +def _second_order_to_first_order(eqs, funcs, t, type="auto", A1=None, + A0=None, b=None, t_=None): + r""" + Expects the system to be in second order and in canonical form + + Explanation + =========== + + Reduces a second order system into a first order one depending on the type of second + order system. + 1. "type0": If this is passed, then the system will be reduced to first order by + introducing dummy variables. + 2. "type1": If this is passed, then a particular substitution will be used to reduce the + the system into first order. + 3. "type2": If this is passed, then the system will be transformed with new dependent + variables and independent variables. This transformation is a part of solving + the corresponding system of ODEs. + + `A1` and `A0` are the coefficient matrices from the system and it is assumed that the + second order system has the form given below: + + .. math:: + A2 * X'' = A1 * X' + A0 * X + b + + Here, $A2$ is the coefficient matrix for the vector $X''$ and $b$ is the non-homogeneous + term. + + Default value for `b` is None but if `A1` and `A0` are passed and `b` is not passed, then the + system will be assumed homogeneous. + + """ + is_a1 = A1 is None + is_a0 = A0 is None + + if (type == "type1" and is_a1) or (type == "type2" and is_a0)\ + or (type == "auto" and (is_a1 or is_a0)): + (A2, A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, 2) + + if not A2.is_Identity: + raise ValueError(filldedent(''' + The system must be in its canonical form. + ''')) + + if type == "auto": + match = _match_second_order_type(A1, A0, t) + type = match["type_of_equation"] + A1 = match.get("A1", None) + A0 = match.get("A0", None) + + sys_order = dict.fromkeys(funcs, 2) + + if type == "type1": + if b is None: + b = zeros(len(eqs)) + eqs = _second_order_subs_type1(A1, b, funcs, t) + sys_order = dict.fromkeys(funcs, 1) + + if type == "type2": + if t_ is None: + t_ = Symbol("{}_".format(t)) + t = t_ + eqs, funcs = _second_order_subs_type2(A0, funcs, t_) + sys_order = dict.fromkeys(funcs, 2) + + return _higher_order_to_first_order(eqs, sys_order, t, funcs=funcs) + + +def _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order, b=None, P=None): + + # Note: To add a test for this ValueError + if J is None or f_t is None or not _matrix_is_constant(J, t): + raise ValueError(filldedent(''' + Correctly input for args 'A' and 'f_t' for Linear, Higher Order, + Type 2 + ''')) + + if P is None and b is not None and not b.is_zero_matrix: + raise ValueError(filldedent(''' + Provide the keyword 'P' for matrix P in A = P * J * P-1. + ''')) + + new_funcs = Matrix([Function(Dummy('{}__0'.format(f.func.__name__)))(t) for f in funcs]) + new_eqs = new_funcs.diff(t, max_order) - f_t * J * new_funcs + + if b is not None and not b.is_zero_matrix: + new_eqs -= P.inv() * b + + new_eqs = canonical_odes(new_eqs, new_funcs, t)[0] + + return new_eqs, new_funcs + + +def _higher_order_to_first_order(eqs, sys_order, t, funcs=None, type="type0", **kwargs): + if funcs is None: + funcs = sys_order.keys() + + # Standard Cauchy Euler system + if type == "type1": + t_ = Symbol('{}_'.format(t)) + new_funcs = [Function(Dummy('{}_'.format(f.func.__name__)))(t_) for f in funcs] + max_order = max(sys_order[func] for func in funcs) + subs_dict = dict(zip(funcs, new_funcs)) + subs_dict[t] = exp(t_) + + free_function = Function(Dummy()) + + def _get_coeffs_from_subs_expression(expr): + if isinstance(expr, Subs): + free_symbol = expr.args[1][0] + term = expr.args[0] + return {ode_order(term, free_symbol): 1} + + if isinstance(expr, Mul): + coeff = expr.args[0] + order = list(_get_coeffs_from_subs_expression(expr.args[1]).keys())[0] + return {order: coeff} + + if isinstance(expr, Add): + coeffs = {} + for arg in expr.args: + + if isinstance(arg, Mul): + coeffs.update(_get_coeffs_from_subs_expression(arg)) + + else: + order = list(_get_coeffs_from_subs_expression(arg).keys())[0] + coeffs[order] = 1 + + return coeffs + + for o in range(1, max_order + 1): + expr = free_function(log(t_)).diff(t_, o)*t_**o + coeff_dict = _get_coeffs_from_subs_expression(expr) + coeffs = [coeff_dict[order] if order in coeff_dict else 0 for order in range(o + 1)] + expr_to_subs = sum(free_function(t_).diff(t_, i) * c for i, c in + enumerate(coeffs)) / t**o + subs_dict.update({f.diff(t, o): expr_to_subs.subs(free_function(t_), nf) + for f, nf in zip(funcs, new_funcs)}) + + new_eqs = [eq.subs(subs_dict) for eq in eqs] + new_sys_order = {nf: sys_order[f] for f, nf in zip(funcs, new_funcs)} + + new_eqs = canonical_odes(new_eqs, new_funcs, t_)[0] + + return _higher_order_to_first_order(new_eqs, new_sys_order, t_, funcs=new_funcs) + + # Systems of the form: X(n)(t) = f(t)*A*X + b + # where X(n)(t) is the nth derivative of the vector of dependent variables + # with respect to the independent variable and A is a constant matrix. + if type == "type2": + J = kwargs.get('J', None) + f_t = kwargs.get('f_t', None) + b = kwargs.get('b', None) + P = kwargs.get('P', None) + max_order = max(sys_order[func] for func in funcs) + + return _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order, P=P, b=b) + + # Note: To be changed to this after doit option is disabled for default cases + # new_sysorder = _get_func_order(new_eqs, new_funcs) + # + # return _higher_order_to_first_order(new_eqs, new_sysorder, t, funcs=new_funcs) + + new_funcs = [] + + for prev_func in funcs: + func_name = prev_func.func.__name__ + func = Function(Dummy('{}_0'.format(func_name)))(t) + new_funcs.append(func) + subs_dict = {prev_func: func} + new_eqs = [] + + for i in range(1, sys_order[prev_func]): + new_func = Function(Dummy('{}_{}'.format(func_name, i)))(t) + subs_dict[prev_func.diff(t, i)] = new_func + new_funcs.append(new_func) + + prev_f = subs_dict[prev_func.diff(t, i-1)] + new_eq = Eq(prev_f.diff(t), new_func) + new_eqs.append(new_eq) + + eqs = [eq.subs(subs_dict) for eq in eqs] + new_eqs + + return eqs, new_funcs + + +def dsolve_system(eqs, funcs=None, t=None, ics=None, doit=False, simplify=True): + r""" + Solves any(supported) system of Ordinary Differential Equations + + Explanation + =========== + + This function takes a system of ODEs as an input, determines if the + it is solvable by this function, and returns the solution if found any. + + This function can handle: + 1. Linear, First Order, Constant coefficient homogeneous system of ODEs + 2. Linear, First Order, Constant coefficient non-homogeneous system of ODEs + 3. Linear, First Order, non-constant coefficient homogeneous system of ODEs + 4. Linear, First Order, non-constant coefficient non-homogeneous system of ODEs + 5. Any implicit system which can be divided into system of ODEs which is of the above 4 forms + 6. Any higher order linear system of ODEs that can be reduced to one of the 5 forms of systems described above. + + The types of systems described above are not limited by the number of equations, i.e. this + function can solve the above types irrespective of the number of equations in the system passed. + But, the bigger the system, the more time it will take to solve the system. + + This function returns a list of solutions. Each solution is a list of equations where LHS is + the dependent variable and RHS is an expression in terms of the independent variable. + + Among the non constant coefficient types, not all the systems are solvable by this function. Only + those which have either a coefficient matrix with a commutative antiderivative or those systems which + may be divided further so that the divided systems may have coefficient matrix with commutative antiderivative. + + Parameters + ========== + + eqs : List + system of ODEs to be solved + funcs : List or None + List of dependent variables that make up the system of ODEs + t : Symbol or None + Independent variable in the system of ODEs + ics : Dict or None + Set of initial boundary/conditions for the system of ODEs + doit : Boolean + Evaluate the solutions if True. Default value is True. Can be + set to false if the integral evaluation takes too much time and/or + is not required. + simplify: Boolean + Simplify the solutions for the systems. Default value is True. + Can be set to false if simplification takes too much time and/or + is not required. + + Examples + ======== + + >>> from sympy import symbols, Eq, Function + >>> from sympy.solvers.ode.systems import dsolve_system + >>> f, g = symbols("f g", cls=Function) + >>> x = symbols("x") + + >>> eqs = [Eq(f(x).diff(x), g(x)), Eq(g(x).diff(x), f(x))] + >>> dsolve_system(eqs) + [[Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), C1*exp(-x) + C2*exp(x))]] + + You can also pass the initial conditions for the system of ODEs: + + >>> dsolve_system(eqs, ics={f(0): 1, g(0): 0}) + [[Eq(f(x), exp(x)/2 + exp(-x)/2), Eq(g(x), exp(x)/2 - exp(-x)/2)]] + + Optionally, you can pass the dependent variables and the independent + variable for which the system is to be solved: + + >>> funcs = [f(x), g(x)] + >>> dsolve_system(eqs, funcs=funcs, t=x) + [[Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), C1*exp(-x) + C2*exp(x))]] + + Lets look at an implicit system of ODEs: + + >>> eqs = [Eq(f(x).diff(x)**2, g(x)**2), Eq(g(x).diff(x), g(x))] + >>> dsolve_system(eqs) + [[Eq(f(x), C1 - C2*exp(x)), Eq(g(x), C2*exp(x))], [Eq(f(x), C1 + C2*exp(x)), Eq(g(x), C2*exp(x))]] + + Returns + ======= + + List of List of Equations + + Raises + ====== + + NotImplementedError + When the system of ODEs is not solvable by this function. + ValueError + When the parameters passed are not in the required form. + + """ + from sympy.solvers.ode.ode import solve_ics, _extract_funcs, constant_renumber + + if not iterable(eqs): + raise ValueError(filldedent(''' + List of equations should be passed. The input is not valid. + ''')) + + eqs = _preprocess_eqs(eqs) + + if funcs is not None and not isinstance(funcs, list): + raise ValueError(filldedent(''' + Input to the funcs should be a list of functions. + ''')) + + if funcs is None: + funcs = _extract_funcs(eqs) + + if any(len(func.args) != 1 for func in funcs): + raise ValueError(filldedent(''' + dsolve_system can solve a system of ODEs with only one independent + variable. + ''')) + + if len(eqs) != len(funcs): + raise ValueError(filldedent(''' + Number of equations and number of functions do not match + ''')) + + if t is not None and not isinstance(t, Symbol): + raise ValueError(filldedent(''' + The independent variable must be of type Symbol + ''')) + + if t is None: + t = list(list(eqs[0].atoms(Derivative))[0].atoms(Symbol))[0] + + sols = [] + canon_eqs = canonical_odes(eqs, funcs, t) + + for canon_eq in canon_eqs: + try: + sol = _strong_component_solver(canon_eq, funcs, t) + except NotImplementedError: + sol = None + + if sol is None: + sol = _component_solver(canon_eq, funcs, t) + + sols.append(sol) + + if sols: + final_sols = [] + variables = Tuple(*eqs).free_symbols + + for sol in sols: + + sol = _select_equations(sol, funcs) + sol = constant_renumber(sol, variables=variables) + + if ics: + constants = Tuple(*sol).free_symbols - variables + solved_constants = solve_ics(sol, funcs, constants, ics) + sol = [s.subs(solved_constants) for s in sol] + + if simplify: + constants = Tuple(*sol).free_symbols - variables + sol = simpsol(sol, [t], constants, doit=doit) + + final_sols.append(sol) + + sols = final_sols + + return sols diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/__init__.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_lie_group.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_lie_group.py new file mode 100644 index 0000000000000000000000000000000000000000..153d30ff563773819e49c989f447c1ec7962169b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_lie_group.py @@ -0,0 +1,152 @@ +from sympy.core.function import Function +from sympy.core.numbers import Rational +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan, sin, tan) + +from sympy.solvers.ode import (classify_ode, checkinfsol, dsolve, infinitesimals) + +from sympy.solvers.ode.subscheck import checkodesol + +from sympy.testing.pytest import XFAIL + + +C1 = Symbol('C1') +x, y = symbols("x y") +f = Function('f') +xi = Function('xi') +eta = Function('eta') + + +def test_heuristic1(): + a, b, c, a4, a3, a2, a1, a0 = symbols("a b c a4 a3 a2 a1 a0") + df = f(x).diff(x) + eq = Eq(df, x**2*f(x)) + eq1 = f(x).diff(x) + a*f(x) - c*exp(b*x) + eq2 = f(x).diff(x) + 2*x*f(x) - x*exp(-x**2) + eq3 = (1 + 2*x)*df + 2 - 4*exp(-f(x)) + eq4 = f(x).diff(x) - (a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**Rational(-1, 2) + eq5 = x**2*df - f(x) + x**2*exp(x - (1/x)) + eqlist = [eq, eq1, eq2, eq3, eq4, eq5] + + i = infinitesimals(eq, hint='abaco1_simple') + assert i == [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0}, + {eta(x, f(x)): f(x), xi(x, f(x)): 0}, + {eta(x, f(x)): 0, xi(x, f(x)): x**(-2)}] + i1 = infinitesimals(eq1, hint='abaco1_simple') + assert i1 == [{eta(x, f(x)): exp(-a*x), xi(x, f(x)): 0}] + i2 = infinitesimals(eq2, hint='abaco1_simple') + assert i2 == [{eta(x, f(x)): exp(-x**2), xi(x, f(x)): 0}] + i3 = infinitesimals(eq3, hint='abaco1_simple') + assert i3 == [{eta(x, f(x)): 0, xi(x, f(x)): 2*x + 1}, + {eta(x, f(x)): 0, xi(x, f(x)): 1/(exp(f(x)) - 2)}] + i4 = infinitesimals(eq4, hint='abaco1_simple') + assert i4 == [{eta(x, f(x)): 1, xi(x, f(x)): 0}, + {eta(x, f(x)): 0, + xi(x, f(x)): sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4)}] + i5 = infinitesimals(eq5, hint='abaco1_simple') + assert i5 == [{xi(x, f(x)): 0, eta(x, f(x)): exp(-1/x)}] + + ilist = [i, i1, i2, i3, i4, i5] + for eq, i in (zip(eqlist, ilist)): + check = checkinfsol(eq, i) + assert check[0] + + # This ODE can be solved by the Lie Group method, when there are + # better assumptions + eq6 = df - (f(x)/x)*(x*log(x**2/f(x)) + 2) + i = infinitesimals(eq6, hint='abaco1_product') + assert i == [{eta(x, f(x)): f(x)*exp(-x), xi(x, f(x)): 0}] + assert checkinfsol(eq6, i)[0] + + eq7 = x*(f(x).diff(x)) + 1 - f(x)**2 + i = infinitesimals(eq7, hint='chi') + assert checkinfsol(eq7, i)[0] + + +def test_heuristic3(): + a, b = symbols("a b") + df = f(x).diff(x) + + eq = x**2*df + x*f(x) + f(x)**2 + x**2 + i = infinitesimals(eq, hint='bivariate') + assert i == [{eta(x, f(x)): f(x), xi(x, f(x)): x}] + assert checkinfsol(eq, i)[0] + + eq = x**2*(-f(x)**2 + df)- a*x**2*f(x) + 2 - a*x + i = infinitesimals(eq, hint='bivariate') + assert checkinfsol(eq, i)[0] + + +def test_heuristic_function_sum(): + eq = f(x).diff(x) - (3*(1 + x**2/f(x)**2)*atan(f(x)/x) + (1 - 2*f(x))/x + + (1 - 3*f(x))*(x/f(x)**2)) + i = infinitesimals(eq, hint='function_sum') + assert i == [{eta(x, f(x)): f(x)**(-2) + x**(-2), xi(x, f(x)): 0}] + assert checkinfsol(eq, i)[0] + + +def test_heuristic_abaco2_similar(): + a, b = symbols("a b") + F = Function('F') + eq = f(x).diff(x) - F(a*x + b*f(x)) + i = infinitesimals(eq, hint='abaco2_similar') + assert i == [{eta(x, f(x)): -a/b, xi(x, f(x)): 1}] + assert checkinfsol(eq, i)[0] + + eq = f(x).diff(x) - (f(x)**2 / (sin(f(x) - x) - x**2 + 2*x*f(x))) + i = infinitesimals(eq, hint='abaco2_similar') + assert i == [{eta(x, f(x)): f(x)**2, xi(x, f(x)): f(x)**2}] + assert checkinfsol(eq, i)[0] + + +def test_heuristic_abaco2_unique_unknown(): + + a, b = symbols("a b") + F = Function('F') + eq = f(x).diff(x) - x**(a - 1)*(f(x)**(1 - b))*F(x**a/a + f(x)**b/b) + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert i == [{eta(x, f(x)): -f(x)*f(x)**(-b), xi(x, f(x)): x*x**(-a)}] + assert checkinfsol(eq, i)[0] + + eq = f(x).diff(x) + tan(F(x**2 + f(x)**2) + atan(x/f(x))) + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert i == [{eta(x, f(x)): x, xi(x, f(x)): -f(x)}] + assert checkinfsol(eq, i)[0] + + eq = (x*f(x).diff(x) + f(x) + 2*x)**2 -4*x*f(x) -4*x**2 -4*a + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert checkinfsol(eq, i)[0] + + +def test_heuristic_linear(): + a, b, m, n = symbols("a b m n") + + eq = x**(n*(m + 1) - m)*(f(x).diff(x)) - a*f(x)**n -b*x**(n*(m + 1)) + i = infinitesimals(eq, hint='linear') + assert checkinfsol(eq, i)[0] + + +@XFAIL +def test_kamke(): + a, b, alpha, c = symbols("a b alpha c") + eq = x**2*(a*f(x)**2+(f(x).diff(x))) + b*x**alpha + c + i = infinitesimals(eq, hint='sum_function') # XFAIL + assert checkinfsol(eq, i)[0] + + +def test_user_infinitesimals(): + x = Symbol("x") # assuming x is real generates an error + eq = x*(f(x).diff(x)) + 1 - f(x)**2 + sol = Eq(f(x), (C1 + x**2)/(C1 - x**2)) + infinitesimals = {'xi':sqrt(f(x) - 1)/sqrt(f(x) + 1), 'eta':0} + assert dsolve(eq, hint='lie_group', **infinitesimals) == sol + assert checkodesol(eq, sol) == (True, 0) + + +@XFAIL +def test_lie_group_issue15219(): + eqn = exp(f(x).diff(x)-f(x)) + assert 'lie_group' not in classify_ode(eqn, f(x)) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_ode.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_ode.py new file mode 100644 index 0000000000000000000000000000000000000000..65e0fa62d52445a4669f3cdc5ef278dbf9c88ea4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_ode.py @@ -0,0 +1,1105 @@ +from sympy.core.function import (Derivative, Function, Subs, diff) +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import acosh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin, tan) +from sympy.integrals.integrals import Integral +from sympy.polys.polytools import Poly +from sympy.series.order import O +from sympy.simplify.radsimp import collect + +from sympy.solvers.ode import (classify_ode, + homogeneous_order, dsolve) + +from sympy.solvers.ode.subscheck import checkodesol +from sympy.solvers.ode.ode import (classify_sysode, + constant_renumber, constantsimp, get_numbered_constants, solve_ics) + +from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match +from sympy.solvers.ode.single import LinearCoefficients +from sympy.solvers.deutils import ode_order +from sympy.testing.pytest import XFAIL, raises, slow, SKIP +from sympy.utilities.misc import filldedent + + +C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C0:11') +u, x, y, z = symbols('u,x:z', real=True) +f = Function('f') +g = Function('g') +h = Function('h') + +# Note: Examples which were specifically testing Single ODE solver are moved to test_single.py +# and all the system of ode examples are moved to test_systems.py +# Note: the tests below may fail (but still be correct) if ODE solver, +# the integral engine, solve(), or even simplify() changes. Also, in +# differently formatted solutions, the arbitrary constants might not be +# equal. Using specific hints in tests can help to avoid this. + +# Tests of order higher than 1 should run the solutions through +# constant_renumber because it will normalize it (constant_renumber causes +# dsolve() to return different results on different machines) + + +def test_get_numbered_constants(): + with raises(ValueError): + get_numbered_constants(None) + + +def test_dsolve_all_hint(): + eq = f(x).diff(x) + output = dsolve(eq, hint='all') + + # Match the Dummy variables: + sol1 = output['separable_Integral'] + _y = sol1.lhs.args[1][0] + sol1 = output['1st_homogeneous_coeff_subs_dep_div_indep_Integral'] + _u1 = sol1.rhs.args[1].args[1][0] + + expected = {'Bernoulli_Integral': Eq(f(x), C1 + Integral(0, x)), + '1st_homogeneous_coeff_best': Eq(f(x), C1), + 'Bernoulli': Eq(f(x), C1), + 'nth_algebraic': Eq(f(x), C1), + 'nth_linear_euler_eq_homogeneous': Eq(f(x), C1), + 'nth_linear_constant_coeff_homogeneous': Eq(f(x), C1), + 'separable': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_indep_div_dep': Eq(f(x), C1), + 'nth_algebraic_Integral': Eq(f(x), C1), + '1st_linear': Eq(f(x), C1), + '1st_linear_Integral': Eq(f(x), C1 + Integral(0, x)), + '1st_exact': Eq(f(x), C1), + '1st_exact_Integral': Eq(Subs(Integral(0, x) + Integral(1, _y), _y, f(x)), C1), + 'lie_group': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_dep_div_indep': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_dep_div_indep_Integral': Eq(log(x), C1 + Integral(-1/_u1, (_u1, f(x)/x))), + '1st_power_series': Eq(f(x), C1), + 'separable_Integral': Eq(Integral(1, (_y, f(x))), C1 + Integral(0, x)), + '1st_homogeneous_coeff_subs_indep_div_dep_Integral': Eq(f(x), C1), + 'best': Eq(f(x), C1), + 'best_hint': 'nth_algebraic', + 'default': 'nth_algebraic', + 'order': 1} + assert output == expected + + assert dsolve(eq, hint='best') == Eq(f(x), C1) + + +def test_dsolve_ics(): + # Maybe this should just use one of the solutions instead of raising... + with raises(NotImplementedError): + dsolve(f(x).diff(x) - sqrt(f(x)), ics={f(1):1}) + + +@slow +def test_dsolve_options(): + eq = x*f(x).diff(x) + f(x) + a = dsolve(eq, hint='all') + b = dsolve(eq, hint='all', simplify=False) + c = dsolve(eq, hint='all_Integral') + keys = ['1st_exact', '1st_exact_Integral', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', '1st_linear', + '1st_linear_Integral', 'Bernoulli', 'Bernoulli_Integral', + 'almost_linear', 'almost_linear_Integral', 'best', 'best_hint', + 'default', 'factorable', 'lie_group', + 'nth_linear_euler_eq_homogeneous', 'order', + 'separable', 'separable_Integral'] + Integral_keys = ['1st_exact_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', '1st_linear_Integral', + 'Bernoulli_Integral', 'almost_linear_Integral', 'best', 'best_hint', 'default', + 'factorable', 'nth_linear_euler_eq_homogeneous', + 'order', 'separable_Integral'] + assert sorted(a.keys()) == keys + assert a['order'] == ode_order(eq, f(x)) + assert a['best'] == Eq(f(x), C1/x) + assert dsolve(eq, hint='best') == Eq(f(x), C1/x) + assert a['default'] == 'factorable' + assert a['best_hint'] == 'factorable' + assert not a['1st_exact'].has(Integral) + assert not a['separable'].has(Integral) + assert not a['1st_homogeneous_coeff_best'].has(Integral) + assert not a['1st_homogeneous_coeff_subs_dep_div_indep'].has(Integral) + assert not a['1st_homogeneous_coeff_subs_indep_div_dep'].has(Integral) + assert not a['1st_linear'].has(Integral) + assert a['1st_linear_Integral'].has(Integral) + assert a['1st_exact_Integral'].has(Integral) + assert a['1st_homogeneous_coeff_subs_dep_div_indep_Integral'].has(Integral) + assert a['1st_homogeneous_coeff_subs_indep_div_dep_Integral'].has(Integral) + assert a['separable_Integral'].has(Integral) + assert sorted(b.keys()) == keys + assert b['order'] == ode_order(eq, f(x)) + assert b['best'] == Eq(f(x), C1/x) + assert dsolve(eq, hint='best', simplify=False) == Eq(f(x), C1/x) + assert b['default'] == 'factorable' + assert b['best_hint'] == 'factorable' + assert a['separable'] != b['separable'] + assert a['1st_homogeneous_coeff_subs_dep_div_indep'] != \ + b['1st_homogeneous_coeff_subs_dep_div_indep'] + assert a['1st_homogeneous_coeff_subs_indep_div_dep'] != \ + b['1st_homogeneous_coeff_subs_indep_div_dep'] + assert not b['1st_exact'].has(Integral) + assert not b['separable'].has(Integral) + assert not b['1st_homogeneous_coeff_best'].has(Integral) + assert not b['1st_homogeneous_coeff_subs_dep_div_indep'].has(Integral) + assert not b['1st_homogeneous_coeff_subs_indep_div_dep'].has(Integral) + assert not b['1st_linear'].has(Integral) + assert b['1st_linear_Integral'].has(Integral) + assert b['1st_exact_Integral'].has(Integral) + assert b['1st_homogeneous_coeff_subs_dep_div_indep_Integral'].has(Integral) + assert b['1st_homogeneous_coeff_subs_indep_div_dep_Integral'].has(Integral) + assert b['separable_Integral'].has(Integral) + assert sorted(c.keys()) == Integral_keys + raises(ValueError, lambda: dsolve(eq, hint='notarealhint')) + raises(ValueError, lambda: dsolve(eq, hint='Liouville')) + assert dsolve(f(x).diff(x) - 1/f(x)**2, hint='all')['best'] == \ + dsolve(f(x).diff(x) - 1/f(x)**2, hint='best') + assert dsolve(f(x) + f(x).diff(x) + sin(x).diff(x) + 1, f(x), + hint="1st_linear_Integral") == \ + Eq(f(x), (C1 + Integral((-sin(x).diff(x) - 1)* + exp(Integral(1, x)), x))*exp(-Integral(1, x))) + + +def test_classify_ode(): + assert classify_ode(f(x).diff(x, 2), f(x)) == \ + ( + 'nth_algebraic', + 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'Liouville', + '2nd_power_series_ordinary', + 'nth_algebraic_Integral', + 'Liouville_Integral', + ) + assert classify_ode(f(x), f(x)) == ('nth_algebraic', 'nth_algebraic_Integral') + assert classify_ode(Eq(f(x).diff(x), 0), f(x)) == ( + 'nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', + 'separable_Integral', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + assert classify_ode(f(x).diff(x)**2, f(x)) == ('factorable', + 'nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', + 'lie_group', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', + 'separable_Integral', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + # issue 4749: f(x) should be cleared from highest derivative before classifying + a = classify_ode(Eq(f(x).diff(x) + f(x), x), f(x)) + b = classify_ode(f(x).diff(x)*f(x) + f(x)*f(x) - x*f(x), f(x)) + c = classify_ode(f(x).diff(x)/f(x) + f(x)/f(x) - x/f(x), f(x)) + assert a == ('1st_exact', + '1st_linear', + 'Bernoulli', + 'almost_linear', + '1st_power_series', "lie_group", + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'almost_linear_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + assert b == ('factorable', + '1st_linear', + 'Bernoulli', + '1st_power_series', + 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + assert c == ('factorable', + '1st_linear', + 'Bernoulli', + '1st_power_series', + 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + + assert classify_ode( + 2*x*f(x)*f(x).diff(x) + (1 + x)*f(x)**2 - exp(x), f(x) + ) == ('factorable', '1st_exact', 'Bernoulli', 'almost_linear', 'lie_group', + '1st_exact_Integral', 'Bernoulli_Integral', 'almost_linear_Integral') + assert 'Riccati_special_minus2' in \ + classify_ode(2*f(x).diff(x) + f(x)**2 - f(x)/x + 3*x**(-2), f(x)) + raises(ValueError, lambda: classify_ode(x + f(x, y).diff(x).diff( + y), f(x, y))) + # issue 5176 + k = Symbol('k') + assert classify_ode(f(x).diff(x)/(k*f(x) + k*x*f(x)) + 2*f(x)/(k*f(x) + + k*x*f(x)) + x*f(x).diff(x)/(k*f(x) + k*x*f(x)) + z, f(x)) == \ + ('factorable', 'separable', '1st_exact', '1st_linear', 'Bernoulli', + '1st_power_series', 'lie_group', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral') + # preprocessing + ans = ('factorable', 'nth_algebraic', 'separable', '1st_exact', '1st_linear', 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters', + 'nth_algebraic_Integral', + 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters_Integral') + # w/o f(x) given + assert classify_ode(diff(f(x) + x, x) + diff(f(x), x)) == ans + # w/ f(x) and prep=True + assert classify_ode(diff(f(x) + x, x) + diff(f(x), x), f(x), + prep=True) == ans + + assert classify_ode(Eq(2*x**3*f(x).diff(x), 0), f(x)) == \ + ('factorable', 'nth_algebraic', 'separable', '1st_exact', + '1st_linear', 'Bernoulli', '1st_power_series', + 'lie_group', 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral') + + + assert classify_ode(Eq(2*f(x)**3*f(x).diff(x), 0), f(x)) == \ + ('factorable', 'nth_algebraic', 'separable', '1st_exact', '1st_linear', + 'Bernoulli', '1st_power_series', 'lie_group', 'nth_algebraic_Integral', + 'separable_Integral', '1st_exact_Integral', '1st_linear_Integral', + 'Bernoulli_Integral') + # test issue 13864 + assert classify_ode(Eq(diff(f(x), x) - f(x)**x, 0), f(x)) == \ + ('1st_power_series', 'lie_group') + assert isinstance(classify_ode(Eq(f(x), 5), f(x), dict=True), dict) + + #This is for new behavior of classify_ode when called internally with default, It should + # return the first hint which matches therefore, 'ordered_hints' key will not be there. + assert sorted(classify_ode(Eq(f(x).diff(x), 0), f(x), dict=True).keys()) == \ + ['default', 'nth_linear_constant_coeff_homogeneous', 'order'] + a = classify_ode(2*x*f(x)*f(x).diff(x) + (1 + x)*f(x)**2 - exp(x), f(x), dict=True, hint='Bernoulli') + assert sorted(a.keys()) == ['Bernoulli', 'Bernoulli_Integral', 'default', 'order', 'ordered_hints'] + + # test issue 22155 + a = classify_ode(f(x).diff(x) - exp(f(x) - x), f(x)) + assert a == ('separable', + '1st_exact', '1st_power_series', + 'lie_group', 'separable_Integral', + '1st_exact_Integral') + + +def test_classify_ode_ics(): + # Dummy + eq = f(x).diff(x, x) - f(x) + + # Not f(0) or f'(0) + ics = {x: 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + + ############################ + # f(0) type (AppliedUndef) # + ############################ + + + # Wrong function + ics = {g(0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(0, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(0): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(0): f(0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(0): 1} + classify_ode(eq, f(x), ics=ics) + + + ##################### + # f'(0) type (Subs) # + ##################### + + # Wrong function + ics = {g(x).diff(x).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(y).diff(y).subs(y, x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Wrong variable + ics = {f(y).diff(y).subs(y, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(x, y).diff(x).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Derivative wrt wrong vars + ics = {Derivative(f(x), x, y).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(x).diff(x).subs(x, 0): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): f(x).diff(x).subs(x, 0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): 1} + classify_ode(eq, f(x), ics=ics) + + ########################### + # f'(y) type (Derivative) # + ########################### + + # Wrong function + ics = {g(x).diff(x).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(y).diff(y).subs(y, x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(x, y).diff(x).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Derivative wrt wrong vars + ics = {Derivative(f(x), x, z).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(x).diff(x).subs(x, y): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): f(0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(x).diff(x).subs(x, y): 1} + classify_ode(eq, f(x), ics=ics) + +def test_classify_sysode(): + # Here x is assumed to be x(t) and y as y(t) for simplicity. + # Similarly diff(x,t) and diff(y,y) is assumed to be x1 and y1 respectively. + k, l, m, n = symbols('k, l, m, n', Integer=True) + k1, k2, k3, l1, l2, l3, m1, m2, m3 = symbols('k1, k2, k3, l1, l2, l3, m1, m2, m3', Integer=True) + P, Q, R, p, q, r = symbols('P, Q, R, p, q, r', cls=Function) + P1, P2, P3, Q1, Q2, R1, R2 = symbols('P1, P2, P3, Q1, Q2, R1, R2', cls=Function) + x, y, z = symbols('x, y, z', cls=Function) + t = symbols('t') + x1 = diff(x(t),t) + y1 = diff(y(t),t) + + eq6 = (Eq(x1, exp(k*x(t))*P(x(t),y(t))), Eq(y1,r(y(t))*P(x(t),y(t)))) + sol6 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): 0, (1, x(t), 1): 0, (0, x(t), 1): 1, (1, y(t), 0): 0, \ + (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): 0, (1, y(t), 1): 1}, 'type_of_equation': 'type2', 'func': \ + [x(t), y(t)], 'is_linear': False, 'eq': [-P(x(t), y(t))*exp(k*x(t)) + Derivative(x(t), t), -P(x(t), \ + y(t))*r(y(t)) + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq6) == sol6 + + eq7 = (Eq(x1, x(t)**2+y(t)/x(t)), Eq(y1, x(t)/y(t))) + sol7 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): 0, (1, x(t), 1): 0, (0, x(t), 1): 1, (1, y(t), 0): 0, \ + (1, x(t), 0): -1/y(t), (0, y(t), 1): 0, (0, y(t), 0): -1/x(t), (1, y(t), 1): 1}, 'type_of_equation': 'type3', \ + 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)**2 + Derivative(x(t), t) - y(t)/x(t), -x(t)/y(t) + \ + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq7) == sol7 + + eq8 = (Eq(x1, P1(x(t))*Q1(y(t))*R(x(t),y(t),t)), Eq(y1, P1(x(t))*Q1(y(t))*R(x(t),y(t),t))) + sol8 = {'func': [x(t), y(t)], 'is_linear': False, 'type_of_equation': 'type4', 'eq': \ + [-P1(x(t))*Q1(y(t))*R(x(t), y(t), t) + Derivative(x(t), t), -P1(x(t))*Q1(y(t))*R(x(t), y(t), t) + \ + Derivative(y(t), t)], 'func_coeff': {(0, y(t), 1): 0, (1, y(t), 1): 1, (1, x(t), 1): 0, (0, y(t), 0): 0, \ + (1, x(t), 0): 0, (0, x(t), 0): 0, (1, y(t), 0): 0, (0, x(t), 1): 1}, 'order': {y(t): 1, x(t): 1}, 'no_of_equation': 2} + assert classify_sysode(eq8) == sol8 + + eq11 = (Eq(x1,x(t)*y(t)**3), Eq(y1,y(t)**5)) + sol11 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): -y(t)**3, (1, x(t), 1): 0, (0, x(t), 1): 1, \ + (1, y(t), 0): 0, (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): 0, (1, y(t), 1): 1}, 'type_of_equation': \ + 'type1', 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)*y(t)**3 + Derivative(x(t), t), \ + -y(t)**5 + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq11) == sol11 + + eq13 = (Eq(x1,x(t)*y(t)*sin(t)**2), Eq(y1,y(t)**2*sin(t)**2)) + sol13 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): -y(t)*sin(t)**2, (1, x(t), 1): 0, (0, x(t), 1): 1, \ + (1, y(t), 0): 0, (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): -x(t)*sin(t)**2, (1, y(t), 1): 1}, \ + 'type_of_equation': 'type4', 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)*y(t)*sin(t)**2 + \ + Derivative(x(t), t), -y(t)**2*sin(t)**2 + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq13) == sol13 + + +def test_solve_ics(): + # Basic tests that things work from dsolve. + assert dsolve(f(x).diff(x) - 1/f(x), f(x), ics={f(1): 2}) == \ + Eq(f(x), sqrt(2 * x + 2)) + assert dsolve(f(x).diff(x) - f(x), f(x), ics={f(0): 1}) == Eq(f(x), exp(x)) + assert dsolve(f(x).diff(x) - f(x), f(x), ics={f(x).diff(x).subs(x, 0): 1}) == Eq(f(x), exp(x)) + assert dsolve(f(x).diff(x, x) + f(x), f(x), ics={f(0): 1, + f(x).diff(x).subs(x, 0): 1}) == Eq(f(x), sin(x) + cos(x)) + assert dsolve([f(x).diff(x) - f(x) + g(x), g(x).diff(x) - g(x) - f(x)], + [f(x), g(x)], ics={f(0): 1, g(0): 0}) == [Eq(f(x), exp(x)*cos(x)), Eq(g(x), exp(x)*sin(x))] + + # Test cases where dsolve returns two solutions. + eq = (x**2*f(x)**2 - x).diff(x) + assert dsolve(eq, f(x), ics={f(1): 0}) == [Eq(f(x), + -sqrt(x - 1)/x), Eq(f(x), sqrt(x - 1)/x)] + assert dsolve(eq, f(x), ics={f(x).diff(x).subs(x, 1): 0}) == [Eq(f(x), + -sqrt(x - S.Half)/x), Eq(f(x), sqrt(x - S.Half)/x)] + + eq = cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x) + assert dsolve(eq, f(x), + ics={f(0):1}, hint='1st_exact', simplify=False) == Eq(x*cos(f(x)) + f(x)**3/3, Rational(1, 3)) + assert dsolve(eq, f(x), + ics={f(0):1}, hint='1st_exact', simplify=True) == Eq(x*cos(f(x)) + f(x)**3/3, Rational(1, 3)) + + assert solve_ics([Eq(f(x), C1*exp(x))], [f(x)], [C1], {f(0): 1}) == {C1: 1} + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], + {f(0): 1, f(pi/2): 1}) == {C1: 1, C2: 1} + + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], + {f(0): 1, f(x).diff(x).subs(x, 0): 1}) == {C1: 1, C2: 1} + + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], {f(0): 1}) == \ + {C2: 1} + + # Some more complicated tests Refer to PR #16098 + + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0, f(x).diff(x).subs(x, 1):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6 - x / 2)} + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0})) == \ + {Eq(f(x), 0), Eq(f(x), C2*x + x**3/6)} + + K, r, f0 = symbols('K r f0') + sol = Eq(f(x), K*f0*exp(r*x)/((-K + f0)*(f0*exp(r*x)/(-K + f0) - 1))) + assert (dsolve(Eq(f(x).diff(x), r * f(x) * (1 - f(x) / K)), f(x), ics={f(0): f0})) == sol + + + #Order dependent issues Refer to PR #16098 + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(x).diff(x).subs(x,0):0, f(0):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6)} + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0, f(x).diff(x).subs(x,0):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6)} + + # XXX: Ought to be ValueError + raises(ValueError, lambda: solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], {f(0): 1, f(pi): 1})) + + # Degenerate case. f'(0) is identically 0. + raises(ValueError, lambda: solve_ics([Eq(f(x), sqrt(C1 - x**2))], [f(x)], [C1], {f(x).diff(x).subs(x, 0): 0})) + + EI, q, L = symbols('EI q L') + + # eq = Eq(EI*diff(f(x), x, 4), q) + sols = [Eq(f(x), C1 + C2*x + C3*x**2 + C4*x**3 + q*x**4/(24*EI))] + funcs = [f(x)] + constants = [C1, C2, C3, C4] + # Test both cases, Derivative (the default from f(x).diff(x).subs(x, L)), + # and Subs + ics1 = {f(0): 0, + f(x).diff(x).subs(x, 0): 0, + f(L).diff(L, 2): 0, + f(L).diff(L, 3): 0} + ics2 = {f(0): 0, + f(x).diff(x).subs(x, 0): 0, + Subs(f(x).diff(x, 2), x, L): 0, + Subs(f(x).diff(x, 3), x, L): 0} + + solved_constants1 = solve_ics(sols, funcs, constants, ics1) + solved_constants2 = solve_ics(sols, funcs, constants, ics2) + assert solved_constants1 == solved_constants2 == { + C1: 0, + C2: 0, + C3: L**2*q/(4*EI), + C4: -L*q/(6*EI)} + + # Allow the ics to refer to f + ics = {f(0): f(0)} + assert dsolve(f(x).diff(x) - f(x), f(x), ics=ics) == Eq(f(x), f(0)*exp(x)) + + ics = {f(x).diff(x).subs(x, 0): f(x).diff(x).subs(x, 0), f(0): f(0)} + assert dsolve(f(x).diff(x, x) + f(x), f(x), ics=ics) == \ + Eq(f(x), f(0)*cos(x) + f(x).diff(x).subs(x, 0)*sin(x)) + +def test_ode_order(): + f = Function('f') + g = Function('g') + x = Symbol('x') + assert ode_order(3*x*exp(f(x)), f(x)) == 0 + assert ode_order(x*diff(f(x), x) + 3*x*f(x) - sin(x)/x, f(x)) == 1 + assert ode_order(x**2*f(x).diff(x, x) + x*diff(f(x), x) - f(x), f(x)) == 2 + assert ode_order(diff(x*exp(f(x)), x, x), f(x)) == 2 + assert ode_order(diff(x*diff(x*exp(f(x)), x, x), x), f(x)) == 3 + assert ode_order(diff(f(x), x, x), g(x)) == 0 + assert ode_order(diff(f(x), x, x)*diff(g(x), x), f(x)) == 2 + assert ode_order(diff(f(x), x, x)*diff(g(x), x), g(x)) == 1 + assert ode_order(diff(x*diff(x*exp(f(x)), x, x), x), g(x)) == 0 + # issue 5835: ode_order has to also work for unevaluated derivatives + # (ie, without using doit()). + assert ode_order(Derivative(x*f(x), x), f(x)) == 1 + assert ode_order(x*sin(Derivative(x*f(x)**2, x, x)), f(x)) == 2 + assert ode_order(Derivative(x*Derivative(x*exp(f(x)), x, x), x), g(x)) == 0 + assert ode_order(Derivative(f(x), x, x), g(x)) == 0 + assert ode_order(Derivative(x*exp(f(x)), x, x), f(x)) == 2 + assert ode_order(Derivative(f(x), x, x)*Derivative(g(x), x), g(x)) == 1 + assert ode_order(Derivative(x*Derivative(f(x), x, x), x), f(x)) == 3 + assert ode_order( + x*sin(Derivative(x*Derivative(f(x), x)**2, x, x)), f(x)) == 3 + + +def test_homogeneous_order(): + assert homogeneous_order(exp(y/x) + tan(y/x), x, y) == 0 + assert homogeneous_order(x**2 + sin(x)*cos(y), x, y) is None + assert homogeneous_order(x - y - x*sin(y/x), x, y) == 1 + assert homogeneous_order((x*y + sqrt(x**4 + y**4) + x**2*(log(x) - log(y)))/ + (pi*x**Rational(2, 3)*sqrt(y)**3), x, y) == Rational(-1, 6) + assert homogeneous_order(y/x*cos(y/x) - x/y*sin(y/x) + cos(y/x), x, y) == 0 + assert homogeneous_order(f(x), x, f(x)) == 1 + assert homogeneous_order(f(x)**2, x, f(x)) == 2 + assert homogeneous_order(x*y*z, x, y) == 2 + assert homogeneous_order(x*y*z, x, y, z) == 3 + assert homogeneous_order(x**2*f(x)/sqrt(x**2 + f(x)**2), f(x)) is None + assert homogeneous_order(f(x, y)**2, x, f(x, y), y) == 2 + assert homogeneous_order(f(x, y)**2, x, f(x), y) is None + assert homogeneous_order(f(x, y)**2, x, f(x, y)) is None + assert homogeneous_order(f(y, x)**2, x, y, f(x, y)) is None + assert homogeneous_order(f(y), f(x), x) is None + assert homogeneous_order(-f(x)/x + 1/sin(f(x)/ x), f(x), x) == 0 + assert homogeneous_order(log(1/y) + log(x**2), x, y) is None + assert homogeneous_order(log(1/y) + log(x), x, y) == 0 + assert homogeneous_order(log(x/y), x, y) == 0 + assert homogeneous_order(2*log(1/y) + 2*log(x), x, y) == 0 + a = Symbol('a') + assert homogeneous_order(a*log(1/y) + a*log(x), x, y) == 0 + assert homogeneous_order(f(x).diff(x), x, y) is None + assert homogeneous_order(-f(x).diff(x) + x, x, y) is None + assert homogeneous_order(O(x), x, y) is None + assert homogeneous_order(x + O(x**2), x, y) is None + assert homogeneous_order(x**pi, x) == pi + assert homogeneous_order(x**x, x) is None + raises(ValueError, lambda: homogeneous_order(x*y)) + + +@XFAIL +def test_noncircularized_real_imaginary_parts(): + # If this passes, lines numbered 3878-3882 (at the time of this commit) + # of sympy/solvers/ode.py for nth_linear_constant_coeff_homogeneous + # should be removed. + y = sqrt(1+x) + i, r = im(y), re(y) + assert not (i.has(atan2) and r.has(atan2)) + + +def test_collect_respecting_exponentials(): + # If this test passes, lines 1306-1311 (at the time of this commit) + # of sympy/solvers/ode.py should be removed. + sol = 1 + exp(x/2) + assert sol == collect( sol, exp(x/3)) + + +def test_undetermined_coefficients_match(): + assert _undetermined_coefficients_match(g(x), x) == {'test': False} + assert _undetermined_coefficients_match(sin(2*x + sqrt(5)), x) == \ + {'test': True, 'trialset': + {cos(2*x + sqrt(5)), sin(2*x + sqrt(5))}} + assert _undetermined_coefficients_match(sin(x)*cos(x), x) == \ + {'test': False} + s = {cos(x), x*cos(x), x**2*cos(x), x**2*sin(x), x*sin(x), sin(x)} + assert _undetermined_coefficients_match(sin(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': s} + assert _undetermined_coefficients_match( + sin(x)*x**2 + sin(x)*x + sin(x), x) == {'test': True, 'trialset': s} + assert _undetermined_coefficients_match( + exp(2*x)*sin(x)*(x**2 + x + 1), x + ) == { + 'test': True, 'trialset': {exp(2*x)*sin(x), x**2*exp(2*x)*sin(x), + cos(x)*exp(2*x), x**2*cos(x)*exp(2*x), x*cos(x)*exp(2*x), + x*exp(2*x)*sin(x)}} + assert _undetermined_coefficients_match(1/sin(x), x) == {'test': False} + assert _undetermined_coefficients_match(log(x), x) == {'test': False} + assert _undetermined_coefficients_match(2**(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': {2**x, x*2**x, x**2*2**x}} + assert _undetermined_coefficients_match(x**y, x) == {'test': False} + assert _undetermined_coefficients_match(exp(x)*exp(2*x + 1), x) == \ + {'test': True, 'trialset': {exp(1 + 3*x)}} + assert _undetermined_coefficients_match(sin(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': {x*cos(x), x*sin(x), x**2*cos(x), + x**2*sin(x), cos(x), sin(x)}} + assert _undetermined_coefficients_match(sin(x)*(x + sin(x)), x) == \ + {'test': False} + assert _undetermined_coefficients_match(sin(x)*(x + sin(2*x)), x) == \ + {'test': False} + assert _undetermined_coefficients_match(sin(x)*tan(x), x) == \ + {'test': False} + assert _undetermined_coefficients_match( + x**2*sin(x)*exp(x) + x*sin(x) + x, x + ) == { + 'test': True, 'trialset': {x**2*cos(x)*exp(x), x, cos(x), S.One, + exp(x)*sin(x), sin(x), x*exp(x)*sin(x), x*cos(x), x*cos(x)*exp(x), + x*sin(x), cos(x)*exp(x), x**2*exp(x)*sin(x)}} + assert _undetermined_coefficients_match(4*x*sin(x - 2), x) == { + 'trialset': {x*cos(x - 2), x*sin(x - 2), cos(x - 2), sin(x - 2)}, + 'test': True, + } + assert _undetermined_coefficients_match(2**x*x, x) == \ + {'test': True, 'trialset': {2**x, x*2**x}} + assert _undetermined_coefficients_match(2**x*exp(2*x), x) == \ + {'test': True, 'trialset': {2**x*exp(2*x)}} + assert _undetermined_coefficients_match(exp(-x)/x, x) == \ + {'test': False} + # Below are from Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 231 + assert _undetermined_coefficients_match(S(4), x) == \ + {'test': True, 'trialset': {S.One}} + assert _undetermined_coefficients_match(12*exp(x), x) == \ + {'test': True, 'trialset': {exp(x)}} + assert _undetermined_coefficients_match(exp(I*x), x) == \ + {'test': True, 'trialset': {exp(I*x)}} + assert _undetermined_coefficients_match(sin(x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x)}} + assert _undetermined_coefficients_match(cos(x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x)}} + assert _undetermined_coefficients_match(8 + 6*exp(x) + 2*sin(x), x) == \ + {'test': True, 'trialset': {S.One, cos(x), sin(x), exp(x)}} + assert _undetermined_coefficients_match(x**2, x) == \ + {'test': True, 'trialset': {S.One, x, x**2}} + assert _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(x), exp(x), exp(-x)}} + assert _undetermined_coefficients_match(2*exp(2*x)*sin(x), x) == \ + {'test': True, 'trialset': {exp(2*x)*sin(x), cos(x)*exp(2*x)}} + assert _undetermined_coefficients_match(x - sin(x), x) == \ + {'test': True, 'trialset': {S.One, x, cos(x), sin(x)}} + assert _undetermined_coefficients_match(x**2 + 2*x, x) == \ + {'test': True, 'trialset': {S.One, x, x**2}} + assert _undetermined_coefficients_match(4*x*sin(x), x) == \ + {'test': True, 'trialset': {x*cos(x), x*sin(x), cos(x), sin(x)}} + assert _undetermined_coefficients_match(x*sin(2*x), x) == \ + {'test': True, 'trialset': + {x*cos(2*x), x*sin(2*x), cos(2*x), sin(2*x)}} + assert _undetermined_coefficients_match(x**2*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), x**2*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(2*exp(-x) - x**2*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), x**2*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(exp(-2*x) + x**2, x) == \ + {'test': True, 'trialset': {S.One, x, x**2, exp(-2*x)}} + assert _undetermined_coefficients_match(x*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(x + exp(2*x), x) == \ + {'test': True, 'trialset': {S.One, x, exp(2*x)}} + assert _undetermined_coefficients_match(sin(x) + exp(-x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x), exp(-x)}} + assert _undetermined_coefficients_match(exp(x), x) == \ + {'test': True, 'trialset': {exp(x)}} + # converted from sin(x)**2 + assert _undetermined_coefficients_match(S.Half - cos(2*x)/2, x) == \ + {'test': True, 'trialset': {S.One, cos(2*x), sin(2*x)}} + # converted from exp(2*x)*sin(x)**2 + assert _undetermined_coefficients_match( + exp(2*x)*(S.Half + cos(2*x)/2), x + ) == { + 'test': True, 'trialset': {exp(2*x)*sin(2*x), cos(2*x)*exp(2*x), + exp(2*x)}} + assert _undetermined_coefficients_match(2*x + sin(x) + cos(x), x) == \ + {'test': True, 'trialset': {S.One, x, cos(x), sin(x)}} + # converted from sin(2*x)*sin(x) + assert _undetermined_coefficients_match(cos(x)/2 - cos(3*x)/2, x) == \ + {'test': True, 'trialset': {cos(x), cos(3*x), sin(x), sin(3*x)}} + assert _undetermined_coefficients_match(cos(x**2), x) == {'test': False} + assert _undetermined_coefficients_match(2**(x**2), x) == {'test': False} + + +def test_issue_4785_22462(): + from sympy.abc import A + eq = x + A*(x + diff(f(x), x) + f(x)) + diff(f(x), x) + f(x) + 2 + assert classify_ode(eq, f(x)) == ('factorable', '1st_exact', '1st_linear', + 'Bernoulli', 'almost_linear', '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_exact_Integral', '1st_linear_Integral', 'Bernoulli_Integral', + 'almost_linear_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + # issue 4864 + eq = (x**2 + f(x)**2)*f(x).diff(x) - 2*x*f(x) + assert classify_ode(eq, f(x)) == ('factorable', '1st_exact', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', + 'lie_group', '1st_exact_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + + +def test_issue_4825(): + raises(ValueError, lambda: dsolve(f(x, y).diff(x) - y*f(x, y), f(x))) + assert classify_ode(f(x, y).diff(x) - y*f(x, y), f(x), dict=True) == \ + {'order': 0, 'default': None, 'ordered_hints': ()} + # See also issue 3793, test Z13. + raises(ValueError, lambda: dsolve(f(x).diff(x), f(y))) + assert classify_ode(f(x).diff(x), f(y), dict=True) == \ + {'order': 0, 'default': None, 'ordered_hints': ()} + + +def test_constant_renumber_order_issue_5308(): + from sympy.utilities.iterables import variations + + assert constant_renumber(C1*x + C2*y) == \ + constant_renumber(C1*y + C2*x) == \ + C1*x + C2*y + e = C1*(C2 + x)*(C3 + y) + for a, b, c in variations([C1, C2, C3], 3): + assert constant_renumber(a*(b + x)*(c + y)) == e + + +def test_constant_renumber(): + e1, e2, x, y = symbols("e1:3 x y") + exprs = [e2*x, e1*x + e2*y] + + assert constant_renumber(exprs[0]) == e2*x + assert constant_renumber(exprs[0], variables=[x]) == C1*x + assert constant_renumber(exprs[0], variables=[x], newconstants=[C2]) == C2*x + assert constant_renumber(exprs, variables=[x, y]) == [C1*x, C1*y + C2*x] + assert constant_renumber(exprs, variables=[x, y], newconstants=symbols("C3:5")) == [C3*x, C3*y + C4*x] + + +def test_issue_5770(): + k = Symbol("k", real=True) + t = Symbol('t') + w = Function('w') + sol = dsolve(w(t).diff(t, 6) - k**6*w(t), w(t)) + assert len([s for s in sol.free_symbols if s.name.startswith('C')]) == 6 + assert constantsimp((C1*cos(x) + C2*cos(x))*exp(x), {C1, C2}) == \ + C1*cos(x)*exp(x) + assert constantsimp(C1*cos(x) + C2*cos(x) + C3*sin(x), {C1, C2, C3}) == \ + C1*cos(x) + C3*sin(x) + assert constantsimp(exp(C1 + x), {C1}) == C1*exp(x) + assert constantsimp(x + C1 + y, {C1, y}) == C1 + x + assert constantsimp(x + C1 + Integral(x, (x, 1, 2)), {C1}) == C1 + x + + +def test_issue_5112_5430(): + assert homogeneous_order(-log(x) + acosh(x), x) is None + assert homogeneous_order(y - log(x), x, y) is None + + +def test_issue_5095(): + f = Function('f') + raises(ValueError, lambda: dsolve(f(x).diff(x)**2, f(x), 'fdsjf')) + + +def test_homogeneous_function(): + f = Function('f') + eq1 = tan(x + f(x)) + eq2 = sin((3*x)/(4*f(x))) + eq3 = cos(x*f(x)*Rational(3, 4)) + eq4 = log((3*x + 4*f(x))/(5*f(x) + 7*x)) + eq5 = exp((2*x**2)/(3*f(x)**2)) + eq6 = log((3*x + 4*f(x))/(5*f(x) + 7*x) + exp((2*x**2)/(3*f(x)**2))) + eq7 = sin((3*x)/(5*f(x) + x**2)) + assert homogeneous_order(eq1, x, f(x)) == None + assert homogeneous_order(eq2, x, f(x)) == 0 + assert homogeneous_order(eq3, x, f(x)) == None + assert homogeneous_order(eq4, x, f(x)) == 0 + assert homogeneous_order(eq5, x, f(x)) == 0 + assert homogeneous_order(eq6, x, f(x)) == 0 + assert homogeneous_order(eq7, x, f(x)) == None + + +def test_linear_coeff_match(): + n, d = z*(2*x + 3*f(x) + 5), z*(7*x + 9*f(x) + 11) + rat = n/d + eq1 = sin(rat) + cos(rat.expand()) + obj1 = LinearCoefficients(eq1) + eq2 = rat + obj2 = LinearCoefficients(eq2) + eq3 = log(sin(rat)) + obj3 = LinearCoefficients(eq3) + ans = (4, Rational(-13, 3)) + assert obj1._linear_coeff_match(eq1, f(x)) == ans + assert obj2._linear_coeff_match(eq2, f(x)) == ans + assert obj3._linear_coeff_match(eq3, f(x)) == ans + + # no c + eq4 = (3*x)/f(x) + obj4 = LinearCoefficients(eq4) + # not x and f(x) + eq5 = (3*x + 2)/x + obj5 = LinearCoefficients(eq5) + # denom will be zero + eq6 = (3*x + 2*f(x) + 1)/(3*x + 2*f(x) + 5) + obj6 = LinearCoefficients(eq6) + # not rational coefficient + eq7 = (3*x + 2*f(x) + sqrt(2))/(3*x + 2*f(x) + 5) + obj7 = LinearCoefficients(eq7) + assert obj4._linear_coeff_match(eq4, f(x)) is None + assert obj5._linear_coeff_match(eq5, f(x)) is None + assert obj6._linear_coeff_match(eq6, f(x)) is None + assert obj7._linear_coeff_match(eq7, f(x)) is None + + +def test_constantsimp_take_problem(): + c = exp(C1) + 2 + assert len(Poly(constantsimp(exp(C1) + c + c*x, [C1])).gens) == 2 + + +def test_series(): + C1 = Symbol("C1") + eq = f(x).diff(x) - f(x) + sol = Eq(f(x), C1 + C1*x + C1*x**2/2 + C1*x**3/6 + C1*x**4/24 + + C1*x**5/120 + O(x**6)) + assert dsolve(eq, hint='1st_power_series') == sol + assert checkodesol(eq, sol, order=1)[0] + + eq = f(x).diff(x) - x*f(x) + sol = Eq(f(x), C1*x**4/8 + C1*x**2/2 + C1 + O(x**6)) + assert dsolve(eq, hint='1st_power_series') == sol + assert checkodesol(eq, sol, order=1)[0] + + eq = f(x).diff(x) - sin(x*f(x)) + sol = Eq(f(x), (x - 2)**2*(1+ sin(4))*cos(4) + (x - 2)*sin(4) + 2 + O(x**3)) + assert dsolve(eq, hint='1st_power_series', ics={f(2): 2}, n=3) == sol + # FIXME: The solution here should be O((x-2)**3) so is incorrect + #assert checkodesol(eq, sol, order=1)[0] + + +@slow +def test_2nd_power_series_ordinary(): + C1, C2 = symbols("C1 C2") + + eq = f(x).diff(x, 2) - x*f(x) + assert classify_ode(eq) == ('2nd_linear_airy', '2nd_power_series_ordinary') + sol = Eq(f(x), C2*(x**3/6 + 1) + C1*x*(x**3/12 + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary') == sol + assert checkodesol(eq, sol) == (True, 0) + + sol = Eq(f(x), C2*((x + 2)**4/6 + (x + 2)**3/6 - (x + 2)**2 + 1) + + C1*(x + (x + 2)**4/12 - (x + 2)**3/3 + S(2)) + + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary', x0=-2) == sol + # FIXME: Solution should be O((x+2)**6) + # assert checkodesol(eq, sol) == (True, 0) + + sol = Eq(f(x), C2*x + C1 + O(x**2)) + assert dsolve(eq, hint='2nd_power_series_ordinary', n=2) == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = (1 + x**2)*(f(x).diff(x, 2)) + 2*x*(f(x).diff(x)) -2*f(x) + assert classify_ode(eq) == ('factorable', '2nd_hypergeometric', '2nd_hypergeometric_Integral', + '2nd_power_series_ordinary') + + sol = Eq(f(x), C2*(-x**4/3 + x**2 + 1) + C1*x + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + x*(f(x).diff(x)) + f(x) + assert classify_ode(eq) == ('factorable', '2nd_power_series_ordinary',) + sol = Eq(f(x), C2*(x**4/8 - x**2/2 + 1) + C1*x*(-x**2/3 + 1) + O(x**6)) + assert dsolve(eq) == sol + # FIXME: checkodesol fails for this solution... + # assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + f(x).diff(x) - x*f(x) + assert classify_ode(eq) == ('2nd_power_series_ordinary',) + sol = Eq(f(x), C2*(-x**4/24 + x**3/6 + 1) + + C1*x*(x**3/24 + x**2/6 - x/2 + 1) + O(x**6)) + assert dsolve(eq) == sol + # FIXME: checkodesol fails for this solution... + # assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + x*f(x) + assert classify_ode(eq) == ('2nd_linear_airy', '2nd_power_series_ordinary') + sol = Eq(f(x), C2*(x**6/180 - x**3/6 + 1) + C1*x*(-x**3/12 + 1) + O(x**7)) + assert dsolve(eq, hint='2nd_power_series_ordinary', n=7) == sol + assert checkodesol(eq, sol) == (True, 0) + + +def test_2nd_power_series_regular(): + C1, C2, a = symbols("C1 C2 a") + eq = x**2*(f(x).diff(x, 2)) - 3*x*(f(x).diff(x)) + (4*x + 4)*f(x) + sol = Eq(f(x), C1*x**2*(-16*x**3/9 + 4*x**2 - 4*x + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = 4*x**2*(f(x).diff(x, 2)) -8*x**2*(f(x).diff(x)) + (4*x**2 + + 1)*f(x) + sol = Eq(f(x), C1*sqrt(x)*(x**4/24 + x**3/6 + x**2/2 + x + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x**2*(f(x).diff(x, 2)) - x**2*(f(x).diff(x)) + ( + x**2 - 2)*f(x) + sol = Eq(f(x), C1*(-x**6/720 - 3*x**5/80 - x**4/8 + x**2/2 + x/2 + 1)/x + + C2*x**2*(-x**3/60 + x**2/20 + x/2 + 1) + O(x**6)) + assert dsolve(eq) == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - Rational(1, 4))*f(x) + sol = Eq(f(x), C1*(x**4/24 - x**2/2 + 1)/sqrt(x) + + C2*sqrt(x)*(x**4/120 - x**2/6 + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x*f(x).diff(x, 2) + f(x).diff(x) - a*x*f(x) + sol = Eq(f(x), C1*(a**2*x**4/64 + a*x**2/4 + 1) + O(x**6)) + assert dsolve(eq, f(x), hint="2nd_power_series_regular") == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + ((1 - x)/x)*f(x).diff(x) + (a/x)*f(x) + sol = Eq(f(x), C1*(-a*x**5*(a - 4)*(a - 3)*(a - 2)*(a - 1)/14400 + \ + a*x**4*(a - 3)*(a - 2)*(a - 1)/576 - a*x**3*(a - 2)*(a - 1)/36 + \ + a*x**2*(a - 1)/4 - a*x + 1) + O(x**6)) + assert dsolve(eq, f(x), hint="2nd_power_series_regular") == sol + assert checkodesol(eq, sol) == (True, 0) + + +def test_issue_15056(): + t = Symbol('t') + C3 = Symbol('C3') + assert get_numbered_constants(Symbol('C1') * Function('C2')(t)) == C3 + + +def test_issue_15913(): + eq = -C1/x - 2*x*f(x) - f(x) + Derivative(f(x), x) + sol = C2*exp(x**2 + x) + exp(x**2 + x)*Integral(C1*exp(-x**2 - x)/x, x) + assert checkodesol(eq, sol) == (True, 0) + sol = C1 + C2*exp(-x*y) + eq = Derivative(y*f(x), x) + f(x).diff(x, 2) + assert checkodesol(eq, sol, f(x)) == (True, 0) + + +def test_issue_16146(): + raises(ValueError, lambda: dsolve([f(x).diff(x), g(x).diff(x)], [f(x), g(x), h(x)])) + raises(ValueError, lambda: dsolve([f(x).diff(x), g(x).diff(x)], [f(x)])) + + +def test_dsolve_remove_redundant_solutions(): + + eq = (f(x)-2)*f(x).diff(x) + sol = Eq(f(x), C1) + assert dsolve(eq) == sol + + eq = (f(x)-sin(x))*(f(x).diff(x, 2)) + sol = {Eq(f(x), C1 + C2*x), Eq(f(x), sin(x))} + assert set(dsolve(eq)) == sol + + eq = (f(x)**2-2*f(x)+1)*f(x).diff(x, 3) + sol = Eq(f(x), C1 + C2*x + C3*x**2) + assert dsolve(eq) == sol + + +def test_issue_13060(): + A, B = symbols("A B", cls=Function) + t = Symbol("t") + eq = [Eq(Derivative(A(t), t), A(t)*B(t)), Eq(Derivative(B(t), t), A(t)*B(t))] + sol = dsolve(eq) + assert checkodesol(eq, sol) == (True, [0, 0]) + + +def test_issue_22523(): + N, s = symbols('N s') + rho = Function('rho') + # intentionally use 4.0 to confirm issue with nfloat + # works here + eqn = 4.0*N*sqrt(N - 1)*rho(s) + (4*s**2*(N - 1) + (N - 2*s*(N - 1))**2 + )*Derivative(rho(s), (s, 2)) + match = classify_ode(eqn, dict=True, hint='all') + assert match['2nd_power_series_ordinary']['terms'] == 5 + C1, C2 = symbols('C1,C2') + sol = dsolve(eqn, hint='2nd_power_series_ordinary') + # there is no r(2.0) in this result + assert filldedent(sol) == filldedent(str(''' + Eq(rho(s), C2*(1 - 4.0*s**4*sqrt(N - 1.0)/N + 0.666666666666667*s**4/N + - 2.66666666666667*s**3*sqrt(N - 1.0)/N - 2.0*s**2*sqrt(N - 1.0)/N + + 9.33333333333333*s**4*sqrt(N - 1.0)/N**2 - 0.666666666666667*s**4/N**2 + + 2.66666666666667*s**3*sqrt(N - 1.0)/N**2 - + 5.33333333333333*s**4*sqrt(N - 1.0)/N**3) + C1*s*(1.0 - + 1.33333333333333*s**3*sqrt(N - 1.0)/N - 0.666666666666667*s**2*sqrt(N + - 1.0)/N + 1.33333333333333*s**3*sqrt(N - 1.0)/N**2) + O(s**6))''')) + + +def test_issue_22604(): + x1, x2 = symbols('x1, x2', cls = Function) + t, k1, k2, m1, m2 = symbols('t k1 k2 m1 m2', real = True) + k1, k2, m1, m2 = 1, 1, 1, 1 + eq1 = Eq(m1*diff(x1(t), t, 2) + k1*x1(t) - k2*(x2(t) - x1(t)), 0) + eq2 = Eq(m2*diff(x2(t), t, 2) + k2*(x2(t) - x1(t)), 0) + eqs = [eq1, eq2] + [x1sol, x2sol] = dsolve(eqs, [x1(t), x2(t)], ics = {x1(0):0, x1(t).diff().subs(t,0):0, \ + x2(0):1, x2(t).diff().subs(t,0):0}) + assert x1sol == Eq(x1(t), sqrt(3 - sqrt(5))*(sqrt(10) + 5*sqrt(2))*cos(sqrt(2)*t*sqrt(3 - sqrt(5))/2)/20 + \ + (-5*sqrt(2) + sqrt(10))*sqrt(sqrt(5) + 3)*cos(sqrt(2)*t*sqrt(sqrt(5) + 3)/2)/20) + assert x2sol == Eq(x2(t), (sqrt(5) + 5)*cos(sqrt(2)*t*sqrt(3 - sqrt(5))/2)/10 + (5 - sqrt(5))*cos(sqrt(2)*t*sqrt(sqrt(5) + 3)/2)/10) + + +def test_issue_22462(): + for de in [ + Eq(f(x).diff(x), -20*f(x)**2 - 500*f(x)/7200), + Eq(f(x).diff(x), -2*f(x)**2 - 5*f(x)/7)]: + assert 'Bernoulli' in classify_ode(de, f(x)) + + +def test_issue_23425(): + x = symbols('x') + y = Function('y') + eq = Eq(-E**x*y(x).diff().diff() + y(x).diff(), 0) + assert classify_ode(eq) == \ + ('Liouville', 'nth_order_reducible', \ + '2nd_power_series_ordinary', 'Liouville_Integral') + + +@SKIP("too slow for @slow") +def test_issue_25820(): + x = Symbol('x') + y = Function('y') + eq = y(x)**3*y(x).diff(x, 2) + 49 + assert dsolve(eq, y(x)) is not None # doesn't raise diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_riccati.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_riccati.py new file mode 100644 index 0000000000000000000000000000000000000000..548a1ee5b5e82d88f1b0aa319af09b8b9d1d9bfe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_riccati.py @@ -0,0 +1,877 @@ +from sympy.core.random import randint +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, oo) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.polys.polytools import Poly +from sympy.simplify.ratsimp import ratsimp +from sympy.solvers.ode.subscheck import checkodesol +from sympy.testing.pytest import slow +from sympy.solvers.ode.riccati import (riccati_normal, riccati_inverse_normal, + riccati_reduced, match_riccati, inverse_transform_poly, limit_at_inf, + check_necessary_conds, val_at_inf, construct_c_case_1, + construct_c_case_2, construct_c_case_3, construct_d_case_4, + construct_d_case_5, construct_d_case_6, rational_laurent_series, + solve_riccati) + +f = Function('f') +x = symbols('x') + +# These are the functions used to generate the tests +# SHOULD NOT BE USED DIRECTLY IN TESTS + +def rand_rational(maxint): + return Rational(randint(-maxint, maxint), randint(1, maxint)) + + +def rand_poly(x, degree, maxint): + return Poly([rand_rational(maxint) for _ in range(degree+1)], x) + + +def rand_rational_function(x, degree, maxint): + degnum = randint(1, degree) + degden = randint(1, degree) + num = rand_poly(x, degnum, maxint) + den = rand_poly(x, degden, maxint) + while den == Poly(0, x): + den = rand_poly(x, degden, maxint) + return num / den + + +def find_riccati_ode(ratfunc, x, yf): + y = ratfunc + yp = y.diff(x) + q1 = rand_rational_function(x, 1, 3) + q2 = rand_rational_function(x, 1, 3) + while q2 == 0: + q2 = rand_rational_function(x, 1, 3) + q0 = ratsimp(yp - q1*y - q2*y**2) + eq = Eq(yf.diff(), q0 + q1*yf + q2*yf**2) + sol = Eq(yf, y) + assert checkodesol(eq, sol) == (True, 0) + return eq, q0, q1, q2 + + +# Testing functions start + +def test_riccati_transformation(): + """ + This function tests the transformation of the + solution of a Riccati ODE to the solution of + its corresponding normal Riccati ODE. + + Each test case 4 values - + + 1. w - The solution to be transformed + 2. b1 - The coefficient of f(x) in the ODE. + 3. b2 - The coefficient of f(x)**2 in the ODE. + 4. y - The solution to the normal Riccati ODE. + """ + tests = [ + ( + x/(x - 1), + (x**2 + 7)/3*x, + x, + -x**2/(x - 1) - x*(x**2/3 + S(7)/3)/2 - 1/(2*x) + ), + ( + (2*x + 3)/(2*x + 2), + (3 - 3*x)/(x + 1), + 5*x, + -5*x*(2*x + 3)/(2*x + 2) - (3 - 3*x)/(Mul(2, x + 1, evaluate=False)) - 1/(2*x) + ), + ( + -1/(2*x**2 - 1), + 0, + (2 - x)/(4*x - 2), + (2 - x)/((4*x - 2)*(2*x**2 - 1)) - (4*x - 2)*(Mul(-4, 2 - x, evaluate=False)/(4*x - \ + 2)**2 - 1/(4*x - 2))/(Mul(2, 2 - x, evaluate=False)) + ), + ( + x, + (8*x - 12)/(12*x + 9), + x**3/(6*x - 9), + -x**4/(6*x - 9) - (8*x - 12)/(Mul(2, 12*x + 9, evaluate=False)) - (6*x - 9)*(-6*x**3/(6*x \ + - 9)**2 + 3*x**2/(6*x - 9))/(2*x**3) + )] + for w, b1, b2, y in tests: + assert y == riccati_normal(w, x, b1, b2) + assert w == riccati_inverse_normal(y, x, b1, b2).cancel() + + # Test bp parameter in riccati_inverse_normal + tests = [ + ( + (-2*x - 1)/(2*x**2 + 2*x - 2), + -2/x, + (-x - 1)/(4*x), + 8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1), + -2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) - (-2*x - 1)*(-x - 1)/(4*x*(2*x**2 + 2*x \ + - 2)) + 1/x + ), + ( + 3/(2*x**2), + -2/x, + (-x - 1)/(4*x), + 8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1), + -2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) + 1/x - Mul(3, -x - 1, evaluate=False)/(8*x**3) + )] + for w, b1, b2, bp, y in tests: + assert y == riccati_normal(w, x, b1, b2) + assert w == riccati_inverse_normal(y, x, b1, b2, bp).cancel() + + +def test_riccati_reduced(): + """ + This function tests the transformation of a + Riccati ODE to its normal Riccati ODE. + + Each test case 2 values - + + 1. eq - A Riccati ODE. + 2. normal_eq - The normal Riccati ODE of eq. + """ + tests = [ + ( + f(x).diff(x) - x**2 - x*f(x) - x*f(x)**2, + + f(x).diff(x) + f(x)**2 + x**3 - x**2/4 - 3/(4*x**2) + ), + ( + 6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)**2/x, + + -3*x**2*(1/x + (-x - 1)/x**2)**2/(4*(-x - 1)**2) + Mul(6, \ + -x - 1, evaluate=False)/(2*x + 9) + f(x)**2 + f(x).diff(x) \ + - (-1 + (x + 1)/x)/(x*(-x - 1)) + ), + ( + f(x)**2 + f(x).diff(x) - (x - 1)*f(x)/(-x - S(1)/2), + + -(2*x - 2)**2/(4*(2*x + 1)**2) + (2*x - 2)/(2*x + 1)**2 + \ + f(x)**2 + f(x).diff(x) - 1/(2*x + 1) + ), + ( + f(x).diff(x) - f(x)**2/x, + + f(x)**2 + f(x).diff(x) + 1/(4*x**2) + ), + ( + -3*(-x**2 - x + 1)/(x**2 + 6*x + 1) + f(x).diff(x) + f(x)**2/x, + + f(x)**2 + f(x).diff(x) + (3*x**2/(x**2 + 6*x + 1) + 3*x/(x**2 \ + + 6*x + 1) - 3/(x**2 + 6*x + 1))/x + 1/(4*x**2) + ), + ( + 6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)/x, + + False + ), + ( + f(x)*f(x).diff(x) - 1/x + f(x)/3 + f(x)**2/(x**2 - 2), + + False + )] + for eq, normal_eq in tests: + assert normal_eq == riccati_reduced(eq, f, x) + + +def test_match_riccati(): + """ + This function tests if an ODE is Riccati or not. + + Each test case has 5 values - + + 1. eq - The Riccati ODE. + 2. match - Boolean indicating if eq is a Riccati ODE. + 3. b0 - + 4. b1 - Coefficient of f(x) in eq. + 5. b2 - Coefficient of f(x)**2 in eq. + """ + tests = [ + # Test Rational Riccati ODEs + ( + f(x).diff(x) - (405*x**3 - 882*x**2 - 78*x + 92)/(243*x**4 \ + - 945*x**3 + 846*x**2 + 180*x - 72) - 2 - f(x)**2/(3*x + 1) \ + - (S(1)/3 - x)*f(x)/(S(1)/3 - 3*x/2), + + True, + + 45*x**3/(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 98*x**2/ \ + (27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 26*x/(81*x**4 - \ + 315*x**3 + 282*x**2 + 60*x - 24) + 2 + 92/(243*x**4 - 945*x**3 \ + + 846*x**2 + 180*x - 72), + + Mul(-1, 2 - 6*x, evaluate=False)/(9*x - 2), + + 1/(3*x + 1) + ), + ( + f(x).diff(x) + 4*x/27 - (x/3 - 1)*f(x)**2 - (2*x/3 + \ + 1)*f(x)/(3*x + 2) - S(10)/27 - (265*x**2 + 423*x + 162) \ + /(324*x**3 + 216*x**2), + + True, + + -4*x/27 + S(10)/27 + 3/(6*x**3 + 4*x**2) + 47/(36*x**2 \ + + 24*x) + 265/(324*x + 216), + + Mul(-1, -2*x - 3, evaluate=False)/(9*x + 6), + + x/3 - 1 + ), + ( + f(x).diff(x) - (304*x**5 - 745*x**4 + 631*x**3 - 876*x**2 \ + + 198*x - 108)/(36*x**6 - 216*x**5 + 477*x**4 - 567*x**3 + \ + 360*x**2 - 108*x) - S(17)/9 - (x - S(3)/2)*f(x)/(x/2 - \ + S(3)/2) - (x/3 - 3)*f(x)**2/(3*x), + + True, + + 304*x**4/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + 360*x - \ + 108) - 745*x**3/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + \ + 360*x - 108) + 631*x**2/(36*x**5 - 216*x**4 + 477*x**3 - 567* \ + x**2 + 360*x - 108) - 292*x/(12*x**5 - 72*x**4 + 159*x**3 - \ + 189*x**2 + 120*x - 36) + S(17)/9 - 12/(4*x**6 - 24*x**5 + \ + 53*x**4 - 63*x**3 + 40*x**2 - 12*x) + 22/(4*x**5 - 24*x**4 \ + + 53*x**3 - 63*x**2 + 40*x - 12), + + Mul(-1, 3 - 2*x, evaluate=False)/(x - 3), + + Mul(-1, 9 - x, evaluate=False)/(9*x) + ), + # Test Non-Rational Riccati ODEs + ( + f(x).diff(x) - x**(S(3)/2)/(x**(S(1)/2) - 2) + x**2*f(x) + \ + x*f(x)**2/(x**(S(3)/4)), + False, 0, 0, 0 + ), + ( + f(x).diff(x) - sin(x**2) + exp(x)*f(x) + log(x)*f(x)**2, + False, 0, 0, 0 + ), + ( + f(x).diff(x) - tanh(x + sqrt(x)) + f(x) + x**4*f(x)**2, + False, 0, 0, 0 + ), + # Test Non-Riccati ODEs + ( + (1 - x**2)*f(x).diff(x, 2) - 2*x*f(x).diff(x) + 20*f(x), + False, 0, 0, 0 + ), + ( + f(x).diff(x) - x**2 + x**3*f(x) + (x**2/(x + 1))*f(x)**3, + False, 0, 0, 0 + ), + ( + f(x).diff(x)*f(x)**2 + (x**2 - 1)/(x**3 + 1)*f(x) + 1/(2*x \ + + 3) + f(x)**2, + False, 0, 0, 0 + )] + for eq, res, b0, b1, b2 in tests: + match, funcs = match_riccati(eq, f, x) + assert match == res + if res: + assert [b0, b1, b2] == funcs + + +def test_val_at_inf(): + """ + This function tests the valuation of rational + function at oo. + + Each test case has 3 values - + + 1. num - Numerator of rational function. + 2. den - Denominator of rational function. + 3. val_inf - Valuation of rational function at oo + """ + tests = [ + # degree(denom) > degree(numer) + ( + Poly(10*x**3 + 8*x**2 - 13*x + 6, x), + Poly(-13*x**10 - x**9 + 5*x**8 + 7*x**7 + 10*x**6 + 6*x**5 - 7*x**4 + 11*x**3 - 8*x**2 + 5*x + 13, x), + 7 + ), + ( + Poly(1, x), + Poly(-9*x**4 + 3*x**3 + 15*x**2 - 6*x - 14, x), + 4 + ), + # degree(denom) == degree(numer) + ( + Poly(-6*x**3 - 8*x**2 + 8*x - 6, x), + Poly(-5*x**3 + 12*x**2 - 6*x - 9, x), + 0 + ), + # degree(denom) < degree(numer) + ( + Poly(12*x**8 - 12*x**7 - 11*x**6 + 8*x**5 + 3*x**4 - x**3 + x**2 - 11*x, x), + Poly(-14*x**2 + x, x), + -6 + ), + ( + Poly(5*x**6 + 9*x**5 - 11*x**4 - 9*x**3 + x**2 - 4*x + 4, x), + Poly(15*x**4 + 3*x**3 - 8*x**2 + 15*x + 12, x), + -2 + )] + for num, den, val in tests: + assert val_at_inf(num, den, x) == val + + +def test_necessary_conds(): + """ + This function tests the necessary conditions for + a Riccati ODE to have a rational particular solution. + """ + # Valuation at Infinity is an odd negative integer + assert check_necessary_conds(-3, [1, 2, 4]) == False + # Valuation at Infinity is a positive integer lesser than 2 + assert check_necessary_conds(1, [1, 2, 4]) == False + # Multiplicity of a pole is an odd integer greater than 1 + assert check_necessary_conds(2, [3, 1, 6]) == False + # All values are correct + assert check_necessary_conds(-10, [1, 2, 8, 12]) == True + + +def test_inverse_transform_poly(): + """ + This function tests the substitution x -> 1/x + in rational functions represented using Poly. + """ + fns = [ + (15*x**3 - 8*x**2 - 2*x - 6)/(18*x + 6), + + (180*x**5 + 40*x**4 + 80*x**3 + 30*x**2 - 60*x - 80)/(180*x**3 - 150*x**2 + 75*x + 12), + + (-15*x**5 - 36*x**4 + 75*x**3 - 60*x**2 - 80*x - 60)/(80*x**4 + 60*x**3 + 60*x**2 + 60*x - 80), + + (60*x**7 + 24*x**6 - 15*x**5 - 20*x**4 + 30*x**2 + 100*x - 60)/(240*x**2 - 20*x - 30), + + (30*x**6 - 12*x**5 + 15*x**4 - 15*x**2 + 10*x + 60)/(3*x**10 - 45*x**9 + 15*x**5 + 15*x**4 - 5*x**3 \ + + 15*x**2 + 45*x - 15) + ] + for f in fns: + num, den = [Poly(e, x) for e in f.as_numer_denom()] + num, den = inverse_transform_poly(num, den, x) + assert f.subs(x, 1/x).cancel() == num/den + + +def test_limit_at_inf(): + """ + This function tests the limit at oo of a + rational function. + + Each test case has 3 values - + + 1. num - Numerator of rational function. + 2. den - Denominator of rational function. + 3. limit_at_inf - Limit of rational function at oo + """ + tests = [ + # deg(denom) > deg(numer) + ( + Poly(-12*x**2 + 20*x + 32, x), + Poly(32*x**3 + 72*x**2 + 3*x - 32, x), + 0 + ), + # deg(denom) < deg(numer) + ( + Poly(1260*x**4 - 1260*x**3 - 700*x**2 - 1260*x + 1400, x), + Poly(6300*x**3 - 1575*x**2 + 756*x - 540, x), + oo + ), + # deg(denom) < deg(numer), one of the leading coefficients is negative + ( + Poly(-735*x**8 - 1400*x**7 + 1680*x**6 - 315*x**5 - 600*x**4 + 840*x**3 - 525*x**2 \ + + 630*x + 3780, x), + Poly(1008*x**7 - 2940*x**6 - 84*x**5 + 2940*x**4 - 420*x**3 + 1512*x**2 + 105*x + 168, x), + -oo + ), + # deg(denom) == deg(numer) + ( + Poly(105*x**7 - 960*x**6 + 60*x**5 + 60*x**4 - 80*x**3 + 45*x**2 + 120*x + 15, x), + Poly(735*x**7 + 525*x**6 + 720*x**5 + 720*x**4 - 8400*x**3 - 2520*x**2 + 2800*x + 280, x), + S(1)/7 + ), + ( + Poly(288*x**4 - 450*x**3 + 280*x**2 - 900*x - 90, x), + Poly(607*x**4 + 840*x**3 - 1050*x**2 + 420*x + 420, x), + S(288)/607 + )] + for num, den, lim in tests: + assert limit_at_inf(num, den, x) == lim + + +def test_construct_c_case_1(): + """ + This function tests the Case 1 in the step + to calculate coefficients of c-vectors. + + Each test case has 4 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. pole - Pole of a(x) for which c-vector is being + calculated. + 4. c - The c-vector for the pole. + """ + tests = [ + ( + Poly(-3*x**3 + 3*x**2 + 4*x - 5, x, extension=True), + Poly(4*x**8 + 16*x**7 + 9*x**5 + 12*x**4 + 6*x**3 + 12*x**2, x, extension=True), + S(0), + [[S(1)/2 + sqrt(6)*I/6], [S(1)/2 - sqrt(6)*I/6]] + ), + ( + Poly(1200*x**3 + 1440*x**2 + 816*x + 560, x, extension=True), + Poly(128*x**5 - 656*x**4 + 1264*x**3 - 1125*x**2 + 385*x + 49, x, extension=True), + S(7)/4, + [[S(1)/2 + sqrt(16367978)/634], [S(1)/2 - sqrt(16367978)/634]] + ), + ( + Poly(4*x + 2, x, extension=True), + Poly(18*x**4 + (2 - 18*sqrt(3))*x**3 + (14 - 11*sqrt(3))*x**2 + (4 - 6*sqrt(3))*x \ + + 8*sqrt(3) + 16, x, domain='QQ'), + (S(1) + sqrt(3))/2, + [[S(1)/2 + sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2], \ + [S(1)/2 - sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2]] + )] + for num, den, pole, c in tests: + assert construct_c_case_1(num, den, x, pole) == c + + +def test_construct_c_case_2(): + """ + This function tests the Case 2 in the step + to calculate coefficients of c-vectors. + + Each test case has 5 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. pole - Pole of a(x) for which c-vector is being + calculated. + 4. mul - The multiplicity of the pole. + 5. c - The c-vector for the pole. + """ + tests = [ + # Testing poles with multiplicity 2 + ( + Poly(1, x, extension=True), + Poly((x - 1)**2*(x - 2), x, extension=True), + 1, 2, + [[-I*(-1 - I)/2], [I*(-1 + I)/2]] + ), + ( + Poly(3*x**5 - 12*x**4 - 7*x**3 + 1, x, extension=True), + Poly((3*x - 1)**2*(x + 2)**2, x, extension=True), + S(1)/3, 2, + [[-S(89)/98], [-S(9)/98]] + ), + # Testing poles with multiplicity 4 + ( + Poly(x**3 - x**2 + 4*x, x, extension=True), + Poly((x - 2)**4*(x + 5)**2, x, extension=True), + 2, 4, + [[7*sqrt(3)*(S(60)/343 - 4*sqrt(3)/7)/12, 2*sqrt(3)/7], \ + [-7*sqrt(3)*(S(60)/343 + 4*sqrt(3)/7)/12, -2*sqrt(3)/7]] + ), + ( + Poly(3*x**5 + x**4 + 3, x, extension=True), + Poly((4*x + 1)**4*(x + 2), x, extension=True), + -S(1)/4, 4, + [[128*sqrt(439)*(-sqrt(439)/128 - S(55)/14336)/439, sqrt(439)/256], \ + [-128*sqrt(439)*(sqrt(439)/128 - S(55)/14336)/439, -sqrt(439)/256]] + ), + # Testing poles with multiplicity 6 + ( + Poly(x**3 + 2, x, extension=True), + Poly((3*x - 1)**6*(x**2 + 1), x, extension=True), + S(1)/3, 6, + [[27*sqrt(66)*(-sqrt(66)/54 - S(131)/267300)/22, -2*sqrt(66)/1485, sqrt(66)/162], \ + [-27*sqrt(66)*(sqrt(66)/54 - S(131)/267300)/22, 2*sqrt(66)/1485, -sqrt(66)/162]] + ), + ( + Poly(x**2 + 12, x, extension=True), + Poly((x - sqrt(2))**6, x, extension=True), + sqrt(2), 6, + [[sqrt(14)*(S(6)/7 - 3*sqrt(14))/28, sqrt(7)/7, sqrt(14)], \ + [-sqrt(14)*(S(6)/7 + 3*sqrt(14))/28, -sqrt(7)/7, -sqrt(14)]] + )] + for num, den, pole, mul, c in tests: + assert construct_c_case_2(num, den, x, pole, mul) == c + + +def test_construct_c_case_3(): + """ + This function tests the Case 3 in the step + to calculate coefficients of c-vectors. + """ + assert construct_c_case_3() == [[1]] + + +def test_construct_d_case_4(): + """ + This function tests the Case 4 in the step + to calculate coefficients of the d-vector. + + Each test case has 4 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. mul - Multiplicity of oo as a pole. + 4. d - The d-vector. + """ + tests = [ + # Tests with multiplicity at oo = 2 + ( + Poly(-x**5 - 2*x**4 + 4*x**3 + 2*x + 5, x, extension=True), + Poly(9*x**3 - 2*x**2 + 10*x - 2, x, extension=True), + 2, + [[10*I/27, I/3, -3*I*(S(158)/243 - I/3)/2], \ + [-10*I/27, -I/3, 3*I*(S(158)/243 + I/3)/2]] + ), + ( + Poly(-x**6 + 9*x**5 + 5*x**4 + 6*x**3 + 5*x**2 + 6*x + 7, x, extension=True), + Poly(x**4 + 3*x**3 + 12*x**2 - x + 7, x, extension=True), + 2, + [[-6*I, I, -I*(17 - I)/2], [6*I, -I, I*(17 + I)/2]] + ), + # Tests with multiplicity at oo = 4 + ( + Poly(-2*x**6 - x**5 - x**4 - 2*x**3 - x**2 - 3*x - 3, x, extension=True), + Poly(3*x**2 + 10*x + 7, x, extension=True), + 4, + [[269*sqrt(6)*I/288, -17*sqrt(6)*I/36, sqrt(6)*I/3, -sqrt(6)*I*(S(16969)/2592 \ + - 2*sqrt(6)*I/3)/4], [-269*sqrt(6)*I/288, 17*sqrt(6)*I/36, -sqrt(6)*I/3, \ + sqrt(6)*I*(S(16969)/2592 + 2*sqrt(6)*I/3)/4]] + ), + ( + Poly(-3*x**5 - 3*x**4 - 3*x**3 - x**2 - 1, x, extension=True), + Poly(12*x - 2, x, extension=True), + 4, + [[41*I/192, 7*I/24, I/2, -I*(-S(59)/6912 - I)], \ + [-41*I/192, -7*I/24, -I/2, I*(-S(59)/6912 + I)]] + ), + # Tests with multiplicity at oo = 4 + ( + Poly(-x**7 - x**5 - x**4 - x**2 - x, x, extension=True), + Poly(x + 2, x, extension=True), + 6, + [[-5*I/2, 2*I, -I, I, -I*(-9 - 3*I)/2], [5*I/2, -2*I, I, -I, I*(-9 + 3*I)/2]] + ), + ( + Poly(-x**7 - x**6 - 2*x**5 - 2*x**4 - x**3 - x**2 + 2*x - 2, x, extension=True), + Poly(2*x - 2, x, extension=True), + 6, + [[3*sqrt(2)*I/4, 3*sqrt(2)*I/4, sqrt(2)*I/2, sqrt(2)*I/2, -sqrt(2)*I*(-S(7)/8 - \ + 3*sqrt(2)*I/2)/2], [-3*sqrt(2)*I/4, -3*sqrt(2)*I/4, -sqrt(2)*I/2, -sqrt(2)*I/2, \ + sqrt(2)*I*(-S(7)/8 + 3*sqrt(2)*I/2)/2]] + )] + for num, den, mul, d in tests: + ser = rational_laurent_series(num, den, x, oo, mul, 1) + assert construct_d_case_4(ser, mul//2) == d + + +def test_construct_d_case_5(): + """ + This function tests the Case 5 in the step + to calculate coefficients of the d-vector. + + Each test case has 3 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. d - The d-vector. + """ + tests = [ + ( + Poly(2*x**3 + x**2 + x - 2, x, extension=True), + Poly(9*x**3 + 5*x**2 + 2*x - 1, x, extension=True), + [[sqrt(2)/3, -sqrt(2)/108], [-sqrt(2)/3, sqrt(2)/108]] + ), + ( + Poly(3*x**5 + x**4 - x**3 + x**2 - 2*x - 2, x, domain='ZZ'), + Poly(9*x**5 + 7*x**4 + 3*x**3 + 2*x**2 + 5*x + 7, x, domain='ZZ'), + [[sqrt(3)/3, -2*sqrt(3)/27], [-sqrt(3)/3, 2*sqrt(3)/27]] + ), + ( + Poly(x**2 - x + 1, x, domain='ZZ'), + Poly(3*x**2 + 7*x + 3, x, domain='ZZ'), + [[sqrt(3)/3, -5*sqrt(3)/9], [-sqrt(3)/3, 5*sqrt(3)/9]] + )] + for num, den, d in tests: + # Multiplicity of oo is 0 + ser = rational_laurent_series(num, den, x, oo, 0, 1) + assert construct_d_case_5(ser) == d + + +def test_construct_d_case_6(): + """ + This function tests the Case 6 in the step + to calculate coefficients of the d-vector. + + Each test case has 3 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. d - The d-vector. + """ + tests = [ + ( + Poly(-2*x**2 - 5, x, domain='ZZ'), + Poly(4*x**4 + 2*x**2 + 10*x + 2, x, domain='ZZ'), + [[S(1)/2 + I/2], [S(1)/2 - I/2]] + ), + ( + Poly(-2*x**3 - 4*x**2 - 2*x - 5, x, domain='ZZ'), + Poly(x**6 - x**5 + 2*x**4 - 4*x**3 - 5*x**2 - 5*x + 9, x, domain='ZZ'), + [[1], [0]] + ), + ( + Poly(-5*x**3 + x**2 + 11*x + 12, x, domain='ZZ'), + Poly(6*x**8 - 26*x**7 - 27*x**6 - 10*x**5 - 44*x**4 - 46*x**3 - 34*x**2 \ + - 27*x - 42, x, domain='ZZ'), + [[1], [0]] + )] + for num, den, d in tests: + assert construct_d_case_6(num, den, x) == d + + +def test_rational_laurent_series(): + """ + This function tests the computation of coefficients + of Laurent series of a rational function. + + Each test case has 5 values - + + 1. num - Numerator of the rational function. + 2. den - Denominator of the rational function. + 3. x0 - Point about which Laurent series is to + be calculated. + 4. mul - Multiplicity of x0 if x0 is a pole of + the rational function (0 otherwise). + 5. n - Number of terms upto which the series + is to be calculated. + """ + tests = [ + # Laurent series about simple pole (Multiplicity = 1) + ( + Poly(x**2 - 3*x + 9, x, extension=True), + Poly(x**2 - x, x, extension=True), + S(1), 1, 6, + {1: 7, 0: -8, -1: 9, -2: -9, -3: 9, -4: -9} + ), + # Laurent series about multiple pole (Multiplicity > 1) + ( + Poly(64*x**3 - 1728*x + 1216, x, extension=True), + Poly(64*x**4 - 80*x**3 - 831*x**2 + 1809*x - 972, x, extension=True), + S(9)/8, 2, 3, + {0: S(32177152)/46521675, 2: S(1019)/984, -1: S(11947565056)/28610830125, \ + 1: S(209149)/75645} + ), + ( + Poly(1, x, extension=True), + Poly(x**5 + (-4*sqrt(2) - 1)*x**4 + (4*sqrt(2) + 12)*x**3 + (-12 - 8*sqrt(2))*x**2 \ + + (4 + 8*sqrt(2))*x - 4, x, extension=True), + sqrt(2), 4, 6, + {4: 1 + sqrt(2), 3: -3 - 2*sqrt(2), 2: Mul(-1, -3 - 2*sqrt(2), evaluate=False)/(-1 \ + + sqrt(2)), 1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**2, 0: Mul(-1, -3 - 2*sqrt(2), evaluate=False \ + )/(-1 + sqrt(2))**3, -1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**4} + ), + # Laurent series about oo + ( + Poly(x**5 - 4*x**3 + 6*x**2 + 10*x - 13, x, extension=True), + Poly(x**2 - 5, x, extension=True), + oo, 3, 6, + {3: 1, 2: 0, 1: 1, 0: 6, -1: 15, -2: 17} + ), + # Laurent series at x0 where x0 is not a pole of the function + # Using multiplicity as 0 (as x0 will not be a pole) + ( + Poly(3*x**3 + 6*x**2 - 2*x + 5, x, extension=True), + Poly(9*x**4 - x**3 - 3*x**2 + 4*x + 4, x, extension=True), + S(2)/5, 0, 1, + {0: S(3345)/3304, -1: S(399325)/2729104, -2: S(3926413375)/4508479808, \ + -3: S(-5000852751875)/1862002160704, -4: S(-6683640101653125)/6152055138966016} + ), + ( + Poly(-7*x**2 + 2*x - 4, x, extension=True), + Poly(7*x**5 + 9*x**4 + 8*x**3 + 3*x**2 + 6*x + 9, x, extension=True), + oo, 0, 6, + {0: 0, -2: 0, -5: -S(71)/49, -1: 0, -3: -1, -4: S(11)/7} + )] + for num, den, x0, mul, n, ser in tests: + assert ser == rational_laurent_series(num, den, x, x0, mul, n) + + +def check_dummy_sol(eq, solse, dummy_sym): + """ + Helper function to check if actual solution + matches expected solution if actual solution + contains dummy symbols. + """ + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + _, funcs = match_riccati(eq, f, x) + + sols = solve_riccati(f(x), x, *funcs) + C1 = Dummy('C1') + sols = [sol.subs(C1, dummy_sym) for sol in sols] + + assert all(x[0] for x in checkodesol(eq, sols)) + assert all(s1.dummy_eq(s2, dummy_sym) for s1, s2 in zip(sols, solse)) + + +def test_solve_riccati(): + """ + This function tests the computation of rational + particular solutions for a Riccati ODE. + + Each test case has 2 values - + + 1. eq - Riccati ODE to be solved. + 2. sol - Expected solution to the equation. + + Some examples have been taken from the paper - "Statistical Investigation of + First-Order Algebraic ODEs and their Rational General Solutions" by + Georg Grasegger, N. Thieu Vo, Franz Winkler + + https://www3.risc.jku.at/publications/download/risc_5197/RISCReport15-19.pdf + """ + C0 = Dummy('C0') + # Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2, + # a, b, c are rational functions of x + + tests = [ + # a(x) is a constant + ( + Eq(f(x).diff(x) + f(x)**2 - 2, 0), + [Eq(f(x), sqrt(2)), Eq(f(x), -sqrt(2))] + ), + # a(x) is a constant + ( + f(x)**2 + f(x).diff(x) + 4*f(x)/x + 2/x**2, + [Eq(f(x), (-2*C0 - x)/(C0*x + x**2))] + ), + # a(x) is a constant + ( + 2*x**2*f(x).diff(x) - x*(4*f(x) + f(x).diff(x) - 4) + (f(x) - 1)*f(x), + [Eq(f(x), (C0 + 2*x**2)/(C0 + x))] + ), + # Pole with multiplicity 1 + ( + Eq(f(x).diff(x), -f(x)**2 - 2/(x**3 - x**2)), + [Eq(f(x), 1/(x**2 - x))] + ), + # One pole of multiplicity 2 + ( + x**2 - (2*x + 1/x)*f(x) + f(x)**2 + f(x).diff(x), + [Eq(f(x), (C0*x + x**3 + 2*x)/(C0 + x**2)), Eq(f(x), x)] + ), + ( + x**4*f(x).diff(x) + x**2 - x*(2*f(x)**2 + f(x).diff(x)) + f(x), + [Eq(f(x), (C0*x**2 + x)/(C0 + x**2)), Eq(f(x), x**2)] + ), + # Multiple poles of multiplicity 2 + ( + -f(x)**2 + f(x).diff(x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \ + - 1)**2), + [Eq(f(x), (9*C0*x - 6*C0 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 \ + - 30*x + 6)/(6*C0*x**2 - 9*C0*x + 3*C0 + 6*x**6 - 29*x**5 + \ + 57*x**4 - 58*x**3 + 30*x**2 - 6*x)), Eq(f(x), (3*x - 2)/(2*x**2 \ + - 3*x + 1))] + ), + # Regression: Poles with even multiplicity > 2 fixed + ( + f(x)**2 + f(x).diff(x) - (4*x**6 - 8*x**5 + 12*x**4 + 4*x**3 + \ + 7*x**2 - 20*x + 4)/(4*x**4), + [Eq(f(x), (2*x**5 - 2*x**4 - x**3 + 4*x**2 + 3*x - 2)/(2*x**4 \ + - 2*x**2))] + ), + # Regression: Poles with even multiplicity > 2 fixed + ( + Eq(f(x).diff(x), (-x**6 + 15*x**4 - 40*x**3 + 45*x**2 - 24*x + 4)/\ + (x**12 - 12*x**11 + 66*x**10 - 220*x**9 + 495*x**8 - 792*x**7 + 924*x**6 - \ + 792*x**5 + 495*x**4 - 220*x**3 + 66*x**2 - 12*x + 1) + f(x)**2 + f(x)), + [Eq(f(x), 1/(x**6 - 6*x**5 + 15*x**4 - 20*x**3 + 15*x**2 - 6*x + 1))] + ), + # More than 2 poles with multiplicity 2 + # Regression: Fixed mistake in necessary conditions + ( + Eq(f(x).diff(x), x*f(x) + 2*x + (3*x - 2)*f(x)**2/(4*x + 2) + \ + (8*x**2 - 7*x + 26)/(16*x**3 - 24*x**2 + 8) - S(3)/2), + [Eq(f(x), (1 - 4*x)/(2*x - 2))] + ), + # Regression: Fixed mistake in necessary conditions + ( + Eq(f(x).diff(x), (-12*x**2 - 48*x - 15)/(24*x**3 - 40*x**2 + 8*x + 8) \ + + 3*f(x)**2/(6*x + 2)), + [Eq(f(x), (2*x + 1)/(2*x - 2))] + ), + # Imaginary poles + ( + f(x).diff(x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \ + - 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2), + [Eq(f(x), (-C0 - x**3 + x**2 - 2*x)/(C0*x - C0 + x**4 - x**3 + x**2 \ + - x)), Eq(f(x), -1/(x - 1))], + ), + # Imaginary coefficients in equation + ( + f(x).diff(x) - 2*I*(f(x)**2 + 1)/x, + [Eq(f(x), (-I*C0 + I*x**4)/(C0 + x**4)), Eq(f(x), -I)] + ), + # Regression: linsolve returning empty solution + # Large value of m (> 10) + ( + Eq(f(x).diff(x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\ + (2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)), + [Eq(f(x), (9 - x)/x), Eq(f(x), (40*x**14 + 28*x**13 + 420*x**12 + 2940*x**11 + \ + 18480*x**10 + 103950*x**9 + 519750*x**8 + 2286900*x**7 + 8731800*x**6 + 28378350*\ + x**5 + 76403250*x**4 + 163721250*x**3 + 261954000*x**2 + 278326125*x + 147349125)/\ + ((24*x**14 + 140*x**13 + 840*x**12 + 4620*x**11 + 23100*x**10 + 103950*x**9 + \ + 415800*x**8 + 1455300*x**7 + 4365900*x**6 + 10914750*x**5 + 21829500*x**4 + 32744250\ + *x**3 + 32744250*x**2 + 16372125*x)))] + ), + # Regression: Fixed bug due to a typo in paper + ( + Eq(f(x).diff(x), 18*x**3 + 18*x**2 + (-x/2 - S(1)/2)*f(x)**2 + 6), + [Eq(f(x), 6*x)] + ), + # Regression: Fixed bug due to a typo in paper + ( + Eq(f(x).diff(x), -3*x**3/4 + 15*x/2 + (x/3 - S(4)/3)*f(x)**2 \ + + 9 + (1 - x)*f(x)/x + 3/x), + [Eq(f(x), -3*x/2 - 3)] + )] + for eq, sol in tests: + check_dummy_sol(eq, sol, C0) + + +@slow +def test_solve_riccati_slow(): + """ + This function tests the computation of rational + particular solutions for a Riccati ODE. + + Each test case has 2 values - + + 1. eq - Riccati ODE to be solved. + 2. sol - Expected solution to the equation. + """ + C0 = Dummy('C0') + tests = [ + # Very large values of m (989 and 991) + ( + Eq(f(x).diff(x), (1 - x)*f(x)/(x - 3) + (2 - 12*x)*f(x)**2/(2*x - 9) + \ + (54924*x**3 - 405264*x**2 + 1084347*x - 1087533)/(8*x**4 - 132*x**3 + 810*x**2 - \ + 2187*x + 2187) + 495), + [Eq(f(x), (18*x + 6)/(2*x - 9))] + )] + for eq, sol in tests: + check_dummy_sol(eq, sol, C0) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_single.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_single.py new file mode 100644 index 0000000000000000000000000000000000000000..45b38029e97b9e74236c45f2f3efb6aa87c26e5c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_single.py @@ -0,0 +1,2902 @@ +# +# The main tests for the code in single.py are currently located in +# sympy/solvers/tests/test_ode.py +# +r""" +This File contains test functions for the individual hints used for solving ODEs. + +Examples of each solver will be returned by _get_examples_ode_sol_name_of_solver. + +Examples should have a key 'XFAIL' which stores the list of hints if they are +expected to fail for that hint. + +Functions that are for internal use: + +1) _ode_solver_test(ode_examples) - It takes a dictionary of examples returned by + _get_examples method and tests them with their respective hints. + +2) _test_particular_example(our_hint, example_name) - It tests the ODE example corresponding + to the hint provided. + +3) _test_all_hints(runxfail=False) - It is used to test all the examples with all the hints + currently implemented. It calls _test_all_examples_for_one_hint() which outputs whether the + given hint functions properly if it classifies the ODE example. + If runxfail flag is set to True then it will only test the examples which are expected to fail. + + Everytime the ODE of a particular solver is added, _test_all_hints() is to be executed to find + the possible failures of different solver hints. + +4) _test_all_examples_for_one_hint(our_hint, all_examples) - It takes hint as argument and checks + this hint against all the ODE examples and gives output as the number of ODEs matched, number + of ODEs which were solved correctly, list of ODEs which gives incorrect solution and list of + ODEs which raises exception. + +""" +from sympy.core.function import (Derivative, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (asinh, cosh, sinh, tanh) +from sympy.functions.elementary.miscellaneous import (cbrt, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sec, sin, tan) +from sympy.functions.special.error_functions import (Ei, erfi) +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import (Integral, integrate) +from sympy.polys.rootoftools import rootof + +from sympy.core import Function, Symbol +from sympy.functions import airyai, airybi, besselj, bessely, lowergamma +from sympy.integrals.risch import NonElementaryIntegral +from sympy.solvers.ode import classify_ode, dsolve +from sympy.solvers.ode.ode import allhints, _remove_redundant_solutions +from sympy.solvers.ode.single import (FirstLinear, ODEMatchError, + SingleODEProblem, SingleODESolver, NthOrderReducible) + +from sympy.solvers.ode.subscheck import checkodesol + +from sympy.testing.pytest import raises, slow +import traceback + + +x = Symbol('x') +u = Symbol('u') +_u = Dummy('u') +y = Symbol('y') +f = Function('f') +g = Function('g') +C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C1:11') +a, b, c = symbols('a b c') + + +hint_message = """\ +Hint did not match the example {example}. + +The ODE is: +{eq}. + +The expected hint was +{our_hint}\ +""" + +expected_sol_message = """\ +Different solution found from dsolve for example {example}. + +The ODE is: +{eq} + +The expected solution was +{sol} + +What dsolve returned is: +{dsolve_sol}\ +""" + +checkodesol_msg = """\ +solution found is not correct for example {example}. + +The ODE is: +{eq}\ +""" + +dsol_incorrect_msg = """\ +solution returned by dsolve is incorrect when using {hint}. + +The ODE is: +{eq} + +The expected solution was +{sol} + +what dsolve returned is: +{dsolve_sol} + +You can test this with: + +eq = {eq} +sol = dsolve(eq, hint='{hint}') +print(sol) +print(checkodesol(eq, sol)) + +""" + +exception_msg = """\ +dsolve raised exception : {e} + +when using {hint} for the example {example} + +You can test this with: + +from sympy.solvers.ode.tests.test_single import _test_an_example + +_test_an_example('{hint}', example_name = '{example}') + +The ODE is: +{eq} + +\ +""" + +check_hint_msg = """\ +Tested hint was : {hint} + +Total of {matched} examples matched with this hint. + +Out of which {solve} gave correct results. + +Examples which gave incorrect results are {unsolve}. + +Examples which raised exceptions are {exceptions} +\ +""" + + +def _add_example_keys(func): + def inner(): + solver=func() + examples=[] + for example in solver['examples']: + temp={ + 'eq': solver['examples'][example]['eq'], + 'sol': solver['examples'][example]['sol'], + 'XFAIL': solver['examples'][example].get('XFAIL', []), + 'func': solver['examples'][example].get('func',solver['func']), + 'example_name': example, + 'slow': solver['examples'][example].get('slow', False), + 'simplify_flag':solver['examples'][example].get('simplify_flag',True), + 'checkodesol_XFAIL': solver['examples'][example].get('checkodesol_XFAIL', False), + 'dsolve_too_slow':solver['examples'][example].get('dsolve_too_slow',False), + 'checkodesol_too_slow':solver['examples'][example].get('checkodesol_too_slow',False), + 'hint': solver['hint'] + } + examples.append(temp) + return examples + return inner() + + +def _ode_solver_test(ode_examples, run_slow_test=False): + for example in ode_examples: + if ((not run_slow_test) and example['slow']) or (run_slow_test and (not example['slow'])): + continue + + result = _test_particular_example(example['hint'], example, solver_flag=True) + if result['xpass_msg'] != "": + print(result['xpass_msg']) + + +def _test_all_hints(runxfail=False): + all_hints = list(allhints)+["default"] + all_examples = _get_all_examples() + + for our_hint in all_hints: + if our_hint.endswith('_Integral') or 'series' in our_hint: + continue + _test_all_examples_for_one_hint(our_hint, all_examples, runxfail) + + +def _test_dummy_sol(expected_sol,dsolve_sol): + if type(dsolve_sol)==list: + return any(expected_sol.dummy_eq(sub_dsol) for sub_dsol in dsolve_sol) + else: + return expected_sol.dummy_eq(dsolve_sol) + + +def _test_an_example(our_hint, example_name): + all_examples = _get_all_examples() + for example in all_examples: + if example['example_name'] == example_name: + _test_particular_example(our_hint, example) + + +def _test_particular_example(our_hint, ode_example, solver_flag=False): + eq = ode_example['eq'] + expected_sol = ode_example['sol'] + example = ode_example['example_name'] + xfail = our_hint in ode_example['XFAIL'] + func = ode_example['func'] + result = {'msg': '', 'xpass_msg': ''} + simplify_flag=ode_example['simplify_flag'] + checkodesol_XFAIL = ode_example['checkodesol_XFAIL'] + dsolve_too_slow = ode_example['dsolve_too_slow'] + checkodesol_too_slow = ode_example['checkodesol_too_slow'] + xpass = True + if solver_flag: + if our_hint not in classify_ode(eq, func): + message = hint_message.format(example=example, eq=eq, our_hint=our_hint) + raise AssertionError(message) + + if our_hint in classify_ode(eq, func): + result['match_list'] = example + try: + if not (dsolve_too_slow): + dsolve_sol = dsolve(eq, func, simplify=simplify_flag,hint=our_hint) + else: + if len(expected_sol)==1: + dsolve_sol = expected_sol[0] + else: + dsolve_sol = expected_sol + + except Exception as e: + dsolve_sol = [] + result['exception_list'] = example + if not solver_flag: + traceback.print_exc() + result['msg'] = exception_msg.format(e=str(e), hint=our_hint, example=example, eq=eq) + if solver_flag and not xfail: + print(result['msg']) + raise + xpass = False + + if solver_flag and dsolve_sol!=[]: + expect_sol_check = False + if type(dsolve_sol)==list: + for sub_sol in expected_sol: + if sub_sol.has(Dummy): + expect_sol_check = not _test_dummy_sol(sub_sol, dsolve_sol) + else: + expect_sol_check = sub_sol not in dsolve_sol + if expect_sol_check: + break + else: + expect_sol_check = dsolve_sol not in expected_sol + for sub_sol in expected_sol: + if sub_sol.has(Dummy): + expect_sol_check = not _test_dummy_sol(sub_sol, dsolve_sol) + + if expect_sol_check: + message = expected_sol_message.format(example=example, eq=eq, sol=expected_sol, dsolve_sol=dsolve_sol) + raise AssertionError(message) + + expected_checkodesol = [(True, 0) for i in range(len(expected_sol))] + if len(expected_sol) == 1: + expected_checkodesol = (True, 0) + + if not checkodesol_too_slow: + if not checkodesol_XFAIL: + if checkodesol(eq, dsolve_sol, func, solve_for_func=False) != expected_checkodesol: + result['unsolve_list'] = example + xpass = False + message = dsol_incorrect_msg.format(hint=our_hint, eq=eq, sol=expected_sol,dsolve_sol=dsolve_sol) + if solver_flag: + message = checkodesol_msg.format(example=example, eq=eq) + raise AssertionError(message) + else: + result['msg'] = 'AssertionError: ' + message + + if xpass and xfail: + result['xpass_msg'] = example + "is now passing for the hint" + our_hint + return result + + +def _test_all_examples_for_one_hint(our_hint, all_examples=[], runxfail=None): + if all_examples == []: + all_examples = _get_all_examples() + match_list, unsolve_list, exception_list = [], [], [] + for ode_example in all_examples: + xfail = our_hint in ode_example['XFAIL'] + if runxfail and not xfail: + continue + if xfail: + continue + result = _test_particular_example(our_hint, ode_example) + match_list += result.get('match_list',[]) + unsolve_list += result.get('unsolve_list',[]) + exception_list += result.get('exception_list',[]) + if runxfail is not None: + msg = result['msg'] + if msg!='': + print(result['msg']) + # print(result.get('xpass_msg','')) + if runxfail is None: + match_count = len(match_list) + solved = len(match_list)-len(unsolve_list)-len(exception_list) + msg = check_hint_msg.format(hint=our_hint, matched=match_count, solve=solved, unsolve=unsolve_list, exceptions=exception_list) + print(msg) + + +def test_SingleODESolver(): + # Test that not implemented methods give NotImplementedError + # Subclasses should override these methods. + problem = SingleODEProblem(f(x).diff(x), f(x), x) + solver = SingleODESolver(problem) + raises(NotImplementedError, lambda: solver.matches()) + raises(NotImplementedError, lambda: solver.get_general_solution()) + raises(NotImplementedError, lambda: solver._matches()) + raises(NotImplementedError, lambda: solver._get_general_solution()) + + # This ODE can not be solved by the FirstLinear solver. Here we test that + # it does not match and the asking for a general solution gives + # ODEMatchError + + problem = SingleODEProblem(f(x).diff(x) + f(x)*f(x), f(x), x) + + solver = FirstLinear(problem) + raises(ODEMatchError, lambda: solver.get_general_solution()) + + solver = FirstLinear(problem) + assert solver.matches() is False + + #These are just test for order of ODE + + problem = SingleODEProblem(f(x).diff(x) + f(x), f(x), x) + assert problem.order == 1 + + problem = SingleODEProblem(f(x).diff(x,4) + f(x).diff(x,2) - f(x).diff(x,3), f(x), x) + assert problem.order == 4 + + problem = SingleODEProblem(f(x).diff(x, 3) + f(x).diff(x, 2) - f(x)**2, f(x), x) + assert problem.is_autonomous == True + + problem = SingleODEProblem(f(x).diff(x, 3) + x*f(x).diff(x, 2) - f(x)**2, f(x), x) + assert problem.is_autonomous == False + + +def test_linear_coefficients(): + _ode_solver_test(_get_examples_ode_sol_linear_coefficients) + + +@slow +def test_1st_homogeneous_coeff_ode(): + #These were marked as test_1st_homogeneous_coeff_corner_case + eq1 = f(x).diff(x) - f(x)/x + c1 = classify_ode(eq1, f(x)) + eq2 = x*f(x).diff(x) - f(x) + c2 = classify_ode(eq2, f(x)) + sdi = "1st_homogeneous_coeff_subs_dep_div_indep" + sid = "1st_homogeneous_coeff_subs_indep_div_dep" + assert sid not in c1 and sdi not in c1 + assert sid not in c2 and sdi not in c2 + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep) + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_best) + + +@slow +def test_slow_examples_1st_homogeneous_coeff_ode(): + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep, run_slow_test=True) + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_best, run_slow_test=True) + + +@slow +def test_nth_linear_constant_coeff_homogeneous(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_constant_coeff_homogeneous) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_homogeneous(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_constant_coeff_homogeneous, run_slow_test=True) + + +def test_Airy_equation(): + _ode_solver_test(_get_examples_ode_sol_2nd_linear_airy) + + +@slow +def test_lie_group(): + _ode_solver_test(_get_examples_ode_sol_lie_group) + + +@slow +def test_separable_reduced(): + df = f(x).diff(x) + eq = (x / f(x))*df + tan(x**2*f(x) / (x**2*f(x) - 1)) + assert classify_ode(eq) == ('factorable', 'separable_reduced', 'lie_group', + 'separable_reduced_Integral') + _ode_solver_test(_get_examples_ode_sol_separable_reduced) + + +@slow +def test_slow_examples_separable_reduced(): + _ode_solver_test(_get_examples_ode_sol_separable_reduced, run_slow_test=True) + + +@slow +def test_2nd_2F1_hypergeometric(): + _ode_solver_test(_get_examples_ode_sol_2nd_2F1_hypergeometric) + + +def test_2nd_2F1_hypergeometric_integral(): + eq = x*(x-1)*f(x).diff(x, 2) + (-1+ S(7)/2*x)*f(x).diff(x) + f(x) + sol = Eq(f(x), (C1 + C2*Integral(exp(Integral((1 - x/2)/(x*(x - 1)), x))/(1 - + x/2)**2, x))*exp(Integral(1/(x - 1), x)/4)*exp(-Integral(7/(x - + 1), x)/4)*hyper((S(1)/2, -1), (1,), x)) + assert sol == dsolve(eq, hint='2nd_hypergeometric_Integral') + assert checkodesol(eq, sol) == (True, 0) + + +@slow +def test_2nd_nonlinear_autonomous_conserved(): + _ode_solver_test(_get_examples_ode_sol_2nd_nonlinear_autonomous_conserved) + + +def test_2nd_nonlinear_autonomous_conserved_integral(): + eq = f(x).diff(x, 2) + asin(f(x)) + actual = [Eq(Integral(1/sqrt(C1 - 2*Integral(asin(_u), _u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*Integral(asin(_u), _u)), (_u, f(x))), C2 - x)] + solved = dsolve(eq, hint='2nd_nonlinear_autonomous_conserved_Integral', simplify=False) + for a,s in zip(actual, solved): + assert a.dummy_eq(s) + # checkodesol unable to simplify solutions with f(x) in an integral equation + assert checkodesol(eq, [s.doit() for s in solved]) == [(True, 0), (True, 0)] + + +@slow +def test_2nd_linear_bessel_equation(): + _ode_solver_test(_get_examples_ode_sol_2nd_linear_bessel) + + +@slow +def test_nth_algebraic(): + eqn = f(x) + f(x)*f(x).diff(x) + solns = [Eq(f(x), exp(x)), + Eq(f(x), C1*exp(C2*x))] + solns_final = _remove_redundant_solutions(eqn, solns, 2, x) + assert solns_final == [Eq(f(x), C1*exp(C2*x))] + + _ode_solver_test(_get_examples_ode_sol_nth_algebraic) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_var_of_parameters(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_var_of_parameters, run_slow_test=True) + + +def test_nth_linear_constant_coeff_var_of_parameters(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_var_of_parameters) + + +@slow +def test_nth_linear_constant_coeff_variation_of_parameters__integral(): + # solve_variation_of_parameters shouldn't attempt to simplify the + # Wronskian if simplify=False. If wronskian() ever gets good enough + # to simplify the result itself, this test might fail. + our_hint = 'nth_linear_constant_coeff_variation_of_parameters_Integral' + eq = f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x) + sol_simp = dsolve(eq, f(x), hint=our_hint, simplify=True) + sol_nsimp = dsolve(eq, f(x), hint=our_hint, simplify=False) + assert sol_simp != sol_nsimp + assert checkodesol(eq, sol_simp, order=5, solve_for_func=False) == (True, 0) + assert checkodesol(eq, sol_simp, order=5, solve_for_func=False) == (True, 0) + + +@slow +def test_slow_examples_1st_exact(): + _ode_solver_test(_get_examples_ode_sol_1st_exact, run_slow_test=True) + + +@slow +def test_1st_exact(): + _ode_solver_test(_get_examples_ode_sol_1st_exact) + + +def test_1st_exact_integral(): + eq = cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x) + sol_1 = dsolve(eq, f(x), simplify=False, hint='1st_exact_Integral') + assert checkodesol(eq, sol_1, order=1, solve_for_func=False) + + +@slow +def test_slow_examples_nth_order_reducible(): + _ode_solver_test(_get_examples_ode_sol_nth_order_reducible, run_slow_test=True) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_undetermined_coefficients(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_undetermined_coefficients, run_slow_test=True) + + +@slow +def test_slow_examples_separable(): + _ode_solver_test(_get_examples_ode_sol_separable, run_slow_test=True) + + +@slow +def test_nth_linear_constant_coeff_undetermined_coefficients(): + #issue-https://github.com/sympy/sympy/issues/5787 + # This test case is to show the classification of imaginary constants under + # nth_linear_constant_coeff_undetermined_coefficients + eq = Eq(diff(f(x), x), I*f(x) + S.Half - I) + our_hint = 'nth_linear_constant_coeff_undetermined_coefficients' + assert our_hint in classify_ode(eq) + _ode_solver_test(_get_examples_ode_sol_nth_linear_undetermined_coefficients) + + +def test_nth_order_reducible(): + F = lambda eq: NthOrderReducible(SingleODEProblem(eq, f(x), x))._matches() + D = Derivative + assert F(D(y*f(x), x, y) + D(f(x), x)) == False + assert F(D(y*f(y), y, y) + D(f(y), y)) == False + assert F(f(x)*D(f(x), x) + D(f(x), x, 2))== False + assert F(D(x*f(y), y, 2) + D(u*y*f(x), x, 3)) == False # no simplification by design + assert F(D(f(y), y, 2) + D(f(y), y, 3) + D(f(x), x, 4)) == False + assert F(D(f(x), x, 2) + D(f(x), x, 3)) == True + _ode_solver_test(_get_examples_ode_sol_nth_order_reducible) + + +@slow +def test_separable(): + _ode_solver_test(_get_examples_ode_sol_separable) + + +@slow +def test_factorable(): + assert integrate(-asin(f(2*x)+pi), x) == -Integral(asin(pi + f(2*x)), x) + _ode_solver_test(_get_examples_ode_sol_factorable) + + +@slow +def test_slow_examples_factorable(): + _ode_solver_test(_get_examples_ode_sol_factorable, run_slow_test=True) + + +def test_Riccati_special_minus2(): + _ode_solver_test(_get_examples_ode_sol_riccati) + + +@slow +def test_1st_rational_riccati(): + _ode_solver_test(_get_examples_ode_sol_1st_rational_riccati) + + +def test_Bernoulli(): + _ode_solver_test(_get_examples_ode_sol_bernoulli) + + +def test_1st_linear(): + _ode_solver_test(_get_examples_ode_sol_1st_linear) + + +def test_almost_linear(): + _ode_solver_test(_get_examples_ode_sol_almost_linear) + + +@slow +def test_Liouville_ODE(): + hint = 'Liouville' + not_Liouville1 = classify_ode(diff(f(x), x)/x + f(x)*diff(f(x), x, x)/2 - + diff(f(x), x)**2/2, f(x)) + not_Liouville2 = classify_ode(diff(f(x), x)/x + diff(f(x), x, x)/2 - + x*diff(f(x), x)**2/2, f(x)) + assert hint not in not_Liouville1 + assert hint not in not_Liouville2 + assert hint + '_Integral' not in not_Liouville1 + assert hint + '_Integral' not in not_Liouville2 + + _ode_solver_test(_get_examples_ode_sol_liouville) + + +def test_nth_order_linear_euler_eq_homogeneous(): + x, t, a, b, c = symbols('x t a b c') + y = Function('y') + our_hint = "nth_linear_euler_eq_homogeneous" + + eq = diff(f(t), t, 4)*t**4 - 13*diff(f(t), t, 2)*t**2 + 36*f(t) + assert our_hint in classify_ode(eq) + + eq = a*y(t) + b*t*diff(y(t), t) + c*t**2*diff(y(t), t, 2) + assert our_hint in classify_ode(eq) + + _ode_solver_test(_get_examples_ode_sol_euler_homogeneous) + + +def test_nth_order_linear_euler_eq_nonhomogeneous_undetermined_coefficients(): + x, t = symbols('x t') + a, b, c, d = symbols('a b c d', integer=True) + our_hint = "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients" + + eq = x**4*diff(f(x), x, 4) - 13*x**2*diff(f(x), x, 2) + 36*f(x) + x + assert our_hint in classify_ode(eq, f(x)) + + eq = a*x**2*diff(f(x), x, 2) + b*x*diff(f(x), x) + c*f(x) + d*log(x) + assert our_hint in classify_ode(eq, f(x)) + + _ode_solver_test(_get_examples_ode_sol_euler_undetermined_coeff) + + +@slow +def test_nth_order_linear_euler_eq_nonhomogeneous_variation_of_parameters(): + x, t = symbols('x, t') + a, b, c, d = symbols('a, b, c, d', integer=True) + our_hint = "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters" + + eq = Eq(x**2*diff(f(x),x,2) - 8*x*diff(f(x),x) + 12*f(x), x**2) + assert our_hint in classify_ode(eq, f(x)) + + eq = Eq(a*x**3*diff(f(x),x,3) + b*x**2*diff(f(x),x,2) + c*x*diff(f(x),x) + d*f(x), x*log(x)) + assert our_hint in classify_ode(eq, f(x)) + + _ode_solver_test(_get_examples_ode_sol_euler_var_para) + + +@_add_example_keys +def _get_examples_ode_sol_euler_homogeneous(): + r1, r2, r3, r4, r5 = [rootof(x**5 - 14*x**4 + 71*x**3 - 154*x**2 + 120*x - 1, n) for n in range(5)] + return { + 'hint': "nth_linear_euler_eq_homogeneous", + 'func': f(x), + 'examples':{ + 'euler_hom_01': { + 'eq': Eq(-3*diff(f(x), x)*x + 2*x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), C1 + C2*x**Rational(5, 2))], + }, + + 'euler_hom_02': { + 'eq': Eq(3*f(x) - 5*diff(f(x), x)*x + 2*x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), C1*sqrt(x) + C2*x**3)] + }, + + 'euler_hom_03': { + 'eq': Eq(4*f(x) + 5*diff(f(x), x)*x + x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), (C1 + C2*log(x))/x**2)] + }, + + 'euler_hom_04': { + 'eq': Eq(6*f(x) - 6*diff(f(x), x)*x + 1*x**2*diff(f(x), x, x) + x**3*diff(f(x), x, x, x), 0), + 'sol': [Eq(f(x), C1/x**2 + C2*x + C3*x**3)] + }, + + 'euler_hom_05': { + 'eq': Eq(-125*f(x) + 61*diff(f(x), x)*x - 12*x**2*diff(f(x), x, x) + x**3*diff(f(x), x, x, x), 0), + 'sol': [Eq(f(x), x**5*(C1 + C2*log(x) + C3*log(x)**2))] + }, + + 'euler_hom_06': { + 'eq': x**2*diff(f(x), x, 2) + x*diff(f(x), x) - 9*f(x), + 'sol': [Eq(f(x), C1*x**-3 + C2*x**3)] + }, + + 'euler_hom_07': { + 'eq': sin(x)*x**2*f(x).diff(x, 2) + sin(x)*x*f(x).diff(x) + sin(x)*f(x), + 'sol': [Eq(f(x), C1*sin(log(x)) + C2*cos(log(x)))], + 'XFAIL': ['2nd_power_series_regular','nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients'] + }, + + 'euler_hom_08': { + 'eq': x**6 * f(x).diff(x, 6) - x*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*x + C2*x**r1 + C3*x**r2 + C4*x**r3 + C5*x**r4 + C6*x**r5)], + 'checkodesol_XFAIL':True + }, + + #This example is from issue: https://github.com/sympy/sympy/issues/15237 #This example is from issue: + # https://github.com/sympy/sympy/issues/15237 + 'euler_hom_09': { + 'eq': Derivative(x*f(x), x, x, x), + 'sol': [Eq(f(x), C1 + C2/x + C3*x)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_euler_undetermined_coeff(): + return { + 'hint': "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients", + 'func': f(x), + 'examples':{ + 'euler_undet_01': { + 'eq': Eq(x**2*diff(f(x), x, x) + x*diff(f(x), x), 1), + 'sol': [Eq(f(x), C1 + C2*log(x) + log(x)**2/2)] + }, + + 'euler_undet_02': { + 'eq': Eq(x**2*diff(f(x), x, x) - 2*x*diff(f(x), x) + 2*f(x), x**3), + 'sol': [Eq(f(x), x*(C1 + C2*x + Rational(1, 2)*x**2))] + }, + + 'euler_undet_03': { + 'eq': Eq(x**2*diff(f(x), x, x) - x*diff(f(x), x) - 3*f(x), log(x)/x), + 'sol': [Eq(f(x), (C1 + C2*x**4 - log(x)**2/8 - log(x)/16)/x)] + }, + + 'euler_undet_04': { + 'eq': Eq(x**2*diff(f(x), x, x) + 3*x*diff(f(x), x) - 8*f(x), log(x)**3 - log(x)), + 'sol': [Eq(f(x), C1/x**4 + C2*x**2 - Rational(1,8)*log(x)**3 - Rational(3,32)*log(x)**2 - Rational(1,64)*log(x) - Rational(7, 256))] + }, + + 'euler_undet_05': { + 'eq': Eq(x**3*diff(f(x), x, x, x) - 3*x**2*diff(f(x), x, x) + 6*x*diff(f(x), x) - 6*f(x), log(x)), + 'sol': [Eq(f(x), C1*x + C2*x**2 + C3*x**3 - Rational(1, 6)*log(x) - Rational(11, 36))] + }, + + #Below examples were added for the issue: https://github.com/sympy/sympy/issues/5096 + 'euler_undet_06': { + 'eq': 2*x**2*f(x).diff(x, 2) + f(x) + sqrt(2*x)*sin(log(2*x)/2), + 'sol': [Eq(f(x), sqrt(x)*(C1*sin(log(x)/2) + C2*cos(log(x)/2) + sqrt(2)*log(x)*cos(log(2*x)/2)/2))] + }, + + 'euler_undet_07': { + 'eq': 2*x**2*f(x).diff(x, 2) + f(x) + sin(log(2*x)/2), + 'sol': [Eq(f(x), C1*sqrt(x)*sin(log(x)/2) + C2*sqrt(x)*cos(log(x)/2) - 2*sin(log(2*x)/2)/5 - 4*cos(log(2*x)/2)/5)] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_euler_var_para(): + return { + 'hint': "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters", + 'func': f(x), + 'examples':{ + 'euler_var_01': { + 'eq': Eq(x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x), x**4), + 'sol': [Eq(f(x), x*(C1 + C2*x + x**3/6))] + }, + + 'euler_var_02': { + 'eq': Eq(3*x**2*diff(f(x), x, x) + 6*x*diff(f(x), x) - 6*f(x), x**3*exp(x)), + 'sol': [Eq(f(x), C1/x**2 + C2*x + x*exp(x)/3 - 4*exp(x)/3 + 8*exp(x)/(3*x) - 8*exp(x)/(3*x**2))] + }, + + 'euler_var_03': { + 'eq': Eq(x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x), x**4*exp(x)), + 'sol': [Eq(f(x), x*(C1 + C2*x + x*exp(x) - 2*exp(x)))] + }, + + 'euler_var_04': { + 'eq': x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - log(x), + 'sol': [Eq(f(x), C1*x + C2*x**2 + log(x)/2 + Rational(3, 4))] + }, + + 'euler_var_05': { + 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))] + }, + + 'euler_var_06': { + 'eq': x**2 * f(x).diff(x, 2) + x * f(x).diff(x) + 4 * f(x) - 1/x, + 'sol': [Eq(f(x), C1*sin(2*log(x)) + C2*cos(2*log(x)) + 1/(5*x))] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_bernoulli(): + # Type: Bernoulli, f'(x) + p(x)*f(x) == q(x)*f(x)**n + return { + 'hint': "Bernoulli", + 'func': f(x), + 'examples':{ + 'bernoulli_01': { + 'eq': Eq(x*f(x).diff(x) + f(x) - f(x)**2, 0), + 'sol': [Eq(f(x), 1/(C1*x + 1))], + 'XFAIL': ['separable_reduced'] + }, + + 'bernoulli_02': { + 'eq': f(x).diff(x) - y*f(x), + 'sol': [Eq(f(x), C1*exp(x*y))] + }, + + 'bernoulli_03': { + 'eq': f(x)*f(x).diff(x) - 1, + 'sol': [Eq(f(x), -sqrt(C1 + 2*x)), Eq(f(x), sqrt(C1 + 2*x))] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_riccati(): + # Type: Riccati special alpha = -2, a*dy/dx + b*y**2 + c*y/x +d/x**2 + return { + 'hint': "Riccati_special_minus2", + 'func': f(x), + 'examples':{ + 'riccati_01': { + 'eq': 2*f(x).diff(x) + f(x)**2 - f(x)/x + 3*x**(-2), + 'sol': [Eq(f(x), (-sqrt(3)*tan(C1 + sqrt(3)*log(x)/4) + 3)/(2*x))], + }, + }, + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_rational_riccati(): + # Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2, + # a, b, c are rational functions of x + return { + 'hint': "1st_rational_riccati", + 'func': f(x), + 'examples':{ + # a(x) is a constant + "rational_riccati_01": { + "eq": Eq(f(x).diff(x) + f(x)**2 - 2, 0), + "sol": [Eq(f(x), sqrt(2)*(-C1 - exp(2*sqrt(2)*x))/(C1 - exp(2*sqrt(2)*x)))] + }, + # a(x) is a constant + "rational_riccati_02": { + "eq": f(x)**2 + Derivative(f(x), x) + 4*f(x)/x + 2/x**2, + "sol": [Eq(f(x), (-2*C1 - x)/(x*(C1 + x)))] + }, + # a(x) is a constant + "rational_riccati_03": { + "eq": 2*x**2*Derivative(f(x), x) - x*(4*f(x) + Derivative(f(x), x) - 4) + (f(x) - 1)*f(x), + "sol": [Eq(f(x), (C1 + 2*x**2)/(C1 + x))] + }, + # Constant coefficients + "rational_riccati_04": { + "eq": f(x).diff(x) - 6 - 5*f(x) - f(x)**2, + "sol": [Eq(f(x), (-2*C1 + 3*exp(x))/(C1 - exp(x)))] + }, + # One pole of multiplicity 2 + "rational_riccati_05": { + "eq": x**2 - (2*x + 1/x)*f(x) + f(x)**2 + Derivative(f(x), x), + "sol": [Eq(f(x), x*(C1 + x**2 + 1)/(C1 + x**2 - 1))] + }, + # One pole of multiplicity 2 + "rational_riccati_06": { + "eq": x**4*Derivative(f(x), x) + x**2 - x*(2*f(x)**2 + Derivative(f(x), x)) + f(x), + "sol": [Eq(f(x), x*(C1*x - x + 1)/(C1 + x**2 - 1))] + }, + # Multiple poles of multiplicity 2 + "rational_riccati_07": { + "eq": -f(x)**2 + Derivative(f(x), x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \ + - 1)**2), + "sol": [Eq(f(x), (9*C1*x - 6*C1 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 - \ + 33*x + 8)/(6*C1*x**2 - 9*C1*x + 3*C1 + 6*x**6 - 29*x**5 + 57*x**4 - \ + 58*x**3 + 28*x**2 - 3*x - 1))] + }, + # Imaginary poles + "rational_riccati_08": { + "eq": Derivative(f(x), x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \ + - 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2), + "sol": [Eq(f(x), (-C1 - x**3 + x**2 - 2*x + 1)/(C1*x - C1 + x**4 - x**3 + x**2 - \ + 2*x + 1))], + }, + # Imaginary coefficients in equation + "rational_riccati_09": { + "eq": Derivative(f(x), x) - 2*I*(f(x)**2 + 1)/x, + "sol": [Eq(f(x), (-I*C1 + I*x**4 + I)/(C1 + x**4 - 1))] + }, + # Regression: linsolve returning empty solution + # Large value of m (> 10) + "rational_riccati_10": { + "eq": Eq(Derivative(f(x), x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\ + (2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)), + "sol": [Eq(f(x), (40*C1*x**14 + 28*C1*x**13 + 420*C1*x**12 + 2940*C1*x**11 + \ + 18480*C1*x**10 + 103950*C1*x**9 + 519750*C1*x**8 + 2286900*C1*x**7 + \ + 8731800*C1*x**6 + 28378350*C1*x**5 + 76403250*C1*x**4 + 163721250*C1*x**3 \ + + 261954000*C1*x**2 + 278326125*C1*x + 147349125*C1 + x*exp(2*x) - 9*exp(2*x) \ + )/(x*(24*C1*x**13 + 140*C1*x**12 + 840*C1*x**11 + 4620*C1*x**10 + 23100*C1*x**9 \ + + 103950*C1*x**8 + 415800*C1*x**7 + 1455300*C1*x**6 + 4365900*C1*x**5 + \ + 10914750*C1*x**4 + 21829500*C1*x**3 + 32744250*C1*x**2 + 32744250*C1*x + \ + 16372125*C1 - exp(2*x))))] + } + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_1st_linear(): + # Type: first order linear form f'(x)+p(x)f(x)=q(x) + return { + 'hint': "1st_linear", + 'func': f(x), + 'examples':{ + 'linear_01': { + 'eq': Eq(f(x).diff(x) + x*f(x), x**2), + 'sol': [Eq(f(x), (C1 + x*exp(x**2/2)- sqrt(2)*sqrt(pi)*erfi(sqrt(2)*x/2)/2)*exp(-x**2/2))], + }, + }, + } + + +@_add_example_keys +def _get_examples_ode_sol_factorable(): + """ some hints are marked as xfail for examples because they missed additional algebraic solution + which could be found by Factorable hint. Fact_01 raise exception for + nth_linear_constant_coeff_undetermined_coefficients""" + + y = Dummy('y') + a0,a1,a2,a3,a4 = symbols('a0, a1, a2, a3, a4') + return { + 'hint': "factorable", + 'func': f(x), + 'examples':{ + 'fact_01': { + 'eq': f(x) + f(x)*f(x).diff(x), + 'sol': [Eq(f(x), 0), Eq(f(x), C1 - x)], + 'XFAIL': ['separable', '1st_exact', '1st_linear', 'Bernoulli', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', '1st_homogeneous_coeff_subs_dep_div_indep', + 'lie_group', 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters', + 'nth_linear_constant_coeff_undetermined_coefficients'] + }, + + 'fact_02': { + 'eq': f(x)*(f(x).diff(x)+f(x)*x+2), + 'sol': [Eq(f(x), (C1 - sqrt(2)*sqrt(pi)*erfi(sqrt(2)*x/2))*exp(-x**2/2)), Eq(f(x), 0)], + 'XFAIL': ['Bernoulli', '1st_linear', 'lie_group'] + }, + + 'fact_03': { + 'eq': (f(x).diff(x)+f(x)*x**2)*(f(x).diff(x, 2) + x*f(x)), + 'sol': [Eq(f(x), C1*airyai(-x) + C2*airybi(-x)),Eq(f(x), C1*exp(-x**3/3))] + }, + + 'fact_04': { + 'eq': (f(x).diff(x)+f(x)*x**2)*(f(x).diff(x, 2) + f(x)), + 'sol': [Eq(f(x), C1*exp(-x**3/3)), Eq(f(x), C1*sin(x) + C2*cos(x))] + }, + + 'fact_05': { + 'eq': (f(x).diff(x)**2-1)*(f(x).diff(x)**2-4), + 'sol': [Eq(f(x), C1 - x), Eq(f(x), C1 + x), Eq(f(x), C1 + 2*x), Eq(f(x), C1 - 2*x)] + }, + + 'fact_06': { + 'eq': (f(x).diff(x, 2)-exp(f(x)))*f(x).diff(x), + 'sol': [ + Eq(f(x), log(-C1/(cos(sqrt(-C1)*(C2 + x)) + 1))), + Eq(f(x), log(-C1/(cos(sqrt(-C1)*(C2 - x)) + 1))), + Eq(f(x), C1) + ], + 'slow': True, + }, + + 'fact_07': { + 'eq': (f(x).diff(x)**2-1)*(f(x)*f(x).diff(x)-1), + 'sol': [Eq(f(x), C1 - x), Eq(f(x), -sqrt(C1 + 2*x)),Eq(f(x), sqrt(C1 + 2*x)), Eq(f(x), C1 + x)] + }, + + 'fact_08': { + 'eq': Derivative(f(x), x)**4 - 2*Derivative(f(x), x)**2 + 1, + 'sol': [Eq(f(x), C1 - x), Eq(f(x), C1 + x)] + }, + + 'fact_09': { + 'eq': f(x)**2*Derivative(f(x), x)**6 - 2*f(x)**2*Derivative(f(x), + x)**4 + f(x)**2*Derivative(f(x), x)**2 - 2*f(x)*Derivative(f(x), + x)**5 + 4*f(x)*Derivative(f(x), x)**3 - 2*f(x)*Derivative(f(x), + x) + Derivative(f(x), x)**4 - 2*Derivative(f(x), x)**2 + 1, + 'sol': [ + Eq(f(x), C1 - x), Eq(f(x), -sqrt(C1 + 2*x)), + Eq(f(x), sqrt(C1 + 2*x)), Eq(f(x), C1 + x) + ] + }, + + 'fact_10': { + 'eq': x**4*f(x)**2 + 2*x**4*f(x)*Derivative(f(x), (x, 2)) + x**4*Derivative(f(x), + (x, 2))**2 + 2*x**3*f(x)*Derivative(f(x), x) + 2*x**3*Derivative(f(x), + x)*Derivative(f(x), (x, 2)) - 7*x**2*f(x)**2 - 7*x**2*f(x)*Derivative(f(x), + (x, 2)) + x**2*Derivative(f(x), x)**2 - 7*x*f(x)*Derivative(f(x), x) + 12*f(x)**2, + 'sol': [ + Eq(f(x), C1*besselj(2, x) + C2*bessely(2, x)), + Eq(f(x), C1*besselj(sqrt(3), x) + C2*bessely(sqrt(3), x)) + ], + 'slow': True, + }, + + 'fact_11': { + 'eq': (f(x).diff(x, 2)-exp(f(x)))*(f(x).diff(x, 2)+exp(f(x))), + 'sol': [ + Eq(f(x), log(C1/(cos(C1*sqrt(-1/C1)*(C2 + x)) - 1))), + Eq(f(x), log(C1/(cos(C1*sqrt(-1/C1)*(C2 - x)) - 1))), + Eq(f(x), log(C1/(1 - cos(C1*sqrt(-1/C1)*(C2 + x))))), + Eq(f(x), log(C1/(1 - cos(C1*sqrt(-1/C1)*(C2 - x))))) + ], + 'dsolve_too_slow': True, + }, + + #Below examples were added for the issue: https://github.com/sympy/sympy/issues/15889 + 'fact_12': { + 'eq': exp(f(x).diff(x))-f(x)**2, + 'sol': [Eq(NonElementaryIntegral(1/log(y**2), (y, f(x))), C1 + x)], + 'XFAIL': ['lie_group'] #It shows not implemented error for lie_group. + }, + + 'fact_13': { + 'eq': f(x).diff(x)**2 - f(x)**3, + 'sol': [Eq(f(x), 4/(C1**2 - 2*C1*x + x**2))], + 'XFAIL': ['lie_group'] #It shows not implemented error for lie_group. + }, + + 'fact_14': { + 'eq': f(x).diff(x)**2 - f(x), + 'sol': [Eq(f(x), C1**2/4 - C1*x/2 + x**2/4)] + }, + + 'fact_15': { + 'eq': f(x).diff(x)**2 - f(x)**2, + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1*exp(-x))] + }, + + 'fact_16': { + 'eq': f(x).diff(x)**2 - f(x)**3, + 'sol': [Eq(f(x), 4/(C1**2 - 2*C1*x + x**2))], + }, + + # kamke ode 1.1 + 'fact_17': { + 'eq': f(x).diff(x)-(a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**(-1/2), + 'sol': [Eq(f(x), C1 + Integral(1/sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4), x))], + 'slow': True + }, + + # This is from issue: https://github.com/sympy/sympy/issues/9446 + 'fact_18':{ + 'eq': Eq(f(2 * x), sin(Derivative(f(x)))), + 'sol': [Eq(f(x), C1 + Integral(pi - asin(f(2*x)), x)), Eq(f(x), C1 + Integral(asin(f(2*x)), x))], + 'checkodesol_XFAIL':True + }, + + # This is from issue: https://github.com/sympy/sympy/issues/7093 + 'fact_19': { + 'eq': Derivative(f(x), x)**2 - x**3, + 'sol': [Eq(f(x), C1 - 2*x**Rational(5,2)/5), Eq(f(x), C1 + 2*x**Rational(5,2)/5)], + }, + + 'fact_20': { + 'eq': x*f(x).diff(x, 2) - x*f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(x))], + }, + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_almost_linear(): + from sympy.functions.special.error_functions import Ei + A = Symbol('A', positive=True) + f = Function('f') + d = f(x).diff(x) + + return { + 'hint': "almost_linear", + 'func': f(x), + 'examples':{ + 'almost_lin_01': { + 'eq': x**2*f(x)**2*d + f(x)**3 + 1, + 'sol': [Eq(f(x), (C1*exp(3/x) - 1)**Rational(1, 3)), + Eq(f(x), (-1 - sqrt(3)*I)*(C1*exp(3/x) - 1)**Rational(1, 3)/2), + Eq(f(x), (-1 + sqrt(3)*I)*(C1*exp(3/x) - 1)**Rational(1, 3)/2)], + + }, + + 'almost_lin_02': { + 'eq': x*f(x)*d + 2*x*f(x)**2 + 1, + 'sol': [Eq(f(x), -sqrt((C1 - 2*Ei(4*x))*exp(-4*x))), Eq(f(x), sqrt((C1 - 2*Ei(4*x))*exp(-4*x)))] + }, + + 'almost_lin_03': { + 'eq': x*d + x*f(x) + 1, + 'sol': [Eq(f(x), (C1 - Ei(x))*exp(-x))] + }, + + 'almost_lin_04': { + 'eq': x*exp(f(x))*d + exp(f(x)) + 3*x, + 'sol': [Eq(f(x), log(C1/x - x*Rational(3, 2)))], + }, + + 'almost_lin_05': { + 'eq': x + A*(x + diff(f(x), x) + f(x)) + diff(f(x), x) + f(x) + 2, + 'sol': [Eq(f(x), (C1 + Piecewise( + (x, Eq(A + 1, 0)), ((-A*x + A - x - 1)*exp(x)/(A + 1), True)))*exp(-x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_liouville(): + n = Symbol('n') + _y = Dummy('y') + return { + 'hint': "Liouville", + 'func': f(x), + 'examples':{ + 'liouville_01': { + 'eq': diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2, + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))], + + }, + + 'liouville_02': { + 'eq': diff(x*exp(-f(x)), x, x), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))] + }, + + 'liouville_03': { + 'eq': ((diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2)*exp(-f(x))/exp(f(x))).expand(), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))] + }, + + 'liouville_04': { + 'eq': diff(f(x), x, x) + 1/f(x)*(diff(f(x), x))**2 + 1/x*diff(f(x), x), + 'sol': [Eq(f(x), -sqrt(C1 + C2*log(x))), Eq(f(x), sqrt(C1 + C2*log(x)))], + }, + + 'liouville_05': { + 'eq': x*diff(f(x), x, x) + x/f(x)*diff(f(x), x)**2 + x*diff(f(x), x), + 'sol': [Eq(f(x), -sqrt(C1 + C2*exp(-x))), Eq(f(x), sqrt(C1 + C2*exp(-x)))], + }, + + 'liouville_06': { + 'eq': Eq((x*exp(f(x))).diff(x, x), 0), + 'sol': [Eq(f(x), log(C1 + C2/x))], + }, + + 'liouville_07': { + 'eq': (diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2)*exp(-f(x))/exp(f(x)), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))], + }, + + 'liouville_08': { + 'eq': x**2*diff(f(x),x) + (n*f(x) + f(x)**2)*diff(f(x),x)**2 + diff(f(x), (x, 2)), + 'sol': [Eq(C1 + C2*lowergamma(Rational(1,3), x**3/3) + NonElementaryIntegral(exp(_y**3/3)*exp(_y**2*n/2), (_y, f(x))), 0)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_algebraic(): + M, m, r, t = symbols('M m r t') + phi = Function('phi') + k = Symbol('k') + # This one needs a substitution f' = g. + # 'algeb_12': { + # 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + # 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + # }, + return { + 'hint': "nth_algebraic", + 'func': f(x), + 'examples':{ + 'algeb_01': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x) * (f(x) - 1) * (f(x).diff(x) - x), + 'sol': [Eq(f(x), C1 + x**2/2), Eq(f(x), C1 + C2*x)] + }, + + 'algeb_02': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x) * (f(x) - 1), + 'sol': [Eq(f(x), C1 + C2*x)] + }, + + 'algeb_03': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x), + 'sol': [Eq(f(x), C1 + C2*x)] + }, + + 'algeb_04': { + 'eq': Eq(-M * phi(t).diff(t), + Rational(3, 2) * m * r**2 * phi(t).diff(t) * phi(t).diff(t,t)), + 'sol': [Eq(phi(t), C1), Eq(phi(t), C1 + C2*t - M*t**2/(3*m*r**2))], + 'func': phi(t) + }, + + 'algeb_05': { + 'eq': (1 - sin(f(x))) * f(x).diff(x), + 'sol': [Eq(f(x), C1)], + 'XFAIL': ['separable'] #It raised exception. + }, + + 'algeb_06': { + 'eq': (diff(f(x)) - x)*(diff(f(x)) + x), + 'sol': [Eq(f(x), C1 - x**2/2), Eq(f(x), C1 + x**2/2)] + }, + + 'algeb_07': { + 'eq': Eq(Derivative(f(x), x), Derivative(g(x), x)), + 'sol': [Eq(f(x), C1 + g(x))], + }, + + 'algeb_08': { + 'eq': f(x).diff(x) - C1, #this example is from issue 15999 + 'sol': [Eq(f(x), C1*x + C2)], + }, + + 'algeb_09': { + 'eq': f(x)*f(x).diff(x), + 'sol': [Eq(f(x), C1)], + }, + + 'algeb_10': { + 'eq': (diff(f(x)) - x)*(diff(f(x)) + x), + 'sol': [Eq(f(x), C1 - x**2/2), Eq(f(x), C1 + x**2/2)], + }, + + 'algeb_11': { + 'eq': f(x) + f(x)*f(x).diff(x), + 'sol': [Eq(f(x), 0), Eq(f(x), C1 - x)], + 'XFAIL': ['separable', '1st_exact', '1st_linear', 'Bernoulli', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', '1st_homogeneous_coeff_subs_dep_div_indep', + 'lie_group', 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters'] + #nth_linear_constant_coeff_undetermined_coefficients raises exception rest all of them misses a solution. + }, + + 'algeb_12': { + 'eq': Derivative(x*f(x), x, x, x), + 'sol': [Eq(f(x), (C1 + C2*x + C3*x**2) / x)], + 'XFAIL': ['nth_algebraic'] # It passes only when prep=False is set in dsolve. + }, + + 'algeb_13': { + 'eq': Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x)), + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + 'XFAIL': ['nth_algebraic'] # It passes only when prep=False is set in dsolve. + }, + + # These are simple tests from the old ode module example 14-18 + 'algeb_14': { + 'eq': Eq(f(x).diff(x), 0), + 'sol': [Eq(f(x), C1)], + }, + + 'algeb_15': { + 'eq': Eq(3*f(x).diff(x) - 5, 0), + 'sol': [Eq(f(x), C1 + x*Rational(5, 3))], + }, + + 'algeb_16': { + 'eq': Eq(3*f(x).diff(x), 5), + 'sol': [Eq(f(x), C1 + x*Rational(5, 3))], + }, + + # Type: 2nd order, constant coefficients (two complex roots) + 'algeb_17': { + 'eq': Eq(3*f(x).diff(x) - 1, 0), + 'sol': [Eq(f(x), C1 + x/3)], + }, + + 'algeb_18': { + 'eq': Eq(x*f(x).diff(x) - 1, 0), + 'sol': [Eq(f(x), C1 + log(x))], + }, + + # https://github.com/sympy/sympy/issues/6989 + 'algeb_19': { + 'eq': f(x).diff(x) - x*exp(-k*x), + 'sol': [Eq(f(x), C1 + Piecewise(((-k*x - 1)*exp(-k*x)/k**2, Ne(k**2, 0)),(x**2/2, True)))], + }, + + 'algeb_20': { + 'eq': -f(x).diff(x) + x*exp(-k*x), + 'sol': [Eq(f(x), C1 + Piecewise(((-k*x - 1)*exp(-k*x)/k**2, Ne(k**2, 0)),(x**2/2, True)))], + }, + + # https://github.com/sympy/sympy/issues/10867 + 'algeb_21': { + 'eq': Eq(g(x).diff(x).diff(x), (x-2)**2 + (x-3)**3), + 'sol': [Eq(g(x), C1 + C2*x + x**5/20 - 2*x**4/3 + 23*x**3/6 - 23*x**2/2)], + 'func': g(x), + }, + + # https://github.com/sympy/sympy/issues/13691 + 'algeb_22': { + 'eq': f(x).diff(x) - C1*g(x).diff(x), + 'sol': [Eq(f(x), C2 + C1*g(x))], + 'func': f(x), + }, + + # https://github.com/sympy/sympy/issues/4838 + 'algeb_23': { + 'eq': f(x).diff(x) - 3*C1 - 3*x**2, + 'sol': [Eq(f(x), C2 + 3*C1*x + x**3)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_order_reducible(): + return { + 'hint': "nth_order_reducible", + 'func': f(x), + 'examples':{ + 'reducible_01': { + 'eq': Eq(x*Derivative(f(x), x)**2 + Derivative(f(x), x, 2), 0), + 'sol': [Eq(f(x),C1 - sqrt(-1/C2)*log(-C2*sqrt(-1/C2) + x) + + sqrt(-1/C2)*log(C2*sqrt(-1/C2) + x))], + 'slow': True, + }, + + 'reducible_02': { + 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + 'slow': True, + }, + + 'reducible_03': { + 'eq': Eq(sqrt(2) * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(2**Rational(3, 4)*x/2) + C3*cos(2**Rational(3, 4)*x/2))], + 'slow': True, + }, + + 'reducible_04': { + 'eq': f(x).diff(x, 2) + 2*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x))], + }, + + 'reducible_05': { + 'eq': f(x).diff(x, 3) + f(x).diff(x, 2) - 6*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-3*x) + C3*exp(2*x))], + 'slow': True, + }, + + 'reducible_06': { + 'eq': f(x).diff(x, 4) - f(x).diff(x, 3) - 4*f(x).diff(x, 2) + \ + 4*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x) + C3*exp(x) + C4*exp(2*x))], + 'slow': True, + }, + + 'reducible_07': { + 'eq': f(x).diff(x, 4) + 3*f(x).diff(x, 3), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*exp(-3*x))], + 'slow': True, + }, + + 'reducible_08': { + 'eq': f(x).diff(x, 4) - 2*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*exp(-sqrt(2)*x) + C4*exp(sqrt(2)*x))], + 'slow': True, + }, + + 'reducible_09': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*sin(2*x) + C4*cos(2*x))], + 'slow': True, + }, + + 'reducible_10': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*x*sin(x) + C2*cos(x) - C3*x*cos(x) + C3*sin(x) + C4*sin(x) + C5*cos(x))], + 'slow': True, + }, + + 'reducible_11': { + 'eq': f(x).diff(x, 2) - f(x).diff(x)**3, + 'sol': [Eq(f(x), C1 - sqrt(2)*sqrt(-1/(C2 + x))*(C2 + x)), + Eq(f(x), C1 + sqrt(2)*sqrt(-1/(C2 + x))*(C2 + x))], + 'slow': True, + }, + + # Needs to be a way to know how to combine derivatives in the expression + 'reducible_12': { + 'eq': Derivative(x*f(x), x, x, x) + Derivative(f(x), x, x, x), + 'sol': [Eq(f(x), C1 + C3/Mul(2, (x**2 + 2*x + 1), evaluate=False) + + x*(C2 + C3/Mul(2, (x**2 + 2*x + 1), evaluate=False)))], # 2-arg Mul! + 'slow': True, + }, + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_undetermined_coefficients(): + # examples 3-27 below are from Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 231 + g = exp(-x) + f2 = f(x).diff(x, 2) + c = 3*f(x).diff(x, 3) + 5*f2 + f(x).diff(x) - f(x) - x + t = symbols("t") + u = symbols("u",cls=Function) + R, L, C, E_0, alpha = symbols("R L C E_0 alpha",positive=True) + omega = Symbol('omega') + return { + 'hint': "nth_linear_constant_coeff_undetermined_coefficients", + 'func': f(x), + 'examples':{ + 'undet_01': { + 'eq': c - x*g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x**2/24 - 3*x/32))*exp(-x) - 1)], + 'slow': True, + }, + + 'undet_02': { + 'eq': c - g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x/8))*exp(-x) - 1)], + 'slow': True, + }, + + 'undet_03': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 4, + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2)], + 'slow': True, + }, + + 'undet_04': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 12*exp(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2*exp(x))], + 'slow': True, + }, + + 'undet_05': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - exp(I*x), + 'sol': [Eq(f(x), (S(3)/10 + I/10)*(C1*exp(-2*x) + C2*exp(-x) - I*exp(I*x)))], + 'slow': True, + }, + + 'undet_06': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - sin(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + sin(x)/10 - 3*cos(x)/10)], + 'slow': True, + }, + + 'undet_07': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - cos(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 3*sin(x)/10 + cos(x)/10)], + 'slow': True, + }, + + 'undet_08': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - (8 + 6*exp(x) + 2*sin(x)), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + exp(x) + sin(x)/5 - 3*cos(x)/5 + 4)], + 'slow': True, + }, + + 'undet_09': { + 'eq': f2 + f(x).diff(x) + f(x) - x**2, + 'sol': [Eq(f(x), -2*x + x**2 + (C1*sin(x*sqrt(3)/2) + C2*cos(x*sqrt(3)/2))*exp(-x/2))], + 'slow': True, + }, + + 'undet_10': { + 'eq': f2 - 2*f(x).diff(x) - 8*f(x) - 9*x*exp(x) - 10*exp(-x), + 'sol': [Eq(f(x), -x*exp(x) - 2*exp(-x) + C1*exp(-2*x) + C2*exp(4*x))], + 'slow': True, + }, + + 'undet_11': { + 'eq': f2 - 3*f(x).diff(x) - 2*exp(2*x)*sin(x), + 'sol': [Eq(f(x), C1 + C2*exp(3*x) - 3*exp(2*x)*sin(x)/5 - exp(2*x)*cos(x)/5)], + 'slow': True, + }, + + 'undet_12': { + 'eq': f(x).diff(x, 4) - 2*f2 + f(x) - x + sin(x), + 'sol': [Eq(f(x), x - sin(x)/4 + (C1 + C2*x)*exp(-x) + (C3 + C4*x)*exp(x))], + 'slow': True, + }, + + 'undet_13': { + 'eq': f2 + f(x).diff(x) - x**2 - 2*x, + 'sol': [Eq(f(x), C1 + x**3/3 + C2*exp(-x))], + 'slow': True, + }, + + 'undet_14': { + 'eq': f2 + f(x).diff(x) - x - sin(2*x), + 'sol': [Eq(f(x), C1 - x - sin(2*x)/5 - cos(2*x)/10 + x**2/2 + C2*exp(-x))], + 'slow': True, + }, + + 'undet_15': { + 'eq': f2 + f(x) - 4*x*sin(x), + 'sol': [Eq(f(x), (C1 - x**2)*cos(x) + (C2 + x)*sin(x))], + 'slow': True, + }, + + 'undet_16': { + 'eq': f2 + 4*f(x) - x*sin(2*x), + 'sol': [Eq(f(x), (C1 - x**2/8)*cos(2*x) + (C2 + x/16)*sin(2*x))], + 'slow': True, + }, + + 'undet_17': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x**3/12))*exp(-x))], + 'slow': True, + }, + + 'undet_18': { + 'eq': f(x).diff(x, 3) + 3*f2 + 3*f(x).diff(x) + f(x) - 2*exp(-x) + \ + x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 - x**3/60 + x/3)))*exp(-x))], + 'slow': True, + }, + + 'undet_19': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - exp(-2*x) - x**2, + 'sol': [Eq(f(x), C2*exp(-x) + x**2/2 - x*Rational(3,2) + (C1 - x)*exp(-2*x) + Rational(7,4))], + 'slow': True, + }, + + 'undet_20': { + 'eq': f2 - 3*f(x).diff(x) + 2*f(x) - x*exp(-x), + 'sol': [Eq(f(x), C1*exp(x) + C2*exp(2*x) + (6*x + 5)*exp(-x)/36)], + 'slow': True, + }, + + 'undet_21': { + 'eq': f2 + f(x).diff(x) - 6*f(x) - x - exp(2*x), + 'sol': [Eq(f(x), Rational(-1, 36) - x/6 + C2*exp(-3*x) + (C1 + x/5)*exp(2*x))], + 'slow': True, + }, + + 'undet_22': { + 'eq': f2 + f(x) - sin(x) - exp(-x), + 'sol': [Eq(f(x), C2*sin(x) + (C1 - x/2)*cos(x) + exp(-x)/2)], + 'slow': True, + }, + + 'undet_23': { + 'eq': f(x).diff(x, 3) - 3*f2 + 3*f(x).diff(x) - f(x) - exp(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 + x/6)))*exp(x))], + 'slow': True, + }, + + 'undet_24': { + 'eq': f2 + f(x) - S.Half - cos(2*x)/2, + 'sol': [Eq(f(x), S.Half - cos(2*x)/6 + C1*sin(x) + C2*cos(x))], + 'slow': True, + }, + + 'undet_25': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - exp(2*x)*(S.Half - cos(2*x)/2), + 'sol': [Eq(f(x), C1 + C2*exp(-x) + C3*exp(x) + (-21*sin(2*x) + 27*cos(2*x) + 130)*exp(2*x)/1560)], + 'slow': True, + }, + + #Note: 'undet_26' is referred in 'undet_37' + 'undet_26': { + 'eq': (f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - + sin(x) - cos(x)), + 'sol': [Eq(f(x), C1 + x**2 + (C2 + x*(C3 - x/8))*sin(x) + (C4 + x*(C5 + x/8))*cos(x))], + 'slow': True, + }, + + 'undet_27': { + 'eq': f2 + f(x) - cos(x)/2 + cos(3*x)/2, + 'sol': [Eq(f(x), cos(3*x)/16 + C2*cos(x) + (C1 + x/4)*sin(x))], + 'slow': True, + }, + + 'undet_28': { + 'eq': f(x).diff(x) - 1, + 'sol': [Eq(f(x), C1 + x)], + 'slow': True, + }, + + # https://github.com/sympy/sympy/issues/19358 + 'undet_29': { + 'eq': f2 + f(x).diff(x) + exp(x-C1), + 'sol': [Eq(f(x), C2 + C3*exp(-x) - exp(-C1 + x)/2)], + 'slow': True, + }, + + # https://github.com/sympy/sympy/issues/18408 + 'undet_30': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - sinh(x), + 'sol': [Eq(f(x), C1 + C2*exp(-x) + C3*exp(x) + x*sinh(x)/2)], + }, + + 'undet_31': { + 'eq': f(x).diff(x, 2) - 49*f(x) - sinh(3*x), + 'sol': [Eq(f(x), C1*exp(-7*x) + C2*exp(7*x) - sinh(3*x)/40)], + }, + + 'undet_32': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - sinh(x) - exp(x), + 'sol': [Eq(f(x), C1 + C3*exp(-x) + x*sinh(x)/2 + (C2 + x/2)*exp(x))], + }, + + # https://github.com/sympy/sympy/issues/5096 + 'undet_33': { + 'eq': f(x).diff(x, x) + f(x) - x*sin(x - 2), + 'sol': [Eq(f(x), C1*sin(x) + C2*cos(x) - x**2*cos(x - 2)/4 + x*sin(x - 2)/4)], + }, + + 'undet_34': { + 'eq': f(x).diff(x, 2) + f(x) - x**4*sin(x-1), + 'sol': [ Eq(f(x), C1*sin(x) + C2*cos(x) - x**5*cos(x - 1)/10 + x**4*sin(x - 1)/4 + x**3*cos(x - 1)/2 - 3*x**2*sin(x - 1)/4 - 3*x*cos(x - 1)/4)], + }, + + 'undet_35': { + 'eq': f(x).diff(x, 2) - f(x) - exp(x - 1), + 'sol': [Eq(f(x), C2*exp(-x) + (C1 + x*exp(-1)/2)*exp(x))], + }, + + 'undet_36': { + 'eq': f(x).diff(x, 2)+f(x)-(sin(x-2)+1), + 'sol': [Eq(f(x), C1*sin(x) + C2*cos(x) - x*cos(x - 2)/2 + 1)], + }, + + # Equivalent to example_name 'undet_26'. + # This previously failed because the algorithm for undetermined coefficients + # didn't know to multiply exp(I*x) by sufficient x because it is linearly + # dependent on sin(x) and cos(x). + 'undet_37': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x), + 'sol': [Eq(f(x), C1 + x**2*(I*exp(I*x)/8 + 1) + (C2 + C3*x)*sin(x) + (C4 + C5*x)*cos(x))], + }, + + # https://github.com/sympy/sympy/issues/12623 + 'undet_38': { + 'eq': Eq( u(t).diff(t,t) + R /L*u(t).diff(t) + 1/(L*C)*u(t), alpha), + 'sol': [Eq(u(t), C*L*alpha + C2*exp(-t*(R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + + C1*exp(t*(-R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)))], + 'func': u(t) + }, + + 'undet_39': { + 'eq': Eq( L*C*u(t).diff(t,t) + R*C*u(t).diff(t) + u(t), E_0*exp(I*omega*t) ), + 'sol': [Eq(u(t), C2*exp(-t*(R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + + C1*exp(t*(-R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + - E_0*exp(I*omega*t)/(C*L*omega**2 - I*C*R*omega - 1))], + 'func': u(t), + }, + + # https://github.com/sympy/sympy/issues/6879 + 'undet_40': { + 'eq': Eq(Derivative(f(x), x, 2) - 2*Derivative(f(x), x) + f(x), sin(x)), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(x) + cos(x)/2)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_separable(): + # test_separable1-5 are from Ordinary Differential Equations, Tenenbaum and + # Pollard, pg. 55 + t,a = symbols('a,t') + m = 96 + g = 9.8 + k = .2 + f1 = g * m + v = Function('v') + return { + 'hint': "separable", + 'func': f(x), + 'examples':{ + 'separable_01': { + 'eq': f(x).diff(x) - f(x), + 'sol': [Eq(f(x), C1*exp(x))], + }, + + 'separable_02': { + 'eq': x*f(x).diff(x) - f(x), + 'sol': [Eq(f(x), C1*x)], + }, + + 'separable_03': { + 'eq': f(x).diff(x) + sin(x), + 'sol': [Eq(f(x), C1 + cos(x))], + }, + + 'separable_04': { + 'eq': f(x)**2 + 1 - (x**2 + 1)*f(x).diff(x), + 'sol': [Eq(f(x), tan(C1 + atan(x)))], + }, + + 'separable_05': { + 'eq': f(x).diff(x)/tan(x) - f(x) - 2, + 'sol': [Eq(f(x), C1/cos(x) - 2)], + }, + + 'separable_06': { + 'eq': f(x).diff(x) * (1 - sin(f(x))) - 1, + 'sol': [Eq(-x + f(x) + cos(f(x)), C1)], + }, + + 'separable_07': { + 'eq': f(x)*x**2*f(x).diff(x) - f(x)**3 - 2*x**2*f(x).diff(x), + 'sol': [Eq(f(x), (-x - sqrt(x*(4*C1*x + x - 4)))/(C1*x - 1)/2), + Eq(f(x), (-x + sqrt(x*(4*C1*x + x - 4)))/(C1*x - 1)/2)], + 'slow': True, + }, + + 'separable_08': { + 'eq': f(x)**2 - 1 - (2*f(x) + x*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -sqrt(C1*x**2 + 4*C1*x + 4*C1 + 1)), + Eq(f(x), sqrt(C1*x**2 + 4*C1*x + 4*C1 + 1))], + 'slow': True, + }, + + 'separable_09': { + 'eq': x*log(x)*f(x).diff(x) + sqrt(1 + f(x)**2), + 'sol': [Eq(f(x), sinh(C1 - log(log(x))))], #One more solution is f(x)=I + 'slow': True, + 'checkodesol_XFAIL': True, + }, + + 'separable_10': { + 'eq': exp(x + 1)*tan(f(x)) + cos(f(x))*f(x).diff(x), + 'sol': [Eq(E*exp(x) + log(cos(f(x)) - 1)/2 - log(cos(f(x)) + 1)/2 + cos(f(x)), C1)], + 'slow': True, + }, + + 'separable_11': { + 'eq': (x*cos(f(x)) + x**2*sin(f(x))*f(x).diff(x) - a**2*sin(f(x))*f(x).diff(x)), + 'sol': [ + Eq(f(x), -acos(C1*sqrt(-a**2 + x**2)) + 2*pi), + Eq(f(x), acos(C1*sqrt(-a**2 + x**2))) + ], + 'slow': True, + }, + + 'separable_12': { + 'eq': f(x).diff(x) - f(x)*tan(x), + 'sol': [Eq(f(x), C1/cos(x))], + }, + + 'separable_13': { + 'eq': (x - 1)*cos(f(x))*f(x).diff(x) - 2*x*sin(f(x)), + 'sol': [ + Eq(f(x), pi - asin(C1*(x**2 - 2*x + 1)*exp(2*x))), + Eq(f(x), asin(C1*(x**2 - 2*x + 1)*exp(2*x))) + ], + }, + + 'separable_14': { + 'eq': f(x).diff(x) - f(x)*log(f(x))/tan(x), + 'sol': [Eq(f(x), exp(C1*sin(x)))], + }, + + 'separable_15': { + 'eq': x*f(x).diff(x) + (1 + f(x)**2)*atan(f(x)), + 'sol': [Eq(f(x), tan(C1/x))], #Two more solutions are f(x)=0 and f(x)=I + 'slow': True, + 'checkodesol_XFAIL': True, + }, + + 'separable_16': { + 'eq': f(x).diff(x) + x*(f(x) + 1), + 'sol': [Eq(f(x), -1 + C1*exp(-x**2/2))], + }, + + 'separable_17': { + 'eq': exp(f(x)**2)*(x**2 + 2*x + 1) + (x*f(x) + f(x))*f(x).diff(x), + 'sol': [ + Eq(f(x), -sqrt(log(1/(C1 + x**2 + 2*x)))), + Eq(f(x), sqrt(log(1/(C1 + x**2 + 2*x)))) + ], + }, + + 'separable_18': { + 'eq': f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*exp(-x))], + }, + + 'separable_19': { + 'eq': sin(x)*cos(2*f(x)) + cos(x)*sin(2*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), pi - acos(C1/cos(x)**2)/2), Eq(f(x), acos(C1/cos(x)**2)/2)], + }, + + 'separable_20': { + 'eq': (1 - x)*f(x).diff(x) - x*(f(x) + 1), + 'sol': [Eq(f(x), (C1*exp(-x) - x + 1)/(x - 1))], + }, + + 'separable_21': { + 'eq': f(x)*diff(f(x), x) + x - 3*x*f(x)**2, + 'sol': [Eq(f(x), -sqrt(3)*sqrt(C1*exp(3*x**2) + 1)/3), + Eq(f(x), sqrt(3)*sqrt(C1*exp(3*x**2) + 1)/3)], + }, + + 'separable_22': { + 'eq': f(x).diff(x) - exp(x + f(x)), + 'sol': [Eq(f(x), log(-1/(C1 + exp(x))))], + 'XFAIL': ['lie_group'] #It shows 'NoneType' object is not subscriptable for lie_group. + }, + + # https://github.com/sympy/sympy/issues/7081 + 'separable_23': { + 'eq': x*(f(x).diff(x)) + 1 - f(x)**2, + 'sol': [Eq(f(x), (-C1 - x**2)/(-C1 + x**2))], + }, + + # https://github.com/sympy/sympy/issues/10379 + 'separable_24': { + 'eq': f(t).diff(t)-(1-51.05*y*f(t)), + 'sol': [Eq(f(t), (0.019588638589618023*exp(y*(C1 - 51.049999999999997*t)) + 0.019588638589618023)/y)], + 'func': f(t), + }, + + # https://github.com/sympy/sympy/issues/15999 + 'separable_25': { + 'eq': f(x).diff(x) - C1*f(x), + 'sol': [Eq(f(x), C2*exp(C1*x))], + }, + + 'separable_26': { + 'eq': f1 - k * (v(t) ** 2) - m * Derivative(v(t)), + 'sol': [Eq(v(t), -68.585712797928991/tanh(C1 - 0.14288690166235204*t))], + 'func': v(t), + 'checkodesol_XFAIL': True, + }, + + #https://github.com/sympy/sympy/issues/22155 + 'separable_27': { + 'eq': f(x).diff(x) - exp(f(x) - x), + 'sol': [Eq(f(x), log(-exp(x)/(C1*exp(x) - 1)))], + } + } + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_exact(): + # Type: Exact differential equation, p(x,f) + q(x,f)*f' == 0, + # where dp/df == dq/dx + ''' + Example 7 is an exact equation that fails under the exact engine. It is caught + by first order homogeneous albeit with a much contorted solution. The + exact engine fails because of a poorly simplified integral of q(0,y)dy, + where q is the function multiplying f'. The solutions should be + Eq(sqrt(x**2+f(x)**2)**3+y**3, C1). The equation below is + equivalent, but it is so complex that checkodesol fails, and takes a long + time to do so. + ''' + return { + 'hint': "1st_exact", + 'func': f(x), + 'examples':{ + '1st_exact_01': { + 'eq': sin(x)*cos(f(x)) + cos(x)*sin(f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))], + 'slow': True, + }, + + '1st_exact_02': { + 'eq': (2*x*f(x) + 1)/f(x) + (f(x) - x)/f(x)**2*f(x).diff(x), + 'sol': [Eq(f(x), exp(C1 - x**2 + LambertW(-x*exp(-C1 + x**2))))], + 'XFAIL': ['lie_group'], #It shows dsolve raises an exception: List index out of range for lie_group + 'slow': True, + 'checkodesol_XFAIL':True + }, + + '1st_exact_03': { + 'eq': 2*x + f(x)*cos(x) + (2*f(x) + sin(x) - sin(f(x)))*f(x).diff(x), + 'sol': [Eq(f(x)*sin(x) + cos(f(x)) + x**2 + f(x)**2, C1)], + 'XFAIL': ['lie_group'], #It goes into infinite loop for lie_group. + 'slow': True, + }, + + '1st_exact_04': { + 'eq': cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + 'sol': [Eq(x*cos(f(x)) + f(x)**3/3, C1)], + 'slow': True, + }, + + '1st_exact_05': { + 'eq': 2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), + 'sol': [Eq(x**2*f(x) + f(x)**3/3, C1)], + 'slow': True, + 'simplify_flag':False + }, + + # This was from issue: https://github.com/sympy/sympy/issues/11290 + '1st_exact_06': { + 'eq': cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + 'sol': [Eq(x*cos(f(x)) + f(x)**3/3, C1)], + 'simplify_flag':False + }, + + '1st_exact_07': { + 'eq': x*sqrt(x**2 + f(x)**2) - (x**2*f(x)/(f(x) - sqrt(x**2 + f(x)**2)))*f(x).diff(x), + 'sol': [Eq(log(x), + C1 - 9*sqrt(1 + f(x)**2/x**2)*asinh(f(x)/x)/(-27*f(x)/x + + 27*sqrt(1 + f(x)**2/x**2)) - 9*sqrt(1 + f(x)**2/x**2)* + log(1 - sqrt(1 + f(x)**2/x**2)*f(x)/x + 2*f(x)**2/x**2)/ + (-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2)) + + 9*asinh(f(x)/x)*f(x)/(x*(-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2))) + + 9*f(x)*log(1 - sqrt(1 + f(x)**2/x**2)*f(x)/x + 2*f(x)**2/x**2)/ + (x*(-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2))))], + 'slow': True, + 'dsolve_too_slow':True + }, + + # Type: a(x)f'(x)+b(x)*f(x)+c(x)=0 + '1st_exact_08': { + 'eq': Eq(x**2*f(x).diff(x) + 3*x*f(x) - sin(x)/x, 0), + 'sol': [Eq(f(x), (C1 - cos(x))/x**3)], + }, + + # these examples are from test_exact_enhancement + '1st_exact_09': { + 'eq': f(x)/x**2 + ((f(x)*x - 1)/x)*f(x).diff(x), + 'sol': [Eq(f(x), (i*sqrt(C1*x**2 + 1) + 1)/x) for i in (-1, 1)], + }, + + '1st_exact_10': { + 'eq': (x*f(x) - 1) + f(x).diff(x)*(x**2 - x*f(x)), + 'sol': [Eq(f(x), x - sqrt(C1 + x**2 - 2*log(x))), Eq(f(x), x + sqrt(C1 + x**2 - 2*log(x)))], + }, + + '1st_exact_11': { + 'eq': (x + 2)*sin(f(x)) + f(x).diff(x)*x*cos(f(x)), + 'sol': [Eq(f(x), -asin(C1*exp(-x)/x**2) + pi), Eq(f(x), asin(C1*exp(-x)/x**2))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_var_of_parameters(): + g = exp(-x) + f2 = f(x).diff(x, 2) + c = 3*f(x).diff(x, 3) + 5*f2 + f(x).diff(x) - f(x) - x + return { + 'hint': "nth_linear_constant_coeff_variation_of_parameters", + 'func': f(x), + 'examples':{ + 'var_of_parameters_01': { + 'eq': c - x*g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x**2/24 - 3*x/32))*exp(-x) - 1)], + 'slow': True, + }, + + 'var_of_parameters_02': { + 'eq': c - g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x/8))*exp(-x) - 1)], + 'slow': True, + }, + + 'var_of_parameters_03': { + 'eq': f(x).diff(x) - 1, + 'sol': [Eq(f(x), C1 + x)], + 'slow': True, + }, + + 'var_of_parameters_04': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 4, + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2)], + 'slow': True, + }, + + 'var_of_parameters_05': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 12*exp(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2*exp(x))], + 'slow': True, + }, + + 'var_of_parameters_06': { + 'eq': f2 - 2*f(x).diff(x) - 8*f(x) - 9*x*exp(x) - 10*exp(-x), + 'sol': [Eq(f(x), -x*exp(x) - 2*exp(-x) + C1*exp(-2*x) + C2*exp(4*x))], + 'slow': True, + }, + + 'var_of_parameters_07': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x**3/12))*exp(-x))], + 'slow': True, + }, + + 'var_of_parameters_08': { + 'eq': f2 - 3*f(x).diff(x) + 2*f(x) - x*exp(-x), + 'sol': [Eq(f(x), C1*exp(x) + C2*exp(2*x) + (6*x + 5)*exp(-x)/36)], + 'slow': True, + }, + + 'var_of_parameters_09': { + 'eq': f(x).diff(x, 3) - 3*f2 + 3*f(x).diff(x) - f(x) - exp(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 + x/6)))*exp(x))], + 'slow': True, + }, + + 'var_of_parameters_10': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - exp(-x)/x, + 'sol': [Eq(f(x), (C1 + x*(C2 + log(x)))*exp(-x))], + 'slow': True, + }, + + 'var_of_parameters_11': { + 'eq': f2 + f(x) - 1/sin(x)*1/cos(x), + 'sol': [Eq(f(x), (C1 + log(sin(x) - 1)/2 - log(sin(x) + 1)/2 + )*cos(x) + (C2 + log(cos(x) - 1)/2 - log(cos(x) + 1)/2)*sin(x))], + 'slow': True, + }, + + 'var_of_parameters_12': { + 'eq': f(x).diff(x, 4) - 1/x, + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + x**3*(C4 + log(x)/6))], + 'slow': True, + }, + + # These were from issue: https://github.com/sympy/sympy/issues/15996 + 'var_of_parameters_13': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x), + 'sol': [Eq(f(x), C1 + x**2 + (C2 + x*(C3 - x/8 + 3*exp(I*x)/2 + 3*exp(-I*x)/2) + 5*exp(2*I*x)/16 + 2*I*exp(I*x) - 2*I*exp(-I*x))*sin(x) + (C4 + x*(C5 + I*x/8 + 3*I*exp(I*x)/2 - 3*I*exp(-I*x)/2) + + 5*I*exp(2*I*x)/16 - 2*exp(I*x) - 2*exp(-I*x))*cos(x) - I*exp(I*x))], + }, + + 'var_of_parameters_14': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - exp(I*x), + 'sol': [Eq(f(x), C1 + (C2 + x*(C3 - x/8) + 5*exp(2*I*x)/16)*sin(x) + (C4 + x*(C5 + I*x/8) + 5*I*exp(2*I*x)/16)*cos(x) - I*exp(I*x))], + }, + + # https://github.com/sympy/sympy/issues/14395 + 'var_of_parameters_15': { + 'eq': Derivative(f(x), x, x) + 9*f(x) - sec(x), + 'sol': [Eq(f(x), (C1 - x/3 + sin(2*x)/3)*sin(3*x) + (C2 + log(cos(x)) + - 2*log(cos(x)**2)/3 + 2*cos(x)**2/3)*cos(3*x))], + 'slow': True, + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_linear_bessel(): + return { + 'hint': "2nd_linear_bessel", + 'func': f(x), + 'examples':{ + '2nd_lin_bessel_01': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - 4)*f(x), + 'sol': [Eq(f(x), C1*besselj(2, x) + C2*bessely(2, x))], + }, + + '2nd_lin_bessel_02': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 +25)*f(x), + 'sol': [Eq(f(x), C1*besselj(5*I, x) + C2*bessely(5*I, x))], + }, + + '2nd_lin_bessel_03': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2)*f(x), + 'sol': [Eq(f(x), C1*besselj(0, x) + C2*bessely(0, x))], + }, + + '2nd_lin_bessel_04': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (81*x**2 -S(1)/9)*f(x), + 'sol': [Eq(f(x), C1*besselj(S(1)/3, 9*x) + C2*bessely(S(1)/3, 9*x))], + }, + + '2nd_lin_bessel_05': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**4 - 4)*f(x), + 'sol': [Eq(f(x), C1*besselj(1, x**2/2) + C2*bessely(1, x**2/2))], + }, + + '2nd_lin_bessel_06': { + 'eq': x**2*(f(x).diff(x, 2)) + 2*x*(f(x).diff(x)) + (x**4 - 4)*f(x), + 'sol': [Eq(f(x), (C1*besselj(sqrt(17)/4, x**2/2) + C2*bessely(sqrt(17)/4, x**2/2))/sqrt(x))], + }, + + '2nd_lin_bessel_07': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - S(1)/4)*f(x), + 'sol': [Eq(f(x), C1*besselj(S(1)/2, x) + C2*bessely(S(1)/2, x))], + }, + + '2nd_lin_bessel_08': { + 'eq': x**2*(f(x).diff(x, 2)) - 3*x*(f(x).diff(x)) + (4*x + 4)*f(x), + 'sol': [Eq(f(x), x**2*(C1*besselj(0, 4*sqrt(x)) + C2*bessely(0, 4*sqrt(x))))], + }, + + '2nd_lin_bessel_09': { + 'eq': x*(f(x).diff(x, 2)) - f(x).diff(x) + 4*x**3*f(x), + 'sol': [Eq(f(x), x*(C1*besselj(S(1)/2, x**2) + C2*bessely(S(1)/2, x**2)))], + }, + + '2nd_lin_bessel_10': { + 'eq': (x-2)**2*(f(x).diff(x, 2)) - (x-2)*f(x).diff(x) + 4*(x-2)**2*f(x), + 'sol': [Eq(f(x), (x - 2)*(C1*besselj(1, 2*x - 4) + C2*bessely(1, 2*x - 4)))], + }, + + # https://github.com/sympy/sympy/issues/4414 + '2nd_lin_bessel_11': { + 'eq': f(x).diff(x, x) + 2/x*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*besselj(S(1)/2, x) + C2*bessely(S(1)/2, x))/sqrt(x))], + }, + '2nd_lin_bessel_12': { + 'eq': x**2*f(x).diff(x, 2) + x*f(x).diff(x) + (a**2*x**2/c**2 - b**2)*f(x), + 'sol': [Eq(f(x), C1*besselj(sqrt(b**2), x*sqrt(a**2/c**2)) + C2*bessely(sqrt(b**2), x*sqrt(a**2/c**2)))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_2F1_hypergeometric(): + return { + 'hint': "2nd_hypergeometric", + 'func': f(x), + 'examples':{ + '2nd_2F1_hyper_01': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(3)/2 -2*x)*f(x).diff(x) + 2*f(x), + 'sol': [Eq(f(x), C1*x**(S(5)/2)*hyper((S(3)/2, S(1)/2), (S(7)/2,), x) + C2*hyper((-1, -2), (-S(3)/2,), x))], + }, + + '2nd_2F1_hyper_02': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(7)/2*x)*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*(1 - x)**(S(5)/2)*hyper((S(1)/2, 2), (S(7)/2,), 1 - x) + + C2*hyper((-S(1)/2, -2), (-S(3)/2,), 1 - x))/(x - 1)**(S(5)/2))], + }, + + '2nd_2F1_hyper_03': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(3)+ S(7)/2*x)*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*(1 - x)**(S(11)/2)*hyper((S(1)/2, 2), (S(13)/2,), 1 - x) + + C2*hyper((-S(7)/2, -5), (-S(9)/2,), 1 - x))/(x - 1)**(S(11)/2))], + }, + + '2nd_2F1_hyper_04': { + 'eq': -x**(S(5)/7)*(-416*x**(S(9)/7)/9 - 2385*x**(S(5)/7)/49 + S(298)*x/3)*f(x)/(196*(-x**(S(6)/7) + + x)**2*(x**(S(6)/7) + x)**2) + Derivative(f(x), (x, 2)), + 'sol': [Eq(f(x), x**(S(45)/98)*(C1*x**(S(4)/49)*hyper((S(1)/3, -S(1)/2), (S(9)/7,), x**(S(2)/7)) + + C2*hyper((S(1)/21, -S(11)/14), (S(5)/7,), x**(S(2)/7)))/(x**(S(2)/7) - 1)**(S(19)/84))], + 'checkodesol_XFAIL':True, + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_2nd_nonlinear_autonomous_conserved(): + return { + 'hint': "2nd_nonlinear_autonomous_conserved", + 'func': f(x), + 'examples': { + '2nd_nonlinear_autonomous_conserved_01': { + 'eq': f(x).diff(x, 2) + exp(f(x)) + log(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*_u*log(_u) + 2*_u - 2*exp(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*_u*log(_u) + 2*_u - 2*exp(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_02': { + 'eq': f(x).diff(x, 2) + cbrt(f(x)) + 1/f(x), + 'sol': [ + Eq(sqrt(2)*Integral(1/sqrt(2*C1 - 3*_u**Rational(4, 3) - 4*log(_u)), (_u, f(x))), C2 + x), + Eq(sqrt(2)*Integral(1/sqrt(2*C1 - 3*_u**Rational(4, 3) - 4*log(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_03': { + 'eq': f(x).diff(x, 2) + sin(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 + 2*cos(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 + 2*cos(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_04': { + 'eq': f(x).diff(x, 2) + cosh(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*sinh(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*sinh(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_05': { + 'eq': f(x).diff(x, 2) + asin(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*_u*asin(_u) - 2*sqrt(1 - _u**2)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*_u*asin(_u) - 2*sqrt(1 - _u**2)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + 'XFAIL': ['2nd_nonlinear_autonomous_conserved_Integral'] + } + } + } + + +@_add_example_keys +def _get_examples_ode_sol_separable_reduced(): + df = f(x).diff(x) + return { + 'hint': "separable_reduced", + 'func': f(x), + 'examples':{ + 'separable_reduced_01': { + 'eq': x* df + f(x)* (1 / (x**2*f(x) - 1)), + 'sol': [Eq(log(x**2*f(x))/3 + log(x**2*f(x) - Rational(3, 2))/6, C1 + log(x))], + 'simplify_flag': False, + 'XFAIL': ['lie_group'], #It hangs. + }, + + #Note: 'separable_reduced_02' is referred in 'separable_reduced_11' + 'separable_reduced_02': { + 'eq': f(x).diff(x) + (f(x) / (x**4*f(x) - x)), + 'sol': [Eq(log(x**3*f(x))/4 + log(x**3*f(x) - Rational(4,3))/12, C1 + log(x))], + 'simplify_flag': False, + 'checkodesol_XFAIL':True, #It hangs for this. + }, + + 'separable_reduced_03': { + 'eq': x*df + f(x)*(x**2*f(x)), + 'sol': [Eq(log(x**2*f(x))/2 - log(x**2*f(x) - 2)/2, C1 + log(x))], + 'simplify_flag': False, + }, + + 'separable_reduced_04': { + 'eq': Eq(f(x).diff(x) + f(x)/x * (1 + (x**(S(2)/3)*f(x))**2), 0), + 'sol': [Eq(-3*log(x**(S(2)/3)*f(x)) + 3*log(3*x**(S(4)/3)*f(x)**2 + 1)/2, C1 + log(x))], + 'simplify_flag': False, + }, + + 'separable_reduced_05': { + 'eq': Eq(f(x).diff(x) + f(x)/x * (1 + (x*f(x))**2), 0), + 'sol': [Eq(f(x), -sqrt(2)*sqrt(1/(C1 + log(x)))/(2*x)),\ + Eq(f(x), sqrt(2)*sqrt(1/(C1 + log(x)))/(2*x))], + }, + + 'separable_reduced_06': { + 'eq': Eq(f(x).diff(x) + (x**4*f(x)**2 + x**2*f(x))*f(x)/(x*(x**6*f(x)**3 + x**4*f(x)**2)), 0), + 'sol': [Eq(f(x), C1 + 1/(2*x**2))], + }, + + 'separable_reduced_07': { + 'eq': Eq(f(x).diff(x) + (f(x)**2)*f(x)/(x), 0), + 'sol': [ + Eq(f(x), -sqrt(2)*sqrt(1/(C1 + log(x)))/2), + Eq(f(x), sqrt(2)*sqrt(1/(C1 + log(x)))/2) + ], + }, + + 'separable_reduced_08': { + 'eq': Eq(f(x).diff(x) + (f(x)+3)*f(x)/(x*(f(x)+2)), 0), + 'sol': [Eq(-log(f(x) + 3)/3 - 2*log(f(x))/3, C1 + log(x))], + 'simplify_flag': False, + 'XFAIL': ['lie_group'], #It hangs. + }, + + 'separable_reduced_09': { + 'eq': Eq(f(x).diff(x) + (f(x)+3)*f(x)/x, 0), + 'sol': [Eq(f(x), 3/(C1*x**3 - 1))], + }, + + 'separable_reduced_10': { + 'eq': Eq(f(x).diff(x) + (f(x)**2+f(x))*f(x)/(x), 0), + 'sol': [Eq(- log(x) - log(f(x) + 1) + log(f(x)) + 1/f(x), C1)], + 'XFAIL': ['lie_group'],#No algorithms are implemented to solve equation -C1 + x*(_y + 1)*exp(-1/_y)/_y + + }, + + # Equivalent to example_name 'separable_reduced_02'. Only difference is testing with simplify=True + 'separable_reduced_11': { + 'eq': f(x).diff(x) + (f(x) / (x**4*f(x) - x)), + 'sol': [Eq(f(x), -sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 +- sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 +- 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), -sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 ++ sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 +- 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 +- sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 4/x**6 + 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 ++ sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) ++ x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 + 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) +- exp(12*C1)/x**6)**Rational(1,3) - 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3))], + 'checkodesol_XFAIL':True, #It hangs for this. + 'slow': True, + }, + + #These were from issue: https://github.com/sympy/sympy/issues/6247 + 'separable_reduced_12': { + 'eq': x**2*f(x)**2 + x*Derivative(f(x), x), + 'sol': [Eq(f(x), 2*C1/(C1*x**2 - 1))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_lie_group(): + a, b, c = symbols("a b c") + return { + 'hint': "lie_group", + 'func': f(x), + 'examples':{ + #Example 1-4 and 19-20 were from issue: https://github.com/sympy/sympy/issues/17322 + 'lie_group_01': { + 'eq': x*f(x).diff(x)*(f(x)+4) + (f(x)**2) -2*f(x)-2*x, + 'sol': [], + 'dsolve_too_slow': True, + 'checkodesol_too_slow': True, + }, + + 'lie_group_02': { + 'eq': x*f(x).diff(x)*(f(x)+4) + (f(x)**2) -2*f(x)-2*x, + 'sol': [], + 'dsolve_too_slow': True, + }, + + 'lie_group_03': { + 'eq': Eq(x**7*Derivative(f(x), x) + 5*x**3*f(x)**2 - (2*x**2 + 2)*f(x)**3, 0), + 'sol': [], + 'dsolve_too_slow': True, + }, + + 'lie_group_04': { + 'eq': f(x).diff(x) - (f(x) - x*log(x))**2/x**2 + log(x), + 'sol': [], + 'XFAIL': ['lie_group'], + }, + + 'lie_group_05': { + 'eq': f(x).diff(x)**2, + 'sol': [Eq(f(x), C1)], + 'XFAIL': ['factorable'], #It raises Not Implemented error + }, + + 'lie_group_06': { + 'eq': Eq(f(x).diff(x), x**2*f(x)), + 'sol': [Eq(f(x), C1*exp(x**3)**Rational(1, 3))], + }, + + 'lie_group_07': { + 'eq': f(x).diff(x) + a*f(x) - c*exp(b*x), + 'sol': [Eq(f(x), Piecewise(((-C1*(a + b) + c*exp(x*(a + b)))*exp(-a*x)/(a + b),\ + Ne(a, -b)), ((-C1 + c*x)*exp(-a*x), True)))], + }, + + 'lie_group_08': { + 'eq': f(x).diff(x) + 2*x*f(x) - x*exp(-x**2), + 'sol': [Eq(f(x), (C1 + x**2/2)*exp(-x**2))], + }, + + 'lie_group_09': { + 'eq': (1 + 2*x)*(f(x).diff(x)) + 2 - 4*exp(-f(x)), + 'sol': [Eq(f(x), log(C1/(2*x + 1) + 2))], + }, + + 'lie_group_10': { + 'eq': x**2*(f(x).diff(x)) - f(x) + x**2*exp(x - (1/x)), + 'sol': [Eq(f(x), (C1 - exp(x))*exp(-1/x))], + 'XFAIL': ['factorable'], #It raises Recursion Error (maixmum depth exceeded) + }, + + 'lie_group_11': { + 'eq': x**2*f(x)**2 + x*Derivative(f(x), x), + 'sol': [Eq(f(x), 2/(C1 + x**2))], + }, + + 'lie_group_12': { + 'eq': diff(f(x),x) + 2*x*f(x) - x*exp(-x**2), + 'sol': [Eq(f(x), exp(-x**2)*(C1 + x**2/2))], + }, + + 'lie_group_13': { + 'eq': diff(f(x),x) + f(x)*cos(x) - exp(2*x), + 'sol': [Eq(f(x), exp(-sin(x))*(C1 + Integral(exp(2*x)*exp(sin(x)), x)))], + }, + + 'lie_group_14': { + 'eq': diff(f(x),x) + f(x)*cos(x) - sin(2*x)/2, + 'sol': [Eq(f(x), C1*exp(-sin(x)) + sin(x) - 1)], + }, + + 'lie_group_15': { + 'eq': x*diff(f(x),x) + f(x) - x*sin(x), + 'sol': [Eq(f(x), (C1 - x*cos(x) + sin(x))/x)], + }, + + 'lie_group_16': { + 'eq': x*diff(f(x),x) - f(x) - x/log(x), + 'sol': [Eq(f(x), x*(C1 + log(log(x))))], + }, + + 'lie_group_17': { + 'eq': (f(x).diff(x)-f(x)) * (f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1*exp(-x))], + }, + + 'lie_group_18': { + 'eq': f(x).diff(x) * (f(x).diff(x) - f(x)), + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1)], + }, + + 'lie_group_19': { + 'eq': (f(x).diff(x)-f(x)) * (f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1*exp(-x)), Eq(f(x), C1*exp(x))], + }, + + 'lie_group_20': { + 'eq': f(x).diff(x)*(f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1), Eq(f(x), C1*exp(-x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_linear_airy(): + return { + 'hint': "2nd_linear_airy", + 'func': f(x), + 'examples':{ + '2nd_lin_airy_01': { + 'eq': f(x).diff(x, 2) - x*f(x), + 'sol': [Eq(f(x), C1*airyai(x) + C2*airybi(x))], + }, + + '2nd_lin_airy_02': { + 'eq': f(x).diff(x, 2) + 2*x*f(x), + 'sol': [Eq(f(x), C1*airyai(-2**(S(1)/3)*x) + C2*airybi(-2**(S(1)/3)*x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_constant_coeff_homogeneous(): + # From Exercise 20, in Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 220 + a = Symbol('a', positive=True) + k = Symbol('k', real=True) + r1, r2, r3, r4, r5 = [rootof(x**5 + 11*x - 2, n) for n in range(5)] + r6, r7, r8, r9, r10 = [rootof(x**5 - 3*x + 1, n) for n in range(5)] + r11, r12, r13, r14, r15 = [rootof(x**5 - 100*x**3 + 1000*x + 1, n) for n in range(5)] + r16, r17, r18, r19, r20 = [rootof(x**5 - x**4 + 10, n) for n in range(5)] + r21, r22, r23, r24, r25 = [rootof(x**5 - x + 1, n) for n in range(5)] + E = exp(1) + return { + 'hint': "nth_linear_constant_coeff_homogeneous", + 'func': f(x), + 'examples':{ + 'lin_const_coeff_hom_01': { + 'eq': f(x).diff(x, 2) + 2*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x))], + }, + + 'lin_const_coeff_hom_02': { + 'eq': f(x).diff(x, 2) - 3*f(x).diff(x) + 2*f(x), + 'sol': [Eq(f(x), (C1 + C2*exp(x))*exp(x))], + }, + + 'lin_const_coeff_hom_03': { + 'eq': f(x).diff(x, 2) - f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(x))], + }, + + 'lin_const_coeff_hom_04': { + 'eq': f(x).diff(x, 3) + f(x).diff(x, 2) - 6*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-3*x) + C3*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_05': { + 'eq': 6*f(x).diff(x, 2) - 11*f(x).diff(x) + 4*f(x), + 'sol': [Eq(f(x), C1*exp(x/2) + C2*exp(x*Rational(4, 3)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_06': { + 'eq': Eq(f(x).diff(x, 2) + 2*f(x).diff(x) - f(x), 0), + 'sol': [Eq(f(x), C1*exp(x*(-1 + sqrt(2))) + C2*exp(-x*(sqrt(2) + 1)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_07': { + 'eq': diff(f(x), x, 3) + diff(f(x), x, 2) - 10*diff(f(x), x) - 6*f(x), + 'sol': [Eq(f(x), C1*exp(3*x) + C3*exp(-x*(2 + sqrt(2))) + C2*exp(x*(-2 + sqrt(2))))], + 'slow': True, + }, + + 'lin_const_coeff_hom_08': { + 'eq': f(x).diff(x, 4) - f(x).diff(x, 3) - 4*f(x).diff(x, 2) + \ + 4*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x) + C3*exp(x) + C4*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_09': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 3) + f(x).diff(x, 2) - \ + 4*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), C3*exp(-x) + C4*exp(x) + (C1*exp(-sqrt(2)*x) + C2*exp(sqrt(2)*x))*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_10': { + 'eq': f(x).diff(x, 4) - a**2*f(x), + 'sol': [Eq(f(x), C1*exp(-sqrt(a)*x) + C2*exp(sqrt(a)*x) + C3*sin(sqrt(a)*x) + C4*cos(sqrt(a)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_11': { + 'eq': f(x).diff(x, 2) - 2*k*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), C1*exp(x*(k - sqrt(k**2 + 2))) + C2*exp(x*(k + sqrt(k**2 + 2))))], + 'slow': True, + }, + + 'lin_const_coeff_hom_12': { + 'eq': f(x).diff(x, 2) + 4*k*f(x).diff(x) - 12*k**2*f(x), + 'sol': [Eq(f(x), C1*exp(-6*k*x) + C2*exp(2*k*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_13': { + 'eq': f(x).diff(x, 4), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*x**3)], + 'slow': True, + }, + + 'lin_const_coeff_hom_14': { + 'eq': f(x).diff(x, 2) + 4*f(x).diff(x) + 4*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_15': { + 'eq': 3*f(x).diff(x, 3) + 5*f(x).diff(x, 2) + f(x).diff(x) - f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-x) + C3*exp(x/3))], + 'slow': True, + }, + + 'lin_const_coeff_hom_16': { + 'eq': f(x).diff(x, 3) - 6*f(x).diff(x, 2) + 12*f(x).diff(x) - 8*f(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + C3*x))*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_17': { + 'eq': f(x).diff(x, 2) - 2*a*f(x).diff(x) + a**2*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(a*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_18': { + 'eq': f(x).diff(x, 4) + 3*f(x).diff(x, 3), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*exp(-3*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_19': { + 'eq': f(x).diff(x, 4) - 2*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*exp(-sqrt(2)*x) + C4*exp(sqrt(2)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_20': { + 'eq': f(x).diff(x, 4) + 2*f(x).diff(x, 3) - 11*f(x).diff(x, 2) - \ + 12*f(x).diff(x) + 36*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-3*x) + (C3 + C4*x)*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_21': { + 'eq': 36*f(x).diff(x, 4) - 37*f(x).diff(x, 2) + 4*f(x).diff(x) + 5*f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(-x/3) + C3*exp(x/2) + C4*exp(x*Rational(5, 6)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_22': { + 'eq': f(x).diff(x, 4) - 8*f(x).diff(x, 2) + 16*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-2*x) + (C3 + C4*x)*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_23': { + 'eq': f(x).diff(x, 2) - 2*f(x).diff(x) + 5*f(x), + 'sol': [Eq(f(x), (C1*sin(2*x) + C2*cos(2*x))*exp(x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_24': { + 'eq': f(x).diff(x, 2) - f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(3)/2) + C2*cos(x*sqrt(3)/2))*exp(x/2))], + 'slow': True, + }, + + 'lin_const_coeff_hom_25': { + 'eq': f(x).diff(x, 4) + 5*f(x).diff(x, 2) + 6*f(x), + 'sol': [Eq(f(x), + C1*sin(sqrt(2)*x) + C2*sin(sqrt(3)*x) + C3*cos(sqrt(2)*x) + C4*cos(sqrt(3)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_26': { + 'eq': f(x).diff(x, 2) - 4*f(x).diff(x) + 20*f(x), + 'sol': [Eq(f(x), (C1*sin(4*x) + C2*cos(4*x))*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_27': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2) + 4*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*sin(x*sqrt(2)) + (C3 + C4*x)*cos(x*sqrt(2)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_28': { + 'eq': f(x).diff(x, 3) + 8*f(x), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(3)) + C2*cos(x*sqrt(3)))*exp(x) + C3*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_29': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*sin(2*x) + C4*cos(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_30': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x), + 'sol': [Eq(f(x), C1 + (C2 + C3*x)*sin(x) + (C4 + C5*x)*cos(x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_31': { + 'eq': f(x).diff(x, 4) + f(x).diff(x, 2) + f(x), + 'sol': [Eq(f(x), (C1*sin(sqrt(3)*x/2) + C2*cos(sqrt(3)*x/2))*exp(-x/2) + + (C3*sin(sqrt(3)*x/2) + C4*cos(sqrt(3)*x/2))*exp(x/2))], + 'slow': True, + }, + + 'lin_const_coeff_hom_32': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2) + f(x), + 'sol': [Eq(f(x), C1*sin(x*sqrt(-sqrt(3) + 2)) + C2*sin(x*sqrt(sqrt(3) + 2)) + + C3*cos(x*sqrt(-sqrt(3) + 2)) + C4*cos(x*sqrt(sqrt(3) + 2)))], + 'slow': True, + }, + + # One real root, two complex conjugate pairs + 'lin_const_coeff_hom_33': { + 'eq': f(x).diff(x, 5) + 11*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), + C5*exp(r1*x) + exp(re(r2)*x) * (C1*sin(im(r2)*x) + C2*cos(im(r2)*x)) + + exp(re(r4)*x) * (C3*sin(im(r4)*x) + C4*cos(im(r4)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Three real roots, one complex conjugate pair + 'lin_const_coeff_hom_34': { + 'eq': f(x).diff(x,5) - 3*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), + C3*exp(r6*x) + C4*exp(r7*x) + C5*exp(r8*x) + + exp(re(r9)*x) * (C1*sin(im(r9)*x) + C2*cos(im(r9)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Five distinct real roots + 'lin_const_coeff_hom_35': { + 'eq': f(x).diff(x,5) - 100*f(x).diff(x,3) + 1000*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*exp(r11*x) + C2*exp(r12*x) + C3*exp(r13*x) + C4*exp(r14*x) + C5*exp(r15*x))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Rational root and unsolvable quintic + 'lin_const_coeff_hom_36': { + 'eq': f(x).diff(x, 6) - 6*f(x).diff(x, 5) + 5*f(x).diff(x, 4) + 10*f(x).diff(x) - 50 * f(x), + 'sol': [Eq(f(x), + C5*exp(5*x) + + C6*exp(x*r16) + + exp(re(r17)*x) * (C1*sin(im(r17)*x) + C2*cos(im(r17)*x)) + + exp(re(r19)*x) * (C3*sin(im(r19)*x) + C4*cos(im(r19)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Five double roots (this is (x**5 - x + 1)**2) + 'lin_const_coeff_hom_37': { + 'eq': f(x).diff(x, 10) - 2*f(x).diff(x, 6) + 2*f(x).diff(x, 5) + + f(x).diff(x, 2) - 2*f(x).diff(x, 1) + f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(x*r21) + (-((C3 + C4*x)*sin(x*im(r22))) + + (C5 + C6*x)*cos(x*im(r22)))*exp(x*re(r22)) + (-((C7 + C8*x)*sin(x*im(r24))) + + (C10*x + C9)*cos(x*im(r24)))*exp(x*re(r24)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + 'lin_const_coeff_hom_38': { + 'eq': Eq(sqrt(2) * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(2**Rational(3, 4)*x/2) + C3*cos(2**Rational(3, 4)*x/2))], + }, + + 'lin_const_coeff_hom_39': { + 'eq': Eq(E * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(x/sqrt(E)) + C3*cos(x/sqrt(E)))], + }, + + 'lin_const_coeff_hom_40': { + 'eq': Eq(pi * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(x/sqrt(pi)) + C3*cos(x/sqrt(pi)))], + }, + + 'lin_const_coeff_hom_41': { + 'eq': Eq(I * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*exp(-sqrt(I)*x) + C3*exp(sqrt(I)*x))], + }, + + 'lin_const_coeff_hom_42': { + 'eq': f(x).diff(x, x) + y*f(x), + 'sol': [Eq(f(x), C1*exp(-x*sqrt(-y)) + C2*exp(x*sqrt(-y)))], + }, + + 'lin_const_coeff_hom_43': { + 'eq': Eq(9*f(x).diff(x, x) + f(x), 0), + 'sol': [Eq(f(x), C1*sin(x/3) + C2*cos(x/3))], + }, + + 'lin_const_coeff_hom_44': { + 'eq': Eq(9*f(x).diff(x, x), f(x)), + 'sol': [Eq(f(x), C1*exp(-x/3) + C2*exp(x/3))], + }, + + 'lin_const_coeff_hom_45': { + 'eq': Eq(f(x).diff(x, x) - 3*diff(f(x), x) + 2*f(x), 0), + 'sol': [Eq(f(x), (C1 + C2*exp(x))*exp(x))], + }, + + 'lin_const_coeff_hom_46': { + 'eq': Eq(f(x).diff(x, x) - 4*diff(f(x), x) + 4*f(x), 0), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(2*x))], + }, + + # Type: 2nd order, constant coefficients (two real equal roots) + 'lin_const_coeff_hom_47': { + 'eq': Eq(f(x).diff(x, x) + 2*diff(f(x), x) + 3*f(x), 0), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(2)) + C2*cos(x*sqrt(2)))*exp(-x))], + }, + + #These were from issue: https://github.com/sympy/sympy/issues/6247 + 'lin_const_coeff_hom_48': { + 'eq': f(x).diff(x, x) + 4*f(x), + 'sol': [Eq(f(x), C1*sin(2*x) + C2*cos(2*x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep(): + return { + 'hint': "1st_homogeneous_coeff_subs_dep_div_indep", + 'func': f(x), + 'examples':{ + 'dep_div_indep_01': { + 'eq': f(x)/x*cos(f(x)/x) - (x/f(x)*sin(f(x)/x) + cos(f(x)/x))*f(x).diff(x), + 'sol': [Eq(log(x), C1 - log(f(x)*sin(f(x)/x)/x))], + 'slow': True + }, + + #indep_div_dep actually has a simpler solution for example 2 but it runs too slow. + 'dep_div_indep_02': { + 'eq': x*f(x).diff(x) - f(x) - x*sin(f(x)/x), + 'sol': [Eq(log(x), log(C1) + log(cos(f(x)/x) - 1)/2 - log(cos(f(x)/x) + 1)/2)], + 'simplify_flag':False, + }, + + 'dep_div_indep_03': { + 'eq': x*exp(f(x)/x) - f(x)*sin(f(x)/x) + x*sin(f(x)/x)*f(x).diff(x), + 'sol': [Eq(log(x), C1 + exp(-f(x)/x)*sin(f(x)/x)/2 + exp(-f(x)/x)*cos(f(x)/x)/2)], + 'slow': True + }, + + 'dep_div_indep_04': { + 'eq': f(x).diff(x) - f(x)/x + 1/sin(f(x)/x), + 'sol': [Eq(f(x), x*(-acos(C1 + log(x)) + 2*pi)), Eq(f(x), x*acos(C1 + log(x)))], + 'slow': True + }, + + # previous code was testing with these other solution: + # example5_solb = Eq(f(x), log(log(C1/x)**(-x))) + 'dep_div_indep_05': { + 'eq': x*exp(f(x)/x) + f(x) - x*f(x).diff(x), + 'sol': [Eq(f(x), log((1/(C1 - log(x)))**x))], + 'checkodesol_XFAIL':True, #(because of **x?) + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_linear_coefficients(): + return { + 'hint': "linear_coefficients", + 'func': f(x), + 'examples':{ + 'linear_coeff_01': { + 'eq': f(x).diff(x) + (3 + 2*f(x))/(x + 3), + 'sol': [Eq(f(x), C1/(x**2 + 6*x + 9) - Rational(3, 2))], + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_1st_homogeneous_coeff_best(): + return { + 'hint': "1st_homogeneous_coeff_best", + 'func': f(x), + 'examples':{ + # previous code was testing this with other solution: + # example1_solb = Eq(-f(x)/(1 + log(x/f(x))), C1) + '1st_homogeneous_coeff_best_01': { + 'eq': f(x) + (x*log(f(x)/x) - 2*x)*diff(f(x), x), + 'sol': [Eq(f(x), -exp(C1)*LambertW(-x*exp(-C1 + 1)))], + 'checkodesol_XFAIL':True, #(because of LambertW?) + }, + + '1st_homogeneous_coeff_best_02': { + 'eq': 2*f(x)*exp(x/f(x)) + f(x)*f(x).diff(x) - 2*x*exp(x/f(x))*f(x).diff(x), + 'sol': [Eq(log(f(x)), C1 - 2*exp(x/f(x)))], + }, + + # previous code was testing this with other solution: + # example3_solb = Eq(log(C1*x*sqrt(1/x)*sqrt(f(x))) + x**2/(2*f(x)**2), 0) + '1st_homogeneous_coeff_best_03': { + 'eq': 2*x**2*f(x) + f(x)**3 + (x*f(x)**2 - 2*x**3)*f(x).diff(x), + 'sol': [Eq(f(x), exp(2*C1 + LambertW(-2*x**4*exp(-4*C1))/2)/x)], + 'checkodesol_XFAIL':True, #(because of LambertW?) + }, + + '1st_homogeneous_coeff_best_04': { + 'eq': (x + sqrt(f(x)**2 - x*f(x)))*f(x).diff(x) - f(x), + 'sol': [Eq(log(f(x)), C1 - 2*sqrt(-x/f(x) + 1))], + 'slow': True, + }, + + '1st_homogeneous_coeff_best_05': { + 'eq': x + f(x) - (x - f(x))*f(x).diff(x), + 'sol': [Eq(log(x), C1 - log(sqrt(1 + f(x)**2/x**2)) + atan(f(x)/x))], + }, + + '1st_homogeneous_coeff_best_06': { + 'eq': x*f(x).diff(x) - f(x) - x*sin(f(x)/x), + 'sol': [Eq(f(x), 2*x*atan(C1*x))], + }, + + '1st_homogeneous_coeff_best_07': { + 'eq': x**2 + f(x)**2 - 2*x*f(x)*f(x).diff(x), + 'sol': [Eq(f(x), -sqrt(x*(C1 + x))), Eq(f(x), sqrt(x*(C1 + x)))], + }, + + '1st_homogeneous_coeff_best_08': { + 'eq': f(x)**2 + (x*sqrt(f(x)**2 - x**2) - x*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -C1*sqrt(-x/(x - 2*C1))), Eq(f(x), C1*sqrt(-x/(x - 2*C1)))], + 'checkodesol_XFAIL': True # solutions are valid in a range + }, + } + } + + +def _get_all_examples(): + all_examples = _get_examples_ode_sol_euler_homogeneous + \ + _get_examples_ode_sol_euler_undetermined_coeff + \ + _get_examples_ode_sol_euler_var_para + \ + _get_examples_ode_sol_factorable + \ + _get_examples_ode_sol_bernoulli + \ + _get_examples_ode_sol_nth_algebraic + \ + _get_examples_ode_sol_riccati + \ + _get_examples_ode_sol_1st_linear + \ + _get_examples_ode_sol_1st_exact + \ + _get_examples_ode_sol_almost_linear + \ + _get_examples_ode_sol_nth_order_reducible + \ + _get_examples_ode_sol_nth_linear_undetermined_coefficients + \ + _get_examples_ode_sol_liouville + \ + _get_examples_ode_sol_separable + \ + _get_examples_ode_sol_1st_rational_riccati + \ + _get_examples_ode_sol_nth_linear_var_of_parameters + \ + _get_examples_ode_sol_2nd_linear_bessel + \ + _get_examples_ode_sol_2nd_2F1_hypergeometric + \ + _get_examples_ode_sol_2nd_nonlinear_autonomous_conserved + \ + _get_examples_ode_sol_separable_reduced + \ + _get_examples_ode_sol_lie_group + \ + _get_examples_ode_sol_2nd_linear_airy + \ + _get_examples_ode_sol_nth_linear_constant_coeff_homogeneous +\ + _get_examples_ode_sol_1st_homogeneous_coeff_best +\ + _get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep +\ + _get_examples_ode_sol_linear_coefficients + + return all_examples diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_subscheck.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_subscheck.py new file mode 100644 index 0000000000000000000000000000000000000000..799c2854e878208721b600767de350cda08cd7e5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_subscheck.py @@ -0,0 +1,203 @@ +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.error_functions import (Ei, erf, erfi) +from sympy.integrals.integrals import Integral + +from sympy.solvers.ode.subscheck import checkodesol, checksysodesol + +from sympy.functions import besselj, bessely + +from sympy.testing.pytest import raises, slow + + +C0, C1, C2, C3, C4 = symbols('C0:5') +u, x, y, z = symbols('u,x:z', real=True) +f = Function('f') +g = Function('g') +h = Function('h') + + +@slow +def test_checkodesol(): + # For the most part, checkodesol is well tested in the tests below. + # These tests only handle cases not checked below. + raises(ValueError, lambda: checkodesol(f(x, y).diff(x), Eq(f(x, y), x))) + raises(ValueError, lambda: checkodesol(f(x).diff(x), Eq(f(x, y), + x), f(x, y))) + assert checkodesol(f(x).diff(x), Eq(f(x, y), x)) == \ + (False, -f(x).diff(x) + f(x, y).diff(x) - 1) + assert checkodesol(f(x).diff(x), Eq(f(x), x)) is not True + assert checkodesol(f(x).diff(x), Eq(f(x), x)) == (False, 1) + sol1 = Eq(f(x)**5 + 11*f(x) - 2*f(x) + x, 0) + assert checkodesol(diff(sol1.lhs, x), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 2), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 2)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3), Eq(f(x), x*log(x))) == \ + (False, 60*x**4*((log(x) + 1)**2 + log(x))*( + log(x) + 1)*log(x)**2 - 5*x**4*log(x)**4 - 9) + assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0)) == \ + (True, 0) + assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0), + solve_for_func=False) == (True, 0) + assert checkodesol(f(x).diff(x, 2), [Eq(f(x), C1 + C2*x), + Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)]) == \ + [(True, 0), (True, 0), (False, C2)] + assert checkodesol(f(x).diff(x, 2), {Eq(f(x), C1 + C2*x), + Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)}) == \ + {(True, 0), (True, 0), (False, C2)} + assert checkodesol(f(x).diff(x) - 1/f(x)/2, Eq(f(x)**2, x)) == \ + [(True, 0), (True, 0)] + assert checkodesol(f(x).diff(x) - f(x), Eq(C1*exp(x), f(x))) == (True, 0) + # Based on test_1st_homogeneous_coeff_ode2_eq3sol. Make sure that + # checkodesol tries back substituting f(x) when it can. + eq3 = x*exp(f(x)/x) + f(x) - x*f(x).diff(x) + sol3 = Eq(f(x), log(log(C1/x)**(-x))) + assert not checkodesol(eq3, sol3)[1].has(f(x)) + # This case was failing intermittently depending on hash-seed: + eqn = Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x)) + sol = Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x)) + assert checkodesol(eqn, sol, order=2, solve_for_func=False)[0] + eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (2*x**2 +25)*f(x) + sol = Eq(f(x), C1*besselj(5*I, sqrt(2)*x) + C2*bessely(5*I, sqrt(2)*x)) + assert checkodesol(eq, sol) == (True, 0) + + eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))] + sol = [Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))] + assert checkodesol(eqs, sol) == (True, [0, 0]) + + +def test_checksysodesol(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + eq = (Eq(diff(x(t),t), 9*y(t)), Eq(diff(y(t),t), 12*x(t))) + sol = [Eq(x(t), 9*C1*exp(-6*sqrt(3)*t) + 9*C2*exp(6*sqrt(3)*t)), \ + Eq(y(t), -6*sqrt(3)*C1*exp(-6*sqrt(3)*t) + 6*sqrt(3)*C2*exp(6*sqrt(3)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 2*x(t) + 4*y(t)), Eq(diff(y(t),t), 12*x(t) + 41*y(t))) + sol = [Eq(x(t), 4*C1*exp(t*(-sqrt(1713)/2 + Rational(43, 2))) + 4*C2*exp(t*(sqrt(1713)/2 + \ + Rational(43, 2)))), Eq(y(t), C1*(-sqrt(1713)/2 + Rational(39, 2))*exp(t*(-sqrt(1713)/2 + \ + Rational(43, 2))) + C2*(Rational(39, 2) + sqrt(1713)/2)*exp(t*(sqrt(1713)/2 + Rational(43, 2))))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t)), Eq(diff(y(t),t), -2*x(t) + 2*y(t))) + sol = [Eq(x(t), (C1*sin(sqrt(7)*t/2) + C2*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2))), \ + Eq(y(t), ((C1/2 - sqrt(7)*C2/2)*sin(sqrt(7)*t/2) + (sqrt(7)*C1/2 + \ + C2/2)*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2)))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23)) + sol = [Eq(x(t), C1*exp(t*(-sqrt(6) + 3)) + C2*exp(t*(sqrt(6) + 3)) - \ + Rational(22, 3)), Eq(y(t), C1*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) + C2*(2 + \ + sqrt(6))*exp(t*(sqrt(6) + 3)) - Rational(5, 3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23)) + sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - Rational(58, 3)), \ + Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - Rational(185, 3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t))) + sol = [Eq(x(t), (C1*exp(Integral(2, t).doit()) + C2*exp(-(Integral(2, t)).doit()))*\ + exp((Integral(5*t, t)).doit())), Eq(y(t), (C1*exp((Integral(2, t)).doit()) - \ + C2*exp(-(Integral(2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + sol = [Eq(x(t), (C1*cos((Integral(t**2, t)).doit()) + C2*sin((Integral(t**2, t)).doit()))*\ + exp((Integral(5*t, t)).doit())), Eq(y(t), (-C1*sin((Integral(t**2, t)).doit()) + \ + C2*cos((Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t))) + sol = [Eq(x(t), (C1*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \ + C2*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit())), \ + Eq(y(t), (C1*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \ + C2*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 5*x(t) + 43*y(t)), Eq(diff(y(t),t,t), x(t) + 9*y(t))) + root0 = -sqrt(-sqrt(47) + 7) + root1 = sqrt(-sqrt(47) + 7) + root2 = -sqrt(sqrt(47) + 7) + root3 = sqrt(sqrt(47) + 7) + sol = [Eq(x(t), 43*C1*exp(t*root0) + 43*C2*exp(t*root1) + 43*C3*exp(t*root2) + 43*C4*exp(t*root3)), \ + Eq(y(t), C1*(root0**2 - 5)*exp(t*root0) + C2*(root1**2 - 5)*exp(t*root1) + \ + C3*(root2**2 - 5)*exp(t*root2) + C4*(root3**2 - 5)*exp(t*root3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 8*x(t)+3*y(t)+31), Eq(diff(y(t),t,t), 9*x(t)+7*y(t)+12)) + root0 = -sqrt(-sqrt(109)/2 + Rational(15, 2)) + root1 = sqrt(-sqrt(109)/2 + Rational(15, 2)) + root2 = -sqrt(sqrt(109)/2 + Rational(15, 2)) + root3 = sqrt(sqrt(109)/2 + Rational(15, 2)) + sol = [Eq(x(t), 3*C1*exp(t*root0) + 3*C2*exp(t*root1) + 3*C3*exp(t*root2) + 3*C4*exp(t*root3) - Rational(181, 29)), \ + Eq(y(t), C1*(root0**2 - 8)*exp(t*root0) + C2*(root1**2 - 8)*exp(t*root1) + \ + C3*(root2**2 - 8)*exp(t*root2) + C4*(root3**2 - 8)*exp(t*root3) + Rational(183, 29))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t) - 9*diff(y(t),t) + 7*x(t),0), Eq(diff(y(t),t,t) + 9*diff(x(t),t) + 7*y(t),0)) + sol = [Eq(x(t), C1*cos(t*(Rational(9, 2) + sqrt(109)/2)) + C2*sin(t*(Rational(9, 2) + sqrt(109)/2)) + \ + C3*cos(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*sin(t*(-sqrt(109)/2 + Rational(9, 2)))), Eq(y(t), -C1*sin(t*(Rational(9, 2) + sqrt(109)/2)) \ + + C2*cos(t*(Rational(9, 2) + sqrt(109)/2)) - C3*sin(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*cos(t*(-sqrt(109)/2 + Rational(9, 2))))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 9*t*diff(y(t),t)-9*y(t)), Eq(diff(y(t),t,t),7*t*diff(x(t),t)-7*x(t))) + I1 = sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erfi(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(3*sqrt(7)*t**2/2)/t + I2 = -sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erf(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(-3*sqrt(7)*t**2/2)/t + sol = [Eq(x(t), C3*t + t*(9*C1*I1 + 9*C2*I2)), Eq(y(t), C4*t + t*(3*sqrt(7)*C1*I1 - 3*sqrt(7)*C2*I2))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 21*x(t)), Eq(diff(y(t),t), 17*x(t)+3*y(t)), Eq(diff(z(t),t), 5*x(t)+7*y(t)+9*z(t))) + sol = [Eq(x(t), C1*exp(21*t)), Eq(y(t), 17*C1*exp(21*t)/18 + C2*exp(3*t)), \ + Eq(z(t), 209*C1*exp(21*t)/216 - 7*C2*exp(3*t)/6 + C3*exp(9*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),3*y(t)-11*z(t)),Eq(diff(y(t),t),7*z(t)-3*x(t)),Eq(diff(z(t),t),11*x(t)-7*y(t))) + sol = [Eq(x(t), 7*C0 + sqrt(179)*C1*cos(sqrt(179)*t) + (77*C1/3 + 130*C2/3)*sin(sqrt(179)*t)), \ + Eq(y(t), 11*C0 + sqrt(179)*C2*cos(sqrt(179)*t) + (-58*C1/3 - 77*C2/3)*sin(sqrt(179)*t)), \ + Eq(z(t), 3*C0 + sqrt(179)*(-7*C1/3 - 11*C2/3)*cos(sqrt(179)*t) + (11*C1 - 7*C2)*sin(sqrt(179)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(3*diff(x(t),t),4*5*(y(t)-z(t))),Eq(4*diff(y(t),t),3*5*(z(t)-x(t))),Eq(5*diff(z(t),t),3*4*(x(t)-y(t)))) + sol = [Eq(x(t), C0 + 5*sqrt(2)*C1*cos(5*sqrt(2)*t) + (12*C1/5 + 164*C2/15)*sin(5*sqrt(2)*t)), \ + Eq(y(t), C0 + 5*sqrt(2)*C2*cos(5*sqrt(2)*t) + (-51*C1/10 - 12*C2/5)*sin(5*sqrt(2)*t)), \ + Eq(z(t), C0 + 5*sqrt(2)*(-9*C1/25 - 16*C2/25)*cos(5*sqrt(2)*t) + (12*C1/5 - 12*C2/5)*sin(5*sqrt(2)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),4*x(t) - z(t)),Eq(diff(y(t),t),2*x(t)+2*y(t)-z(t)),Eq(diff(z(t),t),3*x(t)+y(t))) + sol = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t) + C3*exp(2*t)), \ + Eq(y(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t)), \ + Eq(z(t), 2*C1*exp(2*t) + 2*C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t) + C3*t*exp(2*t) + C3*exp(2*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),4*x(t) - y(t) - 2*z(t)),Eq(diff(y(t),t),2*x(t) + y(t)- 2*z(t)),Eq(diff(z(t),t),5*x(t)-3*z(t))) + sol = [Eq(x(t), C1*exp(2*t) + C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), \ + Eq(y(t), C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), Eq(z(t), C1*exp(2*t) + 5*C2*cos(t) + 5*C3*sin(t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol = [Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5)) + sol = [Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2)) + sol = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)} + assert checksysodesol(eq, sol) == (True, [0, 0]) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_systems.py b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_systems.py new file mode 100644 index 0000000000000000000000000000000000000000..9d206129dfcf38c7b8c2e0ab42bd875003253f35 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/ode/tests/test_systems.py @@ -0,0 +1,2544 @@ +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.hyperbolic import sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.core.containers import Tuple +from sympy.functions import exp, cos, sin, log, Ci, Si, erf, erfi +from sympy.matrices import dotprodsimp, NonSquareMatrixError +from sympy.solvers.ode import dsolve +from sympy.solvers.ode.ode import constant_renumber +from sympy.solvers.ode.subscheck import checksysodesol +from sympy.solvers.ode.systems import (_classify_linear_system, linear_ode_to_matrix, + ODEOrderError, ODENonlinearError, _simpsol, + _is_commutative_anti_derivative, linodesolve, + canonical_odes, dsolve_system, _component_division, + _eqs2dict, _dict2graph) +from sympy.functions import airyai, airybi +from sympy.integrals.integrals import Integral +from sympy.simplify.ratsimp import ratsimp +from sympy.testing.pytest import raises, slow, tooslow, XFAIL + + +C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C0:11') +x = symbols('x') +f = Function('f') +g = Function('g') +h = Function('h') + + +def test_linear_ode_to_matrix(): + f, g, h = symbols("f, g, h", cls=Function) + t = Symbol("t") + funcs = [f(t), g(t), h(t)] + f1 = f(t).diff(t) + g1 = g(t).diff(t) + h1 = h(t).diff(t) + f2 = f(t).diff(t, 2) + g2 = g(t).diff(t, 2) + h2 = h(t).diff(t, 2) + + eqs_1 = [Eq(f1, g(t)), Eq(g1, f(t))] + sol_1 = ([Matrix([[1, 0], [0, 1]]), Matrix([[ 0, 1], [1, 0]])], Matrix([[0],[0]])) + assert linear_ode_to_matrix(eqs_1, funcs[:-1], t, 1) == sol_1 + + eqs_2 = [Eq(f1, f(t) + 2*g(t)), Eq(g1, h(t)), Eq(h1, g(t) + h(t) + f(t))] + sol_2 = ([Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), Matrix([[1, 2, 0], [ 0, 0, 1], [1, 1, 1]])], + Matrix([[0], [0], [0]])) + assert linear_ode_to_matrix(eqs_2, funcs, t, 1) == sol_2 + + eqs_3 = [Eq(2*f1 + 3*h1, f(t) + g(t)), Eq(4*h1 + 5*g1, f(t) + h(t)), Eq(5*f1 + 4*g1, g(t) + h(t))] + sol_3 = ([Matrix([[2, 0, 3], [0, 5, 4], [5, 4, 0]]), Matrix([[1, 1, 0], [1, 0, 1], [0, 1, 1]])], + Matrix([[0], [0], [0]])) + assert linear_ode_to_matrix(eqs_3, funcs, t, 1) == sol_3 + + eqs_4 = [Eq(f2 + h(t), f1 + g(t)), Eq(2*h2 + g2 + g1 + g(t), 0), Eq(3*h1, 4)] + sol_4 = ([Matrix([[1, 0, 0], [0, 1, 2], [0, 0, 0]]), Matrix([[1, 0, 0], [0, -1, 0], [0, 0, -3]]), + Matrix([[0, 1, -1], [0, -1, 0], [0, 0, 0]])], Matrix([[0], [0], [4]])) + assert linear_ode_to_matrix(eqs_4, funcs, t, 2) == sol_4 + + eqs_5 = [Eq(f2, g(t)), Eq(f1 + g1, f(t))] + raises(ODEOrderError, lambda: linear_ode_to_matrix(eqs_5, funcs[:-1], t, 1)) + + eqs_6 = [Eq(f1, f(t)**2), Eq(g1, f(t) + g(t))] + raises(ODENonlinearError, lambda: linear_ode_to_matrix(eqs_6, funcs[:-1], t, 1)) + + +def test__classify_linear_system(): + x, y, z, w = symbols('x, y, z, w', cls=Function) + t, k, l = symbols('t k l') + x1 = diff(x(t), t) + y1 = diff(y(t), t) + z1 = diff(z(t), t) + w1 = diff(w(t), t) + x2 = diff(x(t), t, t) + y2 = diff(y(t), t, t) + funcs = [x(t), y(t)] + funcs_2 = funcs + [z(t), w(t)] + + eqs_1 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * t * x(t) + 3 * y(t) + t)) + assert _classify_linear_system(eqs_1, funcs, t) is None + + eqs_2 = (5 * (x1**2) + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * t * x(t) + 3 * y(t) + t)) + sol2 = {'is_implicit': True, + 'canon_eqs': [[Eq(Derivative(x(t), t), -sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)], + [Eq(Derivative(x(t), t), sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)]]} + assert _classify_linear_system(eqs_2, funcs, t) == sol2 + + eqs_2_1 = [Eq(Derivative(x(t), t), -sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)] + assert _classify_linear_system(eqs_2_1, funcs, t) is None + + eqs_2_2 = [Eq(Derivative(x(t), t), sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)] + assert _classify_linear_system(eqs_2_2, funcs, t) is None + + eqs_3 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * x(t) + 3 * y(t)), (5 * w1 + z(t)), (z1 + w(t))) + answer_3 = {'no_of_equation': 4, + 'eq': (12*x(t) - 6*y(t) + 5*Derivative(x(t), t), + -11*x(t) + 3*y(t) + 2*Derivative(y(t), t), + z(t) + 5*Derivative(w(t), t), + w(t) + Derivative(z(t), t)), + 'func': [x(t), y(t), z(t), w(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, + 'is_linear': True, + 'is_constant': True, + 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [Rational(12, 5), Rational(-6, 5), 0, 0], + [Rational(-11, 2), Rational(3, 2), 0, 0], + [0, 0, 0, 1], + [0, 0, Rational(1, 5), 0]]), + 'type_of_equation': 'type1', + 'is_general': True} + assert _classify_linear_system(eqs_3, funcs_2, t) == answer_3 + + eqs_4 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * x(t) + 3 * y(t)), (z1 - w(t)), (w1 - z(t))) + answer_4 = {'no_of_equation': 4, + 'eq': (12 * x(t) - 6 * y(t) + 5 * Derivative(x(t), t), + -11 * x(t) + 3 * y(t) + 2 * Derivative(y(t), t), + -w(t) + Derivative(z(t), t), + -z(t) + Derivative(w(t), t)), + 'func': [x(t), y(t), z(t), w(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, + 'is_linear': True, + 'is_constant': True, + 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [Rational(12, 5), Rational(-6, 5), 0, 0], + [Rational(-11, 2), Rational(3, 2), 0, 0], + [0, 0, 0, -1], + [0, 0, -1, 0]]), + 'type_of_equation': 'type1', + 'is_general': True} + assert _classify_linear_system(eqs_4, funcs_2, t) == answer_4 + + eqs_5 = (5*x1 + 12*x(t) - 6*(y(t)) + x2, (2*y1 - 11*x(t) + 3*y(t)), (z1 - w(t)), (w1 - z(t))) + answer_5 = {'no_of_equation': 4, 'eq': (12*x(t) - 6*y(t) + 5*Derivative(x(t), t) + Derivative(x(t), (t, 2)), + -11*x(t) + 3*y(t) + 2*Derivative(y(t), t), -w(t) + Derivative(z(t), t), -z(t) + Derivative(w(t), + t)), 'func': [x(t), y(t), z(t), w(t)], 'order': {x(t): 2, y(t): 1, z(t): 1, w(t): 1}, 'is_linear': + True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': 'type0', 'is_higher_order': True} + assert _classify_linear_system(eqs_5, funcs_2, t) == answer_5 + + eqs_6 = (Eq(x1, 3*y(t) - 11*z(t)), Eq(y1, 7*z(t) - 3*x(t)), Eq(z1, 11*x(t) - 7*y(t))) + answer_6 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ 0, -3, 11], + [ 3, 0, -7], + [-11, 7, 0]]), + 'type_of_equation': 'type1', 'is_general': True} + + assert _classify_linear_system(eqs_6, funcs_2[:-1], t) == answer_6 + + eqs_7 = (Eq(x1, y(t)), Eq(y1, x(t))) + answer_7 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, + 'is_homogeneous': True, 'func_coeff': -Matrix([ + [ 0, -1], + [-1, 0]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_7, funcs, t) == answer_7 + + eqs_8 = (Eq(x1, 21*x(t)), Eq(y1, 17*x(t) + 3*y(t)), Eq(z1, 5*x(t) + 7*y(t) + 9*z(t))) + answer_8 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 21*x(t)), Eq(Derivative(y(t), t), 17*x(t) + 3*y(t)), + Eq(Derivative(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-21, 0, 0], + [-17, -3, 0], + [ -5, -7, -9]]), + 'type_of_equation': 'type1', 'is_general': True} + + assert _classify_linear_system(eqs_8, funcs_2[:-1], t) == answer_8 + + eqs_9 = (Eq(x1, 4*x(t) + 5*y(t) + 2*z(t)), Eq(y1, x(t) + 13*y(t) + 9*z(t)), Eq(z1, 32*x(t) + 41*y(t) + 11*z(t))) + answer_9 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 4*x(t) + 5*y(t) + 2*z(t)), + Eq(Derivative(y(t), t), x(t) + 13*y(t) + 9*z(t)), Eq(Derivative(z(t), t), 32*x(t) + 41*y(t) + 11*z(t))), + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, + 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ -4, -5, -2], + [ -1, -13, -9], + [-32, -41, -11]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_9, funcs_2[:-1], t) == answer_9 + + eqs_10 = (Eq(3*x1, 4*5*(y(t) - z(t))), Eq(4*y1, 3*5*(z(t) - x(t))), Eq(5*z1, 3*4*(x(t) - y(t)))) + answer_10 = {'no_of_equation': 3, 'eq': (Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), + Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))), + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, + 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ 0, Rational(-20, 3), Rational(20, 3)], + [Rational(15, 4), 0, Rational(-15, 4)], + [Rational(-12, 5), Rational(12, 5), 0]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_10, funcs_2[:-1], t) == answer_10 + + eq11 = (Eq(x1, 3*y(t) - 11*z(t)), Eq(y1, 7*z(t) - 3*x(t)), Eq(z1, 11*x(t) - 7*y(t))) + sol11 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, 'func_coeff': -Matrix([ + [ 0, -3, 11], [ 3, 0, -7], [-11, 7, 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq11, funcs_2[:-1], t) == sol11 + + eq12 = (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))) + sol12 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, + 'is_homogeneous': True, 'func_coeff': -Matrix([ + [0, -1], + [-1, 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq12, [x(t), y(t)], t) == sol12 + + eq13 = (Eq(Derivative(x(t), t), 21*x(t)), Eq(Derivative(y(t), t), 17*x(t) + 3*y(t)), + Eq(Derivative(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))) + sol13 = {'no_of_equation': 3, 'eq': ( + Eq(Derivative(x(t), t), 21 * x(t)), Eq(Derivative(y(t), t), 17 * x(t) + 3 * y(t)), + Eq(Derivative(z(t), t), 5 * x(t) + 7 * y(t) + 9 * z(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-21, 0, 0], + [-17, -3, 0], + [-5, -7, -9]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq13, [x(t), y(t), z(t)], t) == sol13 + + eq14 = ( + Eq(Derivative(x(t), t), 4*x(t) + 5*y(t) + 2*z(t)), Eq(Derivative(y(t), t), x(t) + 13*y(t) + 9*z(t)), + Eq(Derivative(z(t), t), 32*x(t) + 41*y(t) + 11*z(t))) + sol14 = {'no_of_equation': 3, 'eq': ( + Eq(Derivative(x(t), t), 4 * x(t) + 5 * y(t) + 2 * z(t)), Eq(Derivative(y(t), t), x(t) + 13 * y(t) + 9 * z(t)), + Eq(Derivative(z(t), t), 32 * x(t) + 41 * y(t) + 11 * z(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-4, -5, -2], + [-1, -13, -9], + [-32, -41, -11]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq14, [x(t), y(t), z(t)], t) == sol14 + + eq15 = (Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), + Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))) + sol15 = {'no_of_equation': 3, 'eq': ( + Eq(3 * Derivative(x(t), t), 20 * y(t) - 20 * z(t)), Eq(4 * Derivative(y(t), t), -15 * x(t) + 15 * z(t)), + Eq(5 * Derivative(z(t), t), 12 * x(t) - 12 * y(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [0, Rational(-20, 3), Rational(20, 3)], + [Rational(15, 4), 0, Rational(-15, 4)], + [Rational(-12, 5), Rational(12, 5), 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq15, [x(t), y(t), z(t)], t) == sol15 + + # Constant coefficient homogeneous ODEs + eq1 = (Eq(diff(x(t), t), x(t) + y(t) + 9), Eq(diff(y(t), t), 2*x(t) + 5*y(t) + 23)) + sol1 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), x(t) + y(t) + 9), + Eq(Derivative(y(t), t), 2*x(t) + 5*y(t) + 23)), 'func': [x(t), y(t)], + 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': False, 'is_general': True, + 'func_coeff': -Matrix([[-1, -1], [-2, -5]]), 'rhs': Matrix([[ 9], [23]]), 'type_of_equation': 'type2'} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + # Non constant coefficient homogeneous ODEs + eq1 = (Eq(diff(x(t), t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t), t), 2*x(t) + 5*t*y(t))) + sol1 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), 5*t*x(t) + 2*y(t)), Eq(Derivative(y(t), t), 5*t*y(t) + 2*x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': True, 'func_coeff': -Matrix([ [-5*t, -2], [ -2, -5*t]]), 'commutative_antiderivative': Matrix([ + [5*t**2/2, 2*t], [ 2*t, 5*t**2/2]]), 'type_of_equation': 'type3', 'is_general': True} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + # Non constant coefficient non-homogeneous ODEs + eq1 = [Eq(x1, x(t) + t*y(t) + t), Eq(y1, t*x(t) + y(t))] + sol1 = {'no_of_equation': 2, 'eq': [Eq(Derivative(x(t), t), t*y(t) + t + x(t)), Eq(Derivative(y(t), t), + t*x(t) + y(t))], 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, + 'is_constant': False, 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-1, -t], + [-t, -1]]), 'commutative_antiderivative': Matrix([ [ t, t**2/2], [t**2/2, t]]), 'rhs': + Matrix([ [t], [0]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + eq2 = [Eq(x1, t*x(t) + t*y(t) + t), Eq(y1, t*x(t) + t*y(t) + cos(t))] + sol2 = {'no_of_equation': 2, 'eq': [Eq(Derivative(x(t), t), t*x(t) + t*y(t) + t), Eq(Derivative(y(t), t), + t*x(t) + t*y(t) + cos(t))], 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, + 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ [ t], [cos(t)]]), 'func_coeff': + Matrix([ [t, t], [t, t]]), 'is_constant': False, 'type_of_equation': 'type4', + 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2], [t**2/2, t**2/2]])} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = [Eq(x1, t*(x(t) + y(t) + z(t) + 1)), Eq(y1, t*(x(t) + y(t) + z(t))), Eq(z1, t*(x(t) + y(t) + z(t)))] + sol3 = {'no_of_equation': 3, 'eq': [Eq(Derivative(x(t), t), t*(x(t) + y(t) + z(t) + 1)), + Eq(Derivative(y(t), t), t*(x(t) + y(t) + z(t))), Eq(Derivative(z(t), t), t*(x(t) + y(t) + z(t)))], + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': + False, 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-t, -t, -t], [-t, -t, + -t], [-t, -t, -t]]), 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2, t**2/2], [t**2/2, + t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2]]), 'rhs': Matrix([ [t], [0], [0]]), 'type_of_equation': + 'type4'} + assert _classify_linear_system(eq3, funcs_2[:-1], t) == sol3 + + eq4 = [Eq(x1, x(t) + y(t) + t*z(t) + 1), Eq(y1, x(t) + t*y(t) + z(t) + 10), Eq(z1, t*x(t) + y(t) + z(t) + t)] + sol4 = {'no_of_equation': 3, 'eq': [Eq(Derivative(x(t), t), t*z(t) + x(t) + y(t) + 1), Eq(Derivative(y(t), + t), t*y(t) + x(t) + z(t) + 10), Eq(Derivative(z(t), t), t*x(t) + t + y(t) + z(t))], 'func': [x(t), + y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-1, -1, -t], [-1, -t, -1], [-t, + -1, -1]]), 'commutative_antiderivative': Matrix([ [ t, t, t**2/2], [ t, t**2/2, + t], [t**2/2, t, t]]), 'rhs': Matrix([ [ 1], [10], [ t]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq4, funcs_2[:-1], t) == sol4 + + sum_terms = t*(x(t) + y(t) + z(t) + w(t)) + eq5 = [Eq(x1, sum_terms), Eq(y1, sum_terms), Eq(z1, sum_terms + 1), Eq(w1, sum_terms)] + sol5 = {'no_of_equation': 4, 'eq': [Eq(Derivative(x(t), t), t*(w(t) + x(t) + y(t) + z(t))), + Eq(Derivative(y(t), t), t*(w(t) + x(t) + y(t) + z(t))), Eq(Derivative(z(t), t), t*(w(t) + x(t) + + y(t) + z(t)) + 1), Eq(Derivative(w(t), t), t*(w(t) + x(t) + y(t) + z(t)))], 'func': [x(t), y(t), + z(t), w(t)], 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-t, -t, -t, -t], [-t, -t, -t, + -t], [-t, -t, -t, -t], [-t, -t, -t, -t]]), 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2, + t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2, t**2/2], [t**2/2, + t**2/2, t**2/2, t**2/2]]), 'rhs': Matrix([ [0], [0], [1], [0]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq5, funcs_2, t) == sol5 + + # Second Order + t_ = symbols("t_") + + eq1 = (Eq(9*x(t) + 7*y(t) + 4*Derivative(x(t), t) + Derivative(x(t), (t, 2)) + 3*Derivative(y(t), t), 11*exp(I*t)), + Eq(3*x(t) + 12*y(t) + 5*Derivative(x(t), t) + 8*Derivative(y(t), t) + Derivative(y(t), (t, 2)), 2*exp(I*t))) + sol1 = {'no_of_equation': 2, 'eq': (Eq(9*x(t) + 7*y(t) + 4*Derivative(x(t), t) + Derivative(x(t), (t, 2)) + + 3*Derivative(y(t), t), 11*exp(I*t)), Eq(3*x(t) + 12*y(t) + 5*Derivative(x(t), t) + + 8*Derivative(y(t), t) + Derivative(y(t), (t, 2)), 2*exp(I*t))), 'func': [x(t), y(t)], 'order': + {x(t): 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ + [11*exp(I*t)], [ 2*exp(I*t)]]), 'type_of_equation': 'type0', 'is_second_order': True, + 'is_higher_order': True} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + eq2 = (Eq((4*t**2 + 7*t + 1)**2*Derivative(x(t), (t, 2)), 5*x(t) + 35*y(t)), + Eq((4*t**2 + 7*t + 1)**2*Derivative(y(t), (t, 2)), x(t) + 9*y(t))) + sol2 = {'no_of_equation': 2, 'eq': (Eq((4*t**2 + 7*t + 1)**2*Derivative(x(t), (t, 2)), 5*x(t) + 35*y(t)), + Eq((4*t**2 + 7*t + 1)**2*Derivative(y(t), (t, 2)), x(t) + 9*y(t))), 'func': [x(t), y(t)], 'order': + {x(t): 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': True, 'is_general': True, + 'type_of_equation': 'type2', 'A0': Matrix([ [Rational(53, 4), 35], [ 1, Rational(69, 4)]]), 'g(t)': sqrt(4*t**2 + 7*t + + 1), 'tau': sqrt(33)*log(t - sqrt(33)/8 + Rational(7, 8))/33 - sqrt(33)*log(t + sqrt(33)/8 + Rational(7, 8))/33, + 'is_transformed': True, 't_': t_, 'is_second_order': True, 'is_higher_order': True} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = ((t*Derivative(x(t), t) - x(t))*log(t) + (t*Derivative(y(t), t) - y(t))*exp(t) + Derivative(x(t), (t, 2)), + t**2*(t*Derivative(x(t), t) - x(t)) + t*(t*Derivative(y(t), t) - y(t)) + Derivative(y(t), (t, 2))) + sol3 = {'no_of_equation': 2, 'eq': ((t*Derivative(x(t), t) - x(t))*log(t) + (t*Derivative(y(t), t) - + y(t))*exp(t) + Derivative(x(t), (t, 2)), t**2*(t*Derivative(x(t), t) - x(t)) + t*(t*Derivative(y(t), + t) - y(t)) + Derivative(y(t), (t, 2))), 'func': [x(t), y(t)], 'order': {x(t): 2, y(t): 2}, + 'is_linear': True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': 'type1', 'A1': + Matrix([ [-t*log(t), -t*exp(t)], [ -t**3, -t**2]]), 'is_second_order': True, + 'is_higher_order': True} + assert _classify_linear_system(eq3, funcs, t) == sol3 + + eq4 = (Eq(x2, k*x(t) - l*y1), Eq(y2, l*x1 + k*y(t))) + sol4 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), (t, 2)), k*x(t) - l*Derivative(y(t), t)), + Eq(Derivative(y(t), (t, 2)), k*y(t) + l*Derivative(x(t), t))), 'func': [x(t), y(t)], 'order': {x(t): + 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': + 'type0', 'is_second_order': True, 'is_higher_order': True} + assert _classify_linear_system(eq4, funcs, t) == sol4 + + + # Multiple matches + + f, g = symbols("f g", cls=Function) + y, t_ = symbols("y t_") + funcs = [f(t), g(t)] + + eq1 = [Eq(Derivative(f(t), t)**2 - 2*Derivative(f(t), t) + 1, 4), + Eq(-y*f(t) + Derivative(g(t), t), 0)] + sol1 = {'is_implicit': True, + 'canon_eqs': [[Eq(Derivative(f(t), t), -1), Eq(Derivative(g(t), t), y*f(t))], + [Eq(Derivative(f(t), t), 3), Eq(Derivative(g(t), t), y*f(t))]]} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + raises(ValueError, lambda: _classify_linear_system(eq1, funcs[:1], t)) + + eq2 = [Eq(Derivative(f(t), t), (2*f(t) + g(t) + 1)/t), Eq(Derivative(g(t), t), (f(t) + 2*g(t))/t)] + sol2 = {'no_of_equation': 2, 'eq': [Eq(Derivative(f(t), t), (2*f(t) + g(t) + 1)/t), Eq(Derivative(g(t), t), + (f(t) + 2*g(t))/t)], 'func': [f(t), g(t)], 'order': {f(t): 1, g(t): 1}, 'is_linear': True, + 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ [1], [0]]), 'func_coeff': Matrix([ [2, + 1], [1, 2]]), 'is_constant': False, 'type_of_equation': 'type6', 't_': t_, 'tau': log(t), + 'commutative_antiderivative': Matrix([ [2*log(t), log(t)], [ log(t), 2*log(t)]])} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = [Eq(Derivative(f(t), t), (2*f(t) + g(t))/t), Eq(Derivative(g(t), t), (f(t) + 2*g(t))/t)] + sol3 = {'no_of_equation': 2, 'eq': [Eq(Derivative(f(t), t), (2*f(t) + g(t))/t), Eq(Derivative(g(t), t), + (f(t) + 2*g(t))/t)], 'func': [f(t), g(t)], 'order': {f(t): 1, g(t): 1}, 'is_linear': True, + 'is_homogeneous': True, 'is_general': True, 'func_coeff': Matrix([ [2, 1], [1, 2]]), 'is_constant': + False, 'type_of_equation': 'type5', 't_': t_, 'rhs': Matrix([ [0], [0]]), 'tau': log(t), + 'commutative_antiderivative': Matrix([ [2*log(t), log(t)], [ log(t), 2*log(t)]])} + assert _classify_linear_system(eq3, funcs, t) == sol3 + + +def test_matrix_exp(): + from sympy.matrices.dense import Matrix, eye, zeros + from sympy.solvers.ode.systems import matrix_exp + t = Symbol('t') + + for n in range(1, 6+1): + assert matrix_exp(zeros(n), t) == eye(n) + + for n in range(1, 6+1): + A = eye(n) + expAt = exp(t) * eye(n) + assert matrix_exp(A, t) == expAt + + for n in range(1, 6+1): + A = Matrix(n, n, lambda i,j: i+1 if i==j else 0) + expAt = Matrix(n, n, lambda i,j: exp((i+1)*t) if i==j else 0) + assert matrix_exp(A, t) == expAt + + A = Matrix([[0, 1], [-1, 0]]) + expAt = Matrix([[cos(t), sin(t)], [-sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[2, -5], [2, -4]]) + expAt = Matrix([ + [3*exp(-t)*sin(t) + exp(-t)*cos(t), -5*exp(-t)*sin(t)], + [2*exp(-t)*sin(t), -3*exp(-t)*sin(t) + exp(-t)*cos(t)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[21, 17, 6], [-5, -1, -6], [4, 4, 16]]) + # TO update this. + # expAt = Matrix([ + # [(8*t*exp(12*t) + 5*exp(12*t) - 1)*exp(4*t)/4, + # (8*t*exp(12*t) + 5*exp(12*t) - 5)*exp(4*t)/4, + # (exp(12*t) - 1)*exp(4*t)/2], + # [(-8*t*exp(12*t) - exp(12*t) + 1)*exp(4*t)/4, + # (-8*t*exp(12*t) - exp(12*t) + 5)*exp(4*t)/4, + # (-exp(12*t) + 1)*exp(4*t)/2], + # [4*t*exp(16*t), 4*t*exp(16*t), exp(16*t)]]) + expAt = Matrix([ + [2*t*exp(16*t) + 5*exp(16*t)/4 - exp(4*t)/4, 2*t*exp(16*t) + 5*exp(16*t)/4 - 5*exp(4*t)/4, exp(16*t)/2 - exp(4*t)/2], + [ -2*t*exp(16*t) - exp(16*t)/4 + exp(4*t)/4, -2*t*exp(16*t) - exp(16*t)/4 + 5*exp(4*t)/4, -exp(16*t)/2 + exp(4*t)/2], + [ 4*t*exp(16*t), 4*t*exp(16*t), exp(16*t)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, -S(1)/8], + [0, 0, S(1)/2, S(1)/2]]) + expAt = Matrix([ + [exp(t), t*exp(t), 4*t*exp(3*t/4) + 8*t*exp(t) + 48*exp(3*t/4) - 48*exp(t), + -2*t*exp(3*t/4) - 2*t*exp(t) - 16*exp(3*t/4) + 16*exp(t)], + [0, exp(t), -t*exp(3*t/4) - 8*exp(3*t/4) + 8*exp(t), t*exp(3*t/4)/2 + 2*exp(3*t/4) - 2*exp(t)], + [0, 0, t*exp(3*t/4)/4 + exp(3*t/4), -t*exp(3*t/4)/8], + [0, 0, t*exp(3*t/4)/2, -t*exp(3*t/4)/4 + exp(3*t/4)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([ + [ 0, 1, 0, 0], + [-1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0, -1, 0]]) + + expAt = Matrix([ + [ cos(t), sin(t), 0, 0], + [-sin(t), cos(t), 0, 0], + [ 0, 0, cos(t), sin(t)], + [ 0, 0, -sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + A = Matrix([ + [ 0, 1, 1, 0], + [-1, 0, 0, 1], + [ 0, 0, 0, 1], + [ 0, 0, -1, 0]]) + + expAt = Matrix([ + [ cos(t), sin(t), t*cos(t), t*sin(t)], + [-sin(t), cos(t), -t*sin(t), t*cos(t)], + [ 0, 0, cos(t), sin(t)], + [ 0, 0, -sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + # This case is unacceptably slow right now but should be solvable... + #a, b, c, d, e, f = symbols('a b c d e f') + #A = Matrix([ + #[-a, b, c, d], + #[ a, -b, e, 0], + #[ 0, 0, -c - e - f, 0], + #[ 0, 0, f, -d]]) + + A = Matrix([[0, I], [I, 0]]) + expAt = Matrix([ + [exp(I*t)/2 + exp(-I*t)/2, exp(I*t)/2 - exp(-I*t)/2], + [exp(I*t)/2 - exp(-I*t)/2, exp(I*t)/2 + exp(-I*t)/2]]) + assert matrix_exp(A, t) == expAt + + # Testing Errors + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 7, 7]]) + M1 = Matrix([[t, 1], [1, 1]]) + + raises(ValueError, lambda: matrix_exp(M[:, :2], t)) + raises(ValueError, lambda: matrix_exp(M[:2, :], t)) + raises(ValueError, lambda: matrix_exp(M1, t)) + raises(ValueError, lambda: matrix_exp(M1[:1, :1], t)) + + +def test_canonical_odes(): + f, g, h = symbols('f g h', cls=Function) + x = symbols('x') + funcs = [f(x), g(x), h(x)] + + eqs1 = [Eq(f(x).diff(x, x), f(x) + 2*g(x)), Eq(g(x) + 1, g(x).diff(x) + f(x))] + sol1 = [[Eq(Derivative(f(x), (x, 2)), f(x) + 2*g(x)), Eq(Derivative(g(x), x), -f(x) + g(x) + 1)]] + assert canonical_odes(eqs1, funcs[:2], x) == sol1 + + eqs2 = [Eq(f(x).diff(x), h(x).diff(x) + f(x)), Eq(g(x).diff(x)**2, f(x) + h(x)), Eq(h(x).diff(x), f(x))] + sol2 = [[Eq(Derivative(f(x), x), 2*f(x)), Eq(Derivative(g(x), x), -sqrt(f(x) + h(x))), Eq(Derivative(h(x), x), f(x))], + [Eq(Derivative(f(x), x), 2*f(x)), Eq(Derivative(g(x), x), sqrt(f(x) + h(x))), Eq(Derivative(h(x), x), f(x))]] + assert canonical_odes(eqs2, funcs, x) == sol2 + + +def test_sysode_linear_neq_order1_type1(): + + f, g, x, y, h = symbols('f g x y h', cls=Function) + a, b, c, t = symbols('a b c t') + + eqs1 = [Eq(Derivative(x(t), t), x(t)), + Eq(Derivative(y(t), t), y(t))] + sol1 = [Eq(x(t), C1*exp(t)), + Eq(y(t), C2*exp(t))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(x(t), t), 2*x(t)), + Eq(Derivative(y(t), t), 3*y(t))] + sol2 = [Eq(x(t), C1*exp(2*t)), + Eq(y(t), C2*exp(3*t))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(x(t), t), a*x(t)), + Eq(Derivative(y(t), t), a*y(t))] + sol3 = [Eq(x(t), C1*exp(a*t)), + Eq(y(t), C2*exp(a*t))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + eqs4 = [Eq(Derivative(x(t), t), a*x(t)), + Eq(Derivative(y(t), t), b*y(t))] + sol4 = [Eq(x(t), C1*exp(a*t)), + Eq(y(t), C2*exp(b*t))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(Derivative(x(t), t), -y(t)), + Eq(Derivative(y(t), t), x(t))] + sol5 = [Eq(x(t), -C1*sin(t) - C2*cos(t)), + Eq(y(t), C1*cos(t) - C2*sin(t))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(x(t), t), -2*y(t)), + Eq(Derivative(y(t), t), 2*x(t))] + sol6 = [Eq(x(t), -C1*sin(2*t) - C2*cos(2*t)), + Eq(y(t), C1*cos(2*t) - C2*sin(2*t))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + eqs7 = [Eq(Derivative(x(t), t), I*y(t)), + Eq(Derivative(y(t), t), I*x(t))] + sol7 = [Eq(x(t), -C1*exp(-I*t) + C2*exp(I*t)), + Eq(y(t), C1*exp(-I*t) + C2*exp(I*t))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + eqs8 = [Eq(Derivative(x(t), t), -a*y(t)), + Eq(Derivative(y(t), t), a*x(t))] + sol8 = [Eq(x(t), -I*C1*exp(-I*a*t) + I*C2*exp(I*a*t)), + Eq(y(t), C1*exp(-I*a*t) + C2*exp(I*a*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + eqs9 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), x(t) - y(t))] + sol9 = [Eq(x(t), C1*(1 - sqrt(2))*exp(-sqrt(2)*t) + C2*(1 + sqrt(2))*exp(sqrt(2)*t)), + Eq(y(t), C1*exp(-sqrt(2)*t) + C2*exp(sqrt(2)*t))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0]) + + eqs10 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol10 = [Eq(x(t), -C1 + C2*exp(2*t)), + Eq(y(t), C1 + C2*exp(2*t))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + eqs11 = [Eq(Derivative(x(t), t), 2*x(t) + y(t)), + Eq(Derivative(y(t), t), -x(t) + 2*y(t))] + sol11 = [Eq(x(t), C1*exp(2*t)*sin(t) + C2*exp(2*t)*cos(t)), + Eq(y(t), C1*exp(2*t)*cos(t) - C2*exp(2*t)*sin(t))] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0]) + + eqs12 = [Eq(Derivative(x(t), t), x(t) + 2*y(t)), + Eq(Derivative(y(t), t), 2*x(t) + y(t))] + sol12 = [Eq(x(t), -C1*exp(-t) + C2*exp(3*t)), + Eq(y(t), C1*exp(-t) + C2*exp(3*t))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + eqs13 = [Eq(Derivative(x(t), t), 4*x(t) + y(t)), + Eq(Derivative(y(t), t), -x(t) + 2*y(t))] + sol13 = [Eq(x(t), C2*t*exp(3*t) + (C1 + C2)*exp(3*t)), + Eq(y(t), -C1*exp(3*t) - C2*t*exp(3*t))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + eqs14 = [Eq(Derivative(x(t), t), a*y(t)), + Eq(Derivative(y(t), t), a*x(t))] + sol14 = [Eq(x(t), -C1*exp(-a*t) + C2*exp(a*t)), + Eq(y(t), C1*exp(-a*t) + C2*exp(a*t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0]) + + eqs15 = [Eq(Derivative(x(t), t), a*y(t)), + Eq(Derivative(y(t), t), b*x(t))] + sol15 = [Eq(x(t), -C1*a*exp(-t*sqrt(a*b))/sqrt(a*b) + C2*a*exp(t*sqrt(a*b))/sqrt(a*b)), + Eq(y(t), C1*exp(-t*sqrt(a*b)) + C2*exp(t*sqrt(a*b)))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0]) + + eqs16 = [Eq(Derivative(x(t), t), a*x(t) + b*y(t)), + Eq(Derivative(y(t), t), c*x(t))] + sol16 = [Eq(x(t), -2*C1*b*exp(t*(a + sqrt(a**2 + 4*b*c))/2)/(a - sqrt(a**2 + 4*b*c)) - 2*C2*b*exp(t*(a - + sqrt(a**2 + 4*b*c))/2)/(a + sqrt(a**2 + 4*b*c))), + Eq(y(t), C1*exp(t*(a + sqrt(a**2 + 4*b*c))/2) + C2*exp(t*(a - sqrt(a**2 + 4*b*c))/2))] + assert dsolve(eqs16) == sol16 + assert checksysodesol(eqs16, sol16) == (True, [0, 0]) + + # Regression test case for issue #18562 + # https://github.com/sympy/sympy/issues/18562 + eqs17 = [Eq(Derivative(x(t), t), a*y(t) + x(t)), + Eq(Derivative(y(t), t), a*x(t) - y(t))] + sol17 = [Eq(x(t), C1*a*exp(t*sqrt(a**2 + 1))/(sqrt(a**2 + 1) - 1) - C2*a*exp(-t*sqrt(a**2 + 1))/(sqrt(a**2 + + 1) + 1)), + Eq(y(t), C1*exp(t*sqrt(a**2 + 1)) + C2*exp(-t*sqrt(a**2 + 1)))] + assert dsolve(eqs17) == sol17 + assert checksysodesol(eqs17, sol17) == (True, [0, 0]) + + eqs18 = [Eq(Derivative(x(t), t), 0), + Eq(Derivative(y(t), t), 0)] + sol18 = [Eq(x(t), C1), + Eq(y(t), C2)] + assert dsolve(eqs18) == sol18 + assert checksysodesol(eqs18, sol18) == (True, [0, 0]) + + eqs19 = [Eq(Derivative(x(t), t), 2*x(t) - y(t)), + Eq(Derivative(y(t), t), x(t))] + sol19 = [Eq(x(t), C2*t*exp(t) + (C1 + C2)*exp(t)), + Eq(y(t), C1*exp(t) + C2*t*exp(t))] + assert dsolve(eqs19) == sol19 + assert checksysodesol(eqs19, sol19) == (True, [0, 0]) + + eqs20 = [Eq(Derivative(x(t), t), x(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol20 = [Eq(x(t), C1*exp(t)), + Eq(y(t), C1*t*exp(t) + C2*exp(t))] + assert dsolve(eqs20) == sol20 + assert checksysodesol(eqs20, sol20) == (True, [0, 0]) + + eqs21 = [Eq(Derivative(x(t), t), 3*x(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol21 = [Eq(x(t), 2*C1*exp(3*t)), + Eq(y(t), C1*exp(3*t) + C2*exp(t))] + assert dsolve(eqs21) == sol21 + assert checksysodesol(eqs21, sol21) == (True, [0, 0]) + + eqs22 = [Eq(Derivative(x(t), t), 3*x(t)), + Eq(Derivative(y(t), t), y(t))] + sol22 = [Eq(x(t), C1*exp(3*t)), + Eq(y(t), C2*exp(t))] + assert dsolve(eqs22) == sol22 + assert checksysodesol(eqs22, sol22) == (True, [0, 0]) + + +@slow +def test_sysode_linear_neq_order1_type1_slow(): + + t = Symbol('t') + Z0 = Function('Z0') + Z1 = Function('Z1') + Z2 = Function('Z2') + Z3 = Function('Z3') + + k01, k10, k20, k21, k23, k30 = symbols('k01 k10 k20 k21 k23 k30') + + eqs1 = [Eq(Derivative(Z0(t), t), -k01*Z0(t) + k10*Z1(t) + k20*Z2(t) + k30*Z3(t)), + Eq(Derivative(Z1(t), t), k01*Z0(t) - k10*Z1(t) + k21*Z2(t)), + Eq(Derivative(Z2(t), t), (-k20 - k21 - k23)*Z2(t)), + Eq(Derivative(Z3(t), t), k23*Z2(t) - k30*Z3(t))] + sol1 = [Eq(Z0(t), C1*k10/k01 - C2*(k10 - k30)*exp(-k30*t)/(k01 + k10 - k30) - C3*(k10*(k20 + k21 - k30) - + k20**2 - k20*(k21 + k23 - k30) + k23*k30)*exp(-t*(k20 + k21 + k23))/(k23*(-k01 - k10 + k20 + k21 + + k23)) - C4*exp(-t*(k01 + k10))), + Eq(Z1(t), C1 - C2*k01*exp(-k30*t)/(k01 + k10 - k30) + C3*(-k01*(k20 + k21 - k30) + k20*k21 + k21**2 + + k21*(k23 - k30))*exp(-t*(k20 + k21 + k23))/(k23*(-k01 - k10 + k20 + k21 + k23)) + C4*exp(-t*(k01 + + k10))), + Eq(Z2(t), -C3*(k20 + k21 + k23 - k30)*exp(-t*(k20 + k21 + k23))/k23), + Eq(Z3(t), C2*exp(-k30*t) + C3*exp(-t*(k20 + k21 + k23)))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0, 0, 0]) + + x, y, z, u, v, w = symbols('x y z u v w', cls=Function) + k2, k3 = symbols('k2 k3') + a_b, a_c = symbols('a_b a_c', real=True) + + eqs2 = [Eq(Derivative(z(t), t), k2*y(t)), + Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t))] + sol2 = [Eq(z(t), C1 - C2*k2*exp(-t*(k2 + k3))/(k2 + k3)), + Eq(x(t), -C2*k3*exp(-t*(k2 + k3))/(k2 + k3) + C3), + Eq(y(t), C2*exp(-t*(k2 + k3)))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0, 0]) + + eqs3 = [4*u(t) - v(t) - 2*w(t) + Derivative(u(t), t), + 2*u(t) + v(t) - 2*w(t) + Derivative(v(t), t), + 5*u(t) + v(t) - 3*w(t) + Derivative(w(t), t)] + sol3 = [Eq(u(t), C3*exp(-2*t) + (C1/2 + sqrt(3)*C2/6)*cos(sqrt(3)*t) + sin(sqrt(3)*t)*(sqrt(3)*C1/6 + + C2*Rational(-1, 2))), + Eq(v(t), (C1/2 + sqrt(3)*C2/6)*cos(sqrt(3)*t) + sin(sqrt(3)*t)*(sqrt(3)*C1/6 + C2*Rational(-1, 2))), + Eq(w(t), C1*cos(sqrt(3)*t) - C2*sin(sqrt(3)*t) + C3*exp(-2*t))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0]) + + eqs4 = [Eq(Derivative(x(t), t), w(t)*Rational(-2, 9) + 2*x(t) + y(t) + z(t)*Rational(-8, 9)), + Eq(Derivative(y(t), t), w(t)*Rational(4, 9) + 2*y(t) + z(t)*Rational(16, 9)), + Eq(Derivative(z(t), t), w(t)*Rational(-2, 9) + z(t)*Rational(37, 9)), + Eq(Derivative(w(t), t), w(t)*Rational(44, 9) + z(t)*Rational(-4, 9))] + sol4 = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t)), + Eq(y(t), C2*exp(2*t) + 2*C3*exp(4*t)), + Eq(z(t), 2*C3*exp(4*t) + C4*exp(5*t)*Rational(-1, 4)), + Eq(w(t), C3*exp(4*t) + C4*exp(5*t))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq5 = [Eq(x(t).diff(t), x(t)), Eq(y(t).diff(t), y(t)), Eq(z(t).diff(t), z(t)), Eq(w(t).diff(t), w(t))] + sol5 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t)), Eq(w(t), C4*exp(t))] + assert dsolve(eq5) == sol5 + assert checksysodesol(eq5, sol5) == (True, [0, 0, 0, 0]) + + eqs6 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), y(t) + z(t)), + Eq(Derivative(z(t), t), w(t)*Rational(-1, 8) + z(t)), + Eq(Derivative(w(t), t), w(t)/2 + z(t)/2)] + sol6 = [Eq(x(t), C1*exp(t) + C2*t*exp(t) + 4*C4*t*exp(t*Rational(3, 4)) + (4*C3 + 48*C4)*exp(t*Rational(3, + 4))), + Eq(y(t), C2*exp(t) - C4*t*exp(t*Rational(3, 4)) - (C3 + 8*C4)*exp(t*Rational(3, 4))), + Eq(z(t), C4*t*exp(t*Rational(3, 4))/4 + (C3/4 + C4)*exp(t*Rational(3, 4))), + Eq(w(t), C3*exp(t*Rational(3, 4))/2 + C4*t*exp(t*Rational(3, 4))/2)] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq7 = [Eq(Derivative(x(t), t), x(t)), Eq(Derivative(y(t), t), y(t)), Eq(Derivative(z(t), t), z(t)), + Eq(Derivative(w(t), t), w(t)), Eq(Derivative(u(t), t), u(t))] + sol7 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t)), Eq(w(t), C4*exp(t)), + Eq(u(t), C5*exp(t))] + assert dsolve(eq7) == sol7 + assert checksysodesol(eq7, sol7) == (True, [0, 0, 0, 0, 0]) + + eqs8 = [Eq(Derivative(x(t), t), 2*x(t) + y(t)), + Eq(Derivative(y(t), t), 2*y(t)), + Eq(Derivative(z(t), t), 4*z(t)), + Eq(Derivative(w(t), t), u(t) + 5*w(t)), + Eq(Derivative(u(t), t), 5*u(t))] + sol8 = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t)), + Eq(y(t), C2*exp(2*t)), + Eq(z(t), C3*exp(4*t)), + Eq(w(t), C4*exp(5*t) + C5*t*exp(5*t)), + Eq(u(t), C5*exp(5*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq9 = [Eq(Derivative(x(t), t), x(t)), Eq(Derivative(y(t), t), y(t)), Eq(Derivative(z(t), t), z(t))] + sol9 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t))] + assert dsolve(eq9) == sol9 + assert checksysodesol(eq9, sol9) == (True, [0, 0, 0]) + + # Regression test case for issue #15407 + # https://github.com/sympy/sympy/issues/15407 + eqs10 = [Eq(Derivative(x(t), t), (-a_b - a_c)*x(t)), + Eq(Derivative(y(t), t), a_b*y(t)), + Eq(Derivative(z(t), t), a_c*x(t))] + sol10 = [Eq(x(t), -C1*(a_b + a_c)*exp(-t*(a_b + a_c))/a_c), + Eq(y(t), C2*exp(a_b*t)), + Eq(z(t), C1*exp(-t*(a_b + a_c)) + C3)] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0, 0]) + + # Regression test case for issue #14312 + # https://github.com/sympy/sympy/issues/14312 + eqs11 = [Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t)), + Eq(Derivative(z(t), t), k2*y(t))] + sol11 = [Eq(x(t), C1 + C2*k3*exp(-t*(k2 + k3))/k2), + Eq(y(t), -C2*(k2 + k3)*exp(-t*(k2 + k3))/k2), + Eq(z(t), C2*exp(-t*(k2 + k3)) + C3)] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0, 0]) + + # Regression test case for issue #14312 + # https://github.com/sympy/sympy/issues/14312 + eqs12 = [Eq(Derivative(z(t), t), k2*y(t)), + Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t))] + sol12 = [Eq(z(t), C1 - C2*k2*exp(-t*(k2 + k3))/(k2 + k3)), + Eq(x(t), -C2*k3*exp(-t*(k2 + k3))/(k2 + k3) + C3), + Eq(y(t), C2*exp(-t*(k2 + k3)))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0, 0]) + + f, g, h = symbols('f, g, h', cls=Function) + a, b, c = symbols('a, b, c') + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + eqs13 = [Eq(Derivative(f(t), t), 2*f(t) + g(t)), + Eq(Derivative(g(t), t), a*f(t))] + sol13 = [Eq(f(t), C1*exp(t*(sqrt(a + 1) + 1))/(sqrt(a + 1) - 1) - C2*exp(-t*(sqrt(a + 1) - 1))/(sqrt(a + 1) + + 1)), + Eq(g(t), C1*exp(t*(sqrt(a + 1) + 1)) + C2*exp(-t*(sqrt(a + 1) - 1)))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + eqs14 = [Eq(Derivative(f(t), t), 2*g(t) - 3*h(t)), + Eq(Derivative(g(t), t), -2*f(t) + 4*h(t)), + Eq(Derivative(h(t), t), 3*f(t) - 4*g(t))] + sol14 = [Eq(f(t), 2*C1 - sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(3, 25) + C3*Rational(-8, 25)) - + cos(sqrt(29)*t)*(C2*Rational(8, 25) + sqrt(29)*C3*Rational(3, 25))), + Eq(g(t), C1*Rational(3, 2) + sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(4, 25) + C3*Rational(6, 25)) - + cos(sqrt(29)*t)*(C2*Rational(6, 25) + sqrt(29)*C3*Rational(-4, 25))), + Eq(h(t), C1 + C2*cos(sqrt(29)*t) - C3*sin(sqrt(29)*t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0, 0]) + + eqs15 = [Eq(2*Derivative(f(t), t), 12*g(t) - 12*h(t)), + Eq(3*Derivative(g(t), t), -8*f(t) + 8*h(t)), + Eq(4*Derivative(h(t), t), 6*f(t) - 6*g(t))] + sol15 = [Eq(f(t), C1 - sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(6, 13) + C3*Rational(-16, 13)) - + cos(sqrt(29)*t)*(C2*Rational(16, 13) + sqrt(29)*C3*Rational(6, 13))), + Eq(g(t), C1 + sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(8, 39) + C3*Rational(16, 13)) - + cos(sqrt(29)*t)*(C2*Rational(16, 13) + sqrt(29)*C3*Rational(-8, 39))), + Eq(h(t), C1 + C2*cos(sqrt(29)*t) - C3*sin(sqrt(29)*t))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0, 0]) + + eq16 = (Eq(diff(x(t), t), 21*x(t)), Eq(diff(y(t), t), 17*x(t) + 3*y(t)), + Eq(diff(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))) + sol16 = [Eq(x(t), 216*C1*exp(21*t)/209), + Eq(y(t), 204*C1*exp(21*t)/209 - 6*C2*exp(3*t)/7), + Eq(z(t), C1*exp(21*t) + C2*exp(3*t) + C3*exp(9*t))] + assert dsolve(eq16) == sol16 + assert checksysodesol(eq16, sol16) == (True, [0, 0, 0]) + + eqs17 = [Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), + Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))] + sol17 = [Eq(x(t), C1*Rational(7, 3) - sin(sqrt(179)*t)*(sqrt(179)*C2*Rational(11, 170) + C3*Rational(-21, + 170)) - cos(sqrt(179)*t)*(C2*Rational(21, 170) + sqrt(179)*C3*Rational(11, 170))), + Eq(y(t), C1*Rational(11, 3) + sin(sqrt(179)*t)*(sqrt(179)*C2*Rational(7, 170) + C3*Rational(33, + 170)) - cos(sqrt(179)*t)*(C2*Rational(33, 170) + sqrt(179)*C3*Rational(-7, 170))), + Eq(z(t), C1 + C2*cos(sqrt(179)*t) - C3*sin(sqrt(179)*t))] + assert dsolve(eqs17) == sol17 + assert checksysodesol(eqs17, sol17) == (True, [0, 0, 0]) + + eqs18 = [Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), + Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), + Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))] + sol18 = [Eq(x(t), C1 - sin(5*sqrt(2)*t)*(sqrt(2)*C2*Rational(4, 3) - C3) - cos(5*sqrt(2)*t)*(C2 + + sqrt(2)*C3*Rational(4, 3))), + Eq(y(t), C1 + sin(5*sqrt(2)*t)*(sqrt(2)*C2*Rational(3, 4) + C3) - cos(5*sqrt(2)*t)*(C2 + + sqrt(2)*C3*Rational(-3, 4))), + Eq(z(t), C1 + C2*cos(5*sqrt(2)*t) - C3*sin(5*sqrt(2)*t))] + assert dsolve(eqs18) == sol18 + assert checksysodesol(eqs18, sol18) == (True, [0, 0, 0]) + + eqs19 = [Eq(Derivative(x(t), t), 4*x(t) - z(t)), + Eq(Derivative(y(t), t), 2*x(t) + 2*y(t) - z(t)), + Eq(Derivative(z(t), t), 3*x(t) + y(t))] + sol19 = [Eq(x(t), C2*t**2*exp(2*t)/2 + t*(2*C2 + C3)*exp(2*t) + (C1 + C2 + 2*C3)*exp(2*t)), + Eq(y(t), C2*t**2*exp(2*t)/2 + t*(2*C2 + C3)*exp(2*t) + (C1 + 2*C3)*exp(2*t)), + Eq(z(t), C2*t**2*exp(2*t) + t*(3*C2 + 2*C3)*exp(2*t) + (2*C1 + 3*C3)*exp(2*t))] + assert dsolve(eqs19) == sol19 + assert checksysodesol(eqs19, sol19) == (True, [0, 0, 0]) + + eqs20 = [Eq(Derivative(x(t), t), 4*x(t) - y(t) - 2*z(t)), + Eq(Derivative(y(t), t), 2*x(t) + y(t) - 2*z(t)), + Eq(Derivative(z(t), t), 5*x(t) - 3*z(t))] + sol20 = [Eq(x(t), C1*exp(2*t) - sin(t)*(C2*Rational(3, 5) + C3/5) - cos(t)*(C2/5 + C3*Rational(-3, 5))), + Eq(y(t), -sin(t)*(C2*Rational(3, 5) + C3/5) - cos(t)*(C2/5 + C3*Rational(-3, 5))), + Eq(z(t), C1*exp(2*t) - C2*sin(t) + C3*cos(t))] + assert dsolve(eqs20) == sol20 + assert checksysodesol(eqs20, sol20) == (True, [0, 0, 0]) + + eq21 = (Eq(diff(x(t), t), 9*y(t)), Eq(diff(y(t), t), 12*x(t))) + sol21 = [Eq(x(t), -sqrt(3)*C1*exp(-6*sqrt(3)*t)/2 + sqrt(3)*C2*exp(6*sqrt(3)*t)/2), + Eq(y(t), C1*exp(-6*sqrt(3)*t) + C2*exp(6*sqrt(3)*t))] + + assert dsolve(eq21) == sol21 + assert checksysodesol(eq21, sol21) == (True, [0, 0]) + + eqs22 = [Eq(Derivative(x(t), t), 2*x(t) + 4*y(t)), + Eq(Derivative(y(t), t), 12*x(t) + 41*y(t))] + sol22 = [Eq(x(t), C1*(39 - sqrt(1713))*exp(t*(sqrt(1713) + 43)/2)*Rational(-1, 24) + C2*(39 + + sqrt(1713))*exp(t*(43 - sqrt(1713))/2)*Rational(-1, 24)), + Eq(y(t), C1*exp(t*(sqrt(1713) + 43)/2) + C2*exp(t*(43 - sqrt(1713))/2))] + assert dsolve(eqs22) == sol22 + assert checksysodesol(eqs22, sol22) == (True, [0, 0]) + + eqs23 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), -2*x(t) + 2*y(t))] + sol23 = [Eq(x(t), (C1/4 + sqrt(7)*C2/4)*cos(sqrt(7)*t/2)*exp(t*Rational(3, 2)) + + sin(sqrt(7)*t/2)*(sqrt(7)*C1/4 + C2*Rational(-1, 4))*exp(t*Rational(3, 2))), + Eq(y(t), C1*cos(sqrt(7)*t/2)*exp(t*Rational(3, 2)) - C2*sin(sqrt(7)*t/2)*exp(t*Rational(3, 2)))] + assert dsolve(eqs23) == sol23 + assert checksysodesol(eqs23, sol23) == (True, [0, 0]) + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + a = Symbol("a", real=True) + eq24 = [x(t).diff(t) - a*y(t), y(t).diff(t) + a*x(t)] + sol24 = [Eq(x(t), C1*sin(a*t) + C2*cos(a*t)), Eq(y(t), C1*cos(a*t) - C2*sin(a*t))] + assert dsolve(eq24) == sol24 + assert checksysodesol(eq24, sol24) == (True, [0, 0]) + + # Regression test case for issue #19150 + # https://github.com/sympy/sympy/issues/19150 + eqs25 = [Eq(Derivative(f(t), t), 0), + Eq(Derivative(g(t), t), (f(t) - 2*g(t) + x(t))/(b*c)), + Eq(Derivative(x(t), t), (g(t) - 2*x(t) + y(t))/(b*c)), + Eq(Derivative(y(t), t), (h(t) + x(t) - 2*y(t))/(b*c)), + Eq(Derivative(h(t), t), 0)] + sol25 = [Eq(f(t), -3*C1 + 4*C2), + Eq(g(t), -2*C1 + 3*C2 - C3*exp(-2*t/(b*c)) + C4*exp(-t*(sqrt(2) + 2)/(b*c)) + C5*exp(-t*(2 - + sqrt(2))/(b*c))), + Eq(x(t), -C1 + 2*C2 - sqrt(2)*C4*exp(-t*(sqrt(2) + 2)/(b*c)) + sqrt(2)*C5*exp(-t*(2 - + sqrt(2))/(b*c))), + Eq(y(t), C2 + C3*exp(-2*t/(b*c)) + C4*exp(-t*(sqrt(2) + 2)/(b*c)) + C5*exp(-t*(2 - sqrt(2))/(b*c))), + Eq(h(t), C1)] + assert dsolve(eqs25) == sol25 + assert checksysodesol(eqs25, sol25) == (True, [0, 0, 0, 0, 0]) + + eq26 = [Eq(Derivative(f(t), t), 2*f(t)), Eq(Derivative(g(t), t), 3*f(t) + 7*g(t))] + sol26 = [Eq(f(t), -5*C1*exp(2*t)/3), Eq(g(t), C1*exp(2*t) + C2*exp(7*t))] + assert dsolve(eq26) == sol26 + assert checksysodesol(eq26, sol26) == (True, [0, 0]) + + eq27 = [Eq(Derivative(f(t), t), -9*I*f(t) - 4*g(t)), Eq(Derivative(g(t), t), -4*I*g(t))] + sol27 = [Eq(f(t), 4*I*C1*exp(-4*I*t)/5 + C2*exp(-9*I*t)), Eq(g(t), C1*exp(-4*I*t))] + assert dsolve(eq27) == sol27 + assert checksysodesol(eq27, sol27) == (True, [0, 0]) + + eq28 = [Eq(Derivative(f(t), t), -9*I*f(t)), Eq(Derivative(g(t), t), -4*I*g(t))] + sol28 = [Eq(f(t), C1*exp(-9*I*t)), Eq(g(t), C2*exp(-4*I*t))] + assert dsolve(eq28) == sol28 + assert checksysodesol(eq28, sol28) == (True, [0, 0]) + + eq29 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), 0)] + sol29 = [Eq(f(t), C1), Eq(g(t), C2)] + assert dsolve(eq29) == sol29 + assert checksysodesol(eq29, sol29) == (True, [0, 0]) + + eq30 = [Eq(Derivative(f(t), t), f(t)), Eq(Derivative(g(t), t), 0)] + sol30 = [Eq(f(t), C1*exp(t)), Eq(g(t), C2)] + assert dsolve(eq30) == sol30 + assert checksysodesol(eq30, sol30) == (True, [0, 0]) + + eq31 = [Eq(Derivative(f(t), t), g(t)), Eq(Derivative(g(t), t), 0)] + sol31 = [Eq(f(t), C1 + C2*t), Eq(g(t), C2)] + assert dsolve(eq31) == sol31 + assert checksysodesol(eq31, sol31) == (True, [0, 0]) + + eq32 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), f(t))] + sol32 = [Eq(f(t), C1), Eq(g(t), C1*t + C2)] + assert dsolve(eq32) == sol32 + assert checksysodesol(eq32, sol32) == (True, [0, 0]) + + eq33 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), g(t))] + sol33 = [Eq(f(t), C1), Eq(g(t), C2*exp(t))] + assert dsolve(eq33) == sol33 + assert checksysodesol(eq33, sol33) == (True, [0, 0]) + + eq34 = [Eq(Derivative(f(t), t), f(t)), Eq(Derivative(g(t), t), I*g(t))] + sol34 = [Eq(f(t), C1*exp(t)), Eq(g(t), C2*exp(I*t))] + assert dsolve(eq34) == sol34 + assert checksysodesol(eq34, sol34) == (True, [0, 0]) + + eq35 = [Eq(Derivative(f(t), t), I*f(t)), Eq(Derivative(g(t), t), -I*g(t))] + sol35 = [Eq(f(t), C1*exp(I*t)), Eq(g(t), C2*exp(-I*t))] + assert dsolve(eq35) == sol35 + assert checksysodesol(eq35, sol35) == (True, [0, 0]) + + eq36 = [Eq(Derivative(f(t), t), I*g(t)), Eq(Derivative(g(t), t), 0)] + sol36 = [Eq(f(t), I*C1 + I*C2*t), Eq(g(t), C2)] + assert dsolve(eq36) == sol36 + assert checksysodesol(eq36, sol36) == (True, [0, 0]) + + eq37 = [Eq(Derivative(f(t), t), I*g(t)), Eq(Derivative(g(t), t), I*f(t))] + sol37 = [Eq(f(t), -C1*exp(-I*t) + C2*exp(I*t)), Eq(g(t), C1*exp(-I*t) + C2*exp(I*t))] + assert dsolve(eq37) == sol37 + assert checksysodesol(eq37, sol37) == (True, [0, 0]) + + # Multiple systems + eq1 = [Eq(Derivative(f(t), t)**2, g(t)**2), Eq(-f(t) + Derivative(g(t), t), 0)] + sol1 = [[Eq(f(t), -C1*sin(t) - C2*cos(t)), + Eq(g(t), C1*cos(t) - C2*sin(t))], + [Eq(f(t), -C1*exp(-t) + C2*exp(t)), + Eq(g(t), C1*exp(-t) + C2*exp(t))]] + assert dsolve(eq1) == sol1 + for sol in sol1: + assert checksysodesol(eq1, sol) == (True, [0, 0]) + + +def test_sysode_linear_neq_order1_type2(): + + f, g, h, k = symbols('f g h k', cls=Function) + x, t, a, b, c, d, y = symbols('x t a b c d y') + k1, k2 = symbols('k1 k2') + + + eqs1 = [Eq(Derivative(f(x), x), f(x) + g(x) + 5), + Eq(Derivative(g(x), x), -f(x) - g(x) + 7)] + sol1 = [Eq(f(x), C1 + C2 + 6*x**2 + x*(C2 + 5)), + Eq(g(x), -C1 - 6*x**2 - x*(C2 - 7))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(f(x), x), f(x) + g(x) + 5), + Eq(Derivative(g(x), x), f(x) + g(x) + 7)] + sol2 = [Eq(f(x), -C1 + C2*exp(2*x) - x - 3), + Eq(g(x), C1 + C2*exp(2*x) + x - 3)] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), x), f(x) + 5), + Eq(Derivative(g(x), x), f(x) + 7)] + sol3 = [Eq(f(x), C1*exp(x) - 5), + Eq(g(x), C1*exp(x) + C2 + 2*x - 5)] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(Derivative(f(x), x), f(x) + exp(x)), + Eq(Derivative(g(x), x), x*exp(x) + f(x) + g(x))] + sol4 = [Eq(f(x), C1*exp(x) + x*exp(x)), + Eq(g(x), C1*x*exp(x) + C2*exp(x) + x**2*exp(x))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), 5*x + f(x) + g(x)), + Eq(Derivative(g(x), x), f(x) - g(x))] + sol5 = [Eq(f(x), C1*(1 + sqrt(2))*exp(sqrt(2)*x) + C2*(1 - sqrt(2))*exp(-sqrt(2)*x) + x*Rational(-5, 2) + + Rational(-5, 2)), + Eq(g(x), C1*exp(sqrt(2)*x) + C2*exp(-sqrt(2)*x) + x*Rational(-5, 2))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), -9*f(x) - 4*g(x)), + Eq(Derivative(g(x), x), -4*g(x)), + Eq(Derivative(h(x), x), h(x) + exp(x))] + sol6 = [Eq(f(x), C2*exp(-4*x)*Rational(-4, 5) + C1*exp(-9*x)), + Eq(g(x), C2*exp(-4*x)), + Eq(h(x), C3*exp(x) + x*exp(x))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0]) + + # Regression test case for issue #8859 + # https://github.com/sympy/sympy/issues/8859 + eqs7 = [Eq(Derivative(f(t), t), 3*t + f(t)), + Eq(Derivative(g(t), t), g(t))] + sol7 = [Eq(f(t), C1*exp(t) - 3*t - 3), + Eq(g(t), C2*exp(t))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + # Regression test case for issue #8567 + # https://github.com/sympy/sympy/issues/8567 + eqs8 = [Eq(Derivative(f(t), t), f(t) + 2*g(t)), + Eq(Derivative(g(t), t), -2*f(t) + g(t) + 2*exp(t))] + sol8 = [Eq(f(t), C1*exp(t)*sin(2*t) + C2*exp(t)*cos(2*t) + + exp(t)*sin(2*t)**2 + exp(t)*cos(2*t)**2), + Eq(g(t), C1*exp(t)*cos(2*t) - C2*exp(t)*sin(2*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + # Regression test case for issue #19150 + # https://github.com/sympy/sympy/issues/19150 + eqs9 = [Eq(Derivative(f(t), t), (c - 2*f(t) + g(t))/(a*b)), + Eq(Derivative(g(t), t), (f(t) - 2*g(t) + h(t))/(a*b)), + Eq(Derivative(h(t), t), (d + g(t) - 2*h(t))/(a*b))] + sol9 = [Eq(f(t), -C1*exp(-2*t/(a*b)) + C2*exp(-t*(sqrt(2) + 2)/(a*b)) + C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 4), 3*c + d, evaluate=False)), + Eq(g(t), -sqrt(2)*C2*exp(-t*(sqrt(2) + 2)/(a*b)) + sqrt(2)*C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 2), c + d, evaluate=False)), + Eq(h(t), C1*exp(-2*t/(a*b)) + C2*exp(-t*(sqrt(2) + 2)/(a*b)) + C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 4), c + 3*d, evaluate=False))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0, 0]) + + # Regression test case for issue #16635 + # https://github.com/sympy/sympy/issues/16635 + eqs10 = [Eq(Derivative(f(t), t), 15*t + f(t) - g(t) - 10), + Eq(Derivative(g(t), t), -15*t + f(t) - g(t) - 5)] + sol10 = [Eq(f(t), C1 + C2 + 5*t**3 + 5*t**2 + t*(C2 - 10)), + Eq(g(t), C1 + 5*t**3 - 10*t**2 + t*(C2 - 5))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + # Multiple solutions + eqs11 = [Eq(Derivative(f(t), t)**2 - 2*Derivative(f(t), t) + 1, 4), + Eq(-y*f(t) + Derivative(g(t), t), 0)] + sol11 = [[Eq(f(t), C1 - t), Eq(g(t), C1*t*y + C2*y + t**2*y*Rational(-1, 2))], + [Eq(f(t), C1 + 3*t), Eq(g(t), C1*t*y + C2*y + t**2*y*Rational(3, 2))]] + assert dsolve(eqs11) == sol11 + for s11 in sol11: + assert checksysodesol(eqs11, s11) == (True, [0, 0]) + + # test case for issue #19831 + # https://github.com/sympy/sympy/issues/19831 + n = symbols('n', positive=True) + x0 = symbols('x_0') + t0 = symbols('t_0') + x_0 = symbols('x_0') + t_0 = symbols('t_0') + t = symbols('t') + x = Function('x') + y = Function('y') + T = symbols('T') + + eqs12 = [Eq(Derivative(y(t), t), x(t)), + Eq(Derivative(x(t), t), n*(y(t) + 1))] + sol12 = [Eq(y(t), C1*exp(sqrt(n)*t)*n**Rational(-1, 2) - C2*exp(-sqrt(n)*t)*n**Rational(-1, 2) - 1), + Eq(x(t), C1*exp(sqrt(n)*t) + C2*exp(-sqrt(n)*t))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + sol12b = [ + Eq(y(t), (T*exp(-sqrt(n)*t_0)/2 + exp(-sqrt(n)*t_0)/2 + + x_0*exp(-sqrt(n)*t_0)/(2*sqrt(n)))*exp(sqrt(n)*t) + + (T*exp(sqrt(n)*t_0)/2 + exp(sqrt(n)*t_0)/2 - + x_0*exp(sqrt(n)*t_0)/(2*sqrt(n)))*exp(-sqrt(n)*t) - 1), + Eq(x(t), (T*sqrt(n)*exp(-sqrt(n)*t_0)/2 + sqrt(n)*exp(-sqrt(n)*t_0)/2 + + x_0*exp(-sqrt(n)*t_0)/2)*exp(sqrt(n)*t) + - (T*sqrt(n)*exp(sqrt(n)*t_0)/2 + sqrt(n)*exp(sqrt(n)*t_0)/2 - + x_0*exp(sqrt(n)*t_0)/2)*exp(-sqrt(n)*t)) + ] + assert dsolve(eqs12, ics={y(t0): T, x(t0): x0}) == sol12b + assert checksysodesol(eqs12, sol12b) == (True, [0, 0]) + + #Test cases added for the issue 19763 + #https://github.com/sympy/sympy/issues/19763 + + eq13 = [Eq(Derivative(f(t), t), f(t) + g(t) + 9), + Eq(Derivative(g(t), t), 2*f(t) + 5*g(t) + 23)] + sol13 = [Eq(f(t), -C1*(2 + sqrt(6))*exp(t*(3 - sqrt(6)))/2 - C2*(2 - sqrt(6))*exp(t*(sqrt(6) + 3))/2 - + Rational(22,3)), + Eq(g(t), C1*exp(t*(3 - sqrt(6))) + C2*exp(t*(sqrt(6) + 3)) - Rational(5,3))] + assert dsolve(eq13) == sol13 + assert checksysodesol(eq13, sol13) == (True, [0, 0]) + + eq14 = [Eq(Derivative(f(t), t), f(t) + g(t) + 81), + Eq(Derivative(g(t), t), -2*f(t) + g(t) + 23)] + sol14 = [Eq(f(t), sqrt(2)*C1*exp(t)*sin(sqrt(2)*t)/2 + + sqrt(2)*C2*exp(t)*cos(sqrt(2)*t)/2 + - 58*sin(sqrt(2)*t)**2/3 - 58*cos(sqrt(2)*t)**2/3), + Eq(g(t), C1*exp(t)*cos(sqrt(2)*t) - C2*exp(t)*sin(sqrt(2)*t) + - 185*sin(sqrt(2)*t)**2/3 - 185*cos(sqrt(2)*t)**2/3)] + assert dsolve(eq14) == sol14 + assert checksysodesol(eq14, sol14) == (True, [0,0]) + + eq15 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), 3*f(t) + 4*g(t) + k2)] + sol15 = [Eq(f(t), -C1*(3 - sqrt(33))*exp(t*(5 + sqrt(33))/2)/6 - + C2*(3 + sqrt(33))*exp(t*(5 - sqrt(33))/2)/6 + 2*k1 - k2), + Eq(g(t), C1*exp(t*(5 + sqrt(33))/2) + C2*exp(t*(5 - sqrt(33))/2) - + Mul(Rational(1,2), 3*k1 - k2, evaluate = False))] + assert dsolve(eq15) == sol15 + assert checksysodesol(eq15, sol15) == (True, [0,0]) + + eq16 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), k2)] + sol16 = [Eq(f(t), C1 + k1*t), + Eq(g(t), C2 + k2*t)] + assert dsolve(eq16) == sol16 + assert checksysodesol(eq16, sol16) == (True, [0,0]) + + eq17 = [Eq(Derivative(f(t), t), 0), + Eq(Derivative(g(t), t), c*f(t) + k2)] + sol17 = [Eq(f(t), C1), + Eq(g(t), C2*c + t*(C1*c + k2))] + assert dsolve(eq17) == sol17 + assert checksysodesol(eq17 , sol17) == (True , [0,0]) + + eq18 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), f(t) + k2)] + sol18 = [Eq(f(t), C1 + k1*t), + Eq(g(t), C2 + k1*t**2/2 + t*(C1 + k2))] + assert dsolve(eq18) == sol18 + assert checksysodesol(eq18 , sol18) == (True , [0,0]) + + eq19 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), f(t) + 2*g(t) + k2)] + sol19 = [Eq(f(t), -2*C1 + k1*t), + Eq(g(t), C1 + C2*exp(2*t) - k1*t/2 - Mul(Rational(1,4), k1 + 2*k2 , evaluate = False))] + assert dsolve(eq19) == sol19 + assert checksysodesol(eq19 , sol19) == (True , [0,0]) + + eq20 = [Eq(diff(f(t), t), f(t) + k1), + Eq(diff(g(t), t), k2)] + sol20 = [Eq(f(t), C1*exp(t) - k1), + Eq(g(t), C2 + k2*t)] + assert dsolve(eq20) == sol20 + assert checksysodesol(eq20 , sol20) == (True , [0,0]) + + eq21 = [Eq(diff(f(t), t), g(t) + k1), + Eq(diff(g(t), t), 0)] + sol21 = [Eq(f(t), C1 + t*(C2 + k1)), + Eq(g(t), C2)] + assert dsolve(eq21) == sol21 + assert checksysodesol(eq21 , sol21) == (True , [0,0]) + + eq22 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), k2)] + sol22 = [Eq(f(t), -2*C1 + C2*exp(t) - k1 - 2*k2*t - 2*k2), + Eq(g(t), C1 + k2*t)] + assert dsolve(eq22) == sol22 + assert checksysodesol(eq22 , sol22) == (True , [0,0]) + + eq23 = [Eq(Derivative(f(t), t), g(t) + k1), + Eq(Derivative(g(t), t), 2*g(t) + k2)] + sol23 = [Eq(f(t), C1 + C2*exp(2*t)/2 - k2/4 + t*(2*k1 - k2)/2), + Eq(g(t), C2*exp(2*t) - k2/2)] + assert dsolve(eq23) == sol23 + assert checksysodesol(eq23 , sol23) == (True , [0,0]) + + eq24 = [Eq(Derivative(f(t), t), f(t) + k1), + Eq(Derivative(g(t), t), 2*f(t) + k2)] + sol24 = [Eq(f(t), C1*exp(t)/2 - k1), + Eq(g(t), C1*exp(t) + C2 - 2*k1 - t*(2*k1 - k2))] + assert dsolve(eq24) == sol24 + assert checksysodesol(eq24 , sol24) == (True , [0,0]) + + eq25 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), 3*f(t) + 6*g(t) + k2)] + sol25 = [Eq(f(t), -2*C1 + C2*exp(7*t)/3 + 2*t*(3*k1 - k2)/7 - + Mul(Rational(1,49), k1 + 2*k2 , evaluate = False)), + Eq(g(t), C1 + C2*exp(7*t) - t*(3*k1 - k2)/7 - + Mul(Rational(3,49), k1 + 2*k2 , evaluate = False))] + assert dsolve(eq25) == sol25 + assert checksysodesol(eq25 , sol25) == (True , [0,0]) + + eq26 = [Eq(Derivative(f(t), t), 2*f(t) - g(t) + k1), + Eq(Derivative(g(t), t), 4*f(t) - 2*g(t) + 2*k1)] + sol26 = [Eq(f(t), C1 + 2*C2 + t*(2*C1 + k1)), + Eq(g(t), 4*C2 + t*(4*C1 + 2*k1))] + assert dsolve(eq26) == sol26 + assert checksysodesol(eq26 , sol26) == (True , [0,0]) + + # Test Case added for issue #22715 + # https://github.com/sympy/sympy/issues/22715 + + eq27 = [Eq(diff(x(t),t),-1*y(t)+10), Eq(diff(y(t),t),5*x(t)-2*y(t)+3)] + sol27 = [Eq(x(t), (C1/5 - 2*C2/5)*exp(-t)*cos(2*t) + - (2*C1/5 + C2/5)*exp(-t)*sin(2*t) + + 17*sin(2*t)**2/5 + 17*cos(2*t)**2/5), + Eq(y(t), C1*exp(-t)*cos(2*t) - C2*exp(-t)*sin(2*t) + + 10*sin(2*t)**2 + 10*cos(2*t)**2)] + assert dsolve(eq27) == sol27 + assert checksysodesol(eq27 , sol27) == (True , [0,0]) + + +def test_sysode_linear_neq_order1_type3(): + + f, g, h, k, x0 , y0 = symbols('f g h k x0 y0', cls=Function) + x, t, a = symbols('x t a') + r = symbols('r', real=True) + + eqs1 = [Eq(Derivative(f(r), r), r*g(r) + f(r)), + Eq(Derivative(g(r), r), -r*f(r) + g(r))] + sol1 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(f(x), x), x**2*g(x) + x*f(x)), + Eq(Derivative(g(x), x), 2*x**2*f(x) + (3*x**2 + x)*g(x))] + sol2 = [Eq(f(x), (sqrt(17)*C1/17 + C2*(17 - 3*sqrt(17))/34)*exp(x**3*(3 + sqrt(17))/6 + x**2/2) - + exp(x**3*(3 - sqrt(17))/6 + x**2/2)*(sqrt(17)*C1/17 + C2*(3*sqrt(17) + 17)*Rational(-1, 34))), + Eq(g(x), exp(x**3*(3 - sqrt(17))/6 + x**2/2)*(C1*(17 - 3*sqrt(17))/34 + sqrt(17)*C2*Rational(-2, + 17)) + exp(x**3*(3 + sqrt(17))/6 + x**2/2)*(C1*(3*sqrt(17) + 17)/34 + sqrt(17)*C2*Rational(2, 17)))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(f(x).diff(x), x*f(x) + g(x)), + Eq(g(x).diff(x), -f(x) + x*g(x))] + sol3 = [Eq(f(x), (C1/2 + I*C2/2)*exp(x**2/2 - I*x) + exp(x**2/2 + I*x)*(C1/2 + I*C2*Rational(-1, 2))), + Eq(g(x), (I*C1/2 + C2/2)*exp(x**2/2 + I*x) - exp(x**2/2 - I*x)*(I*C1/2 + C2*Rational(-1, 2)))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(f(x).diff(x), x*(f(x) + g(x) + h(x))), Eq(g(x).diff(x), x*(f(x) + g(x) + h(x))), + Eq(h(x).diff(x), x*(f(x) + g(x) + h(x)))] + sol4 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0]) + + eqs5 = [Eq(f(x).diff(x), x**2*(f(x) + g(x) + h(x))), Eq(g(x).diff(x), x**2*(f(x) + g(x) + h(x))), + Eq(h(x).diff(x), x**2*(f(x) + g(x) + h(x)))] + sol5 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(k(x), x), x*(f(x) + g(x) + h(x) + k(x)))] + sol6 = [Eq(f(x), -C1/4 - C2/4 - C3/4 + 3*C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(g(x), 3*C1/4 - C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(h(x), -C1/4 + 3*C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(k(x), -C1/4 - C2/4 + 3*C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + y = symbols("y", real=True) + + eqs7 = [Eq(Derivative(f(y), y), y*f(y) + g(y)), + Eq(Derivative(g(y), y), y*g(y) - f(y))] + sol7 = [Eq(f(y), C1*exp(y**2/2)*sin(y) + C2*exp(y**2/2)*cos(y)), + Eq(g(y), C1*exp(y**2/2)*cos(y) - C2*exp(y**2/2)*sin(y))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + #Test cases added for the issue 19763 + #https://github.com/sympy/sympy/issues/19763 + + eqs8 = [Eq(Derivative(f(t), t), 5*t*f(t) + 2*h(t)), + Eq(Derivative(h(t), t), 2*f(t) + 5*t*h(t))] + sol8 = [Eq(f(t), Mul(-1, (C1/2 - C2/2), evaluate = False)*exp(5*t**2/2 - 2*t) + (C1/2 + C2/2)*exp(5*t**2/2 + 2*t)), + Eq(h(t), (C1/2 - C2/2)*exp(5*t**2/2 - 2*t) + (C1/2 + C2/2)*exp(5*t**2/2 + 2*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + eqs9 = [Eq(diff(f(t), t), 5*t*f(t) + t**2*g(t)), + Eq(diff(g(t), t), -t**2*f(t) + 5*t*g(t))] + sol9 = [Eq(f(t), (C1/2 - I*C2/2)*exp(I*t**3/3 + 5*t**2/2) + (C1/2 + I*C2/2)*exp(-I*t**3/3 + 5*t**2/2)), + Eq(g(t), Mul(-1, (I*C1/2 - C2/2) , evaluate = False)*exp(-I*t**3/3 + 5*t**2/2) + (I*C1/2 + C2/2)*exp(I*t**3/3 + 5*t**2/2))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9 , sol9) == (True , [0,0]) + + eqs10 = [Eq(diff(f(t), t), t**2*g(t) + 5*t*f(t)), + Eq(diff(g(t), t), -t**2*f(t) + (9*t**2 + 5*t)*g(t))] + sol10 = [Eq(f(t), (C1*(77 - 9*sqrt(77))/154 + sqrt(77)*C2/77)*exp(t**3*(sqrt(77) + 9)/6 + 5*t**2/2) + (C1*(77 + 9*sqrt(77))/154 - sqrt(77)*C2/77)*exp(t**3*(9 - sqrt(77))/6 + 5*t**2/2)), + Eq(g(t), (sqrt(77)*C1/77 + C2*(77 - 9*sqrt(77))/154)*exp(t**3*(9 - sqrt(77))/6 + 5*t**2/2) - (sqrt(77)*C1/77 - C2*(77 + 9*sqrt(77))/154)*exp(t**3*(sqrt(77) + 9)/6 + 5*t**2/2))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10 , sol10) == (True , [0,0]) + + eqs11 = [Eq(diff(f(t), t), 5*t*f(t) + t**2*g(t)), + Eq(diff(g(t), t), (1-t**2)*f(t) + (5*t + 9*t**2)*g(t))] + sol11 = [Eq(f(t), C1*x0(t) + C2*x0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t)), + Eq(g(t), C1*y0(t) + C2*(y0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t) + exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)))] + assert dsolve(eqs11) == sol11 + +@slow +def test_sysode_linear_neq_order1_type4(): + + f, g, h, k = symbols('f g h k', cls=Function) + x, t, a = symbols('x t a') + r = symbols('r', real=True) + + eqs1 = [Eq(diff(f(r), r), f(r) + r*g(r) + r**2), Eq(diff(g(r), r), -r*f(r) + g(r) + r)] + sol1 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2) + exp(r)*sin(r**2/2)*Integral(r**2*exp(-r)*sin(r**2/2) + + r*exp(-r)*cos(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r**2*exp(-r)*cos(r**2/2) - r*exp(-r)*sin(r**2/2), r)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2) - exp(r)*sin(r**2/2)*Integral(r**2*exp(-r)*cos(r**2/2) - + r*exp(-r)*sin(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r**2*exp(-r)*sin(r**2/2) + r*exp(-r)*cos(r**2/2), r))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(diff(f(r), r), f(r) + r*g(r) + r), Eq(diff(g(r), r), -r*f(r) + g(r) + log(r))] + sol2 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2) + exp(r)*sin(r**2/2)*Integral(r*exp(-r)*sin(r**2/2) + + exp(-r)*log(r)*cos(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r*exp(-r)*cos(r**2/2) - exp(-r)*log(r)*sin( + r**2/2), r)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2) - exp(r)*sin(r**2/2)*Integral(r*exp(-r)*cos(r**2/2) - + exp(-r)*log(r)*sin(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r*exp(-r)*sin(r**2/2) + exp(-r)*log(r)*cos( + r**2/2), r))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs2, simplify=False, doit=False) == [sol2] + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x)) + x), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x)) + x), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x)) + 1)] + sol3 = [Eq(f(x), C1*Rational(-1, 3) + C2*Rational(-1, 3) + C3*Rational(2, 3) + x**2/6 + x*Rational(-1, 3) + + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9)), + Eq(g(x), C1*Rational(2, 3) + C2*Rational(-1, 3) + C3*Rational(-1, 3) + x**2/6 + x*Rational(-1, 3) + + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9)), + Eq(h(x), C1*Rational(-1, 3) + C2*Rational(2, 3) + C3*Rational(-1, 3) + x**2*Rational(-1, 3) + + x*Rational(2, 3) + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0]) + + eqs4 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x)) + sin(x)), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x)) + sin(x)), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x)) + sin(x))] + sol4 = [Eq(f(x), C1*Rational(-1, 3) + C2*Rational(-1, 3) + C3*Rational(2, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2))), + Eq(g(x), C1*Rational(2, 3) + C2*Rational(-1, 3) + C3*Rational(-1, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2))), + Eq(h(x), C1*Rational(-1, 3) + C2*Rational(2, 3) + C3*Rational(-1, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2)))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(k(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1))] + sol5 = [Eq(f(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(3, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(g(x), C1*Rational(3, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(h(x), C1*Rational(-1, 4) + C2*Rational(3, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(k(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(3, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(g(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(h(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(k(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1))] + sol6 = [Eq(f(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(3, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(g(x), C1*Rational(3, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(h(x), C1*Rational(-1, 4) + C2*Rational(3, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(k(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(3, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + eqs7 = [Eq(Derivative(f(x), x), (f(x) + g(x) + h(x))*log(x) + sin(x)), Eq(Derivative(g(x), x), (f(x) + g(x) + + h(x))*log(x) + sin(x)), Eq(Derivative(h(x), x), (f(x) + g(x) + h(x))*log(x) + sin(x))] + sol7 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x))] + with dotprodsimp(True): + assert dsolve(eqs7, simplify=False, doit=False) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0, 0]) + + eqs8 = [Eq(Derivative(f(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + sin(x)), Eq(Derivative(g(x), x), (f(x) + + g(x) + h(x) + k(x))*log(x) + sin(x)), Eq(Derivative(h(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + + sin(x)), Eq(Derivative(k(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + sin(x))] + sol8 = [Eq(f(x), -C1/4 - C2/4 - C3/4 + 3*C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(g(x), 3*C1/4 - C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(h(x), -C1/4 + 3*C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(k(x), -C1/4 - C2/4 + 3*C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x))] + with dotprodsimp(True): + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0, 0, 0]) + + +def test_sysode_linear_neq_order1_type5_type6(): + f, g = symbols("f g", cls=Function) + x, x_ = symbols("x x_") + + # Type 5 + eqs1 = [Eq(Derivative(f(x), x), (2*f(x) + g(x))/x), Eq(Derivative(g(x), x), (f(x) + 2*g(x))/x)] + sol1 = [Eq(f(x), -C1*x + C2*x**3), Eq(g(x), C1*x + C2*x**3)] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + # Type 6 + eqs2 = [Eq(Derivative(f(x), x), (2*f(x) + g(x) + 1)/x), + Eq(Derivative(g(x), x), (x + f(x) + 2*g(x))/x)] + sol2 = [Eq(f(x), C2*x**3 - x*(C1 + Rational(1, 4)) + x*log(x)*Rational(-1, 2) + Rational(-2, 3)), + Eq(g(x), C2*x**3 + x*log(x)/2 + x*(C1 + Rational(-1, 4)) + Rational(1, 3))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + +def test_higher_order_to_first_order(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + eqs1 = [Eq(Derivative(f(x), (x, 2)), 2*f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), -f(x))] + sol1 = [Eq(f(x), -C2*x*exp(-x) + C3*x*exp(x) - (C1 - C2)*exp(-x) + (C3 + C4)*exp(x)), + Eq(g(x), C2*x*exp(-x) - C3*x*exp(x) + (C1 + C2)*exp(-x) + (C3 - C4)*exp(x))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(f(x).diff(x, 2), 0), Eq(g(x).diff(x, 2), f(x))] + sol2 = [Eq(f(x), C1 + C2*x), Eq(g(x), C1*x**2/2 + C2*x**3/6 + C3 + C4*x)] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), (x, 2)), 2*f(x)), + Eq(Derivative(g(x), (x, 2)), -f(x) + 2*g(x))] + sol3 = [Eq(f(x), 4*C1*exp(-sqrt(2)*x) + 4*C2*exp(sqrt(2)*x)), + Eq(g(x), sqrt(2)*C1*x*exp(-sqrt(2)*x) - sqrt(2)*C2*x*exp(sqrt(2)*x) + (C1 + + sqrt(2)*C4)*exp(-sqrt(2)*x) + (C2 - sqrt(2)*C3)*exp(sqrt(2)*x))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(Derivative(f(x), (x, 2)), 2*f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), 2*g(x))] + sol4 = [Eq(f(x), C1*x*exp(sqrt(2)*x)/4 + C3*x*exp(-sqrt(2)*x)/4 + (C2/4 + sqrt(2)*C3/8)*exp(-sqrt(2)*x) - + exp(sqrt(2)*x)*(sqrt(2)*C1/8 + C4*Rational(-1, 4))), + Eq(g(x), sqrt(2)*C1*exp(sqrt(2)*x)/2 + sqrt(2)*C3*exp(-sqrt(2)*x)*Rational(-1, 2))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(f(x).diff(x, 2), f(x)), Eq(g(x).diff(x, 2), f(x))] + sol5 = [Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), -C1*exp(-x) + C2*exp(x) + C3 + C4*x)] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), -f(x) - g(x))] + sol6 = [Eq(f(x), C1 + C2*x**2/2 + C2 + C4*x**3/6 + x*(C3 + C4)), + Eq(g(x), -C1 + C2*x**2*Rational(-1, 2) - C3*x + C4*x**3*Rational(-1, 6))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + eqs7 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x) + 1), + Eq(Derivative(g(x), (x, 2)), f(x) + g(x) + 1)] + sol7 = [Eq(f(x), -C1 - C2*x + sqrt(2)*C3*exp(sqrt(2)*x)/2 + sqrt(2)*C4*exp(-sqrt(2)*x)*Rational(-1, 2) + + Rational(-1, 2)), + Eq(g(x), C1 + C2*x + sqrt(2)*C3*exp(sqrt(2)*x)/2 + sqrt(2)*C4*exp(-sqrt(2)*x)*Rational(-1, 2) + + Rational(-1, 2))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + eqs8 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x) + 1), + Eq(Derivative(g(x), (x, 2)), -f(x) - g(x) + 1)] + sol8 = [Eq(f(x), C1 + C2 + C4*x**3/6 + x**4/12 + x**2*(C2/2 + Rational(1, 2)) + x*(C3 + C4)), + Eq(g(x), -C1 - C3*x + C4*x**3*Rational(-1, 6) + x**4*Rational(-1, 12) - x**2*(C2/2 + Rational(-1, + 2)))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + x, y = symbols('x, y', cls=Function) + t, l = symbols('t, l') + + eqs10 = [Eq(Derivative(x(t), (t, 2)), 5*x(t) + 43*y(t)), + Eq(Derivative(y(t), (t, 2)), x(t) + 9*y(t))] + sol10 = [Eq(x(t), C1*(61 - 9*sqrt(47))*sqrt(sqrt(47) + 7)*exp(-t*sqrt(sqrt(47) + 7))/2 + C2*sqrt(7 - + sqrt(47))*(61 + 9*sqrt(47))*exp(-t*sqrt(7 - sqrt(47)))/2 + C3*(61 - 9*sqrt(47))*sqrt(sqrt(47) + + 7)*exp(t*sqrt(sqrt(47) + 7))*Rational(-1, 2) + C4*sqrt(7 - sqrt(47))*(61 + 9*sqrt(47))*exp(t*sqrt(7 + - sqrt(47)))*Rational(-1, 2)), + Eq(y(t), C1*(7 - sqrt(47))*sqrt(sqrt(47) + 7)*exp(-t*sqrt(sqrt(47) + 7))*Rational(-1, 2) + C2*sqrt(7 + - sqrt(47))*(sqrt(47) + 7)*exp(-t*sqrt(7 - sqrt(47)))*Rational(-1, 2) + C3*(7 - + sqrt(47))*sqrt(sqrt(47) + 7)*exp(t*sqrt(sqrt(47) + 7))/2 + C4*sqrt(7 - sqrt(47))*(sqrt(47) + + 7)*exp(t*sqrt(7 - sqrt(47)))/2)] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + eqs11 = [Eq(7*x(t) + Derivative(x(t), (t, 2)) - 9*Derivative(y(t), t), 0), + Eq(7*y(t) + 9*Derivative(x(t), t) + Derivative(y(t), (t, 2)), 0)] + sol11 = [Eq(y(t), C1*(9 - sqrt(109))*sin(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)/14 + C2*(9 - + sqrt(109))*cos(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C3*(9 + + sqrt(109))*sin(sqrt(2)*t*sqrt(95 - 9*sqrt(109))/2)/14 + C4*(9 + sqrt(109))*cos(sqrt(2)*t*sqrt(95 - + 9*sqrt(109))/2)*Rational(-1, 14)), + Eq(x(t), C1*(9 - sqrt(109))*cos(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C2*(9 - + sqrt(109))*sin(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C3*(9 + + sqrt(109))*cos(sqrt(2)*t*sqrt(95 - 9*sqrt(109))/2)/14 + C4*(9 + sqrt(109))*sin(sqrt(2)*t*sqrt(95 - + 9*sqrt(109))/2)/14)] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0]) + + # Euler Systems + # Note: To add examples of euler systems solver with non-homogeneous term. + eqs13 = [Eq(Derivative(f(t), (t, 2)), Derivative(f(t), t)/t + f(t)/t**2 + g(t)/t**2), + Eq(Derivative(g(t), (t, 2)), g(t)/t**2)] + sol13 = [Eq(f(t), C1*(sqrt(5) + 3)*Rational(-1, 2)*t**(Rational(1, 2) + + sqrt(5)*Rational(-1, 2)) + C2*t**(Rational(1, 2) + + sqrt(5)/2)*(3 - sqrt(5))*Rational(-1, 2) - C3*t**(1 - + sqrt(2))*(1 + sqrt(2)) - C4*t**(1 + sqrt(2))*(1 - sqrt(2))), + Eq(g(t), C1*(1 + sqrt(5))*Rational(-1, 2)*t**(Rational(1, 2) + + sqrt(5)*Rational(-1, 2)) + C2*t**(Rational(1, 2) + + sqrt(5)/2)*(1 - sqrt(5))*Rational(-1, 2))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + # Solving systems using dsolve separately + eqs14 = [Eq(Derivative(f(t), (t, 2)), t*f(t)), + Eq(Derivative(g(t), (t, 2)), t*g(t))] + sol14 = [Eq(f(t), C1*airyai(t) + C2*airybi(t)), + Eq(g(t), C3*airyai(t) + C4*airybi(t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0]) + + + eqs15 = [Eq(Derivative(x(t), (t, 2)), t*(4*Derivative(x(t), t) + 8*Derivative(y(t), t))), + Eq(Derivative(y(t), (t, 2)), t*(12*Derivative(x(t), t) - 6*Derivative(y(t), t)))] + sol15 = [Eq(x(t), C1 - erf(sqrt(6)*t)*(sqrt(6)*sqrt(pi)*C2/33 + sqrt(6)*sqrt(pi)*C3*Rational(-1, 44)) + + erfi(sqrt(5)*t)*(sqrt(5)*sqrt(pi)*C2*Rational(2, 55) + sqrt(5)*sqrt(pi)*C3*Rational(4, 55))), + Eq(y(t), C4 + erf(sqrt(6)*t)*(sqrt(6)*sqrt(pi)*C2*Rational(2, 33) + sqrt(6)*sqrt(pi)*C3*Rational(-1, + 22)) + erfi(sqrt(5)*t)*(sqrt(5)*sqrt(pi)*C2*Rational(3, 110) + sqrt(5)*sqrt(pi)*C3*Rational(3, 55)))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0]) + + +@slow +def test_higher_order_to_first_order_9(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + eqs9 = [f(x) + g(x) - 2*exp(I*x) + 2*Derivative(f(x), x) + Derivative(f(x), (x, 2)), + f(x) + g(x) - 2*exp(I*x) + 2*Derivative(g(x), x) + Derivative(g(x), (x, 2))] + sol9 = [Eq(f(x), -C1 + C4*exp(-2*x)/2 - (C2/2 - C3/2)*exp(-x)*cos(x) + + (C2/2 + C3/2)*exp(-x)*sin(x) + 2*((1 - 2*I)*exp(I*x)*sin(x)**2/5) + + 2*((1 - 2*I)*exp(I*x)*cos(x)**2/5)), + Eq(g(x), C1 - C4*exp(-2*x)/2 - (C2/2 - C3/2)*exp(-x)*cos(x) + + (C2/2 + C3/2)*exp(-x)*sin(x) + 2*((1 - 2*I)*exp(I*x)*sin(x)**2/5) + + 2*((1 - 2*I)*exp(I*x)*cos(x)**2/5))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0]) + + +def test_higher_order_to_first_order_12(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + x, y = symbols('x, y', cls=Function) + t, l = symbols('t, l') + + eqs12 = [Eq(4*x(t) + Derivative(x(t), (t, 2)) + 8*Derivative(y(t), t), 0), + Eq(4*y(t) - 8*Derivative(x(t), t) + Derivative(y(t), (t, 2)), 0)] + sol12 = [Eq(y(t), C1*(2 - sqrt(5))*sin(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C2*(2 - + sqrt(5))*cos(2*t*sqrt(4*sqrt(5) + 9))/2 + C3*(2 + sqrt(5))*sin(2*t*sqrt(9 - 4*sqrt(5)))*Rational(-1, + 2) + C4*(2 + sqrt(5))*cos(2*t*sqrt(9 - 4*sqrt(5)))/2), + Eq(x(t), C1*(2 - sqrt(5))*cos(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C2*(2 - + sqrt(5))*sin(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C3*(2 + sqrt(5))*cos(2*t*sqrt(9 - + 4*sqrt(5)))/2 + C4*(2 + sqrt(5))*sin(2*t*sqrt(9 - 4*sqrt(5)))/2)] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + +def test_second_order_to_first_order_2(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + eqs2 = [Eq(f(x).diff(x, 2), 2*(x*g(x).diff(x) - g(x))), + Eq(g(x).diff(x, 2),-2*(x*f(x).diff(x) - f(x)))] + sol2 = [Eq(f(x), C1*x + x*Integral(C2*exp(-x_)*exp(I*exp(2*x_))/2 + C2*exp(-x_)*exp(-I*exp(2*x_))/2 - + I*C3*exp(-x_)*exp(I*exp(2*x_))/2 + I*C3*exp(-x_)*exp(-I*exp(2*x_))/2, (x_, log(x)))), + Eq(g(x), C4*x + x*Integral(I*C2*exp(-x_)*exp(I*exp(2*x_))/2 - I*C2*exp(-x_)*exp(-I*exp(2*x_))/2 + + C3*exp(-x_)*exp(I*exp(2*x_))/2 + C3*exp(-x_)*exp(-I*exp(2*x_))/2, (x_, log(x))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs2, simplify=False, doit=False) == [sol2] + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = (Eq(diff(f(t),t,t), 9*t*diff(g(t),t)-9*g(t)), Eq(diff(g(t),t,t),7*t*diff(f(t),t)-7*f(t))) + sol3 = [Eq(f(t), C1*t + t*Integral(C2*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/2 + C2*exp(-t_)* + exp(-3*sqrt(7)*exp(2*t_)/2)/2 + 3*sqrt(7)*C3*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/14 - + 3*sqrt(7)*C3*exp(-t_)*exp(-3*sqrt(7)*exp(2*t_)/2)/14, (t_, log(t)))), + Eq(g(t), C4*t + t*Integral(sqrt(7)*C2*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/6 - sqrt(7)*C2*exp(-t_)* + exp(-3*sqrt(7)*exp(2*t_)/2)/6 + C3*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/2 + C3*exp(-t_)*exp(-3*sqrt(7)* + exp(2*t_)/2)/2, (t_, log(t))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs3, simplify=False, doit=False) == [sol3] + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + # Regression Test case for sympy#19238 + # https://github.com/sympy/sympy/issues/19238 + # Note: When the doit method is removed, these particular types of systems + # can be divided first so that we have lesser number of big matrices. + eqs5 = [Eq(Derivative(g(t), (t, 2)), a*m), + Eq(Derivative(f(t), (t, 2)), 0)] + sol5 = [Eq(g(t), C1 + C2*t + a*m*t**2/2), + Eq(f(t), C3 + C4*t)] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + # Type 2 + eqs6 = [Eq(Derivative(f(t), (t, 2)), f(t)/t**4), + Eq(Derivative(g(t), (t, 2)), d*g(t)/t**4)] + sol6 = [Eq(f(t), C1*sqrt(t**2)*exp(-1/t) - C2*sqrt(t**2)*exp(1/t)), + Eq(g(t), C3*sqrt(t**2)*exp(-sqrt(d)/t)*d**Rational(-1, 2) - + C4*sqrt(t**2)*exp(sqrt(d)/t)*d**Rational(-1, 2))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + +@slow +def test_second_order_to_first_order_slow1(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + # Type 1 + + eqs1 = [Eq(f(x).diff(x, 2), 2/x *(x*g(x).diff(x) - g(x))), + Eq(g(x).diff(x, 2),-2/x *(x*f(x).diff(x) - f(x)))] + sol1 = [Eq(f(x), C1*x + 2*C2*x*Ci(2*x) - C2*sin(2*x) - 2*C3*x*Si(2*x) - C3*cos(2*x)), + Eq(g(x), -2*C2*x*Si(2*x) - C2*cos(2*x) - 2*C3*x*Ci(2*x) + C3*sin(2*x) + C4*x)] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + +def test_second_order_to_first_order_slow4(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + eqs4 = [Eq(Derivative(f(t), (t, 2)), t*sin(t)*Derivative(g(t), t) - g(t)*sin(t)), + Eq(Derivative(g(t), (t, 2)), t*sin(t)*Derivative(f(t), t) - f(t)*sin(t))] + sol4 = [Eq(f(t), C1*t + t*Integral(C2*exp(-t_)*exp(exp(t_)*cos(exp(t_)))*exp(-sin(exp(t_)))/2 + + C2*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2 - C3*exp(-t_)*exp(exp(t_)*cos(exp(t_)))* + exp(-sin(exp(t_)))/2 + + C3*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2, (t_, log(t)))), + Eq(g(t), C4*t + t*Integral(-C2*exp(-t_)*exp(exp(t_)*cos(exp(t_)))*exp(-sin(exp(t_)))/2 + + C2*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2 + C3*exp(-t_)*exp(exp(t_)*cos(exp(t_)))* + exp(-sin(exp(t_)))/2 + C3*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2, (t_, log(t))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs4, simplify=False, doit=False) == [sol4] + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + +def test_component_division(): + f, g, h, k = symbols('f g h k', cls=Function) + x = symbols("x") + funcs = [f(x), g(x), h(x), k(x)] + + eqs1 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), h(x)**4 + k(x))] + sol1 = [Eq(f(x), 2*C1*exp(2*x)), + Eq(g(x), C1*exp(2*x) + C2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C3**4*exp(4*x)/3 + C4*exp(x))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0, 0, 0]) + + components1 = {((Eq(Derivative(f(x), x), 2*f(x)),), (Eq(Derivative(g(x), x), f(x)),)), + ((Eq(Derivative(h(x), x), h(x)),), (Eq(Derivative(k(x), x), h(x)**4 + k(x)),))} + eqsdict1 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {h(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), h(x)**4 + k(x))}) + graph1 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), h(x))}] + assert {tuple(tuple(scc) for scc in wcc) for wcc in _component_division(eqs1, funcs, x)} == components1 + assert _eqs2dict(eqs1, funcs) == eqsdict1 + assert [set(element) for element in _dict2graph(eqsdict1[0])] == graph1 + + eqs2 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol2 = [Eq(f(x), C1*exp(2*x)), + Eq(g(x), C1*exp(2*x)/2 + C2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C1**4*exp(8*x)/7 + C4*exp(x))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0, 0, 0]) + + components2 = {frozenset([(Eq(Derivative(f(x), x), 2*f(x)),), + (Eq(Derivative(g(x), x), f(x)),), + (Eq(Derivative(k(x), x), f(x)**4 + k(x)),)]), + frozenset([(Eq(Derivative(h(x), x), h(x)),)])} + eqsdict2 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph2 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), f(x))}] + assert {frozenset(tuple(scc) for scc in wcc) for wcc in _component_division(eqs2, funcs, x)} == components2 + assert _eqs2dict(eqs2, funcs) == eqsdict2 + assert [set(element) for element in _dict2graph(eqsdict2[0])] == graph2 + + eqs3 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), x + f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol3 = [Eq(f(x), C1*exp(2*x)), + Eq(g(x), C1*exp(2*x)/2 + C2 + x**2/2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C1**4*exp(8*x)/7 + C4*exp(x))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0, 0]) + + components3 = {frozenset([(Eq(Derivative(f(x), x), 2*f(x)),), + (Eq(Derivative(g(x), x), x + f(x)),), + (Eq(Derivative(k(x), x), f(x)**4 + k(x)),)]), + frozenset([(Eq(Derivative(h(x), x), h(x)),),])} + eqsdict3 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), x + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph3 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), f(x))}] + assert {frozenset(tuple(scc) for scc in wcc) for wcc in _component_division(eqs3, funcs, x)} == components3 + assert _eqs2dict(eqs3, funcs) == eqsdict3 + assert [set(l) for l in _dict2graph(eqsdict3[0])] == graph3 + + # Note: To be uncommented when the default option to call dsolve first for + # single ODE system can be rearranged. This can be done after the doit + # option in dsolve is made False by default. + + eqs4 = [Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), f(x) + x*g(x) + x), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol4 = [Eq(f(x), (C1/2 - sqrt(2)*C2/2 - sqrt(2)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 +\ + sqrt(2)*x)/2, x)/2 + Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 +\ + sqrt(2)*x)/2, x)/2)*exp(x**2/2 - sqrt(2)*x) + (C1/2 + sqrt(2)*C2/2 + sqrt(2)*Integral(x*exp(-x**2/2 + - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 + - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(g(x), (-sqrt(2)*C1/4 + C2/2 + Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 -\ + sqrt(2)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, + x)/4)*exp(x**2/2 - sqrt(2)*x) + (sqrt(2)*C1/4 + C2/2 + Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + sqrt(2)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - + sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/4)*exp(x**2/2 + sqrt(2)*x)), + Eq(h(x), C3*exp(x)), + Eq(k(x), C4*exp(x) + exp(x)*Integral((C1*exp(x**2/2 - sqrt(2)*x)/2 + C1*exp(x**2/2 + sqrt(2)*x)/2 - + sqrt(2)*C2*exp(x**2/2 - sqrt(2)*x)/2 + sqrt(2)*C2*exp(x**2/2 + sqrt(2)*x)/2 - sqrt(2)*exp(x**2/2 - + sqrt(2)*x)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + exp(x**2/2 - + sqrt(2)*x)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, + x)/2 + sqrt(2)*exp(x**2/2 + sqrt(2)*x)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + + sqrt(2)*x)/2, x)/2 + exp(x**2/2 + sqrt(2)*x)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - + sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2)**4*exp(-x), x))] + components4 = {(frozenset([Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + x + f(x))]), + frozenset([Eq(Derivative(k(x), x), f(x)**4 + k(x)),])), + (frozenset([Eq(Derivative(h(x), x), h(x)),]),)} + eqsdict4 = ({f(x): {g(x)}, g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + g(x): Eq(Derivative(g(x), x), x*g(x) + x + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph4 = [{f(x), g(x), h(x), k(x)}, {(f(x), g(x)), (g(x), f(x)), (k(x), f(x))}] + assert {tuple(frozenset(scc) for scc in wcc) for wcc in _component_division(eqs4, funcs, x)} == components4 + assert _eqs2dict(eqs4, funcs) == eqsdict4 + assert [set(element) for element in _dict2graph(eqsdict4[0])] == graph4 + # XXX: dsolve hangs in integration here: + assert dsolve_system(eqs4, simplify=False, doit=False) == [sol4] + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol5 = [Eq(f(x), (C1/2 - sqrt(2)*C2/2)*exp(x**2/2 - sqrt(2)*x) + (C1/2 + sqrt(2)*C2/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(g(x), (-sqrt(2)*C1/4 + C2/2)*exp(x**2/2 - sqrt(2)*x) + (sqrt(2)*C1/4 + C2/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(h(x), C3*exp(x)), + Eq(k(x), C4*exp(x) + exp(x)*Integral((C1*exp(x**2/2 - sqrt(2)*x)/2 + C1*exp(x**2/2 + sqrt(2)*x)/2 - + sqrt(2)*C2*exp(x**2/2 - sqrt(2)*x)/2 + sqrt(2)*C2*exp(x**2/2 + sqrt(2)*x)/2)**4*exp(-x), x))] + components5 = {(frozenset([Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + f(x))]), + frozenset([Eq(Derivative(k(x), x), f(x)**4 + k(x)),])), + (frozenset([Eq(Derivative(h(x), x), h(x)),]),)} + eqsdict5 = ({f(x): {g(x)}, g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + g(x): Eq(Derivative(g(x), x), x*g(x) + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph5 = [{f(x), g(x), h(x), k(x)}, {(f(x), g(x)), (g(x), f(x)), (k(x), f(x))}] + assert {tuple(frozenset(scc) for scc in wcc) for wcc in _component_division(eqs5, funcs, x)} == components5 + assert _eqs2dict(eqs5, funcs) == eqsdict5 + assert [set(element) for element in _dict2graph(eqsdict5[0])] == graph5 + # XXX: dsolve hangs in integration here: + assert dsolve_system(eqs5, simplify=False, doit=False) == [sol5] + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0, 0]) + + +def test_linodesolve(): + t, x, a = symbols("t x a") + f, g, h = symbols("f g h", cls=Function) + + # Testing the Errors + raises(ValueError, lambda: linodesolve(1, t)) + raises(ValueError, lambda: linodesolve(a, t)) + + A1 = Matrix([[1, 2], [2, 4], [4, 6]]) + raises(NonSquareMatrixError, lambda: linodesolve(A1, t)) + + A2 = Matrix([[1, 2, 1], [3, 1, 2]]) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t)) + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t) + g(t).diff(t), g(t)), Eq(g(t).diff(t), f(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [C1*(-Rational(1, 2) + sqrt(5)/2)*exp(t*(-Rational(1, 2) + sqrt(5)/2)) + C2*(-sqrt(5)/2 - Rational(1, 2))* + exp(t*(-sqrt(5)/2 - Rational(1, 2))), + C1*exp(t*(-Rational(1, 2) + sqrt(5)/2)) + C2*exp(t*(-sqrt(5)/2 - Rational(1, 2)))] + assert constant_renumber(linodesolve(A, t), variables=Tuple(*eq).free_symbols) == sol + + # Testing the Errors + raises(ValueError, lambda: linodesolve(1, t, b=Matrix([t+1]))) + raises(ValueError, lambda: linodesolve(a, t, b=Matrix([log(t) + sin(t)]))) + + raises(ValueError, lambda: linodesolve(Matrix([7]), t, b=t**2)) + raises(ValueError, lambda: linodesolve(Matrix([a+10]), t, b=log(t)*cos(t))) + + raises(ValueError, lambda: linodesolve(7, t, b=t**2)) + raises(ValueError, lambda: linodesolve(a, t, b=log(t) + sin(t))) + + A1 = Matrix([[1, 2], [2, 4], [4, 6]]) + b1 = Matrix([t, 1, t**2]) + raises(NonSquareMatrixError, lambda: linodesolve(A1, t, b=b1)) + + A2 = Matrix([[1, 2, 1], [3, 1, 2]]) + b2 = Matrix([t, t**2]) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t, b=b2)) + + raises(ValueError, lambda: linodesolve(A1[:2, :], t, b=b1)) + raises(ValueError, lambda: linodesolve(A1[:2, :], t, b=b1[:1])) + + # DOIT check + A1 = Matrix([[1, -1], [1, -1]]) + b1 = Matrix([15*t - 10, -15*t - 5]) + sol1 = [C1 + C2*t + C2 - 10*t**3 + 10*t**2 + t*(15*t**2 - 5*t) - 10*t, + C1 + C2*t - 10*t**3 - 5*t**2 + t*(15*t**2 - 5*t) - 5*t] + assert constant_renumber(linodesolve(A1, t, b=b1, type="type2", doit=True), + variables=[t]) == sol1 + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t) + g(t).diff(t), g(t) + t), Eq(g(t).diff(t), f(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [-C1*exp(-t/2 + sqrt(5)*t/2)/2 + sqrt(5)*C1*exp(-t/2 + sqrt(5)*t/2)/2 - sqrt(5)*C2*exp(-sqrt(5)*t/2 - + t/2)/2 - C2*exp(-sqrt(5)*t/2 - t/2)/2 - exp(-t/2 + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)), t)/2 + sqrt(5)*exp(-t/2 + + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)), t)/2 - sqrt(5)*exp(-sqrt(5)*t/2 - t/2)*Integral(-sqrt(5)*t*exp(t/2 + + sqrt(5)*t/2)/5, t)/2 - exp(-sqrt(5)*t/2 - t/2)*Integral(-sqrt(5)*t*exp(t/2 + sqrt(5)*t/2)/5, t)/2, + C1*exp(-t/2 + sqrt(5)*t/2) + C2*exp(-sqrt(5)*t/2 - t/2) + exp(-t/2 + + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)), t) + exp(-sqrt(5)*t/2 - + t/2)*Integral(-sqrt(5)*t*exp(t/2 + sqrt(5)*t/2)/5, t)] + assert constant_renumber(linodesolve(A, t, b=b), variables=[t]) == sol + + # non-homogeneous term assumed to be 0 + sol1 = [-C1*exp(-t/2 + sqrt(5)*t/2)/2 + sqrt(5)*C1*exp(-t/2 + sqrt(5)*t/2)/2 - sqrt(5)*C2*exp(-sqrt(5)*t/2 + - t/2)/2 - C2*exp(-sqrt(5)*t/2 - t/2)/2, + C1*exp(-t/2 + sqrt(5)*t/2) + C2*exp(-sqrt(5)*t/2 - t/2)] + assert constant_renumber(linodesolve(A, t, type="type2"), variables=[t]) == sol1 + + # Testing the Errors + raises(ValueError, lambda: linodesolve(t+10, t)) + raises(ValueError, lambda: linodesolve(a*t, t)) + + A1 = Matrix([[1, t], [-t, 1]]) + B1, _ = _is_commutative_anti_derivative(A1, t) + raises(NonSquareMatrixError, lambda: linodesolve(A1[:, :1], t, B=B1)) + raises(ValueError, lambda: linodesolve(A1, t, B=1)) + + A2 = Matrix([[t, t, t], [t, t, t], [t, t, t]]) + B2, _ = _is_commutative_anti_derivative(A2, t) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t, B=B2[:2, :])) + raises(ValueError, lambda: linodesolve(A2, t, B=2)) + raises(ValueError, lambda: linodesolve(A2, t, B=B2, type="type31")) + + raises(ValueError, lambda: linodesolve(A1, t, B=B2)) + raises(ValueError, lambda: linodesolve(A2, t, B=B1)) + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t), f(t) + t*g(t)), Eq(g(t).diff(t), -t*f(t) + g(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [(C1/2 - I*C2/2)*exp(I*t**2/2 + t) + (C1/2 + I*C2/2)*exp(-I*t**2/2 + t), + (-I*C1/2 + C2/2)*exp(-I*t**2/2 + t) + (I*C1/2 + C2/2)*exp(I*t**2/2 + t)] + assert constant_renumber(linodesolve(A, t), variables=Tuple(*eq).free_symbols) == sol + assert constant_renumber(linodesolve(A, t, type="type3"), variables=Tuple(*eq).free_symbols) == sol + + A1 = Matrix([[t, 1], [t, -1]]) + raises(NotImplementedError, lambda: linodesolve(A1, t)) + + # Testing the Errors + raises(ValueError, lambda: linodesolve(t+10, t, b=Matrix([t+1]))) + raises(ValueError, lambda: linodesolve(a*t, t, b=Matrix([log(t) + sin(t)]))) + + raises(ValueError, lambda: linodesolve(Matrix([7*t]), t, b=t**2)) + raises(ValueError, lambda: linodesolve(Matrix([a + 10*log(t)]), t, b=log(t)*cos(t))) + + raises(ValueError, lambda: linodesolve(7*t, t, b=t**2)) + raises(ValueError, lambda: linodesolve(a*t**2, t, b=log(t) + sin(t))) + + A1 = Matrix([[1, t], [-t, 1]]) + b1 = Matrix([t, t ** 2]) + B1, _ = _is_commutative_anti_derivative(A1, t) + raises(NonSquareMatrixError, lambda: linodesolve(A1[:, :1], t, b=b1)) + + A2 = Matrix([[t, t, t], [t, t, t], [t, t, t]]) + b2 = Matrix([t, 1, t**2]) + B2, _ = _is_commutative_anti_derivative(A2, t) + raises(NonSquareMatrixError, lambda: linodesolve(A2[:2, :], t, b=b2)) + + raises(ValueError, lambda: linodesolve(A1, t, b=b2)) + raises(ValueError, lambda: linodesolve(A2, t, b=b1)) + + raises(ValueError, lambda: linodesolve(A1, t, b=b1, B=B2)) + raises(ValueError, lambda: linodesolve(A2, t, b=b2, B=B1)) + + # Testing auto functionality + func = [f(x), g(x), h(x)] + eq = [Eq(f(x).diff(x), x*(f(x) + g(x) + h(x)) + x), + Eq(g(x).diff(x), x*(f(x) + g(x) + h(x)) + x), + Eq(h(x).diff(x), x*(f(x) + g(x) + h(x)) + 1)] + ceq = canonical_odes(eq, func, x) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, x, 1) + A = A0 + _x1 = exp(-3*x**2/2) + _x2 = exp(3*x**2/2) + _x3 = Integral(2*_x1*x/3 + _x1/3 + x/3 - Rational(1, 3), x) + _x4 = 2*_x2*_x3/3 + _x5 = Integral(2*_x1*x/3 + _x1/3 - 2*x/3 + Rational(2, 3), x) + sol = [ + C1*_x2/3 - C1/3 + C2*_x2/3 - C2/3 + C3*_x2/3 + 2*C3/3 + _x2*_x5/3 + _x3/3 + _x4 - _x5/3, + C1*_x2/3 + 2*C1/3 + C2*_x2/3 - C2/3 + C3*_x2/3 - C3/3 + _x2*_x5/3 + _x3/3 + _x4 - _x5/3, + C1*_x2/3 - C1/3 + C2*_x2/3 + 2*C2/3 + C3*_x2/3 - C3/3 + _x2*_x5/3 - 2*_x3/3 + _x4 + 2*_x5/3, + ] + assert constant_renumber(linodesolve(A, x, b=b), variables=Tuple(*eq).free_symbols) == sol + assert constant_renumber(linodesolve(A, x, b=b, type="type4"), + variables=Tuple(*eq).free_symbols) == sol + + A1 = Matrix([[t, 1], [t, -1]]) + raises(NotImplementedError, lambda: linodesolve(A1, t, b=b1)) + + # non-homogeneous term not passed + sol1 = [-C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2), + -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)] + assert constant_renumber(linodesolve(A, x, type="type4", doit=True), variables=Tuple(*eq).free_symbols) == sol1 + + +@slow +def test_linear_3eq_order1_type4_slow(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + + f = t ** 3 + log(t) + g = t ** 2 + sin(t) + eq1 = (Eq(diff(x(t), t), (4 * f + g) * x(t) - f * y(t) - 2 * f * z(t)), + Eq(diff(y(t), t), 2 * f * x(t) + (f + g) * y(t) - 2 * f * z(t)), Eq(diff(z(t), t), 5 * f * x(t) + f * y( + t) + (-3 * f + g) * z(t))) + with dotprodsimp(True): + dsolve(eq1) + + +@slow +def test_linear_neq_order1_type2_slow1(): + i, r1, c1, r2, c2, t = symbols('i, r1, c1, r2, c2, t') + x1 = Function('x1') + x2 = Function('x2') + + eq1 = r1*c1*Derivative(x1(t), t) + x1(t) - x2(t) - r1*i + eq2 = r2*c1*Derivative(x1(t), t) + r2*c2*Derivative(x2(t), t) + x2(t) - r2*i + eq = [eq1, eq2] + + # XXX: Solution is too complicated + [sol] = dsolve_system(eq, simplify=False, doit=False) + assert checksysodesol(eq, sol) == (True, [0, 0]) + + +# Regression test case for issue #9204 +# https://github.com/sympy/sympy/issues/9204 +@tooslow +def test_linear_new_order1_type2_de_lorentz_slow_check(): + m = Symbol("m", real=True) + q = Symbol("q", real=True) + t = Symbol("t", real=True) + + e1, e2, e3 = symbols("e1:4", real=True) + b1, b2, b3 = symbols("b1:4", real=True) + v1, v2, v3 = symbols("v1:4", cls=Function, real=True) + + eqs = [ + -e1*q + m*Derivative(v1(t), t) - q*(-b2*v3(t) + b3*v2(t)), + -e2*q + m*Derivative(v2(t), t) - q*(b1*v3(t) - b3*v1(t)), + -e3*q + m*Derivative(v3(t), t) - q*(-b1*v2(t) + b2*v1(t)) + ] + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0, 0]) + + +# Regression test case for issue #14001 +# https://github.com/sympy/sympy/issues/14001 +@slow +def test_linear_neq_order1_type2_slow_check(): + RC, t, C, Vs, L, R1, V0, I0 = symbols("RC t C Vs L R1 V0 I0") + V = Function("V") + I = Function("I") + system = [Eq(V(t).diff(t), -1/RC*V(t) + I(t)/C), Eq(I(t).diff(t), -R1/L*I(t) - 1/L*V(t) + Vs/L)] + [sol] = dsolve_system(system, simplify=False, doit=False) + + assert checksysodesol(system, sol) == (True, [0, 0]) + + +def _linear_3eq_order1_type4_long(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + + f = t ** 3 + log(t) + g = t ** 2 + sin(t) + + eq1 = (Eq(diff(x(t), t), (4*f + g)*x(t) - f*y(t) - 2*f*z(t)), + Eq(diff(y(t), t), 2*f*x(t) + (f + g)*y(t) - 2*f*z(t)), Eq(diff(z(t), t), 5*f*x(t) + f*y( + t) + (-3*f + g)*z(t))) + + dsolve_sol = dsolve(eq1) + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + + x_1 = sqrt(-t**6 - 8*t**3*log(t) + 8*t**3 - 16*log(t)**2 + 32*log(t) - 16) + x_2 = sqrt(3) + x_3 = 8324372644*C1*x_1*x_2 + 4162186322*C2*x_1*x_2 - 8324372644*C3*x_1*x_2 + x_4 = 1 / (1903457163*t**3 + 3825881643*x_1*x_2 + 7613828652*log(t) - 7613828652) + x_5 = exp(t**3/3 + t*x_1*x_2/4 - cos(t)) + x_6 = exp(t**3/3 - t*x_1*x_2/4 - cos(t)) + x_7 = exp(t**4/2 + t**3/3 + 2*t*log(t) - 2*t - cos(t)) + x_8 = 91238*C1*x_1*x_2 + 91238*C2*x_1*x_2 - 91238*C3*x_1*x_2 + x_9 = 1 / (66049*t**3 - 50629*x_1*x_2 + 264196*log(t) - 264196) + x_10 = 50629 * C1 / 25189 + 37909*C2/25189 - 50629*C3/25189 - x_3*x_4 + x_11 = -50629*C1/25189 - 12720*C2/25189 + 50629*C3/25189 + x_3*x_4 + sol = [Eq(x(t), x_10*x_5 + x_11*x_6 + x_7*(C1 - C2)), Eq(y(t), x_10*x_5 + x_11*x_6), Eq(z(t), x_5*( + -424*C1/257 - 167*C2/257 + 424*C3/257 - x_8*x_9) + x_6*(167*C1/257 + 424*C2/257 - + 167*C3/257 + x_8*x_9) + x_7*(C1 - C2))] + + assert dsolve_sol1 == sol + assert checksysodesol(eq1, dsolve_sol1) == (True, [0, 0, 0]) + + +@slow +def test_neq_order1_type4_slow_check1(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + + eqs = [Eq(diff(f(x), x), x*f(x) + x**2*g(x) + x), + Eq(diff(g(x), x), 2*x**2*f(x) + (x + 3*x**2)*g(x) + 1)] + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0]) + + +@slow +def test_neq_order1_type4_slow_check2(): + f, g, h = symbols("f, g, h", cls=Function) + x = Symbol("x") + + eqs = [ + Eq(Derivative(f(x), x), x*h(x) + f(x) + g(x) + 1), + Eq(Derivative(g(x), x), x*g(x) + f(x) + h(x) + 10), + Eq(Derivative(h(x), x), x*f(x) + x + g(x) + h(x)) + ] + with dotprodsimp(True): + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0, 0]) + + +def _neq_order1_type4_slow3(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + + eqs = [ + Eq(Derivative(f(x), x), x*f(x) + g(x) + sin(x)), + Eq(Derivative(g(x), x), x**2 + x*g(x) - f(x)) + ] + sol = [ + Eq(f(x), (C1/2 - I*C2/2 - I*Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 + Integral(-I*x**2*exp(-x**2/2 + - I*x)/2 + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - + I*x)*sin(x)/2 + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 + + I*x) + (C1/2 + I*C2/2 + I*Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 + Integral(-I*x**2*exp(-x**2/2 + - I*x)/2 + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - + I*x)*sin(x)/2 + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 - + I*x)), + Eq(g(x), (-I*C1/2 + C2/2 + Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 - + I*Integral(-I*x**2*exp(-x**2/2 - I*x)/2 + I*x**2*exp(-x**2/2 + + I*x)/2 + exp(-x**2/2 - I*x)*sin(x)/2 + exp(-x**2/2 + + I*x)*sin(x)/2, x)/2)*exp(x**2/2 - I*x) + (I*C1/2 + C2/2 + + Integral(x**2*exp(-x**2/2 - I*x)/2 + x**2*exp(-x**2/2 + I*x)/2 + + I*exp(-x**2/2 - I*x)*sin(x)/2 - I*exp(-x**2/2 + I*x)*sin(x)/2, + x)/2 + I*Integral(-I*x**2*exp(-x**2/2 - I*x)/2 + + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - I*x)*sin(x)/2 + + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 + I*x)) + ] + + return eqs, sol + + +def test_neq_order1_type4_slow3(): + eqs, sol = _neq_order1_type4_slow3() + assert dsolve_system(eqs, simplify=False, doit=False) == [sol] + # XXX: dsolve gives an error in integration: + # assert dsolve(eqs) == sol + # https://github.com/sympy/sympy/issues/20155 + + +@slow +def test_neq_order1_type4_slow_check3(): + eqs, sol = _neq_order1_type4_slow3() + assert checksysodesol(eqs, sol) == (True, [0, 0]) + + +@tooslow +@XFAIL +def test_linear_3eq_order1_type4_long_dsolve_slow_xfail(): + eq, sol = _linear_3eq_order1_type4_long() + + dsolve_sol = dsolve(eq) + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + + assert dsolve_sol1 == sol + + +@tooslow +def test_linear_3eq_order1_type4_long_dsolve_dotprodsimp(): + eq, sol = _linear_3eq_order1_type4_long() + + # XXX: Only works with dotprodsimp see + # test_linear_3eq_order1_type4_long_dsolve_slow_xfail which is too slow + with dotprodsimp(True): + dsolve_sol = dsolve(eq) + + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + assert dsolve_sol1 == sol + + +@tooslow +def test_linear_3eq_order1_type4_long_check(): + eq, sol = _linear_3eq_order1_type4_long() + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + +def test_dsolve_system(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))] + funcs = [f(x), g(x)] + + sol = [[Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))]] + assert dsolve_system(eqs, funcs=funcs, t=x, doit=True) == sol + + raises(ValueError, lambda: dsolve_system(1)) + raises(ValueError, lambda: dsolve_system(eqs, 1)) + raises(ValueError, lambda: dsolve_system(eqs, funcs, 1)) + raises(ValueError, lambda: dsolve_system(eqs, funcs[:1], x)) + + eq = (Eq(f(x).diff(x), 12 * f(x) - 6 * g(x)), Eq(g(x).diff(x) ** 2, 11 * f(x) + 3 * g(x))) + raises(NotImplementedError, lambda: dsolve_system(eq) == ([], [])) + + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)]) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], t=x) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], t=x, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, t=x, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], ics={f(0): 1, g(0): 1}) == ([], [])) + +def test_dsolve(): + + f, g = symbols('f g', cls=Function) + x, y = symbols('x y') + + eqs = [f(x).diff(x) - x, f(x).diff(x) + x] + with raises(ValueError): + dsolve(eqs) + + eqs = [f(x, y).diff(x)] + with raises(ValueError): + dsolve(eqs) + + eqs = [f(x, y).diff(x)+g(x).diff(x), g(x).diff(x)] + with raises(ValueError): + dsolve(eqs) + + +@slow +def test_higher_order1_slow1(): + x, y = symbols("x y", cls=Function) + t = symbols("t") + + eq = [ + Eq(diff(x(t),t,t), (log(t)+t**2)*diff(x(t),t)+(log(t)+t**2)*3*diff(y(t),t)), + Eq(diff(y(t),t,t), (log(t)+t**2)*2*diff(x(t),t)+(log(t)+t**2)*9*diff(y(t),t)) + ] + sol, = dsolve_system(eq, simplify=False, doit=False) + # The solution is too long to write out explicitly and checkodesol is too + # slow so we test for particular values of t: + for e in eq: + res = (e.lhs - e.rhs).subs({sol[0].lhs:sol[0].rhs, sol[1].lhs:sol[1].rhs}) + res = res.subs({d: d.doit(deep=False) for d in res.atoms(Derivative)}) + assert ratsimp(res.subs(t, 1)) == 0 + + +def test_second_order_type2_slow1(): + x, y, z = symbols('x, y, z', cls=Function) + t, l = symbols('t, l') + + eqs1 = [Eq(Derivative(x(t), (t, 2)), t*(2*x(t) + y(t))), + Eq(Derivative(y(t), (t, 2)), t*(-x(t) + 2*y(t)))] + sol1 = [Eq(x(t), I*C1*airyai(t*(2 - I)**(S(1)/3)) + I*C2*airybi(t*(2 - I)**(S(1)/3)) - I*C3*airyai(t*(2 + + I)**(S(1)/3)) - I*C4*airybi(t*(2 + I)**(S(1)/3))), + Eq(y(t), C1*airyai(t*(2 - I)**(S(1)/3)) + C2*airybi(t*(2 - I)**(S(1)/3)) + C3*airyai(t*(2 + I)**(S(1)/3)) + + C4*airybi(t*(2 + I)**(S(1)/3)))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + +@tooslow +@XFAIL +def test_nonlinear_3eq_order1_type1(): + a, b, c = symbols('a b c') + + eqs = [ + a * f(x).diff(x) - (b - c) * g(x) * h(x), + b * g(x).diff(x) - (c - a) * h(x) * f(x), + c * h(x).diff(x) - (a - b) * f(x) * g(x), + ] + + assert dsolve(eqs) # NotImplementedError + + +@XFAIL +def test_nonlinear_3eq_order1_type4(): + eqs = [ + Eq(f(x).diff(x), (2*h(x)*g(x) - 3*g(x)*h(x))), + Eq(g(x).diff(x), (4*f(x)*h(x) - 2*h(x)*f(x))), + Eq(h(x).diff(x), (3*g(x)*f(x) - 4*f(x)*g(x))), + ] + dsolve(eqs) # KeyError when matching + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +@tooslow +@XFAIL +def test_nonlinear_3eq_order1_type3(): + eqs = [ + Eq(f(x).diff(x), (2*f(x)**2 - 3 )), + Eq(g(x).diff(x), (4 - 2*h(x) )), + Eq(h(x).diff(x), (3*h(x) - 4*f(x)**2)), + ] + dsolve(eqs) # Not sure if this finishes... + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +@XFAIL +def test_nonlinear_3eq_order1_type5(): + eqs = [ + Eq(f(x).diff(x), f(x)*(2*f(x) - 3*g(x))), + Eq(g(x).diff(x), g(x)*(4*g(x) - 2*h(x))), + Eq(h(x).diff(x), h(x)*(3*h(x) - 4*f(x))), + ] + dsolve(eqs) # KeyError + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +def test_linear_2eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + k, l, m, n = symbols('k, l, m, n', Integer=True) + t = Symbol('t') + x0, y0 = symbols('x0, y0', cls=Function) + + eq1 = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23)) + sol1 = [Eq(x(t), C1*exp(t*(sqrt(6) + 3)) + C2*exp(t*(-sqrt(6) + 3)) - Rational(22, 3)), \ + Eq(y(t), C1*(2 + sqrt(6))*exp(t*(sqrt(6) + 3)) + C2*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) - Rational(5, 3))] + assert checksysodesol(eq1, sol1) == (True, [0, 0]) + + eq2 = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23)) + sol2 = [Eq(x(t), (C1*cos(sqrt(2)*t) + C2*sin(sqrt(2)*t))*exp(t) - Rational(58, 3)), \ + Eq(y(t), (-sqrt(2)*C1*sin(sqrt(2)*t) + sqrt(2)*C2*cos(sqrt(2)*t))*exp(t) - Rational(185, 3))] + assert checksysodesol(eq2, sol2) == (True, [0, 0]) + + eq3 = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t))) + sol3 = [Eq(x(t), (C1*exp(2*t) + C2*exp(-2*t))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (C1*exp(2*t) - C2*exp(-2*t))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq3, sol3) == (True, [0, 0]) + + eq4 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + sol4 = [Eq(x(t), (C1*cos((t**3)/3) + C2*sin((t**3)/3))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (-C1*sin((t**3)/3) + C2*cos((t**3)/3))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq4, sol4) == (True, [0, 0]) + + eq5 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t))) + sol5 = [Eq(x(t), (C1*exp((sqrt(77)/2 + Rational(9, 2))*(t**3)/3) + \ + C2*exp((-sqrt(77)/2 + Rational(9, 2))*(t**3)/3))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (C1*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(t**3)/3) + \ + C2*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(t**3)/3))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq5, sol5) == (True, [0, 0]) + + eq6 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), (1-t**2)*x(t) + (5*t+9*t**2)*y(t))) + sol6 = [Eq(x(t), C1*x0(t) + C2*x0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t)), \ + Eq(y(t), C1*y0(t) + C2*(y0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t) + \ + exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)))] + s = dsolve(eq6) + assert s == sol6 # too complicated to test with subs and simplify + # assert checksysodesol(eq10, sol10) == (True, [0, 0]) # this one fails + + +def test_nonlinear_2eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + eq1 = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol1 = [ + Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq1) == sol1 + assert checksysodesol(eq1, sol1) == (True, [0, 0]) + + eq2 = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5)) + sol2 = [ + Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq2) == sol2 + assert checksysodesol(eq2, sol2) == (True, [0, 0]) + + eq3 = (Eq(diff(x(t),t), y(t)*x(t)), Eq(diff(y(t),t), x(t)**3)) + tt = Rational(2, 3) + sol3 = [ + Eq(x(t), 6**tt/(6*(-sinh(sqrt(C1)*(C2 + t)/2)/sqrt(C1))**tt)), + Eq(y(t), sqrt(C1 + C1/sinh(sqrt(C1)*(C2 + t)/2)**2)/3)] + assert dsolve(eq3) == sol3 + # FIXME: assert checksysodesol(eq3, sol3) == (True, [0, 0]) + + eq4 = (Eq(diff(x(t),t),x(t)*y(t)*sin(t)**2), Eq(diff(y(t),t),y(t)**2*sin(t)**2)) + sol4 = {Eq(x(t), -2*exp(C1)/(C2*exp(C1) + t - sin(2*t)/2)), Eq(y(t), -2/(C1 + t - sin(2*t)/2))} + assert dsolve(eq4) == sol4 + # FIXME: assert checksysodesol(eq4, sol4) == (True, [0, 0]) + + eq5 = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2)) + sol5 = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)} + assert dsolve(eq5) == sol5 + assert checksysodesol(eq5, sol5) == (True, [0, 0]) + + eq6 = (Eq(diff(x(t),t),x(t)**2*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol6 = [ + Eq(x(t), 1/(C1 - 1/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 + (-1/(4*C2 + 4*t))**(Rational(-1, 4)))), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 + I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 - I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq6) == sol6 + assert checksysodesol(eq6, sol6) == (True, [0, 0]) + + +@slow +def test_nonlinear_3eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + t, u = symbols('t u') + eq1 = (4*diff(x(t),t) + 2*y(t)*z(t), 3*diff(y(t),t) - z(t)*x(t), 5*diff(z(t),t) - x(t)*y(t)) + sol1 = [Eq(4*Integral(1/(sqrt(-4*u**2 - 3*C1 + C2)*sqrt(-4*u**2 + 5*C1 - C2)), (u, x(t))), + C3 - sqrt(15)*t/15), Eq(3*Integral(1/(sqrt(-6*u**2 - C1 + 5*C2)*sqrt(3*u**2 + C1 - 4*C2)), + (u, y(t))), C3 + sqrt(5)*t/10), Eq(5*Integral(1/(sqrt(-10*u**2 - 3*C1 + C2)* + sqrt(5*u**2 + 4*C1 - C2)), (u, z(t))), C3 + sqrt(3)*t/6)] + assert [i.dummy_eq(j) for i, j in zip(dsolve(eq1), sol1)] + # FIXME: assert checksysodesol(eq1, sol1) == (True, [0, 0, 0]) + + eq2 = (4*diff(x(t),t) + 2*y(t)*z(t)*sin(t), 3*diff(y(t),t) - z(t)*x(t)*sin(t), 5*diff(z(t),t) - x(t)*y(t)*sin(t)) + sol2 = [Eq(3*Integral(1/(sqrt(-6*u**2 - C1 + 5*C2)*sqrt(3*u**2 + C1 - 4*C2)), (u, x(t))), C3 + + sqrt(5)*cos(t)/10), Eq(4*Integral(1/(sqrt(-4*u**2 - 3*C1 + C2)*sqrt(-4*u**2 + 5*C1 - C2)), + (u, y(t))), C3 - sqrt(15)*cos(t)/15), Eq(5*Integral(1/(sqrt(-10*u**2 - 3*C1 + C2)* + sqrt(5*u**2 + 4*C1 - C2)), (u, z(t))), C3 + sqrt(3)*cos(t)/6)] + assert [i.dummy_eq(j) for i, j in zip(dsolve(eq2), sol2)] + # FIXME: assert checksysodesol(eq2, sol2) == (True, [0, 0, 0]) + + +def test_C1_function_9239(): + t = Symbol('t') + C1 = Function('C1') + C2 = Function('C2') + C3 = Symbol('C3') + C4 = Symbol('C4') + eq = (Eq(diff(C1(t), t), 9*C2(t)), Eq(diff(C2(t), t), 12*C1(t))) + sol = [Eq(C1(t), 9*C3*exp(6*sqrt(3)*t) + 9*C4*exp(-6*sqrt(3)*t)), + Eq(C2(t), 6*sqrt(3)*C3*exp(6*sqrt(3)*t) - 6*sqrt(3)*C4*exp(-6*sqrt(3)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + +def test_dsolve_linsystem_symbol(): + eps = Symbol('epsilon', positive=True) + eq1 = (Eq(diff(f(x), x), -eps*g(x)), Eq(diff(g(x), x), eps*f(x))) + sol1 = [Eq(f(x), -C1*eps*cos(eps*x) - C2*eps*sin(eps*x)), + Eq(g(x), -C1*eps*sin(eps*x) + C2*eps*cos(eps*x))] + assert checksysodesol(eq1, sol1) == (True, [0, 0]) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_numeric.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..12abd38c80f07279ed41aefc7952762da0f9f430 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_numeric.py @@ -0,0 +1,139 @@ +from sympy.core.function import nfloat +from sympy.core.numbers import (Float, I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from mpmath import mnorm, mpf +from sympy.solvers import nsolve +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import raises, XFAIL +from sympy.utilities.decorator import conserve_mpmath_dps + +@XFAIL +def test_nsolve_fail(): + x = symbols('x') + # Sometimes it is better to use the numerator (issue 4829) + # but sometimes it is not (issue 11768) so leave this to + # the discretion of the user + ans = nsolve(x**2/(1 - x)/(1 - 2*x)**2 - 100, x, 0) + assert ans > 0.46 and ans < 0.47 + + +def test_nsolve_denominator(): + x = symbols('x') + # Test that nsolve uses the full expression (numerator and denominator). + ans = nsolve((x**2 + 3*x + 2)/(x + 2), -2.1) + # The root -2 was divided out, so make sure we don't find it. + assert ans == -1.0 + +def test_nsolve(): + # onedimensional + x = Symbol('x') + assert nsolve(sin(x), 2) - pi.evalf() < 1e-15 + assert nsolve(Eq(2*x, 2), x, -10) == nsolve(2*x - 2, -10) + # Testing checks on number of inputs + raises(TypeError, lambda: nsolve(Eq(2*x, 2))) + raises(TypeError, lambda: nsolve(Eq(2*x, 2), x, 1, 2)) + # multidimensional + x1 = Symbol('x1') + x2 = Symbol('x2') + f1 = 3 * x1**2 - 2 * x2**2 - 1 + f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8 + f = Matrix((f1, f2)).T + F = lambdify((x1, x2), f.T, modules='mpmath') + for x0 in [(-1, 1), (1, -2), (4, 4), (-4, -4)]: + x = nsolve(f, (x1, x2), x0, tol=1.e-8) + assert mnorm(F(*x), 1) <= 1.e-10 + # The Chinese mathematician Zhu Shijie was the very first to solve this + # nonlinear system 700 years ago (z was added to make it 3-dimensional) + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + f1 = -x + 2*y + f2 = (x**2 + x*(y**2 - 2) - 4*y) / (x + 4) + f3 = sqrt(x**2 + y**2)*z + f = Matrix((f1, f2, f3)).T + F = lambdify((x, y, z), f.T, modules='mpmath') + + def getroot(x0): + root = nsolve(f, (x, y, z), x0) + assert mnorm(F(*root), 1) <= 1.e-8 + return root + assert list(map(round, getroot((1, 1, 1)))) == [2, 1, 0] + assert nsolve([Eq( + f1, 0), Eq(f2, 0), Eq(f3, 0)], [x, y, z], (1, 1, 1)) # just see that it works + a = Symbol('a') + assert abs(nsolve(1/(0.001 + a)**3 - 6/(0.9 - a)**3, a, 0.3) - + mpf('0.31883011387318591')) < 1e-15 + + +def test_issue_6408(): + x = Symbol('x') + assert nsolve(Piecewise((x, x < 1), (x**2, True)), x, 2) == 0 + + +def test_issue_6408_integral(): + x, y = symbols('x y') + assert nsolve(Integral(x*y, (x, 0, 5)), y, 2) == 0 + + +@conserve_mpmath_dps +def test_increased_dps(): + # Issue 8564 + import mpmath + mpmath.mp.dps = 128 + x = Symbol('x') + e1 = x**2 - pi + q = nsolve(e1, x, 3.0) + + assert abs(sqrt(pi).evalf(128) - q) < 1e-128 + +def test_nsolve_precision(): + x, y = symbols('x y') + sol = nsolve(x**2 - pi, x, 3, prec=128) + assert abs(sqrt(pi).evalf(128) - sol) < 1e-128 + assert isinstance(sol, Float) + + sols = nsolve((y**2 - x, x**2 - pi), (x, y), (3, 3), prec=128) + assert isinstance(sols, Matrix) + assert sols.shape == (2, 1) + assert abs(sqrt(pi).evalf(128) - sols[0]) < 1e-128 + assert abs(sqrt(sqrt(pi)).evalf(128) - sols[1]) < 1e-128 + assert all(isinstance(i, Float) for i in sols) + +def test_nsolve_complex(): + x, y = symbols('x y') + + assert nsolve(x**2 + 2, 1j) == sqrt(2.)*I + assert nsolve(x**2 + 2, I) == sqrt(2.)*I + + assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I]) + assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I]) + +def test_nsolve_dict_kwarg(): + x, y = symbols('x y') + # one variable + assert nsolve(x**2 - 2, 1, dict = True) == \ + [{x: sqrt(2.)}] + # one variable with complex solution + assert nsolve(x**2 + 2, I, dict = True) == \ + [{x: sqrt(2.)*I}] + # two variables + assert nsolve([x**2 + y**2 - 5, x**2 - y**2 + 1], [x, y], [1, 1], dict = True) == \ + [{x: sqrt(2.), y: sqrt(3.)}] + +def test_nsolve_rational(): + x = symbols('x') + assert nsolve(x - Rational(1, 3), 0, prec=100) == Rational(1, 3).evalf(100) + + +def test_issue_14950(): + x = Matrix(symbols('t s')) + x0 = Matrix([17, 23]) + eqn = x + x0 + assert nsolve(eqn, x, x0) == nfloat(-x0) + assert nsolve(eqn.T, x.T, x0.T) == nfloat(-x0) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_pde.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_pde.py new file mode 100644 index 0000000000000000000000000000000000000000..948d90c7be21a9e0e03753e723ef04f1fb08a5d6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_pde.py @@ -0,0 +1,239 @@ +from sympy.core.function import (Derivative as D, Function) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.core import S +from sympy.solvers.pde import (pde_separate, pde_separate_add, pde_separate_mul, + pdsolve, classify_pde, checkpdesol) +from sympy.testing.pytest import raises + + +a, b, c, x, y = symbols('a b c x y') + +def test_pde_separate_add(): + x, y, z, t = symbols("x,y,z,t") + F, T, X, Y, Z, u = map(Function, 'FTXYZu') + + eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t))) + res = pde_separate_add(eq, u(x, t), [X(x), T(t)]) + assert res == [D(X(x), x)*exp(-X(x)), D(T(t), t)*exp(T(t))] + + +def test_pde_separate(): + x, y, z, t = symbols("x,y,z,t") + F, T, X, Y, Z, u = map(Function, 'FTXYZu') + + eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t))) + raises(ValueError, lambda: pde_separate(eq, u(x, t), [X(x), T(t)], 'div')) + + +def test_pde_separate_mul(): + x, y, z, t = symbols("x,y,z,t") + c = Symbol("C", real=True) + Phi = Function('Phi') + F, R, T, X, Y, Z, u = map(Function, 'FRTXYZu') + r, theta, z = symbols('r,theta,z') + + # Something simple :) + eq = Eq(D(F(x, y, z), x) + D(F(x, y, z), y) + D(F(x, y, z), z), 0) + + # Duplicate arguments in functions + raises( + ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), u(z, z)])) + # Wrong number of arguments + raises(ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), Y(y)])) + # Wrong variables: [x, y] -> [x, z] + raises( + ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(t), Y(x, y)])) + + assert pde_separate_mul(eq, F(x, y, z), [Y(y), u(x, z)]) == \ + [D(Y(y), y)/Y(y), -D(u(x, z), x)/u(x, z) - D(u(x, z), z)/u(x, z)] + assert pde_separate_mul(eq, F(x, y, z), [X(x), Y(y), Z(z)]) == \ + [D(X(x), x)/X(x), -D(Z(z), z)/Z(z) - D(Y(y), y)/Y(y)] + + # wave equation + wave = Eq(D(u(x, t), t, t), c**2*D(u(x, t), x, x)) + res = pde_separate_mul(wave, u(x, t), [X(x), T(t)]) + assert res == [D(X(x), x, x)/X(x), D(T(t), t, t)/(c**2*T(t))] + + # Laplace equation in cylindrical coords + eq = Eq(1/r * D(Phi(r, theta, z), r) + D(Phi(r, theta, z), r, 2) + + 1/r**2 * D(Phi(r, theta, z), theta, 2) + D(Phi(r, theta, z), z, 2), 0) + # Separate z + res = pde_separate_mul(eq, Phi(r, theta, z), [Z(z), u(theta, r)]) + assert res == [D(Z(z), z, z)/Z(z), + -D(u(theta, r), r, r)/u(theta, r) - + D(u(theta, r), r)/(r*u(theta, r)) - + D(u(theta, r), theta, theta)/(r**2*u(theta, r))] + # Lets use the result to create a new equation... + eq = Eq(res[1], c) + # ...and separate theta... + res = pde_separate_mul(eq, u(theta, r), [T(theta), R(r)]) + assert res == [D(T(theta), theta, theta)/T(theta), + -r*D(R(r), r)/R(r) - r**2*D(R(r), r, r)/R(r) - c*r**2] + # ...or r... + res = pde_separate_mul(eq, u(theta, r), [R(r), T(theta)]) + assert res == [r*D(R(r), r)/R(r) + r**2*D(R(r), r, r)/R(r) + c*r**2, + -D(T(theta), theta, theta)/T(theta)] + + +def test_issue_11726(): + x, t = symbols("x t") + f = symbols("f", cls=Function) + X, T = symbols("X T", cls=Function) + + u = f(x, t) + eq = u.diff(x, 2) - u.diff(t, 2) + res = pde_separate(eq, u, [T(x), X(t)]) + assert res == [D(T(x), x, x)/T(x),D(X(t), t, t)/X(t)] + + +def test_pde_classify(): + # When more number of hints are added, add tests for classifying here. + f = Function('f') + eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y) + eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y) + eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y) + eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y) + eq5 = x**2*f(x,y) + x*f(x,y).diff(x) + x*y*f(x,y).diff(y) + eq6 = y*x**2*f(x,y) + y*f(x,y).diff(x) + f(x,y).diff(y) + for eq in [eq1, eq2, eq3]: + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + for eq in [eq4, eq5, eq6]: + assert classify_pde(eq) == ('1st_linear_variable_coeff',) + + +def test_checkpdesol(): + f, F = map(Function, ['f', 'F']) + eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y) + eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y) + eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y) + for eq in [eq1, eq2, eq3]: + assert checkpdesol(eq, pdsolve(eq))[0] + eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y) + eq5 = 2*f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y) + eq6 = f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y) + assert checkpdesol(eq4, [pdsolve(eq5), pdsolve(eq6)]) == [ + (False, (x - 2)*F(3*x - y)*exp(-x/S(5) - 3*y/S(5))), + (False, (x - 1)*F(3*x - y)*exp(-x/S(10) - 3*y/S(10)))] + for eq in [eq4, eq5, eq6]: + assert checkpdesol(eq, pdsolve(eq))[0] + sol = pdsolve(eq4) + sol4 = Eq(sol.lhs - sol.rhs, 0) + raises(NotImplementedError, lambda: + checkpdesol(eq4, sol4, solve_for_func=False)) + + +def test_solvefun(): + f, F, G, H = map(Function, ['f', 'F', 'G', 'H']) + eq1 = f(x,y) + f(x,y).diff(x) + f(x,y).diff(y) + assert pdsolve(eq1) == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2)) + assert pdsolve(eq1, solvefun=G) == Eq(f(x, y), G(x - y)*exp(-x/2 - y/2)) + assert pdsolve(eq1, solvefun=H) == Eq(f(x, y), H(x - y)*exp(-x/2 - y/2)) + + +def test_pde_1st_linear_constant_coeff_homogeneous(): + f, F = map(Function, ['f', 'F']) + u = f(x, y) + eq = 2*u + u.diff(x) + u.diff(y) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(x - y)*exp(-x - y)) + assert checkpdesol(eq, sol)[0] + + eq = 4 + (3*u.diff(x)/u) + (2*u.diff(y)/u) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(2*x - 3*y)*exp(-S(12)*x/13 - S(8)*y/13)) + assert checkpdesol(eq, sol)[0] + + eq = u + (6*u.diff(x)) + (7*u.diff(y)) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(7*x - 6*y)*exp(-6*x/S(85) - 7*y/S(85))) + assert checkpdesol(eq, sol)[0] + + eq = a*u + b*u.diff(x) + c*u.diff(y) + sol = pdsolve(eq) + assert checkpdesol(eq, sol)[0] + + +def test_pde_1st_linear_constant_coeff(): + f, F = map(Function, ['f', 'F']) + u = f(x,y) + eq = -2*u.diff(x) + 4*u.diff(y) + 5*u - exp(x + 3*y) + sol = pdsolve(eq) + assert sol == Eq(f(x,y), + (F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y)) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = (u.diff(x)/u) + (u.diff(y)/u) + 1 - (exp(x + y)/u) + sol = pdsolve(eq) + assert sol == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2) + exp(x + y)/3) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = 2*u + -u.diff(x) + 3*u.diff(y) + sin(x) + sol = pdsolve(eq) + assert sol == Eq(f(x, y), + F(3*x + y)*exp(x/5 - 3*y/5) - 2*sin(x)/5 - cos(x)/5) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = u + u.diff(x) + u.diff(y) + x*y + sol = pdsolve(eq) + assert sol.expand() == Eq(f(x, y), + x + y + (x - y)**2/4 - (x + y)**2/4 + F(x - y)*exp(-x/2 - y/2) - 2).expand() + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + eq = u + u.diff(x) + u.diff(y) + log(x) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + + +def test_pdsolve_all(): + f, F = map(Function, ['f', 'F']) + u = f(x,y) + eq = u + u.diff(x) + u.diff(y) + x**2*y + sol = pdsolve(eq, hint = 'all') + keys = ['1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral', 'default', 'order'] + assert sorted(sol.keys()) == keys + assert sol['order'] == 1 + assert sol['default'] == '1st_linear_constant_coeff' + assert sol['1st_linear_constant_coeff'].expand() == Eq(f(x, y), + -x**2*y + x**2 + 2*x*y - 4*x - 2*y + F(x - y)*exp(-x/2 - y/2) + 6).expand() + + +def test_pdsolve_variable_coeff(): + f, F = map(Function, ['f', 'F']) + u = f(x, y) + eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2 + sol = pdsolve(eq, hint="1st_linear_variable_coeff") + assert sol == Eq(u, F(x*y)*exp(y**2/2) + 1) + assert checkpdesol(eq, sol)[0] + + eq = x**2*u + x*u.diff(x) + x*y*u.diff(y) + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(y*exp(-x))*exp(-x**2/2)) + assert checkpdesol(eq, sol)[0] + + eq = y*x**2*u + y*u.diff(x) + u.diff(y) + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(-2*x + y**2)*exp(-x**3/3)) + assert checkpdesol(eq, sol)[0] + + eq = exp(x)**2*(u.diff(x)) + y + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, y*exp(-2*x)/2 + F(y)) + assert checkpdesol(eq, sol)[0] + + eq = exp(2*x)*(u.diff(y)) + y*u - u + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(x)*exp(-y*(y - 2)*exp(-2*x)/2)) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_polysys.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_polysys.py new file mode 100644 index 0000000000000000000000000000000000000000..a119591a0354ba377a18767eae6e8b7812810a0d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_polysys.py @@ -0,0 +1,462 @@ +"""Tests for solvers of systems of polynomial equations. """ +from sympy.polys.domains import ZZ, QQ_I +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.polyerrors import UnsolvableFactorError +from sympy.polys.polyoptions import Options +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.solvers.solvers import solve +from sympy.utilities.iterables import flatten +from sympy.abc import a, b, c, x, y, z +from sympy.polys import PolynomialError +from sympy.solvers.polysys import (solve_poly_system, + solve_triangulated, + solve_biquadratic, SolveFailed, + solve_generic, factor_system_bool, + factor_system_cond, factor_system_poly, + factor_system, _factor_sets, _factor_sets_slow) +from sympy.polys.polytools import parallel_poly_from_expr +from sympy.testing.pytest import raises +from sympy.core.relational import Eq +from sympy.functions.elementary.trigonometric import sin, cos + +from sympy.functions.elementary.exponential import exp + + +def test_solve_poly_system(): + assert solve_poly_system([x - 1], x) == [(S.One,)] + + assert solve_poly_system([y - x, y - x - 1], x, y) is None + + assert solve_poly_system([y - x**2, y + x**2], x, y) == [(S.Zero, S.Zero)] + + assert solve_poly_system([2*x - 3, y*Rational(3, 2) - 2*x, z - 5*y], x, y, z) == \ + [(Rational(3, 2), Integer(2), Integer(10))] + + assert solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y) == \ + [(0, 0), (2, -sqrt(2)), (2, sqrt(2))] + + assert solve_poly_system([y - x**2, y + x**2 + 1], x, y) == \ + [(-I*sqrt(S.Half), Rational(-1, 2)), (I*sqrt(S.Half), Rational(-1, 2))] + + f_1 = x**2 + y + z - 1 + f_2 = x + y**2 + z - 1 + f_3 = x + y + z**2 - 1 + + a, b = sqrt(2) - 1, -sqrt(2) - 1 + + assert solve_poly_system([f_1, f_2, f_3], x, y, z) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)] + + solution = [(1, -1), (1, 1)] + + assert solve_poly_system([Poly(x**2 - y**2), Poly(x - 1)]) == solution + assert solve_poly_system([x**2 - y**2, x - 1], x, y) == solution + assert solve_poly_system([x**2 - y**2, x - 1]) == solution + + assert solve_poly_system( + [x + x*y - 3, y + x*y - 4], x, y) == [(-3, -2), (1, 2)] + + raises(NotImplementedError, lambda: solve_poly_system([x**3 - y**3], x, y)) + raises(NotImplementedError, lambda: solve_poly_system( + [z, -2*x*y**2 + x + y**2*z, y**2*(-z - 4) + 2])) + raises(PolynomialError, lambda: solve_poly_system([1/x], x)) + + raises(NotImplementedError, lambda: solve_poly_system( + [x-1,], (x, y))) + raises(NotImplementedError, lambda: solve_poly_system( + [y-1,], (x, y))) + + # solve_poly_system should ideally construct solutions using + # CRootOf for the following four tests + assert solve_poly_system([x**5 - x + 1], [x], strict=False) == [] + raises(UnsolvableFactorError, lambda: solve_poly_system( + [x**5 - x + 1], [x], strict=True)) + + assert solve_poly_system([(x - 1)*(x**5 - x + 1), y**2 - 1], [x, y], + strict=False) == [(1, -1), (1, 1)] + raises(UnsolvableFactorError, + lambda: solve_poly_system([(x - 1)*(x**5 - x + 1), y**2-1], + [x, y], strict=True)) + + +def test_solve_generic(): + NewOption = Options((x, y), {'domain': 'ZZ'}) + assert solve_generic([x**2 - 2*y**2, y**2 - y + 1], NewOption) == \ + [(-sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2), + (sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2), + (-sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2), + (sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2)] + + # solve_generic should ideally construct solutions using + # CRootOf for the following two tests + assert solve_generic( + [2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=False) == \ + [(Rational(1, 2), 1)] + raises(UnsolvableFactorError, lambda: solve_generic( + [2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=True)) + + +def test_solve_biquadratic(): + x0, y0, x1, y1, r = symbols('x0 y0 x1 y1 r') + + f_1 = (x - 1)**2 + (y - 1)**2 - r**2 + f_2 = (x - 2)**2 + (y - 2)**2 - r**2 + s = sqrt(2*r**2 - 1) + a = (3 - s)/2 + b = (3 + s)/2 + assert solve_poly_system([f_1, f_2], x, y) == [(a, b), (b, a)] + + f_1 = (x - 1)**2 + (y - 2)**2 - r**2 + f_2 = (x - 1)**2 + (y - 1)**2 - r**2 + + assert solve_poly_system([f_1, f_2], x, y) == \ + [(1 - sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2)), + (1 + sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2))] + + query = lambda expr: expr.is_Pow and expr.exp is S.Half + + f_1 = (x - 1 )**2 + (y - 2)**2 - r**2 + f_2 = (x - x1)**2 + (y - 1)**2 - r**2 + + result = solve_poly_system([f_1, f_2], x, y) + + assert len(result) == 2 and all(len(r) == 2 for r in result) + assert all(r.count(query) == 1 for r in flatten(result)) + + f_1 = (x - x0)**2 + (y - y0)**2 - r**2 + f_2 = (x - x1)**2 + (y - y1)**2 - r**2 + + result = solve_poly_system([f_1, f_2], x, y) + + assert len(result) == 2 and all(len(r) == 2 for r in result) + assert all(len(r.find(query)) == 1 for r in flatten(result)) + + s1 = (x*y - y, x**2 - x) + assert solve(s1) == [{x: 1}, {x: 0, y: 0}] + s2 = (x*y - x, y**2 - y) + assert solve(s2) == [{y: 1}, {x: 0, y: 0}] + gens = (x, y) + for seq in (s1, s2): + (f, g), opt = parallel_poly_from_expr(seq, *gens) + raises(SolveFailed, lambda: solve_biquadratic(f, g, opt)) + seq = (x**2 + y**2 - 2, y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == [ + (-1, -1), (-1, 1), (1, -1), (1, 1)] + ans = [(0, -1), (0, 1)] + seq = (x**2 + y**2 - 1, y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == ans + seq = (x**2 + y**2 - 1, x**2 - x + y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == ans + + +def test_solve_triangulated(): + f_1 = x**2 + y + z - 1 + f_2 = x + y**2 + z - 1 + f_3 = x + y + z**2 - 1 + + a, b = sqrt(2) - 1, -sqrt(2) - 1 + + assert solve_triangulated([f_1, f_2, f_3], x, y, z) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0)] + + dom = QQ.algebraic_field(sqrt(2)) + + assert solve_triangulated([f_1, f_2, f_3], x, y, z, domain=dom) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)] + + a, b = CRootOf(z**2 + 2*z - 1, 0), CRootOf(z**2 + 2*z - 1, 1) + assert solve_triangulated([f_1, f_2, f_3], x, y, z, extension=True) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)] + + +def test_solve_issue_3686(): + roots = solve_poly_system([((x - 5)**2/250000 + (y - Rational(5, 10))**2/250000) - 1, x], x, y) + assert roots == [(0, S.Half - 15*sqrt(1111)), (0, S.Half + 15*sqrt(1111))] + + roots = solve_poly_system([((x - 5)**2/250000 + (y - 5.0/10)**2/250000) - 1, x], x, y) + # TODO: does this really have to be so complicated?! + assert len(roots) == 2 + assert roots[0][0] == 0 + assert roots[0][1].epsilon_eq(-499.474999374969, 1e12) + assert roots[1][0] == 0 + assert roots[1][1].epsilon_eq(500.474999374969, 1e12) + + +def test_factor_system(): + + assert factor_system([x**2 + 2*x + 1]) == [[x + 1]] + assert factor_system([x**2 + 2*x + 1, y**2 + 2*y + 1]) == [[x + 1, y + 1]] + assert factor_system([x**2 + 1]) == [[x**2 + 1]] + assert factor_system([]) == [[]] + + assert factor_system([x**2 + y**2 + 2*x*y, x**2 - 2], extension=sqrt(2)) == [ + [x + y, x + sqrt(2)], + [x + y, x - sqrt(2)], + ] + + assert factor_system([x**2 + 1, y**2 + 1], gaussian=True) == [ + [x + I, y + I], + [x + I, y - I], + [x - I, y + I], + [x - I, y - I], + ] + + assert factor_system([x**2 + 1, y**2 + 1], domain=QQ_I) == [ + [x + I, y + I], + [x + I, y - I], + [x - I, y + I], + [x - I, y - I], + ] + + assert factor_system([0]) == [[]] + assert factor_system([1]) == [] + assert factor_system([0 , x]) == [[x]] + assert factor_system([1, 0, x]) == [] + + assert factor_system([x**4 - 1, y**6 - 1]) == [ + [x**2 + 1, y**2 + y + 1], + [x**2 + 1, y**2 - y + 1], + [x**2 + 1, y + 1], + [x**2 + 1, y - 1], + [x + 1, y**2 + y + 1], + [x + 1, y**2 - y + 1], + [x - 1, y**2 + y + 1], + [x - 1, y**2 - y + 1], + [x + 1, y + 1], + [x + 1, y - 1], + [x - 1, y + 1], + [x - 1, y - 1], + ] + + assert factor_system([(x - 1)*(y - 2), (y - 2)*(z - 3)]) == [ + [x - 1, z - 3], + [y - 2] + ] + + assert factor_system([sin(x)**2 + cos(x)**2 - 1, x]) == [ + [x, sin(x)**2 + cos(x)**2 - 1], + ] + + assert factor_system([sin(x)**2 + cos(x)**2 - 1]) == [ + [sin(x)**2 + cos(x)**2 - 1] + ] + + assert factor_system([sin(x)**2 + cos(x)**2]) == [ + [sin(x)**2 + cos(x)**2] + ] + + assert factor_system([a*x, y, a]) == [[y, a]] + + assert factor_system([a*x, y, a], [x, y]) == [] + + assert factor_system([a ** 2 * x, y], [x, y]) == [[x, y]] + + assert factor_system([a*x*(x - 1), b*y, c], [x, y]) == [] + + assert factor_system([a*x*(x - 1), b*y, c], [x, y, c]) == [ + [x - 1, y, c], + [x, y, c], + ] + + assert factor_system([a*x*(x - 1), b*y, c]) == [ + [x - 1, y, c], + [x, y, c], + [x - 1, b, c], + [x, b, c], + [y, a, c], + [a, b, c], + ] + + assert factor_system([x**2 - 2], [y]) == [] + + assert factor_system([x**2 - 2], [x]) == [[x**2 - 2]] + + assert factor_system([cos(x)**2 - sin(x)**2, cos(x)**2 + sin(x)**2 - 1]) == [ + [sin(x)**2 + cos(x)**2 - 1, sin(x) + cos(x)], + [sin(x)**2 + cos(x)**2 - 1, -sin(x) + cos(x)], + ] + + assert factor_system([(cos(x) + sin(x))**2 - 1, cos(x)**2 - sin(x)**2 - cos(2*x)]) == [ + [sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) + 1], + [sin(x)**2 - cos(x)**2 + cos(2*x), sin(x) + cos(x) - 1], + ] + + assert factor_system([(cos(x) + sin(x))*exp(y) - 1, (cos(x) - sin(x))*exp(y) - 1]) == [ + [exp(y)*sin(x) + exp(y)*cos(x) - 1, -exp(y)*sin(x) + exp(y)*cos(x) - 1] + ] + + +def test_factor_system_poly(): + + px = lambda e: Poly(e, x) + pxab = lambda e: Poly(e, x, domain=ZZ[a, b]) + pxI = lambda e: Poly(e, x, domain=QQ_I) + pxyz = lambda e: Poly(e, (x, y, z)) + + assert factor_system_poly([px(x**2 - 1), px(x**2 - 4)]) == [ + [px(x + 2), px(x + 1)], + [px(x + 2), px(x - 1)], + [px(x + 1), px(x - 2)], + [px(x - 1), px(x - 2)], + ] + + assert factor_system_poly([px(x**2 - 1)]) == [[px(x + 1)], [px(x - 1)]] + + assert factor_system_poly([pxyz(x**2*y - y), pxyz(x**2*z - z)]) == [ + [pxyz(x + 1)], + [pxyz(x - 1)], + [pxyz(y), pxyz(z)], + ] + + assert factor_system_poly([px(x**2*(x - 1)**2), px(x*(x - 1))]) == [ + [px(x)], + [px(x - 1)], + ] + + assert factor_system_poly([pxyz(x**2 + y*x), pxyz(x**2 + z*x)]) == [ + [pxyz(x + y), pxyz(x + z)], + [pxyz(x)], + ] + + assert factor_system_poly([pxab((a - 1)*(x - 2)), pxab((b - 3)*(x - 2))]) == [ + [pxab(x - 2)], + [pxab(a - 1), pxab(b - 3)], + ] + + assert factor_system_poly([pxI(x**2 + 1)]) == [[pxI(x + I)], [pxI(x - I)]] + + assert factor_system_poly([]) == [[]] + + assert factor_system_poly([px(1)]) == [] + assert factor_system_poly([px(0), px(x)]) == [[px(x)]] + + +def test_factor_system_cond(): + + assert factor_system_cond([x ** 2 - 1, x ** 2 - 4]) == [ + [x + 2, x + 1], + [x + 2, x - 1], + [x + 1, x - 2], + [x - 1, x - 2], + ] + + assert factor_system_cond([1]) == [] + assert factor_system_cond([0]) == [[]] + assert factor_system_cond([1, x]) == [] + assert factor_system_cond([0, x]) == [[x]] + assert factor_system_cond([]) == [[]] + + assert factor_system_cond([x**2 + y*x]) == [[x + y], [x]] + + assert factor_system_cond([(a - 1)*(x - 2), (b - 3)*(x - 2)], [x]) == [ + [x - 2], + [a - 1, b - 3], + ] + + assert factor_system_cond([a * (x - 1), b], [x]) == [[x - 1, b], [a, b]] + + assert factor_system_cond([a*x*(x-1), b*y, c], [x, y]) == [ + [x - 1, y, c], + [x, y, c], + [x - 1, b, c], + [x, b, c], + [y, a, c], + [a, b, c], + ] + + assert factor_system_cond([x*(x-1), y], [x, y]) == [[x - 1, y], [x, y]] + + assert factor_system_cond([a*x, y, a], [x, y]) == [[y, a]] + + assert factor_system_cond([a*x, b*x], [x, y]) == [[x], [a, b]] + + assert factor_system_cond([a*b*x, y], [x, y]) == [[x, y], [y, a*b]] + + assert factor_system_cond([a*b*x, y]) == [[x, y], [y, a], [y, b]] + + assert factor_system_cond([a**2*x, y], [x, y]) == [[x, y], [y, a]] + +def test_factor_system_bool(): + + eqs = [a*(x - 1)*(y - 1), b*(x - 2)*(y - 1)*(y - 2)] + assert factor_system_bool(eqs, [x, y]) == ( + Eq(y - 1, 0) + | (Eq(a, 0) & Eq(b, 0)) + | (Eq(a, 0) & Eq(x - 2, 0)) + | (Eq(a, 0) & Eq(y - 2, 0)) + | (Eq(b, 0) & Eq(x - 1, 0)) + | (Eq(x - 2, 0) & Eq(x - 1, 0)) + | (Eq(x - 1, 0) & Eq(y - 2, 0)) + ) + + assert factor_system_bool([x - 1], [x]) == Eq(x - 1, 0) + + assert factor_system_bool([(x - 1)*(x - 2)], [x]) == Eq(x - 2, 0) | Eq(x - 1, 0) + + assert factor_system_bool([], [x]) == True + assert factor_system_bool([0], [x]) == True + assert factor_system_bool([1], [x]) == False + assert factor_system_bool([a], [x]) == Eq(a, 0) + + assert factor_system_bool([a * x, y, a], [x, y]) == Eq(a, 0) & Eq(y, 0) + + assert (factor_system_bool([a*x, b*y*x, a], [x, y]) == ( + Eq(a, 0) & Eq(b, 0)) + | (Eq(a, 0) & Eq(x, 0)) + | (Eq(a, 0) & Eq(y, 0))) + + assert (factor_system_bool([a*x, b*x], [x, y]) == Eq(x, 0) | + (Eq(a, 0) & Eq(b, 0))) + + assert (factor_system_bool([a*b*x, y], [x, y]) == ( + Eq(x, 0) & Eq(y, 0)) | + (Eq(y, 0) & Eq(a*b, 0))) + + assert (factor_system_bool([a**2*x, y], [x, y]) == ( + Eq(a, 0) & Eq(y, 0)) | + (Eq(x, 0) & Eq(y, 0))) + + assert factor_system_bool([a*x*y, b*y*z], [x, y, z]) == ( + Eq(y, 0) + | (Eq(a, 0) & Eq(b, 0)) + | (Eq(a, 0) & Eq(z, 0)) + | (Eq(b, 0) & Eq(x, 0)) + | (Eq(x, 0) & Eq(z, 0)) + ) + + assert factor_system_bool([a*(x - 1), b], [x]) == ( + (Eq(a, 0) & Eq(b, 0)) + | (Eq(x - 1, 0) & Eq(b, 0)) + ) + + +def test_factor_sets(): + # + from random import randint + + def generate_random_system(n_eqs=3, n_factors=2, max_val=10): + return [ + [randint(0, max_val) for _ in range(randint(1, n_factors))] + for _ in range(n_eqs) + ] + + test_cases = [ + [[1, 2], [1, 3]], + [[1, 2], [3, 4]], + [[1], [1, 2], [2]], + ] + + for case in test_cases: + assert _factor_sets(case) == _factor_sets_slow(case) + + for _ in range(100): + system = generate_random_system() + assert _factor_sets(system) == _factor_sets_slow(system) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_recurr.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_recurr.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6306b51a5cf33ccd9fae131430a24690d540a7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_recurr.py @@ -0,0 +1,295 @@ +from sympy.core.function import (Function, Lambda, expand) +from sympy.core.numbers import (I, Rational) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import (rf, binomial, factorial) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.polys.polytools import factor +from sympy.solvers.recurr import rsolve, rsolve_hyper, rsolve_poly, rsolve_ratio +from sympy.testing.pytest import raises, slow, XFAIL +from sympy.abc import a, b + +y = Function('y') +n, k = symbols('n,k', integer=True) +C0, C1, C2 = symbols('C0,C1,C2') + + +def test_rsolve_poly(): + assert rsolve_poly([-1, -1, 1], 0, n) == 0 + assert rsolve_poly([-1, -1, 1], 1, n) == -1 + + assert rsolve_poly([-1, n + 1], n, n) == 1 + assert rsolve_poly([-1, 1], n, n) == C0 + (n**2 - n)/2 + assert rsolve_poly([-n - 1, n], 1, n) == C0*n - 1 + assert rsolve_poly([-4*n - 2, 1], 4*n + 1, n) == -1 + + assert rsolve_poly([-1, 1], n**5 + n**3, n) == \ + C0 - n**3 / 2 - n**5 / 2 + n**2 / 6 + n**6 / 6 + 2*n**4 / 3 + + +def test_rsolve_ratio(): + solution = rsolve_ratio([-2*n**3 + n**2 + 2*n - 1, 2*n**3 + n**2 - 6*n, + -2*n**3 - 11*n**2 - 18*n - 9, 2*n**3 + 13*n**2 + 22*n + 8], 0, n) + assert solution == C0*(2*n - 3)/(n**2 - 1)/2 + + +def test_rsolve_hyper(): + assert rsolve_hyper([-1, -1, 1], 0, n) in [ + C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n, + C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n, + ] + + assert rsolve_hyper([n**2 - 2, -2*n - 1, 1], 0, n) in [ + C0*rf(sqrt(2), n) + C1*rf(-sqrt(2), n), + C1*rf(sqrt(2), n) + C0*rf(-sqrt(2), n), + ] + + assert rsolve_hyper([n**2 - k, -2*n - 1, 1], 0, n) in [ + C0*rf(sqrt(k), n) + C1*rf(-sqrt(k), n), + C1*rf(sqrt(k), n) + C0*rf(-sqrt(k), n), + ] + + assert rsolve_hyper( + [2*n*(n + 1), -n**2 - 3*n + 2, n - 1], 0, n) == C1*factorial(n) + C0*2**n + + assert rsolve_hyper( + [n + 2, -(2*n + 3)*(17*n**2 + 51*n + 39), n + 1], 0, n) == 0 + + assert rsolve_hyper([-n - 1, -1, 1], 0, n) == 0 + + assert rsolve_hyper([-1, 1], n, n).expand() == C0 + n**2/2 - n/2 + + assert rsolve_hyper([-1, 1], 1 + n, n).expand() == C0 + n**2/2 + n/2 + + assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n + + assert rsolve_hyper([-a, 1],0,n).expand() == C0*a**n + + assert rsolve_hyper([-a, 0, 1], 0, n).expand() == (-1)**n*C1*a**(n/2) + C0*a**(n/2) + + assert rsolve_hyper([1, 1, 1], 0, n).expand() == \ + C0*(Rational(-1, 2) - sqrt(3)*I/2)**n + C1*(Rational(-1, 2) + sqrt(3)*I/2)**n + + assert rsolve_hyper([1, -2*n/a - 2/a, 1], 0, n) == 0 + + +@XFAIL +def test_rsolve_ratio_missed(): + # this arises during computation + # assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n + assert rsolve_ratio([-n, n + 2], n, n) is not None + + +def recurrence_term(c, f): + """Compute RHS of recurrence in f(n) with coefficients in c.""" + return sum(c[i]*f.subs(n, n + i) for i in range(len(c))) + + +def test_rsolve_bulk(): + """Some bulk-generated tests.""" + funcs = [ n, n + 1, n**2, n**3, n**4, n + n**2, 27*n + 52*n**2 - 3* + n**3 + 12*n**4 - 52*n**5 ] + coeffs = [ [-2, 1], [-2, -1, 1], [-1, 1, 1, -1, 1], [-n, 1], [n**2 - + n + 12, 1] ] + for p in funcs: + # compute difference + for c in coeffs: + q = recurrence_term(c, p) + if p.is_polynomial(n): + assert rsolve_poly(c, q, n) == p + # See issue 3956: + if p.is_hypergeometric(n) and len(c) <= 3: + assert rsolve_hyper(c, q, n).subs(zip(symbols('C:3'), [0, 0, 0])).expand() == p + + +def test_rsolve_0_sol_homogeneous(): + # fixed by cherry-pick from + # https://github.com/diofant/diofant/commit/e1d2e52125199eb3df59f12e8944f8a5f24b00a5 + assert rsolve_hyper([n**2 - n + 12, 1], n*(n**2 - n + 12) + n + 1, n) == n + + +def test_rsolve(): + f = y(n + 2) - y(n + 1) - y(n) + h = sqrt(5)*(S.Half + S.Half*sqrt(5))**n \ + - sqrt(5)*(S.Half - S.Half*sqrt(5))**n + + assert rsolve(f, y(n)) in [ + C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n, + C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n, + ] + + assert rsolve(f, y(n), [0, 5]) == h + assert rsolve(f, y(n), {0: 0, 1: 5}) == h + assert rsolve(f, y(n), {y(0): 0, y(1): 5}) == h + assert rsolve(y(n) - y(n - 1) - y(n - 2), y(n), [0, 5]) == h + assert rsolve(Eq(y(n), y(n - 1) + y(n - 2)), y(n), [0, 5]) == h + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n) + g = C1*factorial(n) + C0*2**n + h = -3*factorial(n) + 3*2**n + + assert rsolve(f, y(n)) == g + assert rsolve(f, y(n), []) == g + assert rsolve(f, y(n), {}) == g + + assert rsolve(f, y(n), [0, 3]) == h + assert rsolve(f, y(n), {0: 0, 1: 3}) == h + assert rsolve(f, y(n), {y(0): 0, y(1): 3}) == h + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - y(n - 1) - 2 + + assert rsolve(f, y(n), {y(0): 0}) == 2*n + assert rsolve(f, y(n), {y(0): 1}) == 2*n + 1 + assert rsolve(f, y(n), {y(0): 0, y(1): 1}) is None + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = 3*y(n - 1) - y(n) - 1 + + assert rsolve(f, y(n), {y(0): 0}) == -3**n/2 + S.Half + assert rsolve(f, y(n), {y(0): 1}) == 3**n/2 + S.Half + assert rsolve(f, y(n), {y(0): 2}) == 3*3**n/2 + S.Half + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - 1/n*y(n - 1) + assert rsolve(f, y(n)) == C0/factorial(n) + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - 1/n*y(n - 1) - 1 + assert rsolve(f, y(n)) is None + + f = 2*y(n - 1) + (1 - n)*y(n)/n + + assert rsolve(f, y(n), {y(1): 1}) == 2**(n - 1)*n + assert rsolve(f, y(n), {y(1): 2}) == 2**(n - 1)*n*2 + assert rsolve(f, y(n), {y(1): 3}) == 2**(n - 1)*n*3 + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = (n - 1)*(n - 2)*y(n + 2) - (n + 1)*(n + 2)*y(n) + + assert rsolve(f, y(n), {y(3): 6, y(4): 24}) == n*(n - 1)*(n - 2) + assert rsolve( + f, y(n), {y(3): 6, y(4): -24}) == -n*(n - 1)*(n - 2)*(-1)**(n) + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + assert rsolve(Eq(y(n + 1), a*y(n)), y(n), {y(1): a}).simplify() == a**n + + assert rsolve(y(n) - a*y(n-2),y(n), \ + {y(1): sqrt(a)*(a + b), y(2): a*(a - b)}).simplify() == \ + a**(n/2 + 1) - b*(-sqrt(a))**n + + f = (-16*n**2 + 32*n - 12)*y(n - 1) + (4*n**2 - 12*n + 9)*y(n) + + yn = rsolve(f, y(n), {y(1): binomial(2*n + 1, 3)}) + sol = 2**(2*n)*n*(2*n - 1)**2*(2*n + 1)/12 + assert factor(expand(yn, func=True)) == sol + + sol = rsolve(y(n) + a*(y(n + 1) + y(n - 1))/2, y(n)) + assert str(sol) == 'C0*((-sqrt(1 - a**2) - 1)/a)**n + C1*((sqrt(1 - a**2) - 1)/a)**n' + + assert rsolve((k + 1)*y(k), y(k)) is None + assert (rsolve((k + 1)*y(k) + (k + 3)*y(k + 1) + (k + 5)*y(k + 2), y(k)) + is None) + + assert rsolve(y(n) + y(n + 1) + 2**n + 3**n, y(n)) == (-1)**n*C0 - 2**n/3 - 3**n/4 + + +def test_rsolve_raises(): + x = Function('x') + raises(ValueError, lambda: rsolve(y(n) - y(k + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - y(n + 1), x(n))) + raises(ValueError, lambda: rsolve(y(n) - x(n + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - sqrt(n)*y(n + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - y(n + 1), y(n), {x(0): 0})) + raises(ValueError, lambda: rsolve(y(n) + y(n + 1) + 2**n + cos(n), y(n))) + + +def test_issue_6844(): + f = y(n + 2) - y(n + 1) + y(n)/4 + assert rsolve(f, y(n)) == 2**(-n + 1)*C1*n + 2**(-n)*C0 + assert rsolve(f, y(n), {y(0): 0, y(1): 1}) == 2**(1 - n)*n + + +def test_issue_18751(): + r = Symbol('r', positive=True) + theta = Symbol('theta', real=True) + f = y(n) - 2 * r * cos(theta) * y(n - 1) + r**2 * y(n - 2) + assert rsolve(f, y(n)) == \ + C0*(r*(cos(theta) - I*Abs(sin(theta))))**n + C1*(r*(cos(theta) + I*Abs(sin(theta))))**n + + +def test_constant_naming(): + #issue 8697 + assert rsolve(y(n+3) - y(n+2) - y(n+1) + y(n), y(n)) == (-1)**n*C1 + C0 + C2*n + assert rsolve(y(n+3)+3*y(n+2)+3*y(n+1)+y(n), y(n)).expand() == (-1)**n*C0 - (-1)**n*C1*n - (-1)**n*C2*n**2 + assert rsolve(y(n) - 2*y(n - 3) + 5*y(n - 2) - 4*y(n - 1),y(n),[1,3,8]) == 3*2**n - n - 2 + + #issue 19630 + assert rsolve(y(n+3) - 3*y(n+1) + 2*y(n), y(n), {y(1):0, y(2):8, y(3):-2}) == (-2)**n + 2*n + + +@slow +def test_issue_15751(): + f = y(n) + 21*y(n + 1) - 273*y(n + 2) - 1092*y(n + 3) + 1820*y(n + 4) + 1092*y(n + 5) - 273*y(n + 6) - 21*y(n + 7) + y(n + 8) + assert rsolve(f, y(n)) is not None + + +def test_issue_17990(): + f = -10*y(n) + 4*y(n + 1) + 6*y(n + 2) + 46*y(n + 3) + sol = rsolve(f, y(n)) + expected = C0*((86*18**(S(1)/3)/69 + (-12 + (-1 + sqrt(3)*I)*(290412 + + 3036*sqrt(9165))**(S(1)/3))*(1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))** + (S(1)/3)/276)/((1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3)) + )**n + C1*((86*18**(S(1)/3)/69 + (-12 + (-1 - sqrt(3)*I)*(290412 + 3036 + *sqrt(9165))**(S(1)/3))*(1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))** + (S(1)/3)/276)/((1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3)) + )**n + C2*(-43*18**(S(1)/3)/(69*(24201 + 253*sqrt(9165))**(S(1)/3)) - + S(1)/23 + (290412 + 3036*sqrt(9165))**(S(1)/3)/138)**n + assert sol == expected + e = sol.subs({C0: 1, C1: 1, C2: 1, n: 1}).evalf() + assert abs(e + 0.130434782608696) < 1e-13 + + +def test_issue_8697(): + a = Function('a') + eq = a(n + 3) - a(n + 2) - a(n + 1) + a(n) + assert rsolve(eq, a(n)) == (-1)**n*C1 + C0 + C2*n + eq2 = a(n + 3) + 3*a(n + 2) + 3*a(n + 1) + a(n) + assert (rsolve(eq2, a(n)) == + (-1)**n*C0 + (-1)**(n + 1)*C1*n + (-1)**(n + 1)*C2*n**2) + + assert rsolve(a(n) - 2*a(n - 3) + 5*a(n - 2) - 4*a(n - 1), + a(n), {a(0): 1, a(1): 3, a(2): 8}) == 3*2**n - n - 2 + + # From issue thread (but fixed by https://github.com/diofant/diofant/commit/da9789c6cd7d0c2ceeea19fbf59645987125b289): + assert rsolve(a(n) - 2*a(n - 1) - n, a(n), {a(0): 1}) == 3*2**n - n - 2 + + +def test_diofantissue_294(): + f = y(n) - y(n - 1) - 2*y(n - 2) - 2*n + assert rsolve(f, y(n)) == (-1)**n*C0 + 2**n*C1 - n - Rational(5, 2) + # issue sympy/sympy#11261 + assert rsolve(f, y(n), {y(0): -1, y(1): 1}) == (-(-1)**n/2 + 2*2**n - + n - Rational(5, 2)) + # issue sympy/sympy#7055 + assert rsolve(-2*y(n) + y(n + 1) + n - 1, y(n)) == 2**n*C0 + n + + +def test_issue_15553(): + f = Function("f") + assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n)) == 2**n*C0 - n - 2 + assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n)) == 2**n*C0 - n**2 - 2*n - 4 + assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n), {f(1): 0}) == 7*2**n/2 - n**2 - 2*n - 4 + assert rsolve(Eq(f(n), 2*f(n - 1) + 3*n**2), f(n)) == 2**n*C0 - 3*n**2 - 12*n - 18 + assert rsolve(Eq(f(n), 2*f(n - 1) + n**2), f(n)) == 2**n*C0 - n**2 - 4*n - 6 + assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n), {f(0): 1}) == 3*2**n - n - 2 diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_simplex.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_simplex.py new file mode 100644 index 0000000000000000000000000000000000000000..611205f5df009a6d0de6e687501695b63bb932c9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_simplex.py @@ -0,0 +1,254 @@ +from sympy.core.numbers import Rational +from sympy.core.relational import Eq, Ne +from sympy.core.symbol import symbols +from sympy.core.sympify import sympify +from sympy.core.singleton import S +from sympy.core.random import random, choice +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.ntheory.generate import randprime +from sympy.matrices.dense import Matrix +from sympy.solvers.solveset import linear_eq_to_matrix +from sympy.solvers.simplex import (_lp as lp, _primal_dual, + UnboundedLPError, InfeasibleLPError, lpmin, lpmax, + _m, _abcd, _simplex, linprog) + +from sympy.external.importtools import import_module + +from sympy.testing.pytest import raises + +from sympy.abc import x, y, z + + +np = import_module("numpy") +scipy = import_module("scipy") + + +def test_lp(): + r1 = y + 2*z <= 3 + r2 = -x - 3*z <= -2 + r3 = 2*x + y + 7*z <= 5 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + objective = -x - y - 5 * z + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + r1 = x - y + 2*z <= 3 + r2 = -x + 2*y - 3*z <= -2 + r3 = 2*x + y - 7*z <= -5 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + objective = -x - y - 5*z + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + r1 = x - y + 2*z <= -4 + r2 = -x + 2*y - 3*z <= 8 + r3 = 2*x + y - 7*z <= 10 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + const = 2 + objective = -x-y-5*z+const # has constant term + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + # Section 4 Problem 1 from + # http://web.tecnico.ulisboa.pt/mcasquilho/acad/or/ftp/FergusonUCLA_LP.pdf + # answer on page 55 + v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4') + r1 = x1 - x2 - 2*x3 - x4 <= 4 + r2 = 2*x1 + x3 -4*x4 <= 2 + r3 = -2*x1 + x2 + x4 <= 1 + objective, constraints = x1 - 2*x2 - 3*x3 - x4, [r1, r2, r3] + [ + i >= 0 for i in v] + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert ans == (4, {x1: 7, x2: 0, x3: 0, x4: 3}) + + # input contains Floats + r1 = x - y + 2.0*z <= -4 + r2 = -x + 2*y - 3.0*z <= 8 + r3 = 2*x + y - 7*z <= 10 + constraints = [r1, r2, r3] + [i >= 0 for i in (x, y, z)] + objective = -x-y-5*z + optimum, argmax = lp(max, objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + # input contains non-float or non-Rational + r1 = x - y + sqrt(2) * z <= -4 + r2 = -x + 2*y - 3*z <= 8 + r3 = 2*x + y - 7*z <= 10 + raises(TypeError, lambda: lp(max, -x-y-5*z, [r1, r2, r3])) + + r1 = x >= 0 + raises(UnboundedLPError, lambda: lp(max, x, [r1])) + r2 = x <= -1 + raises(InfeasibleLPError, lambda: lp(max, x, [r1, r2])) + + # strict inequalities are not allowed + r1 = x > 0 + raises(TypeError, lambda: lp(max, x, [r1])) + + # not equals not allowed + r1 = Ne(x, 0) + raises(TypeError, lambda: lp(max, x, [r1])) + + def make_random_problem(nvar=2, num_constraints=2, sparsity=.1): + def rand(): + if random() < sparsity: + return sympify(0) + int1, int2 = [randprime(0, 200) for _ in range(2)] + return Rational(int1, int2)*choice([-1, 1]) + variables = symbols('x1:%s' % (nvar + 1)) + constraints = [(sum(rand()*x for x in variables) <= rand()) + for _ in range(num_constraints)] + objective = sum(rand() * x for x in variables) + return objective, constraints, variables + + # equality + r1 = Eq(x, y) + r2 = Eq(y, z) + r3 = z <= 3 + constraints = [r1, r2, r3] + objective = x + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + +def test_simplex(): + L = [ + [[1, 1], [-1, 1], [0, 1], [-1, 0]], + [5, 1, 2, -1], + [[1, 1]], + [-1]] + A, B, C, D = _abcd(_m(*L), list=False) + assert _simplex(A, B, -C, -D) == (-6, [3, 2], [1, 0, 0, 0]) + assert _simplex(A, B, -C, -D, dual=True) == (-6, + [1, 0, 0, 0], [5, 0]) + + assert _simplex([[]],[],[[1]],[0]) == (0, [0], []) + + # handling of Eq (or Eq-like x<=y, x>=y conditions) + assert lpmax(x - y, [x <= y + 2, x >= y + 2, x >= 0, y >= 0] + ) == (2, {x: 2, y: 0}) + assert lpmax(x - y, [x <= y + 2, Eq(x, y + 2), x >= 0, y >= 0] + ) == (2, {x: 2, y: 0}) + assert lpmax(x - y, [x <= y + 2, Eq(x, 2)]) == (2, {x: 2, y: 0}) + assert lpmax(y, [Eq(y, 2)]) == (2, {y: 2}) + + # the conditions are equivalent to Eq(x, y + 2) + assert lpmin(y, [x <= y + 2, x >= y + 2, y >= 0] + ) == (0, {x: 2, y: 0}) + # equivalent to Eq(y, -2) + assert lpmax(y, [0 <= y + 2, 0 >= y + 2]) == (-2, {y: -2}) + assert lpmax(y, [0 <= y + 2, 0 >= y + 2, y <= 0] + ) == (-2, {y: -2}) + + # extra symbols symbols + assert lpmin(x, [y >= 1, x >= y]) == (1, {x: 1, y: 1}) + assert lpmin(x, [y >= 1, x >= y + z, x >= 0, z >= 0] + ) == (1, {x: 1, y: 1, z: 0}) + + # detect oscillation + # o1 + v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4') + raises(InfeasibleLPError, lambda: lpmin( + 9*x2 - 8*x3 + 3*x4 + 6, + [5*x2 - 2*x3 <= 0, + -x1 - 8*x2 + 9*x3 <= -3, + 10*x1 - x2+ 9*x4 <= -4] + [i >= 0 for i in v])) + # o2 - equations fed to lpmin are changed into a matrix + # system that doesn't oscillate and has the same solution + # as below + M = linear_eq_to_matrix + f = 5*x2 + x3 + 4*x4 - x1 + L = 5*x2 + 2*x3 + 5*x4 - (x1 + 5) + cond = [L <= 0] + [Eq(3*x2 + x4, 2), Eq(-x1 + x3 + 2*x4, 1)] + c, d = M(f, v) + a, b = M(L, v) + aeq, beq = M(cond[1:], v) + ans = (S(9)/2, [0, S(1)/2, 0, S(1)/2]) + assert linprog(c, a, b, aeq, beq, bounds=(0, 1)) == ans + lpans = lpmin(f, cond + [x1 >= 0, x1 <= 1, + x2 >= 0, x2 <= 1, x3 >= 0, x3 <= 1, x4 >= 0, x4 <= 1]) + assert (lpans[0], list(lpans[1].values())) == ans + + +def test_lpmin_lpmax(): + v = x1, x2, y1, y2 = symbols('x1 x2 y1 y2') + L = [[1, -1]], [1], [[1, 1]], [2] + a, b, c, d = [Matrix(i) for i in L] + m = Matrix([[a, b], [c, d]]) + f, constr = _primal_dual(m)[0] + ans = lpmin(f, constr + [i >= 0 for i in v[:2]]) + assert ans == (-1, {x1: 1, x2: 0}),ans + + L = [[1, -1], [1, 1]], [1, 1], [[1, 1]], [2] + a, b, c, d = [Matrix(i) for i in L] + m = Matrix([[a, b], [c, d]]) + f, constr = _primal_dual(m)[1] + ans = lpmax(f, constr + [i >= 0 for i in v[-2:]]) + assert ans == (-1, {y1: 1, y2: 0}) + + +def test_linprog(): + for do in range(2): + if not do: + M = lambda a, b: linear_eq_to_matrix(a, b) + else: + # check matrices as list + M = lambda a, b: tuple([ + i.tolist() for i in linear_eq_to_matrix(a, b)]) + + v = x, y, z = symbols('x1:4') + f = x + y - 2*z + c = M(f, v)[0] + ineq = [7*x + 4*y - 7*z <= 3, + 3*x - y + 10*z <= 6, + x >= 0, y >= 0, z >= 0] + ab = M([i.lts - i.gts for i in ineq], v) + ans = (-S(6)/5, [0, 0, S(3)/5]) + assert lpmin(f, ineq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab) == ans + + f += 1 + c = M(f, v)[0] + eq = [Eq(y - 9*x, 1)] + abeq = M([i.lhs - i.rhs for i in eq], v) + ans = (1 - S(2)/5, [0, 1, S(7)/10]) + assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1]) + + eq = [z - y <= S.Half] + abeq = M([i.lhs - i.rhs for i in eq], v) + ans = (1 - S(10)/9, [0, S(1)/9, S(11)/18]) + assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1]) + + bounds = [(0, None), (0, None), (None, S.Half)] + ans = (0, [0, 0, S.Half]) + assert lpmin(f, ineq + [z <= S.Half]) == ( + ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, bounds=bounds) == (ans[0] - 1, ans[1]) + assert linprog(c, *ab, bounds={v.index(z): bounds[-1]} + ) == (ans[0] - 1, ans[1]) + eq = [z - y <= S.Half] + + assert linprog([[1]], [], [], bounds=(2, 3)) == (2, [2]) + assert linprog([1], [], [], bounds=(2, 3)) == (2, [2]) + assert linprog([1], bounds=(2, 3)) == (2, [2]) + assert linprog([1, -1], [[1, 1]], [2], bounds={1:(None, None)} + ) == (-2, [0, 2]) + assert linprog([1, -1], [[1, 1]], [5], bounds={1:(3, None)} + ) == (-5, [0, 5]) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_solveset.py b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_solveset.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ba7a11e68ed518c4d83c050947b78756ade181 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/tests/test_solveset.py @@ -0,0 +1,3548 @@ +from math import isclose + +from sympy.calculus.util import stationary_points +from sympy.core.containers import Tuple +from sympy.core.function import (Function, Lambda, nfloat, diff) +from sympy.core.mod import Mod +from sympy.core.numbers import (E, I, Rational, oo, pi, Integer, all_close) +from sympy.core.relational import (Eq, Gt, Ne, Ge) +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import (Abs, arg, im, re, sign, conjugate) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, + sinh, cosh, tanh, coth, sech, csch, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.functions.elementary.miscellaneous import sqrt, Min, Max +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import ( + TrigonometricFunction, acos, acot, acsc, asec, asin, atan, atan2, + cos, cot, csc, sec, sin, tan) +from sympy.functions.special.error_functions import (erf, erfc, + erfcinv, erfinv) +from sympy.logic.boolalg import And +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.sets.contains import Contains +from sympy.sets.conditionset import ConditionSet +from sympy.sets.fancysets import ImageSet, Range +from sympy.sets.sets import (Complement, FiniteSet, + Intersection, Interval, Union, imageset, ProductSet) +from sympy.simplify import simplify +from sympy.tensor.indexed import Indexed +from sympy.utilities.iterables import numbered_symbols + +from sympy.testing.pytest import (XFAIL, raises, skip, slow, SKIP, _both_exp_pow) +from sympy.core.random import verify_numerically as tn +from sympy.physics.units import cm + +from sympy.solvers import solve +from sympy.solvers.solveset import ( + solveset_real, domain_check, solveset_complex, linear_eq_to_matrix, + linsolve, _is_function_class_equation, invert_real, invert_complex, + _invert_trig_hyp_real, solveset, solve_decomposition, substitution, + nonlinsolve, solvify, + _is_finite_with_finite_vars, _transolve, _is_exponential, + _solve_exponential, _is_logarithmic, _is_lambert, + _solve_logarithm, _term_factors, _is_modular, NonlinearError) + +from sympy.abc import (a, b, c, d, e, f, g, h, i, j, k, l, m, n, q, r, + t, w, x, y, z) + + +def dumeq(i, j): + if type(i) in (list, tuple): + return all(dumeq(i, j) for i, j in zip(i, j)) + return i == j or i.dummy_eq(j) + + +def assert_close_ss(sol1, sol2): + """Test solutions with floats from solveset are close""" + sol1 = sympify(sol1) + sol2 = sympify(sol2) + assert isinstance(sol1, FiniteSet) + assert isinstance(sol2, FiniteSet) + assert len(sol1) == len(sol2) + assert all(isclose(v1, v2) for v1, v2 in zip(sol1, sol2)) + + +def assert_close_nl(sol1, sol2): + """Test solutions with floats from nonlinsolve are close""" + sol1 = sympify(sol1) + sol2 = sympify(sol2) + assert isinstance(sol1, FiniteSet) + assert isinstance(sol2, FiniteSet) + assert len(sol1) == len(sol2) + for s1, s2 in zip(sol1, sol2): + assert len(s1) == len(s2) + assert all(isclose(v1, v2) for v1, v2 in zip(s1, s2)) + + +@_both_exp_pow +def test_invert_real(): + x = Symbol('x', real=True) + + def ireal(x, s=S.Reals): + return Intersection(s, x) + + assert invert_real(exp(x), z, x) == (x, ireal(FiniteSet(log(z)))) + + y = Symbol('y', positive=True) + n = Symbol('n', real=True) + assert invert_real(x + 3, y, x) == (x, FiniteSet(y - 3)) + assert invert_real(x*3, y, x) == (x, FiniteSet(y / 3)) + + assert invert_real(exp(x), y, x) == (x, FiniteSet(log(y))) + assert invert_real(exp(3*x), y, x) == (x, FiniteSet(log(y) / 3)) + assert invert_real(exp(x + 3), y, x) == (x, FiniteSet(log(y) - 3)) + + assert invert_real(exp(x) + 3, y, x) == (x, ireal(FiniteSet(log(y - 3)))) + assert invert_real(exp(x)*3, y, x) == (x, FiniteSet(log(y / 3))) + + assert invert_real(log(x), y, x) == (x, FiniteSet(exp(y))) + assert invert_real(log(3*x), y, x) == (x, FiniteSet(exp(y) / 3)) + assert invert_real(log(x + 3), y, x) == (x, FiniteSet(exp(y) - 3)) + + assert invert_real(Abs(x), y, x) == (x, FiniteSet(y, -y)) + + assert invert_real(2**x, y, x) == (x, FiniteSet(log(y)/log(2))) + assert invert_real(2**exp(x), y, x) == (x, ireal(FiniteSet(log(log(y)/log(2))))) + + assert invert_real(x**2, y, x) == (x, FiniteSet(sqrt(y), -sqrt(y))) + assert invert_real(x**S.Half, y, x) == (x, FiniteSet(y**2)) + + raises(ValueError, lambda: invert_real(x, x, x)) + + # issue 21236 + assert invert_real(x**pi, y, x) == (x, FiniteSet(y**(1/pi))) + assert invert_real(x**pi, -E, x) == (x, S.EmptySet) + assert invert_real(x**Rational(3/2), 1000, x) == (x, FiniteSet(100)) + assert invert_real(x**1.0, 1, x) == (x**1.0, FiniteSet(1)) + + raises(ValueError, lambda: invert_real(S.One, y, x)) + + assert invert_real(x**31 + x, y, x) == (x**31 + x, FiniteSet(y)) + + lhs = x**31 + x + base_values = FiniteSet(y - 1, -y - 1) + assert invert_real(Abs(x**31 + x + 1), y, x) == (lhs, base_values) + + assert dumeq(invert_real(sin(x), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + asin(y)), S.Integers), + ImageSet(Lambda(n, pi*2*n + pi - asin(y)), S.Integers))))) + + assert dumeq(invert_real(sin(exp(x)), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, log(2*n*pi + asin(y))), S.Integers), + ImageSet(Lambda(n, log(pi*2*n + pi - asin(y))), S.Integers))))) + + assert dumeq(invert_real(csc(x), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, 2*n*pi + acsc(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - acsc(y) + pi), S.Integers))))) + + assert dumeq(invert_real(csc(exp(x)), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, log(2*n*pi + acsc(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi - acsc(y) + pi)), S.Integers))))) + + assert dumeq(invert_real(cos(x), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + acos(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - acos(y)), S.Integers))))) + + assert dumeq(invert_real(cos(exp(x)), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, log(2*n*pi + acos(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi - acos(y))), S.Integers))))) + + assert dumeq(invert_real(sec(x), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, 2*n*pi + asec(y)), S.Integers), \ + ImageSet(Lambda(n, 2*n*pi - asec(y)), S.Integers))))) + + assert dumeq(invert_real(sec(exp(x)), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, log(2*n*pi - asec(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi + asec(y))), S.Integers))))) + + assert dumeq(invert_real(tan(x), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + atan(y)), S.Integers)))) + + assert dumeq(invert_real(tan(exp(x)), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, log(n*pi + atan(y))), S.Integers)))) + + assert dumeq(invert_real(cot(x), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + acot(y)), S.Integers)))) + + assert dumeq(invert_real(cot(exp(x)), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, log(n*pi + acot(y))), S.Integers)))) + + assert dumeq(invert_real(tan(tan(x)), y, x), + (x, ConditionSet(x, Eq(tan(tan(x)), y), S.Reals))) + # slight regression compared to previous result: + # (tan(x), imageset(Lambda(n, n*pi + atan(y)), S.Integers))) + + x = Symbol('x', positive=True) + assert invert_real(x**pi, y, x) == (x, FiniteSet(y**(1/pi))) + + r = Symbol('r', real=True) + p = Symbol('p', positive=True) + assert invert_real(sinh(x), r, x) == (x, FiniteSet(asinh(r))) + assert invert_real(sinh(log(x)), p, x) == (x, FiniteSet(exp(asinh(p)))) + + assert invert_real(cosh(x), r, x) == (x, Intersection( + FiniteSet(-acosh(r), acosh(r)), S.Reals)) + assert invert_real(cosh(x), p + 1, x) == (x, + FiniteSet(-acosh(p + 1), acosh(p + 1))) + + assert invert_real(tanh(x), r, x) == (x, Intersection(FiniteSet(atanh(r)), S.Reals)) + assert invert_real(coth(x), p+1, x) == (x, FiniteSet(acoth(p+1))) + assert invert_real(sech(x), r, x) == (x, Intersection( + FiniteSet(-asech(r), asech(r)), S.Reals)) + assert invert_real(csch(x), p, x) == (x, FiniteSet(acsch(p))) + + assert dumeq(invert_real(tanh(sin(x)), r, x), (x, + ConditionSet(x, (S(-1) <= atanh(r)) & (atanh(r) <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + asin(atanh(r))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - asin(atanh(r)) + pi), S.Integers))))) + + +def test_invert_trig_hyp_real(): + # check some codepaths that are not as easily reached otherwise + n = Dummy('n') + assert _invert_trig_hyp_real(cosh(x), Range(-5, 10, 1), x)[1].dummy_eq(Union( + ImageSet(Lambda(n, -acosh(n)), Range(1, 10, 1)), + ImageSet(Lambda(n, acosh(n)), Range(1, 10, 1)))) + assert _invert_trig_hyp_real(coth(x), Interval(-3, 2), x) == (x, Union( + Interval(-oo, -acoth(3)), Interval(acoth(2), oo))) + assert _invert_trig_hyp_real(tanh(x), Interval(-S.Half, 1), x) == (x, + Interval(-atanh(S.Half), oo)) + assert _invert_trig_hyp_real(sech(x), imageset(n, S.Half + n/3, S.Naturals0), x) == \ + (x, FiniteSet(-asech(S(1)/2), asech(S(1)/2), -asech(S(5)/6), asech(S(5)/6))) + assert _invert_trig_hyp_real(csch(x), S.Reals, x) == (x, + Union(Interval.open(-oo, 0), Interval.open(0, oo))) + + +def test_invert_complex(): + assert invert_complex(x + 3, y, x) == (x, FiniteSet(y - 3)) + assert invert_complex(x*3, y, x) == (x, FiniteSet(y / 3)) + assert invert_complex((x - 1)**3, 0, x) == (x, FiniteSet(1)) + + assert dumeq(invert_complex(exp(x), y, x), + (x, imageset(Lambda(n, I*(2*pi*n + arg(y)) + log(Abs(y))), S.Integers))) + + assert invert_complex(log(x), y, x) == (x, FiniteSet(exp(y))) + + raises(ValueError, lambda: invert_real(1, y, x)) + raises(ValueError, lambda: invert_complex(x, x, x)) + raises(ValueError, lambda: invert_complex(x, x, 1)) + + assert dumeq(invert_complex(sin(x), I, x), (x, Union( + ImageSet(Lambda(n, 2*n*pi + I*log(1 + sqrt(2))), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi - I*log(1 + sqrt(2))), S.Integers)))) + assert dumeq(invert_complex(cos(x), 1+I, x), (x, Union( + ImageSet(Lambda(n, 2*n*pi - acos(1 + I)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(1 + I)), S.Integers)))) + assert dumeq(invert_complex(tan(2*x), 1, x), (x, + ImageSet(Lambda(n, n*pi/2 + pi/8), S.Integers))) + assert dumeq(invert_complex(cot(x), 2*I, x), (x, + ImageSet(Lambda(n, n*pi - I*acoth(2)), S.Integers))) + + assert dumeq(invert_complex(sinh(x), 0, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + I*pi), S.Integers)))) + assert dumeq(invert_complex(cosh(x), 0, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi + I*pi/2), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + 3*I*pi/2), S.Integers)))) + assert invert_complex(tanh(x), 1, x) == (x, S.EmptySet) + assert dumeq(invert_complex(tanh(x), a, x), (x, + ConditionSet(x, Ne(a, -1) & Ne(a, 1), + ImageSet(Lambda(n, n*I*pi + atanh(a)), S.Integers)))) + assert invert_complex(coth(x), 1, x) == (x, S.EmptySet) + assert dumeq(invert_complex(coth(x), a, x), (x, + ConditionSet(x, Ne(a, -1) & Ne(a, 1), + ImageSet(Lambda(n, n*I*pi + acoth(a)), S.Integers)))) + assert dumeq(invert_complex(sech(x), 2, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi + I*pi/3), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + 5*I*pi/3), S.Integers)))) + + +def test_domain_check(): + assert domain_check(1/(1 + (1/(x+1))**2), x, -1) is False + assert domain_check(x**2, x, 0) is True + assert domain_check(x, x, oo) is False + assert domain_check(0, x, oo) is False + + +def test_issue_11536(): + assert solveset(0**x - 100, x, S.Reals) == S.EmptySet + assert solveset(0**x - 1, x, S.Reals) == FiniteSet(0) + + +def test_issue_17479(): + f = (x**2 + y**2)**2 + (x**2 + z**2)**2 - 2*(2*x**2 + y**2 + z**2) + fx = f.diff(x) + fy = f.diff(y) + fz = f.diff(z) + sol = nonlinsolve([fx, fy, fz], [x, y, z]) + assert len(sol) >= 4 and len(sol) <= 20 + # nonlinsolve has been giving a varying number of solutions + # (originally 18, then 20, now 19) due to various internal changes. + # Unfortunately not all the solutions are actually valid and some are + # redundant. Since the original issue was that an exception was raised, + # this first test only checks that nonlinsolve returns a "plausible" + # solution set. The next test checks the result for correctness. + + +@XFAIL +def test_issue_18449(): + x, y, z = symbols("x, y, z") + f = (x**2 + y**2)**2 + (x**2 + z**2)**2 - 2*(2*x**2 + y**2 + z**2) + fx = diff(f, x) + fy = diff(f, y) + fz = diff(f, z) + sol = nonlinsolve([fx, fy, fz], [x, y, z]) + for (xs, ys, zs) in sol: + d = {x: xs, y: ys, z: zs} + assert tuple(_.subs(d).simplify() for _ in (fx, fy, fz)) == (0, 0, 0) + # After simplification and removal of duplicate elements, there should + # only be 4 parametric solutions left: + # simplifiedsolutions = FiniteSet((sqrt(1 - z**2), z, z), + # (-sqrt(1 - z**2), z, z), + # (sqrt(1 - z**2), -z, z), + # (-sqrt(1 - z**2), -z, z)) + # TODO: Is the above solution set definitely complete? + + +def test_issue_21047(): + f = (2 - x)**2 + (sqrt(x - 1) - 1)**6 + assert solveset(f, x, S.Reals) == FiniteSet(2) + + f = (sqrt(x)-1)**2 + (sqrt(x)+1)**2 -2*x**2 + sqrt(2) + assert solveset(f, x, S.Reals) == FiniteSet( + S.Half - sqrt(2*sqrt(2) + 5)/2, S.Half + sqrt(2*sqrt(2) + 5)/2) + + +def test_is_function_class_equation(): + assert _is_function_class_equation(TrigonometricFunction, + tan(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + sin(x) - a, x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x + a) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x*a) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + a*tan(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x)**2 + sin(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + x, x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x**2), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x**2) + sin(x), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x)**sin(x), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(sin(x)) + sin(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + sinh(x) - a, x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x + a) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x*a) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + a*tanh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x)**2 + sinh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + x, x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x**2), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x**2) + sinh(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x)**sinh(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(sinh(x)) + sinh(x), x) is False + + +def test_garbage_input(): + raises(ValueError, lambda: solveset_real([y], y)) + x = Symbol('x', real=True) + assert solveset_real(x, 1) == S.EmptySet + assert solveset_real(x - 1, 1) == FiniteSet(x) + assert solveset_real(x, pi) == S.EmptySet + assert solveset_real(x, x**2) == S.EmptySet + + raises(ValueError, lambda: solveset_complex([x], x)) + assert solveset_complex(x, pi) == S.EmptySet + + raises(ValueError, lambda: solveset((x, y), x)) + raises(ValueError, lambda: solveset(x + 1, S.Reals)) + raises(ValueError, lambda: solveset(x + 1, x, 2)) + + +def test_solve_mul(): + assert solveset_real((a*x + b)*(exp(x) - 3), x) == \ + Union({log(3)}, Intersection({-b/a}, S.Reals)) + anz = Symbol('anz', nonzero=True) + bb = Symbol('bb', real=True) + assert solveset_real((anz*x + bb)*(exp(x) - 3), x) == \ + FiniteSet(-bb/anz, log(3)) + assert solveset_real((2*x + 8)*(8 + exp(x)), x) == FiniteSet(S(-4)) + assert solveset_real(x/log(x), x) is S.EmptySet + + +def test_solve_invert(): + assert solveset_real(exp(x) - 3, x) == FiniteSet(log(3)) + assert solveset_real(log(x) - 3, x) == FiniteSet(exp(3)) + + assert solveset_real(3**(x + 2), x) == FiniteSet() + assert solveset_real(3**(2 - x), x) == FiniteSet() + + assert solveset_real(y - b*exp(a/x), x) == Intersection( + S.Reals, FiniteSet(a/log(y/b))) + + # issue 4504 + assert solveset_real(2**x - 10, x) == FiniteSet(1 + log(5)/log(2)) + + +def test_issue_25768(): + assert dumeq(solveset_real(sin(x) - S.Half, x), Union( + ImageSet(Lambda(n, pi*2*n + pi/6), S.Integers), + ImageSet(Lambda(n, pi*2*n + pi*5/6), S.Integers))) + n1 = solveset_real(sin(x) - 0.5, x).n(5) + n2 = solveset_real(sin(x) - S.Half, x).n(5) + # help pass despite fp differences + eq = [i.replace( + lambda x:x.is_Float, + lambda x:Rational(x).limit_denominator(1000)) for i in (n1, n2)] + assert dumeq(*eq),(n1,n2) + + +def test_errorinverses(): + assert solveset_real(erf(x) - S.Half, x) == \ + FiniteSet(erfinv(S.Half)) + assert solveset_real(erfinv(x) - 2, x) == \ + FiniteSet(erf(2)) + assert solveset_real(erfc(x) - S.One, x) == \ + FiniteSet(erfcinv(S.One)) + assert solveset_real(erfcinv(x) - 2, x) == FiniteSet(erfc(2)) + + +def test_solve_polynomial(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset_real(3*x - 2, x) == FiniteSet(Rational(2, 3)) + + assert solveset_real(x**2 - 1, x) == FiniteSet(-S.One, S.One) + assert solveset_real(x - y**3, x) == FiniteSet(y ** 3) + + assert solveset_real(x**3 - 15*x - 4, x) == FiniteSet( + -2 + 3 ** S.Half, + S(4), + -2 - 3 ** S.Half) + + assert solveset_real(sqrt(x) - 1, x) == FiniteSet(1) + assert solveset_real(sqrt(x) - 2, x) == FiniteSet(4) + assert solveset_real(x**Rational(1, 4) - 2, x) == FiniteSet(16) + assert solveset_real(x**Rational(1, 3) - 3, x) == FiniteSet(27) + assert len(solveset_real(x**5 + x**3 + 1, x)) == 1 + assert len(solveset_real(-2*x**3 + 4*x**2 - 2*x + 6, x)) > 0 + assert solveset_real(x**6 + x**4 + I, x) is S.EmptySet + + +def test_return_root_of(): + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = list(solveset_complex(f, x)) + for root in s: + assert root.func == CRootOf + + # if one uses solve to get the roots of a polynomial that has a CRootOf + # solution, make sure that the use of nfloat during the solve process + # doesn't fail. Note: if you want numerical solutions to a polynomial + # it is *much* faster to use nroots to get them than to solve the + # equation only to get CRootOf solutions which are then numerically + # evaluated. So for eq = x**5 + 3*x + 7 do Poly(eq).nroots() rather + # than [i.n() for i in solve(eq)] to get the numerical roots of eq. + assert nfloat(list(solveset_complex(x**5 + 3*x**3 + 7, x))[0], + exponent=False) == CRootOf(x**5 + 3*x**3 + 7, 0).n() + + sol = list(solveset_complex(x**6 - 2*x + 2, x)) + assert all(isinstance(i, CRootOf) for i in sol) and len(sol) == 6 + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = list(solveset_complex(f, x)) + for root in s: + assert root.func == CRootOf + + s = x**5 + 4*x**3 + 3*x**2 + Rational(7, 4) + assert solveset_complex(s, x) == \ + FiniteSet(*Poly(s*4, domain='ZZ').all_roots()) + + # Refer issue #7876 + eq = x*(x - 1)**2*(x + 1)*(x**6 - x + 1) + assert solveset_complex(eq, x) == \ + FiniteSet(-1, 0, 1, CRootOf(x**6 - x + 1, 0), + CRootOf(x**6 - x + 1, 1), + CRootOf(x**6 - x + 1, 2), + CRootOf(x**6 - x + 1, 3), + CRootOf(x**6 - x + 1, 4), + CRootOf(x**6 - x + 1, 5)) + + +def test_solveset_sqrt_1(): + assert solveset_real(sqrt(5*x + 6) - 2 - x, x) == \ + FiniteSet(-S.One, S(2)) + assert solveset_real(sqrt(x - 1) - x + 7, x) == FiniteSet(10) + assert solveset_real(sqrt(x - 2) - 5, x) == FiniteSet(27) + assert solveset_real(sqrt(x) - 2 - 5, x) == FiniteSet(49) + assert solveset_real(sqrt(x**3), x) == FiniteSet(0) + assert solveset_real(sqrt(x - 1), x) == FiniteSet(1) + assert solveset_real(sqrt((x-3)/x), x) == FiniteSet(3) + assert solveset_real(sqrt((x-3)/x)-Rational(1, 2), x) == \ + FiniteSet(4) + +def test_solveset_sqrt_2(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + # http://tutorial.math.lamar.edu/Classes/Alg/SolveRadicalEqns.aspx#Solve_Rad_Ex2_a + assert solveset_real(sqrt(2*x - 1) - sqrt(x - 4) - 2, x) == \ + FiniteSet(S(5), S(13)) + assert solveset_real(sqrt(x + 7) + 2 - sqrt(3 - x), x) == \ + FiniteSet(-6) + + # http://www.purplemath.com/modules/solverad.htm + assert solveset_real(sqrt(17*x - sqrt(x**2 - 5)) - 7, x) == \ + FiniteSet(3) + + eq = x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4) + assert solveset_real(eq, x) == FiniteSet(Rational(-1, 2), Rational(-1, 3)) + + eq = sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4) + assert solveset_real(eq, x) == FiniteSet(0) + + eq = sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1) + assert solveset_real(eq, x) == FiniteSet(5) + + eq = sqrt(x)*sqrt(x - 7) - 12 + assert solveset_real(eq, x) == FiniteSet(16) + + eq = sqrt(x - 3) + sqrt(x) - 3 + assert solveset_real(eq, x) == FiniteSet(4) + + eq = sqrt(2*x**2 - 7) - (3 - x) + assert solveset_real(eq, x) == FiniteSet(-S(8), S(2)) + + # others + eq = sqrt(9*x**2 + 4) - (3*x + 2) + assert solveset_real(eq, x) == FiniteSet(0) + + assert solveset_real(sqrt(x - 3) - sqrt(x) - 3, x) == FiniteSet() + + eq = (2*x - 5)**Rational(1, 3) - 3 + assert solveset_real(eq, x) == FiniteSet(16) + + assert solveset_real(sqrt(x) + sqrt(sqrt(x)) - 4, x) == \ + FiniteSet((Rational(-1, 2) + sqrt(17)/2)**4) + + eq = sqrt(x) - sqrt(x - 1) + sqrt(sqrt(x)) + assert solveset_real(eq, x) == FiniteSet() + + eq = (x - 4)**2 + (sqrt(x) - 2)**4 + assert solveset_real(eq, x) == FiniteSet(-4, 4) + + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + ans = solveset_real(eq, x) + ra = S('''-1484/375 - 4*(-S(1)/2 + sqrt(3)*I/2)*(-12459439/52734375 + + 114*sqrt(12657)/78125)**(S(1)/3) - 172564/(140625*(-S(1)/2 + + sqrt(3)*I/2)*(-12459439/52734375 + 114*sqrt(12657)/78125)**(S(1)/3))''') + rb = Rational(4, 5) + assert all(abs(eq.subs(x, i).n()) < 1e-10 for i in (ra, rb)) and \ + len(ans) == 2 and \ + {i.n(chop=True) for i in ans} == \ + {i.n(chop=True) for i in (ra, rb)} + + assert solveset_real(sqrt(x) + x**Rational(1, 3) + + x**Rational(1, 4), x) == FiniteSet(0) + + assert solveset_real(x/sqrt(x**2 + 1), x) == FiniteSet(0) + + eq = (x - y**3)/((y**2)*sqrt(1 - y**2)) + assert solveset_real(eq, x) == FiniteSet(y**3) + + # issue 4497 + assert solveset_real(1/(5 + x)**Rational(1, 5) - 9, x) == \ + FiniteSet(Rational(-295244, 59049)) + + +@XFAIL +def test_solve_sqrt_fail(): + # this only works if we check real_root(eq.subs(x, Rational(1, 3))) + # but checksol doesn't work like that + eq = (x**3 - 3*x**2)**Rational(1, 3) + 1 - x + assert solveset_real(eq, x) == FiniteSet(Rational(1, 3)) + + +@slow +def test_solve_sqrt_3(): + R = Symbol('R') + eq = sqrt(2)*R*sqrt(1/(R + 1)) + (R + 1)*(sqrt(2)*sqrt(1/(R + 1)) - 1) + sol = solveset_complex(eq, R) + fset = [Rational(5, 3) + 4*sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3, + -sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3 + + 40*re(1/((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 + + sqrt(30)*sin(atan(3*sqrt(111)/251)/3)/3 + Rational(5, 3) + + I*(-sqrt(30)*cos(atan(3*sqrt(111)/251)/3)/3 - + sqrt(10)*sin(atan(3*sqrt(111)/251)/3)/3 + + 40*im(1/((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9)] + cset = [40*re(1/((Rational(-1, 2) + sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 - + sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3 - sqrt(30)*sin(atan(3*sqrt(111)/251)/3)/3 + + Rational(5, 3) + + I*(40*im(1/((Rational(-1, 2) + sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 - + sqrt(10)*sin(atan(3*sqrt(111)/251)/3)/3 + + sqrt(30)*cos(atan(3*sqrt(111)/251)/3)/3)] + + fs = FiniteSet(*fset) + cs = ConditionSet(R, Eq(eq, 0), FiniteSet(*cset)) + assert sol == (fs - {-1}) | (cs - {-1}) + + # the number of real roots will depend on the value of m: for m=1 there are 4 + # and for m=-1 there are none. + eq = -sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) + sqrt((-m**2/2 - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2) + unsolved_object = ConditionSet(q, Eq(sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) - + sqrt((-m**2/2 - sqrt(4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - + sqrt(4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2), 0), S.Reals) + assert solveset_real(eq, q) == unsolved_object + + +def test_solve_polynomial_symbolic_param(): + assert solveset_complex((x**2 - 1)**2 - a, x) == \ + FiniteSet(sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)), + sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))) + + # issue 4507 + assert solveset_complex(y - b/(1 + a*x), x) == \ + FiniteSet((b/y - 1)/a) - FiniteSet(-1/a) + + # issue 4508 + assert solveset_complex(y - b*x/(a + x), x) == \ + FiniteSet(-a*y/(y - b)) - FiniteSet(-a) + + +def test_solve_rational(): + assert solveset_real(1/x + 1, x) == FiniteSet(-S.One) + assert solveset_real(1/exp(x) - 1, x) == FiniteSet(0) + assert solveset_real(x*(1 - 5/x), x) == FiniteSet(5) + assert solveset_real(2*x/(x + 2) - 1, x) == FiniteSet(2) + assert solveset_real((x**2/(7 - x)).diff(x), x) == \ + FiniteSet(S.Zero, S(14)) + + +def test_solveset_real_gen_is_pow(): + assert solveset_real(sqrt(1) + 1, x) is S.EmptySet + + +def test_no_sol(): + assert solveset(1 - oo*x) is S.EmptySet + assert solveset(oo*x, x) is S.EmptySet + assert solveset(oo*x - oo, x) is S.EmptySet + assert solveset_real(4, x) is S.EmptySet + assert solveset_real(exp(x), x) is S.EmptySet + assert solveset_real(x**2 + 1, x) is S.EmptySet + assert solveset_real(-3*a/sqrt(x), x) is S.EmptySet + assert solveset_real(1/x, x) is S.EmptySet + assert solveset_real(-(1 + x)/(2 + x)**2 + 1/(2 + x), x + ) is S.EmptySet + + +def test_sol_zero_real(): + assert solveset_real(0, x) == S.Reals + assert solveset(0, x, Interval(1, 2)) == Interval(1, 2) + assert solveset_real(-x**2 - 2*x + (x + 1)**2 - 1, x) == S.Reals + + +def test_no_sol_rational_extragenous(): + assert solveset_real((x/(x + 1) + 3)**(-2), x) is S.EmptySet + assert solveset_real((x - 1)/(1 + 1/(x - 1)), x) is S.EmptySet + + +def test_solve_polynomial_cv_1a(): + """ + Test for solving on equations that can be converted to + a polynomial equation using the change of variable y -> x**Rational(p, q) + """ + assert solveset_real(sqrt(x) - 1, x) == FiniteSet(1) + assert solveset_real(sqrt(x) - 2, x) == FiniteSet(4) + assert solveset_real(x**Rational(1, 4) - 2, x) == FiniteSet(16) + assert solveset_real(x**Rational(1, 3) - 3, x) == FiniteSet(27) + assert solveset_real(x*(x**(S.One / 3) - 3), x) == \ + FiniteSet(S.Zero, S(27)) + + +def test_solveset_real_rational(): + """Test solveset_real for rational functions""" + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset_real((x - y**3) / ((y**2)*sqrt(1 - y**2)), x) \ + == FiniteSet(y**3) + # issue 4486 + assert solveset_real(2*x/(x + 2) - 1, x) == FiniteSet(2) + + +def test_solveset_real_log(): + assert solveset_real(log((x-1)*(x+1)), x) == \ + FiniteSet(sqrt(2), -sqrt(2)) + + +def test_poly_gens(): + assert solveset_real(4**(2*(x**2) + 2*x) - 8, x) == \ + FiniteSet(Rational(-3, 2), S.Half) + + +def test_solve_abs(): + n = Dummy('n') + raises(ValueError, lambda: solveset(Abs(x) - 1, x)) + assert solveset(Abs(x) - n, x, S.Reals).dummy_eq( + ConditionSet(x, Contains(n, Interval(0, oo)), {-n, n})) + assert solveset_real(Abs(x) - 2, x) == FiniteSet(-2, 2) + assert solveset_real(Abs(x) + 2, x) is S.EmptySet + assert solveset_real(Abs(x + 3) - 2*Abs(x - 3), x) == \ + FiniteSet(1, 9) + assert solveset_real(2*Abs(x) - Abs(x - 1), x) == \ + FiniteSet(-1, Rational(1, 3)) + + sol = ConditionSet( + x, + And( + Contains(b, Interval(0, oo)), + Contains(a + b, Interval(0, oo)), + Contains(a - b, Interval(0, oo))), + FiniteSet(-a - b - 3, -a + b - 3, a - b - 3, a + b - 3)) + eq = Abs(Abs(x + 3) - a) - b + assert invert_real(eq, 0, x)[1] == sol + reps = {a: 3, b: 1} + eqab = eq.subs(reps) + for si in sol.subs(reps): + assert not eqab.subs(x, si) + assert dumeq(solveset(Eq(sin(Abs(x)), 1), x, domain=S.Reals), Union( + Intersection(Interval(0, oo), Union( + Intersection(ImageSet(Lambda(n, 2*n*pi + 3*pi/2), S.Integers), + Interval(-oo, 0)), + Intersection(ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers), + Interval(0, oo)))))) + + +def test_issue_9824(): + assert dumeq(solveset(sin(x)**2 - 2*sin(x) + 1, x), ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers)) + assert dumeq(solveset(cos(x)**2 - 2*cos(x) + 1, x), ImageSet(Lambda(n, 2*n*pi), S.Integers)) + + +def test_issue_9565(): + assert solveset_real(Abs((x - 1)/(x - 5)) <= Rational(1, 3), x) == Interval(-1, 2) + + +def test_issue_10069(): + eq = abs(1/(x - 1)) - 1 > 0 + assert solveset_real(eq, x) == Union( + Interval.open(0, 1), Interval.open(1, 2)) + + +def test_real_imag_splitting(): + a, b = symbols('a b', real=True) + assert solveset_real(sqrt(a**2 - b**2) - 3, a) == \ + FiniteSet(-sqrt(b**2 + 9), sqrt(b**2 + 9)) + assert solveset_real(sqrt(a**2 + b**2) - 3, a) != \ + S.EmptySet + + +def test_units(): + assert solveset_real(1/x - 1/(2*cm), x) == FiniteSet(2*cm) + + +def test_solve_only_exp_1(): + y = Symbol('y', positive=True) + assert solveset_real(exp(x) - y, x) == FiniteSet(log(y)) + assert solveset_real(exp(x) + exp(-x) - 4, x) == \ + FiniteSet(log(-sqrt(3) + 2), log(sqrt(3) + 2)) + assert solveset_real(exp(x) + exp(-x) - y, x) != S.EmptySet + + +def test_atan2(): + # The .inverse() method on atan2 works only if x.is_real is True and the + # second argument is a real constant + assert solveset_real(atan2(x, 2) - pi/3, x) == FiniteSet(2*sqrt(3)) + + +def test_piecewise_solveset(): + eq = Piecewise((x - 2, Gt(x, 2)), (2 - x, True)) - 3 + assert set(solveset_real(eq, x)) == set(FiniteSet(-1, 5)) + + absxm3 = Piecewise( + (x - 3, 0 <= x - 3), + (3 - x, 0 > x - 3)) + y = Symbol('y', positive=True) + assert solveset_real(absxm3 - y, x) == FiniteSet(-y + 3, y + 3) + + f = Piecewise(((x - 2)**2, x >= 0), (0, True)) + assert solveset(f, x, domain=S.Reals) == Union(FiniteSet(2), Interval(-oo, 0, True, True)) + + assert solveset( + Piecewise((x + 1, x > 0), (I, True)) - I, x, S.Reals + ) == Interval(-oo, 0) + + assert solveset(Piecewise((x - 1, Ne(x, I)), (x, True)), x) == FiniteSet(1) + + # issue 19718 + g = Piecewise((1, x > 10), (0, True)) + assert solveset(g > 0, x, S.Reals) == Interval.open(10, oo) + + from sympy.logic.boolalg import BooleanTrue + f = BooleanTrue() + assert solveset(f, x, domain=Interval(-3, 10)) == Interval(-3, 10) + + # issue 20552 + f = Piecewise((0, Eq(x, 0)), (x**2/Abs(x), True)) + g = Piecewise((0, Eq(x, pi)), ((x - pi)/sin(x), True)) + assert solveset(f, x, domain=S.Reals) == FiniteSet(0) + assert solveset(g) == FiniteSet(pi) + + +def test_solveset_complex_polynomial(): + assert solveset_complex(a*x**2 + b*x + c, x) == \ + FiniteSet(-b/(2*a) - sqrt(-4*a*c + b**2)/(2*a), + -b/(2*a) + sqrt(-4*a*c + b**2)/(2*a)) + + assert solveset_complex(x - y**3, y) == FiniteSet( + (-x**Rational(1, 3))/2 + I*sqrt(3)*x**Rational(1, 3)/2, + x**Rational(1, 3), + (-x**Rational(1, 3))/2 - I*sqrt(3)*x**Rational(1, 3)/2) + + assert solveset_complex(x + 1/x - 1, x) == \ + FiniteSet(S.Half + I*sqrt(3)/2, S.Half - I*sqrt(3)/2) + + +def test_sol_zero_complex(): + assert solveset_complex(0, x) is S.Complexes + + +def test_solveset_complex_rational(): + assert solveset_complex((x - 1)*(x - I)/(x - 3), x) == \ + FiniteSet(1, I) + + assert solveset_complex((x - y**3)/((y**2)*sqrt(1 - y**2)), x) == \ + FiniteSet(y**3) + assert solveset_complex(-x**2 - I, x) == \ + FiniteSet(-sqrt(2)/2 + sqrt(2)*I/2, sqrt(2)/2 - sqrt(2)*I/2) + + +def test_solve_quintics(): + skip("This test is too slow") + f = x**5 - 110*x**3 - 55*x**2 + 2310*x + 979 + s = solveset_complex(f, x) + for root in s: + res = f.subs(x, root.n()).n() + assert tn(res, 0) + + f = x**5 + 15*x + 12 + s = solveset_complex(f, x) + for root in s: + res = f.subs(x, root.n()).n() + assert tn(res, 0) + + +def test_solveset_complex_exp(): + assert dumeq(solveset_complex(exp(x) - 1, x), + imageset(Lambda(n, I*2*n*pi), S.Integers)) + assert dumeq(solveset_complex(exp(x) - I, x), + imageset(Lambda(n, I*(2*n*pi + pi/2)), S.Integers)) + assert solveset_complex(1/exp(x), x) == S.EmptySet + assert dumeq(solveset_complex(sinh(x).rewrite(exp), x), + imageset(Lambda(n, n*pi*I), S.Integers)) + + +def test_solveset_real_exp(): + assert solveset(Eq((-2)**x, 4), x, S.Reals) == FiniteSet(2) + assert solveset(Eq(-2**x, 4), x, S.Reals) == S.EmptySet + assert solveset(Eq((-3)**x, 27), x, S.Reals) == S.EmptySet + assert solveset(Eq((-5)**(x+1), 625), x, S.Reals) == FiniteSet(3) + assert solveset(Eq(2**(x-3), -16), x, S.Reals) == S.EmptySet + assert solveset(Eq((-3)**(x - 3), -3**39), x, S.Reals) == FiniteSet(42) + assert solveset(Eq(2**x, y), x, S.Reals) == Intersection(S.Reals, FiniteSet(log(y)/log(2))) + + assert invert_real((-2)**(2*x) - 16, 0, x) == (x, FiniteSet(2)) + + +def test_solve_complex_log(): + assert solveset_complex(log(x), x) == FiniteSet(1) + assert solveset_complex(1 - log(a + 4*x**2), x) == \ + FiniteSet(-sqrt(-a + E)/2, sqrt(-a + E)/2) + + +def test_solve_complex_sqrt(): + assert solveset_complex(sqrt(5*x + 6) - 2 - x, x) == \ + FiniteSet(-S.One, S(2)) + assert solveset_complex(sqrt(5*x + 6) - (2 + 2*I) - x, x) == \ + FiniteSet(-S(2), 3 - 4*I) + assert solveset_complex(4*x*(1 - a * sqrt(x)), x) == \ + FiniteSet(S.Zero, 1 / a ** 2) + + +def test_solveset_complex_tan(): + s = solveset_complex(tan(x).rewrite(exp), x) + assert dumeq(s, imageset(Lambda(n, pi*n), S.Integers) - \ + imageset(Lambda(n, pi*n + pi/2), S.Integers)) + + +@_both_exp_pow +def test_solve_trig(): + assert dumeq(solveset_real(sin(x), x), + Union(imageset(Lambda(n, 2*pi*n), S.Integers), + imageset(Lambda(n, 2*pi*n + pi), S.Integers))) + + assert dumeq(solveset_real(sin(x) - 1, x), + imageset(Lambda(n, 2*pi*n + pi/2), S.Integers)) + + assert dumeq(solveset_real(cos(x), x), + Union(imageset(Lambda(n, 2*pi*n + pi/2), S.Integers), + imageset(Lambda(n, 2*pi*n + pi*Rational(3, 2)), S.Integers))) + + assert dumeq(solveset_real(sin(x) + cos(x), x), + Union(imageset(Lambda(n, 2*n*pi + pi*Rational(3, 4)), S.Integers), + imageset(Lambda(n, 2*n*pi + pi*Rational(7, 4)), S.Integers))) + + assert solveset_real(sin(x)**2 + cos(x)**2, x) == S.EmptySet + + assert dumeq(solveset_complex(cos(x) - S.Half, x), + Union(imageset(Lambda(n, 2*n*pi + pi*Rational(5, 3)), S.Integers), + imageset(Lambda(n, 2*n*pi + pi/3), S.Integers))) + + assert dumeq(solveset(sin(y + a) - sin(y), a, domain=S.Reals), + ConditionSet(a, (S(-1) <= sin(y)) & (sin(y) <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi - y + asin(sin(y))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - y - asin(sin(y)) + pi), S.Integers)))) + + assert dumeq(solveset_real(sin(2*x)*cos(x) + cos(2*x)*sin(x)-1, x), + ImageSet(Lambda(n, n*pi*Rational(2, 3) + pi/6), S.Integers)) + + assert dumeq(solveset_real(2*tan(x)*sin(x) + 1, x), Union( + ImageSet(Lambda(n, 2*n*pi + atan(sqrt(2)*sqrt(-1 + sqrt(17))/ + (1 - sqrt(17))) + pi), S.Integers), + ImageSet(Lambda(n, 2*n*pi - atan(sqrt(2)*sqrt(-1 + sqrt(17))/ + (1 - sqrt(17))) + pi), S.Integers))) + + assert dumeq(solveset_real(cos(2*x)*cos(4*x) - 1, x), + ImageSet(Lambda(n, n*pi), S.Integers)) + + assert dumeq(solveset(sin(x/10) + Rational(3, 4)), Union( + ImageSet(Lambda(n, 20*n*pi - 10*asin(S(3)/4) + 20*pi), S.Integers), + ImageSet(Lambda(n, 20*n*pi + 10*asin(S(3)/4) + 10*pi), S.Integers))) + + assert dumeq(solveset(cos(x/15) + cos(x/5)), Union( + ImageSet(Lambda(n, 30*n*pi + 15*pi/2), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 45*pi/2), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 75*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 45*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 105*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 15*pi/4), S.Integers))) + + assert dumeq(solveset(sec(sqrt(2)*x/3) + 5), Union( + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*pi - asec(-5))/2), S.Integers), + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*pi + asec(-5))/2), S.Integers))) + + assert dumeq(simplify(solveset(tan(pi*x) - cot(pi/2*x))), Union( + ImageSet(Lambda(n, 4*n + 1), S.Integers), + ImageSet(Lambda(n, 4*n + 3), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(7, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(5, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(11, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(1, 3)), S.Integers))) + + assert dumeq(solveset(cos(9*x)), Union( + ImageSet(Lambda(n, 2*n*pi/9 + pi/18), S.Integers), + ImageSet(Lambda(n, 2*n*pi/9 + pi/6), S.Integers))) + + assert dumeq(solveset(sin(8*x) + cot(12*x), x, S.Reals), Union( + ImageSet(Lambda(n, n*pi/2 + pi/8), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 3*pi/8), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 5*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 3*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 7*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + pi/16), S.Integers))) + + # This is the only remaining solveset test that actually ends up being solved + # by _solve_trig2(). All others are handled by the improved _solve_trig1. + assert dumeq(solveset_real(2*cos(x)*cos(2*x) - 1, x), + Union(ImageSet(Lambda(n, 2*n*pi + 2*atan(sqrt(-2*2**Rational(1, 3)*(67 + + 9*sqrt(57))**Rational(2, 3) + 8*2**Rational(2, 3) + 11*(67 + + 9*sqrt(57))**Rational(1, 3))/(3*(67 + 9*sqrt(57))**Rational(1, 6)))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - 2*atan(sqrt(-2*2**Rational(1, 3)*(67 + + 9*sqrt(57))**Rational(2, 3) + 8*2**Rational(2, 3) + 11*(67 + + 9*sqrt(57))**Rational(1, 3))/(3*(67 + 9*sqrt(57))**Rational(1, 6))) + + 2*pi), S.Integers))) + + # issue #16870 + assert dumeq(simplify(solveset(sin(x/180*pi) - S.Half, x, S.Reals)), Union( + ImageSet(Lambda(n, 360*n + 150), S.Integers), + ImageSet(Lambda(n, 360*n + 30), S.Integers))) + + +def test_solve_trig_hyp_by_inversion(): + n = Dummy('n') + assert solveset_real(sin(2*x + 3) - S(1)/2, x).dummy_eq(Union( + ImageSet(Lambda(n, n*pi - S(3)/2 + 13*pi/12), S.Integers), + ImageSet(Lambda(n, n*pi - S(3)/2 + 17*pi/12), S.Integers))) + assert solveset_complex(sin(2*x + 3) - S(1)/2, x).dummy_eq(Union( + ImageSet(Lambda(n, n*pi - S(3)/2 + 13*pi/12), S.Integers), + ImageSet(Lambda(n, n*pi - S(3)/2 + 17*pi/12), S.Integers))) + assert solveset_real(tan(x) - tan(pi/10), x).dummy_eq( + ImageSet(Lambda(n, n*pi + pi/10), S.Integers)) + assert solveset_complex(tan(x) - tan(pi/10), x).dummy_eq( + ImageSet(Lambda(n, n*pi + pi/10), S.Integers)) + + assert solveset_real(3*cosh(2*x) - 5, x) == FiniteSet( + -acosh(S(5)/3)/2, acosh(S(5)/3)/2) + assert solveset_complex(3*cosh(2*x) - 5, x).dummy_eq(Union( + ImageSet(Lambda(n, n*I*pi - acosh(S(5)/3)/2), S.Integers), + ImageSet(Lambda(n, n*I*pi + acosh(S(5)/3)/2), S.Integers))) + assert solveset_real(sinh(x - 3) - 2, x) == FiniteSet( + asinh(2) + 3) + assert solveset_complex(sinh(x - 3) - 2, x).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi + asinh(2) + 3), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi - asinh(2) + 3 + I*pi), S.Integers))) + + assert solveset_real(cos(sinh(x))-cos(pi/12), x).dummy_eq(Union( + ImageSet(Lambda(n, asinh(2*n*pi + pi/12)), S.Integers), + ImageSet(Lambda(n, asinh(2*n*pi + 23*pi/12)), S.Integers))) + assert solveset(cos(sinh(x))-cos(pi/12), x, Interval(2,3)) == \ + FiniteSet(asinh(23*pi/12), asinh(25*pi/12)) + assert solveset_real(cosh(x**2-1)-2, x) == FiniteSet( + -sqrt(1 + acosh(2)), sqrt(1 + acosh(2))) + + assert solveset_real(sin(x) - 2, x) == S.EmptySet # issue #17334 + assert solveset_real(cos(x) + 2, x) == S.EmptySet + assert solveset_real(sec(x), x) == S.EmptySet + assert solveset_real(csc(x), x) == S.EmptySet + assert solveset_real(cosh(x) + 1, x) == S.EmptySet + assert solveset_real(coth(x), x) == S.EmptySet + assert solveset_real(sech(x) - 2, x) == S.EmptySet + assert solveset_real(sech(x), x) == S.EmptySet + assert solveset_real(tanh(x) + 1, x) == S.EmptySet + assert solveset_complex(tanh(x), 1) == S.EmptySet + assert solveset_complex(coth(x), -1) == S.EmptySet + assert solveset_complex(sech(x), 0) == S.EmptySet + assert solveset_complex(csch(x), 0) == S.EmptySet + + assert solveset_real(abs(csch(x)) - 3, x) == FiniteSet(-acsch(3), acsch(3)) + + assert solveset_real(tanh(x**2 - 1) - exp(-9), x) == FiniteSet( + -sqrt(atanh(exp(-9)) + 1), sqrt(atanh(exp(-9)) + 1)) + + assert solveset_real(coth(log(x)) + 2, x) == FiniteSet(exp(-acoth(2))) + assert solveset_real(coth(exp(x)) + 2, x) == S.EmptySet + + assert solveset_complex(sinh(x) - I/2, x).dummy_eq(Union( + ImageSet(Lambda(n, 2*I*pi*n + 5*I*pi/6), S.Integers), + ImageSet(Lambda(n, 2*I*pi*n + I*pi/6), S.Integers))) + assert solveset_complex(sinh(x/10) + Rational(3, 4), x).dummy_eq(Union( + ImageSet(Lambda(n, 20*n*I*pi - 10*asinh(S(3)/4)), S.Integers), + ImageSet(Lambda(n, 20*n*I*pi + 10*asinh(S(3)/4) + 10*I*pi), S.Integers))) + assert solveset_complex(sech(sqrt(2)*x/3) + 5, x).dummy_eq(Union( + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*I*pi - asech(-5))/2), S.Integers), + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*I*pi + asech(-5))/2), S.Integers))) + assert solveset_complex(cosh(9*x), x).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi/9 + I*pi/18), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi/9 + I*pi/6), S.Integers))) + + eq = (x**5 -4*x + 1).subs(x, coth(z)) + assert solveset(eq, z, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 0))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 1))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 2))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 3))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 4))), S.Integers))) + assert solveset(eq, z, S.Reals) == FiniteSet( + acoth(CRootOf(x**5 - 4*x + 1, 0)), acoth(CRootOf(x**5 - 4*x + 1, 2))) + + eq = ((x-sqrt(3)/2)*(x+2)).expand().subs(x, cos(x)) + assert solveset(eq, x, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi - acos(-2)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(-2)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + 11*pi/6), S.Integers))) + assert solveset(eq, x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + 11*pi/6), S.Integers))) + + assert solveset((1+sec(sqrt(3)*x+4)**2)/(1-sec(sqrt(3)*x+4))).dummy_eq(Union( + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 - asec(I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 + asec(I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 - asec(-I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 + asec(-I))/3), S.Integers))) + + assert all_close(solveset(tan(3.14*x)**(S(3)/2)-5.678, x, Interval(0, 3)), + FiniteSet(0.403301114561067, 0.403301114561067 + 0.318471337579618*pi, + 0.403301114561067 + 0.636942675159236*pi)) + + +def test_old_trig_issues(): + # issues #9606 / #9531: + assert solveset(sinh(x), x, S.Reals) == FiniteSet(0) + assert solveset(sinh(x), x, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + I*pi), S.Integers))) + + # issues #11218 / #18427 + assert solveset(sin(pi*x), x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, (2*n*pi + pi)/pi), S.Integers), + ImageSet(Lambda(n, 2*n), S.Integers))) + assert solveset(sin(pi*x), x).dummy_eq(Union( + ImageSet(Lambda(n, (2*n*pi + pi)/pi), S.Integers), + ImageSet(Lambda(n, 2*n), S.Integers))) + + # issue #17543 + assert solveset(I*cot(8*x - 8*E), x).dummy_eq( + ImageSet(Lambda(n, pi*n/8 - 13*pi/16 + E), S.Integers)) + + # issue #20798 + assert all_close(solveset(cos(2*x) - 0.5, x, Interval(0, 2*pi)), FiniteSet( + 0.523598775598299, -0.523598775598299 + pi, + -0.523598775598299 + 2*pi, 0.523598775598299 + pi)) + sol = Union(ImageSet(Lambda(n, n*pi - 0.523598775598299), S.Integers), + ImageSet(Lambda(n, n*pi + 0.523598775598299), S.Integers)) + ret = solveset(cos(2*x) - 0.5, x, S.Reals) + # replace Dummy n by the regular Symbol n to allow all_close comparison. + ret = ret.subs(ret.atoms(Dummy).pop(), n) + assert all_close(ret, sol) + ret = solveset(cos(2*x) - 0.5, x, S.Complexes) + ret = ret.subs(ret.atoms(Dummy).pop(), n) + assert all_close(ret, sol) + + # issue #21296 / #17667 + assert solveset(tan(x)-sqrt(2), x, Interval(0, pi/2)) == FiniteSet(atan(sqrt(2))) + assert solveset(tan(x)-pi, x, Interval(0, pi/2)) == FiniteSet(atan(pi)) + + # issue #17667 + # not yet working properly: + # solveset(cos(x)-y, x, Interval(0, pi)) + assert solveset(cos(x)-y, x, S.Reals).dummy_eq( + ConditionSet(x,(S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi - acos(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(y)), S.Integers)))) + + # issue #17579 + # Valid result, but the intersection could potentially be simplified. + assert solveset(sin(log(x)), x, Interval(0,1, True, False)).dummy_eq( + Union(Intersection(ImageSet(Lambda(n, exp(2*n*pi)), S.Integers), Interval.Lopen(0, 1)), + Intersection(ImageSet(Lambda(n, exp(2*n*pi + pi)), S.Integers), Interval.Lopen(0, 1)))) + + # issue #17334 + assert solveset(sin(x) - sin(1), x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + 1), S.Integers), + ImageSet(Lambda(n, 2*n*pi - 1 + pi), S.Integers))) + assert solveset(sin(x) - sqrt(5)/3, x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + asin(sqrt(5)/3)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - asin(sqrt(5)/3) + pi), S.Integers))) + assert solveset(sinh(x)-cosh(2), x, S.Reals) == FiniteSet(asinh(cosh(2))) + + # issue 9825 + assert solveset(Eq(tan(x), y), x, domain=S.Reals).dummy_eq( + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + atan(y)), S.Integers))) + r = Symbol('r', real=True) + assert solveset(Eq(tan(x), r), x, domain=S.Reals).dummy_eq( + ImageSet(Lambda(n, n*pi + atan(r)), S.Integers)) + + +def test_solve_hyperbolic(): + # actual solver: _solve_trig1 + n = Dummy('n') + assert solveset(sinh(x) + cosh(x), x) == S.EmptySet + assert solveset(sinh(x) + cos(x), x) == ConditionSet(x, + Eq(cos(x) + sinh(x), 0), S.Complexes) + assert solveset_real(sinh(x) + sech(x), x) == FiniteSet( + log(sqrt(sqrt(5) - 2))) + assert solveset_real(cosh(2*x) + 2*sinh(x) - 5, x) == FiniteSet( + log(-2 + sqrt(5)), log(1 + sqrt(2))) + assert solveset_real((coth(x) + sinh(2*x))/cosh(x) - 3, x) == FiniteSet( + log(S.Half + sqrt(5)/2), log(1 + sqrt(2))) + assert solveset_real(cosh(x)*sinh(x) - 2, x) == FiniteSet( + log(4 + sqrt(17))/2) + assert solveset_real(sinh(x) + tanh(x) - 1, x) == FiniteSet( + log(sqrt(2)/2 + sqrt(-S(1)/2 + sqrt(2)))) + + assert dumeq(solveset_complex(sinh(x) + sech(x), x), Union( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(-2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi/2) + log(sqrt(2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi) + log(sqrt(-2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - pi/2) + log(sqrt(2 + sqrt(5)))), S.Integers))) + + assert dumeq(solveset(cosh(x/15) + cosh(x/5)), Union( + ImageSet(Lambda(n, 15*I*(2*n*pi + pi/2)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - pi/2)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - 3*pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi + 3*pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi + pi/4)), S.Integers))) + + assert dumeq(solveset(tanh(pi*x) - coth(pi/2*x)), Union( + ImageSet(Lambda(n, 2*I*(2*n*pi + pi/2)/pi), S.Integers), + ImageSet(Lambda(n, 2*I*(2*n*pi - pi/2)/pi), S.Integers))) + + # issues #18490 / #19489 + assert solveset(cosh(x) + cosh(3*x) - cosh(5*x), x, S.Reals + ).dummy_eq(ConditionSet(x, + Eq(cosh(x) + cosh(3*x) - cosh(5*x), 0), S.Reals)) + assert solveset(sinh(8*x) + coth(12*x)).dummy_eq( + ConditionSet(x, Eq(sinh(8*x) + coth(12*x), 0), S.Complexes)) + + +def test_solve_trig_hyp_symbolic(): + # actual solver: invert_trig_hyp + assert dumeq(solveset(sin(a*x), x), ConditionSet(x, Ne(a, 0), Union( + ImageSet(Lambda(n, (2*n*pi + pi)/a), S.Integers), + ImageSet(Lambda(n, 2*n*pi/a), S.Integers)))) + + assert dumeq(solveset(cosh(x/a), x), ConditionSet(x, Ne(a, 0), Union( + ImageSet(Lambda(n, a*(2*n*I*pi + I*pi/2)), S.Integers), + ImageSet(Lambda(n, a*(2*n*I*pi + 3*I*pi/2)), S.Integers)))) + + assert dumeq(solveset(sin(2*sqrt(3)/3*a**2/(b*pi)*x) + + cos(4*sqrt(3)/3*a**2/(b*pi)*x), x), + ConditionSet(x, Ne(b, 0) & Ne(a**2, 0), Union( + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi + pi/2)/(2*a**2)), S.Integers), + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi - 5*pi/6)/(2*a**2)), S.Integers), + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi - pi/6)/(2*a**2)), S.Integers)))) + + assert dumeq(solveset(cosh((a**2 + 1)*x) - 3, x), ConditionSet( + x, Ne(a**2 + 1, 0), Union( + ImageSet(Lambda(n, (2*n*I*pi - acosh(3))/(a**2 + 1)), S.Integers), + ImageSet(Lambda(n, (2*n*I*pi + acosh(3))/(a**2 + 1)), S.Integers)))) + + ar = Symbol('ar', real=True) + assert solveset(cosh((ar**2 + 1)*x) - 2, x, S.Reals) == FiniteSet( + -acosh(2)/(ar**2 + 1), acosh(2)/(ar**2 + 1)) + + # actual solver: _solve_trig1 + assert dumeq(simplify(solveset(cot((1 + I)*x) - cot((3 + 3*I)*x), x)), Union( + ImageSet(Lambda(n, pi*(1 - I)*(4*n + 1)/4), S.Integers), + ImageSet(Lambda(n, pi*(1 - I)*(4*n - 1)/4), S.Integers))) + + +def test_issue_9616(): + assert dumeq(solveset(sinh(x) + tanh(x) - 1, x), Union( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - atan(sqrt(2)*sqrt(S.Half + sqrt(2))) + pi) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi) + log(-sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - pi + atan(sqrt(2)*sqrt(S.Half + sqrt(2)))) + + log(sqrt(1 + sqrt(2)))), S.Integers))) + f1 = (sinh(x)).rewrite(exp) + f2 = (tanh(x)).rewrite(exp) + assert dumeq(solveset(f1 + f2 - 1, x), Union( + Complement(ImageSet( + Lambda(n, I*(2*n*pi + pi) + log(-sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement(ImageSet(Lambda(n, I*(2*n*pi - pi + atan(sqrt(2)*sqrt(S.Half + sqrt(2)))) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement(ImageSet(Lambda(n, I*(2*n*pi - atan(sqrt(2)*sqrt(S.Half + sqrt(2))) + pi) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)))) + + +def test_solve_invalid_sol(): + assert 0 not in solveset_real(sin(x)/x, x) + assert 0 not in solveset_complex((exp(x) - 1)/x, x) + + +@XFAIL +def test_solve_trig_simplified(): + n = Dummy('n') + assert dumeq(solveset_real(sin(x), x), + imageset(Lambda(n, n*pi), S.Integers)) + + assert dumeq(solveset_real(cos(x), x), + imageset(Lambda(n, n*pi + pi/2), S.Integers)) + + assert dumeq(solveset_real(cos(x) + sin(x), x), + imageset(Lambda(n, n*pi - pi/4), S.Integers)) + + +@XFAIL +def test_solve_lambert(): + assert solveset_real(x*exp(x) - 1, x) == FiniteSet(LambertW(1)) + assert solveset_real(exp(x) + x, x) == FiniteSet(-LambertW(1)) + assert solveset_real(x + 2**x, x) == \ + FiniteSet(-LambertW(log(2))/log(2)) + + # issue 4739 + ans = solveset_real(3*x + 5 + 2**(-5*x + 3), x) + assert ans == FiniteSet(Rational(-5, 3) + + LambertW(-10240*2**Rational(1, 3)*log(2)/3)/(5*log(2))) + + eq = 2*(3*x + 4)**5 - 6*7**(3*x + 9) + result = solveset_real(eq, x) + ans = FiniteSet((log(2401) + + 5*LambertW(-log(7**(7*3**Rational(1, 5)/5))))/(3*log(7))/-1) + assert result == ans + assert solveset_real(eq.expand(), x) == result + + assert solveset_real(5*x - 1 + 3*exp(2 - 7*x), x) == \ + FiniteSet(Rational(1, 5) + LambertW(-21*exp(Rational(3, 5))/5)/7) + + assert solveset_real(2*x + 5 + log(3*x - 2), x) == \ + FiniteSet(Rational(2, 3) + LambertW(2*exp(Rational(-19, 3))/3)/2) + + assert solveset_real(3*x + log(4*x), x) == \ + FiniteSet(LambertW(Rational(3, 4))/3) + + assert solveset_real(x**x - 2) == FiniteSet(exp(LambertW(log(2)))) + + a = Symbol('a') + assert solveset_real(-a*x + 2*x*log(x), x) == FiniteSet(exp(a/2)) + a = Symbol('a', real=True) + assert solveset_real(a/x + exp(x/2), x) == \ + FiniteSet(2*LambertW(-a/2)) + assert solveset_real((a/x + exp(x/2)).diff(x), x) == \ + FiniteSet(4*LambertW(sqrt(2)*sqrt(a)/4)) + + # coverage test + assert solveset_real(tanh(x + 3)*tanh(x - 3) - 1, x) is S.EmptySet + + assert solveset_real((x**2 - 2*x + 1).subs(x, log(x) + 3*x), x) == \ + FiniteSet(LambertW(3*S.Exp1)/3) + assert solveset_real((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1), x) == \ + FiniteSet(LambertW(3*exp(-sqrt(2)))/3, LambertW(3*exp(sqrt(2)))/3) + assert solveset_real((x**2 - 2*x - 2).subs(x, log(x) + 3*x), x) == \ + FiniteSet(LambertW(3*exp(1 + sqrt(3)))/3, LambertW(3*exp(-sqrt(3) + 1))/3) + assert solveset_real(x*log(x) + 3*x + 1, x) == \ + FiniteSet(exp(-3 + LambertW(-exp(3)))) + eq = (x*exp(x) - 3).subs(x, x*exp(x)) + assert solveset_real(eq, x) == \ + FiniteSet(LambertW(3*exp(-LambertW(3)))) + + assert solveset_real(3*log(a**(3*x + 5)) + a**(3*x + 5), x) == \ + FiniteSet(-((log(a**5) + LambertW(Rational(1, 3)))/(3*log(a)))) + p = symbols('p', positive=True) + assert solveset_real(3*log(p**(3*x + 5)) + p**(3*x + 5), x) == \ + FiniteSet( + log((-3**Rational(1, 3) - 3**Rational(5, 6)*I)*LambertW(Rational(1, 3))**Rational(1, 3)/(2*p**Rational(5, 3)))/log(p), + log((-3**Rational(1, 3) + 3**Rational(5, 6)*I)*LambertW(Rational(1, 3))**Rational(1, 3)/(2*p**Rational(5, 3)))/log(p), + log((3*LambertW(Rational(1, 3))/p**5)**(1/(3*log(p)))),) # checked numerically + # check collection + b = Symbol('b') + eq = 3*log(a**(3*x + 5)) + b*log(a**(3*x + 5)) + a**(3*x + 5) + assert solveset_real(eq, x) == FiniteSet( + -((log(a**5) + LambertW(1/(b + 3)))/(3*log(a)))) + + # issue 4271 + assert solveset_real((a/x + exp(x/2)).diff(x, 2), x) == FiniteSet( + 6*LambertW((-1)**Rational(1, 3)*a**Rational(1, 3)/3)) + + assert solveset_real(x**3 - 3**x, x) == \ + FiniteSet(-3/log(3)*LambertW(-log(3)/3)) + assert solveset_real(3**cos(x) - cos(x)**3) == FiniteSet( + acos(-3*LambertW(-log(3)/3)/log(3))) + + assert solveset_real(x**2 - 2**x, x) == \ + solveset_real(-x**2 + 2**x, x) + + assert solveset_real(3*log(x) - x*log(3)) == FiniteSet( + -3*LambertW(-log(3)/3)/log(3), + -3*LambertW(-log(3)/3, -1)/log(3)) + + assert solveset_real(LambertW(2*x) - y) == FiniteSet( + y*exp(y)/2) + + +@XFAIL +def test_other_lambert(): + a = Rational(6, 5) + assert solveset_real(x**a - a**x, x) == FiniteSet( + a, -a*LambertW(-log(a)/a)/log(a)) + + +@_both_exp_pow +def test_solveset(): + f = Function('f') + raises(ValueError, lambda: solveset(x + y)) + assert solveset(x, 1) == S.EmptySet + assert solveset(f(1)**2 + y + 1, f(1) + ) == FiniteSet(-sqrt(-y - 1), sqrt(-y - 1)) + assert solveset(f(1)**2 - 1, f(1), S.Reals) == FiniteSet(-1, 1) + assert solveset(f(1)**2 + 1, f(1)) == FiniteSet(-I, I) + assert solveset(x - 1, 1) == FiniteSet(x) + assert solveset(sin(x) - cos(x), sin(x)) == FiniteSet(cos(x)) + + assert solveset(0, domain=S.Reals) == S.Reals + assert solveset(1) == S.EmptySet + assert solveset(True, domain=S.Reals) == S.Reals # issue 10197 + assert solveset(False, domain=S.Reals) == S.EmptySet + + assert solveset(exp(x) - 1, domain=S.Reals) == FiniteSet(0) + assert solveset(exp(x) - 1, x, S.Reals) == FiniteSet(0) + assert solveset(Eq(exp(x), 1), x, S.Reals) == FiniteSet(0) + assert solveset(exp(x) - 1, exp(x), S.Reals) == FiniteSet(1) + A = Indexed('A', x) + assert solveset(A - 1, A, S.Reals) == FiniteSet(1) + + assert solveset(x - 1 >= 0, x, S.Reals) == Interval(1, oo) + assert solveset(exp(x) - 1 >= 0, x, S.Reals) == Interval(0, oo) + + assert dumeq(solveset(exp(x) - 1, x), imageset(Lambda(n, 2*I*pi*n), S.Integers)) + assert dumeq(solveset(Eq(exp(x), 1), x), imageset(Lambda(n, 2*I*pi*n), + S.Integers)) + # issue 13825 + assert solveset(x**2 + f(0) + 1, x) == {-sqrt(-f(0) - 1), sqrt(-f(0) - 1)} + + # issue 19977 + assert solveset(atan(log(x)) > 0, x, domain=Interval.open(0, oo)) == Interval.open(1, oo) + + +@_both_exp_pow +def test_multi_exp(): + k1, k2, k3 = symbols('k1, k2, k3') + assert dumeq(solveset(exp(exp(x)) - 5, x),\ + imageset(Lambda(((k1, n),), I*(2*k1*pi + arg(2*n*I*pi + log(5))) + log(Abs(2*n*I*pi + log(5)))),\ + ProductSet(S.Integers, S.Integers))) + assert dumeq(solveset((d*exp(exp(a*x + b)) + c), x),\ + imageset(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k1, n),), \ + I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))), \ + ProductSet(S.Integers, S.Integers)))) + + assert dumeq(solveset((d*exp(exp(exp(a*x + b))) + c), x),\ + imageset(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k2, k1, n),), \ + I*(2*k2*pi + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + \ + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + \ + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))))), \ + ProductSet(S.Integers, S.Integers, S.Integers)))) + + assert dumeq(solveset((d*exp(exp(exp(exp(a*x + b)))) + c), x),\ + ImageSet(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k3, k2, k1, n),), \ + I*(2*k3*pi + arg(I*(2*k2*pi + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + \ + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + \ + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))))) + log(Abs(I*(2*k2*pi + \ + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + \ + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))))))), \ + ProductSet(S.Integers, S.Integers, S.Integers, S.Integers)))) + + +def test__solveset_multi(): + from sympy.solvers.solveset import _solveset_multi + from sympy.sets import Reals + + # Basic univariate case: + assert _solveset_multi([x**2-1], [x], [S.Reals]) == FiniteSet((1,), (-1,)) + + # Linear systems of two equations + assert _solveset_multi([x+y, x+1], [x, y], [Reals, Reals]) == FiniteSet((-1, 1)) + assert _solveset_multi([x+y, x+1], [y, x], [Reals, Reals]) == FiniteSet((1, -1)) + assert _solveset_multi([x+y, x-y-1], [x, y], [Reals, Reals]) == FiniteSet((S(1)/2, -S(1)/2)) + assert _solveset_multi([x-1, y-2], [x, y], [Reals, Reals]) == FiniteSet((1, 2)) + # assert dumeq(_solveset_multi([x+y], [x, y], [Reals, Reals]), ImageSet(Lambda(x, (x, -x)), Reals)) + assert dumeq(_solveset_multi([x+y], [x, y], [Reals, Reals]), Union( + ImageSet(Lambda(((x,),), (x, -x)), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (-y, y)), ProductSet(Reals)))) + assert _solveset_multi([x+y, x+y+1], [x, y], [Reals, Reals]) == S.EmptySet + assert _solveset_multi([x+y, x-y, x-1], [x, y], [Reals, Reals]) == S.EmptySet + assert _solveset_multi([x+y, x-y, x-1], [y, x], [Reals, Reals]) == S.EmptySet + + # Systems of three equations: + assert _solveset_multi([x+y+z-1, x+y-z-2, x-y-z-3], [x, y, z], [Reals, + Reals, Reals]) == FiniteSet((2, -S.Half, -S.Half)) + + # Nonlinear systems: + from sympy.abc import theta + assert _solveset_multi([x**2+y**2-2, x+y], [x, y], [Reals, Reals]) == FiniteSet((-1, 1), (1, -1)) + assert _solveset_multi([x**2-1, y], [x, y], [Reals, Reals]) == FiniteSet((1, 0), (-1, 0)) + #assert _solveset_multi([x**2-y**2], [x, y], [Reals, Reals]) == Union( + # ImageSet(Lambda(x, (x, -x)), Reals), ImageSet(Lambda(x, (x, x)), Reals)) + assert dumeq(_solveset_multi([x**2-y**2], [x, y], [Reals, Reals]), Union( + ImageSet(Lambda(((x,),), (x, -Abs(x))), ProductSet(Reals)), + ImageSet(Lambda(((x,),), (x, Abs(x))), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (-Abs(y), y)), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (Abs(y), y)), ProductSet(Reals)))) + assert _solveset_multi([r*cos(theta)-1, r*sin(theta)], [theta, r], + [Interval(0, pi), Interval(-1, 1)]) == FiniteSet((0, 1), (pi, -1)) + assert _solveset_multi([r*cos(theta)-1, r*sin(theta)], [r, theta], + [Interval(0, 1), Interval(0, pi)]) == FiniteSet((1, 0)) + assert _solveset_multi([r*cos(theta)-r, r*sin(theta)], [r, theta], + [Interval(0, 1), Interval(0, pi)]) == Union( + ImageSet(Lambda(((r,),), (r, 0)), + ImageSet(Lambda(r, (r,)), Interval(0, 1))), + ImageSet(Lambda(((theta,),), (0, theta)), + ImageSet(Lambda(theta, (theta,)), Interval(0, pi)))) + + +def test_conditionset(): + assert solveset(Eq(sin(x)**2 + cos(x)**2, 1), x, domain=S.Reals + ) is S.Reals + + assert solveset(Eq(x**2 + x*sin(x), 1), x, domain=S.Reals + ).dummy_eq(ConditionSet(x, Eq(x**2 + x*sin(x) - 1, 0), S.Reals)) + + assert dumeq(solveset(Eq(-I*(exp(I*x) - exp(-I*x))/2, 1), x + ), imageset(Lambda(n, 2*n*pi + pi/2), S.Integers)) + + assert solveset(x + sin(x) > 1, x, domain=S.Reals + ).dummy_eq(ConditionSet(x, x + sin(x) > 1, S.Reals)) + + assert solveset(Eq(sin(Abs(x)), x), x, domain=S.Reals + ).dummy_eq(ConditionSet(x, Eq(-x + sin(Abs(x)), 0), S.Reals)) + + assert solveset(y**x-z, x, S.Reals + ).dummy_eq(ConditionSet(x, Eq(y**x - z, 0), S.Reals)) + + +@XFAIL +def test_conditionset_equality(): + ''' Checking equality of different representations of ConditionSet''' + assert solveset(Eq(tan(x), y), x) == ConditionSet(x, Eq(tan(x), y), S.Complexes) + + +def test_solveset_domain(): + assert solveset(x**2 - x - 6, x, Interval(0, oo)) == FiniteSet(3) + assert solveset(x**2 - 1, x, Interval(0, oo)) == FiniteSet(1) + assert solveset(x**4 - 16, x, Interval(0, 10)) == FiniteSet(2) + + +def test_improve_coverage(): + solution = solveset(exp(x) + sin(x), x, S.Reals) + unsolved_object = ConditionSet(x, Eq(exp(x) + sin(x), 0), S.Reals) + assert solution.dummy_eq(unsolved_object) + + +def test_issue_9522(): + expr1 = Eq(1/(x**2 - 4) + x, 1/(x**2 - 4) + 2) + expr2 = Eq(1/x + x, 1/x) + + assert solveset(expr1, x, S.Reals) is S.EmptySet + assert solveset(expr2, x, S.Reals) is S.EmptySet + + +def test_solvify(): + assert solvify(x**2 + 10, x, S.Reals) == [] + assert solvify(x**3 + 1, x, S.Complexes) == [-1, S.Half - sqrt(3)*I/2, + S.Half + sqrt(3)*I/2] + assert solvify(log(x), x, S.Reals) == [1] + assert solvify(cos(x), x, S.Reals) == [pi/2, pi*Rational(3, 2)] + assert solvify(sin(x) + 1, x, S.Reals) == [pi*Rational(3, 2)] + raises(NotImplementedError, lambda: solvify(sin(exp(x)), x, S.Complexes)) + + +def test_solvify_piecewise(): + p1 = Piecewise((0, x < -1), (x**2, x <= 1), (log(x), True)) + p2 = Piecewise((0, x < -10), (x**2 + 5*x - 6, x >= -9)) + p3 = Piecewise((0, Eq(x, 0)), (x**2/Abs(x), True)) + p4 = Piecewise((0, Eq(x, pi)), ((x - pi)/sin(x), True)) + + # issue 21079 + assert solvify(p1, x, S.Reals) == [0] + assert solvify(p2, x, S.Reals) == [-6, 1] + assert solvify(p3, x, S.Reals) == [0] + assert solvify(p4, x, S.Reals) == [pi] + + +def test_abs_invert_solvify(): + + x = Symbol('x',positive=True) + assert solvify(sin(Abs(x)), x, S.Reals) == [0, pi] + x = Symbol('x') + assert solvify(sin(Abs(x)), x, S.Reals) is None + + +def test_linear_eq_to_matrix(): + assert linear_eq_to_matrix(0, x) == (Matrix([[0]]), Matrix([[0]])) + assert linear_eq_to_matrix(1, x) == (Matrix([[0]]), Matrix([[-1]])) + + # integer coefficients + eqns1 = [2*x + y - 2*z - 3, x - y - z, x + y + 3*z - 12] + eqns2 = [Eq(3*x + 2*y - z, 1), Eq(2*x - 2*y + 4*z, -2), -2*x + y - 2*z] + + A, B = linear_eq_to_matrix(eqns1, x, y, z) + assert A == Matrix([[2, 1, -2], [1, -1, -1], [1, 1, 3]]) + assert B == Matrix([[3], [0], [12]]) + + A, B = linear_eq_to_matrix(eqns2, x, y, z) + assert A == Matrix([[3, 2, -1], [2, -2, 4], [-2, 1, -2]]) + assert B == Matrix([[1], [-2], [0]]) + + # Pure symbolic coefficients + eqns3 = [a*b*x + b*y + c*z - d, e*x + d*x + f*y + g*z - h, i*x + j*y + k*z - l] + A, B = linear_eq_to_matrix(eqns3, x, y, z) + assert A == Matrix([[a*b, b, c], [d + e, f, g], [i, j, k]]) + assert B == Matrix([[d], [h], [l]]) + + # raise Errors if + # 1) no symbols are given + raises(ValueError, lambda: linear_eq_to_matrix(eqns3)) + # 2) there are duplicates + raises(ValueError, lambda: linear_eq_to_matrix(eqns3, [x, x, y])) + # 3) a nonlinear term is detected in the original expression + raises(NonlinearError, lambda: linear_eq_to_matrix(Eq(1/x + x, 1/x), [x])) + raises(NonlinearError, lambda: linear_eq_to_matrix([x**2], [x])) + raises(NonlinearError, lambda: linear_eq_to_matrix([x*y], [x, y])) + # 4) Eq being used to represent equations autoevaluates + # (use unevaluated Eq instead) + raises(ValueError, lambda: linear_eq_to_matrix(Eq(x, x), x)) + raises(ValueError, lambda: linear_eq_to_matrix(Eq(x, x + 1), x)) + + + # if non-symbols are passed, the user is responsible for interpreting + assert linear_eq_to_matrix([x], [1/x]) == (Matrix([[0]]), Matrix([[-x]])) + + # issue 15195 + assert linear_eq_to_matrix(x + y*(z*(3*x + 2) + 3), x) == ( + Matrix([[3*y*z + 1]]), Matrix([[-y*(2*z + 3)]])) + assert linear_eq_to_matrix(Matrix( + [[a*x + b*y - 7], [5*x + 6*y - c]]), x, y) == ( + Matrix([[a, b], [5, 6]]), Matrix([[7], [c]])) + + # issue 15312 + assert linear_eq_to_matrix(Eq(x + 2, 1), x) == ( + Matrix([[1]]), Matrix([[-1]])) + + # issue 25423 + raises(TypeError, lambda: linear_eq_to_matrix([], {x, y})) + raises(TypeError, lambda: linear_eq_to_matrix([x + y], {x, y})) + raises(ValueError, lambda: linear_eq_to_matrix({x + y}, (x, y))) + + +def test_issue_16577(): + assert linear_eq_to_matrix(Eq(a*(2*x + 3*y) + 4*y, 5), x, y) == ( + Matrix([[2*a, 3*a + 4]]), Matrix([[5]])) + + +def test_issue_10085(): + assert invert_real(exp(x),0,x) == (x, S.EmptySet) + + +def test_linsolve(): + x1, x2, x3, x4 = symbols('x1, x2, x3, x4') + + # Test for different input forms + + M = Matrix([[1, 2, 1, 1, 7], [1, 2, 2, -1, 12], [2, 4, 0, 6, 4]]) + system1 = A, B = M[:, :-1], M[:, -1] + Eqns = [x1 + 2*x2 + x3 + x4 - 7, x1 + 2*x2 + 2*x3 - x4 - 12, + 2*x1 + 4*x2 + 6*x4 - 4] + + sol = FiniteSet((-2*x2 - 3*x4 + 2, x2, 2*x4 + 5, x4)) + assert linsolve(Eqns, (x1, x2, x3, x4)) == sol + assert linsolve(Eqns, *(x1, x2, x3, x4)) == sol + assert linsolve(system1, (x1, x2, x3, x4)) == sol + assert linsolve(system1, *(x1, x2, x3, x4)) == sol + # issue 9667 - symbols can be Dummy symbols + x1, x2, x3, x4 = symbols('x:4', cls=Dummy) + assert linsolve(system1, x1, x2, x3, x4) == FiniteSet( + (-2*x2 - 3*x4 + 2, x2, 2*x4 + 5, x4)) + + # raise ValueError for garbage value + raises(ValueError, lambda: linsolve(Eqns)) + raises(ValueError, lambda: linsolve(x1)) + raises(ValueError, lambda: linsolve(x1, x2)) + raises(ValueError, lambda: linsolve((A,), x1, x2)) + raises(ValueError, lambda: linsolve(A, B, x1, x2)) + raises(ValueError, lambda: linsolve([x1], x1, x1)) + raises(ValueError, lambda: linsolve([x1], (i for i in (x1, x1)))) + + #raise ValueError if equations are non-linear in given variables + raises(NonlinearError, lambda: linsolve([x + y - 1, x ** 2 + y - 3], [x, y])) + raises(NonlinearError, lambda: linsolve([cos(x) + y, x + y], [x, y])) + assert linsolve([x + z - 1, x ** 2 + y - 3], [z, y]) == {(-x + 1, -x**2 + 3)} + + # Fully symbolic test + A = Matrix([[a, b], [c, d]]) + B = Matrix([[e], [g]]) + system2 = (A, B) + sol = FiniteSet(((-b*g + d*e)/(a*d - b*c), (a*g - c*e)/(a*d - b*c))) + assert linsolve(system2, [x, y]) == sol + + # No solution + A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + B = Matrix([0, 0, 1]) + assert linsolve((A, B), (x, y, z)) is S.EmptySet + + # Issue #10056 + A, B, J1, J2 = symbols('A B J1 J2') + Augmatrix = Matrix([ + [2*I*J1, 2*I*J2, -2/J1], + [-2*I*J2, -2*I*J1, 2/J2], + [0, 2, 2*I/(J1*J2)], + [2, 0, 0], + ]) + + assert linsolve(Augmatrix, A, B) == FiniteSet((0, I/(J1*J2))) + + # Issue #10121 - Assignment of free variables + Augmatrix = Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]]) + assert linsolve(Augmatrix, a, b, c, d, e) == FiniteSet((a, 0, c, 0, e)) + #raises(IndexError, lambda: linsolve(Augmatrix, a, b, c)) + + x0, x1, x2, _x0 = symbols('tau0 tau1 tau2 _tau0') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + x0, x1, x2, _x0 = symbols('tau00 tau01 tau02 tau0') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + x0, x1, x2, _x0 = symbols('tau00 tau01 tau02 tau1') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + # symbols can be given as generators + x0, x2, x4 = symbols('x0, x2, x4') + assert linsolve(Augmatrix, numbered_symbols('x') + ) == FiniteSet((x0, 0, x2, 0, x4)) + Augmatrix[-1, -1] = x0 + # use Dummy to avoid clash; the names may clash but the symbols + # will not + Augmatrix[-1, -1] = symbols('_x0') + assert len(linsolve( + Augmatrix, numbered_symbols('x', cls=Dummy)).free_symbols) == 4 + + # Issue #12604 + f = Function('f') + assert linsolve([f(x) - 5], f(x)) == FiniteSet((5,)) + + # Issue #14860 + from sympy.physics.units import meter, newton, kilo + kN = kilo*newton + Eqns = [8*kN + x + y, 28*kN*meter + 3*x*meter] + assert linsolve(Eqns, x, y) == { + (kilo*newton*Rational(-28, 3), kN*Rational(4, 3))} + + # linsolve does not allow expansion (real or implemented) + # to remove singularities, but it will cancel linear terms + assert linsolve([Eq(x, x + y)], [x, y]) == {(x, 0)} + assert linsolve([Eq(x + x*y, 1 + y)], [x]) == {(1,)} + assert linsolve([Eq(1 + y, x + x*y)], [x]) == {(1,)} + raises(NonlinearError, lambda: + linsolve([Eq(x**2, x**2 + y)], [x, y])) + + # corner cases + # + # XXX: The case below should give the same as for [0] + # assert linsolve([], [x]) == {(x,)} + assert linsolve([], [x]) is S.EmptySet + assert linsolve([0], [x]) == {(x,)} + assert linsolve([x], [x, y]) == {(0, y)} + assert linsolve([x, 0], [x, y]) == {(0, y)} + + +def test_linsolve_large_sparse(): + # + # This is mainly a performance test + # + + def _mk_eqs_sol(n): + xs = symbols('x:{}'.format(n)) + ys = symbols('y:{}'.format(n)) + syms = xs + ys + eqs = [] + sol = (-S.Half,) * n + (S.Half,) * n + for xi, yi in zip(xs, ys): + eqs.extend([xi + yi, xi - yi + 1]) + return eqs, syms, FiniteSet(sol) + + n = 500 + eqs, syms, sol = _mk_eqs_sol(n) + assert linsolve(eqs, syms) == sol + + +def test_linsolve_immutable(): + A = ImmutableDenseMatrix([[1, 1, 2], [0, 1, 2], [0, 0, 1]]) + B = ImmutableDenseMatrix([2, 1, -1]) + assert linsolve([A, B], (x, y, z)) == FiniteSet((1, 3, -1)) + + A = ImmutableDenseMatrix([[1, 1, 7], [1, -1, 3]]) + assert linsolve(A) == FiniteSet((5, 2)) + + +def test_solve_decomposition(): + n = Dummy('n') + + f1 = exp(3*x) - 6*exp(2*x) + 11*exp(x) - 6 + f2 = sin(x)**2 - 2*sin(x) + 1 + f3 = sin(x)**2 - sin(x) + f4 = sin(x + 1) + f5 = exp(x + 2) - 1 + f6 = 1/log(x) + f7 = 1/x + + s1 = ImageSet(Lambda(n, 2*n*pi), S.Integers) + s2 = ImageSet(Lambda(n, 2*n*pi + pi), S.Integers) + s3 = ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers) + s4 = ImageSet(Lambda(n, 2*n*pi - 1), S.Integers) + s5 = ImageSet(Lambda(n, 2*n*pi - 1 + pi), S.Integers) + + assert solve_decomposition(f1, x, S.Reals) == FiniteSet(0, log(2), log(3)) + assert dumeq(solve_decomposition(f2, x, S.Reals), s3) + assert dumeq(solve_decomposition(f3, x, S.Reals), Union(s1, s2, s3)) + assert dumeq(solve_decomposition(f4, x, S.Reals), Union(s4, s5)) + assert solve_decomposition(f5, x, S.Reals) == FiniteSet(-2) + assert solve_decomposition(f6, x, S.Reals) == S.EmptySet + assert solve_decomposition(f7, x, S.Reals) == S.EmptySet + assert solve_decomposition(x, x, Interval(1, 2)) == S.EmptySet + + +# nonlinsolve testcases +def test_nonlinsolve_basic(): + assert nonlinsolve([],[]) == S.EmptySet + assert nonlinsolve([],[x, y]) == S.EmptySet + + system = [x, y - x - 5] + assert nonlinsolve([x],[x, y]) == FiniteSet((0, y)) + assert nonlinsolve(system, [y]) == S.EmptySet + soln = (ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers),) + assert dumeq(nonlinsolve([sin(x) - 1], [x]), FiniteSet(tuple(soln))) + soln = ((ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), 1), + (ImageSet(Lambda(n, 2*n*pi), S.Integers), 1)) + assert dumeq(nonlinsolve([sin(x), y - 1], [x, y]), FiniteSet(*soln)) + assert nonlinsolve([x**2 - 1], [x]) == FiniteSet((-1,), (1,)) + + soln = FiniteSet((y, y)) + assert nonlinsolve([x - y, 0], x, y) == soln + assert nonlinsolve([0, x - y], x, y) == soln + assert nonlinsolve([x - y, x - y], x, y) == soln + assert nonlinsolve([x, 0], x, y) == FiniteSet((0, y)) + f = Function('f') + assert nonlinsolve([f(x), 0], f(x), y) == FiniteSet((0, y)) + assert nonlinsolve([f(x), 0], f(x), f(y)) == FiniteSet((0, f(y))) + A = Indexed('A', x) + assert nonlinsolve([A, 0], A, y) == FiniteSet((0, y)) + assert nonlinsolve([x**2 -1], [sin(x)]) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([x**2 -1], sin(x)) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([x**2 -1], 1) == FiniteSet((x**2,)) + assert nonlinsolve([x**2 -1], x + y) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([Eq(1, x + y), Eq(1, -x + y - 1), Eq(1, -x + y - 1)], x, y) == FiniteSet( + (-S.Half, 3*S.Half)) + + +def test_nonlinsolve_abs(): + soln = FiniteSet((y, y), (-y, y)) + assert nonlinsolve([Abs(x) - y], x, y) == soln + + +def test_raise_exception_nonlinsolve(): + raises(IndexError, lambda: nonlinsolve([x**2 -1], [])) + raises(ValueError, lambda: nonlinsolve([x**2 -1])) + + +def test_trig_system(): + # TODO: add more simple testcases when solveset returns + # simplified soln for Trig eq + assert nonlinsolve([sin(x) - 1, cos(x) -1 ], x) == S.EmptySet + soln1 = (ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers),) + soln = FiniteSet(soln1) + assert dumeq(nonlinsolve([sin(x) - 1, cos(x)], x), soln) + + +@XFAIL +def test_trig_system_fail(): + # fails because solveset trig solver is not much smart. + sys = [x + y - pi/2, sin(x) + sin(y) - 1] + # solveset returns conditionset for sin(x) + sin(y) - 1 + soln_1 = (ImageSet(Lambda(n, n*pi + pi/2), S.Integers), + ImageSet(Lambda(n, n*pi), S.Integers)) + soln_1 = FiniteSet(soln_1) + soln_2 = (ImageSet(Lambda(n, n*pi), S.Integers), + ImageSet(Lambda(n, n*pi+ pi/2), S.Integers)) + soln_2 = FiniteSet(soln_2) + soln = soln_1 + soln_2 + assert dumeq(nonlinsolve(sys, [x, y]), soln) + + # Add more cases from here + # http://www.vitutor.com/geometry/trigonometry/equations_systems.html#uno + sys = [sin(x) + sin(y) - (sqrt(3)+1)/2, sin(x) - sin(y) - (sqrt(3) - 1)/2] + soln_x = Union(ImageSet(Lambda(n, 2*n*pi + pi/3), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi*Rational(2, 3)), S.Integers)) + soln_y = Union(ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi*Rational(5, 6)), S.Integers)) + assert dumeq(nonlinsolve(sys, [x, y]), FiniteSet((soln_x, soln_y))) + + +def test_nonlinsolve_positive_dimensional(): + x, y, a, b, c, d = symbols('x, y, a, b, c, d', extended_real=True) + assert nonlinsolve([x*y, x*y - x], [x, y]) == FiniteSet((0, y)) + + system = [a**2 + a*c, a - b] + assert nonlinsolve(system, [a, b]) == FiniteSet((0, 0), (-c, -c)) + # here (a= 0, b = 0) is independent soln so both is printed. + # if symbols = [a, b, c] then only {a : -c ,b : -c} + + eq1 = a + b + c + d + eq2 = a*b + b*c + c*d + d*a + eq3 = a*b*c + b*c*d + c*d*a + d*a*b + eq4 = a*b*c*d - 1 + system = [eq1, eq2, eq3, eq4] + sol1 = (-1/d, -d, 1/d, FiniteSet(d) - FiniteSet(0)) + sol2 = (1/d, -d, -1/d, FiniteSet(d) - FiniteSet(0)) + soln = FiniteSet(sol1, sol2) + assert nonlinsolve(system, [a, b, c, d]) == soln + + assert nonlinsolve([x**4 - 3*x**2 + y*x, x*z**2, y*z - 1], [x, y, z]) == \ + {(0, 1/z, z)} + + +def test_nonlinsolve_polysys(): + x, y, z = symbols('x, y, z', real=True) + assert nonlinsolve([x**2 + y - 2, x**2 + y], [x, y]) == S.EmptySet + + s = (-y + 2, y) + assert nonlinsolve([(x + y)**2 - 4, x + y - 2], [x, y]) == FiniteSet(s) + + system = [x**2 - y**2] + soln_real = FiniteSet((-y, y), (y, y)) + soln_complex = FiniteSet((-Abs(y), y), (Abs(y), y)) + soln =soln_real + soln_complex + assert nonlinsolve(system, [x, y]) == soln + + system = [x**2 - y**2] + soln_real= FiniteSet((y, -y), (y, y)) + soln_complex = FiniteSet((y, -Abs(y)), (y, Abs(y))) + soln = soln_real + soln_complex + assert nonlinsolve(system, [y, x]) == soln + + system = [x**2 + y - 3, x - y - 4] + assert nonlinsolve(system, (x, y)) != nonlinsolve(system, (y, x)) + + assert nonlinsolve([-x**2 - y**2 + z, -2*x, -2*y, S.One], [x, y, z]) == S.EmptySet + assert nonlinsolve([x + y + z, S.One, S.One, S.One], [x, y, z]) == S.EmptySet + + system = [-x**2*z**2 + x*y*z + y**4, -2*x*z**2 + y*z, x*z + 4*y**3, -2*x**2*z + x*y] + assert nonlinsolve(system, [x, y, z]) == FiniteSet((0, 0, z), (x, 0, 0)) + + +def test_nonlinsolve_using_substitution(): + x, y, z, n = symbols('x, y, z, n', real = True) + system = [(x + y)*n - y**2 + 2] + s_x = (n*y - y**2 + 2)/n + soln = (-s_x, y) + assert nonlinsolve(system, [x, y]) == FiniteSet(soln) + + system = [z**2*x**2 - z**2*y**2/exp(x)] + soln_real_1 = (y, x, 0) + soln_real_2 = (-exp(x/2)*Abs(x), x, z) + soln_real_3 = (exp(x/2)*Abs(x), x, z) + soln_complex_1 = (-x*exp(x/2), x, z) + soln_complex_2 = (x*exp(x/2), x, z) + syms = [y, x, z] + soln = FiniteSet(soln_real_1, soln_complex_1, soln_complex_2,\ + soln_real_2, soln_real_3) + assert nonlinsolve(system,syms) == soln + + +def test_nonlinsolve_complex(): + n = Dummy('n') + assert dumeq(nonlinsolve([exp(x) - sin(y), 1/y - 3], [x, y]), { + (ImageSet(Lambda(n, 2*n*I*pi + log(sin(Rational(1, 3)))), S.Integers), Rational(1, 3))}) + + system = [exp(x) - sin(y), 1/exp(y) - 3] + assert dumeq(nonlinsolve(system, [x, y]), { + (ImageSet(Lambda(n, I*(2*n*pi + pi) + + log(sin(log(3)))), S.Integers), -log(3)), + (ImageSet(Lambda(n, I*(2*n*pi + arg(sin(2*n*I*pi - log(3)))) + + log(Abs(sin(2*n*I*pi - log(3))))), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi - log(3)), S.Integers))}) + + system = [exp(x) - sin(y), y**2 - 4] + assert dumeq(nonlinsolve(system, [x, y]), { + (ImageSet(Lambda(n, I*(2*n*pi + pi) + log(sin(2))), S.Integers), -2), + (ImageSet(Lambda(n, 2*n*I*pi + log(sin(2))), S.Integers), 2)}) + + system = [exp(x) - 2, y ** 2 - 2] + assert dumeq(nonlinsolve(system, [x, y]), { + (log(2), -sqrt(2)), (log(2), sqrt(2)), + (ImageSet(Lambda(n, 2*n*I*pi + log(2)), S.Integers), -sqrt(2)), + (ImageSet(Lambda(n, 2 * n * I * pi + log(2)), S.Integers), sqrt(2))}) + + +def test_nonlinsolve_radical(): + assert nonlinsolve([sqrt(y) - x - z, y - 1], [x, y, z]) == {(1 - z, 1, z)} + + +def test_nonlinsolve_inexact(): + sol = [(-1.625, -1.375), (1.625, 1.375)] + res = nonlinsolve([(x + y)**2 - 9, x**2 - y**2 - 0.75], [x, y]) + assert all(abs(res.args[i][j]-sol[i][j]) < 1e-9 + for i in range(2) for j in range(2)) + + assert nonlinsolve([(x + y)**2 - 9, (x + y)**2 - 0.75], [x, y]) == S.EmptySet + + assert nonlinsolve([y**2 + (x - 0.5)**2 - 0.0625, 2*x - 1.0, 2*y], [x, y]) == \ + S.EmptySet + + res = nonlinsolve([x**2 + y - 0.5, (x + y)**2, log(z)], [x, y, z]) + sol = [(-0.366025403784439, 0.366025403784439, 1), + (-0.366025403784439, 0.366025403784439, 1), + (1.36602540378444, -1.36602540378444, 1)] + assert all(abs(res.args[i][j]-sol[i][j]) < 1e-9 + for i in range(3) for j in range(3)) + + res = nonlinsolve([y - x**2, x**5 - x + 1.0], [x, y]) + sol = [(-1.16730397826142, 1.36259857766493), + (-0.181232444469876 - 1.08395410131771*I, + -1.14211129483496 + 0.392895302949911*I), + (-0.181232444469876 + 1.08395410131771*I, + -1.14211129483496 - 0.392895302949911*I), + (0.764884433600585 - 0.352471546031726*I, + 0.460812006002492 - 0.539199997693599*I), + (0.764884433600585 + 0.352471546031726*I, + 0.460812006002492 + 0.539199997693599*I)] + assert all(abs(res.args[i][j] - sol[i][j]) < 1e-9 + for i in range(5) for j in range(2)) + +@XFAIL +def test_solve_nonlinear_trans(): + # After the transcendental equation solver these will work + x, y = symbols('x, y', real=True) + soln1 = FiniteSet((2*LambertW(y/2), y)) + soln2 = FiniteSet((-x*sqrt(exp(x)), y), (x*sqrt(exp(x)), y)) + soln3 = FiniteSet((x*exp(x/2), x)) + soln4 = FiniteSet(2*LambertW(y/2), y) + assert nonlinsolve([x**2 - y**2/exp(x)], [x, y]) == soln1 + assert nonlinsolve([x**2 - y**2/exp(x)], [y, x]) == soln2 + assert nonlinsolve([x**2 - y**2/exp(x)], [y, x]) == soln3 + assert nonlinsolve([x**2 - y**2/exp(x)], [x, y]) == soln4 + + +def test_nonlinsolve_issue_25182(): + a1, b1, c1, ca, cb, cg = symbols('a1, b1, c1, ca, cb, cg') + eq1 = a1*a1 + b1*b1 - 2.*a1*b1*cg - c1*c1 + eq2 = a1*a1 + c1*c1 - 2.*a1*c1*cb - b1*b1 + eq3 = b1*b1 + c1*c1 - 2.*b1*c1*ca - a1*a1 + assert nonlinsolve([eq1, eq2, eq3], [c1, cb, cg]) == FiniteSet( + (1.0*b1*ca - 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2), + -1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1, + -1.0*b1*(ca - 1)*(ca + 1)/a1 + 1.0*ca*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1), + (1.0*b1*ca + 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2), + 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1, + -1.0*b1*(ca - 1)*(ca + 1)/a1 - 1.0*ca*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1)) + + +def test_issue_14642(): + x = Symbol('x') + n1 = 0.5*x**3+x**2+0.5+I #add I in the Polynomials + solution = solveset(n1, x) + assert abs(solution.args[0] - (-2.28267560928153 - 0.312325580497716*I)) <= 1e-9 + assert abs(solution.args[1] - (-0.297354141679308 + 1.01904778618762*I)) <= 1e-9 + assert abs(solution.args[2] - (0.580029750960839 - 0.706722205689907*I)) <= 1e-9 + + # Symbolic + n1 = S.Half*x**3+x**2+S.Half+I + res = FiniteSet(-((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49) + /2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2))/3)/3 - S(2)/3 - 4*cos(atan((27 + + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)* + 31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2))/3)/(3*((3* + sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)**2)**(S(1)/ + 6)) + I*(-((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/ + 2)/2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos( + atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49) + /2)/2 + S(43)/2))/3)/3 + 4*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172) + /49)/2)/2 + S(43)/2))/3)/(3*((3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6))), -S(2)/3 - sqrt(3)*((3* + sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)**2)**(S(1) + /6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)) + /3)/6 - 4*re(1/((-S(1)/2 - sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + + (43 + 54*I)**2)/2)**(S(1)/3)))/3 + ((3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)* + 31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)* + sin(atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 + I*(-4*im(1/((-S(1)/2 - + sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1)/ + 3)))/3 + ((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/ + 49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/ + 4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2))/3)/6), -S(2)/3 - 4*re(1/((-S(1)/2 + + sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1) + /3)))/3 + sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + ((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + I*(-sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos( + atan(S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**( + S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 + ((3*sqrt(3)*31985**(S(1)/4)* + sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**( + S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 - 4*im(1/((-S(1)/2 + sqrt(3)*I/2)* + (S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1)/3)))/3)) + + assert solveset(n1, x) == res + + +def test_issue_13961(): + V = (ax, bx, cx, gx, jx, lx, mx, nx, q) = symbols('ax bx cx gx jx lx mx nx q') + S = (ax*q - lx*q - mx, ax - gx*q - lx, bx*q**2 + cx*q - jx*q - nx, q*(-ax*q + lx*q + mx), q*(-ax + gx*q + lx)) + + sol = FiniteSet((lx + mx/q, (-cx*q + jx*q + nx)/q**2, cx, mx/q**2, jx, lx, mx, nx, Complement({q}, {0})), + (lx + mx/q, (cx*q - jx*q - nx)/q**2*-1, cx, mx/q**2, jx, lx, mx, nx, Complement({q}, {0}))) + assert nonlinsolve(S, *V) == sol + # The two solutions are in fact identical, so even better if only one is returned + + +def test_issue_14541(): + solutions = solveset(sqrt(-x**2 - 2.0), x) + assert abs(solutions.args[0]+1.4142135623731*I) <= 1e-9 + assert abs(solutions.args[1]-1.4142135623731*I) <= 1e-9 + + +def test_issue_13396(): + expr = -2*y*exp(-x**2 - y**2)*Abs(x) + sol = FiniteSet(0) + + assert solveset(expr, y, domain=S.Reals) == sol + + # Related type of equation also solved here + assert solveset(atan(x**2 - y**2)-pi/2, y, S.Reals) is S.EmptySet + + +def test_issue_12032(): + sol = FiniteSet(-sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 + + sqrt(Abs(-2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2, + -sqrt(Abs(-2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2 - + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2, + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 - + I*sqrt(Abs(-2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) - + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2, + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 + + I*sqrt(Abs(-2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) - + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1,3)))))/2) + assert solveset(x**4 + x - 1, x) == sol + + +def test_issue_10876(): + assert solveset(1/sqrt(x), x) == S.EmptySet + + +def test_issue_19050(): + # test_issue_19050 --> TypeError removed + assert dumeq(nonlinsolve([x + y, sin(y)], [x, y]), + FiniteSet((ImageSet(Lambda(n, -2*n*pi), S.Integers), ImageSet(Lambda(n, 2*n*pi), S.Integers)),\ + (ImageSet(Lambda(n, -2*n*pi - pi), S.Integers), ImageSet(Lambda(n, 2*n*pi + pi), S.Integers)))) + assert dumeq(nonlinsolve([x + y, sin(y) + cos(y)], [x, y]), + FiniteSet((ImageSet(Lambda(n, -2*n*pi - 3*pi/4), S.Integers), ImageSet(Lambda(n, 2*n*pi + 3*pi/4), S.Integers)), \ + (ImageSet(Lambda(n, -2*n*pi - 7*pi/4), S.Integers), ImageSet(Lambda(n, 2*n*pi + 7*pi/4), S.Integers)))) + + +def test_issue_16618(): + eqn = [sin(x)*sin(y), cos(x)*cos(y) - 1] + # nonlinsolve's answer is still suspicious since it contains only three + # distinct Dummys instead of 4. (Both 'x' ImageSets share the same Dummy.) + ans = FiniteSet((ImageSet(Lambda(n, 2*n*pi), S.Integers), ImageSet(Lambda(n, 2*n*pi), S.Integers)), + (ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), ImageSet(Lambda(n, 2*n*pi + pi), S.Integers))) + sol = nonlinsolve(eqn, [x, y]) + + for i0, j0 in zip(ordered(sol), ordered(ans)): + assert len(i0) == len(j0) == 2 + assert all(a.dummy_eq(b) for a, b in zip(i0, j0)) + assert len(sol) == len(ans) + + +def test_issue_17566(): + assert nonlinsolve([32*(2**x)/2**(-y) - 4**y, 27*(3**x) - S(1)/3**y], x, y) ==\ + FiniteSet((-log(81)/log(3), 1)) + + +def test_issue_16643(): + n = Dummy('n') + assert solveset(x**2*sin(x), x).dummy_eq(Union(ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), + ImageSet(Lambda(n, 2*n*pi), S.Integers))) + + +def test_issue_19587(): + n,m = symbols('n m') + assert nonlinsolve([32*2**m*2**n - 4**n, 27*3**m - 3**(-n)], m, n) ==\ + FiniteSet((-log(81)/log(3), 1)) + + +def test_issue_5132_1(): + system = [sqrt(x**2 + y**2) - sqrt(10), x + y - 4] + assert nonlinsolve(system, [x, y]) == FiniteSet((1, 3), (3, 1)) + + n = Dummy('n') + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + s_real_y = -log(3) + s_real_z = sqrt(-exp(2*x) - sin(log(3))) + soln_real = FiniteSet((s_real_y, s_real_z), (s_real_y, -s_real_z)) + lam = Lambda(n, 2*n*I*pi + -log(3)) + s_complex_y = ImageSet(lam, S.Integers) + lam = Lambda(n, sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_1 = ImageSet(lam, S.Integers) + lam = Lambda(n, -sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_2 = ImageSet(lam, S.Integers) + soln_complex = FiniteSet( + (s_complex_y, s_complex_z_1), + (s_complex_y, s_complex_z_2) + ) + soln = soln_real + soln_complex + assert dumeq(nonlinsolve(eqs, [y, z]), soln) + + +def test_issue_5132_2(): + x, y = symbols('x, y', real=True) + eqs = [exp(x)**2 - sin(y) + z**2] + n = Dummy('n') + soln_real = (log(-z**2 + sin(y))/2, z) + lam = Lambda( n, I*(2*n*pi + arg(-z**2 + sin(y)))/2 + log(Abs(z**2 - sin(y)))/2) + img = ImageSet(lam, S.Integers) + # not sure about the complex soln. But it looks correct. + soln_complex = (img, z) + soln = FiniteSet(soln_real, soln_complex) + assert dumeq(nonlinsolve(eqs, [x, z]), soln) + + system = [r - x**2 - y**2, tan(t) - y/x] + s_x = sqrt(r/(tan(t)**2 + 1)) + s_y = sqrt(r/(tan(t)**2 + 1))*tan(t) + soln = FiniteSet((s_x, s_y), (-s_x, -s_y)) + assert nonlinsolve(system, [x, y]) == soln + + +def test_issue_6752(): + a, b = symbols('a, b', real=True) + assert nonlinsolve([a**2 + a, a - b], [a, b]) == {(-1, -1), (0, 0)} + + +@SKIP("slow") +def test_issue_5114_solveset(): + # slow testcase + from sympy.abc import o, p + + # there is no 'a' in the equation set but this is how the + # problem was originally posed + syms = [a, b, c, f, h, k, n] + eqs = [b + r/d - c/d, + c*(1/d + 1/e + 1/g) - f/g - r/d, + f*(1/g + 1/i + 1/j) - c/g - h/i, + h*(1/i + 1/l + 1/m) - f/i - k/m, + k*(1/m + 1/o + 1/p) - h/m - n/p, + n*(1/p + 1/q) - k/p] + assert len(nonlinsolve(eqs, syms)) == 1 + + +@SKIP("Hangs") +def _test_issue_5335(): + # Not able to check zero dimensional system. + # is_zero_dimensional Hangs + lam, a0, conc = symbols('lam a0 conc') + eqs = [lam + 2*y - a0*(1 - x/2)*x - 0.005*x/2*x, + a0*(1 - x/2)*x - 1*y - 0.743436700916726*y, + x + y - conc] + sym = [x, y, a0] + # there are 4 solutions but only two are valid + assert len(nonlinsolve(eqs, sym)) == 2 + # float + eqs = [lam + 2*y - a0*(1 - x/2)*x - 0.005*x/2*x, + a0*(1 - x/2)*x - 1*y - 0.743436700916726*y, + x + y - conc] + sym = [x, y, a0] + assert len(nonlinsolve(eqs, sym)) == 2 + + +def test_issue_2777(): + # the equations represent two circles + x, y = symbols('x y', real=True) + e1, e2 = sqrt(x**2 + y**2) - 10, sqrt(y**2 + (-x + 10)**2) - 3 + a, b = Rational(191, 20), 3*sqrt(391)/20 + ans = {(a, -b), (a, b)} + assert nonlinsolve((e1, e2), (x, y)) == ans + assert nonlinsolve((e1, e2/(x - a)), (x, y)) == S.EmptySet + # make the 2nd circle's radius be -3 + e2 += 6 + assert nonlinsolve((e1, e2), (x, y)) == S.EmptySet + + +def test_issue_8828(): + x1 = 0 + y1 = -620 + r1 = 920 + x2 = 126 + y2 = 276 + x3 = 51 + y3 = 205 + r3 = 104 + v = [x, y, z] + + f1 = (x - x1)**2 + (y - y1)**2 - (r1 - z)**2 + f2 = (x2 - x)**2 + (y2 - y)**2 - z**2 + f3 = (x - x3)**2 + (y - y3)**2 - (r3 - z)**2 + F = [f1, f2, f3] + + g1 = sqrt((x - x1)**2 + (y - y1)**2) + z - r1 + g2 = f2 + g3 = sqrt((x - x3)**2 + (y - y3)**2) + z - r3 + G = [g1, g2, g3] + + # both soln same + A = nonlinsolve(F, v) + B = nonlinsolve(G, v) + assert A == B + + +def test_nonlinsolve_conditionset(): + # when solveset failed to solve all the eq + # return conditionset + f = Function('f') + f1 = f(x) - pi/2 + f2 = f(y) - pi*Rational(3, 2) + intermediate_system = Eq(2*f(x) - pi, 0) & Eq(2*f(y) - 3*pi, 0) + syms = Tuple(x, y) + soln = ConditionSet( + syms, + intermediate_system, + S.Complexes**2) + assert nonlinsolve([f1, f2], [x, y]) == soln + + +def test_substitution_basic(): + assert substitution([], [x, y]) == S.EmptySet + assert substitution([], []) == S.EmptySet + system = [2*x**2 + 3*y**2 - 30, 3*x**2 - 2*y**2 - 19] + soln = FiniteSet((-3, -2), (-3, 2), (3, -2), (3, 2)) + assert substitution(system, [x, y]) == soln + + soln = FiniteSet((-1, 1)) + assert substitution([x + y], [x], [{y: 1}], [y], set(), [x, y]) == soln + assert substitution( + [x + y], [x], [{y: 1}], [y], + {x + 1}, [y, x]) == S.EmptySet + + +def test_substitution_incorrect(): + # the solutions in the following two tests are incorrect. The + # correct result is EmptySet in both cases. + assert substitution([h - 1, k - 1, f - 2, f - 4, -2 * k], + [h, k, f]) == {(1, 1, f)} + assert substitution([x + y + z, S.One, S.One, S.One], [x, y, z]) == \ + {(-y - z, y, z)} + + # the correct result in the test below is {(-I, I, I, -I), + # (I, -I, -I, I)} + assert substitution([a - d, b + d, c + d, d**2 + 1], [a, b, c, d]) == \ + {(d, -d, -d, d)} + + # the result in the test below is incomplete. The complete result + # is {(0, b), (log(2), 2)} + assert substitution([a*(a - log(b)), a*(b - 2)], [a, b]) == \ + {(0, b)} + + # The system in the test below is zero-dimensional, so the result + # should have no free symbols + assert substitution([-k*y + 6*x - 4*y, -81*k + 49*y**2 - 270, + -3*k*z + k + z**3, k**2 - 2*k + 4], + [x, y, z, k]).free_symbols == {z} + + +def test_substitution_redundant(): + # the third and fourth solutions are redundant in the test below + assert substitution([x**2 - y**2, z - 1], [x, z]) == \ + {(-y, 1), (y, 1), (-sqrt(y**2), 1), (sqrt(y**2), 1)} + + # the system below has three solutions. Two of the solutions + # returned by substitution are redundant. + res = substitution([x - y, y**3 - 3*y**2 + 1], [x, y]) + assert len(res) == 5 + + +def test_issue_5132_substitution(): + x, y, z, r, t = symbols('x, y, z, r, t', real=True) + system = [r - x**2 - y**2, tan(t) - y/x] + s_x_1 = Complement(FiniteSet(-sqrt(r/(tan(t)**2 + 1))), FiniteSet(0)) + s_x_2 = Complement(FiniteSet(sqrt(r/(tan(t)**2 + 1))), FiniteSet(0)) + s_y = sqrt(r/(tan(t)**2 + 1))*tan(t) + soln = FiniteSet((s_x_2, s_y)) + FiniteSet((s_x_1, -s_y)) + assert substitution(system, [x, y]) == soln + + n = Dummy('n') + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + s_real_y = -log(3) + s_real_z = sqrt(-exp(2*x) - sin(log(3))) + soln_real = FiniteSet((s_real_y, s_real_z), (s_real_y, -s_real_z)) + lam = Lambda(n, 2*n*I*pi + -log(3)) + s_complex_y = ImageSet(lam, S.Integers) + lam = Lambda(n, sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_1 = ImageSet(lam, S.Integers) + lam = Lambda(n, -sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_2 = ImageSet(lam, S.Integers) + soln_complex = FiniteSet( + (s_complex_y, s_complex_z_1), + (s_complex_y, s_complex_z_2)) + soln = soln_real + soln_complex + assert dumeq(substitution(eqs, [y, z]), soln) + + +def test_raises_substitution(): + raises(ValueError, lambda: substitution([x**2 -1], [])) + raises(TypeError, lambda: substitution([x**2 -1])) + raises(ValueError, lambda: substitution([x**2 -1], [sin(x)])) + raises(TypeError, lambda: substitution([x**2 -1], x)) + raises(TypeError, lambda: substitution([x**2 -1], 1)) + + +def test_issue_21022(): + from sympy.core.sympify import sympify + + eqs = [ + 'k-16', + 'p-8', + 'y*y+z*z-x*x', + 'd - x + p', + 'd*d+k*k-y*y', + 'z*z-p*p-k*k', + 'abc-efg', + ] + efg = Symbol('efg') + eqs = [sympify(x) for x in eqs] + + syb = list(ordered(set.union(*[x.free_symbols for x in eqs]))) + res = nonlinsolve(eqs, syb) + + ans = FiniteSet( + (efg, 32, efg, 16, 8, 40, -16*sqrt(5), -8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, -16*sqrt(5), 8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, 16*sqrt(5), -8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, 16*sqrt(5), 8*sqrt(5)), + ) + assert len(res) == len(ans) == 4 + assert res == ans + for result in res.args: + assert len(result) == 8 + + +def test_issue_17940(): + n = Dummy('n') + k1 = Dummy('k1') + sol = ImageSet(Lambda(((k1, n),), I*(2*k1*pi + arg(2*n*I*pi + log(5))) + + log(Abs(2*n*I*pi + log(5)))), + ProductSet(S.Integers, S.Integers)) + assert solveset(exp(exp(x)) - 5, x).dummy_eq(sol) + + +def test_issue_17906(): + assert solveset(7**(x**2 - 80) - 49**x, x) == FiniteSet(-8, 10) + + +@XFAIL +def test_issue_17933(): + eq1 = x*sin(45) - y*cos(q) + eq2 = x*cos(45) - y*sin(q) + eq3 = 9*x*sin(45)/10 + y*cos(q) + eq4 = 9*x*cos(45)/10 + y*sin(z) - z + assert nonlinsolve([eq1, eq2, eq3, eq4], x, y, z, q) ==\ + FiniteSet((0, 0, 0, q)) + +def test_issue_17933_bis(): + # nonlinsolve's result depends on the 'default_sort_key' ordering of + # the unknowns. + eq1 = x*sin(45) - y*cos(q) + eq2 = x*cos(45) - y*sin(q) + eq3 = 9*x*sin(45)/10 + y*cos(q) + eq4 = 9*x*cos(45)/10 + y*sin(z) - z + zz = Symbol('zz') + eqs = [e.subs(q, zz) for e in (eq1, eq2, eq3, eq4)] + assert nonlinsolve(eqs, x, y, z, zz) == FiniteSet((0, 0, 0, zz)) + + +def test_issue_14565(): + # removed redundancy + assert dumeq(nonlinsolve([k + m, k + m*exp(-2*pi*k)], [k, m]) , + FiniteSet((-n*I, ImageSet(Lambda(n, n*I), S.Integers)))) + + +# end of tests for nonlinsolve + + +def test_issue_9556(): + b = Symbol('b', positive=True) + + assert solveset(Abs(x) + 1, x, S.Reals) is S.EmptySet + assert solveset(Abs(x) + b, x, S.Reals) is S.EmptySet + assert solveset(Eq(b, -1), b, S.Reals) is S.EmptySet + + +def test_issue_9611(): + assert solveset(Eq(x - x + a, a), x, S.Reals) == S.Reals + assert solveset(Eq(y - y + a, a), y) == S.Complexes + + +def test_issue_9557(): + assert solveset(x**2 + a, x, S.Reals) == Intersection(S.Reals, + FiniteSet(-sqrt(-a), sqrt(-a))) + + +def test_issue_9778(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset(x**3 + 1, x, S.Reals) == FiniteSet(-1) + assert solveset(x**Rational(3, 5) + 1, x, S.Reals) == S.EmptySet + assert solveset(x**3 + y, x, S.Reals) == \ + FiniteSet(-Abs(y)**Rational(1, 3)*sign(y)) + + +def test_issue_10214(): + assert solveset(x**Rational(3, 2) + 4, x, S.Reals) == S.EmptySet + assert solveset(x**(Rational(-3, 2)) + 4, x, S.Reals) == S.EmptySet + + ans = FiniteSet(-2**Rational(2, 3)) + assert solveset(x**(S(3)) + 4, x, S.Reals) == ans + assert (x**(S(3)) + 4).subs(x,list(ans)[0]) == 0 # substituting ans and verifying the result. + assert (x**(S(3)) + 4).subs(x,-(-2)**Rational(2, 3)) == 0 + + +def test_issue_9849(): + assert solveset(Abs(sin(x)) + 1, x, S.Reals) == S.EmptySet + + +def test_issue_9953(): + assert linsolve([ ], x) == S.EmptySet + + +def test_issue_9913(): + assert solveset(2*x + 1/(x - 10)**2, x, S.Reals) == \ + FiniteSet(-(3*sqrt(24081)/4 + Rational(4027, 4))**Rational(1, 3)/3 - 100/ + (3*(3*sqrt(24081)/4 + Rational(4027, 4))**Rational(1, 3)) + Rational(20, 3)) + + +def test_issue_10397(): + assert solveset(sqrt(x), x, S.Complexes) == FiniteSet(0) + + +def test_issue_14987(): + raises(ValueError, lambda: linear_eq_to_matrix( + [x**2], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [x*(-3/x + 1) + 2*y - a], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x**2 - 3*x)/(x - 3) - 3], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x + 1)**3 - x**3 - 3*x**2 + 7], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [x*(1/x + 1) + y], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x + 1)*y], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(1/x, 1/x + y)], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(y/x, y/x + y)], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(x*(x + 1), x**2 + y)], [x, y])) + + +def test_simplification(): + eq = x + (a - b)/(-2*a + 2*b) + assert solveset(eq, x) == FiniteSet(S.Half) + assert solveset(eq, x, S.Reals) == Intersection({-((a - b)/(-2*a + 2*b))}, S.Reals) + # So that ap - bn is not zero: + ap = Symbol('ap', positive=True) + bn = Symbol('bn', negative=True) + eq = x + (ap - bn)/(-2*ap + 2*bn) + assert solveset(eq, x) == FiniteSet(S.Half) + assert solveset(eq, x, S.Reals) == FiniteSet(S.Half) + + +def test_integer_domain_relational(): + eq1 = 2*x + 3 > 0 + eq2 = x**2 + 3*x - 2 >= 0 + eq3 = x + 1/x > -2 + 1/x + eq4 = x + sqrt(x**2 - 5) > 0 + eq = x + 1/x > -2 + 1/x + eq5 = eq.subs(x,log(x)) + eq6 = log(x)/x <= 0 + eq7 = log(x)/x < 0 + eq8 = x/(x-3) < 3 + eq9 = x/(x**2-3) < 3 + + assert solveset(eq1, x, S.Integers) == Range(-1, oo, 1) + assert solveset(eq2, x, S.Integers) == Union(Range(-oo, -3, 1), Range(1, oo, 1)) + assert solveset(eq3, x, S.Integers) == Union(Range(-1, 0, 1), Range(1, oo, 1)) + assert solveset(eq4, x, S.Integers) == Range(3, oo, 1) + assert solveset(eq5, x, S.Integers) == Range(2, oo, 1) + assert solveset(eq6, x, S.Integers) == Range(1, 2, 1) + assert solveset(eq7, x, S.Integers) == S.EmptySet + assert solveset(eq8, x, domain=Range(0,5)) == Range(0, 3, 1) + assert solveset(eq9, x, domain=Range(0,5)) == Union(Range(0, 2, 1), Range(2, 5, 1)) + + # test_issue_19794 + assert solveset(x + 2 < 0, x, S.Integers) == Range(-oo, -2, 1) + + +def test_issue_10555(): + f = Function('f') + g = Function('g') + assert solveset(f(x) - pi/2, x, S.Reals).dummy_eq( + ConditionSet(x, Eq(f(x) - pi/2, 0), S.Reals)) + assert solveset(f(g(x)) - pi/2, g(x), S.Reals).dummy_eq( + ConditionSet(g(x), Eq(f(g(x)) - pi/2, 0), S.Reals)) + + +def test_issue_8715(): + eq = x + 1/x > -2 + 1/x + assert solveset(eq, x, S.Reals) == \ + (Interval.open(-2, oo) - FiniteSet(0)) + assert solveset(eq.subs(x,log(x)), x, S.Reals) == \ + Interval.open(exp(-2), oo) - FiniteSet(1) + + +def test_issue_11174(): + eq = z**2 + exp(2*x) - sin(y) + soln = Intersection(S.Reals, FiniteSet(log(-z**2 + sin(y))/2)) + assert solveset(eq, x, S.Reals) == soln + + eq = sqrt(r)*Abs(tan(t))/sqrt(tan(t)**2 + 1) + x*tan(t) + s = -sqrt(r)*Abs(tan(t))/(sqrt(tan(t)**2 + 1)*tan(t)) + soln = Intersection(S.Reals, FiniteSet(s)) + assert solveset(eq, x, S.Reals) == soln + + +def test_issue_11534(): + # eq1 and eq2 should not have the same solutions because squaring both + # sides of the radical equation introduces a spurious solution branch. + # The equations have a symbolic parameter y and it is easy to see that for + # y != 0 the solution s1 will not be valid for eq1. + x = Symbol('x', real=True) + y = Symbol('y', real=True) + eq1 = -y + x/sqrt(-x**2 + 1) + eq2 = -y**2 + x**2/(-x**2 + 1) + + # We get a ConditionSet here because s1 works in eq1 if y is equal to zero + # although not for any other value of y. That case is redundant though + # because if y=0 then s1=s2 so the solution for eq1 could just be returned + # as s2 - {-1, 1}. In fact we have + # |y/sqrt(y**2 + 1)| < 1 + # So the complements are not needed either. The ideal output here would be + # sol1 = s2 + # sol2 = s1 | s2. + s1, s2 = FiniteSet(-y/sqrt(y**2 + 1)), FiniteSet(y/sqrt(y**2 + 1)) + cset = ConditionSet(x, Eq(eq1, 0), s1) + sol1 = (s2 - {-1, 1}) | (cset - {-1, 1}) + sol2 = (s1 | s2) - {-1, 1} + + assert solveset(eq1, x, S.Reals) == sol1 + assert solveset(eq2, x, S.Reals) == sol2 + + +def test_issue_10477(): + assert solveset((x**2 + 4*x - 3)/x < 2, x, S.Reals) == \ + Union(Interval.open(-oo, -3), Interval.open(0, 1)) + + +def test_issue_10671(): + assert solveset(sin(y), y, Interval(0, pi)) == FiniteSet(0, pi) + i = Interval(1, 10) + assert solveset((1/x).diff(x) < 0, x, i) == i + + +def test_issue_11064(): + eq = x + sqrt(x**2 - 5) + assert solveset(eq > 0, x, S.Reals) == \ + Interval(sqrt(5), oo) + assert solveset(eq < 0, x, S.Reals) == \ + Interval(-oo, -sqrt(5)) + assert solveset(eq > sqrt(5), x, S.Reals) == \ + Interval.Lopen(sqrt(5), oo) + + +def test_issue_12478(): + eq = sqrt(x - 2) + 2 + soln = solveset_real(eq, x) + assert soln is S.EmptySet + assert solveset(eq < 0, x, S.Reals) is S.EmptySet + assert solveset(eq > 0, x, S.Reals) == Interval(2, oo) + + +def test_issue_12429(): + eq = solveset(log(x)/x <= 0, x, S.Reals) + sol = Interval.Lopen(0, 1) + assert eq == sol + + +def test_issue_19506(): + eq = arg(x + I) + C = Dummy('C') + assert solveset(eq).dummy_eq(Intersection(ConditionSet(C, Eq(im(C) + 1, 0), S.Complexes), + ConditionSet(C, re(C) > 0, S.Complexes))) + + +def test_solveset_arg(): + assert solveset(arg(x), x, S.Reals) == Interval.open(0, oo) + assert solveset(arg(4*x -3), x, S.Reals) == Interval.open(Rational(3, 4), oo) + + +def test__is_finite_with_finite_vars(): + f = _is_finite_with_finite_vars + # issue 12482 + assert all(f(1/x) is None for x in ( + Dummy(), Dummy(real=True), Dummy(complex=True))) + assert f(1/Dummy(real=False)) is True # b/c it's finite but not 0 + + +def test_issue_13550(): + assert solveset(x**2 - 2*x - 15, symbol = x, domain = Interval(-oo, 0)) == FiniteSet(-3) + + +def test_issue_13849(): + assert nonlinsolve((t*(sqrt(5) + sqrt(2)) - sqrt(2), t), t) is S.EmptySet + + +def test_issue_14223(): + assert solveset((Abs(x + Min(x, 2)) - 2).rewrite(Piecewise), x, + S.Reals) == FiniteSet(-1, 1) + assert solveset((Abs(x + Min(x, 2)) - 2).rewrite(Piecewise), x, + Interval(0, 2)) == FiniteSet(1) + assert solveset(x, x, FiniteSet(1, 2)) is S.EmptySet + + +def test_issue_10158(): + dom = S.Reals + assert solveset(x*Max(x, 15) - 10, x, dom) == FiniteSet(Rational(2, 3)) + assert solveset(x*Min(x, 15) - 10, x, dom) == FiniteSet(-sqrt(10), sqrt(10)) + assert solveset(Max(Abs(x - 3) - 1, x + 2) - 3, x, dom) == FiniteSet(-1, 1) + assert solveset(Abs(x - 1) - Abs(y), x, dom) == FiniteSet(-Abs(y) + 1, Abs(y) + 1) + assert solveset(Abs(x + 4*Abs(x + 1)), x, dom) == FiniteSet(Rational(-4, 3), Rational(-4, 5)) + assert solveset(2*Abs(x + Abs(x + Max(3, x))) - 2, x, S.Reals) == FiniteSet(-1, -2) + dom = S.Complexes + raises(ValueError, lambda: solveset(x*Max(x, 15) - 10, x, dom)) + raises(ValueError, lambda: solveset(x*Min(x, 15) - 10, x, dom)) + raises(ValueError, lambda: solveset(Max(Abs(x - 3) - 1, x + 2) - 3, x, dom)) + raises(ValueError, lambda: solveset(Abs(x - 1) - Abs(y), x, dom)) + raises(ValueError, lambda: solveset(Abs(x + 4*Abs(x + 1)), x, dom)) + + +def test_issue_14300(): + f = 1 - exp(-18000000*x) - y + a1 = FiniteSet(-log(-y + 1)/18000000) + + assert solveset(f, x, S.Reals) == \ + Intersection(S.Reals, a1) + assert dumeq(solveset(f, x), + ImageSet(Lambda(n, -I*(2*n*pi + arg(-y + 1))/18000000 - + log(Abs(y - 1))/18000000), S.Integers)) + + +def test_issue_14454(): + number = CRootOf(x**4 + x - 1, 2) + raises(ValueError, lambda: invert_real(number, 0, x)) + assert invert_real(x**2, number, x) # no error + + +def test_issue_17882(): + assert solveset(-8*x**2/(9*(x**2 - 1)**(S(4)/3)) + 4/(3*(x**2 - 1)**(S(1)/3)), x, S.Complexes) == \ + FiniteSet(sqrt(3), -sqrt(3)) + + +def test_term_factors(): + assert list(_term_factors(3**x - 2)) == [-2, 3**x] + expr = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + assert set(_term_factors(expr)) == { + 3**(x + 2), 4**(x + 2), 3**(x + 3), 4**(x - 1), -1, 4**(x + 1)} + + +#################### tests for transolve and its helpers ############### + +def test_transolve(): + + assert _transolve(3**x, x, S.Reals) == S.EmptySet + assert _transolve(3**x - 9**(x + 5), x, S.Reals) == FiniteSet(-10) + + +def test_issue_21276(): + eq = (2*x*(y - z) - y*erf(y - z) - y + z*erf(y - z) + z)**2 + assert solveset(eq.expand(), y) == FiniteSet(z, z + erfinv(2*x - 1)) + + +# exponential tests +def test_exponential_real(): + from sympy.abc import y + + e1 = 3**(2*x) - 2**(x + 3) + e2 = 4**(5 - 9*x) - 8**(2 - x) + e3 = 2**x + 4**x + e4 = exp(log(5)*x) - 2**x + e5 = exp(x/y)*exp(-z/y) - 2 + e6 = 5**(x/2) - 2**(x/3) + e7 = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + e8 = -9*exp(-2*x + 5) + 4*exp(3*x + 1) + e9 = 2**x + 4**x + 8**x - 84 + e10 = 29*2**(x + 1)*615**(x) - 123*2726**(x) + + assert solveset(e1, x, S.Reals) == FiniteSet( + -3*log(2)/(-2*log(3) + log(2))) + assert solveset(e2, x, S.Reals) == FiniteSet(Rational(4, 15)) + assert solveset(e3, x, S.Reals) == S.EmptySet + assert solveset(e4, x, S.Reals) == FiniteSet(0) + assert solveset(e5, x, S.Reals) == Intersection( + S.Reals, FiniteSet(y*log(2*exp(z/y)))) + assert solveset(e6, x, S.Reals) == FiniteSet(0) + assert solveset(e7, x, S.Reals) == FiniteSet(2) + assert solveset(e8, x, S.Reals) == FiniteSet(-2*log(2)/5 + 2*log(3)/5 + Rational(4, 5)) + assert solveset(e9, x, S.Reals) == FiniteSet(2) + assert solveset(e10,x, S.Reals) == FiniteSet((-log(29) - log(2) + log(123))/(-log(2726) + log(2) + log(615))) + + assert solveset_real(-9*exp(-2*x + 5) + 2**(x + 1), x) == FiniteSet( + -((-5 - 2*log(3) + log(2))/(log(2) + 2))) + assert solveset_real(4**(x/2) - 2**(x/3), x) == FiniteSet(0) + b = sqrt(6)*sqrt(log(2))/sqrt(log(5)) + assert solveset_real(5**(x/2) - 2**(3/x), x) == FiniteSet(-b, b) + + # coverage test + C1, C2 = symbols('C1 C2') + f = Function('f') + assert solveset_real(C1 + C2/x**2 - exp(-f(x)), f(x)) == Intersection( + S.Reals, FiniteSet(-log(C1 + C2/x**2))) + y = symbols('y', positive=True) + assert solveset_real(x**2 - y**2/exp(x), y) == Intersection( + S.Reals, FiniteSet(-sqrt(x**2*exp(x)), sqrt(x**2*exp(x)))) + p = Symbol('p', positive=True) + assert solveset_real((1/p + 1)**(p + 1), p).dummy_eq( + ConditionSet(x, Eq((1 + 1/x)**(x + 1), 0), S.Reals)) + assert solveset(2**x - 4**x + 12, x, S.Reals) == {2} + assert solveset(2**x - 2**(2*x) + 12, x, S.Reals) == {2} + + +@XFAIL +def test_exponential_complex(): + n = Dummy('n') + + assert dumeq(solveset_complex(2**x + 4**x, x),imageset( + Lambda(n, I*(2*n*pi + pi)/log(2)), S.Integers)) + assert solveset_complex(x**z*y**z - 2, z) == FiniteSet( + log(2)/(log(x) + log(y))) + assert dumeq(solveset_complex(4**(x/2) - 2**(x/3), x), imageset( + Lambda(n, 3*n*I*pi/log(2)), S.Integers)) + assert dumeq(solveset(2**x + 32, x), imageset( + Lambda(n, (I*(2*n*pi + pi) + 5*log(2))/log(2)), S.Integers)) + + eq = (2**exp(y**2/x) + 2)/(x**2 + 15) + a = sqrt(x)*sqrt(-log(log(2)) + log(log(2) + 2*n*I*pi)) + assert solveset_complex(eq, y) == FiniteSet(-a, a) + + union1 = imageset(Lambda(n, I*(2*n*pi - pi*Rational(2, 3))/log(2)), S.Integers) + union2 = imageset(Lambda(n, I*(2*n*pi + pi*Rational(2, 3))/log(2)), S.Integers) + assert dumeq(solveset(2**x + 4**x + 8**x, x), Union(union1, union2)) + + eq = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + res = solveset(eq, x) + num = 2*n*I*pi - 4*log(2) + 2*log(3) + den = -2*log(2) + log(3) + ans = imageset(Lambda(n, num/den), S.Integers) + assert dumeq(res, ans) + + +def test_expo_conditionset(): + + f1 = (exp(x) + 1)**x - 2 + f2 = (x + 2)**y*x - 3 + f3 = 2**x - exp(x) - 3 + f4 = log(x) - exp(x) + f5 = 2**x + 3**x - 5**x + + assert solveset(f1, x, S.Reals).dummy_eq(ConditionSet( + x, Eq((exp(x) + 1)**x - 2, 0), S.Reals)) + assert solveset(f2, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(x*(x + 2)**y - 3, 0), S.Reals)) + assert solveset(f3, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(2**x - exp(x) - 3, 0), S.Reals)) + assert solveset(f4, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(-exp(x) + log(x), 0), S.Reals)) + assert solveset(f5, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(2**x + 3**x - 5**x, 0), S.Reals)) + + +def test_exponential_symbols(): + x, y, z = symbols('x y z', positive=True) + xr, zr = symbols('xr, zr', real=True) + + assert solveset(z**x - y, x, S.Reals) == Intersection( + S.Reals, FiniteSet(log(y)/log(z))) + + f1 = 2*x**w - 4*y**w + f2 = (x/y)**w - 2 + sol1 = Intersection({log(2)/(log(x) - log(y))}, S.Reals) + sol2 = Intersection({log(2)/log(x/y)}, S.Reals) + assert solveset(f1, w, S.Reals) == sol1, solveset(f1, w, S.Reals) + assert solveset(f2, w, S.Reals) == sol2, solveset(f2, w, S.Reals) + + assert solveset(x**x, x, Interval.Lopen(0,oo)).dummy_eq( + ConditionSet(w, Eq(w**w, 0), Interval.open(0, oo))) + assert solveset(x**y - 1, y, S.Reals) == FiniteSet(0) + assert solveset(exp(x/y)*exp(-z/y) - 2, y, S.Reals) == \ + Complement(ConditionSet(y, Eq(im(x)/y, 0) & Eq(im(z)/y, 0), \ + Complement(Intersection(FiniteSet((x - z)/log(2)), S.Reals), FiniteSet(0))), FiniteSet(0)) + assert solveset(exp(xr/y)*exp(-zr/y) - 2, y, S.Reals) == \ + Complement(FiniteSet((xr - zr)/log(2)), FiniteSet(0)) + + assert solveset(a**x - b**x, x).dummy_eq(ConditionSet( + w, Ne(a, 0) & Ne(b, 0), FiniteSet(0))) + + +def test_ignore_assumptions(): + # make sure assumptions are ignored + xpos = symbols('x', positive=True) + x = symbols('x') + assert solveset_complex(xpos**2 - 4, xpos + ) == solveset_complex(x**2 - 4, x) + + +@XFAIL +def test_issue_10864(): + assert solveset(x**(y*z) - x, x, S.Reals) == FiniteSet(1) + + +@XFAIL +def test_solve_only_exp_2(): + assert solveset_real(sqrt(exp(x)) + sqrt(exp(-x)) - 4, x) == \ + FiniteSet(2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)) + + +def test_is_exponential(): + assert _is_exponential(y, x) is False + assert _is_exponential(3**x - 2, x) is True + assert _is_exponential(5**x - 7**(2 - x), x) is True + assert _is_exponential(sin(2**x) - 4*x, x) is False + assert _is_exponential(x**y - z, y) is True + assert _is_exponential(x**y - z, x) is False + assert _is_exponential(2**x + 4**x - 1, x) is True + assert _is_exponential(x**(y*z) - x, x) is False + assert _is_exponential(x**(2*x) - 3**x, x) is False + assert _is_exponential(x**y - y*z, y) is False + assert _is_exponential(x**y - x*z, y) is True + + +def test_solve_exponential(): + assert _solve_exponential(3**(2*x) - 2**(x + 3), 0, x, S.Reals) == \ + FiniteSet(-3*log(2)/(-2*log(3) + log(2))) + assert _solve_exponential(2**y + 4**y, 1, y, S.Reals) == \ + FiniteSet(log(Rational(-1, 2) + sqrt(5)/2)/log(2)) + assert _solve_exponential(2**y + 4**y, 0, y, S.Reals) == \ + S.EmptySet + assert _solve_exponential(2**x + 3**x - 5**x, 0, x, S.Reals) == \ + ConditionSet(x, Eq(2**x + 3**x - 5**x, 0), S.Reals) + +# end of exponential tests + + +# logarithmic tests +def test_logarithmic(): + assert solveset_real(log(x - 3) + log(x + 3), x) == FiniteSet( + -sqrt(10), sqrt(10)) + assert solveset_real(log(x + 1) - log(2*x - 1), x) == FiniteSet(2) + assert solveset_real(log(x + 3) + log(1 + 3/x) - 3, x) == FiniteSet( + -3 + sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 + exp(3)/2, + -sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 - 3 + exp(3)/2) + + eq = z - log(x) + log(y/(x*(-1 + y**2/x**2))) + assert solveset_real(eq, x) == \ + Intersection(S.Reals, FiniteSet(-sqrt(y**2 - y*exp(z)), + sqrt(y**2 - y*exp(z)))) - \ + Intersection(S.Reals, FiniteSet(-sqrt(y**2), sqrt(y**2))) + assert solveset_real( + log(3*x) - log(-x + 1) - log(4*x + 1), x) == FiniteSet(Rational(-1, 2), S.Half) + assert solveset(log(x**y) - y*log(x), x, S.Reals) == S.Reals + +@XFAIL +def test_uselogcombine_2(): + eq = log(exp(2*x) + 1) + log(-tanh(x) + 1) - log(2) + assert solveset_real(eq, x) is S.EmptySet + eq = log(8*x) - log(sqrt(x) + 1) - 2 + assert solveset_real(eq, x) is S.EmptySet + + +def test_is_logarithmic(): + assert _is_logarithmic(y, x) is False + assert _is_logarithmic(log(x), x) is True + assert _is_logarithmic(log(x) - 3, x) is True + assert _is_logarithmic(log(x)*log(y), x) is True + assert _is_logarithmic(log(x)**2, x) is False + assert _is_logarithmic(log(x - 3) + log(x + 3), x) is True + assert _is_logarithmic(log(x**y) - y*log(x), x) is True + assert _is_logarithmic(sin(log(x)), x) is False + assert _is_logarithmic(x + y, x) is False + assert _is_logarithmic(log(3*x) - log(1 - x) + 4, x) is True + assert _is_logarithmic(log(x) + log(y) + x, x) is False + assert _is_logarithmic(log(log(x - 3)) + log(x - 3), x) is True + assert _is_logarithmic(log(log(3) + x) + log(x), x) is True + assert _is_logarithmic(log(x)*(y + 3) + log(x), y) is False + + +def test_solve_logarithm(): + y = Symbol('y') + assert _solve_logarithm(log(x**y) - y*log(x), 0, x, S.Reals) == S.Reals + y = Symbol('y', positive=True) + assert _solve_logarithm(log(x)*log(y), 0, x, S.Reals) == FiniteSet(1) + +# end of logarithmic tests + + +# lambert tests +def test_is_lambert(): + a, b, c = symbols('a,b,c') + assert _is_lambert(x**2, x) is False + assert _is_lambert(a**x**2+b*x+c, x) is True + assert _is_lambert(E**2, x) is False + assert _is_lambert(x*E**2, x) is False + assert _is_lambert(3*log(x) - x*log(3), x) is True + assert _is_lambert(log(log(x - 3)) + log(x-3), x) is True + assert _is_lambert(5*x - 1 + 3*exp(2 - 7*x), x) is True + assert _is_lambert((a/x + exp(x/2)).diff(x, 2), x) is True + assert _is_lambert((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1), x) is True + assert _is_lambert(x*sinh(x) - 1, x) is True + assert _is_lambert(x*cos(x) - 5, x) is True + assert _is_lambert(tanh(x) - 5*x, x) is True + assert _is_lambert(cosh(x) - sinh(x), x) is False + +# end of lambert tests + + +def test_linear_coeffs(): + from sympy.solvers.solveset import linear_coeffs + assert linear_coeffs(0, x) == [0, 0] + assert all(i is S.Zero for i in linear_coeffs(0, x)) + assert linear_coeffs(x + 2*y + 3, x, y) == [1, 2, 3] + assert linear_coeffs(x + 2*y + 3, y, x) == [2, 1, 3] + assert linear_coeffs(x + 2*x**2 + 3, x, x**2) == [1, 2, 3] + raises(ValueError, lambda: + linear_coeffs(x + 2*x**2 + x**3, x, x**2)) + raises(ValueError, lambda: + linear_coeffs(1/x*(x - 1) + 1/x, x)) + raises(ValueError, lambda: + linear_coeffs(x, x, x)) + assert linear_coeffs(a*(x + y), x, y) == [a, a, 0] + assert linear_coeffs(1.0, x, y) == [0, 0, 1.0] + # don't include coefficients of 0 + assert linear_coeffs(Eq(x, x + y), x, y, dict=True) == {y: -1} + assert linear_coeffs(0, x, y, dict=True) == {} + + +def test_is_modular(): + assert _is_modular(y, x) is False + assert _is_modular(Mod(x, 3) - 1, x) is True + assert _is_modular(Mod(x**3 - 3*x**2 - x + 1, 3) - 1, x) is True + assert _is_modular(Mod(exp(x + y), 3) - 2, x) is True + assert _is_modular(Mod(exp(x + y), 3) - log(x), x) is True + assert _is_modular(Mod(x, 3) - 1, y) is False + assert _is_modular(Mod(x, 3)**2 - 5, x) is False + assert _is_modular(Mod(x, 3)**2 - y, x) is False + assert _is_modular(exp(Mod(x, 3)) - 1, x) is False + assert _is_modular(Mod(3, y) - 1, y) is False + + +def test_invert_modular(): + n = Dummy('n', integer=True) + from sympy.solvers.solveset import _invert_modular as invert_modular + + # no solutions + assert invert_modular(Mod(x, 12), S(1)/2, n, x) == (x, S.EmptySet) + # non invertible cases + assert invert_modular(Mod(sin(x), 7), S(5), n, x) == (Mod(sin(x), 7), 5) + assert invert_modular(Mod(exp(x), 7), S(5), n, x) == (Mod(exp(x), 7), 5) + assert invert_modular(Mod(log(x), 7), S(5), n, x) == (Mod(log(x), 7), 5) + # a is symbol + assert dumeq(invert_modular(Mod(x, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 5), S.Integers))) + # a.is_Add + assert dumeq(invert_modular(Mod(x + 8, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 4), S.Integers))) + assert invert_modular(Mod(x**2 + x, 7), S(5), n, x) == \ + (Mod(x**2 + x, 7), 5) + # a.is_Mul + assert dumeq(invert_modular(Mod(3*x, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 4), S.Integers))) + assert invert_modular(Mod((x + 1)*(x + 2), 7), S(5), n, x) == \ + (Mod((x + 1)*(x + 2), 7), 5) + # a.is_Pow + assert invert_modular(Mod(x**4, 7), S(5), n, x) == \ + (x, S.EmptySet) + assert dumeq(invert_modular(Mod(3**x, 4), S(3), n, x), + (x, ImageSet(Lambda(n, 2*n + 1), S.Naturals0))) + assert dumeq(invert_modular(Mod(2**(x**2 + x + 1), 7), S(2), n, x), + (x**2 + x + 1, ImageSet(Lambda(n, 3*n + 1), S.Naturals0))) + assert invert_modular(Mod(sin(x)**4, 7), S(5), n, x) == (x, S.EmptySet) + + +def test_solve_modular(): + n = Dummy('n', integer=True) + # if rhs has symbol (need to be implemented in future). + assert solveset(Mod(x, 4) - x, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(-x + Mod(x, 4), 0), + S.Integers)) + # when _invert_modular fails to invert + assert solveset(3 - Mod(sin(x), 7), x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(sin(x), 7) - 3, 0), S.Integers)) + assert solveset(3 - Mod(log(x), 7), x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(log(x), 7) - 3, 0), S.Integers)) + assert solveset(3 - Mod(exp(x), 7), x, S.Integers + ).dummy_eq(ConditionSet(x, Eq(Mod(exp(x), 7) - 3, 0), + S.Integers)) + # EmptySet solution definitely + assert solveset(7 - Mod(x, 5), x, S.Integers) is S.EmptySet + assert solveset(5 - Mod(x, 5), x, S.Integers) is S.EmptySet + # Negative m + assert dumeq(solveset(2 + Mod(x, -3), x, S.Integers), + ImageSet(Lambda(n, -3*n - 2), S.Integers)) + assert solveset(4 + Mod(x, -3), x, S.Integers) is S.EmptySet + # linear expression in Mod + assert dumeq(solveset(3 - Mod(x, 5), x, S.Integers), + ImageSet(Lambda(n, 5*n + 3), S.Integers)) + assert dumeq(solveset(3 - Mod(5*x - 8, 7), x, S.Integers), + ImageSet(Lambda(n, 7*n + 5), S.Integers)) + assert dumeq(solveset(3 - Mod(5*x, 7), x, S.Integers), + ImageSet(Lambda(n, 7*n + 2), S.Integers)) + # higher degree expression in Mod + assert dumeq(solveset(Mod(x**2, 160) - 9, x, S.Integers), + Union(ImageSet(Lambda(n, 160*n + 3), S.Integers), + ImageSet(Lambda(n, 160*n + 13), S.Integers), + ImageSet(Lambda(n, 160*n + 67), S.Integers), + ImageSet(Lambda(n, 160*n + 77), S.Integers), + ImageSet(Lambda(n, 160*n + 83), S.Integers), + ImageSet(Lambda(n, 160*n + 93), S.Integers), + ImageSet(Lambda(n, 160*n + 147), S.Integers), + ImageSet(Lambda(n, 160*n + 157), S.Integers))) + assert solveset(3 - Mod(x**4, 7), x, S.Integers) is S.EmptySet + assert dumeq(solveset(Mod(x**4, 17) - 13, x, S.Integers), + Union(ImageSet(Lambda(n, 17*n + 3), S.Integers), + ImageSet(Lambda(n, 17*n + 5), S.Integers), + ImageSet(Lambda(n, 17*n + 12), S.Integers), + ImageSet(Lambda(n, 17*n + 14), S.Integers))) + # a.is_Pow tests + assert dumeq(solveset(Mod(7**x, 41) - 15, x, S.Integers), + ImageSet(Lambda(n, 40*n + 3), S.Naturals0)) + assert dumeq(solveset(Mod(12**x, 21) - 18, x, S.Integers), + ImageSet(Lambda(n, 6*n + 2), S.Naturals0)) + assert dumeq(solveset(Mod(3**x, 4) - 3, x, S.Integers), + ImageSet(Lambda(n, 2*n + 1), S.Naturals0)) + assert dumeq(solveset(Mod(2**x, 7) - 2 , x, S.Integers), + ImageSet(Lambda(n, 3*n + 1), S.Naturals0)) + assert dumeq(solveset(Mod(3**(3**x), 4) - 3, x, S.Integers), + Intersection(ImageSet(Lambda(n, Intersection({log(2*n + 1)/log(3)}, + S.Integers)), S.Naturals0), S.Integers)) + # Implemented for m without primitive root + assert solveset(Mod(x**3, 7) - 2, x, S.Integers) is S.EmptySet + assert dumeq(solveset(Mod(x**3, 8) - 1, x, S.Integers), + ImageSet(Lambda(n, 8*n + 1), S.Integers)) + assert dumeq(solveset(Mod(x**4, 9) - 4, x, S.Integers), + Union(ImageSet(Lambda(n, 9*n + 4), S.Integers), + ImageSet(Lambda(n, 9*n + 5), S.Integers))) + # domain intersection + assert dumeq(solveset(3 - Mod(5*x - 8, 7), x, S.Naturals0), + Intersection(ImageSet(Lambda(n, 7*n + 5), S.Integers), S.Naturals0)) + # Complex args + assert solveset(Mod(x, 3) - I, x, S.Integers) == \ + S.EmptySet + assert solveset(Mod(I*x, 3) - 2, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(I*x, 3) - 2, 0), S.Integers)) + assert solveset(Mod(I + x, 3) - 2, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(x + I, 3) - 2, 0), S.Integers)) + + # issue 17373 (https://github.com/sympy/sympy/issues/17373) + assert dumeq(solveset(Mod(x**4, 14) - 11, x, S.Integers), + Union(ImageSet(Lambda(n, 14*n + 3), S.Integers), + ImageSet(Lambda(n, 14*n + 11), S.Integers))) + assert dumeq(solveset(Mod(x**31, 74) - 43, x, S.Integers), + ImageSet(Lambda(n, 74*n + 31), S.Integers)) + + # issue 13178 + n = symbols('n', integer=True) + a = 742938285 + b = 1898888478 + m = 2**31 - 1 + c = 20170816 + assert dumeq(solveset(c - Mod(a**n*b, m), n, S.Integers), + ImageSet(Lambda(n, 2147483646*n + 100), S.Naturals0)) + assert dumeq(solveset(c - Mod(a**n*b, m), n, S.Naturals0), + Intersection(ImageSet(Lambda(n, 2147483646*n + 100), S.Naturals0), + S.Naturals0)) + assert dumeq(solveset(c - Mod(a**(2*n)*b, m), n, S.Integers), + Intersection(ImageSet(Lambda(n, 1073741823*n + 50), S.Naturals0), + S.Integers)) + assert solveset(c - Mod(a**(2*n + 7)*b, m), n, S.Integers) is S.EmptySet + assert dumeq(solveset(c - Mod(a**(n - 4)*b, m), n, S.Integers), + Intersection(ImageSet(Lambda(n, 2147483646*n + 104), S.Naturals0), + S.Integers)) + +# end of modular tests + +def test_issue_17276(): + assert nonlinsolve([Eq(x, 5**(S(1)/5)), Eq(x*y, 25*sqrt(5))], x, y) == \ + FiniteSet((5**(S(1)/5), 25*5**(S(3)/10))) + + +def test_issue_10426(): + x = Dummy('x') + a = Symbol('a') + n = Dummy('n') + assert (solveset(sin(x + a) - sin(x), a)).dummy_eq(Dummy('x')) == (Union( + ImageSet(Lambda(n, 2*n*pi), S.Integers), + Intersection(S.Complexes, ImageSet(Lambda(n, -I*(I*(2*n*pi + arg(-exp(-2*I*x))) + 2*im(x))), + S.Integers)))).dummy_eq(Dummy('x,n')) + + +def test_solveset_conjugate(): + """Test solveset for simple conjugate functions""" + assert solveset(conjugate(x) -3 + I) == FiniteSet(3 + I) + + +def test_issue_18208(): + variables = symbols('x0:16') + symbols('y0:12') + x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15,\ + y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11 = variables + + eqs = [x0 + x1 + x2 + x3 - 51, + x0 + x1 + x4 + x5 - 46, + x2 + x3 + x6 + x7 - 39, + x0 + x3 + x4 + x7 - 50, + x1 + x2 + x5 + x6 - 35, + x4 + x5 + x6 + x7 - 34, + x4 + x5 + x8 + x9 - 46, + x10 + x11 + x6 + x7 - 23, + x11 + x4 + x7 + x8 - 25, + x10 + x5 + x6 + x9 - 44, + x10 + x11 + x8 + x9 - 35, + x12 + x13 + x8 + x9 - 35, + x10 + x11 + x14 + x15 - 29, + x11 + x12 + x15 + x8 - 35, + x10 + x13 + x14 + x9 - 29, + x12 + x13 + x14 + x15 - 29, + y0 + y1 + y2 + y3 - 55, + y0 + y1 + y4 + y5 - 53, + y2 + y3 + y6 + y7 - 56, + y0 + y3 + y4 + y7 - 57, + y1 + y2 + y5 + y6 - 52, + y4 + y5 + y6 + y7 - 54, + y4 + y5 + y8 + y9 - 48, + y10 + y11 + y6 + y7 - 60, + y11 + y4 + y7 + y8 - 51, + y10 + y5 + y6 + y9 - 57, + y10 + y11 + y8 + y9 - 54, + x10 - 2, + x11 - 5, + x12 - 1, + x13 - 6, + x14 - 1, + x15 - 21, + y0 - 12, + y1 - 20] + + expected = [38 - x3, x3 - 10, 23 - x3, x3, 12 - x7, x7 + 6, 16 - x7, x7, + 8, 20, 2, 5, 1, 6, 1, 21, 12, 20, -y11 + y9 + 2, y11 - y9 + 21, + -y11 - y7 + y9 + 24, y11 + y7 - y9 - 3, 33 - y7, y7, 27 - y9, y9, + 27 - y11, y11] + + A, b = linear_eq_to_matrix(eqs, variables) + + # solve + solve_expected = {v:eq for v, eq in zip(variables, expected) if v != eq} + + assert solve(eqs, variables) == solve_expected + + # linsolve + linsolve_expected = FiniteSet(Tuple(*expected)) + + assert linsolve(eqs, variables) == linsolve_expected + assert linsolve((A, b), variables) == linsolve_expected + + # gauss_jordan_solve + gj_solve, new_vars = A.gauss_jordan_solve(b) + gj_solve = list(gj_solve) + + gj_expected = linsolve_expected.subs(zip([x3, x7, y7, y9, y11], new_vars)) + + assert FiniteSet(Tuple(*gj_solve)) == gj_expected + + # nonlinsolve + # The solution set of nonlinsolve is currently equivalent to linsolve and is + # also correct. However, we would prefer to use the same symbols as parameters + # for the solution to the underdetermined system in all cases if possible. + # We want a solution that is not just equivalent but also given in the same form. + # This test may be changed should nonlinsolve be modified in this way. + + nonlinsolve_expected = FiniteSet((38 - x3, x3 - 10, 23 - x3, x3, 12 - x7, x7 + 6, + 16 - x7, x7, 8, 20, 2, 5, 1, 6, 1, 21, 12, 20, + -y5 + y7 - 1, y5 - y7 + 24, 21 - y5, y5, 33 - y7, + y7, 27 - y9, y9, -y5 + y7 - y9 + 24, y5 - y7 + y9 + 3)) + + assert nonlinsolve(eqs, variables) == nonlinsolve_expected + + +def test_substitution_with_infeasible_solution(): + a00, a01, a10, a11, l0, l1, l2, l3, m0, m1, m2, m3, m4, m5, m6, m7, c00, c01, c10, c11, p00, p01, p10, p11 = symbols( + 'a00, a01, a10, a11, l0, l1, l2, l3, m0, m1, m2, m3, m4, m5, m6, m7, c00, c01, c10, c11, p00, p01, p10, p11' + ) + solvefor = [p00, p01, p10, p11, c00, c01, c10, c11, m0, m1, m3, l0, l1, l2, l3] + system = [ + -l0 * c00 - l1 * c01 + m0 + c00 + c01, + -l0 * c10 - l1 * c11 + m1, + -l2 * c00 - l3 * c01 + c00 + c01, + -l2 * c10 - l3 * c11 + m3, + -l0 * p00 - l2 * p10 + p00 + p10, + -l1 * p00 - l3 * p10 + p00 + p10, + -l0 * p01 - l2 * p11, + -l1 * p01 - l3 * p11, + -a00 + c00 * p00 + c10 * p01, + -a01 + c01 * p00 + c11 * p01, + -a10 + c00 * p10 + c10 * p11, + -a11 + c01 * p10 + c11 * p11, + -m0 * p00, + -m1 * p01, + -m2 * p10, + -m3 * p11, + -m4 * c00, + -m5 * c01, + -m6 * c10, + -m7 * c11, + m2, + m4, + m5, + m6, + m7 + ] + sol = FiniteSet( + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, l2, l3), + (p00, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, 1, 1, -p01/p11, -p01/p11), + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, 1, -l3*p11/p01, -p01/p11, l3), + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, -l2*p11/p01, -l3*p11/p01, l2, l3), + ) + assert sol != nonlinsolve(system, solvefor) + + +def test_issue_20097(): + assert solveset(1/sqrt(x)) is S.EmptySet + + +def test_issue_15350(): + assert solveset(diff(sqrt(1/x+x))) == FiniteSet(-1, 1) + + +def test_issue_18359(): + c1 = Piecewise((0, x < 0), (Min(1, x)/2 - Min(2, x)/2 + Min(3, x)/2, True)) + c2 = Piecewise((Piecewise((0, x < 0), (Min(1, x)/2 - Min(2, x)/2 + Min(3, x)/2, True)), x >= 0), (0, True)) + correct_result = Interval(1, 2) + result1 = solveset(c1 - Rational(1, 2), x, Interval(0, 3)) + result2 = solveset(c2 - Rational(1, 2), x, Interval(0, 3)) + assert result1 == correct_result + assert result2 == correct_result + + +def test_issue_17604(): + lhs = -2**(3*x/11)*exp(x/11) + pi**(x/11) + assert _is_exponential(lhs, x) + assert _solve_exponential(lhs, 0, x, S.Complexes) == FiniteSet(0) + + +def test_issue_17580(): + assert solveset(1/(1 - x**3)**2, x, S.Reals) is S.EmptySet + + +def test_issue_17566_actual(): + sys = [2**x + 2**y - 3, 4**x + 9**y - 5] + # Not clear this is the correct result, but at least no recursion error + assert nonlinsolve(sys, x, y) == FiniteSet((log(3 - 2**y)/log(2), y)) + + +def test_issue_17565(): + eq = Ge(2*(x - 2)**2/(3*(x + 1)**(Integer(1)/3)) + 2*(x - 2)*(x + 1)**(Integer(2)/3), 0) + res = Union(Interval.Lopen(-1, -Rational(1, 4)), Interval(2, oo)) + assert solveset(eq, x, S.Reals) == res + + +def test_issue_15024(): + function = (x + 5)/sqrt(-x**2 - 10*x) + assert solveset(function, x, S.Reals) == FiniteSet(Integer(-5)) + + +def test_issue_16877(): + assert dumeq(nonlinsolve([x - 1, sin(y)], x, y), + FiniteSet((1, ImageSet(Lambda(n, 2*n*pi), S.Integers)), + (1, ImageSet(Lambda(n, 2*n*pi + pi), S.Integers)))) + # Even better if (1, ImageSet(Lambda(n, n*pi), S.Integers)) is obtained + + +def test_issue_16876(): + assert dumeq(nonlinsolve([sin(x), 2*x - 4*y], x, y), + FiniteSet((ImageSet(Lambda(n, 2*n*pi), S.Integers), + ImageSet(Lambda(n, n*pi), S.Integers)), + (ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), + ImageSet(Lambda(n, n*pi + pi/2), S.Integers)))) + # Even better if (ImageSet(Lambda(n, n*pi), S.Integers), + # ImageSet(Lambda(n, n*pi/2), S.Integers)) is obtained + +def test_issue_21236(): + x, z = symbols("x z") + y = symbols('y', rational=True) + assert solveset(x**y - z, x, S.Reals) == ConditionSet(x, Eq(x**y - z, 0), S.Reals) + e1, e2 = symbols('e1 e2', even=True) + y = e1/e2 # don't know if num or den will be odd and the other even + assert solveset(x**y - z, x, S.Reals) == ConditionSet(x, Eq(x**y - z, 0), S.Reals) + + +def test_issue_21908(): + assert nonlinsolve([(x**2 + 2*x - y**2)*exp(x), -2*y*exp(x)], x, y + ) == {(-2, 0), (0, 0)} + + +def test_issue_19144(): + # test case 1 + expr1 = [x + y - 1, y**2 + 1] + eq1 = [Eq(i, 0) for i in expr1] + soln1 = {(1 - I, I), (1 + I, -I)} + soln_expr1 = nonlinsolve(expr1, [x, y]) + soln_eq1 = nonlinsolve(eq1, [x, y]) + assert soln_eq1 == soln_expr1 == soln1 + # test case 2 - with denoms + expr2 = [x/y - 1, y**2 + 1] + eq2 = [Eq(i, 0) for i in expr2] + soln2 = {(-I, -I), (I, I)} + soln_expr2 = nonlinsolve(expr2, [x, y]) + soln_eq2 = nonlinsolve(eq2, [x, y]) + assert soln_eq2 == soln_expr2 == soln2 + # denominators that cancel in expression + assert nonlinsolve([Eq(x + 1/x, 1/x)], [x]) == FiniteSet((S.EmptySet,)) + + +def test_issue_22413(): + res = nonlinsolve((4*y*(2*x + 2*exp(y) + 1)*exp(2*x), + 4*x*exp(2*x) + 4*y*exp(2*x + y) + 4*exp(2*x + y) + 1), + x, y) + # First solution is not correct, but the issue was an exception + sols = FiniteSet((x, S.Zero), (-exp(y) - S.Half, y)) + assert res == sols + + +def test_issue_23318(): + eqs_eq = [ + Eq(53.5780461486929, x * log(y / (5.0 - y) + 1) / y), + Eq(x, 0.0015 * z), + Eq(0.0015, 7845.32 * y / z), + ] + eqs_expr = [eq.lhs - eq.rhs for eq in eqs_eq] + + sol = {(266.97755814852, 0.0340301680681629, 177985.03876568)} + + assert_close_nl(nonlinsolve(eqs_eq, [x, y, z]), sol) + assert_close_nl(nonlinsolve(eqs_expr, [x, y, z]), sol) + + logterm = log(1.91196789933362e-7*z/(5.0 - 1.91196789933362e-7*z) + 1) + eq = -0.0015*z*logterm + 1.02439504345316e-5*z + assert_close_ss(solveset(eq, z), {0, 177985.038765679}) + + +def test_issue_19814(): + assert nonlinsolve([ 2**m - 2**(2*n), 4*2**m - 2**(4*n)], m, n + ) == FiniteSet((log(2**(2*n))/log(2), S.Complexes)) + + +def test_issue_22058(): + sol = solveset(-sqrt(t)*x**2 + 2*x + sqrt(t), x, S.Reals) + # doesn't fail (and following numerical check) + assert sol.xreplace({t: 1}) == {1 - sqrt(2), 1 + sqrt(2)}, sol.xreplace({t: 1}) + + +def test_issue_11184(): + assert solveset(20*sqrt(y**2 + (sqrt(-(y - 10)*(y + 10)) + 10)**2) - 60, y, S.Reals) is S.EmptySet + + +def test_issue_21890(): + e = S(2)/3 + assert nonlinsolve([4*x**3*y**4 - 2*y, 4*x**4*y**3 - 2*x], x, y) == { + (2**e/(2*y), y), ((-2**e/4 - 2**e*sqrt(3)*I/4)/y, y), + ((-2**e/4 + 2**e*sqrt(3)*I/4)/y, y)} + assert nonlinsolve([(1 - 4*x**2)*exp(-2*x**2 - 2*y**2), + -4*x*y*exp(-2*x**2)*exp(-2*y**2)], x, y) == {(-S(1)/2, 0), (S(1)/2, 0)} + rx, ry = symbols('x y', real=True) + sol = nonlinsolve([4*rx**3*ry**4 - 2*ry, 4*rx**4*ry**3 - 2*rx], rx, ry) + ans = {(2**(S(2)/3)/(2*ry), ry), + ((-2**(S(2)/3)/4 - 2**(S(2)/3)*sqrt(3)*I/4)/ry, ry), + ((-2**(S(2)/3)/4 + 2**(S(2)/3)*sqrt(3)*I/4)/ry, ry)} + assert sol == ans + + +def test_issue_22628(): + assert nonlinsolve([h - 1, k - 1, f - 2, f - 4, -2*k], h, k, f) == S.EmptySet + assert nonlinsolve([x**3 - 1, x + y, x**2 - 4], [x, y]) == S.EmptySet + + +def test_issue_25781(): + assert solve(sqrt(x/2) - x) == [0, S.Half] + + +def test_issue_26077(): + _n = Symbol('_n') + function = x*cot(5*x) + critical_points = stationary_points(function, x, S.Reals) + excluded_points = Union( + ImageSet(Lambda(_n, 2*_n*pi/5), S.Integers), + ImageSet(Lambda(_n, 2*_n*pi/5 + pi/5), S.Integers) + ) + solution = ConditionSet(x, + Eq(x*(-5*cot(5*x)**2 - 5) + cot(5*x), 0), + Complement(S.Reals, excluded_points) + ) + assert solution.as_dummy() == critical_points.as_dummy()