Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Aksel Joonas Reedi
Run jobs under user or org accounts with upgrade / org-pick UX (#132)
ff8c636 unverified | """Helpers for Hugging Face account / org access decisions.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import httpx | |
| OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") | |
| class JobsAccess: | |
| """Jobs entitlement derived from whoami-v2.""" | |
| username: str | None | |
| plan: str | |
| personal_can_run_jobs: bool | |
| paid_org_names: list[str] | |
| eligible_namespaces: list[str] | |
| default_namespace: str | None | |
| access_known: bool = True | |
| def can_run_jobs(self) -> bool: | |
| return bool(self.default_namespace) | |
| class JobsAccessError(Exception): | |
| """Structured jobs access error for upgrade / namespace gating.""" | |
| def __init__( | |
| self, | |
| message: str, | |
| *, | |
| access: JobsAccess | None = None, | |
| upgrade_required: bool = False, | |
| namespace_required: bool = False, | |
| ) -> None: | |
| super().__init__(message) | |
| self.access = access | |
| self.upgrade_required = upgrade_required | |
| self.namespace_required = namespace_required | |
| def _extract_username(whoami: dict[str, Any]) -> str | None: | |
| for key in ("name", "user", "preferred_username"): | |
| value = whoami.get(key) | |
| if isinstance(value, str) and value: | |
| return value | |
| return None | |
| def _normalize_personal_plan(whoami: dict[str, Any]) -> str: | |
| plan_str = "" | |
| for key in ("plan", "type", "accountType"): | |
| value = whoami.get(key) | |
| if isinstance(value, str) and value: | |
| plan_str = value.lower() | |
| break | |
| if not plan_str and (whoami.get("isPro") is True or whoami.get("is_pro") is True): | |
| return "pro" | |
| if any(tag in plan_str for tag in ("pro", "enterprise", "team")): | |
| return "pro" | |
| return "free" | |
| def _paid_org_names(whoami: dict[str, Any]) -> list[str]: | |
| names: list[str] = [] | |
| orgs = whoami.get("orgs") or [] | |
| if not isinstance(orgs, list): | |
| return names | |
| for org in orgs: | |
| if not isinstance(org, dict): | |
| continue | |
| name = org.get("name") | |
| if not isinstance(name, str) or not name: | |
| continue | |
| org_plan = str(org.get("plan") or org.get("type") or "").lower() | |
| if any(tag in org_plan for tag in ("pro", "enterprise", "team")): | |
| names.append(name) | |
| return sorted(set(names)) | |
| def jobs_access_from_whoami(whoami: dict[str, Any]) -> JobsAccess: | |
| username = _extract_username(whoami) | |
| personal_plan = _normalize_personal_plan(whoami) | |
| paid_orgs = _paid_org_names(whoami) | |
| personal_can_run = personal_plan == "pro" | |
| eligible_namespaces: list[str] = [] | |
| if personal_can_run and username: | |
| eligible_namespaces.append(username) | |
| eligible_namespaces.extend(paid_orgs) | |
| plan = "pro" if personal_can_run else ("org" if paid_orgs else "free") | |
| default_namespace = username if personal_can_run and username else None | |
| return JobsAccess( | |
| username=username, | |
| plan=plan, | |
| personal_can_run_jobs=personal_can_run, | |
| paid_org_names=paid_orgs, | |
| eligible_namespaces=eligible_namespaces, | |
| default_namespace=default_namespace, | |
| ) | |
| async def fetch_whoami_v2(token: str, timeout: float = 5.0) -> dict[str, Any] | None: | |
| if not token: | |
| return None | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| try: | |
| response = await client.get( | |
| f"{OPENID_PROVIDER_URL}/api/whoami-v2", | |
| headers={"Authorization": f"Bearer {token}"}, | |
| ) | |
| if response.status_code != 200: | |
| return None | |
| payload = response.json() | |
| return payload if isinstance(payload, dict) else None | |
| except (httpx.HTTPError, ValueError): | |
| return None | |
| async def get_jobs_access(token: str) -> JobsAccess | None: | |
| whoami = await fetch_whoami_v2(token) | |
| if whoami is None: | |
| return None | |
| return jobs_access_from_whoami(whoami) | |
| async def resolve_jobs_namespace( | |
| token: str, | |
| requested_namespace: str | None = None, | |
| ) -> tuple[str, JobsAccess | None]: | |
| """Return the namespace to use for jobs. | |
| If whoami-v2 is unavailable, fall back to the token owner's username. | |
| """ | |
| access = await get_jobs_access(token) | |
| if access: | |
| if requested_namespace: | |
| if requested_namespace in access.eligible_namespaces: | |
| return requested_namespace, access | |
| raise JobsAccessError( | |
| f"You can only run jobs under your own Pro account or a paid org you belong to. " | |
| f"Allowed namespaces: {', '.join(access.eligible_namespaces) or '(none)'}", | |
| access=access, | |
| ) | |
| if access.default_namespace: | |
| return access.default_namespace, access | |
| if access.paid_org_names: | |
| raise JobsAccessError( | |
| "Choose which paid organization should own this job run.", | |
| access=access, | |
| namespace_required=True, | |
| ) | |
| raise JobsAccessError( | |
| "Hugging Face Jobs are available only to Pro users and Team or Enterprise organizations. " | |
| "Upgrade to Pro, or run the job under a paid org you belong to.", | |
| access=access, | |
| upgrade_required=True, | |
| ) | |
| # Fallback: whoami-v2 unavailable. Do not block the call pre-emptively. | |
| from huggingface_hub import HfApi | |
| username = None | |
| if token: | |
| whoami = await asyncio.to_thread(HfApi(token=token).whoami) | |
| username = whoami.get("name") | |
| if not username: | |
| raise JobsAccessError("No HF token available to resolve a jobs namespace.") | |
| return requested_namespace or username, None | |