Justin-lee commited on
Commit
d341b49
·
verified ·
1 Parent(s): dc43d1a

Add CodePilot CLI tool with feedback collection

Browse files
Files changed (1) hide show
  1. codepilot.py +248 -0
codepilot.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ CodePilot — 你的專屬 AI 程式助手
5
+ =================================
6
+
7
+ 像 Claude Code 的終端 CLI 工具,內建:
8
+ 🤖 本地 Qwen2.5-Coder 模型推理
9
+ 📁 讀取/修改你的專案文件
10
+ 👍👎 一鍵回饋,自動收集訓練數據
11
+ 🔄 定期用你的回饋數據訓練,模型越用越聰明
12
+
13
+ Install:
14
+ pip install transformers peft bitsandbytes accelerate trl datasets rich
15
+
16
+ Usage:
17
+ python codepilot.py # 啟動
18
+ python codepilot.py --project ~/my-project # 指定專案
19
+ python codepilot.py --adapter ./my-adapter # 用微調模型
20
+ python codepilot.py --stats # 回饋統計
21
+ python codepilot.py --train # 訓練模型
22
+ """
23
+ import argparse, json, os, sqlite3, subprocess, sys, textwrap, torch
24
+ from datetime import datetime
25
+ from pathlib import Path
26
+
27
+ DEFAULT_MODEL = "Qwen/Qwen2.5-Coder-3B-Instruct"
28
+ CONFIG_DIR = os.path.expanduser("~/.codepilot")
29
+ DB_PATH = os.path.join(CONFIG_DIR, "feedback.db")
30
+ AUTO_TRAIN_THRESHOLD = 50
31
+
32
+ class FeedbackDB:
33
+ def __init__(self, db_path=DB_PATH):
34
+ os.makedirs(os.path.dirname(db_path), exist_ok=True)
35
+ self.conn = sqlite3.connect(db_path)
36
+ self.conn.execute("""CREATE TABLE IF NOT EXISTS feedback (
37
+ id INTEGER PRIMARY KEY, timestamp TEXT, prompt TEXT, completion TEXT,
38
+ label INTEGER, edited_completion TEXT, project_dir TEXT, files_context TEXT)""")
39
+ self.conn.commit()
40
+
41
+ def save(self, prompt, completion, label, edited=None, project_dir=None, files=None):
42
+ self.conn.execute("INSERT INTO feedback VALUES (NULL,?,?,?,?,?,?,?)",
43
+ (datetime.now().isoformat(), prompt, completion, int(label), edited, project_dir,
44
+ json.dumps(files) if files else None))
45
+ self.conn.commit()
46
+
47
+ def count(self):
48
+ r = self.conn.execute("SELECT COUNT(*), SUM(label), SUM(CASE WHEN edited_completion IS NOT NULL THEN 1 ELSE 0 END) FROM feedback").fetchone()
49
+ return {"total": r[0] or 0, "thumbs_up": r[1] or 0, "edits": r[2] or 0}
50
+
51
+ def export_kto(self):
52
+ rows = self.conn.execute("SELECT prompt, completion, label FROM feedback").fetchall()
53
+ return [{"prompt":[{"role":"user","content":p}],"completion":[{"role":"assistant","content":c}],"label":bool(l)} for p,c,l in rows]
54
+
55
+ def export_sft(self):
56
+ rows = self.conn.execute("SELECT prompt, edited_completion FROM feedback WHERE edited_completion IS NOT NULL").fetchall()
57
+ return [{"messages":[{"role":"user","content":p},{"role":"assistant","content":c}]} for p,c in rows]
58
+
59
+ def export_dpo(self):
60
+ rows = self.conn.execute("SELECT prompt, completion, edited_completion FROM feedback WHERE edited_completion IS NOT NULL").fetchall()
61
+ return [{"prompt":[{"role":"user","content":p}],"chosen":[{"role":"assistant","content":e}],"rejected":[{"role":"assistant","content":o}]} for p,o,e in rows]
62
+
63
+ class CodeModel:
64
+ def __init__(self, model_name=DEFAULT_MODEL, adapter_path=None):
65
+ from transformers import AutoTokenizer, AutoModelForCausalLM
66
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
67
+ if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
68
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
69
+ if adapter_path and os.path.exists(adapter_path):
70
+ from peft import PeftModel
71
+ self.model = PeftModel.from_pretrained(self.model, adapter_path)
72
+ self.model.eval()
73
+
74
+ def generate(self, user_message, system_prompt=None, file_context=None, max_tokens=2048):
75
+ messages = []
76
+ if system_prompt: messages.append({"role":"system","content":system_prompt})
77
+ if file_context:
78
+ messages.append({"role":"user","content":f"相關文件:\n\n{file_context}"})
79
+ messages.append({"role":"assistant","content":"已了解。請問需要什麼幫助?"})
80
+ messages.append({"role":"user","content":user_message})
81
+ text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
+ inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
83
+ with torch.no_grad():
84
+ outputs = self.model.generate(**inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.7, top_p=0.9, repetition_penalty=1.1, pad_token_id=self.tokenizer.pad_token_id)
85
+ return self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
86
+
87
+ def extract_code_blocks(text):
88
+ blocks = []
89
+ parts = text.split("```")
90
+ for i in range(1, len(parts), 2):
91
+ lines = parts[i].split("\n", 1)
92
+ lang = lines[0].strip() or "python"
93
+ code = lines[1].strip() if len(lines) > 1 else ""
94
+ blocks.append((lang, code))
95
+ return blocks
96
+
97
+ def run_cli(args):
98
+ from rich.console import Console
99
+ from rich.markdown import Markdown
100
+ from rich.panel import Panel
101
+ from rich.prompt import Prompt
102
+ from rich.syntax import Syntax
103
+ from rich.table import Table
104
+
105
+ console = Console(); db = FeedbackDB()
106
+ console.print(Panel.fit("[bold cyan]CodePilot[/] — 你的專屬 AI 程式助手\n" + f"[dim]Model: {args.model or DEFAULT_MODEL}[/]", border_style="cyan"))
107
+
108
+ with console.status("[bold green]載入模型中..."):
109
+ model = CodeModel(args.model or DEFAULT_MODEL, args.adapter)
110
+ console.print("[green]✅ 模型載入完成[/]\n")
111
+ console.print("[dim]指令: /file <path> 讀文件 | /apply 套用code | /stats 統計 | /train 訓練 | /quit 退出[/]\n")
112
+
113
+ system_prompt = "You are CodePilot, an expert programming assistant. Write clean, efficient, well-documented code. When modifying existing code, show the complete modified version."
114
+ project_dir = args.project or os.getcwd()
115
+ file_context = current_response = current_prompt = None
116
+
117
+ while True:
118
+ try: user_input = Prompt.ask("\n[bold green]🧑 You")
119
+ except (EOFError, KeyboardInterrupt): break
120
+ if not user_input.strip(): continue
121
+ cmd = user_input.strip().lower()
122
+
123
+ if cmd in ("/quit", "/exit"): break
124
+ elif cmd == "/stats":
125
+ stats = db.count()
126
+ t = Table(title="📊 回饋統計"); t.add_column("指標",style="cyan"); t.add_column("數值",style="green")
127
+ t.add_row("總回饋",str(stats["total"])); t.add_row("👍",str(stats["thumbs_up"]))
128
+ t.add_row("👎",str(stats["total"]-stats["thumbs_up"])); t.add_row("✏️修改",str(stats["edits"]))
129
+ console.print(t); continue
130
+ elif cmd == "/train": trigger_cli_training(db, console, args); continue
131
+ elif cmd.startswith("/file "):
132
+ fp = os.path.join(project_dir, user_input[6:].strip())
133
+ if os.path.exists(fp):
134
+ with open(fp) as f: content = f.read()
135
+ file_context = f"--- {fp} ---\n{content}\n--- END ---"
136
+ console.print(f"[green]📁 已讀取: {fp} ({len(content)} chars)[/]")
137
+ else: console.print(f"[red]❌ 不存在: {fp}[/]")
138
+ continue
139
+ elif cmd == "/apply":
140
+ if current_response:
141
+ for i,(lang,code) in enumerate(extract_code_blocks(current_response)):
142
+ console.print(Syntax(code, lang or "python", theme="monokai"))
143
+ fp = Prompt.ask(" 儲存到? (Enter跳過)")
144
+ if fp.strip():
145
+ full = os.path.join(project_dir, fp)
146
+ os.makedirs(os.path.dirname(full) or ".", exist_ok=True)
147
+ open(full,"w").write(code)
148
+ console.print(f" [green]✅ {full}[/]")
149
+ continue
150
+
151
+ current_prompt = user_input
152
+ with console.status("[bold cyan]思考中..."):
153
+ current_response = model.generate(user_input, system_prompt=system_prompt, file_context=file_context)
154
+ console.print("\n[bold blue]🤖 CodePilot:[/]")
155
+ console.print(Markdown(current_response))
156
+ console.print("\n[dim][green]y[/]=👍 [red]n[/]=👎 [yellow]e[/]=✏️修改 Enter=跳過[/]")
157
+ fb = Prompt.ask(" ", choices=["y","n","e",""], default="", show_choices=False)
158
+
159
+ if fb == "y":
160
+ db.save(current_prompt, current_response, label=1, project_dir=project_dir)
161
+ s = db.count(); console.print(f" [green]👍 +1 (累計:{s['total']})[/]")
162
+ elif fb == "n":
163
+ db.save(current_prompt, current_response, label=0, project_dir=project_dir)
164
+ s = db.count(); console.print(f" [red]👎 +1 (累計:{s['total']})[/]")
165
+ elif fb == "e":
166
+ console.print(" [yellow]貼上修改版 (END 結束):[/]")
167
+ lines = []
168
+ while True:
169
+ try:
170
+ l = input()
171
+ if l.strip()=="END": break
172
+ lines.append(l)
173
+ except EOFError: break
174
+ edited = "\n".join(lines)
175
+ if edited.strip():
176
+ db.save(current_prompt, current_response, label=1, edited=edited, project_dir=project_dir)
177
+ s = db.count(); console.print(f" [yellow]✏️ +1 (累計:{s['total']}, 修改:{s['edits']})[/]")
178
+
179
+ if db.count()["total"] % AUTO_TRAIN_THRESHOLD == 0 and db.count()["total"] > 0:
180
+ console.print(f"\n [bold yellow]🔔 累積 {db.count()['total']} 條!codepilot --train[/]")
181
+
182
+ console.print("\n[cyan]👋 再見![/]")
183
+
184
+ def trigger_cli_training(db, console, args):
185
+ stats = db.count()
186
+ if stats["total"] == 0: console.print("[yellow]⚠️ 無數據[/]"); return
187
+ console.print(f"\n[bold]🚀 訓練[/] 👍:{stats['thumbs_up']} 👎:{stats['total']-stats['thumbs_up']} ✏️:{stats['edits']}")
188
+ from datasets import Dataset
189
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
190
+ from peft import LoraConfig, prepare_model_for_kbit_training
191
+ model_name = args.model or DEFAULT_MODEL
192
+ output_dir = os.path.join(CONFIG_DIR, f"adapter_{datetime.now().strftime('%Y%m%d_%H%M')}")
193
+ bnb = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
194
+ peft_cfg = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])
195
+
196
+ sft_data = db.export_sft()
197
+ kto_data = db.export_kto()
198
+
199
+ if sft_data:
200
+ console.print(f"\n[bold]📚 SFT ({len(sft_data)} edits)...[/]")
201
+ from trl import SFTTrainer, SFTConfig
202
+ ds = Dataset.from_list(sft_data)
203
+ model = AutoModelForCausalLM.from_pretrained(model_name,quantization_config=bnb,device_map="auto",trust_remote_code=True)
204
+ tok = AutoTokenizer.from_pretrained(model_name)
205
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
206
+ model = prepare_model_for_kbit_training(model)
207
+ trainer = SFTTrainer(model=model,args=SFTConfig(output_dir=output_dir,learning_rate=2e-4,num_train_epochs=3,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_seq_length=1024,gradient_checkpointing=True,bf16=True,optim="paged_adamw_8bit",logging_steps=5,save_total_limit=1,logging_strategy="steps",logging_first_step=True),processing_class=tok,train_dataset=ds,peft_config=peft_cfg)
208
+ trainer.train(); trainer.save_model(output_dir)
209
+ console.print(f"[green]✅ SFT 完成[/]"); del model; torch.cuda.empty_cache()
210
+ elif len(kto_data) >= 10:
211
+ console.print(f"\n[bold]📚 KTO ({len(kto_data)} feedbacks)...[/]")
212
+ from trl import KTOConfig, KTOTrainer
213
+ ds = Dataset.from_list(kto_data)
214
+ model = AutoModelForCausalLM.from_pretrained(model_name,quantization_config=bnb,device_map="auto",trust_remote_code=True)
215
+ tok = AutoTokenizer.from_pretrained(model_name)
216
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
217
+ trainer = KTOTrainer(model=model,args=KTOConfig(output_dir=output_dir,learning_rate=1e-5,num_train_epochs=1,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_length=1024,gradient_checkpointing=True,bf16=True,logging_steps=5,logging_strategy="steps",logging_first_step=True),processing_class=tok,train_dataset=ds,peft_config=peft_cfg)
218
+ trainer.train(); trainer.save_model(output_dir)
219
+ console.print(f"[green]✅ KTO 完成[/]")
220
+
221
+ console.print(f"\n[bold green]🎉 訓練完成![/]\n Adapter: {output_dir}\n 重啟: codepilot --adapter {output_dir}")
222
+
223
+ def show_stats():
224
+ from rich.console import Console; from rich.table import Table
225
+ console = Console(); db = FeedbackDB(); s = db.count()
226
+ t = Table(title="📊 CodePilot 回饋統計"); t.add_column("指標",style="cyan"); t.add_column("數值",style="green"); t.add_column("",style="dim")
227
+ t.add_row("總回饋",str(s["total"]),"█"*min(s["total"]//2,40))
228
+ t.add_row("👍",str(s["thumbs_up"]),"█"*min(s["thumbs_up"]//2,40))
229
+ t.add_row("👎",str(s["total"]-s["thumbs_up"]),"█"*min((s["total"]-s["thumbs_up"])//2,40))
230
+ t.add_row("✏️修改",str(s["edits"]),"█"*min(s["edits"]//2,40))
231
+ console.print(t)
232
+ if s["total"]>0:
233
+ r = s["thumbs_up"]/s["total"]*100; console.print(f"\n 接受率: {r:.0f}%")
234
+ if r < 50: console.print(" [yellow]💡 接受率低,建議 --train[/]")
235
+
236
+ def main():
237
+ p = argparse.ArgumentParser(description="CodePilot — 你的專屬 AI 程式助手")
238
+ p.add_argument("--model",type=str,default=None,help=f"模型 (預設:{DEFAULT_MODEL})")
239
+ p.add_argument("--adapter",type=str,default=None,help="LoRA adapter")
240
+ p.add_argument("--project",type=str,default=None,help="專案目錄")
241
+ p.add_argument("--stats",action="store_true",help="回饋統計")
242
+ p.add_argument("--train",action="store_true",help="訓練模型")
243
+ a = p.parse_args()
244
+ if a.stats: show_stats()
245
+ elif a.train: from rich.console import Console; trigger_cli_training(FeedbackDB(), Console(), a)
246
+ else: run_cli(a)
247
+
248
+ if __name__ == "__main__": main()