import logging import time import json from typing import List, Dict from huggingface_hub import InferenceClient from .prompts import CLARIFIER_DIRECTION logger = logging.getLogger(__name__) CLARIFICATION_SCHEMA = { "name": "clarification_suggestions", "schema": { "type": "object", "properties": { "suggestions": { "type": "array", "items": { "type": "object", "properties": { "title": {"type": "string"}, "description": {"type": "string"} }, "required": ["title", "description"] } } }, "required": ["suggestions"] }, "strict": True } class Clarifier: def __init__(self, model_name: str, hf_key: str): self.model_name = model_name self.client = InferenceClient(token=hf_key, timeout=120) def get_suggestions(self, topic: str) -> List[Dict[str, str]]: logger.info(f'Clarifying Topic: {topic} using model {self.model_name}') max_retries = 3 for attempt in range(max_retries): try: logger.info(f"Attempt {attempt + 1}/{max_retries} using {self.model_name}") # Use streaming to be more resilient to StopIteration/timeout issues on thinking models full_content = "" try: # Note: response_format is used to guide the model, but we'll parse manually for robustness stream = self.client.chat_completion( model=self.model_name, messages=[ {"role": "system", "content": CLARIFIER_DIRECTION}, {"role": "user", "content": topic} ], max_tokens=2000, stream=True, temperature=1.0, top_p=1.0, # Fallback for models that don't support JSON schema response_format={ "type": "json_schema", "json_schema": CLARIFICATION_SCHEMA, } if attempt == 0 else None ) for chunk in stream: delta = chunk.choices[0].delta if hasattr(delta, 'content') and delta.content: full_content += delta.content # Capture DeepSeek-R1 style reasoning_content if hasattr(delta, 'reasoning_content') and delta.reasoning_content: # Wrap in think tags if not already present in the stream to satisfy our parsers if "" not in full_content: full_content += "" full_content += delta.reasoning_content elif hasattr(delta, 'reasoning') and delta.reasoning: if "" not in full_content: full_content += "" full_content += delta.reasoning # Close think tag if it was opened but never closed by the model if "" in full_content and "" not in full_content: full_content += "" except StopIteration: if not full_content: logger.error(f"Model {self.model_name} returned an empty stream. It may not be supported on the current Inference API endpoint.") else: logger.warning("Stream ended abruptly.") except Exception as stream_err: logger.error(f"Streaming error from {self.model_name}: {stream_err}") if not full_content: logger.warning(f"No content received on attempt {attempt + 1}") if attempt < max_retries - 1: time.sleep(2 ** attempt) continue return [] suggestions = self._parse_suggestions(full_content) if suggestions: return suggestions logger.warning(f"Failed to parse suggestions from content on attempt {attempt + 1}") except Exception as e: logger.error(f"Error during get_suggestions (Attempt {attempt + 1}): {e}") if attempt < max_retries - 1: wait_time = 2 ** attempt logger.info(f"Retrying in {wait_time} seconds...") time.sleep(wait_time) return [] def _parse_suggestions(self, content: str) -> List[Dict[str, str]]: if not content: return [] logger.debug(f"Raw response: {content}") # Clean the content: remove thinking blocks and extract JSON clean_content = content.strip() # Remove thinking blocks if present if "" in clean_content: if "" in clean_content: clean_content = clean_content.split("")[-1].strip() else: # If think is open but not closed, try to find the first JSON object after it parts = clean_content.split("") clean_content = parts[-1].strip() if "{" in clean_content: clean_content = "{" + clean_content.split("{", 1)[1] # Extract JSON from potential code blocks if "```json" in clean_content: clean_content = clean_content.split("```json")[1].split("```")[0].strip() elif "```" in clean_content: # Check if it looks like JSON inside a generic code block block_content = clean_content.split("```")[1].split("```")[0].strip() if block_content.startswith("{") or "[" in block_content: clean_content = block_content # Final attempt to find a JSON-like structure if it's buried in text if not (clean_content.startswith("{") or clean_content.startswith("[")): if "{" in clean_content: clean_content = "{" + clean_content.split("{", 1)[1] if "}" in clean_content: clean_content = clean_content.rsplit("}", 1)[0] + "}" try: data = json.loads(clean_content) # Handle potential nested 'suggestions' key or direct list if isinstance(data, dict): suggestions = data.get("suggestions", []) elif isinstance(data, list): suggestions = data else: suggestions = [] logger.info(f"Successfully extracted {len(suggestions)} suggestions") return suggestions except json.JSONDecodeError as e: logger.error(f"JSON parsing failed: {e}. Snippet: {clean_content[:150]}...") return [] def clarify(self, topic: str) -> str: suggestions = self.get_suggestions(topic) if not suggestions: logger.warning("No suggestions generated, using original topic.") return topic print("\n\033[93m--- Research Topic Suggestions ---\033[0m") for i, sug in enumerate(suggestions, 1): print(f"[{i}] {sug.get('title')}") print(f" {sug.get('description')}\n") print(f"[0] Use original: {topic}") print("\033[93m----------------------------------\033[0m") while True: try: user_input = input("Select a choice (number) or enter a custom topic: ").strip() if user_input == "0": return topic if user_input.isdigit(): idx = int(user_input) - 1 if 0 <= idx < len(suggestions): selected = suggestions[idx] # Return the combined title and description as the refined topic final_topic = f"{selected.get('title')}: {selected.get('description')}" logger.info(f"User selected suggestion {user_input}") return final_topic else: print(f"Please enter a number between 0 and {len(suggestions)}.") continue if user_input: # Treat non-numeric non-empty input as a custom topic logger.info("User provided custom topic.") return user_input else: print("Input cannot be empty. Please select a number or type a topic.") except EOFError: return topic