diff --git a/README.md b/README.md index 729674aabe970a7c0cf11778b679e84cbf4a371f..fed2a689ba76144abbf89ce49d0b74a973325062 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,11 @@ hf_oauth: true hf_oauth_scopes: - read-repos - write-repos + - contribute-repos + - manage-repos - inference-api + - jobs + - write-discussions --- # HF Agent diff --git a/agent/context_manager/manager.py b/agent/context_manager/manager.py index 8d7296ae0f8952dd976b522d93a34156db41ed72..d1a3f3aa64157e56e161794737e0983b143ee5ef 100644 --- a/agent/context_manager/manager.py +++ b/agent/context_manager/manager.py @@ -2,6 +2,7 @@ Context management for conversation history """ +import logging import os import zoneinfo from datetime import datetime @@ -13,6 +14,72 @@ from huggingface_hub import HfApi from jinja2 import Template from litellm import Message, acompletion +logger = logging.getLogger(__name__) + +# Module-level cache for HF username — avoids repeating the slow whoami() call +_hf_username_cache: str | None = None + +_HF_WHOAMI_URL = "https://huggingface.co/api/whoami-v2" +_HF_WHOAMI_TIMEOUT = 5 # seconds + + +def _get_hf_username() -> str: + """Return the HF username, cached after the first call. + + Uses subprocess + curl to avoid Python HTTP client IPv6 issues that + cause 40+ second hangs (httpx/urllib try IPv6 first which times out + at OS level before falling back to IPv4 — the "Happy Eyeballs" problem). + """ + import json + import subprocess + import time as _t + + global _hf_username_cache + if _hf_username_cache is not None: + return _hf_username_cache + + hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") + if not hf_token: + logger.warning("No HF_TOKEN set, using 'unknown' as username") + _hf_username_cache = "unknown" + return _hf_username_cache + + t0 = _t.monotonic() + try: + result = subprocess.run( + [ + "curl", + "-s", + "-4", # force IPv4 + "-m", + str(_HF_WHOAMI_TIMEOUT), # max time + "-H", + f"Authorization: Bearer {hf_token}", + _HF_WHOAMI_URL, + ], + capture_output=True, + text=True, + timeout=_HF_WHOAMI_TIMEOUT + 2, + ) + t1 = _t.monotonic() + if result.returncode == 0 and result.stdout: + data = json.loads(result.stdout) + _hf_username_cache = data.get("name", "unknown") + logger.info( + f"HF username resolved to '{_hf_username_cache}' in {t1 - t0:.2f}s" + ) + else: + logger.warning( + f"curl whoami failed (rc={result.returncode}) in {t1 - t0:.2f}s" + ) + _hf_username_cache = "unknown" + except Exception as e: + t1 = _t.monotonic() + logger.warning(f"HF whoami failed in {t1 - t0:.2f}s: {e}") + _hf_username_cache = "unknown" + + return _hf_username_cache + class ContextManager: """Manages conversation context and message history for the agent""" @@ -54,9 +121,8 @@ class ContextManager: current_time = now.strftime("%H:%M:%S.%f")[:-3] current_timezone = f"{now.strftime('%Z')} (UTC{now.strftime('%z')[:3]}:{now.strftime('%z')[3:]})" - # Get HF user info with explicit token from env - hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") - hf_user_info = HfApi(token=hf_token).whoami().get("name", "unknown") + # Get HF user info (cached after the first call) + hf_user_info = _get_hf_username() template = Template(template_str) return template.render( @@ -110,11 +176,15 @@ class ContextManager: ) ) + hf_key = os.environ.get("INFERENCE_TOKEN") response = await acompletion( model=model_name, messages=messages_to_summarize, max_completion_tokens=self.compact_size, tools=tool_specs, + api_key=hf_key + if hf_key and model_name.startswith("huggingface/") + else None, ) summarized_message = Message( role="assistant", content=response.choices[0].message.content diff --git a/agent/core/agent_loop.py b/agent/core/agent_loop.py index b1f85cd156ce8f79f1e51ffae72f4ee818b28480..1e9085df13847d876c81efc4be79a00fb1bb8675 100644 --- a/agent/core/agent_loop.py +++ b/agent/core/agent_loop.py @@ -4,8 +4,10 @@ Main agent implementation with integrated tool system and MCP support import asyncio import json +import logging +import os -from litellm import ChatCompletionMessageToolCall, Message, ModelResponse, acompletion +from litellm import ChatCompletionMessageToolCall, Message, acompletion from litellm.exceptions import ContextWindowExceededError from lmnr import observe @@ -14,7 +16,42 @@ from agent.core.session import Event, OpType, Session from agent.core.tools import ToolRouter from agent.tools.jobs_tool import CPU_FLAVORS +logger = logging.getLogger(__name__) + ToolCall = ChatCompletionMessageToolCall +# Explicit inference token — needed because litellm checks HF_TOKEN before +# HUGGINGFACE_API_KEY, and HF_TOKEN (used for Hub ops) may lack inference permissions. +_INFERENCE_API_KEY = os.environ.get("INFERENCE_TOKEN") + + +def _resolve_hf_router_params(model_name: str) -> dict: + """ + Build LiteLLM kwargs for HuggingFace Router models. + + api-inference.huggingface.co is deprecated; the new router lives at + router.huggingface.co//v3/openai. LiteLLM's built-in + ``huggingface/`` provider still targets the old endpoint, so we + rewrite model names to ``openai/`` and supply the correct api_base. + + Input format: huggingface/// + Example: huggingface/novita/moonshotai/kimi-k2.5 + """ + if not model_name.startswith("huggingface/"): + return {"model": model_name} + + parts = model_name.split("/", 2) # ['huggingface', 'novita', 'moonshotai/kimi-k2.5'] + if len(parts) < 3: + return {"model": model_name} + + router_provider = parts[1] + actual_model = parts[2] + api_key = _INFERENCE_API_KEY or os.environ.get("HF_TOKEN") + + return { + "model": f"openai/{actual_model}", + "api_base": f"https://router.huggingface.co/{router_provider}/v3/openai", + "api_key": api_key, + } def _validate_tool_args(tool_args: dict) -> tuple[bool, str | None]: @@ -130,6 +167,42 @@ async def _compact_and_notify(session: Session) -> None: class Handlers: """Handler functions for each operation type""" + @staticmethod + async def _abandon_pending_approval(session: Session) -> None: + """Cancel pending approval tools when the user continues the conversation. + + Injects rejection tool-result messages into the LLM context (so the + history stays valid) and notifies the frontend that those tools were + abandoned. + """ + tool_calls = session.pending_approval.get("tool_calls", []) + for tc in tool_calls: + tool_name = tc.function.name + abandon_msg = "Task abandoned — user continued the conversation without approving." + + # Keep LLM context valid: every tool_call needs a tool result + tool_msg = Message( + role="tool", + content=abandon_msg, + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "abandoned", + }, + ) + ) + + session.pending_approval = None + logger.info("Abandoned %d pending approval tool(s)", len(tool_calls)) + @staticmethod @observe(name="run_agent") async def run_agent( @@ -145,6 +218,11 @@ class Handlers: Laminar.set_trace_session_id(session_id=session.session_id) + # If there's a pending approval and the user sent a new message, + # abandon the pending tools so the LLM context stays valid. + if text and session.pending_approval: + await Handlers._abandon_pending_approval(session) + # Add user message to history only if there's actual content if text: user_msg = Message(role="user", content=text) @@ -165,37 +243,100 @@ class Handlers: messages = session.context_manager.get_messages() tools = session.tool_router.get_tool_specs_for_llm() - try: - response: ModelResponse = await acompletion( - model=session.config.model_name, + # ── Stream the LLM response ────────────────────────── + llm_params = _resolve_hf_router_params(session.config.model_name) + response = await acompletion( messages=messages, tools=tools, tool_choice="auto", + stream=True, + stream_options={"include_usage": True}, + **llm_params, ) - # Extract text response, token usage, and tool calls - message = response.choices[0].message - content = message.content - token_count = response.usage.total_tokens - tool_calls: list[ToolCall] = message.get("tool_calls", []) + full_content = "" + tool_calls_acc: dict[int, dict] = {} + token_count = 0 + + async for chunk in response: + choice = chunk.choices[0] if chunk.choices else None + if not choice: + # Last chunk may carry only usage info + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + continue + + delta = choice.delta + + # Stream text deltas to the frontend + if delta.content: + full_content += delta.content + await session.send_event( + Event( + event_type="assistant_chunk", + data={"content": delta.content}, + ) + ) + + # Accumulate tool-call deltas (name + args arrive in pieces) + if delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in tool_calls_acc: + tool_calls_acc[idx] = { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + if tc_delta.id: + tool_calls_acc[idx]["id"] = tc_delta.id + if tc_delta.function: + if tc_delta.function.name: + tool_calls_acc[idx]["function"]["name"] += ( + tc_delta.function.name + ) + if tc_delta.function.arguments: + tool_calls_acc[idx]["function"]["arguments"] += ( + tc_delta.function.arguments + ) + + # Capture usage from the final chunk + if hasattr(chunk, "usage") and chunk.usage: + token_count = chunk.usage.total_tokens + + # ── Stream finished — reconstruct full message ─────── + content = full_content or None + + # Build tool_calls list from accumulated deltas + tool_calls: list[ToolCall] = [] + for idx in sorted(tool_calls_acc.keys()): + tc_data = tool_calls_acc[idx] + tool_calls.append( + ToolCall( + id=tc_data["id"], + type="function", + function={ + "name": tc_data["function"]["name"], + "arguments": tc_data["function"]["arguments"], + }, + ) + ) + + # Signal end of streaming to the frontend + await session.send_event( + Event(event_type="assistant_stream_end", data={}) + ) # If no tool calls, add assistant message and we're done if not tool_calls: if content: assistant_msg = Message(role="assistant", content=content) session.context_manager.add_message(assistant_msg, token_count) - await session.send_event( - Event( - event_type="assistant_message", - data={"content": content}, - ) - ) final_response = content break # Add assistant message with tool calls to history - # LiteLLM will format this correctly for the provider assistant_msg = Message( role="assistant", content=content, @@ -203,66 +344,97 @@ class Handlers: ) session.context_manager.add_message(assistant_msg, token_count) - if content: - await session.send_event( - Event(event_type="assistant_message", data={"content": content}) - ) - # Separate tools into those requiring approval and those that don't approval_required_tools = [] non_approval_tools = [] for tc in tool_calls: tool_name = tc.function.name - tool_args = json.loads(tc.function.arguments) + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"Malformed tool arguments for {tool_name}: {e}") + tool_args = {} if _needs_approval(tool_name, tool_args, session.config): approval_required_tools.append(tc) else: non_approval_tools.append(tc) + # Execute non-approval tools (in parallel when possible) + if non_approval_tools: + # 1. Parse args and validate upfront + parsed_tools: list[ + tuple[ChatCompletionMessageToolCall, str, dict, bool, str] + ] = [] + for tc in non_approval_tools: + tool_name = tc.function.name + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + tool_args = {} + + args_valid, error_msg = _validate_tool_args(tool_args) + parsed_tools.append( + (tc, tool_name, tool_args, args_valid, error_msg) + ) - # Execute non-approval tools first - for tc in non_approval_tools: - tool_name = tc.function.name - tool_args = json.loads(tc.function.arguments) - - # Validate tool arguments before calling - args_valid, error_msg = _validate_tool_args(tool_args) - if not args_valid: - # Return error to agent instead of calling tool - output = error_msg - success = False - else: - await session.send_event( - Event( - event_type="tool_call", - data={"tool": tool_name, "arguments": tool_args}, + # 2. Send all tool_call events upfront (so frontend shows them all) + for tc, tool_name, tool_args, args_valid, _ in parsed_tools: + if args_valid: + await session.send_event( + Event( + event_type="tool_call", + data={ + "tool": tool_name, + "arguments": tool_args, + "tool_call_id": tc.id, + }, + ) ) - ) - output, success = await session.tool_router.call_tool( - tool_name, tool_args, session=session + # 3. Execute all valid tools in parallel + async def _exec_tool( + tc: ChatCompletionMessageToolCall, + name: str, + args: dict, + valid: bool, + err: str, + ) -> tuple[ChatCompletionMessageToolCall, str, dict, str, bool]: + if not valid: + return (tc, name, args, err, False) + out, ok = await session.tool_router.call_tool( + name, args, session=session ) + return (tc, name, args, out, ok) - # Add tool result to history - tool_msg = Message( - role="tool", - content=output, - tool_call_id=tc.id, - name=tool_name, + results = await asyncio.gather( + *[ + _exec_tool(tc, name, args, valid, err) + for tc, name, args, valid, err in parsed_tools + ] ) - session.context_manager.add_message(tool_msg) - await session.send_event( - Event( - event_type="tool_output", - data={ - "tool": tool_name, - "output": output, - "success": success, - }, + # 4. Record results and send outputs (order preserved) + for tc, tool_name, tool_args, output, success in results: + tool_msg = Message( + role="tool", + content=output, + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tool_name, + "tool_call_id": tc.id, + "output": output, + "success": success, + }, + ) ) - ) # If there are tools requiring approval, ask for batch approval if approval_required_tools: @@ -270,7 +442,10 @@ class Handlers: tools_data = [] for tc in approval_required_tools: tool_name = tc.function.name - tool_args = json.loads(tc.function.arguments) + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError): + tool_args = {} tools_data.append( { "tool": tool_name, @@ -339,11 +514,27 @@ class Handlers: @staticmethod async def undo(session: Session) -> None: - """Handle undo (like undo in codex.rs:1314)""" - # Remove last user turn and all following items - # Simplified: just remove last 2 items - for _ in range(min(2, len(session.context_manager.items))): - session.context_manager.items.pop() + """Remove the last complete turn (user msg + all assistant/tool msgs that follow). + + Anthropic requires every tool_use to have a matching tool_result, + so we can't just pop 2 items — we must pop everything back to + (and including) the last user message to keep the history valid. + """ + items = session.context_manager.items + if not items: + await session.send_event(Event(event_type="undo_complete")) + return + + # Pop from the end until we've removed the last user message + removed_user = False + while items: + msg = items.pop() + if getattr(msg, "role", None) == "user": + removed_user = True + break + + if not removed_user: + logger.warning("Undo: no user message found to remove") await session.send_event(Event(event_type="undo_complete")) @@ -371,6 +562,9 @@ class Handlers: # Create a map of tool_call_id -> approval decision approval_map = {a["tool_call_id"]: a for a in approvals} + for a in approvals: + if a.get("edited_script"): + logger.info(f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)") # Separate approved and rejected tool calls approved_tasks = [] @@ -378,36 +572,99 @@ class Handlers: for tc in tool_calls: tool_name = tc.function.name - tool_args = json.loads(tc.function.arguments) + try: + tool_args = json.loads(tc.function.arguments) + except (json.JSONDecodeError, TypeError) as e: + # Malformed arguments — treat as failed, notify agent + logger.warning(f"Malformed tool arguments for {tool_name}: {e}") + tool_msg = Message( + role="tool", + content=f"Malformed arguments: {e}", + tool_call_id=tc.id, + name=tool_name, + ) + session.context_manager.add_message(tool_msg) + await session.send_event( + Event( + event_type="tool_output", + data={ + "tool": tool_name, + "tool_call_id": tc.id, + "output": f"Malformed arguments: {e}", + "success": False, + }, + ) + ) + continue + approval_decision = approval_map.get(tc.id, {"approved": False}) if approval_decision.get("approved", False): - approved_tasks.append((tc, tool_name, tool_args)) + edited_script = approval_decision.get("edited_script") + was_edited = False + if edited_script and "script" in tool_args: + tool_args["script"] = edited_script + was_edited = True + logger.info(f"Using user-edited script for {tool_name} ({tc.id})") + approved_tasks.append((tc, tool_name, tool_args, was_edited)) else: rejected_tasks.append((tc, tool_name, approval_decision)) + # Notify frontend of approval decisions immediately (before execution) + for tc, tool_name, tool_args, _was_edited in approved_tasks: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "approved", + }, + ) + ) + for tc, tool_name, approval_decision in rejected_tasks: + await session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "rejected", + }, + ) + ) + # Execute all approved tools concurrently - async def execute_tool(tc, tool_name, tool_args): - """Execute a single tool and return its result""" + async def execute_tool(tc, tool_name, tool_args, was_edited): + """Execute a single tool and return its result. + + The TraceLog already exists on the frontend (created by + approval_required), so we send tool_state_change instead of + tool_call to avoid creating a duplicate. + """ await session.send_event( Event( - event_type="tool_call", - data={"tool": tool_name, "arguments": tool_args}, + event_type="tool_state_change", + data={ + "tool_call_id": tc.id, + "tool": tool_name, + "state": "running", + }, ) ) output, success = await session.tool_router.call_tool( - tool_name, tool_args, session=session + tool_name, tool_args, session=session, tool_call_id=tc.id ) - return (tc, tool_name, output, success) + return (tc, tool_name, output, success, was_edited) # Execute all approved tools concurrently and wait for ALL to complete if approved_tasks: results = await asyncio.gather( *[ - execute_tool(tc, tool_name, tool_args) - for tc, tool_name, tool_args in approved_tasks + execute_tool(tc, tool_name, tool_args, was_edited) + for tc, tool_name, tool_args, was_edited in approved_tasks ], return_exceptions=True, ) @@ -416,10 +673,13 @@ class Handlers: for result in results: if isinstance(result, Exception): # Handle execution error - print(f"Tool execution error: {result}") + logger.error(f"Tool execution error: {result}") continue - tc, tool_name, output, success = result + tc, tool_name, output, success, was_edited = result + + if was_edited: + output = f"[Note: The user edited the script before execution. The output below reflects the user-modified version, not your original script.]\n\n{output}" # Add tool result to context tool_msg = Message( @@ -435,6 +695,7 @@ class Handlers: event_type="tool_output", data={ "tool": tool_name, + "tool_call_id": tc.id, "output": output, "success": success, }, @@ -446,7 +707,14 @@ class Handlers: rejection_msg = "Job execution cancelled by user" user_feedback = approval_decision.get("feedback") if user_feedback: - rejection_msg += f". User feedback: {user_feedback}" + # Ensure feedback is a string and sanitize any problematic characters + feedback_str = str(user_feedback).strip() + # Remove any control characters that might break JSON parsing + feedback_str = "".join(char for char in feedback_str if ord(char) >= 32 or char in "\n\t") + rejection_msg += f". User feedback: {feedback_str}" + + # Ensure rejection_msg is a clean string + rejection_msg = str(rejection_msg).strip() tool_msg = Message( role="tool", @@ -461,6 +729,7 @@ class Handlers: event_type="tool_output", data={ "tool": tool_name, + "tool_call_id": tc.id, "output": rejection_msg, "success": False, }, @@ -478,11 +747,9 @@ class Handlers: """Handle shutdown (like shutdown in codex.rs:1329)""" # Save session trajectory if enabled (fire-and-forget, returns immediately) if session.config.save_sessions: - print("💾 Saving session...") + logger.info("Saving session...") repo_id = session.config.session_dataset_repo _ = session.save_and_upload_detached(repo_id) - # if local_path: - # print("✅ Session saved locally, upload in progress") session.is_running = False await session.send_event(Event(event_type="shutdown")) @@ -497,7 +764,7 @@ async def process_submission(session: Session, submission) -> bool: bool: True to continue, False to shutdown """ op = submission.operation - # print(f"📨 Received: {op.op_type.value}") + logger.debug("Received operation: %s", op.op_type.value) if op.op_type == OpType.USER_INPUT: text = op.data.get("text", "") if op.data else "" @@ -509,7 +776,6 @@ async def process_submission(session: Session, submission) -> bool: return True if op.op_type == OpType.COMPACT: - # compact from the frontend await _compact_and_notify(session) return True @@ -525,7 +791,7 @@ async def process_submission(session: Session, submission) -> bool: if op.op_type == OpType.SHUTDOWN: return not await Handlers.shutdown(session) - print(f"⚠️ Unknown operation: {op.op_type}") + logger.warning(f"Unknown operation: {op.op_type}") return True @@ -543,7 +809,7 @@ async def submission_loop( # Create session with tool router session = Session(event_queue, config=config, tool_router=tool_router) - print("Agent loop started") + logger.info("Agent loop started") # Retry any failed uploads from previous sessions (fire-and-forget) if config and config.save_sessions: @@ -567,25 +833,25 @@ async def submission_loop( if not should_continue: break except asyncio.CancelledError: - print("\n⚠️ Agent loop cancelled") + logger.warning("Agent loop cancelled") break except Exception as e: - print(f"❌ Error in agent loop: {e}") + logger.error(f"Error in agent loop: {e}") await session.send_event( Event(event_type="error", data={"error": str(e)}) ) - print("🛑 Agent loop exited") + logger.info("Agent loop exited") finally: # Emergency save if session saving is enabled and shutdown wasn't called properly if session.config.save_sessions and session.is_running: - print("\n💾 Emergency save: preserving session before exit...") + logger.info("Emergency save: preserving session before exit...") try: local_path = session.save_and_upload_detached( session.config.session_dataset_repo ) if local_path: - print("✅ Emergency save successful, upload in progress") + logger.info("Emergency save successful, upload in progress") except Exception as e: - print(f"❌ Emergency save failed: {e}") + logger.error(f"Emergency save failed: {e}") diff --git a/agent/core/session.py b/agent/core/session.py index 439260d5a9fc3992cc5ccb8aee3a26e56b41bc8c..6f88e3b44a6ef3469aeb90459cbcfde02e792d6b 100644 --- a/agent/core/session.py +++ b/agent/core/session.py @@ -1,5 +1,6 @@ import asyncio import json +import logging import subprocess import sys import uuid @@ -9,11 +10,48 @@ from enum import Enum from pathlib import Path from typing import Any, Optional -from litellm import get_max_tokens - from agent.config import Config from agent.context_manager.manager import ContextManager +logger = logging.getLogger(__name__) + +# Local max-token lookup — avoids litellm.get_max_tokens() which can hang +# on network calls for certain providers (known litellm issue). +_MAX_TOKENS_MAP: dict[str, int] = { + # Anthropic + "anthropic/claude-opus-4-5-20251101": 200_000, + "anthropic/claude-sonnet-4-5-20250929": 200_000, + "anthropic/claude-sonnet-4-20250514": 200_000, + "anthropic/claude-haiku-3-5-20241022": 200_000, + "anthropic/claude-3-5-sonnet-20241022": 200_000, + "anthropic/claude-3-opus-20240229": 200_000, + "huggingface/novita/minimax/minimax-m2.1": 196_608, + "huggingface/novita/moonshotai/kimi-k2.5": 262_144, + "huggingface/novita/zai-org/glm-5": 200_000, +} +_DEFAULT_MAX_TOKENS = 200_000 + + +def _get_max_tokens_safe(model_name: str) -> int: + """Return the max context window for a model without network calls.""" + tokens = _MAX_TOKENS_MAP.get(model_name) + if tokens: + return tokens + # Fallback: try litellm but with a short timeout via threading + try: + from litellm import get_max_tokens + + result = get_max_tokens(model_name) + if result and isinstance(result, int): + return result + logger.warning( + f"get_max_tokens returned {result} for {model_name}, using default" + ) + return _DEFAULT_MAX_TOKENS + except Exception as e: + logger.warning(f"get_max_tokens failed for {model_name}, using default: {e}") + return _DEFAULT_MAX_TOKENS + class OpType(Enum): USER_INPUT = "user_input" @@ -46,7 +84,7 @@ class Session: self.tool_router = tool_router tool_specs = tool_router.get_tool_specs_for_llm() if tool_router else [] self.context_manager = context_manager or ContextManager( - max_context=get_max_tokens(config.model_name), + max_context=_get_max_tokens_safe(config.model_name), compact_size=0.1, untouched_messages=5, tool_specs=tool_specs, @@ -59,6 +97,8 @@ class Session: self.is_running = True self.current_task: asyncio.Task | None = None self.pending_approval: Optional[dict[str, Any]] = None + # User's HF OAuth token — set by session_manager after construction + self.hf_token: Optional[str] = None self.sandbox = None # Session trajectory logging @@ -100,7 +140,7 @@ class Session: turns_since_last_save = self.turn_count - self.last_auto_save_turn if turns_since_last_save >= interval: - print(f"\n💾 Auto-saving session (turn {self.turn_count})...") + logger.info(f"Auto-saving session (turn {self.turn_count})...") # Fire-and-forget save - returns immediately self.save_and_upload_detached(self.config.session_dataset_repo) self.last_auto_save_turn = self.turn_count @@ -152,7 +192,7 @@ class Session: return str(filepath) except Exception as e: - print(f"Failed to save session locally: {e}") + logger.error(f"Failed to save session locally: {e}") return None def update_local_save_status( @@ -172,7 +212,7 @@ class Session: return True except Exception as e: - print(f"Failed to update local save status: {e}") + logger.error(f"Failed to update local save status: {e}") return False def save_and_upload_detached(self, repo_id: str) -> Optional[str]: @@ -203,7 +243,7 @@ class Session: start_new_session=True, # Detach from parent ) except Exception as e: - print(f"⚠️ Failed to spawn upload subprocess: {e}") + logger.warning(f"Failed to spawn upload subprocess: {e}") return local_path @@ -233,4 +273,4 @@ class Session: start_new_session=True, # Detach from parent ) except Exception as e: - print(f"⚠️ Failed to spawn retry subprocess: {e}") + logger.warning(f"Failed to spawn retry subprocess: {e}") diff --git a/agent/core/session_uploader.py b/agent/core/session_uploader.py index 2cbef9e306530e40b4cad310233a9952f737120e..ef2f9496d87f832489010f9a9529c538d939bedb 100644 --- a/agent/core/session_uploader.py +++ b/agent/core/session_uploader.py @@ -15,10 +15,8 @@ from dotenv import load_dotenv load_dotenv() -# Fallback token for session uploads (write-only access to akseljoonas/hf-agent-sessions) -_SESSION_TOKEN = "".join([ - "hf_", "Nzya", "Eeb", "ESz", "DtA", "BoW", "Czj", "SEC", "ZZv", "kVL", "Ac", "Vf", "Sz" -]) +# Token for session uploads — loaded from env var (never hardcode tokens in source) +_SESSION_TOKEN = os.environ.get("HF_SESSION_UPLOAD_TOKEN", "") def upload_session_as_file( diff --git a/agent/core/tools.py b/agent/core/tools.py index 586afab5eb074b4e1c391f140edc30d90c52c0a2..21da0ffa0a9643344946ff567240dc3a0c595529 100644 --- a/agent/core/tools.py +++ b/agent/core/tools.py @@ -3,10 +3,13 @@ Tool system for the agent Provides ToolSpec and ToolRouter for managing both built-in and MCP tools """ +import logging import warnings from dataclasses import dataclass from typing import Any, Awaitable, Callable, Optional +logger = logging.getLogger(__name__) + from fastmcp import Client from fastmcp.exceptions import ToolError from lmnr import observe @@ -132,6 +135,7 @@ class ToolRouter: for tool in create_builtin_tools(): self.register_tool(tool) + self.mcp_client: Client | None = None if mcp_servers: mcp_servers_payload = {} for name, server in mcp_servers.items(): @@ -159,7 +163,7 @@ class ToolRouter: handler=None, ) ) - print( + logger.info( f"Loaded {len(registered_names)} MCP tools: {', '.join(registered_names)} ({skipped_count} disabled)" ) @@ -180,7 +184,7 @@ class ToolRouter: handler=search_openapi_handler, ) ) - print(f"Loaded OpenAPI search tool: {openapi_spec['name']}") + logger.info(f"Loaded OpenAPI search tool: {openapi_spec['name']}") def get_tool_specs_for_llm(self) -> list[dict[str, Any]]: """Get tool specifications in OpenAI format""" @@ -209,7 +213,7 @@ class ToolRouter: await self.register_openapi_tool() total_tools = len(self.tools) - print(f"\nAgent ready with {total_tools} tools total\n") + logger.info(f"Agent ready with {total_tools} tools total") return self @@ -220,7 +224,7 @@ class ToolRouter: @observe(name="call_tool") async def call_tool( - self, tool_name: str, arguments: dict[str, Any], session: Any = None + self, tool_name: str, arguments: dict[str, Any], session: Any = None, tool_call_id: str | None = None ) -> tuple[str, bool]: """ Call a tool and return (output_string, success_bool). @@ -236,6 +240,9 @@ class ToolRouter: # Check if handler accepts session argument sig = inspect.signature(tool.handler) if "session" in sig.parameters: + # Check if handler also accepts tool_call_id parameter + if "tool_call_id" in sig.parameters: + return await tool.handler(arguments, session=session, tool_call_id=tool_call_id) return await tool.handler(arguments, session=session) return await tool.handler(arguments) @@ -328,10 +335,10 @@ def create_builtin_tools() -> list[ToolSpec]: ), ] - # Sandbox tools + # Sandbox tools (highest priority) tools = get_sandbox_tools() + tools tool_names = ", ".join([t.name for t in tools]) - print(f"Loaded {len(tools)} built-in tools: {tool_names}") + logger.info(f"Loaded {len(tools)} built-in tools: {tool_names}") return tools diff --git a/agent/prompts/system_prompt.yaml b/agent/prompts/system_prompt.yaml index fc9607dd2e1d38cd5e0bde7ccd938cdf6a131645..00f28be1718457e1d1004c7e3745f903f14d5ab1 100644 --- a/agent/prompts/system_prompt.yaml +++ b/agent/prompts/system_prompt.yaml @@ -1,5 +1,5 @@ system_prompt: | - You are Hugging Face Agent, a skilled AI assistant for machine learning engineering. Hugging Face is a company that provides two main services : libraries to write deep learning tasks, and ressources (models, datasets, compute) to execute them. You will aid users to do theses tasks, interacting with the Hugging Face stack via {{ num_tools }}. + You are Hugging Face Agent, a skilled AI assistant for machine learning engineering. Hugging Face is a company that provides two main services : libraries to write deep learning tasks, and resources (models, datasets, compute) to execute them. You will aid users to do these tasks, interacting with the Hugging Face stack via {{ num_tools }}. # General behavior @@ -9,7 +9,7 @@ system_prompt: | **CRITICAL : Research first, Then Implement** - For ANY implementation task (training, fine-tuning, inference, data processing, etc.), you should proceed in thoses three mandatory steps: + For ANY implementation task (training, fine-tuning, inference, data processing, etc.), you should proceed in these three mandatory steps: 1. **FIRST**: Search HF documentation to find the correct approach. - Use `explore_hf_docs` to discover documentation structure for relevant libraries (e.g., "trl", "transformers", "diffusers"). diff --git a/agent/prompts/system_prompt_v2.yaml b/agent/prompts/system_prompt_v2.yaml index 9f80bbefb98424d6ac281c628442322aea3c9fc7..d404b2788fe887a1a6f0f326961b284efbc9ca09 100644 --- a/agent/prompts/system_prompt_v2.yaml +++ b/agent/prompts/system_prompt_v2.yaml @@ -186,59 +186,61 @@ system_prompt: | 3. ✅ Determine optimal processing approach based on requirements 4. ✅ Plan output format and destination - ## PHASE 3: IMPLEMENT (Develop in Sandbox, Launch via Jobs) - - ⚠️ **CRITICAL WORKFLOW: Sandbox First, Jobs Second** - - For ANY implementation task (training, data processing, inference), follow this pattern: - - **Step 1: Create a sandbox** — `sandbox_create` with appropriate hardware (cpu-basic for scripting, t4-small for GPU testing) - **Step 2: Develop & iterate** — Write scripts, install dependencies, test with small runs, fix errors interactively - **Step 3: Launch via hf_jobs** — Once the script works, pass the sandbox file path directly: `hf_jobs(operation="run", script="/app/train.py", ...)` - - This is the CORRECT pattern: - ``` - sandbox_create(hardware="t4-small") # interactive dev environment - bash("pip install trl transformers") # install deps - write("/app/train.py", "...") # write training script - bash("cd /app && python train.py --max_steps 10") # test run - edit("/app/train.py", ...) # fix issues - bash("cd /app && python train.py --max_steps 10") # verify fix - hf_jobs(operation="run", script="/app/train.py", hardware_flavor="a10g-large", timeout="4h") # launch at scale - ``` - - Do NOT write long inline scripts directly in hf_jobs if necessary — develop in sandbox first. - - ### Training Script Requirements - - **Script MUST Include:** - - Imports from researched documentation (current APIs) - - Trackio initialization with project/run_name/config - - Model and tokenizer loading - - Dataset loading with verified columns and conversational format - - Training config with ALL critical settings: + ## PHASE 3: IMPLEMENT (Execute with Researched Approaches) + + ### For Training Tasks + + ⚠️ **TRAINING REQUIREMENTS CHECKLIST:** + + **Before Submission:** + - [ ] Researched current TRL documentation + - [ ] Found and verified base model + - [ ] Found dataset and VALIDATED columns and conversational format matches method + - [ ] Selected optimal model + dataset + hardware configuration + - [ ] Created plan with plan_tool + - [ ] Researched Trackio monitoring setup + + **Training Script MUST Include:** + - [ ] Imports from researched documentation (current APIs) + - [ ] Trackio initialization with project/run_name/config + - [ ] Model and tokenizer loading + - [ ] Dataset loading with verified columns and conversational format + - [ ] Training config with ALL critical settings: - `push_to_hub=True` ⚠️ MANDATORY - `hub_model_id="username/model-name"` ⚠️ MANDATORY - `report_to=["trackio"]` (for monitoring) - `output_dir="./output"` - `num_train_epochs`, `per_device_train_batch_size`, `learning_rate` - `logging_steps`, `save_steps` - - `trainer.train()` call - - `trainer.push_to_hub()` at end ⚠️ MANDATORY - - **hf_jobs Launch Configuration:** - - `script`: Path to sandbox file (e.g. "/app/train.py") or inline code - - `dependencies`: ['transformers', 'trl', 'torch', 'datasets', 'trackio'] - - `hardware_flavor`: Based on model size: - - 1-3B models: `t4-small` or `a10g-small` - - 7-13B models: `a10g-large` - - 30B+ models: `a100-large` - - 70B+ models: `h100` or `h100x8` - - `timeout`: ⚠️ CRITICAL — Small (2-4h), Medium (4-8h), Large (8-24h). NEVER default 30m for training. + - `max_length` if needed (default 1024 usually fine) + - [ ] Trainer initialization with model, args, dataset, tokenizer + - [ ] `trainer.train()` call + - [ ] `trainer.push_to_hub()` at end ⚠️ MANDATORY + - [ ] `tracker.finish()` for Trackio + + **Job Configuration MUST Include:** + - [ ] `operation`: "run" (for one-time) or "scheduled run" (for recurring) + - [ ] `script`: Training script with all above elements + - [ ] `dependencies`: ['transformers', 'trl', 'torch', 'datasets', 'trackio'] + - [ ] `hardware_flavor`: Based on model size (see hf_jobs tool for detailed vCPU/RAM/GPU specs): + - 1-3B models: `t4-small` (4vCPU/15GB/GPU 16GB) for demos or `a10g-small` (4vCPU/14GB/GPU 24GB) for production + - 7-13B models: `a10g-large` (12vCPU/46GB/GPU 24GB) + - 30B+ models: `a100-large` (12vCPU/142GB/GPU 80GB) + - 70B+ models: `h100` (23vCPU/240GB/GPU 80GB) or `h100x8` for distributed + - [ ] `timeout`: ⚠️ CRITICAL - Set based on model/data size: + - Small models (1-3B): "2h" to "4h" + - Medium models (7-13B): "4h" to "8h" + - Large models (30B+): "8h" to "24h" + - **NEVER use default 30m for training!** ### For Data Processing Tasks - **Same pattern:** develop script in sandbox, test on subset, launch via hf_jobs. + **Script Requirements:** + - Load dataset with `load_dataset` + - Process according to user requirements + - Push results with `push_to_hub()` or upload to `hf_private_repos` + + **Job Configuration:** - Use `cpu-upgrade` or `cpu-performance` for most data tasks - Set timeout based on dataset size (1-4 hours typical) @@ -339,21 +341,6 @@ system_prompt: | - ⚠️ Include HF_TOKEN for Hub operations - ⚠️ Storage is EPHEMERAL - must push_to_hub - ## Sandbox (Interactive Development Environment) - - **sandbox_create:** - - ⚠️ **Create a sandbox FIRST for any implementation task** — develop and test before launching jobs - - Persistent remote Linux environment on HF Spaces - - First call sandbox_create with hardware choice, then use bash/read/write/edit freely - - Hardware: cpu-basic (free tier), cpu-upgrade (8vCPU/32GB), t4-small (16GB GPU), a10g-small (24GB GPU), a10g-large (24GB GPU + 46GB RAM), a100-large (80GB GPU) - - `pip install` works out of the box — no special flags needed - - Workflow: sandbox_create → write script → test → fix → hf_jobs(script="/app/script.py") to launch at scale - - **bash / read / write / edit:** - - Available after sandbox_create — no additional approvals needed - - Same semantics as local file/shell operations, but run on the remote sandbox - - bash: run shell commands; read/write/edit: file operations - **hf_private_repos:** - Store job outputs persistently in datasets with push_to_hub (jobs lose files after completion) - Upload logs, scripts, results that can't push_to_hub diff --git a/agent/tools/jobs_tool.py b/agent/tools/jobs_tool.py index c9819b5c75f65e9c9154b40f66637c8935a8fe03..156e3f88110103af2b0292867b5ffeaf34d01594 100644 --- a/agent/tools/jobs_tool.py +++ b/agent/tools/jobs_tool.py @@ -9,7 +9,9 @@ import base64 import http.client import os import re -from typing import Any, Awaitable, Callable, Dict, Literal, Optional +from typing import Any, Dict, Literal, Optional, Callable, Awaitable + +import logging import httpx from huggingface_hub import HfApi @@ -17,6 +19,8 @@ from huggingface_hub.utils import HfHubHTTPError from agent.core.session import Event from agent.tools.types import ToolResult + +logger = logging.getLogger(__name__) from agent.tools.utilities import ( format_job_details, format_jobs_table, @@ -128,8 +132,11 @@ def _add_default_env(params: Dict[str, Any] | None) -> Dict[str, Any]: return result -def _add_environment_variables(params: Dict[str, Any] | None) -> Dict[str, Any]: - token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or "" +def _add_environment_variables( + params: Dict[str, Any] | None, user_token: str | None = None +) -> Dict[str, Any]: + # Prefer the authenticated user's OAuth token, fall back to global env var + token = user_token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") or "" # Start with user-provided env vars, then force-set token last result = dict(params or {}) @@ -285,10 +292,15 @@ class HfJobsTool: hf_token: Optional[str] = None, namespace: Optional[str] = None, log_callback: Optional[Callable[[str], Awaitable[None]]] = None, + session: Any = None, + tool_call_id: Optional[str] = None, ): + self.hf_token = hf_token self.api = HfApi(token=hf_token) self.namespace = namespace self.log_callback = log_callback + self.session = session + self.tool_call_id = tool_call_id async def execute(self, params: Dict[str, Any]) -> ToolResult: """Execute the specified operation""" @@ -384,9 +396,7 @@ class HfJobsTool: def log_producer(): try: # fetch_job_logs is a blocking sync generator - logs_gen = self.api.fetch_job_logs( - job_id=job_id, namespace=namespace - ) + logs_gen = self.api.fetch_job_logs(job_id=job_id, namespace=namespace) for line in logs_gen: # Push line to queue thread-safely loop.call_soon_threadsafe(queue.put_nowait, line) @@ -413,7 +423,7 @@ class HfJobsTool: # Process log line log_line = item - print("\t" + log_line) + logger.debug(log_line) if self.log_callback: await self.log_callback(log_line) all_logs.append(log_line) @@ -441,19 +451,19 @@ class HfJobsTool: if current_status in terminal_states: # Job finished, no need to retry - print(f"\tJob reached terminal state: {current_status}") + logger.info(f"Job reached terminal state: {current_status}") break # Job still running, retry connection - print( - f"\tConnection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..." + logger.warning( + f"Connection interrupted ({str(e)[:50]}...), reconnecting in {retry_delay}s..." ) await asyncio.sleep(retry_delay) continue except (ConnectionError, TimeoutError, OSError): # Can't even check job status, wait and retry - print(f"\tConnection error, retrying in {retry_delay}s...") + logger.warning(f"Connection error, retrying in {retry_delay}s...") await asyncio.sleep(retry_delay) continue @@ -510,15 +520,29 @@ class HfJobsTool: image=image, command=command, env=_add_default_env(args.get("env")), - secrets=_add_environment_variables(args.get("secrets")), + secrets=_add_environment_variables(args.get("secrets"), self.hf_token), flavor=args.get("hardware_flavor", "cpu-basic"), timeout=args.get("timeout", "30m"), namespace=self.namespace, ) + # Send job URL immediately after job creation (before waiting for completion) + if self.session and self.tool_call_id: + await self.session.send_event( + Event( + event_type="tool_state_change", + data={ + "tool_call_id": self.tool_call_id, + "tool": "hf_jobs", + "state": "running", + "jobUrl": job.url, + }, + ) + ) + # Wait for completion and stream logs - print(f"{job_type} job started: {job.url}") - print("Streaming logs...\n---\n") + logger.info(f"{job_type} job started: {job.url}") + logger.info("Streaming logs...") final_status, all_logs = await self._wait_for_job_completion( job_id=job.id, @@ -728,7 +752,7 @@ To verify, call this tool with `{{"operation": "inspect", "job_id": "{job_id}"}} command=command, schedule=schedule, env=_add_default_env(args.get("env")), - secrets=_add_environment_variables(args.get("secrets")), + secrets=_add_environment_variables(args.get("secrets"), self.hf_token), flavor=args.get("hardware_flavor", "cpu-basic"), timeout=args.get("timeout", "30m"), namespace=self.namespace, @@ -998,7 +1022,7 @@ HF_JOBS_TOOL_SPEC = { async def hf_jobs_handler( - arguments: Dict[str, Any], session: Any = None + arguments: Dict[str, Any], session: Any = None, tool_call_id: str | None = None ) -> tuple[str, bool]: """Handler for agent tool router""" try: @@ -1031,14 +1055,20 @@ async def hf_jobs_handler( return f"Failed to read {script} from sandbox: {result.error}", False arguments = {**arguments, "script": result.output} - # Get token and namespace from HF token - hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN") - namespace = HfApi(token=hf_token).whoami().get("name") if hf_token else None + # Prefer the authenticated user's OAuth token, fall back to global env + hf_token = ( + (getattr(session, "hf_token", None) if session else None) + or os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACE_HUB_TOKEN") + ) + namespace = os.environ.get("HF_NAMESPACE") or (HfApi(token=hf_token).whoami().get("name") if hf_token else None) tool = HfJobsTool( namespace=namespace, hf_token=hf_token, log_callback=log_callback if session else None, + session=session, + tool_call_id=tool_call_id, ) result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) diff --git a/agent/tools/sandbox_tool.py b/agent/tools/sandbox_tool.py index 61dea149a2605ecc4789d82f96ddeab9264058c9..936803b7866535bdd5faa04d2a6f3f32e10d65d2 100644 --- a/agent/tools/sandbox_tool.py +++ b/agent/tools/sandbox_tool.py @@ -38,9 +38,13 @@ async def _ensure_sandbox( if not session: return None, "No session available." - token = os.environ.get("HF_TOKEN") + token = ( + getattr(session, "hf_token", None) + or os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACE_HUB_TOKEN") + ) if not token: - return None, "HF_TOKEN environment variable not set. Cannot create sandbox." + return None, "No HF token available. Cannot create sandbox." api = HfApi(token=token) user_info = api.whoami() diff --git a/backend/dependencies.py b/backend/dependencies.py new file mode 100644 index 0000000000000000000000000000000000000000..03a1bb284507b6b78ad2a7534492934e416f6bed --- /dev/null +++ b/backend/dependencies.py @@ -0,0 +1,144 @@ +"""Authentication dependencies for FastAPI routes. + +Provides auth validation for both REST and WebSocket endpoints. +- In dev mode (OAUTH_CLIENT_ID not set): auth is bypassed, returns a default "dev" user. +- In production: validates Bearer tokens or cookies against HF OAuth. +""" + +import logging +import os +import time +from typing import Any + +import httpx +from fastapi import HTTPException, Request, WebSocket, status + +logger = logging.getLogger(__name__) + +OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") +AUTH_ENABLED = bool(os.environ.get("OAUTH_CLIENT_ID", "")) + +# Simple in-memory token cache: token -> (user_info, expiry_time) +_token_cache: dict[str, tuple[dict[str, Any], float]] = {} +TOKEN_CACHE_TTL = 300 # 5 minutes + +DEV_USER: dict[str, Any] = { + "user_id": "dev", + "username": "dev", + "authenticated": True, +} + + +async def _validate_token(token: str) -> dict[str, Any] | None: + """Validate a token against HF OAuth userinfo endpoint. + + Results are cached for TOKEN_CACHE_TTL seconds to avoid excessive API calls. + """ + now = time.time() + + # Check cache + if token in _token_cache: + user_info, expiry = _token_cache[token] + if now < expiry: + return user_info + del _token_cache[token] + + # Validate against HF + async with httpx.AsyncClient(timeout=10.0) as client: + try: + response = await client.get( + f"{OPENID_PROVIDER_URL}/oauth/userinfo", + headers={"Authorization": f"Bearer {token}"}, + ) + if response.status_code != 200: + logger.debug("Token validation failed: status %d", response.status_code) + return None + user_info = response.json() + _token_cache[token] = (user_info, now + TOKEN_CACHE_TTL) + return user_info + except httpx.HTTPError as e: + logger.warning("Token validation error: %s", e) + return None + + +def _user_from_info(user_info: dict[str, Any]) -> dict[str, Any]: + """Build a normalized user dict from HF userinfo response.""" + return { + "user_id": user_info.get("sub", user_info.get("preferred_username", "unknown")), + "username": user_info.get("preferred_username", "unknown"), + "name": user_info.get("name"), + "picture": user_info.get("picture"), + "authenticated": True, + } + + +async def _extract_user_from_token(token: str) -> dict[str, Any] | None: + """Validate a token and return a user dict, or None.""" + user_info = await _validate_token(token) + if user_info: + return _user_from_info(user_info) + return None + + +async def get_current_user(request: Request) -> dict[str, Any]: + """FastAPI dependency: extract and validate the current user. + + Checks (in order): + 1. Authorization: Bearer header + 2. hf_access_token cookie + + In dev mode (AUTH_ENABLED=False), returns a default dev user. + """ + if not AUTH_ENABLED: + return DEV_USER + + # Try Authorization header + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + token = auth_header[7:] + user = await _extract_user_from_token(token) + if user: + return user + + # Try cookie + token = request.cookies.get("hf_access_token") + if token: + user = await _extract_user_from_token(token) + if user: + return user + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated. Please log in via /auth/login.", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +async def get_ws_user(websocket: WebSocket) -> dict[str, Any] | None: + """Extract and validate user from WebSocket connection. + + WebSocket doesn't support custom headers from browser, so we check: + 1. ?token= query parameter + 2. hf_access_token cookie (sent automatically for same-origin) + + Returns user dict or None if not authenticated. + In dev mode, returns the default dev user. + """ + if not AUTH_ENABLED: + return DEV_USER + + # Try query param + token = websocket.query_params.get("token") + if token: + user = await _extract_user_from_token(token) + if user: + return user + + # Try cookie (works for same-origin WebSocket) + token = websocket.cookies.get("hf_access_token") + if token: + user = await _extract_user_from_token(token) + if user: + return user + + return None diff --git a/backend/main.py b/backend/main.py index 2ea33e05b92332f2bc4e32ba160b4029fdb69e31..fc75ab9e11696664776cc2370d68e589196af7ad 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,6 +5,14 @@ import os from contextlib import asynccontextmanager from pathlib import Path +from dotenv import load_dotenv + +load_dotenv() + +# Ensure HF_TOKEN is set — fall back to HF_ADMIN_TOKEN if available (HF Spaces) +if not os.environ.get("HF_TOKEN") and os.environ.get("HF_ADMIN_TOKEN"): + os.environ["HF_TOKEN"] = os.environ["HF_ADMIN_TOKEN"] + from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles diff --git a/backend/models.py b/backend/models.py index 0c529522f88804e8ccdc97224359a6ed62462ab2..f22ab3048b1d1a18e75517e1475521f59bce526d 100644 --- a/backend/models.py +++ b/backend/models.py @@ -37,6 +37,7 @@ class ToolApproval(BaseModel): tool_call_id: str approved: bool feedback: str | None = None + edited_script: str | None = None class ApprovalRequest(BaseModel): @@ -67,6 +68,7 @@ class SessionInfo(BaseModel): created_at: str is_active: bool message_count: int + user_id: str = "dev" class HealthResponse(BaseModel): @@ -74,3 +76,13 @@ class HealthResponse(BaseModel): status: str = "ok" active_sessions: int = 0 + max_sessions: int = 0 + + +class LLMHealthResponse(BaseModel): + """LLM provider health check response.""" + + status: str # "ok" | "error" + model: str + error: str | None = None + error_type: str | None = None # "auth" | "credits" | "rate_limit" | "network" | "unknown" diff --git a/backend/routes/agent.py b/backend/routes/agent.py index 8a3db071ef8707d8e79ac272facb203f44cd9857..fed198d672ff147e60ecbced0c60c474e964eced 100644 --- a/backend/routes/agent.py +++ b/backend/routes/agent.py @@ -1,58 +1,252 @@ -"""Agent API routes - WebSocket and REST endpoints.""" +"""Agent API routes - WebSocket and REST endpoints. -import logging +All routes (except /health) require authentication via the get_current_user +dependency. In dev mode (no OAUTH_CLIENT_ID), auth is bypassed automatically. +""" -from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect +import logging +import os +from typing import Any + +from dependencies import get_current_user, get_ws_user +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + WebSocket, + WebSocketDisconnect, +) +from litellm import acompletion +from agent.core.agent_loop import _resolve_hf_router_params from models import ( ApprovalRequest, HealthResponse, + LLMHealthResponse, SessionInfo, SessionResponse, SubmitRequest, ) -from session_manager import session_manager +from session_manager import MAX_SESSIONS, SessionCapacityError, session_manager from websocket import manager as ws_manager logger = logging.getLogger(__name__) router = APIRouter(prefix="/api", tags=["agent"]) +AVAILABLE_MODELS = [ + { + "id": "huggingface/novita/minimax/minimax-m2.1", + "label": "MiniMax M2.1", + "provider": "huggingface", + "recommended": True, + }, + { + "id": "anthropic/claude-opus-4-5-20251101", + "label": "Claude Opus 4.5", + "provider": "anthropic", + "recommended": True, + }, + { + "id": "huggingface/novita/moonshotai/kimi-k2.5", + "label": "Kimi K2.5", + "provider": "huggingface", + }, + { + "id": "huggingface/novita/zai-org/glm-5", + "label": "GLM 5", + "provider": "huggingface", + }, +] + + +def _check_session_access(session_id: str, user: dict[str, Any]) -> None: + """Verify the user has access to the given session. Raises 403 or 404.""" + info = session_manager.get_session_info(session_id) + if not info: + raise HTTPException(status_code=404, detail="Session not found") + if not session_manager.verify_session_access(session_id, user["user_id"]): + raise HTTPException(status_code=403, detail="Access denied to this session") + @router.get("/health", response_model=HealthResponse) async def health_check() -> HealthResponse: """Health check endpoint.""" return HealthResponse( - status="ok", active_sessions=session_manager.active_session_count + status="ok", + active_sessions=session_manager.active_session_count, + max_sessions=MAX_SESSIONS, ) +@router.get("/health/llm", response_model=LLMHealthResponse) +async def llm_health_check() -> LLMHealthResponse: + """Check if the LLM provider is reachable and the API key is valid. + + Makes a minimal 1-token completion call. Catches common errors: + - 401 → invalid API key + - 402/insufficient_quota → out of credits + - 429 → rate limited + - timeout / network → provider unreachable + """ + model = session_manager.config.model_name + try: + llm_params = _resolve_hf_router_params(model) + await acompletion( + messages=[{"role": "user", "content": "hi"}], + max_tokens=1, + timeout=10, + **llm_params, + ) + return LLMHealthResponse(status="ok", model=model) + except Exception as e: + err_str = str(e).lower() + error_type = "unknown" + + if ( + "401" in err_str + or "auth" in err_str + or "invalid" in err_str + or "api key" in err_str + ): + error_type = "auth" + elif ( + "402" in err_str + or "credit" in err_str + or "quota" in err_str + or "insufficient" in err_str + or "billing" in err_str + ): + error_type = "credits" + elif "429" in err_str or "rate" in err_str: + error_type = "rate_limit" + elif "timeout" in err_str or "connect" in err_str or "network" in err_str: + error_type = "network" + + logger.warning(f"LLM health check failed ({error_type}): {e}") + return LLMHealthResponse( + status="error", + model=model, + error=str(e)[:500], + error_type=error_type, + ) + + +@router.get("/config/model") +async def get_model() -> dict: + """Get current model and available models. No auth required.""" + return { + "current": session_manager.config.model_name, + "available": AVAILABLE_MODELS, + } + + +@router.post("/config/model") +async def set_model(body: dict, user: dict = Depends(get_current_user)) -> dict: + """Set the LLM model. Applies to new conversations.""" + model_id = body.get("model") + if not model_id: + raise HTTPException(status_code=400, detail="Missing 'model' field") + valid_ids = {m["id"] for m in AVAILABLE_MODELS} + if model_id not in valid_ids: + raise HTTPException(status_code=400, detail=f"Unknown model: {model_id}") + session_manager.config.model_name = model_id + logger.info(f"Model changed to {model_id} by {user.get('username', 'unknown')}") + return {"model": model_id} + + +@router.post("/title") +async def generate_title( + request: SubmitRequest, user: dict = Depends(get_current_user) +) -> dict: + """Generate a short title for a chat session based on the first user message.""" + model = session_manager.config.model_name + llm_params = _resolve_hf_router_params(model) + try: + response = await acompletion( + messages=[ + { + "role": "system", + "content": ( + "Generate a very short title (max 6 words) for a chat conversation " + "that starts with the following user message. " + "Reply with ONLY the title, no quotes, no punctuation at the end." + ), + }, + {"role": "user", "content": request.text[:500]}, + ], + max_tokens=20, + temperature=0.3, + timeout=8, + **llm_params, + ) + title = response.choices[0].message.content.strip().strip('"').strip("'") + # Safety: cap at 50 chars + if len(title) > 50: + title = title[:50].rstrip() + "…" + return {"title": title} + except Exception as e: + logger.warning(f"Title generation failed: {e}") + # Fallback: truncate the message + fallback = request.text.strip() + title = fallback[:40].rstrip() + "…" if len(fallback) > 40 else fallback + return {"title": title} + + @router.post("/session", response_model=SessionResponse) -async def create_session() -> SessionResponse: - """Create a new agent session.""" - session_id = await session_manager.create_session() +async def create_session( + request: Request, user: dict = Depends(get_current_user) +) -> SessionResponse: + """Create a new agent session bound to the authenticated user. + + The user's HF access token is extracted from the Authorization header + and stored in the session so that tools (e.g. hf_jobs) can act on + behalf of the user. + + Returns 503 if the server or user has reached the session limit. + """ + # Extract the user's HF token (Bearer header or HttpOnly cookie) + hf_token = None + auth_header = request.headers.get("Authorization", "") + if auth_header.startswith("Bearer "): + hf_token = auth_header[7:] + if not hf_token: + hf_token = request.cookies.get("hf_access_token") + + try: + session_id = await session_manager.create_session( + user_id=user["user_id"], hf_token=hf_token + ) + except SessionCapacityError as e: + raise HTTPException(status_code=503, detail=str(e)) + return SessionResponse(session_id=session_id, ready=True) @router.get("/session/{session_id}", response_model=SessionInfo) -async def get_session(session_id: str) -> SessionInfo: - """Get session information.""" +async def get_session( + session_id: str, user: dict = Depends(get_current_user) +) -> SessionInfo: + """Get session information. Only accessible by the session owner.""" + _check_session_access(session_id, user) info = session_manager.get_session_info(session_id) - if not info: - raise HTTPException(status_code=404, detail="Session not found") return SessionInfo(**info) @router.get("/sessions", response_model=list[SessionInfo]) -async def list_sessions() -> list[SessionInfo]: - """List all sessions.""" - sessions = session_manager.list_sessions() +async def list_sessions(user: dict = Depends(get_current_user)) -> list[SessionInfo]: + """List sessions belonging to the authenticated user.""" + sessions = session_manager.list_sessions(user_id=user["user_id"]) return [SessionInfo(**s) for s in sessions] @router.delete("/session/{session_id}") -async def delete_session(session_id: str) -> dict: - """Delete a session.""" +async def delete_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: + """Delete a session. Only accessible by the session owner.""" + _check_session_access(session_id, user) success = await session_manager.delete_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found") @@ -60,8 +254,11 @@ async def delete_session(session_id: str) -> dict: @router.post("/submit") -async def submit_input(request: SubmitRequest) -> dict: - """Submit user input to a session.""" +async def submit_input( + request: SubmitRequest, user: dict = Depends(get_current_user) +) -> dict: + """Submit user input to a session. Only accessible by the session owner.""" + _check_session_access(request.session_id, user) success = await session_manager.submit_user_input(request.session_id, request.text) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -69,13 +266,17 @@ async def submit_input(request: SubmitRequest) -> dict: @router.post("/approve") -async def submit_approval(request: ApprovalRequest) -> dict: - """Submit tool approvals to a session.""" +async def submit_approval( + request: ApprovalRequest, user: dict = Depends(get_current_user) +) -> dict: + """Submit tool approvals to a session. Only accessible by the session owner.""" + _check_session_access(request.session_id, user) approvals = [ { "tool_call_id": a.tool_call_id, "approved": a.approved, "feedback": a.feedback, + "edited_script": a.edited_script, } for a in request.approvals ] @@ -86,8 +287,11 @@ async def submit_approval(request: ApprovalRequest) -> dict: @router.post("/interrupt/{session_id}") -async def interrupt_session(session_id: str) -> dict: +async def interrupt_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: """Interrupt the current operation in a session.""" + _check_session_access(session_id, user) success = await session_manager.interrupt(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -95,8 +299,9 @@ async def interrupt_session(session_id: str) -> dict: @router.post("/undo/{session_id}") -async def undo_session(session_id: str) -> dict: +async def undo_session(session_id: str, user: dict = Depends(get_current_user)) -> dict: """Undo the last turn in a session.""" + _check_session_access(session_id, user) success = await session_manager.undo(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -104,8 +309,11 @@ async def undo_session(session_id: str) -> dict: @router.post("/compact/{session_id}") -async def compact_session(session_id: str) -> dict: +async def compact_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: """Compact the context in a session.""" + _check_session_access(session_id, user) success = await session_manager.compact(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -113,8 +321,11 @@ async def compact_session(session_id: str) -> dict: @router.post("/shutdown/{session_id}") -async def shutdown_session(session_id: str) -> dict: +async def shutdown_session( + session_id: str, user: dict = Depends(get_current_user) +) -> dict: """Shutdown a session.""" + _check_session_access(session_id, user) success = await session_manager.shutdown_session(session_id) if not success: raise HTTPException(status_code=404, detail="Session not found or inactive") @@ -123,17 +334,61 @@ async def shutdown_session(session_id: str) -> dict: @router.websocket("/ws/{session_id}") async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None: - """WebSocket endpoint for real-time events.""" + """WebSocket endpoint for real-time events. + + Authentication is done via: + - ?token= query parameter (for browsers that can't send WS headers) + - Cookie (automatic for same-origin connections) + - Dev mode bypass (when OAUTH_CLIENT_ID is not set) + + NOTE: We must accept() before close() so the browser receives our custom + close codes (4001, 4003, 4004). If we close() before accept(), Starlette + sends HTTP 403 and the browser only sees code 1006 (abnormal closure). + """ logger.info(f"WebSocket connection request for session {session_id}") + + # Authenticate the WebSocket connection + user = await get_ws_user(websocket) + if not user: + logger.warning( + f"WebSocket rejected: authentication failed for session {session_id}" + ) + await websocket.accept() + await websocket.close(code=4001, reason="Authentication required") + return + # Verify session exists info = session_manager.get_session_info(session_id) if not info: - logger.warning(f"WebSocket connection rejected: Session {session_id} not found") + logger.warning(f"WebSocket rejected: session {session_id} not found") + await websocket.accept() await websocket.close(code=4004, reason="Session not found") return + # Verify user owns the session + if not session_manager.verify_session_access(session_id, user["user_id"]): + logger.warning( + f"WebSocket rejected: user {user['user_id']} denied access to session {session_id}" + ) + await websocket.accept() + await websocket.close(code=4003, reason="Access denied") + return + await ws_manager.connect(websocket, session_id) + # Send "ready" immediately on WebSocket connection so the frontend + # knows the session is alive. The original ready event from _run_session + # fires before the WS is connected and is always lost. + try: + await websocket.send_json( + { + "event_type": "ready", + "data": {"message": "Agent initialized"}, + } + ) + except Exception as e: + logger.error(f"Failed to send ready event for session {session_id}: {e}") + try: while True: # Keep connection alive, handle ping/pong diff --git a/backend/routes/auth.py b/backend/routes/auth.py index a39bacd0b4696d2c3989ff161be3c54ffec2b14f..224febf4b926890eb58943e3103a985fe0ed4626 100644 --- a/backend/routes/auth.py +++ b/backend/routes/auth.py @@ -1,11 +1,17 @@ -"""Authentication routes for HF OAuth.""" +"""Authentication routes for HF OAuth. + +Handles the OAuth 2.0 authorization code flow with HF as provider. +After successful auth, sets an HttpOnly cookie with the access token. +""" import os import secrets +import time from urllib.parse import urlencode import httpx -from fastapi import APIRouter, HTTPException, Request +from dependencies import AUTH_ENABLED, get_current_user +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import RedirectResponse router = APIRouter(prefix="/auth", tags=["auth"]) @@ -15,10 +21,19 @@ OAUTH_CLIENT_ID = os.environ.get("OAUTH_CLIENT_ID", "") OAUTH_CLIENT_SECRET = os.environ.get("OAUTH_CLIENT_SECRET", "") OPENID_PROVIDER_URL = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co") -# In-memory session store (replace with proper session management in production) +# In-memory OAuth state store with expiry (5 min TTL) +_OAUTH_STATE_TTL = 300 oauth_states: dict[str, dict] = {} +def _cleanup_expired_states() -> None: + """Remove expired OAuth states to prevent memory growth.""" + now = time.time() + expired = [k for k, v in oauth_states.items() if now > v.get("expires_at", 0)] + for k in expired: + del oauth_states[k] + + def get_redirect_uri(request: Request) -> str: """Get the OAuth callback redirect URI.""" # In HF Spaces, use the SPACE_HOST if available @@ -38,17 +53,26 @@ async def oauth_login(request: Request) -> RedirectResponse: detail="OAuth not configured. Set OAUTH_CLIENT_ID environment variable.", ) + # Clean up expired states to prevent memory growth + _cleanup_expired_states() + # Generate state for CSRF protection state = secrets.token_urlsafe(32) - oauth_states[state] = {"redirect_uri": get_redirect_uri(request)} + oauth_states[state] = { + "redirect_uri": get_redirect_uri(request), + "expires_at": time.time() + _OAUTH_STATE_TTL, + } # Build authorization URL params = { "client_id": OAUTH_CLIENT_ID, "redirect_uri": get_redirect_uri(request), - "scope": "openid profile", + "scope": "openid profile read-repos write-repos contribute-repos manage-repos inference-api jobs write-discussions", "response_type": "code", "state": state, + "orgIds": os.environ.get( + "HF_OAUTH_ORG_ID", "698dbf55845d85df163175f1" + ), # ml-agent-explorers } auth_url = f"{OPENID_PROVIDER_URL}/oauth/authorize?{urlencode(params)}" @@ -91,58 +115,57 @@ async def oauth_callback( # Get user info access_token = token_data.get("access_token") - if access_token: - async with httpx.AsyncClient() as client: - try: - userinfo_response = await client.get( - f"{OPENID_PROVIDER_URL}/oauth/userinfo", - headers={"Authorization": f"Bearer {access_token}"}, - ) - userinfo_response.raise_for_status() - user_info = userinfo_response.json() - except httpx.HTTPError: - user_info = {} - else: - user_info = {} - - # For now, redirect to home with token in query params - # In production, use secure cookies or session storage - redirect_params = { - "access_token": access_token, - "username": user_info.get("preferred_username", ""), - } + if not access_token: + raise HTTPException( + status_code=500, + detail="Token exchange succeeded but no access_token was returned.", + ) - return RedirectResponse(url=f"/?{urlencode(redirect_params)}") + # Fetch user info (optional — failure is not fatal) + async with httpx.AsyncClient() as client: + try: + userinfo_response = await client.get( + f"{OPENID_PROVIDER_URL}/oauth/userinfo", + headers={"Authorization": f"Bearer {access_token}"}, + ) + userinfo_response.raise_for_status() + except httpx.HTTPError: + pass # user_info not required for auth flow + + # Set access token as HttpOnly cookie (not in URL — avoids leaks via + # Referrer headers, browser history, and server logs) + is_production = bool(os.environ.get("SPACE_HOST")) + response = RedirectResponse(url="/", status_code=302) + response.set_cookie( + key="hf_access_token", + value=access_token, + httponly=True, + secure=is_production, # Secure flag only in production (HTTPS) + samesite="lax", + max_age=3600 * 24, # 24 hours + path="/", + ) + return response @router.get("/logout") async def logout() -> RedirectResponse: - """Log out the user.""" - return RedirectResponse(url="/") + """Log out the user by clearing the auth cookie.""" + response = RedirectResponse(url="/") + response.delete_cookie(key="hf_access_token", path="/") + return response -@router.get("/me") -async def get_current_user(request: Request) -> dict: - """Get current user info from Authorization header.""" - auth_header = request.headers.get("Authorization", "") - if not auth_header.startswith("Bearer "): - return {"authenticated": False} +@router.get("/status") +async def auth_status() -> dict: + """Check if OAuth is enabled on this instance.""" + return {"auth_enabled": AUTH_ENABLED} - token = auth_header.split(" ")[1] - async with httpx.AsyncClient() as client: - try: - response = await client.get( - f"{OPENID_PROVIDER_URL}/oauth/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - user_info = response.json() - return { - "authenticated": True, - "username": user_info.get("preferred_username"), - "name": user_info.get("name"), - "picture": user_info.get("picture"), - } - except httpx.HTTPError: - return {"authenticated": False} +@router.get("/me") +async def get_me(user: dict = Depends(get_current_user)) -> dict: + """Get current user info. Returns the authenticated user or dev user. + + Uses the shared auth dependency which handles cookie + Bearer token. + """ + return user diff --git a/backend/session_manager.py b/backend/session_manager.py index 058b376123ff9d00013ce4fa7e49e7bbee5585d0..03d9b2d9b8d706f1fa391f69e43e759d77246b86 100644 --- a/backend/session_manager.py +++ b/backend/session_manager.py @@ -48,11 +48,28 @@ class AgentSession: session: Session tool_router: ToolRouter submission_queue: asyncio.Queue + user_id: str = "dev" # Owner of this session + hf_token: str | None = None # User's HF OAuth token for tool execution task: asyncio.Task | None = None created_at: datetime = field(default_factory=datetime.utcnow) is_active: bool = True +class SessionCapacityError(Exception): + """Raised when no more sessions can be created.""" + + def __init__(self, message: str, error_type: str = "global") -> None: + super().__init__(message) + self.error_type = error_type # "global" or "per_user" + + +# ── Capacity limits ───────────────────────────────────────────────── +# Estimated for HF Spaces cpu-basic (2 vCPU, 16 GB RAM). +# Each session uses ~10-20 MB (context, tools, queues, task). +MAX_SESSIONS: int = 50 +MAX_SESSIONS_PER_USER: int = 10 + + class SessionManager: """Manages multiple concurrent agent sessions.""" @@ -61,19 +78,69 @@ class SessionManager: self.sessions: dict[str, AgentSession] = {} self._lock = asyncio.Lock() - async def create_session(self) -> str: - """Create a new agent session and return its ID.""" + def _count_user_sessions(self, user_id: str) -> int: + """Count active sessions owned by a specific user.""" + return sum( + 1 + for s in self.sessions.values() + if s.user_id == user_id and s.is_active + ) + + async def create_session(self, user_id: str = "dev", hf_token: str | None = None) -> str: + """Create a new agent session and return its ID. + + Session() and ToolRouter() constructors contain blocking I/O + (e.g. HfApi().whoami(), litellm.get_max_tokens()) so they are + executed in a thread pool to avoid freezing the async event loop. + + Args: + user_id: The ID of the user who owns this session. + + Raises: + SessionCapacityError: If the server or user has reached the + maximum number of concurrent sessions. + """ + # ── Capacity checks ────────────────────────────────────────── + async with self._lock: + active_count = self.active_session_count + if active_count >= MAX_SESSIONS: + raise SessionCapacityError( + f"Server is at capacity ({active_count}/{MAX_SESSIONS} sessions). " + "Please try again later.", + error_type="global", + ) + if user_id != "dev": + user_count = self._count_user_sessions(user_id) + if user_count >= MAX_SESSIONS_PER_USER: + raise SessionCapacityError( + f"You have reached the maximum of {MAX_SESSIONS_PER_USER} " + "concurrent sessions. Please close an existing session first.", + error_type="per_user", + ) + session_id = str(uuid.uuid4()) # Create queues for this session submission_queue: asyncio.Queue = asyncio.Queue() event_queue: asyncio.Queue = asyncio.Queue() - # Create tool router - tool_router = ToolRouter(self.config.mcpServers) + # Run blocking constructors in a thread to keep the event loop responsive. + # Without this, Session.__init__ → ContextManager → litellm.get_max_tokens() + # blocks all HTTP/WebSocket handling. + import time as _time + + def _create_session_sync(): + t0 = _time.monotonic() + tool_router = ToolRouter(self.config.mcpServers) + session = Session(event_queue, config=self.config, tool_router=tool_router) + t1 = _time.monotonic() + logger.info(f"Session initialized in {t1 - t0:.2f}s") + return tool_router, session - # Create the agent session - session = Session(event_queue, config=self.config, tool_router=tool_router) + tool_router, session = await asyncio.to_thread(_create_session_sync) + + # Store user's HF token on the session so tools can use it + session.hf_token = hf_token # Create wrapper agent_session = AgentSession( @@ -81,6 +148,8 @@ class SessionManager: session=session, tool_router=tool_router, submission_queue=submission_queue, + user_id=user_id, + hf_token=hf_token, ) async with self._lock: @@ -92,7 +161,7 @@ class SessionManager: ) agent_session.task = task - logger.info(f"Created session {session_id}") + logger.info(f"Created session {session_id} for user {user_id}") return session_id async def _run_session( @@ -245,6 +314,27 @@ class SessionManager: return True + def get_session_owner(self, session_id: str) -> str | None: + """Get the user_id that owns a session, or None if session doesn't exist.""" + agent_session = self.sessions.get(session_id) + if not agent_session: + return None + return agent_session.user_id + + def verify_session_access(self, session_id: str, user_id: str) -> bool: + """Check if a user has access to a session. + + Returns True if: + - The session exists AND the user owns it + - The user_id is "dev" (dev mode bypass) + """ + owner = self.get_session_owner(session_id) + if owner is None: + return False + if user_id == "dev" or owner == "dev": + return True + return owner == user_id + def get_session_info(self, session_id: str) -> dict[str, Any] | None: """Get information about a session.""" agent_session = self.sessions.get(session_id) @@ -256,15 +346,25 @@ class SessionManager: "created_at": agent_session.created_at.isoformat(), "is_active": agent_session.is_active, "message_count": len(agent_session.session.context_manager.items), + "user_id": agent_session.user_id, } - def list_sessions(self) -> list[dict[str, Any]]: - """List all sessions.""" - return [ - self.get_session_info(sid) - for sid in self.sessions - if self.get_session_info(sid) - ] + def list_sessions(self, user_id: str | None = None) -> list[dict[str, Any]]: + """List sessions, optionally filtered by user. + + Args: + user_id: If provided, only return sessions owned by this user. + If "dev", return all sessions (dev mode). + """ + results = [] + for sid in self.sessions: + info = self.get_session_info(sid) + if not info: + continue + if user_id and user_id != "dev" and info.get("user_id") != user_id: + continue + results.append(info) + return results @property def active_session_count(self) -> int: diff --git a/backend/websocket.py b/backend/websocket.py index 924d7a831273890db939d32a8094373ee0a69fd3..bc09ed747b164bbe99ddebd6d35a36ae6a2faad8 100644 --- a/backend/websocket.py +++ b/backend/websocket.py @@ -1,6 +1,5 @@ """WebSocket connection manager for real-time communication.""" -import asyncio import logging from typing import Any @@ -15,23 +14,18 @@ class ConnectionManager: def __init__(self) -> None: # session_id -> WebSocket self.active_connections: dict[str, WebSocket] = {} - # session_id -> asyncio.Queue for outgoing messages - self.message_queues: dict[str, asyncio.Queue] = {} async def connect(self, websocket: WebSocket, session_id: str) -> None: """Accept a WebSocket connection and register it.""" logger.info(f"Attempting to accept WebSocket for session {session_id}") await websocket.accept() self.active_connections[session_id] = websocket - self.message_queues[session_id] = asyncio.Queue() logger.info(f"WebSocket connected and registered for session {session_id}") def disconnect(self, session_id: str) -> None: """Remove a WebSocket connection.""" if session_id in self.active_connections: del self.active_connections[session_id] - if session_id in self.message_queues: - del self.message_queues[session_id] logger.info(f"WebSocket disconnected for session {session_id}") async def send_event( @@ -63,10 +57,6 @@ class ConnectionManager: """Check if a session has an active WebSocket connection.""" return session_id in self.active_connections - def get_queue(self, session_id: str) -> asyncio.Queue | None: - """Get the message queue for a session.""" - return self.message_queues.get(session_id) - # Global connection manager instance manager = ConnectionManager() diff --git a/configs/main_agent_config.json b/configs/main_agent_config.json index 18a414b3bfced18b47d2737579e3db9c9d137cd6..1ef25f251b053a73dc461cf4bdf617bb11c983d5 100644 --- a/configs/main_agent_config.json +++ b/configs/main_agent_config.json @@ -1,9 +1,9 @@ { - "model_name": "anthropic/claude-opus-4-5-20251101", + "model_name": "huggingface/novita/moonshotai/kimi-k2.5", "save_sessions": true, "session_dataset_repo": "akseljoonas/hf-agent-sessions", "yolo_mode": false, - "confirm_cpu_jobs": false, + "confirm_cpu_jobs": true, "auto_file_upload": true, "mcpServers": { "hf-mcp-server": { diff --git a/frontend/package-lock.json b/frontend/package-lock.json index a800dd3f254b2ff725890c4f250e34d7490bf52d..1e1a41bb0535aa7abb2b7a8ac165ab216c1bd384 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -8,10 +8,12 @@ "name": "hf-agent-frontend", "version": "1.0.0", "dependencies": { + "@ai-sdk/react": "^3.0.93", "@emotion/react": "^11.13.0", "@emotion/styled": "^11.13.0", "@mui/icons-material": "^6.1.0", "@mui/material": "^6.1.0", + "ai": "^6.0.91", "react": "^18.3.1", "react-dom": "^18.3.1", "react-markdown": "^9.0.1", @@ -34,6 +36,70 @@ "vite": "^5.4.10" } }, + "node_modules/@ai-sdk/gateway": { + "version": "3.0.50", + "resolved": "https://registry.npmjs.org/@ai-sdk/gateway/-/gateway-3.0.50.tgz", + "integrity": "sha512-Jdd1a8VgbD7l7r+COj0h5SuaYRfPvOJ/AO6l0OrmTPEcI2MUQPr3C4JttfpNkcheEN+gOdy0CtZWuG17bW2fjw==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.15", + "@vercel/oidc": "3.1.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/provider": { + "version": "3.0.8", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-3.0.8.tgz", + "integrity": "sha512-oGMAgGoQdBXbZqNG0Ze56CHjDZ1IDYOwGYxYjO5KLSlz5HiNQ9udIXsPZ61VWaHGZ5XW/jyjmr6t2xz2jGVwbQ==", + "license": "Apache-2.0", + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/provider-utils": { + "version": "4.0.15", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-4.0.15.tgz", + "integrity": "sha512-8XiKWbemmCbvNN0CLR9u3PQiet4gtEVIrX4zzLxnCj06AwsEDJwJVBbKrEI4t6qE8XRSIvU2irka0dcpziKW6w==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "3.0.8", + "@standard-schema/spec": "^1.1.0", + "eventsource-parser": "^3.0.6" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, + "node_modules/@ai-sdk/react": { + "version": "3.0.93", + "resolved": "https://registry.npmjs.org/@ai-sdk/react/-/react-3.0.93.tgz", + "integrity": "sha512-FY1HmeAfCpiAGLhIZh2QR8QFzHFZfhjMmkA9D5KC/O3eGqPeY7CwBABLkzRH+5Gkf+MfxXnEm4VF0MpmvDMjpg==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider-utils": "4.0.15", + "ai": "6.0.91", + "swr": "^2.2.5", + "throttleit": "2.1.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "react": "^18 || ~19.0.1 || ~19.1.2 || ^19.2.1" + } + }, "node_modules/@babel/code-frame": { "version": "7.28.6", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.28.6.tgz", @@ -1348,6 +1414,15 @@ } } }, + "node_modules/@opentelemetry/api": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", + "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", + "license": "Apache-2.0", + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/@popperjs/core": { "version": "2.11.8", "resolved": "https://registry.npmjs.org/@popperjs/core/-/core-2.11.8.tgz", @@ -1715,6 +1790,12 @@ "win32" ] }, + "node_modules/@standard-schema/spec": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@standard-schema/spec/-/spec-1.1.0.tgz", + "integrity": "sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==", + "license": "MIT" + }, "node_modules/@types/babel__core": { "version": "7.20.5", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.5.tgz", @@ -2155,6 +2236,15 @@ "integrity": "sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==", "license": "ISC" }, + "node_modules/@vercel/oidc": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/@vercel/oidc/-/oidc-3.1.0.tgz", + "integrity": "sha512-Fw28YZpRnA3cAHHDlkt7xQHiJ0fcL+NRcIqsocZQUSmbzeIKRpwttJjik5ZGanXP+vlA4SbTg+AbA3bP363l+w==", + "license": "Apache-2.0", + "engines": { + "node": ">= 20" + } + }, "node_modules/@vitejs/plugin-react": { "version": "4.7.0", "resolved": "https://registry.npmjs.org/@vitejs/plugin-react/-/plugin-react-4.7.0.tgz", @@ -2200,6 +2290,24 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, + "node_modules/ai": { + "version": "6.0.91", + "resolved": "https://registry.npmjs.org/ai/-/ai-6.0.91.tgz", + "integrity": "sha512-k1/8BusZMhYVxxLZt0BUZzm9HVDCCh117nyWfWUx5xjR2+tWisJbXgysL7EBMq2lgyHwgpA1jDR3tVjWSdWZXw==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/gateway": "3.0.50", + "@ai-sdk/provider": "3.0.8", + "@ai-sdk/provider-utils": "4.0.15", + "@opentelemetry/api": "1.9.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -2848,6 +2956,15 @@ "node": ">=0.10.0" } }, + "node_modules/eventsource-parser": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", + "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==", + "license": "MIT", + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/extend": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", @@ -3356,6 +3473,12 @@ "integrity": "sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==", "license": "MIT" }, + "node_modules/json-schema": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/json-schema/-/json-schema-0.4.0.tgz", + "integrity": "sha512-es94M3nTIfsEPisRafak+HDLfHXnKBhV3vU5eqPcS3flIWqcxJWgXHXiey3YrpaNsanY5ei1VoYEbOzijuq9BA==", + "license": "(AFL-2.1 OR BSD-3-Clause)" + }, "node_modules/json-schema-traverse": { "version": "0.4.1", "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", @@ -5052,6 +5175,31 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/swr": { + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/swr/-/swr-2.4.0.tgz", + "integrity": "sha512-sUlC20T8EOt1pHmDiqueUWMmRRX03W7w5YxovWX7VR2KHEPCTMly85x05vpkP5i6Bu4h44ePSMD9Tc+G2MItFw==", + "license": "MIT", + "dependencies": { + "dequal": "^2.0.3", + "use-sync-external-store": "^1.6.0" + }, + "peerDependencies": { + "react": "^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, + "node_modules/throttleit": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/throttleit/-/throttleit-2.1.0.tgz", + "integrity": "sha512-nt6AMGKW1p/70DF/hGBdJB57B8Tspmbp5gfJ8ilhLnt7kkr2ye7hzD6NVG8GGErk2HWF34igrL2CXmNIkzKqKw==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/tinyglobby": { "version": "0.2.15", "resolved": "https://registry.npmjs.org/tinyglobby/-/tinyglobby-0.2.15.tgz", @@ -5282,6 +5430,16 @@ "punycode": "^2.1.0" } }, + "node_modules/use-sync-external-store": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz", + "integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==", + "license": "MIT", + "peer": true, + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, "node_modules/vfile": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz", @@ -5426,6 +5584,16 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/zod": { + "version": "4.3.6", + "resolved": "https://registry.npmjs.org/zod/-/zod-4.3.6.tgz", + "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", + "license": "MIT", + "peer": true, + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, "node_modules/zustand": { "version": "5.0.10", "resolved": "https://registry.npmjs.org/zustand/-/zustand-5.0.10.tgz", diff --git a/frontend/package.json b/frontend/package.json index 553726bae62a96f8869c8bec29bf3fbad511bc0c..9efe3dced3118cbf0976e413f376f1050f1b2853 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -10,10 +10,12 @@ "preview": "vite preview" }, "dependencies": { + "@ai-sdk/react": "^3.0.93", "@emotion/react": "^11.13.0", "@emotion/styled": "^11.13.0", "@mui/icons-material": "^6.1.0", "@mui/material": "^6.1.0", + "ai": "^6.0.91", "react": "^18.3.1", "react-dom": "^18.3.1", "react-markdown": "^9.0.1", diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index e9aecc9ed99aee276dc509c7078d4c7404b50d89..de1f785734359130675174b487a4d30d1ca34f50 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,7 +1,12 @@ import { Box } from '@mui/material'; import AppLayout from '@/components/Layout/AppLayout'; +import { useAuth } from '@/hooks/useAuth'; function App() { + // Non-blocking auth check — fires in background, updates store when done. + // If auth fails later, apiFetch redirects to /auth/login. + useAuth(); + return ( diff --git a/frontend/src/components/ApprovalModal/ApprovalModal.tsx b/frontend/src/components/ApprovalModal/ApprovalModal.tsx deleted file mode 100644 index 98414524160bf8c81c7efdb3d6ce1adc2578435f..0000000000000000000000000000000000000000 --- a/frontend/src/components/ApprovalModal/ApprovalModal.tsx +++ /dev/null @@ -1,208 +0,0 @@ -import { useState, useCallback } from 'react'; -import { - Dialog, - DialogTitle, - DialogContent, - DialogActions, - Button, - Box, - Typography, - Checkbox, - FormControlLabel, - Accordion, - AccordionSummary, - AccordionDetails, - TextField, - Chip, -} from '@mui/material'; -import ExpandMoreIcon from '@mui/icons-material/ExpandMore'; -import WarningIcon from '@mui/icons-material/Warning'; -import { useAgentStore } from '@/store/agentStore'; - -interface ApprovalModalProps { - sessionId: string | null; -} - -interface ApprovalState { - [toolCallId: string]: { - approved: boolean; - feedback: string; - }; -} - -export default function ApprovalModal({ sessionId }: ApprovalModalProps) { - const { pendingApprovals, setPendingApprovals } = useAgentStore(); - const [approvalState, setApprovalState] = useState({}); - - const isOpen = pendingApprovals !== null && pendingApprovals.tools.length > 0; - - const handleApprovalChange = useCallback( - (toolCallId: string, approved: boolean) => { - setApprovalState((prev) => ({ - ...prev, - [toolCallId]: { - ...prev[toolCallId], - approved, - feedback: prev[toolCallId]?.feedback || '', - }, - })); - }, - [] - ); - - const handleFeedbackChange = useCallback( - (toolCallId: string, feedback: string) => { - setApprovalState((prev) => ({ - ...prev, - [toolCallId]: { - ...prev[toolCallId], - feedback, - }, - })); - }, - [] - ); - - const handleSubmit = useCallback(async () => { - if (!sessionId || !pendingApprovals) return; - - const approvals = pendingApprovals.tools.map((tool) => ({ - tool_call_id: tool.tool_call_id, - approved: approvalState[tool.tool_call_id]?.approved ?? false, - feedback: approvalState[tool.tool_call_id]?.feedback || null, - })); - - try { - await fetch('/api/approve', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - session_id: sessionId, - approvals, - }), - }); - setPendingApprovals(null); - setApprovalState({}); - } catch (e) { - console.error('Approval submission failed:', e); - } - }, [sessionId, pendingApprovals, approvalState, setPendingApprovals]); - - const handleApproveAll = useCallback(() => { - if (!pendingApprovals) return; - const newState: ApprovalState = {}; - pendingApprovals.tools.forEach((tool) => { - newState[tool.tool_call_id] = { approved: true, feedback: '' }; - }); - setApprovalState(newState); - }, [pendingApprovals]); - - const handleRejectAll = useCallback(() => { - if (!pendingApprovals) return; - const newState: ApprovalState = {}; - pendingApprovals.tools.forEach((tool) => { - newState[tool.tool_call_id] = { approved: false, feedback: '' }; - }); - setApprovalState(newState); - }, [pendingApprovals]); - - if (!isOpen || !pendingApprovals) return null; - - const approvedCount = Object.values(approvalState).filter((s) => s.approved).length; - - return ( - - - - Approval Required - 1 ? 's' : ''}`} - size="small" - sx={{ ml: 1 }} - /> - - - - The following tool calls require your approval before execution: - - {pendingApprovals.tools.map((tool, index) => ( - - }> - - { - e.stopPropagation(); - handleApprovalChange(tool.tool_call_id, e.target.checked); - }} - onClick={(e) => e.stopPropagation()} - /> - } - label="" - sx={{ m: 0 }} - /> - - - {approvalState[tool.tool_call_id]?.approved ? 'Approved' : 'Pending'} - - - - - - Arguments: - - - {JSON.stringify(tool.arguments, null, 2)} - - {!approvalState[tool.tool_call_id]?.approved && ( - handleFeedbackChange(tool.tool_call_id, e.target.value)} - sx={{ mt: 2 }} - /> - )} - - - ))} - - - - - - - {approvedCount} of {pendingApprovals.count} approved - - - - - ); -} diff --git a/frontend/src/components/Chat/ActivityStatusBar.tsx b/frontend/src/components/Chat/ActivityStatusBar.tsx new file mode 100644 index 0000000000000000000000000000000000000000..5aac135875b951d97fdb942eda9b6e1f4943e193 --- /dev/null +++ b/frontend/src/components/Chat/ActivityStatusBar.tsx @@ -0,0 +1,57 @@ +import { Box, Typography } from '@mui/material'; +import { keyframes } from '@mui/system'; +import { useAgentStore, type ActivityStatus } from '@/store/agentStore'; + +const shimmer = keyframes` + 0% { background-position: -100% center; } + 50% { background-position: 200% center; } + 100% { background-position: -100% center; } +`; + +const TOOL_LABELS: Record = { + hf_jobs: 'Running job', + hf_repo_files: 'Uploading file', + hf_repo_git: 'Git operation', + hf_inspect_dataset: 'Inspecting dataset', + hf_search: 'Searching', + plan_tool: 'Planning', +}; + +function statusLabel(status: ActivityStatus): string { + switch (status.type) { + case 'thinking': return 'Thinking'; + case 'streaming': return 'Writing'; + case 'tool': return TOOL_LABELS[status.toolName] || `Running ${status.toolName}`; + case 'waiting-approval': return 'Waiting for approval'; + default: return ''; + } +} + +export default function ActivityStatusBar() { + const activityStatus = useAgentStore(s => s.activityStatus); + + if (activityStatus.type === 'idle') return null; + + const label = statusLabel(activityStatus); + + return ( + + + {label}… + + + ); +} diff --git a/frontend/src/components/Chat/ApprovalFlow.tsx b/frontend/src/components/Chat/ApprovalFlow.tsx deleted file mode 100644 index 58c1d8e6520ac561ae341965c2f421c74112ad63..0000000000000000000000000000000000000000 --- a/frontend/src/components/Chat/ApprovalFlow.tsx +++ /dev/null @@ -1,515 +0,0 @@ -import { useState, useCallback, useEffect } from 'react'; -import { Box, Typography, Button, TextField, IconButton, Link } from '@mui/material'; -import SendIcon from '@mui/icons-material/Send'; -import OpenInNewIcon from '@mui/icons-material/OpenInNew'; -import CheckCircleIcon from '@mui/icons-material/CheckCircle'; -import CancelIcon from '@mui/icons-material/Cancel'; -import LaunchIcon from '@mui/icons-material/Launch'; -import { useAgentStore } from '@/store/agentStore'; -import { useLayoutStore } from '@/store/layoutStore'; -import { useSessionStore } from '@/store/sessionStore'; -import type { Message, ToolApproval } from '@/types/agent'; - -interface ApprovalFlowProps { - message: Message; -} - -export default function ApprovalFlow({ message }: ApprovalFlowProps) { - const { setPanelContent, setPanelTab, setActivePanelTab, clearPanelTabs, updateMessage } = useAgentStore(); - const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); - const { activeSessionId } = useSessionStore(); - const [currentIndex, setCurrentIndex] = useState(0); - const [feedback, setFeedback] = useState(''); - const [decisions, setDecisions] = useState([]); - - const approvalData = message.approval; - - if (!approvalData) return null; - - const { batch, status } = approvalData; - - // Parse toolOutput to extract job info (URL, status, logs, errors) - let logsContent = ''; - let showLogsButton = false; - let jobUrl = ''; - let jobStatus = ''; - let jobFailed = false; - let errorMessage = ''; - - if (message.toolOutput) { - const output = message.toolOutput; - - // Extract job URL: **View at:** https://... - const urlMatch = output.match(/\*\*View at:\*\*\s*(https:\/\/[^\s\n]+)/); - if (urlMatch) { - jobUrl = urlMatch[1]; - } - - // Extract job status: **Final Status:** ... - const statusMatch = output.match(/\*\*Final Status:\*\*\s*([^\n]+)/); - if (statusMatch) { - jobStatus = statusMatch[1].trim(); - jobFailed = jobStatus.toLowerCase().includes('error') || jobStatus.toLowerCase().includes('failed'); - } - - // Extract logs - if (output.includes('**Logs:**')) { - const parts = output.split('**Logs:**'); - if (parts.length > 1) { - const logsPart = parts[1].trim(); - const codeBlockMatch = logsPart.match(/```([\s\S]*?)```/); - if (codeBlockMatch) { - logsContent = codeBlockMatch[1].trim(); - showLogsButton = true; - } - } - } - - // Detect errors - if output exists but doesn't have the expected job completion format - // This catches early failures (validation errors, API errors, etc.) - const isExpectedFormat = output.includes('**Job ID:**') || output.includes('**View at:**'); - const looksLikeError = output.toLowerCase().includes('error') || - output.toLowerCase().includes('failed') || - output.toLowerCase().includes('exception') || - output.includes('Traceback'); - - if (!isExpectedFormat || (looksLikeError && !logsContent)) { - // This is likely an error message - show it - errorMessage = output; - jobFailed = true; - } - } - - // Sync right panel with current tool - useEffect(() => { - if (!batch || currentIndex >= batch.tools.length) return; - - // Only auto-open panel if pending - if (status !== 'pending') return; - - const tool = batch.tools[currentIndex]; - const args = tool.arguments as any; - - if (tool.tool === 'hf_jobs' && (args.operation === 'run' || args.operation === 'scheduled run') && args.script) { - setPanelContent({ - title: 'Compute Job Script', - content: args.script, - language: 'python', - parameters: args - }); - // Don't auto-open if already resolved - } else if (tool.tool === 'hf_repo_files' && args.operation === 'upload' && args.content) { - setPanelContent({ - title: `File Upload: ${args.path || 'unnamed'}`, - content: args.content, - parameters: args - }); - } - }, [currentIndex, batch, status, setPanelContent]); - - const handleResolve = useCallback(async (approved: boolean) => { - if (!batch || !activeSessionId) return; - - const currentTool = batch.tools[currentIndex]; - const newDecisions = [ - ...decisions, - { - tool_call_id: currentTool.tool_call_id, - approved, - feedback: approved ? null : feedback || 'Rejected by user', - }, - ]; - - if (currentIndex < batch.tools.length - 1) { - setDecisions(newDecisions); - setCurrentIndex(currentIndex + 1); - setFeedback(''); - } else { - // All tools in batch resolved - try { - await fetch('/api/approve', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - session_id: activeSessionId, - approvals: newDecisions, - }), - }); - - // Update message status - updateMessage(activeSessionId, message.id, { - approval: { - ...approvalData!, - status: approved ? 'approved' : 'rejected', - decisions: newDecisions - } - }); - - } catch (e) { - console.error('Approval submission failed:', e); - } - } - }, [activeSessionId, message.id, batch, currentIndex, feedback, decisions, approvalData, updateMessage]); - - if (!batch || currentIndex >= batch.tools.length) return null; - - const currentTool = batch.tools[currentIndex]; - - // Check if script contains push_to_hub or upload_file - const args = currentTool.arguments as any; - const containsPushToHub = currentTool.tool === 'hf_jobs' && args.script && (args.script.includes('push_to_hub') || args.script.includes('upload_file')); - - const getToolDescription = (toolName: string, args: any) => { - if (toolName === 'hf_jobs') { - return ( - - - The agent wants to execute hf_jobs on{' '} - {args.hardware_flavor || 'default'} with a timeout of{' '} - {args.timeout || '30m'} - - - ); - } - return ( - - The agent wants to execute {toolName} - - ); - }; - - const showCode = () => { - const args = currentTool.arguments as any; - if (currentTool.tool === 'hf_jobs' && args.script) { - // Clear existing tabs and set up script tab (and logs if available) - clearPanelTabs(); - setPanelTab({ - id: 'script', - title: 'Script', - content: args.script, - language: 'python', - parameters: args - }); - // If logs are available (job completed), also add logs tab - if (logsContent) { - setPanelTab({ - id: 'logs', - title: 'Logs', - content: logsContent, - language: 'text' - }); - } - setActivePanelTab('script'); - setRightPanelOpen(true); - setLeftSidebarOpen(false); - } else { - setPanelContent({ - title: `Tool: ${currentTool.tool}`, - content: JSON.stringify(args, null, 2), - language: 'json', - parameters: args - }); - setRightPanelOpen(true); - setLeftSidebarOpen(false); - } - }; - - const handleViewLogs = (e: React.MouseEvent) => { - e.stopPropagation(); - const args = currentTool.arguments as any; - // Set up both tabs so user can switch between script and logs - clearPanelTabs(); - if (currentTool.tool === 'hf_jobs' && args.script) { - setPanelTab({ - id: 'script', - title: 'Script', - content: args.script, - language: 'python', - parameters: args - }); - } - setPanelTab({ - id: 'logs', - title: 'Logs', - content: logsContent, - language: 'text' - }); - setActivePanelTab('logs'); - setRightPanelOpen(true); - setLeftSidebarOpen(false); - }; - - return ( - - - - {status === 'pending' ? 'Approval Required' : status === 'approved' ? 'Approved' : 'Rejected'} - - - ({currentIndex + 1}/{batch.count}) - - {status === 'approved' && } - {status === 'rejected' && } - - - - {getToolDescription(currentTool.tool, currentTool.arguments)} - - - - {/* Script/Logs buttons for hf_jobs - always show when we have a script */} - {currentTool.tool === 'hf_jobs' && args.script && ( - - - - - - - {/* Job URL - only show when we have a specific URL */} - {jobUrl && ( - - - View Job on Hugging Face - - )} - - {/* Show job status if available */} - {jobStatus && ( - - Status: {jobStatus} - - )} - - )} - - {containsPushToHub && ( - - We've detected the result will be pushed to hub. - - )} - - {/* Show error message if job failed */} - {errorMessage && status !== 'pending' && ( - - - Error - - - {errorMessage.length > 500 ? errorMessage.substring(0, 500) + '...' : errorMessage} - - - )} - - - {status === 'pending' && ( - - - setFeedback(e.target.value)} - variant="outlined" - sx={{ - '& .MuiOutlinedInput-root': { - bgcolor: 'rgba(0,0,0,0.2)', - fontFamily: 'inherit', - fontSize: '0.9rem' - } - }} - /> - handleResolve(false)} - disabled={!feedback} - title="Reject with feedback" - sx={{ - color: 'var(--accent-red)', - border: '1px solid rgba(255,255,255,0.05)', - borderRadius: '8px', - width: 40, - height: 40, - '&:hover': { - bgcolor: 'rgba(224, 90, 79, 0.1)', - borderColor: 'var(--accent-red)', - }, - '&.Mui-disabled': { - color: 'rgba(255,255,255,0.1)', - borderColor: 'rgba(255,255,255,0.02)' - } - }} - > - - - - - - - - - - )} - - {status === 'rejected' && decisions.some(d => d.feedback) && ( - - Feedback: {decisions.find(d => d.feedback)?.feedback} - - )} - - ); -} \ No newline at end of file diff --git a/frontend/src/components/Chat/AssistantMessage.tsx b/frontend/src/components/Chat/AssistantMessage.tsx new file mode 100644 index 0000000000000000000000000000000000000000..83bd8cae505808781908a2292eaa8acc1242536b --- /dev/null +++ b/frontend/src/components/Chat/AssistantMessage.tsx @@ -0,0 +1,119 @@ +import { useMemo } from 'react'; +import { Box, Stack, Typography } from '@mui/material'; +import MarkdownContent from './MarkdownContent'; +import ToolCallGroup from './ToolCallGroup'; +import type { UIMessage } from 'ai'; +import type { MessageMeta } from '@/types/agent'; + +interface AssistantMessageProps { + message: UIMessage; + isStreaming?: boolean; + approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; +} + +/** + * Groups consecutive tool parts together so they render as a single + * ToolCallGroup (visually identical to the old segments approach). + */ +type DynamicToolPart = Extract; + +function groupParts(parts: UIMessage['parts']) { + const groups: Array< + | { kind: 'text'; text: string; idx: number } + | { kind: 'tools'; tools: DynamicToolPart[]; idx: number } + > = []; + + for (let i = 0; i < parts.length; i++) { + const part = parts[i]; + + if (part.type === 'text') { + groups.push({ kind: 'text', text: part.text, idx: i }); + } else if (part.type === 'dynamic-tool') { + const toolPart = part as DynamicToolPart; + const last = groups[groups.length - 1]; + if (last?.kind === 'tools') { + last.tools.push(toolPart); + } else { + groups.push({ kind: 'tools', tools: [toolPart], idx: i }); + } + } + // step-start, step-end, etc. are ignored visually + } + + return groups; +} + +export default function AssistantMessage({ message, isStreaming = false, approveTools }: AssistantMessageProps) { + const groups = useMemo(() => groupParts(message.parts), [message.parts]); + + // Find the last text group index for streaming cursor + let lastTextIdx = -1; + for (let i = groups.length - 1; i >= 0; i--) { + if (groups[i].kind === 'text') { lastTextIdx = i; break; } + } + + const meta = message.metadata as MessageMeta | undefined; + const timeStr = meta?.createdAt + ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) + : null; + + if (groups.length === 0) return null; + + return ( + + + + Assistant + + {timeStr && ( + + {timeStr} + + )} + + + + {groups.map((group, i) => { + if (group.kind === 'text' && group.text) { + return ( + + ); + } + if (group.kind === 'tools' && group.tools.length > 0) { + return ( + + ); + } + return null; + })} + + + ); +} diff --git a/frontend/src/components/Chat/ChatInput.tsx b/frontend/src/components/Chat/ChatInput.tsx index a242eee54a373b6559ef565803e0be76d15df5c1..5fa7bd5f03c321a8d39245a95014b314221cd1b2 100644 --- a/frontend/src/components/Chat/ChatInput.tsx +++ b/frontend/src/components/Chat/ChatInput.tsx @@ -1,14 +1,103 @@ -import { useState, useCallback, KeyboardEvent } from 'react'; -import { Box, TextField, IconButton, CircularProgress, Typography } from '@mui/material'; +import { useState, useCallback, useEffect, useRef, KeyboardEvent } from 'react'; +import { Box, TextField, IconButton, CircularProgress, Typography, Menu, MenuItem, ListItemIcon, ListItemText, Chip } from '@mui/material'; import ArrowUpwardIcon from '@mui/icons-material/ArrowUpward'; +import ArrowDropDownIcon from '@mui/icons-material/ArrowDropDown'; +import { apiFetch } from '@/utils/api'; + +// Model configuration +interface ModelOption { + id: string; + name: string; + description: string; + modelPath: string; + avatarUrl: string; + recommended?: boolean; +} + +const getHfAvatarUrl = (modelId: string) => { + const org = modelId.split('/')[0]; + return `https://huggingface.co/api/avatars/${org}`; +}; + +const MODEL_OPTIONS: ModelOption[] = [ + { + id: 'minimax-m2.1', + name: 'MiniMax M2.1', + description: 'Via Novita', + modelPath: 'huggingface/novita/minimax/minimax-m2.1', + avatarUrl: getHfAvatarUrl('MiniMaxAI/MiniMax-M2.1'), + recommended: true, + }, + { + id: 'claude-opus', + name: 'Claude Opus 4.5', + description: 'Anthropic', + modelPath: 'anthropic/claude-opus-4-5-20251101', + avatarUrl: 'https://huggingface.co/api/avatars/Anthropic', + recommended: true, + }, + { + id: 'kimi-k2.5', + name: 'Kimi K2.5', + description: 'Via Novita', + modelPath: 'huggingface/novita/moonshotai/kimi-k2.5', + avatarUrl: getHfAvatarUrl('moonshotai/Kimi-K2.5'), + }, + { + id: 'glm-5', + name: 'GLM 5', + description: 'Via Novita', + modelPath: 'huggingface/novita/zai-org/glm-5', + avatarUrl: getHfAvatarUrl('zai-org/GLM-5'), + }, +]; + +const findModelByPath = (path: string): ModelOption | undefined => { + return MODEL_OPTIONS.find(m => m.modelPath === path || path?.includes(m.id)); +}; interface ChatInputProps { onSend: (text: string) => void; disabled?: boolean; + placeholder?: string; } -export default function ChatInput({ onSend, disabled = false }: ChatInputProps) { +export default function ChatInput({ onSend, disabled = false, placeholder = 'Ask anything...' }: ChatInputProps) { const [input, setInput] = useState(''); + const inputRef = useRef(null); + const [selectedModelId, setSelectedModelId] = useState(() => { + try { + const stored = localStorage.getItem('hf-agent-model'); + if (stored && MODEL_OPTIONS.some(m => m.id === stored)) return stored; + } catch { /* localStorage unavailable */ } + return MODEL_OPTIONS[0].id; + }); + const [modelAnchorEl, setModelAnchorEl] = useState(null); + + // Sync with backend on mount (backend is source of truth, localStorage is just a cache) + useEffect(() => { + fetch('/api/config/model') + .then((res) => (res.ok ? res.json() : null)) + .then((data) => { + if (data?.current) { + const model = findModelByPath(data.current); + if (model) { + setSelectedModelId(model.id); + try { localStorage.setItem('hf-agent-model', model.id); } catch { /* ignore */ } + } + } + }) + .catch(() => { /* ignore */ }); + }, []); + + const selectedModel = MODEL_OPTIONS.find(m => m.id === selectedModelId) || MODEL_OPTIONS[0]; + + // Auto-focus the textarea when the session becomes ready (disabled -> false) + useEffect(() => { + if (!disabled && inputRef.current) { + inputRef.current.focus(); + } + }, [disabled]); const handleSend = useCallback(() => { if (input.trim() && !disabled) { @@ -27,26 +116,48 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps) [handleSend] ); + const handleModelClick = (event: React.MouseEvent) => { + setModelAnchorEl(event.currentTarget); + }; + + const handleModelClose = () => { + setModelAnchorEl(null); + }; + + const handleSelectModel = async (model: ModelOption) => { + handleModelClose(); + try { + const res = await apiFetch('/api/config/model', { + method: 'POST', + body: JSON.stringify({ model: model.modelPath }), + }); + if (res.ok) { + setSelectedModelId(model.id); + try { localStorage.setItem('hf-agent-model', model.id); } catch { /* ignore */ } + } + } catch { /* ignore */ } + }; + return ( - + setInput(e.target.value)} onKeyDown={handleKeyDown} - placeholder="Ask anything..." + placeholder={placeholder} disabled={disabled} variant="standard" + inputRef={inputRef} InputProps={{ disableUnderline: true, sx: { @@ -72,7 +184,7 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps) fontFamily: 'inherit', padding: 0, lineHeight: 1.5, - minHeight: '56px', + minHeight: { xs: '44px', md: '56px' }, alignItems: 'flex-start', } }} @@ -99,7 +211,7 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps) transition: 'all 0.2s', '&:hover': { color: 'var(--accent-yellow)', - bgcolor: 'rgba(255,255,255,0.05)', + bgcolor: 'var(--hover-bg)', }, '&.Mui-disabled': { opacity: 0.3, @@ -109,17 +221,108 @@ export default function ChatInput({ onSend, disabled = false }: ChatInputProps) {disabled ? : } - + {/* Powered By Badge */} - + powered by - Claude + {selectedModel.name} - claude-opus-4-5-20251101 + {selectedModel.name} + + + {/* Model Selection Menu */} + + {MODEL_OPTIONS.map((model) => ( + handleSelectModel(model)} + selected={selectedModelId === model.id} + sx={{ + py: 1.5, + '&.Mui-selected': { + bgcolor: 'rgba(255,255,255,0.05)', + } + }} + > + + {model.name} + + + {model.name} + {model.recommended && ( + + )} + + } + secondary={model.description} + secondaryTypographyProps={{ + sx: { fontSize: '12px', color: 'var(--muted-text)' } + }} + /> + + ))} + ); diff --git a/frontend/src/components/Chat/MarkdownContent.tsx b/frontend/src/components/Chat/MarkdownContent.tsx new file mode 100644 index 0000000000000000000000000000000000000000..beb682720bf2b4d846b67a86d45607bc4544044b --- /dev/null +++ b/frontend/src/components/Chat/MarkdownContent.tsx @@ -0,0 +1,160 @@ +import { useMemo, useRef, useState, useEffect } from 'react'; +import { Box } from '@mui/material'; +import ReactMarkdown from 'react-markdown'; +import remarkGfm from 'remark-gfm'; +import type { SxProps, Theme } from '@mui/material/styles'; + +interface MarkdownContentProps { + content: string; + sx?: SxProps; + /** When true, shows a blinking cursor and throttles renders. */ + isStreaming?: boolean; +} + +/** Shared markdown styles — adapts to light/dark via CSS variables. */ +const markdownSx: SxProps = { + fontSize: '0.925rem', + lineHeight: 1.7, + color: 'var(--text)', + wordBreak: 'break-word', + + '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } }, + + '& h1, & h2, & h3, & h4': { mt: 2.5, mb: 1, fontWeight: 600, lineHeight: 1.3 }, + '& h1': { fontSize: '1.35rem' }, + '& h2': { fontSize: '1.15rem' }, + '& h3': { fontSize: '1.05rem' }, + + '& pre': { + bgcolor: 'var(--code-bg)', + p: 2, + borderRadius: 2, + overflow: 'auto', + fontSize: '0.82rem', + lineHeight: 1.6, + border: '1px solid var(--tool-border)', + my: 2, + }, + '& code': { + bgcolor: 'var(--hover-bg)', + px: 0.75, + py: 0.25, + borderRadius: 0.5, + fontSize: '0.84rem', + fontFamily: '"JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', + }, + '& pre code': { bgcolor: 'transparent', p: 0 }, + + '& a': { + color: 'var(--accent-yellow)', + textDecoration: 'none', + fontWeight: 500, + '&:hover': { textDecoration: 'underline' }, + }, + + '& ul, & ol': { pl: 3, my: 1 }, + '& li': { mb: 0.5 }, + '& li::marker': { color: 'var(--muted-text)' }, + + '& blockquote': { + borderLeft: '3px solid var(--accent-yellow)', + pl: 2, + ml: 0, + my: 1.5, + color: 'var(--muted-text)', + fontStyle: 'italic', + }, + + '& table': { + borderCollapse: 'collapse', + width: '100%', + my: 2, + fontSize: '0.85rem', + }, + '& th': { + borderBottom: '2px solid var(--border-hover)', + textAlign: 'left', + p: 1, + fontWeight: 600, + }, + '& td': { + borderBottom: '1px solid var(--tool-border)', + p: 1, + }, + + '& hr': { + border: 'none', + borderTop: '1px solid var(--border)', + my: 2, + }, + + '& img': { + maxWidth: '100%', + borderRadius: 2, + }, +}; + +/** + * Throttled content for streaming: render the full markdown through + * ReactMarkdown but only re-parse every ~80ms to avoid layout thrashing. + * This is the Claude approach — always render as markdown, never split + * into raw text. The parser handles incomplete tables gracefully. + */ +function useThrottledValue(value: string, isStreaming: boolean, intervalMs = 80): string { + const [throttled, setThrottled] = useState(value); + const lastUpdate = useRef(0); + const pending = useRef | null>(null); + const latestValue = useRef(value); + latestValue.current = value; + + useEffect(() => { + if (!isStreaming) { + // Not streaming — always use latest value immediately + setThrottled(value); + return; + } + + const now = Date.now(); + const elapsed = now - lastUpdate.current; + + if (elapsed >= intervalMs) { + // Enough time passed — update immediately + setThrottled(value); + lastUpdate.current = now; + } else { + // Schedule an update for the remaining time + if (pending.current) clearTimeout(pending.current); + pending.current = setTimeout(() => { + setThrottled(latestValue.current); + lastUpdate.current = Date.now(); + pending.current = null; + }, intervalMs - elapsed); + } + + return () => { + if (pending.current) clearTimeout(pending.current); + }; + }, [value, isStreaming, intervalMs]); + + // When streaming ends, flush immediately + useEffect(() => { + if (!isStreaming) { + setThrottled(latestValue.current); + } + }, [isStreaming]); + + return throttled; +} + +export default function MarkdownContent({ content, sx, isStreaming = false }: MarkdownContentProps) { + // Throttle re-parses during streaming to ~12fps (every 80ms) + const displayContent = useThrottledValue(content, isStreaming); + + const remarkPlugins = useMemo(() => [remarkGfm], []); + + return ( + + {displayContent} + + ); +} diff --git a/frontend/src/components/Chat/MessageBubble.tsx b/frontend/src/components/Chat/MessageBubble.tsx index 5e7e0f197afc796beb443bf4b326d8b93f57023a..af5d5a49a1d426039b1bb65ef52abae3b601a104 100644 --- a/frontend/src/components/Chat/MessageBubble.tsx +++ b/frontend/src/components/Chat/MessageBubble.tsx @@ -1,215 +1,44 @@ -import { Box, Paper, Typography } from '@mui/material'; -import ReactMarkdown from 'react-markdown'; -import remarkGfm from 'remark-gfm'; -import ApprovalFlow from './ApprovalFlow'; -import type { Message, TraceLog } from '@/types/agent'; -import { useAgentStore } from '@/store/agentStore'; -import { useLayoutStore } from '@/store/layoutStore'; +import UserMessage from './UserMessage'; +import AssistantMessage from './AssistantMessage'; +import type { UIMessage } from 'ai'; interface MessageBubbleProps { - message: Message; + message: UIMessage; + isLastTurn?: boolean; + onUndoTurn?: () => void; + isProcessing?: boolean; + isStreaming?: boolean; + approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; } -// Render a tools segment with clickable tool calls -function ToolsSegment({ tools }: { tools: TraceLog[] }) { - const { showToolOutput } = useAgentStore(); - const { setRightPanelOpen } = useLayoutStore(); - - const handleToolClick = (log: TraceLog) => { - if (log.completed && log.output) { - showToolOutput(log); - setRightPanelOpen(true); - } - }; - - return ( - - - {tools.map((log) => { - const isClickable = log.completed && log.output; - return ( - handleToolClick(log)} - sx={{ - color: 'var(--muted-text)', - fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', - fontSize: '0.75rem', - display: 'flex', - alignItems: 'center', - gap: 0.5, - cursor: isClickable ? 'pointer' : 'default', - borderRadius: 0.5, - px: 0.5, - mx: -0.5, - transition: 'background-color 0.15s ease', - '&:hover': isClickable ? { - bgcolor: 'rgba(255,255,255,0.05)', - } : {}, - }} - > - - {log.completed ? (log.success === false ? '✗' : '✓') : '•'} - - - {log.tool} - - {!log.completed && ...} - {isClickable && ( - - click to view - - )} - - ); - })} - - - ); -} - -// Markdown styles -const markdownStyles = { - '& p': { m: 0, mb: 1, '&:last-child': { mb: 0 } }, - '& pre': { - bgcolor: 'rgba(0,0,0,0.5)', - p: 1.5, - borderRadius: 1, - overflow: 'auto', - fontSize: '0.85rem', - border: '1px solid rgba(255,255,255,0.05)', - }, - '& code': { - bgcolor: 'rgba(255,255,255,0.05)', - px: 0.5, - py: 0.25, - borderRadius: 0.5, - fontSize: '0.85rem', - fontFamily: '"JetBrains Mono", monospace', - }, - '& pre code': { bgcolor: 'transparent', p: 0 }, - '& a': { - color: 'var(--accent-yellow)', - textDecoration: 'none', - '&:hover': { textDecoration: 'underline' }, - }, - '& ul, & ol': { pl: 2, my: 1 }, - '& table': { - borderCollapse: 'collapse', - width: '100%', - my: 2, - fontSize: '0.875rem', - }, - '& th': { - borderBottom: '1px solid rgba(255,255,255,0.1)', - textAlign: 'left', - p: 1, - bgcolor: 'rgba(255,255,255,0.02)', - }, - '& td': { - borderBottom: '1px solid rgba(255,255,255,0.05)', - p: 1, - }, -}; - -export default function MessageBubble({ message }: MessageBubbleProps) { - const isUser = message.role === 'user'; - const isAssistant = message.role === 'assistant'; - - if (message.approval) { +export default function MessageBubble({ + message, + isLastTurn = false, + onUndoTurn, + isProcessing = false, + isStreaming = false, + approveTools, +}: MessageBubbleProps) { + if (message.role === 'user') { return ( - - - + ); } - // Render segments chronologically if available, otherwise fall back to content - const renderContent = () => { - if (message.segments && message.segments.length > 0) { - return message.segments.map((segment, idx) => { - if (segment.type === 'text' && segment.content) { - return ( - - {segment.content} - - ); - } - if (segment.type === 'tools' && segment.tools && segment.tools.length > 0) { - return ; - } - return null; - }); - } - // Fallback: just render content + if (message.role === 'assistant') { return ( - - {message.content} - + ); - }; - - return ( - - - {renderContent()} + } - - {new Date(message.timestamp).toLocaleTimeString()} - - - - ); + return null; } diff --git a/frontend/src/components/Chat/MessageList.tsx b/frontend/src/components/Chat/MessageList.tsx index c54d4761c83681f1f1eebe7c7eb8619de3c5d962..21729fdc441dea8ebfcad1189686645ff1844e7b 100644 --- a/frontend/src/components/Chat/MessageList.tsx +++ b/frontend/src/components/Chat/MessageList.tsx @@ -1,100 +1,151 @@ -import { useEffect, useRef } from 'react'; -import { Box, Typography } from '@mui/material'; -import { useSessionStore } from '@/store/sessionStore'; +import { useCallback, useEffect, useRef, useMemo } from 'react'; +import { Box, Stack, Typography } from '@mui/material'; import MessageBubble from './MessageBubble'; -import type { Message } from '@/types/agent'; +import ActivityStatusBar from './ActivityStatusBar'; +import { useAgentStore } from '@/store/agentStore'; +import type { UIMessage } from 'ai'; interface MessageListProps { - messages: Message[]; + messages: UIMessage[]; isProcessing: boolean; + approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null }>) => Promise; + onUndoLastTurn: () => void | Promise; } -const TechnicalIndicator = () => ( - -); +function getGreeting(): string { + const h = new Date().getHours(); + if (h < 12) return 'Morning'; + if (h < 17) return 'Afternoon'; + return 'Evening'; +} + +function WelcomeGreeting() { + const { user } = useAgentStore(); + const firstName = user?.name?.split(' ')[0] || user?.username; + const greeting = firstName ? `${getGreeting()}, ${firstName}` : getGreeting(); -export default function MessageList({ messages, isProcessing }: MessageListProps) { - const bottomRef = useRef(null); - const { activeSessionId } = useSessionStore(); + return ( + + + {greeting} + + + Let's build something impressive? + + + ); +} + +export default function MessageList({ messages, isProcessing, approveTools, onUndoLastTurn }: MessageListProps) { + const scrollContainerRef = useRef(null); + const stickToBottom = useRef(true); + + const scrollToBottom = useCallback(() => { + const el = scrollContainerRef.current; + if (el) el.scrollTop = el.scrollHeight; + }, []); + + useEffect(() => { + const el = scrollContainerRef.current; + if (!el) return; + const onScroll = () => { + const distFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight; + stickToBottom.current = distFromBottom < 80; + }; + el.addEventListener('scroll', onScroll, { passive: true }); + return () => el.removeEventListener('scroll', onScroll); + }, []); + + useEffect(() => { + if (stickToBottom.current) scrollToBottom(); + }, [messages, isProcessing, scrollToBottom]); - // Auto-scroll to bottom when new messages arrive useEffect(() => { - bottomRef.current?.scrollIntoView({ behavior: 'smooth' }); - }, [messages, isProcessing]); + const el = scrollContainerRef.current; + if (!el) return; + const observer = new MutationObserver(() => { + if (stickToBottom.current) el.scrollTop = el.scrollHeight; + }); + observer.observe(el, { childList: true, subtree: true, characterData: true }); + return () => observer.disconnect(); + }, []); + + const lastUserMsgId = useMemo(() => { + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'user') return messages[i].id; + } + return null; + }, [messages]); + + // The last assistant message is "streaming" when we're processing + const lastAssistantId = useMemo(() => { + for (let i = messages.length - 1; i >= 0; i--) { + if (messages[i].role === 'assistant') return messages[i].id; + } + return null; + }, [messages]); return ( - + {messages.length === 0 && !isProcessing ? ( - - - Awaiting input… - - + ) : ( - messages.map((message) => ( - + messages.map((msg) => ( + )) )} - - {isProcessing && ( - - - - Thinking - - - - - )} - {activeSessionId && ( - // ApprovalFlow is now handled within messages - null - )} - -
- + + +
+ ); -} \ No newline at end of file +} diff --git a/frontend/src/components/Chat/ThinkingIndicator.tsx b/frontend/src/components/Chat/ThinkingIndicator.tsx new file mode 100644 index 0000000000000000000000000000000000000000..b8c37181f5de70eeb26a6f42049311393cc3d73e --- /dev/null +++ b/frontend/src/components/Chat/ThinkingIndicator.tsx @@ -0,0 +1,48 @@ +import { Box, Typography } from '@mui/material'; + +/** Pulsing dots shown while the agent is processing. */ +export default function ThinkingIndicator() { + return ( + + + Thinking + + + + + + + + ); +} diff --git a/frontend/src/components/Chat/ToolCallGroup.tsx b/frontend/src/components/Chat/ToolCallGroup.tsx new file mode 100644 index 0000000000000000000000000000000000000000..65bc752ba9020a3a78f036271e94adb3db62960c --- /dev/null +++ b/frontend/src/components/Chat/ToolCallGroup.tsx @@ -0,0 +1,655 @@ +import { useCallback, useMemo, useRef, useState } from 'react'; +import { Box, Stack, Typography, Chip, Button, TextField, IconButton, Link, CircularProgress } from '@mui/material'; +import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline'; +import ErrorOutlineIcon from '@mui/icons-material/ErrorOutline'; +import OpenInNewIcon from '@mui/icons-material/OpenInNew'; +import HourglassEmptyIcon from '@mui/icons-material/HourglassEmpty'; +import LaunchIcon from '@mui/icons-material/Launch'; +import SendIcon from '@mui/icons-material/Send'; +import BlockIcon from '@mui/icons-material/Block'; +import { useAgentStore } from '@/store/agentStore'; +import { useLayoutStore } from '@/store/layoutStore'; +import { logger } from '@/utils/logger'; +import type { UIMessage } from 'ai'; + +// --------------------------------------------------------------------------- +// Type helpers — extract the dynamic-tool part type from UIMessage +// --------------------------------------------------------------------------- +type DynamicToolPart = Extract; + +type ToolPartState = DynamicToolPart['state']; + +interface ToolCallGroupProps { + tools: DynamicToolPart[]; + approveTools: (approvals: Array<{ tool_call_id: string; approved: boolean; feedback?: string | null; edited_script?: string | null }>) => Promise; +} + +// --------------------------------------------------------------------------- +// Visual helpers +// --------------------------------------------------------------------------- + +function StatusIcon({ state }: { state: ToolPartState }) { + switch (state) { + case 'approval-requested': + return ; + case 'output-available': + return ; + case 'output-error': + return ; + case 'output-denied': + return ; + case 'input-streaming': + case 'input-available': + default: + return ; + } +} + +function statusLabel(state: ToolPartState): string | null { + switch (state) { + case 'approval-requested': return 'awaiting approval'; + case 'input-streaming': + case 'input-available': return 'running'; + case 'output-denied': return 'denied'; + case 'output-error': return 'error'; + default: return null; + } +} + +function statusColor(state: ToolPartState): string { + switch (state) { + case 'approval-requested': return 'var(--accent-yellow)'; + case 'output-available': return 'var(--accent-green)'; + case 'output-error': return 'var(--accent-red)'; + case 'output-denied': return 'var(--muted-text)'; + default: return 'var(--accent-yellow)'; + } +} + +// --------------------------------------------------------------------------- +// Inline approval UI (per-tool) +// --------------------------------------------------------------------------- + +function InlineApproval({ + toolCallId, + toolName, + input, + scriptLabel, + onResolve, +}: { + toolCallId: string; + toolName: string; + input: unknown; + scriptLabel: string; + onResolve: (toolCallId: string, approved: boolean, feedback?: string) => void; +}) { + const [feedback, setFeedback] = useState(''); + const args = input as Record | undefined; + const { setPanel, getEditedScript } = useAgentStore(); + const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); + const hasEditedScript = !!getEditedScript(toolCallId); + + const handleScriptClick = useCallback(() => { + if (toolName === 'hf_jobs' && args?.script) { + const scriptContent = getEditedScript(toolCallId) || String(args.script); + setPanel( + { title: scriptLabel, script: { content: scriptContent, language: 'python' }, parameters: { tool_call_id: toolCallId } }, + 'script', + true, + ); + setRightPanelOpen(true); + setLeftSidebarOpen(false); + } + }, [toolCallId, toolName, args, scriptLabel, setPanel, getEditedScript, setRightPanelOpen, setLeftSidebarOpen]); + + return ( + + {toolName === 'hf_jobs' && args && ( + + + Execute {scriptLabel.replace('Script', 'Job')} on{' '} + + {String(args.hardware_flavor || 'default')} + + {!!args.timeout && ( + <> with timeout + {String(args.timeout)} + + )} + + {typeof args.script === 'string' && args.script && ( + + + {String(args.script).trim()} + + + Click to view & edit + + + )} + + )} + + + setFeedback(e.target.value)} + variant="outlined" + sx={{ + '& .MuiOutlinedInput-root': { + bgcolor: 'var(--hover-bg)', + fontFamily: 'inherit', + fontSize: '0.8rem', + '& fieldset': { borderColor: 'var(--tool-border)' }, + '&:hover fieldset': { borderColor: 'var(--border-hover)' }, + '&.Mui-focused fieldset': { borderColor: 'var(--accent-yellow)' }, + }, + '& .MuiOutlinedInput-input': { + color: 'var(--text)', + '&::placeholder': { color: 'var(--muted-text)', opacity: 0.7 }, + }, + }} + /> + onResolve(toolCallId, false, feedback || 'Rejected by user')} + disabled={!feedback} + size="small" + sx={{ + color: 'var(--accent-red)', + border: '1px solid var(--tool-border)', + borderRadius: '6px', + '&:hover': { bgcolor: 'rgba(224,90,79,0.1)', borderColor: 'var(--accent-red)' }, + '&.Mui-disabled': { color: 'var(--muted-text)', opacity: 0.3 }, + }} + > + + + + + + + + + + ); +} + +// --------------------------------------------------------------------------- +// Main component +// --------------------------------------------------------------------------- + +export default function ToolCallGroup({ tools, approveTools }: ToolCallGroupProps) { + const { setPanel, lockPanel, getJobUrl, getEditedScript } = useAgentStore(); + const { setRightPanelOpen, setLeftSidebarOpen } = useLayoutStore(); + + // ── Batch approval state ────────────────────────────────────────── + const pendingTools = useMemo( + () => tools.filter(t => t.state === 'approval-requested'), + [tools], + ); + + const [decisions, setDecisions] = useState>({}); + const [isSubmitting, setIsSubmitting] = useState(false); + const submittingRef = useRef(false); + + const { scriptLabelMap, toolDisplayMap } = useMemo(() => { + const hfJobs = tools.filter(t => t.toolName === 'hf_jobs' && (t.input as Record)?.script); + const scriptMap: Record = {}; + const displayMap: Record = {}; + for (let i = 0; i < hfJobs.length; i++) { + const id = hfJobs[i].toolCallId; + if (hfJobs.length > 1) { + scriptMap[id] = `Script ${i + 1}`; + displayMap[id] = `hf_jobs #${i + 1}`; + } else { + scriptMap[id] = 'Script'; + displayMap[id] = 'hf_jobs'; + } + } + return { scriptLabelMap: scriptMap, toolDisplayMap: displayMap }; + }, [tools]); + + // ── Send all decisions as a single batch ────────────────────────── + const sendBatch = useCallback( + async (batch: Record) => { + if (submittingRef.current) return; + submittingRef.current = true; + setIsSubmitting(true); + + const approvals = Object.entries(batch).map(([toolCallId, d]) => { + const editedScript = d.approved ? (getEditedScript(toolCallId) ?? null) : null; + if (editedScript) { + logger.log(`Sending edited script for ${toolCallId} (${editedScript.length} chars)`); + } + return { + tool_call_id: toolCallId, + approved: d.approved, + feedback: d.approved ? null : (d.feedback || 'Rejected by user'), + edited_script: editedScript, + }; + }); + + const ok = await approveTools(approvals); + if (ok) { + lockPanel(); + } else { + logger.error('Batch approval failed'); + submittingRef.current = false; + setIsSubmitting(false); + } + }, + [approveTools, lockPanel, getEditedScript], + ); + + const handleApproveAll = useCallback(() => { + const batch: Record = {}; + for (const t of pendingTools) batch[t.toolCallId] = { approved: true }; + sendBatch(batch); + }, [pendingTools, sendBatch]); + + const handleRejectAll = useCallback(() => { + const batch: Record = {}; + for (const t of pendingTools) batch[t.toolCallId] = { approved: false }; + sendBatch(batch); + }, [pendingTools, sendBatch]); + + const handleIndividualDecision = useCallback( + (toolCallId: string, approved: boolean, feedback?: string) => { + setDecisions(prev => { + const next = { ...prev, [toolCallId]: { approved, feedback } }; + if (pendingTools.every(t => next[t.toolCallId])) { + queueMicrotask(() => sendBatch(next)); + } + return next; + }); + }, + [pendingTools, sendBatch], + ); + + const undoDecision = useCallback((toolCallId: string) => { + setDecisions(prev => { + const next = { ...prev }; + delete next[toolCallId]; + return next; + }); + }, []); + + // ── Panel click handler ─────────────────────────────────────────── + const handleClick = useCallback( + (tool: DynamicToolPart) => { + const args = tool.input as Record | undefined; + const displayName = toolDisplayMap[tool.toolCallId] || tool.toolName; + + if (tool.toolName === 'hf_jobs' && args?.script) { + const hasOutput = (tool.state === 'output-available' || tool.state === 'output-error') && tool.output; + const scriptContent = getEditedScript(tool.toolCallId) || String(args.script); + setPanel( + { + title: displayName, + script: { content: scriptContent, language: 'python' }, + ...(hasOutput ? { output: { content: String(tool.output), language: 'markdown' } } : {}), + parameters: { tool_call_id: tool.toolCallId }, + }, + hasOutput ? 'output' : 'script', + ); + setRightPanelOpen(true); + setLeftSidebarOpen(false); + return; + } + + if ((tool.state === 'output-available' || tool.state === 'output-error') && tool.output) { + let language = 'text'; + const content = String(tool.output); + if (content.trim().startsWith('{') || content.trim().startsWith('[')) language = 'json'; + else if (content.includes('```')) language = 'markdown'; + + setPanel({ title: displayName, output: { content, language } }, 'output'); + setRightPanelOpen(true); + } else if (args) { + const content = JSON.stringify(args, null, 2); + setPanel({ title: displayName, output: { content, language: 'json' } }, 'output'); + setRightPanelOpen(true); + } + }, + [toolDisplayMap, setPanel, getEditedScript, setRightPanelOpen, setLeftSidebarOpen], + ); + + // ── Parse hf_jobs metadata from output ──────────────────────────── + function parseJobMeta(output: unknown): { jobUrl?: string; jobStatus?: string } { + if (typeof output !== 'string') return {}; + const urlMatch = output.match(/\*\*View at:\*\*\s*(https:\/\/[^\s\n]+)/); + const statusMatch = output.match(/\*\*Final Status:\*\*\s*([^\n]+)/); + return { + jobUrl: urlMatch?.[1], + jobStatus: statusMatch?.[1]?.trim(), + }; + } + + // ── Render ──────────────────────────────────────────────────────── + const decidedCount = pendingTools.filter(t => decisions[t.toolCallId]).length; + + return ( + + {/* Batch approval header — hidden once user starts deciding individually */} + {pendingTools.length > 1 && !isSubmitting && decidedCount === 0 && ( + + + {`${pendingTools.length} tool${pendingTools.length > 1 ? 's' : ''} pending`} + + + + + )} + + {/* Tool list */} + }> + {tools.map((tool) => { + const state = tool.state; + const isPending = state === 'approval-requested'; + const clickable = + state === 'output-available' || + state === 'output-error' || + !!tool.input; + const localDecision = decisions[tool.toolCallId]; + + const displayState = isPending && localDecision + ? (localDecision.approved ? 'input-available' : 'output-denied') + : state; + const label = statusLabel(displayState as ToolPartState); + + // Parse job metadata from hf_jobs output and store + const jobUrlFromStore = tool.toolName === 'hf_jobs' ? getJobUrl(tool.toolCallId) : undefined; + const jobMetaFromOutput = tool.toolName === 'hf_jobs' && tool.state === 'output-available' + ? parseJobMeta(tool.output) + : {}; + + // Combine job URL from store (available immediately) with output metadata (available at completion) + const jobMeta = { + jobUrl: jobUrlFromStore || jobMetaFromOutput.jobUrl, + jobStatus: jobMetaFromOutput.jobStatus, + }; + + return ( + + {/* Main tool row */} + !isPending && handleClick(tool)} + sx={{ + px: 1.5, + py: 1, + cursor: isPending ? 'default' : clickable ? 'pointer' : 'default', + transition: 'background-color 0.15s', + '&:hover': clickable && !isPending ? { bgcolor: 'var(--hover-bg)' } : {}, + }} + > + + + + {toolDisplayMap[tool.toolCallId] || tool.toolName} + + + {/* Status chip (non hf_jobs, or hf_jobs without final status) */} + {label && !(tool.toolName === 'hf_jobs' && jobMeta.jobStatus) && ( + + )} + + {/* HF Jobs: final status chip from job metadata */} + {tool.toolName === 'hf_jobs' && jobMeta.jobStatus && ( + + )} + + {/* View on HF link — single place, shown whenever URL is available */} + {tool.toolName === 'hf_jobs' && jobMeta.jobUrl && ( + e.stopPropagation()} + sx={{ + display: 'inline-flex', + alignItems: 'center', + gap: 0.5, + color: 'var(--accent-yellow)', + fontSize: '0.68rem', + textDecoration: 'none', + ml: 0.5, + '&:hover': { textDecoration: 'underline' }, + }} + > + + View on HF + + )} + + {clickable && !isPending && ( + + )} + + + + {/* Per-tool approval: undecided */} + {isPending && !localDecision && !isSubmitting && ( + + )} + + {/* Per-tool approval: locally decided (undo available) */} + {isPending && localDecision && !isSubmitting && ( + + + {localDecision.approved + ? 'Marked for approval' + : `Marked for rejection${localDecision.feedback ? `: ${localDecision.feedback}` : ''}`} + + + + )} + + ); + })} + + + ); +} diff --git a/frontend/src/components/Chat/UserMessage.tsx b/frontend/src/components/Chat/UserMessage.tsx new file mode 100644 index 0000000000000000000000000000000000000000..4bcea89867df46d92e2630b12ac9978bc5da476f --- /dev/null +++ b/frontend/src/components/Chat/UserMessage.tsx @@ -0,0 +1,105 @@ +import { Box, Stack, Typography, IconButton, Tooltip } from '@mui/material'; +import CloseIcon from '@mui/icons-material/Close'; +import type { UIMessage } from 'ai'; +import type { MessageMeta } from '@/types/agent'; + +interface UserMessageProps { + message: UIMessage; + isLastTurn?: boolean; + onUndoTurn?: () => void; + isProcessing?: boolean; +} + +function extractText(message: UIMessage): string { + return message.parts + .filter((p): p is Extract => p.type === 'text') + .map(p => p.text) + .join(''); +} + +export default function UserMessage({ + message, + isLastTurn = false, + onUndoTurn, + isProcessing = false, +}: UserMessageProps) { + const showUndo = isLastTurn && !isProcessing && !!onUndoTurn; + const text = extractText(message); + const meta = message.metadata as MessageMeta | undefined; + const timeStr = meta?.createdAt + ? new Date(meta.createdAt).toLocaleTimeString([], { hour: '2-digit', minute: '2-digit' }) + : null; + return ( + + {showUndo && ( + + + + + + + + )} + + + + {text} + + + {timeStr && ( + + {timeStr} + + )} + + + ); +} diff --git a/frontend/src/components/CodePanel/CodePanel.tsx b/frontend/src/components/CodePanel/CodePanel.tsx index 3cda17dab23b7e2bbc4940b3ca171dc9b776e661..4b38fbf5f7bda3664bfacac65bb20d67d2e6a626 100644 --- a/frontend/src/components/CodePanel/CodePanel.tsx +++ b/frontend/src/components/CodePanel/CodePanel.tsx @@ -1,138 +1,463 @@ -import { useRef, useEffect, useMemo } from 'react'; -import { Box, Typography, IconButton } from '@mui/material'; +import { useRef, useEffect, useMemo, useState, useCallback } from 'react'; +import { Box, Stack, Typography, IconButton, Button, Tooltip } from '@mui/material'; import CloseIcon from '@mui/icons-material/Close'; import RadioButtonUncheckedIcon from '@mui/icons-material/RadioButtonUnchecked'; import CheckCircleIcon from '@mui/icons-material/CheckCircle'; import PlayCircleOutlineIcon from '@mui/icons-material/PlayCircleOutline'; import CodeIcon from '@mui/icons-material/Code'; -import TerminalIcon from '@mui/icons-material/Terminal'; import ArticleIcon from '@mui/icons-material/Article'; +import EditIcon from '@mui/icons-material/Edit'; +import UndoIcon from '@mui/icons-material/Undo'; +import ContentCopyIcon from '@mui/icons-material/ContentCopy'; +import CheckIcon from '@mui/icons-material/Check'; import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter'; -import { vscDarkPlus } from 'react-syntax-highlighter/dist/esm/styles/prism'; +import { vscDarkPlus, vs } from 'react-syntax-highlighter/dist/esm/styles/prism'; import ReactMarkdown from 'react-markdown'; import remarkGfm from 'remark-gfm'; import { useAgentStore } from '@/store/agentStore'; import { useLayoutStore } from '@/store/layoutStore'; import { processLogs } from '@/utils/logProcessor'; +import type { PanelView } from '@/store/agentStore'; + +// ── Helpers ────────────────────────────────────────────────────── + +function PlanStatusIcon({ status }: { status: string }) { + if (status === 'completed') return ; + if (status === 'in_progress') return ; + return ; +} + +// ── Markdown styles (adapts via CSS vars) ──────────────────────── +const markdownSx = { + color: 'var(--text)', + fontSize: '13px', + lineHeight: 1.6, + '& p': { m: 0, mb: 1.5, '&:last-child': { mb: 0 } }, + '& pre': { + bgcolor: 'var(--code-bg)', + p: 1.5, + borderRadius: 1, + overflow: 'auto', + fontSize: '12px', + border: '1px solid var(--tool-border)', + }, + '& code': { + bgcolor: 'var(--hover-bg)', + px: 0.5, + py: 0.25, + borderRadius: 0.5, + fontSize: '12px', + fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', + }, + '& pre code': { bgcolor: 'transparent', p: 0 }, + '& a': { + color: 'var(--accent-yellow)', + textDecoration: 'none', + '&:hover': { textDecoration: 'underline' }, + }, + '& ul, & ol': { pl: 2.5, my: 1 }, + '& li': { mb: 0.5 }, + '& table': { + borderCollapse: 'collapse', + width: '100%', + my: 2, + fontSize: '12px', + fontFamily: 'ui-monospace, SFMono-Regular, Menlo, Monaco, monospace', + }, + '& th': { + borderBottom: '2px solid var(--border-hover)', + textAlign: 'left', + p: 1, + fontWeight: 600, + }, + '& td': { + borderBottom: '1px solid var(--tool-border)', + p: 1, + }, + '& h1, & h2, & h3, & h4': { mt: 2, mb: 1, fontWeight: 600 }, + '& h1': { fontSize: '1.25rem' }, + '& h2': { fontSize: '1.1rem' }, + '& h3': { fontSize: '1rem' }, + '& blockquote': { + borderLeft: '3px solid var(--accent-yellow)', + pl: 2, + ml: 0, + color: 'var(--muted-text)', + }, +} as const; + +// ── View toggle button ────────────────────────────────────────── + +function ViewToggle({ view, icon, label, isActive, onClick }: { + view: PanelView; + icon: React.ReactNode; + label: string; + isActive: boolean; + onClick: (v: PanelView) => void; +}) { + return ( + onClick(view)} + sx={{ + display: 'flex', + alignItems: 'center', + gap: 0.5, + px: 1.5, + py: 0.75, + borderRadius: 1, + cursor: 'pointer', + fontSize: '0.7rem', + fontWeight: 600, + textTransform: 'uppercase', + letterSpacing: '0.05em', + whiteSpace: 'nowrap', + color: isActive ? 'var(--text)' : 'var(--muted-text)', + bgcolor: isActive ? 'var(--tab-active-bg)' : 'transparent', + border: '1px solid', + borderColor: isActive ? 'var(--tab-active-border)' : 'transparent', + transition: 'all 0.15s ease', + '&:hover': { bgcolor: 'var(--tab-hover-bg)' }, + }} + > + {icon} + {label} + + ); +} + +// ── Component ──────────────────────────────────────────────────── export default function CodePanel() { - const { panelContent, panelTabs, activePanelTab, setActivePanelTab, removePanelTab, plan } = useAgentStore(); - const { setRightPanelOpen } = useLayoutStore(); + const { panelData, panelView, panelEditable, setPanelView, updatePanelScript, setEditedScript, plan } = + useAgentStore(); + const { setRightPanelOpen, themeMode } = useLayoutStore(); const scrollRef = useRef(null); + const textareaRef = useRef(null); + const [isEditing, setIsEditing] = useState(false); + const [editedContent, setEditedContent] = useState(''); + const [originalContent, setOriginalContent] = useState(''); + const [copied, setCopied] = useState(false); + + const isDark = themeMode === 'dark'; + const syntaxTheme = isDark ? vscDarkPlus : vs; + + const activeSection = panelView === 'script' ? panelData?.script : panelData?.output; + const hasScript = !!panelData?.script; + const hasOutput = !!panelData?.output; + const hasBothViews = hasScript && hasOutput; + + const isEditableScript = panelView === 'script' && panelEditable; + const hasUnsavedChanges = isEditing && editedContent !== originalContent; + + // Sync edited content when panel data changes + useEffect(() => { + if (panelData?.script?.content && panelView === 'script' && panelEditable) { + setOriginalContent(panelData.script.content); + if (!isEditing) { + setEditedContent(panelData.script.content); + } + } + }, [panelData?.script?.content, panelView, panelEditable, isEditing]); + + // Exit editing when switching away from script view or losing editable + useEffect(() => { + if (!isEditableScript && isEditing) { + setIsEditing(false); + } + }, [isEditableScript, isEditing]); + + const handleStartEdit = useCallback(() => { + if (panelData?.script?.content) { + setEditedContent(panelData.script.content); + setOriginalContent(panelData.script.content); + setIsEditing(true); + setTimeout(() => textareaRef.current?.focus(), 0); + } + }, [panelData?.script?.content]); + + const handleCancelEdit = useCallback(() => { + setEditedContent(originalContent); + setIsEditing(false); + }, [originalContent]); + + const handleSaveEdit = useCallback(() => { + if (editedContent !== originalContent) { + updatePanelScript(editedContent); + const toolCallId = panelData?.parameters?.tool_call_id as string | undefined; + if (toolCallId) { + setEditedScript(toolCallId, editedContent); + } + setOriginalContent(editedContent); + } + setIsEditing(false); + }, [panelData?.parameters?.tool_call_id, editedContent, originalContent, updatePanelScript, setEditedScript]); - // Get the active tab content, or fall back to panelContent for backwards compatibility - const activeTab = panelTabs.find(t => t.id === activePanelTab); - const currentContent = activeTab || panelContent; + const handleCopy = useCallback(async () => { + const contentToCopy = isEditing ? editedContent : (activeSection?.content || ''); + if (contentToCopy) { + try { + await navigator.clipboard.writeText(contentToCopy); + setCopied(true); + setTimeout(() => setCopied(false), 2000); + } catch (err) { + console.error('Failed to copy:', err); + } + } + }, [isEditing, editedContent, activeSection?.content]); const displayContent = useMemo(() => { - if (!currentContent?.content) return ''; - // Apply log processing only for text/logs, not for code/json - if (!currentContent.language || currentContent.language === 'text') { - return processLogs(currentContent.content); + if (!activeSection?.content) return ''; + if (!activeSection.language || activeSection.language === 'text') { + return processLogs(activeSection.content); } - return currentContent.content; - }, [currentContent?.content, currentContent?.language]); + return activeSection.content; + }, [activeSection?.content, activeSection?.language]); useEffect(() => { - // Auto-scroll only for logs tab - if (scrollRef.current && activePanelTab === 'logs') { + if (scrollRef.current && panelView === 'output') { scrollRef.current.scrollTop = scrollRef.current.scrollHeight; } - }, [displayContent, activePanelTab]); + }, [displayContent, panelView]); + + // ── Syntax-highlighted code block (DRY) ──────────────────────── + const renderSyntaxBlock = (language: string) => ( + + {displayContent} + + ); + + // ── Content renderer ─────────────────────────────────────────── + const renderContent = () => { + if (!activeSection?.content) { + return ( + + NO CONTENT TO DISPLAY + + ); + } + + if (isEditing && isEditableScript) { + return ( + + + {editedContent || ' '} + +