Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| CodePilot v3 — 多模型 AI 開發助手 + 知識蒸餾 | |
| ============================================== | |
| 支援多種模型後端: | |
| 🏠 Local: Qwen2.5-Coder-3B(你的本地模型) | |
| ☁️ Cloud: OpenAI (GPT-4o/5), Anthropic (Claude Opus), Google (Gemini) | |
| 🔗 Proxy: OpenRouter(一個 API 接所有模型) | |
| 知識蒸餾模式: | |
| 用 Opus/GPT-5 的回答,自動訓練你的本地模型 → 免費版的 Opus! | |
| Usage: | |
| # 用本地模型 | |
| codepilot | |
| # 用 Claude Opus(並自動收集訓練數據) | |
| codepilot --provider anthropic --api-key sk-xxx | |
| # 用 OpenRouter(最方便,一個 key 用所有模型) | |
| codepilot --provider openrouter --api-key sk-xxx --cloud-model anthropic/claude-opus-4 | |
| # 蒸餾模式:用雲端模型產生數據,訓練本地模型 | |
| codepilot --distill --provider openrouter --api-key sk-xxx | |
| # 用收集的雲端數據訓練本地模型 | |
| codepilot --train | |
| """ | |
| import argparse, difflib, json, os, re, shutil, sqlite3, subprocess, sys, torch, httpx | |
| from datetime import datetime | |
| from pathlib import Path | |
| # ============================================================ | |
| # CONFIG | |
| # ============================================================ | |
| DEFAULT_LOCAL_MODEL = "Qwen/Qwen2.5-Coder-3B-Instruct" | |
| CONFIG_DIR = os.path.expanduser("~/.codepilot") | |
| DB_PATH = os.path.join(CONFIG_DIR, "feedback.db") | |
| # 雲端模型預設 | |
| PROVIDER_CONFIGS = { | |
| "local": { | |
| "name": "Local (Qwen2.5-Coder-3B)", | |
| "type": "local", | |
| }, | |
| "openai": { | |
| "name": "OpenAI", | |
| "type": "openai", | |
| "base_url": "https://api.openai.com/v1", | |
| "default_model": "gpt-4o", | |
| }, | |
| "anthropic": { | |
| "name": "Anthropic", | |
| "type": "anthropic", | |
| "base_url": "https://api.anthropic.com/v1", | |
| "default_model": "claude-sonnet-4-20250514", | |
| }, | |
| "openrouter": { | |
| "name": "OpenRouter (所有模型)", | |
| "type": "openai", # OpenRouter 用 OpenAI 相容 API | |
| "base_url": "https://openrouter.ai/api/v1", | |
| "default_model": "anthropic/claude-sonnet-4", | |
| }, | |
| "ollama": { | |
| "name": "Ollama (本地)", | |
| "type": "openai", | |
| "base_url": "http://localhost:11434/v1", | |
| "default_model": "qwen2.5-coder:3b", | |
| }, | |
| } | |
| # ============================================================ | |
| # FEEDBACK DB (升級版 — 記錄來源模型) | |
| # ============================================================ | |
| class FeedbackDB: | |
| def __init__(self): | |
| os.makedirs(CONFIG_DIR, exist_ok=True) | |
| self.conn = sqlite3.connect(DB_PATH) | |
| self.conn.execute("""CREATE TABLE IF NOT EXISTS feedback ( | |
| id INTEGER PRIMARY KEY, timestamp TEXT, prompt TEXT, completion TEXT, | |
| label INTEGER, edited_completion TEXT, project TEXT, | |
| source_model TEXT, provider TEXT)""") | |
| self.conn.commit() | |
| def save(self, prompt, completion, label, edited=None, project=None, | |
| source_model=None, provider=None): | |
| self.conn.execute( | |
| "INSERT INTO feedback VALUES (NULL,?,?,?,?,?,?,?,?)", | |
| (datetime.now().isoformat(), prompt, completion, int(label), | |
| edited, project, source_model, provider)) | |
| self.conn.commit() | |
| def count(self, provider=None): | |
| if provider: | |
| r = self.conn.execute( | |
| "SELECT COUNT(*), COALESCE(SUM(label),0), " | |
| "SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) " | |
| "FROM feedback WHERE provider=?", (provider,)).fetchone() | |
| else: | |
| r = self.conn.execute( | |
| "SELECT COUNT(*), COALESCE(SUM(label),0), " | |
| "SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) " | |
| "FROM feedback").fetchone() | |
| return {"total": r[0], "up": int(r[1]), "edits": int(r[2] or 0)} | |
| def export_sft(self, only_cloud=False): | |
| """匯出 SFT 數據(可選只匯出雲端模型的)""" | |
| if only_cloud: | |
| # 雲端模型接受的回答 = 高品質 SFT 數據 | |
| rows = self.conn.execute( | |
| "SELECT prompt, completion FROM feedback " | |
| "WHERE label=1 AND provider != 'local' AND provider IS NOT NULL" | |
| ).fetchall() | |
| else: | |
| rows = self.conn.execute( | |
| "SELECT prompt, COALESCE(edited_completion, completion) FROM feedback " | |
| "WHERE label=1" | |
| ).fetchall() | |
| return [{"messages": [ | |
| {"role": "user", "content": p}, | |
| {"role": "assistant", "content": c}, | |
| ]} for p, c in rows] | |
| def export_dpo(self): | |
| """用雲端模型 vs 本地模型的回答配對成 DPO 數據""" | |
| # 找相同 prompt 但不同 provider 的配對 | |
| rows = self.conn.execute(""" | |
| SELECT c.prompt, c.completion, l.completion | |
| FROM feedback c | |
| JOIN feedback l ON c.prompt = l.prompt | |
| WHERE c.provider != 'local' AND c.label = 1 | |
| AND l.provider = 'local' AND l.label = 0 | |
| """).fetchall() | |
| return [{ | |
| "prompt": [{"role": "user", "content": p}], | |
| "chosen": [{"role": "assistant", "content": cloud}], | |
| "rejected": [{"role": "assistant", "content": local}], | |
| } for p, cloud, local in rows] | |
| def export_kto(self): | |
| rows = self.conn.execute( | |
| "SELECT prompt, completion, label FROM feedback" | |
| ).fetchall() | |
| return [{ | |
| "prompt": [{"role": "user", "content": p}], | |
| "completion": [{"role": "assistant", "content": c}], | |
| "label": bool(l), | |
| } for p, c, l in rows] | |
| # ============================================================ | |
| # MODEL BACKENDS | |
| # ============================================================ | |
| class LocalModel: | |
| """本地 Qwen 模型""" | |
| def __init__(self, model_name=DEFAULT_LOCAL_MODEL, adapter_path=None): | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| self.name = model_name.split("/")[-1] | |
| self.provider = "local" | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, torch_dtype=torch.bfloat16, | |
| device_map="auto", trust_remote_code=True) | |
| if adapter_path and os.path.exists(adapter_path): | |
| from peft import PeftModel | |
| self.model = PeftModel.from_pretrained(self.model, adapter_path) | |
| self.model.eval() | |
| def chat(self, messages, max_tokens=4096): | |
| text = self.tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True) | |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| out = self.model.generate( | |
| **inputs, max_new_tokens=max_tokens, do_sample=True, | |
| temperature=0.7, top_p=0.9, repetition_penalty=1.1, | |
| pad_token_id=self.tokenizer.pad_token_id) | |
| return self.tokenizer.decode( | |
| out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
| class CloudModel: | |
| """雲端模型(OpenAI 相容 API)""" | |
| def __init__(self, provider_key, api_key, model_name=None): | |
| config = PROVIDER_CONFIGS[provider_key] | |
| self.provider = provider_key | |
| self.base_url = config["base_url"] | |
| self.name = model_name or config["default_model"] | |
| self.api_key = api_key | |
| self.api_type = config["type"] | |
| def chat(self, messages, max_tokens=4096): | |
| if self.api_type == "anthropic": | |
| return self._chat_anthropic(messages, max_tokens) | |
| else: | |
| return self._chat_openai(messages, max_tokens) | |
| def _chat_openai(self, messages, max_tokens): | |
| """OpenAI / OpenRouter / Ollama 相容 API""" | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| # OpenRouter 需要額外 header | |
| if self.provider == "openrouter": | |
| headers["HTTP-Referer"] = "https://codepilot.local" | |
| headers["X-Title"] = "CodePilot" | |
| data = { | |
| "model": self.name, | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": 0.7, | |
| } | |
| resp = httpx.post( | |
| f"{self.base_url}/chat/completions", | |
| headers=headers, json=data, timeout=120) | |
| resp.raise_for_status() | |
| return resp.json()["choices"][0]["message"]["content"] | |
| def _chat_anthropic(self, messages, max_tokens): | |
| """Anthropic API""" | |
| # 分離 system message | |
| system = None | |
| chat_msgs = [] | |
| for m in messages: | |
| if m["role"] == "system": | |
| system = m["content"] | |
| else: | |
| chat_msgs.append(m) | |
| headers = { | |
| "x-api-key": self.api_key, | |
| "Content-Type": "application/json", | |
| "anthropic-version": "2023-06-01", | |
| } | |
| data = { | |
| "model": self.name, | |
| "messages": chat_msgs, | |
| "max_tokens": max_tokens, | |
| "temperature": 0.7, | |
| } | |
| if system: | |
| data["system"] = system | |
| resp = httpx.post( | |
| f"{self.base_url}/messages", | |
| headers=headers, json=data, timeout=120) | |
| resp.raise_for_status() | |
| return resp.json()["content"][0]["text"] | |
| # ============================================================ | |
| # PROJECT TOOLS(和 v2 相同) | |
| # ============================================================ | |
| class ProjectTools: | |
| def __init__(self, project_dir): | |
| self.project_dir = os.path.abspath(project_dir) | |
| self.cwd = self.project_dir | |
| self.read_cache = {} | |
| def _resolve(self, path): | |
| return path if os.path.isabs(path) else os.path.normpath(os.path.join(self.cwd, path)) | |
| def read_file(self, path, offset=1, limit=200): | |
| full = self._resolve(path) | |
| if not os.path.exists(full): return f"❌ 不存在: {path}" | |
| try: | |
| content = Path(full).read_text(encoding="utf-8", errors="replace") | |
| lines = content.splitlines() | |
| self.read_cache[full] = {"time": os.path.getmtime(full), "content": content} | |
| selected = lines[offset-1:offset-1+limit] | |
| result = "\n".join(f"{i+offset:4d} │ {line}" for i, line in enumerate(selected)) | |
| if offset + limit < len(lines): result += f"\n... ({len(lines)-offset-limit+1} more)" | |
| return result | |
| except Exception as e: return f"❌ {e}" | |
| def edit_file(self, path, old_string, new_string): | |
| full = self._resolve(path) | |
| if full not in self.read_cache: return "❌ 必須先 read_file" | |
| content = Path(full).read_text(encoding="utf-8") | |
| if os.path.getmtime(full) != self.read_cache[full]["time"]: return "❌ 文件已被外部修改" | |
| count = content.count(old_string) | |
| if count == 0: return "❌ 找不到要替換的文字" | |
| if count > 1: return f"❌ 找到 {count} 處,請提供更多上下文" | |
| new_content = content.replace(old_string, new_string, 1) | |
| diff = "".join(difflib.unified_diff( | |
| content.splitlines(keepends=True), new_content.splitlines(keepends=True), | |
| fromfile=f"a/{path}", tofile=f"b/{path}")) | |
| Path(full).write_text(new_content, encoding="utf-8") | |
| self.read_cache[full] = {"time": os.path.getmtime(full), "content": new_content} | |
| return "✅ 已修改:\n" + diff | |
| def write_file(self, path, content): | |
| full = self._resolve(path); os.makedirs(os.path.dirname(full) or ".", exist_ok=True) | |
| is_new = not os.path.exists(full) | |
| Path(full).write_text(content, encoding="utf-8") | |
| self.read_cache[full] = {"time": os.path.getmtime(full), "content": content} | |
| return f"✅ {'建立' if is_new else '覆寫'}: {path}" | |
| def run_command(self, command, timeout=120): | |
| for d in {"rm -rf /","git push --force","git reset --hard"}: | |
| if d in command: return f"⛔ 危險: {command}" | |
| try: | |
| r = subprocess.run(command, shell=True, cwd=self.cwd, capture_output=True, text=True, timeout=timeout) | |
| return (r.stdout + (f"\nSTDERR:\n{r.stderr}" if r.stderr else ""))[:10000] | |
| except subprocess.TimeoutExpired: return f"⏰ 超時" | |
| except Exception as e: return f"❌ {e}" | |
| def search_files(self, pattern, glob_pattern=None): | |
| rg = shutil.which("rg"); cmd = [rg or "grep", "-rn"] | |
| if rg: cmd += ["--color=never", "--max-count=50"] | |
| if glob_pattern and rg: cmd += ["--glob", glob_pattern] | |
| cmd += [pattern, self.cwd] | |
| try: return subprocess.run(cmd, capture_output=True, text=True, timeout=30).stdout[:5000] or "無匹配" | |
| except Exception as e: return f"❌ {e}" | |
| def list_files(self, pattern="*", max_depth=3): | |
| files = [] | |
| for root, dirs, fnames in os.walk(self.cwd): | |
| dirs[:] = [d for d in dirs if d not in {".git","node_modules","__pycache__",".venv","dist","build"}] | |
| if root.replace(self.cwd, "").count(os.sep) >= max_depth: continue | |
| files.extend(os.path.relpath(os.path.join(root, f), self.cwd) for f in fnames if Path(f).match(pattern)) | |
| return "\n".join(sorted(files)[:100]) | |
| def git_context(self): | |
| try: | |
| b = subprocess.run(["git","branch","--show-current"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() | |
| s = subprocess.run(["git","status","--short"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() | |
| l = subprocess.run(["git","log","--oneline","-5"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() | |
| return f"Branch: {b}\nStatus:\n{s}\nRecent:\n{l}" | |
| except: return "(not a git repo)" | |
| # ============================================================ | |
| # TOOL PARSER + EXECUTOR + SYSTEM PROMPT(和 v2 相同) | |
| # ============================================================ | |
| TOOL_PATTERN = re.compile(r'<tool>\s*(\w+)\s*\n(.*?)</tool>', re.DOTALL) | |
| def parse_tool_calls(text): | |
| calls = [] | |
| for m in TOOL_PATTERN.finditer(text): | |
| try: params = json.loads(m.group(2).strip()) | |
| except: | |
| params = {} | |
| for line in m.group(2).strip().split("\n"): | |
| if ":" in line: k, v = line.split(":", 1); params[k.strip()] = v.strip().strip('"') | |
| calls.append({"tool": m.group(1), "params": params}) | |
| return calls | |
| def execute_tool(tools, call): | |
| n, p = call["tool"], call["params"] | |
| try: | |
| if n == "read_file": return tools.read_file(p.get("path",""), int(p.get("offset",1)), int(p.get("limit",200))) | |
| elif n == "edit_file": return tools.edit_file(p.get("path",""), p.get("old_string",""), p.get("new_string","")) | |
| elif n == "write_file": return tools.write_file(p.get("path",""), p.get("content","")) | |
| elif n == "run_command": return tools.run_command(p.get("command",""), int(p.get("timeout",120))) | |
| elif n == "search_files": return tools.search_files(p.get("pattern",""), p.get("glob")) | |
| elif n == "list_files": return tools.list_files(p.get("pattern","*"), int(p.get("max_depth",3))) | |
| elif n == "git_status": return tools.git_context() | |
| else: return f"❌ 未知: {n}" | |
| except Exception as e: return f"❌ {e}" | |
| def build_system_prompt(tools): | |
| return f"""You are CodePilot, an expert AI programming assistant working in the user's project. | |
| Working directory: {tools.cwd} | |
| {tools.git_context()} | |
| ## Tools (use <tool>name\n{{json}}</tool>) | |
| - read_file: {{"path":"...","offset":1,"limit":200}} | |
| - edit_file: {{"path":"...","old_string":"...","new_string":"..."}} (must read first) | |
| - write_file: {{"path":"...","content":"..."}} | |
| - run_command: {{"command":"...","timeout":120}} | |
| - search_files: {{"pattern":"...","glob":"*.py"}} | |
| - list_files: {{"pattern":"*","max_depth":3}} | |
| - git_status: {{}} | |
| Rules: read before edit, old_string must be unique, prefer edit over write, verify changes.""" | |
| # ============================================================ | |
| # MAIN AGENT LOOP | |
| # ============================================================ | |
| def run_agent_loop(args): | |
| from rich.console import Console | |
| from rich.markdown import Markdown | |
| from rich.panel import Panel | |
| from rich.prompt import Prompt | |
| from rich.syntax import Syntax | |
| from rich.table import Table | |
| console = Console() | |
| db = FeedbackDB() | |
| project_dir = args.project or os.getcwd() | |
| tools = ProjectTools(project_dir) | |
| provider_key = args.provider or "local" | |
| # 決定使用哪個模型 | |
| if provider_key == "local": | |
| model_label = args.model or DEFAULT_LOCAL_MODEL | |
| with console.status("[bold green]載入本地模型..."): | |
| model = LocalModel(model_label, args.adapter) | |
| else: | |
| if not args.api_key: | |
| console.print(f"[red]❌ 使用 {provider_key} 需要 --api-key[/]") | |
| sys.exit(1) | |
| cloud_model = args.cloud_model or PROVIDER_CONFIGS[provider_key]["default_model"] | |
| model = CloudModel(provider_key, args.api_key, cloud_model) | |
| model_label = f"{PROVIDER_CONFIGS[provider_key]['name']}/{model.name}" | |
| # 蒸餾模式標記 | |
| distill_mode = args.distill and provider_key != "local" | |
| # Banner | |
| banner = f"[bold cyan]CodePilot v3[/]" | |
| if distill_mode: | |
| banner += " [bold yellow]⚗️ 蒸餾模式[/]" | |
| banner += f"\n[dim]Model: {model_label}\nProject: {project_dir}[/]" | |
| if distill_mode: | |
| banner += f"\n[yellow]雲端回答將自動收集為本地模型的訓練數據[/]" | |
| console.print(Panel.fit(banner, border_style="cyan")) | |
| if provider_key == "local": | |
| console.print("[green]✅ 本地模型載入完成[/]") | |
| else: | |
| console.print(f"[green]✅ 已連接 {PROVIDER_CONFIGS[provider_key]['name']}[/]") | |
| git_ctx = tools.git_context() | |
| if git_ctx != "(not a git repo)": | |
| console.print(Panel(git_ctx, title="📂 Project", border_style="dim")) | |
| console.print("[dim]指令: /ls /git /clear /switch /compare /status /train /quit[/]\n") | |
| # 保存模型參照,讓 /compare 可以用 | |
| local_model_ref = None | |
| cloud_model_ref = None | |
| if provider_key == "local": | |
| local_model_ref = model | |
| else: | |
| cloud_model_ref = model | |
| # 蒸餾/compare 模式下,也嘗試載入本地模型 | |
| if args.adapter or distill_mode: | |
| try: | |
| with console.status("[dim]同時載入本地模型 (for /compare)..."): | |
| local_model_ref = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter) | |
| console.print("[dim]✅ 本地模型也已載入,可用 /compare[/]") | |
| except Exception: | |
| console.print("[dim]⚠️ 本地模型載入失敗,/compare 不可用[/]") | |
| system_prompt = build_system_prompt(tools) | |
| messages = [{"role": "system", "content": system_prompt}] | |
| while True: | |
| try: user_input = Prompt.ask("\n[bold green]🧑 You") | |
| except (EOFError, KeyboardInterrupt): break | |
| if not user_input.strip(): continue | |
| cmd = user_input.strip() | |
| if cmd in ("/quit", "/exit"): break | |
| elif cmd == "/status": | |
| s_all = db.count(); s_cloud = db.count("local") | |
| t = Table(title="📊 數據統計"); t.add_column("來源"); t.add_column("數量"); t.add_column("👍") | |
| t.add_row("全部", str(s_all["total"]), str(s_all["up"])) | |
| for pk in ["local", "openai", "anthropic", "openrouter"]: | |
| sc = db.count(pk) | |
| if sc["total"] > 0: t.add_row(pk, str(sc["total"]), str(sc["up"])) | |
| console.print(t) | |
| # 蒸餾數據 | |
| sft = db.export_sft(only_cloud=True) | |
| dpo = db.export_dpo() | |
| if sft or dpo: | |
| console.print(f"\n [yellow]⚗️ 可蒸餾數據: SFT {len(sft)} / DPO {len(dpo)}[/]") | |
| console.print(f" [dim]運行 codepilot --train 開始蒸餾訓練[/]") | |
| continue | |
| elif cmd == "/train": | |
| trigger_training(db, console, args); continue | |
| elif cmd == "/clear": | |
| messages = [{"role": "system", "content": system_prompt}] | |
| console.print("[dim]已清除[/]"); continue | |
| elif cmd == "/git": | |
| console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue | |
| elif cmd.startswith("/ls"): | |
| console.print(tools.list_files(cmd[3:].strip() or "*")); continue | |
| elif cmd == "/compare" or cmd.startswith("/compare "): | |
| # /compare 模式:同一問題自動送給本地+雲端,並排比較,一鍵產生 DPO | |
| compare_question = cmd[8:].strip() if cmd.startswith("/compare ") else None | |
| if not compare_question: | |
| compare_question = Prompt.ask(" 輸入要比較的問題") | |
| if not compare_question.strip(): | |
| continue | |
| need_local = local_model_ref or (provider_key == "local" and model) | |
| need_cloud = cloud_model_ref or (provider_key != "local" and model) | |
| if not need_local or not need_cloud: | |
| console.print("[yellow]⚠️ /compare 需要同時有本地和雲端模型[/]") | |
| console.print("[dim]啟動方式: codepilot --provider openrouter --api-key sk-xxx --adapter ./adapter[/]") | |
| continue | |
| lm = local_model_ref if local_model_ref else model | |
| cm = cloud_model_ref if cloud_model_ref else model | |
| compare_msgs = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": compare_question}, | |
| ] | |
| # 同時生成兩個回答 | |
| with console.status("[bold cyan]本地模型思考中..."): | |
| try: local_resp = lm.chat(compare_msgs) | |
| except Exception as e: local_resp = f"(錯誤: {e})" | |
| with console.status("[bold magenta]雲端模型思考中..."): | |
| try: cloud_resp = cm.chat(compare_msgs) | |
| except Exception as e: cloud_resp = f"(錯誤: {e})" | |
| # 並排顯示 | |
| console.print(f"\n[bold]🔄 Compare: {compare_question[:60]}{'...' if len(compare_question)>60 else ''}[/]\n") | |
| console.print(Panel(Markdown(local_resp), title=f"🏠 Local ({lm.name})", border_style="blue")) | |
| console.print(Panel(Markdown(cloud_resp), title=f"☁️ Cloud ({cm.name})", border_style="magenta")) | |
| # 選擇 | |
| console.print(f"\n [green]1[/] = 本地較好 [magenta]2[/] = 雲端較好 [yellow]b[/] = 都好 [red]x[/] = 都差 Enter = 跳過") | |
| choice = Prompt.ask(" ", choices=["1","2","b","x",""], default="", show_choices=False) | |
| if choice == "2": | |
| # 雲端好,本地差 → DPO: chosen=cloud, rejected=local | |
| db.save(compare_question, cloud_resp, 1, project=project_dir, | |
| source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud")) | |
| db.save(compare_question, local_resp, 0, project=project_dir, | |
| source_model=getattr(lm, "name", "local"), provider="local") | |
| dpo_count = len(db.export_dpo()) | |
| console.print(f" [magenta]☁️ 雲端勝 → DPO +1[/] (累計 DPO 對: {dpo_count})") | |
| elif choice == "1": | |
| # 本地好 → 記錄本地為正面 | |
| db.save(compare_question, local_resp, 1, project=project_dir, | |
| source_model=getattr(lm, "name", "local"), provider="local") | |
| db.save(compare_question, cloud_resp, 0, project=project_dir, | |
| source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud")) | |
| console.print(f" [green]🏠 本地勝!你的模型在進步![/]") | |
| elif choice == "b": | |
| # 都好 → 兩個都記為 SFT | |
| db.save(compare_question, local_resp, 1, project=project_dir, | |
| source_model=getattr(lm, "name", "local"), provider="local") | |
| db.save(compare_question, cloud_resp, 1, project=project_dir, | |
| source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud")) | |
| console.print(f" [yellow]👍 都好 → SFT +2[/]") | |
| elif choice == "x": | |
| # 都差 | |
| db.save(compare_question, local_resp, 0, project=project_dir, | |
| source_model=getattr(lm, "name", "local"), provider="local") | |
| db.save(compare_question, cloud_resp, 0, project=project_dir, | |
| source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud")) | |
| console.print(f" [red]👎 都差[/]") | |
| continue | |
| elif cmd == "/switch": | |
| console.print("可用: local, openai, anthropic, openrouter, ollama") | |
| new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys())) | |
| if new_p == "local": | |
| with console.status("載入本地模型..."): | |
| model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter) | |
| local_model_ref = model; provider_key = "local" | |
| else: | |
| key = args.api_key or Prompt.ask("API Key") | |
| cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"]) | |
| model = CloudModel(new_p, key, cm); cloud_model_ref = model; provider_key = new_p | |
| console.print(f"[green]✅ 切換到 {provider_key}[/]"); continue | |
| messages.append({"role": "user", "content": user_input}) | |
| full_response = "" | |
| for rnd in range(10): | |
| with console.status(f"[bold cyan]{'思考中' if rnd == 0 else f'工具 round {rnd+1}'}..."): | |
| try: | |
| response = model.chat(messages) | |
| except Exception as e: | |
| console.print(f"[red]❌ API 錯誤: {e}[/]"); break | |
| tool_calls = parse_tool_calls(response) | |
| text_parts = TOOL_PATTERN.sub("", response).strip() | |
| if text_parts: | |
| console.print(f"\n[bold blue]🤖 CodePilot ({model.name if hasattr(model,'name') else 'local'}):[/]") | |
| console.print(Markdown(text_parts)) | |
| full_response += response + "\n" | |
| if not tool_calls: break | |
| messages.append({"role": "assistant", "content": response}) | |
| results = [] | |
| for call in tool_calls: | |
| console.print(f" [dim]🔧 {call['tool']}[/]") | |
| result = execute_tool(tools, call) | |
| if call["tool"] == "edit_file" and "✅" in result: | |
| d = result.split("\n", 1)[1] if "\n" in result else "" | |
| if d: console.print(Syntax(d, "diff", theme="monokai")) | |
| elif call["tool"] == "run_command": | |
| console.print(Panel(result[:500], title="Terminal", border_style="dim")) | |
| else: | |
| console.print(f" [dim]{result[:200]}[/]") | |
| results.append(f"[{call['tool']}] {result}") | |
| messages.append({"role": "user", "content": "Tool results:\n" + "\n\n".join(results)}) | |
| # 回饋 | |
| if distill_mode: | |
| # 蒸餾模式:自動接受雲端回答 | |
| db.save(user_input, full_response, 1, project=project_dir, | |
| source_model=getattr(model, "name", "local"), provider=provider_key) | |
| console.print(f" [yellow]⚗️ 自動記錄為訓練數據[/]") | |
| else: | |
| console.print(f"\n[dim][green]y[/]=👍 [red]n[/]=👎 [yellow]e[/]=✏️ Enter=跳過[/]") | |
| fb = Prompt.ask(" ", choices=["y","n","e",""], default="", show_choices=False) | |
| if fb == "y": | |
| db.save(user_input, full_response, 1, project=project_dir, | |
| source_model=getattr(model, "name", "local"), provider=provider_key) | |
| console.print(" [green]👍[/]") | |
| elif fb == "n": | |
| db.save(user_input, full_response, 0, project=project_dir, | |
| source_model=getattr(model, "name", "local"), provider=provider_key) | |
| console.print(" [red]👎[/]") | |
| elif fb == "e": | |
| console.print(" [yellow]貼上修改版(END結束):[/]") | |
| lines = [] | |
| while True: | |
| try: | |
| l = input() | |
| if l.strip() == "END": break | |
| lines.append(l) | |
| except EOFError: break | |
| edited = "\n".join(lines) | |
| if edited.strip(): | |
| db.save(user_input, full_response, 1, edited=edited, project=project_dir, | |
| source_model=getattr(model, "name", "local"), provider=provider_key) | |
| console.print(" [yellow]✏️[/]") | |
| if messages[-1]["role"] == "user" and "Tool results:" in messages[-1]["content"]: | |
| messages.append({"role": "assistant", "content": full_response}) | |
| console.print("\n[cyan]👋[/]") | |
| # ============================================================ | |
| # TRAINING(支援蒸餾) | |
| # ============================================================ | |
| def trigger_training(db, console, args): | |
| stats = db.count() | |
| if stats["total"] == 0: | |
| console.print("[yellow]⚠️ 無數據[/]"); return | |
| # 蒸餾數據(雲端模型產生的) | |
| cloud_sft = db.export_sft(only_cloud=True) | |
| dpo_data = db.export_dpo() | |
| all_sft = db.export_sft(only_cloud=False) | |
| kto_data = db.export_kto() | |
| console.print(f"\n[bold]🚀 訓練數據統計[/]") | |
| console.print(f" ⚗️ 雲端蒸餾 SFT: {len(cloud_sft)} 條 (Opus/GPT 的回答)") | |
| console.print(f" 📊 雲端 vs 本地 DPO: {len(dpo_data)} 對") | |
| console.print(f" 📚 全部 SFT: {len(all_sft)} 條") | |
| console.print(f" 👍👎 KTO: {len(kto_data)} 條") | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig, prepare_model_for_kbit_training | |
| model_name = args.model or DEFAULT_LOCAL_MODEL | |
| output_dir = os.path.join(CONFIG_DIR, f"adapter_{datetime.now().strftime('%Y%m%d_%H%M')}") | |
| bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True) | |
| peft_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]) | |
| # 優先級:蒸餾 SFT > DPO > 全部 SFT > KTO | |
| train_data = cloud_sft or all_sft | |
| if train_data: | |
| label = "⚗️ 蒸餾 SFT" if cloud_sft else "📚 SFT" | |
| console.print(f"\n[bold]{label} ({len(train_data)} 條)...[/]") | |
| from trl import SFTTrainer, SFTConfig | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, quantization_config=bnb, device_map="auto", trust_remote_code=True) | |
| tok = AutoTokenizer.from_pretrained(model_name) | |
| if tok.pad_token is None: tok.pad_token = tok.eos_token | |
| model = prepare_model_for_kbit_training(model) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=SFTConfig( | |
| output_dir=output_dir, learning_rate=2e-4, num_train_epochs=3, | |
| per_device_train_batch_size=1, gradient_accumulation_steps=8, | |
| max_seq_length=1024, gradient_checkpointing=True, bf16=True, | |
| optim="paged_adamw_8bit", logging_steps=5, save_total_limit=1, | |
| logging_strategy="steps", logging_first_step=True), | |
| processing_class=tok, | |
| train_dataset=Dataset.from_list(train_data), | |
| peft_config=peft_cfg) | |
| trainer.train() | |
| trainer.save_model(output_dir) | |
| del model; torch.cuda.empty_cache() | |
| console.print(f"\n[bold green]🎉 完成![/]") | |
| console.print(f" Adapter: {output_dir}") | |
| console.print(f" 使用: codepilot --adapter {output_dir}") | |
| def show_stats(): | |
| from rich.console import Console; from rich.table import Table | |
| c = Console(); db = FeedbackDB() | |
| t = Table(title="📊 CodePilot 數據統計") | |
| t.add_column("來源", style="cyan"); t.add_column("總計"); t.add_column("👍"); t.add_column("✏️") | |
| for pk in [None, "local", "openai", "anthropic", "openrouter"]: | |
| s = db.count(pk) | |
| if s["total"] > 0: | |
| t.add_row(pk or "全部", str(s["total"]), str(s["up"]), str(s["edits"])) | |
| c.print(t) | |
| cloud_sft = db.export_sft(only_cloud=True) | |
| dpo = db.export_dpo() | |
| if cloud_sft or dpo: | |
| c.print(f"\n [yellow]⚗️ 可蒸餾: SFT {len(cloud_sft)} / DPO {len(dpo)}[/]") | |
| def main(): | |
| p = argparse.ArgumentParser(description="CodePilot v3 — 多模型 AI 開發助手") | |
| p.add_argument("--model", type=str, default=None, help="本地模型") | |
| p.add_argument("--adapter", type=str, default=None, help="LoRA adapter") | |
| p.add_argument("--project", type=str, default=None, help="專案目錄") | |
| p.add_argument("--provider", type=str, default=None, | |
| choices=["local","openai","anthropic","openrouter","ollama"], | |
| help="模型提供者") | |
| p.add_argument("--api-key", type=str, default=None, help="API key") | |
| p.add_argument("--cloud-model", type=str, default=None, help="雲端模型名稱") | |
| p.add_argument("--distill", action="store_true", help="蒸餾模式(自動收集雲端回答)") | |
| p.add_argument("--stats", action="store_true") | |
| p.add_argument("--train", action="store_true") | |
| a = p.parse_args() | |
| if a.stats: show_stats() | |
| elif a.train: | |
| from rich.console import Console | |
| trigger_training(FeedbackDB(), Console(), a) | |
| else: | |
| run_agent_loop(a) | |
| if __name__ == "__main__": | |
| main() | |