Spaces:
Running on CPU Upgrade
Run jobs under user or org accounts with upgrade / org-pick UX (#132)
Browse files* agent/core/agent_loop.py: jobs access rewiring
* agent/core/telemetry.py: jobs access rewiring
* agent/tools/jobs_tool.py: jobs access rewiring
* backend/dependencies.py: jobs access rewiring
* backend/models.py: jobs access rewiring
* backend/routes/agent.py: jobs access rewiring
* frontend/src/components/Chat/ChatInput.tsx: jobs access rewiring
* frontend/src/components/ClaudeCapDialog.tsx: jobs access rewiring
* frontend/src/components/SessionChat.tsx: jobs access rewiring
* frontend/src/hooks/useAgentChat.ts: jobs access rewiring
* frontend/src/lib/sse-chat-transport.ts: jobs access rewiring
* frontend/src/store/agentStore.ts: jobs access rewiring
* frontend/src/types/agent.ts: jobs access rewiring
* scripts/build_kpis.py: jobs access rewiring
* tests/unit/test_build_kpis.py: jobs access rewiring
* agent/core/hf_access.py: jobs access rewiring
* frontend/src/components/JobsUpgradeDialog.tsx: jobs access rewiring
* tests/unit/test_hf_access.py: jobs access rewiring
* ci: re-trigger Claude review after OIDC fix
* hf_access: avoid blocking fallback whoami call
* agent routes: remove dead can_run_jobs branch
- agent/core/agent_loop.py +3 -0
- agent/core/hf_access.py +181 -0
- agent/core/telemetry.py +38 -0
- agent/tools/jobs_tool.py +20 -2
- backend/dependencies.py +6 -50
- backend/models.py +1 -0
- backend/routes/agent.py +154 -2
- frontend/src/components/Chat/ChatInput.tsx +55 -1
- frontend/src/components/ClaudeCapDialog.tsx +3 -0
- frontend/src/components/JobsUpgradeDialog.tsx +191 -0
- frontend/src/components/SessionChat.tsx +3 -1
- frontend/src/hooks/useAgentChat.ts +130 -1
- frontend/src/lib/sse-chat-transport.ts +26 -0
- frontend/src/store/agentStore.ts +37 -0
- frontend/src/types/agent.ts +1 -0
- scripts/build_kpis.py +24 -2
- tests/unit/test_build_kpis.py +32 -0
- tests/unit/test_hf_access.py +39 -0
|
@@ -1137,6 +1137,9 @@ class Handlers:
|
|
| 1137 |
tool_args["script"] = edited_script
|
| 1138 |
was_edited = True
|
| 1139 |
logger.info(f"Using user-edited script for {tool_name} ({tc.id})")
|
|
|
|
|
|
|
|
|
|
| 1140 |
approved_tasks.append((tc, tool_name, tool_args, was_edited))
|
| 1141 |
else:
|
| 1142 |
rejected_tasks.append((tc, tool_name, approval_decision))
|
|
|
|
| 1137 |
tool_args["script"] = edited_script
|
| 1138 |
was_edited = True
|
| 1139 |
logger.info(f"Using user-edited script for {tool_name} ({tc.id})")
|
| 1140 |
+
selected_namespace = approval_decision.get("namespace")
|
| 1141 |
+
if selected_namespace and tool_name == "hf_jobs":
|
| 1142 |
+
tool_args["namespace"] = selected_namespace
|
| 1143 |
approved_tasks.append((tc, tool_name, tool_args, was_edited))
|
| 1144 |
else:
|
| 1145 |
rejected_tasks.append((tc, tool_name, approval_decision))
|
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers for Hugging Face account / org access decisions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
import os
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import httpx
|
| 11 |
+
|
| 12 |
+
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class JobsAccess:
|
| 17 |
+
"""Jobs entitlement derived from whoami-v2."""
|
| 18 |
+
|
| 19 |
+
username: str | None
|
| 20 |
+
plan: str
|
| 21 |
+
personal_can_run_jobs: bool
|
| 22 |
+
paid_org_names: list[str]
|
| 23 |
+
eligible_namespaces: list[str]
|
| 24 |
+
default_namespace: str | None
|
| 25 |
+
access_known: bool = True
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def can_run_jobs(self) -> bool:
|
| 29 |
+
return bool(self.default_namespace)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class JobsAccessError(Exception):
|
| 33 |
+
"""Structured jobs access error for upgrade / namespace gating."""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
message: str,
|
| 38 |
+
*,
|
| 39 |
+
access: JobsAccess | None = None,
|
| 40 |
+
upgrade_required: bool = False,
|
| 41 |
+
namespace_required: bool = False,
|
| 42 |
+
) -> None:
|
| 43 |
+
super().__init__(message)
|
| 44 |
+
self.access = access
|
| 45 |
+
self.upgrade_required = upgrade_required
|
| 46 |
+
self.namespace_required = namespace_required
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _extract_username(whoami: dict[str, Any]) -> str | None:
|
| 50 |
+
for key in ("name", "user", "preferred_username"):
|
| 51 |
+
value = whoami.get(key)
|
| 52 |
+
if isinstance(value, str) and value:
|
| 53 |
+
return value
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _normalize_personal_plan(whoami: dict[str, Any]) -> str:
|
| 58 |
+
plan_str = ""
|
| 59 |
+
for key in ("plan", "type", "accountType"):
|
| 60 |
+
value = whoami.get(key)
|
| 61 |
+
if isinstance(value, str) and value:
|
| 62 |
+
plan_str = value.lower()
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
if not plan_str and (whoami.get("isPro") is True or whoami.get("is_pro") is True):
|
| 66 |
+
return "pro"
|
| 67 |
+
|
| 68 |
+
if any(tag in plan_str for tag in ("pro", "enterprise", "team")):
|
| 69 |
+
return "pro"
|
| 70 |
+
return "free"
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _paid_org_names(whoami: dict[str, Any]) -> list[str]:
|
| 74 |
+
names: list[str] = []
|
| 75 |
+
orgs = whoami.get("orgs") or []
|
| 76 |
+
if not isinstance(orgs, list):
|
| 77 |
+
return names
|
| 78 |
+
|
| 79 |
+
for org in orgs:
|
| 80 |
+
if not isinstance(org, dict):
|
| 81 |
+
continue
|
| 82 |
+
name = org.get("name")
|
| 83 |
+
if not isinstance(name, str) or not name:
|
| 84 |
+
continue
|
| 85 |
+
org_plan = str(org.get("plan") or org.get("type") or "").lower()
|
| 86 |
+
if any(tag in org_plan for tag in ("pro", "enterprise", "team")):
|
| 87 |
+
names.append(name)
|
| 88 |
+
return sorted(set(names))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def jobs_access_from_whoami(whoami: dict[str, Any]) -> JobsAccess:
|
| 92 |
+
username = _extract_username(whoami)
|
| 93 |
+
personal_plan = _normalize_personal_plan(whoami)
|
| 94 |
+
paid_orgs = _paid_org_names(whoami)
|
| 95 |
+
personal_can_run = personal_plan == "pro"
|
| 96 |
+
|
| 97 |
+
eligible_namespaces: list[str] = []
|
| 98 |
+
if personal_can_run and username:
|
| 99 |
+
eligible_namespaces.append(username)
|
| 100 |
+
eligible_namespaces.extend(paid_orgs)
|
| 101 |
+
|
| 102 |
+
plan = "pro" if personal_can_run else ("org" if paid_orgs else "free")
|
| 103 |
+
default_namespace = username if personal_can_run and username else None
|
| 104 |
+
|
| 105 |
+
return JobsAccess(
|
| 106 |
+
username=username,
|
| 107 |
+
plan=plan,
|
| 108 |
+
personal_can_run_jobs=personal_can_run,
|
| 109 |
+
paid_org_names=paid_orgs,
|
| 110 |
+
eligible_namespaces=eligible_namespaces,
|
| 111 |
+
default_namespace=default_namespace,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
async def fetch_whoami_v2(token: str, timeout: float = 5.0) -> dict[str, Any] | None:
|
| 116 |
+
if not token:
|
| 117 |
+
return None
|
| 118 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 119 |
+
try:
|
| 120 |
+
response = await client.get(
|
| 121 |
+
f"{OPENID_PROVIDER_URL}/api/whoami-v2",
|
| 122 |
+
headers={"Authorization": f"Bearer {token}"},
|
| 123 |
+
)
|
| 124 |
+
if response.status_code != 200:
|
| 125 |
+
return None
|
| 126 |
+
payload = response.json()
|
| 127 |
+
return payload if isinstance(payload, dict) else None
|
| 128 |
+
except (httpx.HTTPError, ValueError):
|
| 129 |
+
return None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
async def get_jobs_access(token: str) -> JobsAccess | None:
|
| 133 |
+
whoami = await fetch_whoami_v2(token)
|
| 134 |
+
if whoami is None:
|
| 135 |
+
return None
|
| 136 |
+
return jobs_access_from_whoami(whoami)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
async def resolve_jobs_namespace(
|
| 140 |
+
token: str,
|
| 141 |
+
requested_namespace: str | None = None,
|
| 142 |
+
) -> tuple[str, JobsAccess | None]:
|
| 143 |
+
"""Return the namespace to use for jobs.
|
| 144 |
+
|
| 145 |
+
If whoami-v2 is unavailable, fall back to the token owner's username.
|
| 146 |
+
"""
|
| 147 |
+
access = await get_jobs_access(token)
|
| 148 |
+
if access:
|
| 149 |
+
if requested_namespace:
|
| 150 |
+
if requested_namespace in access.eligible_namespaces:
|
| 151 |
+
return requested_namespace, access
|
| 152 |
+
raise JobsAccessError(
|
| 153 |
+
f"You can only run jobs under your own Pro account or a paid org you belong to. "
|
| 154 |
+
f"Allowed namespaces: {', '.join(access.eligible_namespaces) or '(none)'}",
|
| 155 |
+
access=access,
|
| 156 |
+
)
|
| 157 |
+
if access.default_namespace:
|
| 158 |
+
return access.default_namespace, access
|
| 159 |
+
if access.paid_org_names:
|
| 160 |
+
raise JobsAccessError(
|
| 161 |
+
"Choose which paid organization should own this job run.",
|
| 162 |
+
access=access,
|
| 163 |
+
namespace_required=True,
|
| 164 |
+
)
|
| 165 |
+
raise JobsAccessError(
|
| 166 |
+
"Hugging Face Jobs are available only to Pro users and Team or Enterprise organizations. "
|
| 167 |
+
"Upgrade to Pro, or run the job under a paid org you belong to.",
|
| 168 |
+
access=access,
|
| 169 |
+
upgrade_required=True,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Fallback: whoami-v2 unavailable. Do not block the call pre-emptively.
|
| 173 |
+
from huggingface_hub import HfApi
|
| 174 |
+
|
| 175 |
+
username = None
|
| 176 |
+
if token:
|
| 177 |
+
whoami = await asyncio.to_thread(HfApi(token=token).whoami)
|
| 178 |
+
username = whoami.get("name")
|
| 179 |
+
if not username:
|
| 180 |
+
raise JobsAccessError("No HF token available to resolve a jobs namespace.")
|
| 181 |
+
return requested_namespace or username, None
|
|
@@ -141,6 +141,7 @@ async def record_hf_job_submit(
|
|
| 141 |
"timeout": args.get("timeout", "30m"),
|
| 142 |
"job_type": job_type,
|
| 143 |
"image": image,
|
|
|
|
| 144 |
"push_to_hub": _infer_push_to_hub(script_text),
|
| 145 |
},
|
| 146 |
))
|
|
@@ -239,6 +240,43 @@ async def record_feedback(
|
|
| 239 |
logger.debug("record_feedback failed (non-fatal): %s", e)
|
| 240 |
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
# ── heartbeat ──────────────────────────────────────────────────────────────
|
| 243 |
|
| 244 |
# Module-level reference set for fire-and-forget heartbeat tasks. asyncio only
|
|
|
|
| 141 |
"timeout": args.get("timeout", "30m"),
|
| 142 |
"job_type": job_type,
|
| 143 |
"image": image,
|
| 144 |
+
"namespace": args.get("namespace"),
|
| 145 |
"push_to_hub": _infer_push_to_hub(script_text),
|
| 146 |
},
|
| 147 |
))
|
|
|
|
| 240 |
logger.debug("record_feedback failed (non-fatal): %s", e)
|
| 241 |
|
| 242 |
|
| 243 |
+
async def record_jobs_access_blocked(
|
| 244 |
+
session: Any,
|
| 245 |
+
*,
|
| 246 |
+
tool_call_ids: list[str],
|
| 247 |
+
plan: str,
|
| 248 |
+
eligible_namespaces: list[str],
|
| 249 |
+
) -> None:
|
| 250 |
+
from agent.core.session import Event
|
| 251 |
+
try:
|
| 252 |
+
await session.send_event(Event(
|
| 253 |
+
event_type="jobs_access_blocked",
|
| 254 |
+
data={
|
| 255 |
+
"tool_call_ids": tool_call_ids,
|
| 256 |
+
"plan": plan,
|
| 257 |
+
"eligible_namespaces": eligible_namespaces,
|
| 258 |
+
},
|
| 259 |
+
))
|
| 260 |
+
except Exception as e:
|
| 261 |
+
logger.debug("record_jobs_access_blocked failed (non-fatal): %s", e)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
async def record_pro_cta_click(
|
| 265 |
+
session: Any,
|
| 266 |
+
*,
|
| 267 |
+
source: str,
|
| 268 |
+
target: str = "pro_pricing",
|
| 269 |
+
) -> None:
|
| 270 |
+
from agent.core.session import Event
|
| 271 |
+
try:
|
| 272 |
+
await session.send_event(Event(
|
| 273 |
+
event_type="pro_cta_click",
|
| 274 |
+
data={"source": source, "target": target},
|
| 275 |
+
))
|
| 276 |
+
except Exception as e:
|
| 277 |
+
logger.debug("record_pro_cta_click failed (non-fatal): %s", e)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
# ── heartbeat ──────────────────────────────────────────────────────────────
|
| 281 |
|
| 282 |
# Module-level reference set for fire-and-forget heartbeat tasks. asyncio only
|
|
@@ -17,6 +17,7 @@ import httpx
|
|
| 17 |
from huggingface_hub import HfApi
|
| 18 |
from huggingface_hub.utils import HfHubHTTPError
|
| 19 |
|
|
|
|
| 20 |
from agent.core.session import Event
|
| 21 |
from agent.tools.types import ToolResult
|
| 22 |
|
|
@@ -298,6 +299,7 @@ class HfJobsTool:
|
|
| 298 |
self,
|
| 299 |
hf_token: Optional[str] = None,
|
| 300 |
namespace: Optional[str] = None,
|
|
|
|
| 301 |
log_callback: Optional[Callable[[str], Awaitable[None]]] = None,
|
| 302 |
session: Any = None,
|
| 303 |
tool_call_id: Optional[str] = None,
|
|
@@ -305,6 +307,7 @@ class HfJobsTool:
|
|
| 305 |
self.hf_token = hf_token
|
| 306 |
self.api = HfApi(token=hf_token)
|
| 307 |
self.namespace = namespace
|
|
|
|
| 308 |
self.log_callback = log_callback
|
| 309 |
self.session = session
|
| 310 |
self.tool_call_id = tool_call_id
|
|
@@ -565,7 +568,7 @@ class HfJobsTool:
|
|
| 565 |
from agent.core import telemetry
|
| 566 |
submit_ts = await telemetry.record_hf_job_submit(
|
| 567 |
self.session, job,
|
| 568 |
-
{**args, "hardware_flavor": flavor, "timeout": timeout_str},
|
| 569 |
image=image, job_type=job_type,
|
| 570 |
)
|
| 571 |
|
|
@@ -1057,6 +1060,14 @@ HF_JOBS_TOOL_SPEC = {
|
|
| 1057 |
"type": "object",
|
| 1058 |
"description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
|
| 1059 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1060 |
"job_id": {
|
| 1061 |
"type": "string",
|
| 1062 |
"description": "Job ID. Required for: logs, inspect, cancel.",
|
|
@@ -1099,11 +1110,18 @@ async def hf_jobs_handler(
|
|
| 1099 |
arguments = {**arguments, "script": content}
|
| 1100 |
|
| 1101 |
hf_token = session.hf_token if session else None
|
| 1102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1103 |
|
| 1104 |
tool = HfJobsTool(
|
| 1105 |
namespace=namespace,
|
| 1106 |
hf_token=hf_token,
|
|
|
|
| 1107 |
log_callback=log_callback if session else None,
|
| 1108 |
session=session,
|
| 1109 |
tool_call_id=tool_call_id,
|
|
|
|
| 17 |
from huggingface_hub import HfApi
|
| 18 |
from huggingface_hub.utils import HfHubHTTPError
|
| 19 |
|
| 20 |
+
from agent.core.hf_access import JobsAccessError, resolve_jobs_namespace
|
| 21 |
from agent.core.session import Event
|
| 22 |
from agent.tools.types import ToolResult
|
| 23 |
|
|
|
|
| 299 |
self,
|
| 300 |
hf_token: Optional[str] = None,
|
| 301 |
namespace: Optional[str] = None,
|
| 302 |
+
jobs_access: Any = None,
|
| 303 |
log_callback: Optional[Callable[[str], Awaitable[None]]] = None,
|
| 304 |
session: Any = None,
|
| 305 |
tool_call_id: Optional[str] = None,
|
|
|
|
| 307 |
self.hf_token = hf_token
|
| 308 |
self.api = HfApi(token=hf_token)
|
| 309 |
self.namespace = namespace
|
| 310 |
+
self.jobs_access = jobs_access
|
| 311 |
self.log_callback = log_callback
|
| 312 |
self.session = session
|
| 313 |
self.tool_call_id = tool_call_id
|
|
|
|
| 568 |
from agent.core import telemetry
|
| 569 |
submit_ts = await telemetry.record_hf_job_submit(
|
| 570 |
self.session, job,
|
| 571 |
+
{**args, "hardware_flavor": flavor, "timeout": timeout_str, "namespace": self.namespace},
|
| 572 |
image=image, job_type=job_type,
|
| 573 |
)
|
| 574 |
|
|
|
|
| 1060 |
"type": "object",
|
| 1061 |
"description": "Environment variables {'KEY': 'VALUE'}. HF_TOKEN is auto-included.",
|
| 1062 |
},
|
| 1063 |
+
"namespace": {
|
| 1064 |
+
"type": "string",
|
| 1065 |
+
"description": (
|
| 1066 |
+
"Optional namespace to run the job under. Must be your own Pro account "
|
| 1067 |
+
"or a paid org you belong to. If omitted, the tool prefers your personal "
|
| 1068 |
+
"account when eligible, otherwise the first eligible paid org."
|
| 1069 |
+
),
|
| 1070 |
+
},
|
| 1071 |
"job_id": {
|
| 1072 |
"type": "string",
|
| 1073 |
"description": "Job ID. Required for: logs, inspect, cancel.",
|
|
|
|
| 1110 |
arguments = {**arguments, "script": content}
|
| 1111 |
|
| 1112 |
hf_token = session.hf_token if session else None
|
| 1113 |
+
try:
|
| 1114 |
+
namespace, jobs_access = await resolve_jobs_namespace(
|
| 1115 |
+
hf_token or "",
|
| 1116 |
+
arguments.get("namespace"),
|
| 1117 |
+
)
|
| 1118 |
+
except JobsAccessError as e:
|
| 1119 |
+
return str(e), False
|
| 1120 |
|
| 1121 |
tool = HfJobsTool(
|
| 1122 |
namespace=namespace,
|
| 1123 |
hf_token=hf_token,
|
| 1124 |
+
jobs_access=jobs_access,
|
| 1125 |
log_callback=log_callback if session else None,
|
| 1126 |
session=session,
|
| 1127 |
tool_call_id=tool_call_id,
|
|
@@ -12,6 +12,8 @@ from typing import Any
|
|
| 12 |
import httpx
|
| 13 |
from fastapi import HTTPException, Request, status
|
| 14 |
|
|
|
|
|
|
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
| 17 |
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
|
|
@@ -80,41 +82,6 @@ def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]:
|
|
| 80 |
}
|
| 81 |
|
| 82 |
|
| 83 |
-
def _normalize_plan(whoami: dict[str, Any]) -> str:
|
| 84 |
-
"""Map an HF /api/whoami-v2 payload to one of: 'free' | 'pro' | 'org'.
|
| 85 |
-
|
| 86 |
-
The exact field shape in whoami-v2 isn't documented for our purposes,
|
| 87 |
-
so we try a handful of likely keys and fall back to 'free'. The first
|
| 88 |
-
call logs the raw shape at DEBUG (see `_fetch_user_plan`) so we can
|
| 89 |
-
pin the real key post-deploy.
|
| 90 |
-
"""
|
| 91 |
-
plan_str = ""
|
| 92 |
-
for key in ("plan", "type", "accountType"):
|
| 93 |
-
val = whoami.get(key)
|
| 94 |
-
if isinstance(val, str) and val:
|
| 95 |
-
plan_str = val.lower()
|
| 96 |
-
break
|
| 97 |
-
|
| 98 |
-
if not plan_str:
|
| 99 |
-
if whoami.get("isPro") is True or whoami.get("is_pro") is True:
|
| 100 |
-
return "pro"
|
| 101 |
-
|
| 102 |
-
if "pro" in plan_str or "enterprise" in plan_str or "team" in plan_str:
|
| 103 |
-
return "pro"
|
| 104 |
-
|
| 105 |
-
# Org tier: anyone in a paid / enterprise org. We don't pay for this
|
| 106 |
-
# right now, but the "pro" cap applies identically.
|
| 107 |
-
orgs = whoami.get("orgs") or []
|
| 108 |
-
if isinstance(orgs, list):
|
| 109 |
-
for org in orgs:
|
| 110 |
-
if isinstance(org, dict):
|
| 111 |
-
org_plan = str(org.get("plan") or org.get("type") or "").lower()
|
| 112 |
-
if "pro" in org_plan or "enterprise" in org_plan or "team" in org_plan:
|
| 113 |
-
return "org"
|
| 114 |
-
|
| 115 |
-
return "free"
|
| 116 |
-
|
| 117 |
-
|
| 118 |
async def _fetch_user_plan(token: str) -> str:
|
| 119 |
"""Look up the user's HF plan via /api/whoami-v2.
|
| 120 |
|
|
@@ -123,19 +90,9 @@ async def _fetch_user_plan(token: str) -> str:
|
|
| 123 |
grant the Pro cap than over-grant it on bad data.
|
| 124 |
"""
|
| 125 |
global _WHOAMI_SHAPE_LOGGED
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
f"{OPENID_PROVIDER_URL}/api/whoami-v2",
|
| 130 |
-
headers={"Authorization": f"Bearer {token}"},
|
| 131 |
-
)
|
| 132 |
-
if resp.status_code != 200:
|
| 133 |
-
return "free"
|
| 134 |
-
whoami = resp.json()
|
| 135 |
-
except httpx.HTTPError:
|
| 136 |
-
return "free"
|
| 137 |
-
except ValueError:
|
| 138 |
-
return "free"
|
| 139 |
|
| 140 |
if not _WHOAMI_SHAPE_LOGGED:
|
| 141 |
_WHOAMI_SHAPE_LOGGED = True
|
|
@@ -149,7 +106,7 @@ async def _fetch_user_plan(token: str) -> str:
|
|
| 149 |
|
| 150 |
if not isinstance(whoami, dict):
|
| 151 |
return "free"
|
| 152 |
-
return
|
| 153 |
|
| 154 |
|
| 155 |
async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
|
|
@@ -246,4 +203,3 @@ async def require_huggingface_org_member(request: Request) -> bool:
|
|
| 246 |
return False
|
| 247 |
return await check_org_membership(token, HF_EMPLOYEE_ORG)
|
| 248 |
|
| 249 |
-
|
|
|
|
| 12 |
import httpx
|
| 13 |
from fastapi import HTTPException, Request, status
|
| 14 |
|
| 15 |
+
from agent.core.hf_access import fetch_whoami_v2, jobs_access_from_whoami
|
| 16 |
+
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co")
|
|
|
|
| 82 |
}
|
| 83 |
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
async def _fetch_user_plan(token: str) -> str:
|
| 86 |
"""Look up the user's HF plan via /api/whoami-v2.
|
| 87 |
|
|
|
|
| 90 |
grant the Pro cap than over-grant it on bad data.
|
| 91 |
"""
|
| 92 |
global _WHOAMI_SHAPE_LOGGED
|
| 93 |
+
whoami = await fetch_whoami_v2(token)
|
| 94 |
+
if whoami is None:
|
| 95 |
+
return "free"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
if not _WHOAMI_SHAPE_LOGGED:
|
| 98 |
_WHOAMI_SHAPE_LOGGED = True
|
|
|
|
| 106 |
|
| 107 |
if not isinstance(whoami, dict):
|
| 108 |
return "free"
|
| 109 |
+
return jobs_access_from_whoami(whoami).plan
|
| 110 |
|
| 111 |
|
| 112 |
async def _extract_user_from_token(token: str) -> dict[str, Any] | None:
|
|
|
|
| 203 |
return False
|
| 204 |
return await check_org_membership(token, HF_EMPLOYEE_ORG)
|
| 205 |
|
|
|
|
@@ -38,6 +38,7 @@ class ToolApproval(BaseModel):
|
|
| 38 |
approved: bool
|
| 39 |
feedback: str | None = None
|
| 40 |
edited_script: str | None = None
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
class ApprovalRequest(BaseModel):
|
|
|
|
| 38 |
approved: bool
|
| 39 |
feedback: str | None = None
|
| 40 |
edited_script: str | None = None
|
| 41 |
+
namespace: str | None = None
|
| 42 |
|
| 43 |
|
| 44 |
class ApprovalRequest(BaseModel):
|
|
@@ -32,6 +32,7 @@ from session_manager import MAX_SESSIONS, AgentSession, SessionCapacityError, se
|
|
| 32 |
|
| 33 |
import user_quotas
|
| 34 |
|
|
|
|
| 35 |
from agent.core.llm_params import _resolve_llm_params
|
| 36 |
|
| 37 |
logger = logging.getLogger(__name__)
|
|
@@ -136,6 +137,105 @@ async def _enforce_claude_quota(
|
|
| 136 |
agent_session.claude_counted = True
|
| 137 |
|
| 138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
|
| 140 |
"""Verify the user has access to the given session. Raises 403 or 404."""
|
| 141 |
info = session_manager.get_session_info(session_id)
|
|
@@ -442,6 +542,27 @@ async def get_user_quota(user: dict = Depends(get_current_user)) -> dict:
|
|
| 442 |
}
|
| 443 |
|
| 444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
@router.get("/sessions", response_model=list[SessionInfo])
|
| 446 |
async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
|
| 447 |
"""List sessions belonging to the authenticated user."""
|
|
@@ -482,15 +603,20 @@ async def submit_approval(
|
|
| 482 |
) -> dict:
|
| 483 |
"""Submit tool approvals to a session. Only accessible by the session owner."""
|
| 484 |
_check_session_access(request.session_id, user)
|
|
|
|
|
|
|
|
|
|
| 485 |
approvals = [
|
| 486 |
{
|
| 487 |
"tool_call_id": a.tool_call_id,
|
| 488 |
"approved": a.approved,
|
| 489 |
"feedback": a.feedback,
|
| 490 |
"edited_script": a.edited_script,
|
|
|
|
| 491 |
}
|
| 492 |
for a in request.approvals
|
| 493 |
]
|
|
|
|
| 494 |
success = await session_manager.submit_approval(request.session_id, approvals)
|
| 495 |
if not success:
|
| 496 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
@@ -540,9 +666,11 @@ async def chat_sse(
|
|
| 540 |
"approved": a["approved"],
|
| 541 |
"feedback": a.get("feedback"),
|
| 542 |
"edited_script": a.get("edited_script"),
|
|
|
|
| 543 |
}
|
| 544 |
for a in approvals
|
| 545 |
]
|
|
|
|
| 546 |
success = await session_manager.submit_approval(session_id, formatted)
|
| 547 |
elif text is not None:
|
| 548 |
success = await session_manager.submit_user_input(session_id, text)
|
|
@@ -554,6 +682,7 @@ async def chat_sse(
|
|
| 554 |
broadcaster.unsubscribe(sub_id)
|
| 555 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 556 |
except HTTPException:
|
|
|
|
| 557 |
raise
|
| 558 |
except Exception:
|
| 559 |
broadcaster.unsubscribe(sub_id)
|
|
@@ -562,6 +691,31 @@ async def chat_sse(
|
|
| 562 |
return _sse_response(broadcaster, event_queue, sub_id)
|
| 563 |
|
| 564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
# ---------------------------------------------------------------------------
|
| 566 |
# Shared SSE helpers
|
| 567 |
# ---------------------------------------------------------------------------
|
|
@@ -729,5 +883,3 @@ async def submit_feedback(
|
|
| 729 |
agent_session.session.config.session_dataset_repo
|
| 730 |
)
|
| 731 |
return {"status": "ok"}
|
| 732 |
-
|
| 733 |
-
|
|
|
|
| 32 |
|
| 33 |
import user_quotas
|
| 34 |
|
| 35 |
+
from agent.core.hf_access import get_jobs_access
|
| 36 |
from agent.core.llm_params import _resolve_llm_params
|
| 37 |
|
| 38 |
logger = logging.getLogger(__name__)
|
|
|
|
| 137 |
agent_session.claude_counted = True
|
| 138 |
|
| 139 |
|
| 140 |
+
async def _enforce_jobs_access_for_approvals(
|
| 141 |
+
user: dict[str, Any],
|
| 142 |
+
agent_session: AgentSession,
|
| 143 |
+
approvals: list[dict[str, Any]],
|
| 144 |
+
) -> None:
|
| 145 |
+
"""Block approved hf_jobs tool calls when the user has no eligible jobs namespace."""
|
| 146 |
+
pending = agent_session.session.pending_approval or {}
|
| 147 |
+
tool_calls = pending.get("tool_calls") or []
|
| 148 |
+
if not tool_calls:
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
approved_ids = {
|
| 152 |
+
a.get("tool_call_id")
|
| 153 |
+
for a in approvals
|
| 154 |
+
if a.get("approved")
|
| 155 |
+
}
|
| 156 |
+
if not approved_ids:
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
hf_job_ids = [
|
| 160 |
+
tc.id for tc in tool_calls
|
| 161 |
+
if tc.id in approved_ids and tc.function.name == "hf_jobs"
|
| 162 |
+
]
|
| 163 |
+
if not hf_job_ids:
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
token = agent_session.hf_token or agent_session.session.hf_token
|
| 167 |
+
if not token:
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
access = await get_jobs_access(token)
|
| 171 |
+
if access is None:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
approval_map = {a.get("tool_call_id"): a for a in approvals}
|
| 175 |
+
if access.personal_can_run_jobs:
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
if access.paid_org_names:
|
| 179 |
+
invalid_namespace = [
|
| 180 |
+
tool_call_id
|
| 181 |
+
for tool_call_id in hf_job_ids
|
| 182 |
+
if (
|
| 183 |
+
approval_map.get(tool_call_id, {}).get("namespace")
|
| 184 |
+
and approval_map.get(tool_call_id, {}).get("namespace") not in access.paid_org_names
|
| 185 |
+
)
|
| 186 |
+
]
|
| 187 |
+
if invalid_namespace:
|
| 188 |
+
raise HTTPException(
|
| 189 |
+
status_code=400,
|
| 190 |
+
detail={
|
| 191 |
+
"error": "hf_jobs_invalid_namespace",
|
| 192 |
+
"message": (
|
| 193 |
+
"The selected jobs namespace is not one of your eligible paid organizations. "
|
| 194 |
+
f"Allowed namespaces: {', '.join(access.paid_org_names)}"
|
| 195 |
+
),
|
| 196 |
+
},
|
| 197 |
+
)
|
| 198 |
+
missing_namespace = [
|
| 199 |
+
tool_call_id
|
| 200 |
+
for tool_call_id in hf_job_ids
|
| 201 |
+
if not approval_map.get(tool_call_id, {}).get("namespace")
|
| 202 |
+
]
|
| 203 |
+
if missing_namespace:
|
| 204 |
+
raise HTTPException(
|
| 205 |
+
status_code=409,
|
| 206 |
+
detail={
|
| 207 |
+
"error": "hf_jobs_namespace_required",
|
| 208 |
+
"message": "Choose which paid organization should own this job run.",
|
| 209 |
+
"plan": user.get("plan", "free"),
|
| 210 |
+
"tool_call_ids": missing_namespace,
|
| 211 |
+
"eligible_namespaces": access.paid_org_names,
|
| 212 |
+
},
|
| 213 |
+
)
|
| 214 |
+
return
|
| 215 |
+
|
| 216 |
+
from agent.core import telemetry
|
| 217 |
+
await telemetry.record_jobs_access_blocked(
|
| 218 |
+
agent_session.session,
|
| 219 |
+
tool_call_ids=hf_job_ids,
|
| 220 |
+
plan=user.get("plan", "free"),
|
| 221 |
+
eligible_namespaces=access.eligible_namespaces,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
raise HTTPException(
|
| 225 |
+
status_code=402,
|
| 226 |
+
detail={
|
| 227 |
+
"error": "hf_jobs_upgrade_required",
|
| 228 |
+
"message": (
|
| 229 |
+
"Hugging Face Jobs are available only to Pro users and Team or Enterprise organizations. "
|
| 230 |
+
"Upgrade to Pro, or decline the job tool call so the agent can choose another path."
|
| 231 |
+
),
|
| 232 |
+
"plan": user.get("plan", "free"),
|
| 233 |
+
"tool_call_ids": hf_job_ids,
|
| 234 |
+
"eligible_namespaces": access.eligible_namespaces,
|
| 235 |
+
},
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
def _check_session_access(session_id: str, user: dict[str, Any]) -> None:
|
| 240 |
"""Verify the user has access to the given session. Raises 403 or 404."""
|
| 241 |
info = session_manager.get_session_info(session_id)
|
|
|
|
| 542 |
}
|
| 543 |
|
| 544 |
|
| 545 |
+
@router.get("/user/jobs-access")
|
| 546 |
+
async def get_jobs_access_info(request: Request, user: dict = Depends(get_current_user)) -> dict:
|
| 547 |
+
"""Return whether the current token can run HF Jobs and under which namespaces."""
|
| 548 |
+
token = None
|
| 549 |
+
auth_header = request.headers.get("Authorization", "")
|
| 550 |
+
if auth_header.startswith("Bearer "):
|
| 551 |
+
token = auth_header[7:]
|
| 552 |
+
if not token:
|
| 553 |
+
token = request.cookies.get("hf_access_token")
|
| 554 |
+
if not token:
|
| 555 |
+
token = os.environ.get("HF_TOKEN")
|
| 556 |
+
|
| 557 |
+
access = await get_jobs_access(token or "")
|
| 558 |
+
return {
|
| 559 |
+
"plan": user.get("plan", "free"),
|
| 560 |
+
"can_run_jobs": bool(access and (access.personal_can_run_jobs or access.paid_org_names)),
|
| 561 |
+
"eligible_namespaces": access.eligible_namespaces if access else [],
|
| 562 |
+
"default_namespace": access.default_namespace if access else None,
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
|
| 566 |
@router.get("/sessions", response_model=list[SessionInfo])
|
| 567 |
async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]:
|
| 568 |
"""List sessions belonging to the authenticated user."""
|
|
|
|
| 603 |
) -> dict:
|
| 604 |
"""Submit tool approvals to a session. Only accessible by the session owner."""
|
| 605 |
_check_session_access(request.session_id, user)
|
| 606 |
+
agent_session = session_manager.sessions.get(request.session_id)
|
| 607 |
+
if agent_session is None:
|
| 608 |
+
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 609 |
approvals = [
|
| 610 |
{
|
| 611 |
"tool_call_id": a.tool_call_id,
|
| 612 |
"approved": a.approved,
|
| 613 |
"feedback": a.feedback,
|
| 614 |
"edited_script": a.edited_script,
|
| 615 |
+
"namespace": a.namespace,
|
| 616 |
}
|
| 617 |
for a in request.approvals
|
| 618 |
]
|
| 619 |
+
await _enforce_jobs_access_for_approvals(user, agent_session, approvals)
|
| 620 |
success = await session_manager.submit_approval(request.session_id, approvals)
|
| 621 |
if not success:
|
| 622 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
|
|
|
| 666 |
"approved": a["approved"],
|
| 667 |
"feedback": a.get("feedback"),
|
| 668 |
"edited_script": a.get("edited_script"),
|
| 669 |
+
"namespace": a.get("namespace"),
|
| 670 |
}
|
| 671 |
for a in approvals
|
| 672 |
]
|
| 673 |
+
await _enforce_jobs_access_for_approvals(user, agent_session, formatted)
|
| 674 |
success = await session_manager.submit_approval(session_id, formatted)
|
| 675 |
elif text is not None:
|
| 676 |
success = await session_manager.submit_user_input(session_id, text)
|
|
|
|
| 682 |
broadcaster.unsubscribe(sub_id)
|
| 683 |
raise HTTPException(status_code=404, detail="Session not found or inactive")
|
| 684 |
except HTTPException:
|
| 685 |
+
broadcaster.unsubscribe(sub_id)
|
| 686 |
raise
|
| 687 |
except Exception:
|
| 688 |
broadcaster.unsubscribe(sub_id)
|
|
|
|
| 691 |
return _sse_response(broadcaster, event_queue, sub_id)
|
| 692 |
|
| 693 |
|
| 694 |
+
@router.post("/pro-click/{session_id}")
|
| 695 |
+
async def record_pro_click(
|
| 696 |
+
session_id: str,
|
| 697 |
+
body: dict,
|
| 698 |
+
user: dict = Depends(get_current_user),
|
| 699 |
+
) -> dict:
|
| 700 |
+
"""Record a click on a Pro upgrade CTA shown from inside a session."""
|
| 701 |
+
_check_session_access(session_id, user)
|
| 702 |
+
agent_session = session_manager.sessions.get(session_id)
|
| 703 |
+
if not agent_session:
|
| 704 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 705 |
+
|
| 706 |
+
from agent.core import telemetry
|
| 707 |
+
await telemetry.record_pro_cta_click(
|
| 708 |
+
agent_session.session,
|
| 709 |
+
source=str(body.get("source") or "unknown"),
|
| 710 |
+
target=str(body.get("target") or "pro_pricing"),
|
| 711 |
+
)
|
| 712 |
+
if agent_session.session.config.save_sessions:
|
| 713 |
+
agent_session.session.save_and_upload_detached(
|
| 714 |
+
agent_session.session.config.session_dataset_repo
|
| 715 |
+
)
|
| 716 |
+
return {"status": "ok"}
|
| 717 |
+
|
| 718 |
+
|
| 719 |
# ---------------------------------------------------------------------------
|
| 720 |
# Shared SSE helpers
|
| 721 |
# ---------------------------------------------------------------------------
|
|
|
|
| 883 |
agent_session.session.config.session_dataset_repo
|
| 884 |
)
|
| 885 |
return {"status": "ok"}
|
|
|
|
|
|
|
@@ -6,6 +6,7 @@ import StopIcon from '@mui/icons-material/Stop';
|
|
| 6 |
import { apiFetch } from '@/utils/api';
|
| 7 |
import { useUserQuota } from '@/hooks/useUserQuota';
|
| 8 |
import ClaudeCapDialog from '@/components/ClaudeCapDialog';
|
|
|
|
| 9 |
import { useAgentStore } from '@/store/agentStore';
|
| 10 |
import { CLAUDE_MODEL_PATH, FIRST_FREE_MODEL_PATH, isClaudePath } from '@/utils/model';
|
| 11 |
|
|
@@ -65,6 +66,8 @@ interface ChatInputProps {
|
|
| 65 |
sessionId?: string;
|
| 66 |
onSend: (text: string) => void;
|
| 67 |
onStop?: () => void;
|
|
|
|
|
|
|
| 68 |
isProcessing?: boolean;
|
| 69 |
disabled?: boolean;
|
| 70 |
placeholder?: string;
|
|
@@ -73,7 +76,7 @@ interface ChatInputProps {
|
|
| 73 |
const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
|
| 74 |
const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0];
|
| 75 |
|
| 76 |
-
export default function ChatInput({ sessionId, onSend, onStop, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
|
| 77 |
const [input, setInput] = useState('');
|
| 78 |
const inputRef = useRef<HTMLTextAreaElement>(null);
|
| 79 |
const [selectedModelId, setSelectedModelId] = useState<string>(MODEL_OPTIONS[0].id);
|
|
@@ -86,6 +89,8 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa
|
|
| 86 |
// the hook layer can flip it without threading props through.
|
| 87 |
const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted);
|
| 88 |
const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted);
|
|
|
|
|
|
|
| 89 |
const lastSentRef = useRef<string>('');
|
| 90 |
|
| 91 |
// Model is per-session: fetch this tab's current model every time the
|
|
@@ -197,6 +202,44 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa
|
|
| 197 |
} catch { /* ignore */ }
|
| 198 |
}, [sessionId, onSend, setClaudeQuotaExhausted]);
|
| 199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
// Hide the chip until the user has actually burned quota — an unused
|
| 201 |
// Opus session shouldn't populate a counter.
|
| 202 |
const claudeChip = (() => {
|
|
@@ -435,6 +478,17 @@ export default function ChatInput({ sessionId, onSend, onStop, isProcessing = fa
|
|
| 435 |
cap={quota?.claudeDailyCap ?? 1}
|
| 436 |
onClose={handleCapDialogClose}
|
| 437 |
onUseFreeModel={handleUseFreeModel}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
/>
|
| 439 |
</Box>
|
| 440 |
</Box>
|
|
|
|
| 6 |
import { apiFetch } from '@/utils/api';
|
| 7 |
import { useUserQuota } from '@/hooks/useUserQuota';
|
| 8 |
import ClaudeCapDialog from '@/components/ClaudeCapDialog';
|
| 9 |
+
import JobsUpgradeDialog from '@/components/JobsUpgradeDialog';
|
| 10 |
import { useAgentStore } from '@/store/agentStore';
|
| 11 |
import { CLAUDE_MODEL_PATH, FIRST_FREE_MODEL_PATH, isClaudePath } from '@/utils/model';
|
| 12 |
|
|
|
|
| 66 |
sessionId?: string;
|
| 67 |
onSend: (text: string) => void;
|
| 68 |
onStop?: () => void;
|
| 69 |
+
onDeclineBlockedJobs?: () => Promise<boolean>;
|
| 70 |
+
onContinueBlockedJobsWithNamespace?: (namespace: string) => Promise<boolean>;
|
| 71 |
isProcessing?: boolean;
|
| 72 |
disabled?: boolean;
|
| 73 |
placeholder?: string;
|
|
|
|
| 76 |
const isClaudeModel = (m: ModelOption) => isClaudePath(m.modelPath);
|
| 77 |
const firstFreeModel = () => MODEL_OPTIONS.find(m => !isClaudeModel(m)) ?? MODEL_OPTIONS[0];
|
| 78 |
|
| 79 |
+
export default function ChatInput({ sessionId, onSend, onStop, onDeclineBlockedJobs, onContinueBlockedJobsWithNamespace, isProcessing = false, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) {
|
| 80 |
const [input, setInput] = useState('');
|
| 81 |
const inputRef = useRef<HTMLTextAreaElement>(null);
|
| 82 |
const [selectedModelId, setSelectedModelId] = useState<string>(MODEL_OPTIONS[0].id);
|
|
|
|
| 89 |
// the hook layer can flip it without threading props through.
|
| 90 |
const claudeQuotaExhausted = useAgentStore((s) => s.claudeQuotaExhausted);
|
| 91 |
const setClaudeQuotaExhausted = useAgentStore((s) => s.setClaudeQuotaExhausted);
|
| 92 |
+
const jobsUpgradeRequired = useAgentStore((s) => s.jobsUpgradeRequired);
|
| 93 |
+
const setJobsUpgradeRequired = useAgentStore((s) => s.setJobsUpgradeRequired);
|
| 94 |
const lastSentRef = useRef<string>('');
|
| 95 |
|
| 96 |
// Model is per-session: fetch this tab's current model every time the
|
|
|
|
| 202 |
} catch { /* ignore */ }
|
| 203 |
}, [sessionId, onSend, setClaudeQuotaExhausted]);
|
| 204 |
|
| 205 |
+
const handleClaudeUpgradeClick = useCallback(async () => {
|
| 206 |
+
if (!sessionId) return;
|
| 207 |
+
try {
|
| 208 |
+
await apiFetch(`/api/pro-click/${sessionId}`, {
|
| 209 |
+
method: 'POST',
|
| 210 |
+
body: JSON.stringify({ source: 'claude_cap_dialog', target: 'pro_pricing' }),
|
| 211 |
+
});
|
| 212 |
+
} catch {
|
| 213 |
+
/* tracking is best-effort */
|
| 214 |
+
}
|
| 215 |
+
}, [sessionId]);
|
| 216 |
+
|
| 217 |
+
const handleJobsUpgradeClose = useCallback(() => {
|
| 218 |
+
setJobsUpgradeRequired(null);
|
| 219 |
+
}, [setJobsUpgradeRequired]);
|
| 220 |
+
|
| 221 |
+
const handleJobsUpgradeClick = useCallback(async () => {
|
| 222 |
+
if (!sessionId || !jobsUpgradeRequired) return;
|
| 223 |
+
try {
|
| 224 |
+
await apiFetch(`/api/pro-click/${sessionId}`, {
|
| 225 |
+
method: 'POST',
|
| 226 |
+
body: JSON.stringify({ source: 'hf_jobs_upgrade_dialog', target: 'pro_pricing' }),
|
| 227 |
+
});
|
| 228 |
+
} catch {
|
| 229 |
+
/* tracking is best-effort */
|
| 230 |
+
}
|
| 231 |
+
}, [sessionId, jobsUpgradeRequired]);
|
| 232 |
+
|
| 233 |
+
const handleDeclineBlockedJobs = useCallback(async () => {
|
| 234 |
+
if (!onDeclineBlockedJobs) return;
|
| 235 |
+
await onDeclineBlockedJobs();
|
| 236 |
+
}, [onDeclineBlockedJobs]);
|
| 237 |
+
|
| 238 |
+
const handleContinueBlockedJobsWithNamespace = useCallback(async (namespace: string) => {
|
| 239 |
+
if (!onContinueBlockedJobsWithNamespace) return;
|
| 240 |
+
await onContinueBlockedJobsWithNamespace(namespace);
|
| 241 |
+
}, [onContinueBlockedJobsWithNamespace]);
|
| 242 |
+
|
| 243 |
// Hide the chip until the user has actually burned quota — an unused
|
| 244 |
// Opus session shouldn't populate a counter.
|
| 245 |
const claudeChip = (() => {
|
|
|
|
| 478 |
cap={quota?.claudeDailyCap ?? 1}
|
| 479 |
onClose={handleCapDialogClose}
|
| 480 |
onUseFreeModel={handleUseFreeModel}
|
| 481 |
+
onUpgrade={handleClaudeUpgradeClick}
|
| 482 |
+
/>
|
| 483 |
+
<JobsUpgradeDialog
|
| 484 |
+
open={!!jobsUpgradeRequired}
|
| 485 |
+
mode={jobsUpgradeRequired?.mode || 'upgrade'}
|
| 486 |
+
message={jobsUpgradeRequired?.message || ''}
|
| 487 |
+
eligibleNamespaces={jobsUpgradeRequired?.eligibleNamespaces || []}
|
| 488 |
+
onClose={handleJobsUpgradeClose}
|
| 489 |
+
onUpgrade={handleJobsUpgradeClick}
|
| 490 |
+
onDecline={handleDeclineBlockedJobs}
|
| 491 |
+
onContinueWithNamespace={handleContinueBlockedJobsWithNamespace}
|
| 492 |
/>
|
| 493 |
</Box>
|
| 494 |
</Box>
|
|
@@ -19,6 +19,7 @@ interface ClaudeCapDialogProps {
|
|
| 19 |
cap: number;
|
| 20 |
onClose: () => void;
|
| 21 |
onUseFreeModel: () => void;
|
|
|
|
| 22 |
}
|
| 23 |
|
| 24 |
export default function ClaudeCapDialog({
|
|
@@ -27,6 +28,7 @@ export default function ClaudeCapDialog({
|
|
| 27 |
cap,
|
| 28 |
onClose,
|
| 29 |
onUseFreeModel,
|
|
|
|
| 30 |
}: ClaudeCapDialogProps) {
|
| 31 |
// plan not surfaced in copy right now — Pro users see the same dialog and
|
| 32 |
// can upgrade their org if they're also capped.
|
|
@@ -100,6 +102,7 @@ export default function ClaudeCapDialog({
|
|
| 100 |
href={HF_PRICING_URL}
|
| 101 |
target="_blank"
|
| 102 |
rel="noopener noreferrer"
|
|
|
|
| 103 |
variant="contained"
|
| 104 |
size="small"
|
| 105 |
sx={{
|
|
|
|
| 19 |
cap: number;
|
| 20 |
onClose: () => void;
|
| 21 |
onUseFreeModel: () => void;
|
| 22 |
+
onUpgrade: () => void;
|
| 23 |
}
|
| 24 |
|
| 25 |
export default function ClaudeCapDialog({
|
|
|
|
| 28 |
cap,
|
| 29 |
onClose,
|
| 30 |
onUseFreeModel,
|
| 31 |
+
onUpgrade,
|
| 32 |
}: ClaudeCapDialogProps) {
|
| 33 |
// plan not surfaced in copy right now — Pro users see the same dialog and
|
| 34 |
// can upgrade their org if they're also capped.
|
|
|
|
| 102 |
href={HF_PRICING_URL}
|
| 103 |
target="_blank"
|
| 104 |
rel="noopener noreferrer"
|
| 105 |
+
onClick={onUpgrade}
|
| 106 |
variant="contained"
|
| 107 |
size="small"
|
| 108 |
sx={{
|
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import { useEffect, useState } from 'react';
|
| 2 |
+
import {
|
| 3 |
+
Box,
|
| 4 |
+
Button,
|
| 5 |
+
Dialog,
|
| 6 |
+
DialogActions,
|
| 7 |
+
DialogContent,
|
| 8 |
+
DialogContentText,
|
| 9 |
+
DialogTitle,
|
| 10 |
+
FormControl,
|
| 11 |
+
InputLabel,
|
| 12 |
+
MenuItem,
|
| 13 |
+
Select,
|
| 14 |
+
Typography,
|
| 15 |
+
} from '@mui/material';
|
| 16 |
+
|
| 17 |
+
const HF_PRICING_URL = 'https://huggingface.co/pricing';
|
| 18 |
+
|
| 19 |
+
interface JobsUpgradeDialogProps {
|
| 20 |
+
open: boolean;
|
| 21 |
+
mode: 'upgrade' | 'namespace';
|
| 22 |
+
message: string;
|
| 23 |
+
eligibleNamespaces: string[];
|
| 24 |
+
onUpgrade: () => void;
|
| 25 |
+
onDecline: () => void;
|
| 26 |
+
onClose: () => void;
|
| 27 |
+
onContinueWithNamespace: (namespace: string) => void;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
export default function JobsUpgradeDialog({
|
| 31 |
+
open,
|
| 32 |
+
mode,
|
| 33 |
+
message,
|
| 34 |
+
eligibleNamespaces,
|
| 35 |
+
onUpgrade,
|
| 36 |
+
onDecline,
|
| 37 |
+
onClose,
|
| 38 |
+
onContinueWithNamespace,
|
| 39 |
+
}: JobsUpgradeDialogProps) {
|
| 40 |
+
const [selectedNamespace, setSelectedNamespace] = useState('');
|
| 41 |
+
|
| 42 |
+
useEffect(() => {
|
| 43 |
+
if (!open) return;
|
| 44 |
+
setSelectedNamespace(eligibleNamespaces[0] || '');
|
| 45 |
+
}, [open, eligibleNamespaces]);
|
| 46 |
+
|
| 47 |
+
return (
|
| 48 |
+
<Dialog
|
| 49 |
+
open={open}
|
| 50 |
+
onClose={onClose}
|
| 51 |
+
slotProps={{
|
| 52 |
+
backdrop: { sx: { backgroundColor: 'rgba(0,0,0,0.5)', backdropFilter: 'blur(4px)' } },
|
| 53 |
+
}}
|
| 54 |
+
PaperProps={{
|
| 55 |
+
sx: {
|
| 56 |
+
bgcolor: 'var(--panel)',
|
| 57 |
+
border: '1px solid var(--border)',
|
| 58 |
+
borderRadius: 'var(--radius-md)',
|
| 59 |
+
boxShadow: 'var(--shadow-1)',
|
| 60 |
+
maxWidth: 500,
|
| 61 |
+
mx: 2,
|
| 62 |
+
},
|
| 63 |
+
}}
|
| 64 |
+
>
|
| 65 |
+
<DialogTitle
|
| 66 |
+
sx={{ color: 'var(--text)', fontWeight: 700, fontSize: '1rem', pt: 2.5, pb: 0, px: 3 }}
|
| 67 |
+
>
|
| 68 |
+
{mode === 'namespace' ? 'Choose the org for this job' : 'Jobs need Pro or a paid org'}
|
| 69 |
+
</DialogTitle>
|
| 70 |
+
<DialogContent sx={{ px: 3, pt: 1.25, pb: 0 }}>
|
| 71 |
+
<DialogContentText
|
| 72 |
+
sx={{ color: 'var(--muted-text)', fontSize: '0.85rem', lineHeight: 1.6 }}
|
| 73 |
+
>
|
| 74 |
+
{message}
|
| 75 |
+
</DialogContentText>
|
| 76 |
+
{eligibleNamespaces.length > 0 && (
|
| 77 |
+
<Box
|
| 78 |
+
sx={{
|
| 79 |
+
mt: 2,
|
| 80 |
+
p: 1.5,
|
| 81 |
+
borderRadius: '8px',
|
| 82 |
+
bgcolor: 'var(--accent-yellow-weak)',
|
| 83 |
+
border: '1px solid var(--border)',
|
| 84 |
+
}}
|
| 85 |
+
>
|
| 86 |
+
<Typography
|
| 87 |
+
variant="caption"
|
| 88 |
+
sx={{
|
| 89 |
+
display: 'block',
|
| 90 |
+
fontWeight: 700,
|
| 91 |
+
color: 'var(--text)',
|
| 92 |
+
fontSize: '0.78rem',
|
| 93 |
+
mb: 1,
|
| 94 |
+
letterSpacing: '0.02em',
|
| 95 |
+
}}
|
| 96 |
+
>
|
| 97 |
+
Eligible namespaces
|
| 98 |
+
</Typography>
|
| 99 |
+
{mode === 'namespace' ? (
|
| 100 |
+
<FormControl fullWidth size="small">
|
| 101 |
+
<InputLabel id="jobs-namespace-label">Organization</InputLabel>
|
| 102 |
+
<Select
|
| 103 |
+
labelId="jobs-namespace-label"
|
| 104 |
+
value={selectedNamespace}
|
| 105 |
+
label="Organization"
|
| 106 |
+
onChange={(e) => setSelectedNamespace(String(e.target.value))}
|
| 107 |
+
>
|
| 108 |
+
{eligibleNamespaces.map((namespace) => (
|
| 109 |
+
<MenuItem key={namespace} value={namespace}>
|
| 110 |
+
{namespace}
|
| 111 |
+
</MenuItem>
|
| 112 |
+
))}
|
| 113 |
+
</Select>
|
| 114 |
+
</FormControl>
|
| 115 |
+
) : (
|
| 116 |
+
<Typography
|
| 117 |
+
variant="caption"
|
| 118 |
+
sx={{ display: 'block', color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
|
| 119 |
+
>
|
| 120 |
+
{eligibleNamespaces.join(', ')}
|
| 121 |
+
</Typography>
|
| 122 |
+
)}
|
| 123 |
+
</Box>
|
| 124 |
+
)}
|
| 125 |
+
<Typography
|
| 126 |
+
variant="caption"
|
| 127 |
+
sx={{ display: 'block', mt: 2, color: 'var(--muted-text)', fontSize: '0.78rem', lineHeight: 1.55 }}
|
| 128 |
+
>
|
| 129 |
+
If you decline, the agent will have to find another way forward without `hf_jobs`.
|
| 130 |
+
</Typography>
|
| 131 |
+
</DialogContent>
|
| 132 |
+
<DialogActions sx={{ px: 3, pb: 2.5, pt: 2, gap: 1 }}>
|
| 133 |
+
{mode === 'namespace' ? (
|
| 134 |
+
<Button
|
| 135 |
+
onClick={() => onContinueWithNamespace(selectedNamespace)}
|
| 136 |
+
disabled={!selectedNamespace}
|
| 137 |
+
variant="contained"
|
| 138 |
+
size="small"
|
| 139 |
+
sx={{
|
| 140 |
+
fontSize: '0.82rem',
|
| 141 |
+
px: 2.5,
|
| 142 |
+
bgcolor: 'var(--accent-yellow)',
|
| 143 |
+
color: '#000',
|
| 144 |
+
textTransform: 'none',
|
| 145 |
+
fontWeight: 700,
|
| 146 |
+
boxShadow: 'none',
|
| 147 |
+
'&:hover': { bgcolor: '#FFB340', boxShadow: 'none' },
|
| 148 |
+
}}
|
| 149 |
+
>
|
| 150 |
+
Run under selected org
|
| 151 |
+
</Button>
|
| 152 |
+
) : (
|
| 153 |
+
<Button
|
| 154 |
+
component="a"
|
| 155 |
+
href={HF_PRICING_URL}
|
| 156 |
+
target="_blank"
|
| 157 |
+
rel="noopener noreferrer"
|
| 158 |
+
onClick={onUpgrade}
|
| 159 |
+
variant="contained"
|
| 160 |
+
size="small"
|
| 161 |
+
sx={{
|
| 162 |
+
fontSize: '0.82rem',
|
| 163 |
+
px: 2.5,
|
| 164 |
+
bgcolor: 'var(--accent-yellow)',
|
| 165 |
+
color: '#000',
|
| 166 |
+
textTransform: 'none',
|
| 167 |
+
fontWeight: 700,
|
| 168 |
+
boxShadow: 'none',
|
| 169 |
+
'&:hover': { bgcolor: '#FFB340', boxShadow: 'none' },
|
| 170 |
+
}}
|
| 171 |
+
>
|
| 172 |
+
Upgrade to Pro
|
| 173 |
+
</Button>
|
| 174 |
+
)}
|
| 175 |
+
<Button
|
| 176 |
+
onClick={onDecline}
|
| 177 |
+
size="small"
|
| 178 |
+
sx={{
|
| 179 |
+
color: 'var(--muted-text)',
|
| 180 |
+
fontSize: '0.82rem',
|
| 181 |
+
px: 2,
|
| 182 |
+
textTransform: 'none',
|
| 183 |
+
'&:hover': { bgcolor: 'var(--hover-bg)' },
|
| 184 |
+
}}
|
| 185 |
+
>
|
| 186 |
+
Decline tool call
|
| 187 |
+
</Button>
|
| 188 |
+
</DialogActions>
|
| 189 |
+
</Dialog>
|
| 190 |
+
);
|
| 191 |
+
}
|
|
@@ -26,7 +26,7 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
|
|
| 26 |
const { updateSessionTitle, sessions } = useSessionStore();
|
| 27 |
const isExpired = sessions.find((s) => s.id === sessionId)?.expired === true;
|
| 28 |
|
| 29 |
-
const { messages, sendMessage, stop, status, undoLastTurn, editAndRegenerate, approveTools } = useAgentChat({
|
| 30 |
sessionId,
|
| 31 |
isActive,
|
| 32 |
onReady: () => logger.log(`Session ${sessionId} ready`),
|
|
@@ -114,6 +114,8 @@ export default function SessionChat({ sessionId, isActive, onSessionDead }: Sess
|
|
| 114 |
sessionId={sessionId}
|
| 115 |
onSend={handleSendMessage}
|
| 116 |
onStop={handleStop}
|
|
|
|
|
|
|
| 117 |
isProcessing={busy}
|
| 118 |
disabled={!isConnected || activityStatus.type === 'waiting-approval'}
|
| 119 |
placeholder={
|
|
|
|
| 26 |
const { updateSessionTitle, sessions } = useSessionStore();
|
| 27 |
const isExpired = sessions.find((s) => s.id === sessionId)?.expired === true;
|
| 28 |
|
| 29 |
+
const { messages, sendMessage, stop, status, undoLastTurn, editAndRegenerate, approveTools, declineBlockedJobs, continueBlockedJobsWithNamespace } = useAgentChat({
|
| 30 |
sessionId,
|
| 31 |
isActive,
|
| 32 |
onReady: () => logger.log(`Session ${sessionId} ready`),
|
|
|
|
| 114 |
sessionId={sessionId}
|
| 115 |
onSend={handleSendMessage}
|
| 116 |
onStop={handleStop}
|
| 117 |
+
onDeclineBlockedJobs={declineBlockedJobs}
|
| 118 |
+
onContinueBlockedJobsWithNamespace={continueBlockedJobsWithNamespace}
|
| 119 |
isProcessing={busy}
|
| 120 |
disabled={!isConnected || activityStatus.type === 'waiting-approval'}
|
| 121 |
placeholder={
|
|
@@ -330,6 +330,49 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 330 |
messages: UIMessage[];
|
| 331 |
}>({ setMessages: null, messages: [] });
|
| 332 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
// -- useChat from Vercel AI SDK -----------------------------------------
|
| 334 |
const chat = useChat({
|
| 335 |
id: sessionId,
|
|
@@ -354,6 +397,56 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 354 |
}
|
| 355 |
return;
|
| 356 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
logger.error('useChat error:', error);
|
| 358 |
if (isActiveRef.current) {
|
| 359 |
useAgentStore.getState().setError(error.message);
|
|
@@ -672,12 +765,15 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 672 |
|
| 673 |
// -- Approve tools ------------------------------------------------------
|
| 674 |
const approveTools = useCallback(
|
| 675 |
-
async (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>) => {
|
| 676 |
// Store edited scripts so the transport can read them when sendMessages is called
|
| 677 |
for (const a of approvals) {
|
| 678 |
if (a.edited_script) {
|
| 679 |
useAgentStore.getState().setEditedScript(a.tool_call_id, a.edited_script);
|
| 680 |
}
|
|
|
|
|
|
|
|
|
|
| 681 |
}
|
| 682 |
|
| 683 |
// Update SDK tool state — this triggers sendMessages() via the transport
|
|
@@ -707,6 +803,37 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 707 |
[sessionId, chat, updateSession, setNeedsAttention],
|
| 708 |
);
|
| 709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
// -- Stop (interrupt backend agent loop, keep SSE open for events) --------
|
| 711 |
const stop = useCallback(() => {
|
| 712 |
// Don't call chat.stop() — keep the SSE stream open so the backend's
|
|
@@ -763,5 +890,7 @@ export function useAgentChat({ sessionId, isActive, onReady, onError, onSessionD
|
|
| 763 |
undoLastTurn,
|
| 764 |
editAndRegenerate,
|
| 765 |
approveTools,
|
|
|
|
|
|
|
| 766 |
};
|
| 767 |
}
|
|
|
|
| 330 |
messages: UIMessage[];
|
| 331 |
}>({ setMessages: null, messages: [] });
|
| 332 |
|
| 333 |
+
const hydrateFromBackend = useCallback(async () => {
|
| 334 |
+
try {
|
| 335 |
+
const [msgsRes, infoRes] = await Promise.all([
|
| 336 |
+
apiFetch(`/api/session/${sessionId}/messages`),
|
| 337 |
+
apiFetch(`/api/session/${sessionId}`),
|
| 338 |
+
]);
|
| 339 |
+
if (!msgsRes.ok) return null;
|
| 340 |
+
const data = await msgsRes.json();
|
| 341 |
+
if (!Array.isArray(data) || data.length === 0) return null;
|
| 342 |
+
|
| 343 |
+
saveBackendMessages(sessionId, data);
|
| 344 |
+
|
| 345 |
+
let pendingIds: Set<string> | undefined;
|
| 346 |
+
let info: Record<string, unknown> | null = null;
|
| 347 |
+
if (infoRes.ok) {
|
| 348 |
+
info = await infoRes.json();
|
| 349 |
+
const pendingApproval = info?.pending_approval;
|
| 350 |
+
if (pendingApproval && Array.isArray(pendingApproval)) {
|
| 351 |
+
pendingIds = new Set(
|
| 352 |
+
pendingApproval.map((t: { tool_call_id: string }) => t.tool_call_id),
|
| 353 |
+
);
|
| 354 |
+
if (pendingIds.size > 0) {
|
| 355 |
+
setNeedsAttention(sessionId, true);
|
| 356 |
+
}
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
const uiMsgs = llmMessagesToUIMessages(data, pendingIds, chatActionsRef.current.messages);
|
| 361 |
+
if (uiMsgs.length > 0) {
|
| 362 |
+
chatActionsRef.current.setMessages?.(uiMsgs);
|
| 363 |
+
saveMessages(sessionId, uiMsgs);
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
if (pendingIds && pendingIds.size > 0) {
|
| 367 |
+
updateSession(sessionId, { activityStatus: { type: 'waiting-approval' }, isProcessing: false });
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
return { data, pendingIds, info };
|
| 371 |
+
} catch {
|
| 372 |
+
return null;
|
| 373 |
+
}
|
| 374 |
+
}, [sessionId, setNeedsAttention]);
|
| 375 |
+
|
| 376 |
// -- useChat from Vercel AI SDK -----------------------------------------
|
| 377 |
const chat = useChat({
|
| 378 |
id: sessionId,
|
|
|
|
| 397 |
}
|
| 398 |
return;
|
| 399 |
}
|
| 400 |
+
if (error.message === 'HF_JOBS_UPGRADE_REQUIRED') {
|
| 401 |
+
const typed = error as Error & {
|
| 402 |
+
detail?: Record<string, unknown>;
|
| 403 |
+
approvals?: Array<{
|
| 404 |
+
tool_call_id: string;
|
| 405 |
+
approved: boolean;
|
| 406 |
+
feedback?: string | null;
|
| 407 |
+
edited_script?: string | null;
|
| 408 |
+
}>;
|
| 409 |
+
};
|
| 410 |
+
void hydrateFromBackend();
|
| 411 |
+
if (isActiveRef.current) {
|
| 412 |
+
useAgentStore.getState().setJobsUpgradeRequired({
|
| 413 |
+
approvals: typed.approvals || [],
|
| 414 |
+
toolCallIds: (typed.detail?.tool_call_ids as string[]) || [],
|
| 415 |
+
message: String(
|
| 416 |
+
typed.detail?.message
|
| 417 |
+
|| 'Hugging Face Jobs are available only to Pro users and Team or Enterprise organizations.',
|
| 418 |
+
),
|
| 419 |
+
eligibleNamespaces: (typed.detail?.eligible_namespaces as string[]) || [],
|
| 420 |
+
plan: ((typed.detail?.plan as 'free' | 'pro' | 'org') || 'free'),
|
| 421 |
+
mode: 'upgrade',
|
| 422 |
+
});
|
| 423 |
+
}
|
| 424 |
+
return;
|
| 425 |
+
}
|
| 426 |
+
if (error.message === 'HF_JOBS_NAMESPACE_REQUIRED') {
|
| 427 |
+
const typed = error as Error & {
|
| 428 |
+
detail?: Record<string, unknown>;
|
| 429 |
+
approvals?: Array<{
|
| 430 |
+
tool_call_id: string;
|
| 431 |
+
approved: boolean;
|
| 432 |
+
feedback?: string | null;
|
| 433 |
+
edited_script?: string | null;
|
| 434 |
+
namespace?: string | null;
|
| 435 |
+
}>;
|
| 436 |
+
};
|
| 437 |
+
void hydrateFromBackend();
|
| 438 |
+
if (isActiveRef.current) {
|
| 439 |
+
useAgentStore.getState().setJobsUpgradeRequired({
|
| 440 |
+
approvals: typed.approvals || [],
|
| 441 |
+
toolCallIds: (typed.detail?.tool_call_ids as string[]) || [],
|
| 442 |
+
message: String(typed.detail?.message || 'Choose which organization should own this job run.'),
|
| 443 |
+
eligibleNamespaces: (typed.detail?.eligible_namespaces as string[]) || [],
|
| 444 |
+
plan: ((typed.detail?.plan as 'free' | 'pro' | 'org') || 'free'),
|
| 445 |
+
mode: 'namespace',
|
| 446 |
+
});
|
| 447 |
+
}
|
| 448 |
+
return;
|
| 449 |
+
}
|
| 450 |
logger.error('useChat error:', error);
|
| 451 |
if (isActiveRef.current) {
|
| 452 |
useAgentStore.getState().setError(error.message);
|
|
|
|
| 765 |
|
| 766 |
// -- Approve tools ------------------------------------------------------
|
| 767 |
const approveTools = useCallback(
|
| 768 |
+
async (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null; namespace?: string | null }>) => {
|
| 769 |
// Store edited scripts so the transport can read them when sendMessages is called
|
| 770 |
for (const a of approvals) {
|
| 771 |
if (a.edited_script) {
|
| 772 |
useAgentStore.getState().setEditedScript(a.tool_call_id, a.edited_script);
|
| 773 |
}
|
| 774 |
+
if (a.namespace) {
|
| 775 |
+
useAgentStore.getState().setApprovalNamespace(a.tool_call_id, a.namespace);
|
| 776 |
+
}
|
| 777 |
}
|
| 778 |
|
| 779 |
// Update SDK tool state — this triggers sendMessages() via the transport
|
|
|
|
| 803 |
[sessionId, chat, updateSession, setNeedsAttention],
|
| 804 |
);
|
| 805 |
|
| 806 |
+
const declineBlockedJobs = useCallback(async () => {
|
| 807 |
+
const blocked = useAgentStore.getState().jobsUpgradeRequired;
|
| 808 |
+
if (!blocked) return false;
|
| 809 |
+
|
| 810 |
+
const approvals = blocked.approvals.map((approval) => ({
|
| 811 |
+
...approval,
|
| 812 |
+
approved: blocked.toolCallIds.includes(approval.tool_call_id) ? false : approval.approved,
|
| 813 |
+
feedback: blocked.toolCallIds.includes(approval.tool_call_id)
|
| 814 |
+
? 'Rejected because this account cannot launch Hugging Face Jobs.'
|
| 815 |
+
: approval.feedback,
|
| 816 |
+
}));
|
| 817 |
+
|
| 818 |
+
useAgentStore.getState().setJobsUpgradeRequired(null);
|
| 819 |
+
return approveTools(approvals);
|
| 820 |
+
}, [approveTools]);
|
| 821 |
+
|
| 822 |
+
const continueBlockedJobsWithNamespace = useCallback(async (namespace: string) => {
|
| 823 |
+
const blocked = useAgentStore.getState().jobsUpgradeRequired;
|
| 824 |
+
if (!blocked) return false;
|
| 825 |
+
|
| 826 |
+
const approvals = blocked.approvals.map((approval) => ({
|
| 827 |
+
...approval,
|
| 828 |
+
namespace: blocked.toolCallIds.includes(approval.tool_call_id)
|
| 829 |
+
? namespace
|
| 830 |
+
: approval.namespace,
|
| 831 |
+
}));
|
| 832 |
+
|
| 833 |
+
useAgentStore.getState().setJobsUpgradeRequired(null);
|
| 834 |
+
return approveTools(approvals);
|
| 835 |
+
}, [approveTools]);
|
| 836 |
+
|
| 837 |
// -- Stop (interrupt backend agent loop, keep SSE open for events) --------
|
| 838 |
const stop = useCallback(() => {
|
| 839 |
// Don't call chat.stop() — keep the SSE stream open so the backend's
|
|
|
|
| 890 |
undoLastTurn,
|
| 891 |
editAndRegenerate,
|
| 892 |
approveTools,
|
| 893 |
+
declineBlockedJobs,
|
| 894 |
+
continueBlockedJobsWithNamespace,
|
| 895 |
};
|
| 896 |
}
|
|
@@ -320,11 +320,13 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
|
|
| 320 |
const approved = p.approval?.approved ?? true;
|
| 321 |
// Get edited script from agentStore if available
|
| 322 |
const editedScript = useAgentStore.getState().getEditedScript(p.toolCallId);
|
|
|
|
| 323 |
return {
|
| 324 |
tool_call_id: p.toolCallId,
|
| 325 |
approved,
|
| 326 |
feedback: approved ? null : (p.approval?.reason || 'Rejected by user'),
|
| 327 |
edited_script: editedScript ?? null,
|
|
|
|
| 328 |
};
|
| 329 |
}).filter(Boolean);
|
| 330 |
body = { approvals };
|
|
@@ -362,6 +364,30 @@ export class SSEChatTransport implements ChatTransport<UIMessage> {
|
|
| 362 |
// instead of a generic error banner.
|
| 363 |
throw new Error('CLAUDE_QUOTA_EXHAUSTED');
|
| 364 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
if (!response.ok) {
|
| 366 |
const errorText = await response.text().catch(() => 'Request failed');
|
| 367 |
throw new Error(`Chat request failed: ${response.status} ${errorText}`);
|
|
|
|
| 320 |
const approved = p.approval?.approved ?? true;
|
| 321 |
// Get edited script from agentStore if available
|
| 322 |
const editedScript = useAgentStore.getState().getEditedScript(p.toolCallId);
|
| 323 |
+
const namespace = useAgentStore.getState().getApprovalNamespace(p.toolCallId);
|
| 324 |
return {
|
| 325 |
tool_call_id: p.toolCallId,
|
| 326 |
approved,
|
| 327 |
feedback: approved ? null : (p.approval?.reason || 'Rejected by user'),
|
| 328 |
edited_script: editedScript ?? null,
|
| 329 |
+
namespace: namespace ?? null,
|
| 330 |
};
|
| 331 |
}).filter(Boolean);
|
| 332 |
body = { approvals };
|
|
|
|
| 364 |
// instead of a generic error banner.
|
| 365 |
throw new Error('CLAUDE_QUOTA_EXHAUSTED');
|
| 366 |
}
|
| 367 |
+
if (response.status === 402) {
|
| 368 |
+
const payload = await response.json().catch(() => null);
|
| 369 |
+
if (payload?.detail?.error === 'hf_jobs_upgrade_required') {
|
| 370 |
+
const err = new Error('HF_JOBS_UPGRADE_REQUIRED') as Error & {
|
| 371 |
+
detail?: Record<string, unknown>;
|
| 372 |
+
approvals?: Array<Record<string, unknown>>;
|
| 373 |
+
};
|
| 374 |
+
err.detail = payload.detail as Record<string, unknown>;
|
| 375 |
+
err.approvals = (body.approvals as Array<Record<string, unknown>> | undefined) || [];
|
| 376 |
+
throw err;
|
| 377 |
+
}
|
| 378 |
+
}
|
| 379 |
+
if (response.status === 409) {
|
| 380 |
+
const payload = await response.json().catch(() => null);
|
| 381 |
+
if (payload?.detail?.error === 'hf_jobs_namespace_required') {
|
| 382 |
+
const err = new Error('HF_JOBS_NAMESPACE_REQUIRED') as Error & {
|
| 383 |
+
detail?: Record<string, unknown>;
|
| 384 |
+
approvals?: Array<Record<string, unknown>>;
|
| 385 |
+
};
|
| 386 |
+
err.detail = payload.detail as Record<string, unknown>;
|
| 387 |
+
err.approvals = (body.approvals as Array<Record<string, unknown>> | undefined) || [];
|
| 388 |
+
throw err;
|
| 389 |
+
}
|
| 390 |
+
}
|
| 391 |
if (!response.ok) {
|
| 392 |
const errorText = await response.text().catch(() => 'Request failed');
|
| 393 |
throw new Error(`Chat request failed: ${response.status} ${errorText}`);
|
|
@@ -45,6 +45,21 @@ export interface LLMHealthError {
|
|
| 45 |
model: string;
|
| 46 |
}
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
export type ActivityStatus =
|
| 49 |
| { type: 'idle' }
|
| 50 |
| { type: 'thinking' }
|
|
@@ -110,6 +125,7 @@ interface AgentStore {
|
|
| 110 |
llmHealthError: LLMHealthError | null;
|
| 111 |
/** Set when a Claude-send hits the daily quota — ChatInput opens the cap dialog in response. */
|
| 112 |
claudeQuotaExhausted: boolean;
|
|
|
|
| 113 |
|
| 114 |
// Right panel (single-artifact pattern)
|
| 115 |
panelData: PanelData | null;
|
|
@@ -122,6 +138,9 @@ interface AgentStore {
|
|
| 122 |
// Edited scripts (tool_call_id -> edited content)
|
| 123 |
editedScripts: Record<string, string>;
|
| 124 |
|
|
|
|
|
|
|
|
|
|
| 125 |
// Job URLs (tool_call_id -> job URL) for HF jobs
|
| 126 |
jobUrls: Record<string, string>;
|
| 127 |
|
|
@@ -156,6 +175,7 @@ interface AgentStore {
|
|
| 156 |
setError: (error: string | null) => void;
|
| 157 |
setLlmHealthError: (error: LLMHealthError | null) => void;
|
| 158 |
setClaudeQuotaExhausted: (exhausted: boolean) => void;
|
|
|
|
| 159 |
|
| 160 |
setPanel: (data: PanelData, view?: PanelView, editable?: boolean) => void;
|
| 161 |
setPanelView: (view: PanelView) => void;
|
|
@@ -170,6 +190,10 @@ interface AgentStore {
|
|
| 170 |
getEditedScript: (toolCallId: string) => string | undefined;
|
| 171 |
clearEditedScripts: () => void;
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
setJobUrl: (toolCallId: string, jobUrl: string) => void;
|
| 174 |
getJobUrl: (toolCallId: string) => string | undefined;
|
| 175 |
|
|
@@ -251,6 +275,7 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 251 |
error: null,
|
| 252 |
llmHealthError: null,
|
| 253 |
claudeQuotaExhausted: false,
|
|
|
|
| 254 |
|
| 255 |
panelData: null,
|
| 256 |
panelView: 'script',
|
|
@@ -259,6 +284,7 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 259 |
plan: [],
|
| 260 |
|
| 261 |
editedScripts: {},
|
|
|
|
| 262 |
jobUrls: {},
|
| 263 |
jobStatuses: {},
|
| 264 |
toolErrors: loadToolErrors(),
|
|
@@ -363,6 +389,7 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 363 |
setError: (error) => set({ error }),
|
| 364 |
setLlmHealthError: (error) => set({ llmHealthError: error }),
|
| 365 |
setClaudeQuotaExhausted: (exhausted) => set({ claudeQuotaExhausted: exhausted }),
|
|
|
|
| 366 |
|
| 367 |
// ── Panel (single-artifact) ───────────────────────────────────────
|
| 368 |
// Each setter also patches the active session's snapshot so that
|
|
@@ -428,6 +455,16 @@ export const useAgentStore = create<AgentStore>()((set, get) => ({
|
|
| 428 |
|
| 429 |
clearEditedScripts: () => set({ editedScripts: {} }),
|
| 430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
// ── Job URLs ────────────────────────────────────────────────────────
|
| 432 |
|
| 433 |
setJobUrl: (toolCallId, jobUrl) => {
|
|
|
|
| 45 |
model: string;
|
| 46 |
}
|
| 47 |
|
| 48 |
+
export interface JobsUpgradeState {
|
| 49 |
+
approvals: Array<{
|
| 50 |
+
tool_call_id: string;
|
| 51 |
+
approved: boolean;
|
| 52 |
+
feedback?: string | null;
|
| 53 |
+
edited_script?: string | null;
|
| 54 |
+
namespace?: string | null;
|
| 55 |
+
}>;
|
| 56 |
+
toolCallIds: string[];
|
| 57 |
+
message: string;
|
| 58 |
+
eligibleNamespaces: string[];
|
| 59 |
+
plan: 'free' | 'pro' | 'org';
|
| 60 |
+
mode: 'upgrade' | 'namespace';
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
export type ActivityStatus =
|
| 64 |
| { type: 'idle' }
|
| 65 |
| { type: 'thinking' }
|
|
|
|
| 125 |
llmHealthError: LLMHealthError | null;
|
| 126 |
/** Set when a Claude-send hits the daily quota — ChatInput opens the cap dialog in response. */
|
| 127 |
claudeQuotaExhausted: boolean;
|
| 128 |
+
jobsUpgradeRequired: JobsUpgradeState | null;
|
| 129 |
|
| 130 |
// Right panel (single-artifact pattern)
|
| 131 |
panelData: PanelData | null;
|
|
|
|
| 138 |
// Edited scripts (tool_call_id -> edited content)
|
| 139 |
editedScripts: Record<string, string>;
|
| 140 |
|
| 141 |
+
// Namespace overrides chosen for hf_jobs approvals (tool_call_id -> namespace)
|
| 142 |
+
approvalNamespaces: Record<string, string>;
|
| 143 |
+
|
| 144 |
// Job URLs (tool_call_id -> job URL) for HF jobs
|
| 145 |
jobUrls: Record<string, string>;
|
| 146 |
|
|
|
|
| 175 |
setError: (error: string | null) => void;
|
| 176 |
setLlmHealthError: (error: LLMHealthError | null) => void;
|
| 177 |
setClaudeQuotaExhausted: (exhausted: boolean) => void;
|
| 178 |
+
setJobsUpgradeRequired: (state: JobsUpgradeState | null) => void;
|
| 179 |
|
| 180 |
setPanel: (data: PanelData, view?: PanelView, editable?: boolean) => void;
|
| 181 |
setPanelView: (view: PanelView) => void;
|
|
|
|
| 190 |
getEditedScript: (toolCallId: string) => string | undefined;
|
| 191 |
clearEditedScripts: () => void;
|
| 192 |
|
| 193 |
+
setApprovalNamespace: (toolCallId: string, namespace: string) => void;
|
| 194 |
+
getApprovalNamespace: (toolCallId: string) => string | undefined;
|
| 195 |
+
clearApprovalNamespaces: () => void;
|
| 196 |
+
|
| 197 |
setJobUrl: (toolCallId: string, jobUrl: string) => void;
|
| 198 |
getJobUrl: (toolCallId: string) => string | undefined;
|
| 199 |
|
|
|
|
| 275 |
error: null,
|
| 276 |
llmHealthError: null,
|
| 277 |
claudeQuotaExhausted: false,
|
| 278 |
+
jobsUpgradeRequired: null,
|
| 279 |
|
| 280 |
panelData: null,
|
| 281 |
panelView: 'script',
|
|
|
|
| 284 |
plan: [],
|
| 285 |
|
| 286 |
editedScripts: {},
|
| 287 |
+
approvalNamespaces: {},
|
| 288 |
jobUrls: {},
|
| 289 |
jobStatuses: {},
|
| 290 |
toolErrors: loadToolErrors(),
|
|
|
|
| 389 |
setError: (error) => set({ error }),
|
| 390 |
setLlmHealthError: (error) => set({ llmHealthError: error }),
|
| 391 |
setClaudeQuotaExhausted: (exhausted) => set({ claudeQuotaExhausted: exhausted }),
|
| 392 |
+
setJobsUpgradeRequired: (state) => set({ jobsUpgradeRequired: state }),
|
| 393 |
|
| 394 |
// ── Panel (single-artifact) ───────────────────────────────────────
|
| 395 |
// Each setter also patches the active session's snapshot so that
|
|
|
|
| 455 |
|
| 456 |
clearEditedScripts: () => set({ editedScripts: {} }),
|
| 457 |
|
| 458 |
+
setApprovalNamespace: (toolCallId, namespace) => {
|
| 459 |
+
set((state) => ({
|
| 460 |
+
approvalNamespaces: { ...state.approvalNamespaces, [toolCallId]: namespace },
|
| 461 |
+
}));
|
| 462 |
+
},
|
| 463 |
+
|
| 464 |
+
getApprovalNamespace: (toolCallId) => get().approvalNamespaces[toolCallId],
|
| 465 |
+
|
| 466 |
+
clearApprovalNamespaces: () => set({ approvalNamespaces: {} }),
|
| 467 |
+
|
| 468 |
// ── Job URLs ────────────────────────────────────────────────────────
|
| 469 |
|
| 470 |
setJobUrl: (toolCallId, jobUrl) => {
|
|
@@ -27,6 +27,7 @@ export interface ToolApproval {
|
|
| 27 |
tool_call_id: string;
|
| 28 |
approved: boolean;
|
| 29 |
feedback?: string | null;
|
|
|
|
| 30 |
}
|
| 31 |
|
| 32 |
export interface User {
|
|
|
|
| 27 |
tool_call_id: string;
|
| 28 |
approved: boolean;
|
| 29 |
feedback?: string | null;
|
| 30 |
+
namespace?: string | null;
|
| 31 |
}
|
| 32 |
|
| 33 |
export interface User {
|
|
@@ -44,7 +44,8 @@ re-running the same hour overwrites.
|
|
| 44 |
regenerate_rate — sessions with any `undo_complete` event / sessions
|
| 45 |
time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
|
| 46 |
thumbs_up / thumbs_down
|
| 47 |
-
hf_jobs_submitted / _succeeded
|
|
|
|
| 48 |
gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
|
| 49 |
|
| 50 |
================================================================================
|
|
@@ -210,7 +211,8 @@ def _session_metrics(session: dict) -> dict:
|
|
| 210 |
"tool_calls_total": 0, "tool_calls_success": 0,
|
| 211 |
"failures": 0, "regenerate_sessions": 0,
|
| 212 |
"thumbs_up": 0, "thumbs_down": 0,
|
| 213 |
-
"hf_jobs_submitted": 0, "hf_jobs_succeeded": 0,
|
|
|
|
| 214 |
"first_tool_s": -1,
|
| 215 |
}
|
| 216 |
events = session.get("events") or []
|
|
@@ -229,8 +231,11 @@ def _session_metrics(session: dict) -> dict:
|
|
| 229 |
gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
|
| 230 |
jobs_submitted = 0
|
| 231 |
jobs_succeeded = 0
|
|
|
|
| 232 |
thumbs_up = 0
|
| 233 |
thumbs_down = 0
|
|
|
|
|
|
|
| 234 |
|
| 235 |
start_dt = _parse_ts(session_start)
|
| 236 |
|
|
@@ -283,6 +288,14 @@ def _session_metrics(session: dict) -> dict:
|
|
| 283 |
if status in ("completed", "succeeded", "success"):
|
| 284 |
jobs_succeeded += 1
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
out["tool_calls_total"] = tool_total
|
| 287 |
out["tool_calls_success"] = tool_success
|
| 288 |
out["failures"] = 1 if had_error else 0
|
|
@@ -291,8 +304,11 @@ def _session_metrics(session: dict) -> dict:
|
|
| 291 |
out["thumbs_down"] = thumbs_down
|
| 292 |
out["hf_jobs_submitted"] = jobs_submitted
|
| 293 |
out["hf_jobs_succeeded"] = jobs_succeeded
|
|
|
|
|
|
|
| 294 |
out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
|
| 295 |
out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
|
|
|
|
| 296 |
out["_user"] = session.get("user_id") or session.get("session_id")
|
| 297 |
return dict(out)
|
| 298 |
|
|
@@ -301,9 +317,12 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 301 |
"""Collapse a bucket's worth of session rollups into the final KPI row."""
|
| 302 |
ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
|
| 303 |
gpu_hours: dict[str, float] = defaultdict(float)
|
|
|
|
| 304 |
for s in per_session:
|
| 305 |
for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
|
| 306 |
gpu_hours[f] += h
|
|
|
|
|
|
|
| 307 |
|
| 308 |
total_sessions = sum(s["sessions"] for s in per_session)
|
| 309 |
total_turns = sum(s["turns"] for s in per_session)
|
|
@@ -340,7 +359,10 @@ def _aggregate(per_session: list[dict]) -> dict:
|
|
| 340 |
"thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
|
| 341 |
"hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
|
| 342 |
"hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
|
|
|
|
|
|
|
| 343 |
"gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
|
|
|
|
| 344 |
}
|
| 345 |
|
| 346 |
|
|
|
|
| 44 |
regenerate_rate — sessions with any `undo_complete` event / sessions
|
| 45 |
time_to_first_action_s_p50 / _p95 — from session_start to first tool_call
|
| 46 |
thumbs_up / thumbs_down
|
| 47 |
+
hf_jobs_submitted / _succeeded / _blocked
|
| 48 |
+
pro_cta_clicks
|
| 49 |
gpu_hours_by_flavor_json — JSON-serialised {flavor: gpu-hours}
|
| 50 |
|
| 51 |
================================================================================
|
|
|
|
| 211 |
"tool_calls_total": 0, "tool_calls_success": 0,
|
| 212 |
"failures": 0, "regenerate_sessions": 0,
|
| 213 |
"thumbs_up": 0, "thumbs_down": 0,
|
| 214 |
+
"hf_jobs_submitted": 0, "hf_jobs_succeeded": 0, "hf_jobs_blocked": 0,
|
| 215 |
+
"pro_cta_clicks": 0,
|
| 216 |
"first_tool_s": -1,
|
| 217 |
}
|
| 218 |
events = session.get("events") or []
|
|
|
|
| 231 |
gpu_hours_by_flavor: dict[str, float] = defaultdict(float)
|
| 232 |
jobs_submitted = 0
|
| 233 |
jobs_succeeded = 0
|
| 234 |
+
jobs_blocked = 0
|
| 235 |
thumbs_up = 0
|
| 236 |
thumbs_down = 0
|
| 237 |
+
pro_cta_clicks = 0
|
| 238 |
+
pro_cta_by_source: dict[str, int] = defaultdict(int)
|
| 239 |
|
| 240 |
start_dt = _parse_ts(session_start)
|
| 241 |
|
|
|
|
| 288 |
if status in ("completed", "succeeded", "success"):
|
| 289 |
jobs_succeeded += 1
|
| 290 |
|
| 291 |
+
elif et == "jobs_access_blocked":
|
| 292 |
+
jobs_blocked += 1
|
| 293 |
+
|
| 294 |
+
elif et == "pro_cta_click":
|
| 295 |
+
pro_cta_clicks += 1
|
| 296 |
+
source = str(data.get("source") or "unknown")
|
| 297 |
+
pro_cta_by_source[source] += 1
|
| 298 |
+
|
| 299 |
out["tool_calls_total"] = tool_total
|
| 300 |
out["tool_calls_success"] = tool_success
|
| 301 |
out["failures"] = 1 if had_error else 0
|
|
|
|
| 304 |
out["thumbs_down"] = thumbs_down
|
| 305 |
out["hf_jobs_submitted"] = jobs_submitted
|
| 306 |
out["hf_jobs_succeeded"] = jobs_succeeded
|
| 307 |
+
out["hf_jobs_blocked"] = jobs_blocked
|
| 308 |
+
out["pro_cta_clicks"] = pro_cta_clicks
|
| 309 |
out["first_tool_s"] = first_tool_ts if first_tool_ts is not None else -1
|
| 310 |
out["_gpu_hours_by_flavor"] = dict(gpu_hours_by_flavor)
|
| 311 |
+
out["_pro_cta_by_source"] = dict(pro_cta_by_source)
|
| 312 |
out["_user"] = session.get("user_id") or session.get("session_id")
|
| 313 |
return dict(out)
|
| 314 |
|
|
|
|
| 317 |
"""Collapse a bucket's worth of session rollups into the final KPI row."""
|
| 318 |
ttfa_values = [s["first_tool_s"] for s in per_session if s.get("first_tool_s", -1) >= 0]
|
| 319 |
gpu_hours: dict[str, float] = defaultdict(float)
|
| 320 |
+
pro_cta_by_source: dict[str, int] = defaultdict(int)
|
| 321 |
for s in per_session:
|
| 322 |
for f, h in (s.get("_gpu_hours_by_flavor") or {}).items():
|
| 323 |
gpu_hours[f] += h
|
| 324 |
+
for source, count in (s.get("_pro_cta_by_source") or {}).items():
|
| 325 |
+
pro_cta_by_source[source] += int(count)
|
| 326 |
|
| 327 |
total_sessions = sum(s["sessions"] for s in per_session)
|
| 328 |
total_turns = sum(s["turns"] for s in per_session)
|
|
|
|
| 359 |
"thumbs_down": int(sum(s["thumbs_down"] for s in per_session)),
|
| 360 |
"hf_jobs_submitted": int(sum(s["hf_jobs_submitted"] for s in per_session)),
|
| 361 |
"hf_jobs_succeeded": int(sum(s["hf_jobs_succeeded"] for s in per_session)),
|
| 362 |
+
"hf_jobs_blocked": int(sum(s["hf_jobs_blocked"] for s in per_session)),
|
| 363 |
+
"pro_cta_clicks": int(sum(s["pro_cta_clicks"] for s in per_session)),
|
| 364 |
"gpu_hours_by_flavor_json": json.dumps(dict(gpu_hours), sort_keys=True),
|
| 365 |
+
"pro_cta_by_source_json": json.dumps(dict(pro_cta_by_source), sort_keys=True),
|
| 366 |
}
|
| 367 |
|
| 368 |
|
|
@@ -88,6 +88,22 @@ def test_hf_job_gpu_hours():
|
|
| 88 |
assert abs(m["_gpu_hours_by_flavor"]["a100-large"] - 1.0) < 1e-6
|
| 89 |
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def test_feedback_counts():
|
| 92 |
mod = _load()
|
| 93 |
events = [
|
|
@@ -120,6 +136,22 @@ def test_aggregate_day_cache_hit_and_users():
|
|
| 120 |
assert abs(row["cost_usd"] - 1.5) < 1e-9
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
def test_failure_and_regenerate_rates():
|
| 124 |
mod = _load()
|
| 125 |
s1 = mod._session_metrics(_session([_ev("error", {"error": "boom"})], user_id="a"))
|
|
|
|
| 88 |
assert abs(m["_gpu_hours_by_flavor"]["a100-large"] - 1.0) < 1e-6
|
| 89 |
|
| 90 |
|
| 91 |
+
def test_hf_job_blocked_and_pro_clicks_are_counted():
|
| 92 |
+
mod = _load()
|
| 93 |
+
events = [
|
| 94 |
+
_ev("jobs_access_blocked", {"tool_call_ids": ["tc1"], "plan": "free"}),
|
| 95 |
+
_ev("pro_cta_click", {"source": "hf_jobs_upgrade_dialog"}),
|
| 96 |
+
_ev("pro_cta_click", {"source": "claude_cap_dialog"}),
|
| 97 |
+
]
|
| 98 |
+
m = mod._session_metrics(_session(events))
|
| 99 |
+
assert m["hf_jobs_blocked"] == 1
|
| 100 |
+
assert m["pro_cta_clicks"] == 2
|
| 101 |
+
assert m["_pro_cta_by_source"] == {
|
| 102 |
+
"hf_jobs_upgrade_dialog": 1,
|
| 103 |
+
"claude_cap_dialog": 1,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
def test_feedback_counts():
|
| 108 |
mod = _load()
|
| 109 |
events = [
|
|
|
|
| 136 |
assert abs(row["cost_usd"] - 1.5) < 1e-9
|
| 137 |
|
| 138 |
|
| 139 |
+
def test_aggregate_day_sums_pro_click_sources():
|
| 140 |
+
mod = _load()
|
| 141 |
+
s1 = mod._session_metrics(_session([
|
| 142 |
+
_ev("pro_cta_click", {"source": "hf_jobs_upgrade_dialog"}),
|
| 143 |
+
_ev("pro_cta_click", {"source": "hf_jobs_upgrade_dialog"}),
|
| 144 |
+
], user_id="u1"))
|
| 145 |
+
s2 = mod._session_metrics(_session([
|
| 146 |
+
_ev("pro_cta_click", {"source": "claude_cap_dialog"}),
|
| 147 |
+
], user_id="u2"))
|
| 148 |
+
row = mod._aggregate_day([s1, s2])
|
| 149 |
+
assert row["pro_cta_clicks"] == 3
|
| 150 |
+
assert row["pro_cta_by_source_json"] == (
|
| 151 |
+
'{"claude_cap_dialog": 1, "hf_jobs_upgrade_dialog": 2}'
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
def test_failure_and_regenerate_rates():
|
| 156 |
mod = _load()
|
| 157 |
s1 = mod._session_metrics(_session([_ev("error", {"error": "boom"})], user_id="a"))
|
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agent.core.hf_access import jobs_access_from_whoami
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_personal_pro_prefers_username_namespace():
|
| 5 |
+
access = jobs_access_from_whoami({
|
| 6 |
+
"name": "alice",
|
| 7 |
+
"plan": "pro",
|
| 8 |
+
"orgs": [],
|
| 9 |
+
})
|
| 10 |
+
assert access.plan == "pro"
|
| 11 |
+
assert access.eligible_namespaces == ["alice"]
|
| 12 |
+
assert access.default_namespace == "alice"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_free_user_with_paid_org_uses_org_namespace():
|
| 16 |
+
access = jobs_access_from_whoami({
|
| 17 |
+
"name": "alice",
|
| 18 |
+
"plan": "free",
|
| 19 |
+
"orgs": [
|
| 20 |
+
{"name": "team-a", "plan": "team"},
|
| 21 |
+
{"name": "oss-friends", "plan": "free"},
|
| 22 |
+
],
|
| 23 |
+
})
|
| 24 |
+
assert access.plan == "org"
|
| 25 |
+
assert access.personal_can_run_jobs is False
|
| 26 |
+
assert access.eligible_namespaces == ["team-a"]
|
| 27 |
+
assert access.default_namespace is None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_free_user_without_paid_org_cannot_run_jobs():
|
| 31 |
+
access = jobs_access_from_whoami({
|
| 32 |
+
"name": "alice",
|
| 33 |
+
"plan": "free",
|
| 34 |
+
"orgs": [{"name": "community", "plan": "free"}],
|
| 35 |
+
})
|
| 36 |
+
assert access.plan == "free"
|
| 37 |
+
assert access.can_run_jobs is False
|
| 38 |
+
assert access.eligible_namespaces == []
|
| 39 |
+
assert access.default_namespace is None
|