File size: 3,732 Bytes
6dedae0
 
b2e8697
 
 
6dedae0
 
 
b2e8697
 
09dada1
6dedae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2e8697
 
6dedae0
 
 
 
 
 
b2e8697
6dedae0
b2e8697
6dedae0
09dada1
6dedae0
09dada1
 
 
6dedae0
 
 
 
09dada1
6dedae0
 
 
 
b2e8697
6dedae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2e8697
6dedae0
 
b2e8697
5a09aa9
6dedae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2e8697
6dedae0
b2e8697
6dedae0
b2e8697
 
6dedae0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import annotations

import os
import subprocess
import threading
import time
from pathlib import Path

from huggingface_hub import HfApi, snapshot_download


BASE_DIR = Path(__file__).resolve().parent
PROJECT_DIR = BASE_DIR / "grok2api"
DATA_DIR = Path(os.getenv("DATA_DIR", str(PROJECT_DIR / "data"))).expanduser()
LOG_DIR = Path(os.getenv("LOG_DIR", str(PROJECT_DIR / "logs"))).expanduser()

HF_TOKEN = os.getenv("HF_TOKEN", "")
DATASET_ID = os.getenv("DATASET_ID", "")
SYNC_INTERVAL = max(int(os.getenv("HF_SYNC_INTERVAL", "1800")), 60)

SERVER_HOST = os.getenv("SERVER_HOST", "0.0.0.0")
SERVER_PORT = os.getenv("SERVER_PORT") or os.getenv("PORT") or "8000"
SERVER_WORKERS = os.getenv("SERVER_WORKERS", "1")

SYNC_ALLOW_PATTERNS = ["data/**"]
SYNC_IGNORE_PATTERNS = [
    "data/.locks/**",
    "data/tmp/**",
    "logs/**",
    "**/__pycache__/**",
]


def log(message: str) -> None:
    print(f"[HF-Space] {message}", flush=True)


def ensure_local_dirs() -> None:
    DATA_DIR.mkdir(parents=True, exist_ok=True)
    LOG_DIR.mkdir(parents=True, exist_ok=True)


def download_data() -> None:
    if not DATASET_ID:
        log("未配置 DATASET_ID,跳过启动数据同步。")
        return

    try:
        log(f"开始从 Dataset 拉取数据: {DATASET_ID}")
        snapshot_download(
            repo_id=DATASET_ID,
            repo_type="dataset",
            local_dir=str(PROJECT_DIR),
            token=HF_TOKEN or None,
            allow_patterns=SYNC_ALLOW_PATTERNS,
            ignore_patterns=SYNC_IGNORE_PATTERNS,
        )
        log("数据拉取完成。")
    except Exception as exc:
        log(f"数据拉取失败,继续本地启动: {exc}")


def upload_data(run_as_future: bool) -> None:
    if not DATASET_ID:
        return

    if not HF_TOKEN:
        log("已配置 DATASET_ID,但未配置 HF_TOKEN,跳过数据上传。")
        return

    try:
        api = HfApi(token=HF_TOKEN)
        api.upload_folder(
            folder_path=str(PROJECT_DIR),
            repo_id=DATASET_ID,
            repo_type="dataset",
            commit_message="chore: sync Grok2API data from Space",
            allow_patterns=SYNC_ALLOW_PATTERNS,
            ignore_patterns=SYNC_IGNORE_PATTERNS,
            run_as_future=run_as_future,
        )
        if run_as_future:
            log("已提交后台数据同步任务。")
        else:
            log("退出前数据同步完成。")
    except Exception as exc:
        if "No files have been modified" not in str(exc):
            log(f"数据上传失败: {exc}")


def upload_loop() -> None:
    while True:
        time.sleep(SYNC_INTERVAL)
        upload_data(run_as_future=True)


def init_storage() -> None:
    subprocess.run(
        ["sh", "scripts/init_storage.sh"],
        cwd=PROJECT_DIR,
        check=True,
        env=os.environ.copy(),
    )


def run_server() -> None:
    env = os.environ.copy()
    env.setdefault("DATA_DIR", str(DATA_DIR))
    env.setdefault("LOG_DIR", str(LOG_DIR))
    env.setdefault("LOG_FILE_ENABLED", "false")

    command = [
        "granian",
        "--interface",
        "asgi",
        "--host",
        SERVER_HOST,
        "--port",
        SERVER_PORT,
        "--workers",
        SERVER_WORKERS,
        "main:app",
    ]

    log(
        f"启动 Grok2API: host={SERVER_HOST} port={SERVER_PORT} workers={SERVER_WORKERS}"
    )
    subprocess.run(command, cwd=PROJECT_DIR, check=True, env=env)


if __name__ == "__main__":
    ensure_local_dirs()
    download_data()
    init_storage()

    backup_thread = threading.Thread(target=upload_loop, daemon=True)
    backup_thread.start()

    try:
        run_server()
    finally:
        upload_data(run_as_future=False)