research_agent / src /clarifier.py
Luis Vizcaya
pipeline working
61f47ab
import logging
import time
import json
from typing import List, Dict
from huggingface_hub import InferenceClient
from .prompts import CLARIFIER_DIRECTION
logger = logging.getLogger(__name__)
CLARIFICATION_SCHEMA = {
"name": "clarification_suggestions",
"schema": {
"type": "object",
"properties": {
"suggestions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"description": {"type": "string"}
},
"required": ["title", "description"]
}
}
},
"required": ["suggestions"]
},
"strict": True
}
class Clarifier:
def __init__(self, model_name: str, hf_key: str):
self.model_name = model_name
self.client = InferenceClient(token=hf_key, timeout=120)
def get_suggestions(self, topic: str) -> List[Dict[str, str]]:
logger.info(f'Clarifying Topic: {topic} using model {self.model_name}')
max_retries = 3
for attempt in range(max_retries):
try:
logger.info(f"Attempt {attempt + 1}/{max_retries} using {self.model_name}")
# Use streaming to be more resilient to StopIteration/timeout issues on thinking models
full_content = ""
try:
# Note: response_format is used to guide the model, but we'll parse manually for robustness
stream = self.client.chat_completion(
model=self.model_name,
messages=[
{"role": "system", "content": CLARIFIER_DIRECTION},
{"role": "user", "content": topic}
],
max_tokens=2000,
stream=True,
temperature=1.0,
top_p=1.0,
# Fallback for models that don't support JSON schema
response_format={
"type": "json_schema",
"json_schema": CLARIFICATION_SCHEMA,
} if attempt == 0 else None
)
for chunk in stream:
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
full_content += delta.content
# Capture DeepSeek-R1 style reasoning_content
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
# Wrap in think tags if not already present in the stream to satisfy our parsers
if "<think>" not in full_content:
full_content += "<think>"
full_content += delta.reasoning_content
elif hasattr(delta, 'reasoning') and delta.reasoning:
if "<think>" not in full_content:
full_content += "<think>"
full_content += delta.reasoning
# Close think tag if it was opened but never closed by the model
if "<think>" in full_content and "</think>" not in full_content:
full_content += "</think>"
except StopIteration:
if not full_content:
logger.error(f"Model {self.model_name} returned an empty stream. It may not be supported on the current Inference API endpoint.")
else:
logger.warning("Stream ended abruptly.")
except Exception as stream_err:
logger.error(f"Streaming error from {self.model_name}: {stream_err}")
if not full_content:
logger.warning(f"No content received on attempt {attempt + 1}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return []
suggestions = self._parse_suggestions(full_content)
if suggestions:
return suggestions
logger.warning(f"Failed to parse suggestions from content on attempt {attempt + 1}")
except Exception as e:
logger.error(f"Error during get_suggestions (Attempt {attempt + 1}): {e}")
if attempt < max_retries - 1:
wait_time = 2 ** attempt
logger.info(f"Retrying in {wait_time} seconds...")
time.sleep(wait_time)
return []
def _parse_suggestions(self, content: str) -> List[Dict[str, str]]:
if not content:
return []
logger.debug(f"Raw response: {content}")
# Clean the content: remove thinking blocks and extract JSON
clean_content = content.strip()
# Remove thinking blocks if present
if "<think>" in clean_content:
if "</think>" in clean_content:
clean_content = clean_content.split("</think>")[-1].strip()
else:
# If think is open but not closed, try to find the first JSON object after it
parts = clean_content.split("<think>")
clean_content = parts[-1].strip()
if "{" in clean_content:
clean_content = "{" + clean_content.split("{", 1)[1]
# Extract JSON from potential code blocks
if "```json" in clean_content:
clean_content = clean_content.split("```json")[1].split("```")[0].strip()
elif "```" in clean_content:
# Check if it looks like JSON inside a generic code block
block_content = clean_content.split("```")[1].split("```")[0].strip()
if block_content.startswith("{") or "[" in block_content:
clean_content = block_content
# Final attempt to find a JSON-like structure if it's buried in text
if not (clean_content.startswith("{") or clean_content.startswith("[")):
if "{" in clean_content:
clean_content = "{" + clean_content.split("{", 1)[1]
if "}" in clean_content:
clean_content = clean_content.rsplit("}", 1)[0] + "}"
try:
data = json.loads(clean_content)
# Handle potential nested 'suggestions' key or direct list
if isinstance(data, dict):
suggestions = data.get("suggestions", [])
elif isinstance(data, list):
suggestions = data
else:
suggestions = []
logger.info(f"Successfully extracted {len(suggestions)} suggestions")
return suggestions
except json.JSONDecodeError as e:
logger.error(f"JSON parsing failed: {e}. Snippet: {clean_content[:150]}...")
return []
def clarify(self, topic: str) -> str:
suggestions = self.get_suggestions(topic)
if not suggestions:
logger.warning("No suggestions generated, using original topic.")
return topic
print("\n\033[93m--- Research Topic Suggestions ---\033[0m")
for i, sug in enumerate(suggestions, 1):
print(f"[{i}] {sug.get('title')}")
print(f" {sug.get('description')}\n")
print(f"[0] Use original: {topic}")
print("\033[93m----------------------------------\033[0m")
while True:
try:
user_input = input("Select a choice (number) or enter a custom topic: ").strip()
if user_input == "0":
return topic
if user_input.isdigit():
idx = int(user_input) - 1
if 0 <= idx < len(suggestions):
selected = suggestions[idx]
# Return the combined title and description as the refined topic
final_topic = f"{selected.get('title')}: {selected.get('description')}"
logger.info(f"User selected suggestion {user_input}")
return final_topic
else:
print(f"Please enter a number between 0 and {len(suggestions)}.")
continue
if user_input:
# Treat non-numeric non-empty input as a custom topic
logger.info("User provided custom topic.")
return user_input
else:
print("Input cannot be empty. Please select a number or type a topic.")
except EOFError:
return topic