#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ CodePilot v3 — 多模型 AI 開發助手 + 知識蒸餾 ============================================== 支援多種模型後端: 🏠 Local: Qwen2.5-Coder-3B(你的本地模型) ☁️ Cloud: OpenAI (GPT-4o/5), Anthropic (Claude Opus), Google (Gemini) 🔗 Proxy: OpenRouter(一個 API 接所有模型) 知識蒸餾模式: 用 Opus/GPT-5 的回答,自動訓練你的本地模型 → 免費版的 Opus! Usage: # 用本地模型 codepilot # 用 Claude Opus(並自動收集訓練數據) codepilot --provider anthropic --api-key sk-xxx # 用 OpenRouter(最方便,一個 key 用所有模型) codepilot --provider openrouter --api-key sk-xxx --cloud-model anthropic/claude-opus-4 # 蒸餾模式:用雲端模型產生數據,訓練本地模型 codepilot --distill --provider openrouter --api-key sk-xxx # 用收集的雲端數據訓練本地模型 codepilot --train """ import argparse, difflib, json, os, re, shutil, sqlite3, subprocess, sys, torch, httpx from datetime import datetime from pathlib import Path # ============================================================ # CONFIG # ============================================================ DEFAULT_LOCAL_MODEL = "Qwen/Qwen2.5-Coder-3B-Instruct" CONFIG_DIR = os.path.expanduser("~/.codepilot") DB_PATH = os.path.join(CONFIG_DIR, "feedback.db") # 雲端模型預設 PROVIDER_CONFIGS = { "local": { "name": "Local (Qwen2.5-Coder-3B)", "type": "local", }, "openai": { "name": "OpenAI", "type": "openai", "base_url": "https://api.openai.com/v1", "default_model": "gpt-4o", }, "anthropic": { "name": "Anthropic", "type": "anthropic", "base_url": "https://api.anthropic.com/v1", "default_model": "claude-sonnet-4-20250514", }, "openrouter": { "name": "OpenRouter (所有模型)", "type": "openai", # OpenRouter 用 OpenAI 相容 API "base_url": "https://openrouter.ai/api/v1", "default_model": "anthropic/claude-sonnet-4", }, "ollama": { "name": "Ollama (本地)", "type": "openai", "base_url": "http://localhost:11434/v1", "default_model": "qwen2.5-coder:3b", }, } # ============================================================ # FEEDBACK DB (升級版 — 記錄來源模型) # ============================================================ class FeedbackDB: def __init__(self): os.makedirs(CONFIG_DIR, exist_ok=True) self.conn = sqlite3.connect(DB_PATH) self.conn.execute("""CREATE TABLE IF NOT EXISTS feedback ( id INTEGER PRIMARY KEY, timestamp TEXT, prompt TEXT, completion TEXT, label INTEGER, edited_completion TEXT, project TEXT, source_model TEXT, provider TEXT)""") self.conn.commit() def save(self, prompt, completion, label, edited=None, project=None, source_model=None, provider=None): self.conn.execute( "INSERT INTO feedback VALUES (NULL,?,?,?,?,?,?,?,?)", (datetime.now().isoformat(), prompt, completion, int(label), edited, project, source_model, provider)) self.conn.commit() def count(self, provider=None): if provider: r = self.conn.execute( "SELECT COUNT(*), COALESCE(SUM(label),0), " "SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) " "FROM feedback WHERE provider=?", (provider,)).fetchone() else: r = self.conn.execute( "SELECT COUNT(*), COALESCE(SUM(label),0), " "SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) " "FROM feedback").fetchone() return {"total": r[0], "up": int(r[1]), "edits": int(r[2] or 0)} def export_sft(self, only_cloud=False): """匯出 SFT 數據(可選只匯出雲端模型的)""" if only_cloud: # 雲端模型接受的回答 = 高品質 SFT 數據 rows = self.conn.execute( "SELECT prompt, completion FROM feedback " "WHERE label=1 AND provider != 'local' AND provider IS NOT NULL" ).fetchall() else: rows = self.conn.execute( "SELECT prompt, COALESCE(edited_completion, completion) FROM feedback " "WHERE label=1" ).fetchall() return [{"messages": [ {"role": "user", "content": p}, {"role": "assistant", "content": c}, ]} for p, c in rows] def export_dpo(self): """用雲端模型 vs 本地模型的回答配對成 DPO 數據""" # 找相同 prompt 但不同 provider 的配對 rows = self.conn.execute(""" SELECT c.prompt, c.completion, l.completion FROM feedback c JOIN feedback l ON c.prompt = l.prompt WHERE c.provider != 'local' AND c.label = 1 AND l.provider = 'local' AND l.label = 0 """).fetchall() return [{ "prompt": [{"role": "user", "content": p}], "chosen": [{"role": "assistant", "content": cloud}], "rejected": [{"role": "assistant", "content": local}], } for p, cloud, local in rows] def export_kto(self): rows = self.conn.execute( "SELECT prompt, completion, label FROM feedback" ).fetchall() return [{ "prompt": [{"role": "user", "content": p}], "completion": [{"role": "assistant", "content": c}], "label": bool(l), } for p, c, l in rows] # ============================================================ # MODEL BACKENDS # ============================================================ class LocalModel: """本地 Qwen 模型""" def __init__(self, model_name=DEFAULT_LOCAL_MODEL, adapter_path=None): from transformers import AutoTokenizer, AutoModelForCausalLM self.name = model_name.split("/")[-1] self.provider = "local" self.tokenizer = AutoTokenizer.from_pretrained(model_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) if adapter_path and os.path.exists(adapter_path): from peft import PeftModel self.model = PeftModel.from_pretrained(self.model, adapter_path) self.model.eval() def chat(self, messages, max_tokens=4096): text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device) with torch.no_grad(): out = self.model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=self.tokenizer.pad_token_id) return self.tokenizer.decode( out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) class CloudModel: """雲端模型(OpenAI 相容 API)""" def __init__(self, provider_key, api_key, model_name=None): config = PROVIDER_CONFIGS[provider_key] self.provider = provider_key self.base_url = config["base_url"] self.name = model_name or config["default_model"] self.api_key = api_key self.api_type = config["type"] def chat(self, messages, max_tokens=4096): if self.api_type == "anthropic": return self._chat_anthropic(messages, max_tokens) else: return self._chat_openai(messages, max_tokens) def _chat_openai(self, messages, max_tokens): """OpenAI / OpenRouter / Ollama 相容 API""" headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } # OpenRouter 需要額外 header if self.provider == "openrouter": headers["HTTP-Referer"] = "https://codepilot.local" headers["X-Title"] = "CodePilot" data = { "model": self.name, "messages": messages, "max_tokens": max_tokens, "temperature": 0.7, } resp = httpx.post( f"{self.base_url}/chat/completions", headers=headers, json=data, timeout=120) resp.raise_for_status() return resp.json()["choices"][0]["message"]["content"] def _chat_anthropic(self, messages, max_tokens): """Anthropic API""" # 分離 system message system = None chat_msgs = [] for m in messages: if m["role"] == "system": system = m["content"] else: chat_msgs.append(m) headers = { "x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01", } data = { "model": self.name, "messages": chat_msgs, "max_tokens": max_tokens, "temperature": 0.7, } if system: data["system"] = system resp = httpx.post( f"{self.base_url}/messages", headers=headers, json=data, timeout=120) resp.raise_for_status() return resp.json()["content"][0]["text"] # ============================================================ # PROJECT TOOLS(和 v2 相同) # ============================================================ class ProjectTools: def __init__(self, project_dir): self.project_dir = os.path.abspath(project_dir) self.cwd = self.project_dir self.read_cache = {} def _resolve(self, path): return path if os.path.isabs(path) else os.path.normpath(os.path.join(self.cwd, path)) def read_file(self, path, offset=1, limit=200): full = self._resolve(path) if not os.path.exists(full): return f"❌ 不存在: {path}" try: content = Path(full).read_text(encoding="utf-8", errors="replace") lines = content.splitlines() self.read_cache[full] = {"time": os.path.getmtime(full), "content": content} selected = lines[offset-1:offset-1+limit] result = "\n".join(f"{i+offset:4d} │ {line}" for i, line in enumerate(selected)) if offset + limit < len(lines): result += f"\n... ({len(lines)-offset-limit+1} more)" return result except Exception as e: return f"❌ {e}" def edit_file(self, path, old_string, new_string): full = self._resolve(path) if full not in self.read_cache: return "❌ 必須先 read_file" content = Path(full).read_text(encoding="utf-8") if os.path.getmtime(full) != self.read_cache[full]["time"]: return "❌ 文件已被外部修改" count = content.count(old_string) if count == 0: return "❌ 找不到要替換的文字" if count > 1: return f"❌ 找到 {count} 處,請提供更多上下文" new_content = content.replace(old_string, new_string, 1) diff = "".join(difflib.unified_diff( content.splitlines(keepends=True), new_content.splitlines(keepends=True), fromfile=f"a/{path}", tofile=f"b/{path}")) Path(full).write_text(new_content, encoding="utf-8") self.read_cache[full] = {"time": os.path.getmtime(full), "content": new_content} return "✅ 已修改:\n" + diff def write_file(self, path, content): full = self._resolve(path); os.makedirs(os.path.dirname(full) or ".", exist_ok=True) is_new = not os.path.exists(full) Path(full).write_text(content, encoding="utf-8") self.read_cache[full] = {"time": os.path.getmtime(full), "content": content} return f"✅ {'建立' if is_new else '覆寫'}: {path}" def run_command(self, command, timeout=120): for d in {"rm -rf /","git push --force","git reset --hard"}: if d in command: return f"⛔ 危險: {command}" try: r = subprocess.run(command, shell=True, cwd=self.cwd, capture_output=True, text=True, timeout=timeout) return (r.stdout + (f"\nSTDERR:\n{r.stderr}" if r.stderr else ""))[:10000] except subprocess.TimeoutExpired: return f"⏰ 超時" except Exception as e: return f"❌ {e}" def search_files(self, pattern, glob_pattern=None): rg = shutil.which("rg"); cmd = [rg or "grep", "-rn"] if rg: cmd += ["--color=never", "--max-count=50"] if glob_pattern and rg: cmd += ["--glob", glob_pattern] cmd += [pattern, self.cwd] try: return subprocess.run(cmd, capture_output=True, text=True, timeout=30).stdout[:5000] or "無匹配" except Exception as e: return f"❌ {e}" def list_files(self, pattern="*", max_depth=3): files = [] for root, dirs, fnames in os.walk(self.cwd): dirs[:] = [d for d in dirs if d not in {".git","node_modules","__pycache__",".venv","dist","build"}] if root.replace(self.cwd, "").count(os.sep) >= max_depth: continue files.extend(os.path.relpath(os.path.join(root, f), self.cwd) for f in fnames if Path(f).match(pattern)) return "\n".join(sorted(files)[:100]) def git_context(self): try: b = subprocess.run(["git","branch","--show-current"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() s = subprocess.run(["git","status","--short"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() l = subprocess.run(["git","log","--oneline","-5"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() return f"Branch: {b}\nStatus:\n{s}\nRecent:\n{l}" except: return "(not a git repo)" # ============================================================ # TOOL PARSER + EXECUTOR + SYSTEM PROMPT(和 v2 相同) # ============================================================ TOOL_PATTERN = re.compile(r'\s*(\w+)\s*\n(.*?)', re.DOTALL) def parse_tool_calls(text): calls = [] for m in TOOL_PATTERN.finditer(text): try: params = json.loads(m.group(2).strip()) except: params = {} for line in m.group(2).strip().split("\n"): if ":" in line: k, v = line.split(":", 1); params[k.strip()] = v.strip().strip('"') calls.append({"tool": m.group(1), "params": params}) return calls def execute_tool(tools, call): n, p = call["tool"], call["params"] try: if n == "read_file": return tools.read_file(p.get("path",""), int(p.get("offset",1)), int(p.get("limit",200))) elif n == "edit_file": return tools.edit_file(p.get("path",""), p.get("old_string",""), p.get("new_string","")) elif n == "write_file": return tools.write_file(p.get("path",""), p.get("content","")) elif n == "run_command": return tools.run_command(p.get("command",""), int(p.get("timeout",120))) elif n == "search_files": return tools.search_files(p.get("pattern",""), p.get("glob")) elif n == "list_files": return tools.list_files(p.get("pattern","*"), int(p.get("max_depth",3))) elif n == "git_status": return tools.git_context() else: return f"❌ 未知: {n}" except Exception as e: return f"❌ {e}" def build_system_prompt(tools): return f"""You are CodePilot, an expert AI programming assistant working in the user's project. Working directory: {tools.cwd} {tools.git_context()} ## Tools (use name\n{{json}}) - read_file: {{"path":"...","offset":1,"limit":200}} - edit_file: {{"path":"...","old_string":"...","new_string":"..."}} (must read first) - write_file: {{"path":"...","content":"..."}} - run_command: {{"command":"...","timeout":120}} - search_files: {{"pattern":"...","glob":"*.py"}} - list_files: {{"pattern":"*","max_depth":3}} - git_status: {{}} Rules: read before edit, old_string must be unique, prefer edit over write, verify changes.""" # ============================================================ # MAIN AGENT LOOP # ============================================================ def run_agent_loop(args): from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel from rich.prompt import Prompt from rich.syntax import Syntax from rich.table import Table console = Console() db = FeedbackDB() project_dir = args.project or os.getcwd() tools = ProjectTools(project_dir) provider_key = args.provider or "local" # 決定使用哪個模型 if provider_key == "local": model_label = args.model or DEFAULT_LOCAL_MODEL with console.status("[bold green]載入本地模型..."): model = LocalModel(model_label, args.adapter) else: if not args.api_key: console.print(f"[red]❌ 使用 {provider_key} 需要 --api-key[/]") sys.exit(1) cloud_model = args.cloud_model or PROVIDER_CONFIGS[provider_key]["default_model"] model = CloudModel(provider_key, args.api_key, cloud_model) model_label = f"{PROVIDER_CONFIGS[provider_key]['name']}/{model.name}" # 蒸餾模式標記 distill_mode = args.distill and provider_key != "local" # Banner banner = f"[bold cyan]CodePilot v3[/]" if distill_mode: banner += " [bold yellow]⚗️ 蒸餾模式[/]" banner += f"\n[dim]Model: {model_label}\nProject: {project_dir}[/]" if distill_mode: banner += f"\n[yellow]雲端回答將自動收集為本地模型的訓練數據[/]" console.print(Panel.fit(banner, border_style="cyan")) if provider_key == "local": console.print("[green]✅ 本地模型載入完成[/]") else: console.print(f"[green]✅ 已連接 {PROVIDER_CONFIGS[provider_key]['name']}[/]") git_ctx = tools.git_context() if git_ctx != "(not a git repo)": console.print(Panel(git_ctx, title="📂 Project", border_style="dim")) console.print("[dim]指令: /ls /git /clear /switch /compare /status /train /quit[/]\n") # 保存模型參照,讓 /compare 可以用 local_model_ref = None cloud_model_ref = None if provider_key == "local": local_model_ref = model else: cloud_model_ref = model # 蒸餾/compare 模式下,也嘗試載入本地模型 if args.adapter or distill_mode: try: with console.status("[dim]同時載入本地模型 (for /compare)..."): local_model_ref = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter) console.print("[dim]✅ 本地模型也已載入,可用 /compare[/]") except Exception: console.print("[dim]⚠️ 本地模型載入失敗,/compare 不可用[/]") system_prompt = build_system_prompt(tools) messages = [{"role": "system", "content": system_prompt}] while True: try: user_input = Prompt.ask("\n[bold green]🧑 You") except (EOFError, KeyboardInterrupt): break if not user_input.strip(): continue cmd = user_input.strip() if cmd in ("/quit", "/exit"): break elif cmd == "/status": s_all = db.count(); s_cloud = db.count("local") t = Table(title="📊 數據統計"); t.add_column("來源"); t.add_column("數量"); t.add_column("👍") t.add_row("全部", str(s_all["total"]), str(s_all["up"])) for pk in ["local", "openai", "anthropic", "openrouter"]: sc = db.count(pk) if sc["total"] > 0: t.add_row(pk, str(sc["total"]), str(sc["up"])) console.print(t) # 蒸餾數據 sft = db.export_sft(only_cloud=True) dpo = db.export_dpo() if sft or dpo: console.print(f"\n [yellow]⚗️ 可蒸餾數據: SFT {len(sft)} / DPO {len(dpo)}[/]") console.print(f" [dim]運行 codepilot --train 開始蒸餾訓練[/]") continue elif cmd == "/train": trigger_training(db, console, args); continue elif cmd == "/clear": messages = [{"role": "system", "content": system_prompt}] console.print("[dim]已清除[/]"); continue elif cmd == "/git": console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue elif cmd.startswith("/ls"): console.print(tools.list_files(cmd[3:].strip() or "*")); continue elif cmd == "/compare" or cmd.startswith("/compare "): # /compare 模式:同一問題自動送給本地+雲端,並排比較,一鍵產生 DPO compare_question = cmd[8:].strip() if cmd.startswith("/compare ") else None if not compare_question: compare_question = Prompt.ask(" 輸入要比較的問題") if not compare_question.strip(): continue need_local = local_model_ref or (provider_key == "local" and model) need_cloud = cloud_model_ref or (provider_key != "local" and model) if not need_local or not need_cloud: console.print("[yellow]⚠️ /compare 需要同時有本地和雲端模型[/]") console.print("[dim]啟動方式: codepilot --provider openrouter --api-key sk-xxx --adapter ./adapter[/]") continue lm = local_model_ref if local_model_ref else model cm = cloud_model_ref if cloud_model_ref else model compare_msgs = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": compare_question}, ] # 同時生成兩個回答 with console.status("[bold cyan]本地模型思考中..."): try: local_resp = lm.chat(compare_msgs) except Exception as e: local_resp = f"(錯誤: {e})" with console.status("[bold magenta]雲端模型思考中..."): try: cloud_resp = cm.chat(compare_msgs) except Exception as e: cloud_resp = f"(錯誤: {e})" # 並排顯示 console.print(f"\n[bold]🔄 Compare: {compare_question[:60]}{'...' if len(compare_question)>60 else ''}[/]\n") console.print(Panel(Markdown(local_resp), title=f"🏠 Local ({lm.name})", border_style="blue")) console.print(Panel(Markdown(cloud_resp), title=f"☁️ Cloud ({cm.name})", border_style="magenta")) # 選擇 console.print(f"\n [green]1[/] = 本地較好 [magenta]2[/] = 雲端較好 [yellow]b[/] = 都好 [red]x[/] = 都差 Enter = 跳過") choice = Prompt.ask(" ", choices=["1","2","b","x",""], default="", show_choices=False) if choice == "2": # 雲端好,本地差 → DPO: chosen=cloud, rejected=local db.save(compare_question, cloud_resp, 1, project=project_dir, source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud")) db.save(compare_question, local_resp, 0, project=project_dir, source_model=getattr(lm, "name", "local"), provider="local") dpo_count = len(db.export_dpo()) console.print(f" [magenta]☁️ 雲端勝 → DPO +1[/] (累計 DPO 對: {dpo_count})") elif choice == "1": # 本地好 → 記錄本地為正面 db.save(compare_question, local_resp, 1, project=project_dir, source_model=getattr(lm, "name", "local"), provider="local") db.save(compare_question, cloud_resp, 0, project=project_dir, source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud")) console.print(f" [green]🏠 本地勝!你的模型在進步![/]") elif choice == "b": # 都好 → 兩個都記為 SFT db.save(compare_question, local_resp, 1, project=project_dir, source_model=getattr(lm, "name", "local"), provider="local") db.save(compare_question, cloud_resp, 1, project=project_dir, source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud")) console.print(f" [yellow]👍 都好 → SFT +2[/]") elif choice == "x": # 都差 db.save(compare_question, local_resp, 0, project=project_dir, source_model=getattr(lm, "name", "local"), provider="local") db.save(compare_question, cloud_resp, 0, project=project_dir, source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud")) console.print(f" [red]👎 都差[/]") continue elif cmd == "/switch": console.print("可用: local, openai, anthropic, openrouter, ollama") new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys())) if new_p == "local": with console.status("載入本地模型..."): model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter) local_model_ref = model; provider_key = "local" else: key = args.api_key or Prompt.ask("API Key") cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"]) model = CloudModel(new_p, key, cm); cloud_model_ref = model; provider_key = new_p console.print(f"[green]✅ 切換到 {provider_key}[/]"); continue messages.append({"role": "user", "content": user_input}) full_response = "" for rnd in range(10): with console.status(f"[bold cyan]{'思考中' if rnd == 0 else f'工具 round {rnd+1}'}..."): try: response = model.chat(messages) except Exception as e: console.print(f"[red]❌ API 錯誤: {e}[/]"); break tool_calls = parse_tool_calls(response) text_parts = TOOL_PATTERN.sub("", response).strip() if text_parts: console.print(f"\n[bold blue]🤖 CodePilot ({model.name if hasattr(model,'name') else 'local'}):[/]") console.print(Markdown(text_parts)) full_response += response + "\n" if not tool_calls: break messages.append({"role": "assistant", "content": response}) results = [] for call in tool_calls: console.print(f" [dim]🔧 {call['tool']}[/]") result = execute_tool(tools, call) if call["tool"] == "edit_file" and "✅" in result: d = result.split("\n", 1)[1] if "\n" in result else "" if d: console.print(Syntax(d, "diff", theme="monokai")) elif call["tool"] == "run_command": console.print(Panel(result[:500], title="Terminal", border_style="dim")) else: console.print(f" [dim]{result[:200]}[/]") results.append(f"[{call['tool']}] {result}") messages.append({"role": "user", "content": "Tool results:\n" + "\n\n".join(results)}) # 回饋 if distill_mode: # 蒸餾模式:自動接受雲端回答 db.save(user_input, full_response, 1, project=project_dir, source_model=getattr(model, "name", "local"), provider=provider_key) console.print(f" [yellow]⚗️ 自動記錄為訓練數據[/]") else: console.print(f"\n[dim][green]y[/]=👍 [red]n[/]=👎 [yellow]e[/]=✏️ Enter=跳過[/]") fb = Prompt.ask(" ", choices=["y","n","e",""], default="", show_choices=False) if fb == "y": db.save(user_input, full_response, 1, project=project_dir, source_model=getattr(model, "name", "local"), provider=provider_key) console.print(" [green]👍[/]") elif fb == "n": db.save(user_input, full_response, 0, project=project_dir, source_model=getattr(model, "name", "local"), provider=provider_key) console.print(" [red]👎[/]") elif fb == "e": console.print(" [yellow]貼上修改版(END結束):[/]") lines = [] while True: try: l = input() if l.strip() == "END": break lines.append(l) except EOFError: break edited = "\n".join(lines) if edited.strip(): db.save(user_input, full_response, 1, edited=edited, project=project_dir, source_model=getattr(model, "name", "local"), provider=provider_key) console.print(" [yellow]✏️[/]") if messages[-1]["role"] == "user" and "Tool results:" in messages[-1]["content"]: messages.append({"role": "assistant", "content": full_response}) console.print("\n[cyan]👋[/]") # ============================================================ # TRAINING(支援蒸餾) # ============================================================ def trigger_training(db, console, args): stats = db.count() if stats["total"] == 0: console.print("[yellow]⚠️ 無數據[/]"); return # 蒸餾數據(雲端模型產生的) cloud_sft = db.export_sft(only_cloud=True) dpo_data = db.export_dpo() all_sft = db.export_sft(only_cloud=False) kto_data = db.export_kto() console.print(f"\n[bold]🚀 訓練數據統計[/]") console.print(f" ⚗️ 雲端蒸餾 SFT: {len(cloud_sft)} 條 (Opus/GPT 的回答)") console.print(f" 📊 雲端 vs 本地 DPO: {len(dpo_data)} 對") console.print(f" 📚 全部 SFT: {len(all_sft)} 條") console.print(f" 👍👎 KTO: {len(kto_data)} 條") from datasets import Dataset from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from peft import LoraConfig, prepare_model_for_kbit_training model_name = args.model or DEFAULT_LOCAL_MODEL output_dir = os.path.join(CONFIG_DIR, f"adapter_{datetime.now().strftime('%Y%m%d_%H%M')}") bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True) peft_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]) # 優先級:蒸餾 SFT > DPO > 全部 SFT > KTO train_data = cloud_sft or all_sft if train_data: label = "⚗️ 蒸餾 SFT" if cloud_sft else "📚 SFT" console.print(f"\n[bold]{label} ({len(train_data)} 條)...[/]") from trl import SFTTrainer, SFTConfig model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb, device_map="auto", trust_remote_code=True) tok = AutoTokenizer.from_pretrained(model_name) if tok.pad_token is None: tok.pad_token = tok.eos_token model = prepare_model_for_kbit_training(model) trainer = SFTTrainer( model=model, args=SFTConfig( output_dir=output_dir, learning_rate=2e-4, num_train_epochs=3, per_device_train_batch_size=1, gradient_accumulation_steps=8, max_seq_length=1024, gradient_checkpointing=True, bf16=True, optim="paged_adamw_8bit", logging_steps=5, save_total_limit=1, logging_strategy="steps", logging_first_step=True), processing_class=tok, train_dataset=Dataset.from_list(train_data), peft_config=peft_cfg) trainer.train() trainer.save_model(output_dir) del model; torch.cuda.empty_cache() console.print(f"\n[bold green]🎉 完成![/]") console.print(f" Adapter: {output_dir}") console.print(f" 使用: codepilot --adapter {output_dir}") def show_stats(): from rich.console import Console; from rich.table import Table c = Console(); db = FeedbackDB() t = Table(title="📊 CodePilot 數據統計") t.add_column("來源", style="cyan"); t.add_column("總計"); t.add_column("👍"); t.add_column("✏️") for pk in [None, "local", "openai", "anthropic", "openrouter"]: s = db.count(pk) if s["total"] > 0: t.add_row(pk or "全部", str(s["total"]), str(s["up"]), str(s["edits"])) c.print(t) cloud_sft = db.export_sft(only_cloud=True) dpo = db.export_dpo() if cloud_sft or dpo: c.print(f"\n [yellow]⚗️ 可蒸餾: SFT {len(cloud_sft)} / DPO {len(dpo)}[/]") def main(): p = argparse.ArgumentParser(description="CodePilot v3 — 多模型 AI 開發助手") p.add_argument("--model", type=str, default=None, help="本地模型") p.add_argument("--adapter", type=str, default=None, help="LoRA adapter") p.add_argument("--project", type=str, default=None, help="專案目錄") p.add_argument("--provider", type=str, default=None, choices=["local","openai","anthropic","openrouter","ollama"], help="模型提供者") p.add_argument("--api-key", type=str, default=None, help="API key") p.add_argument("--cloud-model", type=str, default=None, help="雲端模型名稱") p.add_argument("--distill", action="store_true", help="蒸餾模式(自動收集雲端回答)") p.add_argument("--stats", action="store_true") p.add_argument("--train", action="store_true") a = p.parse_args() if a.stats: show_stats() elif a.train: from rich.console import Console trigger_training(FeedbackDB(), Console(), a) else: run_agent_loop(a) if __name__ == "__main__": main()