| """ |
| Database operation and rollback tests for chat functionality |
| |
| Tests database transactions, connection pooling, and rollback scenarios |
| """ |
|
|
| import pytest |
| from fastapi.testclient import TestClient |
| from backend.src.main import app |
| from backend.src.core.database import get_session, get_engine |
| from backend.src.models.conversation import ConversationCreate |
| from backend.src.auth.security import create_access_token |
| from sqlalchemy.exc import OperationalError, IntegrityError |
|
|
| |
| client = TestClient(app) |
|
|
| |
| TEST_USER_ID = "test_user_123" |
| TEST_TOKEN = create_access_token(data={"user_id": TEST_USER_ID}) |
|
|
| @pytest.fixture |
| def cleanup_conversations(): |
| """Cleanup conversations after each test""" |
| session = get_session() |
| session.exec("DELETE FROM message WHERE conversation_id IN (SELECT id FROM conversation WHERE user_id = :user_id)", {"user_id": TEST_USER_ID}) |
| session.exec("DELETE FROM conversation WHERE user_id = :user_id", {"user_id": TEST_USER_ID}) |
| session.commit() |
| session.close() |
|
|
| @pytest.fixture |
| def test_conversation(cleanup_conversations): |
| """Create a test conversation for database tests""" |
| response = client.post( |
| "/api/v1/chat/", |
| json={"title": "Test Conversation"}, |
| headers={"Authorization": f"Bearer {TEST_TOKEN}"} |
| ) |
| return response.json()["id"] |
|
|
| class TestDatabaseOperations: |
| """Test database operations, transactions, and rollback scenarios""" |
|
|
| def test_database_connection_pooling(self, cleanup_conversations): |
| """Test database connection pooling""" |
| |
| sessions = [] |
| for i in range(10): |
| session = get_session() |
| sessions.append(session) |
|
|
| |
| for session in sessions: |
| assert session.is_active |
|
|
| |
| for session in sessions: |
| session.close() |
|
|
| def test_transaction_atomicity(self, cleanup_conversations): |
| """Test that transactions are atomic""" |
| |
| session = get_session() |
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Atomic Test" |
| ) |
| session.add(conversation) |
| session.flush() |
|
|
| |
| for i in range(3): |
| message = Conversation( |
| conversation_id=conversation.id, |
| content=f"Message {i}", |
| sender="user" |
| ) |
| session.add(message) |
|
|
| |
| session.commit() |
|
|
| |
| db_conversation = session.get(Conversation, conversation.id) |
| db_messages = session.exec( |
| "SELECT * FROM message WHERE conversation_id = :conversation_id", |
| {"conversation_id": conversation.id} |
| ).all() |
|
|
| assert db_conversation is not None |
| assert len(db_messages) == 3 |
|
|
| finally: |
| session.close() |
|
|
| def test_rollback_on_error(self, cleanup_conversations): |
| """Test that transactions are rolled back on error""" |
| session = get_session() |
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Rollback Test" |
| ) |
| session.add(conversation) |
| session.flush() |
|
|
| |
| message = Message( |
| conversation_id=conversation.id, |
| content="" * 10001, |
| sender="user" |
| ) |
| session.add(message) |
|
|
| |
| try: |
| session.commit() |
| assert False, "Commit should have failed" |
| except Exception: |
| |
| session.rollback() |
|
|
| |
| db_conversation = session.get(Conversation, conversation.id) |
| db_messages = session.exec( |
| "SELECT * FROM message WHERE conversation_id = :conversation_id", |
| {"conversation_id": conversation.id} |
| ).all() |
|
|
| assert db_conversation is None |
| assert len(db_messages) == 0 |
|
|
| finally: |
| session.close() |
|
|
| def test_concurrent_transaction_isolation(self, cleanup_conversations): |
| """Test transaction isolation levels""" |
| session1 = get_session() |
| session2 = get_session() |
|
|
| try: |
| |
| conversation1 = Conversation( |
| user_id=TEST_USER_ID, |
| title="Session 1" |
| ) |
| session1.add(conversation1) |
| session1.flush() |
|
|
| |
| db_conversation2 = session2.get(Conversation, conversation1.id) |
| assert db_conversation2 is None |
|
|
| |
| session1.commit() |
|
|
| |
| db_conversation2 = session2.get(Conversation, conversation1.id) |
| assert db_conversation2 is not None |
|
|
| finally: |
| session1.close() |
| session2.close() |
|
|
| def test_deadlock_detection_and_resolution(self, cleanup_conversations): |
| """Test deadlock detection and resolution""" |
| session1 = get_session() |
| session2 = get_session() |
|
|
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Deadlock Test" |
| ) |
| session1.add(conversation) |
| session1.commit() |
|
|
| |
| try: |
| |
| session1.begin() |
| conv1 = session1.get(Conversation, conversation.id) |
| conv1.title = "Session 1 Locked" |
|
|
| |
| session2.begin() |
| conv2 = session2.get(Conversation, conversation.id) |
| conv2.title = "Session 2 Locked" |
|
|
| |
| session2.commit() |
|
|
| |
| session1.commit() |
| assert False, "Session 1 commit should have failed" |
|
|
| except Exception as e: |
| |
| session1.rollback() |
|
|
| |
| final_session = get_session() |
| final_conv = final_session.get(Conversation, conversation.id) |
| assert final_conv.title == "Session 2 Locked" |
| final_session.close() |
|
|
| finally: |
| session1.close() |
| session2.close() |
|
|
| def test_savepoint_rollback(self, cleanup_conversations): |
| """Test savepoint rollback functionality""" |
| session = get_session() |
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Savepoint Test" |
| ) |
| session.add(conversation) |
| session.commit() |
|
|
| |
| session.begin() |
| session.add(Conversation( |
| conversation_id=conversation.id, |
| content="Message 1", |
| sender="user" |
| )) |
|
|
| |
| session.begin_nested() |
| session.add(Conversation( |
| conversation_id=conversation.id, |
| content="Message 2", |
| sender="user" |
| )) |
|
|
| |
| session.rollback() |
|
|
| |
| session.commit() |
|
|
| |
| final_session = get_session() |
| messages = final_session.exec( |
| "SELECT * FROM message WHERE conversation_id = :conversation_id", |
| {"conversation_id": conversation.id} |
| ).all() |
|
|
| assert len(messages) == 1 |
| assert messages[0].content == "Message 1" |
| final_session.close() |
|
|
| finally: |
| session.close() |
|
|
| def test_connection_reuse(self, cleanup_conversations): |
| """Test connection reuse in connection pool""" |
| |
| session1 = get_session() |
| session2 = get_session() |
|
|
| |
| assert session1.bind.url != session2.bind.url |
|
|
| |
| session1.close() |
| session2.close() |
|
|
| session3 = get_session() |
| session4 = get_session() |
|
|
| |
| |
|
|
| session3.close() |
| session4.close() |
|
|
| def test_database_schema_validation(self, cleanup_conversations): |
| """Test database schema validation""" |
| session = get_session() |
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Schema Test" |
| ) |
| session.add(conversation) |
| session.commit() |
|
|
| |
| message = Message( |
| conversation_id=conversation.id, |
| content="Test message", |
| sender="user" |
| ) |
| session.add(message) |
| session.commit() |
|
|
| |
| |
|
|
| finally: |
| session.close() |
|
|
| def test_batch_operations(self, cleanup_conversations): |
| """Test batch database operations""" |
| session = get_session() |
| try: |
| |
| conversations = [] |
| for i in range(100): |
| conversations.append(Conversation( |
| user_id=TEST_USER_ID, |
| title=f"Batch Conversation {i}" |
| )) |
| session.add_all(conversations) |
| session.commit() |
|
|
| |
| db_conversations = session.exec( |
| "SELECT * FROM conversation WHERE user_id = :user_id", |
| {"user_id": TEST_USER_ID} |
| ).all() |
| assert len(db_conversations) == 100 |
|
|
| finally: |
| session.close() |
|
|
| def test_large_data_handling(self, cleanup_conversations): |
| """Test handling of large data volumes""" |
| session = get_session() |
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Large Data Test" |
| ) |
| session.add(conversation) |
| session.commit() |
|
|
| |
| messages = [] |
| for i in range(1000): |
| messages.append(Message( |
| conversation_id=conversation.id, |
| content=f"Message {i}", |
| sender="user" |
| )) |
| session.add_all(messages) |
| session.commit() |
|
|
| |
| db_messages = session.exec( |
| "SELECT * FROM message WHERE conversation_id = :conversation_id", |
| {"conversation_id": conversation.id} |
| ).all() |
| assert len(db_messages) == 1000 |
|
|
| finally: |
| session.close() |
|
|
| def test_database_error_handling(self, cleanup_conversations): |
| """Test graceful handling of database errors""" |
| |
| session = get_session() |
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Duplicate Test" |
| ) |
| session.add(conversation) |
| session.commit() |
|
|
| |
| duplicate = Conversation( |
| user_id=TEST_USER_ID, |
| title="Duplicate Test" |
| ) |
| session.add(duplicate) |
|
|
| try: |
| session.commit() |
| assert False, "Commit should have failed due to duplicate" |
| except IntegrityError: |
| session.rollback() |
| assert True |
|
|
| finally: |
| session.close() |
|
|
| def test_connection_timeout_handling(self, cleanup_conversations): |
| """Test handling of connection timeouts""" |
| |
| |
| pass |
|
|
| def test_read_only_transactions(self, cleanup_conversations): |
| """Test read-only transactions""" |
| session = get_session() |
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Read Only Test" |
| ) |
| session.add(conversation) |
| session.commit() |
|
|
| |
| session.begin() |
| db_conversation = session.get(Conversation, conversation.id) |
| assert db_conversation is not None |
|
|
| |
| db_conversation.title = "Modified Title" |
| try: |
| session.commit() |
| assert False, "Commit should have failed in read-only transaction" |
| except Exception: |
| session.rollback() |
| assert True |
|
|
| finally: |
| session.close() |
|
|
| def test_transaction_retry_on_deadlock(self, cleanup_conversations): |
| """Test automatic retry on deadlock""" |
| |
| |
| pass |
|
|
| def test_database_migration_rollback(self, cleanup_conversations): |
| """Test database migration rollback""" |
| |
| |
| pass |
|
|
| def test_data_consistency_checks(self, cleanup_conversations): |
| """Test data consistency checks""" |
| session = get_session() |
| try: |
| |
| conversation = Conversation( |
| user_id=TEST_USER_ID, |
| title="Consistency Test" |
| ) |
| session.add(conversation) |
| session.commit() |
|
|
| |
| messages = [] |
| for i in range(5): |
| messages.append(Message( |
| conversation_id=conversation.id, |
| content=f"Message {i}", |
| sender="user" |
| )) |
| session.add_all(messages) |
| session.commit() |
|
|
| |
| db_conversation = session.get(Conversation, conversation.id) |
| db_messages = session.exec( |
| "SELECT * FROM message WHERE conversation_id = :conversation_id", |
| {"conversation_id": conversation.id} |
| ).all() |
|
|
| assert db_conversation is not None |
| assert len(db_messages) == 5 |
| assert all(msg.conversation_id == conversation.id for msg in db_messages) |
|
|
| finally: |
| session.close() |
|
|
| def test_database_performance_under_load(self, cleanup_conversations): |
| """Test database performance under load""" |
| |
| |
| pass |