Spaces:
Sleeping
Sleeping
File size: 18,989 Bytes
ca2b150 64b4a54 ca2b150 64b4a54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CodePilot v2 — Claude Code 風格的 AI 開發助手
==============================================
像 Claude Code 一樣在專案目錄中直接開發:
📁 讀取、編輯、建立專案文件
🖥️ 執行終端指令(python, pytest, npm, etc.)
🔍 搜尋程式碼(ripgrep/grep)
📂 瀏覽專案結構
🔀 Git 整合(狀態、diff、commit)
👍👎 回饋收集 → 模型越用越強
Install:
pip install transformers peft bitsandbytes accelerate trl datasets rich
Usage:
cd ~/my-project && python codepilot_v2.py
python codepilot_v2.py --stats
python codepilot_v2.py --train
python codepilot_v2.py --adapter ~/.codepilot/adapter_20260422
"""
import argparse,difflib,json,os,re,shutil,sqlite3,subprocess,sys,textwrap,torch
from datetime import datetime
from pathlib import Path
DEFAULT_MODEL="Qwen/Qwen2.5-Coder-3B-Instruct"
CONFIG_DIR=os.path.expanduser("~/.codepilot")
DB_PATH=os.path.join(CONFIG_DIR,"feedback.db")
AUTO_TRAIN_THRESHOLD=50
DANGEROUS_COMMANDS={"rm -rf /","git push --force","git reset --hard","dd if=","mkfs.","> /dev/sd"}
class FeedbackDB:
def __init__(self):
os.makedirs(CONFIG_DIR,exist_ok=True);self.conn=sqlite3.connect(DB_PATH)
self.conn.execute("CREATE TABLE IF NOT EXISTS feedback(id INTEGER PRIMARY KEY,timestamp TEXT,prompt TEXT,completion TEXT,label INTEGER,edited_completion TEXT,project TEXT)");self.conn.commit()
def save(self,prompt,completion,label,edited=None,project=None):
self.conn.execute("INSERT INTO feedback VALUES(NULL,?,?,?,?,?,?)",(datetime.now().isoformat(),prompt,completion,int(label),edited,project));self.conn.commit()
def count(self):
r=self.conn.execute("SELECT COUNT(*),COALESCE(SUM(label),0),SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) FROM feedback").fetchone()
return{"total":r[0],"up":int(r[1]),"edits":int(r[2]or 0)}
def export_sft(self):
return[{"messages":[{"role":"user","content":p},{"role":"assistant","content":c}]}for p,c in self.conn.execute("SELECT prompt,edited_completion FROM feedback WHERE edited_completion IS NOT NULL").fetchall()]
def export_kto(self):
return[{"prompt":[{"role":"user","content":p}],"completion":[{"role":"assistant","content":c}],"label":bool(l)}for p,c,l in self.conn.execute("SELECT prompt,completion,label FROM feedback").fetchall()]
class ProjectTools:
def __init__(self,project_dir):self.project_dir=os.path.abspath(project_dir);self.cwd=self.project_dir;self.read_cache={}
def _resolve(self,path):return path if os.path.isabs(path)else os.path.normpath(os.path.join(self.cwd,path))
def read_file(self,path,offset=1,limit=200):
full=self._resolve(path)
if not os.path.exists(full):return f"❌ 文件不存在: {path}"
try:
content=Path(full).read_text(encoding="utf-8",errors="replace");lines=content.splitlines();selected=lines[offset-1:offset-1+limit]
self.read_cache[full]={"time":os.path.getmtime(full),"content":content}
result="\n".join(f"{i+offset:4d} │ {line}"for i,line in enumerate(selected))
if offset+limit<len(lines):result+=f"\n... ({len(lines)-offset-limit+1} more lines)"
return result
except Exception as e:return f"❌ {e}"
def edit_file(self,path,old_string,new_string):
full=self._resolve(path)
if full not in self.read_cache:return"❌ 必須先 read_file 才能編輯"
if not os.path.exists(full):return f"❌ 不存在: {path}"
content=Path(full).read_text(encoding="utf-8")
if os.path.getmtime(full)!=self.read_cache[full]["time"]:return"❌ 文件已被外部修改,請重新 read_file"
count=content.count(old_string)
if count==0:return"❌ 找不到要替換的文字"
if count>1:return f"❌ 找到 {count} 處匹配,請提供更多上下文"
new_content=content.replace(old_string,new_string,1)
diff=list(difflib.unified_diff(content.splitlines(keepends=True),new_content.splitlines(keepends=True),fromfile=f"a/{path}",tofile=f"b/{path}"))
Path(full).write_text(new_content,encoding="utf-8");self.read_cache[full]={"time":os.path.getmtime(full),"content":new_content}
return"✅ 已修改:\n"+"".join(diff)
def write_file(self,path,content):
full=self._resolve(path);os.makedirs(os.path.dirname(full)or".",exist_ok=True);is_new=not os.path.exists(full)
Path(full).write_text(content,encoding="utf-8");self.read_cache[full]={"time":os.path.getmtime(full),"content":content}
return f"✅ {'建立'if is_new else'覆寫'}: {path} ({len(content)} chars)"
def run_command(self,command,timeout=120):
for d in DANGEROUS_COMMANDS:
if d in command:return f"⛔ 危險指令: {command}"
try:
r=subprocess.run(command,shell=True,cwd=self.cwd,capture_output=True,text=True,timeout=timeout)
out=r.stdout+(f"\nSTDERR:\n{r.stderr}"if r.stderr else "")+(f"\n(exit {r.returncode})"if r.returncode else"")
m=re.match(r"cd\s+(.+?)(\s*&&|\s*;|\s*$)",command)
if m:
nd=self._resolve(m.group(1).strip())
if os.path.isdir(nd):self.cwd=nd
return out[:10000]
except subprocess.TimeoutExpired:return f"⏰ 超時 ({timeout}s)"
except Exception as e:return f"❌ {e}"
def search_files(self,pattern,glob_pattern=None):
rg=shutil.which("rg");cmd=[rg or"grep","-rn"]
if rg:cmd+=["--color=never","--max-count=50"]
if glob_pattern and rg:cmd+=["--glob",glob_pattern]
cmd+=[pattern,self.cwd]
try:return subprocess.run(cmd,capture_output=True,text=True,timeout=30).stdout[:5000]or"無匹配"
except Exception as e:return f"❌ {e}"
def list_files(self,pattern="*",max_depth=3):
try:
r=subprocess.run(["find",self.cwd,"-maxdepth",str(max_depth),"-name",pattern,"-not","-path","*/.git/*","-not","-path","*/node_modules/*","-not","-path","*/__pycache__/*"],capture_output=True,text=True,timeout=10)
return"\n".join(sorted(os.path.relpath(f,self.cwd)for f in r.stdout.strip().split("\n")if f.strip())[:100])
except:
files=[]
for root,dirs,fnames in os.walk(self.cwd):
dirs[:]=[d for d in dirs if d not in{".git","node_modules","__pycache__",".venv"}]
if root.replace(self.cwd,"").count(os.sep)>=max_depth:continue
files.extend(os.path.relpath(os.path.join(root,f),self.cwd)for f in fnames if Path(f).match(pattern))
return"\n".join(sorted(files)[:100])
def git_context(self):
try:
b=subprocess.run(["git","branch","--show-current"],cwd=self.project_dir,capture_output=True,text=True).stdout.strip()
s=subprocess.run(["git","status","--short"],cwd=self.project_dir,capture_output=True,text=True).stdout.strip()
l=subprocess.run(["git","log","--oneline","-5"],cwd=self.project_dir,capture_output=True,text=True).stdout.strip()
return f"Branch: {b}\nStatus:\n{s}\nRecent:\n{l}"
except:return"(not a git repo)"
TOOL_PATTERN=re.compile(r'<tool>\s*(\w+)\s*\n(.*?)</tool>',re.DOTALL)
def parse_tool_calls(text):
calls=[]
for m in TOOL_PATTERN.finditer(text):
try:params=json.loads(m.group(2).strip())
except:
params={}
for line in m.group(2).strip().split("\n"):
if":"in line:k,v=line.split(":",1);params[k.strip()]=v.strip().strip('"').strip("'")
calls.append({"tool":m.group(1),"params":params})
return calls
def execute_tool(tools,call):
n,p=call["tool"],call["params"]
try:
if n=="read_file":return tools.read_file(p.get("path",""),int(p.get("offset",1)),int(p.get("limit",200)))
elif n=="edit_file":return tools.edit_file(p.get("path",""),p.get("old_string",""),p.get("new_string",""))
elif n=="write_file":return tools.write_file(p.get("path",""),p.get("content",""))
elif n=="run_command":return tools.run_command(p.get("command",""),int(p.get("timeout",120)))
elif n=="search_files":return tools.search_files(p.get("pattern",""),p.get("glob"))
elif n=="list_files":return tools.list_files(p.get("pattern","*"),int(p.get("max_depth",3)))
elif n=="git_status":return tools.git_context()
else:return f"❌ 未知: {n}"
except Exception as e:return f"❌ {e}"
def build_system_prompt(tools):
return f"""You are CodePilot, an expert AI programming assistant working directly in the user's project.
## Current Project
Working directory: {tools.cwd}
{tools.git_context()}
## Available Tools
Use tools by wrapping in <tool></tool> tags:
### read_file — Read a file (ALWAYS do this before editing)
<tool>read_file
{{"path": "src/main.py", "offset": 1, "limit": 200}}
</tool>
### edit_file — Edit file by exact string replace (must read first, old_string must be unique)
<tool>edit_file
{{"path": "src/main.py", "old_string": "def old():\\n pass", "new_string": "def new():\\n return 42"}}
</tool>
### write_file — Create or overwrite a file
<tool>write_file
{{"path": "src/new.py", "content": "print('hello')"}}
</tool>
### run_command — Execute shell command
<tool>run_command
{{"command": "python -m pytest tests/", "timeout": 60}}
</tool>
### search_files — Search code with regex
<tool>search_files
{{"pattern": "def main", "glob": "*.py"}}
</tool>
### list_files — List project files
<tool>list_files
{{"pattern": "*.py", "max_depth": 3}}
</tool>
### git_status — Get git info
<tool>git_status
{{}}
</tool>
## Rules
1. ALWAYS read a file before editing it
2. old_string must EXACTLY match file content (whitespace matters)
3. Prefer edit_file over write_file for existing files
4. After changes, verify by reading file or running tests
5. For git: stage specific files, never git add -A
6. Be concise and actionable"""
class CodeModel:
def __init__(self,model_name=DEFAULT_MODEL,adapter_path=None):
from transformers import AutoTokenizer,AutoModelForCausalLM
self.tokenizer=AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:self.tokenizer.pad_token=self.tokenizer.eos_token
self.model=AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,device_map="auto",trust_remote_code=True)
if adapter_path and os.path.exists(adapter_path):
from peft import PeftModel;self.model=PeftModel.from_pretrained(self.model,adapter_path)
self.model.eval()
def chat(self,messages,max_tokens=4096):
text=self.tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
inputs=self.tokenizer(text,return_tensors="pt").to(self.model.device)
with torch.no_grad():out=self.model.generate(**inputs,max_new_tokens=max_tokens,do_sample=True,temperature=0.7,top_p=0.9,repetition_penalty=1.1,pad_token_id=self.tokenizer.pad_token_id)
return self.tokenizer.decode(out[0][inputs["input_ids"].shape[1]:],skip_special_tokens=True)
def run_agent_loop(args):
from rich.console import Console;from rich.markdown import Markdown;from rich.panel import Panel;from rich.prompt import Prompt;from rich.syntax import Syntax
console=Console();db=FeedbackDB();project_dir=args.project or os.getcwd();tools=ProjectTools(project_dir)
console.print(Panel.fit(f"[bold cyan]CodePilot v2[/] — Claude Code 風格 AI 開發助手\n[dim]Model: {args.model or DEFAULT_MODEL}\nProject: {project_dir}[/]",border_style="cyan"))
with console.status("[bold green]載入模型中..."):model=CodeModel(args.model or DEFAULT_MODEL,args.adapter)
console.print("[green]✅ 模型載入完成[/]")
git_ctx=tools.git_context()
if git_ctx!="(not a git repo)":console.print(Panel(git_ctx,title="📂 Project",border_style="dim"))
console.print("[dim]直接輸入需求 | /ls 列檔 | /git 狀態 | /clear 清除 | /status 統計 | /train 訓練 | /quit 退出[/]\n")
system_prompt=build_system_prompt(tools);messages=[{"role":"system","content":system_prompt}]
while True:
try:user_input=Prompt.ask("\n[bold green]🧑 You")
except(EOFError,KeyboardInterrupt):break
if not user_input.strip():continue
cmd=user_input.strip()
if cmd in("/quit","/exit"):break
elif cmd=="/status":s=db.count();console.print(f" 👍{s['up']} 👎{s['total']-s['up']} ✏️{s['edits']} Total:{s['total']}");continue
elif cmd=="/train":trigger_training(db,console,args);continue
elif cmd=="/clear":messages=[{"role":"system","content":system_prompt}];console.print("[dim]已清除[/]");continue
elif cmd=="/git":console.print(Panel(tools.git_context(),title="Git",border_style="dim"));continue
elif cmd.startswith("/ls"):console.print(tools.list_files(cmd[3:].strip()or"*"));continue
messages.append({"role":"user","content":user_input});full_response=""
for rnd in range(10):
with console.status(f"[bold cyan]{'思考中'if rnd==0 else f'工具 round {rnd+1}'}..."):response=model.chat(messages)
tool_calls=parse_tool_calls(response);text_parts=TOOL_PATTERN.sub("",response).strip()
if text_parts:console.print(f"\n[bold blue]🤖 CodePilot:[/]");console.print(Markdown(text_parts))
full_response+=response+"\n"
if not tool_calls:break
messages.append({"role":"assistant","content":response});results=[]
for call in tool_calls:
console.print(f"\n [dim]🔧 {call['tool']}({json.dumps(call['params'],ensure_ascii=False)[:100]})[/]")
result=execute_tool(tools,call)
if call["tool"]=="edit_file"and result.startswith("✅"):
d=result.split("\n",1)[1]if"\n"in result else""
if d:console.print(Syntax(d,"diff",theme="monokai"))
else:console.print(f" [green]{result.split(chr(10))[0]}[/]")
elif call["tool"]=="run_command":console.print(Panel(result[:500]+("..."if len(result)>500 else""),title="Terminal",border_style="dim"))
elif call["tool"]=="read_file":console.print(f" [dim]({result.count(chr(10))+1} lines)[/]")
else:console.print(f" [dim]{result[:200]}{'...'if len(result)>200 else''}[/]")
results.append(f"[{call['tool']}] {result}")
messages.append({"role":"user","content":"Tool results:\n"+"\n\n".join(results)})
console.print(f"\n[dim][green]y[/]=👍 [red]n[/]=👎 [yellow]e[/]=✏️ Enter=跳過[/]")
fb=Prompt.ask(" ",choices=["y","n","e",""],default="",show_choices=False)
if fb=="y":db.save(user_input,full_response,1,project=project_dir);console.print(" [green]👍[/]")
elif fb=="n":db.save(user_input,full_response,0,project=project_dir);console.print(" [red]👎[/]")
elif fb=="e":
console.print(" [yellow]貼上修改版(END結束):[/]");lines=[]
while True:
try:l=input();(lines.append(l)if l.strip()!="END"else(_ for _ in()).throw(StopIteration))
except(EOFError,StopIteration):break
edited="\n".join(lines)
if edited.strip():db.save(user_input,full_response,1,edited=edited,project=project_dir);console.print(" [yellow]✏️[/]")
if messages[-1]["role"]=="user"and"Tool results:"in messages[-1]["content"]:messages.append({"role":"assistant","content":full_response})
console.print("\n[cyan]👋[/]")
def trigger_training(db,console,args):
s=db.count()
if s["total"]==0:console.print("[yellow]⚠️ 無數據[/]");return
console.print(f"\n[bold]🚀[/] 👍:{s['up']} 👎:{s['total']-s['up']} ✏️:{s['edits']}")
from datasets import Dataset;from transformers import AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig;from peft import LoraConfig,prepare_model_for_kbit_training
mn=args.model or DEFAULT_MODEL;od=os.path.join(CONFIG_DIR,f"adapter_{datetime.now().strftime('%Y%m%d_%H%M')}")
bnb=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
pc=LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
sft,kto=db.export_sft(),db.export_kto()
if sft:
console.print(f"[bold]📚 SFT({len(sft)})[/]");from trl import SFTTrainer,SFTConfig
m=AutoModelForCausalLM.from_pretrained(mn,quantization_config=bnb,device_map="auto",trust_remote_code=True);t=AutoTokenizer.from_pretrained(mn)
if t.pad_token is None:t.pad_token=t.eos_token
m=prepare_model_for_kbit_training(m)
SFTTrainer(model=m,args=SFTConfig(output_dir=od,learning_rate=2e-4,num_train_epochs=3,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_seq_length=1024,gradient_checkpointing=True,bf16=True,optim="paged_adamw_8bit",logging_steps=5,save_total_limit=1,logging_strategy="steps",logging_first_step=True),processing_class=t,train_dataset=Dataset.from_list(sft),peft_config=pc).train()
m.save_pretrained(od);del m;torch.cuda.empty_cache()
elif len(kto)>=10:
console.print(f"[bold]📚 KTO({len(kto)})[/]");from trl import KTOConfig,KTOTrainer
m=AutoModelForCausalLM.from_pretrained(mn,quantization_config=bnb,device_map="auto",trust_remote_code=True);t=AutoTokenizer.from_pretrained(mn)
if t.pad_token is None:t.pad_token=t.eos_token
KTOTrainer(model=m,args=KTOConfig(output_dir=od,learning_rate=1e-5,num_train_epochs=1,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_length=1024,gradient_checkpointing=True,bf16=True,logging_steps=5,logging_strategy="steps",logging_first_step=True),processing_class=t,train_dataset=Dataset.from_list(kto),peft_config=pc).train()
m.save_pretrained(od)
console.print(f"\n[bold green]🎉[/] {od}\n codepilot --adapter {od}")
def show_stats():
from rich.console import Console;from rich.table import Table
c=Console();s=FeedbackDB().count();t=Table(title="📊 CodePilot");t.add_column("",style="cyan");t.add_column("",style="green")
t.add_row("Total",str(s["total"]));t.add_row("👍",str(s["up"]));t.add_row("👎",str(s["total"]-s["up"]));t.add_row("✏️",str(s["edits"]));c.print(t)
def main():
p=argparse.ArgumentParser(description="CodePilot v2");p.add_argument("--model",type=str);p.add_argument("--adapter",type=str);p.add_argument("--project",type=str);p.add_argument("--stats",action="store_true");p.add_argument("--train",action="store_true")
a=p.parse_args()
if a.stats:show_stats()
elif a.train:from rich.console import Console;trigger_training(FeedbackDB(),Console(),a)
else:run_agent_loop(a)
if __name__=="__main__":main()
|