File size: 5,293 Bytes
61f47ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import logging
import time
import json
from huggingface_hub import InferenceClient
from .prompts import PLANNER_DIRECTION

logger = logging.getLogger(__name__)

class Planner:
    def __init__(self, model_name: str, hf_key: str):
        self.model_name = model_name
        self.client = InferenceClient(token=hf_key, timeout=120)

    def plan(self, topic: str) -> str:
        logger.info(f'Starting Planning: {topic} using model {self.model_name}')
        
        max_retries = 3
        for attempt in range(max_retries):
            try:
                full_content = ""
                # Use streaming for robustness with reasoning/large models
                try:
                    stream = self.client.chat_completion(
                        model=self.model_name,
                        messages=[
                            {"role": "system", "content": PLANNER_DIRECTION},
                            {"role": "user", "content": topic}
                        ],
                        max_tokens=4000,
                        stream=True,
                        temperature=1.0,
                        top_p=1.0
                    )
                    
                    for chunk in stream:
                        delta = chunk.choices[0].delta
                        if hasattr(delta, 'content') and delta.content:
                            full_content += delta.content
                        if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
                            if "<think>" not in full_content:
                                full_content += "<think>"
                            full_content += delta.reasoning_content
                        elif hasattr(delta, 'reasoning') and delta.reasoning:
                            if "<think>" not in full_content:
                                full_content += "<think>"
                            full_content += delta.reasoning
                            
                    if "<think>" in full_content and "</think>" not in full_content:
                        full_content += "</think>"
                            
                except StopIteration:
                    if not full_content:
                         logger.error(f"Model {self.model_name} returned an empty stream. It may not be supported on the current Inference API endpoint.")
                    else:
                        logger.warning("Stream ended with StopIteration.")
                except Exception as stream_err:
                    logger.error(f"Streaming failed: {stream_err}")

                if not full_content:
                    logger.warning(f"Empty content received on attempt {attempt + 1}")
                    continue

                research_plan = self._parse_plan(full_content)
                
                if not research_plan:
                    logger.warning(f"Failed to parse plan from content on attempt {attempt + 1}")
                    continue

                logger.info("Generated research plan")
                return research_plan
            except Exception as e:
                logger.error(f"Error during API call (Attempt {attempt + 1}/{max_retries}): {e}")
                if attempt < max_retries - 1:
                    wait_time = 2 ** attempt  
                    logger.info(f"Retrying in {wait_time} seconds...")
                    time.sleep(wait_time)
                else:
                     return ""
        return ""

    def _parse_plan(self, content: str) -> str:
        if not content:
            return ""
            
        # Clean the content: remove thinking blocks and extract JSON
        clean_content = content.strip()
        
        # Remove thinking blocks if present
        if "<think>" in clean_content:
            if "</think>" in clean_content:
                clean_content = clean_content.split("</think>")[-1].strip()
            else:
                parts = clean_content.split("<think>")
                clean_content = parts[-1].strip()
                if "{" in clean_content:
                    clean_content = "{" + clean_content.split("{", 1)[1]

        # Extract JSON from potential code blocks
        if "```json" in clean_content:
            clean_content = clean_content.split("```json")[1].split("```")[0].strip()
        elif "```" in clean_content:
            block_content = clean_content.split("```")[1].split("```")[0].strip()
            if block_content.startswith("{") or "research_plan" in block_content:
                clean_content = block_content

        # Final attempt to find a JSON-like structure
        if not clean_content.startswith("{"):
            if "{" in clean_content:
                clean_content = "{" + clean_content.split("{", 1)[1]
            if "}" in clean_content:
                clean_content = clean_content.rsplit("}", 1)[0] + "}"

        try:
            data = json.loads(clean_content)
            if isinstance(data, dict):
                return data.get("research_plan", clean_content)
            return clean_content
        except json.JSONDecodeError:
            # Fallback for models that don't return valid JSON
            return content.split("</think>")[-1].strip() if "</think>" in content else content.strip()