Spaces:
Sleeping
Sleeping
File size: 34,486 Bytes
66bbfd0 1e8fbb1 66bbfd0 1e8fbb1 66bbfd0 1e8fbb1 66bbfd0 1e8fbb1 66bbfd0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CodePilot v3 — 多模型 AI 開發助手 + 知識蒸餾
==============================================
支援多種模型後端:
🏠 Local: Qwen2.5-Coder-3B(你的本地模型)
☁️ Cloud: OpenAI (GPT-4o/5), Anthropic (Claude Opus), Google (Gemini)
🔗 Proxy: OpenRouter(一個 API 接所有模型)
知識蒸餾模式:
用 Opus/GPT-5 的回答,自動訓練你的本地模型 → 免費版的 Opus!
Usage:
# 用本地模型
codepilot
# 用 Claude Opus(並自動收集訓練數據)
codepilot --provider anthropic --api-key sk-xxx
# 用 OpenRouter(最方便,一個 key 用所有模型)
codepilot --provider openrouter --api-key sk-xxx --cloud-model anthropic/claude-opus-4
# 蒸餾模式:用雲端模型產生數據,訓練本地模型
codepilot --distill --provider openrouter --api-key sk-xxx
# 用收集的雲端數據訓練本地模型
codepilot --train
"""
import argparse, difflib, json, os, re, shutil, sqlite3, subprocess, sys, torch, httpx
from datetime import datetime
from pathlib import Path
# ============================================================
# CONFIG
# ============================================================
DEFAULT_LOCAL_MODEL = "Qwen/Qwen2.5-Coder-3B-Instruct"
CONFIG_DIR = os.path.expanduser("~/.codepilot")
DB_PATH = os.path.join(CONFIG_DIR, "feedback.db")
# 雲端模型預設
PROVIDER_CONFIGS = {
"local": {
"name": "Local (Qwen2.5-Coder-3B)",
"type": "local",
},
"openai": {
"name": "OpenAI",
"type": "openai",
"base_url": "https://api.openai.com/v1",
"default_model": "gpt-4o",
},
"anthropic": {
"name": "Anthropic",
"type": "anthropic",
"base_url": "https://api.anthropic.com/v1",
"default_model": "claude-sonnet-4-20250514",
},
"openrouter": {
"name": "OpenRouter (所有模型)",
"type": "openai", # OpenRouter 用 OpenAI 相容 API
"base_url": "https://openrouter.ai/api/v1",
"default_model": "anthropic/claude-sonnet-4",
},
"ollama": {
"name": "Ollama (本地)",
"type": "openai",
"base_url": "http://localhost:11434/v1",
"default_model": "qwen2.5-coder:3b",
},
}
# ============================================================
# FEEDBACK DB (升級版 — 記錄來源模型)
# ============================================================
class FeedbackDB:
def __init__(self):
os.makedirs(CONFIG_DIR, exist_ok=True)
self.conn = sqlite3.connect(DB_PATH)
self.conn.execute("""CREATE TABLE IF NOT EXISTS feedback (
id INTEGER PRIMARY KEY, timestamp TEXT, prompt TEXT, completion TEXT,
label INTEGER, edited_completion TEXT, project TEXT,
source_model TEXT, provider TEXT)""")
self.conn.commit()
def save(self, prompt, completion, label, edited=None, project=None,
source_model=None, provider=None):
self.conn.execute(
"INSERT INTO feedback VALUES (NULL,?,?,?,?,?,?,?,?)",
(datetime.now().isoformat(), prompt, completion, int(label),
edited, project, source_model, provider))
self.conn.commit()
def count(self, provider=None):
if provider:
r = self.conn.execute(
"SELECT COUNT(*), COALESCE(SUM(label),0), "
"SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) "
"FROM feedback WHERE provider=?", (provider,)).fetchone()
else:
r = self.conn.execute(
"SELECT COUNT(*), COALESCE(SUM(label),0), "
"SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) "
"FROM feedback").fetchone()
return {"total": r[0], "up": int(r[1]), "edits": int(r[2] or 0)}
def export_sft(self, only_cloud=False):
"""匯出 SFT 數據(可選只匯出雲端模型的)"""
if only_cloud:
# 雲端模型接受的回答 = 高品質 SFT 數據
rows = self.conn.execute(
"SELECT prompt, completion FROM feedback "
"WHERE label=1 AND provider != 'local' AND provider IS NOT NULL"
).fetchall()
else:
rows = self.conn.execute(
"SELECT prompt, COALESCE(edited_completion, completion) FROM feedback "
"WHERE label=1"
).fetchall()
return [{"messages": [
{"role": "user", "content": p},
{"role": "assistant", "content": c},
]} for p, c in rows]
def export_dpo(self):
"""用雲端模型 vs 本地模型的回答配對成 DPO 數據"""
# 找相同 prompt 但不同 provider 的配對
rows = self.conn.execute("""
SELECT c.prompt, c.completion, l.completion
FROM feedback c
JOIN feedback l ON c.prompt = l.prompt
WHERE c.provider != 'local' AND c.label = 1
AND l.provider = 'local' AND l.label = 0
""").fetchall()
return [{
"prompt": [{"role": "user", "content": p}],
"chosen": [{"role": "assistant", "content": cloud}],
"rejected": [{"role": "assistant", "content": local}],
} for p, cloud, local in rows]
def export_kto(self):
rows = self.conn.execute(
"SELECT prompt, completion, label FROM feedback"
).fetchall()
return [{
"prompt": [{"role": "user", "content": p}],
"completion": [{"role": "assistant", "content": c}],
"label": bool(l),
} for p, c, l in rows]
# ============================================================
# MODEL BACKENDS
# ============================================================
class LocalModel:
"""本地 Qwen 模型"""
def __init__(self, model_name=DEFAULT_LOCAL_MODEL, adapter_path=None):
from transformers import AutoTokenizer, AutoModelForCausalLM
self.name = model_name.split("/")[-1]
self.provider = "local"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16,
device_map="auto", trust_remote_code=True)
if adapter_path and os.path.exists(adapter_path):
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, adapter_path)
self.model.eval()
def chat(self, messages, max_tokens=4096):
text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
out = self.model.generate(
**inputs, max_new_tokens=max_tokens, do_sample=True,
temperature=0.7, top_p=0.9, repetition_penalty=1.1,
pad_token_id=self.tokenizer.pad_token_id)
return self.tokenizer.decode(
out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
class CloudModel:
"""雲端模型(OpenAI 相容 API)"""
def __init__(self, provider_key, api_key, model_name=None):
config = PROVIDER_CONFIGS[provider_key]
self.provider = provider_key
self.base_url = config["base_url"]
self.name = model_name or config["default_model"]
self.api_key = api_key
self.api_type = config["type"]
def chat(self, messages, max_tokens=4096):
if self.api_type == "anthropic":
return self._chat_anthropic(messages, max_tokens)
else:
return self._chat_openai(messages, max_tokens)
def _chat_openai(self, messages, max_tokens):
"""OpenAI / OpenRouter / Ollama 相容 API"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# OpenRouter 需要額外 header
if self.provider == "openrouter":
headers["HTTP-Referer"] = "https://codepilot.local"
headers["X-Title"] = "CodePilot"
data = {
"model": self.name,
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.7,
}
resp = httpx.post(
f"{self.base_url}/chat/completions",
headers=headers, json=data, timeout=120)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]
def _chat_anthropic(self, messages, max_tokens):
"""Anthropic API"""
# 分離 system message
system = None
chat_msgs = []
for m in messages:
if m["role"] == "system":
system = m["content"]
else:
chat_msgs.append(m)
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
data = {
"model": self.name,
"messages": chat_msgs,
"max_tokens": max_tokens,
"temperature": 0.7,
}
if system:
data["system"] = system
resp = httpx.post(
f"{self.base_url}/messages",
headers=headers, json=data, timeout=120)
resp.raise_for_status()
return resp.json()["content"][0]["text"]
# ============================================================
# PROJECT TOOLS(和 v2 相同)
# ============================================================
class ProjectTools:
def __init__(self, project_dir):
self.project_dir = os.path.abspath(project_dir)
self.cwd = self.project_dir
self.read_cache = {}
def _resolve(self, path):
return path if os.path.isabs(path) else os.path.normpath(os.path.join(self.cwd, path))
def read_file(self, path, offset=1, limit=200):
full = self._resolve(path)
if not os.path.exists(full): return f"❌ 不存在: {path}"
try:
content = Path(full).read_text(encoding="utf-8", errors="replace")
lines = content.splitlines()
self.read_cache[full] = {"time": os.path.getmtime(full), "content": content}
selected = lines[offset-1:offset-1+limit]
result = "\n".join(f"{i+offset:4d} │ {line}" for i, line in enumerate(selected))
if offset + limit < len(lines): result += f"\n... ({len(lines)-offset-limit+1} more)"
return result
except Exception as e: return f"❌ {e}"
def edit_file(self, path, old_string, new_string):
full = self._resolve(path)
if full not in self.read_cache: return "❌ 必須先 read_file"
content = Path(full).read_text(encoding="utf-8")
if os.path.getmtime(full) != self.read_cache[full]["time"]: return "❌ 文件已被外部修改"
count = content.count(old_string)
if count == 0: return "❌ 找不到要替換的文字"
if count > 1: return f"❌ 找到 {count} 處,請提供更多上下文"
new_content = content.replace(old_string, new_string, 1)
diff = "".join(difflib.unified_diff(
content.splitlines(keepends=True), new_content.splitlines(keepends=True),
fromfile=f"a/{path}", tofile=f"b/{path}"))
Path(full).write_text(new_content, encoding="utf-8")
self.read_cache[full] = {"time": os.path.getmtime(full), "content": new_content}
return "✅ 已修改:\n" + diff
def write_file(self, path, content):
full = self._resolve(path); os.makedirs(os.path.dirname(full) or ".", exist_ok=True)
is_new = not os.path.exists(full)
Path(full).write_text(content, encoding="utf-8")
self.read_cache[full] = {"time": os.path.getmtime(full), "content": content}
return f"✅ {'建立' if is_new else '覆寫'}: {path}"
def run_command(self, command, timeout=120):
for d in {"rm -rf /","git push --force","git reset --hard"}:
if d in command: return f"⛔ 危險: {command}"
try:
r = subprocess.run(command, shell=True, cwd=self.cwd, capture_output=True, text=True, timeout=timeout)
return (r.stdout + (f"\nSTDERR:\n{r.stderr}" if r.stderr else ""))[:10000]
except subprocess.TimeoutExpired: return f"⏰ 超時"
except Exception as e: return f"❌ {e}"
def search_files(self, pattern, glob_pattern=None):
rg = shutil.which("rg"); cmd = [rg or "grep", "-rn"]
if rg: cmd += ["--color=never", "--max-count=50"]
if glob_pattern and rg: cmd += ["--glob", glob_pattern]
cmd += [pattern, self.cwd]
try: return subprocess.run(cmd, capture_output=True, text=True, timeout=30).stdout[:5000] or "無匹配"
except Exception as e: return f"❌ {e}"
def list_files(self, pattern="*", max_depth=3):
files = []
for root, dirs, fnames in os.walk(self.cwd):
dirs[:] = [d for d in dirs if d not in {".git","node_modules","__pycache__",".venv","dist","build"}]
if root.replace(self.cwd, "").count(os.sep) >= max_depth: continue
files.extend(os.path.relpath(os.path.join(root, f), self.cwd) for f in fnames if Path(f).match(pattern))
return "\n".join(sorted(files)[:100])
def git_context(self):
try:
b = subprocess.run(["git","branch","--show-current"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip()
s = subprocess.run(["git","status","--short"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip()
l = subprocess.run(["git","log","--oneline","-5"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip()
return f"Branch: {b}\nStatus:\n{s}\nRecent:\n{l}"
except: return "(not a git repo)"
# ============================================================
# TOOL PARSER + EXECUTOR + SYSTEM PROMPT(和 v2 相同)
# ============================================================
TOOL_PATTERN = re.compile(r'<tool>\s*(\w+)\s*\n(.*?)</tool>', re.DOTALL)
def parse_tool_calls(text):
calls = []
for m in TOOL_PATTERN.finditer(text):
try: params = json.loads(m.group(2).strip())
except:
params = {}
for line in m.group(2).strip().split("\n"):
if ":" in line: k, v = line.split(":", 1); params[k.strip()] = v.strip().strip('"')
calls.append({"tool": m.group(1), "params": params})
return calls
def execute_tool(tools, call):
n, p = call["tool"], call["params"]
try:
if n == "read_file": return tools.read_file(p.get("path",""), int(p.get("offset",1)), int(p.get("limit",200)))
elif n == "edit_file": return tools.edit_file(p.get("path",""), p.get("old_string",""), p.get("new_string",""))
elif n == "write_file": return tools.write_file(p.get("path",""), p.get("content",""))
elif n == "run_command": return tools.run_command(p.get("command",""), int(p.get("timeout",120)))
elif n == "search_files": return tools.search_files(p.get("pattern",""), p.get("glob"))
elif n == "list_files": return tools.list_files(p.get("pattern","*"), int(p.get("max_depth",3)))
elif n == "git_status": return tools.git_context()
else: return f"❌ 未知: {n}"
except Exception as e: return f"❌ {e}"
def build_system_prompt(tools):
return f"""You are CodePilot, an expert AI programming assistant working in the user's project.
Working directory: {tools.cwd}
{tools.git_context()}
## Tools (use <tool>name\n{{json}}</tool>)
- read_file: {{"path":"...","offset":1,"limit":200}}
- edit_file: {{"path":"...","old_string":"...","new_string":"..."}} (must read first)
- write_file: {{"path":"...","content":"..."}}
- run_command: {{"command":"...","timeout":120}}
- search_files: {{"pattern":"...","glob":"*.py"}}
- list_files: {{"pattern":"*","max_depth":3}}
- git_status: {{}}
Rules: read before edit, old_string must be unique, prefer edit over write, verify changes."""
# ============================================================
# MAIN AGENT LOOP
# ============================================================
def run_agent_loop(args):
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Prompt
from rich.syntax import Syntax
from rich.table import Table
console = Console()
db = FeedbackDB()
project_dir = args.project or os.getcwd()
tools = ProjectTools(project_dir)
provider_key = args.provider or "local"
# 決定使用哪個模型
if provider_key == "local":
model_label = args.model or DEFAULT_LOCAL_MODEL
with console.status("[bold green]載入本地模型..."):
model = LocalModel(model_label, args.adapter)
else:
if not args.api_key:
console.print(f"[red]❌ 使用 {provider_key} 需要 --api-key[/]")
sys.exit(1)
cloud_model = args.cloud_model or PROVIDER_CONFIGS[provider_key]["default_model"]
model = CloudModel(provider_key, args.api_key, cloud_model)
model_label = f"{PROVIDER_CONFIGS[provider_key]['name']}/{model.name}"
# 蒸餾模式標記
distill_mode = args.distill and provider_key != "local"
# Banner
banner = f"[bold cyan]CodePilot v3[/]"
if distill_mode:
banner += " [bold yellow]⚗️ 蒸餾模式[/]"
banner += f"\n[dim]Model: {model_label}\nProject: {project_dir}[/]"
if distill_mode:
banner += f"\n[yellow]雲端回答將自動收集為本地模型的訓練數據[/]"
console.print(Panel.fit(banner, border_style="cyan"))
if provider_key == "local":
console.print("[green]✅ 本地模型載入完成[/]")
else:
console.print(f"[green]✅ 已連接 {PROVIDER_CONFIGS[provider_key]['name']}[/]")
git_ctx = tools.git_context()
if git_ctx != "(not a git repo)":
console.print(Panel(git_ctx, title="📂 Project", border_style="dim"))
console.print("[dim]指令: /ls /git /clear /switch /compare /status /train /quit[/]\n")
# 保存模型參照,讓 /compare 可以用
local_model_ref = None
cloud_model_ref = None
if provider_key == "local":
local_model_ref = model
else:
cloud_model_ref = model
# 蒸餾/compare 模式下,也嘗試載入本地模型
if args.adapter or distill_mode:
try:
with console.status("[dim]同時載入本地模型 (for /compare)..."):
local_model_ref = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
console.print("[dim]✅ 本地模型也已載入,可用 /compare[/]")
except Exception:
console.print("[dim]⚠️ 本地模型載入失敗,/compare 不可用[/]")
system_prompt = build_system_prompt(tools)
messages = [{"role": "system", "content": system_prompt}]
while True:
try: user_input = Prompt.ask("\n[bold green]🧑 You")
except (EOFError, KeyboardInterrupt): break
if not user_input.strip(): continue
cmd = user_input.strip()
if cmd in ("/quit", "/exit"): break
elif cmd == "/status":
s_all = db.count(); s_cloud = db.count("local")
t = Table(title="📊 數據統計"); t.add_column("來源"); t.add_column("數量"); t.add_column("👍")
t.add_row("全部", str(s_all["total"]), str(s_all["up"]))
for pk in ["local", "openai", "anthropic", "openrouter"]:
sc = db.count(pk)
if sc["total"] > 0: t.add_row(pk, str(sc["total"]), str(sc["up"]))
console.print(t)
# 蒸餾數據
sft = db.export_sft(only_cloud=True)
dpo = db.export_dpo()
if sft or dpo:
console.print(f"\n [yellow]⚗️ 可蒸餾數據: SFT {len(sft)} / DPO {len(dpo)}[/]")
console.print(f" [dim]運行 codepilot --train 開始蒸餾訓練[/]")
continue
elif cmd == "/train":
trigger_training(db, console, args); continue
elif cmd == "/clear":
messages = [{"role": "system", "content": system_prompt}]
console.print("[dim]已清除[/]"); continue
elif cmd == "/git":
console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue
elif cmd.startswith("/ls"):
console.print(tools.list_files(cmd[3:].strip() or "*")); continue
elif cmd == "/compare" or cmd.startswith("/compare "):
# /compare 模式:同一問題自動送給本地+雲端,並排比較,一鍵產生 DPO
compare_question = cmd[8:].strip() if cmd.startswith("/compare ") else None
if not compare_question:
compare_question = Prompt.ask(" 輸入要比較的問題")
if not compare_question.strip():
continue
need_local = local_model_ref or (provider_key == "local" and model)
need_cloud = cloud_model_ref or (provider_key != "local" and model)
if not need_local or not need_cloud:
console.print("[yellow]⚠️ /compare 需要同時有本地和雲端模型[/]")
console.print("[dim]啟動方式: codepilot --provider openrouter --api-key sk-xxx --adapter ./adapter[/]")
continue
lm = local_model_ref if local_model_ref else model
cm = cloud_model_ref if cloud_model_ref else model
compare_msgs = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": compare_question},
]
# 同時生成兩個回答
with console.status("[bold cyan]本地模型思考中..."):
try: local_resp = lm.chat(compare_msgs)
except Exception as e: local_resp = f"(錯誤: {e})"
with console.status("[bold magenta]雲端模型思考中..."):
try: cloud_resp = cm.chat(compare_msgs)
except Exception as e: cloud_resp = f"(錯誤: {e})"
# 並排顯示
console.print(f"\n[bold]🔄 Compare: {compare_question[:60]}{'...' if len(compare_question)>60 else ''}[/]\n")
console.print(Panel(Markdown(local_resp), title=f"🏠 Local ({lm.name})", border_style="blue"))
console.print(Panel(Markdown(cloud_resp), title=f"☁️ Cloud ({cm.name})", border_style="magenta"))
# 選擇
console.print(f"\n [green]1[/] = 本地較好 [magenta]2[/] = 雲端較好 [yellow]b[/] = 都好 [red]x[/] = 都差 Enter = 跳過")
choice = Prompt.ask(" ", choices=["1","2","b","x",""], default="", show_choices=False)
if choice == "2":
# 雲端好,本地差 → DPO: chosen=cloud, rejected=local
db.save(compare_question, cloud_resp, 1, project=project_dir,
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
db.save(compare_question, local_resp, 0, project=project_dir,
source_model=getattr(lm, "name", "local"), provider="local")
dpo_count = len(db.export_dpo())
console.print(f" [magenta]☁️ 雲端勝 → DPO +1[/] (累計 DPO 對: {dpo_count})")
elif choice == "1":
# 本地好 → 記錄本地為正面
db.save(compare_question, local_resp, 1, project=project_dir,
source_model=getattr(lm, "name", "local"), provider="local")
db.save(compare_question, cloud_resp, 0, project=project_dir,
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
console.print(f" [green]🏠 本地勝!你的模型在進步![/]")
elif choice == "b":
# 都好 → 兩個都記為 SFT
db.save(compare_question, local_resp, 1, project=project_dir,
source_model=getattr(lm, "name", "local"), provider="local")
db.save(compare_question, cloud_resp, 1, project=project_dir,
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
console.print(f" [yellow]👍 都好 → SFT +2[/]")
elif choice == "x":
# 都差
db.save(compare_question, local_resp, 0, project=project_dir,
source_model=getattr(lm, "name", "local"), provider="local")
db.save(compare_question, cloud_resp, 0, project=project_dir,
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
console.print(f" [red]👎 都差[/]")
continue
elif cmd == "/switch":
console.print("可用: local, openai, anthropic, openrouter, ollama")
new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys()))
if new_p == "local":
with console.status("載入本地模型..."):
model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
local_model_ref = model; provider_key = "local"
else:
key = args.api_key or Prompt.ask("API Key")
cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"])
model = CloudModel(new_p, key, cm); cloud_model_ref = model; provider_key = new_p
console.print(f"[green]✅ 切換到 {provider_key}[/]"); continue
messages.append({"role": "user", "content": user_input})
full_response = ""
for rnd in range(10):
with console.status(f"[bold cyan]{'思考中' if rnd == 0 else f'工具 round {rnd+1}'}..."):
try:
response = model.chat(messages)
except Exception as e:
console.print(f"[red]❌ API 錯誤: {e}[/]"); break
tool_calls = parse_tool_calls(response)
text_parts = TOOL_PATTERN.sub("", response).strip()
if text_parts:
console.print(f"\n[bold blue]🤖 CodePilot ({model.name if hasattr(model,'name') else 'local'}):[/]")
console.print(Markdown(text_parts))
full_response += response + "\n"
if not tool_calls: break
messages.append({"role": "assistant", "content": response})
results = []
for call in tool_calls:
console.print(f" [dim]🔧 {call['tool']}[/]")
result = execute_tool(tools, call)
if call["tool"] == "edit_file" and "✅" in result:
d = result.split("\n", 1)[1] if "\n" in result else ""
if d: console.print(Syntax(d, "diff", theme="monokai"))
elif call["tool"] == "run_command":
console.print(Panel(result[:500], title="Terminal", border_style="dim"))
else:
console.print(f" [dim]{result[:200]}[/]")
results.append(f"[{call['tool']}] {result}")
messages.append({"role": "user", "content": "Tool results:\n" + "\n\n".join(results)})
# 回饋
if distill_mode:
# 蒸餾模式:自動接受雲端回答
db.save(user_input, full_response, 1, project=project_dir,
source_model=getattr(model, "name", "local"), provider=provider_key)
console.print(f" [yellow]⚗️ 自動記錄為訓練數據[/]")
else:
console.print(f"\n[dim][green]y[/]=👍 [red]n[/]=👎 [yellow]e[/]=✏️ Enter=跳過[/]")
fb = Prompt.ask(" ", choices=["y","n","e",""], default="", show_choices=False)
if fb == "y":
db.save(user_input, full_response, 1, project=project_dir,
source_model=getattr(model, "name", "local"), provider=provider_key)
console.print(" [green]👍[/]")
elif fb == "n":
db.save(user_input, full_response, 0, project=project_dir,
source_model=getattr(model, "name", "local"), provider=provider_key)
console.print(" [red]👎[/]")
elif fb == "e":
console.print(" [yellow]貼上修改版(END結束):[/]")
lines = []
while True:
try:
l = input()
if l.strip() == "END": break
lines.append(l)
except EOFError: break
edited = "\n".join(lines)
if edited.strip():
db.save(user_input, full_response, 1, edited=edited, project=project_dir,
source_model=getattr(model, "name", "local"), provider=provider_key)
console.print(" [yellow]✏️[/]")
if messages[-1]["role"] == "user" and "Tool results:" in messages[-1]["content"]:
messages.append({"role": "assistant", "content": full_response})
console.print("\n[cyan]👋[/]")
# ============================================================
# TRAINING(支援蒸餾)
# ============================================================
def trigger_training(db, console, args):
stats = db.count()
if stats["total"] == 0:
console.print("[yellow]⚠️ 無數據[/]"); return
# 蒸餾數據(雲端模型產生的)
cloud_sft = db.export_sft(only_cloud=True)
dpo_data = db.export_dpo()
all_sft = db.export_sft(only_cloud=False)
kto_data = db.export_kto()
console.print(f"\n[bold]🚀 訓練數據統計[/]")
console.print(f" ⚗️ 雲端蒸餾 SFT: {len(cloud_sft)} 條 (Opus/GPT 的回答)")
console.print(f" 📊 雲端 vs 本地 DPO: {len(dpo_data)} 對")
console.print(f" 📚 全部 SFT: {len(all_sft)} 條")
console.print(f" 👍👎 KTO: {len(kto_data)} 條")
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training
model_name = args.model or DEFAULT_LOCAL_MODEL
output_dir = os.path.join(CONFIG_DIR, f"adapter_{datetime.now().strftime('%Y%m%d_%H%M')}")
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
peft_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
# 優先級:蒸餾 SFT > DPO > 全部 SFT > KTO
train_data = cloud_sft or all_sft
if train_data:
label = "⚗️ 蒸餾 SFT" if cloud_sft else "📚 SFT"
console.print(f"\n[bold]{label} ({len(train_data)} 條)...[/]")
from trl import SFTTrainer, SFTConfig
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb, device_map="auto", trust_remote_code=True)
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = prepare_model_for_kbit_training(model)
trainer = SFTTrainer(
model=model,
args=SFTConfig(
output_dir=output_dir, learning_rate=2e-4, num_train_epochs=3,
per_device_train_batch_size=1, gradient_accumulation_steps=8,
max_seq_length=1024, gradient_checkpointing=True, bf16=True,
optim="paged_adamw_8bit", logging_steps=5, save_total_limit=1,
logging_strategy="steps", logging_first_step=True),
processing_class=tok,
train_dataset=Dataset.from_list(train_data),
peft_config=peft_cfg)
trainer.train()
trainer.save_model(output_dir)
del model; torch.cuda.empty_cache()
console.print(f"\n[bold green]🎉 完成![/]")
console.print(f" Adapter: {output_dir}")
console.print(f" 使用: codepilot --adapter {output_dir}")
def show_stats():
from rich.console import Console; from rich.table import Table
c = Console(); db = FeedbackDB()
t = Table(title="📊 CodePilot 數據統計")
t.add_column("來源", style="cyan"); t.add_column("總計"); t.add_column("👍"); t.add_column("✏️")
for pk in [None, "local", "openai", "anthropic", "openrouter"]:
s = db.count(pk)
if s["total"] > 0:
t.add_row(pk or "全部", str(s["total"]), str(s["up"]), str(s["edits"]))
c.print(t)
cloud_sft = db.export_sft(only_cloud=True)
dpo = db.export_dpo()
if cloud_sft or dpo:
c.print(f"\n [yellow]⚗️ 可蒸餾: SFT {len(cloud_sft)} / DPO {len(dpo)}[/]")
def main():
p = argparse.ArgumentParser(description="CodePilot v3 — 多模型 AI 開發助手")
p.add_argument("--model", type=str, default=None, help="本地模型")
p.add_argument("--adapter", type=str, default=None, help="LoRA adapter")
p.add_argument("--project", type=str, default=None, help="專案目錄")
p.add_argument("--provider", type=str, default=None,
choices=["local","openai","anthropic","openrouter","ollama"],
help="模型提供者")
p.add_argument("--api-key", type=str, default=None, help="API key")
p.add_argument("--cloud-model", type=str, default=None, help="雲端模型名稱")
p.add_argument("--distill", action="store_true", help="蒸餾模式(自動收集雲端回答)")
p.add_argument("--stats", action="store_true")
p.add_argument("--train", action="store_true")
a = p.parse_args()
if a.stats: show_stats()
elif a.train:
from rich.console import Console
trigger_training(FeedbackDB(), Console(), a)
else:
run_agent_loop(a)
if __name__ == "__main__":
main()
|