astrbot_help / src /repo_manager.py
qa1145's picture
Upload 28 files
d347708 verified
"""
仓库管理器 - 自动下载 Git 仓库
功能:
- 从环境变量读取仓库列表
- 自动下载并解压 ZIP 文件
- 支持多个仓库并行处理
"""
import os
import shutil
import threading
import zipfile
import urllib.request
import urllib.error
from pathlib import Path
from typing import List
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class RepoConfig:
"""仓库配置"""
name: str # 仓库名称(目录名)
url: str # Git URL (支持 GitHub HTTPS)
branch: str = "main" # 分支名
auto_update: bool = True # 是否自动更新
class RepoManager:
"""仓库管理器"""
def __init__(self, base_dir: str = "./repos"):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.repos: List[RepoConfig] = []
self._lock = threading.Lock()
def load_from_env(self):
"""从环境变量加载仓库配置
环境变量格式:
REPO_URLS=repo1,repo2,repo3
REPO_1_URL=https://github.com/user/repo1.git
REPO_2_URL=https://github.com/user/repo2.git
"""
self.repos = []
# 方式1:逗号分隔的 URL 列表
repo_urls = os.getenv("REPO_URLS", "")
if repo_urls:
for url in repo_urls.split(","):
url = url.strip()
if url:
name = self._extract_repo_name(url)
self.repos.append(RepoConfig(name=name, url=url))
# 方式2:带编号的配置
idx = 1
while True:
url_key = f"REPO_{idx}_URL"
url = os.getenv(url_key, "")
if not url:
break
name = os.getenv(f"REPO_{idx}_NAME", "")
if not name:
name = self._extract_repo_name(url)
branch = os.getenv(f"REPO_{idx}_BRANCH", "main")
auto_update = os.getenv(f"REPO_{idx}_AUTO_UPDATE", "true").lower() == "true"
self.repos.append(RepoConfig(
name=name,
url=url,
branch=branch,
auto_update=auto_update
))
idx += 1
logger.info(f"加载了 {len(self.repos)} 个仓库配置")
return self.repos
def _extract_repo_name(self, url: str) -> str:
"""从 URL 提取仓库名称"""
# 统一处理 URL
url = url.replace("git@github.com:", "https://github.com/")
url = url.rstrip("/")
# 提取最后一部分
if "/" in url:
name = url.split("/")[-1]
# 去除 .git 后缀
if name.endswith(".git"):
name = name[:-4]
return name
return url
def _get_repo_dir(self, name: str) -> Path:
"""获取仓库目录路径"""
return self.base_dir / name
def _get_zip_url(self, url: str, branch: str) -> str:
"""将 Git URL 转换为 ZIP 下载 URL"""
# git@github.com:user/repo.git -> https://github.com/user/repo/archive/refs/heads/branch.zip
# https://github.com/user/repo.git -> https://github.com/user/repo/archive/refs/heads/branch.zip
url = url.replace("git@github.com:", "https://github.com/")
url = url.removesuffix(".git")
# 提取 owner/repo
parts = url.rstrip("/").split("/")
if len(parts) >= 2:
owner = parts[-2]
repo = parts[-1]
return f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"
raise ValueError(f"无法解析仓库 URL: {url}")
def download_repo(self, repo: RepoConfig, force: bool = False) -> bool:
"""下载并解压仓库
Args:
repo: 仓库配置
force: 是否强制重新下载
Returns:
bool: 是否成功
"""
repo_dir = self._get_repo_dir(repo.name)
# 检查仓库是否已存在且有效(无需重新下载)
if not force and repo_dir.exists() and any(repo_dir.iterdir()):
logger.info(f"仓库已存在,跳过下载: {repo.name}")
return True
try:
# 删除已存在的目录
if repo_dir.exists():
logger.info(f"删除旧版本: {repo.name}")
shutil.rmtree(repo_dir)
repo_dir.mkdir(parents=True, exist_ok=True)
# 下载 ZIP
zip_url = self._get_zip_url(repo.url, repo.branch)
zip_path = repo_dir / "repo.zip"
logger.info(f"下载仓库: {repo.name} ({zip_url})")
headers = {"User-Agent": "Mozilla/5.0"}
req = urllib.request.Request(zip_url, headers=headers)
with urllib.request.urlopen(req, timeout=120) as response:
with open(zip_path, 'wb') as f:
f.write(response.read())
# 解压
logger.info(f"解压仓库: {repo.name}")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(repo_dir)
# ZIP 包会多一层目录,需要移动内容
extracted_dir = repo_dir / f"{repo.name}-{repo.branch}"
if extracted_dir.exists():
for item in extracted_dir.iterdir():
item.rename(repo_dir / item.name)
extracted_dir.rmdir()
# 删除 ZIP 文件
zip_path.unlink()
logger.info(f"仓库 {repo.name} 下载成功")
return True
except Exception as e:
logger.error(f"下载仓库 {repo.name} 失败: {e}")
# 清理失败的目录
if repo_dir.exists():
shutil.rmtree(repo_dir)
return False
def sync_all(self, parallel: bool = True, force: bool = False) -> dict:
"""同步所有仓库
Args:
parallel: 是否并行下载
force: 是否强制重新下载所有仓库
Returns:
dict: { "success": [成功列表], "skipped": [跳过列表], "failed": [失败列表] }
"""
if not self.repos:
self.load_from_env()
if not self.repos:
logger.warning("没有配置任何仓库")
return {"success": [], "skipped": [], "failed": []}
results = {"success": [], "skipped": [], "failed": []}
if parallel:
# 并行下载
threads = []
for repo in self.repos:
t = threading.Thread(target=self._sync_single, args=(repo, results, force))
t.start()
threads.append(t)
for t in threads:
t.join()
else:
# 顺序下载
for repo in self.repos:
self._sync_single(repo, results, force)
logger.info(f"仓库同步完成: 成功 {len(results['success'])}, 跳过 {len(results['skipped'])}, 失败 {len(results['failed'])}")
return results
def _sync_single(self, repo: RepoConfig, results: dict, force: bool = False):
"""同步单个仓库(线程安全)"""
# 先检查是否已存在且有效(非强制模式)
if not force:
repo_dir = self._get_repo_dir(repo.name)
if repo_dir.exists() and any(repo_dir.iterdir()):
with self._lock:
results["skipped"].append(repo.name)
return
if self.download_repo(repo, force):
with self._lock:
results["success"].append(repo.name)
else:
with self._lock:
results["failed"].append(repo.name)
def get_repo_list(self) -> List[dict]:
"""获取已下载的仓库列表"""
repos = []
for item in self.base_dir.iterdir():
if item.is_dir():
repos.append({
"name": item.name,
"path": str(item)
})
return repos
def remove_repo(self, name: str) -> bool:
"""删除仓库"""
repo_dir = self._get_repo_dir(name)
if repo_dir.exists():
try:
shutil.rmtree(repo_dir)
logger.info(f"删除仓库: {name}")
return True
except Exception as e:
logger.error(f"删除仓库 {name} 失败: {e}")
return False
return False
def clear_all(self) -> int:
"""清空所有仓库"""
count = 0
for item in self.base_dir.iterdir():
if item.is_dir():
try:
shutil.rmtree(item)
count += 1
except Exception as e:
logger.warning(f"删除 {item.name} 失败: {e}")
logger.info(f"清空仓库: 删除 {count} 个")
return count
def sync_repos_on_startup():
"""启动时同步仓库"""
manager = RepoManager()
repos = manager.load_from_env()
if repos:
logger.info(f"启动同步 {len(repos)} 个仓库...")
results = manager.sync_all(parallel=True)
return results
else:
logger.info("未配置仓库,跳过同步")
return {"success": [], "failed": []}
if __name__ == "__main__":
# 测试
logging.basicConfig(level=logging.INFO)
# 示例:设置环境变量后测试
os.environ["REPO_1_URL"] = "https://github.com/psyche/astronomy.git"
os.environ["REPO_1_NAME"] = "astronomy"
os.environ["REPO_1_BRANCH"] = "main"
manager = RepoManager("./test-code")
manager.load_from_env()
print(f"配置了 {len(manager.repos)} 个仓库:")
for repo in manager.repos:
print(f" - {repo.name}: {repo.url}")
results = manager.sync_all()
print(f"\n同步结果: 成功 {len(results['success'])}, 失败 {len(results['failed'])}")