import json import httpx from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any from urllib.parse import urlparse from config import Config from cache_manager import cache async def get_cid(force: bool = False) -> str: if not force: cached = cache.get_cid() if cached: return cached try: url = Config.get_cid_url() async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client: response = await client.get(url) response.raise_for_status() data = response.json() if 'cid' not in data: raise ValueError("CID not found in response") cid = data['cid'] cache.set_cid(cid) return cid except Exception as e: if cache.cid: return cache.cid raise e async def get_auth(force: bool = False, retry_count: int = 0) -> Dict[str, Any]: if not force: cached = cache.get_auth() if cached: return cached try: cid = await get_cid(force=(retry_count > 0)) login_url = Config.get_login_url(cid) async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client: response = await client.get(login_url) response.raise_for_status() data = response.json() if data.get('code') != 'OK': error_msg = data.get('message', 'Unknown error') if 'cid' in error_msg.lower() and retry_count < 2: return await get_auth(force=True, retry_count=retry_count + 1) raise ValueError(f"Login failed: {error_msg}") product_config = json.loads(data.get('product_config', '{}')) auth = { 'access_token': data['access_token'], 'vms_host': product_config['vms_host'].rstrip('/'), 'vms_uid': product_config['vms_uid'] } if not all(auth.values()): raise ValueError("Incomplete auth data") cache.set_auth(auth) return auth except Exception as e: if cache.auth and retry_count == 0: return cache.auth raise e async def get_channels(auth: Dict[str, Any], force: bool = False) -> list: if not force: cached = cache.get_channels() if cached: return cached try: url = Config.get_list_url(auth['vms_uid'], with_epg=False) headers = { 'Referer': Config.REQUIRED_REFERER, 'User-Agent': 'Mozilla/5.0' } async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client: response = await client.get(url, headers=headers) response.raise_for_status() data = response.json() channels = [ ch for ch in data.get('result', []) if ch.get('id') and ch.get('no') and ch.get('name') and ch.get('playpath') ] if not channels: raise ValueError("No channels found") cache.set_channels(channels) return channels except httpx.HTTPStatusError as e: if e.response.status_code in [401, 403]: new_auth = await get_auth(force=True) return await get_channels(new_auth, force=True) raise e except Exception as e: if cache.channels: return cache.channels raise e async def fetch_epg(vid: str, date: str, auth: dict, retry_count: int = 0) -> list: """获取EPG数据,优先从缓存读取""" # 先检查缓存 cached = cache.get_epg(vid, date) if cached is not None: return cached # 缓存未命中,从API获取 try: url = Config.get_epg_url(auth['vms_uid'], vid) headers = { 'Referer': Config.REQUIRED_REFERER, 'User-Agent': 'Mozilla/5.0' } async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client: response = await client.get(url, headers=headers) if response.status_code in [401, 403] and retry_count < 2: new_auth = await get_auth(force=True) return await fetch_epg(vid, date, new_auth, retry_count + 1) response.raise_for_status() data = response.json() if not data.get('result') or not data['result'][0].get('record_epg'): # 空数据也缓存 cache.set_epg(vid, date, []) return [] full_epg = json.loads(data['result'][0]['record_epg']) # 处理节目数据 processed_epg = [] for i, program in enumerate(full_epg): if not program.get('time'): continue if 'time_end' not in program or not program['time_end']: if i + 1 < len(full_epg) and full_epg[i + 1].get('time'): program['time_end'] = full_epg[i + 1]['time'] else: continue processed_epg.append(program) # 按天分组缓存 daily_epg = {} for program in processed_epg: dt = datetime.fromtimestamp(program['time']) date_str = get_jst_date(dt) if date_str not in daily_epg: daily_epg[date_str] = [] daily_epg[date_str].append(program) # 缓存所有日期的数据 for d, programs in daily_epg.items(): sorted_programs = sorted(programs, key=lambda x: x['time']) cache.set_epg(vid, d, sorted_programs) # 返回请求的日期数据 result = daily_epg.get(date, []) if result: return sorted(result, key=lambda x: x['time']) else: # 如果请求的日期没有数据,也缓存空结果 if date not in daily_epg: cache.set_epg(vid, date, []) return [] except Exception as e: raise e async def get_all_epg(auth: Dict[str, Any], force: bool = False) -> Dict[str, list]: """获取所有频道的EPG数据,优先使用缓存""" # 检查全量缓存 if not force: cached = cache.get_epg('_all_', 'full') if cached: return cached # 从API获取全量数据 try: url = Config.get_list_url(auth['vms_uid'], with_epg=True) headers = { 'Referer': Config.REQUIRED_REFERER, 'User-Agent': 'Mozilla/5.0' } async with httpx.AsyncClient(timeout=Config.TIMEOUT) as client: response = await client.get(url, headers=headers) response.raise_for_status() data = response.json() result = {} for channel in data.get('result', []): channel_id = channel.get('id') record_epg = channel.get('record_epg') if not channel_id: continue if not record_epg: result[channel_id] = [] continue try: epg_list = json.loads(record_epg) processed_programs = [] for i, program in enumerate(epg_list): if not program.get('time'): continue if 'time_end' not in program or not program['time_end']: if i + 1 < len(epg_list) and epg_list[i + 1].get('time'): program['time_end'] = epg_list[i + 1]['time'] else: continue processed_programs.append(program) # 按天分组缓存 daily_epg = {} for program in processed_programs: dt = datetime.fromtimestamp(program['time']) date_str = get_jst_date(dt) if date_str not in daily_epg: daily_epg[date_str] = [] daily_epg[date_str].append(program) # 缓存每一天的数据 for date, programs in daily_epg.items(): sorted_programs = sorted(programs, key=lambda x: x['time']) cache.set_epg(channel_id, date, sorted_programs) result[channel_id] = processed_programs except json.JSONDecodeError: result[channel_id] = [] continue # 缓存全量数据 cache.set_epg('_all_', 'full', result) return result except Exception as e: # 如果有缓存,返回缓存 cached = cache.get_epg('_all_', 'full') if cached: return cached return {} def get_jst_date(dt: Optional[datetime] = None) -> str: if dt is None: dt = datetime.now() jst = timezone(timedelta(hours=9)) jst_time = dt.astimezone(jst) return jst_time.strftime('%Y-%m-%d') def rewrite_m3u8(content: str, current_path: str, worker_base: str) -> str: lines = content.split('\n') output = [] if '?' in current_path: base_path_part, query_part = current_path.rsplit('?', 1) base_dir = base_path_part[:base_path_part.rfind('/') + 1] else: base_dir = current_path[:current_path.rfind('/') + 1] query_part = '' for line in lines: trimmed = line.strip() if trimmed.startswith('#') or not trimmed: output.append(line) continue if trimmed.startswith('http://') or trimmed.startswith('https://'): parsed = urlparse(trimmed) target_path = parsed.path if parsed.query: target_path += f"?{parsed.query}" elif trimmed.startswith('/'): target_path = trimmed else: target_path = base_dir + trimmed if '?' not in target_path and query_part: target_path += f"?{query_part}" output.append(worker_base + target_path) return '\n'.join(output) def extract_playlist_url(content: str, base_url: str) -> Optional[str]: for line in content.split('\n'): trimmed = line.strip() if not trimmed or trimmed.startswith('#'): continue if trimmed.startswith('http'): return trimmed if trimmed.endswith('.m3u8') or trimmed.endswith('.M3U8'): parsed = urlparse(base_url) if trimmed.startswith('/'): return f"{parsed.scheme}://{parsed.netloc}{trimmed}" else: base_path = base_url[:base_url.rfind('/') + 1] return base_path + trimmed return None