liumaolin commited on
Commit
8064340
·
1 Parent(s): 5357d86

feat(api): enhance task file management and download endpoints

Browse files

- Update `/tasks/{task_id}/outputs` to include inference configuration details
- Replace `/tasks/{task_id}/outputs/{filename}` with `/tasks/{task_id}/outputs/{file_type}/{filename}` for flexible file downloads
- Support downloading inference outputs, reference audio, and model files
- Add `FileType` validation and path security checks in `TaskService`
- Implement `download_file` method to handle various file types securely

api_server/app/api/v1/endpoints/tasks.py CHANGED
@@ -10,13 +10,13 @@ API 列表:
10
  - DELETE /tasks/{task_id} 取消任务
11
  - GET /tasks/{task_id}/progress SSE 进度订阅
12
  - GET /tasks/{task_id}/outputs 获取推理输出列表
13
- - GET /tasks/{task_id}/outputs/{filename} 下载推理输出文件
14
  """
15
 
16
  import json
17
- from typing import Optional
18
 
19
- from fastapi import APIRouter, Depends, HTTPException, Query
20
  from fastapi.responses import StreamingResponse, Response
21
 
22
  from ....models.schemas.task import (
@@ -238,18 +238,28 @@ async def subscribe_progress(
238
  response_model=InferenceOutputsResponse,
239
  summary="获取推理输出列表",
240
  description="""
241
- 获取任务的推理输出文件列表。
242
 
243
  训练任务完成后,推理阶段会生成测试音频文件。此端点返回所有生成的音频文件元信息,
244
- 包括文件名、使用的模型、文件大小等。
245
 
246
- **返回信息**:
 
 
 
 
 
247
  - `filename`: 文件名
248
  - `gpt_model`: 使用的 GPT 模型名称
249
  - `sovits_model`: 使用的 SoVITS 模型名称
250
- - `file_path`: 文件相对路径
 
 
251
  - `size_bytes`: 文件大小(字节)
252
  - `created_at`: 创建时间
 
 
 
253
  """,
254
  responses={
255
  200: {"model": InferenceOutputsResponse, "description": "推理输出列表"},
@@ -269,32 +279,47 @@ async def get_task_outputs(
269
  return result
270
 
271
 
 
 
 
 
272
  @router.get(
273
- "/{task_id}/outputs/{filename}",
274
- summary="下载推理输出文件",
275
  description="""
276
- 下载指定推理输出音频文件。
 
 
 
 
 
 
277
 
278
- 文件名可从 `/tasks/{task_id}/outputs` 端点获取。
 
 
 
 
279
 
280
  **返回**:
281
- - Content-Type: audio/wav
282
- - 音频文件二进制数据
283
  """,
284
  responses={
285
- 200: {"description": "音频文件", "content": {"audio/wav": {}}},
286
  404: {"model": ErrorResponse, "description": "任务或文件不存在"},
287
  },
288
  )
289
- async def download_task_output(
290
  task_id: str,
291
- filename: str,
 
292
  service: TaskService = Depends(get_task_service),
293
  ) -> Response:
294
  """
295
- 下载指定的推理输出文件
296
  """
297
- result = await service.download_inference_output(task_id, filename)
298
  if result is None:
299
  raise HTTPException(status_code=404, detail="任务或文件不存在")
300
 
 
10
  - DELETE /tasks/{task_id} 取消任务
11
  - GET /tasks/{task_id}/progress SSE 进度订阅
12
  - GET /tasks/{task_id}/outputs 获取推理输出列表
13
+ - GET /tasks/{task_id}/outputs/{file_type}/{filename} 下载任务相关文件
14
  """
15
 
16
  import json
17
+ from typing import Literal, Optional
18
 
19
+ from fastapi import APIRouter, Depends, HTTPException, Path, Query
20
  from fastapi.responses import StreamingResponse, Response
21
 
22
  from ....models.schemas.task import (
 
238
  response_model=InferenceOutputsResponse,
239
  summary="获取推理输出列表",
240
  description="""
241
+ 获取任务的推理输出文件列表及推理配置信息
242
 
243
  训练任务完成后,推理阶段会生成测试音频文件。此端点返回所有生成的音频文件元信息,
244
+ 包括文件名、使用的模型路径、文件大小等,以及推理使用的参考音频和文本信息
245
 
246
+ **推理配置**:
247
+ - `ref_text`: 参考音频的文本内容
248
+ - `ref_audio_path`: 参考音频文件路径
249
+ - `target_text`: 合成的目标文本
250
+
251
+ **输出文件信息**:
252
  - `filename`: 文件名
253
  - `gpt_model`: 使用的 GPT 模型名称
254
  - `sovits_model`: 使用的 SoVITS 模型名称
255
+ - `gpt_path`: GPT 模型完整路径
256
+ - `sovits_path`: SoVITS 模型完整路径
257
+ - `file_path`: 输出文件相对路径
258
  - `size_bytes`: 文件大小(字节)
259
  - `created_at`: 创建时间
260
+
261
+ **下载文件**:
262
+ 使用 `/tasks/{task_id}/outputs/{file_type}/{filename}` 端点下载相关文件。
263
  """,
264
  responses={
265
  200: {"model": InferenceOutputsResponse, "description": "推理输出列表"},
 
279
  return result
280
 
281
 
282
+ # 文件类型定义
283
+ FileType = Literal["output", "ref_audio", "gpt_model", "sovits_model"]
284
+
285
+
286
  @router.get(
287
+ "/{task_id}/outputs/{file_type}/{filename:path}",
288
+ summary="下载任务相关文件",
289
  description="""
290
+ 下载任务相关各类文件。
291
+
292
+ **文件类型 (file_type)**:
293
+ - `output` - 推理输出音频文件 (.wav)
294
+ - `ref_audio` - 参考音频文件 (.wav)
295
+ - `gpt_model` - GPT 模型文件 (.ckpt)
296
+ - `sovits_model` - SoVITS 模型文件 (.pth)
297
 
298
+ **文件名来源**:
299
+ - `output`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].filename` 获取
300
+ - `ref_audio`: 从 `/tasks/{task_id}/outputs` 端点的 `ref_audio_path` 获取
301
+ - `gpt_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].gpt_path` 获取文件名部分
302
+ - `sovits_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].sovits_path` 获取文件名部分
303
 
304
  **返回**:
305
+ - 音频文件: Content-Type: audio/wav
306
+ - 模型文件: Content-Type: application/octet-stream
307
  """,
308
  responses={
309
+ 200: {"description": "文件内容"},
310
  404: {"model": ErrorResponse, "description": "任务或文件不存在"},
311
  },
312
  )
313
+ async def download_task_file(
314
  task_id: str,
315
+ file_type: FileType = Path(..., description="文件类型: output/ref_audio/gpt_model/sovits_model"),
316
+ filename: str = Path(..., description="文件名或路径"),
317
  service: TaskService = Depends(get_task_service),
318
  ) -> Response:
319
  """
320
+ 下载任务相关文件(推理输出、参考音频、模型文件
321
  """
322
+ result = await service.download_file(task_id, file_type, filename)
323
  if result is None:
324
  raise HTTPException(status_code=404, detail="任务或文件不存在")
325
 
api_server/app/services/task_service.py CHANGED
@@ -447,6 +447,68 @@ class TaskService:
447
 
448
  return file_data, filename, "audio/wav"
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  def _parse_inference_filename(self, filename: str, exp_name: str) -> Tuple[str, str]:
451
  """
452
  解析推理输出文件名,提取 GPT 和 SoVITS 模型名称
 
447
 
448
  return file_data, filename, "audio/wav"
449
 
450
+ async def download_file(
451
+ self,
452
+ task_id: str,
453
+ file_type: str,
454
+ filename: str
455
+ ) -> Optional[Tuple[bytes, str, str]]:
456
+ """
457
+ 下载指定类型的文件
458
+
459
+ Args:
460
+ task_id: 任务 ID
461
+ file_type: 文件类型 (output/ref_audio/gpt_model/sovits_model)
462
+ filename: 文件名
463
+
464
+ Returns:
465
+ (文件内容, 文件名, content_type) 或 None(不存在时)
466
+ """
467
+ # 获取任务
468
+ task = await self.db.get_task(task_id)
469
+ if not task:
470
+ return None
471
+
472
+ exp_name = task.exp_name
473
+ version = task.config.get("version", "v2")
474
+
475
+ # 根据文件类型确定路径和 content_type
476
+ if file_type == "output":
477
+ file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename
478
+ content_type = "audio/wav"
479
+ elif file_type == "ref_audio":
480
+ # ref_audio 使用完整路径(filename 参数实际上是完整路径)
481
+ file_path = Path(filename)
482
+ content_type = "audio/wav"
483
+ elif file_type == "gpt_model":
484
+ gpt_dir = self._get_gpt_weight_dir(version)
485
+ file_path = Path(settings.EXP_ROOT) / exp_name / gpt_dir / filename
486
+ content_type = "application/octet-stream"
487
+ elif file_type == "sovits_model":
488
+ sovits_dir = self._get_sovits_weight_dir(version)
489
+ file_path = Path(settings.EXP_ROOT) / exp_name / sovits_dir / filename
490
+ content_type = "application/octet-stream"
491
+ else:
492
+ return None
493
+
494
+ # 安全检查:确保文件路径有效
495
+ try:
496
+ file_path = file_path.resolve()
497
+ except (ValueError, OSError):
498
+ return None
499
+
500
+ if not file_path.exists() or not file_path.is_file():
501
+ return None
502
+
503
+ # 读取文件内容
504
+ with open(file_path, "rb") as f:
505
+ file_data = f.read()
506
+
507
+ # 使用文件名(不含路径)作为下载文件名
508
+ download_filename = file_path.name
509
+
510
+ return file_data, download_filename, content_type
511
+
512
  def _parse_inference_filename(self, filename: str, exp_name: str) -> Tuple[str, str]:
513
  """
514
  解析推理输出文件名,提取 GPT 和 SoVITS 模型名称