kai / app.py
Chrunos's picture
Update app.py
3a6e605 verified
import os
import tempfile
import urllib.parse
from datetime import datetime
from pathlib import Path
import logging
from fastapi import FastAPI, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.staticfiles import StaticFiles
import yt_dlp
import cloudscraper
from dotenv import load_dotenv
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
tmp_dir = tempfile.gettempdir()
BASE_URL = "https://chrunos-kai.hf.space"
load_dotenv()
app = FastAPI()
# Define a global temporary download directory
global_download_dir = tempfile.mkdtemp()
EXTRACT_API = os.getenv("EXTRACT_API")
ALT_API = os.getenv("ALT_API")
def extract_video_info(video_url: str) -> dict:
"""Extract video information from the provided URL using fallback APIs."""
api_urls = [f'{ALT_API}?url={video_url}', f'{EXTRACT_API}?url={video_url}']
for api_url in api_urls:
logger.info(f"Trying API: {api_url}")
session = cloudscraper.create_scraper()
try:
response = session.get(api_url, timeout=20)
if response.status_code == 200:
json_response = response.json()
result = []
if 'formats' in json_response:
for format_item in json_response['formats']:
format_url = format_item.get('url')
format_id = format_item.get('format_id')
p_cookies = format_item.get('cookies')
if format_id and format_url:
result.append({
"url": format_url,
"format_id": format_id,
"cookies": p_cookies
})
title = json_response.get('title')
logger.info(f"Video title: {title}")
if "ornhub.com" in video_url:
p_result = [item for item in result if 'hls' in item['format_id']]
return p_result
else:
if len(result) == 1:
new_item = {
"format_id": "This is Fake, Don't Choose This One",
"url": "none"
}
result.append(new_item)
return result
else:
if 'url' in json_response:
d_url = json_response.get('url')
t_url = json_response.get('thumbnail')
result.append({
"url": d_url,
"format_id": "video"
})
result.append({
"url": t_url,
"format_id": "thumbnail"
})
return result
else:
return {"error": "No formats available"}
else:
logger.warning(f"Request failed with status code {response.status_code}, API: {api_url}")
except Exception as e:
logger.error(f"An error occurred with API {api_url}: {e}")
return {"error": "Both APIs failed to provide valid results."}
@app.post("/pripper")
async def test_download(request: Request):
"""Test endpoint to extract video information."""
try:
data = await request.json()
video_url = data.get('url')
if not video_url:
return {"error": "URL parameter is required"}
response = extract_video_info(video_url)
return response
except Exception as e:
logger.error(f"Error in test_download: {e}")
return {"error": f"Failed to process request: {str(e)}"}
@app.post("/hls")
async def download_hls_video(request: Request):
"""Download HLS video and return download URL."""
try:
data = await request.json()
hls_url = data.get('url')
if not hls_url:
return {"error": "URL parameter is required"}
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
# 优化点 1:放弃 %(title)s,对于 m3u8,直接使用固定前缀+时间戳更稳妥
# 否则 %(title)s 解析出奇怪的路径符会导致报错
output_template = str(Path(global_download_dir) / f'hls_video_{timestamp}.%(ext)s')
ydl_opts = {
'format': 'best',
'outtmpl': output_template,
'quiet': True,
'no_warnings': True,
'noprogress': True,
# 优化点 2:不要用 merge_output_format,改用后处理(postprocessors)强制转换格式
'postprocessors': [{
'key': 'FFmpegVideoConvertor',
'preferedformat': 'mp4',
}]
}
# 优化点 3:使用 extract_info 替代 download,这样能拿到确切的文件路径
def _download_task():
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
# download=True 会执行下载,并返回包含所有元数据的字典
return ydl.extract_info(hls_url, download=True)
try:
info = await run_in_threadpool(_download_task)
except Exception as e:
logger.error(f"yt-dlp download failed: {e}")
return {"error": f"Download failed: {str(e)}"}
if not info:
return {"error": "Could not extract video info"}
# 优化点 4:精准获取最终生成的文件名,告别 glob 盲猜
# 经过 postprocessor 转换后的最终文件路径通常在 requested_downloads 里
filepath = None
if 'requested_downloads' in info and len(info['requested_downloads']) > 0:
filepath = info['requested_downloads'][0].get('filepath')
# 兜底:如果没有触发后处理,拿原始文件名
if not filepath:
filepath = info.get('_filename')
if not filepath or not os.path.exists(filepath):
# 最后的倔强兜底:如果不幸找不到,搜一下任意后缀
downloaded_files = list(Path(global_download_dir).glob(f"hls_video_{timestamp}.*"))
if not downloaded_files:
return {"error": "Download failed - no files generated"}
filepath = str(downloaded_files[0])
# 提取文件名并拼接 URL
filename = os.path.basename(filepath)
encoded_filename = urllib.parse.quote(filename)
download_url = f"{BASE_URL}/file/{encoded_filename}"
logger.info(f"Download successful: {download_url}")
return {"url": download_url}
except Exception as e:
logger.error(f"Error in download_hls_video: {e}")
return {"error": f"Failed to process request: {str(e)}"}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy", "message": "Service is running"}
# Mount the static files directory
app.mount("/file", StaticFiles(directory=global_download_dir), name="downloads")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)