File size: 3,433 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e64e71
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a001a97
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
 
 
d9759a5
5dd1bb4
 
 
 
 
 
 
 
d9759a5
 
 
 
 
 
 
 
5dd1bb4
 
 
 
d9759a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
FastAPI application for the SQLEnv environment.

Exposes the SQLEnvironment over HTTP and WebSocket endpoints,
compatible with the OpenEnv EnvClient.

Usage:
    # Development (with auto-reload):
    uv run uvicorn server.app:app --reload --host 0.0.0.0 --port 8000

    # Via uv:
    uv run server
"""

import os
from pathlib import Path

# Load environment variables from .env file
try:
    from dotenv import load_dotenv

    env_file = Path(__file__).parent.parent / ".env"
    if env_file.exists():
        load_dotenv(env_file)
except ImportError:
    pass  # python-dotenv not installed, use system env vars

from openenv.core.env_server import create_app

try:
    from sql_env.models import SQLAction, SQLObservation
    from sql_env.server.sql_environment import SQLEnvironment
except ImportError:
    # Fallback for Docker where PYTHONPATH=/app/env
    from models import SQLAction, SQLObservation  # type: ignore[no-redef]
    from server.sql_environment import SQLEnvironment  # type: ignore[no-redef]


def get_tokenizer():
    """Get tokenizer from environment or use a mock for testing."""
    tokenizer_name = os.environ.get(
        "TOKENIZER_NAME", "mistralai/Mistral-7B-Instruct-v0.1"
    )

    try:
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        print(f"Loaded tokenizer: {tokenizer_name}")
        return tokenizer
    except ImportError:
        print(
            "Warning: transformers not installed, using mock tokenizer for testing only"
        )
        from server.mock_tokenizer import MockTokenizer

        return MockTokenizer()


def create_sql_environment():
    """Factory function that creates SQLEnvironment with tokenizer and paths."""
    tokenizer = get_tokenizer()
    questions_path = os.environ.get(
        "QUESTIONS_PATH",
        str(
            Path(__file__).parent.parent
            / "data"
            / "questions"
            / "student_assessment.json"
        ),
    )
    db_dir = os.environ.get(
        "DB_DIR",
        str(Path(__file__).parent.parent / "data" / "databases"),
    )
    return SQLEnvironment(
        questions_path=questions_path,
        db_dir=db_dir,
        tokenizer=tokenizer,
    )


# Create the FastAPI app.
#
# Note: hosted Space is single-session. External users running TRL's
# GRPOTrainer against https://hjerpe-sql-env.hf.space with
# num_generations > 1 will hit openenv-core's default 1-session cap.
# Fix requires (a) auditing SQLEnvironment for shared mutable state
# across sessions, (b) declaring SUPPORTS_CONCURRENT_SESSIONS=True on
# the class, (c) passing max_concurrent_envs=64 here. Deferred as a
# post-launch follow-up. Our own training uses an in-process
# SQLEnvironment via SQLEnvTRL, so this does not affect internal runs.
app = create_app(
    create_sql_environment,
    SQLAction,
    SQLObservation,
    env_name="sql_env",
)


def main(host: str = "0.0.0.0", port: int | None = None):
    """Entry point for running the server directly.

    Enables:
        uv run server
        python -m sql_env.server.app
    """
    import uvicorn

    if port is None:
        import argparse

        parser = argparse.ArgumentParser()
        parser.add_argument("--port", type=int, default=8000)
        args = parser.parse_args()
        port = args.port

    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()