File size: 10,276 Bytes
d347708
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""

仓库管理器 - 自动下载 Git 仓库



功能:

- 从环境变量读取仓库列表

- 自动下载并解压 ZIP 文件

- 支持多个仓库并行处理

"""

import os
import shutil
import threading
import zipfile
import urllib.request
import urllib.error
from pathlib import Path
from typing import List
from dataclasses import dataclass
import logging

logger = logging.getLogger(__name__)


@dataclass
class RepoConfig:
    """仓库配置"""
    name: str  # 仓库名称(目录名)
    url: str  # Git URL (支持 GitHub HTTPS)
    branch: str = "main"  # 分支名
    auto_update: bool = True  # 是否自动更新


class RepoManager:
    """仓库管理器"""

    def __init__(self, base_dir: str = "./repos"):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(parents=True, exist_ok=True)
        self.repos: List[RepoConfig] = []
        self._lock = threading.Lock()

    def load_from_env(self):
        """从环境变量加载仓库配置



        环境变量格式:

        REPO_URLS=repo1,repo2,repo3

        REPO_1_URL=https://github.com/user/repo1.git

        REPO_2_URL=https://github.com/user/repo2.git

        """
        self.repos = []

        # 方式1:逗号分隔的 URL 列表
        repo_urls = os.getenv("REPO_URLS", "")
        if repo_urls:
            for url in repo_urls.split(","):
                url = url.strip()
                if url:
                    name = self._extract_repo_name(url)
                    self.repos.append(RepoConfig(name=name, url=url))

        # 方式2:带编号的配置
        idx = 1
        while True:
            url_key = f"REPO_{idx}_URL"
            url = os.getenv(url_key, "")
            if not url:
                break

            name = os.getenv(f"REPO_{idx}_NAME", "")
            if not name:
                name = self._extract_repo_name(url)

            branch = os.getenv(f"REPO_{idx}_BRANCH", "main")
            auto_update = os.getenv(f"REPO_{idx}_AUTO_UPDATE", "true").lower() == "true"

            self.repos.append(RepoConfig(
                name=name,
                url=url,
                branch=branch,
                auto_update=auto_update
            ))

            idx += 1

        logger.info(f"加载了 {len(self.repos)} 个仓库配置")
        return self.repos

    def _extract_repo_name(self, url: str) -> str:
        """从 URL 提取仓库名称"""
        # 统一处理 URL
        url = url.replace("git@github.com:", "https://github.com/")
        url = url.rstrip("/")

        # 提取最后一部分
        if "/" in url:
            name = url.split("/")[-1]
            # 去除 .git 后缀
            if name.endswith(".git"):
                name = name[:-4]
            return name
        return url

    def _get_repo_dir(self, name: str) -> Path:
        """获取仓库目录路径"""
        return self.base_dir / name

    def _get_zip_url(self, url: str, branch: str) -> str:
        """将 Git URL 转换为 ZIP 下载 URL"""
        # git@github.com:user/repo.git -> https://github.com/user/repo/archive/refs/heads/branch.zip
        # https://github.com/user/repo.git -> https://github.com/user/repo/archive/refs/heads/branch.zip

        url = url.replace("git@github.com:", "https://github.com/")
        url = url.removesuffix(".git")

        # 提取 owner/repo
        parts = url.rstrip("/").split("/")
        if len(parts) >= 2:
            owner = parts[-2]
            repo = parts[-1]
            return f"https://github.com/{owner}/{repo}/archive/refs/heads/{branch}.zip"

        raise ValueError(f"无法解析仓库 URL: {url}")

    def download_repo(self, repo: RepoConfig, force: bool = False) -> bool:
        """下载并解压仓库

        

        Args:

            repo: 仓库配置

            force: 是否强制重新下载

        

        Returns:

            bool: 是否成功

        """
        repo_dir = self._get_repo_dir(repo.name)
        
        # 检查仓库是否已存在且有效(无需重新下载)
        if not force and repo_dir.exists() and any(repo_dir.iterdir()):
            logger.info(f"仓库已存在,跳过下载: {repo.name}")
            return True

        try:
            # 删除已存在的目录
            if repo_dir.exists():
                logger.info(f"删除旧版本: {repo.name}")
                shutil.rmtree(repo_dir)

            repo_dir.mkdir(parents=True, exist_ok=True)

            # 下载 ZIP
            zip_url = self._get_zip_url(repo.url, repo.branch)
            zip_path = repo_dir / "repo.zip"

            logger.info(f"下载仓库: {repo.name} ({zip_url})")

            headers = {"User-Agent": "Mozilla/5.0"}
            req = urllib.request.Request(zip_url, headers=headers)

            with urllib.request.urlopen(req, timeout=120) as response:
                with open(zip_path, 'wb') as f:
                    f.write(response.read())

            # 解压
            logger.info(f"解压仓库: {repo.name}")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(repo_dir)

            # ZIP 包会多一层目录,需要移动内容
            extracted_dir = repo_dir / f"{repo.name}-{repo.branch}"
            if extracted_dir.exists():
                for item in extracted_dir.iterdir():
                    item.rename(repo_dir / item.name)
                extracted_dir.rmdir()

            # 删除 ZIP 文件
            zip_path.unlink()

            logger.info(f"仓库 {repo.name} 下载成功")
            return True

        except Exception as e:
            logger.error(f"下载仓库 {repo.name} 失败: {e}")
            # 清理失败的目录
            if repo_dir.exists():
                shutil.rmtree(repo_dir)
            return False

    def sync_all(self, parallel: bool = True, force: bool = False) -> dict:
        """同步所有仓库

        

        Args:

            parallel: 是否并行下载

            force: 是否强制重新下载所有仓库

        

        Returns:

            dict: { "success": [成功列表], "skipped": [跳过列表], "failed": [失败列表] }

        """
        if not self.repos:
            self.load_from_env()

        if not self.repos:
            logger.warning("没有配置任何仓库")
            return {"success": [], "skipped": [], "failed": []}

        results = {"success": [], "skipped": [], "failed": []}

        if parallel:
            # 并行下载
            threads = []
            for repo in self.repos:
                t = threading.Thread(target=self._sync_single, args=(repo, results, force))
                t.start()
                threads.append(t)

            for t in threads:
                t.join()
        else:
            # 顺序下载
            for repo in self.repos:
                self._sync_single(repo, results, force)

        logger.info(f"仓库同步完成: 成功 {len(results['success'])}, 跳过 {len(results['skipped'])}, 失败 {len(results['failed'])}")
        return results

    def _sync_single(self, repo: RepoConfig, results: dict, force: bool = False):
        """同步单个仓库(线程安全)"""
        # 先检查是否已存在且有效(非强制模式)
        if not force:
            repo_dir = self._get_repo_dir(repo.name)
            if repo_dir.exists() and any(repo_dir.iterdir()):
                with self._lock:
                    results["skipped"].append(repo.name)
                return
        
        if self.download_repo(repo, force):
            with self._lock:
                results["success"].append(repo.name)
        else:
            with self._lock:
                results["failed"].append(repo.name)

    def get_repo_list(self) -> List[dict]:
        """获取已下载的仓库列表"""
        repos = []
        for item in self.base_dir.iterdir():
            if item.is_dir():
                repos.append({
                    "name": item.name,
                    "path": str(item)
                })
        return repos

    def remove_repo(self, name: str) -> bool:
        """删除仓库"""
        repo_dir = self._get_repo_dir(name)
        if repo_dir.exists():
            try:
                shutil.rmtree(repo_dir)
                logger.info(f"删除仓库: {name}")
                return True
            except Exception as e:
                logger.error(f"删除仓库 {name} 失败: {e}")
                return False
        return False

    def clear_all(self) -> int:
        """清空所有仓库"""
        count = 0
        for item in self.base_dir.iterdir():
            if item.is_dir():
                try:
                    shutil.rmtree(item)
                    count += 1
                except Exception as e:
                    logger.warning(f"删除 {item.name} 失败: {e}")
        logger.info(f"清空仓库: 删除 {count} 个")
        return count


def sync_repos_on_startup():
    """启动时同步仓库"""
    manager = RepoManager()
    repos = manager.load_from_env()

    if repos:
        logger.info(f"启动同步 {len(repos)} 个仓库...")
        results = manager.sync_all(parallel=True)
        return results
    else:
        logger.info("未配置仓库,跳过同步")
        return {"success": [], "failed": []}


if __name__ == "__main__":
    # 测试
    logging.basicConfig(level=logging.INFO)

    # 示例:设置环境变量后测试
    os.environ["REPO_1_URL"] = "https://github.com/psyche/astronomy.git"
    os.environ["REPO_1_NAME"] = "astronomy"
    os.environ["REPO_1_BRANCH"] = "main"

    manager = RepoManager("./test-code")
    manager.load_from_env()

    print(f"配置了 {len(manager.repos)} 个仓库:")
    for repo in manager.repos:
        print(f"  - {repo.name}: {repo.url}")

    results = manager.sync_all()
    print(f"\n同步结果: 成功 {len(results['success'])}, 失败 {len(results['failed'])}")