Justin-lee commited on
Commit
dc43d1a
ยท
verified ยท
1 Parent(s): 61f84a7

Add data flywheel collector system

Browse files
Files changed (1) hide show
  1. code_llm_collector.py +237 -0
code_llm_collector.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Code LLM ๆ•ธๆ“š้ฃ›่ผช็ณป็ตฑ (Data Flywheel)
5
+ =======================================
6
+
7
+ ไฝฟ็”จๆจกๅž‹ๆ™‚่‡ชๅ‹•ๆ”ถ้›†ๆ•ธๆ“š โ†’ ็ดฏ็ฉๅˆฐไธ€ๅฎš้‡ โ†’ ่‡ชๅ‹•่งธ็™ผ่จ“็ทด โ†’ ๆจกๅž‹่ฎŠๆ›ดๅผท
8
+
9
+ ไธ‰็จฎๆ”ถ้›†ๆจกๅผ๏ผš
10
+ 1. ไบ’ๅ‹•ๆจกๅผ โ€” ไฝ ๅ•ๆจกๅž‹ๅฏซ code๏ผŒๆŽฅๅ—/ๆ‹’็ต•/ไฟฎๆ”นๅ›ž็ญ” โ†’ ่‡ชๅ‹•็”ข็”Ÿ่จ“็ทดๆ•ธๆ“š
11
+ 2. Git ็›ฃๆŽง โ€” ็›ฃๆŽงไฝ ็š„ Git repo๏ผŒๆ–ฐ commit ่‡ชๅ‹•่ฎŠๆˆ่จ“็ทดๆ•ธๆ“š
12
+ 3. API ๆœๅ‹™ โ€” ้ƒจ็ฝฒๆˆ API๏ผŒๆฏๆฌก่ซ‹ๆฑ‚่‡ชๅ‹•่จ˜้Œ„
13
+
14
+ Usage:
15
+ python code_llm_collector.py chat # ไบ’ๅ‹•ๆจกๅผ
16
+ python code_llm_collector.py watch --repo . # Git ็›ฃๆŽง
17
+ python code_llm_collector.py status # ๆŸฅ็œ‹็‹€ๆ…‹
18
+ python code_llm_collector.py train # ็”จๆ”ถ้›†็š„ๆ•ธๆ“š่จ“็ทด
19
+ python code_llm_collector.py export # ๅŒฏๅ‡บๅˆฐ HuggingFace
20
+ """
21
+
22
+ import argparse, json, os, subprocess, sys, tempfile, time, hashlib, torch
23
+ from datetime import datetime
24
+ from pathlib import Path
25
+
26
+ BASE_MODEL = "Qwen/Qwen2.5-Coder-3B"
27
+ ADAPTER_PATH = None
28
+ HF_USERNAME = "YOUR_HF_USERNAME"
29
+ DATA_DIR = "./collected_data"
30
+ SFT_FILE = os.path.join(DATA_DIR, "sft_data.jsonl")
31
+ DPO_FILE = os.path.join(DATA_DIR, "dpo_data.jsonl")
32
+ GRPO_FILE = os.path.join(DATA_DIR, "grpo_data.jsonl")
33
+ META_FILE = os.path.join(DATA_DIR, "metadata.json")
34
+ AUTO_TRAIN_THRESHOLD = 100
35
+
36
+ def ensure_data_dir():
37
+ os.makedirs(DATA_DIR, exist_ok=True)
38
+ if not os.path.exists(META_FILE):
39
+ save_metadata({"total_sft":0,"total_dpo":0,"total_grpo":0,"last_train":None,"train_count":0,"created":datetime.now().isoformat()})
40
+
41
+ def load_metadata():
42
+ if os.path.exists(META_FILE):
43
+ with open(META_FILE) as f: return json.load(f)
44
+ return {}
45
+
46
+ def save_metadata(meta):
47
+ with open(META_FILE,"w") as f: json.dump(meta,f,indent=2,ensure_ascii=False)
48
+
49
+ def append_data(filepath, data):
50
+ with open(filepath,"a",encoding="utf-8") as f: f.write(json.dumps(data,ensure_ascii=False)+"\n")
51
+
52
+ def count_lines(filepath):
53
+ if not os.path.exists(filepath): return 0
54
+ with open(filepath) as f: return sum(1 for _ in f)
55
+
56
+ # ============================================================
57
+ # ไบ’ๅ‹•ๆจกๅผ
58
+ # ============================================================
59
+ def run_chat():
60
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
61
+ from peft import PeftModel
62
+ ensure_data_dir()
63
+ print("""
64
+ โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•—
65
+ โ•‘ Code LLM ไบ’ๅ‹•ๆจกๅผ โ€” ้‚Š็”จ้‚Šๆ”ถ้›†ๆ•ธๆ“š โ•‘
66
+ โ• โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•ฃ
67
+ โ•‘ ็›ดๆŽฅ่ผธๅ…ฅๅ•้กŒ โ†’ ๆจกๅž‹ๅฏซ code โ•‘
68
+ โ•‘ /accept โ†’ ๆŽฅๅ—๏ผˆๅญ˜็‚บ SFT ๆ•ธๆ“š๏ผ‰ โ•‘
69
+ โ•‘ /edit โ†’ ่ฒผไธŠไฟฎๆ”น็‰ˆ๏ผˆ็”ข็”Ÿ SFT + DPO ๅฐ๏ผ‰ โ•‘
70
+ โ•‘ /reject โ†’ ๆ‹’็ต• โ•‘
71
+ โ•‘ /test โ†’ ๅŠ ๆธฌ่ฉฆ๏ผˆ็”ข็”Ÿ GRPO ๆ•ธๆ“š๏ผ‰ โ•‘
72
+ โ•‘ /status โ†’ ๆŸฅ็œ‹ๆ”ถ้›†็‹€ๆ…‹ โ•‘
73
+ โ•‘ /train โ†’ ็”จๆ”ถ้›†็š„ๆ•ธๆ“š่จ“็ทด โ•‘
74
+ โ•‘ /quit โ†’ ้€€ๅ‡บ โ•‘
75
+ โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
76
+ """)
77
+ print("๐Ÿ“ฅ ่ผ‰ๅ…ฅๆจกๅž‹...")
78
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
79
+ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
80
+ bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
81
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL,quantization_config=bnb_config,device_map="auto",trust_remote_code=True)
82
+ if ADAPTER_PATH and os.path.exists(ADAPTER_PATH):
83
+ model = PeftModel.from_pretrained(model, ADAPTER_PATH); print(f" LoRA: {ADAPTER_PATH}")
84
+ model.eval(); print("โœ… ๆจกๅž‹่ผ‰ๅ…ฅๅฎŒๆˆ\n")
85
+ meta = load_metadata(); current_prompt = None; current_response = None
86
+
87
+ while True:
88
+ try: user_input = input("๐Ÿง‘ ไฝ : ").strip()
89
+ except (EOFError, KeyboardInterrupt): break
90
+ if not user_input: continue
91
+ if user_input == "/quit": break
92
+ elif user_input == "/status": show_status(); continue
93
+ elif user_input == "/train": trigger_training(); continue
94
+ elif user_input == "/accept":
95
+ if current_prompt and current_response:
96
+ append_data(SFT_FILE, {"messages":[{"role":"user","content":current_prompt},{"role":"assistant","content":current_response}],"timestamp":datetime.now().isoformat(),"source":"chat_accepted"})
97
+ meta["total_sft"] = meta.get("total_sft",0)+1; save_metadata(meta)
98
+ print(f" โœ… SFT +1 (็ดฏ่จˆ: {meta['total_sft']})"); check_auto_train(meta)
99
+ continue
100
+ elif user_input == "/reject":
101
+ print(" โŒ ๅทฒๆ‹’็ต•"); current_response = None; continue
102
+ elif user_input == "/edit":
103
+ if current_prompt and current_response:
104
+ print(" ่ฒผไธŠไฟฎๆ”นๅพŒ็š„ code๏ผˆ่ผธๅ…ฅ END ็ตๆŸ๏ผ‰:")
105
+ edited_lines = []
106
+ while True:
107
+ line = input()
108
+ if line.strip() == "END": break
109
+ edited_lines.append(line)
110
+ edited_code = "\n".join(edited_lines)
111
+ if edited_code.strip():
112
+ append_data(DPO_FILE, {"prompt":[{"role":"user","content":current_prompt}],"chosen":[{"role":"assistant","content":edited_code}],"rejected":[{"role":"assistant","content":current_response}],"timestamp":datetime.now().isoformat(),"source":"chat_edited"})
113
+ append_data(SFT_FILE, {"messages":[{"role":"user","content":current_prompt},{"role":"assistant","content":edited_code}],"timestamp":datetime.now().isoformat(),"source":"chat_edited_sft"})
114
+ meta["total_dpo"] = meta.get("total_dpo",0)+1; meta["total_sft"] = meta.get("total_sft",0)+1; save_metadata(meta)
115
+ print(f" โœ… DPO +1 / SFT +1 (DPO:{meta['total_dpo']} SFT:{meta['total_sft']})"); check_auto_train(meta)
116
+ continue
117
+ elif user_input == "/test":
118
+ if current_prompt and current_response:
119
+ print(" ่ฒผไธŠ pytest ๆธฌ่ฉฆ๏ผˆ่ผธๅ…ฅ END ็ตๆŸ๏ผ‰:")
120
+ test_lines = []
121
+ while True:
122
+ line = input()
123
+ if line.strip() == "END": break
124
+ test_lines.append(line)
125
+ test_code = "\n".join(test_lines)
126
+ if test_code.strip():
127
+ append_data(GRPO_FILE, {"prompt":[{"role":"user","content":current_prompt}],"solution":current_response,"test":test_code,"timestamp":datetime.now().isoformat(),"source":"chat_test"})
128
+ meta["total_grpo"] = meta.get("total_grpo",0)+1; save_metadata(meta)
129
+ print(f" โœ… GRPO +1 (็ดฏ่จˆ: {meta['total_grpo']})")
130
+ continue
131
+
132
+ # ็”Ÿๆˆๅ›ž็ญ”
133
+ current_prompt = user_input
134
+ messages = [{"role":"system","content":"You are an exceptionally skilled programmer. Write clean, efficient, well-documented code."},{"role":"user","content":user_input}]
135
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
136
+ inputs = tokenizer(text, return_tensors="pt").to(model.device)
137
+ with torch.no_grad():
138
+ outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.pad_token_id)
139
+ current_response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
140
+ print(f"\n๐Ÿค– ๆจกๅž‹:\n{current_response}\n\n ๐Ÿ’ก /accept | /edit | /reject | /test\n")
141
+
142
+ # ============================================================
143
+ # Git ็›ฃๆŽงๆจกๅผ
144
+ # ============================================================
145
+ def run_watch(repo_path):
146
+ ensure_data_dir(); repo_path = os.path.abspath(repo_path)
147
+ print(f" ๐Ÿ‘€ ็›ฃๆŽง: {repo_path}"); meta = load_metadata()
148
+ seen_file = os.path.join(DATA_DIR, "seen_commits.json")
149
+ seen = set(json.load(open(seen_file))) if os.path.exists(seen_file) else set()
150
+ print(f" ๅทฒ่™•็†: {len(seen)} commits\n ็›ฃๆŽงไธญ... (Ctrl+C ๅœๆญข)\n")
151
+ while True:
152
+ try:
153
+ r = subprocess.run(["git","log","--oneline","-20","--format=%H %s"], cwd=repo_path, capture_output=True, text=True)
154
+ for line in r.stdout.strip().split("\n"):
155
+ if not line.strip(): continue
156
+ parts = line.split(" ",1); h = parts[0]; msg = parts[1] if len(parts)>1 else ""
157
+ if h in seen: continue
158
+ dr = subprocess.run(["git","diff",f"{h}~1",h,"--name-only"], cwd=repo_path, capture_output=True, text=True)
159
+ for f in [x for x in dr.stdout.strip().split("\n") if x.endswith(".py")]:
160
+ try:
161
+ fr = subprocess.run(["git","show",f"{h}:{f}"], cwd=repo_path, capture_output=True, text=True)
162
+ code = fr.stdout
163
+ if 50 < len(code) < 10000:
164
+ append_data(SFT_FILE, {"messages":[{"role":"user","content":f"Write: {f}\nCommit: {msg}"},{"role":"assistant","content":code}],"timestamp":datetime.now().isoformat(),"source":"git","commit":h[:8],"file":f})
165
+ meta["total_sft"] = meta.get("total_sft",0)+1
166
+ print(f" ๐Ÿ“ {h[:8]} | {f} โ†’ SFT ({meta['total_sft']})")
167
+ except: pass
168
+ seen.add(h)
169
+ save_metadata(meta); json.dump(list(seen), open(seen_file,"w")); check_auto_train(meta)
170
+ time.sleep(30)
171
+ except KeyboardInterrupt: print("\nโน๏ธ ๅทฒๅœๆญข"); break
172
+ except Exception as e: print(f" โš ๏ธ {e}"); time.sleep(30)
173
+
174
+ def show_status():
175
+ ensure_data_dir(); meta = load_metadata()
176
+ s,d,g = count_lines(SFT_FILE), count_lines(DPO_FILE), count_lines(GRPO_FILE); t = s+d+g
177
+ print(f"""
178
+ ๐Ÿ“Š ๆ•ธๆ“šๆ”ถ้›†็‹€ๆ…‹
179
+ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
180
+ SFT: {s:>5} ๆข {'โ–ˆ'*min(s//5,30)}
181
+ DPO: {d:>5} ๆข {'โ–ˆ'*min(d//5,30)}
182
+ GRPO: {g:>5} ๆข {'โ–ˆ'*min(g//5,30)}
183
+ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
184
+ ็ธฝ่จˆ: {t:>5} ๆข
185
+ ่‡ชๅ‹•่จ“็ทด้–€ๆชป: {AUTO_TRAIN_THRESHOLD} ๆข
186
+ ่ทไธ‹ๆฌก่จ“็ทด: {max(0,AUTO_TRAIN_THRESHOLD-t)} ๆข
187
+ ๅทฒ่จ“็ทดๆฌกๆ•ธ: {meta.get('train_count',0)} ๆฌก
188
+ """)
189
+
190
+ def check_auto_train(meta):
191
+ total = count_lines(SFT_FILE)+count_lines(DPO_FILE)+count_lines(GRPO_FILE)
192
+ new = total - meta.get("last_train_total",0)
193
+ if new >= AUTO_TRAIN_THRESHOLD:
194
+ print(f"\n ๐Ÿ”” ็ดฏ็ฉ {new} ๆขๆ–ฐๆ•ธๆ“š๏ผ้‹่กŒ python code_llm_collector.py train")
195
+
196
+ def trigger_training():
197
+ ensure_data_dir(); meta = load_metadata()
198
+ s,d = count_lines(SFT_FILE), count_lines(DPO_FILE)
199
+ if s+d == 0: print(" โš ๏ธ ็„กๆ•ธๆ“š"); return
200
+ print(f"\n๐Ÿš€ ่จ“็ทดไธญ... SFT:{s} DPO:{d}")
201
+ from datasets import Dataset
202
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
203
+ from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, PeftModel
204
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
205
+ if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
206
+ bnb = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)
207
+ if s > 0:
208
+ from trl import SFTTrainer, SFTConfig
209
+ data = [json.loads(l) for l in open(SFT_FILE)]; ds = Dataset.from_list([{"messages":x["messages"]} for x in data])
210
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL,quantization_config=bnb,device_map="auto",trust_remote_code=True)
211
+ if ADAPTER_PATH and os.path.exists(ADAPTER_PATH): model = PeftModel.from_pretrained(model,ADAPTER_PATH,is_trainable=True)
212
+ else: model = prepare_model_for_kbit_training(model); model = get_peft_model(model, LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]))
213
+ td = os.path.join(DATA_DIR,f"train_{datetime.now().strftime('%Y%m%d_%H%M')}")
214
+ trainer = SFTTrainer(model=model,args=SFTConfig(output_dir=td,learning_rate=2e-4,num_train_epochs=3,per_device_train_batch_size=1,gradient_accumulation_steps=8,max_seq_length=1024,gradient_checkpointing=True,bf16=True,optim="paged_adamw_8bit",logging_steps=10,save_total_limit=1,logging_strategy="steps",logging_first_step=True),processing_class=tokenizer,train_dataset=ds)
215
+ trainer.train(); trainer.save_model(td); print(f" โœ… SFT: {td}"); del model; torch.cuda.empty_cache()
216
+ meta["last_train"]=datetime.now().isoformat(); meta["train_count"]=meta.get("train_count",0)+1; meta["last_train_total"]=s+d; save_metadata(meta)
217
+ print(f"\n๐ŸŽ‰ ็ฌฌ {meta['train_count']} ๆฌก่จ“็ทดๅฎŒๆˆ๏ผ")
218
+
219
+ def export_dataset():
220
+ ensure_data_dir(); s,d = count_lines(SFT_FILE), count_lines(DPO_FILE)
221
+ if s+d == 0: print(" โš ๏ธ ็„กๆ•ธๆ“š"); return
222
+ from datasets import Dataset
223
+ if s > 0:
224
+ ds = Dataset.from_list([json.loads(l) for l in open(SFT_FILE)]); n = f"{HF_USERNAME}/my-code-sft-data"
225
+ ds.push_to_hub(n, private=True); print(f" โœ… SFT: https://huggingface.co/datasets/{n}")
226
+ if d > 0:
227
+ ds = Dataset.from_list([json.loads(l) for l in open(DPO_FILE)]); n = f"{HF_USERNAME}/my-code-dpo-data"
228
+ ds.push_to_hub(n, private=True); print(f" โœ… DPO: https://huggingface.co/datasets/{n}")
229
+
230
+ def main():
231
+ parser = argparse.ArgumentParser(description="Code LLM ๆ•ธๆ“š้ฃ›่ผช")
232
+ parser.add_argument("mode", choices=["chat","watch","status","train","export"])
233
+ parser.add_argument("--repo", type=str, default=".")
234
+ args = parser.parse_args()
235
+ {"chat":run_chat,"watch":lambda:run_watch(args.repo),"status":show_status,"train":trigger_training,"export":export_dataset}[args.mode]()
236
+
237
+ if __name__ == "__main__": main()