Spaces:
Sleeping
Sleeping
Add /compare command for automatic DPO pair generation
Browse files- codepilot_v3.py +96 -3
codepilot_v3.py
CHANGED
|
@@ -437,7 +437,23 @@ def run_agent_loop(args):
|
|
| 437 |
if git_ctx != "(not a git repo)":
|
| 438 |
console.print(Panel(git_ctx, title="📂 Project", border_style="dim"))
|
| 439 |
|
| 440 |
-
console.print("[dim]指令: /ls /git /clear /switch /status /train /quit[/]\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
|
| 442 |
system_prompt = build_system_prompt(tools)
|
| 443 |
messages = [{"role": "system", "content": system_prompt}]
|
|
@@ -473,17 +489,94 @@ def run_agent_loop(args):
|
|
| 473 |
console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue
|
| 474 |
elif cmd.startswith("/ls"):
|
| 475 |
console.print(tools.list_files(cmd[3:].strip() or "*")); continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
elif cmd == "/switch":
|
| 477 |
console.print("可用: local, openai, anthropic, openrouter, ollama")
|
| 478 |
new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys()))
|
| 479 |
if new_p == "local":
|
| 480 |
with console.status("載入本地模型..."):
|
| 481 |
model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
|
| 482 |
-
provider_key = "local"
|
| 483 |
else:
|
| 484 |
key = args.api_key or Prompt.ask("API Key")
|
| 485 |
cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"])
|
| 486 |
-
model = CloudModel(new_p, key, cm); provider_key = new_p
|
| 487 |
console.print(f"[green]✅ 切換到 {provider_key}[/]"); continue
|
| 488 |
|
| 489 |
messages.append({"role": "user", "content": user_input})
|
|
|
|
| 437 |
if git_ctx != "(not a git repo)":
|
| 438 |
console.print(Panel(git_ctx, title="📂 Project", border_style="dim"))
|
| 439 |
|
| 440 |
+
console.print("[dim]指令: /ls /git /clear /switch /compare /status /train /quit[/]\n")
|
| 441 |
+
|
| 442 |
+
# 保存模型參照,讓 /compare 可以用
|
| 443 |
+
local_model_ref = None
|
| 444 |
+
cloud_model_ref = None
|
| 445 |
+
if provider_key == "local":
|
| 446 |
+
local_model_ref = model
|
| 447 |
+
else:
|
| 448 |
+
cloud_model_ref = model
|
| 449 |
+
# 蒸餾/compare 模式下,也嘗試載入本地模型
|
| 450 |
+
if args.adapter or distill_mode:
|
| 451 |
+
try:
|
| 452 |
+
with console.status("[dim]同時載入本地模型 (for /compare)..."):
|
| 453 |
+
local_model_ref = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
|
| 454 |
+
console.print("[dim]✅ 本地模型也已載入,可用 /compare[/]")
|
| 455 |
+
except Exception:
|
| 456 |
+
console.print("[dim]⚠️ 本地模型載入失敗,/compare 不可用[/]")
|
| 457 |
|
| 458 |
system_prompt = build_system_prompt(tools)
|
| 459 |
messages = [{"role": "system", "content": system_prompt}]
|
|
|
|
| 489 |
console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue
|
| 490 |
elif cmd.startswith("/ls"):
|
| 491 |
console.print(tools.list_files(cmd[3:].strip() or "*")); continue
|
| 492 |
+
elif cmd == "/compare" or cmd.startswith("/compare "):
|
| 493 |
+
# /compare 模式:同一問題自動送給本地+雲端,並排比較,一鍵產生 DPO
|
| 494 |
+
compare_question = cmd[8:].strip() if cmd.startswith("/compare ") else None
|
| 495 |
+
if not compare_question:
|
| 496 |
+
compare_question = Prompt.ask(" 輸入要比較的問題")
|
| 497 |
+
if not compare_question.strip():
|
| 498 |
+
continue
|
| 499 |
+
|
| 500 |
+
need_local = local_model_ref or (provider_key == "local" and model)
|
| 501 |
+
need_cloud = cloud_model_ref or (provider_key != "local" and model)
|
| 502 |
+
|
| 503 |
+
if not need_local or not need_cloud:
|
| 504 |
+
console.print("[yellow]⚠️ /compare 需要同時有本地和雲端模型[/]")
|
| 505 |
+
console.print("[dim]啟動方式: codepilot --provider openrouter --api-key sk-xxx --adapter ./adapter[/]")
|
| 506 |
+
continue
|
| 507 |
+
|
| 508 |
+
lm = local_model_ref if local_model_ref else model
|
| 509 |
+
cm = cloud_model_ref if cloud_model_ref else model
|
| 510 |
+
|
| 511 |
+
compare_msgs = [
|
| 512 |
+
{"role": "system", "content": system_prompt},
|
| 513 |
+
{"role": "user", "content": compare_question},
|
| 514 |
+
]
|
| 515 |
+
|
| 516 |
+
# 同時生成兩個回答
|
| 517 |
+
with console.status("[bold cyan]本地模型思考中..."):
|
| 518 |
+
try: local_resp = lm.chat(compare_msgs)
|
| 519 |
+
except Exception as e: local_resp = f"(錯誤: {e})"
|
| 520 |
+
|
| 521 |
+
with console.status("[bold magenta]雲端模型思考中..."):
|
| 522 |
+
try: cloud_resp = cm.chat(compare_msgs)
|
| 523 |
+
except Exception as e: cloud_resp = f"(錯誤: {e})"
|
| 524 |
+
|
| 525 |
+
# 並排顯示
|
| 526 |
+
console.print(f"\n[bold]🔄 Compare: {compare_question[:60]}{'...' if len(compare_question)>60 else ''}[/]\n")
|
| 527 |
+
console.print(Panel(Markdown(local_resp), title=f"🏠 Local ({lm.name})", border_style="blue"))
|
| 528 |
+
console.print(Panel(Markdown(cloud_resp), title=f"☁️ Cloud ({cm.name})", border_style="magenta"))
|
| 529 |
+
|
| 530 |
+
# 選擇
|
| 531 |
+
console.print(f"\n [green]1[/] = 本地較好 [magenta]2[/] = 雲端較好 [yellow]b[/] = 都好 [red]x[/] = 都差 Enter = 跳過")
|
| 532 |
+
choice = Prompt.ask(" ", choices=["1","2","b","x",""], default="", show_choices=False)
|
| 533 |
+
|
| 534 |
+
if choice == "2":
|
| 535 |
+
# 雲端好,本地差 → DPO: chosen=cloud, rejected=local
|
| 536 |
+
db.save(compare_question, cloud_resp, 1, project=project_dir,
|
| 537 |
+
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
|
| 538 |
+
db.save(compare_question, local_resp, 0, project=project_dir,
|
| 539 |
+
source_model=getattr(lm, "name", "local"), provider="local")
|
| 540 |
+
dpo_count = len(db.export_dpo())
|
| 541 |
+
console.print(f" [magenta]☁️ 雲端勝 → DPO +1[/] (累計 DPO 對: {dpo_count})")
|
| 542 |
+
|
| 543 |
+
elif choice == "1":
|
| 544 |
+
# 本地好 → 記錄本地為正面
|
| 545 |
+
db.save(compare_question, local_resp, 1, project=project_dir,
|
| 546 |
+
source_model=getattr(lm, "name", "local"), provider="local")
|
| 547 |
+
db.save(compare_question, cloud_resp, 0, project=project_dir,
|
| 548 |
+
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
|
| 549 |
+
console.print(f" [green]🏠 本地勝!你的模型在進步![/]")
|
| 550 |
+
|
| 551 |
+
elif choice == "b":
|
| 552 |
+
# 都好 → 兩個都記為 SFT
|
| 553 |
+
db.save(compare_question, local_resp, 1, project=project_dir,
|
| 554 |
+
source_model=getattr(lm, "name", "local"), provider="local")
|
| 555 |
+
db.save(compare_question, cloud_resp, 1, project=project_dir,
|
| 556 |
+
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
|
| 557 |
+
console.print(f" [yellow]👍 都好 → SFT +2[/]")
|
| 558 |
+
|
| 559 |
+
elif choice == "x":
|
| 560 |
+
# 都差
|
| 561 |
+
db.save(compare_question, local_resp, 0, project=project_dir,
|
| 562 |
+
source_model=getattr(lm, "name", "local"), provider="local")
|
| 563 |
+
db.save(compare_question, cloud_resp, 0, project=project_dir,
|
| 564 |
+
source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
|
| 565 |
+
console.print(f" [red]👎 都差[/]")
|
| 566 |
+
|
| 567 |
+
continue
|
| 568 |
+
|
| 569 |
elif cmd == "/switch":
|
| 570 |
console.print("可用: local, openai, anthropic, openrouter, ollama")
|
| 571 |
new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys()))
|
| 572 |
if new_p == "local":
|
| 573 |
with console.status("載入本地模型..."):
|
| 574 |
model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
|
| 575 |
+
local_model_ref = model; provider_key = "local"
|
| 576 |
else:
|
| 577 |
key = args.api_key or Prompt.ask("API Key")
|
| 578 |
cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"])
|
| 579 |
+
model = CloudModel(new_p, key, cm); cloud_model_ref = model; provider_key = new_p
|
| 580 |
console.print(f"[green]✅ 切換到 {provider_key}[/]"); continue
|
| 581 |
|
| 582 |
messages.append({"role": "user", "content": user_input})
|