liumaolin commited on
Commit ·
8064340
1
Parent(s): 5357d86
feat(api): enhance task file management and download endpoints
Browse files- Update `/tasks/{task_id}/outputs` to include inference configuration details
- Replace `/tasks/{task_id}/outputs/{filename}` with `/tasks/{task_id}/outputs/{file_type}/{filename}` for flexible file downloads
- Support downloading inference outputs, reference audio, and model files
- Add `FileType` validation and path security checks in `TaskService`
- Implement `download_file` method to handle various file types securely
api_server/app/api/v1/endpoints/tasks.py
CHANGED
|
@@ -10,13 +10,13 @@ API 列表:
|
|
| 10 |
- DELETE /tasks/{task_id} 取消任务
|
| 11 |
- GET /tasks/{task_id}/progress SSE 进度订阅
|
| 12 |
- GET /tasks/{task_id}/outputs 获取推理输出列表
|
| 13 |
-
- GET /tasks/{task_id}/outputs/{filename} 下载
|
| 14 |
"""
|
| 15 |
|
| 16 |
import json
|
| 17 |
-
from typing import Optional
|
| 18 |
|
| 19 |
-
from fastapi import APIRouter, Depends, HTTPException, Query
|
| 20 |
from fastapi.responses import StreamingResponse, Response
|
| 21 |
|
| 22 |
from ....models.schemas.task import (
|
|
@@ -238,18 +238,28 @@ async def subscribe_progress(
|
|
| 238 |
response_model=InferenceOutputsResponse,
|
| 239 |
summary="获取推理输出列表",
|
| 240 |
description="""
|
| 241 |
-
获取任务的推理输出文件列表。
|
| 242 |
|
| 243 |
训练任务完成后,推理阶段会生成测试音频文件。此端点返回所有生成的音频文件元信息,
|
| 244 |
-
包括文件名、使用的模型、文件大小等。
|
| 245 |
|
| 246 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
- `filename`: 文件名
|
| 248 |
- `gpt_model`: 使用的 GPT 模型名称
|
| 249 |
- `sovits_model`: 使用的 SoVITS 模型名称
|
| 250 |
-
- `
|
|
|
|
|
|
|
| 251 |
- `size_bytes`: 文件大小(字节)
|
| 252 |
- `created_at`: 创建时间
|
|
|
|
|
|
|
|
|
|
| 253 |
""",
|
| 254 |
responses={
|
| 255 |
200: {"model": InferenceOutputsResponse, "description": "推理输出列表"},
|
|
@@ -269,32 +279,47 @@ async def get_task_outputs(
|
|
| 269 |
return result
|
| 270 |
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
@router.get(
|
| 273 |
-
"/{task_id}/outputs/{filename}",
|
| 274 |
-
summary="下载
|
| 275 |
description="""
|
| 276 |
-
下载
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
-
文件名
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
**返回**:
|
| 281 |
-
- Content-Type: audio/wav
|
| 282 |
-
-
|
| 283 |
""",
|
| 284 |
responses={
|
| 285 |
-
200: {"description": "
|
| 286 |
404: {"model": ErrorResponse, "description": "任务或文件不存在"},
|
| 287 |
},
|
| 288 |
)
|
| 289 |
-
async def
|
| 290 |
task_id: str,
|
| 291 |
-
|
|
|
|
| 292 |
service: TaskService = Depends(get_task_service),
|
| 293 |
) -> Response:
|
| 294 |
"""
|
| 295 |
-
下载
|
| 296 |
"""
|
| 297 |
-
result = await service.
|
| 298 |
if result is None:
|
| 299 |
raise HTTPException(status_code=404, detail="任务或文件不存在")
|
| 300 |
|
|
|
|
| 10 |
- DELETE /tasks/{task_id} 取消任务
|
| 11 |
- GET /tasks/{task_id}/progress SSE 进度订阅
|
| 12 |
- GET /tasks/{task_id}/outputs 获取推理输出列表
|
| 13 |
+
- GET /tasks/{task_id}/outputs/{file_type}/{filename} 下载任务相关文件
|
| 14 |
"""
|
| 15 |
|
| 16 |
import json
|
| 17 |
+
from typing import Literal, Optional
|
| 18 |
|
| 19 |
+
from fastapi import APIRouter, Depends, HTTPException, Path, Query
|
| 20 |
from fastapi.responses import StreamingResponse, Response
|
| 21 |
|
| 22 |
from ....models.schemas.task import (
|
|
|
|
| 238 |
response_model=InferenceOutputsResponse,
|
| 239 |
summary="获取推理输出列表",
|
| 240 |
description="""
|
| 241 |
+
获取任务的推理输出文件列表及推理配置信息。
|
| 242 |
|
| 243 |
训练任务完成后,推理阶段会生成测试音频文件。此端点返回所有生成的音频文件元信息,
|
| 244 |
+
包括文件名、使用的模型路径、文件大小等,以及推理使用的参考音频和文本信息。
|
| 245 |
|
| 246 |
+
**推理配置**:
|
| 247 |
+
- `ref_text`: 参考音频的文本内容
|
| 248 |
+
- `ref_audio_path`: 参考音频文件路径
|
| 249 |
+
- `target_text`: 合成的目标文本
|
| 250 |
+
|
| 251 |
+
**输出文件信息**:
|
| 252 |
- `filename`: 文件名
|
| 253 |
- `gpt_model`: 使用的 GPT 模型名称
|
| 254 |
- `sovits_model`: 使用的 SoVITS 模型名称
|
| 255 |
+
- `gpt_path`: GPT 模型完整路径
|
| 256 |
+
- `sovits_path`: SoVITS 模型完整路径
|
| 257 |
+
- `file_path`: 输出文件相对路径
|
| 258 |
- `size_bytes`: 文件大小(字节)
|
| 259 |
- `created_at`: 创建时间
|
| 260 |
+
|
| 261 |
+
**下载文件**:
|
| 262 |
+
使用 `/tasks/{task_id}/outputs/{file_type}/{filename}` 端点下载相关文件。
|
| 263 |
""",
|
| 264 |
responses={
|
| 265 |
200: {"model": InferenceOutputsResponse, "description": "推理输出列表"},
|
|
|
|
| 279 |
return result
|
| 280 |
|
| 281 |
|
| 282 |
+
# 文件类型定义
|
| 283 |
+
FileType = Literal["output", "ref_audio", "gpt_model", "sovits_model"]
|
| 284 |
+
|
| 285 |
+
|
| 286 |
@router.get(
|
| 287 |
+
"/{task_id}/outputs/{file_type}/{filename:path}",
|
| 288 |
+
summary="下载任务相关文件",
|
| 289 |
description="""
|
| 290 |
+
下载任务相关的各类文件。
|
| 291 |
+
|
| 292 |
+
**文件类型 (file_type)**:
|
| 293 |
+
- `output` - 推理输出音频文件 (.wav)
|
| 294 |
+
- `ref_audio` - 参考音频文件 (.wav)
|
| 295 |
+
- `gpt_model` - GPT 模型文件 (.ckpt)
|
| 296 |
+
- `sovits_model` - SoVITS 模型文件 (.pth)
|
| 297 |
|
| 298 |
+
**文件名来源**:
|
| 299 |
+
- `output`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].filename` 获取
|
| 300 |
+
- `ref_audio`: 从 `/tasks/{task_id}/outputs` 端点的 `ref_audio_path` 获取
|
| 301 |
+
- `gpt_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].gpt_path` 获取文件名部分
|
| 302 |
+
- `sovits_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].sovits_path` 获取文件名部分
|
| 303 |
|
| 304 |
**返回**:
|
| 305 |
+
- 音频文件: Content-Type: audio/wav
|
| 306 |
+
- 模型文件: Content-Type: application/octet-stream
|
| 307 |
""",
|
| 308 |
responses={
|
| 309 |
+
200: {"description": "文件内容"},
|
| 310 |
404: {"model": ErrorResponse, "description": "任务或文件不存在"},
|
| 311 |
},
|
| 312 |
)
|
| 313 |
+
async def download_task_file(
|
| 314 |
task_id: str,
|
| 315 |
+
file_type: FileType = Path(..., description="文件类型: output/ref_audio/gpt_model/sovits_model"),
|
| 316 |
+
filename: str = Path(..., description="文件名或路径"),
|
| 317 |
service: TaskService = Depends(get_task_service),
|
| 318 |
) -> Response:
|
| 319 |
"""
|
| 320 |
+
下载任务相关文件(推理输出、参考音频、模型文件)
|
| 321 |
"""
|
| 322 |
+
result = await service.download_file(task_id, file_type, filename)
|
| 323 |
if result is None:
|
| 324 |
raise HTTPException(status_code=404, detail="任务或文件不存在")
|
| 325 |
|
api_server/app/services/task_service.py
CHANGED
|
@@ -447,6 +447,68 @@ class TaskService:
|
|
| 447 |
|
| 448 |
return file_data, filename, "audio/wav"
|
| 449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
def _parse_inference_filename(self, filename: str, exp_name: str) -> Tuple[str, str]:
|
| 451 |
"""
|
| 452 |
解析推理输出文件名,提取 GPT 和 SoVITS 模型名称
|
|
|
|
| 447 |
|
| 448 |
return file_data, filename, "audio/wav"
|
| 449 |
|
| 450 |
+
async def download_file(
|
| 451 |
+
self,
|
| 452 |
+
task_id: str,
|
| 453 |
+
file_type: str,
|
| 454 |
+
filename: str
|
| 455 |
+
) -> Optional[Tuple[bytes, str, str]]:
|
| 456 |
+
"""
|
| 457 |
+
下载指定类型的文件
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
task_id: 任务 ID
|
| 461 |
+
file_type: 文件类型 (output/ref_audio/gpt_model/sovits_model)
|
| 462 |
+
filename: 文件名
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
(文件内容, 文件名, content_type) 或 None(不存在时)
|
| 466 |
+
"""
|
| 467 |
+
# 获取任务
|
| 468 |
+
task = await self.db.get_task(task_id)
|
| 469 |
+
if not task:
|
| 470 |
+
return None
|
| 471 |
+
|
| 472 |
+
exp_name = task.exp_name
|
| 473 |
+
version = task.config.get("version", "v2")
|
| 474 |
+
|
| 475 |
+
# 根据文件类型确定路径和 content_type
|
| 476 |
+
if file_type == "output":
|
| 477 |
+
file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename
|
| 478 |
+
content_type = "audio/wav"
|
| 479 |
+
elif file_type == "ref_audio":
|
| 480 |
+
# ref_audio 使用完整路径(filename 参数实际上是完整路径)
|
| 481 |
+
file_path = Path(filename)
|
| 482 |
+
content_type = "audio/wav"
|
| 483 |
+
elif file_type == "gpt_model":
|
| 484 |
+
gpt_dir = self._get_gpt_weight_dir(version)
|
| 485 |
+
file_path = Path(settings.EXP_ROOT) / exp_name / gpt_dir / filename
|
| 486 |
+
content_type = "application/octet-stream"
|
| 487 |
+
elif file_type == "sovits_model":
|
| 488 |
+
sovits_dir = self._get_sovits_weight_dir(version)
|
| 489 |
+
file_path = Path(settings.EXP_ROOT) / exp_name / sovits_dir / filename
|
| 490 |
+
content_type = "application/octet-stream"
|
| 491 |
+
else:
|
| 492 |
+
return None
|
| 493 |
+
|
| 494 |
+
# 安全检查:确保文件路径有效
|
| 495 |
+
try:
|
| 496 |
+
file_path = file_path.resolve()
|
| 497 |
+
except (ValueError, OSError):
|
| 498 |
+
return None
|
| 499 |
+
|
| 500 |
+
if not file_path.exists() or not file_path.is_file():
|
| 501 |
+
return None
|
| 502 |
+
|
| 503 |
+
# 读取文件内容
|
| 504 |
+
with open(file_path, "rb") as f:
|
| 505 |
+
file_data = f.read()
|
| 506 |
+
|
| 507 |
+
# 使用文件名(不含路径)作为下载文件名
|
| 508 |
+
download_filename = file_path.name
|
| 509 |
+
|
| 510 |
+
return file_data, download_filename, content_type
|
| 511 |
+
|
| 512 |
def _parse_inference_filename(self, filename: str, exp_name: str) -> Tuple[str, str]:
|
| 513 |
"""
|
| 514 |
解析推理输出文件名,提取 GPT 和 SoVITS 模型名称
|