import logging import time import json from huggingface_hub import InferenceClient from .prompts import PLANNER_DIRECTION logger = logging.getLogger(__name__) class Planner: def __init__(self, model_name: str, hf_key: str): self.model_name = model_name self.client = InferenceClient(token=hf_key, timeout=120) def plan(self, topic: str) -> str: logger.info(f'Starting Planning: {topic} using model {self.model_name}') max_retries = 3 for attempt in range(max_retries): try: full_content = "" # Use streaming for robustness with reasoning/large models try: stream = self.client.chat_completion( model=self.model_name, messages=[ {"role": "system", "content": PLANNER_DIRECTION}, {"role": "user", "content": topic} ], max_tokens=4000, stream=True, temperature=1.0, top_p=1.0 ) for chunk in stream: delta = chunk.choices[0].delta if hasattr(delta, 'content') and delta.content: full_content += delta.content if hasattr(delta, 'reasoning_content') and delta.reasoning_content: if "" not in full_content: full_content += "" full_content += delta.reasoning_content elif hasattr(delta, 'reasoning') and delta.reasoning: if "" not in full_content: full_content += "" full_content += delta.reasoning if "" in full_content and "" not in full_content: full_content += "" except StopIteration: if not full_content: logger.error(f"Model {self.model_name} returned an empty stream. It may not be supported on the current Inference API endpoint.") else: logger.warning("Stream ended with StopIteration.") except Exception as stream_err: logger.error(f"Streaming failed: {stream_err}") if not full_content: logger.warning(f"Empty content received on attempt {attempt + 1}") continue research_plan = self._parse_plan(full_content) if not research_plan: logger.warning(f"Failed to parse plan from content on attempt {attempt + 1}") continue logger.info("Generated research plan") return research_plan except Exception as e: logger.error(f"Error during API call (Attempt {attempt + 1}/{max_retries}): {e}") if attempt < max_retries - 1: wait_time = 2 ** attempt logger.info(f"Retrying in {wait_time} seconds...") time.sleep(wait_time) else: return "" return "" def _parse_plan(self, content: str) -> str: if not content: return "" # Clean the content: remove thinking blocks and extract JSON clean_content = content.strip() # Remove thinking blocks if present if "" in clean_content: if "" in clean_content: clean_content = clean_content.split("")[-1].strip() else: parts = clean_content.split("") clean_content = parts[-1].strip() if "{" in clean_content: clean_content = "{" + clean_content.split("{", 1)[1] # Extract JSON from potential code blocks if "```json" in clean_content: clean_content = clean_content.split("```json")[1].split("```")[0].strip() elif "```" in clean_content: block_content = clean_content.split("```")[1].split("```")[0].strip() if block_content.startswith("{") or "research_plan" in block_content: clean_content = block_content # Final attempt to find a JSON-like structure if not clean_content.startswith("{"): if "{" in clean_content: clean_content = "{" + clean_content.split("{", 1)[1] if "}" in clean_content: clean_content = clean_content.rsplit("}", 1)[0] + "}" try: data = json.loads(clean_content) if isinstance(data, dict): return data.get("research_plan", clean_content) return clean_content except json.JSONDecodeError: # Fallback for models that don't return valid JSON return content.split("")[-1].strip() if "" in content else content.strip()