Spaces:
Sleeping
Sleeping
| """ | |
| CBT Agent implementation using Hugging Face Inference API directly. | |
| This avoids smolagents version drift and works cleanly on Spaces. | |
| Adds optional provider, billing, and timeout support plus true | |
| token streaming for lower-latency responses in Spaces. | |
| """ | |
| import os | |
| from collections.abc import Iterable | |
| from huggingface_hub import InferenceClient | |
| try: # Newer hub versions expose a specific timeout error | |
| from huggingface_hub import InferenceTimeoutError | |
| except Exception: # pragma: no cover | |
| class InferenceTimeoutError(Exception): # type: ignore | |
| pass | |
| class CBTAgent: | |
| """Agent that composes warm CBT responses using HF Inference API. | |
| Args: | |
| model_name: HF model id or endpoint URL to use. | |
| provider: Optional provider routing (e.g., "together", "groq", "hf-inference"). | |
| bill_to: Optional org to bill (Enterprise Hub). | |
| timeout: Request timeout in seconds. | |
| use_local: Ignored; kept for backward compatibility. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str | None = None, | |
| *, | |
| provider: str | None = None, | |
| bill_to: str | None = None, | |
| timeout: float | None = None, | |
| use_local: bool = False, # use_local ignored | |
| ): | |
| self.model_name = model_name or os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") | |
| # Optional routing / billing / timeout via env overrides | |
| self.provider = provider or os.getenv("HF_PROVIDER") | |
| self.bill_to = bill_to or os.getenv("HF_BILL_TO") | |
| # Parse env timeout if not provided | |
| if timeout is None: | |
| try: | |
| timeout_env = os.getenv("HF_TIMEOUT_SECONDS") | |
| timeout = float(timeout_env) if timeout_env else None | |
| except Exception: | |
| timeout = None | |
| self.timeout = timeout | |
| token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") | |
| if not token: | |
| raise RuntimeError("HF token missing: set HF_TOKEN or HUGGINGFACEHUB_API_TOKEN") | |
| # Initialize client with explicit named params to satisfy typing | |
| self.client = InferenceClient( | |
| model=self.model_name, | |
| provider=(self.provider if self.provider is not None else None), # type: ignore[arg-type] | |
| token=token, | |
| timeout=self.timeout, | |
| bill_to=self.bill_to, | |
| ) | |
| def _get_system_prompt(self) -> str: | |
| return ( | |
| "You are a compassionate CBT assistant. Identify likely cognitive distortions, " | |
| "validate feelings, and suggest balanced, realistic alternatives. Keep responses " | |
| "concise, supportive, and practical. This is self-help guidance, not therapy." | |
| ) | |
| def _generate(self, prompt: str, messages: list[dict] | None = None) -> str: | |
| # Prefer chat-completions if available | |
| try: | |
| msgs = messages or [ | |
| {"role": "system", "content": self._get_system_prompt()}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| if hasattr(self.client, "chat") and hasattr(self.client.chat, "completions"): | |
| resp = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=msgs, | |
| max_tokens=400, | |
| temperature=0.6, | |
| ) | |
| msg = getattr(resp.choices[0], "message", None) | |
| return getattr(msg, "content", "") or "" | |
| if hasattr(self.client, "chat_completion"): | |
| resp = self.client.chat_completion( | |
| messages=msgs, | |
| max_tokens=400, | |
| temperature=0.6, | |
| ) | |
| choices = resp.get("choices") or [] | |
| if choices: | |
| return choices[0].get("message", {}).get("content", "") | |
| except Exception: | |
| pass | |
| # Fallback to plain text generation | |
| return self.client.text_generation( | |
| prompt if prompt else (messages[-1]["content"] if messages else ""), | |
| max_new_tokens=400, | |
| temperature=0.6, | |
| return_full_text=False, | |
| ) | |
| def _build_messages(self, message: str, context: list[dict] | None = None) -> list[dict]: | |
| messages: list[dict] = [ | |
| {"role": "system", "content": self._get_system_prompt()}, | |
| ] | |
| if context: | |
| turns = context[-3:] | |
| for turn in turns: | |
| u = turn.get("user", "") | |
| a = turn.get("assistant", "") | |
| if u: | |
| messages.append({"role": "user", "content": u}) | |
| if a: | |
| messages.append({"role": "assistant", "content": a}) | |
| messages.append({"role": "user", "content": message}) | |
| return messages | |
| def analyze_thought(self, thought: str) -> dict: | |
| from cbt_knowledge import detect_distortions, find_similar_situations | |
| distortions = detect_distortions(thought) | |
| codes = [c for c, _ in distortions] | |
| prompt = ( | |
| f"Thought: {thought}\n" | |
| f"Detected codes: {', '.join(codes) if codes else 'none'}\n" | |
| "Write a gentle, balanced reframe (2-4 sentences). Validate feelings first." | |
| ) | |
| reframe = self._generate(prompt) | |
| primary = codes[0] if codes else None | |
| similar = find_similar_situations(primary) if primary else [] | |
| return { | |
| "thought": thought, | |
| "distortions": distortions, | |
| "reframe": reframe, | |
| "similar_situations": similar, | |
| } | |
| def generate_response(self, message: str, context: list[dict] | None = None) -> str: | |
| """Return a full assistant response (non-streaming).""" | |
| messages = self._build_messages(message, context) | |
| return self._generate(prompt="", messages=messages) | |
| def stream_generate_response( | |
| self, message: str, context: list[dict] | None = None | |
| ) -> Iterable[str]: | |
| """Yield assistant tokens via provider streaming if supported. | |
| Falls back to a single full message chunk when streaming | |
| is not supported by the provider/client. | |
| """ | |
| messages = self._build_messages(message, context) | |
| # Try OpenAI-compatible chat streaming first | |
| try: | |
| if hasattr(self.client, "chat") and hasattr(self.client.chat, "completions"): | |
| stream = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| temperature=0.6, | |
| max_tokens=400, | |
| stream=True, | |
| ) | |
| for chunk in stream: # ChatCompletionStreamOutput-like | |
| # Handle both object-like and dict-like chunks | |
| delta = ( | |
| getattr(chunk.choices[0], "delta", None) | |
| if hasattr(chunk, "choices") | |
| else None | |
| ) | |
| content = None | |
| if delta is not None: | |
| content = getattr(delta, "content", None) | |
| elif isinstance(chunk, dict): | |
| content = chunk.get("choices", [{}])[0].get("delta", {}).get("content") | |
| if content: | |
| yield content | |
| return | |
| if hasattr(self.client, "chat_completion"): | |
| stream = self.client.chat_completion( | |
| messages=messages, | |
| temperature=0.6, | |
| max_tokens=400, | |
| stream=True, | |
| ) | |
| for chunk in stream: | |
| # older dict-shaped stream | |
| content = (chunk.get("choices") or [{}])[0].get("delta", {}).get("content") | |
| if content: | |
| yield content | |
| return | |
| except InferenceTimeoutError as e: # pragma: no cover | |
| # Surface a short message; UI will present it | |
| yield f"[timeout] {e}" | |
| return | |
| except Exception: | |
| # Fall through to plain generation | |
| pass | |
| # Fallback: text generation streaming if available | |
| try: | |
| gen = self.client.text_generation( | |
| messages[-1]["content"], | |
| max_new_tokens=400, | |
| temperature=0.6, | |
| stream=True, | |
| return_full_text=False, | |
| ) | |
| for token in gen: | |
| if token: | |
| yield str(token) | |
| return | |
| except Exception: | |
| pass | |
| # Final fallback: non-streaming | |
| yield self.generate_response(message, context) | |
| def get_crisis_response(self) -> str: | |
| return ( | |
| "I'm concerned about what you're sharing and want you to know that support is " | |
| "available.\n\n" | |
| "Immediate help: see international resources at " | |
| "https://blog.opencounseling.com/suicide-hotlines/\n\n" | |
| "You don't have to go through this alone. Please reach out to a professional who " | |
| "can provide the support you deserve." | |
| ) | |
| def create_simple_cbt_assistant(): | |
| """Backwards-compatible alias.""" | |
| return CBTAgent() | |