#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CodePilot v3 — 多模型 AI 開發助手 + 知識蒸餾
==============================================
支援多種模型後端:
🏠 Local: Qwen2.5-Coder-3B(你的本地模型)
☁️ Cloud: OpenAI (GPT-4o/5), Anthropic (Claude Opus), Google (Gemini)
🔗 Proxy: OpenRouter(一個 API 接所有模型)
知識蒸餾模式:
用 Opus/GPT-5 的回答,自動訓練你的本地模型 → 免費版的 Opus!
Usage:
# 用本地模型
codepilot
# 用 Claude Opus(並自動收集訓練數據)
codepilot --provider anthropic --api-key sk-xxx
# 用 OpenRouter(最方便,一個 key 用所有模型)
codepilot --provider openrouter --api-key sk-xxx --cloud-model anthropic/claude-opus-4
# 蒸餾模式:用雲端模型產生數據,訓練本地模型
codepilot --distill --provider openrouter --api-key sk-xxx
# 用收集的雲端數據訓練本地模型
codepilot --train
"""
import argparse, difflib, json, os, re, shutil, sqlite3, subprocess, sys, torch, httpx
from datetime import datetime
from pathlib import Path
# ============================================================
# CONFIG
# ============================================================
DEFAULT_LOCAL_MODEL = "Qwen/Qwen2.5-Coder-3B-Instruct"
CONFIG_DIR = os.path.expanduser("~/.codepilot")
DB_PATH = os.path.join(CONFIG_DIR, "feedback.db")
# 雲端模型預設
PROVIDER_CONFIGS = {
"local": {
"name": "Local (Qwen2.5-Coder-3B)",
"type": "local",
},
"openai": {
"name": "OpenAI",
"type": "openai",
"base_url": "https://api.openai.com/v1",
"default_model": "gpt-4o",
},
"anthropic": {
"name": "Anthropic",
"type": "anthropic",
"base_url": "https://api.anthropic.com/v1",
"default_model": "claude-sonnet-4-20250514",
},
"openrouter": {
"name": "OpenRouter (所有模型)",
"type": "openai", # OpenRouter 用 OpenAI 相容 API
"base_url": "https://openrouter.ai/api/v1",
"default_model": "anthropic/claude-sonnet-4",
},
"ollama": {
"name": "Ollama (本地)",
"type": "openai",
"base_url": "http://localhost:11434/v1",
"default_model": "qwen2.5-coder:3b",
},
}
# ============================================================
# FEEDBACK DB (升級版 — 記錄來源模型)
# ============================================================
class FeedbackDB:
def __init__(self):
os.makedirs(CONFIG_DIR, exist_ok=True)
self.conn = sqlite3.connect(DB_PATH)
self.conn.execute("""CREATE TABLE IF NOT EXISTS feedback (
id INTEGER PRIMARY KEY, timestamp TEXT, prompt TEXT, completion TEXT,
label INTEGER, edited_completion TEXT, project TEXT,
source_model TEXT, provider TEXT)""")
self.conn.commit()
def save(self, prompt, completion, label, edited=None, project=None,
source_model=None, provider=None):
self.conn.execute(
"INSERT INTO feedback VALUES (NULL,?,?,?,?,?,?,?,?)",
(datetime.now().isoformat(), prompt, completion, int(label),
edited, project, source_model, provider))
self.conn.commit()
def count(self, provider=None):
if provider:
r = self.conn.execute(
"SELECT COUNT(*), COALESCE(SUM(label),0), "
"SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) "
"FROM feedback WHERE provider=?", (provider,)).fetchone()
else:
r = self.conn.execute(
"SELECT COUNT(*), COALESCE(SUM(label),0), "
"SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) "
"FROM feedback").fetchone()
return {"total": r[0], "up": int(r[1]), "edits": int(r[2] or 0)}
def export_sft(self, only_cloud=False):
"""匯出 SFT 數據(可選只匯出雲端模型的)"""
if only_cloud:
# 雲端模型接受的回答 = 高品質 SFT 數據
rows = self.conn.execute(
"SELECT prompt, completion FROM feedback "
"WHERE label=1 AND provider != 'local' AND provider IS NOT NULL"
).fetchall()
else:
rows = self.conn.execute(
"SELECT prompt, COALESCE(edited_completion, completion) FROM feedback "
"WHERE label=1"
).fetchall()
return [{"messages": [
{"role": "user", "content": p},
{"role": "assistant", "content": c},
]} for p, c in rows]
def export_dpo(self):
"""用雲端模型 vs 本地模型的回答配對成 DPO 數據"""
# 找相同 prompt 但不同 provider 的配對
rows = self.conn.execute("""
SELECT c.prompt, c.completion, l.completion
FROM feedback c
JOIN feedback l ON c.prompt = l.prompt
WHERE c.provider != 'local' AND c.label = 1
AND l.provider = 'local' AND l.label = 0
""").fetchall()
return [{
"prompt": [{"role": "user", "content": p}],
"chosen": [{"role": "assistant", "content": cloud}],
"rejected": [{"role": "assistant", "content": local}],
} for p, cloud, local in rows]
def export_kto(self):
rows = self.conn.execute(
"SELECT prompt, completion, label FROM feedback"
).fetchall()
return [{
"prompt": [{"role": "user", "content": p}],
"completion": [{"role": "assistant", "content": c}],
"label": bool(l),
} for p, c, l in rows]
# ============================================================
# MODEL BACKENDS
# ============================================================
class LocalModel:
"""本地 Qwen 模型"""
def __init__(self, model_name=DEFAULT_LOCAL_MODEL, adapter_path=None):
from transformers import AutoTokenizer, AutoModelForCausalLM
self.name = model_name.split("/")[-1]
self.provider = "local"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16,
device_map="auto", trust_remote_code=True)
if adapter_path and os.path.exists(adapter_path):
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, adapter_path)
self.model.eval()
def chat(self, messages, max_tokens=4096):
text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
out = self.model.generate(
**inputs, max_new_tokens=max_tokens, do_sample=True,
temperature=0.7, top_p=0.9, repetition_penalty=1.1,
pad_token_id=self.tokenizer.pad_token_id)
return self.tokenizer.decode(
out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
class CloudModel:
"""雲端模型(OpenAI 相容 API)"""
def __init__(self, provider_key, api_key, model_name=None):
config = PROVIDER_CONFIGS[provider_key]
self.provider = provider_key
self.base_url = config["base_url"]
self.name = model_name or config["default_model"]
self.api_key = api_key
self.api_type = config["type"]
def chat(self, messages, max_tokens=4096):
if self.api_type == "anthropic":
return self._chat_anthropic(messages, max_tokens)
else:
return self._chat_openai(messages, max_tokens)
def _chat_openai(self, messages, max_tokens):
"""OpenAI / OpenRouter / Ollama 相容 API"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# OpenRouter 需要額外 header
if self.provider == "openrouter":
headers["HTTP-Referer"] = "https://codepilot.local"
headers["X-Title"] = "CodePilot"
data = {
"model": self.name,
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.7,
}
resp = httpx.post(
f"{self.base_url}/chat/completions",
headers=headers, json=data, timeout=120)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]
def _chat_anthropic(self, messages, max_tokens):
"""Anthropic API"""
# 分離 system message
system = None
chat_msgs = []
for m in messages:
if m["role"] == "system":
system = m["content"]
else:
chat_msgs.append(m)
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}
data = {
"model": self.name,
"messages": chat_msgs,
"max_tokens": max_tokens,
"temperature": 0.7,
}
if system:
data["system"] = system
resp = httpx.post(
f"{self.base_url}/messages",
headers=headers, json=data, timeout=120)
resp.raise_for_status()
return resp.json()["content"][0]["text"]
# ============================================================
# PROJECT TOOLS(和 v2 相同)
# ============================================================
class ProjectTools:
def __init__(self, project_dir):
self.project_dir = os.path.abspath(project_dir)
self.cwd = self.project_dir
self.read_cache = {}
def _resolve(self, path):
return path if os.path.isabs(path) else os.path.normpath(os.path.join(self.cwd, path))
def read_file(self, path, offset=1, limit=200):
full = self._resolve(path)
if not os.path.exists(full): return f"❌ 不存在: {path}"
try:
content = Path(full).read_text(encoding="utf-8", errors="replace")
lines = content.splitlines()
self.read_cache[full] = {"time": os.path.getmtime(full), "content": content}
selected = lines[offset-1:offset-1+limit]
result = "\n".join(f"{i+offset:4d} │ {line}" for i, line in enumerate(selected))
if offset + limit < len(lines): result += f"\n... ({len(lines)-offset-limit+1} more)"
return result
except Exception as e: return f"❌ {e}"
def edit_file(self, path, old_string, new_string):
full = self._resolve(path)
if full not in self.read_cache: return "❌ 必須先 read_file"
content = Path(full).read_text(encoding="utf-8")
if os.path.getmtime(full) != self.read_cache[full]["time"]: return "❌ 文件已被外部修改"
count = content.count(old_string)
if count == 0: return "❌ 找不到要替換的文字"
if count > 1: return f"❌ 找到 {count} 處,請提供更多上下文"
new_content = content.replace(old_string, new_string, 1)
diff = "".join(difflib.unified_diff(
content.splitlines(keepends=True), new_content.splitlines(keepends=True),
fromfile=f"a/{path}", tofile=f"b/{path}"))
Path(full).write_text(new_content, encoding="utf-8")
self.read_cache[full] = {"time": os.path.getmtime(full), "content": new_content}
return "✅ 已修改:\n" + diff
def write_file(self, path, content):
full = self._resolve(path); os.makedirs(os.path.dirname(full) or ".", exist_ok=True)
is_new = not os.path.exists(full)
Path(full).write_text(content, encoding="utf-8")
self.read_cache[full] = {"time": os.path.getmtime(full), "content": content}
return f"✅ {'建立' if is_new else '覆寫'}: {path}"
def run_command(self, command, timeout=120):
for d in {"rm -rf /","git push --force","git reset --hard"}:
if d in command: return f"⛔ 危險: {command}"
try:
r = subprocess.run(command, shell=True, cwd=self.cwd, capture_output=True, text=True, timeout=timeout)
return (r.stdout + (f"\nSTDERR:\n{r.stderr}" if r.stderr else ""))[:10000]
except subprocess.TimeoutExpired: return f"⏰ 超時"
except Exception as e: return f"❌ {e}"
def search_files(self, pattern, glob_pattern=None):
rg = shutil.which("rg"); cmd = [rg or "grep", "-rn"]
if rg: cmd += ["--color=never", "--max-count=50"]
if glob_pattern and rg: cmd += ["--glob", glob_pattern]
cmd += [pattern, self.cwd]
try: return subprocess.run(cmd, capture_output=True, text=True, timeout=30).stdout[:5000] or "無匹配"
except Exception as e: return f"❌ {e}"
def list_files(self, pattern="*", max_depth=3):
files = []
for root, dirs, fnames in os.walk(self.cwd):
dirs[:] = [d for d in dirs if d not in {".git","node_modules","__pycache__",".venv","dist","build"}]
if root.replace(self.cwd, "").count(os.sep) >= max_depth: continue
files.extend(os.path.relpath(os.path.join(root, f), self.cwd) for f in fnames if Path(f).match(pattern))
return "\n".join(sorted(files)[:100])
def git_context(self):
try:
b = subprocess.run(["git","branch","--show-current"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip()
s = subprocess.run(["git","status","--short"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip()
l = subprocess.run(["git","log","--oneline","-5"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip()
return f"Branch: {b}\nStatus:\n{s}\nRecent:\n{l}"
except: return "(not a git repo)"
# ============================================================
# TOOL PARSER + EXECUTOR + SYSTEM PROMPT(和 v2 相同)
# ============================================================
TOOL_PATTERN = re.compile(r'\s*(\w+)\s*\n(.*?)', re.DOTALL)
def parse_tool_calls(text):
calls = []
for m in TOOL_PATTERN.finditer(text):
try: params = json.loads(m.group(2).strip())
except:
params = {}
for line in m.group(2).strip().split("\n"):
if ":" in line: k, v = line.split(":", 1); params[k.strip()] = v.strip().strip('"')
calls.append({"tool": m.group(1), "params": params})
return calls
def execute_tool(tools, call):
n, p = call["tool"], call["params"]
try:
if n == "read_file": return tools.read_file(p.get("path",""), int(p.get("offset",1)), int(p.get("limit",200)))
elif n == "edit_file": return tools.edit_file(p.get("path",""), p.get("old_string",""), p.get("new_string",""))
elif n == "write_file": return tools.write_file(p.get("path",""), p.get("content",""))
elif n == "run_command": return tools.run_command(p.get("command",""), int(p.get("timeout",120)))
elif n == "search_files": return tools.search_files(p.get("pattern",""), p.get("glob"))
elif n == "list_files": return tools.list_files(p.get("pattern","*"), int(p.get("max_depth",3)))
elif n == "git_status": return tools.git_context()
else: return f"❌ 未知: {n}"
except Exception as e: return f"❌ {e}"
def build_system_prompt(tools):
return f"""You are CodePilot, an expert AI programming assistant working in the user's project.
Working directory: {tools.cwd}
{tools.git_context()}
## Tools (use name\n{{json}})
- read_file: {{"path":"...","offset":1,"limit":200}}
- edit_file: {{"path":"...","old_string":"...","new_string":"..."}} (must read first)
- write_file: {{"path":"...","content":"..."}}
- run_command: {{"command":"...","timeout":120}}
- search_files: {{"pattern":"...","glob":"*.py"}}
- list_files: {{"pattern":"*","max_depth":3}}
- git_status: {{}}
Rules: read before edit, old_string must be unique, prefer edit over write, verify changes."""
# ============================================================
# MAIN AGENT LOOP
# ============================================================
def run_agent_loop(args):
from rich.console import Console
from rich.markdown import Markdown
from rich.panel import Panel
from rich.prompt import Prompt
from rich.syntax import Syntax
from rich.table import Table
console = Console()
db = FeedbackDB()
project_dir = args.project or os.getcwd()
tools = ProjectTools(project_dir)
provider_key = args.provider or "local"
# 決定使用哪個模型
if provider_key == "local":
model_label = args.model or DEFAULT_LOCAL_MODEL
with console.status("[bold green]載入本地模型..."):
model = LocalModel(model_label, args.adapter)
else:
if not args.api_key:
console.print(f"[red]❌ 使用 {provider_key} 需要 --api-key[/]")
sys.exit(1)
cloud_model = args.cloud_model or PROVIDER_CONFIGS[provider_key]["default_model"]
model = CloudModel(provider_key, args.api_key, cloud_model)
model_label = f"{PROVIDER_CONFIGS[provider_key]['name']}/{model.name}"
# 蒸餾模式標記
distill_mode = args.distill and provider_key != "local"
# Banner
banner = f"[bold cyan]CodePilot v3[/]"
if distill_mode:
banner += " [bold yellow]⚗️ 蒸餾模式[/]"
banner += f"\n[dim]Model: {model_label}\nProject: {project_dir}[/]"
if distill_mode:
banner += f"\n[yellow]雲端回答將自動收集為本地模型的訓練數據[/]"
console.print(Panel.fit(banner, border_style="cyan"))
if provider_key == "local":
console.print("[green]✅ 本地模型載入完成[/]")
else:
console.print(f"[green]✅ 已連接 {PROVIDER_CONFIGS[provider_key]['name']}[/]")
git_ctx = tools.git_context()
if git_ctx != "(not a git repo)":
console.print(Panel(git_ctx, title="📂 Project", border_style="dim"))
console.print("[dim]指令: /ls /git /clear /switch /compare /status /train /quit[/]\n")
# 保存模型參照,讓 /compare 可以用
local_model_ref = None
cloud_model_ref = None
if provider_key == "local":
local_model_ref = model
else:
cloud_model_ref = model
# 蒸餾/compare 模式下,也嘗試載入本地模型
if args.adapter or distill_mode:
try:
with console.status("[dim]同時載入本地模型 (for /compare)..."):
local_model_ref = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
console.print("[dim]✅ 本地模型也已載入,可用 /compare[/]")
except Exception:
console.print("[dim]⚠️ 本地模型載入失敗,/compare 不可用[/]")
system_prompt = build_system_prompt(tools)
messages = [{"role": "system", "content": system_prompt}]
while True:
try: user_input = Prompt.ask("\n[bold green]🧑 You")
except (EOFError, KeyboardInterrupt): break
if not user_input.strip(): continue
cmd = user_input.strip()
if cmd in ("/quit", "/exit"): break
elif cmd == "/status":
s_all = db.count(); s_cloud = db.count("local")
t = Table(title="📊 數據統計"); t.add_column("來源"); t.add_column("數量"); t.add_column("👍")
t.add_row("全部", str(s_all["total"]), str(s_all["up"]))
for pk in ["local", "openai", "anthropic", "openrouter"]:
sc = db.count(pk)
if sc["total"] > 0: t.add_row(pk, str(sc["total"]), str(sc["up"]))
console.print(t)
# 蒸餾數據
sft = db.export_sft(only_cloud=True)
dpo = db.export_dpo()
if sft or dpo:
console.print(f"\n [yellow]⚗️ 可蒸餾數據: SFT {len(sft)} / DPO {len(dpo)}[/]")
console.print(f" [dim]運行 codepilot --train 開始蒸餾訓練[/]")
continue
elif cmd == "/train":
trigger_training(db, console, args); continue
elif cmd == "/clear":
messages = [{"role": "system", "content": system_prompt}]
console.print("[dim]已清除[/]"); continue
elif cmd == "/git":
console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue
elif cmd.startswith("/ls"):
console.print(tools.list_files(cmd[3:].strip() or "*")); continue
elif cmd == "/compare" or cmd.startswith("/compare "):
# /compare 模式:同一問題自動送給本地+雲端,並排比較,一鍵產生 DPO
compare_question = cmd[8:].strip() if cmd.startswith("/compare ") else None
if not compare_question:
compare_question = Prompt.ask(" 輸入要比較的問題")
if not compare_question.strip():
continue
need_local = local_model_ref or (provider_key == "local" and model)
need_cloud = cloud_model_ref or (provider_key != "local" and model)
if not need_local or not need_cloud:
console.print("[yellow]⚠️ /compare 需要同時有本地和雲端模型[/]")
console.print("[dim]啟動方式: codepilot --provider openrouter --api-key sk-xxx --adapter ./adapter[/]")
continue
lm = local_model_ref if local_model_ref else model
cm = cloud_model_ref if cloud_model_ref else model
compare_msgs = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": compare_question},
]
# 同時生成兩個回答
with console.status("[bold cyan]本地模型思考中..."):
try: local_resp = lm.chat(compare_msgs)
except Exception as e: local_resp = f"(錯誤: {e})"
with console.status("[bold magenta]雲端模型思考中..."):
try: cloud_resp = cm.chat(compare_msgs)
except Exception as e: cloud_resp = f"(錯誤: {e})"
# 並排顯示
console.print(f"\n[bold]🔄 Compare: {compare_question[:60]}{'...' if len(compare_question)>60 else ''}[/]\n")
console.print(Panel(Markdown(local_resp), title=f"🏠 Local ({lm.name})", border_style="blue"))
console.print(Panel(Markdown(cloud_resp), title=f"☁️ Cloud ({cm.name})", border_style="magenta"))
# 選擇
console.print(f"\n [green]1[/] = 本地較好 [magenta]2[/] = 雲端較好 [yellow]b[/] = 都好 [red]x[/] = 都差 Enter = 跳過")
choice = Prompt.ask(" ", choices=["1","2","b","x",""], default="", show_choices=False)
if choice == "2":
# 雲端好,本地差 → DPO: chosen=cloud, rejected=local
db.save(compare_question, cloud_resp, 1, project=project_dir,
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
db.save(compare_question, local_resp, 0, project=project_dir,
source_model=getattr(lm, "name", "local"), provider="local")
dpo_count = len(db.export_dpo())
console.print(f" [magenta]☁️ 雲端勝 → DPO +1[/] (累計 DPO 對: {dpo_count})")
elif choice == "1":
# 本地好 → 記錄本地為正面
db.save(compare_question, local_resp, 1, project=project_dir,
source_model=getattr(lm, "name", "local"), provider="local")
db.save(compare_question, cloud_resp, 0, project=project_dir,
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
console.print(f" [green]🏠 本地勝!你的模型在進步![/]")
elif choice == "b":
# 都好 → 兩個都記為 SFT
db.save(compare_question, local_resp, 1, project=project_dir,
source_model=getattr(lm, "name", "local"), provider="local")
db.save(compare_question, cloud_resp, 1, project=project_dir,
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
console.print(f" [yellow]👍 都好 → SFT +2[/]")
elif choice == "x":
# 都差
db.save(compare_question, local_resp, 0, project=project_dir,
source_model=getattr(lm, "name", "local"), provider="local")
db.save(compare_question, cloud_resp, 0, project=project_dir,
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
console.print(f" [red]👎 都差[/]")
continue
elif cmd == "/switch":
console.print("可用: local, openai, anthropic, openrouter, ollama")
new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys()))
if new_p == "local":
with console.status("載入本地模型..."):
model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
local_model_ref = model; provider_key = "local"
else:
key = args.api_key or Prompt.ask("API Key")
cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"])
model = CloudModel(new_p, key, cm); cloud_model_ref = model; provider_key = new_p
console.print(f"[green]✅ 切換到 {provider_key}[/]"); continue
messages.append({"role": "user", "content": user_input})
full_response = ""
for rnd in range(10):
with console.status(f"[bold cyan]{'思考中' if rnd == 0 else f'工具 round {rnd+1}'}..."):
try:
response = model.chat(messages)
except Exception as e:
console.print(f"[red]❌ API 錯誤: {e}[/]"); break
tool_calls = parse_tool_calls(response)
text_parts = TOOL_PATTERN.sub("", response).strip()
if text_parts:
console.print(f"\n[bold blue]🤖 CodePilot ({model.name if hasattr(model,'name') else 'local'}):[/]")
console.print(Markdown(text_parts))
full_response += response + "\n"
if not tool_calls: break
messages.append({"role": "assistant", "content": response})
results = []
for call in tool_calls:
console.print(f" [dim]🔧 {call['tool']}[/]")
result = execute_tool(tools, call)
if call["tool"] == "edit_file" and "✅" in result:
d = result.split("\n", 1)[1] if "\n" in result else ""
if d: console.print(Syntax(d, "diff", theme="monokai"))
elif call["tool"] == "run_command":
console.print(Panel(result[:500], title="Terminal", border_style="dim"))
else:
console.print(f" [dim]{result[:200]}[/]")
results.append(f"[{call['tool']}] {result}")
messages.append({"role": "user", "content": "Tool results:\n" + "\n\n".join(results)})
# 回饋
if distill_mode:
# 蒸餾模式:自動接受雲端回答
db.save(user_input, full_response, 1, project=project_dir,
source_model=getattr(model, "name", "local"), provider=provider_key)
console.print(f" [yellow]⚗️ 自動記錄為訓練數據[/]")
else:
console.print(f"\n[dim][green]y[/]=👍 [red]n[/]=👎 [yellow]e[/]=✏️ Enter=跳過[/]")
fb = Prompt.ask(" ", choices=["y","n","e",""], default="", show_choices=False)
if fb == "y":
db.save(user_input, full_response, 1, project=project_dir,
source_model=getattr(model, "name", "local"), provider=provider_key)
console.print(" [green]👍[/]")
elif fb == "n":
db.save(user_input, full_response, 0, project=project_dir,
source_model=getattr(model, "name", "local"), provider=provider_key)
console.print(" [red]👎[/]")
elif fb == "e":
console.print(" [yellow]貼上修改版(END結束):[/]")
lines = []
while True:
try:
l = input()
if l.strip() == "END": break
lines.append(l)
except EOFError: break
edited = "\n".join(lines)
if edited.strip():
db.save(user_input, full_response, 1, edited=edited, project=project_dir,
source_model=getattr(model, "name", "local"), provider=provider_key)
console.print(" [yellow]✏️[/]")
if messages[-1]["role"] == "user" and "Tool results:" in messages[-1]["content"]:
messages.append({"role": "assistant", "content": full_response})
console.print("\n[cyan]👋[/]")
# ============================================================
# TRAINING(支援蒸餾)
# ============================================================
def trigger_training(db, console, args):
stats = db.count()
if stats["total"] == 0:
console.print("[yellow]⚠️ 無數據[/]"); return
# 蒸餾數據(雲端模型產生的)
cloud_sft = db.export_sft(only_cloud=True)
dpo_data = db.export_dpo()
all_sft = db.export_sft(only_cloud=False)
kto_data = db.export_kto()
console.print(f"\n[bold]🚀 訓練數據統計[/]")
console.print(f" ⚗️ 雲端蒸餾 SFT: {len(cloud_sft)} 條 (Opus/GPT 的回答)")
console.print(f" 📊 雲端 vs 本地 DPO: {len(dpo_data)} 對")
console.print(f" 📚 全部 SFT: {len(all_sft)} 條")
console.print(f" 👍👎 KTO: {len(kto_data)} 條")
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training
model_name = args.model or DEFAULT_LOCAL_MODEL
output_dir = os.path.join(CONFIG_DIR, f"adapter_{datetime.now().strftime('%Y%m%d_%H%M')}")
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
peft_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
# 優先級:蒸餾 SFT > DPO > 全部 SFT > KTO
train_data = cloud_sft or all_sft
if train_data:
label = "⚗️ 蒸餾 SFT" if cloud_sft else "📚 SFT"
console.print(f"\n[bold]{label} ({len(train_data)} 條)...[/]")
from trl import SFTTrainer, SFTConfig
model = AutoModelForCausalLM.from_pretrained(
model_name, quantization_config=bnb, device_map="auto", trust_remote_code=True)
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = prepare_model_for_kbit_training(model)
trainer = SFTTrainer(
model=model,
args=SFTConfig(
output_dir=output_dir, learning_rate=2e-4, num_train_epochs=3,
per_device_train_batch_size=1, gradient_accumulation_steps=8,
max_seq_length=1024, gradient_checkpointing=True, bf16=True,
optim="paged_adamw_8bit", logging_steps=5, save_total_limit=1,
logging_strategy="steps", logging_first_step=True),
processing_class=tok,
train_dataset=Dataset.from_list(train_data),
peft_config=peft_cfg)
trainer.train()
trainer.save_model(output_dir)
del model; torch.cuda.empty_cache()
console.print(f"\n[bold green]🎉 完成![/]")
console.print(f" Adapter: {output_dir}")
console.print(f" 使用: codepilot --adapter {output_dir}")
def show_stats():
from rich.console import Console; from rich.table import Table
c = Console(); db = FeedbackDB()
t = Table(title="📊 CodePilot 數據統計")
t.add_column("來源", style="cyan"); t.add_column("總計"); t.add_column("👍"); t.add_column("✏️")
for pk in [None, "local", "openai", "anthropic", "openrouter"]:
s = db.count(pk)
if s["total"] > 0:
t.add_row(pk or "全部", str(s["total"]), str(s["up"]), str(s["edits"]))
c.print(t)
cloud_sft = db.export_sft(only_cloud=True)
dpo = db.export_dpo()
if cloud_sft or dpo:
c.print(f"\n [yellow]⚗️ 可蒸餾: SFT {len(cloud_sft)} / DPO {len(dpo)}[/]")
def main():
p = argparse.ArgumentParser(description="CodePilot v3 — 多模型 AI 開發助手")
p.add_argument("--model", type=str, default=None, help="本地模型")
p.add_argument("--adapter", type=str, default=None, help="LoRA adapter")
p.add_argument("--project", type=str, default=None, help="專案目錄")
p.add_argument("--provider", type=str, default=None,
choices=["local","openai","anthropic","openrouter","ollama"],
help="模型提供者")
p.add_argument("--api-key", type=str, default=None, help="API key")
p.add_argument("--cloud-model", type=str, default=None, help="雲端模型名稱")
p.add_argument("--distill", action="store_true", help="蒸餾模式(自動收集雲端回答)")
p.add_argument("--stats", action="store_true")
p.add_argument("--train", action="store_true")
a = p.parse_args()
if a.stats: show_stats()
elif a.train:
from rich.console import Console
trigger_training(FeedbackDB(), Console(), a)
else:
run_agent_loop(a)
if __name__ == "__main__":
main()