#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Code LLM 數據飛輪系統 (Data Flywheel) ======================================= 使用模型時自動收集數據 → 累積到一定量 → 自動觸發訓練 → 模型變更強 三種收集模式: 1. 互動模式 — 你問模型寫 code,接受/拒絕/修改回答 → 自動產生訓練數據 2. Git 監控 — 監控你的 Git repo,新 commit 自動變成訓練數據 3. API 服務 — 部署成 API,每次請求自動記錄 Usage: python code_llm_collector.py chat # 互動模式 python code_llm_collector.py watch --repo . # Git 監控 python code_llm_collector.py status # 查看狀態 python code_llm_collector.py train # 用收集的數據訓練 python code_llm_collector.py export # 匯出到 HuggingFace """ import argparse, json, os, subprocess, sys, tempfile, time, hashlib, torch from datetime import datetime from pathlib import Path BASE_MODEL = "Qwen/Qwen2.5-Coder-3B" ADAPTER_PATH = None HF_USERNAME = "YOUR_HF_USERNAME" DATA_DIR = "./collected_data" SFT_FILE = os.path.join(DATA_DIR, "sft_data.jsonl") DPO_FILE = os.path.join(DATA_DIR, "dpo_data.jsonl") GRPO_FILE = os.path.join(DATA_DIR, "grpo_data.jsonl") META_FILE = os.path.join(DATA_DIR, "metadata.json") AUTO_TRAIN_THRESHOLD = 100 def ensure_data_dir(): os.makedirs(DATA_DIR, exist_ok=True) if not os.path.exists(META_FILE): save_metadata({"total_sft":0,"total_dpo":0,"total_grpo":0,"last_train":None,"train_count":0,"created":datetime.now().isoformat()}) def load_metadata(): if os.path.exists(META_FILE): with open(META_FILE) as f: return json.load(f) return {} def save_metadata(meta): with open(META_FILE,"w") as f: json.dump(meta,f,indent=2,ensure_ascii=False) def append_data(filepath, data): with open(filepath,"a",encoding="utf-8") as f: f.write(json.dumps(data,ensure_ascii=False)+"\n") def count_lines(filepath): if not os.path.exists(filepath): return 0 with open(filepath) as f: return sum(1 for _ in f) # ============================================================ # 互動模式 # ============================================================ def run_chat(): from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel ensure_data_dir() print(""" ╔════════════════════════════════════════════════════════════╗ ║ Code LLM 互動模式 — 邊用邊收集數據 ║ ╠════════════════════════════════════════════════════════════╣ ║ 直接輸入問題 → 模型寫 code ║ ║ /accept → 接受(存為 SFT 數據) ║ ║ /edit → 貼上修改版(產生 SFT + DPO 對) ║ ║ /reject → 拒絕 ║ ║ /test → 加測試(產生 GRPO 數據) ║ ║ /status → 查看收集狀態 ║ ║ /train → 用收集的數據訓練 ║ ║ /quit → 退出 ║ ╚════════════════════════════════════════════════════════════╝ """) print("📥 載入模型...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True) model = AutoModelForCausalLM.from_pretrained(BASE_MODEL,quantization_config=bnb_config,device_map="auto",trust_remote_code=True) if ADAPTER_PATH and os.path.exists(ADAPTER_PATH): model = PeftModel.from_pretrained(model, ADAPTER_PATH); print(f" LoRA: {ADAPTER_PATH}") model.eval(); print("✅ 模型載入完成\n") meta = load_metadata(); current_prompt = None; current_response = None while True: try: user_input = input("🧑 你: ").strip() except (EOFError, KeyboardInterrupt): break if not user_input: continue if user_input == "/quit": break elif user_input == "/status": show_status(); continue elif user_input == "/train": trigger_training(); continue elif user_input == "/accept": if current_prompt and current_response: append_data(SFT_FILE, {"messages":[{"role":"user","content":current_prompt},{"role":"assistant","content":current_response}],"timestamp":datetime.now().isoformat(),"source":"chat_accepted"}) meta["total_sft"] = meta.get("total_sft",0)+1; save_metadata(meta) print(f" ✅ SFT +1 (累計: {meta['total_sft']})"); check_auto_train(meta) continue elif user_input == "/reject": print(" ❌ 已拒絕"); current_response = None; continue elif user_input == "/edit": if current_prompt and current_response: print(" 貼上修改後的 code(輸入 END 結束):") edited_lines = [] while True: line = input() if line.strip() == "END": break edited_lines.append(line) edited_code = "\n".join(edited_lines) if edited_code.strip(): append_data(DPO_FILE, {"prompt":[{"role":"user","content":current_prompt}],"chosen":[{"role":"assistant","content":edited_code}],"rejected":[{"role":"assistant","content":current_response}],"timestamp":datetime.now().isoformat(),"source":"chat_edited"}) append_data(SFT_FILE, {"messages":[{"role":"user","content":current_prompt},{"role":"assistant","content":edited_code}],"timestamp":datetime.now().isoformat(),"source":"chat_edited_sft"}) meta["total_dpo"] = meta.get("total_dpo",0)+1; meta["total_sft"] = meta.get("total_sft",0)+1; save_metadata(meta) print(f" ✅ DPO +1 / SFT +1 (DPO:{meta['total_dpo']} SFT:{meta['total_sft']})"); check_auto_train(meta) continue elif user_input == "/test": if current_prompt and current_response: print(" 貼上 pytest 測試(輸入 END 結束):") test_lines = [] while True: line = input() if line.strip() == "END": break test_lines.append(line) test_code = "\n".join(test_lines) if test_code.strip(): append_data(GRPO_FILE, {"prompt":[{"role":"user","content":current_prompt}],"solution":current_response,"test":test_code,"timestamp":datetime.now().isoformat(),"source":"chat_test"}) meta["total_grpo"] = meta.get("total_grpo",0)+1; save_metadata(meta) print(f" ✅ GRPO +1 (累計: {meta['total_grpo']})") continue # 生成回答 current_prompt = user_input messages = [{"role":"system","content":"You are an exceptionally skilled programmer. Write clean, efficient, well-documented code."},{"role":"user","content":user_input}] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(text, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.pad_token_id) current_response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) print(f"\n🤖 模型:\n{current_response}\n\n 💡 /accept | /edit | /reject | /test\n") # ============================================================ # Git 監控模式 # ============================================================ def run_watch(repo_path): ensure_data_dir(); repo_path = os.path.abspath(repo_path) print(f" 👀 監控: {repo_path}"); meta = load_metadata() seen_file = os.path.join(DATA_DIR, "seen_commits.json") seen = set(json.load(open(seen_file))) if os.path.exists(seen_file) else set() print(f" 已處理: {len(seen)} commits\n 監控中... (Ctrl+C 停止)\n") while True: try: r = subprocess.run(["git","log","--oneline","-20","--format=%H %s"], cwd=repo_path, capture_output=True, text=True) for line in r.stdout.strip().split("\n"): if not line.strip(): continue parts = line.split(" ",1); h = parts[0]; msg = parts[1] if len(parts)>1 else "" if h in seen: continue dr = subprocess.run(["git","diff",f"{h}~1",h,"--name-only"], cwd=repo_path, capture_output=True, text=True) for f in [x for x in dr.stdout.strip().split("\n") if x.endswith(".py")]: try: fr = subprocess.run(["git","show",f"{h}:{f}"], cwd=repo_path, capture_output=True, text=True) code = fr.stdout if 50 < len(code) < 10000: append_data(SFT_FILE, {"messages":[{"role":"user","content":f"Write: {f}\nCommit: {msg}"},{"role":"assistant","content":code}],"timestamp":datetime.now().isoformat(),"source":"git","commit":h[:8],"file":f}) meta["total_sft"] = meta.get("total_sft",0)+1 print(f" 📝 {h[:8]} | {f} → SFT ({meta['total_sft']})") except: pass seen.add(h) save_metadata(meta); json.dump(list(seen), open(seen_file,"w")); check_auto_train(meta) time.sleep(30) except KeyboardInterrupt: print("\n⏹️ 已停止"); break except Exception as e: print(f" ⚠️ {e}"); time.sleep(30) def show_status(): ensure_data_dir(); meta = load_metadata() s,d,g = count_lines(SFT_FILE), count_lines(DPO_FILE), count_lines(GRPO_FILE); t = s+d+g print(f""" 📊 數據收集狀態 ───────────────────────────── SFT: {s:>5} 條 {'█'*min(s//5,30)} DPO: {d:>5} 條 {'█'*min(d//5,30)} GRPO: {g:>5} 條 {'█'*min(g//5,30)} ───────────────────────────── 總計: {t:>5} 條 自動訓練門檻: {AUTO_TRAIN_THRESHOLD} 條 距下次訓練: {max(0,AUTO_TRAIN_THRESHOLD-t)} 條 已訓練次數: {meta.get('train_count',0)} 次 """) def check_auto_train(meta): total = count_lines(SFT_FILE)+count_lines(DPO_FILE)+count_lines(GRPO_FILE) new = total - meta.get("last_train_total",0) if new >= AUTO_TRAIN_THRESHOLD: print(f"\n 🔔 累積 {new} 條新數據!運行 python code_llm_collector.py train") def trigger_training(): ensure_data_dir(); meta = load_metadata() s,d = count_lines(SFT_FILE), count_lines(DPO_FILE) if s+d == 0: print(" ⚠️ 無數據"); return print(f"\n🚀 訓練中... SFT:{s} DPO:{d}") from datasets import Dataset from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, PeftModel tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token bnb = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True) if s > 0: from trl import SFTTrainer, SFTConfig data = [json.loads(l) for l in open(SFT_FILE)]; ds = Dataset.from_list([{"messages":x["messages"]} for x in data]) model = AutoModelForCausalLM.from_pretrained(BASE_MODEL,quantization_config=bnb,device_map="auto",trust_remote_code=True) if ADAPTER_PATH and os.path.exists(ADAPTER_PATH): model = PeftModel.from_pretrained(model,ADAPTER_PATH,is_trainable=True) else: model = prepare_model_for_kbit_training(model); model = get_peft_model(model, LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"])) td = os.path.join(DATA_DIR,f"train_{datetime.now().strftime('%Y%m%d_%H%M')}") trainer = SFTTrainer(model=model,args=SFTConfig(output_dir=td,learning_rate=2e-4,num_train_epochs=3,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_seq_length=1024,gradient_checkpointing=True,bf16=True,optim="paged_adamw_8bit",logging_steps=10,save_total_limit=1,logging_strategy="steps",logging_first_step=True),processing_class=tokenizer,train_dataset=ds) trainer.train(); trainer.save_model(td); print(f" ✅ SFT: {td}"); del model; torch.cuda.empty_cache() meta["last_train"]=datetime.now().isoformat(); meta["train_count"]=meta.get("train_count",0)+1; meta["last_train_total"]=s+d; save_metadata(meta) print(f"\n🎉 第 {meta['train_count']} 次訓練完成!") def export_dataset(): ensure_data_dir(); s,d = count_lines(SFT_FILE), count_lines(DPO_FILE) if s+d == 0: print(" ⚠️ 無數據"); return from datasets import Dataset if s > 0: ds = Dataset.from_list([json.loads(l) for l in open(SFT_FILE)]); n = f"{HF_USERNAME}/my-code-sft-data" ds.push_to_hub(n, private=True); print(f" ✅ SFT: https://huggingface.co/datasets/{n}") if d > 0: ds = Dataset.from_list([json.loads(l) for l in open(DPO_FILE)]); n = f"{HF_USERNAME}/my-code-dpo-data" ds.push_to_hub(n, private=True); print(f" ✅ DPO: https://huggingface.co/datasets/{n}") def main(): parser = argparse.ArgumentParser(description="Code LLM 數據飛輪") parser.add_argument("mode", choices=["chat","watch","status","train","export"]) parser.add_argument("--repo", type=str, default=".") args = parser.parse_args() {"chat":run_chat,"watch":lambda:run_watch(args.repo),"status":show_status,"train":trigger_training,"export":export_dataset}[args.mode]() if __name__ == "__main__": main()