aamrinder commited on
Commit
7f60dea
Β·
verified Β·
1 Parent(s): 225e725

Upload folder using huggingface_hub

Browse files
openenv_subtext_arena.egg-info/PKG-INFO CHANGED
@@ -3,7 +3,7 @@ Name: openenv-subtext_arena
3
  Version: 0.1.0
4
  Summary: Subtext Arena environment for OpenEnv
5
  Requires-Python: >=3.10
6
- Requires-Dist: openenv-core[core]>=0.2.2
7
  Provides-Extra: dev
8
  Requires-Dist: pytest>=8.0.0; extra == "dev"
9
  Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
 
3
  Version: 0.1.0
4
  Summary: Subtext Arena environment for OpenEnv
5
  Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.3
7
  Provides-Extra: dev
8
  Requires-Dist: pytest>=8.0.0; extra == "dev"
9
  Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_subtext_arena.egg-info/SOURCES.txt CHANGED
@@ -17,4 +17,10 @@ server/app.py
17
  server/audio_tools.py
18
  server/grader.py
19
  server/scenarios.py
20
- server/subtext_arena_environment.py
 
 
 
 
 
 
 
17
  server/audio_tools.py
18
  server/grader.py
19
  server/scenarios.py
20
+ server/subtext_arena_environment.py
21
+ train/__init__.py
22
+ train/curate_pivot_set.py
23
+ train/eval_pivot_set.py
24
+ train/hour1_smoke.py
25
+ train/plot_reward_decomp.py
26
+ train/train_grpo.py
openenv_subtext_arena.egg-info/requires.txt CHANGED
@@ -1,4 +1,4 @@
1
- openenv-core[core]>=0.2.2
2
 
3
  [dev]
4
  pytest>=8.0.0
 
1
+ openenv-core[core]>=0.2.3
2
 
3
  [dev]
4
  pytest>=8.0.0
pyproject.toml CHANGED
@@ -35,5 +35,5 @@ server = "subtext_arena.server.app:main"
35
 
36
  [tool.setuptools]
37
  include-package-data = true
38
- packages = ["subtext_arena", "subtext_arena.server"]
39
- package-dir = { "subtext_arena" = ".", "subtext_arena.server" = "server" }
 
35
 
36
  [tool.setuptools]
37
  include-package-data = true
38
+ packages = ["subtext_arena", "subtext_arena.server", "subtext_arena.train"]
39
+ package-dir = { "subtext_arena" = ".", "subtext_arena.server" = "server", "subtext_arena.train" = "train" }
train/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Subtext Arena training scripts (GRPO + eval + plotting + Pivot Set curation)."""
train/eval_pivot_set.py CHANGED
@@ -28,17 +28,25 @@ import sys
28
  from pathlib import Path
29
  from typing import Any, Dict, List
30
 
31
- ROOT = Path(__file__).resolve().parent.parent
32
- if str(ROOT) not in sys.path:
33
- sys.path.insert(0, str(ROOT))
34
-
35
- from server.scenarios import load_scenarios
36
- from train.train_grpo import (
37
- SYSTEM_PROMPT,
38
- build_full_observation,
39
- parse_final,
40
- reward_decomposition,
41
- )
 
 
 
 
 
 
 
 
42
 
43
 
44
  def main():
 
28
  from pathlib import Path
29
  from typing import Any, Dict, List
30
 
31
+ try:
32
+ from subtext_arena.server.scenarios import load_scenarios
33
+ from subtext_arena.train.train_grpo import (
34
+ SYSTEM_PROMPT,
35
+ build_full_observation,
36
+ parse_final,
37
+ reward_decomposition,
38
+ )
39
+ except ImportError:
40
+ ROOT = Path(__file__).resolve().parent.parent
41
+ if str(ROOT) not in sys.path:
42
+ sys.path.insert(0, str(ROOT))
43
+ from server.scenarios import load_scenarios # type: ignore[no-redef]
44
+ from train.train_grpo import ( # type: ignore[no-redef]
45
+ SYSTEM_PROMPT,
46
+ build_full_observation,
47
+ parse_final,
48
+ reward_decomposition,
49
+ )
50
 
51
 
52
  def main():
train/hour1_smoke.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hour-1 smoke test for Path A.
2
+
3
+ Validates the entire training stack on a T4 in ~10-15 minutes:
4
+ 1. Unsloth + Qwen2.5-3B loads with 4-bit + LoRA
5
+ 2. Our env package installs from the HF Space and prompts build correctly
6
+ 3. TRL GRPOTrainer runs 2 steps end-to-end
7
+ 4. Reward function fires; rewards are non-zero
8
+ 5. LoRA weights actually update
9
+
10
+ If this passes -> commit to the full 200-step run.
11
+ If this fails -> the error tells us exactly what to fix before spending more.
12
+
13
+ Run on HF Jobs T4-medium ($0.60/hr, ~$0.15 for this script):
14
+ hf jobs uv run --flavor t4-medium -s HF_TOKEN \\
15
+ --with unsloth --with "trl>=0.11" --with datasets --with accelerate \\
16
+ --with "git+https://huggingface.co/spaces/aamrinder/subtext-arena" \\
17
+ -- python -m subtext_arena.train.hour1_smoke
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import sys
22
+ import time
23
+ import traceback
24
+
25
+
26
+ def main():
27
+ t_start = time.time()
28
+ print("=" * 60)
29
+ print("Subtext Arena hour-1 smoke test (Path A)")
30
+ print("=" * 60)
31
+
32
+ # 1. PyTorch + GPU
33
+ print("\n[1/6] checking PyTorch + GPU")
34
+ try:
35
+ import torch
36
+ assert torch.cuda.is_available(), "CUDA not available"
37
+ gpu_name = torch.cuda.get_device_name(0)
38
+ gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
39
+ print(f" βœ“ {gpu_name} ({gpu_mem:.1f} GB)")
40
+ except Exception as e:
41
+ print(f" βœ— {e}")
42
+ traceback.print_exc()
43
+ sys.exit(1)
44
+
45
+ # 2. Unsloth + TRL imports
46
+ print("\n[2/6] importing Unsloth + TRL")
47
+ try:
48
+ from unsloth import FastLanguageModel
49
+ from trl import GRPOTrainer, GRPOConfig
50
+ from datasets import Dataset
51
+ print(" βœ“ Unsloth + TRL + datasets imported")
52
+ except Exception as e:
53
+ print(f" βœ— {e}")
54
+ traceback.print_exc()
55
+ sys.exit(1)
56
+
57
+ # 3. Subtext Arena env package
58
+ print("\n[3/6] importing subtext_arena package")
59
+ try:
60
+ from subtext_arena import SubtextArenaEnv, SubtextArenaAction
61
+ from subtext_arena.server.scenarios import load_scenarios
62
+ from subtext_arena.train.train_grpo import (
63
+ SYSTEM_PROMPT, build_dataset, make_reward_fn, reward_decomposition,
64
+ )
65
+ scenarios = load_scenarios()
66
+ print(f" βœ“ {len(scenarios)} MUStARD scenarios loaded")
67
+ except Exception as e:
68
+ print(f" βœ— {e}")
69
+ traceback.print_exc()
70
+ sys.exit(1)
71
+
72
+ # 4. Build a TINY dataset (8 rows is enough for smoke)
73
+ print("\n[4/6] building tiny dataset (8 rows)")
74
+ try:
75
+ ds = build_dataset(scenarios, n_rows=8, seed=0)
76
+ print(f" βœ“ dataset cols={ds.column_names}, len={len(ds)}")
77
+ print(f" first prompt user-msg first 200 chars: {ds[0]['prompt'][1]['content'][:200]!r}")
78
+ except Exception as e:
79
+ print(f" βœ— {e}")
80
+ traceback.print_exc()
81
+ sys.exit(1)
82
+
83
+ # 5. Load Qwen2.5-3B-Instruct + LoRA
84
+ print("\n[5/6] loading Qwen2.5-3B-Instruct (4-bit + LoRA)")
85
+ try:
86
+ model, tokenizer = FastLanguageModel.from_pretrained(
87
+ model_name="unsloth/Qwen2.5-3B-Instruct",
88
+ max_seq_length=2048, # smaller than full 4096 for speed
89
+ load_in_4bit=True,
90
+ )
91
+ model = FastLanguageModel.get_peft_model(
92
+ model,
93
+ r=8, # smaller r for the smoke test
94
+ lora_alpha=16,
95
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
96
+ use_gradient_checkpointing="unsloth",
97
+ )
98
+ n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
99
+ print(f" βœ“ model loaded; {n_trainable / 1e6:.1f}M LoRA params trainable")
100
+ except Exception as e:
101
+ print(f" βœ— {e}")
102
+ traceback.print_exc()
103
+ sys.exit(1)
104
+
105
+ # 6. Run 2 GRPO steps
106
+ print("\n[6/6] running 2 GRPO steps")
107
+ try:
108
+ reward_fn = make_reward_fn()
109
+
110
+ # Wrap reward_fn to print per-rollout decomposition for the smoke test
111
+ last_rewards = []
112
+ def smoke_reward_fn(prompts, completions, **kwargs):
113
+ rewards = reward_fn(prompts, completions, **kwargs)
114
+ last_rewards.append(list(rewards))
115
+ for i, (c, r, gold) in enumerate(zip(completions, rewards, kwargs.get("gold", []))):
116
+ text = c[0]["content"] if isinstance(c, list) else str(c)
117
+ d = reward_decomposition(text, gold)
118
+ print(f" rollout {i}: reward={r:.3f} pred={d['_predicted']!s:>10} gold={gold:>10} "
119
+ f"correct={d['_correct']} well_formed={d['_well_formed']}")
120
+ return rewards
121
+
122
+ config = GRPOConfig(
123
+ output_dir="/tmp/smoke_out",
124
+ num_generations=2, # keep small for speed
125
+ max_completion_length=384,
126
+ per_device_train_batch_size=1,
127
+ learning_rate=5e-6,
128
+ max_steps=2,
129
+ logging_steps=1,
130
+ save_steps=10, # never saves in 2 steps
131
+ bf16=True,
132
+ report_to="none",
133
+ gradient_checkpointing=True,
134
+ )
135
+
136
+ trainer = GRPOTrainer(
137
+ model=model,
138
+ reward_funcs=smoke_reward_fn,
139
+ args=config,
140
+ train_dataset=ds,
141
+ processing_class=tokenizer,
142
+ )
143
+ trainer.train()
144
+ print(f" βœ“ 2 GRPO steps completed")
145
+
146
+ if last_rewards:
147
+ all_r = [r for batch in last_rewards for r in batch]
148
+ mean_r = sum(all_r) / len(all_r)
149
+ n_well_formed = sum(1 for r in all_r if r > 0.05)
150
+ print(f" βœ“ {len(all_r)} rollouts, mean reward = {mean_r:.3f}, {n_well_formed} well-formed")
151
+ if n_well_formed == 0:
152
+ print(" ⚠ WARNING: zero well-formed completions. The base model isn't following the format.")
153
+ print(" Consider an SFT warmup pass before GRPO.")
154
+ except Exception as e:
155
+ print(f" βœ— {e}")
156
+ traceback.print_exc()
157
+ sys.exit(1)
158
+
159
+ elapsed = time.time() - t_start
160
+ print()
161
+ print("=" * 60)
162
+ print(f"βœ“ ALL CHECKS PASS in {elapsed:.1f}s β€” Path A stack is alive")
163
+ print("=" * 60)
164
+
165
+
166
+ if __name__ == "__main__":
167
+ main()
train/train_grpo.py CHANGED
@@ -36,19 +36,25 @@ import sys
36
  from pathlib import Path
37
  from typing import Any, Dict, List, Optional
38
 
39
- # We import the env's prosody/transcript renderers and scenario loader so
40
- # the prompt format the model trains on is IDENTICAL to what an inference-
41
- # time agent would see if it called the tools.
42
- ROOT = Path(__file__).resolve().parent.parent
43
- if str(ROOT) not in sys.path:
44
- sys.path.insert(0, str(ROOT))
45
-
46
- from server.scenarios import load_scenarios
47
- from server.audio_tools import (
48
- render_transcript,
49
- render_prosody_features,
50
- render_pitch_contour,
51
- )
 
 
 
 
 
 
52
 
53
 
54
  # ---------------------------------------------------------------------------
 
36
  from pathlib import Path
37
  from typing import Any, Dict, List, Optional
38
 
39
+ # Dual import path: works whether this script is run locally (with
40
+ # subtext_arena/ on sys.path) or after `pip install` (subtext_arena.* package).
41
+ try:
42
+ from subtext_arena.server.scenarios import load_scenarios
43
+ from subtext_arena.server.audio_tools import (
44
+ render_transcript,
45
+ render_prosody_features,
46
+ render_pitch_contour,
47
+ )
48
+ except ImportError:
49
+ ROOT = Path(__file__).resolve().parent.parent
50
+ if str(ROOT) not in sys.path:
51
+ sys.path.insert(0, str(ROOT))
52
+ from server.scenarios import load_scenarios # type: ignore[no-redef]
53
+ from server.audio_tools import ( # type: ignore[no-redef]
54
+ render_transcript,
55
+ render_prosody_features,
56
+ render_pitch_contour,
57
+ )
58
 
59
 
60
  # ---------------------------------------------------------------------------