Spaces:
Sleeping
Sleeping
File size: 6,805 Bytes
61f47ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import os
import json
import logging
import time
from typing import List
from pydantic import Field, BaseModel
from huggingface_hub import InferenceClient
from .prompts import SPLITTER_DIRECTION
from pprint import pprint
logger = logging.getLogger(__name__)
class Subtask(BaseModel):
id: str = Field(
...,
description="Short identifier for the subtask (e.g. 'A', 'history', 'drivers').",
)
title: str = Field(
...,
description="Short descriptive title of the subtask.",
)
description: str = Field(
...,
description="Clear, detailed instructions for the sub-agent that will research this subtask.",
)
class SubtaskList(BaseModel):
subtasks: List[Subtask] = Field(
...,
description="List of subtasks that together cover the whole research plan.",
)
TASK_SPLITTER_SCHEMA = {
"name": "subtaskList",
"schema": SubtaskList.model_json_schema(),
"strict": True,
}
class Splitter:
def __init__(self, model_name: str = "moonshotai/Kimi-K2-Thinking", hf_key: str = None):
self.model_name = model_name
self.hf_key = hf_key or os.getenv("HF_KEY") or os.getenv("HF_TOKEN")
self.client = InferenceClient(
api_key=self.hf_key,
)
def split(self, research_plan: str) -> List[dict]:
logger.info(f"Splitting the research plan into subtasks using {self.model_name}...")
max_retries = 3
for attempt in range(max_retries):
try:
full_content = ""
try:
# Using chat_completion with streaming
stream = self.client.chat_completion(
model=self.model_name,
messages=[
{"role": "system", "content": SPLITTER_DIRECTION},
{"role": "user", "content": research_plan},
],
response_format={
"type": "json_schema",
"json_schema": TASK_SPLITTER_SCHEMA,
} if attempt == 0 else None,
max_tokens=4000,
stream=True,
temperature=0.6,
top_p=0.95
)
for chunk in stream:
delta = chunk.choices[0].delta
if hasattr(delta, 'content') and delta.content:
full_content += delta.content
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
if "<think>" not in full_content:
full_content += "<think>"
full_content += delta.reasoning_content
elif hasattr(delta, 'reasoning') and delta.reasoning:
if "<think>" not in full_content:
full_content += "<think>"
full_content += delta.reasoning
if "<think>" in full_content and "</think>" not in full_content:
full_content += "</think>"
except StopIteration:
if not full_content:
logger.error(f"Model {self.model_name} returned an empty stream. It may not be supported on the current Inference API endpoint.")
else:
logger.warning("Stream ended with StopIteration.")
except Exception as stream_err:
logger.error(f"Stream error: {stream_err}")
if not full_content:
logger.warning(f"Empty content on attempt {attempt + 1}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
continue
return []
subtasks = self._parse_subtasks(full_content)
if subtasks:
return subtasks
except Exception as e:
logger.error(f"Error during split (Attempt {attempt + 1}): {e}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt)
return []
def _parse_subtasks(self, content: str) -> List[dict]:
if not content:
return []
# Basic cleaning of the response
clean_content = content.strip()
# Remove DeepSeek/NVIDIA thinking block if present
if "<think>" in clean_content:
if "</think>" in clean_content:
clean_content = clean_content.split("</think>")[-1].strip()
else:
parts = clean_content.split("<think>")
clean_content = parts[-1].strip()
if "{" in clean_content:
clean_content = "{" + clean_content.split("{", 1)[1]
if "```json" in clean_content:
clean_content = clean_content.split("```json")[1].split("```")[0].strip()
elif "```" in clean_content:
block_content = clean_content.split("```")[1].split("```")[0].strip()
if block_content.startswith("{") or "[" in block_content:
clean_content = block_content
if not (clean_content.startswith("{") or clean_content.startswith("[")):
if "{" in clean_content:
clean_content = "{" + clean_content.split("{", 1)[1]
if "}" in clean_content:
clean_content = clean_content.rsplit("}", 1)[0] + "}"
try:
content_json = json.loads(clean_content)
if isinstance(content_json, dict):
subtasks = content_json.get('subtasks', [])
elif isinstance(content_json, list):
subtasks = content_json
else:
subtasks = []
if subtasks:
print("\n\033[93m--- Generated Subtasks ---\033[0m")
for task in subtasks:
print(f"\033[93mID: {task.get('id')} - {task.get('title')}\033[0m")
pprint(task.get('description'))
print()
return subtasks
except json.JSONDecodeError as je:
logger.error(f"Failed to parse JSON content: {je}. Snippet: {clean_content[:150]}...")
return []
if __name__ == "__main__":
# Test block
from dotenv import load_dotenv
load_dotenv()
test_plan = "Research the current state of Solid State Batteries, focusing on major players and technical hurdles."
splitter = Splitter()
result = splitter.split(test_plan) |