Spaces:
Sleeping
Sleeping
Commit ·
2fd8593
0
Parent(s):
initial commit
Browse files- Dockerfile +10 -0
- app.py +1 -0
- codes/0_blog_process.py +136 -0
- codes/1_1_extract_config.py +66 -0
- codes/1_planning.py +385 -0
- codes/2_analyzing.py +228 -0
- codes/3_coding.py +249 -0
- codes/eval.py +277 -0
- codes/example_use_gemma.py +27 -0
- codes/llm_provider.py +342 -0
- codes/rate_limiter.py +97 -0
- codes/test_gemma.py +39 -0
- codes/utils.py +440 -0
- main.py +136 -0
- requirements.txt +13 -0
Dockerfile
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from main import app
|
codes/0_blog_process.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import requests
|
| 4 |
+
from bs4 import BeautifulSoup
|
| 5 |
+
import markdown
|
| 6 |
+
from urllib.parse import urlparse
|
| 7 |
+
|
| 8 |
+
def fetch_blog_from_url(url):
|
| 9 |
+
"""Fetch blog content from URL"""
|
| 10 |
+
try:
|
| 11 |
+
# Add user agent to avoid 403 errors from sites like Medium
|
| 12 |
+
headers = {
|
| 13 |
+
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
| 14 |
+
}
|
| 15 |
+
response = requests.get(url, timeout=30, headers=headers)
|
| 16 |
+
response.raise_for_status()
|
| 17 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
| 18 |
+
|
| 19 |
+
# Extract main content (adjust selectors based on blog platform)
|
| 20 |
+
title = soup.find('h1').get_text() if soup.find('h1') else "Untitled"
|
| 21 |
+
|
| 22 |
+
# Common content selectors - try multiple strategies
|
| 23 |
+
content = (soup.find('article') or
|
| 24 |
+
soup.find('main') or
|
| 25 |
+
soup.find('div', class_='content') or
|
| 26 |
+
soup.find('div', class_='post-content') or
|
| 27 |
+
soup.find('div', class_='entry-content'))
|
| 28 |
+
|
| 29 |
+
if content:
|
| 30 |
+
text = content.get_text(separator='\n', strip=True)
|
| 31 |
+
# Also extract code blocks separately
|
| 32 |
+
code_blocks = content.find_all(['pre', 'code'])
|
| 33 |
+
codes = [block.get_text() for block in code_blocks]
|
| 34 |
+
else:
|
| 35 |
+
text = soup.get_text(separator='\n', strip=True)
|
| 36 |
+
codes = []
|
| 37 |
+
|
| 38 |
+
return {
|
| 39 |
+
'title': title,
|
| 40 |
+
'url': url,
|
| 41 |
+
'content': text,
|
| 42 |
+
'code_snippets': codes
|
| 43 |
+
}
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"[ERROR] Failed to fetch URL: {e}")
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
def process_markdown_file(file_path):
|
| 49 |
+
"""Process markdown blog file"""
|
| 50 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 51 |
+
md_content = f.read()
|
| 52 |
+
|
| 53 |
+
# Convert markdown to HTML then extract text
|
| 54 |
+
html = markdown.markdown(md_content, extensions=['fenced_code', 'codehilite'])
|
| 55 |
+
soup = BeautifulSoup(html, 'html.parser')
|
| 56 |
+
|
| 57 |
+
# Extract title (first h1)
|
| 58 |
+
title = soup.find('h1')
|
| 59 |
+
title_text = title.get_text() if title else "Untitled"
|
| 60 |
+
|
| 61 |
+
# Extract code blocks
|
| 62 |
+
code_blocks = soup.find_all(['pre', 'code'])
|
| 63 |
+
codes = [block.get_text() for block in code_blocks]
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
'title': title_text,
|
| 67 |
+
'content': md_content,
|
| 68 |
+
'html': html,
|
| 69 |
+
'code_snippets': codes
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def process_html_file(file_path):
|
| 73 |
+
"""Process HTML blog file"""
|
| 74 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 75 |
+
html_content = f.read()
|
| 76 |
+
|
| 77 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
| 78 |
+
title = soup.find('h1').get_text() if soup.find('h1') else "Untitled"
|
| 79 |
+
text = soup.get_text(separator='\n', strip=True)
|
| 80 |
+
|
| 81 |
+
# Extract code blocks
|
| 82 |
+
code_blocks = soup.find_all(['pre', 'code'])
|
| 83 |
+
codes = [block.get_text() for block in code_blocks]
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
'title': title,
|
| 87 |
+
'content': text,
|
| 88 |
+
'html': html_content,
|
| 89 |
+
'code_snippets': codes
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def process_text_file(file_path):
|
| 93 |
+
"""Process plain text file"""
|
| 94 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 95 |
+
content = f.read()
|
| 96 |
+
|
| 97 |
+
return {
|
| 98 |
+
'title': 'Blog Post',
|
| 99 |
+
'content': content,
|
| 100 |
+
'code_snippets': []
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
def main(args):
|
| 104 |
+
if args.url:
|
| 105 |
+
print(f"[INFO] Fetching blog from URL: {args.url}")
|
| 106 |
+
blog_data = fetch_blog_from_url(args.url)
|
| 107 |
+
elif args.input_path:
|
| 108 |
+
print(f"[INFO] Processing local file: {args.input_path}")
|
| 109 |
+
if args.input_path.endswith('.md'):
|
| 110 |
+
blog_data = process_markdown_file(args.input_path)
|
| 111 |
+
elif args.input_path.endswith('.html'):
|
| 112 |
+
blog_data = process_html_file(args.input_path)
|
| 113 |
+
else:
|
| 114 |
+
# Plain text
|
| 115 |
+
blog_data = process_text_file(args.input_path)
|
| 116 |
+
else:
|
| 117 |
+
print("[ERROR] Must provide either --url or --input_path")
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
# Save as JSON
|
| 121 |
+
with open(args.output_json_path, 'w', encoding='utf-8') as f:
|
| 122 |
+
json.dump(blog_data, f, indent=2, ensure_ascii=False)
|
| 123 |
+
|
| 124 |
+
print(f"[SAVED] {args.output_json_path}")
|
| 125 |
+
print(f"[INFO] Title: {blog_data['title']}")
|
| 126 |
+
print(f"[INFO] Content length: {len(blog_data['content'])} characters")
|
| 127 |
+
print(f"[INFO] Code snippets found: {len(blog_data.get('code_snippets', []))}")
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
parser = argparse.ArgumentParser(description="Process blog posts into JSON format for Blog2Code")
|
| 131 |
+
parser.add_argument("--url", type=str, help="Blog URL to fetch")
|
| 132 |
+
parser.add_argument("--input_path", type=str, help="Local blog file path (.md, .html, or .txt)")
|
| 133 |
+
parser.add_argument("--output_json_path", type=str, required=True, help="Output JSON file path")
|
| 134 |
+
|
| 135 |
+
args = parser.parse_args()
|
| 136 |
+
main(args)
|
codes/1_1_extract_config.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
import shutil
|
| 6 |
+
from utils import extract_planning, content_to_json, format_json_data
|
| 7 |
+
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
|
| 10 |
+
parser.add_argument('--paper_name',type=str)
|
| 11 |
+
parser.add_argument('--output_dir',type=str, default="")
|
| 12 |
+
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
|
| 15 |
+
output_dir = args.output_dir
|
| 16 |
+
|
| 17 |
+
with open(f'{output_dir}/planning_trajectories.json', encoding='utf8') as f:
|
| 18 |
+
traj = json.load(f)
|
| 19 |
+
|
| 20 |
+
yaml_raw_content = ""
|
| 21 |
+
for turn_idx, turn in enumerate(traj):
|
| 22 |
+
if turn_idx == 8:
|
| 23 |
+
yaml_raw_content = turn['content']
|
| 24 |
+
|
| 25 |
+
if "</think>" in yaml_raw_content:
|
| 26 |
+
yaml_raw_content = yaml_raw_content.split("</think>")[-1]
|
| 27 |
+
|
| 28 |
+
match = re.search(r"```yaml\n(.*?)\n```", yaml_raw_content, re.DOTALL)
|
| 29 |
+
if match:
|
| 30 |
+
yaml_content = match.group(1)
|
| 31 |
+
with open(f'{output_dir}/planning_config.yaml', 'w', encoding='utf8') as f:
|
| 32 |
+
f.write(yaml_content)
|
| 33 |
+
else:
|
| 34 |
+
# print("No YAML content found.")
|
| 35 |
+
match2 = re.search(r"```yaml\\n(.*?)\\n```", yaml_raw_content, re.DOTALL)
|
| 36 |
+
if match2:
|
| 37 |
+
yaml_content = match2.group(1)
|
| 38 |
+
with open(f'{output_dir}/planning_config.yaml', 'w', encoding='utf8') as f:
|
| 39 |
+
f.write(yaml_content)
|
| 40 |
+
else:
|
| 41 |
+
print("No YAML content found.")
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------
|
| 44 |
+
|
| 45 |
+
artifact_output_dir=f"{output_dir}/planning_artifacts"
|
| 46 |
+
|
| 47 |
+
os.makedirs(artifact_output_dir, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
context_lst = extract_planning(f'{output_dir}/planning_trajectories.json')
|
| 50 |
+
|
| 51 |
+
arch_design = content_to_json(context_lst[1])
|
| 52 |
+
logic_design = content_to_json(context_lst[2])
|
| 53 |
+
|
| 54 |
+
formatted_arch_design = format_json_data(arch_design)
|
| 55 |
+
formatted_logic_design = format_json_data(logic_design)
|
| 56 |
+
|
| 57 |
+
with open(f"{artifact_output_dir}/1.1_overall_plan.txt", "w", encoding="utf-8") as f:
|
| 58 |
+
f.write(context_lst[0])
|
| 59 |
+
|
| 60 |
+
with open(f"{artifact_output_dir}/1.2_arch_design.txt", "w", encoding="utf-8") as f:
|
| 61 |
+
f.write(formatted_arch_design)
|
| 62 |
+
|
| 63 |
+
with open(f"{artifact_output_dir}/1.3_logic_design.txt", "w", encoding="utf-8") as f:
|
| 64 |
+
f.write(formatted_logic_design)
|
| 65 |
+
|
| 66 |
+
shutil.copy(f"{output_dir}/planning_config.yaml", f"{artifact_output_dir}/1.4_config.yaml")
|
codes/1_planning.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from utils import print_response, print_log_cost, load_accumulated_cost, save_accumulated_cost
|
| 7 |
+
from rate_limiter import RateLimiter, estimate_tokens
|
| 8 |
+
from llm_provider import get_provider, get_default_model
|
| 9 |
+
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
|
| 12 |
+
# Support both paper and blog inputs
|
| 13 |
+
parser.add_argument('--paper_name', type=str, help='Name of the paper (deprecated, use --content_name)')
|
| 14 |
+
parser.add_argument('--blog_name', type=str, help='Name of the blog')
|
| 15 |
+
parser.add_argument('--content_name', type=str, help='Name of the content (paper or blog)')
|
| 16 |
+
parser.add_argument('--gpt_version', type=str, help='Model version (deprecated, use --model)')
|
| 17 |
+
parser.add_argument('--model', type=str, help='Model name (e.g., gpt-4o-mini, gemini-1.5-flash)')
|
| 18 |
+
parser.add_argument('--provider', type=str, default='gemini', choices=['openai', 'gemini', 'gemma'], help='LLM provider to use')
|
| 19 |
+
parser.add_argument('--paper_format', type=str, default="JSON", choices=["JSON", "LaTeX"], help='Format for papers')
|
| 20 |
+
parser.add_argument('--blog_format', type=str, default="JSON", choices=["JSON", "Markdown", "HTML"], help='Format for blogs')
|
| 21 |
+
parser.add_argument('--content_format', type=str, default="JSON", help='Format of the content')
|
| 22 |
+
parser.add_argument('--pdf_json_path', type=str, help='Path to paper JSON file')
|
| 23 |
+
parser.add_argument('--pdf_latex_path', type=str, help='Path to paper LaTeX file')
|
| 24 |
+
parser.add_argument('--blog_json_path', type=str, help='Path to blog JSON file')
|
| 25 |
+
parser.add_argument('--blog_md_path', type=str, help='Path to blog Markdown file')
|
| 26 |
+
parser.add_argument('--blog_html_path', type=str, help='Path to blog HTML file')
|
| 27 |
+
parser.add_argument('--content_type', type=str, default="paper", choices=["paper", "blog"], help='Type of content to process')
|
| 28 |
+
parser.add_argument('--output_dir', type=str, default="")
|
| 29 |
+
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
# Initialize LLM provider
|
| 33 |
+
provider_name = args.provider
|
| 34 |
+
llm_provider = get_provider(provider_name)
|
| 35 |
+
model = args.model or args.gpt_version or get_default_model(provider_name)
|
| 36 |
+
|
| 37 |
+
print(f"🤖 Using {provider_name.upper()} with model: {model}")
|
| 38 |
+
|
| 39 |
+
# Determine content type and set variables
|
| 40 |
+
if args.blog_name or args.blog_json_path or args.blog_md_path or args.blog_html_path:
|
| 41 |
+
content_type = "blog"
|
| 42 |
+
content_name = args.blog_name or args.content_name or "BlogPost"
|
| 43 |
+
content_format = args.blog_format or args.content_format
|
| 44 |
+
content_path = args.blog_json_path or args.blog_md_path or args.blog_html_path
|
| 45 |
+
else:
|
| 46 |
+
content_type = args.content_type
|
| 47 |
+
content_name = args.paper_name or args.content_name or "Paper"
|
| 48 |
+
content_format = args.paper_format or args.content_format
|
| 49 |
+
content_path = args.pdf_json_path or args.pdf_latex_path
|
| 50 |
+
|
| 51 |
+
gpt_version = args.gpt_version
|
| 52 |
+
output_dir = args.output_dir
|
| 53 |
+
|
| 54 |
+
# Create output directory if it doesn't exist
|
| 55 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
# Load content based on format
|
| 58 |
+
if content_format in ["JSON"]:
|
| 59 |
+
with open(f'{content_path}') as f:
|
| 60 |
+
content_data = json.load(f)
|
| 61 |
+
elif content_format in ["LaTeX", "Markdown", "HTML"]:
|
| 62 |
+
with open(f'{content_path}') as f:
|
| 63 |
+
content_data = f.read()
|
| 64 |
+
else:
|
| 65 |
+
print(f"[ERROR] Invalid format. Please select JSON, LaTeX, Markdown, or HTML.")
|
| 66 |
+
sys.exit(0)
|
| 67 |
+
|
| 68 |
+
if content_type == "blog":
|
| 69 |
+
plan_msg = [
|
| 70 |
+
{'role': "system", "content": f"""You are an expert software engineer and technical content analyst with deep understanding of tutorial implementation and code reproduction.
|
| 71 |
+
You will receive a technical blog post in {content_format} format.
|
| 72 |
+
Your task is to create a detailed and efficient plan to implement the code, algorithms, or systems described in the blog post.
|
| 73 |
+
This plan should align precisely with the blog's tutorial steps, code examples, and technical specifications.
|
| 74 |
+
|
| 75 |
+
Instructions:
|
| 76 |
+
|
| 77 |
+
1. Align with the Blog: Your plan must strictly follow the methods, code examples, configurations, and implementation steps described in the blog.
|
| 78 |
+
2. Extract Code Snippets: Identify and organize any existing code snippets from the blog.
|
| 79 |
+
3. Fill Gaps: Identify missing implementation details that need to be inferred or completed.
|
| 80 |
+
4. Be Clear and Structured: Present the plan in a well-organized and easy-to-follow format, breaking it down into actionable steps.
|
| 81 |
+
5. Prioritize Efficiency: Optimize the plan for clarity and practical implementation while ensuring fidelity to the original tutorial.
|
| 82 |
+
6. Add Production Features: Plan for error handling, logging, testing, and documentation that may not be in the blog."""},
|
| 83 |
+
{"role": "user",
|
| 84 |
+
"content" : f"""## Blog Post
|
| 85 |
+
{content_data}
|
| 86 |
+
|
| 87 |
+
## Task
|
| 88 |
+
1. We want to implement the tutorial/system described in this blog post.
|
| 89 |
+
2. The blog may contain partial code snippets that we need to organize and complete.
|
| 90 |
+
3. Before writing the final code, please outline a comprehensive plan that covers:
|
| 91 |
+
- Key implementation steps from the blog
|
| 92 |
+
- Code architecture and structure
|
| 93 |
+
- Dependencies and libraries mentioned
|
| 94 |
+
- Configuration requirements
|
| 95 |
+
- Any code snippets already provided in the blog
|
| 96 |
+
- Missing details that need to be inferred or completed
|
| 97 |
+
4. The plan should be as **detailed and practical** as possible to help us write production-ready code.
|
| 98 |
+
|
| 99 |
+
## Requirements
|
| 100 |
+
- Extract and organize any existing code snippets from the blog
|
| 101 |
+
- Identify gaps in the blog's explanation that need to be filled
|
| 102 |
+
- Focus on creating a **working, complete implementation**
|
| 103 |
+
- If something is unclear from the blog, mention it explicitly and suggest reasonable defaults
|
| 104 |
+
|
| 105 |
+
## Instruction
|
| 106 |
+
The response should give us a strong roadmap for turning this blog tutorial into production code."""}]
|
| 107 |
+
else:
|
| 108 |
+
plan_msg = [
|
| 109 |
+
{'role': "system", "content": f"""You are an expert researcher and strategic planner with a deep understanding of experimental design and reproducibility in scientific research.
|
| 110 |
+
You will receive a research paper in {content_format} format.
|
| 111 |
+
Your task is to create a detailed and efficient plan to reproduce the experiments and methodologies described in the paper.
|
| 112 |
+
This plan should align precisely with the paper's methodology, experimental setup, and evaluation metrics.
|
| 113 |
+
|
| 114 |
+
Instructions:
|
| 115 |
+
|
| 116 |
+
1. Align with the Paper: Your plan must strictly follow the methods, datasets, model configurations, hyperparameters, and experimental setups described in the paper.
|
| 117 |
+
2. Be Clear and Structured: Present the plan in a well-organized and easy-to-follow format, breaking it down into actionable steps.
|
| 118 |
+
3. Prioritize Efficiency: Optimize the plan for clarity and practical implementation while ensuring fidelity to the original experiments."""},
|
| 119 |
+
{"role": "user",
|
| 120 |
+
"content" : f"""## Paper
|
| 121 |
+
{content_data}
|
| 122 |
+
|
| 123 |
+
## Task
|
| 124 |
+
1. We want to reproduce the method described in the attached paper.
|
| 125 |
+
2. The authors did not release any official code, so we have to plan our own implementation.
|
| 126 |
+
3. Before writing any Python code, please outline a comprehensive plan that covers:
|
| 127 |
+
- Key details from the paper's **Methodology**.
|
| 128 |
+
- Important aspects of **Experiments**, including dataset requirements, experimental settings, hyperparameters, or evaluation metrics.
|
| 129 |
+
4. The plan should be as **detailed and informative** as possible to help us write the final code later.
|
| 130 |
+
|
| 131 |
+
## Requirements
|
| 132 |
+
- You don't need to provide the actual code yet; focus on a **thorough, clear strategy**.
|
| 133 |
+
- If something is unclear from the paper, mention it explicitly.
|
| 134 |
+
|
| 135 |
+
## Instruction
|
| 136 |
+
The response should give us a strong roadmap, making it easier to write the code later."""}]
|
| 137 |
+
|
| 138 |
+
file_list_msg = [
|
| 139 |
+
{"role": "user", "content": """Your goal is to create a concise, usable, and complete software system design for reproducing the paper's method. Use appropriate open-source libraries and keep the overall architecture simple.
|
| 140 |
+
|
| 141 |
+
Based on the plan for reproducing the paper’s main method, please design a concise, usable, and complete software system.
|
| 142 |
+
Keep the architecture simple and make effective use of open-source libraries.
|
| 143 |
+
|
| 144 |
+
-----
|
| 145 |
+
|
| 146 |
+
## Format Example
|
| 147 |
+
[CONTENT]
|
| 148 |
+
{
|
| 149 |
+
"Implementation approach": "We will ... ,
|
| 150 |
+
"File list": [
|
| 151 |
+
"main.py",
|
| 152 |
+
"dataset_loader.py",
|
| 153 |
+
"model.py",
|
| 154 |
+
"trainer.py",
|
| 155 |
+
"evaluation.py"
|
| 156 |
+
],
|
| 157 |
+
"Data structures and interfaces": "\nclassDiagram\n class Main {\n +__init__()\n +run_experiment()\n }\n class DatasetLoader {\n +__init__(config: dict)\n +load_data() -> Any\n }\n class Model {\n +__init__(params: dict)\n +forward(x: Tensor) -> Tensor\n }\n class Trainer {\n +__init__(model: Model, data: Any)\n +train() -> None\n }\n class Evaluation {\n +__init__(model: Model, data: Any)\n +evaluate() -> dict\n }\n Main --> DatasetLoader\n Main --> Trainer\n Main --> Evaluation\n Trainer --> Model\n",
|
| 158 |
+
"Program call flow": "\nsequenceDiagram\n participant M as Main\n participant DL as DatasetLoader\n participant MD as Model\n participant TR as Trainer\n participant EV as Evaluation\n M->>DL: load_data()\n DL-->>M: return dataset\n M->>MD: initialize model()\n M->>TR: train(model, dataset)\n TR->>MD: forward(x)\n MD-->>TR: predictions\n TR-->>M: training complete\n M->>EV: evaluate(model, dataset)\n EV->>MD: forward(x)\n MD-->>EV: predictions\n EV-->>M: metrics\n",
|
| 159 |
+
"Anything UNCLEAR": "Need clarification on the exact dataset format and any specialized hyperparameters."
|
| 160 |
+
}
|
| 161 |
+
[/CONTENT]
|
| 162 |
+
|
| 163 |
+
## Nodes: "<node>: <type> # <instruction>"
|
| 164 |
+
- Implementation approach: <class 'str'> # Summarize the chosen solution strategy.
|
| 165 |
+
- File list: typing.List[str] # Only need relative paths. ALWAYS write a main.py or app.py here.
|
| 166 |
+
- Data structures and interfaces: typing.Optional[str] # Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.
|
| 167 |
+
- Program call flow: typing.Optional[str] # Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.
|
| 168 |
+
- Anything UNCLEAR: <class 'str'> # Mention ambiguities and ask for clarifications.
|
| 169 |
+
|
| 170 |
+
## Constraint
|
| 171 |
+
Format: output wrapped inside [CONTENT][/CONTENT] like the format example, nothing else.
|
| 172 |
+
|
| 173 |
+
## Action
|
| 174 |
+
Follow the instructions for the nodes, generate the output, and ensure it follows the format example."""}
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
task_list_msg = [
|
| 178 |
+
{'role': 'user', 'content': """Your goal is break down tasks according to PRD/technical design, generate a task list, and analyze task dependencies.
|
| 179 |
+
You will break down tasks, analyze dependencies.
|
| 180 |
+
|
| 181 |
+
You outline a clear PRD/technical design for reproducing the paper’s method and experiments.
|
| 182 |
+
|
| 183 |
+
Now, let's break down tasks according to PRD/technical design, generate a task list, and analyze task dependencies.
|
| 184 |
+
The Logic Analysis should not only consider the dependencies between files but also provide detailed descriptions to assist in writing the code needed to reproduce the paper.
|
| 185 |
+
|
| 186 |
+
-----
|
| 187 |
+
|
| 188 |
+
## Format Example
|
| 189 |
+
[CONTENT]
|
| 190 |
+
{
|
| 191 |
+
"Required packages": [
|
| 192 |
+
"numpy==1.21.0",
|
| 193 |
+
"torch==1.9.0"
|
| 194 |
+
],
|
| 195 |
+
"Required Other language third-party packages": [
|
| 196 |
+
"No third-party dependencies required"
|
| 197 |
+
],
|
| 198 |
+
"Logic Analysis": [
|
| 199 |
+
[
|
| 200 |
+
"data_preprocessing.py",
|
| 201 |
+
"DataPreprocessing class ........"
|
| 202 |
+
],
|
| 203 |
+
[
|
| 204 |
+
"trainer.py",
|
| 205 |
+
"Trainer ....... "
|
| 206 |
+
],
|
| 207 |
+
[
|
| 208 |
+
"dataset_loader.py",
|
| 209 |
+
"Handles loading and ........"
|
| 210 |
+
],
|
| 211 |
+
[
|
| 212 |
+
"model.py",
|
| 213 |
+
"Defines the model ......."
|
| 214 |
+
],
|
| 215 |
+
[
|
| 216 |
+
"evaluation.py",
|
| 217 |
+
"Evaluation class ........ "
|
| 218 |
+
],
|
| 219 |
+
[
|
| 220 |
+
"main.py",
|
| 221 |
+
"Entry point ......."
|
| 222 |
+
]
|
| 223 |
+
],
|
| 224 |
+
"Task list": [
|
| 225 |
+
"dataset_loader.py",
|
| 226 |
+
"model.py",
|
| 227 |
+
"trainer.py",
|
| 228 |
+
"evaluation.py",
|
| 229 |
+
"main.py"
|
| 230 |
+
],
|
| 231 |
+
"Full API spec": "openapi: 3.0.0 ...",
|
| 232 |
+
"Shared Knowledge": "Both data_preprocessing.py and trainer.py share ........",
|
| 233 |
+
"Anything UNCLEAR": "Clarification needed on recommended hardware configuration for large-scale experiments."
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
[/CONTENT]
|
| 237 |
+
|
| 238 |
+
## Nodes: "<node>: <type> # <instruction>"
|
| 239 |
+
- Required packages: typing.Optional[typing.List[str]] # Provide required third-party packages in requirements.txt format.(e.g., 'numpy==1.21.0').
|
| 240 |
+
- Required Other language third-party packages: typing.List[str] # List down packages required for non-Python languages. If none, specify "No third-party dependencies required".
|
| 241 |
+
- Logic Analysis: typing.List[typing.List[str]] # Provide a list of files with the classes/methods/functions to be implemented, including dependency analysis and imports. Include as much detailed description as possible.
|
| 242 |
+
- Task list: typing.List[str] # Break down the tasks into a list of filenames, prioritized based on dependency order. The task list must include the previously generated file list.
|
| 243 |
+
- Full API spec: <class 'str'> # Describe all APIs using OpenAPI 3.0 spec that may be used by both frontend and backend. If front-end and back-end communication is not required, leave it blank.
|
| 244 |
+
- Shared Knowledge: <class 'str'> # Detail any shared knowledge, like common utility functions or configuration variables.
|
| 245 |
+
- Anything UNCLEAR: <class 'str'> # Mention any unresolved questions or clarifications needed from the paper or project scope.
|
| 246 |
+
|
| 247 |
+
## Constraint
|
| 248 |
+
Format: output wrapped inside [CONTENT][/CONTENT] like the format example, nothing else.
|
| 249 |
+
|
| 250 |
+
## Action
|
| 251 |
+
Follow the node instructions above, generate your output accordingly, and ensure it follows the given format example."""}]
|
| 252 |
+
|
| 253 |
+
# config
|
| 254 |
+
config_msg = [
|
| 255 |
+
{'role': 'user', 'content': """You write elegant, modular, and maintainable code. Adhere to Google-style guidelines.
|
| 256 |
+
|
| 257 |
+
Based on the paper, plan, design specified previously, follow the "Format Example" and generate the code.
|
| 258 |
+
Extract the training details from the above paper (e.g., learning rate, batch size, epochs, etc.), follow the "Format example" and generate the code.
|
| 259 |
+
DO NOT FABRICATE DETAILS — only use what the paper provides.
|
| 260 |
+
|
| 261 |
+
You must write `config.yaml`.
|
| 262 |
+
|
| 263 |
+
ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Your output format must follow the example below exactly.
|
| 264 |
+
|
| 265 |
+
-----
|
| 266 |
+
|
| 267 |
+
# Format Example
|
| 268 |
+
## Code: config.yaml
|
| 269 |
+
```yaml
|
| 270 |
+
## config.yaml
|
| 271 |
+
training:
|
| 272 |
+
learning_rate: ...
|
| 273 |
+
batch_size: ...
|
| 274 |
+
epochs: ...
|
| 275 |
+
...
|
| 276 |
+
```
|
| 277 |
+
|
| 278 |
+
-----
|
| 279 |
+
|
| 280 |
+
## Code: config.yaml
|
| 281 |
+
"""
|
| 282 |
+
}]
|
| 283 |
+
|
| 284 |
+
def api_call(msg, model_name):
|
| 285 |
+
"""Make API call using the configured provider"""
|
| 286 |
+
# Special handling for o3-mini reasoning effort
|
| 287 |
+
if "o3-mini" in model_name and provider_name == 'openai':
|
| 288 |
+
completion = llm_provider.create_completion(
|
| 289 |
+
messages=msg,
|
| 290 |
+
model=model_name,
|
| 291 |
+
reasoning_effort="high"
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
completion = llm_provider.create_completion(
|
| 295 |
+
messages=msg,
|
| 296 |
+
model=model_name
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
return completion
|
| 300 |
+
|
| 301 |
+
responses = []
|
| 302 |
+
trajectories = []
|
| 303 |
+
total_accumulated_cost = 0
|
| 304 |
+
|
| 305 |
+
# Initialize rate limiter to avoid hitting TPM limits
|
| 306 |
+
rate_limiter = RateLimiter(max_tokens_per_minute=95000) # 95K with 5K buffer
|
| 307 |
+
print("🛡️ Rate limiter initialized (95K TPM limit)")
|
| 308 |
+
|
| 309 |
+
for idx, instruction_msg in enumerate([plan_msg, file_list_msg, task_list_msg, config_msg]):
|
| 310 |
+
current_stage = ""
|
| 311 |
+
if idx == 0 :
|
| 312 |
+
current_stage = f"[Planning] Overall plan"
|
| 313 |
+
elif idx == 1:
|
| 314 |
+
current_stage = f"[Planning] Architecture design"
|
| 315 |
+
elif idx == 2:
|
| 316 |
+
current_stage = f"[Planning] Logic design"
|
| 317 |
+
elif idx == 3:
|
| 318 |
+
current_stage = f"[Planning] Configuration file generation"
|
| 319 |
+
print(current_stage)
|
| 320 |
+
|
| 321 |
+
trajectories.extend(instruction_msg)
|
| 322 |
+
|
| 323 |
+
# Estimate tokens for this request and wait if needed
|
| 324 |
+
estimated_tokens = estimate_tokens(str(trajectories))
|
| 325 |
+
rate_limiter.wait_if_needed(estimated_tokens)
|
| 326 |
+
|
| 327 |
+
completion = api_call(trajectories, model)
|
| 328 |
+
|
| 329 |
+
# Extract response text using provider abstraction
|
| 330 |
+
response_text = llm_provider.get_response_text(completion)
|
| 331 |
+
usage_info = llm_provider.get_usage_info(completion)
|
| 332 |
+
|
| 333 |
+
# Create completion JSON for logging (compatible format)
|
| 334 |
+
completion_json = {
|
| 335 |
+
'choices': [{'message': {'role': 'assistant', 'content': response_text}}],
|
| 336 |
+
'usage': usage_info,
|
| 337 |
+
'model': model
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
# print and logging
|
| 341 |
+
print_response(completion_json)
|
| 342 |
+
temp_total_accumulated_cost = print_log_cost(completion_json, model, current_stage, output_dir, total_accumulated_cost)
|
| 343 |
+
total_accumulated_cost = temp_total_accumulated_cost
|
| 344 |
+
|
| 345 |
+
responses.append(completion_json)
|
| 346 |
+
|
| 347 |
+
# trajectories
|
| 348 |
+
message = {'role': 'assistant', 'content': response_text}
|
| 349 |
+
trajectories.append(message)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# save
|
| 353 |
+
save_accumulated_cost(f"{output_dir}/accumulated_cost.json", total_accumulated_cost)
|
| 354 |
+
|
| 355 |
+
# Print rate limiter statistics
|
| 356 |
+
rate_limiter.print_stats()
|
| 357 |
+
|
| 358 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 359 |
+
|
| 360 |
+
with open(f'{output_dir}/planning_response.json', 'w') as f:
|
| 361 |
+
json.dump(responses, f)
|
| 362 |
+
|
| 363 |
+
with open(f'{output_dir}/planning_trajectories.json', 'w') as f:
|
| 364 |
+
json.dump(trajectories, f)
|
| 365 |
+
|
| 366 |
+
# Export planning as markdown for easy reference
|
| 367 |
+
print("\n📝 Exporting planning to markdown...")
|
| 368 |
+
with open(f'{output_dir}/planning_output.md', 'w', encoding='utf-8') as f:
|
| 369 |
+
f.write(f"# Planning Output for {content_name}\n\n")
|
| 370 |
+
f.write(f"**Model:** {model}\n")
|
| 371 |
+
f.write(f"**Provider:** {provider_name}\n")
|
| 372 |
+
f.write(f"**Content Type:** {content_type}\n\n")
|
| 373 |
+
f.write("---\n\n")
|
| 374 |
+
|
| 375 |
+
for idx, response in enumerate(responses):
|
| 376 |
+
stage_names = ["Overall Plan", "Architecture Design", "Logic Design", "Configuration"]
|
| 377 |
+
stage_name = stage_names[idx] if idx < len(stage_names) else f"Stage {idx+1}"
|
| 378 |
+
|
| 379 |
+
f.write(f"## {stage_name}\n\n")
|
| 380 |
+
content = response['choices'][0]['message']['content']
|
| 381 |
+
f.write(content)
|
| 382 |
+
f.write("\n\n---\n\n")
|
| 383 |
+
|
| 384 |
+
print(f"✅ Planning saved to: {output_dir}/planning_output.md")
|
| 385 |
+
|
codes/2_analyzing.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import sys
|
| 5 |
+
from utils import extract_planning, content_to_json, print_response, print_log_cost, load_accumulated_cost, save_accumulated_cost
|
| 6 |
+
from llm_provider import get_provider, get_default_model
|
| 7 |
+
import copy
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
|
| 13 |
+
parser.add_argument('--paper_name', type=str)
|
| 14 |
+
parser.add_argument('--gpt_version', type=str, help='Model version (deprecated, use --model)')
|
| 15 |
+
parser.add_argument('--model', type=str, help='Model name')
|
| 16 |
+
parser.add_argument('--provider', type=str, default='gemini', choices=['openai', 'gemini', 'gemma'], help='LLM provider')
|
| 17 |
+
parser.add_argument('--paper_format', type=str, default="JSON", choices=["JSON", "LaTeX"])
|
| 18 |
+
parser.add_argument('--pdf_json_path', type=str)
|
| 19 |
+
parser.add_argument('--pdf_latex_path', type=str)
|
| 20 |
+
parser.add_argument('--output_dir', type=str, default="")
|
| 21 |
+
|
| 22 |
+
args = parser.parse_args()
|
| 23 |
+
|
| 24 |
+
# Initialize LLM provider
|
| 25 |
+
provider_name = args.provider
|
| 26 |
+
llm_provider = get_provider(provider_name)
|
| 27 |
+
model = args.model or args.gpt_version or get_default_model(provider_name)
|
| 28 |
+
|
| 29 |
+
print(f"🤖 Using {provider_name.upper()} with model: {model}")
|
| 30 |
+
|
| 31 |
+
paper_name = args.paper_name
|
| 32 |
+
gpt_version = args.gpt_version # Keep for backward compatibility if needed in print_log_cost or other places
|
| 33 |
+
paper_format = args.paper_format
|
| 34 |
+
pdf_json_path = args.pdf_json_path
|
| 35 |
+
pdf_latex_path = args.pdf_latex_path
|
| 36 |
+
output_dir = args.output_dir
|
| 37 |
+
|
| 38 |
+
if paper_format == "JSON":
|
| 39 |
+
with open(f'{pdf_json_path}') as f:
|
| 40 |
+
paper_content = json.load(f)
|
| 41 |
+
elif paper_format == "LaTeX":
|
| 42 |
+
with open(f'{pdf_latex_path}') as f:
|
| 43 |
+
paper_content = f.read()
|
| 44 |
+
else:
|
| 45 |
+
print(f"[ERROR] Invalid paper format. Please select either 'JSON' or 'LaTeX.")
|
| 46 |
+
sys.exit(0)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
with open(f'{output_dir}/planning_config.yaml') as f:
|
| 50 |
+
config_yaml = f.read()
|
| 51 |
+
|
| 52 |
+
context_lst = extract_planning(f'{output_dir}/planning_trajectories.json')
|
| 53 |
+
|
| 54 |
+
# 0: overview, 1: detailed, 2: PRD
|
| 55 |
+
if os.path.exists(f'{output_dir}/task_list.json'):
|
| 56 |
+
with open(f'{output_dir}/task_list.json') as f:
|
| 57 |
+
task_list = json.load(f)
|
| 58 |
+
else:
|
| 59 |
+
task_list = content_to_json(context_lst[2])
|
| 60 |
+
|
| 61 |
+
if 'Task list' in task_list:
|
| 62 |
+
todo_file_lst = task_list['Task list']
|
| 63 |
+
elif 'task_list' in task_list:
|
| 64 |
+
todo_file_lst = task_list['task_list']
|
| 65 |
+
elif 'task list' in task_list:
|
| 66 |
+
todo_file_lst = task_list['task list']
|
| 67 |
+
else:
|
| 68 |
+
print(f"[ERROR] 'Task list' does not exist. Please re-generate the planning.")
|
| 69 |
+
sys.exit(0)
|
| 70 |
+
|
| 71 |
+
if 'Logic Analysis' in task_list:
|
| 72 |
+
logic_analysis = task_list['Logic Analysis']
|
| 73 |
+
elif 'logic_analysis' in task_list:
|
| 74 |
+
logic_analysis = task_list['logic_analysis']
|
| 75 |
+
elif 'logic analysis' in task_list:
|
| 76 |
+
logic_analysis = task_list['logic analysis']
|
| 77 |
+
else:
|
| 78 |
+
print(f"[ERROR] 'Logic Analysis' does not exist. Please re-generate the planning.")
|
| 79 |
+
sys.exit(0)
|
| 80 |
+
|
| 81 |
+
done_file_lst = ['config.yaml']
|
| 82 |
+
logic_analysis_dict = {}
|
| 83 |
+
for desc in task_list['Logic Analysis']:
|
| 84 |
+
logic_analysis_dict[desc[0]] = desc[1]
|
| 85 |
+
|
| 86 |
+
analysis_msg = [
|
| 87 |
+
{"role": "system", "content": f"""You are an expert researcher, strategic analyzer and software engineer with a deep understanding of experimental design and reproducibility in scientific research.
|
| 88 |
+
You will receive a research paper in {paper_format} format, an overview of the plan, a design in JSON format consisting of "Implementation approach", "File list", "Data structures and interfaces", and "Program call flow", followed by a task in JSON format that includes "Required packages", "Required other language third-party packages", "Logic Analysis", and "Task list", along with a configuration file named "config.yaml".
|
| 89 |
+
|
| 90 |
+
Your task is to conduct a comprehensive logic analysis to accurately reproduce the experiments and methodologies described in the research paper.
|
| 91 |
+
This analysis must align precisely with the paper’s methodology, experimental setup, and evaluation criteria.
|
| 92 |
+
|
| 93 |
+
1. Align with the Paper: Your analysis must strictly follow the methods, datasets, model configurations, hyperparameters, and experimental setups described in the paper.
|
| 94 |
+
2. Be Clear and Structured: Present your analysis in a logical, well-organized, and actionable format that is easy to follow and implement.
|
| 95 |
+
3. Prioritize Efficiency: Optimize the analysis for clarity and practical implementation while ensuring fidelity to the original experiments.
|
| 96 |
+
4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.
|
| 97 |
+
5. REFER TO CONFIGURATION: Always reference settings from the config.yaml file. Do not invent or assume any values—only use configurations explicitly provided.
|
| 98 |
+
|
| 99 |
+
"""}]
|
| 100 |
+
|
| 101 |
+
def get_write_msg(todo_file_name, todo_file_desc):
|
| 102 |
+
|
| 103 |
+
draft_desc = f"Write the logic analysis in '{todo_file_name}', which is intended for '{todo_file_desc}'."
|
| 104 |
+
if len(todo_file_desc.strip()) == 0:
|
| 105 |
+
draft_desc = f"Write the logic analysis in '{todo_file_name}'."
|
| 106 |
+
|
| 107 |
+
write_msg=[{'role': 'user', "content": f"""## Paper
|
| 108 |
+
{paper_content}
|
| 109 |
+
|
| 110 |
+
-----
|
| 111 |
+
|
| 112 |
+
## Overview of the plan
|
| 113 |
+
{context_lst[0]}
|
| 114 |
+
|
| 115 |
+
-----
|
| 116 |
+
|
| 117 |
+
## Design
|
| 118 |
+
{context_lst[1]}
|
| 119 |
+
|
| 120 |
+
-----
|
| 121 |
+
|
| 122 |
+
## Task
|
| 123 |
+
{context_lst[2]}
|
| 124 |
+
|
| 125 |
+
-----
|
| 126 |
+
|
| 127 |
+
## Configuration file
|
| 128 |
+
```yaml
|
| 129 |
+
{config_yaml}
|
| 130 |
+
```
|
| 131 |
+
-----
|
| 132 |
+
|
| 133 |
+
## Instruction
|
| 134 |
+
Conduct a Logic Analysis to assist in writing the code, based on the paper, the plan, the design, the task and the previously specified configuration file (config.yaml).
|
| 135 |
+
You DON'T need to provide the actual code yet; focus on a thorough, clear analysis.
|
| 136 |
+
|
| 137 |
+
{draft_desc}
|
| 138 |
+
|
| 139 |
+
-----
|
| 140 |
+
|
| 141 |
+
## Logic Analysis: {todo_file_name}"""}]
|
| 142 |
+
return write_msg
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def api_call(msg):
|
| 146 |
+
"""Make API call using the configured provider"""
|
| 147 |
+
if "o3-mini" in model and provider_name == 'openai':
|
| 148 |
+
completion = llm_provider.create_completion(
|
| 149 |
+
messages=msg,
|
| 150 |
+
model=model,
|
| 151 |
+
reasoning_effort="high"
|
| 152 |
+
)
|
| 153 |
+
else:
|
| 154 |
+
completion = llm_provider.create_completion(
|
| 155 |
+
messages=msg,
|
| 156 |
+
model=model
|
| 157 |
+
)
|
| 158 |
+
return completion
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
artifact_output_dir=f'{output_dir}/analyzing_artifacts'
|
| 162 |
+
os.makedirs(artifact_output_dir, exist_ok=True)
|
| 163 |
+
|
| 164 |
+
total_accumulated_cost = load_accumulated_cost(f"{output_dir}/accumulated_cost.json")
|
| 165 |
+
for todo_file_name in tqdm(todo_file_lst):
|
| 166 |
+
responses = []
|
| 167 |
+
trajectories = copy.deepcopy(analysis_msg)
|
| 168 |
+
|
| 169 |
+
current_stage=f"[ANALYSIS] {todo_file_name}"
|
| 170 |
+
print(current_stage)
|
| 171 |
+
if todo_file_name == "config.yaml":
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
if todo_file_name not in logic_analysis_dict:
|
| 175 |
+
# print(f"[DEBUG ANALYSIS] {paper_name} {todo_file_name} is not exist in the logic analysis")
|
| 176 |
+
logic_analysis_dict[todo_file_name] = ""
|
| 177 |
+
|
| 178 |
+
instruction_msg = get_write_msg(todo_file_name, logic_analysis_dict[todo_file_name])
|
| 179 |
+
trajectories.extend(instruction_msg)
|
| 180 |
+
|
| 181 |
+
completion = llm_provider.create_completion(
|
| 182 |
+
messages=trajectories,
|
| 183 |
+
model=model
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Extract response using provider abstraction
|
| 187 |
+
response_text = llm_provider.get_response_text(completion)
|
| 188 |
+
usage_info = llm_provider.get_usage_info(completion)
|
| 189 |
+
|
| 190 |
+
# Create completion JSON for logging
|
| 191 |
+
completion_json = {
|
| 192 |
+
'choices': [{'message': {'role': 'assistant', 'content': response_text}}],
|
| 193 |
+
'usage': usage_info,
|
| 194 |
+
'model': model
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
# print and logging
|
| 198 |
+
print_response(completion_json)
|
| 199 |
+
temp_total_accumulated_cost = print_log_cost(completion_json, model, current_stage, output_dir, total_accumulated_cost)
|
| 200 |
+
total_accumulated_cost = temp_total_accumulated_cost
|
| 201 |
+
|
| 202 |
+
responses.append(completion_json)
|
| 203 |
+
|
| 204 |
+
# trajectories
|
| 205 |
+
message = {'role': 'assistant', 'content': response_text}
|
| 206 |
+
trajectories.append(message)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# save - create subdirectories if needed
|
| 210 |
+
artifact_file_path = f'{artifact_output_dir}/{todo_file_name}_simple_analysis.txt'
|
| 211 |
+
artifact_file_dir = os.path.dirname(artifact_file_path)
|
| 212 |
+
os.makedirs(artifact_file_dir, exist_ok=True)
|
| 213 |
+
|
| 214 |
+
with open(artifact_file_path, 'w') as f:
|
| 215 |
+
f.write(completion_json['choices'][0]['message']['content'])
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
done_file_lst.append(todo_file_name)
|
| 219 |
+
|
| 220 |
+
# save for next stage(coding)
|
| 221 |
+
todo_file_name = todo_file_name.replace("/", "_")
|
| 222 |
+
with open(f'{output_dir}/{todo_file_name}_simple_analysis_response.json', 'w') as f:
|
| 223 |
+
json.dump(responses, f)
|
| 224 |
+
|
| 225 |
+
with open(f'{output_dir}/{todo_file_name}_simple_analysis_trajectories.json', 'w') as f:
|
| 226 |
+
json.dump(trajectories, f)
|
| 227 |
+
|
| 228 |
+
save_accumulated_cost(f"{output_dir}/accumulated_cost.json", total_accumulated_cost)
|
codes/3_coding.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import re
|
| 5 |
+
import sys
|
| 6 |
+
import copy
|
| 7 |
+
from utils import extract_planning, content_to_json, extract_code_from_content, print_response, print_log_cost, load_accumulated_cost, save_accumulated_cost
|
| 8 |
+
from llm_provider import get_provider, get_default_model
|
| 9 |
+
import argparse
|
| 10 |
+
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
|
| 13 |
+
parser.add_argument('--paper_name', type=str)
|
| 14 |
+
parser.add_argument('--gpt_version', type=str, default="o3-mini", help='Model version (deprecated, use --model)')
|
| 15 |
+
parser.add_argument('--model', type=str, help='Model name')
|
| 16 |
+
parser.add_argument('--provider', type=str, default='gemini', choices=['openai', 'gemini', 'gemma'], help='LLM provider')
|
| 17 |
+
parser.add_argument('--paper_format', type=str, default="JSON", choices=["JSON", "LaTeX"])
|
| 18 |
+
parser.add_argument('--pdf_json_path', type=str) # json format
|
| 19 |
+
parser.add_argument('--pdf_latex_path', type=str) # latex format
|
| 20 |
+
parser.add_argument('--output_dir',type=str, default="")
|
| 21 |
+
parser.add_argument('--output_repo_dir',type=str, default="")
|
| 22 |
+
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
# Initialize LLM provider
|
| 26 |
+
provider_name = args.provider
|
| 27 |
+
llm_provider = get_provider(provider_name)
|
| 28 |
+
model = args.model or args.gpt_version or get_default_model(provider_name)
|
| 29 |
+
|
| 30 |
+
print(f"🤖 Using {provider_name.upper()} with model: {model}")
|
| 31 |
+
|
| 32 |
+
paper_name = args.paper_name
|
| 33 |
+
gpt_version = args.gpt_version
|
| 34 |
+
paper_format = args.paper_format
|
| 35 |
+
pdf_json_path = args.pdf_json_path
|
| 36 |
+
pdf_latex_path = args.pdf_latex_path
|
| 37 |
+
output_dir = args.output_dir
|
| 38 |
+
output_repo_dir = args.output_repo_dir
|
| 39 |
+
|
| 40 |
+
if paper_format == "JSON":
|
| 41 |
+
with open(f'{pdf_json_path}') as f:
|
| 42 |
+
paper_content = json.load(f)
|
| 43 |
+
elif paper_format == "LaTeX":
|
| 44 |
+
with open(f'{pdf_latex_path}') as f:
|
| 45 |
+
paper_content = f.read()
|
| 46 |
+
else:
|
| 47 |
+
print(f"[ERROR] Invalid paper format. Please select either 'JSON' or 'LaTeX.")
|
| 48 |
+
sys.exit(0)
|
| 49 |
+
|
| 50 |
+
with open(f'{output_dir}/planning_config.yaml') as f:
|
| 51 |
+
config_yaml = f.read()
|
| 52 |
+
|
| 53 |
+
context_lst = extract_planning(f'{output_dir}/planning_trajectories.json')
|
| 54 |
+
# 0: overview, 1: detailed, 2: PRD
|
| 55 |
+
# file_list = content_to_json(context_lst[1])
|
| 56 |
+
task_list = content_to_json(context_lst[2])
|
| 57 |
+
|
| 58 |
+
todo_file_lst = task_list['Task list']
|
| 59 |
+
done_file_lst = ['config.yaml']
|
| 60 |
+
done_file_dict = {}
|
| 61 |
+
|
| 62 |
+
code_msg = [
|
| 63 |
+
{"role": "system", "content": f"""You are an expert researcher and software engineer with a deep understanding of experimental design and reproducibility in scientific research.
|
| 64 |
+
You will receive a research paper in {paper_format} format, an overview of the plan, a Design in JSON format consisting of "Implementation approach", "File list", "Data structures and interfaces", and "Program call flow", followed by a Task in JSON format that includes "Required packages", "Required other language third-party packages", "Logic Analysis", and "Task list", along with a configuration file named "config.yaml".
|
| 65 |
+
Your task is to write code to reproduce the experiments and methodologies described in the paper.
|
| 66 |
+
|
| 67 |
+
The code you write must be elegant, modular, and maintainable, adhering to Google-style guidelines.
|
| 68 |
+
The code must strictly align with the paper's methodology, experimental setup, and evaluation metrics.
|
| 69 |
+
Write code with triple quoto."""}]
|
| 70 |
+
|
| 71 |
+
def get_write_msg(todo_file_name, detailed_logic_analysis, done_file_lst):
|
| 72 |
+
code_files = ""
|
| 73 |
+
for done_file in done_file_lst:
|
| 74 |
+
if done_file.endswith(".yaml"): continue
|
| 75 |
+
code_files += f"""
|
| 76 |
+
```python
|
| 77 |
+
{done_file_dict[done_file]}
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
write_msg=[
|
| 83 |
+
{'role': 'user', "content": f"""# Context
|
| 84 |
+
## Paper
|
| 85 |
+
{paper_content}
|
| 86 |
+
|
| 87 |
+
-----
|
| 88 |
+
|
| 89 |
+
## Overview of the plan
|
| 90 |
+
{context_lst[0]}
|
| 91 |
+
|
| 92 |
+
-----
|
| 93 |
+
|
| 94 |
+
## Design
|
| 95 |
+
{context_lst[1]}
|
| 96 |
+
|
| 97 |
+
-----
|
| 98 |
+
|
| 99 |
+
## Task
|
| 100 |
+
{context_lst[2]}
|
| 101 |
+
|
| 102 |
+
-----
|
| 103 |
+
|
| 104 |
+
## Configuration file
|
| 105 |
+
```yaml
|
| 106 |
+
{config_yaml}
|
| 107 |
+
```
|
| 108 |
+
-----
|
| 109 |
+
|
| 110 |
+
## Code Files
|
| 111 |
+
{code_files}
|
| 112 |
+
|
| 113 |
+
-----
|
| 114 |
+
|
| 115 |
+
# Format example
|
| 116 |
+
## Code: {todo_file_name}
|
| 117 |
+
```python
|
| 118 |
+
## {todo_file_name}
|
| 119 |
+
...
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
-----
|
| 123 |
+
|
| 124 |
+
# Instruction
|
| 125 |
+
Based on the paper, plan, design, task and configuration file(config.yaml) specified previously, follow "Format example", write the code.
|
| 126 |
+
|
| 127 |
+
We have {done_file_lst}.
|
| 128 |
+
Next, you must write only the "{todo_file_name}".
|
| 129 |
+
1. Only One file: do your best to implement THIS ONLY ONE FILE.
|
| 130 |
+
2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets.
|
| 131 |
+
3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import.
|
| 132 |
+
4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design.
|
| 133 |
+
5. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE.
|
| 134 |
+
6. Before using a external variable/module, make sure you import it first.
|
| 135 |
+
7. Write out EVERY CODE DETAIL, DON'T LEAVE TODO.
|
| 136 |
+
8. REFER TO CONFIGURATION: you must use configuration from "config.yaml". DO NOT FABRICATE any configuration values.
|
| 137 |
+
|
| 138 |
+
{detailed_logic_analysis}
|
| 139 |
+
|
| 140 |
+
## Code: {todo_file_name}"""}]
|
| 141 |
+
return write_msg
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def api_call(msg):
|
| 145 |
+
"""Make API call using the configured provider"""
|
| 146 |
+
if "o3-mini" in model and provider_name == 'openai':
|
| 147 |
+
completion = llm_provider.create_completion(
|
| 148 |
+
messages=msg,
|
| 149 |
+
model=model,
|
| 150 |
+
reasoning_effort="high"
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
completion = llm_provider.create_completion(
|
| 154 |
+
messages=msg,
|
| 155 |
+
model=model
|
| 156 |
+
)
|
| 157 |
+
return completion
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# testing for checking
|
| 161 |
+
detailed_logic_analysis_dict = {}
|
| 162 |
+
retrieved_section_dict = {}
|
| 163 |
+
for todo_file_name in todo_file_lst:
|
| 164 |
+
# simple analysis
|
| 165 |
+
save_todo_file_name = todo_file_name.replace("/", "_")
|
| 166 |
+
|
| 167 |
+
if todo_file_name == "config.yaml":
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
with open(f"{output_dir}/{save_todo_file_name}_simple_analysis_response.json") as f:
|
| 171 |
+
detailed_logic_analysis_response = json.load(f)
|
| 172 |
+
detailed_logic_analysis_dict[todo_file_name] = detailed_logic_analysis_response[0]['choices'][0]['message']['content']
|
| 173 |
+
|
| 174 |
+
artifact_output_dir=f'{output_dir}/coding_artifacts'
|
| 175 |
+
os.makedirs(artifact_output_dir, exist_ok=True)
|
| 176 |
+
|
| 177 |
+
total_accumulated_cost = load_accumulated_cost(f"{output_dir}/accumulated_cost.json")
|
| 178 |
+
for todo_idx, todo_file_name in enumerate(tqdm(todo_file_lst)):
|
| 179 |
+
responses = []
|
| 180 |
+
trajectories = copy.deepcopy(code_msg)
|
| 181 |
+
|
| 182 |
+
current_stage = f"[CODING] {todo_file_name}"
|
| 183 |
+
print(current_stage)
|
| 184 |
+
|
| 185 |
+
if todo_file_name == "config.yaml":
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
instruction_msg = get_write_msg(todo_file_name, detailed_logic_analysis_dict[todo_file_name], done_file_lst)
|
| 189 |
+
trajectories.extend(instruction_msg)
|
| 190 |
+
|
| 191 |
+
completion = api_call(trajectories)
|
| 192 |
+
|
| 193 |
+
# Extract response using provider abstraction
|
| 194 |
+
response_text = llm_provider.get_response_text(completion)
|
| 195 |
+
usage_info = llm_provider.get_usage_info(completion)
|
| 196 |
+
|
| 197 |
+
# Create completion JSON for logging
|
| 198 |
+
completion_json = {
|
| 199 |
+
'choices': [{'message': {'role': 'assistant', 'content': response_text}}],
|
| 200 |
+
'usage': usage_info,
|
| 201 |
+
'model': model
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# print and logging
|
| 205 |
+
print_response(completion_json)
|
| 206 |
+
temp_total_accumulated_cost = print_log_cost(completion_json, model, current_stage, output_dir, total_accumulated_cost)
|
| 207 |
+
total_accumulated_cost = temp_total_accumulated_cost
|
| 208 |
+
|
| 209 |
+
responses.append(completion_json)
|
| 210 |
+
|
| 211 |
+
# trajectories
|
| 212 |
+
message = {'role': 'assistant', 'content': response_text}
|
| 213 |
+
trajectories.append(message)
|
| 214 |
+
|
| 215 |
+
done_file_lst.append(todo_file_name)
|
| 216 |
+
|
| 217 |
+
# save
|
| 218 |
+
# save_dir_name = f"{paper_name}_repo"
|
| 219 |
+
os.makedirs(f'{output_repo_dir}', exist_ok=True)
|
| 220 |
+
save_todo_file_name = todo_file_name.replace("/", "_")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# save artifacts - create subdirectories if needed
|
| 224 |
+
artifact_file_path = f'{artifact_output_dir}/{save_todo_file_name}_coding.txt'
|
| 225 |
+
artifact_file_dir = os.path.dirname(artifact_file_path)
|
| 226 |
+
os.makedirs(artifact_file_dir, exist_ok=True)
|
| 227 |
+
|
| 228 |
+
with open(artifact_file_path, 'w') as f:
|
| 229 |
+
f.write(completion_json['choices'][0]['message']['content'])
|
| 230 |
+
|
| 231 |
+
# extract code save
|
| 232 |
+
code = extract_code_from_content(completion_json['choices'][0]['message']['content'])
|
| 233 |
+
if len(code) == 0:
|
| 234 |
+
code = completion_json['choices'][0]['message']['content']
|
| 235 |
+
|
| 236 |
+
done_file_dict[todo_file_name] = code
|
| 237 |
+
if save_todo_file_name != todo_file_name:
|
| 238 |
+
todo_file_dir = '/'.join(todo_file_name.split("/")[:-1])
|
| 239 |
+
os.makedirs(f"{output_repo_dir}/{todo_file_dir}", exist_ok=True)
|
| 240 |
+
|
| 241 |
+
# save code file - create subdirectories if needed
|
| 242 |
+
code_file_path = f"{output_repo_dir}/{todo_file_name}"
|
| 243 |
+
code_file_dir = os.path.dirname(code_file_path)
|
| 244 |
+
os.makedirs(code_file_dir, exist_ok=True)
|
| 245 |
+
|
| 246 |
+
with open(code_file_path, 'w') as f:
|
| 247 |
+
f.write(code)
|
| 248 |
+
|
| 249 |
+
save_accumulated_cost(f"{output_dir}/accumulated_cost.json", total_accumulated_cost)
|
codes/eval.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import argparse
|
| 6 |
+
from utils import read_python_files, extract_planning, content_to_json, \
|
| 7 |
+
num_tokens_from_messages, read_all_files, extract_json_from_string, get_now_str, print_log_cost
|
| 8 |
+
|
| 9 |
+
client = OpenAI(api_key = os.environ["OPENAI_API_KEY"])
|
| 10 |
+
|
| 11 |
+
def api_call(request_json):
|
| 12 |
+
completion = client.chat.completions.create(**request_json)
|
| 13 |
+
return completion
|
| 14 |
+
|
| 15 |
+
def main(args):
|
| 16 |
+
|
| 17 |
+
paper_name = args.paper_name
|
| 18 |
+
pdf_json_path = args.pdf_json_path
|
| 19 |
+
output_dir = args.output_dir
|
| 20 |
+
target_repo_dir = args.target_repo_dir
|
| 21 |
+
eval_result_dir = args.eval_result_dir
|
| 22 |
+
gpt_version = args.gpt_version
|
| 23 |
+
generated_n = args.generated_n
|
| 24 |
+
data_dir = args.data_dir
|
| 25 |
+
eval_type = args.eval_type
|
| 26 |
+
is_papercoder = True if args.papercoder else False
|
| 27 |
+
|
| 28 |
+
gold_repo_dir = args.gold_repo_dir
|
| 29 |
+
|
| 30 |
+
# paper
|
| 31 |
+
with open(f'{pdf_json_path}') as f:
|
| 32 |
+
paper_json = json.load(f)
|
| 33 |
+
|
| 34 |
+
codes = ""
|
| 35 |
+
if is_papercoder:
|
| 36 |
+
# python files
|
| 37 |
+
target_files_dict = read_python_files(target_repo_dir)
|
| 38 |
+
|
| 39 |
+
# configuration
|
| 40 |
+
with open(f'{output_dir}/planning_config.yaml') as f:
|
| 41 |
+
config_yaml = f.read()
|
| 42 |
+
|
| 43 |
+
context_lst = extract_planning(f'{output_dir}/planning_trajectories.json')
|
| 44 |
+
|
| 45 |
+
if os.path.exists(f'{output_dir}/task_list.json'):
|
| 46 |
+
with open(f'{output_dir}/task_list.json') as f:
|
| 47 |
+
task_list = json.load(f)
|
| 48 |
+
else:
|
| 49 |
+
task_list = content_to_json(context_lst[2])
|
| 50 |
+
|
| 51 |
+
todo_file_lst = task_list['Task list']
|
| 52 |
+
for todo_file in todo_file_lst:
|
| 53 |
+
if todo_file.endswith(".yaml"):
|
| 54 |
+
continue
|
| 55 |
+
codes += f"```python\n## File name: {todo_file}\n{target_files_dict[todo_file]}\n```\n\n"
|
| 56 |
+
|
| 57 |
+
codes += f"```yaml\n## File name: config.yaml\n{config_yaml}\n```\n\n"
|
| 58 |
+
else:
|
| 59 |
+
target_files_dict = read_all_files(target_repo_dir, allowed_ext=[".py", ".yaml", ".yml", ".md", ".sh", ".bash"], is_print=False)
|
| 60 |
+
for file_name, code in target_files_dict.items():
|
| 61 |
+
codes += f"```## File name: {file_name}\n{code}\n```\n\n"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
prompt = open(f"{data_dir}/prompts/{eval_type}.txt").read()
|
| 65 |
+
|
| 66 |
+
cur_prompt = prompt.replace('{{Paper}}', f"{paper_json}").replace('{{Code}}', codes)
|
| 67 |
+
|
| 68 |
+
# refernce-based
|
| 69 |
+
if "ref_based" == eval_type and len(gold_repo_dir) > 0:
|
| 70 |
+
all_files_dict = read_all_files(gold_repo_dir, allowed_ext=[".py", ".yaml", ".yml", ".md", ".sh", ".bash"], is_print=False)
|
| 71 |
+
|
| 72 |
+
goldcodes = ""
|
| 73 |
+
gold_cnt = 0
|
| 74 |
+
if len(args.selected_file_path) > 0:
|
| 75 |
+
selected_file_lst = []
|
| 76 |
+
with open(args.selected_file_path) as f:
|
| 77 |
+
selected_file_lst = f.readlines()
|
| 78 |
+
|
| 79 |
+
for s_idx in range(len(selected_file_lst)):
|
| 80 |
+
selected_file_lst[s_idx] = selected_file_lst[s_idx].strip()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
for all_file, all_file_code in all_files_dict.items():
|
| 84 |
+
if all_file not in selected_file_lst:
|
| 85 |
+
continue
|
| 86 |
+
|
| 87 |
+
goldcodes += f"```## File name: {all_file}\n{all_file_code}\n```\n\n"
|
| 88 |
+
|
| 89 |
+
gold_cnt += 1
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
else:
|
| 93 |
+
for all_file, all_file_code in all_files_dict.items():
|
| 94 |
+
goldcodes += f"```## File name: {all_file}\n{all_file_code}\n```\n\n"
|
| 95 |
+
|
| 96 |
+
gold_cnt += 1
|
| 97 |
+
|
| 98 |
+
cur_prompt = cur_prompt.replace('{{GoldCode}}', f"{goldcodes}")
|
| 99 |
+
|
| 100 |
+
msg = [{"role": "system", "content": cur_prompt}]
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
num_tokens = num_tokens_from_messages(msg)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"[WARNING] An exception was raised while counting tokens for the target repository of {args.paper_name}.")
|
| 106 |
+
print(e)
|
| 107 |
+
print("-"*40)
|
| 108 |
+
num_tokens = 0
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if num_tokens > 128000:
|
| 112 |
+
print(f"[ERROR] {args.paper_name} more than 128k")
|
| 113 |
+
sys.exit(0)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if "o3-mini" in gpt_version:
|
| 117 |
+
if generated_n > 8:
|
| 118 |
+
print(f"[WARNING] o3-mini does not support n > 8. Setting generated_n to 8.")
|
| 119 |
+
generated_n = 8
|
| 120 |
+
|
| 121 |
+
request_json = {
|
| 122 |
+
"model": gpt_version,
|
| 123 |
+
"messages": msg,
|
| 124 |
+
"reasoning_effort": "high",
|
| 125 |
+
"n": generated_n
|
| 126 |
+
}
|
| 127 |
+
else:
|
| 128 |
+
request_json = {
|
| 129 |
+
"model": gpt_version,
|
| 130 |
+
"messages": msg,
|
| 131 |
+
"temperature": 1,
|
| 132 |
+
"frequency_penalty": 0,
|
| 133 |
+
"presence_penalty": 0,
|
| 134 |
+
"stop": None,
|
| 135 |
+
"n": generated_n # 10
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
completion = api_call(request_json)
|
| 139 |
+
completion_json = json.loads(completion.model_dump_json())
|
| 140 |
+
|
| 141 |
+
score_key = "score"
|
| 142 |
+
rationale_key = "critique_list"
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
all_scores = []
|
| 146 |
+
rationales = []
|
| 147 |
+
for n in range(generated_n):
|
| 148 |
+
choice = completion_json['choices'][n]
|
| 149 |
+
|
| 150 |
+
output = choice['message']['content'].strip()
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
output_json2 = json.loads(output)
|
| 154 |
+
score = int(output_json2[score_key])
|
| 155 |
+
|
| 156 |
+
if isinstance(output_json2[rationale_key], str):
|
| 157 |
+
rationale = output_json2[rationale_key]
|
| 158 |
+
else:
|
| 159 |
+
rationale = json.dumps(output_json2[rationale_key])
|
| 160 |
+
except Exception as e:
|
| 161 |
+
# print(e)
|
| 162 |
+
try:
|
| 163 |
+
output_json2 = json.loads(extract_json_from_string(output))
|
| 164 |
+
score = int(output_json2[score_key])
|
| 165 |
+
|
| 166 |
+
if isinstance(output_json2[rationale_key], str):
|
| 167 |
+
rationale = output_json2[rationale_key]
|
| 168 |
+
else:
|
| 169 |
+
rationale = json.dumps(output_json2[rationale_key])
|
| 170 |
+
except Exception as e2: # Parsing Error
|
| 171 |
+
print(f"[WARNING] Invalid repsponse: parsing error")
|
| 172 |
+
print(e2)
|
| 173 |
+
print("-"*40)
|
| 174 |
+
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
# score
|
| 178 |
+
if score < 1 or score > 5:
|
| 179 |
+
print(f"[WARNING] Invalid repsponse: score {score}, Score must be in the range of 1–5.")
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
all_scores.append(int(score))
|
| 183 |
+
rationales.append(rationale)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
avg_score = sum(all_scores) / len(all_scores)
|
| 187 |
+
|
| 188 |
+
output_json= {
|
| 189 |
+
"paper_name": paper_name,
|
| 190 |
+
"target_repo_dir": target_repo_dir,
|
| 191 |
+
"eval_type": eval_type,
|
| 192 |
+
"gold_repo_dir": gold_repo_dir,
|
| 193 |
+
"generated_n": generated_n,
|
| 194 |
+
"request_json": request_json,
|
| 195 |
+
"completion_json": completion_json,
|
| 196 |
+
"eval_result": {
|
| 197 |
+
"score": avg_score,
|
| 198 |
+
"valid_n": len(all_scores),
|
| 199 |
+
"scroe_lst": all_scores,
|
| 200 |
+
"rationale_lst": rationales,
|
| 201 |
+
},
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
now_str = get_now_str()
|
| 205 |
+
os.makedirs(eval_result_dir, exist_ok=True)
|
| 206 |
+
with open(f"{eval_result_dir}/{paper_name}_eval_{eval_type}_{gpt_version}_{now_str}.json", 'w', encoding='utf-8') as f:
|
| 207 |
+
json.dump(output_json, f)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ---------------
|
| 211 |
+
print()
|
| 212 |
+
print("=" * 40)
|
| 213 |
+
print("🌟 Evaluation Summary 🌟")
|
| 214 |
+
print(f"📄 Paper name: {paper_name}")
|
| 215 |
+
print(f"🧪 Evaluation type: {eval_type}")
|
| 216 |
+
print(f"📁 Target repo directory: {target_repo_dir}")
|
| 217 |
+
print(f"📊 Evaluation result:")
|
| 218 |
+
print(f"\t📈 Score: {avg_score:.4f}")
|
| 219 |
+
print(f"\t✅ Valid: {output_json['eval_result']['valid_n']}/{generated_n}")
|
| 220 |
+
print("=" * 40)
|
| 221 |
+
|
| 222 |
+
print_log_cost(completion_json, gpt_version, f"[Evaluation] {paper_name} - {eval_type}", output_dir, 0)
|
| 223 |
+
# ---------------
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if __name__ == '__main__':
|
| 227 |
+
|
| 228 |
+
argparser = argparse.ArgumentParser()
|
| 229 |
+
|
| 230 |
+
argparser.add_argument('--paper_name', type=str)
|
| 231 |
+
argparser.add_argument('--pdf_json_path', type=str)
|
| 232 |
+
argparser.add_argument('--data_dir',type=str, default="../data")
|
| 233 |
+
|
| 234 |
+
argparser.add_argument('--output_dir',type=str)
|
| 235 |
+
|
| 236 |
+
argparser.add_argument('--target_repo_dir', type=str)
|
| 237 |
+
argparser.add_argument('--gold_repo_dir', type=str, default="")
|
| 238 |
+
argparser.add_argument('--eval_result_dir',type=str)
|
| 239 |
+
|
| 240 |
+
argparser.add_argument('--eval_type', type=str, default="ref_free", choices=["ref_free", "ref_based"])
|
| 241 |
+
|
| 242 |
+
argparser.add_argument('--generated_n', type=int, default=8)
|
| 243 |
+
argparser.add_argument('--gpt_version', type=str, default="o3-mini")
|
| 244 |
+
|
| 245 |
+
argparser.add_argument('--selected_file_path', type=str, default="")
|
| 246 |
+
argparser.add_argument('--papercoder', action="store_true")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
args = argparser.parse_args()
|
| 251 |
+
main(args)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ref-free
|
| 255 |
+
# python eval.py \
|
| 256 |
+
# --paper_name Transformer \
|
| 257 |
+
# --pdf_json_path ../examples/Transformer_cleaned.json \
|
| 258 |
+
# --data_dir ../data \
|
| 259 |
+
# --output_dir ../outputs/Transformer \
|
| 260 |
+
# --target_repo_dir ../outputs/Transformer_repo \
|
| 261 |
+
# --eval_result_dir ../results \
|
| 262 |
+
# --eval_type ref_free \
|
| 263 |
+
# --generated_n 8 \
|
| 264 |
+
# --papercoder
|
| 265 |
+
|
| 266 |
+
# ref-based
|
| 267 |
+
# python eval.py \
|
| 268 |
+
# --paper_name Transformer \
|
| 269 |
+
# --pdf_json_path ../examples/Transformer_cleaned.json \
|
| 270 |
+
# --data_dir ../data \
|
| 271 |
+
# --output_dir ../outputs/Transformer \
|
| 272 |
+
# --target_repo_dir ../outputs/Transformer_repo \
|
| 273 |
+
# --gold_repo_dir ../examples/Transformer_gold_repo \
|
| 274 |
+
# --eval_result_dir ../results \
|
| 275 |
+
# --eval_type ref_based \
|
| 276 |
+
# --generated_n 8 \
|
| 277 |
+
# --papercoder
|
codes/example_use_gemma.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Example: How to use Gemma provider in Blog2Code pipeline
|
| 3 |
+
|
| 4 |
+
This demonstrates how to modify any of the pipeline scripts to use Gemma instead of OpenAI/Gemini.
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
# Set the NVIDIA API key
|
| 9 |
+
os.environ['NVIDIA_API_KEY'] = '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'
|
| 10 |
+
|
| 11 |
+
# When running any pipeline script, simply change the provider_name to 'gemma'
|
| 12 |
+
# For example, in 1_planning.py, 2_analyzing.py, or 3_coding.py:
|
| 13 |
+
|
| 14 |
+
# OLD (using OpenAI):
|
| 15 |
+
# provider_name = 'openai'
|
| 16 |
+
|
| 17 |
+
# NEW (using Gemma):
|
| 18 |
+
provider_name = 'gemma'
|
| 19 |
+
|
| 20 |
+
# The rest of the code remains the same!
|
| 21 |
+
# The scripts will automatically use the Gemma model with default settings.
|
| 22 |
+
|
| 23 |
+
print(f"✅ Pipeline configured to use: {provider_name}")
|
| 24 |
+
print("You can now run any of the pipeline scripts:")
|
| 25 |
+
print(" - python 1_planning.py")
|
| 26 |
+
print(" - python 2_analyzing.py")
|
| 27 |
+
print(" - python 3_coding.py")
|
codes/llm_provider.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Provider abstraction layer for Blog2Code.
|
| 3 |
+
Supports multiple LLM providers: OpenAI, Google Gemini, NVIDIA Gemma
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, List, Any, Optional
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LLMProvider(ABC):
|
| 11 |
+
"""Base class for LLM providers"""
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 15 |
+
"""Create a chat completion"""
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def get_response_text(self, completion: Any) -> str:
|
| 20 |
+
"""Extract text from completion response"""
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def get_usage_info(self, completion: Any) -> Dict:
|
| 25 |
+
"""Extract token usage information"""
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def calculate_cost(self, usage: Dict, model: str) -> float:
|
| 30 |
+
"""Calculate cost based on usage"""
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class OpenAIProvider(LLMProvider):
|
| 35 |
+
"""OpenAI API implementation"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 38 |
+
from openai import OpenAI
|
| 39 |
+
self.client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY"))
|
| 40 |
+
|
| 41 |
+
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 42 |
+
"""Create OpenAI chat completion"""
|
| 43 |
+
return self.client.chat.completions.create(
|
| 44 |
+
model=model,
|
| 45 |
+
messages=messages,
|
| 46 |
+
**kwargs
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def get_response_text(self, completion: Any) -> str:
|
| 50 |
+
"""Extract text from OpenAI response"""
|
| 51 |
+
return completion.choices[0].message.content
|
| 52 |
+
|
| 53 |
+
def get_usage_info(self, completion: Any) -> Dict:
|
| 54 |
+
"""Extract usage from OpenAI response"""
|
| 55 |
+
return {
|
| 56 |
+
'prompt_tokens': completion.usage.prompt_tokens,
|
| 57 |
+
'completion_tokens': completion.usage.completion_tokens,
|
| 58 |
+
'total_tokens': completion.usage.total_tokens,
|
| 59 |
+
'cached_tokens': getattr(completion.usage.prompt_tokens_details, 'cached_tokens', 0) if hasattr(completion.usage, 'prompt_tokens_details') else 0
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def calculate_cost(self, usage: Dict, model: str) -> float:
|
| 63 |
+
"""Calculate OpenAI cost"""
|
| 64 |
+
# Pricing per 1M tokens
|
| 65 |
+
model_costs = {
|
| 66 |
+
"gpt-4o-mini": {"input": 0.150, "cached": 0.075, "output": 0.600},
|
| 67 |
+
"gpt-4o": {"input": 2.50, "cached": 1.25, "output": 10.00},
|
| 68 |
+
"gpt-3.5-turbo": {"input": 0.50, "cached": 0.25, "output": 1.50},
|
| 69 |
+
"o3-mini": {"input": 1.10, "cached": 0.55, "output": 4.40},
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
costs = model_costs.get(model, model_costs["gpt-4o-mini"])
|
| 73 |
+
|
| 74 |
+
prompt_tokens = usage['prompt_tokens']
|
| 75 |
+
cached_tokens = usage.get('cached_tokens', 0)
|
| 76 |
+
completion_tokens = usage['completion_tokens']
|
| 77 |
+
|
| 78 |
+
actual_input_tokens = prompt_tokens - cached_tokens
|
| 79 |
+
|
| 80 |
+
input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
|
| 81 |
+
cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
|
| 82 |
+
output_cost = (completion_tokens / 1_000_000) * costs["output"]
|
| 83 |
+
|
| 84 |
+
return input_cost + cached_cost + output_cost
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class GeminiProvider(LLMProvider):
|
| 88 |
+
"""Google Gemini API implementation"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 91 |
+
try:
|
| 92 |
+
import google.generativeai as genai
|
| 93 |
+
self.genai = genai
|
| 94 |
+
genai.configure(api_key=api_key or os.environ.get("GEMINI_API_KEY"))
|
| 95 |
+
except ImportError:
|
| 96 |
+
raise ImportError(
|
| 97 |
+
"google-generativeai not installed. "
|
| 98 |
+
"Install with: pip install google-generativeai"
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 102 |
+
"""Create Gemini chat completion"""
|
| 103 |
+
# Convert OpenAI message format to Gemini format
|
| 104 |
+
gemini_messages = self._convert_messages(messages)
|
| 105 |
+
|
| 106 |
+
# Fix model name - Gemini expects models/model-name format
|
| 107 |
+
if not model.startswith('models/'):
|
| 108 |
+
model = f'models/{model}'
|
| 109 |
+
|
| 110 |
+
# Create model
|
| 111 |
+
gemini_model = self.genai.GenerativeModel(model)
|
| 112 |
+
|
| 113 |
+
# Generate response
|
| 114 |
+
response = gemini_model.generate_content(
|
| 115 |
+
gemini_messages,
|
| 116 |
+
generation_config=self._get_generation_config(**kwargs)
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return response
|
| 120 |
+
|
| 121 |
+
def _convert_messages(self, messages: List[Dict]) -> str:
|
| 122 |
+
"""Convert OpenAI messages to Gemini prompt format"""
|
| 123 |
+
# Gemini uses a simpler format - concatenate all messages
|
| 124 |
+
prompt_parts = []
|
| 125 |
+
|
| 126 |
+
for msg in messages:
|
| 127 |
+
role = msg['role']
|
| 128 |
+
content = msg['content']
|
| 129 |
+
|
| 130 |
+
if role == 'system':
|
| 131 |
+
prompt_parts.append(f"System Instructions:\n{content}\n")
|
| 132 |
+
elif role == 'user':
|
| 133 |
+
prompt_parts.append(f"User:\n{content}\n")
|
| 134 |
+
elif role == 'assistant':
|
| 135 |
+
prompt_parts.append(f"Assistant:\n{content}\n")
|
| 136 |
+
|
| 137 |
+
return "\n".join(prompt_parts)
|
| 138 |
+
|
| 139 |
+
def _get_generation_config(self, **kwargs):
|
| 140 |
+
"""Convert OpenAI kwargs to Gemini generation config"""
|
| 141 |
+
config = {}
|
| 142 |
+
|
| 143 |
+
# Map common parameters
|
| 144 |
+
if 'temperature' in kwargs:
|
| 145 |
+
config['temperature'] = kwargs['temperature']
|
| 146 |
+
if 'max_tokens' in kwargs:
|
| 147 |
+
config['max_output_tokens'] = kwargs['max_tokens']
|
| 148 |
+
if 'top_p' in kwargs:
|
| 149 |
+
config['top_p'] = kwargs['top_p']
|
| 150 |
+
|
| 151 |
+
return config
|
| 152 |
+
|
| 153 |
+
def get_response_text(self, completion: Any) -> str:
|
| 154 |
+
"""Extract text from Gemini response"""
|
| 155 |
+
return completion.text
|
| 156 |
+
|
| 157 |
+
def get_usage_info(self, completion: Any) -> Dict:
|
| 158 |
+
"""Extract usage from Gemini response"""
|
| 159 |
+
# Gemini provides token counts in metadata
|
| 160 |
+
try:
|
| 161 |
+
metadata = completion.usage_metadata
|
| 162 |
+
return {
|
| 163 |
+
'prompt_tokens': metadata.prompt_token_count,
|
| 164 |
+
'completion_tokens': metadata.candidates_token_count,
|
| 165 |
+
'total_tokens': metadata.total_token_count,
|
| 166 |
+
'cached_tokens': getattr(metadata, 'cached_content_token_count', 0)
|
| 167 |
+
}
|
| 168 |
+
except:
|
| 169 |
+
# Fallback if metadata not available
|
| 170 |
+
return {
|
| 171 |
+
'prompt_tokens': 0,
|
| 172 |
+
'completion_tokens': 0,
|
| 173 |
+
'total_tokens': 0,
|
| 174 |
+
'cached_tokens': 0
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
def calculate_cost(self, usage: Dict, model: str) -> float:
|
| 178 |
+
"""Calculate Gemini cost"""
|
| 179 |
+
# Gemini pricing per 1M tokens (as of Jan 2026)
|
| 180 |
+
model_costs = {
|
| 181 |
+
"gemini-1.5-flash": {"input": 0.075, "cached": 0.01875, "output": 0.30},
|
| 182 |
+
"gemini-1.5-pro": {"input": 1.25, "cached": 0.3125, "output": 5.00},
|
| 183 |
+
"gemini-2.0-flash-exp": {"input": 0.0, "cached": 0.0, "output": 0.0}, # Free during preview
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
costs = model_costs.get(model, model_costs["gemini-1.5-flash"])
|
| 187 |
+
|
| 188 |
+
prompt_tokens = usage['prompt_tokens']
|
| 189 |
+
cached_tokens = usage.get('cached_tokens', 0)
|
| 190 |
+
completion_tokens = usage['completion_tokens']
|
| 191 |
+
|
| 192 |
+
actual_input_tokens = prompt_tokens - cached_tokens
|
| 193 |
+
|
| 194 |
+
input_cost = (actual_input_tokens / 1_000_000) * costs["input"]
|
| 195 |
+
cached_cost = (cached_tokens / 1_000_000) * costs["cached"]
|
| 196 |
+
output_cost = (completion_tokens / 1_000_000) * costs["output"]
|
| 197 |
+
|
| 198 |
+
return input_cost + cached_cost + output_cost
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class GemmaProvider(LLMProvider):
|
| 202 |
+
"""NVIDIA Gemma API implementation"""
|
| 203 |
+
|
| 204 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 205 |
+
import requests
|
| 206 |
+
self.requests = requests
|
| 207 |
+
self.api_key = api_key or os.environ.get("NVIDIA_API_KEY")
|
| 208 |
+
if not self.api_key:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
"NVIDIA_API_KEY not found. "
|
| 211 |
+
"Set it as an environment variable or pass it to the constructor."
|
| 212 |
+
)
|
| 213 |
+
self.invoke_url = "https://integrate.api.nvidia.com/v1/chat/completions"
|
| 214 |
+
|
| 215 |
+
def create_completion(self, messages: List[Dict], model: str, **kwargs) -> Any:
|
| 216 |
+
"""Create Gemma chat completion"""
|
| 217 |
+
# Prepare headers
|
| 218 |
+
headers = {
|
| 219 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 220 |
+
"Accept": "application/json" # Non-streaming for simplicity
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
# Prepare payload
|
| 224 |
+
payload = {
|
| 225 |
+
"model": model,
|
| 226 |
+
"messages": messages,
|
| 227 |
+
"max_tokens": kwargs.get('max_tokens', 512),
|
| 228 |
+
"temperature": kwargs.get('temperature', 0.20),
|
| 229 |
+
"top_p": kwargs.get('top_p', 0.70),
|
| 230 |
+
"stream": False # Disable streaming for now
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
# Make request
|
| 234 |
+
response = self.requests.post(self.invoke_url, headers=headers, json=payload)
|
| 235 |
+
response.raise_for_status()
|
| 236 |
+
|
| 237 |
+
return response.json()
|
| 238 |
+
|
| 239 |
+
def get_response_text(self, completion: Any) -> str:
|
| 240 |
+
"""Extract text from Gemma response"""
|
| 241 |
+
# NVIDIA API returns OpenAI-compatible format
|
| 242 |
+
if isinstance(completion, dict):
|
| 243 |
+
return completion['choices'][0]['message']['content']
|
| 244 |
+
return str(completion)
|
| 245 |
+
|
| 246 |
+
def get_usage_info(self, completion: Any) -> Dict:
|
| 247 |
+
"""Extract usage from Gemma response"""
|
| 248 |
+
try:
|
| 249 |
+
usage = completion.get('usage', {})
|
| 250 |
+
return {
|
| 251 |
+
'prompt_tokens': usage.get('prompt_tokens', 0),
|
| 252 |
+
'completion_tokens': usage.get('completion_tokens', 0),
|
| 253 |
+
'total_tokens': usage.get('total_tokens', 0),
|
| 254 |
+
'cached_tokens': 0 # NVIDIA API doesn't provide cached token info
|
| 255 |
+
}
|
| 256 |
+
except:
|
| 257 |
+
return {
|
| 258 |
+
'prompt_tokens': 0,
|
| 259 |
+
'completion_tokens': 0,
|
| 260 |
+
'total_tokens': 0,
|
| 261 |
+
'cached_tokens': 0
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
def calculate_cost(self, usage: Dict, model: str) -> float:
|
| 265 |
+
"""Calculate Gemma cost"""
|
| 266 |
+
# NVIDIA API pricing (check current pricing at build.nvidia.com)
|
| 267 |
+
# For now, using placeholder values - update with actual pricing
|
| 268 |
+
model_costs = {
|
| 269 |
+
"google/gemma-3-27b-it": {"input": 0.0, "output": 0.0}, # Free tier or update with actual costs
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
costs = model_costs.get(model, {"input": 0.0, "output": 0.0})
|
| 273 |
+
|
| 274 |
+
prompt_tokens = usage['prompt_tokens']
|
| 275 |
+
completion_tokens = usage['completion_tokens']
|
| 276 |
+
|
| 277 |
+
input_cost = (prompt_tokens / 1_000_000) * costs["input"]
|
| 278 |
+
output_cost = (completion_tokens / 1_000_000) * costs["output"]
|
| 279 |
+
|
| 280 |
+
return input_cost + output_cost
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def get_provider(provider_name: str, api_key: Optional[str] = None) -> LLMProvider:
|
| 284 |
+
"""
|
| 285 |
+
Factory function to get LLM provider.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
provider_name: Name of provider ('openai' or 'gemini')
|
| 289 |
+
api_key: Optional API key (uses env var if not provided)
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
LLMProvider instance
|
| 293 |
+
"""
|
| 294 |
+
providers = {
|
| 295 |
+
'openai': OpenAIProvider,
|
| 296 |
+
'gemini': GeminiProvider,
|
| 297 |
+
'gemma': GemmaProvider,
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
if provider_name not in providers:
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"Unknown provider: {provider_name}. "
|
| 303 |
+
f"Available providers: {list(providers.keys())}"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return providers[provider_name](api_key=api_key)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def get_default_model(provider_name: str) -> str:
|
| 310 |
+
"""Get default model for a provider"""
|
| 311 |
+
defaults = {
|
| 312 |
+
'openai': 'gpt-4o-mini',
|
| 313 |
+
'gemini': 'gemini-1.5-flash',
|
| 314 |
+
'gemma': 'google/gemma-3-27b-it',
|
| 315 |
+
}
|
| 316 |
+
return defaults.get(provider_name, 'gpt-4o-mini')
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
# Test script
|
| 321 |
+
print("Testing LLM Provider abstraction...")
|
| 322 |
+
|
| 323 |
+
# Test OpenAI
|
| 324 |
+
try:
|
| 325 |
+
provider = get_provider('openai')
|
| 326 |
+
print("✅ OpenAI provider initialized")
|
| 327 |
+
except Exception as e:
|
| 328 |
+
print(f"❌ OpenAI provider failed: {e}")
|
| 329 |
+
|
| 330 |
+
# Test Gemini
|
| 331 |
+
try:
|
| 332 |
+
provider = get_provider('gemini')
|
| 333 |
+
print("✅ Gemini provider initialized")
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print(f"❌ Gemini provider failed: {e}")
|
| 336 |
+
|
| 337 |
+
# Test Gemma
|
| 338 |
+
try:
|
| 339 |
+
provider = get_provider('gemma')
|
| 340 |
+
print("✅ Gemma provider initialized")
|
| 341 |
+
except Exception as e:
|
| 342 |
+
print(f"❌ Gemma provider failed: {e}")
|
codes/rate_limiter.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Rate Limiter for OpenAI API to avoid hitting TPM (tokens per minute) limits.
|
| 3 |
+
"""
|
| 4 |
+
import time
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
class RateLimiter:
|
| 8 |
+
"""Smart rate limiter that tracks token usage and sleeps only when necessary."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, max_tokens_per_minute: int = 95000, buffer: int = 5000):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the rate limiter.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
max_tokens_per_minute: Maximum tokens allowed per minute (default: 95K for safety)
|
| 16 |
+
buffer: Safety buffer to stay under limit (default: 5K)
|
| 17 |
+
"""
|
| 18 |
+
self.max_tokens = max_tokens_per_minute - buffer
|
| 19 |
+
self.tokens_used: List[Tuple[float, int]] = [] # [(timestamp, tokens), ...]
|
| 20 |
+
self.total_waits = 0
|
| 21 |
+
self.total_wait_time = 0.0
|
| 22 |
+
|
| 23 |
+
def wait_if_needed(self, tokens_needed: int) -> None:
|
| 24 |
+
"""
|
| 25 |
+
Check if we need to wait before making the next API call.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
tokens_needed: Estimated tokens for the next API call
|
| 29 |
+
"""
|
| 30 |
+
now = time.time()
|
| 31 |
+
|
| 32 |
+
# Remove tokens older than 60 seconds (sliding window)
|
| 33 |
+
self.tokens_used = [
|
| 34 |
+
(ts, tok) for ts, tok in self.tokens_used
|
| 35 |
+
if now - ts < 60
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# Calculate tokens used in last 60 seconds
|
| 39 |
+
tokens_in_window = sum(tok for _, tok in self.tokens_used)
|
| 40 |
+
|
| 41 |
+
# If adding new request would exceed limit, wait
|
| 42 |
+
if tokens_in_window + tokens_needed > self.max_tokens:
|
| 43 |
+
# Calculate how long to wait
|
| 44 |
+
oldest_timestamp = self.tokens_used[0][0]
|
| 45 |
+
wait_time = 60 - (now - oldest_timestamp) + 1 # +1 for safety
|
| 46 |
+
|
| 47 |
+
print(f"⏰ Rate limit approaching ({tokens_in_window + tokens_needed}/{self.max_tokens} tokens)")
|
| 48 |
+
print(f" Waiting {wait_time:.1f}s for rate limit window to reset...")
|
| 49 |
+
|
| 50 |
+
time.sleep(wait_time)
|
| 51 |
+
self.total_waits += 1
|
| 52 |
+
self.total_wait_time += wait_time
|
| 53 |
+
|
| 54 |
+
# Clear old tokens after waiting
|
| 55 |
+
now = time.time()
|
| 56 |
+
self.tokens_used = [
|
| 57 |
+
(ts, tok) for ts, tok in self.tokens_used
|
| 58 |
+
if now - ts < 60
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# Record this request
|
| 62 |
+
self.tokens_used.append((now, tokens_needed))
|
| 63 |
+
|
| 64 |
+
def get_stats(self) -> dict:
|
| 65 |
+
"""Get statistics about rate limiting."""
|
| 66 |
+
return {
|
| 67 |
+
'total_waits': self.total_waits,
|
| 68 |
+
'total_wait_time': self.total_wait_time,
|
| 69 |
+
'current_window_tokens': sum(tok for _, tok in self.tokens_used)
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def print_stats(self) -> None:
|
| 73 |
+
"""Print rate limiting statistics."""
|
| 74 |
+
stats = self.get_stats()
|
| 75 |
+
print("\n" + "="*50)
|
| 76 |
+
print("📊 Rate Limiter Statistics")
|
| 77 |
+
print("="*50)
|
| 78 |
+
print(f"Total waits: {stats['total_waits']}")
|
| 79 |
+
print(f"Total wait time: {stats['total_wait_time']:.1f}s")
|
| 80 |
+
print(f"Current window usage: {stats['current_window_tokens']} tokens")
|
| 81 |
+
print("="*50 + "\n")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def estimate_tokens(text: str, overhead: int = 800) -> int:
|
| 85 |
+
"""
|
| 86 |
+
Estimate tokens for a text string.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
text: Input text
|
| 90 |
+
overhead: Additional tokens for system prompts, formatting, etc.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
Estimated token count
|
| 94 |
+
"""
|
| 95 |
+
# Rough estimation: 1 token ≈ 4 characters
|
| 96 |
+
content_tokens = len(str(text)) // 4
|
| 97 |
+
return content_tokens + overhead
|
codes/test_gemma.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script to verify Gemma provider works with real API calls
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from llm_provider import get_provider, get_default_model
|
| 6 |
+
|
| 7 |
+
# Set API key
|
| 8 |
+
os.environ['NVIDIA_API_KEY'] = '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'
|
| 9 |
+
|
| 10 |
+
# Initialize Gemma provider
|
| 11 |
+
print("Initializing Gemma provider...")
|
| 12 |
+
provider = get_provider('gemma')
|
| 13 |
+
model = get_default_model('gemma')
|
| 14 |
+
print(f"✅ Provider initialized with model: {model}")
|
| 15 |
+
|
| 16 |
+
# Test a simple completion
|
| 17 |
+
print("\nTesting completion...")
|
| 18 |
+
messages = [
|
| 19 |
+
{"role": "user", "content": "Say 'Hello, I am Gemma!' in exactly those words."}
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
completion = provider.create_completion(messages, model, max_tokens=50)
|
| 24 |
+
response_text = provider.get_response_text(completion)
|
| 25 |
+
usage = provider.get_usage_info(completion)
|
| 26 |
+
cost = provider.calculate_cost(usage, model)
|
| 27 |
+
|
| 28 |
+
print(f"\n✅ Completion successful!")
|
| 29 |
+
print(f"Response: {response_text}")
|
| 30 |
+
print(f"\nUsage:")
|
| 31 |
+
print(f" - Prompt tokens: {usage['prompt_tokens']}")
|
| 32 |
+
print(f" - Completion tokens: {usage['completion_tokens']}")
|
| 33 |
+
print(f" - Total tokens: {usage['total_tokens']}")
|
| 34 |
+
print(f" - Cost: ${cost:.6f}")
|
| 35 |
+
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"\n❌ Completion failed: {e}")
|
| 38 |
+
import traceback
|
| 39 |
+
traceback.print_exc()
|
codes/utils.py
ADDED
|
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
def extract_planning(trajectories_json_file_path):
|
| 7 |
+
with open(trajectories_json_file_path) as f:
|
| 8 |
+
traj = json.load(f)
|
| 9 |
+
|
| 10 |
+
context_lst = []
|
| 11 |
+
for turn in traj:
|
| 12 |
+
if turn['role'] == 'assistant':
|
| 13 |
+
# context_lst.append(turn['content'])
|
| 14 |
+
content = turn['content']
|
| 15 |
+
if "</think>" in content:
|
| 16 |
+
content = content.split("</think>")[-1].strip()
|
| 17 |
+
context_lst.append(content)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
context_lst = context_lst[:3]
|
| 21 |
+
|
| 22 |
+
return context_lst
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def content_to_json(data):
|
| 27 |
+
clean_data = re.sub(r'\[CONTENT\]|\[/CONTENT\]', '', data).strip()
|
| 28 |
+
|
| 29 |
+
clean_data = re.sub(r'(".*?"),\s*#.*', r'\1,', clean_data)
|
| 30 |
+
|
| 31 |
+
clean_data = re.sub(r',\s*\]', ']', clean_data)
|
| 32 |
+
|
| 33 |
+
clean_data = re.sub(r'\n\s*', '', clean_data)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# JSON parsing
|
| 37 |
+
try:
|
| 38 |
+
json_data = json.loads(clean_data)
|
| 39 |
+
return json_data
|
| 40 |
+
except json.JSONDecodeError as e:
|
| 41 |
+
# print(e)
|
| 42 |
+
return content_to_json2(data)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def content_to_json2(data):
|
| 46 |
+
# remove [CONTENT][/CONTENT]
|
| 47 |
+
clean_data = re.sub(r'\[CONTENT\]|\[/CONTENT\]', '', data).strip()
|
| 48 |
+
|
| 49 |
+
# "~~~~", #comment -> "~~~~",
|
| 50 |
+
clean_data = re.sub(r'(".*?"),\s*#.*', r'\1,', clean_data)
|
| 51 |
+
|
| 52 |
+
# "~~~~" #comment → "~~~~"
|
| 53 |
+
clean_data = re.sub(r'(".*?")\s*#.*', r'\1', clean_data)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ("~~~~",] -> "~~~~"])
|
| 57 |
+
clean_data = re.sub(r',\s*\]', ']', clean_data)
|
| 58 |
+
|
| 59 |
+
clean_data = re.sub(r'\n\s*', '', clean_data)
|
| 60 |
+
|
| 61 |
+
# JSON parsing
|
| 62 |
+
try:
|
| 63 |
+
json_data = json.loads(clean_data)
|
| 64 |
+
return json_data
|
| 65 |
+
|
| 66 |
+
except json.JSONDecodeError as e:
|
| 67 |
+
# print("Json parsing error", e)
|
| 68 |
+
return content_to_json3(data)
|
| 69 |
+
|
| 70 |
+
def content_to_json3(data):
|
| 71 |
+
# remove [CONTENT] [/CONTENT]
|
| 72 |
+
clean_data = re.sub(r'\[CONTENT\]|\[/CONTENT\]', '', data).strip()
|
| 73 |
+
|
| 74 |
+
# "~~~~", #comment -> "~~~~",
|
| 75 |
+
clean_data = re.sub(r'(".*?"),\s*#.*', r'\1,', clean_data)
|
| 76 |
+
|
| 77 |
+
# "~~~~" #comment → "~~~~"
|
| 78 |
+
clean_data = re.sub(r'(".*?")\s*#.*', r'\1', clean_data)
|
| 79 |
+
|
| 80 |
+
# remove ("~~~~",] -> "~~~~"])
|
| 81 |
+
clean_data = re.sub(r',\s*\]', ']', clean_data)
|
| 82 |
+
|
| 83 |
+
clean_data = re.sub(r'\n\s*', '', clean_data)
|
| 84 |
+
clean_data = re.sub(r'"""', '"', clean_data) # Replace triple double quotes
|
| 85 |
+
clean_data = re.sub(r"'''", "'", clean_data) # Replace triple single quotes
|
| 86 |
+
clean_data = re.sub(r"\\", "'", clean_data) # Replace \
|
| 87 |
+
|
| 88 |
+
# JSON parsing
|
| 89 |
+
try:
|
| 90 |
+
json_data = json.loads(f"""{clean_data}""")
|
| 91 |
+
return json_data
|
| 92 |
+
|
| 93 |
+
except json.JSONDecodeError as e:
|
| 94 |
+
# print(e)
|
| 95 |
+
|
| 96 |
+
# print(f"[DEBUG] utils.py > content_to_json3 ")
|
| 97 |
+
# return None
|
| 98 |
+
return content_to_json4(data)
|
| 99 |
+
|
| 100 |
+
def content_to_json4(data):
|
| 101 |
+
# 1. Extract Logic Analysis, Task list
|
| 102 |
+
pattern = r'"Logic Analysis":\s*(\[[\s\S]*?\])\s*,\s*"Task list":\s*(\[[\s\S]*?\])'
|
| 103 |
+
match = re.search(pattern, data)
|
| 104 |
+
|
| 105 |
+
if match:
|
| 106 |
+
logic_analysis = json.loads(match.group(1))
|
| 107 |
+
task_list = json.loads(match.group(2))
|
| 108 |
+
|
| 109 |
+
result = {
|
| 110 |
+
"Logic Analysis": logic_analysis,
|
| 111 |
+
"Task list": task_list
|
| 112 |
+
}
|
| 113 |
+
else:
|
| 114 |
+
result = {}
|
| 115 |
+
|
| 116 |
+
# print(json.dumps(result, indent=2))
|
| 117 |
+
return result
|
| 118 |
+
|
| 119 |
+
def extract_code_from_content(content):
|
| 120 |
+
pattern = r'^```(?:\w+)?\s*\n(.*?)(?=^```)```'
|
| 121 |
+
code = re.findall(pattern, content, re.DOTALL | re.MULTILINE)
|
| 122 |
+
if len(code) == 0:
|
| 123 |
+
return ""
|
| 124 |
+
else:
|
| 125 |
+
return code[0]
|
| 126 |
+
|
| 127 |
+
def extract_code_from_content2(content):
|
| 128 |
+
pattern = r'```python\s*(.*?)```'
|
| 129 |
+
result = re.search(pattern, content, re.DOTALL)
|
| 130 |
+
|
| 131 |
+
if result:
|
| 132 |
+
extracted_code = result.group(1).strip()
|
| 133 |
+
else:
|
| 134 |
+
extracted_code = ""
|
| 135 |
+
print("[WARNING] No Python code found.")
|
| 136 |
+
return extracted_code
|
| 137 |
+
|
| 138 |
+
def format_json_data(data):
|
| 139 |
+
formatted_text = ""
|
| 140 |
+
for key, value in data.items():
|
| 141 |
+
formatted_text += "-" * 40 + "\n"
|
| 142 |
+
formatted_text += "[" + key + "]\n"
|
| 143 |
+
if isinstance(value, list):
|
| 144 |
+
for item in value:
|
| 145 |
+
formatted_text += f"- {item}\n"
|
| 146 |
+
else:
|
| 147 |
+
formatted_text += str(value) + "\n"
|
| 148 |
+
formatted_text += "\n"
|
| 149 |
+
return formatted_text
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def cal_cost(response_json, model_name):
|
| 153 |
+
model_cost = {
|
| 154 |
+
# OpenAI Models
|
| 155 |
+
"gpt-4o-mini": {"input": 0.150, "cached_input": 0.075, "output": 0.600},
|
| 156 |
+
"gpt-4o": {"input": 2.50, "cached_input": 1.25, "output": 10.00},
|
| 157 |
+
|
| 158 |
+
# gpt-4o-realtime-preview
|
| 159 |
+
"gpt-4o-realtime-preview": {"input": 5.00, "cached_input": 2.50, "output": 20.00},
|
| 160 |
+
"gpt-4o-realtime-preview-2024-12-17": {"input": 5.00, "cached_input": 2.50, "output": 20.00},
|
| 161 |
+
"gpt-4o-realtime-preview-2024-10-01": {"input": 5.00, "cached_input": 2.50, "output": 20.00},
|
| 162 |
+
|
| 163 |
+
# gpt-4o-mini
|
| 164 |
+
"gpt-4o-mini": {"input": 0.15, "cached_input": 0.075, "output": 0.60},
|
| 165 |
+
"gpt-4o-mini-2024-07-18": {"input": 0.15, "cached_input": 0.075, "output": 0.60},
|
| 166 |
+
|
| 167 |
+
# gpt-4o-mini-audio-preview
|
| 168 |
+
"gpt-4o-mini-audio-preview": {"input": 0.15, "cached_input": None, "output": 0.60},
|
| 169 |
+
"gpt-4o-mini-audio-preview-2024-12-17": {"input": 0.15, "cached_input": None, "output": 0.60},
|
| 170 |
+
|
| 171 |
+
# gpt-4o-mini-realtime-preview
|
| 172 |
+
"gpt-4o-mini-realtime-preview": {"input": 0.60, "cached_input": 0.30, "output": 2.40},
|
| 173 |
+
"gpt-4o-mini-realtime-preview-2024-12-17": {"input": 0.60, "cached_input": 0.30, "output": 2.40},
|
| 174 |
+
|
| 175 |
+
# o1
|
| 176 |
+
"o1": {"input": 15.00, "cached_input": 7.50, "output": 60.00},
|
| 177 |
+
"o1-2024-12-17": {"input": 15.00, "cached_input": 7.50, "output": 60.00},
|
| 178 |
+
"o1-preview-2024-09-12": {"input": 15.00, "cached_input": 7.50, "output": 60.00},
|
| 179 |
+
|
| 180 |
+
# o1-pro
|
| 181 |
+
"o1-pro": {"input": 150.00, "cached_input": None, "output": 600.00},
|
| 182 |
+
"o1-pro-2025-03-19": {"input": 150.00, "cached_input": None, "output": 600.00},
|
| 183 |
+
|
| 184 |
+
# o3
|
| 185 |
+
"o3": {"input": 10.00, "cached_input": 2.50, "output": 40.00},
|
| 186 |
+
"o3-2025-04-16": {"input": 10.00, "cached_input": 2.50, "output": 40.00},
|
| 187 |
+
|
| 188 |
+
# o4-mini
|
| 189 |
+
"o4-mini": {"input": 1.10, "cached_input": 0.275, "output": 4.40},
|
| 190 |
+
"o4-mini-2025-04-16": {"input": 1.10, "cached_input": 0.275, "output": 4.40},
|
| 191 |
+
|
| 192 |
+
# o3-mini
|
| 193 |
+
"o3-mini": {"input": 1.10, "cached_input": 0.55, "output": 4.40},
|
| 194 |
+
"o3-mini-2025-01-31": {"input": 1.10, "cached_input": 0.55, "output": 4.40},
|
| 195 |
+
|
| 196 |
+
# o1-mini
|
| 197 |
+
"o1-mini": {"input": 1.10, "cached_input": 0.55, "output": 4.40},
|
| 198 |
+
"o1-mini-2024-09-12": {"input": 1.10, "cached_input": 0.55, "output": 4.40},
|
| 199 |
+
|
| 200 |
+
# gpt-4o-mini-search-preview
|
| 201 |
+
"gpt-4o-mini-search-preview": {"input": 0.15, "cached_input": None, "output": 0.60},
|
| 202 |
+
"gpt-4o-mini-search-preview-2025-03-11": {"input": 0.15, "cached_input": None, "output": 0.60},
|
| 203 |
+
|
| 204 |
+
# gpt-4o-search-preview
|
| 205 |
+
"gpt-4o-search-preview": {"input": 2.50, "cached_input": None, "output": 10.00},
|
| 206 |
+
"gpt-4o-search-preview-2025-03-11": {"input": 2.50, "cached_input": None, "output": 10.00},
|
| 207 |
+
|
| 208 |
+
# computer-use-preview
|
| 209 |
+
"computer-use-preview": {"input": 3.00, "cached_input": None, "output": 12.00},
|
| 210 |
+
"computer-use-preview-2025-03-11": {"input": 3.00, "cached_input": None, "output": 12.00},
|
| 211 |
+
|
| 212 |
+
# gpt-image-1
|
| 213 |
+
"gpt-image-1": {"input": 5.00, "cached_input": None, "output": None},
|
| 214 |
+
|
| 215 |
+
# Google Gemini Models
|
| 216 |
+
"gemini-1.5-flash": {"input": 0.075, "cached_input": 0.01875, "output": 0.30},
|
| 217 |
+
"gemini-1.5-pro": {"input": 1.25, "cached_input": 0.3125, "output": 5.00},
|
| 218 |
+
"gemini-2.0-flash-exp": {"input": 0.0, "cached_input": 0.0, "output": 0.0},
|
| 219 |
+
"gemini-3-flash-preview": {"input": 0.0, "cached_input": 0.0, "output": 0.0},
|
| 220 |
+
"models/gemini-1.5-flash": {"input": 0.075, "cached_input": 0.01875, "output": 0.30},
|
| 221 |
+
"models/gemini-1.5-pro": {"input": 1.25, "cached_input": 0.3125, "output": 5.00},
|
| 222 |
+
"models/gemini-2.0-flash-exp": {"input": 0.0, "cached_input": 0.0, "output": 0.0},
|
| 223 |
+
"models/gemini-3-flash-preview": {"input": 0.0, "cached_input": 0.0, "output": 0.0},
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
# Extract token counts
|
| 227 |
+
prompt_tokens = response_json["usage"]["prompt_tokens"]
|
| 228 |
+
completion_tokens = response_json["usage"]["completion_tokens"]
|
| 229 |
+
|
| 230 |
+
# Handle cached tokens (may not exist in all providers)
|
| 231 |
+
cached_tokens = 0
|
| 232 |
+
if "prompt_tokens_details" in response_json["usage"]:
|
| 233 |
+
cached_tokens = response_json["usage"]["prompt_tokens_details"].get("cached_tokens", 0)
|
| 234 |
+
elif "cached_tokens" in response_json["usage"]:
|
| 235 |
+
cached_tokens = response_json["usage"]["cached_tokens"]
|
| 236 |
+
|
| 237 |
+
# input token = (prompt_tokens - cached_tokens)
|
| 238 |
+
actual_input_tokens = prompt_tokens - cached_tokens
|
| 239 |
+
output_tokens = completion_tokens
|
| 240 |
+
|
| 241 |
+
# Get cost info with fallback for unknown models
|
| 242 |
+
if model_name not in model_cost:
|
| 243 |
+
print(f"⚠️ Warning: Unknown model '{model_name}', assuming free tier")
|
| 244 |
+
cost_info = {"input": 0.0, "cached_input": 0.0, "output": 0.0}
|
| 245 |
+
else:
|
| 246 |
+
cost_info = model_cost[model_name]
|
| 247 |
+
|
| 248 |
+
input_cost = (actual_input_tokens / 1_000_000) * cost_info['input']
|
| 249 |
+
cached_input_cost = (cached_tokens / 1_000_000) * cost_info['cached_input']
|
| 250 |
+
output_cost = (output_tokens / 1_000_000) * cost_info['output']
|
| 251 |
+
|
| 252 |
+
total_cost = input_cost + cached_input_cost + output_cost
|
| 253 |
+
|
| 254 |
+
return {
|
| 255 |
+
'model_name': model_name,
|
| 256 |
+
'actual_input_tokens': actual_input_tokens,
|
| 257 |
+
'input_cost': input_cost,
|
| 258 |
+
'cached_tokens': cached_tokens,
|
| 259 |
+
'cached_input_cost': cached_input_cost,
|
| 260 |
+
'output_tokens': output_tokens,
|
| 261 |
+
'output_cost': output_cost,
|
| 262 |
+
'total_cost': total_cost,
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
def load_accumulated_cost(accumulated_cost_file):
|
| 266 |
+
if os.path.exists(accumulated_cost_file):
|
| 267 |
+
with open(accumulated_cost_file, "r", encoding="utf-8") as f:
|
| 268 |
+
data = json.load(f)
|
| 269 |
+
return data.get("total_cost", 0.0)
|
| 270 |
+
else:
|
| 271 |
+
return 0.0
|
| 272 |
+
|
| 273 |
+
def save_accumulated_cost(accumulated_cost_file, cost):
|
| 274 |
+
with open(accumulated_cost_file, "w", encoding="utf-8") as f:
|
| 275 |
+
json.dump({"total_cost": cost}, f)
|
| 276 |
+
|
| 277 |
+
def print_response(completion_json, is_llm=False):
|
| 278 |
+
print("============================================")
|
| 279 |
+
if is_llm:
|
| 280 |
+
print(completion_json['text'])
|
| 281 |
+
else:
|
| 282 |
+
print(completion_json['choices'][0]['message']['content'])
|
| 283 |
+
print("============================================\n")
|
| 284 |
+
|
| 285 |
+
def print_log_cost(completion_json, gpt_version, current_stage, output_dir, total_accumulated_cost):
|
| 286 |
+
usage_info = cal_cost(completion_json, gpt_version)
|
| 287 |
+
|
| 288 |
+
current_cost = usage_info['total_cost']
|
| 289 |
+
total_accumulated_cost += current_cost
|
| 290 |
+
|
| 291 |
+
output_lines = []
|
| 292 |
+
output_lines.append("🌟 Usage Summary 🌟")
|
| 293 |
+
output_lines.append(f"{current_stage}")
|
| 294 |
+
output_lines.append(f"🛠️ Model: {usage_info['model_name']}")
|
| 295 |
+
output_lines.append(f"📥 Input tokens: {usage_info['actual_input_tokens']} (Cost: ${usage_info['input_cost']:.8f})")
|
| 296 |
+
output_lines.append(f"📦 Cached input tokens: {usage_info['cached_tokens']} (Cost: ${usage_info['cached_input_cost']:.8f})")
|
| 297 |
+
output_lines.append(f"📤 Output tokens: {usage_info['output_tokens']} (Cost: ${usage_info['output_cost']:.8f})")
|
| 298 |
+
output_lines.append(f"💵 Current total cost: ${current_cost:.8f}")
|
| 299 |
+
output_lines.append(f"🪙 Accumulated total cost so far: ${total_accumulated_cost:.8f}")
|
| 300 |
+
output_lines.append("============================================\n")
|
| 301 |
+
|
| 302 |
+
output_text = "\n".join(output_lines)
|
| 303 |
+
|
| 304 |
+
print(output_text)
|
| 305 |
+
|
| 306 |
+
with open(f"{output_dir}/cost_info.log", "a", encoding="utf-8") as f:
|
| 307 |
+
f.write(output_text + "\n")
|
| 308 |
+
|
| 309 |
+
return total_accumulated_cost
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def num_tokens_from_messages(messages, model="gpt-4o-2024-08-06"):
|
| 313 |
+
import tiktoken
|
| 314 |
+
|
| 315 |
+
"""Return the number of tokens used by a list of messages."""
|
| 316 |
+
try:
|
| 317 |
+
encoding = tiktoken.encoding_for_model(model)
|
| 318 |
+
except KeyError:
|
| 319 |
+
print("Warning: model not found. Using o200k_base encoding.")
|
| 320 |
+
encoding = tiktoken.get_encoding("o200k_base")
|
| 321 |
+
if model in {
|
| 322 |
+
"gpt-3.5-turbo-0125",
|
| 323 |
+
"gpt-4-0314",
|
| 324 |
+
"gpt-4-32k-0314",
|
| 325 |
+
"gpt-4-0613",
|
| 326 |
+
"gpt-4-32k-0613",
|
| 327 |
+
"gpt-4o-mini-2024-07-18",
|
| 328 |
+
"gpt-4o-2024-08-06"
|
| 329 |
+
}:
|
| 330 |
+
tokens_per_message = 3
|
| 331 |
+
tokens_per_name = 1
|
| 332 |
+
elif "gpt-3.5-turbo" in model:
|
| 333 |
+
print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0125.")
|
| 334 |
+
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0125")
|
| 335 |
+
elif "gpt-4o-mini" in model:
|
| 336 |
+
print("Warning: gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-mini-2024-07-18.")
|
| 337 |
+
return num_tokens_from_messages(messages, model="gpt-4o-mini-2024-07-18")
|
| 338 |
+
elif "gpt-4o" in model:
|
| 339 |
+
print("Warning: gpt-4o and gpt-4o-mini may update over time. Returning num tokens assuming gpt-4o-2024-08-06.")
|
| 340 |
+
return num_tokens_from_messages(messages, model="gpt-4o-2024-08-06")
|
| 341 |
+
|
| 342 |
+
elif "gpt-4" in model:
|
| 343 |
+
print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
| 344 |
+
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
| 345 |
+
else:
|
| 346 |
+
raise NotImplementedError(
|
| 347 |
+
f"""num_tokens_from_messages() is not implemented for model {model}."""
|
| 348 |
+
)
|
| 349 |
+
num_tokens = 0
|
| 350 |
+
for message in messages:
|
| 351 |
+
num_tokens += tokens_per_message
|
| 352 |
+
for key, value in message.items():
|
| 353 |
+
# num_tokens += len(encoding.encode(value)
|
| 354 |
+
num_tokens += len(encoding.encode(value, allowed_special={"<|endoftext|>"},disallowed_special=()))
|
| 355 |
+
|
| 356 |
+
if key == "name":
|
| 357 |
+
num_tokens += tokens_per_name
|
| 358 |
+
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
| 359 |
+
return num_tokens
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def read_all_files(directory, allowed_ext, is_print=True):
|
| 364 |
+
"""Recursively read all .py files in the specified directory and return their contents."""
|
| 365 |
+
all_files_content = {}
|
| 366 |
+
|
| 367 |
+
for root, _, files in os.walk(directory): # Recursively traverse directories
|
| 368 |
+
for filename in files:
|
| 369 |
+
relative_path = os.path.relpath(os.path.join(root, filename), directory) # Preserve directory structure
|
| 370 |
+
|
| 371 |
+
# print(f"fn: {filename}\tdirectory: {directory}")
|
| 372 |
+
_file_name, ext = os.path.splitext(filename)
|
| 373 |
+
|
| 374 |
+
is_skip = False
|
| 375 |
+
if len(directory) < len(root):
|
| 376 |
+
root2 = root[len(directory)+1:]
|
| 377 |
+
for dirname in root2.split("/"):
|
| 378 |
+
if dirname.startswith("."):
|
| 379 |
+
is_skip = True
|
| 380 |
+
break
|
| 381 |
+
|
| 382 |
+
if filename.startswith(".") or "requirements.txt" in filename or ext == "" or is_skip:
|
| 383 |
+
if is_print and ext == "":
|
| 384 |
+
print(f"[SKIP] {os.path.join(root, filename)}")
|
| 385 |
+
continue
|
| 386 |
+
|
| 387 |
+
if ext not in allowed_ext:
|
| 388 |
+
if _file_name.lower() != "readme":
|
| 389 |
+
if is_print:
|
| 390 |
+
print(f"[SKIP] {os.path.join(root, filename)}")
|
| 391 |
+
continue
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
filepath = os.path.join(root, filename)
|
| 395 |
+
file_size = os.path.getsize(filepath) # bytes
|
| 396 |
+
|
| 397 |
+
if file_size > 204800: # > 200KB
|
| 398 |
+
print(f"[BIG] {filepath} {file_size}")
|
| 399 |
+
|
| 400 |
+
with open(filepath, "r") as file: # encoding="utf-8"
|
| 401 |
+
all_files_content[relative_path] = file.read()
|
| 402 |
+
except Exception as e:
|
| 403 |
+
print(e)
|
| 404 |
+
print(f"[SKIP] {os.path.join(root, filename)}")
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
return all_files_content
|
| 408 |
+
|
| 409 |
+
def read_python_files(directory):
|
| 410 |
+
"""Recursively read all .py files in the specified directory and return their contents."""
|
| 411 |
+
python_files_content = {}
|
| 412 |
+
|
| 413 |
+
for root, _, files in os.walk(directory): # Recursively traverse directories
|
| 414 |
+
for filename in files:
|
| 415 |
+
if filename.endswith(".py"): # Check if file has .py extension
|
| 416 |
+
relative_path = os.path.relpath(os.path.join(root, filename), directory) # Preserve directory structure
|
| 417 |
+
with open(os.path.join(root, filename), "r", encoding="utf-8") as file:
|
| 418 |
+
python_files_content[relative_path] = file.read()
|
| 419 |
+
|
| 420 |
+
return python_files_content
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def extract_json_from_string(text):
|
| 424 |
+
# Extract content inside ```yaml\n...\n```
|
| 425 |
+
match = re.search(r"```json\n(.*?)\n```", text, re.DOTALL)
|
| 426 |
+
|
| 427 |
+
if match:
|
| 428 |
+
yaml_content = match.group(1)
|
| 429 |
+
return yaml_content
|
| 430 |
+
else:
|
| 431 |
+
print("No JSON content found.")
|
| 432 |
+
return ""
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def get_now_str():
|
| 436 |
+
now = datetime.now()
|
| 437 |
+
now = str(now)
|
| 438 |
+
now = now.split(".")[0]
|
| 439 |
+
now = now.replace("-","").replace(" ","_").replace(":","")
|
| 440 |
+
return now # now - "20250427_205124"
|
main.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, shutil, tempfile, zipfile, asyncio, subprocess
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from fastapi.responses import FileResponse
|
| 6 |
+
|
| 7 |
+
REPO_ROOT = Path(__file__).parent.resolve()
|
| 8 |
+
CODES_DIR = REPO_ROOT / "codes"
|
| 9 |
+
|
| 10 |
+
app = FastAPI(title="Blog2Code API", version="1.0.0")
|
| 11 |
+
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*").split(",")
|
| 12 |
+
app.add_middleware(
|
| 13 |
+
CORSMiddleware,
|
| 14 |
+
allow_origins=ALLOWED_ORIGINS,
|
| 15 |
+
allow_methods=["*"],
|
| 16 |
+
allow_headers=["*"],
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def _run(script: str, args: list, extra_env: dict) -> None:
|
| 20 |
+
cmd = [sys.executable, str(CODES_DIR / script)] + args
|
| 21 |
+
result = subprocess.run(
|
| 22 |
+
cmd,
|
| 23 |
+
cwd=str(REPO_ROOT),
|
| 24 |
+
env={**os.environ, **extra_env},
|
| 25 |
+
capture_output=True,
|
| 26 |
+
text=True,
|
| 27 |
+
)
|
| 28 |
+
if result.returncode != 0:
|
| 29 |
+
raise RuntimeError(
|
| 30 |
+
f"{script} failed (exit {result.returncode}):\n"
|
| 31 |
+
f"STDOUT: {result.stdout[-2000:]}\n"
|
| 32 |
+
f"STDERR: {result.stderr[-2000:]}"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
@app.get("/health")
|
| 36 |
+
def health():
|
| 37 |
+
return {"status": "ok"}
|
| 38 |
+
|
| 39 |
+
@app.post("/generate")
|
| 40 |
+
async def generate(
|
| 41 |
+
url: str = Form(None),
|
| 42 |
+
file: UploadFile = File(None),
|
| 43 |
+
):
|
| 44 |
+
if not url and not file:
|
| 45 |
+
raise HTTPException(400, "Provide either 'url' or 'file'.")
|
| 46 |
+
|
| 47 |
+
tmp = Path(tempfile.mkdtemp())
|
| 48 |
+
data_dir = tmp / "data"
|
| 49 |
+
output_dir = tmp / "output"
|
| 50 |
+
data_dir.mkdir(parents=True)
|
| 51 |
+
output_dir.mkdir(parents=True)
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
if file:
|
| 55 |
+
suffix = Path(file.filename).suffix or ".md"
|
| 56 |
+
input_path = tmp / f"blog{suffix}"
|
| 57 |
+
input_path.write_bytes(await file.read())
|
| 58 |
+
source_args = ["--input_path", str(input_path)]
|
| 59 |
+
else:
|
| 60 |
+
source_args = ["--url", url.strip()]
|
| 61 |
+
|
| 62 |
+
provider = os.getenv("PROVIDER", "gemini")
|
| 63 |
+
model = os.getenv("MODEL", "")
|
| 64 |
+
extra_env = {"MODEL": model} if model else {}
|
| 65 |
+
|
| 66 |
+
blog_json = data_dir / "blog_data.json"
|
| 67 |
+
|
| 68 |
+
def run_pipeline():
|
| 69 |
+
# Stage 0 – parse blog
|
| 70 |
+
_run("0_blog_process.py",
|
| 71 |
+
source_args + ["--output_json_path", str(blog_json)],
|
| 72 |
+
extra_env)
|
| 73 |
+
|
| 74 |
+
if not blog_json.exists():
|
| 75 |
+
candidates = list(data_dir.glob("*.json"))
|
| 76 |
+
if not candidates:
|
| 77 |
+
raise RuntimeError("Stage 0: no JSON output found.")
|
| 78 |
+
blog_json_path = candidates[0]
|
| 79 |
+
else:
|
| 80 |
+
blog_json_path = blog_json
|
| 81 |
+
|
| 82 |
+
# Stage 1 – planning
|
| 83 |
+
_run("1_planning.py", [
|
| 84 |
+
"--blog_json_path", str(blog_json_path),
|
| 85 |
+
"--output_dir", str(data_dir),
|
| 86 |
+
"--provider", provider,
|
| 87 |
+
"--content_type", "blog",
|
| 88 |
+
], extra_env)
|
| 89 |
+
|
| 90 |
+
# Stage 1.1 – extract config
|
| 91 |
+
_run("1_1_extract_config.py", [
|
| 92 |
+
"--output_dir", str(data_dir),
|
| 93 |
+
], extra_env)
|
| 94 |
+
|
| 95 |
+
config_yaml = data_dir / "planning_config.yaml"
|
| 96 |
+
if not config_yaml.exists():
|
| 97 |
+
raise RuntimeError("Stage 1.1: planning_config.yaml not found.")
|
| 98 |
+
|
| 99 |
+
# Stage 2 – analysis
|
| 100 |
+
_run("2_analyzing.py", [
|
| 101 |
+
"--pdf_json_path", str(blog_json_path),
|
| 102 |
+
"--output_dir", str(data_dir),
|
| 103 |
+
"--provider", provider,
|
| 104 |
+
], extra_env)
|
| 105 |
+
|
| 106 |
+
# Stage 3 – code generation
|
| 107 |
+
_run("3_coding.py", [
|
| 108 |
+
"--pdf_json_path", str(blog_json_path),
|
| 109 |
+
"--output_dir", str(data_dir),
|
| 110 |
+
"--output_repo_dir", str(output_dir),
|
| 111 |
+
"--provider", provider,
|
| 112 |
+
], extra_env)
|
| 113 |
+
|
| 114 |
+
await asyncio.get_event_loop().run_in_executor(None, run_pipeline)
|
| 115 |
+
|
| 116 |
+
zip_path = tmp / "repo.zip"
|
| 117 |
+
files = [f for f in output_dir.rglob("*") if f.is_file()]
|
| 118 |
+
if not files:
|
| 119 |
+
raise HTTPException(500, "Pipeline produced no output files.")
|
| 120 |
+
|
| 121 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 122 |
+
for f in files:
|
| 123 |
+
zf.write(f, f.relative_to(output_dir))
|
| 124 |
+
|
| 125 |
+
return FileResponse(
|
| 126 |
+
path=str(zip_path),
|
| 127 |
+
media_type="application/zip",
|
| 128 |
+
filename="generated-repo.zip",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
except HTTPException:
|
| 132 |
+
shutil.rmtree(tmp, ignore_errors=True)
|
| 133 |
+
raise
|
| 134 |
+
except Exception as exc:
|
| 135 |
+
shutil.rmtree(tmp, ignore_errors=True)
|
| 136 |
+
raise HTTPException(500, str(exc)) from exc
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn
|
| 3 |
+
python-multipart
|
| 4 |
+
openai>=1.65.4
|
| 5 |
+
tiktoken>=0.9.0
|
| 6 |
+
google-generativeai>=0.8.0
|
| 7 |
+
beautifulsoup4>=4.12.0
|
| 8 |
+
requests>=2.31.0
|
| 9 |
+
markdown>=3.5.0
|
| 10 |
+
html2text>=2020.1.16
|
| 11 |
+
lxml>=5.0.0
|
| 12 |
+
tqdm>=4.60.0
|
| 13 |
+
pyyaml>=6.0
|