Justin-lee commited on
Commit
1e8fbb1
·
verified ·
1 Parent(s): 66bbfd0

Add /compare command for automatic DPO pair generation

Browse files
Files changed (1) hide show
  1. codepilot_v3.py +96 -3
codepilot_v3.py CHANGED
@@ -437,7 +437,23 @@ def run_agent_loop(args):
437
  if git_ctx != "(not a git repo)":
438
  console.print(Panel(git_ctx, title="📂 Project", border_style="dim"))
439
 
440
- console.print("[dim]指令: /ls /git /clear /switch /status /train /quit[/]\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  system_prompt = build_system_prompt(tools)
443
  messages = [{"role": "system", "content": system_prompt}]
@@ -473,17 +489,94 @@ def run_agent_loop(args):
473
  console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue
474
  elif cmd.startswith("/ls"):
475
  console.print(tools.list_files(cmd[3:].strip() or "*")); continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  elif cmd == "/switch":
477
  console.print("可用: local, openai, anthropic, openrouter, ollama")
478
  new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys()))
479
  if new_p == "local":
480
  with console.status("載入本地模型..."):
481
  model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
482
- provider_key = "local"
483
  else:
484
  key = args.api_key or Prompt.ask("API Key")
485
  cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"])
486
- model = CloudModel(new_p, key, cm); provider_key = new_p
487
  console.print(f"[green]✅ 切換到 {provider_key}[/]"); continue
488
 
489
  messages.append({"role": "user", "content": user_input})
 
437
  if git_ctx != "(not a git repo)":
438
  console.print(Panel(git_ctx, title="📂 Project", border_style="dim"))
439
 
440
+ console.print("[dim]指令: /ls /git /clear /switch /compare /status /train /quit[/]\n")
441
+
442
+ # 保存模型參照,讓 /compare 可以用
443
+ local_model_ref = None
444
+ cloud_model_ref = None
445
+ if provider_key == "local":
446
+ local_model_ref = model
447
+ else:
448
+ cloud_model_ref = model
449
+ # 蒸餾/compare 模式下,也嘗試載入本地模型
450
+ if args.adapter or distill_mode:
451
+ try:
452
+ with console.status("[dim]同時載入本地模型 (for /compare)..."):
453
+ local_model_ref = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
454
+ console.print("[dim]✅ 本地模型也已載入,可用 /compare[/]")
455
+ except Exception:
456
+ console.print("[dim]⚠️ 本地模型載入失敗,/compare 不可用[/]")
457
 
458
  system_prompt = build_system_prompt(tools)
459
  messages = [{"role": "system", "content": system_prompt}]
 
489
  console.print(Panel(tools.git_context(), title="Git", border_style="dim")); continue
490
  elif cmd.startswith("/ls"):
491
  console.print(tools.list_files(cmd[3:].strip() or "*")); continue
492
+ elif cmd == "/compare" or cmd.startswith("/compare "):
493
+ # /compare 模式:同一問題自動送給本地+雲端,並排比較,一鍵產生 DPO
494
+ compare_question = cmd[8:].strip() if cmd.startswith("/compare ") else None
495
+ if not compare_question:
496
+ compare_question = Prompt.ask(" 輸入要比較的問題")
497
+ if not compare_question.strip():
498
+ continue
499
+
500
+ need_local = local_model_ref or (provider_key == "local" and model)
501
+ need_cloud = cloud_model_ref or (provider_key != "local" and model)
502
+
503
+ if not need_local or not need_cloud:
504
+ console.print("[yellow]⚠️ /compare 需要同時有本地和雲端模型[/]")
505
+ console.print("[dim]啟動方式: codepilot --provider openrouter --api-key sk-xxx --adapter ./adapter[/]")
506
+ continue
507
+
508
+ lm = local_model_ref if local_model_ref else model
509
+ cm = cloud_model_ref if cloud_model_ref else model
510
+
511
+ compare_msgs = [
512
+ {"role": "system", "content": system_prompt},
513
+ {"role": "user", "content": compare_question},
514
+ ]
515
+
516
+ # 同時生成兩個回答
517
+ with console.status("[bold cyan]本地模型思考中..."):
518
+ try: local_resp = lm.chat(compare_msgs)
519
+ except Exception as e: local_resp = f"(錯誤: {e})"
520
+
521
+ with console.status("[bold magenta]雲端模型思考中..."):
522
+ try: cloud_resp = cm.chat(compare_msgs)
523
+ except Exception as e: cloud_resp = f"(錯誤: {e})"
524
+
525
+ # 並排顯示
526
+ console.print(f"\n[bold]🔄 Compare: {compare_question[:60]}{'...' if len(compare_question)>60 else ''}[/]\n")
527
+ console.print(Panel(Markdown(local_resp), title=f"🏠 Local ({lm.name})", border_style="blue"))
528
+ console.print(Panel(Markdown(cloud_resp), title=f"☁️ Cloud ({cm.name})", border_style="magenta"))
529
+
530
+ # 選擇
531
+ console.print(f"\n [green]1[/] = 本地較好 [magenta]2[/] = 雲端較好 [yellow]b[/] = 都好 [red]x[/] = 都差 Enter = 跳過")
532
+ choice = Prompt.ask(" ", choices=["1","2","b","x",""], default="", show_choices=False)
533
+
534
+ if choice == "2":
535
+ # 雲端好,本地差 → DPO: chosen=cloud, rejected=local
536
+ db.save(compare_question, cloud_resp, 1, project=project_dir,
537
+ source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
538
+ db.save(compare_question, local_resp, 0, project=project_dir,
539
+ source_model=getattr(lm, "name", "local"), provider="local")
540
+ dpo_count = len(db.export_dpo())
541
+ console.print(f" [magenta]☁️ 雲端勝 → DPO +1[/] (累計 DPO 對: {dpo_count})")
542
+
543
+ elif choice == "1":
544
+ # 本地好 → 記錄本地為正面
545
+ db.save(compare_question, local_resp, 1, project=project_dir,
546
+ source_model=getattr(lm, "name", "local"), provider="local")
547
+ db.save(compare_question, cloud_resp, 0, project=project_dir,
548
+ source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
549
+ console.print(f" [green]🏠 本地勝!你的模型在進步![/]")
550
+
551
+ elif choice == "b":
552
+ # 都好 → 兩個都記為 SFT
553
+ db.save(compare_question, local_resp, 1, project=project_dir,
554
+ source_model=getattr(lm, "name", "local"), provider="local")
555
+ db.save(compare_question, cloud_resp, 1, project=project_dir,
556
+ source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
557
+ console.print(f" [yellow]👍 都好 → SFT +2[/]")
558
+
559
+ elif choice == "x":
560
+ # 都差
561
+ db.save(compare_question, local_resp, 0, project=project_dir,
562
+ source_model=getattr(lm, "name", "local"), provider="local")
563
+ db.save(compare_question, cloud_resp, 0, project=project_dir,
564
+ source_model=getattr(cm, "name", "cloud"), provider=getattr(cm, "provider", "cloud"))
565
+ console.print(f" [red]👎 都差[/]")
566
+
567
+ continue
568
+
569
  elif cmd == "/switch":
570
  console.print("可用: local, openai, anthropic, openrouter, ollama")
571
  new_p = Prompt.ask("切換到", choices=list(PROVIDER_CONFIGS.keys()))
572
  if new_p == "local":
573
  with console.status("載入本地模型..."):
574
  model = LocalModel(args.model or DEFAULT_LOCAL_MODEL, args.adapter)
575
+ local_model_ref = model; provider_key = "local"
576
  else:
577
  key = args.api_key or Prompt.ask("API Key")
578
  cm = Prompt.ask("模型", default=PROVIDER_CONFIGS[new_p]["default_model"])
579
+ model = CloudModel(new_p, key, cm); cloud_model_ref = model; provider_key = new_p
580
  console.print(f"[green]✅ 切換到 {provider_key}[/]"); continue
581
 
582
  messages.append({"role": "user", "content": user_input})