gate / app.py
harii66's picture
Upload 23 files
b4edbc0 verified
import time
import json
import secrets
import asyncio
import httpx
from pathlib import Path
from datetime import datetime, timedelta
from fastapi import FastAPI, Request, Header, Depends, HTTPException, status
from fastapi.responses import JSONResponse, Response, FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from pydantic import BaseModel
from typing import Optional, List
from config import Config
from cache_manager import cache
from user_manager import user_manager, User, AVAILABLE_BADGES
from proxy_handler import (
proxy_media,
proxy_live_stream_direct,
proxy_playback_stream,
get_live_m3u8_url
)
from utils import get_auth, get_channels, get_jst_date, fetch_epg, get_all_epg
app = FastAPI(
title=Config.APP_NAME,
version=Config.APP_VERSION,
description=Config.APP_DESCRIPTION
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["Content-Length", "Content-Range", "Accept-Ranges", "Content-Disposition"]
)
if Config.ENABLE_GZIP:
app.add_middleware(GZipMiddleware, minimum_size=1000)
static_path = Path(__file__).parent / "static"
if static_path.exists():
app.mount("/static", StaticFiles(directory=str(static_path)), name="static")
admin_tokens = {}
def create_admin_token() -> str:
token = secrets.token_urlsafe(32)
expiry = datetime.now() + timedelta(hours=24)
admin_tokens[token] = expiry
return token
def verify_admin_token(token: str) -> bool:
if not token:
return False
now = datetime.now()
expired = [t for t, exp in admin_tokens.items() if exp < now]
for t in expired:
del admin_tokens[t]
if token not in admin_tokens:
return False
expiry = admin_tokens[token]
now = datetime.now()
if now > expiry:
del admin_tokens[token]
return False
return True
def get_admin_token(authorization: Optional[str]) -> Optional[str]:
if not authorization:
return None
if authorization.startswith("Bearer "):
return authorization[7:]
return authorization
def get_current_admin_token(authorization: Optional[str] = Header(None)) -> str:
token = get_admin_token(authorization)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No token provided"
)
if not verify_admin_token(token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token"
)
return token
class PasswordVerify(BaseModel):
username: str
password_hash: str
class AdminLogin(BaseModel):
username: str
password_hash: str
class CreateUserRequest(BaseModel):
username: str
password: Optional[str] = None
expires_days: Optional[int] = None
notes: str = ""
badge: Optional[str] = None
is_admin: bool = False
class ExtendExpiryRequest(BaseModel):
days: int
class SetBadgeRequest(BaseModel):
badge: Optional[str] = None
class UserSettings(BaseModel):
favorite_channels: Optional[List[str]] = None
playback_history: Optional[dict] = None
program_reminders: Optional[List[dict]] = None
download_concurrency: Optional[int] = None
batch_download_concurrency: Optional[int] = None
fab_position: Optional[dict] = None
other_settings: Optional[dict] = None
@app.middleware("http")
async def protocol_middleware(request: Request, call_next):
forwarded_proto = request.headers.get('X-Forwarded-Proto', '')
forwarded_host = request.headers.get('X-Forwarded-Host', '')
forwarded_port = request.headers.get('X-Forwarded-Port', '')
if forwarded_proto:
request.scope['scheme'] = forwarded_proto
if forwarded_host:
port = 443 if forwarded_proto == 'https' else 80
if forwarded_port:
try:
port = int(forwarded_port)
except:
pass
request.scope['server'] = (forwarded_host, port)
response = await call_next(request)
return response
@app.middleware("http")
async def performance_middleware(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = int((time.time() - start_time) * 1000)
response.headers["X-Response-Time"] = f"{process_time}ms"
if request.url.path.startswith('/static/'):
response.headers['Cache-Control'] = 'public, max-age=86400'
if request.url.path.startswith('/api/') or request.url.path.startswith('/live/') or request.url.path.startswith('/vod/'):
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS, DELETE'
response.headers['Access-Control-Allow-Headers'] = 'Authorization, Content-Type, Range'
return response
@app.get("/")
async def root():
html_path = Path(__file__).parent / "static" / "index.html"
if html_path.exists():
return FileResponse(html_path)
return {"message": "Frontend not found"}
@app.get("/channels")
async def channels_page():
return await root()
@app.get("/player")
async def player_page():
return await root()
@app.get("/epg")
async def epg_page():
return await root()
@app.get("/cache")
async def cache_page():
return await root()
@app.get("/api-test")
async def api_test_page():
return await root()
@app.get("/admin")
async def admin_page():
html_path = Path(__file__).parent / "static" / "admin.html"
if html_path.exists():
return FileResponse(html_path)
return {"message": "Admin page not found"}
@app.get("/admin/login")
async def admin_login_page():
html_path = Path(__file__).parent / "static" / "admin-login.html"
if html_path.exists():
return FileResponse(html_path)
return {"message": "Admin login page not found"}
@app.post("/api/verify-password")
async def verify_password(data: PasswordVerify):
try:
# ✅ 检查是否是配置文件中的管理员
if (data.username == Config.ADMIN_USERNAME and
data.password_hash == Config.ADMIN_PASSWORD_HASH):
return {
"success": True,
"message": "Admin login successful",
"user": {
"username": data.username,
"is_admin": True, # ✅ 配置文件管理员
"badge": None
}
}
# ✅ 检查数据库中的用户
if data.username and user_manager.verify_user(data.username, data.password_hash):
user = user_manager.get_user(data.username)
if not user:
return {"success": False, "message": "User not found"}
user_data = user_manager.get_user_data(data.username)
return {
"success": True,
"message": "User login successful",
"user": {
"username": data.username,
"is_admin": user.is_admin, # ✅ 从数据库读取 is_admin 字段
"badge": user.badge if user and user.badge else None
},
"user_data": user_data
}
return {"success": False, "message": "Invalid username or password"}
except Exception as e:
return JSONResponse(
content={"success": False, "message": str(e)},
status_code=500
)
@app.get("/api/badges")
async def get_badges():
return {
"success": True,
"badges": AVAILABLE_BADGES
}
@app.post("/api/admin/login")
async def admin_login(data: AdminLogin):
try:
if (data.username == Config.ADMIN_USERNAME and
data.password_hash == Config.ADMIN_PASSWORD_HASH):
token = create_admin_token()
return {
"success": True,
"token": token,
"message": "Login successful"
}
else:
return JSONResponse(
content={"success": False, "message": "Invalid credentials"},
status_code=401
)
except Exception as e:
return JSONResponse(
content={"success": False, "message": str(e)},
status_code=500
)
@app.get("/api/admin/check")
async def admin_check(authorization: Optional[str] = Header(None)):
token = get_admin_token(authorization)
if token and verify_admin_token(token):
return {"authenticated": True}
return JSONResponse(
content={"authenticated": False},
status_code=401
)
@app.get("/api/admin/badges")
async def admin_get_badges(token: str = Depends(get_current_admin_token)):
try:
return {
"success": True,
"badges": AVAILABLE_BADGES
}
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.get("/api/admin/stats")
async def admin_stats(token: str = Depends(get_current_admin_token)):
try:
stats = user_manager.get_stats()
return stats
except Exception as e:
return JSONResponse(
content={"error": str(e)},
status_code=500
)
@app.get("/api/admin/users")
async def admin_list_users(token: str = Depends(get_current_admin_token)):
try:
users = user_manager.list_users()
return {
"success": True,
"count": len(users),
"users": [u.dict() for u in users]
}
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.post("/api/admin/users")
async def admin_create_user(data: CreateUserRequest, token: str = Depends(get_current_admin_token)):
try:
if len(user_manager.users) >= Config.MAX_USERS:
return JSONResponse(
content={"error": f"Maximum {Config.MAX_USERS} users allowed"},
status_code=400
)
user, plain_password = user_manager.create_user(
username=data.username,
password=data.password,
expires_days=data.expires_days,
notes=data.notes,
badge=data.badge,
is_admin=data.is_admin
)
return {
"success": True,
"user": user.dict(),
"password": plain_password
}
except ValueError as e:
return JSONResponse(
content={"error": str(e)},
status_code=400
)
except Exception as e:
return JSONResponse(
content={"error": str(e)},
status_code=500
)
@app.delete("/api/admin/users/{username}")
async def admin_delete_user(username: str, token: str = Depends(get_current_admin_token)):
if user_manager.delete_user(username):
# ✅ 同时删除用户设置
user_manager.delete_user_settings(username)
return {"success": True, "message": f"User {username} deleted"}
return JSONResponse(
content={"error": "User not found"},
status_code=404
)
@app.post("/api/admin/users/{username}/activate")
async def admin_activate_user(username: str, token: str = Depends(get_current_admin_token)):
if user_manager.activate_user(username):
return {"success": True, "message": f"User {username} activated"}
return JSONResponse(
content={"error": "User not found"},
status_code=404
)
@app.post("/api/admin/users/{username}/deactivate")
async def admin_deactivate_user(username: str, token: str = Depends(get_current_admin_token)):
if user_manager.deactivate_user(username):
return {"success": True, "message": f"User {username} deactivated"}
return JSONResponse(
content={"error": "User not found"},
status_code=404
)
@app.post("/api/admin/users/{username}/extend")
async def admin_extend_expiry(username: str, data: ExtendExpiryRequest, token: str = Depends(get_current_admin_token)):
if user_manager.extend_expiry(username, data.days):
return {
"success": True,
"message": f"Extended {username} expiry by {data.days} days"
}
return JSONResponse(
content={"error": "User not found"},
status_code=404
)
@app.post("/api/admin/users/{username}/badge")
async def admin_set_badge(username: str, data: SetBadgeRequest, token: str = Depends(get_current_admin_token)):
try:
if user_manager.set_badge(username, data.badge):
return {
"success": True,
"message": f"Badge updated for {username}"
}
return JSONResponse(
content={"error": "User not found"},
status_code=404
)
except ValueError as e:
return JSONResponse(
content={"error": str(e)},
status_code=400
)
except Exception as e:
return JSONResponse(
content={"error": str(e)},
status_code=500
)
# ==================== 用户设置API ====================
@app.get("/api/user/{username}/settings")
async def get_user_settings(username: str):
"""获取用户设置"""
print("\n" + "=" * 80)
print(f"📥 [API] 收到读取请求")
print(f" URL: /api/user/{username}/settings")
print(f" 用户名: {username}")
print("=" * 80)
try:
settings = user_manager.get_user_settings(username)
print(f"📤 [API] 返回数据: {list(settings.keys())}")
print("=" * 80 + "\n")
return {
"success": True,
"settings": settings
}
except Exception as e:
print(f"❌ [API] 异常: {e}")
import traceback
traceback.print_exc()
print("=" * 80 + "\n")
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
# ==================== 用户数据同步接口(内部使用)====================
class UserDataSync(BaseModel):
username: str
data: dict
@app.post("/api/user/data/sync")
async def sync_user_data(payload: UserDataSync):
"""同步用户数据到 Redis(内部接口)"""
print(f"📡 [SYNC] 收到用户数据同步请求: {payload.username}")
print(f" 数据字段: {list(payload.data.keys())}")
try:
success = user_manager.update_user_data(payload.username, payload.data)
if success:
print(f"✅ [SYNC] 用户 {payload.username} 数据同步成功")
return {
"success": True,
"message": "数据已实时同步到Redis"
}
else:
print(f"❌ [SYNC] 用户 {payload.username} 不存在")
return JSONResponse(
content={"success": False, "error": "用户不存在"},
status_code=404
)
except Exception as e:
print(f"❌ [SYNC] 同步失败: {e}")
import traceback
traceback.print_exc()
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
# ==================== 用户行为跟踪接口 ====================
class UserBehaviorLog(BaseModel):
username: str
action: str # 'play', 'download', 'favorite', 'search', 'setting_change', etc.
data: dict # 相关数据
@app.post("/api/user/behavior/track")
async def track_user_behavior(payload: UserBehaviorLog):
"""实时跟踪用户行为并保存到Redis"""
print(f"📊 [BEHAVIOR] 跟踪用户行为: {payload.username} - {payload.action}")
try:
# 获取当前用户数据
user_data = user_manager.get_user_data(payload.username)
if not user_data:
return JSONResponse(
content={"success": False, "error": "用户不存在"},
status_code=404
)
# 根据行为类型更新相应数据
update_data = {}
if payload.action == 'play':
# 更新播放历史
playback_history = user_data.get('playback_history', [])
playback_entry = {
'timestamp': datetime.now().isoformat(),
'channel_id': payload.data.get('channel_id'),
'channel_name': payload.data.get('channel_name'),
'duration': payload.data.get('duration', 0)
}
playback_history.insert(0, playback_entry)
# 保留最近100条记录
playback_history = playback_history[:100]
update_data['playback_history'] = playback_history
elif payload.action == 'favorite':
# 更新收藏频道
favorite_channels = payload.data.get('favorite_channels', [])
update_data['favorite_channels'] = favorite_channels
elif payload.action == 'setting_change':
# 更新设置
for key, value in payload.data.items():
if key in ['download_concurrency', 'batch_download_concurrency', 'fab_position']:
update_data[key] = value
elif payload.action == 'reminder':
# 更新节目提醒
program_reminders = payload.data.get('program_reminders', [])
update_data['program_reminders'] = program_reminders
# 实时保存到Redis
if update_data:
success = user_manager.update_user_data(payload.username, update_data)
if success:
print(f"✅ [BEHAVIOR] 用户 {payload.username} 行为数据已实时保存")
return {
"success": True,
"message": f"用户行为 '{payload.action}' 已实时保存到Redis"
}
return JSONResponse(
content={"success": False, "error": "无效的行为数据"},
status_code=400
)
except Exception as e:
print(f"❌ [BEHAVIOR] 行为跟踪失败: {e}")
import traceback
traceback.print_exc()
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.get("/health")
async def health_check():
stats = cache.get_stats()
is_valid, missing = Config.validate()
return {
"name": Config.APP_NAME,
"version": Config.APP_VERSION,
"description": Config.APP_DESCRIPTION,
"status": "running" if is_valid else "configuration_error",
"config_valid": is_valid,
"missing_config": missing if not is_valid else [],
"password_protected": len(user_manager.users) > 0,
"total_users": len(user_manager.users),
"cache": {
"storage_type": stats['storage_type'],
"cid": stats['cid'],
"auth": stats['auth'],
"channels": stats['channels'],
"streams": stats['streams'],
"epg": stats['epg'],
"epg_detail": stats.get('epg_detail')
},
"features": {
"streaming": True,
"download": True,
"live_recording": True,
"recording_mode": "Frontend Sequential Recording",
"user_management": True,
"admin_features": True,
"unified_login": True,
"cache_persistence": stats['storage_type'] in ['redis', 'disk'],
"user_settings_sync": True,
"auto_refresh": {
"cid": "1 day (auto refresh on expire)",
"auth": "3 hours (auto refresh on expire or 401/403)",
"storage": stats['storage_type'].upper()
}
}
}
@app.get("/api/refresh")
async def refresh_cache(type: str = "all"):
cache.clear_cache(type)
if type in ['auth', 'all']:
try:
await get_auth(force=True)
message = f"{type.capitalize()} cache cleared and refreshed"
except Exception as e:
message = f"{type.capitalize()} cache cleared, but refresh failed: {str(e)}"
elif type == 'cid':
try:
from utils import get_cid
await get_cid(force=True)
message = "CID cache cleared and refreshed"
except Exception as e:
message = f"CID cache cleared, but refresh failed: {str(e)}"
else:
message = f"{type.capitalize()} cache cleared"
return {
"success": True,
"message": message
}
@app.get("/api/list")
async def list_channels(request: Request):
try:
auth = await get_auth()
channels = await get_channels(auth)
scheme = request.url.scheme
host = request.url.netloc
worker_base = f"{scheme}://{host}"
rewritten_channels = [
{
**ch,
"playUrl": f"{worker_base}/api/live/{ch['no']}"
}
for ch in channels
]
return {
"success": True,
"count": len(rewritten_channels),
"channels": rewritten_channels,
"cached": cache.get_channels() is not None
}
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.get("/api/epg")
async def get_epg(vid: str, date: str):
"""获取单个频道某天的EPG,优先使用缓存"""
try:
if not vid or not date:
return JSONResponse(
content={"success": False, "error": "Missing vid or date"},
status_code=400
)
auth = await get_auth()
# 直接调用 fetch_epg,它会自动处理缓存
epg_data = await fetch_epg(vid, date, auth)
return {
"success": True,
"vid": vid,
"date": date,
"count": len(epg_data),
"epg": epg_data,
"cached": cache.get_epg(vid, date) is not None
}
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.get("/api/epg/all")
async def get_all_epg_data():
"""获取所有EPG数据,优先使用缓存"""
try:
auth = await get_auth()
# get_all_epg 会自动处理缓存
all_epg = await get_all_epg(auth, force=False)
total_channels = len(all_epg)
total_programs = sum(len(programs) for programs in all_epg.values())
return {
"success": True,
"total_channels": total_channels,
"total_programs": total_programs,
"data": all_epg,
"cached": cache.get_epg('_all_', 'full') is not None
}
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.get("/api/epg/search")
async def search_epg(keyword: str, days: int = 30):
"""搜索节目,快速返回结果,后台异步缓存"""
try:
if not keyword:
return JSONResponse(
content={"success": False, "error": "Missing keyword"},
status_code=400
)
auth = await get_auth()
channels_list = await get_channels(auth)
channel_map = {ch['id']: ch for ch in channels_list}
now = datetime.now()
date_list = []
for i in range(days + 1):
date_obj = now - timedelta(days=i)
date_str = get_jst_date(date_obj)
date_list.append(date_str)
results = []
keyword_lower = keyword.lower()
cache_hits = 0
cache_misses = 0
# 检查是否有全量缓存
full_cache = cache.get_epg('_all_', 'full')
if full_cache:
# 有全量缓存,直接搜索(最快)
for channel_id, programs in full_cache.items():
channel_info = channel_map.get(channel_id)
if not channel_info:
continue
for program in programs:
program_time = program.get('time', 0)
program_date = get_jst_date(datetime.fromtimestamp(program_time))
if program_date not in date_list:
continue
title = program.get('title') or program.get('name') or ''
if keyword_lower in title.lower():
results.append({
'channel_id': channel_id,
'channel_name': channel_info['name'],
'channel_no': channel_info['no'],
'program': program,
'date': program_date
})
cache_hits += 1
else:
# 没有全量缓存,使用智能搜索策略
# 策略:只获取和搜索数据,不等待全部缓存完成
# 先从已有缓存中搜索
for channel_id, channel_info in channel_map.items():
for date_str in date_list:
cached_epg = cache.get_epg(channel_id, date_str)
if cached_epg is not None:
# 从缓存中搜索
cache_hits += 1
for program in cached_epg:
title = program.get('title') or program.get('name') or ''
if keyword_lower in title.lower():
results.append({
'channel_id': channel_id,
'channel_name': channel_info['name'],
'channel_no': channel_info['no'],
'program': program,
'date': date_str
})
else:
cache_misses += 1
# 如果没有足够的缓存,启动后台任务获取全量数据
if cache_hits == 0 or cache_misses > cache_hits:
# 后台异步获取全量EPG并缓存
asyncio.create_task(background_fetch_all_epg(auth))
# 排序结果
results.sort(key=lambda x: x['program']['time'], reverse=True)
return {
"success": True,
"keyword": keyword,
"days": days,
"total": len(results),
"results": results,
"cache_stats": {
"hits": cache_hits,
"misses": cache_misses,
"strategy": "full_cache" if full_cache else "partial_cache",
"hit_rate": f"{cache_hits * 100 // (cache_hits + cache_misses) if (cache_hits + cache_misses) > 0 else 0}%"
},
"message": "后台正在缓存数据,下次搜索会更快" if not full_cache and cache_misses > 0 else None
}
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
async def background_fetch_all_epg(auth: dict):
"""后台异步任务:获取全量EPG数据"""
try:
# 调用 get_all_epg 来获取并缓存所有数据
await get_all_epg(auth, force=False)
except Exception as e:
# 静默失败,不影响用户体验
pass
@app.get("/api/live/{chid}")
async def live_stream_info(chid: str, request: Request):
try:
auth = await get_auth()
channels = await get_channels(auth)
channel = next((ch for ch in channels if str(ch['no']) == chid), None)
if not channel:
return JSONResponse(
content={
"success": False,
"error": f"Channel {chid} not found"
},
status_code=404
)
scheme = request.url.scheme
host = request.url.netloc
worker_base = f"{scheme}://{host}"
upstream_m3u8 = await get_live_m3u8_url(chid, auth)
return {
"success": True,
"channel": {
"id": channel['id'],
"no": channel['no'],
"name": channel['name']
},
"stream": {
"m3u8": f"{worker_base}/stream/live/{chid}.m3u8",
"direct": upstream_m3u8
},
"info": {
"protocol": scheme,
"cached": cache.get_stream(f"live_{chid}") is not None
}
}
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.get("/stream/live/{chid}.m3u8")
async def live_stream_m3u8(chid: str, request: Request):
return await proxy_live_stream_direct(chid, request)
@app.get("/api/playback/{path:path}")
async def playback_stream_info(path: str, request: Request):
try:
auth = await get_auth()
scheme = request.url.scheme
host = request.url.netloc
worker_base = f"{scheme}://{host}"
clean_path = path.strip('/')
if clean_path.startswith('/'):
clean_path = clean_path[1:]
if not clean_path.startswith('query/'):
if '/' not in clean_path:
clean_path = f"query/{clean_path}"
return {
"success": True,
"playback": {
"path": f"/{clean_path}",
"m3u8": f"{worker_base}/stream/playback/{clean_path}.m3u8",
"original_path": path
},
"info": {
"protocol": scheme,
"type": "playback"
}
}
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.get("/stream/playback/{path:path}.m3u8")
async def playback_stream_m3u8(path: str, request: Request):
return await proxy_playback_stream(path, request)
@app.get("/api/download/playback/")
async def download_playback_by_path(
request: Request,
path: str,
channel: str
):
try:
auth = await get_auth()
channels = await get_channels(auth)
target_channel = None
for ch in channels:
if str(ch['no']) == str(channel):
target_channel = ch
break
if not target_channel:
raise ValueError(f"频道 {channel} 不存在")
clean_path = path.strip()
if clean_path.startswith('/'):
clean_path = clean_path[1:]
if clean_path.startswith('query/'):
clean_path = clean_path[6:]
if clean_path.endswith('.m3u8'):
clean_path = clean_path[:-6]
program_title = "Unknown"
program_time = None
found_date = None
from datetime import timezone
JST = timezone(timedelta(hours=9))
now_jst = datetime.now(JST)
for days_ago in range(0, 30):
check_date_jst = now_jst - timedelta(days=days_ago)
check_date = check_date_jst.strftime('%Y-%m-%d')
try:
epg_list = await fetch_epg(target_channel['id'], check_date, auth)
if not epg_list:
continue
for prog in epg_list:
if prog.get('path'):
prog_path = prog['path'].strip()
if prog_path.startswith('/'):
prog_path = prog_path[1:]
if prog_path.startswith('query/'):
prog_path = prog_path[6:]
if prog_path.endswith('.m3u8'):
prog_path = prog_path[:-6]
if prog_path == clean_path:
program_title = prog.get('title') or prog.get('name') or 'Unknown'
program_time = datetime.fromtimestamp(prog['time'], tz=JST)
found_date = check_date
break
if program_time:
break
except Exception as e:
continue
if not program_time:
program_time = now_jst
program_title = f"Playback_{target_channel['name']}"
def clean_text(text):
import re
text = str(text).strip()
forbidden_chars = r'[<>:"/\\|?*]'
cleaned = re.sub(forbidden_chars, '_', text)
cleaned = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', cleaned)
cleaned = re.sub(r'_+', '_', cleaned)
cleaned = cleaned.strip('_').strip()
max_length = 150
if len(cleaned) > max_length:
if '】' in cleaned[:max_length]:
pos = cleaned[:max_length].rfind('】')
cleaned = cleaned[:pos+1]
elif '【' in cleaned[:max_length]:
pos = cleaned[:max_length].rfind('【')
cleaned = cleaned[:pos]
else:
cleaned = cleaned[:max_length]
return cleaned if cleaned else "unknown"
time_str = program_time.strftime('%Y%m%d_%H%M')
channel_name = clean_text(target_channel['name'])
program_name = clean_text(program_title)
filename = f"{time_str}_{channel_name}_{program_name}.ts"
playback_path = path.strip()
if playback_path.startswith('/'):
playback_path = playback_path[1:]
if not playback_path.startswith('query/'):
playback_path = f"query/{playback_path}"
vod_host = Config.UPSTREAM_HOSTS['vod']
from urllib.parse import quote
access_token = quote(auth['access_token'])
upstream_m3u8 = f"{vod_host}/{playback_path}.m3u8?type=vod&__cross_domain_user={access_token}"
headers = {
'Referer': Config.REQUIRED_REFERER,
'User-Agent': 'Mozilla/5.0'
}
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(upstream_m3u8, headers=headers)
if resp.status_code != 200:
raise Exception(f"M3U8获取失败: HTTP {resp.status_code}")
m3u8_content = resp.text
from utils import extract_playlist_url
playlist_url = extract_playlist_url(m3u8_content, upstream_m3u8)
if not playlist_url or playlist_url == upstream_m3u8:
playlist_content = m3u8_content
playlist_url = upstream_m3u8
else:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.get(playlist_url, headers=headers)
if resp.status_code != 200:
raise Exception(f"播放列表获取失败: HTTP {resp.status_code}")
playlist_content = resp.text
base_url = playlist_url.rsplit('/', 1)[0]
ts_urls = []
for line in playlist_content.split('\n'):
line = line.strip()
if line and not line.startswith('#'):
ts_urls.append(line if line.startswith('http') else f"{base_url}/{line}")
if len(ts_urls) == 0:
raise Exception("未找到TS分段")
async def download_concurrent():
async def fetch_batch(client, batch, start_idx):
tasks = [client.get(url, headers=headers, timeout=60.0) for url in batch]
responses = await asyncio.gather(*tasks, return_exceptions=True)
results = []
for i, resp in enumerate(responses):
idx = start_idx + i
if isinstance(resp, Exception):
results.append((idx, None))
elif resp.status_code == 200:
results.append((idx, resp.content))
else:
results.append((idx, None))
return results
batch_size = 10
all_segments = {}
async with httpx.AsyncClient(
timeout=60.0,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=30)
) as client:
for i in range(0, len(ts_urls), batch_size):
batch = ts_urls[i:i+batch_size]
batch_results = await fetch_batch(client, batch, i)
for idx, content in batch_results:
if content:
all_segments[idx] = content
progress = min(i + batch_size, len(ts_urls))
percent = progress * 100 // len(ts_urls)
for i in range(len(ts_urls)):
if i in all_segments:
yield all_segments[i]
from urllib.parse import quote
encoded_filename = quote(filename)
return StreamingResponse(
download_concurrent(),
media_type="video/mp2t",
headers={
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
"Cache-Control": "no-cache",
}
)
except Exception as e:
return JSONResponse(
content={"success": False, "error": str(e)},
status_code=500
)
@app.options("/live/{path:path}")
@app.options("/vod/{path:path}")
@app.options("/query/{path:path}")
@app.options("/stream/{path:path}")
@app.options("/api/{path:path}")
async def options_handler():
return Response(
status_code=200,
headers={
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'GET, POST, OPTIONS, DELETE',
'Access-Control-Allow-Headers': 'Authorization, Content-Type, Range',
'Access-Control-Max-Age': '3600'
}
)
@app.get("/live/{path:path}")
async def proxy_live_media(path: str, request: Request):
return await proxy_media(request, f"/live/{path}")
@app.get("/vod/{path:path}")
async def proxy_vod_media(path: str, request: Request):
return await proxy_media(request, f"/vod/{path}")
@app.get("/query/{path:path}")
async def proxy_query_media(path: str, request: Request):
return await proxy_media(request, f"/query/{path}")
@app.exception_handler(404)
async def not_found_handler(request: Request, exc):
return JSONResponse(
content={"error": "Not Found", "path": request.url.path},
status_code=404
)
@app.exception_handler(500)
async def server_error_handler(request: Request, exc):
return JSONResponse(
content={"error": "Internal Server Error", "detail": "An error occurred"},
status_code=500
)
@app.on_event("startup")
async def startup_event():
print("=" * 60)
print("🚀 Media Gateway 启动")
print("=" * 60)
# 显示缓存状态
stats = cache.get_stats()
print(f"📦 存储类型: {stats['storage_type'].upper()}")
if stats['storage_type'] == 'redis':
print(" ✅ Redis 持久化已启用")
elif stats['storage_type'] == 'disk':
print(f" ✅ 磁盘缓存已启用: {cache.cache_dir}")
print(f" 📊 EPG 缓存: {stats.get('epg', 0)} 条")
else:
print(" ⚠️ 仅使用内存缓存(重启后丢失)")
# 用户管理状态
if user_manager.redis:
print("👥 用户数据: Redis 持久化")
else:
print("👥 用户数据: 内存存储")
# 配置验证
is_valid, missing = Config.validate()
if is_valid:
print("✅ 配置验证通过")
else:
print(f"⚠️ 缺少配置: {', '.join(missing)}")
# 预加载缓存(可选)
try:
print("🔄 预加载数据...")
from utils import get_cid
cid = await get_cid()
auth = await get_auth()
channels = await get_channels(auth)
print(f" ✅ 频道列表: {len(channels)} 个")
except Exception as e:
print(f" ⚠️ 预加载失败: {e}")
print("=" * 60)
print("✅ 启动完成!")
print("=" * 60)
@app.on_event("shutdown")
async def shutdown_event():
print("\n" + "=" * 60)
print("🛑 Media Gateway 关闭中...")
print("=" * 60)
# 保存缓存
if cache.storage_type == 'disk':
cache._save_to_disk(force=True)
print(f"💾 磁盘缓存已保存 ({len(cache.epg)} 条 EPG)")
# 保存用户数据
if not user_manager.redis and hasattr(user_manager, 'users'):
print(f"💾 用户数据已保存 ({len(user_manager.users)} 个用户)")
print("✅ 关闭完成")
print("=" * 60)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="error"
)