sandbox-5ca717e4 / code_llm_collector.py
Justin-lee's picture
Add data flywheel collector system
dc43d1a verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Code LLM ๆ•ธๆ“š้ฃ›่ผช็ณป็ตฑ (Data Flywheel)
=======================================
ไฝฟ็”จๆจกๅž‹ๆ™‚่‡ชๅ‹•ๆ”ถ้›†ๆ•ธๆ“š โ†’ ็ดฏ็ฉๅˆฐไธ€ๅฎš้‡ โ†’ ่‡ชๅ‹•่งธ็™ผ่จ“็ทด โ†’ ๆจกๅž‹่ฎŠๆ›ดๅผท
ไธ‰็จฎๆ”ถ้›†ๆจกๅผ๏ผš
1. ไบ’ๅ‹•ๆจกๅผ โ€” ไฝ ๅ•ๆจกๅž‹ๅฏซ code๏ผŒๆŽฅๅ—/ๆ‹’็ต•/ไฟฎๆ”นๅ›ž็ญ” โ†’ ่‡ชๅ‹•็”ข็”Ÿ่จ“็ทดๆ•ธๆ“š
2. Git ็›ฃๆŽง โ€” ็›ฃๆŽงไฝ ็š„ Git repo๏ผŒๆ–ฐ commit ่‡ชๅ‹•่ฎŠๆˆ่จ“็ทดๆ•ธๆ“š
3. API ๆœๅ‹™ โ€” ้ƒจ็ฝฒๆˆ API๏ผŒๆฏๆฌก่ซ‹ๆฑ‚่‡ชๅ‹•่จ˜้Œ„
Usage:
python code_llm_collector.py chat # ไบ’ๅ‹•ๆจกๅผ
python code_llm_collector.py watch --repo . # Git ็›ฃๆŽง
python code_llm_collector.py status # ๆŸฅ็œ‹็‹€ๆ…‹
python code_llm_collector.py train # ็”จๆ”ถ้›†็š„ๆ•ธๆ“š่จ“็ทด
python code_llm_collector.py export # ๅŒฏๅ‡บๅˆฐ HuggingFace
"""
import argparse, json, os, subprocess, sys, tempfile, time, hashlib, torch
from datetime import datetime
from pathlib import Path
BASE_MODEL = "Qwen/Qwen2.5-Coder-3B"
ADAPTER_PATH = None
HF_USERNAME = "YOUR_HF_USERNAME"
DATA_DIR = "./collected_data"
SFT_FILE = os.path.join(DATA_DIR, "sft_data.jsonl")
DPO_FILE = os.path.join(DATA_DIR, "dpo_data.jsonl")
GRPO_FILE = os.path.join(DATA_DIR, "grpo_data.jsonl")
META_FILE = os.path.join(DATA_DIR, "metadata.json")
AUTO_TRAIN_THRESHOLD = 100
def ensure_data_dir():
os.makedirs(DATA_DIR, exist_ok=True)
if not os.path.exists(META_FILE):
save_metadata({"total_sft":0,"total_dpo":0,"total_grpo":0,"last_train":None,"train_count":0,"created":datetime.now().isoformat()})
def load_metadata():
if os.path.exists(META_FILE):
with open(META_FILE) as f: return json.load(f)
return {}
def save_metadata(meta):
with open(META_FILE,"w") as f: json.dump(meta,f,indent=2,ensure_ascii=False)
def append_data(filepath, data):
with open(filepath,"a",encoding="utf-8") as f: f.write(json.dumps(data,ensure_ascii=False)+"\n")
def count_lines(filepath):
if not os.path.exists(filepath): return 0
with open(filepath) as f: return sum(1 for _ in f)
# ============================================================
# ไบ’ๅ‹•ๆจกๅผ
# ============================================================
def run_chat():
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
ensure_data_dir()
print("""
โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•—
โ•‘ Code LLM ไบ’ๅ‹•ๆจกๅผ โ€” ้‚Š็”จ้‚Šๆ”ถ้›†ๆ•ธๆ“š โ•‘
โ• โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•ฃ
โ•‘ ็›ดๆŽฅ่ผธๅ…ฅๅ•้กŒ โ†’ ๆจกๅž‹ๅฏซ code โ•‘
โ•‘ /accept โ†’ ๆŽฅๅ—๏ผˆๅญ˜็‚บ SFT ๆ•ธๆ“š๏ผ‰ โ•‘
โ•‘ /edit โ†’ ่ฒผไธŠไฟฎๆ”น็‰ˆ๏ผˆ็”ข็”Ÿ SFT + DPO ๅฐ๏ผ‰ โ•‘
โ•‘ /reject โ†’ ๆ‹’็ต• โ•‘
โ•‘ /test โ†’ ๅŠ ๆธฌ่ฉฆ๏ผˆ็”ข็”Ÿ GRPO ๆ•ธๆ“š๏ผ‰ โ•‘
โ•‘ /status โ†’ ๆŸฅ็œ‹ๆ”ถ้›†็‹€ๆ…‹ โ•‘
โ•‘ /train โ†’ ็”จๆ”ถ้›†็š„ๆ•ธๆ“š่จ“็ทด โ•‘
โ•‘ /quit โ†’ ้€€ๅ‡บ โ•‘
โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
""")
print("๐Ÿ“ฅ ่ผ‰ๅ…ฅๆจกๅž‹...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL,quantization_config=bnb_config,device_map="auto",trust_remote_code=True)
if ADAPTER_PATH and os.path.exists(ADAPTER_PATH):
model = PeftModel.from_pretrained(model, ADAPTER_PATH); print(f" LoRA: {ADAPTER_PATH}")
model.eval(); print("โœ… ๆจกๅž‹่ผ‰ๅ…ฅๅฎŒๆˆ\n")
meta = load_metadata(); current_prompt = None; current_response = None
while True:
try: user_input = input("๐Ÿง‘ ไฝ : ").strip()
except (EOFError, KeyboardInterrupt): break
if not user_input: continue
if user_input == "/quit": break
elif user_input == "/status": show_status(); continue
elif user_input == "/train": trigger_training(); continue
elif user_input == "/accept":
if current_prompt and current_response:
append_data(SFT_FILE, {"messages":[{"role":"user","content":current_prompt},{"role":"assistant","content":current_response}],"timestamp":datetime.now().isoformat(),"source":"chat_accepted"})
meta["total_sft"] = meta.get("total_sft",0)+1; save_metadata(meta)
print(f" โœ… SFT +1 (็ดฏ่จˆ: {meta['total_sft']})"); check_auto_train(meta)
continue
elif user_input == "/reject":
print(" โŒ ๅทฒๆ‹’็ต•"); current_response = None; continue
elif user_input == "/edit":
if current_prompt and current_response:
print(" ่ฒผไธŠไฟฎๆ”นๅพŒ็š„ code๏ผˆ่ผธๅ…ฅ END ็ตๆŸ๏ผ‰:")
edited_lines = []
while True:
line = input()
if line.strip() == "END": break
edited_lines.append(line)
edited_code = "\n".join(edited_lines)
if edited_code.strip():
append_data(DPO_FILE, {"prompt":[{"role":"user","content":current_prompt}],"chosen":[{"role":"assistant","content":edited_code}],"rejected":[{"role":"assistant","content":current_response}],"timestamp":datetime.now().isoformat(),"source":"chat_edited"})
append_data(SFT_FILE, {"messages":[{"role":"user","content":current_prompt},{"role":"assistant","content":edited_code}],"timestamp":datetime.now().isoformat(),"source":"chat_edited_sft"})
meta["total_dpo"] = meta.get("total_dpo",0)+1; meta["total_sft"] = meta.get("total_sft",0)+1; save_metadata(meta)
print(f" โœ… DPO +1 / SFT +1 (DPO:{meta['total_dpo']} SFT:{meta['total_sft']})"); check_auto_train(meta)
continue
elif user_input == "/test":
if current_prompt and current_response:
print(" ่ฒผไธŠ pytest ๆธฌ่ฉฆ๏ผˆ่ผธๅ…ฅ END ็ตๆŸ๏ผ‰:")
test_lines = []
while True:
line = input()
if line.strip() == "END": break
test_lines.append(line)
test_code = "\n".join(test_lines)
if test_code.strip():
append_data(GRPO_FILE, {"prompt":[{"role":"user","content":current_prompt}],"solution":current_response,"test":test_code,"timestamp":datetime.now().isoformat(),"source":"chat_test"})
meta["total_grpo"] = meta.get("total_grpo",0)+1; save_metadata(meta)
print(f" โœ… GRPO +1 (็ดฏ่จˆ: {meta['total_grpo']})")
continue
# ็”Ÿๆˆๅ›ž็ญ”
current_prompt = user_input
messages = [{"role":"system","content":"You are an exceptionally skilled programmer. Write clean, efficient, well-documented code."},{"role":"user","content":user_input}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.pad_token_id)
current_response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
print(f"\n๐Ÿค– ๆจกๅž‹:\n{current_response}\n\n ๐Ÿ’ก /accept | /edit | /reject | /test\n")
# ============================================================
# Git ็›ฃๆŽงๆจกๅผ
# ============================================================
def run_watch(repo_path):
ensure_data_dir(); repo_path = os.path.abspath(repo_path)
print(f" ๐Ÿ‘€ ็›ฃๆŽง: {repo_path}"); meta = load_metadata()
seen_file = os.path.join(DATA_DIR, "seen_commits.json")
seen = set(json.load(open(seen_file))) if os.path.exists(seen_file) else set()
print(f" ๅทฒ่™•็†: {len(seen)} commits\n ็›ฃๆŽงไธญ... (Ctrl+C ๅœๆญข)\n")
while True:
try:
r = subprocess.run(["git","log","--oneline","-20","--format=%H %s"], cwd=repo_path, capture_output=True, text=True)
for line in r.stdout.strip().split("\n"):
if not line.strip(): continue
parts = line.split(" ",1); h = parts[0]; msg = parts[1] if len(parts)>1 else ""
if h in seen: continue
dr = subprocess.run(["git","diff",f"{h}~1",h,"--name-only"], cwd=repo_path, capture_output=True, text=True)
for f in [x for x in dr.stdout.strip().split("\n") if x.endswith(".py")]:
try:
fr = subprocess.run(["git","show",f"{h}:{f}"], cwd=repo_path, capture_output=True, text=True)
code = fr.stdout
if 50 < len(code) < 10000:
append_data(SFT_FILE, {"messages":[{"role":"user","content":f"Write: {f}\nCommit: {msg}"},{"role":"assistant","content":code}],"timestamp":datetime.now().isoformat(),"source":"git","commit":h[:8],"file":f})
meta["total_sft"] = meta.get("total_sft",0)+1
print(f" ๐Ÿ“ {h[:8]} | {f} โ†’ SFT ({meta['total_sft']})")
except: pass
seen.add(h)
save_metadata(meta); json.dump(list(seen), open(seen_file,"w")); check_auto_train(meta)
time.sleep(30)
except KeyboardInterrupt: print("\nโน๏ธ ๅทฒๅœๆญข"); break
except Exception as e: print(f" โš ๏ธ {e}"); time.sleep(30)
def show_status():
ensure_data_dir(); meta = load_metadata()
s,d,g = count_lines(SFT_FILE), count_lines(DPO_FILE), count_lines(GRPO_FILE); t = s+d+g
print(f"""
๐Ÿ“Š ๆ•ธๆ“šๆ”ถ้›†็‹€ๆ…‹
โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
SFT: {s:>5} ๆข {'โ–ˆ'*min(s//5,30)}
DPO: {d:>5} ๆข {'โ–ˆ'*min(d//5,30)}
GRPO: {g:>5} ๆข {'โ–ˆ'*min(g//5,30)}
โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
็ธฝ่จˆ: {t:>5} ๆข
่‡ชๅ‹•่จ“็ทด้–€ๆชป: {AUTO_TRAIN_THRESHOLD} ๆข
่ทไธ‹ๆฌก่จ“็ทด: {max(0,AUTO_TRAIN_THRESHOLD-t)} ๆข
ๅทฒ่จ“็ทดๆฌกๆ•ธ: {meta.get('train_count',0)} ๆฌก
""")
def check_auto_train(meta):
total = count_lines(SFT_FILE)+count_lines(DPO_FILE)+count_lines(GRPO_FILE)
new = total - meta.get("last_train_total",0)
if new >= AUTO_TRAIN_THRESHOLD:
print(f"\n ๐Ÿ”” ็ดฏ็ฉ {new} ๆขๆ–ฐๆ•ธๆ“š๏ผ้‹่กŒ python code_llm_collector.py train")
def trigger_training():
ensure_data_dir(); meta = load_metadata()
s,d = count_lines(SFT_FILE), count_lines(DPO_FILE)
if s+d == 0: print(" โš ๏ธ ็„กๆ•ธๆ“š"); return
print(f"\n๐Ÿš€ ่จ“็ทดไธญ... SFT:{s} DPO:{d}")
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, PeftModel
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
bnb = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
if s > 0:
from trl import SFTTrainer, SFTConfig
data = [json.loads(l) for l in open(SFT_FILE)]; ds = Dataset.from_list([{"messages":x["messages"]} for x in data])
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL,quantization_config=bnb,device_map="auto",trust_remote_code=True)
if ADAPTER_PATH and os.path.exists(ADAPTER_PATH): model = PeftModel.from_pretrained(model,ADAPTER_PATH,is_trainable=True)
else: model = prepare_model_for_kbit_training(model); model = get_peft_model(model, LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]))
td = os.path.join(DATA_DIR,f"train_{datetime.now().strftime('%Y%m%d_%H%M')}")
trainer = SFTTrainer(model=model,args=SFTConfig(output_dir=td,learning_rate=2e-4,num_train_epochs=3,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_seq_length=1024,gradient_checkpointing=True,bf16=True,optim="paged_adamw_8bit",logging_steps=10,save_total_limit=1,logging_strategy="steps",logging_first_step=True),processing_class=tokenizer,train_dataset=ds)
trainer.train(); trainer.save_model(td); print(f" โœ… SFT: {td}"); del model; torch.cuda.empty_cache()
meta["last_train"]=datetime.now().isoformat(); meta["train_count"]=meta.get("train_count",0)+1; meta["last_train_total"]=s+d; save_metadata(meta)
print(f"\n๐ŸŽ‰ ็ฌฌ {meta['train_count']} ๆฌก่จ“็ทดๅฎŒๆˆ๏ผ")
def export_dataset():
ensure_data_dir(); s,d = count_lines(SFT_FILE), count_lines(DPO_FILE)
if s+d == 0: print(" โš ๏ธ ็„กๆ•ธๆ“š"); return
from datasets import Dataset
if s > 0:
ds = Dataset.from_list([json.loads(l) for l in open(SFT_FILE)]); n = f"{HF_USERNAME}/my-code-sft-data"
ds.push_to_hub(n, private=True); print(f" โœ… SFT: https://huggingface.co/datasets/{n}")
if d > 0:
ds = Dataset.from_list([json.loads(l) for l in open(DPO_FILE)]); n = f"{HF_USERNAME}/my-code-dpo-data"
ds.push_to_hub(n, private=True); print(f" โœ… DPO: https://huggingface.co/datasets/{n}")
def main():
parser = argparse.ArgumentParser(description="Code LLM ๆ•ธๆ“š้ฃ›่ผช")
parser.add_argument("mode", choices=["chat","watch","status","train","export"])
parser.add_argument("--repo", type=str, default=".")
args = parser.parse_args()
{"chat":run_chat,"watch":lambda:run_watch(args.repo),"status":show_status,"train":trigger_training,"export":export_dataset}[args.mode]()
if __name__ == "__main__": main()