reframe / agents /cbt_agent.py
macayaven's picture
first-features (#1)
c6ce43e verified
"""
CBT Agent implementation using Hugging Face Inference API directly.
This avoids smolagents version drift and works cleanly on Spaces.
Adds optional provider, billing, and timeout support plus true
token streaming for lower-latency responses in Spaces.
"""
import os
from collections.abc import Iterable
from huggingface_hub import InferenceClient
try: # Newer hub versions expose a specific timeout error
from huggingface_hub import InferenceTimeoutError
except Exception: # pragma: no cover
class InferenceTimeoutError(Exception): # type: ignore
pass
class CBTAgent:
"""Agent that composes warm CBT responses using HF Inference API.
Args:
model_name: HF model id or endpoint URL to use.
provider: Optional provider routing (e.g., "together", "groq", "hf-inference").
bill_to: Optional org to bill (Enterprise Hub).
timeout: Request timeout in seconds.
use_local: Ignored; kept for backward compatibility.
"""
def __init__(
self,
model_name: str | None = None,
*,
provider: str | None = None,
bill_to: str | None = None,
timeout: float | None = None,
use_local: bool = False, # use_local ignored
):
self.model_name = model_name or os.getenv("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
# Optional routing / billing / timeout via env overrides
self.provider = provider or os.getenv("HF_PROVIDER")
self.bill_to = bill_to or os.getenv("HF_BILL_TO")
# Parse env timeout if not provided
if timeout is None:
try:
timeout_env = os.getenv("HF_TIMEOUT_SECONDS")
timeout = float(timeout_env) if timeout_env else None
except Exception:
timeout = None
self.timeout = timeout
token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
if not token:
raise RuntimeError("HF token missing: set HF_TOKEN or HUGGINGFACEHUB_API_TOKEN")
# Initialize client with explicit named params to satisfy typing
self.client = InferenceClient(
model=self.model_name,
provider=(self.provider if self.provider is not None else None), # type: ignore[arg-type]
token=token,
timeout=self.timeout,
bill_to=self.bill_to,
)
def _get_system_prompt(self) -> str:
return (
"You are a compassionate CBT assistant. Identify likely cognitive distortions, "
"validate feelings, and suggest balanced, realistic alternatives. Keep responses "
"concise, supportive, and practical. This is self-help guidance, not therapy."
)
def _generate(self, prompt: str, messages: list[dict] | None = None) -> str:
# Prefer chat-completions if available
try:
msgs = messages or [
{"role": "system", "content": self._get_system_prompt()},
{"role": "user", "content": prompt},
]
if hasattr(self.client, "chat") and hasattr(self.client.chat, "completions"):
resp = self.client.chat.completions.create(
model=self.model_name,
messages=msgs,
max_tokens=400,
temperature=0.6,
)
msg = getattr(resp.choices[0], "message", None)
return getattr(msg, "content", "") or ""
if hasattr(self.client, "chat_completion"):
resp = self.client.chat_completion(
messages=msgs,
max_tokens=400,
temperature=0.6,
)
choices = resp.get("choices") or []
if choices:
return choices[0].get("message", {}).get("content", "")
except Exception:
pass
# Fallback to plain text generation
return self.client.text_generation(
prompt if prompt else (messages[-1]["content"] if messages else ""),
max_new_tokens=400,
temperature=0.6,
return_full_text=False,
)
def _build_messages(self, message: str, context: list[dict] | None = None) -> list[dict]:
messages: list[dict] = [
{"role": "system", "content": self._get_system_prompt()},
]
if context:
turns = context[-3:]
for turn in turns:
u = turn.get("user", "")
a = turn.get("assistant", "")
if u:
messages.append({"role": "user", "content": u})
if a:
messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": message})
return messages
def analyze_thought(self, thought: str) -> dict:
from cbt_knowledge import detect_distortions, find_similar_situations
distortions = detect_distortions(thought)
codes = [c for c, _ in distortions]
prompt = (
f"Thought: {thought}\n"
f"Detected codes: {', '.join(codes) if codes else 'none'}\n"
"Write a gentle, balanced reframe (2-4 sentences). Validate feelings first."
)
reframe = self._generate(prompt)
primary = codes[0] if codes else None
similar = find_similar_situations(primary) if primary else []
return {
"thought": thought,
"distortions": distortions,
"reframe": reframe,
"similar_situations": similar,
}
def generate_response(self, message: str, context: list[dict] | None = None) -> str:
"""Return a full assistant response (non-streaming)."""
messages = self._build_messages(message, context)
return self._generate(prompt="", messages=messages)
def stream_generate_response(
self, message: str, context: list[dict] | None = None
) -> Iterable[str]:
"""Yield assistant tokens via provider streaming if supported.
Falls back to a single full message chunk when streaming
is not supported by the provider/client.
"""
messages = self._build_messages(message, context)
# Try OpenAI-compatible chat streaming first
try:
if hasattr(self.client, "chat") and hasattr(self.client.chat, "completions"):
stream = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=0.6,
max_tokens=400,
stream=True,
)
for chunk in stream: # ChatCompletionStreamOutput-like
# Handle both object-like and dict-like chunks
delta = (
getattr(chunk.choices[0], "delta", None)
if hasattr(chunk, "choices")
else None
)
content = None
if delta is not None:
content = getattr(delta, "content", None)
elif isinstance(chunk, dict):
content = chunk.get("choices", [{}])[0].get("delta", {}).get("content")
if content:
yield content
return
if hasattr(self.client, "chat_completion"):
stream = self.client.chat_completion(
messages=messages,
temperature=0.6,
max_tokens=400,
stream=True,
)
for chunk in stream:
# older dict-shaped stream
content = (chunk.get("choices") or [{}])[0].get("delta", {}).get("content")
if content:
yield content
return
except InferenceTimeoutError as e: # pragma: no cover
# Surface a short message; UI will present it
yield f"[timeout] {e}"
return
except Exception:
# Fall through to plain generation
pass
# Fallback: text generation streaming if available
try:
gen = self.client.text_generation(
messages[-1]["content"],
max_new_tokens=400,
temperature=0.6,
stream=True,
return_full_text=False,
)
for token in gen:
if token:
yield str(token)
return
except Exception:
pass
# Final fallback: non-streaming
yield self.generate_response(message, context)
def get_crisis_response(self) -> str:
return (
"I'm concerned about what you're sharing and want you to know that support is "
"available.\n\n"
"Immediate help: see international resources at "
"https://blog.opencounseling.com/suicide-hotlines/\n\n"
"You don't have to go through this alone. Please reach out to a professional who "
"can provide the support you deserve."
)
def create_simple_cbt_assistant():
"""Backwards-compatible alias."""
return CBTAgent()