| package middleware |
|
|
| import ( |
| "errors" |
| "fmt" |
| "net/http" |
| "slices" |
| "strconv" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/i18n" |
| "github.com/QuantumNous/new-api/model" |
| relayconstant "github.com/QuantumNous/new-api/relay/constant" |
| "github.com/QuantumNous/new-api/service" |
| "github.com/QuantumNous/new-api/setting/ratio_setting" |
| "github.com/QuantumNous/new-api/types" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| type ModelRequest struct { |
| Model string `json:"model"` |
| Group string `json:"group,omitempty"` |
| } |
|
|
| func Distribute() func(c *gin.Context) { |
| return func(c *gin.Context) { |
| var channel *model.Channel |
| channelId, ok := common.GetContextKey(c, constant.ContextKeyTokenSpecificChannelId) |
| modelRequest, shouldSelectChannel, err := getModelRequest(c) |
| if err != nil { |
| abortWithOpenAiMessage(c, http.StatusBadRequest, i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()})) |
| return |
| } |
| if ok { |
| id, err := strconv.Atoi(channelId.(string)) |
| if err != nil { |
| abortWithOpenAiMessage(c, http.StatusBadRequest, i18n.T(c, i18n.MsgDistributorInvalidChannelId)) |
| return |
| } |
| channel, err = model.GetChannelById(id, true) |
| if err != nil { |
| abortWithOpenAiMessage(c, http.StatusBadRequest, i18n.T(c, i18n.MsgDistributorInvalidChannelId)) |
| return |
| } |
| if channel.Status != common.ChannelStatusEnabled { |
| abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorChannelDisabled)) |
| return |
| } |
| } else { |
| |
| |
| modelLimitEnable := common.GetContextKeyBool(c, constant.ContextKeyTokenModelLimitEnabled) |
| if modelLimitEnable { |
| s, ok := common.GetContextKey(c, constant.ContextKeyTokenModelLimit) |
| if !ok { |
| |
| abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorTokenNoModelAccess)) |
| return |
| } |
| var tokenModelLimit map[string]bool |
| tokenModelLimit, ok = s.(map[string]bool) |
| if !ok { |
| tokenModelLimit = map[string]bool{} |
| } |
| matchName := ratio_setting.FormatMatchingModelName(modelRequest.Model) |
| if _, ok := tokenModelLimit[matchName]; !ok { |
| abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorTokenModelForbidden, map[string]any{"Model": modelRequest.Model})) |
| return |
| } |
| } |
|
|
| if shouldSelectChannel { |
| if modelRequest.Model == "" { |
| abortWithOpenAiMessage(c, http.StatusBadRequest, i18n.T(c, i18n.MsgDistributorModelNameRequired)) |
| return |
| } |
| var selectGroup string |
| usingGroup := common.GetContextKeyString(c, constant.ContextKeyUsingGroup) |
| |
| if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { |
| playgroundRequest := &dto.PlayGroundRequest{} |
| err = common.UnmarshalBodyReusable(c, playgroundRequest) |
| if err != nil { |
| abortWithOpenAiMessage(c, http.StatusBadRequest, i18n.T(c, i18n.MsgDistributorInvalidPlayground, map[string]any{"Error": err.Error()})) |
| return |
| } |
| if playgroundRequest.Group != "" { |
| if !service.GroupInUserUsableGroups(usingGroup, playgroundRequest.Group) && playgroundRequest.Group != usingGroup { |
| abortWithOpenAiMessage(c, http.StatusForbidden, i18n.T(c, i18n.MsgDistributorGroupAccessDenied)) |
| return |
| } |
| usingGroup = playgroundRequest.Group |
| common.SetContextKey(c, constant.ContextKeyUsingGroup, usingGroup) |
| } |
| } |
|
|
| if preferredChannelID, found := service.GetPreferredChannelByAffinity(c, modelRequest.Model, usingGroup); found { |
| preferred, err := model.CacheGetChannel(preferredChannelID) |
| if err == nil && preferred != nil && preferred.Status == common.ChannelStatusEnabled { |
| if usingGroup == "auto" { |
| userGroup := common.GetContextKeyString(c, constant.ContextKeyUserGroup) |
| autoGroups := service.GetUserAutoGroup(userGroup) |
| for _, g := range autoGroups { |
| if model.IsChannelEnabledForGroupModel(g, modelRequest.Model, preferred.Id) { |
| selectGroup = g |
| common.SetContextKey(c, constant.ContextKeyAutoGroup, g) |
| channel = preferred |
| service.MarkChannelAffinityUsed(c, g, preferred.Id) |
| break |
| } |
| } |
| } else if model.IsChannelEnabledForGroupModel(usingGroup, modelRequest.Model, preferred.Id) { |
| channel = preferred |
| selectGroup = usingGroup |
| service.MarkChannelAffinityUsed(c, usingGroup, preferred.Id) |
| } |
| } |
| } |
|
|
| if channel == nil { |
| channel, selectGroup, err = service.CacheGetRandomSatisfiedChannel(&service.RetryParam{ |
| Ctx: c, |
| ModelName: modelRequest.Model, |
| TokenGroup: usingGroup, |
| Retry: common.GetPointer(0), |
| }) |
| if err != nil { |
| showGroup := usingGroup |
| if usingGroup == "auto" { |
| showGroup = fmt.Sprintf("auto(%s)", selectGroup) |
| } |
| message := i18n.T(c, i18n.MsgDistributorGetChannelFailed, map[string]any{"Group": showGroup, "Model": modelRequest.Model, "Error": err.Error()}) |
| |
| |
| |
| |
| |
| abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message, types.ErrorCodeModelNotFound) |
| return |
| } |
| if channel == nil { |
| abortWithOpenAiMessage(c, http.StatusServiceUnavailable, i18n.T(c, i18n.MsgDistributorNoAvailableChannel, map[string]any{"Group": usingGroup, "Model": modelRequest.Model}), types.ErrorCodeModelNotFound) |
| return |
| } |
| } |
| } |
| } |
| common.SetContextKey(c, constant.ContextKeyRequestStartTime, time.Now()) |
| SetupContextForSelectedChannel(c, channel, modelRequest.Model) |
| c.Next() |
| if channel != nil && c.Writer != nil && c.Writer.Status() < http.StatusBadRequest { |
| service.RecordChannelAffinity(c, channel.Id) |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| func getModelFromRequest(c *gin.Context) (*ModelRequest, error) { |
| var modelRequest ModelRequest |
| err := common.UnmarshalBodyReusable(c, &modelRequest) |
| if err != nil { |
| return nil, errors.New(i18n.T(c, i18n.MsgDistributorInvalidRequest, map[string]any{"Error": err.Error()})) |
| } |
| return &modelRequest, nil |
| } |
|
|
| func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { |
| var modelRequest ModelRequest |
| shouldSelectChannel := true |
| var err error |
| if strings.Contains(c.Request.URL.Path, "/mj/") { |
| relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) |
| if relayMode == relayconstant.RelayModeMidjourneyTaskFetch || |
| relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition || |
| relayMode == relayconstant.RelayModeMidjourneyNotify || |
| relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed { |
| shouldSelectChannel = false |
| } else { |
| midjourneyRequest := dto.MidjourneyRequest{} |
| err = common.UnmarshalBodyReusable(c, &midjourneyRequest) |
| if err != nil { |
| return nil, false, errors.New(i18n.T(c, i18n.MsgDistributorInvalidMidjourney, map[string]any{"Error": err.Error()})) |
| } |
| midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) |
| if mjErr != nil { |
| return nil, false, fmt.Errorf("%s", mjErr.Description) |
| } |
| if midjourneyModel == "" { |
| if !success { |
| return nil, false, fmt.Errorf("%s", i18n.T(c, i18n.MsgDistributorInvalidParseModel)) |
| } else { |
| |
| shouldSelectChannel = false |
| } |
| } |
| modelRequest.Model = midjourneyModel |
| } |
| c.Set("relay_mode", relayMode) |
| } else if strings.Contains(c.Request.URL.Path, "/suno/") { |
| relayMode := relayconstant.Path2RelaySuno(c.Request.Method, c.Request.URL.Path) |
| if relayMode == relayconstant.RelayModeSunoFetch || |
| relayMode == relayconstant.RelayModeSunoFetchByID { |
| shouldSelectChannel = false |
| } else { |
| modelName := service.CoverTaskActionToModelName(constant.TaskPlatformSuno, c.Param("action")) |
| modelRequest.Model = modelName |
| } |
| c.Set("platform", string(constant.TaskPlatformSuno)) |
| c.Set("relay_mode", relayMode) |
| } else if strings.Contains(c.Request.URL.Path, "/v1/videos/") && strings.HasSuffix(c.Request.URL.Path, "/remix") { |
| relayMode := relayconstant.RelayModeVideoSubmit |
| c.Set("relay_mode", relayMode) |
| shouldSelectChannel = false |
| } else if strings.Contains(c.Request.URL.Path, "/v1/videos") { |
| |
| |
| |
| |
| |
| relayMode := relayconstant.RelayModeUnknown |
| if c.Request.Method == http.MethodPost { |
| relayMode = relayconstant.RelayModeVideoSubmit |
| req, err := getModelFromRequest(c) |
| if err != nil { |
| return nil, false, err |
| } |
| if req != nil { |
| modelRequest.Model = req.Model |
| } |
| } else if c.Request.Method == http.MethodGet { |
| relayMode = relayconstant.RelayModeVideoFetchByID |
| shouldSelectChannel = false |
| } |
| c.Set("relay_mode", relayMode) |
| } else if strings.Contains(c.Request.URL.Path, "/v1/video/generations") { |
| relayMode := relayconstant.RelayModeUnknown |
| if c.Request.Method == http.MethodPost { |
| req, err := getModelFromRequest(c) |
| if err != nil { |
| return nil, false, err |
| } |
| modelRequest.Model = req.Model |
| relayMode = relayconstant.RelayModeVideoSubmit |
| } else if c.Request.Method == http.MethodGet { |
| relayMode = relayconstant.RelayModeVideoFetchByID |
| shouldSelectChannel = false |
| } |
| if _, ok := c.Get("relay_mode"); !ok { |
| c.Set("relay_mode", relayMode) |
| } |
| } else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") || strings.HasPrefix(c.Request.URL.Path, "/v1/models/") { |
| |
| relayMode := relayconstant.RelayModeGemini |
| modelName := extractModelNameFromGeminiPath(c.Request.URL.Path) |
| if modelName != "" { |
| modelRequest.Model = modelName |
| } |
| c.Set("relay_mode", relayMode) |
| } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.Contains(c.Request.Header.Get("Content-Type"), "multipart/form-data") { |
| req, err := getModelFromRequest(c) |
| if err != nil { |
| return nil, false, err |
| } |
| modelRequest.Model = req.Model |
| } |
| if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") { |
| |
| modelRequest.Model = c.Query("model") |
| } |
| if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { |
| if modelRequest.Model == "" { |
| modelRequest.Model = "text-moderation-stable" |
| } |
| } |
| if strings.HasSuffix(c.Request.URL.Path, "embeddings") { |
| if modelRequest.Model == "" { |
| modelRequest.Model = c.Param("model") |
| } |
| } |
| if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { |
| modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") |
| } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") { |
| |
| contentType := c.ContentType() |
| if slices.Contains([]string{gin.MIMEPOSTForm, gin.MIMEMultipartPOSTForm}, contentType) { |
| req, err := getModelFromRequest(c) |
| if err == nil && req.Model != "" { |
| modelRequest.Model = req.Model |
| } |
| } |
| } |
| if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { |
| relayMode := relayconstant.RelayModeAudioSpeech |
| if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { |
|
|
| modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") |
| } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { |
| |
| if req, err := getModelFromRequest(c); err == nil && req.Model != "" { |
| modelRequest.Model = req.Model |
| } |
| modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") |
| relayMode = relayconstant.RelayModeAudioTranslation |
| } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { |
| |
| if req, err := getModelFromRequest(c); err == nil && req.Model != "" { |
| modelRequest.Model = req.Model |
| } |
| modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") |
| relayMode = relayconstant.RelayModeAudioTranscription |
| } |
| c.Set("relay_mode", relayMode) |
| } |
| if strings.HasPrefix(c.Request.URL.Path, "/pg/chat/completions") { |
| |
| req, err := getModelFromRequest(c) |
| if err != nil { |
| return nil, false, err |
| } |
| modelRequest.Model = req.Model |
| modelRequest.Group = req.Group |
| common.SetContextKey(c, constant.ContextKeyTokenGroup, modelRequest.Group) |
| } |
|
|
| if strings.HasPrefix(c.Request.URL.Path, "/v1/responses/compact") && modelRequest.Model != "" { |
| modelRequest.Model = ratio_setting.WithCompactModelSuffix(modelRequest.Model) |
| } |
| return &modelRequest, shouldSelectChannel, nil |
| } |
|
|
| func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) *types.NewAPIError { |
| c.Set("original_model", modelName) |
| if channel == nil { |
| return types.NewError(errors.New("channel is nil"), types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) |
| } |
| common.SetContextKey(c, constant.ContextKeyChannelId, channel.Id) |
| common.SetContextKey(c, constant.ContextKeyChannelName, channel.Name) |
| common.SetContextKey(c, constant.ContextKeyChannelType, channel.Type) |
| common.SetContextKey(c, constant.ContextKeyChannelCreateTime, channel.CreatedTime) |
| common.SetContextKey(c, constant.ContextKeyChannelSetting, channel.GetSetting()) |
| common.SetContextKey(c, constant.ContextKeyChannelOtherSetting, channel.GetOtherSettings()) |
| paramOverride := channel.GetParamOverride() |
| headerOverride := channel.GetHeaderOverride() |
| if mergedParam, applied := service.ApplyChannelAffinityOverrideTemplate(c, paramOverride); applied { |
| paramOverride = mergedParam |
| } |
| common.SetContextKey(c, constant.ContextKeyChannelParamOverride, paramOverride) |
| common.SetContextKey(c, constant.ContextKeyChannelHeaderOverride, headerOverride) |
| if nil != channel.OpenAIOrganization && *channel.OpenAIOrganization != "" { |
| common.SetContextKey(c, constant.ContextKeyChannelOrganization, *channel.OpenAIOrganization) |
| } |
| common.SetContextKey(c, constant.ContextKeyChannelAutoBan, channel.GetAutoBan()) |
| common.SetContextKey(c, constant.ContextKeyChannelModelMapping, channel.GetModelMapping()) |
| common.SetContextKey(c, constant.ContextKeyChannelStatusCodeMapping, channel.GetStatusCodeMapping()) |
|
|
| key, index, newAPIError := channel.GetNextEnabledKey() |
| if newAPIError != nil { |
| return newAPIError |
| } |
| if channel.ChannelInfo.IsMultiKey { |
| common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, true) |
| common.SetContextKey(c, constant.ContextKeyChannelMultiKeyIndex, index) |
| } else { |
| |
| common.SetContextKey(c, constant.ContextKeyChannelIsMultiKey, false) |
| } |
| |
| common.SetContextKey(c, constant.ContextKeyChannelKey, key) |
| common.SetContextKey(c, constant.ContextKeyChannelBaseUrl, channel.GetBaseURL()) |
|
|
| common.SetContextKey(c, constant.ContextKeySystemPromptOverride, false) |
|
|
| |
| switch channel.Type { |
| case constant.ChannelTypeAzure: |
| c.Set("api_version", channel.Other) |
| case constant.ChannelTypeVertexAi: |
| c.Set("region", channel.Other) |
| case constant.ChannelTypeXunfei: |
| c.Set("api_version", channel.Other) |
| case constant.ChannelTypeGemini: |
| c.Set("api_version", channel.Other) |
| case constant.ChannelTypeAli: |
| c.Set("plugin", channel.Other) |
| case constant.ChannelCloudflare: |
| c.Set("api_version", channel.Other) |
| case constant.ChannelTypeMokaAI: |
| c.Set("api_version", channel.Other) |
| case constant.ChannelTypeCoze: |
| c.Set("bot_id", channel.Other) |
| } |
| return nil |
| } |
|
|
| |
| |
| |
| func extractModelNameFromGeminiPath(path string) string { |
| |
| modelsPrefix := "/models/" |
| modelsIndex := strings.Index(path, modelsPrefix) |
| if modelsIndex == -1 { |
| return "" |
| } |
|
|
| |
| startIndex := modelsIndex + len(modelsPrefix) |
| if startIndex >= len(path) { |
| return "" |
| } |
|
|
| |
| colonIndex := strings.Index(path[startIndex:], ":") |
| if colonIndex == -1 { |
| |
| return path[startIndex:] |
| } |
|
|
| |
| return path[startIndex : startIndex+colonIndex] |
| } |
|
|