| package relay |
|
|
| import ( |
| "bytes" |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "strconv" |
| "strings" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/model" |
| "github.com/QuantumNous/new-api/relay/channel" |
| "github.com/QuantumNous/new-api/relay/channel/task/taskcommon" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| relayconstant "github.com/QuantumNous/new-api/relay/constant" |
| "github.com/QuantumNous/new-api/relay/helper" |
| "github.com/QuantumNous/new-api/service" |
| "github.com/gin-gonic/gin" |
| ) |
|
|
| type TaskSubmitResult struct { |
| UpstreamTaskID string |
| TaskData []byte |
| Platform constant.TaskPlatform |
| Quota int |
| |
| } |
|
|
| |
| |
| |
| |
| |
| func ResolveOriginTask(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError { |
| |
| path := c.Request.URL.Path |
| if strings.Contains(path, "/v1/videos/") && strings.HasSuffix(path, "/remix") { |
| info.Action = constant.TaskActionRemix |
| } |
|
|
| |
| if info.Action == constant.TaskActionRemix { |
| videoID := c.Param("video_id") |
| if strings.TrimSpace(videoID) == "" { |
| return service.TaskErrorWrapperLocal(fmt.Errorf("video_id is required"), "invalid_request", http.StatusBadRequest) |
| } |
| info.OriginTaskID = videoID |
| } |
|
|
| if info.OriginTaskID == "" { |
| return nil |
| } |
|
|
| |
| originTask, exist, err := model.GetByTaskId(info.UserId, info.OriginTaskID) |
| if err != nil { |
| return service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError) |
| } |
| if !exist { |
| return service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest) |
| } |
|
|
| |
| if info.OriginModelName == "" { |
| if originTask.Properties.OriginModelName != "" { |
| info.OriginModelName = originTask.Properties.OriginModelName |
| } else if originTask.Properties.UpstreamModelName != "" { |
| info.OriginModelName = originTask.Properties.UpstreamModelName |
| } else { |
| var taskData map[string]interface{} |
| _ = common.Unmarshal(originTask.Data, &taskData) |
| if m, ok := taskData["model"].(string); ok && m != "" { |
| info.OriginModelName = m |
| } |
| } |
| } |
|
|
| |
| ch, err := model.GetChannelById(originTask.ChannelId, true) |
| if err != nil { |
| return service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest) |
| } |
| if ch.Status != common.ChannelStatusEnabled { |
| return service.TaskErrorWrapperLocal(errors.New("the channel of the origin task is disabled"), "task_channel_disable", http.StatusBadRequest) |
| } |
| info.LockedChannel = ch |
|
|
| if originTask.ChannelId != info.ChannelId { |
| key, _, newAPIError := ch.GetNextEnabledKey() |
| if newAPIError != nil { |
| return service.TaskErrorWrapper(newAPIError, "channel_no_available_key", newAPIError.StatusCode) |
| } |
| common.SetContextKey(c, constant.ContextKeyChannelKey, key) |
| common.SetContextKey(c, constant.ContextKeyChannelType, ch.Type) |
| common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, ch.GetBaseURL()) |
| common.SetContextKey(c, constant.ContextKeyChannelId, originTask.ChannelId) |
|
|
| info.ChannelBaseUrl = ch.GetBaseURL() |
| info.ChannelId = originTask.ChannelId |
| info.ChannelType = ch.Type |
| info.ApiKey = key |
| } |
|
|
| |
| if info.Action == constant.TaskActionRemix { |
| if originTask.PrivateData.BillingContext != nil { |
| |
| for s, f := range originTask.PrivateData.BillingContext.OtherRatios { |
| info.PriceData.AddOtherRatio(s, f) |
| } |
| } else { |
| |
| var taskData map[string]interface{} |
| _ = common.Unmarshal(originTask.Data, &taskData) |
| secondsStr, _ := taskData["seconds"].(string) |
| seconds, _ := strconv.Atoi(secondsStr) |
| if seconds <= 0 { |
| seconds = 4 |
| } |
| sizeStr, _ := taskData["size"].(string) |
| if info.PriceData.OtherRatios == nil { |
| info.PriceData.OtherRatios = map[string]float64{} |
| } |
| info.PriceData.OtherRatios["seconds"] = float64(seconds) |
| info.PriceData.OtherRatios["size"] = 1 |
| if sizeStr == "1792x1024" || sizeStr == "1024x1792" { |
| info.PriceData.OtherRatios["size"] = 1.666667 |
| } |
| } |
| } |
|
|
| return nil |
| } |
|
|
| |
| |
| |
| |
| |
| func RelayTaskSubmit(c *gin.Context, info *relaycommon.RelayInfo) (*TaskSubmitResult, *dto.TaskError) { |
| info.InitChannelMeta(c) |
|
|
| |
| platform := constant.TaskPlatform(c.GetString("platform")) |
| if platform == "" { |
| platform = GetTaskPlatform(c) |
| } |
| adaptor := GetTaskAdaptor(platform) |
| if adaptor == nil { |
| return nil, service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest) |
| } |
| adaptor.Init(info) |
| if taskErr := adaptor.ValidateRequestAndSetAction(c, info); taskErr != nil { |
| return nil, taskErr |
| } |
|
|
| |
| modelName := info.OriginModelName |
| if modelName == "" { |
| modelName = service.CoverTaskActionToModelName(platform, info.Action) |
| } |
|
|
| |
| info.OriginModelName = modelName |
| info.UpstreamModelName = modelName |
| if err := helper.ModelMappedHelper(c, info, nil); err != nil { |
| return nil, service.TaskErrorWrapperLocal(err, "model_mapping_failed", http.StatusBadRequest) |
| } |
|
|
| |
| if info.PublicTaskID == "" { |
| info.PublicTaskID = model.GenerateTaskID() |
| } |
|
|
| |
| info.OriginModelName = modelName |
| priceData, err := helper.ModelPriceHelperPerCall(c, info) |
| if err != nil { |
| return nil, service.TaskErrorWrapper(err, "model_price_error", http.StatusBadRequest) |
| } |
| info.PriceData = priceData |
|
|
| |
| |
| |
| if estimatedRatios := adaptor.EstimateBilling(c, info); len(estimatedRatios) > 0 { |
| for k, v := range estimatedRatios { |
| info.PriceData.AddOtherRatio(k, v) |
| } |
| } |
|
|
| |
| if !common.StringsContains(constant.TaskPricePatches, modelName) { |
| for _, ra := range info.PriceData.OtherRatios { |
| if ra != 1.0 { |
| info.PriceData.Quota = int(float64(info.PriceData.Quota) * ra) |
| } |
| } |
| } |
|
|
| |
| if info.Billing == nil && !info.PriceData.FreeModel { |
| info.ForcePreConsume = true |
| if apiErr := service.PreConsumeBilling(c, info.PriceData.Quota, info); apiErr != nil { |
| return nil, service.TaskErrorFromAPIError(apiErr) |
| } |
| } |
|
|
| |
| requestBody, err := adaptor.BuildRequestBody(c, info) |
| if err != nil { |
| return nil, service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError) |
| } |
|
|
| |
| resp, err := adaptor.DoRequest(c, info, requestBody) |
| if err != nil { |
| return nil, service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) |
| } |
| if resp != nil && resp.StatusCode != http.StatusOK { |
| responseBody, _ := io.ReadAll(resp.Body) |
| return nil, service.TaskErrorWrapper(fmt.Errorf("%s", string(responseBody)), "fail_to_fetch_task", resp.StatusCode) |
| } |
|
|
| |
| otherRatios := info.PriceData.OtherRatios |
| if otherRatios == nil { |
| otherRatios = map[string]float64{} |
| } |
| ratiosJSON, _ := common.Marshal(otherRatios) |
| c.Header("X-New-Api-Other-Ratios", string(ratiosJSON)) |
|
|
| |
| upstreamTaskID, taskData, taskErr := adaptor.DoResponse(c, resp, info) |
| if taskErr != nil { |
| return nil, taskErr |
| } |
|
|
| |
| finalQuota := info.PriceData.Quota |
| if adjustedRatios := adaptor.AdjustBillingOnSubmit(info, taskData); len(adjustedRatios) > 0 { |
| |
| finalQuota = recalcQuotaFromRatios(info, adjustedRatios) |
| info.PriceData.OtherRatios = adjustedRatios |
| info.PriceData.Quota = finalQuota |
| } |
|
|
| return &TaskSubmitResult{ |
| UpstreamTaskID: upstreamTaskID, |
| TaskData: taskData, |
| Platform: platform, |
| Quota: finalQuota, |
| }, nil |
| } |
|
|
| |
| |
| func recalcQuotaFromRatios(info *relaycommon.RelayInfo, ratios map[string]float64) int { |
| |
| baseQuota := info.PriceData.Quota |
| |
| for _, ra := range info.PriceData.OtherRatios { |
| if ra != 1.0 && ra > 0 { |
| baseQuota = int(float64(baseQuota) / ra) |
| } |
| } |
| |
| result := float64(baseQuota) |
| for _, ra := range ratios { |
| if ra != 1.0 { |
| result *= ra |
| } |
| } |
| return int(result) |
| } |
|
|
| var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){ |
| relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder, |
| relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder, |
| relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder, |
| } |
|
|
| func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) { |
| respBuilder, ok := fetchRespBuilders[relayMode] |
| if !ok { |
| taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest) |
| } |
|
|
| respBody, taskErr := respBuilder(c) |
| if taskErr != nil { |
| return taskErr |
| } |
| if len(respBody) == 0 { |
| respBody = []byte("{\"code\":\"success\",\"data\":null}") |
| } |
|
|
| c.Writer.Header().Set("Content-Type", "application/json") |
| _, err := io.Copy(c.Writer, bytes.NewBuffer(respBody)) |
| if err != nil { |
| taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) |
| return |
| } |
| return |
| } |
|
|
| func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { |
| userId := c.GetInt("id") |
| var condition = struct { |
| IDs []any `json:"ids"` |
| Action string `json:"action"` |
| }{} |
| err := c.BindJSON(&condition) |
| if err != nil { |
| taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest) |
| return |
| } |
| var tasks []any |
| if len(condition.IDs) > 0 { |
| taskModels, err := model.GetByTaskIds(userId, condition.IDs) |
| if err != nil { |
| taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError) |
| return |
| } |
| for _, task := range taskModels { |
| tasks = append(tasks, TaskModel2Dto(task)) |
| } |
| } else { |
| tasks = make([]any, 0) |
| } |
| respBody, err = common.Marshal(dto.TaskResponse[[]any]{ |
| Code: "success", |
| Data: tasks, |
| }) |
| return |
| } |
|
|
| func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { |
| taskId := c.Param("id") |
| userId := c.GetInt("id") |
|
|
| originTask, exist, err := model.GetByTaskId(userId, taskId) |
| if err != nil { |
| taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) |
| return |
| } |
| if !exist { |
| taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) |
| return |
| } |
|
|
| respBody, err = common.Marshal(dto.TaskResponse[any]{ |
| Code: "success", |
| Data: TaskModel2Dto(originTask), |
| }) |
| return |
| } |
|
|
| func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) { |
| taskId := c.Param("task_id") |
| if taskId == "" { |
| taskId = c.GetString("task_id") |
| } |
| userId := c.GetInt("id") |
|
|
| originTask, exist, err := model.GetByTaskId(userId, taskId) |
| if err != nil { |
| taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError) |
| return |
| } |
| if !exist { |
| taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest) |
| return |
| } |
|
|
| isOpenAIVideoAPI := strings.HasPrefix(c.Request.RequestURI, "/v1/videos/") |
|
|
| |
| if realtimeResp := tryRealtimeFetch(originTask, isOpenAIVideoAPI); len(realtimeResp) > 0 { |
| respBody = realtimeResp |
| return |
| } |
|
|
| |
| if isOpenAIVideoAPI { |
| adaptor := GetTaskAdaptor(originTask.Platform) |
| if adaptor == nil { |
| taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("invalid channel id: %d", originTask.ChannelId), "invalid_channel_id", http.StatusBadRequest) |
| return |
| } |
| if converter, ok := adaptor.(channel.OpenAIVideoConverter); ok { |
| openAIVideoData, err := converter.ConvertToOpenAIVideo(originTask) |
| if err != nil { |
| taskResp = service.TaskErrorWrapper(err, "convert_to_openai_video_failed", http.StatusInternalServerError) |
| return |
| } |
| respBody = openAIVideoData |
| return |
| } |
| taskResp = service.TaskErrorWrapperLocal(fmt.Errorf("not_implemented:%s", originTask.Platform), "not_implemented", http.StatusNotImplemented) |
| return |
| } |
|
|
| |
| respBody, err = common.Marshal(dto.TaskResponse[any]{ |
| Code: "success", |
| Data: TaskModel2Dto(originTask), |
| }) |
| if err != nil { |
| taskResp = service.TaskErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError) |
| } |
| return |
| } |
|
|
| |
| |
| |
| func tryRealtimeFetch(task *model.Task, isOpenAIVideoAPI bool) []byte { |
| channelModel, err := model.GetChannelById(task.ChannelId, true) |
| if err != nil { |
| return nil |
| } |
| if channelModel.Type != constant.ChannelTypeVertexAi && channelModel.Type != constant.ChannelTypeGemini { |
| return nil |
| } |
|
|
| baseURL := constant.ChannelBaseURLs[channelModel.Type] |
| if channelModel.GetBaseURL() != "" { |
| baseURL = channelModel.GetBaseURL() |
| } |
| proxy := channelModel.GetSetting().Proxy |
| adaptor := GetTaskAdaptor(constant.TaskPlatform(strconv.Itoa(channelModel.Type))) |
| if adaptor == nil { |
| return nil |
| } |
|
|
| resp, err := adaptor.FetchTask(baseURL, channelModel.Key, map[string]any{ |
| "task_id": task.GetUpstreamTaskID(), |
| "action": task.Action, |
| }, proxy) |
| if err != nil || resp == nil { |
| return nil |
| } |
| defer resp.Body.Close() |
| body, err := io.ReadAll(resp.Body) |
| if err != nil { |
| return nil |
| } |
|
|
| ti, err := adaptor.ParseTaskResult(body) |
| if err != nil || ti == nil { |
| return nil |
| } |
|
|
| snap := task.Snapshot() |
|
|
| |
| if ti.Status != "" { |
| task.Status = model.TaskStatus(ti.Status) |
| } |
| if ti.Progress != "" { |
| task.Progress = ti.Progress |
| } |
| if strings.HasPrefix(ti.Url, "data:") { |
| |
| } else if ti.Url != "" { |
| task.PrivateData.ResultURL = ti.Url |
| } else if task.Status == model.TaskStatusSuccess { |
| |
| task.PrivateData.ResultURL = taskcommon.BuildProxyURL(task.TaskID) |
| } |
|
|
| if !snap.Equal(task.Snapshot()) { |
| _, _ = task.UpdateWithStatus(snap.Status) |
| } |
|
|
| |
| if isOpenAIVideoAPI { |
| return nil |
| } |
|
|
| |
| format := detectVideoFormat(body) |
| out := map[string]any{ |
| "error": nil, |
| "format": format, |
| "metadata": nil, |
| "status": mapTaskStatusToSimple(task.Status), |
| "task_id": task.TaskID, |
| "url": task.GetResultURL(), |
| } |
| respBody, _ := common.Marshal(dto.TaskResponse[any]{ |
| Code: "success", |
| Data: out, |
| }) |
| return respBody |
| } |
|
|
| |
| func detectVideoFormat(rawBody []byte) string { |
| var raw map[string]any |
| if err := common.Unmarshal(rawBody, &raw); err != nil { |
| return "mp4" |
| } |
| respObj, ok := raw["response"].(map[string]any) |
| if !ok { |
| return "mp4" |
| } |
| vids, ok := respObj["videos"].([]any) |
| if !ok || len(vids) == 0 { |
| return "mp4" |
| } |
| v0, ok := vids[0].(map[string]any) |
| if !ok { |
| return "mp4" |
| } |
| mt, ok := v0["mimeType"].(string) |
| if !ok || mt == "" || strings.Contains(mt, "mp4") { |
| return "mp4" |
| } |
| return mt |
| } |
|
|
| |
| func mapTaskStatusToSimple(status model.TaskStatus) string { |
| switch status { |
| case model.TaskStatusSuccess: |
| return "succeeded" |
| case model.TaskStatusFailure: |
| return "failed" |
| case model.TaskStatusQueued, model.TaskStatusSubmitted: |
| return "queued" |
| default: |
| return "processing" |
| } |
| } |
|
|
| func TaskModel2Dto(task *model.Task) *dto.TaskDto { |
| return &dto.TaskDto{ |
| ID: task.ID, |
| CreatedAt: task.CreatedAt, |
| UpdatedAt: task.UpdatedAt, |
| TaskID: task.TaskID, |
| Platform: string(task.Platform), |
| UserId: task.UserId, |
| Group: task.Group, |
| ChannelId: task.ChannelId, |
| Quota: task.Quota, |
| Action: task.Action, |
| Status: string(task.Status), |
| FailReason: task.FailReason, |
| ResultURL: task.GetResultURL(), |
| SubmitTime: task.SubmitTime, |
| StartTime: task.StartTime, |
| FinishTime: task.FinishTime, |
| Progress: task.Progress, |
| Properties: task.Properties, |
| Username: task.Username, |
| Data: task.Data, |
| } |
| } |
|
|