| from __future__ import annotations |
|
|
| import threading |
| import traceback |
| from types import TracebackType |
| from typing import Any |
| from typing import Callable |
| from typing import Generator |
| from typing import TYPE_CHECKING |
| import warnings |
|
|
| import pytest |
|
|
|
|
| if TYPE_CHECKING: |
| from typing_extensions import Self |
|
|
|
|
| |
| class catch_threading_exception: |
| """Context manager catching threading.Thread exception using |
| threading.excepthook. |
| |
| Storing exc_value using a custom hook can create a reference cycle. The |
| reference cycle is broken explicitly when the context manager exits. |
| |
| Storing thread using a custom hook can resurrect it if it is set to an |
| object which is being finalized. Exiting the context manager clears the |
| stored object. |
| |
| Usage: |
| with threading_helper.catch_threading_exception() as cm: |
| # code spawning a thread which raises an exception |
| ... |
| # check the thread exception: use cm.args |
| ... |
| # cm.args attribute no longer exists at this point |
| # (to break a reference cycle) |
| """ |
|
|
| def __init__(self) -> None: |
| self.args: threading.ExceptHookArgs | None = None |
| self._old_hook: Callable[[threading.ExceptHookArgs], Any] | None = None |
|
|
| def _hook(self, args: threading.ExceptHookArgs) -> None: |
| self.args = args |
|
|
| def __enter__(self) -> Self: |
| self._old_hook = threading.excepthook |
| threading.excepthook = self._hook |
| return self |
|
|
| def __exit__( |
| self, |
| exc_type: type[BaseException] | None, |
| exc_val: BaseException | None, |
| exc_tb: TracebackType | None, |
| ) -> None: |
| assert self._old_hook is not None |
| threading.excepthook = self._old_hook |
| self._old_hook = None |
| del self.args |
|
|
|
|
| def thread_exception_runtest_hook() -> Generator[None]: |
| with catch_threading_exception() as cm: |
| try: |
| yield |
| finally: |
| if cm.args: |
| thread_name = ( |
| "<unknown>" if cm.args.thread is None else cm.args.thread.name |
| ) |
| msg = f"Exception in thread {thread_name}\n\n" |
| msg += "".join( |
| traceback.format_exception( |
| cm.args.exc_type, |
| cm.args.exc_value, |
| cm.args.exc_traceback, |
| ) |
| ) |
| warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg)) |
|
|
|
|
| @pytest.hookimpl(wrapper=True, trylast=True) |
| def pytest_runtest_setup() -> Generator[None]: |
| yield from thread_exception_runtest_hook() |
|
|
|
|
| @pytest.hookimpl(wrapper=True, tryfirst=True) |
| def pytest_runtest_call() -> Generator[None]: |
| yield from thread_exception_runtest_hook() |
|
|
|
|
| @pytest.hookimpl(wrapper=True, tryfirst=True) |
| def pytest_runtest_teardown() -> Generator[None]: |
| yield from thread_exception_runtest_hook() |
|
|