sub / restore_from_dataset.py
gallyg's picture
Update restore_from_dataset.py
c4f56ae verified
#!/usr/bin/env python3
import argparse
import gzip
import json
import os
import shutil
import subprocess
import sys
from pathlib import Path
from huggingface_hub import HfFileSystem
def env(name: str, default: str | None = None, required: bool = False) -> str:
value = os.getenv(name, default)
if required and not value:
raise RuntimeError(f"Missing required environment variable: {name}")
return value or ""
def run_sql(sql: str) -> str:
host = env("DATABASE_HOST", "127.0.0.1")
port = env("DATABASE_PORT", "5432")
user = env("DATABASE_USER", env("POSTGRES_USER", "sub2api"))
password = env("DATABASE_PASSWORD", env("POSTGRES_PASSWORD", ""))
dbname = env("DATABASE_DBNAME", env("POSTGRES_DB", "sub2api"))
cmd = [
"psql",
"-h", host,
"-p", port,
"-U", user,
"-d", dbname,
"-tAc",
sql,
]
env_map = os.environ.copy()
env_map["PGPASSWORD"] = password
return subprocess.check_output(cmd, env=env_map, text=True).strip()
def users_table_exists() -> bool:
try:
sql = """
SELECT CASE
WHEN EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema='public' AND table_name='users'
) THEN 1 ELSE 0 END;
"""
out = run_sql(sql)
return out == "1"
except Exception as exc:
print(f"[restore] failed to check users table existence: {exc}")
return False
def is_business_data_present() -> bool:
# 如果 users 表都还不存在,说明还没迁移/初始化,直接视为空库
if not users_table_exists():
print("[restore] users table does not exist yet, treating DB as empty")
return False
try:
out = run_sql("SELECT CASE WHEN EXISTS (SELECT 1 FROM users LIMIT 1) THEN 1 ELSE 0 END;")
return out == "1"
except Exception as exc:
print(f"[restore] failed to check business data presence: {exc}")
return False
def download_latest_metadata(fs: HfFileSystem, dataset_repo_id: str, workdir: Path) -> dict | None:
remote_latest = f"datasets/{dataset_repo_id}/postgres/latest.json"
local_latest = workdir / "latest.json"
try:
with fs.open(remote_latest, "rb") as src, local_latest.open("wb") as dst:
shutil.copyfileobj(src, dst)
except Exception as exc:
print(f"[restore] latest.json not found or unreadable: {exc}")
return None
try:
return json.loads(local_latest.read_text(encoding="utf-8"))
except Exception as exc:
print(f"[restore] latest.json parse failed: {exc}")
return None
def download_backup(fs: HfFileSystem, dataset_repo_id: str, remote_sql_path: str, local_gz: Path) -> None:
remote_path = f"datasets/{dataset_repo_id}/{remote_sql_path}"
print(f"[restore] downloading {remote_path}")
with fs.open(remote_path, "rb") as src, local_gz.open("wb") as dst:
shutil.copyfileobj(src, dst)
def gunzip_file(src: Path, dst: Path) -> None:
with gzip.open(src, "rb") as fin, dst.open("wb") as fout:
shutil.copyfileobj(fin, fout)
def restore_sql(sql_path: Path) -> None:
host = env("DATABASE_HOST", "127.0.0.1")
port = env("DATABASE_PORT", "5432")
user = env("DATABASE_USER", env("POSTGRES_USER", "sub2api"))
password = env("DATABASE_PASSWORD", env("POSTGRES_PASSWORD", ""))
dbname = env("DATABASE_DBNAME", env("POSTGRES_DB", "sub2api"))
cmd = [
"psql",
"-h", host,
"-p", port,
"-U", user,
"-d", dbname,
"-v", "ON_ERROR_STOP=1",
"-f", str(sql_path),
]
env_map = os.environ.copy()
env_map["PGPASSWORD"] = password
print(f"[restore] running: {' '.join(cmd[:-1])} <sql>")
subprocess.run(cmd, check=True, env=env_map)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument("--restore-latest", action="store_true")
parser.parse_args()
dataset_repo_id = env("DATASET_REPO_ID", "")
if not dataset_repo_id:
print("[restore] DATASET_REPO_ID not set, skipping restore")
return 0
# 如果库里已经有业务数据,就不恢复
if is_business_data_present():
print("[restore] users table already has data, skipping restore")
return 0
hf_token = os.getenv("HF_TOKEN")
fs = HfFileSystem(token=hf_token) if hf_token else HfFileSystem()
workdir = Path("/tmp/sub2api_restore")
workdir.mkdir(parents=True, exist_ok=True)
try:
metadata = download_latest_metadata(fs, dataset_repo_id, workdir)
if not metadata:
print("[restore] no backup metadata available, skipping restore")
return 0
remote_sql_path = metadata.get("remote_sql_path") or f"postgres/{metadata['timestamp_utc']}.sql.gz"
gz_path = workdir / Path(remote_sql_path).name
sql_path = workdir / gz_path.stem
download_backup(fs, dataset_repo_id, remote_sql_path, gz_path)
gunzip_file(gz_path, sql_path)
restore_sql(sql_path)
marker = Path(env("SUB2API_DATA_DIR", "/app/data")) / "restore_last.json"
marker.parent.mkdir(parents=True, exist_ok=True)
marker.write_text(json.dumps(metadata, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"[restore] restored backup {remote_sql_path}")
return 0
except Exception as exc:
print(f"[restore] failed: {exc}", file=sys.stderr)
return 1
finally:
for p in workdir.glob("*"):
try:
p.unlink()
except Exception:
pass
if __name__ == "__main__":
raise SystemExit(main())