File size: 16,672 Bytes
c38df78
 
81ff144
 
c38df78
 
aeb2234
c38df78
 
 
 
81ff144
 
 
 
 
ad68e43
 
 
 
aeb2234
 
 
 
 
 
ad68e43
 
aeb2234
ad68e43
 
 
 
 
c38df78
 
 
 
 
 
 
 
 
 
 
 
81ff144
c38df78
 
 
 
 
81ff144
 
 
 
 
 
 
 
 
0cb1aa7
 
 
81ff144
 
 
 
 
 
 
 
 
 
 
 
 
0cb1aa7
81ff144
c38df78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81ff144
c38df78
81ff144
 
 
 
 
 
 
 
 
c38df78
81ff144
 
 
 
 
 
 
 
 
 
 
 
c38df78
 
 
 
 
 
 
 
 
 
 
 
 
 
81ff144
 
 
c38df78
 
 
 
 
 
 
81ff144
 
 
 
 
 
ee2e59e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81ff144
c38df78
ad68e43
 
 
c38df78
ad68e43
81ff144
c38df78
 
 
 
 
 
 
 
 
 
 
81ff144
 
 
7ebd2cf
81ff144
7ebd2cf
 
 
 
 
 
 
 
c38df78
 
 
 
 
 
 
81ff144
0cb1aa7
c38df78
0cb1aa7
 
 
c38df78
ad68e43
0cb1aa7
ad68e43
0cb1aa7
 
 
ad68e43
 
0cb1aa7
ad68e43
 
 
a362a22
ad68e43
 
 
 
 
 
 
 
 
 
 
 
c38df78
 
 
 
0cb1aa7
 
 
 
 
 
 
 
 
 
 
c38df78
 
 
 
 
 
 
 
 
 
0cb1aa7
ad68e43
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from services.supabase_service import supabase
from services.agent_runner_service import AgentRunnerService
from services.config import settings
from services.audit_service import audit_service
from services.output_quality import report_text_from_output
from services.task_queue import TaskQueueService
from services.memory_service import memory_service
import asyncio
import json
import logging

router = APIRouter()
logger = logging.getLogger("uvicorn")


def _assert_task_quality(task: dict):
    output_data = task.get("output_data") or {}
    if not isinstance(output_data, dict):
        raise HTTPException(status_code=400, detail="Task output is missing or malformed.")
    if output_data.get("error"):
        raise HTTPException(status_code=400, detail=f"Task execution failed: {output_data['error']}")
    rendered = report_text_from_output(output_data).strip()
    if not rendered or rendered in ("{}", "[]"):
        raise HTTPException(status_code=400, detail="Task has no usable output to approve.")
    quality_review = output_data.get("quality_review")
    if not quality_review:
        raise HTTPException(status_code=400, detail="Task output is missing quality validation.")
    if quality_review.get("approved"):
        return
    reasons = quality_review.get("fail_reasons") or ["Task output failed quality validation."]
    raise HTTPException(status_code=400, detail=f"Task output failed quality review: {'; '.join(reasons)}")

def _assert_project_is_mutable(project_id: str):
    project = supabase.table("projects").select("id,status").eq("id", project_id).single().execute().data
    if not project:
        raise HTTPException(status_code=404, detail="Project not found")
    if project.get("status") == "completed":
        raise HTTPException(status_code=409, detail="Completed projects are locked and cannot be modified.")

def _assert_task_project_is_mutable(task: dict):
    project_id = task.get("project_id")
    if project_id:
        _assert_project_is_mutable(project_id)

def update_task_status(task_id: str, status: str):
    task_res = supabase.table("tasks").select("project_id").eq("id", task_id).single().execute()
    if not task_res.data:
        raise HTTPException(status_code=404, detail="Task not found")
    _assert_task_project_is_mutable(task_res.data)

    result = (
        supabase.table("tasks")
        .update({"status": status})
        .eq("id", task_id)
        .execute()
    )
    if not result.data:
        raise HTTPException(status_code=404, detail="Task not found or status was not updated")

    task_data = result.data[0]

    project_id = task_data.get("project_id")
    if project_id:
        task_result = (
            supabase.table("tasks")
            .select("id,status")
            .eq("project_id", project_id)
            .execute()
        )
        tasks = task_result.data or []
        if status == "done" and tasks and all(task.get("status") == "done" for task in tasks):
            supabase.table("projects").update({"status": "completed"}).eq("id", project_id).execute()
        elif status != "done":
            supabase.table("projects").update({"status": "active"}).eq("id", project_id).execute()

    return task_data


def _sse_event(event: str, data: dict, event_id: str | None = None) -> str:
    lines = []
    if event_id:
        lines.append(f"id: {event_id}")
    lines.append(f"event: {event}")
    payload = json.dumps(data, default=str)
    for line in payload.splitlines() or ["{}"]:
        lines.append(f"data: {line}")
    return "\n".join(lines) + "\n\n"


def _project_task_ids(project_id: str) -> list[str]:
    rows = (
        supabase.table("tasks")
        .select("id")
        .eq("project_id", project_id)
        .execute()
        .data
        or []
    )
    return [row["id"] for row in rows if row.get("id")]


def _user_id_from_access_token(access_token: str | None) -> str:
    if not access_token:
        raise HTTPException(status_code=401, detail="Missing access token")
    try:
        auth_user = supabase.auth.get_user(access_token)
        user = getattr(auth_user, "user", None)
        user_id = getattr(user, "id", None)
        if not user_id and isinstance(auth_user, dict):
            user_id = auth_user.get("user", {}).get("id")
    except Exception as exc:
        logger.warning("Could not validate log stream access token: %s", exc)
        raise HTTPException(status_code=401, detail="Invalid access token") from exc
    if not user_id:
        raise HTTPException(status_code=401, detail="Invalid access token")
    return user_id


def _team_ids_for_user(user_id: str) -> list[str]:
    try:
        rows = (
            supabase.table("team_members")
            .select("team_id")
            .eq("user_id", user_id)
            .execute()
            .data
            or []
        )
    except Exception as exc:
        logger.warning("Team membership lookup unavailable for log stream: %s", exc)
        return []
    return [row["team_id"] for row in rows if row.get("team_id")]


def _project_ids_for_user(user_id: str) -> list[str]:
    project_ids: set[str] = set()

    owned = (
        supabase.table("projects")
        .select("id")
        .eq("owner_id", user_id)
        .execute()
        .data
        or []
    )
    project_ids.update(row["id"] for row in owned if row.get("id"))

    public = (
        supabase.table("projects")
        .select("id")
        .eq("is_public", True)
        .execute()
        .data
        or []
    )
    project_ids.update(row["id"] for row in public if row.get("id"))

    team_ids = _team_ids_for_user(user_id)
    if team_ids:
        team_projects = (
            supabase.table("projects")
            .select("id")
            .in_("team_id", team_ids)
            .execute()
            .data
            or []
        )
        project_ids.update(row["id"] for row in team_projects if row.get("id"))

    return list(project_ids)


def _can_view_project_for_user(project_id: str, user_id: str) -> bool:
    if not project_id:
        return False
    if project_id in _project_ids_for_user(user_id):
        return True
    return False


def _authorized_task_ids(user_id: str, project_id: str | None = None, task_id: str | None = None) -> list[str]:
    if task_id:
        task = supabase.table("tasks").select("id,project_id").eq("id", task_id).single().execute().data
        if not task or not _can_view_project_for_user(task.get("project_id"), user_id):
            raise HTTPException(status_code=403, detail="Task logs are not visible to this user")
        return [task_id]

    if project_id:
        if not _can_view_project_for_user(project_id, user_id):
            raise HTTPException(status_code=403, detail="Project logs are not visible to this user")
        return _project_task_ids(project_id)

    project_ids = _project_ids_for_user(user_id)
    if not project_ids:
        return []
    rows = (
        supabase.table("tasks")
        .select("id")
        .in_("project_id", project_ids)
        .execute()
        .data
        or []
    )
    return [row["id"] for row in rows if row.get("id")]


def _fetch_recent_logs(
    limit: int = 50,
    after_created_at: str | None = None,
    *,
    task_ids: list[str],
) -> list[dict]:
    if not task_ids:
        return []
    query = (
        supabase.table("agent_logs")
        .select("id,task_id,run_id,action,content,metadata,created_at")
        .order("created_at", desc=after_created_at is None)
        .limit(limit)
        .in_("task_id", task_ids)
    )
    if after_created_at:
        query = query.gt("created_at", after_created_at)
    rows = query.execute().data or []
    return rows if after_created_at else list(reversed(rows))


@router.get("/logs/stream")
async def stream_agent_logs(
    request: Request,
    limit: int = 50,
    project_id: str | None = None,
    task_id: str | None = None,
    access_token: str | None = None,
):
    """
    Streams agent log inserts as Server-Sent Events.
    """
    if project_id and task_id:
        raise HTTPException(status_code=400, detail="Use either project_id or task_id, not both.")
    user_id = _user_id_from_access_token(access_token)
    task_ids = _authorized_task_ids(user_id, project_id=project_id, task_id=task_id)

    async def event_generator():
        last_created_at = None
        sent_ids: set[str] = set()
        yield _sse_event("ready", {
            "message": "Agent log stream connected",
            "project_id": project_id,
            "task_id": task_id,
            "user_id": user_id,
        })

        while not await request.is_disconnected():
            try:
                rows = _fetch_recent_logs(
                    limit=max(1, min(limit, 100)),
                    after_created_at=last_created_at,
                    task_ids=task_ids,
                )
                for row in rows:
                    row_id = row.get("id")
                    if row_id in sent_ids:
                        continue
                    sent_ids.add(row_id)
                    if len(sent_ids) > 500:
                        sent_ids = set(list(sent_ids)[-250:])
                    last_created_at = row.get("created_at") or last_created_at
                    yield _sse_event("log", row, row_id)
            except Exception as exc:
                logger.warning("Agent log SSE stream failed to fetch logs: %s", exc)
                yield _sse_event("error", {"message": str(exc)})

            yield ": keep-alive\n\n"
            await asyncio.sleep(1)

    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "Connection": "keep-alive",
            "X-Accel-Buffering": "no",
        },
    )


@router.post("/{task_id}/run")
async def run_task(task_id: str, background_tasks: BackgroundTasks, use_queue: bool | None = None):
    """
    Triggers the execution of a specific task.
    """
    # 1. Fetch task data
    task_res = supabase.table("tasks").select("*, project:projects(*)").eq("id", task_id).single().execute()
    if not task_res.data:
        raise HTTPException(status_code=404, detail="Task not found")
    
    task = task_res.data
    _assert_task_project_is_mutable(task)
    
    # 2. Check if agent is assigned
    agent_id = task.get("assigned_agent_id")
    if not agent_id:
        raise HTTPException(status_code=400, detail="No agent assigned to this task")
    
    # 3. Fetch agent data
    agent_res = supabase.table("agents").select("*").eq("id", agent_id).single().execute()
    if not agent_res.data:
        raise HTTPException(status_code=404, detail="Assigned agent not found")
    
    agent_data = agent_res.data

    should_queue = use_queue if use_queue is not None else settings.TASK_EXECUTION_MODE == "queue"
    if should_queue:
        queued = await TaskQueueService.queue_task(task_id)
        if not queued or not queued.data:
            raise HTTPException(status_code=500, detail="Task could not be queued")
        await audit_service.log_action(
            user_id=task.get("project", {}).get("owner_id"),
            action="task_queued",
            agent_id=agent_id,
            task_id=task_id,
            metadata={"project_id": task.get("project_id"), "source": "task_run_endpoint"},
        )
        return {"message": "Task queued for worker execution", "task_id": task_id, "mode": "queue"}
    
    # 4. Update task status to in_progress
    supabase.table("tasks").update({"status": "in_progress"}).eq("id", task_id).execute()
    await audit_service.log_action(
        user_id=task.get("project", {}).get("owner_id"),
        action="task_run_started",
        agent_id=agent_id,
        task_id=task_id,
        metadata={"project_id": task.get("project_id"), "mode": "direct"},
    )
    
    # 5. Run in background
    background_tasks.add_task(AgentRunnerService.execute_agent_logic, task, agent_data)
    
    return {"message": "Task execution started", "task_id": task_id}

@router.patch("/{task_id}/output")
async def update_task_output(task_id: str, payload: dict):
    """
    Updates the output_data of a task. Allows for manual human corrections.
    """
    if "output_data" not in payload:
        raise HTTPException(status_code=400, detail="Missing output_data in payload")
    
    # Verify task existence and project state
    task_res = supabase.table("tasks").select("id, project_id").eq("id", task_id).single().execute()
    if not task_res.data:
        raise HTTPException(status_code=404, detail="Task not found")
    _assert_task_project_is_mutable(task_res.data)

    result = supabase.table("tasks").update({
        "output_data": payload["output_data"]
    }).eq("id", task_id).execute()

    if not result.data:
        raise HTTPException(status_code=500, detail="Failed to update task output")
        
    await audit_service.log_action(
        user_id=None,
        action="task_output_manually_edited",
        task_id=task_id,
        metadata={"project_id": task_res.data["project_id"]}
    )

    return {"message": "Task output updated", "task": result.data[0]}

@router.post("/{task_id}/approve")
async def approve_task(task_id: str, background_tasks: BackgroundTasks):
    task_res = supabase.table("tasks").select("*").eq("id", task_id).single().execute()
    if not task_res.data:
        raise HTTPException(status_code=404, detail="Task not found")
    _assert_task_project_is_mutable(task_res.data)
    _assert_task_quality(task_res.data)
    task = update_task_status(task_id, "done")
    
    # Index for Long-Term Memory
    background_tasks.add_task(memory_service.index_task_output, task)
    
    await audit_service.log_action(
        user_id=None,
        action="task_approved",
        agent_id=task.get("assigned_agent_id"),
        task_id=task_id,
        metadata={"project_id": task.get("project_id")},
    )
    return {"message": "Task approved", "task": task}

@router.post("/{task_id}/reject")
async def reject_task(task_id: str, background_tasks: BackgroundTasks, feedback: str | None = None):
    task = update_task_status(task_id, "todo")
    
    # Trigger Self-Optimization Loop
    background_tasks.add_task(
        memory_service.analyze_rejection, 
        task_id=task_id, 
        feedback=feedback
    )
    
    await audit_service.log_action(
        user_id=None,
        action="task_rejected",
        agent_id=task.get("assigned_agent_id"),
        task_id=task_id,
        metadata={"project_id": task.get("project_id")},
    )
    return {"message": "Task rejected", "task": task}
@router.post("/project/{project_id}/approve-all")
async def approve_all_tasks(project_id: str, background_tasks: BackgroundTasks):
    """
    Approves all tasks in a project that are awaiting approval.
    """
    _assert_project_is_mutable(project_id)
    waiting_tasks = (
        supabase.table("tasks")
        .select("*")
        .eq("project_id", project_id)
        .eq("status", "awaiting_approval")
        .execute()
        .data
        or []
    )
    blocked = []
    approvable_ids = []
    for task in waiting_tasks:
        approvable_ids.append(task["id"])

    # 1. Update tasks
    result_data = []
    if approvable_ids:
        result = (
            supabase.table("tasks")
            .update({"status": "done"})
            .eq("project_id", project_id)
            .in_("id", approvable_ids)
            .execute()
        )
        result_data = result.data or []
        
        # Index all approved tasks for Long-Term Memory
        for approved_task in result_data:
            background_tasks.add_task(memory_service.index_task_output, approved_task)
    
    # 2. Check if all tasks in project are now done
    task_result = (
        supabase.table("tasks")
        .select("status")
        .eq("project_id", project_id)
        .execute()
    )
    tasks = task_result.data or []
    if tasks and all(task.get("status") == "done" for task in tasks):
        supabase.table("projects").update({"status": "completed"}).eq("id", project_id).execute()

    await audit_service.log_action(
        user_id=None,
        action="tasks_approved_bulk",
        metadata={
            "project_id": project_id,
            "approved_count": len(result_data),
            "blocked_count": len(blocked),
        },
    )
    
    return {
        "message": f"Approved {len(result_data)} tasks",
        "count": len(result_data),
        "blocked": blocked
    }