cesjavi commited on
Commit
bea04ab
·
1 Parent(s): 93d283b

Backend: Added dynamic infrastructure management (DigitalOcean)

Browse files
backend/main.py CHANGED
@@ -1,11 +1,17 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import FileResponse, Response
 
 
4
  import os
5
  import json
6
  from pathlib import Path
7
  from dotenv import load_dotenv
8
  import sentry_sdk
 
 
 
 
9
 
10
 
11
  def _load_app_version() -> str:
@@ -21,6 +27,9 @@ def _load_app_version() -> str:
21
  load_dotenv()
22
  FRONTEND_DIST = Path(__file__).resolve().parent.parent / "frontend" / "dist"
23
  APP_VERSION = _load_app_version()
 
 
 
24
 
25
  # Sentry Initialization
26
  SENTRY_DSN = os.getenv("SENTRY_DSN")
@@ -48,6 +57,54 @@ app.add_middleware(
48
  allow_headers=["*"],
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.get("/")
52
  async def root():
53
  index_path = FRONTEND_DIST / "index.html"
@@ -81,6 +138,20 @@ async def runtime_config():
81
  media_type="application/javascript",
82
  )
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @app.get("/{path:path}", include_in_schema=False)
85
  async def serve_frontend(path: str):
86
  if not FRONTEND_DIST.exists():
@@ -96,7 +167,32 @@ async def serve_frontend(path: str):
96
 
97
  return await root()
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  if __name__ == "__main__":
100
  import uvicorn
101
- from services.config import settings
102
- uvicorn.run("main:app", host="0.0.0.0", port=settings.PORT, reload=True)
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import FileResponse, Response
4
+ import asyncio
5
+ import logging
6
  import os
7
  import json
8
  from pathlib import Path
9
  from dotenv import load_dotenv
10
  import sentry_sdk
11
+ from services.orchestrator_service import orchestrator_service
12
+ from services.infrastructure_service import infrastructure_service
13
+ from services.config import settings
14
+ from worker import AubmWorker
15
 
16
 
17
  def _load_app_version() -> str:
 
27
  load_dotenv()
28
  FRONTEND_DIST = Path(__file__).resolve().parent.parent / "frontend" / "dist"
29
  APP_VERSION = _load_app_version()
30
+ logger = logging.getLogger("aubm.api")
31
+ embedded_worker: AubmWorker | None = None
32
+ embedded_worker_task: asyncio.Task | None = None
33
 
34
  # Sentry Initialization
35
  SENTRY_DSN = os.getenv("SENTRY_DSN")
 
57
  allow_headers=["*"],
58
  )
59
 
60
+
61
+ def _log_embedded_worker_result(task: asyncio.Task) -> None:
62
+ if task.cancelled():
63
+ return
64
+
65
+ exc = task.exception()
66
+ if exc:
67
+ logger.error(
68
+ "Embedded worker stopped unexpectedly",
69
+ exc_info=(type(exc), exc, exc.__traceback__),
70
+ )
71
+
72
+
73
+ @app.on_event("startup")
74
+ async def start_embedded_worker() -> None:
75
+ global embedded_worker, embedded_worker_task
76
+
77
+ if settings.TASK_EXECUTION_MODE != "queue" or not settings.TASK_QUEUE_EMBEDDED_WORKER:
78
+ return
79
+
80
+ if embedded_worker_task and not embedded_worker_task.done():
81
+ return
82
+
83
+ embedded_worker = AubmWorker()
84
+ embedded_worker_task = asyncio.create_task(embedded_worker.start())
85
+ embedded_worker_task.add_done_callback(_log_embedded_worker_result)
86
+ logger.info("Embedded task worker started: %s", embedded_worker.worker_id)
87
+
88
+
89
+ @app.on_event("shutdown")
90
+ async def stop_embedded_worker() -> None:
91
+ global embedded_worker, embedded_worker_task
92
+
93
+ if not embedded_worker or not embedded_worker_task:
94
+ return
95
+
96
+ embedded_worker.stop()
97
+ try:
98
+ await asyncio.wait_for(embedded_worker_task, timeout=10)
99
+ await embedded_worker.heartbeat("stopping")
100
+ except asyncio.TimeoutError:
101
+ embedded_worker_task.cancel()
102
+ logger.warning("Embedded task worker did not stop before timeout")
103
+ finally:
104
+ embedded_worker = None
105
+ embedded_worker_task = None
106
+
107
+
108
  @app.get("/")
109
  async def root():
110
  index_path = FRONTEND_DIST / "index.html"
 
138
  media_type="application/javascript",
139
  )
140
 
141
+ @app.get("/{path:path}", include_in_schema=False)
142
+ async def serve_frontend(path: str):
143
+ if not FRONTEND_DIST.exists():
144
+ return await root()
145
+
146
+ requested_path = FRONTEND_DIST / path
147
+ if requested_path.is_file():
148
+ return FileResponse(requested_path)
149
+
150
+ return Response(
151
+ content=f"window.__AUBM_CONFIG__ = {json.dumps(config)};",
152
+ media_type="application/javascript",
153
+ )
154
+
155
  @app.get("/{path:path}", include_in_schema=False)
156
  async def serve_frontend(path: str):
157
  if not FRONTEND_DIST.exists():
 
167
 
168
  return await root()
169
 
170
+ # --- Infrastructure Management ---
171
+
172
+ @app.post("/infrastructure/nodes/provision")
173
+ async def provision_node(name: str = "aubm-inference-node", size: str = "s-4vcpu-8gb-amd"):
174
+ """Creates a new inference node on DigitalOcean."""
175
+ node = await infrastructure_service.create_inference_node(name, size)
176
+ if not node:
177
+ raise HTTPException(status_code=500, detail="Failed to initiate node provisioning.")
178
+ return node
179
+
180
+ @app.get("/infrastructure/nodes/{droplet_id}/ip")
181
+ async def get_node_ip(droplet_id: int):
182
+ """Wait and return the public IP of a node."""
183
+ ip = await infrastructure_service.wait_for_ip(droplet_id)
184
+ if not ip:
185
+ raise HTTPException(status_code=404, detail="IP not assigned or timed out.")
186
+ return {"ip": ip}
187
+
188
+ @app.delete("/infrastructure/nodes/{droplet_id}")
189
+ async def terminate_node(droplet_id: int):
190
+ """Destroy an inference node."""
191
+ success = await infrastructure_service.terminate_node(droplet_id)
192
+ if not success:
193
+ raise HTTPException(status_code=500, detail="Failed to terminate node.")
194
+ return {"status": "termination_requested"}
195
+
196
  if __name__ == "__main__":
197
  import uvicorn
198
+ uvicorn.run(app, host="0.0.0.0", port=int(settings.PORT))
 
backend/services/config.py CHANGED
@@ -16,6 +16,10 @@ class Settings(BaseSettings):
16
  AMD_API_KEY: Optional[str] = None
17
  TAVILY_API_KEY: Optional[str] = None
18
 
 
 
 
 
19
  # App Config
20
  TASK_QUEUE_EMBEDDED_WORKER: bool = True
21
  TASK_QUEUE_HEARTBEAT_ENABLED: bool = True
 
16
  AMD_API_KEY: Optional[str] = None
17
  TAVILY_API_KEY: Optional[str] = None
18
 
19
+ # Infrastructure (DigitalOcean)
20
+ DO_API_TOKEN: Optional[str] = None
21
+ DO_REGION: str = "nyc3"
22
+
23
  # App Config
24
  TASK_QUEUE_EMBEDDED_WORKER: bool = True
25
  TASK_QUEUE_HEARTBEAT_ENABLED: bool = True
backend/services/infrastructure_service.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import logging
3
+ import asyncio
4
+ from typing import Optional, Dict, Any
5
+ from .config import settings
6
+
7
+ logger = logging.getLogger("infrastructure")
8
+
9
+ class InfrastructureService:
10
+ """
11
+ Manages on-the-fly compute resources on DigitalOcean for AI inference.
12
+ """
13
+ API_URL = "https://api.digitalocean.com/v2"
14
+
15
+ def __init__(self):
16
+ self.headers = {
17
+ "Authorization": f"Bearer {settings.DO_API_TOKEN}",
18
+ "Content-Type": "application/json"
19
+ }
20
+
21
+ async def create_inference_node(self, name: str, size: str = "s-4vcpu-8gb-amd") -> Optional[Dict[str, Any]]:
22
+ """
23
+ Provision a new AMD-based droplet with Ollama pre-installed.
24
+ Default size is a capable AMD-based node.
25
+ """
26
+ if not settings.DO_API_TOKEN:
27
+ logger.error("DO_API_TOKEN not configured.")
28
+ return None
29
+
30
+ # Cloud-init script to setup the inference environment
31
+ user_data = """#cloud-config
32
+ runcmd:
33
+ - curl -fsSL https://get.docker.com | sh
34
+ - docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama -e OLLAMA_HOST=0.0.0.0 ollama/ollama
35
+ - sleep 10
36
+ - docker exec ollama ollama pull llama3
37
+ """
38
+
39
+ payload = {
40
+ "name": name,
41
+ "region": settings.DO_REGION,
42
+ "size": size,
43
+ "image": "ubuntu-22-04-x64",
44
+ "user_data": user_data,
45
+ "tags": ["aubm-worker", "inference-node"]
46
+ }
47
+
48
+ async with httpx.AsyncClient() as client:
49
+ try:
50
+ response = await client.post(f"{self.API_URL}/droplets", headers=self.headers, json=payload)
51
+ response.raise_for_status()
52
+ data = response.json()
53
+ droplet_id = data["droplet"]["id"]
54
+ logger.info(f"Inference node creation initiated: {name} (ID: {droplet_id})")
55
+ return data["droplet"]
56
+ except Exception as e:
57
+ logger.error(f"Failed to create droplet: {e}")
58
+ return None
59
+
60
+ async def wait_for_ip(self, droplet_id: int, timeout: int = 300) -> Optional[str]:
61
+ """
62
+ Polls the API until the droplet has a public IP assigned.
63
+ """
64
+ start_time = asyncio.get_event_loop().time()
65
+ async with httpx.AsyncClient() as client:
66
+ while (asyncio.get_event_loop().time() - start_time) < timeout:
67
+ try:
68
+ response = await client.get(f"{self.API_URL}/droplets/{droplet_id}", headers=self.headers)
69
+ response.raise_for_status()
70
+ droplet = response.json()["droplet"]
71
+
72
+ networks = droplet.get("networks", {}).get("v4", [])
73
+ for nw in networks:
74
+ if nw["type"] == "public":
75
+ return nw["ip_address"]
76
+
77
+ except Exception as e:
78
+ logger.warning(f"Error polling droplet {droplet_id}: {e}")
79
+
80
+ await asyncio.sleep(10)
81
+ return None
82
+
83
+ async def terminate_node(self, droplet_id: int):
84
+ """
85
+ Destroy the inference node to stop billing.
86
+ """
87
+ async with httpx.AsyncClient() as client:
88
+ try:
89
+ response = await client.delete(f"{self.API_URL}/droplets/{droplet_id}", headers=self.headers)
90
+ response.raise_for_status()
91
+ logger.info(f"Inference node {droplet_id} termination requested.")
92
+ return True
93
+ except Exception as e:
94
+ logger.error(f"Failed to terminate droplet {droplet_id}: {e}")
95
+ return False
96
+
97
+ infrastructure_service = InfrastructureService()