Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| CodePilot v4 — AI 開發助手 + 自動進化 | |
| ====================================== | |
| v4 新功能: | |
| 🔄 /duel on|off — 雙模型比較開關,開啟後每個問題自動 DPO 配對 | |
| 🧠 上下文記憶 — CODEPILOT.md 專案記憶 + 對話歷史 + 文件快取 | |
| 🏋️ /grind — LeetCode 自動刷題,無人值守產生訓練數據 | |
| Usage: | |
| codepilot # 本地模型 | |
| codepilot --provider openrouter --api-key sk-xxx # 雲端 | |
| codepilot --duel --provider openrouter --api-key sk-xxx --adapter ./my-adapter | |
| codepilot --grind # 自動刷 LeetCode | |
| codepilot --grind --provider openrouter --api-key sk-xxx # 用雲端刷題蒸餾 | |
| """ | |
| import argparse, difflib, json, os, re, shutil, sqlite3, subprocess, sys, torch, time | |
| from datetime import datetime | |
| from pathlib import Path | |
| try: | |
| import httpx | |
| except ImportError: | |
| httpx = None | |
| DEFAULT_LOCAL_MODEL = "Qwen/Qwen2.5-Coder-3B-Instruct" | |
| CONFIG_DIR = os.path.expanduser("~/.codepilot") | |
| DB_PATH = os.path.join(CONFIG_DIR, "feedback.db") | |
| PROVIDER_CONFIGS = { | |
| "local": {"name": "Local", "type": "local"}, | |
| "openai": {"name": "OpenAI", "type": "openai", "base_url": "https://api.openai.com/v1", "default_model": "gpt-4o"}, | |
| "anthropic": {"name": "Anthropic", "type": "anthropic", "base_url": "https://api.anthropic.com/v1", "default_model": "claude-sonnet-4-20250514"}, | |
| "openrouter": {"name": "OpenRouter", "type": "openai", "base_url": "https://openrouter.ai/api/v1", "default_model": "anthropic/claude-sonnet-4"}, | |
| "ollama": {"name": "Ollama", "type": "openai", "base_url": "http://localhost:11434/v1", "default_model": "qwen2.5-coder:3b"}, | |
| "codex": {"name": "OpenAI Codex", "type": "codex", "default_model": "gpt-5.4"}, | |
| } | |
| # ============================================================ | |
| # FEEDBACK DB | |
| # ============================================================ | |
| class FeedbackDB: | |
| def __init__(self): | |
| os.makedirs(CONFIG_DIR, exist_ok=True) | |
| self.conn = sqlite3.connect(DB_PATH) | |
| self.conn.execute("""CREATE TABLE IF NOT EXISTS feedback ( | |
| id INTEGER PRIMARY KEY, timestamp TEXT, prompt TEXT, completion TEXT, | |
| label INTEGER, edited_completion TEXT, project TEXT, | |
| source_model TEXT, provider TEXT)""") | |
| self.conn.commit() | |
| def save(self, prompt, completion, label, edited=None, project=None, | |
| source_model=None, provider=None): | |
| self.conn.execute("INSERT INTO feedback VALUES (NULL,?,?,?,?,?,?,?,?)", | |
| (datetime.now().isoformat(), prompt, completion, int(label), | |
| edited, project, source_model, provider)) | |
| self.conn.commit() | |
| def count(self, provider=None): | |
| q = "SELECT COUNT(*), COALESCE(SUM(label),0), SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) FROM feedback" | |
| r = self.conn.execute(q + (" WHERE provider=?" if provider else ""), (provider,) if provider else ()).fetchone() | |
| return {"total": r[0], "up": int(r[1]), "edits": int(r[2] or 0)} | |
| def export_sft(self, only_cloud=False): | |
| if only_cloud: | |
| rows = self.conn.execute("SELECT prompt, completion FROM feedback WHERE label=1 AND provider != 'local' AND provider IS NOT NULL").fetchall() | |
| else: | |
| rows = self.conn.execute("SELECT prompt, COALESCE(edited_completion, completion) FROM feedback WHERE label=1").fetchall() | |
| return [{"messages": [{"role": "user", "content": p}, {"role": "assistant", "content": c}]} for p, c in rows] | |
| def export_dpo(self): | |
| rows = self.conn.execute("""SELECT c.prompt, c.completion, l.completion FROM feedback c | |
| JOIN feedback l ON c.prompt = l.prompt WHERE c.provider != 'local' AND c.label = 1 | |
| AND l.provider = 'local' AND l.label = 0""").fetchall() | |
| return [{"prompt": [{"role": "user", "content": p}], "chosen": [{"role": "assistant", "content": c}], | |
| "rejected": [{"role": "assistant", "content": l}]} for p, c, l in rows] | |
| def export_kto(self): | |
| rows = self.conn.execute("SELECT prompt, completion, label FROM feedback").fetchall() | |
| return [{"prompt": [{"role": "user", "content": p}], "completion": [{"role": "assistant", "content": c}], "label": bool(l)} for p, c, l in rows] | |
| # ============================================================ | |
| # MEMORY SYSTEM — Claude Code 風格四層記憶 | |
| # ============================================================ | |
| # 匯入 memory.py 模組(如果存在),否則使用內建簡化版 | |
| try: | |
| from memory import ( | |
| load_instructions, load_memory, save_memory, append_memory, | |
| build_full_system_prompt, SessionTranscript, FileStateCache, | |
| should_compact, compact_messages, estimate_tokens | |
| ) | |
| MEMORY_MODULE_AVAILABLE = True | |
| except ImportError: | |
| MEMORY_MODULE_AVAILABLE = False | |
| class ProjectContext: | |
| """ | |
| 四層記憶: | |
| L1: CODEPILOT.md 指令(遞迴搜尋 CWD 到根目錄) | |
| L2: MEMORY.md 跨 session 記憶 | |
| L3: Session transcript (JSONL) | |
| L4: 自動壓縮(context window 管理) | |
| """ | |
| def __init__(self, project_dir): | |
| self.project_dir = project_dir | |
| self.cwd = project_dir | |
| if MEMORY_MODULE_AVAILABLE: | |
| # 用完整 memory.py 模組 | |
| self.transcript = SessionTranscript.find_latest(project_dir) | |
| self.file_cache = FileStateCache() | |
| else: | |
| self.transcript = None | |
| self.file_cache = None | |
| # Session 文件(簡化版 fallback) | |
| self.session_file = os.path.join(CONFIG_DIR, "sessions", | |
| os.path.basename(project_dir) + ".json") | |
| os.makedirs(os.path.dirname(self.session_file), exist_ok=True) | |
| def load_all_instructions(self): | |
| """L1: 載入所有 CODEPILOT.md 指令""" | |
| if MEMORY_MODULE_AVAILABLE: | |
| return load_instructions(self.cwd) | |
| # Fallback: 只讀當前目錄的 | |
| f = os.path.join(self.project_dir, "CODEPILOT.md") | |
| return Path(f).read_text(encoding="utf-8") if os.path.exists(f) else "" | |
| def load_memory(self): | |
| """L2: 載入跨 session 記憶""" | |
| if MEMORY_MODULE_AVAILABLE: | |
| return load_memory(self.cwd) | |
| return "" | |
| def save_memory_entry(self, entry): | |
| """L2: 追加一條記憶""" | |
| if MEMORY_MODULE_AVAILABLE: | |
| append_memory(self.cwd, entry) | |
| def load_session(self): | |
| """L3: 載入上次對話""" | |
| if MEMORY_MODULE_AVAILABLE and self.transcript: | |
| return self.transcript.load_messages() | |
| if os.path.exists(self.session_file): | |
| try: | |
| data = json.loads(Path(self.session_file).read_text()) | |
| msgs = data.get("messages", []) | |
| if len(msgs) > 42: msgs = [msgs[0]] + msgs[-40:] | |
| return msgs | |
| except: pass | |
| return None | |
| def save_session(self, messages): | |
| """L3: 保存當前對話""" | |
| if MEMORY_MODULE_AVAILABLE: | |
| if not self.transcript: | |
| self.transcript = SessionTranscript(self.cwd) | |
| # 追加最新訊息到 JSONL | |
| if messages: | |
| last = messages[-1] | |
| self.transcript.append(last.get("role", "user"), last) | |
| # 也保存簡化版 | |
| if len(messages) > 42: messages = [messages[0]] + messages[-40:] | |
| Path(self.session_file).write_text( | |
| json.dumps({"messages": messages, "timestamp": datetime.now().isoformat()}, ensure_ascii=False)) | |
| def check_compact(self, messages, model_chat_fn=None): | |
| """L4: 檢查是否需要壓縮,自動執行""" | |
| if not MEMORY_MODULE_AVAILABLE: | |
| # Fallback: 簡單截斷 | |
| if len(messages) > 42: | |
| return [messages[0]] + messages[-40:] | |
| return messages | |
| if should_compact(messages): | |
| edited_files = self.file_cache.get_recently_edited() if self.file_cache else [] | |
| if model_chat_fn: | |
| return compact_messages(messages, model_chat_fn, edited_files) | |
| else: | |
| return [messages[0]] + messages[-30:] | |
| return messages | |
| def build_system_prompt(self, git_context=""): | |
| """組裝完整 system prompt""" | |
| if MEMORY_MODULE_AVAILABLE: | |
| return build_full_system_prompt(self.cwd, git_context) | |
| # Fallback | |
| memory = self.load_all_instructions() | |
| mem = self.load_memory() | |
| parts = ["You are CodePilot, an expert AI programming assistant."] | |
| if memory: parts.append(memory) | |
| if mem: parts.append(f"## Memory\n{mem}") | |
| parts.append(f"Working directory: {self.cwd}\n{git_context}") | |
| return "\n\n".join(parts) | |
| # ============================================================ | |
| # MODEL BACKENDS | |
| # ============================================================ | |
| class LocalModel: | |
| def __init__(self, model_name=DEFAULT_LOCAL_MODEL, adapter_path=None): | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| self.name = model_name.split("/")[-1]; self.provider = "local" | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True) | |
| if adapter_path and os.path.exists(adapter_path): | |
| from peft import PeftModel; self.model = PeftModel.from_pretrained(self.model, adapter_path) | |
| self.model.eval() | |
| def chat(self, messages, max_tokens=4096): | |
| text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| out = self.model.generate(**inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=self.tokenizer.pad_token_id) | |
| return self.tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) | |
| class CloudModel: | |
| def __init__(self, provider_key, api_key, model_name=None): | |
| config = PROVIDER_CONFIGS[provider_key] | |
| self.provider = provider_key; self.base_url = config["base_url"] | |
| self.name = model_name or config["default_model"]; self.api_key = api_key; self.api_type = config["type"] | |
| def chat(self, messages, max_tokens=4096): | |
| if self.api_type == "anthropic": return self._anthropic(messages, max_tokens) | |
| else: return self._openai(messages, max_tokens) | |
| def _openai(self, messages, max_tokens): | |
| headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} | |
| if self.provider == "openrouter": headers.update({"HTTP-Referer": "https://codepilot.local", "X-Title": "CodePilot"}) | |
| resp = httpx.post(f"{self.base_url}/chat/completions", headers=headers, | |
| json={"model": self.name, "messages": messages, "max_tokens": max_tokens, "temperature": 0.7}, timeout=120) | |
| resp.raise_for_status(); return resp.json()["choices"][0]["message"]["content"] | |
| def _anthropic(self, messages, max_tokens): | |
| system = None; chat_msgs = [] | |
| for m in messages: | |
| if m["role"] == "system": system = m["content"] | |
| else: chat_msgs.append(m) | |
| data = {"model": self.name, "messages": chat_msgs, "max_tokens": max_tokens, "temperature": 0.7} | |
| if system: data["system"] = system | |
| resp = httpx.post(f"{self.base_url}/messages", headers={"x-api-key": self.api_key, "Content-Type": "application/json", "anthropic-version": "2023-06-01"}, json=data, timeout=120) | |
| resp.raise_for_status(); return resp.json()["content"][0]["text"] | |
| class CodexModel: | |
| """OpenAI Codex CLI 整合 — 透過 Python SDK 或 subprocess""" | |
| def __init__(self, model_name="gpt-5.4"): | |
| self.name = model_name | |
| self.provider = "codex" | |
| self._sdk_available = False | |
| self._thread = None | |
| # 嘗試用 Python SDK | |
| try: | |
| from codex_app_server import Codex | |
| self._codex = Codex() | |
| self._codex.__enter__() | |
| self._thread = self._codex.thread_start(model=model_name) | |
| self._sdk_available = True | |
| except ImportError: | |
| # Fallback: 用 subprocess 呼叫 codex CLI | |
| import shutil | |
| self._codex_bin = shutil.which("codex") | |
| if not self._codex_bin: | |
| raise RuntimeError( | |
| "OpenAI Codex 未安裝。請先安裝:\n" | |
| " npm install -g @openai/codex\n" | |
| " # 或\n" | |
| " brew install --cask codex\n\n" | |
| "然後執行 codex 登入你的 OpenAI 帳號。" | |
| ) | |
| def chat(self, messages, max_tokens=4096): | |
| # 組合 messages 成單一 prompt | |
| prompt_parts = [] | |
| for m in messages: | |
| if m["role"] == "system": | |
| prompt_parts.append(f"[System Instructions]\n{m['content']}\n") | |
| elif m["role"] == "user": | |
| prompt_parts.append(f"User: {m['content']}") | |
| elif m["role"] == "assistant": | |
| prompt_parts.append(f"Assistant: {m['content']}") | |
| prompt = "\n\n".join(prompt_parts[-6:]) # 只取最近幾輪,避免太長 | |
| if self._sdk_available: | |
| return self._chat_sdk(prompt) | |
| else: | |
| return self._chat_subprocess(prompt) | |
| def _chat_sdk(self, prompt): | |
| """透過 Python SDK""" | |
| result = self._thread.run(prompt) | |
| return result.final_response or "(no response)" | |
| def _chat_subprocess(self, prompt): | |
| """透過 CLI subprocess — 不需要 SDK,只要裝了 codex CLI""" | |
| try: | |
| result = subprocess.run( | |
| [self._codex_bin, "--model", self.name, | |
| "--approval-mode", "auto", # 自動批准工具呼叫 | |
| "--quiet", # 減少輸出噪音 | |
| prompt], | |
| capture_output=True, text=True, timeout=180, | |
| env={**os.environ, "NO_COLOR": "1"}, # 關閉 ANSI 顏色 | |
| ) | |
| output = result.stdout.strip() | |
| if not output and result.stderr: | |
| output = result.stderr.strip() | |
| return output or "(no response)" | |
| except subprocess.TimeoutExpired: | |
| return "⏰ Codex 回應超時 (180s)" | |
| except Exception as e: | |
| return f"❌ Codex 錯誤: {e}" | |
| def __del__(self): | |
| if self._sdk_available and hasattr(self, '_codex'): | |
| try: self._codex.__exit__(None, None, None) | |
| except: pass | |
| # ============================================================ | |
| # PROJECT TOOLS | |
| # ============================================================ | |
| class ProjectTools: | |
| def __init__(self, project_dir): | |
| self.project_dir = os.path.abspath(project_dir); self.cwd = self.project_dir; self.read_cache = {} | |
| def _resolve(self, path): | |
| return path if os.path.isabs(path) else os.path.normpath(os.path.join(self.cwd, path)) | |
| def read_file(self, path, offset=1, limit=200): | |
| full = self._resolve(path) | |
| if not os.path.exists(full): return f"❌ 不存在: {path}" | |
| # P2-3: 多模態檔案 | |
| mm = read_multimodal(full) | |
| if mm is not None: return mm | |
| try: | |
| content = Path(full).read_text(encoding="utf-8", errors="replace"); lines = content.splitlines() | |
| self.read_cache[full] = {"time": os.path.getmtime(full), "content": content} | |
| result = "\n".join(f"{i+offset:4d} │ {line}" for i, line in enumerate(lines[offset-1:offset-1+limit])) | |
| if offset + limit < len(lines): result += f"\n... ({len(lines)-offset-limit+1} more)" | |
| return result | |
| except Exception as e: return f"❌ {e}" | |
| def edit_file(self, path, old_string, new_string): | |
| full = self._resolve(path) | |
| if full not in self.read_cache: return "❌ 必須先 read_file" | |
| content = Path(full).read_text(encoding="utf-8") | |
| if os.path.getmtime(full) != self.read_cache[full]["time"]: return "❌ 文件已被外部修改" | |
| count = content.count(old_string) | |
| if count == 0: return "❌ 找不到要替換的文字" | |
| if count > 1: return f"❌ 找到 {count} 處,請提供更多上下文" | |
| new_content = content.replace(old_string, new_string, 1) | |
| diff = "".join(difflib.unified_diff(content.splitlines(keepends=True), new_content.splitlines(keepends=True), fromfile=f"a/{path}", tofile=f"b/{path}")) | |
| Path(full).write_text(new_content, encoding="utf-8") | |
| self.read_cache[full] = {"time": os.path.getmtime(full), "content": new_content} | |
| return "✅ 已修改:\n" + diff | |
| def write_file(self, path, content): | |
| full = self._resolve(path); os.makedirs(os.path.dirname(full) or ".", exist_ok=True) | |
| is_new = not os.path.exists(full); Path(full).write_text(content, encoding="utf-8") | |
| self.read_cache[full] = {"time": os.path.getmtime(full), "content": content} | |
| return f"✅ {'建立' if is_new else '覆寫'}: {path}" | |
| def run_command(self, command, timeout=120): | |
| # P2-4: 安全分類器 | |
| safety, reason = classify_command(command) | |
| if safety == "block": | |
| return f"⛔ 危險指令被阻擋: {command}\n原因: {reason}" | |
| if safety == "warn": | |
| return f"⚠️ 警告: {reason}\n指令: {command}\n(在 --approval ask 模式下會要求確認)" | |
| try: | |
| r = subprocess.run(command, shell=True, cwd=self.cwd, capture_output=True, text=True, timeout=timeout) | |
| return (r.stdout + (f"\nSTDERR:\n{r.stderr}" if r.stderr else ""))[:10000] | |
| except subprocess.TimeoutExpired: return "⏰ 超時" | |
| except Exception as e: return f"❌ {e}" | |
| def search_files(self, pattern, glob_pattern=None): | |
| rg = shutil.which("rg"); cmd = [rg or "grep", "-rn"] | |
| if rg: cmd += ["--color=never", "--max-count=50"] | |
| if glob_pattern and rg: cmd += ["--glob", glob_pattern] | |
| cmd += [pattern, self.cwd] | |
| try: return subprocess.run(cmd, capture_output=True, text=True, timeout=30).stdout[:5000] or "無匹配" | |
| except Exception as e: return f"❌ {e}" | |
| def list_files(self, pattern="*", max_depth=3): | |
| files = [] | |
| for root, dirs, fnames in os.walk(self.cwd): | |
| dirs[:] = [d for d in dirs if d not in {".git","node_modules","__pycache__",".venv","dist","build"}] | |
| if root.replace(self.cwd, "").count(os.sep) >= max_depth: continue | |
| files.extend(os.path.relpath(os.path.join(root, f), self.cwd) for f in fnames if Path(f).match(pattern)) | |
| return "\n".join(sorted(files)[:100]) | |
| def git_context(self): | |
| try: | |
| b = subprocess.run(["git","branch","--show-current"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() | |
| s = subprocess.run(["git","status","--short"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() | |
| l = subprocess.run(["git","log","--oneline","-5"], cwd=self.project_dir, capture_output=True, text=True).stdout.strip() | |
| return f"Branch: {b}\nStatus:\n{s}\nRecent:\n{l}" | |
| except: return "(not a git repo)" | |
| TOOL_PATTERN = re.compile(r'<tool>\s*(\w+)\s*\n(.*?)</tool>', re.DOTALL) | |
| # ============================================================ | |
| # P0-2: TOOL RESULT BUDGET REDUCTION(工具結果截斷) | |
| # ============================================================ | |
| MAX_TOOL_RESULT_CHARS = 12000 # ~3000 tokens | |
| def truncate_tool_result(result, max_chars=MAX_TOOL_RESULT_CHARS): | |
| """Claude Code 的 Budget Reduction — 限制每個工具結果大小""" | |
| if len(result) <= max_chars: | |
| return result | |
| head = max_chars * 2 // 3 | |
| tail = max_chars // 3 | |
| truncated_lines = len(result) - max_chars | |
| return (result[:head] | |
| + f"\n\n... ⚠️ Output truncated ({len(result):,} chars total, {truncated_lines:,} chars omitted) ...\n\n" | |
| + result[-tail:]) | |
| def parse_tool_calls(text): | |
| calls = [] | |
| for m in TOOL_PATTERN.finditer(text): | |
| try: params = json.loads(m.group(2).strip()) | |
| except: | |
| params = {} | |
| for line in m.group(2).strip().split("\n"): | |
| if ":" in line: k, v = line.split(":", 1); params[k.strip()] = v.strip().strip('"') | |
| calls.append({"tool": m.group(1), "params": params}) | |
| return calls | |
| def execute_tool(tools, call): | |
| n, p = call["tool"], call["params"] | |
| try: | |
| if n == "read_file": result = tools.read_file(p.get("path",""), int(p.get("offset",1)), int(p.get("limit",200))) | |
| elif n == "edit_file": result = tools.edit_file(p.get("path",""), p.get("old_string",""), p.get("new_string","")) | |
| elif n == "write_file": result = tools.write_file(p.get("path",""), p.get("content","")) | |
| elif n == "run_command": result = tools.run_command(p.get("command",""), int(p.get("timeout",120))) | |
| elif n == "search_files": result = tools.search_files(p.get("pattern",""), p.get("glob")) | |
| elif n == "list_files": result = tools.list_files(p.get("pattern","*"), int(p.get("max_depth",3))) | |
| elif n == "git_status": result = tools.git_context() | |
| elif n == "web_fetch": result = web_fetch(p.get("url","")) # P2-1 | |
| elif n == "web_search": result = web_search(p.get("query","")) # P2-1 | |
| else: result = f"❌ 未知: {n}" | |
| except Exception as e: result = f"❌ {e}" | |
| return truncate_tool_result(result) | |
| # ============================================================ | |
| # P2-1: WEB FETCH / WEB SEARCH | |
| # ============================================================ | |
| def web_fetch(url, max_chars=8000): | |
| """讀取網頁內容(去掉 HTML 標籤)""" | |
| try: | |
| if not httpx: return "❌ 請安裝 httpx: pip install httpx" | |
| resp = httpx.get(url, timeout=15, follow_redirects=True, | |
| headers={"User-Agent": "CodePilot/1.0"}) | |
| resp.raise_for_status() | |
| content = resp.text | |
| # 簡易去 HTML 標籤 | |
| content = re.sub(r'<script[^>]*>.*?</script>', '', content, flags=re.DOTALL) | |
| content = re.sub(r'<style[^>]*>.*?</style>', '', content, flags=re.DOTALL) | |
| content = re.sub(r'<[^>]+>', ' ', content) | |
| content = re.sub(r'\s+', ' ', content).strip() | |
| return content[:max_chars] | |
| except Exception as e: | |
| return f"❌ 抓取失敗: {e}" | |
| def web_search(query, max_results=5): | |
| """網路搜尋(使用 DuckDuckGo HTML,不需要 API key)""" | |
| try: | |
| if not httpx: return "❌ 請安裝 httpx: pip install httpx" | |
| resp = httpx.get("https://html.duckduckgo.com/html/", | |
| params={"q": query}, timeout=10, | |
| headers={"User-Agent": "CodePilot/1.0"}) | |
| # 提取搜尋結果 | |
| results = [] | |
| for match in re.finditer(r'<a[^>]+href="(https?://[^"]+)"[^>]*class="result__a"[^>]*>(.*?)</a>', resp.text, re.DOTALL): | |
| url = match.group(1) | |
| title = re.sub(r'<[^>]+>', '', match.group(2)).strip() | |
| results.append(f"- [{title}]({url})") | |
| if len(results) >= max_results: break | |
| # 也嘗試提取摘要 | |
| for match in re.finditer(r'<a[^>]+class="result__snippet"[^>]*>(.*?)</a>', resp.text, re.DOTALL): | |
| snippet = re.sub(r'<[^>]+>', '', match.group(1)).strip() | |
| if snippet and len(results) > 0: | |
| idx = min(len(results)-1, len([r for r in results if not r.startswith(" ")]) - 1) | |
| if idx >= 0: results.insert(idx+1, f" {snippet[:150]}") | |
| return "\n".join(results) if results else f"無搜尋結果: {query}" | |
| except Exception as e: | |
| return f"❌ 搜尋失敗: {e}" | |
| # ============================================================ | |
| # P2-2: STREAMING OUTPUT(逐字輸出) | |
| # ============================================================ | |
| def stream_local_chat(model, messages, console, max_tokens=4096): | |
| """本地模型 streaming — 逐 token 顯示""" | |
| if not hasattr(model, 'tokenizer') or not hasattr(model, 'model'): | |
| return model.chat(messages, max_tokens) # 非本地模型 fallback | |
| from transformers import TextIteratorStreamer | |
| import threading | |
| text = model.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = model.tokenizer(text, return_tensors="pt").to(model.model.device) | |
| streamer = TextIteratorStreamer(model.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| gen_kwargs = dict(**inputs, max_new_tokens=max_tokens, do_sample=True, | |
| temperature=0.7, top_p=0.9, repetition_penalty=1.1, | |
| pad_token_id=model.tokenizer.pad_token_id, streamer=streamer) | |
| thread = threading.Thread(target=model.model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| console.print(f"\n[bold blue]🤖 CodePilot:[/]", end="") | |
| full_text = "" | |
| for chunk in streamer: | |
| print(chunk, end="", flush=True) | |
| full_text += chunk | |
| print() # newline | |
| thread.join() | |
| return full_text | |
| # ============================================================ | |
| # P2-3: MULTIMODAL(圖片/PDF 讀取) | |
| # ============================================================ | |
| def read_multimodal(path): | |
| """讀取圖片/PDF/notebook 的文字描述""" | |
| ext = Path(path).suffix.lower() | |
| if ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg"): | |
| # 圖片:回傳檔案資訊 | |
| try: | |
| size = os.path.getsize(path) | |
| return f"[Image: {path}, {size/1024:.0f}KB, {ext}]\n(圖片內容無法在文字模式顯示。如需分析圖片,請使用支援多模態的雲端模型。)" | |
| except: return f"❌ 無法讀取圖片: {path}" | |
| elif ext == ".pdf": | |
| # PDF:嘗試用 pdfminer 或 fallback | |
| try: | |
| from pdfminer.high_level import extract_text | |
| text = extract_text(path, maxpages=20) | |
| return f"[PDF: {path}, {len(text)} chars extracted]\n\n{text[:10000]}" | |
| except ImportError: | |
| try: | |
| # fallback: pdftotext 指令 | |
| r = subprocess.run(["pdftotext", "-l", "20", path, "-"], | |
| capture_output=True, text=True, timeout=30) | |
| return f"[PDF: {path}]\n\n{r.stdout[:10000]}" | |
| except: | |
| return f"[PDF: {path}] (安裝 pdfminer.six 以讀取: pip install pdfminer.six)" | |
| elif ext == ".ipynb": | |
| # Jupyter Notebook:提取 code cells 和 markdown | |
| try: | |
| nb = json.loads(Path(path).read_text()) | |
| cells = nb.get("cells", []) | |
| output = [] | |
| for i, cell in enumerate(cells[:30]): | |
| ctype = cell.get("cell_type", "") | |
| source = "".join(cell.get("source", [])) | |
| if ctype == "markdown": | |
| output.append(f"[Markdown Cell {i+1}]\n{source}") | |
| elif ctype == "code": | |
| output.append(f"[Code Cell {i+1}]\n```python\n{source}\n```") | |
| return "\n\n".join(output)[:10000] | |
| except Exception as e: | |
| return f"❌ 無法讀取 notebook: {e}" | |
| return None # 非多模態檔案 | |
| # ============================================================ | |
| # P2-4: SHELL SANDBOX(指令安全分類) | |
| # ============================================================ | |
| # 不用 ML,用規則分類 — 比 ML 更可靠且不需要額外模型 | |
| DANGEROUS_PATTERNS = [ | |
| r"rm\s+(-rf?|--recursive)\s+[/~]", # rm -rf / | |
| r"rm\s+-rf?\s+\.", # rm -rf . | |
| r">(>?)\s*/dev/sd", # 覆寫磁碟 | |
| r"mkfs\.", # 格式化 | |
| r"dd\s+if=", # 磁碟操作 | |
| r":()\{.*\|.*&\s*\};:", # fork bomb | |
| r"chmod\s+777\s+/", # 危險權限 | |
| r"curl.*\|\s*(bash|sh)", # pipe to shell | |
| r"wget.*\|\s*(bash|sh)", # pipe to shell | |
| ] | |
| WARN_PATTERNS = [ | |
| r"git\s+push\s+.*--force", # force push | |
| r"git\s+reset\s+--hard", # hard reset | |
| r"git\s+clean\s+-fd", # clean untracked | |
| r"npm\s+publish", # publish package | |
| r"pip\s+install\s+--force", # force install | |
| r"docker\s+system\s+prune", # docker cleanup | |
| r"DROP\s+TABLE", # SQL drop | |
| r"DELETE\s+FROM\s+\w+\s*;?\s*$", # SQL delete all | |
| r"sudo\s+", # sudo | |
| ] | |
| def classify_command(command): | |
| """ | |
| 分類指令安全等級: | |
| - 'block': 直接阻擋 | |
| - 'warn': 需要額外確認 | |
| - 'safe': 安全 | |
| """ | |
| for p in DANGEROUS_PATTERNS: | |
| if re.search(p, command, re.IGNORECASE): | |
| return "block", f"危險指令匹配: {p}" | |
| for p in WARN_PATTERNS: | |
| if re.search(p, command, re.IGNORECASE): | |
| return "warn", f"需要確認: {p}" | |
| return "safe", "" | |
| # ============================================================ | |
| # P2-5: MCP LITE(簡易外部工具協議) | |
| # ============================================================ | |
| class MCPLite: | |
| """ | |
| 簡易 MCP — 讀取 .codepilot/mcp.json,連接外部工具伺服器。 | |
| 支援 stdio 和 http 兩種傳輸方式。 | |
| .codepilot/mcp.json: | |
| { | |
| "servers": { | |
| "database": { | |
| "command": "python db_mcp_server.py", | |
| "type": "stdio" | |
| }, | |
| "api": { | |
| "url": "http://localhost:9000/mcp", | |
| "type": "http" | |
| } | |
| } | |
| } | |
| """ | |
| def __init__(self, project_dir): | |
| self.servers = {} | |
| self.processes = {} | |
| mcp_file = os.path.join(project_dir, ".codepilot", "mcp.json") | |
| if os.path.exists(mcp_file): | |
| try: | |
| config = json.loads(Path(mcp_file).read_text()) | |
| self.servers = config.get("servers", {}) | |
| except: pass | |
| def call(self, server_name, method, params=None): | |
| """呼叫 MCP 伺服器""" | |
| server = self.servers.get(server_name) | |
| if not server: | |
| return f"❌ MCP 伺服器不存在: {server_name}(可用: {', '.join(self.servers.keys())})" | |
| if server.get("type") == "http": | |
| return self._call_http(server, method, params) | |
| else: | |
| return self._call_stdio(server_name, server, method, params) | |
| def _call_http(self, server, method, params): | |
| try: | |
| if not httpx: return "❌ 需要 httpx" | |
| resp = httpx.post(server["url"], json={ | |
| "jsonrpc": "2.0", "id": 1, "method": method, | |
| "params": params or {} | |
| }, timeout=30) | |
| resp.raise_for_status() | |
| result = resp.json() | |
| return json.dumps(result.get("result", result), ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| return f"❌ MCP HTTP 錯誤: {e}" | |
| def _call_stdio(self, name, server, method, params): | |
| try: | |
| # 啟動進程(如果還沒啟動) | |
| if name not in self.processes or self.processes[name].poll() is not None: | |
| self.processes[name] = subprocess.Popen( | |
| server["command"], shell=True, | |
| stdin=subprocess.PIPE, stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, text=True) | |
| proc = self.processes[name] | |
| request = json.dumps({"jsonrpc": "2.0", "id": 1, "method": method, | |
| "params": params or {}}) + "\n" | |
| proc.stdin.write(request) | |
| proc.stdin.flush() | |
| # 讀取回應(1 行 JSON) | |
| import select | |
| ready, _, _ = select.select([proc.stdout], [], [], 10) | |
| if ready: | |
| line = proc.stdout.readline() | |
| result = json.loads(line) | |
| return json.dumps(result.get("result", result), ensure_ascii=False, indent=2) | |
| return "⏰ MCP 伺服器無回應" | |
| except Exception as e: | |
| return f"❌ MCP stdio 錯誤: {e}" | |
| def list_servers(self): | |
| if not self.servers: return "(無 MCP 伺服器。建立 .codepilot/mcp.json)" | |
| lines = [] | |
| for name, cfg in self.servers.items(): | |
| stype = cfg.get("type", "stdio") | |
| target = cfg.get("url", cfg.get("command", "?")) | |
| lines.append(f" 🔌 {name} ({stype}): {target}") | |
| return "\n".join(lines) | |
| def cleanup(self): | |
| for proc in self.processes.values(): | |
| try: proc.kill() | |
| except: pass | |
| # ============================================================ | |
| # P0-1: /init 自動產生 CODEPILOT.md | |
| # ============================================================ | |
| def cmd_init(tools, model, console): | |
| """掃描專案結構,用模型自動產生 CODEPILOT.md""" | |
| console.print("\n[bold]🔍 掃描專案結構...[/]") | |
| # 收集專案資訊 | |
| file_list = tools.list_files("*", max_depth=2) | |
| git = tools.git_context() | |
| # 嘗試讀取關鍵檔案 | |
| key_files = {} | |
| for f in ["README.md", "README.rst", "package.json", "pyproject.toml", | |
| "requirements.txt", "Cargo.toml", "go.mod", "Makefile", | |
| "docker-compose.yml", "Dockerfile", ".gitignore"]: | |
| full = os.path.join(tools.project_dir, f) | |
| if os.path.exists(full): | |
| try: | |
| content = Path(full).read_text(encoding="utf-8", errors="replace")[:3000] | |
| key_files[f] = content | |
| except: pass | |
| key_files_text = "\n\n".join(f"--- {k} ---\n{v}" for k, v in key_files.items()) | |
| prompt = f"""Analyze this project and generate a CODEPILOT.md configuration file. | |
| ## Project Files (top 2 levels) | |
| {file_list[:3000]} | |
| ## Git Info | |
| {git} | |
| ## Key Config Files | |
| {key_files_text[:6000]} | |
| ## Instructions | |
| Generate a markdown file with these sections: | |
| 1. **Project Overview** — one-line description | |
| 2. **Tech Stack** — languages, frameworks, databases | |
| 3. **Code Style** — formatting tools, naming conventions | |
| 4. **Testing** — test framework, how to run tests | |
| 5. **Key Commands** — build, run, test, lint commands | |
| 6. **Architecture** — key directories and their purpose | |
| 7. **Rules** — important rules for AI to follow (e.g., "always write tests", "use TypeScript strict mode") | |
| Be concise. Use bullet points. Write in the language matching the project (Chinese if README is Chinese, English otherwise).""" | |
| with console.status("[bold cyan]分析專案中..."): | |
| result = model.chat([{"role": "user", "content": prompt}], max_tokens=2048) | |
| codepilot_path = os.path.join(tools.project_dir, "CODEPILOT.md") | |
| Path(codepilot_path).write_text(result, encoding="utf-8") | |
| console.print(f"\n[green]✅ 已產生 CODEPILOT.md[/]") | |
| console.print(f"[dim]{result[:500]}...[/]") | |
| console.print(f"\n[dim]檢查並編輯: {codepilot_path}[/]") | |
| return result | |
| # ============================================================ | |
| # P0-3: ERROR RECOVERY(錯誤自動恢復) | |
| # ============================================================ | |
| MAX_RETRIES = 3 | |
| def chat_with_recovery(model, messages, ctx=None, console=None, fallback_model=None): | |
| """帶自動恢復的 model.chat — 重試 + 壓縮 + fallback""" | |
| last_error = None | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| return model.chat(messages) | |
| except Exception as e: | |
| last_error = e | |
| error_str = str(e).lower() | |
| if console: | |
| console.print(f" [yellow]⚠️ 嘗試 {attempt+1}/{MAX_RETRIES}: {type(e).__name__}[/]") | |
| # 策略 1: context 太長 → 壓縮 | |
| if any(k in error_str for k in ["too long", "too_long", "context_length", "max_tokens", "prompt_too_long"]): | |
| if ctx and hasattr(ctx, 'check_compact'): | |
| if console: console.print(" [dim]🔄 壓縮對話歷史...[/]") | |
| messages = ctx.check_compact(messages, model_chat_fn=model.chat) | |
| continue | |
| else: | |
| # 手動截斷 | |
| if len(messages) > 6: | |
| messages = [messages[0]] + messages[-4:] | |
| continue | |
| # 策略 2: rate limit → 等待重試 | |
| if any(k in error_str for k in ["rate_limit", "429", "too many"]): | |
| wait = 2 ** attempt * 5 # 5s, 10s, 20s | |
| if console: console.print(f" [dim]⏳ Rate limit, 等待 {wait}s...[/]") | |
| time.sleep(wait) | |
| continue | |
| # 策略 3: 伺服器錯誤 → 等待重試 | |
| if any(k in error_str for k in ["500", "502", "503", "server", "timeout", "connection"]): | |
| wait = 2 ** attempt * 3 | |
| if console: console.print(f" [dim]⏳ 伺服器錯誤, 等待 {wait}s...[/]") | |
| time.sleep(wait) | |
| continue | |
| # 策略 4: 切換 fallback model | |
| if fallback_model and attempt == MAX_RETRIES - 1: | |
| if console: console.print(f" [yellow]🔄 切換到 fallback 模型...[/]") | |
| try: | |
| return fallback_model.chat(messages) | |
| except: pass | |
| # 其他錯誤直接 break | |
| break | |
| raise last_error or RuntimeError("chat failed after retries") | |
| # ============================================================ | |
| # P0-4: VERIFICATION SUB-AGENT(驗證子代理) | |
| # ============================================================ | |
| def run_verification(model, tools, console, edited_files=None): | |
| """完成修改後自動跑測試驗證""" | |
| console.print("\n[bold]🔍 Verification Agent[/]") | |
| checks = [] | |
| # 1. 語法檢查修改過的 Python 文件 | |
| if edited_files: | |
| for f in edited_files: | |
| if f.endswith(".py") and os.path.exists(f): | |
| try: | |
| content = Path(f).read_text() | |
| compile(content, f, "exec") | |
| checks.append(f" ✅ {os.path.basename(f)} 語法正確") | |
| except SyntaxError as e: | |
| checks.append(f" ❌ {os.path.basename(f)} 語法錯誤: {e.msg} (line {e.lineno})") | |
| # 2. 嘗試跑 pytest / npm test | |
| test_commands = [] | |
| if os.path.exists(os.path.join(tools.project_dir, "pytest.ini")) or \ | |
| os.path.exists(os.path.join(tools.project_dir, "tests")) or \ | |
| os.path.exists(os.path.join(tools.project_dir, "test")): | |
| test_commands.append(("pytest", f"{sys.executable} -m pytest --tb=short -q")) | |
| if os.path.exists(os.path.join(tools.project_dir, "package.json")): | |
| test_commands.append(("npm test", "npm test --if-present 2>&1 | head -30")) | |
| if os.path.exists(os.path.join(tools.project_dir, "Makefile")): | |
| # 檢查是否有 test target | |
| makefile = Path(os.path.join(tools.project_dir, "Makefile")).read_text(errors="replace") | |
| if "test:" in makefile: | |
| test_commands.append(("make test", "make test 2>&1 | tail -20")) | |
| for name, cmd in test_commands: | |
| console.print(f" [dim]🧪 Running {name}...[/]") | |
| result = tools.run_command(cmd, timeout=60) | |
| # 判斷通過/失敗 | |
| result_lower = result.lower() | |
| if any(k in result_lower for k in ["passed", "ok", "success", "0 error"]): | |
| passed_match = re.search(r'(\d+) passed', result) | |
| n = passed_match.group(1) if passed_match else "" | |
| checks.append(f" ✅ {name}: {n} passed" if n else f" ✅ {name}: OK") | |
| elif any(k in result_lower for k in ["failed", "error", "fail"]): | |
| # 只顯示最後幾行 | |
| last_lines = "\n".join(result.strip().split("\n")[-5:]) | |
| checks.append(f" ❌ {name}: FAILED\n{last_lines}") | |
| else: | |
| checks.append(f" ⚠️ {name}: {result[:200]}") | |
| if not checks: | |
| checks.append(" [dim]沒有找到測試框架[/]") | |
| for c in checks: | |
| console.print(c) | |
| return checks | |
| # ============================================================ | |
| # P0-BONUS: HOOKS SYSTEM(post-edit 自動格式化) | |
| # ============================================================ | |
| class Hooks: | |
| """簡易 Hooks 系統 — 讀取 .codepilot/hooks.json""" | |
| def __init__(self, project_dir): | |
| self.project_dir = project_dir | |
| self.hooks = {} | |
| hooks_file = os.path.join(project_dir, ".codepilot", "hooks.json") | |
| if os.path.exists(hooks_file): | |
| try: | |
| self.hooks = json.loads(Path(hooks_file).read_text()) | |
| except: pass | |
| def run(self, event, context=None): | |
| """執行 hook。context = {"file": "path/to/file.py"} 等""" | |
| cmd_template = self.hooks.get(event) | |
| if not cmd_template: | |
| return None | |
| cmd = cmd_template | |
| if context: | |
| for k, v in context.items(): | |
| cmd = cmd.replace(f"{{{k}}}", str(v)) | |
| try: | |
| result = subprocess.run(cmd, shell=True, cwd=self.project_dir, | |
| capture_output=True, text=True, timeout=30) | |
| return result.stdout + result.stderr if result.returncode != 0 else None | |
| except: | |
| return None | |
| # ============================================================ | |
| # SKILL SYSTEM(技能系統 — 仿 Claude Code SkillTool) | |
| # ============================================================ | |
| """ | |
| Skill 和 Agent 的關鍵差異(來自 Claude Code 原始碼): | |
| - Skill → 注入指令到「當前」context window(不建新 context) | |
| - Agent → spawn 一個「新的」隔離 context window | |
| Skill 定義方式: | |
| .codepilot/skills/<name>/SKILL.md | |
| SKILL.md 格式: | |
| --- | |
| name: API Generator | |
| description: Generate RESTful API endpoints from a data model | |
| tools: [read_file, edit_file, write_file, run_command] | |
| arguments: | |
| - name: model_file | |
| description: Path to the data model file | |
| - name: framework | |
| description: Web framework (fastapi, express, gin) | |
| default: fastapi | |
| hooks: | |
| post_edit_file: "black {file}" | |
| --- | |
| 你是一位 API 專家。根據用戶提供的 data model,生成完整的 RESTful CRUD API。 | |
| 步驟: | |
| 1. 讀取 model_file 了解數據結構 | |
| 2. 生成路由文件 | |
| 3. 生成測試文件 | |
| 4. 執行測試確認通過 | |
| 內建 Skills(bundled): | |
| - create-skill: 幫你建立新的 skill | |
| - refactor: 重構程式碼 | |
| - test-gen: 自動產生測試 | |
| - doc-gen: 自動產生文檔 | |
| - debug: 除錯助手 | |
| """ | |
| class SkillManager: | |
| """管理和執行 Skills""" | |
| def __init__(self, project_dir): | |
| self.project_dir = project_dir | |
| self.skills = {} | |
| # 載入自訂 skills | |
| skills_dir = os.path.join(project_dir, ".codepilot", "skills") | |
| if os.path.isdir(skills_dir): | |
| for skill_dir in Path(skills_dir).iterdir(): | |
| if skill_dir.is_dir(): | |
| skill_md = skill_dir / "SKILL.md" | |
| if skill_md.exists(): | |
| skill = self._parse_skill(skill_md) | |
| if skill: | |
| self.skills[skill["name"]] = skill | |
| # 載入全域 skills | |
| global_skills = CONFIG_DIR / "skills" if isinstance(CONFIG_DIR, Path) else Path(CONFIG_DIR) / "skills" | |
| if global_skills.is_dir(): | |
| for skill_dir in global_skills.iterdir(): | |
| if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists(): | |
| skill = self._parse_skill(skill_dir / "SKILL.md") | |
| if skill and skill["name"] not in self.skills: | |
| self.skills[skill["name"]] = skill | |
| # 註冊內建 bundled skills | |
| self._register_bundled_skills() | |
| def _parse_skill(self, skill_md_path): | |
| """解析 SKILL.md""" | |
| try: | |
| content = Path(skill_md_path).read_text(encoding="utf-8") | |
| skill = { | |
| "name": skill_md_path.parent.name, | |
| "path": str(skill_md_path.parent), | |
| "description": "", | |
| "prompt": content, | |
| "tools": None, # None = 全部工具, list = 限定 | |
| "arguments": [], | |
| "hooks": {}, | |
| "model": None, | |
| "fork": False, # True = 在隔離 context 中執行 | |
| } | |
| # 解析 YAML frontmatter | |
| if content.startswith("---"): | |
| parts = content.split("---", 2) | |
| if len(parts) >= 3: | |
| for line in parts[1].strip().split("\n"): | |
| line = line.strip() | |
| if not line or line.startswith("#"): | |
| continue | |
| if ":" in line: | |
| k, v = line.split(":", 1) | |
| k, v = k.strip(), v.strip() | |
| if k == "name": skill["name"] = v | |
| elif k == "description": skill["description"] = v | |
| elif k == "model": skill["model"] = v | |
| elif k == "fork": skill["fork"] = v.lower() in ("true", "yes", "1") | |
| elif k == "tools": | |
| if v.startswith("["): | |
| skill["tools"] = [x.strip().strip("'\"") for x in v[1:-1].split(",")] | |
| elif k == "arguments": | |
| pass # 複雜結構,在下面處理 | |
| elif k == "hooks": | |
| pass # 在下面處理 | |
| # 解析 arguments(簡易版) | |
| in_args = False | |
| current_arg = {} | |
| for line in parts[1].strip().split("\n"): | |
| line = line.strip() | |
| if line.startswith("arguments:"): | |
| in_args = True; continue | |
| if in_args: | |
| if line.startswith("- name:"): | |
| if current_arg: skill["arguments"].append(current_arg) | |
| current_arg = {"name": line.split(":", 1)[1].strip()} | |
| elif line.startswith("description:") and current_arg: | |
| current_arg["description"] = line.split(":", 1)[1].strip() | |
| elif line.startswith("default:") and current_arg: | |
| current_arg["default"] = line.split(":", 1)[1].strip() | |
| elif not line.startswith(" ") and not line.startswith("-"): | |
| in_args = False | |
| if current_arg and "name" in current_arg: | |
| skill["arguments"].append(current_arg) | |
| # hooks | |
| in_hooks = False | |
| for line in parts[1].strip().split("\n"): | |
| line = line.strip() | |
| if line.startswith("hooks:"): | |
| in_hooks = True; continue | |
| if in_hooks and ":" in line and line.startswith(" "): | |
| hk, hv = line.strip().split(":", 1) | |
| skill["hooks"][hk.strip()] = hv.strip().strip('"').strip("'") | |
| elif in_hooks and not line.startswith(" "): | |
| in_hooks = False | |
| skill["prompt"] = parts[2].strip() | |
| return skill | |
| except Exception as e: | |
| return None | |
| def _register_bundled_skills(self): | |
| """註冊內建 skills""" | |
| bundled = { | |
| "create-skill": { | |
| "name": "create-skill", | |
| "description": "建立新的 skill", | |
| "prompt": """幫用戶在 .codepilot/skills/<name>/SKILL.md 建立一個新的 skill。 | |
| 先問用戶: | |
| 1. Skill 名稱 | |
| 2. 用途描述 | |
| 3. 需要用到哪些工具 | |
| 然後產生 SKILL.md,包含 YAML frontmatter 和詳細指令。""", | |
| "tools": ["write_file", "list_files"], | |
| "arguments": [{"name": "name", "description": "skill 名稱"}], | |
| "hooks": {}, | |
| "fork": False, | |
| "path": "(bundled)", | |
| }, | |
| "refactor": { | |
| "name": "refactor", | |
| "description": "重構程式碼:提取函數、重命名、簡化邏輯", | |
| "prompt": """你是重構專家。閱讀用戶指定的文件,進行以下改進: | |
| 1. 提取重複的程式碼為函數 | |
| 2. 改善命名(變數、函數、類別) | |
| 3. 簡化複雜的條件邏輯 | |
| 4. 加入或改進 docstring | |
| 5. 確保修改後測試仍然通過 | |
| 每次只做一個小修改,驗證後再做下一個。""", | |
| "tools": ["read_file", "edit_file", "run_command", "search_files"], | |
| "arguments": [{"name": "file", "description": "要重構的文件路徑"}], | |
| "hooks": {}, | |
| "fork": False, | |
| "path": "(bundled)", | |
| }, | |
| "test-gen": { | |
| "name": "test-gen", | |
| "description": "自動產生測試", | |
| "prompt": """你是測試工程師。為用戶指定的文件或函數產生完整的測試。 | |
| 步驟: | |
| 1. 讀取原始碼,了解所有公開函數和類別 | |
| 2. 為每個函數產生:正常輸入、邊界值、錯誤輸入的測試 | |
| 3. 使用專案現有的測試框架(pytest/jest/等) | |
| 4. 把測試寫入對應的 tests/ 目錄 | |
| 5. 執行測試確認通過""", | |
| "tools": ["read_file", "write_file", "run_command", "search_files", "list_files"], | |
| "arguments": [{"name": "file", "description": "要產生測試的文件"}], | |
| "hooks": {}, | |
| "fork": False, | |
| "path": "(bundled)", | |
| }, | |
| "doc-gen": { | |
| "name": "doc-gen", | |
| "description": "自動產生文檔(docstring / README / API docs)", | |
| "prompt": """你是技術文件專家。為用戶的程式碼產生或改善文檔。 | |
| 可以: | |
| 1. 為所有函數加上 docstring | |
| 2. 產生或更新 README.md | |
| 3. 產生 API 文檔(如有 web framework) | |
| 4. 產生 CHANGELOG | |
| 根據用戶的要求決定做哪個。""", | |
| "tools": ["read_file", "edit_file", "write_file", "search_files", "list_files"], | |
| "arguments": [{"name": "target", "description": "文件或目錄", "default": "."}], | |
| "hooks": {}, | |
| "fork": False, | |
| "path": "(bundled)", | |
| }, | |
| "debug": { | |
| "name": "debug", | |
| "description": "除錯助手:分析錯誤訊息、找出原因、修復", | |
| "prompt": """你是除錯專家。用戶會給你一個錯誤訊息或描述問題。 | |
| 步驟: | |
| 1. 分析錯誤訊息,定位問題文件和行數 | |
| 2. 讀取相關程式碼 | |
| 3. 搜尋可能相關的其他文件 | |
| 4. 找出根本原因 | |
| 5. 提出修復方案 | |
| 6. 實施修復 | |
| 7. 跑測試驗證 | |
| 先分析再動手,不要急著改。""", | |
| "tools": ["read_file", "edit_file", "run_command", "search_files", "list_files", "git_status"], | |
| "arguments": [{"name": "error", "description": "錯誤訊息或問題描述"}], | |
| "hooks": {}, | |
| "fork": False, | |
| "path": "(bundled)", | |
| }, | |
| } | |
| for name, skill in bundled.items(): | |
| if name not in self.skills: | |
| self.skills[name] = skill | |
| def list_skills(self): | |
| """列出所有可用 skills""" | |
| lines = [] | |
| bundled = [] | |
| custom = [] | |
| for name, s in sorted(self.skills.items()): | |
| icon = "📦" if s.get("path") == "(bundled)" else "🔧" | |
| desc = s.get("description", "") | |
| args = ", ".join(a["name"] for a in s.get("arguments", [])) | |
| entry = f" {icon} {name}: {desc}" | |
| if args: entry += f" [dim]({args})[/]" | |
| if s.get("path") == "(bundled)": | |
| bundled.append(entry) | |
| else: | |
| custom.append(entry) | |
| if custom: | |
| lines.append("[bold]自訂 Skills:[/]") | |
| lines.extend(custom) | |
| if bundled: | |
| lines.append("[bold]內建 Skills:[/]") | |
| lines.extend(bundled) | |
| return "\n".join(lines) if lines else "(無 skill。用 /skill create-skill 建立)" | |
| def invoke(self, skill_name, args_dict, model, tools, console, messages=None): | |
| """ | |
| 執行 skill。 | |
| 核心差異:skill 注入到當前 context(不像 agent 建新 context) | |
| """ | |
| from rich.markdown import Markdown | |
| skill = self.skills.get(skill_name) | |
| if not skill: | |
| console.print(f"[red]❌ 未知 skill: {skill_name}[/]") | |
| console.print(self.list_skills()) | |
| return None, None | |
| console.print(f"\n[bold magenta]⚡ Skill: {skill['name']}[/] — {skill.get('description','')}") | |
| # 組裝 skill prompt + 用戶參數 | |
| skill_prompt = skill["prompt"] | |
| # 替換參數 | |
| for arg_def in skill.get("arguments", []): | |
| arg_name = arg_def["name"] | |
| arg_val = args_dict.get(arg_name, arg_def.get("default", "")) | |
| skill_prompt = skill_prompt.replace(f"{{{arg_name}}}", str(arg_val)) | |
| if skill.get("fork"): | |
| # Fork 模式:隔離 context(像 agent) | |
| console.print(f" [dim](fork mode — 隔離 context)[/]") | |
| fork_messages = [ | |
| {"role": "system", "content": skill_prompt}, | |
| {"role": "user", "content": json.dumps(args_dict, ensure_ascii=False)}, | |
| ] | |
| full_response = "" | |
| for rnd in range(8): | |
| with console.status(f"[magenta]{skill_name} round {rnd+1}..."): | |
| try: response = model.chat(fork_messages) | |
| except: break | |
| tcalls = parse_tool_calls(response) | |
| text = TOOL_PATTERN.sub("", response).strip() | |
| if text: console.print(Markdown(text)) | |
| full_response += response + "\n" | |
| if not tcalls: break | |
| fork_messages.append({"role": "assistant", "content": response}) | |
| results = [] | |
| for call in tcalls: | |
| # 工具權限過濾 | |
| if skill.get("tools") and call["tool"] not in skill["tools"]: | |
| results.append(f"[{call['tool']}] ❌ 此 skill 不允許"); continue | |
| result = execute_tool(tools, call) | |
| console.print(f" [dim]🔧 {call['tool']}[/]") | |
| results.append(f"[{call['tool']}] {result}") | |
| # 觸發 skill 自帶的 hooks | |
| if call["tool"] in ("edit_file", "write_file"): | |
| fpath = call["params"].get("path", "") | |
| hook_cmd = skill.get("hooks", {}).get(f"post_{call['tool']}") | |
| if hook_cmd and fpath: | |
| subprocess.run(hook_cmd.replace("{file}", fpath), shell=True, | |
| cwd=tools.project_dir, capture_output=True, timeout=30) | |
| fork_messages.append({"role": "user", "content": "Tool results:\n" + "\n\n".join(results)}) | |
| return full_response, None | |
| else: | |
| # 注入模式(預設):把 skill 指令注入當前 context | |
| inject_msg = f"[Skill: {skill_name}]\n\n{skill_prompt}\n\nUser arguments: {json.dumps(args_dict, ensure_ascii=False)}" | |
| return None, inject_msg # 回傳注入內容,由主循環處理 | |
| # ============================================================ | |
| # P1-1: APPROVAL SYSTEM(權限/審批) | |
| # ============================================================ | |
| APPROVAL_MODES = { | |
| "auto": "全自動(只擋危險指令)", | |
| "auto-edit": "文件修改自動,shell 指令要確認", | |
| "ask": "每次工具呼叫都確認", | |
| } | |
| # 不需要確認的工具(只讀) | |
| SAFE_TOOLS = {"read_file", "search_files", "list_files", "git_status"} | |
| def check_approval(tool_name, params, approval_mode, console): | |
| """檢查工具是否需要用戶確認。回傳 True = 允許, False = 拒絕""" | |
| if approval_mode == "auto": | |
| return True # 全自動(危險指令在 run_command 裡已經擋了) | |
| if tool_name in SAFE_TOOLS: | |
| return True # 只讀工具永遠通過 | |
| if approval_mode == "auto-edit" and tool_name in ("edit_file", "write_file"): | |
| return True # auto-edit 模式下文件修改自動通過 | |
| # 需要用戶確認 | |
| from rich.prompt import Confirm | |
| param_preview = json.dumps(params, ensure_ascii=False)[:120] | |
| console.print(f" [yellow]⚠️ {tool_name}({param_preview})[/]") | |
| return Confirm.ask(" 允許執行?", default=True) | |
| # ============================================================ | |
| # P1-2: BACKGROUND TASKS(背景任務管理) | |
| # ============================================================ | |
| import threading, uuid as _uuid | |
| class BackgroundTaskManager: | |
| """背景任務管理器 — 長時間指令不阻塞主循環""" | |
| def __init__(self): | |
| self._tasks = {} # id → {process, command, start_time, output} | |
| def start(self, command, cwd): | |
| """啟動背景任務""" | |
| task_id = str(_uuid.uuid4())[:6] | |
| proc = subprocess.Popen( | |
| command, shell=True, cwd=cwd, | |
| stdout=subprocess.PIPE, stderr=subprocess.STDOUT, | |
| text=True) | |
| self._tasks[task_id] = { | |
| "process": proc, "command": command, | |
| "start_time": datetime.now(), "output_lines": [] | |
| } | |
| # 背景讀取輸出 | |
| def _reader(): | |
| for line in proc.stdout: | |
| self._tasks[task_id]["output_lines"].append(line) | |
| threading.Thread(target=_reader, daemon=True).start() | |
| return task_id | |
| def check(self, task_id): | |
| """檢查任務狀態""" | |
| t = self._tasks.get(task_id) | |
| if not t: return {"status": "not_found"} | |
| running = t["process"].poll() is None | |
| elapsed = (datetime.now() - t["start_time"]).seconds | |
| output = "".join(t["output_lines"][-20:]) # 最後 20 行 | |
| return { | |
| "status": "running" if running else "done", | |
| "exit_code": t["process"].returncode, | |
| "elapsed": elapsed, | |
| "output": output, | |
| "command": t["command"], | |
| } | |
| def list_tasks(self): | |
| """列出所有背景任務""" | |
| results = [] | |
| for tid, t in self._tasks.items(): | |
| running = t["process"].poll() is None | |
| elapsed = (datetime.now() - t["start_time"]).seconds | |
| results.append(f" {'🟢' if running else '⚫'} {tid}: {t['command'][:50]} ({elapsed}s)") | |
| return "\n".join(results) if results else " (無背景任務)" | |
| def kill(self, task_id): | |
| """終止任務""" | |
| t = self._tasks.get(task_id) | |
| if t and t["process"].poll() is None: | |
| t["process"].kill() | |
| return True | |
| return False | |
| # ============================================================ | |
| # P1-3: CUSTOM AGENTS(自訂代理 .codepilot/agents/*.md) | |
| # ============================================================ | |
| def load_custom_agents(project_dir): | |
| """載入 .codepilot/agents/*.md 自訂代理""" | |
| agents_dir = os.path.join(project_dir, ".codepilot", "agents") | |
| agents = {} | |
| if not os.path.isdir(agents_dir): | |
| return agents | |
| for f in sorted(Path(agents_dir).glob("*.md")): | |
| content = f.read_text(encoding="utf-8") | |
| name = f.stem | |
| # 解析 YAML frontmatter | |
| config = {"name": name, "prompt": content} | |
| if content.startswith("---"): | |
| parts = content.split("---", 2) | |
| if len(parts) >= 3: | |
| try: | |
| # 簡易 YAML 解析 | |
| for line in parts[1].strip().split("\n"): | |
| if ":" in line: | |
| k, v = line.split(":", 1) | |
| k, v = k.strip(), v.strip() | |
| if v.startswith("[") and v.endswith("]"): | |
| v = [x.strip().strip("'\"") for x in v[1:-1].split(",")] | |
| config[k] = v | |
| except: pass | |
| config["prompt"] = parts[2].strip() | |
| agents[name] = config | |
| return agents | |
| def run_custom_agent(agent_config, user_task, model, tools, console): | |
| """執行自訂代理""" | |
| from rich.markdown import Markdown | |
| name = agent_config["name"] | |
| prompt = agent_config["prompt"] | |
| allowed = agent_config.get("tools") # list or None | |
| denied = agent_config.get("disallowedTools", []) | |
| console.print(f"\n[bold magenta]🤖 Agent: {name}[/]") | |
| agent_messages = [ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": user_task}, | |
| ] | |
| full_response = "" | |
| for rnd in range(5): # 子代理最多 5 輪 | |
| with console.status(f"[magenta]{name} 思考中 (round {rnd+1})..."): | |
| try: response = model.chat(agent_messages) | |
| except Exception as e: console.print(f"[red]❌ {e}[/]"); break | |
| tool_calls = parse_tool_calls(response) | |
| text_parts = TOOL_PATTERN.sub("", response).strip() | |
| if text_parts: | |
| console.print(f" [magenta][{name}][/] {text_parts[:300]}") | |
| full_response += response + "\n" | |
| if not tool_calls: break | |
| agent_messages.append({"role": "assistant", "content": response}) | |
| results = [] | |
| for call in tool_calls: | |
| # 權限檢查 | |
| if allowed and call["tool"] not in allowed: | |
| results.append(f"[{call['tool']}] ❌ 此代理不允許使用 {call['tool']}") | |
| continue | |
| if call["tool"] in denied: | |
| results.append(f"[{call['tool']}] ❌ 此代理禁止使用 {call['tool']}") | |
| continue | |
| result = execute_tool(tools, call) | |
| results.append(f"[{call['tool']}] {result}") | |
| agent_messages.append({"role": "user", "content": "Tool results:\n" + "\n\n".join(results)}) | |
| return full_response | |
| # ============================================================ | |
| # P1-4: AUTO GIT COMMIT | |
| # ============================================================ | |
| def auto_git_commit(tools, model, edited_files, console): | |
| """自動 stage 修改的文件並 commit""" | |
| if not edited_files: | |
| console.print("[dim]沒有修改的文件[/]") | |
| return | |
| # 只 stage 明確修改過的文件(不用 git add -A) | |
| rel_files = [] | |
| for f in edited_files: | |
| try: | |
| rel = os.path.relpath(f, tools.project_dir) | |
| rel_files.append(rel) | |
| except: continue | |
| if not rel_files: | |
| return | |
| console.print(f" [dim]📁 Stage: {', '.join(rel_files[:5])}{'...' if len(rel_files)>5 else ''}[/]") | |
| # git add 個別文件 | |
| for f in rel_files: | |
| subprocess.run(["git", "add", f], cwd=tools.project_dir, capture_output=True) | |
| # 用模型生成 commit message | |
| diff = subprocess.run(["git", "diff", "--cached", "--stat"], | |
| cwd=tools.project_dir, capture_output=True, text=True).stdout | |
| with console.status("[dim]生成 commit message..."): | |
| msg_prompt = f"Generate a concise git commit message (1 line, max 72 chars) for:\n\n{diff[:2000]}" | |
| try: | |
| commit_msg = model.chat([{"role": "user", "content": msg_prompt}], max_tokens=100) | |
| # 清理:取第一行,去掉引號 | |
| commit_msg = commit_msg.strip().split("\n")[0].strip('"').strip("'") | |
| if len(commit_msg) > 72: commit_msg = commit_msg[:69] + "..." | |
| except: | |
| commit_msg = f"codepilot: update {len(rel_files)} file(s)" | |
| console.print(f" [dim]💬 {commit_msg}[/]") | |
| from rich.prompt import Confirm | |
| if Confirm.ask(" Commit?", default=True): | |
| result = subprocess.run(["git", "commit", "-m", commit_msg], | |
| cwd=tools.project_dir, capture_output=True, text=True) | |
| if result.returncode == 0: | |
| console.print(f" [green]✅ Committed[/]") | |
| else: | |
| console.print(f" [red]❌ {result.stderr[:200]}[/]") | |
| else: | |
| # unstage | |
| subprocess.run(["git", "reset", "HEAD"] + rel_files, | |
| cwd=tools.project_dir, capture_output=True) | |
| console.print(" [dim]已取消[/]") | |
| def build_system_prompt(tools, project_memory=""): | |
| memory_section = f"\n\n## Project Memory (CODEPILOT.md)\n{project_memory}" if project_memory else "" | |
| return f"""You are CodePilot, an expert AI programming assistant working in the user's project. | |
| Working directory: {tools.cwd} | |
| {tools.git_context()}{memory_section} | |
| ## Tools (use <tool>name\n{{json}}</tool>) | |
| - read_file: {{"path":"...","offset":1,"limit":200}} | |
| - edit_file: {{"path":"...","old_string":"...","new_string":"..."}} (must read first) | |
| - write_file: {{"path":"...","content":"..."}} | |
| - run_command: {{"command":"...","timeout":120}} | |
| - search_files: {{"pattern":"...","glob":"*.py"}} | |
| - list_files: {{"pattern":"*","max_depth":3}} | |
| - git_status: {{}} | |
| Rules: read before edit, old_string must be unique, prefer edit over write, verify changes.""" | |
| # ============================================================ | |
| # MODEL FACTORY | |
| # ============================================================ | |
| def _create_model(provider_key, args, console=None): | |
| """統一的模型建立函數""" | |
| if provider_key == "local": | |
| return LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter) | |
| elif provider_key == "codex": | |
| model_name = args.cloud_model or PROVIDER_CONFIGS["codex"]["default_model"] | |
| return CodexModel(model_name) | |
| else: | |
| if not args.api_key: | |
| raise ValueError(f"使用 {provider_key} 需要 --api-key") | |
| model_name = args.cloud_model or PROVIDER_CONFIGS[provider_key]["default_model"] | |
| return CloudModel(provider_key, args.api_key, model_name) | |
| # ============================================================ | |
| # LEETCODE AUTO-GRIND | |
| # ============================================================ | |
| def run_grind(args, num_problems=100): | |
| """自動刷 LeetCode 題目,產生訓練數據""" | |
| from rich.console import Console | |
| from rich.progress import Progress | |
| console = Console() | |
| db = FeedbackDB() | |
| console.print(f""" | |
| ╔════════════════════════════════════════════════════════════╗ | |
| ║ 🏋️ LeetCode Auto-Grind ║ | |
| ║ 自動刷題,無人值守產生訓練數據 ║ | |
| ╚════════════════════════════════════════════════════════════╝ | |
| """) | |
| # 載入模型 | |
| provider_key = args.provider or "local" | |
| model = _create_model(provider_key, args) | |
| console.print(f"[green]✅ 模型: {model.name}[/]") | |
| # 載入 KodCode 題目 | |
| console.print("📦 載入 KodCode 題庫...") | |
| from datasets import load_dataset | |
| dataset = load_dataset("KodCode/KodCode-V1", split="train") | |
| dataset = dataset.shuffle(seed=int(time.time()) % 10000).select(range(min(num_problems, len(dataset)))) | |
| console.print(f" {len(dataset)} 題已載入\n") | |
| passed = 0 | |
| failed = 0 | |
| errors = 0 | |
| with Progress() as progress: | |
| task = progress.add_task("[cyan]刷題中...", total=len(dataset)) | |
| for i, problem in enumerate(dataset): | |
| question = problem["question"] | |
| test_code = problem["test"] | |
| solution_ref = problem["solution"] | |
| prompt = f"Write a Python solution. Provide ONLY the code, no explanation.\n\n{question}" | |
| messages = [ | |
| {"role": "system", "content": "You are an expert Python programmer. Output only clean Python code."}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| # 生成回答 | |
| try: | |
| response = model.chat(messages, max_tokens=1024) | |
| except Exception as e: | |
| errors += 1; progress.update(task, advance=1); continue | |
| # 提取 code | |
| code = response | |
| if "```python" in code: code = code.split("```python")[1].split("```")[0] | |
| elif "```" in code: code = code.split("```")[1].split("```")[0] | |
| # 執行測試 | |
| reward = 0.0 | |
| try: | |
| import tempfile | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| Path(os.path.join(tmpdir, "solution.py")).write_text(code) | |
| Path(os.path.join(tmpdir, "test_solution.py")).write_text(test_code) | |
| r = subprocess.run( | |
| [sys.executable, "-m", "pytest", "test_solution.py", "-x", "--tb=no", "-q"], | |
| cwd=tmpdir, capture_output=True, text=True, timeout=15) | |
| if r.returncode == 0: | |
| reward = 1.0; passed += 1 | |
| else: | |
| reward = 0.0; failed += 1 | |
| except: | |
| reward = 0.0; failed += 1 | |
| # 記錄數據 | |
| if reward == 1.0: | |
| # 通過測試 → 記為好答案 (SFT + KTO positive) | |
| db.save(prompt, code, 1, source_model=model.name, | |
| provider=getattr(model, "provider", provider_key)) | |
| else: | |
| # 失敗 → 記為壞答案,同時記錄正確答案 | |
| db.save(prompt, code, 0, source_model=model.name, | |
| provider=getattr(model, "provider", provider_key)) | |
| # 正確答案記為 SFT | |
| if solution_ref: | |
| db.save(prompt, solution_ref, 1, source_model="ground_truth", | |
| provider="reference") | |
| progress.update(task, advance=1, | |
| description=f"[cyan]刷題中... ✅{passed} ❌{failed}") | |
| # 統計 | |
| total = passed + failed + errors | |
| console.print(f"\n{'='*50}") | |
| console.print(f" 🏋️ 刷題完成!") | |
| console.print(f" ✅ 通過: {passed}/{total} ({100*passed/max(total,1):.0f}%)") | |
| console.print(f" ❌ 失敗: {failed}/{total}") | |
| console.print(f" ⚠️ 錯誤: {errors}") | |
| console.print(f"\n 📊 數據統計:") | |
| s = db.count() | |
| console.print(f" 總數據: {s['total']}") | |
| console.print(f" 👍: {s['up']} / 👎: {s['total']-s['up']}") | |
| console.print(f"\n 💡 運行 codepilot --train 開始訓練") | |
| # ============================================================ | |
| # MAIN AGENT LOOP | |
| # ============================================================ | |
| def run_agent_loop(args): | |
| from rich.console import Console, Group | |
| from rich.markdown import Markdown | |
| from rich.panel import Panel | |
| from rich.prompt import Prompt | |
| from rich.syntax import Syntax | |
| from rich.table import Table | |
| console = Console(); db = FeedbackDB() | |
| project_dir = args.project or os.getcwd() | |
| tools = ProjectTools(project_dir) | |
| ctx = ProjectContext(project_dir) | |
| provider_key = args.provider or "local" | |
| # 載入模型(支援 local, cloud API, codex, ollama) | |
| local_model_ref = None; cloud_model_ref = None | |
| try: | |
| if provider_key == "local": | |
| with console.status("[bold green]載入本地模型..."): | |
| model = _create_model(provider_key, args) | |
| local_model_ref = model | |
| elif provider_key == "codex": | |
| with console.status("[bold green]連接 OpenAI Codex..."): | |
| model = _create_model(provider_key, args) | |
| cloud_model_ref = model | |
| console.print(f"[green]✅ Codex ({model.name})[/]") | |
| else: | |
| model = _create_model(provider_key, args) | |
| cloud_model_ref = model | |
| except Exception as e: | |
| console.print(f"[red]❌ 模型載入失敗: {e}[/]"); sys.exit(1) | |
| if args.adapter and provider_key != "local": | |
| try: | |
| with console.status("[dim]載入本地模型 (for duel)..."): | |
| local_model_ref = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter) | |
| console.print("[dim]✅ 本地模型已載入[/]") | |
| except: pass | |
| # Duel 模式開關 | |
| duel_mode = args.duel and local_model_ref and cloud_model_ref | |
| # 專案記憶(四層) | |
| instructions = ctx.load_all_instructions() | |
| memory = ctx.load_memory() | |
| # Banner | |
| banner = f"[bold cyan]CodePilot v4[/]" | |
| if duel_mode: banner += " [bold yellow]⚔️ Duel ON[/]" | |
| banner += f"\n[dim]Model: {model.name}\nProject: {project_dir}[/]" | |
| if instructions: banner += f"\n[dim]📋 CODEPILOT.md loaded[/]" | |
| if memory: banner += f"\n[dim]🧠 MEMORY.md loaded ({len(memory)} chars)[/]" | |
| if MEMORY_MODULE_AVAILABLE: banner += f"\n[dim]💾 Session JSONL + Auto-compact enabled[/]" | |
| console.print(Panel.fit(banner, border_style="cyan")) | |
| git_ctx = tools.git_context() | |
| if git_ctx != "(not a git repo)": console.print(Panel(git_ctx, title="📂 Project", border_style="dim")) | |
| # 嘗試恢復上次對話 | |
| git_ctx = tools.git_context() | |
| system_prompt = ctx.build_system_prompt(git_ctx) | |
| prev_session = ctx.load_session() | |
| if prev_session and len(prev_session) > 1: | |
| messages = prev_session | |
| # 更新 system prompt | |
| messages[0] = {"role": "system", "content": system_prompt} | |
| console.print(f"[dim]🔄 已恢復上次對話 ({(len(messages)-1)//2} 輪)[/]") | |
| else: | |
| messages = [{"role": "system", "content": system_prompt}] | |
| hooks = Hooks(project_dir) | |
| bg_tasks = BackgroundTaskManager() | |
| custom_agents = load_custom_agents(project_dir) | |
| mcp = MCPLite(project_dir) | |
| skill_mgr = SkillManager(project_dir) # Skill 系統 | |
| approval_mode = args.approval or "auto" | |
| use_streaming = args.stream and provider_key == "local" | |
| edited_files_this_session = [] | |
| if custom_agents: | |
| console.print(f"[dim]🤖 自訂代理: {', '.join(custom_agents.keys())}[/]") | |
| console.print(f"[dim]指令: /init /verify /commit /agent /bg /approval /web /mcp /stream | /duel /memo /grind /ls /git /clear /status /train /quit[/]\n") | |
| while True: | |
| try: user_input = Prompt.ask("\n[bold green]🧑 You") | |
| except (EOFError, KeyboardInterrupt): break | |
| if not user_input.strip(): continue | |
| cmd = user_input.strip() | |
| # ---- 指令 ---- | |
| if cmd in ("/quit", "/exit"): break | |
| elif cmd == "/init": | |
| result = cmd_init(tools, model, console) | |
| # 重建 system prompt | |
| system_prompt = ctx.build_system_prompt(tools.git_context()) | |
| messages[0] = {"role": "system", "content": system_prompt} | |
| continue | |
| elif cmd == "/verify": | |
| run_verification(model, tools, console, edited_files_this_session) | |
| continue | |
| elif cmd == "/duel on": | |
| if local_model_ref and cloud_model_ref: | |
| duel_mode = True; console.print("[yellow]⚔️ Duel 模式已開啟 — 每個問題自動雙模型比較[/]") | |
| else: | |
| console.print("[red]需要同時有本地和雲端模型。啟動: codepilot --duel --provider openrouter --api-key xxx --adapter ./adapter[/]") | |
| continue | |
| elif cmd == "/duel off": | |
| duel_mode = False; console.print("[dim]Duel 模式已關閉[/]"); continue | |
| elif cmd == "/memo" or cmd.startswith("/memo "): | |
| # /memo → 編輯 CODEPILOT.md 指令 | |
| # /memo + 文字 → 快速追加到 MEMORY.md | |
| quick_note = cmd[5:].strip() if cmd.startswith("/memo ") else "" | |
| if quick_note: | |
| ctx.save_memory_entry(quick_note) | |
| console.print(f"[green]🧠 已追加到 MEMORY.md: {quick_note}[/]") | |
| else: | |
| console.print(f"[bold]📋 CODEPILOT.md[/] — 專案指令(提交到 repo)") | |
| console.print(f"[bold]🧠 MEMORY.md[/] — 自動記憶(跨 session)\n") | |
| console.print("[dim]快速追加: /memo 這是一條記憶[/]") | |
| console.print("[dim]編輯指令: 輸入內容(END 結束)[/]") | |
| cur = ctx.load_all_instructions() | |
| if cur: console.print(f"[dim]目前 CODEPILOT.md:\n{cur[:300]}...[/]\n") | |
| cur_mem = ctx.load_memory() | |
| if cur_mem: console.print(f"[dim]目前 MEMORY.md:\n{cur_mem[:300]}...[/]\n") | |
| console.print("選擇: [cyan]1[/]=編輯 CODEPILOT.md [cyan]2[/]=編輯 MEMORY.md Enter=取消") | |
| choice = Prompt.ask(" ", choices=["1","2",""], default="", show_choices=False) | |
| if choice in ("1", "2"): | |
| console.print("輸入內容(END 結束):") | |
| edit_lines = [] | |
| while True: | |
| try: | |
| l = input() | |
| if l.strip() == "END": break | |
| edit_lines.append(l) | |
| except EOFError: break | |
| if edit_lines: | |
| content = "\n".join(edit_lines) | |
| if choice == "1": | |
| codepilot_md = os.path.join(project_dir, "CODEPILOT.md") | |
| Path(codepilot_md).write_text(content, encoding="utf-8") | |
| console.print(f"[green]✅ CODEPILOT.md 已保存[/]") | |
| else: | |
| if MEMORY_MODULE_AVAILABLE: | |
| save_memory(project_dir, content) | |
| console.print(f"[green]✅ MEMORY.md 已保存[/]") | |
| # 重建 system prompt | |
| system_prompt = ctx.build_system_prompt(tools.git_context()) | |
| messages[0] = {"role": "system", "content": system_prompt} | |
| continue | |
| elif cmd == "/grind": | |
| n = Prompt.ask("刷幾題?", default="50") | |
| run_grind(args, int(n)); continue | |
| elif cmd == "/commit": | |
| # P1-4: 自動 git commit | |
| auto_git_commit(tools, model, edited_files_this_session, console) | |
| continue | |
| elif cmd.startswith("/skill"): | |
| # Skill 系統 | |
| parts = cmd.split(None, 2) | |
| if len(parts) < 2 or parts[1] == "list": | |
| console.print(skill_mgr.list_skills()) | |
| continue | |
| skill_name = parts[1] | |
| # 收集參數 | |
| skill_def = skill_mgr.skills.get(skill_name) | |
| skill_args = {} | |
| if skill_def: | |
| # 如果指令裡有第三段,用它作為第一個參數 | |
| if len(parts) > 2 and skill_def.get("arguments"): | |
| skill_args[skill_def["arguments"][0]["name"]] = parts[2] | |
| else: | |
| for arg_def in skill_def.get("arguments", []): | |
| default = arg_def.get("default", "") | |
| val = Prompt.ask(f" {arg_def['name']} ({arg_def.get('description','')})", default=default) | |
| if val: skill_args[arg_def["name"]] = val | |
| result, inject = skill_mgr.invoke(skill_name, skill_args, model, tools, console, messages) | |
| if inject: | |
| # 注入模式:加入當前對話 | |
| messages.append({"role": "user", "content": inject}) | |
| with console.status("[bold cyan]執行 skill..."): | |
| response = chat_with_recovery(model, messages, ctx=ctx, console=console) | |
| console.print(f"\n[bold blue]🤖 CodePilot:[/]") | |
| from rich.markdown import Markdown as _Md | |
| console.print(_Md(TOOL_PATTERN.sub("", response).strip())) | |
| messages.append({"role": "assistant", "content": response}) | |
| # 處理工具呼叫 | |
| tcalls = parse_tool_calls(response) | |
| if tcalls: | |
| results = [] | |
| for call in tcalls: | |
| console.print(f" [dim]🔧 {call['tool']}[/]") | |
| r = execute_tool(tools, call) | |
| results.append(f"[{call['tool']}] {r}") | |
| messages.append({"role": "user", "content": "Tool results:\n" + "\n\n".join(results)}) | |
| continue | |
| elif cmd.startswith("/agent"): | |
| # P1-3: 自訂代理 | |
| parts = cmd.split(None, 2) | |
| if len(parts) < 2: | |
| console.print("[bold]可用代理:[/]") | |
| if custom_agents: | |
| for name, cfg in custom_agents.items(): | |
| desc = cfg.get("description", "") | |
| console.print(f" 🤖 {name}: {desc}") | |
| console.print(f"\n[dim]用法: /agent <名稱> <任務>[/]") | |
| else: | |
| console.print("[dim]無自訂代理。建立 .codepilot/agents/*.md[/]") | |
| console.print("[dim]範例: .codepilot/agents/reviewer.md[/]") | |
| continue | |
| agent_name = parts[1] | |
| agent_task = parts[2] if len(parts) > 2 else Prompt.ask("任務") | |
| if agent_name in custom_agents: | |
| result = run_custom_agent(custom_agents[agent_name], agent_task, model, tools, console) | |
| elif agent_name == "explore": | |
| # 內建 Explore agent(只讀) | |
| result = run_custom_agent( | |
| {"name": "explore", "prompt": "You are an exploration agent. Read and search files to investigate. NEVER modify or create files.", | |
| "tools": ["read_file", "search_files", "list_files", "git_status"]}, | |
| agent_task, model, tools, console) | |
| elif agent_name == "plan": | |
| # 內建 Plan agent | |
| result = run_custom_agent( | |
| {"name": "plan", "prompt": "You are a planning agent. Analyze the task and create a detailed step-by-step plan. Do NOT execute any changes.", | |
| "tools": ["read_file", "search_files", "list_files", "git_status"]}, | |
| agent_task, model, tools, console) | |
| else: | |
| console.print(f"[red]未知代理: {agent_name}[/]") | |
| console.print(f"[dim]可用: {', '.join(list(custom_agents.keys()) + ['explore', 'plan'])}[/]") | |
| continue | |
| elif cmd.startswith("/bg"): | |
| # P1-2: 背景任務 | |
| parts = cmd.split(None, 1) | |
| if len(parts) < 2 or parts[1] == "list": | |
| console.print(bg_tasks.list_tasks()) | |
| elif parts[1].startswith("run "): | |
| bg_cmd = parts[1][4:] | |
| tid = bg_tasks.start(bg_cmd, tools.cwd) | |
| console.print(f" [green]🚀 背景任務 {tid}: {bg_cmd}[/]") | |
| elif parts[1].startswith("check "): | |
| tid = parts[1][6:].strip() | |
| info = bg_tasks.check(tid) | |
| console.print(f" 狀態: {info['status']} | 耗時: {info.get('elapsed',0)}s") | |
| if info.get("output"): console.print(Panel(info["output"][:500], title=f"bg:{tid}", border_style="dim")) | |
| elif parts[1].startswith("kill "): | |
| tid = parts[1][5:].strip() | |
| if bg_tasks.kill(tid): console.print(f" [red]⛔ 已終止 {tid}[/]") | |
| else: console.print(f" [dim]任務不存在或已結束[/]") | |
| else: | |
| console.print("[dim]/bg list | /bg run <cmd> | /bg check <id> | /bg kill <id>[/]") | |
| continue | |
| elif cmd.startswith("/web "): | |
| # P2-1: 快速網頁搜尋/抓取 | |
| query = cmd[5:].strip() | |
| if query.startswith("http"): | |
| console.print(f"[dim]🌐 抓取 {query}...[/]") | |
| result = web_fetch(query) | |
| else: | |
| console.print(f"[dim]🔍 搜尋: {query}...[/]") | |
| result = web_search(query) | |
| console.print(result[:2000]) | |
| continue | |
| elif cmd.startswith("/mcp"): | |
| # P2-5: MCP 伺服器管理 | |
| parts = cmd.split(None, 3) | |
| if len(parts) < 2 or parts[1] == "list": | |
| console.print(f"[bold]🔌 MCP 伺服器[/]") | |
| console.print(mcp.list_servers()) | |
| elif len(parts) >= 3: | |
| server = parts[1] | |
| method = parts[2] | |
| params = json.loads(parts[3]) if len(parts) > 3 else {} | |
| console.print(f"[dim]🔌 {server}.{method}...[/]") | |
| result = mcp.call(server, method, params) | |
| console.print(result[:1000]) | |
| else: | |
| console.print("[dim]/mcp list | /mcp <server> <method> [json_params][/]") | |
| continue | |
| elif cmd == "/stream on": | |
| use_streaming = (provider_key == "local") | |
| console.print(f"[green]{'✅ Streaming ON' if use_streaming else '❌ Streaming 只支援本地模型'}[/]") | |
| continue | |
| elif cmd == "/stream off": | |
| use_streaming = False; console.print("[dim]Streaming OFF[/]"); continue | |
| elif cmd.startswith("/approval"): | |
| # P1-1: 切換審批模式 | |
| parts = cmd.split() | |
| if len(parts) > 1 and parts[1] in APPROVAL_MODES: | |
| approval_mode = parts[1] | |
| console.print(f" [green]審批模式: {approval_mode} — {APPROVAL_MODES[approval_mode]}[/]") | |
| else: | |
| console.print(f" 目前: [bold]{approval_mode}[/] — {APPROVAL_MODES.get(approval_mode,'')}") | |
| for k, v in APPROVAL_MODES.items(): | |
| marker = "→" if k == approval_mode else " " | |
| console.print(f" {marker} /approval {k}: {v}") | |
| continue | |
| elif cmd == "/status": | |
| s = db.count() | |
| t = Table(title="📊 統計"); t.add_column("", style="cyan"); t.add_column("", style="green") | |
| t.add_row("Total", str(s["total"])); t.add_row("👍", str(s["up"])) | |
| t.add_row("👎", str(s["total"]-s["up"])); t.add_row("✏️", str(s["edits"])) | |
| t.add_row("DPO 對", str(len(db.export_dpo()))) | |
| t.add_row("Duel", "⚔️ ON" if duel_mode else "OFF") | |
| t.add_row("記憶", f"{len(project_memory)} chars" if project_memory else "無") | |
| t.add_row("對話輪數", str((len(messages)-1)//2)) | |
| console.print(t); continue | |
| elif cmd == "/train": trigger_training(db, console, args); continue | |
| elif cmd == "/clear": | |
| messages = [{"role": "system", "content": system_prompt}] | |
| ctx.save_session(messages); console.print("[dim]已清除[/]"); continue | |
| elif cmd == "/git": console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue | |
| elif cmd.startswith("/ls"): console.print(tools.list_files(cmd[3:].strip() or "*")); continue | |
| elif cmd == "/switch": | |
| new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys())) | |
| if new_p == "local": | |
| with console.status("載入..."): model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter) | |
| local_model_ref = model; provider_key = "local" | |
| else: | |
| key = args.api_key or Prompt.ask("API Key") | |
| cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"]) | |
| model = CloudModel(new_p, key, cm); cloud_model_ref = model; provider_key = new_p | |
| console.print(f"[green]✅ {provider_key}[/]"); continue | |
| # ---- Duel 模式:自動雙模型比較 ---- | |
| if duel_mode and local_model_ref and cloud_model_ref: | |
| compare_msgs = list(messages) + [{"role": "user", "content": user_input}] | |
| with console.status("[bold cyan]🏠 本地模型..."): | |
| try: local_resp = local_model_ref.chat(compare_msgs) | |
| except Exception as e: local_resp = f"(錯誤: {e})" | |
| with console.status("[bold magenta]☁️ 雲端模型..."): | |
| try: cloud_resp = cloud_model_ref.chat(compare_msgs) | |
| except Exception as e: cloud_resp = f"(錯誤: {e})" | |
| console.print(Panel(Markdown(local_resp), title=f"🏠 {local_model_ref.name}", border_style="blue")) | |
| console.print(Panel(Markdown(cloud_resp), title=f"☁️ {cloud_model_ref.name}", border_style="magenta")) | |
| console.print(f"[dim][green]1[/]=🏠本地 [magenta]2[/]=☁️雲端 [yellow]b[/]=都好 [red]x[/]=都差 Enter=跳過[/]") | |
| choice = Prompt.ask(" ", choices=["1","2","b","x",""], default="", show_choices=False) | |
| if choice == "2": | |
| db.save(user_input, cloud_resp, 1, source_model=cloud_model_ref.name, provider=cloud_model_ref.provider) | |
| db.save(user_input, local_resp, 0, source_model=local_model_ref.name, provider="local") | |
| console.print(f" [magenta]☁️ 雲端勝 → DPO +1 ({len(db.export_dpo())} 對)[/]") | |
| messages.append({"role": "user", "content": user_input}) | |
| messages.append({"role": "assistant", "content": cloud_resp}) | |
| elif choice == "1": | |
| db.save(user_input, local_resp, 1, source_model=local_model_ref.name, provider="local") | |
| db.save(user_input, cloud_resp, 0, source_model=cloud_model_ref.name, provider=cloud_model_ref.provider) | |
| console.print(f" [green]🏠 本地勝![/]") | |
| messages.append({"role": "user", "content": user_input}) | |
| messages.append({"role": "assistant", "content": local_resp}) | |
| elif choice == "b": | |
| db.save(user_input, local_resp, 1, source_model=local_model_ref.name, provider="local") | |
| db.save(user_input, cloud_resp, 1, source_model=cloud_model_ref.name, provider=cloud_model_ref.provider) | |
| console.print(f" [yellow]👍 都好[/]") | |
| messages.append({"role": "user", "content": user_input}) | |
| messages.append({"role": "assistant", "content": cloud_resp}) | |
| elif choice == "x": | |
| db.save(user_input, local_resp, 0, source_model=local_model_ref.name, provider="local") | |
| db.save(user_input, cloud_resp, 0, source_model=cloud_model_ref.name, provider=cloud_model_ref.provider) | |
| console.print(f" [red]都差[/]") | |
| else: | |
| messages.append({"role": "user", "content": user_input}) | |
| messages.append({"role": "assistant", "content": cloud_resp}) | |
| ctx.save_session(messages) | |
| continue | |
| # ---- 正常模式:單模型 + 工具循環 + 錯誤恢復 ---- | |
| messages.append({"role": "user", "content": user_input}) | |
| full_response = "" | |
| tools_used_this_turn = [] # 追蹤這輪用了哪些工具 | |
| for rnd in range(10): | |
| try: | |
| if use_streaming and rnd == 0 and provider_key == "local": | |
| # P2-2: Streaming 輸出(第一輪,本地模型) | |
| response = stream_local_chat(model, messages, console) | |
| else: | |
| with console.status(f"[bold cyan]{'思考中' if rnd == 0 else f'工具 round {rnd+1}'}..."): | |
| response = chat_with_recovery( | |
| model, messages, ctx=ctx, console=console, | |
| fallback_model=local_model_ref if provider_key != "local" else None) | |
| except Exception as e: | |
| console.print(f"[red]❌ 所有重試失敗: {e}[/]") | |
| break | |
| tool_calls = parse_tool_calls(response) | |
| text_parts = TOOL_PATTERN.sub("", response).strip() | |
| if text_parts and not (use_streaming and rnd == 0): | |
| # streaming 模式已經顯示過了,不重複 | |
| console.print(f"\n[bold blue]🤖 CodePilot:[/]") | |
| console.print(Markdown(text_parts)) | |
| full_response += response + "\n" | |
| if not tool_calls: break | |
| messages.append({"role": "assistant", "content": response}) | |
| results = [] | |
| for call in tool_calls: | |
| console.print(f" [dim]🔧 {call['tool']}[/]") | |
| # P1-1: 審批檢查 | |
| if not check_approval(call["tool"], call["params"], approval_mode, console): | |
| results.append(f"[{call['tool']}] ⛔ 用戶拒絕執行") | |
| continue | |
| result = execute_tool(tools, call) # 已含 P0-2 截斷 | |
| tools_used_this_turn.append(call["tool"]) | |
| # 追蹤修改的文件 | |
| if call["tool"] in ("edit_file", "write_file") and "✅" in result: | |
| fpath = call["params"].get("path", "") | |
| if fpath: | |
| full_path = os.path.join(tools.cwd, fpath) if not os.path.isabs(fpath) else fpath | |
| if full_path not in edited_files_this_session: | |
| edited_files_this_session.append(full_path) | |
| # P0-Bonus: 觸發 post-edit hook | |
| hook_result = hooks.run(f"post_{call['tool']}", {"file": full_path}) | |
| if hook_result: | |
| console.print(f" [dim]🪝 Hook: {hook_result[:100]}[/]") | |
| # 顯示結果 | |
| if call["tool"] == "edit_file" and "✅" in result: | |
| d = result.split("\n", 1)[1] if "\n" in result else "" | |
| if d: console.print(Syntax(d, "diff", theme="monokai")) | |
| elif call["tool"] == "run_command": | |
| console.print(Panel(result[:500], title="Terminal", border_style="dim")) | |
| else: console.print(f" [dim]{result[:200]}[/]") | |
| results.append(f"[{call['tool']}] {result}") | |
| messages.append({"role": "user", "content": "Tool results:\n" + "\n\n".join(results)}) | |
| # P0-4: 自動驗證 — 如果這輪有修改文件,自動跑測試 | |
| if any(t in ("edit_file", "write_file") for t in tools_used_this_turn): | |
| if edited_files_this_session: | |
| console.print(f"\n[dim]🔍 Auto-verify ({len(edited_files_this_session)} files modified)...[/]") | |
| run_verification(model, tools, console, edited_files_this_session) | |
| # 回饋 | |
| console.print(f"\n[dim][green]y[/]=👍 [red]n[/]=👎 [yellow]e[/]=✏️ Enter=跳過[/]") | |
| fb = Prompt.ask(" ", choices=["y","n","e",""], default="", show_choices=False) | |
| if fb == "y": | |
| db.save(user_input, full_response, 1, source_model=getattr(model,"name",""), provider=provider_key) | |
| console.print(" [green]👍[/]") | |
| elif fb == "n": | |
| db.save(user_input, full_response, 0, source_model=getattr(model,"name",""), provider=provider_key) | |
| console.print(" [red]👎[/]") | |
| elif fb == "e": | |
| console.print(" [yellow]貼上修改版(END結束):[/]"); lines = [] | |
| while True: | |
| try: | |
| l = input() | |
| if l.strip() == "END": break | |
| lines.append(l) | |
| except EOFError: break | |
| edited = "\n".join(lines) | |
| if edited.strip(): | |
| db.save(user_input, full_response, 1, edited=edited, source_model=getattr(model,"name",""), provider=provider_key) | |
| console.print(" [yellow]✏️[/]") | |
| messages.append({"role": "assistant", "content": full_response}) | |
| # L4: 自動壓縮檢查 | |
| messages = ctx.check_compact(messages, model_chat_fn=model.chat if hasattr(model, 'chat') else None) | |
| ctx.save_session(messages) | |
| mcp.cleanup() | |
| console.print("\n[cyan]👋[/]") | |
| # ============================================================ | |
| # TRAINING | |
| # ============================================================ | |
| def trigger_training(db, console, args): | |
| s = db.count() | |
| if s["total"] == 0: console.print("[yellow]⚠️ 無數據[/]"); return | |
| cloud_sft = db.export_sft(only_cloud=True); all_sft = db.export_sft(); dpo = db.export_dpo() | |
| console.print(f"\n[bold]🚀 數據[/] ⚗️蒸餾SFT:{len(cloud_sft)} 📊DPO:{len(dpo)} 📚全SFT:{len(all_sft)}") | |
| from datasets import Dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig, prepare_model_for_kbit_training | |
| mn = args.model or DEFAULT_LOCAL_MODEL | |
| od = os.path.join(CONFIG_DIR, f"adapter_{datetime.now().strftime('%Y%m%d_%H%M')}") | |
| bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True) | |
| pc = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", | |
| target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]) | |
| td = cloud_sft or all_sft | |
| if td: | |
| console.print(f"[bold]📚 {'⚗️蒸餾' if cloud_sft else ''} SFT ({len(td)})...[/]") | |
| from trl import SFTTrainer, SFTConfig | |
| m = AutoModelForCausalLM.from_pretrained(mn, quantization_config=bnb, device_map="auto", trust_remote_code=True) | |
| t = AutoTokenizer.from_pretrained(mn) | |
| if t.pad_token is None: t.pad_token = t.eos_token | |
| m = prepare_model_for_kbit_training(m) | |
| SFTTrainer(model=m, args=SFTConfig(output_dir=od, learning_rate=2e-4, num_train_epochs=3, | |
| per_device_train_batch_size=1, gradient_accumulation_steps=8, max_seq_length=1024, | |
| gradient_checkpointing=True, bf16=True, optim="paged_adamw_8bit", logging_steps=5, | |
| save_total_limit=1, logging_strategy="steps", logging_first_step=True), | |
| processing_class=t, train_dataset=Dataset.from_list(td), peft_config=pc).train() | |
| m.save_pretrained(od); del m; torch.cuda.empty_cache() | |
| console.print(f"\n[bold green]🎉[/] {od}\n codepilot --adapter {od}") | |
| def show_stats(): | |
| from rich.console import Console; from rich.table import Table | |
| c = Console(); db = FeedbackDB(); s = db.count() | |
| t = Table(title="📊 CodePilot"); t.add_column("",style="cyan"); t.add_column("",style="green") | |
| t.add_row("Total",str(s["total"])); t.add_row("👍",str(s["up"])); t.add_row("DPO",str(len(db.export_dpo()))) | |
| c.print(t) | |
| def main(): | |
| p = argparse.ArgumentParser(description="CodePilot v4") | |
| p.add_argument("--model", type=str); p.add_argument("--adapter", type=str) | |
| p.add_argument("--provider", type=str, choices=list(PROVIDER_CONFIGS.keys()), | |
| help="模型: local, openai, anthropic, openrouter, ollama, codex") | |
| p.add_argument("--api-key", type=str); p.add_argument("--cloud-model", type=str) | |
| p.add_argument("--duel", action="store_true", help="啟動時開啟 Duel 模式") | |
| p.add_argument("--approval", type=str, choices=["auto","auto-edit","ask"], default="auto", | |
| help="審批模式: auto=全自動, auto-edit=指令要確認, ask=全部確認") | |
| p.add_argument("--distill", action="store_true") | |
| p.add_argument("--grind", action="store_true", help="LeetCode 自動刷題") | |
| p.add_argument("--grind-count", type=int, default=100, help="刷幾題") | |
| p.add_argument("--stream", action="store_true", help="啟用 streaming 輸出(本地模型)") | |
| p.add_argument("--stats", action="store_true"); p.add_argument("--train", action="store_true") | |
| a = p.parse_args() | |
| if a.stats: show_stats() | |
| elif a.train: from rich.console import Console; trigger_training(FeedbackDB(), Console(), a) | |
| elif a.grind: run_grind(a, a.grind_count) | |
| else: run_agent_loop(a) | |
| if __name__ == "__main__": main() | |