File size: 1,874 Bytes
b77d3c5
9f6f68c
 
b77d3c5
ae60795
 
b77d3c5
 
9f6f68c
 
 
 
 
 
b77d3c5
ae60795
 
b77d3c5
 
ae60795
b77d3c5
 
 
 
ae60795
 
 
 
b77d3c5
 
 
ae60795
b77d3c5
 
ae60795
 
b77d3c5
 
 
 
 
 
ae60795
b77d3c5
 
 
ae60795
 
 
 
 
b77d3c5
 
ae60795
 
 
 
 
b77d3c5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# training/test_rollout.py
# Works both as: python training/test_rollout.py
# And as:        python -m training.test_rollout

import sys
import os
import asyncio

# Add BOTH project root AND training/ dir so imports work in all contexts
_ROOT     = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
_TRAINING = os.path.abspath(os.path.dirname(__file__))
for _p in [_ROOT, _TRAINING]:
    if _p not in sys.path:
        sys.path.insert(0, _p)

from transformers import AutoModelForCausalLM, AutoTokenizer
from rollout import run_episode


MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
ENV_URL = "http://127.0.0.1:8000"


async def main():
    print("Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype="auto",
        device_map="auto",
    )
    print(f"Model on: {next(model.parameters()).device}")
    print("Running single episode (difficulty=1)...\n")

    result = await run_episode(
        model=model,
        tokenizer=tokenizer,
        env_url=ENV_URL,
        difficulty=1,
        message_timeout_s=300.0,
    )

    print("\n========== RESULT ==========")
    print(f"Total Reward:     {result['total_reward']:.3f}")
    print(f"Violations:       {result['violations']}")
    print(f"Steps Completed:  {result['steps_completed']}")
    print(f"Difficulty:       {result['difficulty']}")
    print("=============================\n")

    if result["trajectory"]:
        first = result["trajectory"][0]
        print("=== First Generation ===")
        print(f"ACTION:  {first['action_type']}")
        print(f"CONTENT: {first['generated'][:200]}")
        print(f"REWARD:  {first['reward']:.3f}")


if __name__ == "__main__":
    asyncio.run(main())