Spaces:
Sleeping
Sleeping
File size: 9,140 Bytes
61f47ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 | import logging
import time
import json
from typing import List, Dict
from huggingface_hub import InferenceClient
from .prompts import CLARIFIER_DIRECTION
logger = logging.getLogger(__name__)
CLARIFICATION_SCHEMA = {
"name": "clarification_suggestions",
"schema": {
"type": "object",
"properties": {
"suggestions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"title": {"type": "string"},
"description": {"type": "string"}
},
"required": ["title", "description"]
}
}
},
"required": ["suggestions"]
},
"strict": True
}
class Clarifier:
def __init__(self, model_name: str, hf_key: str):
self.model_name = model_name
self.client = InferenceClient(token=hf_key, timeout=120)
def get_suggestions(self, topic: str) -> List[Dict[str, str]]:
logger.info(f'Clarifying Topic: {topic} using model {self.model_name}')
max_retries = 3
for attempt in range(max_retries):
try:
logger.info(f"Attempt {attempt + 1}/{max_retries} using {self.model_name}")
# Use streaming to be more resilient to StopIteration/timeout issues on thinking models
full_content = ""
try:
# Note: response_format is used to guide the model, but we'll parse manually for robustness
stream = self.client.chat_completion(
model=self.model_name,
messages=[
{"role": "system", "content": CLARIFIER_DIRECTION},
{"role": "user", "content": topic}
],
max_tokens=2000,
stream=True,
temperature=1.0,
top_p=1.0,
# Fallback for models that don't support JSON schema
response_format={
"type": "json_schema",
"json_schema": CLARIFICATION_SCHEMA,
} if attempt == 0 else None
)
for chunk in stream:
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
full_content += delta.content
# Capture DeepSeek-R1 style reasoning_content
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
# Wrap in think tags if not already present in the stream to satisfy our parsers
if "<think>" not in full_content:
full_content += "<think>"
full_content += delta.reasoning_content
elif hasattr(delta, 'reasoning') and delta.reasoning:
if "<think>" not in full_content:
full_content += "<think>"
full_content += delta.reasoning
# Close think tag if it was opened but never closed by the model
if "<think>" in full_content and "</think>" not in full_content:
full_content += "</think>"
except StopIteration:
if not full_content:
logger.error(f"Model {self.model_name} returned an empty stream. It may not be supported on the current Inference API endpoint.")
else:
logger.warning("Stream ended abruptly.")
except Exception as stream_err:
logger.error(f"Streaming error from {self.model_name}: {stream_err}")
if not full_content:
logger.warning(f"No content received on attempt {attempt + 1}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return []
suggestions = self._parse_suggestions(full_content)
if suggestions:
return suggestions
logger.warning(f"Failed to parse suggestions from content on attempt {attempt + 1}")
except Exception as e:
logger.error(f"Error during get_suggestions (Attempt {attempt + 1}): {e}")
if attempt < max_retries - 1:
wait_time = 2 ** attempt
logger.info(f"Retrying in {wait_time} seconds...")
time.sleep(wait_time)
return []
def _parse_suggestions(self, content: str) -> List[Dict[str, str]]:
if not content:
return []
logger.debug(f"Raw response: {content}")
# Clean the content: remove thinking blocks and extract JSON
clean_content = content.strip()
# Remove thinking blocks if present
if "<think>" in clean_content:
if "</think>" in clean_content:
clean_content = clean_content.split("</think>")[-1].strip()
else:
# If think is open but not closed, try to find the first JSON object after it
parts = clean_content.split("<think>")
clean_content = parts[-1].strip()
if "{" in clean_content:
clean_content = "{" + clean_content.split("{", 1)[1]
# Extract JSON from potential code blocks
if "```json" in clean_content:
clean_content = clean_content.split("```json")[1].split("```")[0].strip()
elif "```" in clean_content:
# Check if it looks like JSON inside a generic code block
block_content = clean_content.split("```")[1].split("```")[0].strip()
if block_content.startswith("{") or "[" in block_content:
clean_content = block_content
# Final attempt to find a JSON-like structure if it's buried in text
if not (clean_content.startswith("{") or clean_content.startswith("[")):
if "{" in clean_content:
clean_content = "{" + clean_content.split("{", 1)[1]
if "}" in clean_content:
clean_content = clean_content.rsplit("}", 1)[0] + "}"
try:
data = json.loads(clean_content)
# Handle potential nested 'suggestions' key or direct list
if isinstance(data, dict):
suggestions = data.get("suggestions", [])
elif isinstance(data, list):
suggestions = data
else:
suggestions = []
logger.info(f"Successfully extracted {len(suggestions)} suggestions")
return suggestions
except json.JSONDecodeError as e:
logger.error(f"JSON parsing failed: {e}. Snippet: {clean_content[:150]}...")
return []
def clarify(self, topic: str) -> str:
suggestions = self.get_suggestions(topic)
if not suggestions:
logger.warning("No suggestions generated, using original topic.")
return topic
print("\n\033[93m--- Research Topic Suggestions ---\033[0m")
for i, sug in enumerate(suggestions, 1):
print(f"[{i}] {sug.get('title')}")
print(f" {sug.get('description')}\n")
print(f"[0] Use original: {topic}")
print("\033[93m----------------------------------\033[0m")
while True:
try:
user_input = input("Select a choice (number) or enter a custom topic: ").strip()
if user_input == "0":
return topic
if user_input.isdigit():
idx = int(user_input) - 1
if 0 <= idx < len(suggestions):
selected = suggestions[idx]
# Return the combined title and description as the refined topic
final_topic = f"{selected.get('title')}: {selected.get('description')}"
logger.info(f"User selected suggestion {user_input}")
return final_topic
else:
print(f"Please enter a number between 0 and {len(suggestions)}.")
continue
if user_input:
# Treat non-numeric non-empty input as a custom topic
logger.info("User provided custom topic.")
return user_input
else:
print("Input cannot be empty. Please select a number or type a topic.")
except EOFError:
return topic
|