File size: 18,989 Bytes
ca2b150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64b4a54
 
 
ca2b150
64b4a54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CodePilot v2 — Claude Code 風格的 AI 開發助手
==============================================

像 Claude Code 一樣在專案目錄中直接開發:
  📁 讀取、編輯、建立專案文件
  🖥️  執行終端指令(python, pytest, npm, etc.)
  🔍 搜尋程式碼(ripgrep/grep)
  📂 瀏覽專案結構
  🔀 Git 整合(狀態、diff、commit)
  👍👎 回饋收集 → 模型越用越強

Install:
    pip install transformers peft bitsandbytes accelerate trl datasets rich

Usage:
    cd ~/my-project && python codepilot_v2.py
    python codepilot_v2.py --stats
    python codepilot_v2.py --train
    python codepilot_v2.py --adapter ~/.codepilot/adapter_20260422
"""
import argparse,difflib,json,os,re,shutil,sqlite3,subprocess,sys,textwrap,torch
from datetime import datetime
from pathlib import Path
DEFAULT_MODEL="Qwen/Qwen2.5-Coder-3B-Instruct"
CONFIG_DIR=os.path.expanduser("~/.codepilot")
DB_PATH=os.path.join(CONFIG_DIR,"feedback.db")
AUTO_TRAIN_THRESHOLD=50
DANGEROUS_COMMANDS={"rm -rf /","git push --force","git reset --hard","dd if=","mkfs.","> /dev/sd"}
class FeedbackDB:
    def __init__(self):
        os.makedirs(CONFIG_DIR,exist_ok=True);self.conn=sqlite3.connect(DB_PATH)
        self.conn.execute("CREATE TABLE IF NOT EXISTS feedback(id INTEGER PRIMARY KEY,timestamp TEXT,prompt TEXT,completion TEXT,label INTEGER,edited_completion TEXT,project TEXT)");self.conn.commit()
    def save(self,prompt,completion,label,edited=None,project=None):
        self.conn.execute("INSERT INTO feedback VALUES(NULL,?,?,?,?,?,?)",(datetime.now().isoformat(),prompt,completion,int(label),edited,project));self.conn.commit()
    def count(self):
        r=self.conn.execute("SELECT COUNT(*),COALESCE(SUM(label),0),SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) FROM feedback").fetchone()
        return{"total":r[0],"up":int(r[1]),"edits":int(r[2]or 0)}
    def export_sft(self):
        return[{"messages":[{"role":"user","content":p},{"role":"assistant","content":c}]}for p,c in self.conn.execute("SELECT prompt,edited_completion FROM feedback WHERE edited_completion IS NOT NULL").fetchall()]
    def export_kto(self):
        return[{"prompt":[{"role":"user","content":p}],"completion":[{"role":"assistant","content":c}],"label":bool(l)}for p,c,l in self.conn.execute("SELECT prompt,completion,label FROM feedback").fetchall()]
class ProjectTools:
    def __init__(self,project_dir):self.project_dir=os.path.abspath(project_dir);self.cwd=self.project_dir;self.read_cache={}
    def _resolve(self,path):return path if os.path.isabs(path)else os.path.normpath(os.path.join(self.cwd,path))
    def read_file(self,path,offset=1,limit=200):
        full=self._resolve(path)
        if not os.path.exists(full):return f"❌ 文件不存在: {path}"
        try:
            content=Path(full).read_text(encoding="utf-8",errors="replace");lines=content.splitlines();selected=lines[offset-1:offset-1+limit]
            self.read_cache[full]={"time":os.path.getmtime(full),"content":content}
            result="\n".join(f"{i+offset:4d}{line}"for i,line in enumerate(selected))
            if offset+limit<len(lines):result+=f"\n... ({len(lines)-offset-limit+1} more lines)"
            return result
        except Exception as e:return f"❌ {e}"
    def edit_file(self,path,old_string,new_string):
        full=self._resolve(path)
        if full not in self.read_cache:return"❌ 必須先 read_file 才能編輯"
        if not os.path.exists(full):return f"❌ 不存在: {path}"
        content=Path(full).read_text(encoding="utf-8")
        if os.path.getmtime(full)!=self.read_cache[full]["time"]:return"❌ 文件已被外部修改,請重新 read_file"
        count=content.count(old_string)
        if count==0:return"❌ 找不到要替換的文字"
        if count>1:return f"❌ 找到 {count} 處匹配,請提供更多上下文"
        new_content=content.replace(old_string,new_string,1)
        diff=list(difflib.unified_diff(content.splitlines(keepends=True),new_content.splitlines(keepends=True),fromfile=f"a/{path}",tofile=f"b/{path}"))
        Path(full).write_text(new_content,encoding="utf-8");self.read_cache[full]={"time":os.path.getmtime(full),"content":new_content}
        return"✅ 已修改:\n"+"".join(diff)
    def write_file(self,path,content):
        full=self._resolve(path);os.makedirs(os.path.dirname(full)or".",exist_ok=True);is_new=not os.path.exists(full)
        Path(full).write_text(content,encoding="utf-8");self.read_cache[full]={"time":os.path.getmtime(full),"content":content}
        return f"✅ {'建立'if is_new else'覆寫'}: {path} ({len(content)} chars)"
    def run_command(self,command,timeout=120):
        for d in DANGEROUS_COMMANDS:
            if d in command:return f"⛔ 危險指令: {command}"
        try:
            r=subprocess.run(command,shell=True,cwd=self.cwd,capture_output=True,text=True,timeout=timeout)
            out=r.stdout+(f"\nSTDERR:\n{r.stderr}"if r.stderr else "")+(f"\n(exit {r.returncode})"if r.returncode else"")
            m=re.match(r"cd\s+(.+?)(\s*&&|\s*;|\s*$)",command)
            if m:
                nd=self._resolve(m.group(1).strip())
                if os.path.isdir(nd):self.cwd=nd
            return out[:10000]
        except subprocess.TimeoutExpired:return f"⏰ 超時 ({timeout}s)"
        except Exception as e:return f"❌ {e}"
    def search_files(self,pattern,glob_pattern=None):
        rg=shutil.which("rg");cmd=[rg or"grep","-rn"]
        if rg:cmd+=["--color=never","--max-count=50"]
        if glob_pattern and rg:cmd+=["--glob",glob_pattern]
        cmd+=[pattern,self.cwd]
        try:return subprocess.run(cmd,capture_output=True,text=True,timeout=30).stdout[:5000]or"無匹配"
        except Exception as e:return f"❌ {e}"
    def list_files(self,pattern="*",max_depth=3):
        try:
            r=subprocess.run(["find",self.cwd,"-maxdepth",str(max_depth),"-name",pattern,"-not","-path","*/.git/*","-not","-path","*/node_modules/*","-not","-path","*/__pycache__/*"],capture_output=True,text=True,timeout=10)
            return"\n".join(sorted(os.path.relpath(f,self.cwd)for f in r.stdout.strip().split("\n")if f.strip())[:100])
        except:
            files=[]
            for root,dirs,fnames in os.walk(self.cwd):
                dirs[:]=[d for d in dirs if d not in{".git","node_modules","__pycache__",".venv"}]
                if root.replace(self.cwd,"").count(os.sep)>=max_depth:continue
                files.extend(os.path.relpath(os.path.join(root,f),self.cwd)for f in fnames if Path(f).match(pattern))
            return"\n".join(sorted(files)[:100])
    def git_context(self):
        try:
            b=subprocess.run(["git","branch","--show-current"],cwd=self.project_dir,capture_output=True,text=True).stdout.strip()
            s=subprocess.run(["git","status","--short"],cwd=self.project_dir,capture_output=True,text=True).stdout.strip()
            l=subprocess.run(["git","log","--oneline","-5"],cwd=self.project_dir,capture_output=True,text=True).stdout.strip()
            return f"Branch: {b}\nStatus:\n{s}\nRecent:\n{l}"
        except:return"(not a git repo)"
TOOL_PATTERN=re.compile(r'<tool>\s*(\w+)\s*\n(.*?)</tool>',re.DOTALL)
def parse_tool_calls(text):
    calls=[]
    for m in TOOL_PATTERN.finditer(text):
        try:params=json.loads(m.group(2).strip())
        except:
            params={}
            for line in m.group(2).strip().split("\n"):
                if":"in line:k,v=line.split(":",1);params[k.strip()]=v.strip().strip('"').strip("'")
        calls.append({"tool":m.group(1),"params":params})
    return calls
def execute_tool(tools,call):
    n,p=call["tool"],call["params"]
    try:
        if n=="read_file":return tools.read_file(p.get("path",""),int(p.get("offset",1)),int(p.get("limit",200)))
        elif n=="edit_file":return tools.edit_file(p.get("path",""),p.get("old_string",""),p.get("new_string",""))
        elif n=="write_file":return tools.write_file(p.get("path",""),p.get("content",""))
        elif n=="run_command":return tools.run_command(p.get("command",""),int(p.get("timeout",120)))
        elif n=="search_files":return tools.search_files(p.get("pattern",""),p.get("glob"))
        elif n=="list_files":return tools.list_files(p.get("pattern","*"),int(p.get("max_depth",3)))
        elif n=="git_status":return tools.git_context()
        else:return f"❌ 未知: {n}"
    except Exception as e:return f"❌ {e}"
def build_system_prompt(tools):
    return f"""You are CodePilot, an expert AI programming assistant working directly in the user's project.

## Current Project
Working directory: {tools.cwd}
{tools.git_context()}

## Available Tools
Use tools by wrapping in <tool></tool> tags:

### read_file — Read a file (ALWAYS do this before editing)
<tool>read_file
{{"path": "src/main.py", "offset": 1, "limit": 200}}
</tool>

### edit_file — Edit file by exact string replace (must read first, old_string must be unique)
<tool>edit_file
{{"path": "src/main.py", "old_string": "def old():\\n    pass", "new_string": "def new():\\n    return 42"}}
</tool>

### write_file — Create or overwrite a file
<tool>write_file
{{"path": "src/new.py", "content": "print('hello')"}}
</tool>

### run_command — Execute shell command
<tool>run_command
{{"command": "python -m pytest tests/", "timeout": 60}}
</tool>

### search_files — Search code with regex
<tool>search_files
{{"pattern": "def main", "glob": "*.py"}}
</tool>

### list_files — List project files
<tool>list_files
{{"pattern": "*.py", "max_depth": 3}}
</tool>

### git_status — Get git info
<tool>git_status
{{}}
</tool>

## Rules
1. ALWAYS read a file before editing it
2. old_string must EXACTLY match file content (whitespace matters)
3. Prefer edit_file over write_file for existing files
4. After changes, verify by reading file or running tests
5. For git: stage specific files, never git add -A
6. Be concise and actionable"""
class CodeModel:
    def __init__(self,model_name=DEFAULT_MODEL,adapter_path=None):
        from transformers import AutoTokenizer,AutoModelForCausalLM
        self.tokenizer=AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:self.tokenizer.pad_token=self.tokenizer.eos_token
        self.model=AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,device_map="auto",trust_remote_code=True)
        if adapter_path and os.path.exists(adapter_path):
            from peft import PeftModel;self.model=PeftModel.from_pretrained(self.model,adapter_path)
        self.model.eval()
    def chat(self,messages,max_tokens=4096):
        text=self.tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)
        inputs=self.tokenizer(text,return_tensors="pt").to(self.model.device)
        with torch.no_grad():out=self.model.generate(**inputs,max_new_tokens=max_tokens,do_sample=True,temperature=0.7,top_p=0.9,repetition_penalty=1.1,pad_token_id=self.tokenizer.pad_token_id)
        return self.tokenizer.decode(out[0][inputs["input_ids"].shape[1]:],skip_special_tokens=True)
def run_agent_loop(args):
    from rich.console import Console;from rich.markdown import Markdown;from rich.panel import Panel;from rich.prompt import Prompt;from rich.syntax import Syntax
    console=Console();db=FeedbackDB();project_dir=args.project or os.getcwd();tools=ProjectTools(project_dir)
    console.print(Panel.fit(f"[bold cyan]CodePilot v2[/] — Claude Code 風格 AI 開發助手\n[dim]Model: {args.model or DEFAULT_MODEL}\nProject: {project_dir}[/]",border_style="cyan"))
    with console.status("[bold green]載入模型中..."):model=CodeModel(args.model or DEFAULT_MODEL,args.adapter)
    console.print("[green]✅ 模型載入完成[/]")
    git_ctx=tools.git_context()
    if git_ctx!="(not a git repo)":console.print(Panel(git_ctx,title="📂 Project",border_style="dim"))
    console.print("[dim]直接輸入需求 | /ls 列檔 | /git 狀態 | /clear 清除 | /status 統計 | /train 訓練 | /quit 退出[/]\n")
    system_prompt=build_system_prompt(tools);messages=[{"role":"system","content":system_prompt}]
    while True:
        try:user_input=Prompt.ask("\n[bold green]🧑 You")
        except(EOFError,KeyboardInterrupt):break
        if not user_input.strip():continue
        cmd=user_input.strip()
        if cmd in("/quit","/exit"):break
        elif cmd=="/status":s=db.count();console.print(f"  👍{s['up']} 👎{s['total']-s['up']} ✏️{s['edits']} Total:{s['total']}");continue
        elif cmd=="/train":trigger_training(db,console,args);continue
        elif cmd=="/clear":messages=[{"role":"system","content":system_prompt}];console.print("[dim]已清除[/]");continue
        elif cmd=="/git":console.print(Panel(tools.git_context(),title="Git",border_style="dim"));continue
        elif cmd.startswith("/ls"):console.print(tools.list_files(cmd[3:].strip()or"*"));continue
        messages.append({"role":"user","content":user_input});full_response=""
        for rnd in range(10):
            with console.status(f"[bold cyan]{'思考中'if rnd==0 else f'工具 round {rnd+1}'}..."):response=model.chat(messages)
            tool_calls=parse_tool_calls(response);text_parts=TOOL_PATTERN.sub("",response).strip()
            if text_parts:console.print(f"\n[bold blue]🤖 CodePilot:[/]");console.print(Markdown(text_parts))
            full_response+=response+"\n"
            if not tool_calls:break
            messages.append({"role":"assistant","content":response});results=[]
            for call in tool_calls:
                console.print(f"\n  [dim]🔧 {call['tool']}({json.dumps(call['params'],ensure_ascii=False)[:100]})[/]")
                result=execute_tool(tools,call)
                if call["tool"]=="edit_file"and result.startswith("✅"):
                    d=result.split("\n",1)[1]if"\n"in result else""
                    if d:console.print(Syntax(d,"diff",theme="monokai"))
                    else:console.print(f"  [green]{result.split(chr(10))[0]}[/]")
                elif call["tool"]=="run_command":console.print(Panel(result[:500]+("..."if len(result)>500 else""),title="Terminal",border_style="dim"))
                elif call["tool"]=="read_file":console.print(f"  [dim]({result.count(chr(10))+1} lines)[/]")
                else:console.print(f"  [dim]{result[:200]}{'...'if len(result)>200 else''}[/]")
                results.append(f"[{call['tool']}] {result}")
            messages.append({"role":"user","content":"Tool results:\n"+"\n\n".join(results)})
        console.print(f"\n[dim][green]y[/]=👍 [red]n[/]=👎 [yellow]e[/]=✏️ Enter=跳過[/]")
        fb=Prompt.ask("  ",choices=["y","n","e",""],default="",show_choices=False)
        if fb=="y":db.save(user_input,full_response,1,project=project_dir);console.print("  [green]👍[/]")
        elif fb=="n":db.save(user_input,full_response,0,project=project_dir);console.print("  [red]👎[/]")
        elif fb=="e":
            console.print("  [yellow]貼上修改版(END結束):[/]");lines=[]
            while True:
                try:l=input();(lines.append(l)if l.strip()!="END"else(_ for _ in()).throw(StopIteration))
                except(EOFError,StopIteration):break
            edited="\n".join(lines)
            if edited.strip():db.save(user_input,full_response,1,edited=edited,project=project_dir);console.print("  [yellow]✏️[/]")
        if messages[-1]["role"]=="user"and"Tool results:"in messages[-1]["content"]:messages.append({"role":"assistant","content":full_response})
    console.print("\n[cyan]👋[/]")
def trigger_training(db,console,args):
    s=db.count()
    if s["total"]==0:console.print("[yellow]⚠️ 無數據[/]");return
    console.print(f"\n[bold]🚀[/] 👍:{s['up']} 👎:{s['total']-s['up']} ✏️:{s['edits']}")
    from datasets import Dataset;from transformers import AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig;from peft import LoraConfig,prepare_model_for_kbit_training
    mn=args.model or DEFAULT_MODEL;od=os.path.join(CONFIG_DIR,f"adapter_{datetime.now().strftime('%Y%m%d_%H%M')}")
    bnb=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
    pc=LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
    sft,kto=db.export_sft(),db.export_kto()
    if sft:
        console.print(f"[bold]📚 SFT({len(sft)})[/]");from trl import SFTTrainer,SFTConfig
        m=AutoModelForCausalLM.from_pretrained(mn,quantization_config=bnb,device_map="auto",trust_remote_code=True);t=AutoTokenizer.from_pretrained(mn)
        if t.pad_token is None:t.pad_token=t.eos_token
        m=prepare_model_for_kbit_training(m)
        SFTTrainer(model=m,args=SFTConfig(output_dir=od,learning_rate=2e-4,num_train_epochs=3,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_seq_length=1024,gradient_checkpointing=True,bf16=True,optim="paged_adamw_8bit",logging_steps=5,save_total_limit=1,logging_strategy="steps",logging_first_step=True),processing_class=t,train_dataset=Dataset.from_list(sft),peft_config=pc).train()
        m.save_pretrained(od);del m;torch.cuda.empty_cache()
    elif len(kto)>=10:
        console.print(f"[bold]📚 KTO({len(kto)})[/]");from trl import KTOConfig,KTOTrainer
        m=AutoModelForCausalLM.from_pretrained(mn,quantization_config=bnb,device_map="auto",trust_remote_code=True);t=AutoTokenizer.from_pretrained(mn)
        if t.pad_token is None:t.pad_token=t.eos_token
        KTOTrainer(model=m,args=KTOConfig(output_dir=od,learning_rate=1e-5,num_train_epochs=1,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_length=1024,gradient_checkpointing=True,bf16=True,logging_steps=5,logging_strategy="steps",logging_first_step=True),processing_class=t,train_dataset=Dataset.from_list(kto),peft_config=pc).train()
        m.save_pretrained(od)
    console.print(f"\n[bold green]🎉[/] {od}\n   codepilot --adapter {od}")
def show_stats():
    from rich.console import Console;from rich.table import Table
    c=Console();s=FeedbackDB().count();t=Table(title="📊 CodePilot");t.add_column("",style="cyan");t.add_column("",style="green")
    t.add_row("Total",str(s["total"]));t.add_row("👍",str(s["up"]));t.add_row("👎",str(s["total"]-s["up"]));t.add_row("✏️",str(s["edits"]));c.print(t)
def main():
    p=argparse.ArgumentParser(description="CodePilot v2");p.add_argument("--model",type=str);p.add_argument("--adapter",type=str);p.add_argument("--project",type=str);p.add_argument("--stats",action="store_true");p.add_argument("--train",action="store_true")
    a=p.parse_args()
    if a.stats:show_stats()
    elif a.train:from rich.console import Console;trigger_training(FeedbackDB(),Console(),a)
    else:run_agent_loop(a)
if __name__=="__main__":main()