File size: 10,804 Bytes
8ede856 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 | import base64
import logging
import os
import shutil
import socket
import ssl
import time
import uuid
import zipfile
from pathlib import Path
import aiohttp
import certifi
import psutil
from PIL import Image
from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path
logger = logging.getLogger("astrbot")
def on_error(func, path, exc_info) -> None:
"""A callback of the rmtree function."""
import stat
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise exc_info[1]
def remove_dir(file_path: str) -> bool:
if not os.path.exists(file_path):
return True
shutil.rmtree(file_path, onerror=on_error)
return True
def port_checker(port: int, host: str = "localhost") -> bool:
sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sk.settimeout(1)
try:
sk.connect((host, port))
sk.close()
return True
except Exception:
sk.close()
return False
def save_temp_img(img: Image.Image | bytes) -> str:
temp_dir = get_astrbot_temp_path()
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = os.path.join(temp_dir, f"io_temp_img_{timestamp}.jpg")
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
return p
async def download_image_by_url(
url: str,
post: bool = False,
post_data: dict | None = None,
path: str | None = None,
) -> str:
"""下载图片, 返回 path"""
try:
ssl_context = ssl.create_default_context(
cafile=certifi.where(),
) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
async with aiohttp.ClientSession(
trust_env=True,
connector=connector,
) as session:
if post:
async with session.post(url, json=post_data) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证(仅在证书验证失败时作为fallback)
logger.warning(
f"SSL certificate verification failed for {url}. "
"Disabling SSL verification (CERT_NONE) as a fallback. "
"This is insecure and exposes the application to man-in-the-middle attacks. "
"Please investigate and resolve certificate issues."
)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession() as session:
if post:
async with session.post(url, json=post_data, ssl=ssl_context) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url, ssl=ssl_context) as resp:
if not path:
return save_temp_img(await resp.read())
with open(path, "wb") as f:
f.write(await resp.read())
return path
except Exception as e:
raise e
async def download_file(url: str, path: str, show_progress: bool = False) -> None:
"""从指定 url 下载文件到指定路径 path"""
try:
ssl_context = ssl.create_default_context(
cafile=certifi.where(),
) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True,
connector=connector,
) as session:
async with session.get(url, timeout=1800) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = (
time.time() - start_time
if time.time() - start_time > 0
else 1
)
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证(仅在证书验证失败时作为fallback)
logger.warning(
"SSL 证书验证失败,已关闭 SSL 验证(不安全,仅用于临时下载)。请检查目标服务器的证书配置。"
)
logger.warning(
f"SSL certificate verification failed for {url}. "
"Falling back to unverified connection (CERT_NONE). "
"This is insecure and exposes the application to man-in-the-middle attacks. "
"Please investigate certificate issues with the remote server."
)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get("content-length", 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, "wb") as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
if show_progress:
print()
def file_to_base64(file_path: str) -> str:
with open(file_path, "rb") as f:
data_bytes = f.read()
base64_str = base64.b64encode(data_bytes).decode()
return "base64://" + base64_str
def get_local_ip_addresses():
net_interfaces = psutil.net_if_addrs()
network_ips = []
for interface, addrs in net_interfaces.items():
for addr in addrs:
if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET
network_ips.append(addr.address)
return network_ips
async def get_dashboard_version():
# First check user data directory (manually updated / downloaded dashboard).
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
if not os.path.exists(dist_dir):
# Fall back to the dist bundled inside the installed wheel.
_bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist"
if _bundled.exists():
dist_dir = str(_bundled)
if os.path.exists(dist_dir):
version_file = os.path.join(dist_dir, "assets", "version")
if os.path.exists(version_file):
with open(version_file, encoding="utf-8") as f:
v = f.read().strip()
return v
return None
async def download_dashboard(
path: str | None = None,
extract_path: str = "data",
latest: bool = True,
version: str | None = None,
proxy: str | None = None,
) -> None:
"""下载管理面板文件"""
if path is None:
zip_path = Path(get_astrbot_data_path()).absolute() / "dashboard.zip"
else:
zip_path = Path(path).absolute()
if latest or len(str(version)) != 40:
ver_name = "latest" if latest else version
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
logger.info(
f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}",
)
try:
await download_file(
dashboard_release_url,
str(zip_path),
show_progress=True,
)
except BaseException as _:
if latest:
dashboard_release_url = "https://github.com/AstrBotDevs/AstrBot/releases/latest/download/dist.zip"
else:
dashboard_release_url = f"https://github.com/AstrBotDevs/AstrBot/releases/download/{version}/dist.zip"
if proxy:
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
await download_file(
dashboard_release_url,
str(zip_path),
show_progress=True,
)
else:
url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
if proxy:
url = f"{proxy}/{url}"
await download_file(url, str(zip_path), show_progress=True)
with zipfile.ZipFile(zip_path, "r") as z:
z.extractall(extract_path)
|