| package handler |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "net/http/httptest" |
| "strconv" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/logger" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/response" |
| middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
| "github.com/gin-gonic/gin" |
| ) |
|
|
| const ( |
| |
| modelCacheTTL = 1 * time.Hour |
| modelCacheFailedTTL = 2 * time.Minute |
| ) |
|
|
| |
| type SoraClientHandler struct { |
| genService *service.SoraGenerationService |
| quotaService *service.SoraQuotaService |
| s3Storage *service.SoraS3Storage |
| soraGatewayService *service.SoraGatewayService |
| gatewayService *service.GatewayService |
| mediaStorage *service.SoraMediaStorage |
| apiKeyService *service.APIKeyService |
|
|
| |
| modelCacheMu sync.RWMutex |
| cachedFamilies []service.SoraModelFamily |
| modelCacheTime time.Time |
| modelCacheUpstream bool |
| } |
|
|
| |
| func NewSoraClientHandler( |
| genService *service.SoraGenerationService, |
| quotaService *service.SoraQuotaService, |
| s3Storage *service.SoraS3Storage, |
| soraGatewayService *service.SoraGatewayService, |
| gatewayService *service.GatewayService, |
| mediaStorage *service.SoraMediaStorage, |
| apiKeyService *service.APIKeyService, |
| ) *SoraClientHandler { |
| return &SoraClientHandler{ |
| genService: genService, |
| quotaService: quotaService, |
| s3Storage: s3Storage, |
| soraGatewayService: soraGatewayService, |
| gatewayService: gatewayService, |
| mediaStorage: mediaStorage, |
| apiKeyService: apiKeyService, |
| } |
| } |
|
|
| |
| type GenerateRequest struct { |
| Model string `json:"model" binding:"required"` |
| Prompt string `json:"prompt" binding:"required"` |
| MediaType string `json:"media_type"` |
| VideoCount int `json:"video_count,omitempty"` |
| ImageInput string `json:"image_input,omitempty"` |
| APIKeyID *int64 `json:"api_key_id,omitempty"` |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) Generate(c *gin.Context) { |
| userID := getUserIDFromContext(c) |
| if userID == 0 { |
| response.Error(c, http.StatusUnauthorized, "未登录") |
| return |
| } |
|
|
| var req GenerateRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) |
| return |
| } |
|
|
| if req.MediaType == "" { |
| req.MediaType = "video" |
| } |
| req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) |
|
|
| |
| activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
| if activeCount >= 3 { |
| response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") |
| return |
| } |
|
|
| |
| if h.quotaService != nil { |
| if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { |
| var quotaErr *service.QuotaExceededError |
| if errors.As(err, "aErr) { |
| response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") |
| return |
| } |
| response.Error(c, http.StatusForbidden, err.Error()) |
| return |
| } |
| } |
|
|
| |
| var apiKeyID *int64 |
| var groupID *int64 |
|
|
| if req.APIKeyID != nil && h.apiKeyService != nil { |
| |
| apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) |
| if err != nil { |
| response.Error(c, http.StatusBadRequest, "API Key 不存在") |
| return |
| } |
| if apiKey.UserID != userID { |
| response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") |
| return |
| } |
| if apiKey.Status != service.StatusAPIKeyActive { |
| response.Error(c, http.StatusForbidden, "API Key 不可用") |
| return |
| } |
| apiKeyID = &apiKey.ID |
| groupID = apiKey.GroupID |
| } else if id, ok := c.Get("api_key_id"); ok { |
| |
| if v, ok := id.(int64); ok { |
| apiKeyID = &v |
| } |
| } |
|
|
| gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) |
| if err != nil { |
| if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { |
| response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") |
| return |
| } |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| |
| go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) |
|
|
| response.Success(c, gin.H{ |
| "generation_id": gen.ID, |
| "status": gen.Status, |
| }) |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { |
| ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) |
| defer cancel() |
|
|
| |
| if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { |
| if errors.Is(err, service.ErrSoraGenerationStateConflict) { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) |
| return |
| } |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) |
| return |
| } |
|
|
| logger.LegacyPrintf( |
| "handler.sora_client", |
| "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", |
| genID, |
| userID, |
| groupIDForLog(groupID), |
| model, |
| mediaType, |
| videoCount, |
| strings.TrimSpace(imageInput) != "", |
| len(strings.TrimSpace(prompt)), |
| ) |
|
|
| |
| if groupID == nil { |
| ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) |
| } |
|
|
| if h.gatewayService == nil { |
| _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") |
| return |
| } |
|
|
| |
| account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) |
| if err != nil { |
| logger.LegacyPrintf( |
| "handler.sora_client", |
| "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", |
| genID, |
| userID, |
| groupIDForLog(groupID), |
| model, |
| err, |
| ) |
| _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) |
| return |
| } |
| logger.LegacyPrintf( |
| "handler.sora_client", |
| "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", |
| genID, |
| userID, |
| groupIDForLog(groupID), |
| model, |
| account.ID, |
| account.Name, |
| account.Platform, |
| account.Type, |
| ) |
|
|
| |
| body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) |
|
|
| if h.soraGatewayService == nil { |
| _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") |
| return |
| } |
|
|
| |
| recorder := httptest.NewRecorder() |
| mockGinCtx, _ := gin.CreateTestContext(recorder) |
| mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) |
|
|
| |
| result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) |
| if err != nil { |
| logger.LegacyPrintf( |
| "handler.sora_client", |
| "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", |
| genID, |
| account.ID, |
| model, |
| recorder.Code, |
| trimForLog(recorder.Body.String(), 400), |
| err, |
| ) |
| |
| gen, _ := h.genService.GetByID(ctx, genID, userID) |
| if gen != nil && gen.Status == service.SoraGenStatusCancelled { |
| return |
| } |
| _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) |
| return |
| } |
|
|
| |
| mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) |
| if mediaURL == "" { |
| logger.LegacyPrintf( |
| "handler.sora_client", |
| "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", |
| genID, |
| account.ID, |
| model, |
| recorder.Code, |
| trimForLog(recorder.Body.String(), 400), |
| ) |
| _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") |
| return |
| } |
|
|
| |
| gen, _ := h.genService.GetByID(ctx, genID, userID) |
| if gen != nil && gen.Status == service.SoraGenStatusCancelled { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) |
| return |
| } |
|
|
| |
| storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) |
|
|
| usageAdded := false |
| if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { |
| if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { |
| h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) |
| var quotaErr *service.QuotaExceededError |
| if errors.As(err, "aErr) { |
| _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") |
| return |
| } |
| _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) |
| return |
| } |
| usageAdded = true |
| } |
|
|
| |
| gen, _ = h.genService.GetByID(ctx, genID, userID) |
| if gen != nil && gen.Status == service.SoraGenStatusCancelled { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) |
| h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) |
| if usageAdded && h.quotaService != nil { |
| _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) |
| } |
| return |
| } |
|
|
| |
| if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { |
| if errors.Is(err, service.ErrSoraGenerationStateConflict) { |
| h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) |
| if usageAdded && h.quotaService != nil { |
| _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) |
| } |
| return |
| } |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) |
| return |
| } |
|
|
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) |
| } |
|
|
| |
| func (h *SoraClientHandler) storeMediaWithDegradation( |
| ctx context.Context, userID int64, mediaType string, |
| mediaURL string, mediaURLs []string, |
| ) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { |
| urls := mediaURLs |
| if len(urls) == 0 { |
| urls = []string{mediaURL} |
| } |
|
|
| |
| if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { |
| keys := make([]string, 0, len(urls)) |
| var totalSize int64 |
| allOK := true |
| for _, u := range urls { |
| key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) |
| if err != nil { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) |
| allOK = false |
| |
| if len(keys) > 0 { |
| _ = h.s3Storage.DeleteObjects(ctx, keys) |
| } |
| break |
| } |
| keys = append(keys, key) |
| totalSize += size |
| } |
| if allOK && len(keys) > 0 { |
| accessURLs := make([]string, 0, len(keys)) |
| for _, key := range keys { |
| accessURL, err := h.s3Storage.GetAccessURL(ctx, key) |
| if err != nil { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) |
| _ = h.s3Storage.DeleteObjects(ctx, keys) |
| allOK = false |
| break |
| } |
| accessURLs = append(accessURLs, accessURL) |
| } |
| if allOK && len(accessURLs) > 0 { |
| return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize |
| } |
| } |
| } |
|
|
| |
| if h.mediaStorage != nil && h.mediaStorage.Enabled() { |
| storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) |
| if err == nil && len(storedPaths) > 0 { |
| firstPath := storedPaths[0] |
| totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) |
| if sizeErr != nil { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) |
| } |
| return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize |
| } |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) |
| } |
|
|
| |
| return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 |
| } |
|
|
| |
| func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { |
| body := map[string]any{ |
| "model": model, |
| "messages": []map[string]string{ |
| {"role": "user", "content": prompt}, |
| }, |
| "stream": false, |
| } |
| if imageInput != "" { |
| body["image_input"] = imageInput |
| } |
| if videoCount > 1 { |
| body["video_count"] = videoCount |
| } |
| b, _ := json.Marshal(body) |
| return b |
| } |
|
|
| func normalizeVideoCount(mediaType string, videoCount int) int { |
| if mediaType != "video" { |
| return 1 |
| } |
| if videoCount <= 0 { |
| return 1 |
| } |
| if videoCount > 3 { |
| return 3 |
| } |
| return videoCount |
| } |
|
|
| |
| |
| |
| func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { |
| |
| if result != nil && result.MediaURL != "" { |
| |
| if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { |
| return urls[0], urls |
| } |
| return result.MediaURL, []string{result.MediaURL} |
| } |
|
|
| |
| if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { |
| return urls[0], urls |
| } |
|
|
| return "", nil |
| } |
|
|
| |
| func parseMediaURLsFromBody(body []byte) []string { |
| if len(body) == 0 { |
| return nil |
| } |
| var resp map[string]any |
| if err := json.Unmarshal(body, &resp); err != nil { |
| return nil |
| } |
|
|
| |
| if rawURLs, ok := resp["media_urls"]; ok { |
| if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { |
| urls := make([]string, 0, len(arr)) |
| for _, item := range arr { |
| if s, ok := item.(string); ok && s != "" { |
| urls = append(urls, s) |
| } |
| } |
| if len(urls) > 0 { |
| return urls |
| } |
| } |
| } |
|
|
| |
| if url, ok := resp["media_url"].(string); ok && url != "" { |
| return []string{url} |
| } |
|
|
| return nil |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) ListGenerations(c *gin.Context) { |
| userID := getUserIDFromContext(c) |
| if userID == 0 { |
| response.Error(c, http.StatusUnauthorized, "未登录") |
| return |
| } |
|
|
| page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) |
| pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) |
|
|
| params := service.SoraGenerationListParams{ |
| UserID: userID, |
| Status: c.Query("status"), |
| StorageType: c.Query("storage_type"), |
| MediaType: c.Query("media_type"), |
| Page: page, |
| PageSize: pageSize, |
| } |
|
|
| gens, total, err := h.genService.List(c.Request.Context(), params) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| |
| for _, gen := range gens { |
| _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) |
| } |
|
|
| response.Success(c, gin.H{ |
| "data": gens, |
| "total": total, |
| "page": page, |
| }) |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) GetGeneration(c *gin.Context) { |
| userID := getUserIDFromContext(c) |
| if userID == 0 { |
| response.Error(c, http.StatusUnauthorized, "未登录") |
| return |
| } |
|
|
| id, err := strconv.ParseInt(c.Param("id"), 10, 64) |
| if err != nil { |
| response.Error(c, http.StatusBadRequest, "无效的 ID") |
| return |
| } |
|
|
| gen, err := h.genService.GetByID(c.Request.Context(), id, userID) |
| if err != nil { |
| response.Error(c, http.StatusNotFound, err.Error()) |
| return |
| } |
|
|
| _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) |
| response.Success(c, gen) |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { |
| userID := getUserIDFromContext(c) |
| if userID == 0 { |
| response.Error(c, http.StatusUnauthorized, "未登录") |
| return |
| } |
|
|
| id, err := strconv.ParseInt(c.Param("id"), 10, 64) |
| if err != nil { |
| response.Error(c, http.StatusBadRequest, "无效的 ID") |
| return |
| } |
|
|
| gen, err := h.genService.GetByID(c.Request.Context(), id, userID) |
| if err != nil { |
| response.Error(c, http.StatusNotFound, err.Error()) |
| return |
| } |
|
|
| |
| if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { |
| paths := gen.MediaURLs |
| if len(paths) == 0 && gen.MediaURL != "" { |
| paths = []string{gen.MediaURL} |
| } |
| if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) |
| } |
| } |
|
|
| if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { |
| response.Error(c, http.StatusNotFound, err.Error()) |
| return |
| } |
|
|
| response.Success(c, gin.H{"message": "已删除"}) |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) GetQuota(c *gin.Context) { |
| userID := getUserIDFromContext(c) |
| if userID == 0 { |
| response.Error(c, http.StatusUnauthorized, "未登录") |
| return |
| } |
|
|
| if h.quotaService == nil { |
| response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) |
| return |
| } |
|
|
| quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
| response.Success(c, quota) |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { |
| userID := getUserIDFromContext(c) |
| if userID == 0 { |
| response.Error(c, http.StatusUnauthorized, "未登录") |
| return |
| } |
|
|
| id, err := strconv.ParseInt(c.Param("id"), 10, 64) |
| if err != nil { |
| response.Error(c, http.StatusBadRequest, "无效的 ID") |
| return |
| } |
|
|
| |
| gen, err := h.genService.GetByID(c.Request.Context(), id, userID) |
| if err != nil { |
| response.Error(c, http.StatusNotFound, err.Error()) |
| return |
| } |
| _ = gen |
|
|
| if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { |
| if errors.Is(err, service.ErrSoraGenerationNotActive) { |
| response.Error(c, http.StatusConflict, "任务已结束,无法取消") |
| return |
| } |
| response.Error(c, http.StatusBadRequest, err.Error()) |
| return |
| } |
|
|
| response.Success(c, gin.H{"message": "已取消"}) |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { |
| userID := getUserIDFromContext(c) |
| if userID == 0 { |
| response.Error(c, http.StatusUnauthorized, "未登录") |
| return |
| } |
|
|
| id, err := strconv.ParseInt(c.Param("id"), 10, 64) |
| if err != nil { |
| response.Error(c, http.StatusBadRequest, "无效的 ID") |
| return |
| } |
|
|
| gen, err := h.genService.GetByID(c.Request.Context(), id, userID) |
| if err != nil { |
| response.Error(c, http.StatusNotFound, err.Error()) |
| return |
| } |
|
|
| if gen.StorageType != service.SoraStorageTypeUpstream { |
| response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") |
| return |
| } |
| if gen.MediaURL == "" { |
| response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") |
| return |
| } |
|
|
| if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { |
| response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") |
| return |
| } |
|
|
| sourceURLs := gen.MediaURLs |
| if len(sourceURLs) == 0 && gen.MediaURL != "" { |
| sourceURLs = []string{gen.MediaURL} |
| } |
| if len(sourceURLs) == 0 { |
| response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") |
| return |
| } |
|
|
| uploadedKeys := make([]string, 0, len(sourceURLs)) |
| accessURLs := make([]string, 0, len(sourceURLs)) |
| var totalSize int64 |
|
|
| for _, sourceURL := range sourceURLs { |
| objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) |
| if uploadErr != nil { |
| if len(uploadedKeys) > 0 { |
| _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) |
| } |
| var upstreamErr *service.UpstreamDownloadError |
| if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { |
| response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") |
| return |
| } |
| response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) |
| return |
| } |
| accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) |
| if err != nil { |
| uploadedKeys = append(uploadedKeys, objectKey) |
| _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) |
| response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) |
| return |
| } |
| uploadedKeys = append(uploadedKeys, objectKey) |
| accessURLs = append(accessURLs, accessURL) |
| totalSize += fileSize |
| } |
|
|
| usageAdded := false |
| if totalSize > 0 && h.quotaService != nil { |
| if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { |
| _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) |
| var quotaErr *service.QuotaExceededError |
| if errors.As(err, "aErr) { |
| response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") |
| return |
| } |
| response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) |
| return |
| } |
| usageAdded = true |
| } |
|
|
| if err := h.genService.UpdateStorageForCompleted( |
| c.Request.Context(), |
| id, |
| accessURLs[0], |
| accessURLs, |
| service.SoraStorageTypeS3, |
| uploadedKeys, |
| totalSize, |
| ); err != nil { |
| _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) |
| if usageAdded && h.quotaService != nil { |
| _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) |
| } |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, gin.H{ |
| "message": "已保存到 S3", |
| "object_key": uploadedKeys[0], |
| "object_keys": uploadedKeys, |
| }) |
| } |
|
|
| |
| |
| func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { |
| s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) |
| s3Healthy := false |
| if s3Enabled { |
| s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) |
| } |
| localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() |
| response.Success(c, gin.H{ |
| "s3_enabled": s3Enabled, |
| "s3_healthy": s3Healthy, |
| "local_enabled": localEnabled, |
| }) |
| } |
|
|
| func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { |
| switch storageType { |
| case service.SoraStorageTypeS3: |
| if h.s3Storage != nil && len(s3Keys) > 0 { |
| if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) |
| } |
| } |
| case service.SoraStorageTypeLocal: |
| if h.mediaStorage != nil && len(localPaths) > 0 { |
| if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) |
| } |
| } |
| } |
| } |
|
|
| |
| func getUserIDFromContext(c *gin.Context) int64 { |
| if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { |
| return subject.UserID |
| } |
|
|
| if id, ok := c.Get("user_id"); ok { |
| switch v := id.(type) { |
| case int64: |
| return v |
| case float64: |
| return int64(v) |
| case string: |
| n, _ := strconv.ParseInt(v, 10, 64) |
| return n |
| } |
| } |
| |
| if id, ok := c.Get("userID"); ok { |
| if v, ok := id.(int64); ok { |
| return v |
| } |
| } |
| return 0 |
| } |
|
|
| func groupIDForLog(groupID *int64) int64 { |
| if groupID == nil { |
| return 0 |
| } |
| return *groupID |
| } |
|
|
| func trimForLog(raw string, maxLen int) string { |
| trimmed := strings.TrimSpace(raw) |
| if maxLen <= 0 || len(trimmed) <= maxLen { |
| return trimmed |
| } |
| return trimmed[:maxLen] + "...(truncated)" |
| } |
|
|
| |
| |
| |
| func (h *SoraClientHandler) GetModels(c *gin.Context) { |
| families := h.getModelFamilies(c.Request.Context()) |
| response.Success(c, families) |
| } |
|
|
| |
| func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { |
| |
| h.modelCacheMu.RLock() |
| ttl := modelCacheTTL |
| if !h.modelCacheUpstream { |
| ttl = modelCacheFailedTTL |
| } |
| if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { |
| families := h.cachedFamilies |
| h.modelCacheMu.RUnlock() |
| return families |
| } |
| h.modelCacheMu.RUnlock() |
|
|
| |
| h.modelCacheMu.Lock() |
| defer h.modelCacheMu.Unlock() |
|
|
| |
| ttl = modelCacheTTL |
| if !h.modelCacheUpstream { |
| ttl = modelCacheFailedTTL |
| } |
| if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { |
| return h.cachedFamilies |
| } |
|
|
| |
| families, err := h.fetchUpstreamModels(ctx) |
| if err != nil { |
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) |
| families = service.BuildSoraModelFamilies() |
| h.cachedFamilies = families |
| h.modelCacheTime = time.Now() |
| h.modelCacheUpstream = false |
| return families |
| } |
|
|
| logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) |
| h.cachedFamilies = families |
| h.modelCacheTime = time.Now() |
| h.modelCacheUpstream = true |
| return families |
| } |
|
|
| |
| func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { |
| if h.gatewayService == nil { |
| return nil, fmt.Errorf("gatewayService 未初始化") |
| } |
|
|
| |
| ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) |
|
|
| |
| account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") |
| if err != nil { |
| return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) |
| } |
|
|
| |
| if account.Type != service.AccountTypeAPIKey { |
| return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) |
| } |
|
|
| apiKey := account.GetCredential("api_key") |
| if apiKey == "" { |
| return nil, fmt.Errorf("账号缺少 api_key") |
| } |
|
|
| baseURL := account.GetBaseURL() |
| if baseURL == "" { |
| return nil, fmt.Errorf("账号缺少 base_url") |
| } |
|
|
| |
| modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" |
|
|
| reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) |
| defer cancel() |
|
|
| req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) |
| if err != nil { |
| return nil, fmt.Errorf("创建请求失败: %w", err) |
| } |
| req.Header.Set("Authorization", "Bearer "+apiKey) |
|
|
| client := &http.Client{Timeout: 10 * time.Second} |
| resp, err := client.Do(req) |
| if err != nil { |
| return nil, fmt.Errorf("请求上游失败: %w", err) |
| } |
| defer func() { |
| _ = resp.Body.Close() |
| }() |
|
|
| if resp.StatusCode != http.StatusOK { |
| return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) |
| } |
|
|
| body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) |
| if err != nil { |
| return nil, fmt.Errorf("读取响应失败: %w", err) |
| } |
|
|
| |
| var modelsResp struct { |
| Data []struct { |
| ID string `json:"id"` |
| } `json:"data"` |
| } |
| if err := json.Unmarshal(body, &modelsResp); err != nil { |
| return nil, fmt.Errorf("解析响应失败: %w", err) |
| } |
|
|
| if len(modelsResp.Data) == 0 { |
| return nil, fmt.Errorf("上游返回空模型列表") |
| } |
|
|
| |
| modelIDs := make([]string, 0, len(modelsResp.Data)) |
| for _, m := range modelsResp.Data { |
| modelIDs = append(modelIDs, m.ID) |
| } |
|
|
| |
| families := service.BuildSoraModelFamiliesFromIDs(modelIDs) |
| if len(families) == 0 { |
| return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") |
| } |
|
|
| return families, nil |
| } |
|
|