anugrahhu commited on
Commit
8f805e2
·
verified ·
1 Parent(s): f82f913

fix: disable fast_inference (vLLM not installed) in training/evaluate.py

Browse files
Files changed (1) hide show
  1. training/evaluate.py +153 -152
training/evaluate.py CHANGED
@@ -1,152 +1,153 @@
1
- """Evaluate an LLM (with optional LoRA adapters) on CERNenv.
2
-
3
- Usage:
4
- python -m training.evaluate --model_name unsloth/Qwen2.5-3B-Instruct \\
5
- --difficulty easy --episodes 16 --tag pre_train \\
6
- --out training/runs/eval_pre_train.jsonl
7
-
8
- python -m training.evaluate --model_name unsloth/Qwen2.5-3B-Instruct \\
9
- --adapter_dir training/runs/unsloth-grpo --difficulty easy \\
10
- --episodes 16 --tag post_train --out training/runs/eval_post_train.jsonl
11
- """
12
-
13
- from __future__ import annotations
14
-
15
- import argparse
16
- import json
17
- import logging
18
- import os
19
- from dataclasses import asdict
20
- from pathlib import Path
21
- from typing import Any, Dict, List, Optional
22
-
23
-
24
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
25
- logger = logging.getLogger(__name__)
26
-
27
-
28
- def _build_generate_fn(
29
- *,
30
- model_name: str,
31
- adapter_dir: Optional[str],
32
- use_unsloth: bool,
33
- max_seq_length: int,
34
- ):
35
- if use_unsloth:
36
- from unsloth import FastLanguageModel # type: ignore
37
-
38
- model, tokenizer = FastLanguageModel.from_pretrained(
39
- model_name=model_name,
40
- max_seq_length=max_seq_length,
41
- load_in_4bit=True,
42
- fast_inference=True,
43
- )
44
- if adapter_dir:
45
- model.load_adapter(adapter_dir)
46
- FastLanguageModel.for_inference(model)
47
- else:
48
- import torch
49
- from transformers import AutoModelForCausalLM, AutoTokenizer
50
-
51
- tokenizer = AutoTokenizer.from_pretrained(model_name)
52
- model = AutoModelForCausalLM.from_pretrained(
53
- model_name,
54
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
55
- device_map="auto" if torch.cuda.is_available() else None,
56
- )
57
- if adapter_dir:
58
- from peft import PeftModel # type: ignore
59
- model = PeftModel.from_pretrained(model, adapter_dir)
60
-
61
- if tokenizer.pad_token is None:
62
- tokenizer.pad_token = tokenizer.eos_token
63
-
64
- def prompt_fn(chat: List[Dict[str, str]]) -> str:
65
- return tokenizer.apply_chat_template(
66
- chat, add_generation_prompt=True, tokenize=False
67
- )
68
-
69
- def generate_fn(prompt: str, config) -> str:
70
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
71
- outputs = model.generate(
72
- **inputs,
73
- max_new_tokens=config.max_new_tokens,
74
- do_sample=True,
75
- temperature=config.temperature,
76
- top_p=config.top_p,
77
- pad_token_id=tokenizer.pad_token_id,
78
- )
79
- gen = outputs[0][inputs["input_ids"].shape[1]:]
80
- return tokenizer.decode(gen, skip_special_tokens=True)
81
-
82
- return prompt_fn, generate_fn
83
-
84
-
85
- def main() -> None: # pragma: no cover
86
- parser = argparse.ArgumentParser()
87
- parser.add_argument("--model_name", required=True)
88
- parser.add_argument("--adapter_dir", default=None)
89
- parser.add_argument("--scenario", default=None)
90
- parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
91
- parser.add_argument("--episodes", type=int, default=16)
92
- parser.add_argument("--seed", type=int, default=1000)
93
- parser.add_argument("--max_steps", type=int, default=18)
94
- parser.add_argument("--max_seq_length", type=int, default=2048)
95
- parser.add_argument("--no_unsloth", action="store_true")
96
- parser.add_argument("--tag", default="eval")
97
- parser.add_argument("--out", required=True)
98
- args = parser.parse_args()
99
-
100
- from server.environment import CERNCollisionEnvironment
101
- from training.llm_agent import LLMAgentConfig
102
- from training.rollouts import collect_episode, save_episodes_jsonl
103
-
104
- use_unsloth = not args.no_unsloth
105
- try:
106
- prompt_fn, generate_fn = _build_generate_fn(
107
- model_name=args.model_name,
108
- adapter_dir=args.adapter_dir,
109
- use_unsloth=use_unsloth,
110
- max_seq_length=args.max_seq_length,
111
- )
112
- except ImportError as exc:
113
- logger.warning("Unsloth not available (%s); falling back to transformers.", exc)
114
- prompt_fn, generate_fn = _build_generate_fn(
115
- model_name=args.model_name,
116
- adapter_dir=args.adapter_dir,
117
- use_unsloth=False,
118
- max_seq_length=args.max_seq_length,
119
- )
120
-
121
- env = CERNCollisionEnvironment(max_steps=args.max_steps)
122
- cfg = LLMAgentConfig()
123
-
124
- episodes = []
125
- for ep in range(args.episodes):
126
- seed = args.seed + ep
127
- rec = collect_episode(
128
- env=env,
129
- seed=seed,
130
- scenario=args.scenario,
131
- difficulty=args.difficulty,
132
- prompt_fn=prompt_fn,
133
- generate_fn=generate_fn,
134
- config=cfg,
135
- )
136
- episodes.append(rec)
137
- logger.info(
138
- "[%s][%d/%d] reward=%+.3f discovered=%s mass=%s channel=%s",
139
- args.tag, ep + 1, args.episodes,
140
- rec.cumulative_reward, rec.discovered, rec.correct_mass, rec.correct_channel,
141
- )
142
-
143
- Path(args.out).parent.mkdir(parents=True, exist_ok=True)
144
- save_episodes_jsonl(episodes, args.out)
145
-
146
- rewards = [e.cumulative_reward for e in episodes]
147
- success = sum(1 for e in episodes if e.discovered) / len(episodes)
148
- logger.info("[%s] mean_reward=%.3f success_rate=%.2f", args.tag, sum(rewards) / len(rewards), success)
149
-
150
-
151
- if __name__ == "__main__": # pragma: no cover
152
- main()
 
 
1
+ """Evaluate an LLM (with optional LoRA adapters) on CERNenv.
2
+
3
+ Usage:
4
+ python -m training.evaluate --model_name unsloth/Qwen2.5-3B-Instruct \\
5
+ --difficulty easy --episodes 16 --tag pre_train \\
6
+ --out training/runs/eval_pre_train.jsonl
7
+
8
+ python -m training.evaluate --model_name unsloth/Qwen2.5-3B-Instruct \\
9
+ --adapter_dir training/runs/unsloth-grpo --difficulty easy \\
10
+ --episodes 16 --tag post_train --out training/runs/eval_post_train.jsonl
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ import logging
18
+ import os
19
+ from dataclasses import asdict
20
+ from pathlib import Path
21
+ from typing import Any, Dict, List, Optional
22
+
23
+
24
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def _build_generate_fn(
29
+ *,
30
+ model_name: str,
31
+ adapter_dir: Optional[str],
32
+ use_unsloth: bool,
33
+ max_seq_length: int,
34
+ ):
35
+ if use_unsloth:
36
+ from unsloth import FastLanguageModel # type: ignore
37
+
38
+ model, tokenizer = FastLanguageModel.from_pretrained(
39
+ model_name=model_name,
40
+ max_seq_length=max_seq_length,
41
+ load_in_4bit=True,
42
+ # fast_inference requires vLLM, which is not in requirements; plain transformers generation is used instead. Re-enable after pinning vllm in space/training/requirements.txt.
43
+ fast_inference=False,
44
+ )
45
+ if adapter_dir:
46
+ model.load_adapter(adapter_dir)
47
+ FastLanguageModel.for_inference(model)
48
+ else:
49
+ import torch
50
+ from transformers import AutoModelForCausalLM, AutoTokenizer
51
+
52
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ model_name,
55
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
56
+ device_map="auto" if torch.cuda.is_available() else None,
57
+ )
58
+ if adapter_dir:
59
+ from peft import PeftModel # type: ignore
60
+ model = PeftModel.from_pretrained(model, adapter_dir)
61
+
62
+ if tokenizer.pad_token is None:
63
+ tokenizer.pad_token = tokenizer.eos_token
64
+
65
+ def prompt_fn(chat: List[Dict[str, str]]) -> str:
66
+ return tokenizer.apply_chat_template(
67
+ chat, add_generation_prompt=True, tokenize=False
68
+ )
69
+
70
+ def generate_fn(prompt: str, config) -> str:
71
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
72
+ outputs = model.generate(
73
+ **inputs,
74
+ max_new_tokens=config.max_new_tokens,
75
+ do_sample=True,
76
+ temperature=config.temperature,
77
+ top_p=config.top_p,
78
+ pad_token_id=tokenizer.pad_token_id,
79
+ )
80
+ gen = outputs[0][inputs["input_ids"].shape[1]:]
81
+ return tokenizer.decode(gen, skip_special_tokens=True)
82
+
83
+ return prompt_fn, generate_fn
84
+
85
+
86
+ def main() -> None: # pragma: no cover
87
+ parser = argparse.ArgumentParser()
88
+ parser.add_argument("--model_name", required=True)
89
+ parser.add_argument("--adapter_dir", default=None)
90
+ parser.add_argument("--scenario", default=None)
91
+ parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], default="easy")
92
+ parser.add_argument("--episodes", type=int, default=16)
93
+ parser.add_argument("--seed", type=int, default=1000)
94
+ parser.add_argument("--max_steps", type=int, default=18)
95
+ parser.add_argument("--max_seq_length", type=int, default=2048)
96
+ parser.add_argument("--no_unsloth", action="store_true")
97
+ parser.add_argument("--tag", default="eval")
98
+ parser.add_argument("--out", required=True)
99
+ args = parser.parse_args()
100
+
101
+ from server.environment import CERNCollisionEnvironment
102
+ from training.llm_agent import LLMAgentConfig
103
+ from training.rollouts import collect_episode, save_episodes_jsonl
104
+
105
+ use_unsloth = not args.no_unsloth
106
+ try:
107
+ prompt_fn, generate_fn = _build_generate_fn(
108
+ model_name=args.model_name,
109
+ adapter_dir=args.adapter_dir,
110
+ use_unsloth=use_unsloth,
111
+ max_seq_length=args.max_seq_length,
112
+ )
113
+ except ImportError as exc:
114
+ logger.warning("Unsloth not available (%s); falling back to transformers.", exc)
115
+ prompt_fn, generate_fn = _build_generate_fn(
116
+ model_name=args.model_name,
117
+ adapter_dir=args.adapter_dir,
118
+ use_unsloth=False,
119
+ max_seq_length=args.max_seq_length,
120
+ )
121
+
122
+ env = CERNCollisionEnvironment(max_steps=args.max_steps)
123
+ cfg = LLMAgentConfig()
124
+
125
+ episodes = []
126
+ for ep in range(args.episodes):
127
+ seed = args.seed + ep
128
+ rec = collect_episode(
129
+ env=env,
130
+ seed=seed,
131
+ scenario=args.scenario,
132
+ difficulty=args.difficulty,
133
+ prompt_fn=prompt_fn,
134
+ generate_fn=generate_fn,
135
+ config=cfg,
136
+ )
137
+ episodes.append(rec)
138
+ logger.info(
139
+ "[%s][%d/%d] reward=%+.3f discovered=%s mass=%s channel=%s",
140
+ args.tag, ep + 1, args.episodes,
141
+ rec.cumulative_reward, rec.discovered, rec.correct_mass, rec.correct_channel,
142
+ )
143
+
144
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
145
+ save_episodes_jsonl(episodes, args.out)
146
+
147
+ rewards = [e.cumulative_reward for e in episodes]
148
+ success = sum(1 for e in episodes if e.discovered) / len(episodes)
149
+ logger.info("[%s] mean_reward=%.3f success_rate=%.2f", args.tag, sum(rewards) / len(rewards), success)
150
+
151
+
152
+ if __name__ == "__main__": # pragma: no cover
153
+ main()