""" 仓库管理器 - 自动下载 Git 仓库 功能: - 从环境变量读取仓库列表 - 自动下载并解压 ZIP 文件 - 支持多个仓库并行处理 """ import os import shutil import threading import zipfile import urllib.request import urllib.error from pathlib import Path from typing import List from dataclasses import dataclass import logging logger = logging.getLogger(__name__) @dataclass class RepoConfig: """仓库配置""" name: str # 仓库名称(目录名) url: str # Git URL (支持 GitHub HTTPS) branch: str = "main" # 分支名 auto_update: bool = True # 是否自动更新 class RepoManager: """仓库管理器""" def __init__(self, base_dir: str = "./repos"): self.base_dir = Path(base_dir) self.base_dir.mkdir(parents=True, exist_ok=True) self.repos: List[RepoConfig] = [] self._lock = threading.Lock() def load_from_env(self): """从环境变量加载仓库配置 环境变量格式: REPO_URLS=repo1,repo2,repo3 REPO_1_URL=https://github.com/user/repo1.git REPO_2_URL=https://github.com/user/repo2.git """ self.repos = [] # 方式1:逗号分隔的 URL 列表 repo_urls = os.getenv("REPO_URLS", "") if repo_urls: for url in repo_urls.split(","): url = url.strip() if url: name = self._extract_repo_name(url) self.repos.append(RepoConfig(name=name, url=url)) # 方式2:带编号的配置 idx = 1 while True: url_key = f"REPO_{idx}_URL" url = os.getenv(url_key, "") if not url: break name = os.getenv(f"REPO_{idx}_NAME", "") if not name: name = self._extract_repo_name(url) branch = os.getenv(f"REPO_{idx}_BRANCH", "main") auto_update = os.getenv(f"REPO_{idx}_AUTO_UPDATE", "true").lower() == "true" self.repos.append(RepoConfig( name=name, url=url, branch=branch, auto_update=auto_update )) idx += 1 logger.info(f"加载了 {len(self.repos)} 个仓库配置") return self.repos def _extract_repo_name(self, url: str) -> str: """从 URL 提取仓库名称""" # 统一处理 URL url = url.replace("git@github.com:", "https://github.com/") url = url.rstrip("/") # 提取最后一部分 if "/" in url: name = url.split("/")[-1] # 去除 .git 后缀 if name.endswith(".git"): name = name[:-4] return name return url def _get_repo_dir(self, name: str) -> Path: """获取仓库目录路径""" return self.base_dir / name def _get_zip_url(self, url: str, branch: str) -> str: """将 Git URL 转换为 ZIP 下载 URL""" # git@github.com:user/repo.git -> https://github.com/user/repo/archive/refs/heads/branch.zip # https://github.com/user/repo.git -> https://github.com/user/repo/archive/refs/heads/branch.zip url = url.replace("git@github.com:", "https://github.com/") url = url.removesuffix(".git") # 提取 owner/repo parts = url.rstrip("/").split("/") if len(parts) >= 2: owner = parts[-2] repo = parts[-1] return f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip" raise ValueError(f"无法解析仓库 URL: {url}") def download_repo(self, repo: RepoConfig, force: bool = False) -> bool: """下载并解压仓库 Args: repo: 仓库配置 force: 是否强制重新下载 Returns: bool: 是否成功 """ repo_dir = self._get_repo_dir(repo.name) # 检查仓库是否已存在且有效(无需重新下载) if not force and repo_dir.exists() and any(repo_dir.iterdir()): logger.info(f"仓库已存在,跳过下载: {repo.name}") return True try: # 删除已存在的目录 if repo_dir.exists(): logger.info(f"删除旧版本: {repo.name}") shutil.rmtree(repo_dir) repo_dir.mkdir(parents=True, exist_ok=True) # 下载 ZIP zip_url = self._get_zip_url(repo.url, repo.branch) zip_path = repo_dir / "repo.zip" logger.info(f"下载仓库: {repo.name} ({zip_url})") headers = {"User-Agent": "Mozilla/5.0"} req = urllib.request.Request(zip_url, headers=headers) with urllib.request.urlopen(req, timeout=120) as response: with open(zip_path, 'wb') as f: f.write(response.read()) # 解压 logger.info(f"解压仓库: {repo.name}") with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(repo_dir) # ZIP 包会多一层目录,需要移动内容 extracted_dir = repo_dir / f"{repo.name}-{repo.branch}" if extracted_dir.exists(): for item in extracted_dir.iterdir(): item.rename(repo_dir / item.name) extracted_dir.rmdir() # 删除 ZIP 文件 zip_path.unlink() logger.info(f"仓库 {repo.name} 下载成功") return True except Exception as e: logger.error(f"下载仓库 {repo.name} 失败: {e}") # 清理失败的目录 if repo_dir.exists(): shutil.rmtree(repo_dir) return False def sync_all(self, parallel: bool = True, force: bool = False) -> dict: """同步所有仓库 Args: parallel: 是否并行下载 force: 是否强制重新下载所有仓库 Returns: dict: { "success": [成功列表], "skipped": [跳过列表], "failed": [失败列表] } """ if not self.repos: self.load_from_env() if not self.repos: logger.warning("没有配置任何仓库") return {"success": [], "skipped": [], "failed": []} results = {"success": [], "skipped": [], "failed": []} if parallel: # 并行下载 threads = [] for repo in self.repos: t = threading.Thread(target=self._sync_single, args=(repo, results, force)) t.start() threads.append(t) for t in threads: t.join() else: # 顺序下载 for repo in self.repos: self._sync_single(repo, results, force) logger.info(f"仓库同步完成: 成功 {len(results['success'])}, 跳过 {len(results['skipped'])}, 失败 {len(results['failed'])}") return results def _sync_single(self, repo: RepoConfig, results: dict, force: bool = False): """同步单个仓库(线程安全)""" # 先检查是否已存在且有效(非强制模式) if not force: repo_dir = self._get_repo_dir(repo.name) if repo_dir.exists() and any(repo_dir.iterdir()): with self._lock: results["skipped"].append(repo.name) return if self.download_repo(repo, force): with self._lock: results["success"].append(repo.name) else: with self._lock: results["failed"].append(repo.name) def get_repo_list(self) -> List[dict]: """获取已下载的仓库列表""" repos = [] for item in self.base_dir.iterdir(): if item.is_dir(): repos.append({ "name": item.name, "path": str(item) }) return repos def remove_repo(self, name: str) -> bool: """删除仓库""" repo_dir = self._get_repo_dir(name) if repo_dir.exists(): try: shutil.rmtree(repo_dir) logger.info(f"删除仓库: {name}") return True except Exception as e: logger.error(f"删除仓库 {name} 失败: {e}") return False return False def clear_all(self) -> int: """清空所有仓库""" count = 0 for item in self.base_dir.iterdir(): if item.is_dir(): try: shutil.rmtree(item) count += 1 except Exception as e: logger.warning(f"删除 {item.name} 失败: {e}") logger.info(f"清空仓库: 删除 {count} 个") return count def sync_repos_on_startup(): """启动时同步仓库""" manager = RepoManager() repos = manager.load_from_env() if repos: logger.info(f"启动同步 {len(repos)} 个仓库...") results = manager.sync_all(parallel=True) return results else: logger.info("未配置仓库,跳过同步") return {"success": [], "failed": []} if __name__ == "__main__": # 测试 logging.basicConfig(level=logging.INFO) # 示例:设置环境变量后测试 os.environ["REPO_1_URL"] = "https://github.com/psyche/astronomy.git" os.environ["REPO_1_NAME"] = "astronomy" os.environ["REPO_1_BRANCH"] = "main" manager = RepoManager("./test-code") manager.load_from_env() print(f"配置了 {len(manager.repos)} 个仓库:") for repo in manager.repos: print(f" - {repo.name}: {repo.url}") results = manager.sync_all() print(f"\n同步结果: 成功 {len(results['success'])}, 失败 {len(results['failed'])}")