Spaces:
Runtime error
Runtime error
| import argparse | |
| import asyncio | |
| import ast | |
| import os | |
| import re | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from training.curriculum import DEFAULT_CURRICULUM, sample_difficulty | |
| from training.rollout import run_episode | |
| DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct" | |
| DEFAULT_ENV_URL = "http://127.0.0.1:8000" | |
| VALID_ACTIONS = { | |
| "PROSPECT", | |
| "QUALIFY", | |
| "PRESENT", | |
| "HANDLE_OBJECTION", | |
| "OFFER_DEMO", | |
| "NEGOTIATE", | |
| "CLOSE", | |
| "FOLLOW_UP", | |
| "DISQUALIFY", | |
| } | |
| WORKFLOW_MAP = { | |
| 1: ["QUALIFY", "PRESENT", "CLOSE"], | |
| 2: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "CLOSE"], | |
| 3: ["QUALIFY", "PRESENT", "HANDLE_OBJECTION", "OFFER_DEMO", "HANDLE_OBJECTION", "NEGOTIATE", "CLOSE"], | |
| 4: [], | |
| } | |
| def _load_model_and_tokenizer(model_name: str, use_unsloth: bool = False): | |
| if use_unsloth: | |
| try: | |
| from unsloth import FastLanguageModel | |
| print("Loading with Unsloth in 4-bit + LoRA...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=2048, | |
| load_in_4bit=True, | |
| fast_inference=True, | |
| max_lora_rank=16, | |
| ) | |
| # Inject LoRA adapters to drastically reduce VRAM | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"], | |
| lora_alpha=16, | |
| lora_dropout=0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return model, tokenizer | |
| except ImportError: | |
| print("Warning: unsloth not found. Falling back to standard HF loading (High VRAM).") | |
| # Standard loading (for curriculum or fallback) | |
| import torch | |
| bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16 if bf16_supported else torch.float32, | |
| device_map="auto", | |
| ) | |
| try: | |
| from peft import LoraConfig, get_peft_model | |
| print("Applying standard PEFT (LoRA) adapters...") | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| except ImportError: | |
| print("Warning: PEFT not found. Proceeding with full parameter tuning (May OOM).") | |
| return model, tokenizer | |
| async def curriculum_train( | |
| model, | |
| tokenizer, | |
| env_url: str, | |
| total_steps: int = 100, | |
| print_every: int = 10, | |
| ): | |
| """Curriculum rollout loop to benchmark env + policy behavior.""" | |
| mean_reward = 0.0 | |
| reward_history: list[float] = [] | |
| run_log: list[dict] = [] | |
| for step in range(total_steps): | |
| difficulty = sample_difficulty(DEFAULT_CURRICULUM, mean_reward) | |
| result = await run_episode( | |
| model=model, | |
| tokenizer=tokenizer, | |
| env_url=env_url, | |
| difficulty=difficulty, | |
| ) | |
| reward_history.append(float(result["total_reward"])) | |
| mean_reward = float(np.mean(reward_history[-20:])) | |
| run_log.append( | |
| { | |
| "step": step, | |
| "difficulty": difficulty, | |
| "reward": float(result["total_reward"]), | |
| "violations": len(result["violations"]), | |
| "steps_completed": list(result["steps_completed"]), | |
| } | |
| ) | |
| if step % print_every == 0: | |
| print( | |
| f"Step {step:04d} | Difficulty {difficulty} | " | |
| f"Reward {result['total_reward']:.3f} | Mean(20) {mean_reward:.3f} | " | |
| f"Violations {len(result['violations'])} | Steps {result['steps_completed']}" | |
| ) | |
| return { | |
| "mean_reward": mean_reward, | |
| "reward_history": reward_history, | |
| "run_log": run_log, | |
| } | |
| def _save_metrics(output_dir: str, metrics: dict): | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| rewards_path = output_path / "reward_history.txt" | |
| with rewards_path.open("w", encoding="utf-8") as f: | |
| for idx, reward in enumerate(metrics["reward_history"]): | |
| f.write(f"{idx}\t{reward:.6f}\n") | |
| print(f"Saved reward history to {rewards_path}") | |
| def _extract_action_content(text: str) -> tuple[str, str]: | |
| action_match = re.search(r"ACTION:\s*(\w+)", text, re.IGNORECASE) | |
| content_match = re.search(r"CONTENT:\s*(.+?)(?:\n|$)", text, re.IGNORECASE | re.DOTALL) | |
| action_type = action_match.group(1).upper() if action_match else "" | |
| content = content_match.group(1).strip() if content_match else "" | |
| return action_type, content | |
| def _extract_steps_completed(prompt_text: str) -> list[str]: | |
| match = re.search(r"Steps completed:\s*(\[.*?\])", prompt_text, re.DOTALL) | |
| if not match: | |
| return [] | |
| try: | |
| parsed = ast.literal_eval(match.group(1)) | |
| if isinstance(parsed, list): | |
| return [str(v).upper() for v in parsed] | |
| except Exception: | |
| return [] | |
| return [] | |
| def _extract_required_workflow(prompt_text: str) -> list[str]: | |
| match = re.search(r"Required workflow steps \(in order\):\s*(.+)", prompt_text) | |
| if not match: | |
| return [] | |
| raw = match.group(1).strip() | |
| if raw.lower().startswith("dynamic"): | |
| return [] | |
| return [part.strip().upper() for part in raw.split("->") if part.strip()] | |
| def salespath_reward_func(prompts, completions, **kwargs): | |
| """ | |
| Lightweight GRPO reward signal aligned with project rules. | |
| Uses format validity + basic workflow order constraints. | |
| """ | |
| rewards: list[float] = [] | |
| for prompt, completion in zip(prompts, completions): | |
| action_type, content = _extract_action_content(completion) | |
| steps_completed = _extract_steps_completed(prompt) | |
| required_workflow = _extract_required_workflow(prompt) | |
| reward = 0.0 | |
| # Format + valid action (make this dense, not binary) | |
| has_action_prefix = "ACTION:" in completion.upper() | |
| has_content_prefix = "CONTENT:" in completion.upper() | |
| if has_action_prefix: | |
| reward += 0.05 | |
| if has_content_prefix: | |
| reward += 0.05 | |
| if action_type in VALID_ACTIONS: | |
| reward += 0.15 | |
| else: | |
| rewards.append(-0.2) | |
| continue | |
| if content: | |
| reward += 0.1 | |
| else: | |
| reward -= 0.1 | |
| # Encourage concise responses so completions terminate before cap. | |
| content_len = len(content) | |
| if content_len > 220: | |
| reward -= 0.15 | |
| elif content_len > 120: | |
| reward -= 0.05 | |
| elif 12 <= content_len <= 120: | |
| reward += 0.05 | |
| # Penalize rambling multi-paragraph completions. | |
| if completion.count("\n") > 4: | |
| reward -= 0.1 | |
| # Positive signal for selecting the next expected workflow step. | |
| if required_workflow: | |
| next_idx = min(len(steps_completed), len(required_workflow) - 1) | |
| expected = required_workflow[next_idx] | |
| if action_type == expected: | |
| reward += 0.2 | |
| # Rule hints | |
| if not steps_completed and action_type != "PROSPECT": | |
| reward -= 0.2 # R06 | |
| if action_type == "PRESENT" and "QUALIFY" not in steps_completed: | |
| reward -= 0.2 # R01 | |
| if action_type == "NEGOTIATE" and "OFFER_DEMO" not in steps_completed: | |
| reward -= 0.2 # R02 | |
| if action_type == "CLOSE" and "OFFER_DEMO" not in steps_completed: | |
| reward -= 0.2 # R09 | |
| # Keep rewards bounded for training stability. | |
| rewards.append(float(max(-1.0, min(1.0, reward)))) | |
| return rewards | |
| def _build_grpo_dataset_rows(num_rows: int = 128): | |
| rows = [] | |
| prospect_snippets = [ | |
| "We are evaluating options right now.", | |
| "Budget is tight this quarter.", | |
| "Can you explain implementation effort?", | |
| "Pricing seems high compared to alternatives.", | |
| ] | |
| for i in range(num_rows): | |
| difficulty = (i % 4) + 1 | |
| workflow = WORKFLOW_MAP[difficulty] | |
| steps_completed = [] if i % 3 == 0 else workflow[: min(len(workflow), i % 2 + 1)] | |
| prompt = ( | |
| "You are a B2B sales agent.\n\n" | |
| f"Required workflow steps (in order): {' -> '.join(workflow) if workflow else 'Dynamic'}\n" | |
| f"Current stage: {'START' if not steps_completed else steps_completed[-1]}\n" | |
| f"Steps completed: {steps_completed}\n" | |
| f"Turn: {(i % 8) + 1}/20\n" | |
| "Business rules: R01..R09 must be respected.\n" | |
| f"Prospect response: {prospect_snippets[i % len(prospect_snippets)]}\n\n" | |
| "Respond exactly with:\nACTION: <action>\nCONTENT: <message>" | |
| ) | |
| rows.append({"prompt": prompt}) | |
| return rows | |
| def run_grpo(args): | |
| try: | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| except Exception as exc: | |
| raise RuntimeError( | |
| "Failed to initialize TRL GRPO stack. On this machine, this is usually due to " | |
| "Windows blocking pyarrow dataset binaries in the local virtualenv. " | |
| "Use the provided Colab notebook (`training/colab_train.ipynb`) for GRPO runs, " | |
| "or fix local pyarrow/datasets installation first." | |
| ) from exc | |
| # Try Unsloth first (4-bit saves VRAM), fallback to standard HF | |
| use_unsloth = args.model_name.startswith("unsloth/") | |
| model, tokenizer = _load_model_and_tokenizer(args.model_name, use_unsloth=use_unsloth) | |
| rows = _build_grpo_dataset_rows(args.grpo_dataset_size) | |
| train_dataset = Dataset.from_list(rows) | |
| bf16_supported = torch.cuda.is_available() and torch.cuda.is_bf16_supported() | |
| config = GRPOConfig( | |
| output_dir=args.output_dir, | |
| learning_rate=args.learning_rate, | |
| per_device_train_batch_size=args.per_device_train_batch_size, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| num_generations=args.num_generations, | |
| max_completion_length=args.max_completion_length, | |
| temperature=args.temperature, | |
| logging_steps=args.logging_steps, | |
| save_steps=args.save_steps, | |
| max_steps=args.grpo_steps, | |
| report_to="none", | |
| fp16=not bf16_supported, | |
| bf16=bf16_supported, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| reward_funcs=salespath_reward_func, | |
| args=config, | |
| train_dataset=train_dataset, | |
| tokenizer=tokenizer, | |
| ) | |
| trainer.train() | |
| trainer.save_model(str(Path(args.output_dir) / "grpo_final")) | |
| print(f"Saved GRPO model to {Path(args.output_dir) / 'grpo_final'}") | |
| # --- Save reward history from trainer logs so plot_rewards.py works --- | |
| output_path = Path(args.output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| rewards_path = output_path / "reward_history.txt" | |
| log_rewards = [] | |
| for entry in trainer.state.log_history: | |
| # TRL GRPO logs rewards under various key names depending on version | |
| for key in ("reward", "rewards", "mean_reward", "train/reward", "train/rewards"): | |
| if key in entry: | |
| log_rewards.append(float(entry[key])) | |
| break | |
| if log_rewards: | |
| with rewards_path.open("w") as f: | |
| for idx, r in enumerate(log_rewards): | |
| f.write(f"{idx}\t{r:.6f}\n") | |
| print(f"Saved reward history ({len(log_rewards)} entries) → {rewards_path}") | |
| else: | |
| # Fallback: write a placeholder so plot_rewards.py doesn't crash | |
| print("Warning: no reward entries found in trainer logs. Writing placeholder.") | |
| with rewards_path.open("w") as f: | |
| for entry in trainer.state.log_history: | |
| if "loss" in entry: | |
| f.write(f"{entry.get('step', 0)}\t0.0\n") | |
| if args.push_to_hub: | |
| trainer.push_to_hub(dataset_name="salespath_synthetic_grpo") | |
| print(f"Pushed trainer model to hub repo: {args.hub_repo}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="SalesPath training entrypoint.") | |
| parser.add_argument("--mode", choices=["curriculum", "grpo"], default="curriculum") | |
| parser.add_argument("--model-name", default=DEFAULT_MODEL) | |
| parser.add_argument("--env-url", default=DEFAULT_ENV_URL) | |
| parser.add_argument("--steps", type=int, default=100, help="Curriculum rollout steps.") | |
| parser.add_argument("--print-every", type=int, default=10) | |
| parser.add_argument("--output-dir", default="salespath_training_outputs") | |
| parser.add_argument("--hub-repo", default="Imsachin010/salespath-qwen25-0.5b") | |
| parser.add_argument("--push-to-hub", action="store_true") | |
| parser.add_argument("--push-merged", action="store_true") | |
| # GRPO-specific knobs | |
| parser.add_argument("--grpo-steps", type=int, default=30) | |
| parser.add_argument("--grpo-dataset-size", type=int, default=128) | |
| parser.add_argument("--learning-rate", type=float, default=5e-6) | |
| parser.add_argument("--per-device-train-batch-size", type=int, default=4) | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=4) | |
| parser.add_argument("--num-generations", type=int, default=4) | |
| parser.add_argument("--max-completion-length", type=int, default=64) | |
| parser.add_argument("--temperature", type=float, default=0.7) | |
| parser.add_argument("--logging-steps", type=int, default=10) | |
| parser.add_argument("--save-steps", type=int, default=100) | |
| return parser.parse_args() | |
| async def _run_curriculum_mode(args): | |
| print(f"Loading model: {args.model_name}") | |
| model, tokenizer = _load_model_and_tokenizer(args.model_name) | |
| print(f"Starting curriculum loop against {args.env_url}") | |
| metrics = await curriculum_train( | |
| model=model, | |
| tokenizer=tokenizer, | |
| env_url=args.env_url, | |
| total_steps=args.steps, | |
| print_every=args.print_every, | |
| ) | |
| print(f"Final mean reward (last 20): {metrics['mean_reward']:.4f}") | |
| _save_metrics(args.output_dir, metrics) | |
| if args.push_merged: | |
| hf_token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
| if hasattr(model, "save_pretrained_merged"): | |
| merged_dir = Path(args.output_dir) / "salespath_trained_merged" | |
| model.save_pretrained_merged( | |
| str(merged_dir), | |
| tokenizer, | |
| save_method="merged_16bit", | |
| ) | |
| print(f"Saved merged model to {merged_dir}") | |
| if hf_token and hasattr(model, "push_to_hub_merged"): | |
| model.push_to_hub_merged( | |
| args.hub_repo, | |
| tokenizer, | |
| save_method="merged_16bit", | |
| token=hf_token, | |
| ) | |
| print(f"Pushed merged model to {args.hub_repo}") | |
| else: | |
| print( | |
| "Model does not support merged save APIs. " | |
| "Use an Unsloth merged-capable model to enable --push-merged." | |
| ) | |
| async def _main(): | |
| args = parse_args() | |
| if args.mode == "curriculum": | |
| await _run_curriculum_mode(args) | |
| return | |
| print("Launching TRL GRPO mode...") | |
| run_grpo(args) | |
| if __name__ == "__main__": | |
| asyncio.run(_main()) | |