| package service |
|
|
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "sort" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/logger" |
| "github.com/QuantumNous/new-api/model" |
| "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
|
|
| "github.com/samber/lo" |
| ) |
|
|
| |
| type TaskPollingAdaptor interface { |
| Init(info *relaycommon.RelayInfo) |
| FetchTask(baseURL string, key string, body map[string]any, proxy string) (*http.Response, error) |
| ParseTaskResult(body []byte) (*relaycommon.TaskInfo, error) |
| |
| |
| AdjustBillingOnComplete(task *model.Task, taskResult *relaycommon.TaskInfo) int |
| } |
|
|
| |
| |
| var GetTaskAdaptorFunc func(platform constant.TaskPlatform) TaskPollingAdaptor |
|
|
| |
| |
| |
| func sweepTimedOutTasks(ctx context.Context) { |
| if constant.TaskTimeoutMinutes <= 0 { |
| return |
| } |
| cutoff := time.Now().Unix() - int64(constant.TaskTimeoutMinutes)*60 |
| tasks := model.GetTimedOutUnfinishedTasks(cutoff, 100) |
| if len(tasks) == 0 { |
| return |
| } |
|
|
| const legacyTaskCutoff int64 = 1740182400 |
| reason := fmt.Sprintf("任务超时(%d分钟)", constant.TaskTimeoutMinutes) |
| legacyReason := "任务超时(旧系统遗留任务,不进行退款,请联系管理员)" |
| now := time.Now().Unix() |
| timedOutCount := 0 |
|
|
| for _, task := range tasks { |
| isLegacy := task.SubmitTime > 0 && task.SubmitTime < legacyTaskCutoff |
|
|
| oldStatus := task.Status |
| task.Status = model.TaskStatusFailure |
| task.Progress = "100%" |
| task.FinishTime = now |
| if isLegacy { |
| task.FailReason = legacyReason |
| } else { |
| task.FailReason = reason |
| } |
|
|
| won, err := task.UpdateWithStatus(oldStatus) |
| if err != nil { |
| logger.LogError(ctx, fmt.Sprintf("sweepTimedOutTasks CAS update error for task %s: %v", task.TaskID, err)) |
| continue |
| } |
| if !won { |
| logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: task %s already transitioned, skip", task.TaskID)) |
| continue |
| } |
| timedOutCount++ |
| if !isLegacy && task.Quota != 0 { |
| RefundTaskQuota(ctx, task, reason) |
| } |
| } |
|
|
| if timedOutCount > 0 { |
| logger.LogInfo(ctx, fmt.Sprintf("sweepTimedOutTasks: timed out %d tasks", timedOutCount)) |
| } |
| } |
|
|
| |
| func TaskPollingLoop() { |
| for { |
| time.Sleep(time.Duration(15) * time.Second) |
| common.SysLog("任务进度轮询开始") |
| ctx := context.TODO() |
| sweepTimedOutTasks(ctx) |
| allTasks := model.GetAllUnFinishSyncTasks(constant.TaskQueryLimit) |
| platformTask := make(map[constant.TaskPlatform][]*model.Task) |
| for _, t := range allTasks { |
| platformTask[t.Platform] = append(platformTask[t.Platform], t) |
| } |
| for platform, tasks := range platformTask { |
| if len(tasks) == 0 { |
| continue |
| } |
| taskChannelM := make(map[int][]string) |
| taskM := make(map[string]*model.Task) |
| nullTaskIds := make([]int64, 0) |
| for _, task := range tasks { |
| upstreamID := task.GetUpstreamTaskID() |
| if upstreamID == "" { |
| |
| nullTaskIds = append(nullTaskIds, task.ID) |
| continue |
| } |
| taskM[upstreamID] = task |
| taskChannelM[task.ChannelId] = append(taskChannelM[task.ChannelId], upstreamID) |
| } |
| if len(nullTaskIds) > 0 { |
| err := model.TaskBulkUpdateByID(nullTaskIds, map[string]any{ |
| "status": "FAILURE", |
| "progress": "100%", |
| }) |
| if err != nil { |
| logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err)) |
| } else { |
| logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds)) |
| } |
| } |
| if len(taskChannelM) == 0 { |
| continue |
| } |
|
|
| DispatchPlatformUpdate(platform, taskChannelM, taskM) |
| } |
| common.SysLog("任务进度轮询完成") |
| } |
| } |
|
|
| |
| func DispatchPlatformUpdate(platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) { |
| switch platform { |
| case constant.TaskPlatformMidjourney: |
| |
| case constant.TaskPlatformSuno: |
| _ = UpdateSunoTasks(context.Background(), taskChannelM, taskM) |
| default: |
| if err := UpdateVideoTasks(context.Background(), platform, taskChannelM, taskM); err != nil { |
| common.SysLog(fmt.Sprintf("UpdateVideoTasks fail: %s", err)) |
| } |
| } |
| } |
|
|
| |
| func UpdateSunoTasks(ctx context.Context, taskChannelM map[int][]string, taskM map[string]*model.Task) error { |
| for channelId, taskIds := range taskChannelM { |
| err := updateSunoTasks(ctx, channelId, taskIds, taskM) |
| if err != nil { |
| logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %s", channelId, err.Error())) |
| } |
| } |
| return nil |
| } |
|
|
| func updateSunoTasks(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error { |
| logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds))) |
| if len(taskIds) == 0 { |
| return nil |
| } |
| ch, err := model.CacheGetChannel(channelId) |
| if err != nil { |
| common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err)) |
| |
| var failedIDs []int64 |
| for _, upstreamID := range taskIds { |
| if t, ok := taskM[upstreamID]; ok { |
| failedIDs = append(failedIDs, t.ID) |
| } |
| } |
| err = model.TaskBulkUpdateByID(failedIDs, map[string]any{ |
| "fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId), |
| "status": "FAILURE", |
| "progress": "100%", |
| }) |
| if err != nil { |
| common.SysLog(fmt.Sprintf("UpdateSunoTask error: %v", err)) |
| } |
| return err |
| } |
| adaptor := GetTaskAdaptorFunc(constant.TaskPlatformSuno) |
| if adaptor == nil { |
| return errors.New("adaptor not found") |
| } |
| proxy := ch.GetSetting().Proxy |
| resp, err := adaptor.FetchTask(*ch.BaseURL, ch.Key, map[string]any{ |
| "ids": taskIds, |
| }, proxy) |
| if err != nil { |
| common.SysLog(fmt.Sprintf("Get Task Do req error: %v", err)) |
| return err |
| } |
| if resp.StatusCode != http.StatusOK { |
| logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode)) |
| return fmt.Errorf("Get Task status code: %d", resp.StatusCode) |
| } |
| defer resp.Body.Close() |
| responseBody, err := io.ReadAll(resp.Body) |
| if err != nil { |
| common.SysLog(fmt.Sprintf("Get Suno Task parse body error: %v", err)) |
| return err |
| } |
| var responseItems dto.TaskResponse[[]dto.SunoDataResponse] |
| err = common.Unmarshal(responseBody, &responseItems) |
| if err != nil { |
| logger.LogError(ctx, fmt.Sprintf("Get Suno Task parse body error2: %v, body: %s", err, string(responseBody))) |
| return err |
| } |
| if !responseItems.IsSuccess() { |
| common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %s", channelId, len(taskIds), string(responseBody))) |
| return err |
| } |
|
|
| for _, responseItem := range responseItems.Data { |
| task := taskM[responseItem.TaskID] |
| if !taskNeedsUpdate(task, responseItem) { |
| continue |
| } |
|
|
| task.Status = lo.If(model.TaskStatus(responseItem.Status) != "", model.TaskStatus(responseItem.Status)).Else(task.Status) |
| task.FailReason = lo.If(responseItem.FailReason != "", responseItem.FailReason).Else(task.FailReason) |
| task.SubmitTime = lo.If(responseItem.SubmitTime != 0, responseItem.SubmitTime).Else(task.SubmitTime) |
| task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime) |
| task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime) |
| if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure { |
| logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason) |
| task.Progress = "100%" |
| RefundTaskQuota(ctx, task, task.FailReason) |
| } |
| if responseItem.Status == model.TaskStatusSuccess { |
| task.Progress = "100%" |
| } |
| task.Data = responseItem.Data |
|
|
| err = task.Update() |
| if err != nil { |
| common.SysLog("UpdateSunoTask task error: " + err.Error()) |
| } |
| } |
| return nil |
| } |
|
|
| |
| func taskNeedsUpdate(oldTask *model.Task, newTask dto.SunoDataResponse) bool { |
| if oldTask.SubmitTime != newTask.SubmitTime { |
| return true |
| } |
| if oldTask.StartTime != newTask.StartTime { |
| return true |
| } |
| if oldTask.FinishTime != newTask.FinishTime { |
| return true |
| } |
| if string(oldTask.Status) != newTask.Status { |
| return true |
| } |
| if oldTask.FailReason != newTask.FailReason { |
| return true |
| } |
|
|
| if (oldTask.Status == model.TaskStatusFailure || oldTask.Status == model.TaskStatusSuccess) && oldTask.Progress != "100%" { |
| return true |
| } |
|
|
| oldData, _ := common.Marshal(oldTask.Data) |
| newData, _ := common.Marshal(newTask.Data) |
|
|
| sort.Slice(oldData, func(i, j int) bool { |
| return oldData[i] < oldData[j] |
| }) |
| sort.Slice(newData, func(i, j int) bool { |
| return newData[i] < newData[j] |
| }) |
|
|
| if string(oldData) != string(newData) { |
| return true |
| } |
| return false |
| } |
|
|
| |
| func UpdateVideoTasks(ctx context.Context, platform constant.TaskPlatform, taskChannelM map[int][]string, taskM map[string]*model.Task) error { |
| for channelId, taskIds := range taskChannelM { |
| if err := updateVideoTasks(ctx, platform, channelId, taskIds, taskM); err != nil { |
| logger.LogError(ctx, fmt.Sprintf("Channel #%d failed to update video async tasks: %s", channelId, err.Error())) |
| } |
| } |
| return nil |
| } |
|
|
| func updateVideoTasks(ctx context.Context, platform constant.TaskPlatform, channelId int, taskIds []string, taskM map[string]*model.Task) error { |
| logger.LogInfo(ctx, fmt.Sprintf("Channel #%d pending video tasks: %d", channelId, len(taskIds))) |
| if len(taskIds) == 0 { |
| return nil |
| } |
| cacheGetChannel, err := model.CacheGetChannel(channelId) |
| if err != nil { |
| |
| var failedIDs []int64 |
| for _, upstreamID := range taskIds { |
| if t, ok := taskM[upstreamID]; ok { |
| failedIDs = append(failedIDs, t.ID) |
| } |
| } |
| errUpdate := model.TaskBulkUpdateByID(failedIDs, map[string]any{ |
| "fail_reason": fmt.Sprintf("Failed to get channel info, channel ID: %d", channelId), |
| "status": "FAILURE", |
| "progress": "100%", |
| }) |
| if errUpdate != nil { |
| common.SysLog(fmt.Sprintf("UpdateVideoTask error: %v", errUpdate)) |
| } |
| return fmt.Errorf("CacheGetChannel failed: %w", err) |
| } |
| adaptor := GetTaskAdaptorFunc(platform) |
| if adaptor == nil { |
| return fmt.Errorf("video adaptor not found") |
| } |
| info := &relaycommon.RelayInfo{} |
| info.ChannelMeta = &relaycommon.ChannelMeta{ |
| ChannelBaseUrl: cacheGetChannel.GetBaseURL(), |
| } |
| info.ApiKey = cacheGetChannel.Key |
| adaptor.Init(info) |
| for _, taskId := range taskIds { |
| if err := updateVideoSingleTask(ctx, adaptor, cacheGetChannel, taskId, taskM); err != nil { |
| logger.LogError(ctx, fmt.Sprintf("Failed to update video task %s: %s", taskId, err.Error())) |
| } |
| |
| time.Sleep(1 * time.Second) |
| } |
| return nil |
| } |
|
|
| func updateVideoSingleTask(ctx context.Context, adaptor TaskPollingAdaptor, ch *model.Channel, taskId string, taskM map[string]*model.Task) error { |
| baseURL := constant.ChannelBaseURLs[ch.Type] |
| if ch.GetBaseURL() != "" { |
| baseURL = ch.GetBaseURL() |
| } |
| proxy := ch.GetSetting().Proxy |
|
|
| task := taskM[taskId] |
| if task == nil { |
| logger.LogError(ctx, fmt.Sprintf("Task %s not found in taskM", taskId)) |
| return fmt.Errorf("task %s not found", taskId) |
| } |
| key := ch.Key |
|
|
| privateData := task.PrivateData |
| if privateData.Key != "" { |
| key = privateData.Key |
| } |
| resp, err := adaptor.FetchTask(baseURL, key, map[string]any{ |
| "task_id": task.GetUpstreamTaskID(), |
| "action": task.Action, |
| }, proxy) |
| if err != nil { |
| return fmt.Errorf("fetchTask failed for task %s: %w", taskId, err) |
| } |
| defer resp.Body.Close() |
| responseBody, err := io.ReadAll(resp.Body) |
| if err != nil { |
| return fmt.Errorf("readAll failed for task %s: %w", taskId, err) |
| } |
|
|
| logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask response: %s", string(responseBody))) |
|
|
| snap := task.Snapshot() |
|
|
| taskResult := &relaycommon.TaskInfo{} |
| |
| var responseItems dto.TaskResponse[model.Task] |
| if err = common.Unmarshal(responseBody, &responseItems); err == nil && responseItems.IsSuccess() { |
| logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask parsed as new api response format: %+v", responseItems)) |
| t := responseItems.Data |
| taskResult.TaskID = t.TaskID |
| taskResult.Status = string(t.Status) |
| taskResult.Url = t.GetResultURL() |
| taskResult.Progress = t.Progress |
| taskResult.Reason = t.FailReason |
| task.Data = t.Data |
| } else if taskResult, err = adaptor.ParseTaskResult(responseBody); err != nil { |
| return fmt.Errorf("parseTaskResult failed for task %s: %w", taskId, err) |
| } |
|
|
| task.Data = redactVideoResponseBody(responseBody) |
|
|
| logger.LogDebug(ctx, fmt.Sprintf("updateVideoSingleTask taskResult: %+v", taskResult)) |
|
|
| now := time.Now().Unix() |
| if taskResult.Status == "" { |
| |
| errorResult := &dto.GeneralErrorResponse{} |
| if err = common.Unmarshal(responseBody, &errorResult); err == nil { |
| openaiError := errorResult.TryToOpenAIError() |
| if openaiError != nil { |
| |
| if openaiError.Code == "429" { |
| |
| return nil |
| } |
|
|
| |
| taskResult = relaycommon.FailTaskInfo("upstream returned error") |
| } else { |
| |
| logger.LogError(ctx, fmt.Sprintf("Task %s returned empty status with unrecognized error format, response: %s", taskId, string(responseBody))) |
| taskResult = relaycommon.FailTaskInfo("upstream returned unrecognized message") |
| } |
| } |
| } |
|
|
| shouldRefund := false |
| shouldSettle := false |
| quota := task.Quota |
|
|
| task.Status = model.TaskStatus(taskResult.Status) |
| switch taskResult.Status { |
| case model.TaskStatusSubmitted: |
| task.Progress = taskcommon.ProgressSubmitted |
| case model.TaskStatusQueued: |
| task.Progress = taskcommon.ProgressQueued |
| case model.TaskStatusInProgress: |
| task.Progress = taskcommon.ProgressInProgress |
| if task.StartTime == 0 { |
| task.StartTime = now |
| } |
| case model.TaskStatusSuccess: |
| task.Progress = taskcommon.ProgressComplete |
| if task.FinishTime == 0 { |
| task.FinishTime = now |
| } |
| if strings.HasPrefix(taskResult.Url, "data:") { |
| |
| task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) |
| } else if taskResult.Url != "" { |
| |
| task.PrivateData.ResultURL = taskResult.Url |
| } else { |
| |
| task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) |
| } |
| shouldSettle = true |
| case model.TaskStatusFailure: |
| logger.LogJson(ctx, fmt.Sprintf("Task %s failed", taskId), task) |
| task.Status = model.TaskStatusFailure |
| task.Progress = taskcommon.ProgressComplete |
| if task.FinishTime == 0 { |
| task.FinishTime = now |
| } |
| task.FailReason = taskResult.Reason |
| logger.LogInfo(ctx, fmt.Sprintf("Task %s failed: %s", task.TaskID, task.FailReason)) |
| taskResult.Progress = taskcommon.ProgressComplete |
| if quota != 0 { |
| shouldRefund = true |
| } |
| default: |
| return fmt.Errorf("unknown task status %s for task %s", taskResult.Status, task.TaskID) |
| } |
| if taskResult.Progress != "" { |
| task.Progress = taskResult.Progress |
| } |
|
|
| isDone := task.Status == model.TaskStatusSuccess || task.Status == model.TaskStatusFailure |
| if isDone && snap.Status != task.Status { |
| won, err := task.UpdateWithStatus(snap.Status) |
| if err != nil { |
| logger.LogError(ctx, fmt.Sprintf("UpdateWithStatus failed for task %s: %s", task.TaskID, err.Error())) |
| shouldRefund = false |
| shouldSettle = false |
| } else if !won { |
| logger.LogWarn(ctx, fmt.Sprintf("Task %s already transitioned by another process, skip billing", task.TaskID)) |
| shouldRefund = false |
| shouldSettle = false |
| } |
| } else if !snap.Equal(task.Snapshot()) { |
| if _, err := task.UpdateWithStatus(snap.Status); err != nil { |
| logger.LogError(ctx, fmt.Sprintf("Failed to update task %s: %s", task.TaskID, err.Error())) |
| } |
| } else { |
| |
| logger.LogDebug(ctx, fmt.Sprintf("No update needed for task %s", task.TaskID)) |
| } |
|
|
| if shouldSettle { |
| settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) |
| } |
| if shouldRefund { |
| RefundTaskQuota(ctx, task, task.FailReason) |
| } |
|
|
| return nil |
| } |
|
|
| func redactVideoResponseBody(body []byte) []byte { |
| var m map[string]any |
| if err := common.Unmarshal(body, &m); err != nil { |
| return body |
| } |
| resp, _ := m["response"].(map[string]any) |
| if resp != nil { |
| delete(resp, "bytesBase64Encoded") |
| if v, ok := resp["video"].(string); ok { |
| resp["video"] = truncateBase64(v) |
| } |
| if vs, ok := resp["videos"].([]any); ok { |
| for i := range vs { |
| if vm, ok := vs[i].(map[string]any); ok { |
| delete(vm, "bytesBase64Encoded") |
| } |
| } |
| } |
| } |
| b, err := common.Marshal(m) |
| if err != nil { |
| return body |
| } |
| return b |
| } |
|
|
| func truncateBase64(s string) string { |
| const maxKeep = 256 |
| if len(s) <= maxKeep { |
| return s |
| } |
| return s[:maxKeep] + "..." |
| } |
|
|
| |
| |
| |
| |
| |
| func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor, task *model.Task, taskResult *relaycommon.TaskInfo) { |
| |
| if bc := task.PrivateData.BillingContext; bc != nil && bc.PerCallBilling { |
| logger.LogInfo(ctx, fmt.Sprintf("任务 %s 按次计费,跳过差额结算", task.TaskID)) |
| return |
| } |
| |
| if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { |
| RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") |
| return |
| } |
| |
| if taskResult.TotalTokens > 0 { |
| RecalculateTaskQuotaByTokens(ctx, task, taskResult.TotalTokens) |
| return |
| } |
| |
| } |
|
|