Spaces:
Sleeping
Sleeping
| """ | |
| Re-frame: Cognitive Reframing Assistant | |
| A Gradio-based CBT tool for identifying and reframing cognitive distortions | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import os | |
| import shutil | |
| from datetime import datetime | |
| from typing import Optional | |
| import gradio as gr | |
| from huggingface_hub import whoami as _hf_whoami | |
| # Import our CBT knowledge base | |
| from cbt_knowledge import ( | |
| COGNITIVE_DISTORTIONS, | |
| find_similar_situations, | |
| ) | |
| # Import UI components | |
| from ui_components.landing import create_landing_tab | |
| from ui_components.learn import create_learn_tab | |
| # Agentic LLM support (Hugging Face Inference API) | |
| try: | |
| from agents import CBTAgent | |
| AGENT_AVAILABLE = True | |
| except Exception as e: | |
| print(f"Error importing CBTAgent: {e}") | |
| CBTAgent = None # type: ignore | |
| AGENT_AVAILABLE = False | |
| # Load translations | |
| def load_translations(): | |
| """Load translation files for internationalization""" | |
| translations = {} | |
| for lang in ['en', 'es']: | |
| try: | |
| with open(f'locales/{lang}.json', encoding='utf-8') as f: | |
| translations[lang] = json.load(f) | |
| except FileNotFoundError: | |
| # Fallback to embedded translations if files don't exist | |
| print(f"Error loading translations for {lang}: FileNotFoundError") | |
| # Fallback translations | |
| if 'en' not in translations: | |
| translations['en'] = { | |
| "app_title": "🧠 re-frame: Cognitive Reframing Assistant", | |
| "app_description": "Using CBT principles to help you find balanced perspectives", | |
| "welcome": { | |
| "title": "Welcome to re-frame", | |
| "subtitle": "Find a kinder perspective", | |
| "description": ( | |
| "Using ideas from Cognitive Behavioral Therapy (CBT), we help you notice " | |
| "thinking patterns and explore gentler, more balanced perspectives." | |
| ), | |
| "how_it_works": "How it works", | |
| "step1": "Share your thoughts", | |
| "step1_desc": "Tell us what's on your mind", | |
| "step2": "Identify patterns", | |
| "step2_desc": "We'll help spot thinking traps", | |
| "step3": "Find balance", | |
| "step3_desc": "Explore alternative perspectives", | |
| "start_chat": "Start Chat", | |
| "disclaimer": "Important: This is a self-help tool, not therapy or medical advice.", | |
| "privacy": "Privacy: No data is stored beyond your session.", | |
| }, | |
| "chat": { | |
| "title": "Chat", | |
| "placeholder": "Share what's on your mind...", | |
| "send": "Send", | |
| "clear": "New Session", | |
| "thinking": "Thinking...", | |
| "distortions_found": "Thinking patterns identified:", | |
| "reframe_suggestion": "Alternative perspective:", | |
| "similar_situations": "Similar situations:", | |
| "try_this": "You might try:", | |
| }, | |
| "learn": { | |
| "title": "Learn", | |
| "select_distortion": "Select a thinking pattern to explore", | |
| "definition": "Definition", | |
| "examples": "Common Examples", | |
| "strategies": "Reframing Strategies", | |
| "actions": "Small Steps to Try", | |
| }, | |
| } | |
| if 'es' not in translations: | |
| translations['es'] = { | |
| "app_title": "🧠 re-frame: Asistente de Reencuadre Cognitivo", | |
| "app_description": ( | |
| "Usando principios de TCC para ayudarte a encontrar perspectivas equilibradas" | |
| ), | |
| "welcome": { | |
| "title": "Bienvenido a re-frame", | |
| "subtitle": "Encuentra una perspectiva más amable", | |
| "description": ( | |
| "Usando ideas de la Terapia Cognitivo-Conductual (TCC), te ayudamos a notar " | |
| "patrones de pensamiento y explorar perspectivas más gentiles y equilibradas." | |
| ), | |
| "how_it_works": "Cómo funciona", | |
| "step1": "Comparte tus pensamientos", | |
| "step1_desc": "Cuéntanos qué piensas", | |
| "step2": "Identifica patrones", | |
| "step2_desc": "Te ayudamos a detectar trampas mentales", | |
| "step3": "Encuentra balance", | |
| "step3_desc": "Explora perspectivas alternativas", | |
| "start_chat": "Iniciar Chat", | |
| "disclaimer": ( | |
| "Importante: Esta es una herramienta de autoayuda, " | |
| "no terapia ni consejo médico." | |
| ), | |
| "privacy": "Privacidad: No se almacenan datos más allá de tu sesión.", | |
| }, | |
| "chat": { | |
| "title": "Chat", | |
| "placeholder": "Comparte lo que piensas...", | |
| "send": "Enviar", | |
| "clear": "Nueva Sesión", | |
| "thinking": "Pensando...", | |
| "distortions_found": "Patrones de pensamiento identificados:", | |
| "reframe_suggestion": "Perspectiva alternativa:", | |
| "similar_situations": "Situaciones similares:", | |
| "try_this": "Podrías intentar:", | |
| }, | |
| "learn": { | |
| "title": "Aprender", | |
| "select_distortion": "Selecciona un patrón de pensamiento para explorar", | |
| "definition": "Definición", | |
| "examples": "Ejemplos Comunes", | |
| "strategies": "Estrategias de Reencuadre", | |
| "actions": "Pequeños Pasos a Intentar", | |
| }, | |
| } | |
| return translations | |
| class CBTChatbot: | |
| """Main chatbot class for handling CBT conversations""" | |
| def __init__(self, language='en', memory_size: int = 6): | |
| self.language = language | |
| self.translations = load_translations() | |
| self.t = self.translations.get(language, self.translations['en']) | |
| self.conversation_history: list[list[str]] = [] | |
| self.identified_distortions: list[tuple[str, float]] = [] | |
| self.memory_size = max(2, int(memory_size)) | |
| def _history_to_context(self, history) -> list[dict]: | |
| """Convert Chatbot history to agent context. | |
| Supports both legacy [[user, assistant], ...] and new Gradio | |
| messages format [{role, content}, ...]. Returns a list of | |
| {user, assistant} dicts capped to memory_size. | |
| """ | |
| ctx: list[dict] = [] | |
| if not history: | |
| return ctx | |
| # New messages format | |
| if isinstance(history, list) and history and isinstance(history[0], dict): | |
| pending_user = None | |
| for msg in history: | |
| role = str(msg.get("role", "")) | |
| content = str(msg.get("content", "")) | |
| if role == "user": | |
| pending_user = content | |
| elif role == "assistant" and pending_user is not None: | |
| ctx.append({"user": pending_user, "assistant": content}) | |
| pending_user = None | |
| return ctx[-self.memory_size :] | |
| # Legacy tuple format | |
| for turn in history: | |
| if isinstance(turn, (list, tuple)) and len(turn) == 2: | |
| ctx.append({"user": turn[0] or "", "assistant": turn[1] or ""}) | |
| return ctx[-self.memory_size :] | |
| def process_message( | |
| self, | |
| message: str, | |
| history: list[list[str]], | |
| use_agent: bool = False, | |
| agent: Optional[CBTAgent] = None, | |
| ) -> tuple[list[list[str]], str, str, str]: | |
| """ | |
| Process user message and generate response with CBT analysis | |
| Returns: | |
| - Updated chat history | |
| - Identified distortions display | |
| - Reframe suggestion | |
| - Similar situations display | |
| """ | |
| if not message or message.strip() == "": | |
| return history or [], "", "", "" | |
| # Add user message to history | |
| history = history or [] | |
| # Agentic path only: remove non-LLM fallback | |
| if use_agent and agent is not None: | |
| try: | |
| analysis = agent.analyze_thought(message) | |
| response = agent.generate_response( | |
| message, context=self._history_to_context(history) | |
| ) | |
| distortions_display = self._format_distortions(analysis.get("distortions", [])) | |
| reframe_display = analysis.get("reframe", "") | |
| primary = analysis.get("distortions", []) | |
| primary_code = primary[0][0] if primary else None | |
| situations_display = ( | |
| self._format_similar_situations(primary_code) if primary_code else "" | |
| ) | |
| except Exception as e: | |
| # Do not fallback to local heuristics | |
| history.append([message, f"Agent error: {e}"]) | |
| return history, "", "", "" | |
| else: | |
| # Non-agent mode disabled | |
| history.append( | |
| [message, "Agent-only mode: please enable the agent to generate responses."] | |
| ) | |
| return history, "", "", "" | |
| # Update history with memory cap | |
| history.append([message, response]) | |
| if len(history) > self.memory_size: | |
| history = history[-self.memory_size :] | |
| return history, distortions_display, reframe_display, situations_display | |
| def _format_distortions(self, distortions: list[tuple[str, float]]) -> str: | |
| """Format detected distortions for display""" | |
| if not distortions: | |
| return "" | |
| lines = [f"### {self.t['chat']['distortions_found']}\n"] | |
| for code, confidence in distortions[:3]: # Show top 3 | |
| for _key, info in COGNITIVE_DISTORTIONS.items(): | |
| if info['code'] == code: | |
| lines.append(f"**{info['name']}** ({confidence * 100:.0f}% match)") | |
| lines.append(f"*{info['definition']}*\n") | |
| break | |
| return "\n".join(lines) | |
| def _format_similar_situations(self, distortion_code: str) -> str: | |
| """Format similar situations for display""" | |
| situations = find_similar_situations(distortion_code, num_situations=2) | |
| if not situations: | |
| return "" | |
| lines = [f"### {self.t['chat']['similar_situations']}\n"] | |
| for i, situation in enumerate(situations, 1): | |
| lines.append(f"**Example {i}:** {situation['situation']}") | |
| lines.append(f"*Distorted:* \"{situation['distorted']}\"") | |
| lines.append(f"*Reframed:* \"{situation['reframed']}\"\n") | |
| return "\n".join(lines) | |
| def clear_session(self): | |
| """Clear the conversation session""" | |
| self.conversation_history = [] | |
| self.identified_distortions = [] | |
| return [], "", "", "" | |
| def create_app(language='en'): | |
| """Create and configure the Gradio application""" | |
| # Initialize chatbot | |
| chatbot = CBTChatbot(language) | |
| t = chatbot.t | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
| } | |
| .gr-button-primary { | |
| background-color: #2563eb !important; | |
| border-color: #2563eb !important; | |
| } | |
| .gr-button-primary:hover { | |
| background-color: #1e40af !important; | |
| } | |
| .info-box { | |
| background-color: #f0f9ff; | |
| border: 1px solid #3b82f6; | |
| border-radius: 8px; | |
| padding: 12px; | |
| margin: 8px 0; | |
| } | |
| """ | |
| with gr.Blocks(title=t['app_title'], theme=gr.themes.Soft(), css=custom_css) as app: | |
| gr.Markdown(f"# {t['app_title']}") | |
| gr.Markdown(f"*{t['app_description']}*") | |
| with gr.Tabs(): | |
| # Welcome Tab | |
| with gr.Tab(t['welcome']['title']): | |
| create_landing_tab(t['welcome']) | |
| # Chat Tab | |
| with gr.Tab(t['chat']['title']): | |
| # Settings row (agentic only) | |
| with gr.Row(): | |
| gr.LoginButton() | |
| billing_notice = gr.Markdown("") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot_ui = gr.Chatbot(height=400, label="Conversation", type='messages') | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| label="", placeholder=t['chat']['placeholder'], scale=4 | |
| ) | |
| send_btn = gr.Button(t['chat']['send'], variant="primary", scale=1) | |
| clear_btn = gr.Button(t['chat']['clear'], variant="secondary") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Analysis") | |
| distortions_output = gr.Markdown(label="Patterns Detected") | |
| reframe_output = gr.Markdown(label="Reframe Suggestion") | |
| situations_output = gr.Markdown(label="Similar Situations") | |
| # Owner-only Admin controls (gated at load) | |
| admin_accordion = gr.Accordion( | |
| "Owner Controls", open=False, visible=False | |
| ) | |
| with admin_accordion: | |
| # Locked message for non-owners (kept hidden unless needed) | |
| chat_locked_panel = gr.Markdown( | |
| "### Owner only\nPlease log in with your Hugging Face account.", | |
| visible=False, | |
| ) | |
| chat_admin_panel = gr.Column(visible=False) | |
| with chat_admin_panel: | |
| gr.Markdown("## Admin Dashboard") | |
| admin_summary = gr.Markdown("") | |
| admin_limit_info = gr.Markdown("") | |
| # Owner-only model selection | |
| model_dropdown = gr.Dropdown( | |
| label="Model (HF)", | |
| choices=[ | |
| "meta-llama/Llama-3.1-8B-Instruct", | |
| "meta-llama/Llama-3.1-70B-Instruct", | |
| "Qwen/Qwen2.5-7B-Instruct", | |
| "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "google/gemma-2-9b-it", | |
| ], | |
| value=os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct"), | |
| allow_custom_value=True, | |
| info="Only visible to owner. Requires HF Inference API token.", | |
| ) | |
| with gr.Row(): | |
| override_tb = gr.Textbox( | |
| label="Per-user interaction limit override (blank to clear)" | |
| ) | |
| set_override_btn = gr.Button( | |
| "Set Limit Override", variant="secondary" | |
| ) | |
| refresh_btn = gr.Button( | |
| "Refresh Metrics", variant="secondary" | |
| ) | |
| gr.Markdown("### Debug") | |
| owner_identity_md = gr.Markdown("") | |
| with gr.Row(): | |
| identity_btn = gr.Button( | |
| "Refresh Identity", variant="secondary" | |
| ) | |
| storage_btn = gr.Button( | |
| "Check /data", variant="secondary" | |
| ) | |
| storage_info_md = gr.Markdown("") | |
| # Internal state for agent instance, selected model, and agentic enable flag | |
| agent_state = gr.State(value=None) | |
| model_state = gr.State( | |
| value=os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") | |
| ) | |
| agentic_enabled_state = gr.State(value=True) | |
| # Admin runtime settings (e.g., per-user limit override) | |
| admin_state = gr.State(value={"per_user_limit_override": None}) | |
| # Connect chat interface (streaming) | |
| def _ensure_hf_token_env(): | |
| # Honor either HF_TOKEN or HUGGINGFACEHUB_API_TOKEN | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| if token and not os.getenv("HUGGINGFACEHUB_API_TOKEN"): | |
| os.environ["HUGGINGFACEHUB_API_TOKEN"] = token | |
| def _stream_chunks(text: str, chunk_words: int = 12): | |
| words = (text or "").split() | |
| buf = [] | |
| for i, w in enumerate(words, 1): | |
| buf.append(w) | |
| if i % chunk_words == 0: | |
| yield " ".join(buf) | |
| buf = [] | |
| if buf: | |
| yield " ".join(buf) | |
| # Budget guard helpers | |
| def _get_call_log_path(): | |
| return os.getenv("AGENT_CALL_LOG_PATH", "/tmp/agent_calls.json") | |
| # Simple privacy-preserving metrics (no raw PII/content) | |
| def _get_metrics_path(): | |
| return os.getenv("APP_METRICS_PATH", "/tmp/app_metrics.json") | |
| def _load_call_log(): | |
| try: | |
| with open(_get_call_log_path(), encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def _load_metrics(): | |
| try: | |
| with open(_get_metrics_path(), encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def _save_call_log(data): | |
| try: | |
| with open(_get_call_log_path(), "w", encoding="utf-8") as f: | |
| json.dump(data, f) | |
| except Exception: | |
| pass | |
| def _save_metrics(data): | |
| try: | |
| with open(_get_metrics_path(), "w", encoding="utf-8") as f: | |
| json.dump(data, f) | |
| except Exception: | |
| pass | |
| def _today_key(): | |
| return datetime.utcnow().strftime("%Y-%m-%d") | |
| # Metrics helpers | |
| def _metrics_today(): | |
| m = _load_metrics() | |
| return m.get(_today_key(), {}) | |
| def _write_metrics_today(d): | |
| m = _load_metrics() | |
| m[_today_key()] = d | |
| _save_metrics(m) | |
| def _inc_metric(key: str, inc: int = 1): | |
| d = _metrics_today() | |
| d[key] = int(d.get(key, 0)) + inc | |
| _write_metrics_today(d) | |
| def _record_distortion_counts(codes: list[str]): | |
| if not codes: | |
| return | |
| d = _metrics_today() | |
| dist = d.get("distortion_counts", {}) | |
| if not isinstance(dist, dict): | |
| dist = {} | |
| for c in codes: | |
| dist[c] = int(dist.get(c, 0)) + 1 | |
| d["distortion_counts"] = dist | |
| _write_metrics_today(d) | |
| def _record_response_chars(n: int): | |
| d = _metrics_today() | |
| d["response_chars_total"] = int(d.get("response_chars_total", 0)) + max( | |
| 0, int(n) | |
| ) | |
| d["response_count"] = int(d.get("response_count", 0)) + 1 | |
| _write_metrics_today(d) | |
| def _calls_today(): | |
| data = _load_call_log() | |
| day = _today_key() | |
| val = data.get(day, 0) | |
| # Backward compatible: handle both scalar int and per-day dict blob | |
| if isinstance(val, dict): | |
| return int(val.get("calls", 0)) | |
| try: | |
| return int(val) | |
| except Exception: | |
| return 0 | |
| def _inc_calls_today(): | |
| data = _load_call_log() | |
| day = _today_key() | |
| day_blob = data.get(day, {}) if isinstance(data.get(day, {}), dict) else {} | |
| try: | |
| day_blob["calls"] = int(day_blob.get("calls", 0)) + 1 | |
| except Exception: | |
| day_blob["calls"] = 1 | |
| data[day] = day_blob | |
| _save_call_log(data) | |
| def _agentic_budget_allows(): | |
| hard = os.getenv("HF_AGENT_HARD_DISABLE", "").lower() in ("1", "true", "yes") | |
| if hard: | |
| return False | |
| limit = os.getenv("HF_AGENT_MAX_CALLS_PER_DAY") | |
| if not limit: | |
| return True | |
| try: | |
| limit_i = int(limit) | |
| except Exception: | |
| return True | |
| return _calls_today() < max(0, limit_i) | |
| def respond_stream( | |
| message, | |
| history, | |
| model_value, | |
| agent_obj, | |
| agentic_ok, | |
| admin_settings, | |
| request: "gr.Request", | |
| profile: "gr.OAuthProfile | None" = None, | |
| ): | |
| if not message: | |
| yield history, "", "", "", agent_obj, "", agentic_ok | |
| return | |
| budget_ok = _agentic_budget_allows() | |
| notice = "" | |
| # Compute user id (salted hash) for per-user quotas | |
| def _user_id(req: "gr.Request", prof: "gr.OAuthProfile | None") -> str: | |
| try: | |
| salt = os.getenv("USAGE_SALT", "reframe_salt") | |
| # Prefer OAuth profile when available | |
| if prof is not None: | |
| # Try common fields in OAuth profile | |
| username = None | |
| for key in ( | |
| "preferred_username", | |
| "username", | |
| "login", | |
| "name", | |
| "sub", | |
| "id", | |
| ): | |
| try: | |
| if hasattr(prof, key): | |
| username = getattr(prof, key) | |
| elif isinstance(prof, dict) and key in prof: | |
| username = prof[key] | |
| if username: | |
| break | |
| except Exception as e: | |
| print(f"Error getting username from profile: {e}") | |
| pass | |
| raw = f"oauth:{username or 'unknown'}" | |
| # req is expected to be provided by Gradio | |
| elif getattr(req, "username", None): | |
| raw = f"user:{req.username}" | |
| else: | |
| ip = getattr(getattr(req, "client", None), "host", "?") | |
| ua = ( | |
| dict(req.headers).get("user-agent", "?") | |
| if getattr(req, "headers", None) | |
| else "?" | |
| ) | |
| sess = getattr(req, "session_hash", None) or "?" | |
| raw = f"ipua:{ip}|{ua}|{sess}" | |
| return hashlib.sha256(f"{salt}|{raw}".encode()).hexdigest() | |
| except Exception as e: | |
| print(f"Error getting user id: {e}") | |
| return "anon" | |
| user_id = _user_id(request, profile) | |
| # Helper functions for tracking per-user interactions | |
| def _interactions_today(uid: str) -> int: | |
| data = _load_call_log() | |
| day = _today_key() | |
| day_blob = data.get(day, {}) if isinstance(data.get(day, {}), dict) else {} | |
| inter = ( | |
| day_blob.get("interactions", {}) | |
| if isinstance(day_blob.get("interactions", {}), dict) | |
| else {} | |
| ) | |
| val = inter.get(uid, 0) | |
| if isinstance(val, (str, int, float)): | |
| try: | |
| return int(val) | |
| except Exception as e: | |
| print(f"Error getting interactions today: {e}") | |
| return 0 | |
| return 0 | |
| def _inc_interactions_today(uid: str): | |
| data = _load_call_log() | |
| day = _today_key() | |
| day_blob = data.get(day, {}) if isinstance(data.get(day, {}), dict) else {} | |
| inter = ( | |
| day_blob.get("interactions", {}) | |
| if isinstance(day_blob.get("interactions", {}), dict) | |
| else {} | |
| ) | |
| inter[uid] = int(inter.get(uid, 0)) + 1 | |
| day_blob["interactions"] = inter | |
| data[day] = day_blob | |
| _save_call_log(data) | |
| # Determine how many interactions the user has already had today | |
| try: | |
| interactions_before = _interactions_today(user_id) | |
| except Exception as e: | |
| print(f"Error getting interactions today: {e}") | |
| interactions_before = 0 | |
| # Per-user interaction quota (counts 1 per message) | |
| max_interactions_env = os.getenv("HF_AGENT_MAX_INTERACTIONS_PER_USER") | |
| try: | |
| # Default to a generous 12 if not configured | |
| per_user_limit_env = ( | |
| int(max_interactions_env) if max_interactions_env else 12 | |
| ) | |
| except Exception: | |
| per_user_limit_env = 12 | |
| per_user_limit = per_user_limit_env | |
| # Admin override (runtime) | |
| try: | |
| override = None | |
| if isinstance(admin_settings, dict): | |
| override = admin_settings.get("per_user_limit_override") | |
| if isinstance(override, int | float) and int(override) > 0: | |
| per_user_limit = int(override) | |
| except Exception: | |
| pass | |
| if per_user_limit is not None and interactions_before >= max( | |
| 0, per_user_limit | |
| ): | |
| _inc_metric("blocked_interactions") | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| f"Per-user limit reached ({per_user_limit} interactions).", | |
| agentic_ok, | |
| ) | |
| return | |
| if not AGENT_AVAILABLE: | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| "Agent not available. Check HF token and model name.", | |
| agentic_ok, | |
| ) | |
| return | |
| if not agentic_ok: | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| "Agentic mode disabled due to a prior quota/billing error.", | |
| agentic_ok, | |
| ) | |
| return | |
| if not budget_ok: | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| "Daily budget reached. Set HF_AGENT_MAX_CALLS_PER_DAY or try tomorrow.", | |
| agentic_ok, | |
| ) | |
| return | |
| # Count one interaction for this user upfront | |
| _inc_interactions_today(user_id) | |
| interactions_after = interactions_before + 1 | |
| # Lazily initialize agent if requested | |
| _ensure_hf_token_env() | |
| if agent_obj is None: | |
| try: | |
| agent_obj = CBTAgent(model_name=model_value) | |
| except Exception as e: | |
| err = str(e) | |
| yield ( | |
| history or [], | |
| "", | |
| "", | |
| "", | |
| agent_obj, | |
| f"Agent failed to initialize: {err}", | |
| agentic_ok, | |
| ) | |
| return | |
| # Prepare side panels first for a snappy UI | |
| try: | |
| analysis = agent_obj.analyze_thought(message) | |
| distortions_display = chatbot._format_distortions( | |
| analysis.get("distortions", []) | |
| ) | |
| reframe_display = analysis.get("reframe", "") | |
| primary = analysis.get("distortions", []) | |
| primary_code = primary[0][0] if primary else None | |
| situations_display = ( | |
| chatbot._format_similar_situations(primary_code) if primary_code else "" | |
| ) | |
| # Metrics: record this interaction | |
| _inc_metric("total_interactions") | |
| _record_distortion_counts([c for c, _ in analysis.get("distortions", [])]) | |
| _inc_calls_today() | |
| except Exception as e: | |
| distortions_display = reframe_display = situations_display = "" | |
| # Detect quota/billing signals and permanently disable agent for this run | |
| msg = str(e).lower() | |
| if any( | |
| k in msg | |
| for k in [ | |
| "quota", | |
| "limit", | |
| "billing", | |
| "payment", | |
| "insufficient", | |
| "402", | |
| "429", | |
| ] | |
| ): | |
| agentic_ok = False | |
| notice = "Agentic mode disabled due to quota/billing error." | |
| else: | |
| notice = f"Agent analysis failed: {e}" | |
| _inc_metric("agent_errors") | |
| yield ( | |
| history or [], | |
| distortions_display, | |
| reframe_display, | |
| situations_display, | |
| agent_obj, | |
| notice, | |
| agentic_ok, | |
| ) | |
| return | |
| # Start streaming the assistant reply (messages format) | |
| history = history or [] | |
| # Append user message then assistant placeholder | |
| try: | |
| history.append({"role": "user", "content": message}) | |
| except Exception: | |
| history = list(history) + [{"role": "user", "content": message}] | |
| history.append({"role": "assistant", "content": ""}) | |
| # Optional prune to last N pairs to keep UI light | |
| try: | |
| pairs = chatbot._history_to_context(history[:-1]) | |
| pruned: list[dict] = [] | |
| for p in pairs: | |
| pruned.append({"role": "user", "content": p.get("user", "")}) | |
| pruned.append({"role": "assistant", "content": p.get("assistant", "")}) | |
| pruned.append({"role": "user", "content": message}) | |
| pruned.append({"role": "assistant", "content": ""}) | |
| history = pruned | |
| except Exception: | |
| pass | |
| # Choose response source: true token streaming via HF Inference | |
| try: | |
| _inc_calls_today() | |
| stream = getattr(agent_obj, "stream_generate_response", None) | |
| if callable(stream): | |
| token_iter = stream( | |
| message, context=chatbot._history_to_context(history[:-1]) | |
| ) | |
| else: | |
| # Fallback to non-streaming | |
| def _one_shot(): | |
| yield agent_obj.generate_response( | |
| message, context=chatbot._history_to_context(history[:-1]) | |
| ) | |
| token_iter = _one_shot() | |
| except Exception as e: | |
| _inc_metric("agent_errors") | |
| yield ( | |
| history, | |
| distortions_display, | |
| reframe_display, | |
| situations_display, | |
| agent_obj, | |
| f"Agent response failed: {e}", | |
| agentic_ok, | |
| ) | |
| return | |
| acc = "" | |
| for chunk in token_iter: | |
| if not chunk: | |
| continue | |
| acc += str(chunk) | |
| if isinstance(history[-1], dict): | |
| history[-1]["content"] = acc | |
| else: | |
| try: | |
| history[-1][1] = acc | |
| except Exception: | |
| pass | |
| # yield streaming frame | |
| yield ( | |
| history, | |
| distortions_display, | |
| reframe_display, | |
| situations_display, | |
| agent_obj, | |
| notice, | |
| agentic_ok, | |
| ) | |
| # Final yield ensures the last state is consistent | |
| _record_response_chars(len(acc)) | |
| # Show remaining interactions | |
| try: | |
| remaining = ( | |
| None | |
| if per_user_limit is None | |
| else max(0, per_user_limit - interactions_after) | |
| ) | |
| if remaining is not None: | |
| notice = ( | |
| notice + f"\nRemaining interactions today: {remaining}" | |
| ).strip() | |
| except Exception: | |
| pass | |
| yield ( | |
| history, | |
| distortions_display, | |
| reframe_display, | |
| situations_display, | |
| agent_obj, | |
| notice, | |
| agentic_ok, | |
| ) | |
| def clear_input(): | |
| return "" | |
| msg_input.submit( | |
| respond_stream, | |
| inputs=[ | |
| msg_input, | |
| chatbot_ui, | |
| model_state, | |
| agent_state, | |
| agentic_enabled_state, | |
| admin_state, | |
| ], | |
| outputs=[ | |
| chatbot_ui, | |
| distortions_output, | |
| reframe_output, | |
| situations_output, | |
| agent_state, | |
| billing_notice, | |
| agentic_enabled_state, | |
| ], | |
| ).then(fn=clear_input, outputs=[msg_input]) | |
| send_btn.click( | |
| fn=respond_stream, | |
| inputs=[ | |
| msg_input, | |
| chatbot_ui, | |
| model_state, | |
| agent_state, | |
| agentic_enabled_state, | |
| admin_state, | |
| ], | |
| outputs=[ | |
| chatbot_ui, | |
| distortions_output, | |
| reframe_output, | |
| situations_output, | |
| agent_state, | |
| billing_notice, | |
| agentic_enabled_state, | |
| ], | |
| ).then(clear_input, outputs=[msg_input]) | |
| def _clear_session_and_notice(): | |
| h, d, r, s = chatbot.clear_session() | |
| return h, d, r, s, "" | |
| clear_btn.click( | |
| fn=_clear_session_and_notice, | |
| outputs=[ | |
| chatbot_ui, | |
| distortions_output, | |
| reframe_output, | |
| situations_output, | |
| billing_notice, | |
| ], | |
| ) | |
| # Learn Tab | |
| with gr.Tab(t['learn']['title']): | |
| create_learn_tab(t['learn'], COGNITIVE_DISTORTIONS) | |
| def _owner_is(profile: gr.OAuthProfile | None, request: gr.Request | None = None) -> bool: | |
| try: | |
| # Prefer explicit OWNER_USER, fallback to the Space author | |
| # (useful if OWNER_USER not set) | |
| owner = ( | |
| os.getenv("OWNER_USER") | |
| or os.getenv("SPACE_AUTHOR_NAME") | |
| or "" | |
| ).strip().lower() | |
| if not owner: | |
| return False | |
| # Try common OAuth profile fields | |
| username = None | |
| for key in ("preferred_username", "username", "login", "name", "sub", "id"): | |
| try: | |
| if hasattr(profile, key): | |
| username = getattr(profile, key) | |
| elif isinstance(profile, dict) and key in profile: | |
| username = profile[key] | |
| if username: | |
| break | |
| except Exception as e: | |
| print(f"Error getting username from profile: {e}") | |
| pass | |
| # Fallback to request.username provided by Gradio when OAuth is enabled | |
| if not username and request is not None: | |
| try: | |
| username = getattr(request, "username", None) | |
| except Exception as e: | |
| print(f"Error getting username from request: {e}") | |
| username = None | |
| if not username: | |
| return False | |
| return str(username).lower() == owner | |
| except Exception as e: | |
| print(f"Error checking owner: {e}") | |
| return False | |
| def _metrics_paths(): | |
| return ( | |
| os.getenv("APP_METRICS_PATH", "/tmp/app_metrics.json"), | |
| os.getenv("AGENT_CALL_LOG_PATH", "/tmp/agent_calls.json"), | |
| ) | |
| def _read_json(path: str) -> dict: | |
| try: | |
| with open(path, encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| return {} | |
| def _summarize_metrics_md() -> str: | |
| mpath, _ = _metrics_paths() | |
| data = _read_json(mpath) | |
| if not data: | |
| return "No metrics recorded yet." | |
| # Summarize last 7 days | |
| days = sorted(data.keys())[-7:] | |
| total = blocked = errors = resp_chars = resp_count = 0 | |
| dist_counts: dict[str, int] = {} | |
| for d in days: | |
| day = data.get(d, {}) or {} | |
| total += int(day.get("total_interactions", 0)) | |
| blocked += int(day.get("blocked_interactions", 0)) | |
| errors += int(day.get("agent_errors", 0)) | |
| resp_chars += int(day.get("response_chars_total", 0)) | |
| resp_count += int(day.get("response_count", 0)) | |
| dist = day.get("distortion_counts", {}) | |
| if isinstance(dist, dict): | |
| for k, v in dist.items(): | |
| dist_counts[k] = int(dist_counts.get(k, 0)) + int(v) | |
| avg_len = (resp_chars / resp_count) if resp_count else 0 | |
| top = sorted(dist_counts.items(), key=lambda x: x[1], reverse=True)[:5] | |
| lines = [ | |
| "### Usage (last 7 days)", | |
| f"- Total interactions: {total}", | |
| f"- Blocked interactions: {blocked}", | |
| f"- Agent errors: {errors}", | |
| f"- Avg response length: {avg_len:.0f} chars", | |
| "", | |
| "### Top cognitive patterns", | |
| ] | |
| if top: | |
| for k, v in top: | |
| lines.append(f"- {k}: {v}") | |
| else: | |
| lines.append("- None recorded") | |
| return "\n".join(lines) | |
| def _limit_info_md(settings: dict | None) -> str: | |
| env_val = os.getenv("HF_AGENT_MAX_INTERACTIONS_PER_USER") | |
| try: | |
| env_limit = int(env_val) if env_val else 12 | |
| except Exception: | |
| env_limit = 12 | |
| override = None | |
| if isinstance(settings, dict): | |
| override = settings.get("per_user_limit_override") | |
| effective = ( | |
| int(override) | |
| if isinstance(override, int | float) and int(override) > 0 | |
| else env_limit | |
| ) | |
| return ( | |
| f"Per-user daily limit: {effective} (env: {env_limit}, override: " | |
| f"{override if override else 'None'})" | |
| ) | |
| def admin_set_limit(override_text: str, settings: dict | None): | |
| # Only update runtime state; does not change env var | |
| try: | |
| if settings is None or not isinstance(settings, dict): | |
| settings = {"per_user_limit_override": None} | |
| override = None | |
| if override_text and override_text.strip(): | |
| override = int(override_text.strip()) | |
| if override <= 0: | |
| override = None | |
| settings["per_user_limit_override"] = override | |
| except Exception as e: | |
| print(f"Error setting limit override: {e}") | |
| settings = {"per_user_limit_override": None} | |
| return settings, _limit_info_md(settings) | |
| def admin_refresh(): | |
| return _summarize_metrics_md() | |
| def _profile_username(profile: gr.OAuthProfile | None, request: gr.Request | None = None) -> str: | |
| try: | |
| for key in ("preferred_username", "username", "login", "name", "sub", "id"): | |
| if hasattr(profile, key): | |
| v = getattr(profile, key) | |
| if v: | |
| return str(v) | |
| elif isinstance(profile, dict) and key in profile and profile[key]: | |
| return str(profile[key]) | |
| except Exception as e: | |
| print(f"Error getting username from profile: {e}") | |
| pass | |
| try: | |
| if request is not None and getattr(request, "username", None): | |
| return str(request.username) | |
| except Exception as e: | |
| print(f"Error getting username from request: {e}") | |
| pass | |
| return "unknown" | |
| def identity_refresh(profile: gr.OAuthProfile | None, request: gr.Request | None = None): | |
| viewer = _profile_username(profile, request) | |
| visible = _owner_is(profile, request) | |
| token_info = "" | |
| try: # local import | |
| info = _hf_whoami() | |
| uname = info.get("name") or info.get("fullname") or "?" | |
| ttype = info.get("type", "?") | |
| orgs = ", ".join([o.get("name", "?") for o in info.get("orgs", [])]) | |
| token_info = f"Token user: `{uname}` (type: {ttype}); orgs: [{orgs}]" | |
| except Exception as e: | |
| print(f"Error getting token info whoami: {e}") | |
| token_info = f"Token whoami failed: {e}" | |
| return ( | |
| f"Logged in as (OAuth): `{viewer}`\n\n" | |
| f"OWNER_USER: `{(os.getenv('OWNER_USER') or '').strip()}`\n" | |
| f"SPACE_AUTHOR_NAME: `{(os.getenv('SPACE_AUTHOR_NAME') or '').strip()}`\n" | |
| f"Owner match: {'yes' if visible else 'no'}\n\n" | |
| f"{token_info}" | |
| ) | |
| def storage_check(): | |
| try: | |
| path = "/data" | |
| exists = os.path.exists(path) | |
| lines = [f"Path: `{path}` — {'present' if exists else 'absent'}"] | |
| if exists: | |
| total, used, free = shutil.disk_usage(path) | |
| gb = 1024 ** 3 | |
| lines.append( | |
| f"Disk: total {total/gb:.1f} GB, used {used/gb:.1f} GB, free {free/gb:.1f} GB" | |
| ) | |
| try: | |
| entries = sorted(os.listdir(path))[:20] | |
| if entries: | |
| lines.append("Entries: " + ", ".join(entries)) | |
| except Exception: | |
| pass | |
| hf_home = os.getenv("HF_HOME", "(not set)") | |
| lines.append(f"HF_HOME: `{hf_home}`") | |
| return "\n".join(lines) | |
| except Exception as e: | |
| return f"/data check failed: {e}" | |
| # Wire admin interactions | |
| model_dropdown.change(fn=lambda v: v, inputs=[model_dropdown], outputs=[model_state]) | |
| set_override_btn.click( | |
| fn=admin_set_limit, | |
| inputs=[override_tb, admin_state], | |
| outputs=[admin_state, admin_limit_info], | |
| ) | |
| refresh_btn.click(fn=admin_refresh, outputs=[admin_summary]) | |
| identity_btn.click(fn=identity_refresh, outputs=[owner_identity_md]) | |
| storage_btn.click(fn=storage_check, outputs=[storage_info_md]) | |
| # Gate admin panel visibility on load (OAuth) | |
| try: | |
| # Also populate identity + storage placeholders | |
| def _load(profile: gr.OAuthProfile | None, request: gr.Request | None = None): | |
| visible = _owner_is(profile, request) | |
| ident = identity_refresh(profile, request) if visible else "" | |
| return ( | |
| gr.update(visible=visible), | |
| gr.update(visible=visible), | |
| gr.update(visible=not visible), | |
| _summarize_metrics_md() if visible else "", | |
| _limit_info_md(admin_state.value if hasattr(admin_state, "value") else None) | |
| if visible | |
| else "", | |
| ident, | |
| "", | |
| ) | |
| app.load( | |
| _load, | |
| outputs=[admin_accordion, chat_admin_panel, chat_locked_panel, admin_summary, | |
| admin_limit_info, owner_identity_md, storage_info_md], | |
| ) | |
| except Exception as e: | |
| print(f"Error loading app: {e}") | |
| # If OAuth not available, keep admin hidden | |
| pass | |
| # Enable queue for Spaces compatibility | |
| return app.queue() | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app = create_app(language='en') | |
| app.launch(share=False, show_error=True, show_api=False) | |