File size: 9,140 Bytes
61f47ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import logging
import time
import json
from typing import List, Dict
from huggingface_hub import InferenceClient
from .prompts import CLARIFIER_DIRECTION

logger = logging.getLogger(__name__)

CLARIFICATION_SCHEMA = {
    "name": "clarification_suggestions",
    "schema": {
        "type": "object",
        "properties": {
            "suggestions": {
                "type": "array",
                "items": {
                    "type": "object",
                    "properties": {
                        "title": {"type": "string"},
                        "description": {"type": "string"}
                    },
                    "required": ["title", "description"]
                }
            }
        },
        "required": ["suggestions"]
    },
    "strict": True
}

class Clarifier:
    def __init__(self, model_name: str, hf_key: str):
        self.model_name = model_name
        self.client = InferenceClient(token=hf_key, timeout=120)

    def get_suggestions(self, topic: str) -> List[Dict[str, str]]:
        logger.info(f'Clarifying Topic: {topic} using model {self.model_name}')
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                logger.info(f"Attempt {attempt + 1}/{max_retries} using {self.model_name}")
                
                # Use streaming to be more resilient to StopIteration/timeout issues on thinking models
                full_content = ""
                try:
                    # Note: response_format is used to guide the model, but we'll parse manually for robustness
                    stream = self.client.chat_completion(
                        model=self.model_name,
                        messages=[
                            {"role": "system", "content": CLARIFIER_DIRECTION},
                            {"role": "user", "content": topic}
                        ],
                        max_tokens=2000,
                        stream=True,
                        temperature=1.0,
                        top_p=1.0,
                        # Fallback for models that don't support JSON schema
                        response_format={
                            "type": "json_schema",
                            "json_schema": CLARIFICATION_SCHEMA,
                        } if attempt == 0 else None
                    )
                    
                    for chunk in stream:
                        delta = chunk.choices[0].delta
                        if hasattr(delta, 'content') and delta.content:
                            full_content += delta.content
                        # Capture DeepSeek-R1 style reasoning_content
                        if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
                            # Wrap in think tags if not already present in the stream to satisfy our parsers
                            if "<think>" not in full_content:
                                full_content += "<think>"
                            full_content += delta.reasoning_content
                        elif hasattr(delta, 'reasoning') and delta.reasoning:
                            if "<think>" not in full_content:
                                full_content += "<think>"
                            full_content += delta.reasoning
                    
                    # Close think tag if it was opened but never closed by the model
                    if "<think>" in full_content and "</think>" not in full_content:
                        full_content += "</think>"
                
                except StopIteration:
                    if not full_content:
                        logger.error(f"Model {self.model_name} returned an empty stream. It may not be supported on the current Inference API endpoint.")
                    else:
                        logger.warning("Stream ended abruptly.")
                except Exception as stream_err:
                    logger.error(f"Streaming error from {self.model_name}: {stream_err}")
                
                if not full_content:
                    logger.warning(f"No content received on attempt {attempt + 1}")
                    if attempt < max_retries - 1:
                        time.sleep(2 ** attempt)
                        continue
                    return []

                suggestions = self._parse_suggestions(full_content)
                if suggestions:
                    return suggestions
                
                logger.warning(f"Failed to parse suggestions from content on attempt {attempt + 1}")

            except Exception as e:
                logger.error(f"Error during get_suggestions (Attempt {attempt + 1}): {e}")
                if attempt < max_retries - 1:
                    wait_time = 2 ** attempt
                    logger.info(f"Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
        
        return []

    def _parse_suggestions(self, content: str) -> List[Dict[str, str]]:
        if not content:
            return []
            
        logger.debug(f"Raw response: {content}")
        
        # Clean the content: remove thinking blocks and extract JSON
        clean_content = content.strip()
        
        # Remove thinking blocks if present
        if "<think>" in clean_content:
            if "</think>" in clean_content:
                clean_content = clean_content.split("</think>")[-1].strip()
            else:
                # If think is open but not closed, try to find the first JSON object after it
                parts = clean_content.split("<think>")
                clean_content = parts[-1].strip()
                if "{" in clean_content:
                    clean_content = "{" + clean_content.split("{", 1)[1]

        # Extract JSON from potential code blocks
        if "```json" in clean_content:
            clean_content = clean_content.split("```json")[1].split("```")[0].strip()
        elif "```" in clean_content:
            # Check if it looks like JSON inside a generic code block
            block_content = clean_content.split("```")[1].split("```")[0].strip()
            if block_content.startswith("{") or "[" in block_content:
                clean_content = block_content

        # Final attempt to find a JSON-like structure if it's buried in text
        if not (clean_content.startswith("{") or clean_content.startswith("[")):
            if "{" in clean_content:
                clean_content = "{" + clean_content.split("{", 1)[1]
            if "}" in clean_content:
                clean_content = clean_content.rsplit("}", 1)[0] + "}"

        try:
            data = json.loads(clean_content)
            # Handle potential nested 'suggestions' key or direct list
            if isinstance(data, dict):
                suggestions = data.get("suggestions", [])
            elif isinstance(data, list):
                suggestions = data
            else:
                suggestions = []
                
            logger.info(f"Successfully extracted {len(suggestions)} suggestions")
            return suggestions
        except json.JSONDecodeError as e:
            logger.error(f"JSON parsing failed: {e}. Snippet: {clean_content[:150]}...")
            return []

    def clarify(self, topic: str) -> str:
        suggestions = self.get_suggestions(topic)
        
        if not suggestions:
            logger.warning("No suggestions generated, using original topic.")
            return topic

        print("\n\033[93m--- Research Topic Suggestions ---\033[0m")
        for i, sug in enumerate(suggestions, 1):
            print(f"[{i}] {sug.get('title')}")
            print(f"    {sug.get('description')}\n")
        print(f"[0] Use original: {topic}")
        print("\033[93m----------------------------------\033[0m")
        
        while True:
            try:
                user_input = input("Select a choice (number) or enter a custom topic: ").strip()
                
                if user_input == "0":
                    return topic
                
                if user_input.isdigit():
                    idx = int(user_input) - 1
                    if 0 <= idx < len(suggestions):
                        selected = suggestions[idx]
                        # Return the combined title and description as the refined topic
                        final_topic = f"{selected.get('title')}: {selected.get('description')}"
                        logger.info(f"User selected suggestion {user_input}")
                        return final_topic
                    else:
                        print(f"Please enter a number between 0 and {len(suggestions)}.")
                        continue
                
                if user_input:
                    # Treat non-numeric non-empty input as a custom topic
                    logger.info("User provided custom topic.")
                    return user_input
                else:
                    print("Input cannot be empty. Please select a number or type a topic.")

            except EOFError:
                return topic