adeshboudh16 commited on
Commit
08fce0d
·
1 Parent(s): da0480b

fix(api): avoid asyncio shadowing in lifespan

Browse files
src/civicsetu/api/main.py CHANGED
@@ -44,7 +44,6 @@ def create_checkpointer():
44
  async def lifespan(app: FastAPI):
45
  """Startup and shutdown events."""
46
  if sys.platform == "win32":
47
- import asyncio
48
  asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
49
  log.info("civicsetu_starting", env=settings.api_env)
50
 
 
44
  async def lifespan(app: FastAPI):
45
  """Startup and shutdown events."""
46
  if sys.platform == "win32":
 
47
  asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
48
  log.info("civicsetu_starting", env=settings.api_env)
49
 
tests/unit/api/test_query_route.py CHANGED
@@ -78,6 +78,33 @@ def test_app_startup_warms_reranker_from_retrieval_module():
78
  mock_get_ranker.assert_called_once()
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def test_query_returns_200_with_citations(client):
82
  test_client, mock_graph = client
83
  mock_graph.invoke.return_value = {
 
78
  mock_get_ranker.assert_called_once()
79
 
80
 
81
+ def test_app_startup_on_non_windows_does_not_shadow_asyncio():
82
+ fake_checkpointer = AsyncMock()
83
+
84
+ @asynccontextmanager
85
+ async def fake_checkpointer_context():
86
+ yield fake_checkpointer
87
+
88
+ with patch("civicsetu.api.main.sys.platform", "linux"), patch(
89
+ "civicsetu.api.main.create_checkpointer", return_value=fake_checkpointer_context()
90
+ ), patch("civicsetu.agent.graph.get_compiled_graph", return_value=MagicMock()), patch(
91
+ "civicsetu.api.main.get_driver", new=AsyncMock()
92
+ ), patch("civicsetu.api.main.close_driver", new=AsyncMock()), patch(
93
+ "civicsetu.retrieval.warm_embedding_model"
94
+ ) as mock_warm_embedding_model, patch(
95
+ "civicsetu.retrieval.reranker._get_ranker"
96
+ ) as mock_get_ranker:
97
+ from civicsetu.api.main import create_app
98
+
99
+ app = create_app()
100
+
101
+ with TestClient(app):
102
+ pass
103
+
104
+ mock_warm_embedding_model.assert_called_once()
105
+ mock_get_ranker.assert_called_once()
106
+
107
+
108
  def test_query_returns_200_with_citations(client):
109
  test_client, mock_graph = client
110
  mock_graph.invoke.return_value = {