| package controller |
|
|
| import ( |
| "bytes" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "io" |
| "math" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "strconv" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/middleware" |
| "github.com/QuantumNous/new-api/model" |
| "github.com/QuantumNous/new-api/relay" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| relayconstant "github.com/QuantumNous/new-api/relay/constant" |
| "github.com/QuantumNous/new-api/relay/helper" |
| "github.com/QuantumNous/new-api/service" |
| "github.com/QuantumNous/new-api/setting/operation_setting" |
| "github.com/QuantumNous/new-api/setting/ratio_setting" |
| "github.com/QuantumNous/new-api/types" |
|
|
| "github.com/bytedance/gopkg/util/gopool" |
| "github.com/samber/lo" |
| "github.com/tidwall/gjson" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| type testResult struct { |
| context *gin.Context |
| localErr error |
| newAPIError *types.NewAPIError |
| } |
|
|
| func normalizeChannelTestEndpoint(channel *model.Channel, modelName, endpointType string) string { |
| normalized := strings.TrimSpace(endpointType) |
| if normalized != "" { |
| return normalized |
| } |
| if strings.HasSuffix(modelName, ratio_setting.CompactModelSuffix) { |
| return string(constant.EndpointTypeOpenAIResponseCompact) |
| } |
| if channel != nil && channel.Type == constant.ChannelTypeCodex { |
| return string(constant.EndpointTypeOpenAIResponse) |
| } |
| return normalized |
| } |
|
|
| func testChannel(channel *model.Channel, testModel string, endpointType string, isStream bool) testResult { |
| tik := time.Now() |
| var unsupportedTestChannelTypes = []int{ |
| constant.ChannelTypeMidjourney, |
| constant.ChannelTypeMidjourneyPlus, |
| constant.ChannelTypeSunoAPI, |
| constant.ChannelTypeKling, |
| constant.ChannelTypeJimeng, |
| constant.ChannelTypeDoubaoVideo, |
| constant.ChannelTypeVidu, |
| } |
| if lo.Contains(unsupportedTestChannelTypes, channel.Type) { |
| channelTypeName := constant.GetChannelTypeName(channel.Type) |
| return testResult{ |
| localErr: fmt.Errorf("%s channel test is not supported", channelTypeName), |
| } |
| } |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| testModel = strings.TrimSpace(testModel) |
| if testModel == "" { |
| if channel.TestModel != nil && *channel.TestModel != "" { |
| testModel = strings.TrimSpace(*channel.TestModel) |
| } else { |
| models := channel.GetModels() |
| if len(models) > 0 { |
| testModel = strings.TrimSpace(models[0]) |
| } |
| if testModel == "" { |
| testModel = "gpt-4o-mini" |
| } |
| } |
| } |
|
|
| endpointType = normalizeChannelTestEndpoint(channel, testModel, endpointType) |
|
|
| requestPath := "/v1/chat/completions" |
|
|
| |
| if endpointType != "" { |
| if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok { |
| requestPath = endpointInfo.Path |
| } |
| } else { |
| |
|
|
| if strings.Contains(strings.ToLower(testModel), "rerank") { |
| requestPath = "/v1/rerank" |
| } |
|
|
| |
| if strings.Contains(strings.ToLower(testModel), "embedding") || |
| strings.HasPrefix(testModel, "m3e") || |
| strings.Contains(testModel, "bge-") || |
| strings.Contains(testModel, "embed") || |
| channel.Type == constant.ChannelTypeMokaAI { |
| requestPath = "/v1/embeddings" |
| } |
|
|
| |
| if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") { |
| requestPath = "/v1/images/generations" |
| } |
|
|
| |
| if strings.Contains(strings.ToLower(testModel), "codex") { |
| requestPath = "/v1/responses" |
| } |
|
|
| |
| if strings.HasSuffix(testModel, ratio_setting.CompactModelSuffix) { |
| requestPath = "/v1/responses/compact" |
| } |
| } |
| if strings.HasPrefix(requestPath, "/v1/responses/compact") { |
| testModel = ratio_setting.WithCompactModelSuffix(testModel) |
| } |
|
|
| c.Request = &http.Request{ |
| Method: "POST", |
| URL: &url.URL{Path: requestPath}, |
| Body: nil, |
| Header: make(http.Header), |
| } |
|
|
| cache, err := model.GetUserCache(1) |
| if err != nil { |
| return testResult{ |
| localErr: err, |
| newAPIError: nil, |
| } |
| } |
| cache.WriteContext(c) |
|
|
| |
| c.Request.Header.Set("Content-Type", "application/json") |
| c.Set("channel", channel.Type) |
| c.Set("base_url", channel.GetBaseURL()) |
| group, _ := model.GetUserGroup(1, false) |
| c.Set("group", group) |
|
|
| newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel) |
| if newAPIError != nil { |
| return testResult{ |
| context: c, |
| localErr: newAPIError, |
| newAPIError: newAPIError, |
| } |
| } |
|
|
| |
| var relayFormat types.RelayFormat |
| if endpointType != "" { |
| |
| switch constant.EndpointType(endpointType) { |
| case constant.EndpointTypeOpenAI: |
| relayFormat = types.RelayFormatOpenAI |
| case constant.EndpointTypeOpenAIResponse: |
| relayFormat = types.RelayFormatOpenAIResponses |
| case constant.EndpointTypeOpenAIResponseCompact: |
| relayFormat = types.RelayFormatOpenAIResponsesCompaction |
| case constant.EndpointTypeAnthropic: |
| relayFormat = types.RelayFormatClaude |
| case constant.EndpointTypeGemini: |
| relayFormat = types.RelayFormatGemini |
| case constant.EndpointTypeJinaRerank: |
| relayFormat = types.RelayFormatRerank |
| case constant.EndpointTypeImageGeneration: |
| relayFormat = types.RelayFormatOpenAIImage |
| case constant.EndpointTypeEmbeddings: |
| relayFormat = types.RelayFormatEmbedding |
| default: |
| relayFormat = types.RelayFormatOpenAI |
| } |
| } else { |
| |
| relayFormat = types.RelayFormatOpenAI |
| if c.Request.URL.Path == "/v1/embeddings" { |
| relayFormat = types.RelayFormatEmbedding |
| } |
| if c.Request.URL.Path == "/v1/images/generations" { |
| relayFormat = types.RelayFormatOpenAIImage |
| } |
| if c.Request.URL.Path == "/v1/messages" { |
| relayFormat = types.RelayFormatClaude |
| } |
| if strings.Contains(c.Request.URL.Path, "/v1beta/models") { |
| relayFormat = types.RelayFormatGemini |
| } |
| if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" { |
| relayFormat = types.RelayFormatRerank |
| } |
| if c.Request.URL.Path == "/v1/responses" { |
| relayFormat = types.RelayFormatOpenAIResponses |
| } |
| if strings.HasPrefix(c.Request.URL.Path, "/v1/responses/compact") { |
| relayFormat = types.RelayFormatOpenAIResponsesCompaction |
| } |
| } |
|
|
| request := buildTestRequest(testModel, endpointType, channel, isStream) |
|
|
| info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil) |
|
|
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed), |
| } |
| } |
|
|
| info.IsChannelTest = true |
| info.InitChannelMeta(c) |
|
|
| err = helper.ModelMappedHelper(c, info, request) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError), |
| } |
| } |
|
|
| testModel = info.UpstreamModelName |
| |
| request.SetModelName(testModel) |
|
|
| apiType, _ := common.ChannelType2APIType(channel.Type) |
| if info.RelayMode == relayconstant.RelayModeResponsesCompact && |
| apiType != constant.APITypeOpenAI && |
| apiType != constant.APITypeCodex { |
| return testResult{ |
| context: c, |
| localErr: fmt.Errorf("responses compaction test only supports openai/codex channels, got api type %d", apiType), |
| newAPIError: types.NewError(fmt.Errorf("unsupported api type: %d", apiType), types.ErrorCodeInvalidApiType), |
| } |
| } |
| adaptor := relay.GetAdaptor(apiType) |
| if adaptor == nil { |
| return testResult{ |
| context: c, |
| localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), |
| newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType), |
| } |
| } |
|
|
| |
| |
| |
| common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString())) |
|
|
| priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta()) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeModelPriceError), |
| } |
| } |
|
|
| adaptor.Init(info) |
|
|
| var convertedRequest any |
| |
| switch info.RelayMode { |
| case relayconstant.RelayModeEmbeddings: |
| |
| if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok { |
| convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid embedding request type"), |
| newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| case relayconstant.RelayModeImagesGenerations: |
| |
| if imageReq, ok := request.(*dto.ImageRequest); ok { |
| convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid image request type"), |
| newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| case relayconstant.RelayModeRerank: |
| |
| if rerankReq, ok := request.(*dto.RerankRequest); ok { |
| convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid rerank request type"), |
| newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| case relayconstant.RelayModeResponses: |
| |
| if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok { |
| convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid response request type"), |
| newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| case relayconstant.RelayModeResponsesCompact: |
| |
| switch req := request.(type) { |
| case *dto.OpenAIResponsesCompactionRequest: |
| convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, dto.OpenAIResponsesRequest{ |
| Model: req.Model, |
| Input: req.Input, |
| Instructions: req.Instructions, |
| PreviousResponseID: req.PreviousResponseID, |
| }) |
| case *dto.OpenAIResponsesRequest: |
| convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *req) |
| default: |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid response compaction request type"), |
| newAPIError: types.NewError(errors.New("invalid response compaction request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| default: |
| |
| if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok { |
| convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq) |
| } else { |
| return testResult{ |
| context: c, |
| localErr: errors.New("invalid general request type"), |
| newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| } |
|
|
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed), |
| } |
| } |
| jsonData, err := common.Marshal(convertedRequest) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed), |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if len(info.ParamOverride) > 0 { |
| jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) |
| if err != nil { |
| if fixedErr, ok := relaycommon.AsParamOverrideReturnError(err); ok { |
| return testResult{ |
| context: c, |
| localErr: fixedErr, |
| newAPIError: relaycommon.NewAPIErrorFromParamOverride(fixedErr), |
| } |
| } |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid), |
| } |
| } |
| } |
|
|
| requestBody := bytes.NewBuffer(jsonData) |
| c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData)) |
| resp, err := adaptor.DoRequest(c, info, requestBody) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError), |
| } |
| } |
| var httpResp *http.Response |
| if resp != nil { |
| httpResp = resp.(*http.Response) |
| if httpResp.StatusCode != http.StatusOK { |
| err := service.RelayErrorHandler(c.Request.Context(), httpResp, true) |
| common.SysError(fmt.Sprintf( |
| "channel test bad response: channel_id=%d name=%s type=%d model=%s endpoint_type=%s status=%d err=%v", |
| channel.Id, |
| channel.Name, |
| channel.Type, |
| testModel, |
| endpointType, |
| httpResp.StatusCode, |
| err, |
| )) |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError), |
| } |
| } |
| } |
| usageA, respErr := adaptor.DoResponse(c, httpResp, info) |
| if respErr != nil { |
| return testResult{ |
| context: c, |
| localErr: respErr, |
| newAPIError: respErr, |
| } |
| } |
| usage, usageErr := coerceTestUsage(usageA, isStream, info.GetEstimatePromptTokens()) |
| if usageErr != nil { |
| return testResult{ |
| context: c, |
| localErr: usageErr, |
| newAPIError: types.NewOpenAIError(usageErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), |
| } |
| } |
| result := w.Result() |
| respBody, err := readTestResponseBody(result.Body, isStream) |
| if err != nil { |
| return testResult{ |
| context: c, |
| localErr: err, |
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), |
| } |
| } |
| if bodyErr := detectErrorFromTestResponseBody(respBody); bodyErr != nil { |
| return testResult{ |
| context: c, |
| localErr: bodyErr, |
| newAPIError: types.NewOpenAIError(bodyErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), |
| } |
| } |
| info.SetEstimatePromptTokens(usage.PromptTokens) |
|
|
| quota := 0 |
| if !priceData.UsePrice { |
| quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) |
| quota = int(math.Round(float64(quota) * priceData.ModelRatio)) |
| if priceData.ModelRatio != 0 && quota <= 0 { |
| quota = 1 |
| } |
| } else { |
| quota = int(priceData.ModelPrice * common.QuotaPerUnit) |
| } |
| tok := time.Now() |
| milliseconds := tok.Sub(tik).Milliseconds() |
| consumedTime := float64(milliseconds) / 1000.0 |
| other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio, |
| usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio) |
| model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{ |
| ChannelId: channel.Id, |
| PromptTokens: usage.PromptTokens, |
| CompletionTokens: usage.CompletionTokens, |
| ModelName: info.OriginModelName, |
| TokenName: "模型测试", |
| Quota: quota, |
| Content: "模型测试", |
| UseTimeSeconds: int(consumedTime), |
| IsStream: info.IsStream, |
| Group: info.UsingGroup, |
| Other: other, |
| }) |
| common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) |
| return testResult{ |
| context: c, |
| localErr: nil, |
| newAPIError: nil, |
| } |
| } |
|
|
| func coerceTestUsage(usageAny any, isStream bool, estimatePromptTokens int) (*dto.Usage, error) { |
| switch u := usageAny.(type) { |
| case *dto.Usage: |
| return u, nil |
| case dto.Usage: |
| return &u, nil |
| case nil: |
| if !isStream { |
| return nil, errors.New("usage is nil") |
| } |
| usage := &dto.Usage{ |
| PromptTokens: estimatePromptTokens, |
| } |
| usage.TotalTokens = usage.PromptTokens |
| return usage, nil |
| default: |
| if !isStream { |
| return nil, fmt.Errorf("invalid usage type: %T", usageAny) |
| } |
| usage := &dto.Usage{ |
| PromptTokens: estimatePromptTokens, |
| } |
| usage.TotalTokens = usage.PromptTokens |
| return usage, nil |
| } |
| } |
|
|
| func readTestResponseBody(body io.ReadCloser, isStream bool) ([]byte, error) { |
| defer func() { _ = body.Close() }() |
| const maxStreamLogBytes = 8 << 10 |
| if isStream { |
| return io.ReadAll(io.LimitReader(body, maxStreamLogBytes)) |
| } |
| return io.ReadAll(body) |
| } |
|
|
| func detectErrorFromTestResponseBody(respBody []byte) error { |
| b := bytes.TrimSpace(respBody) |
| if len(b) == 0 { |
| return nil |
| } |
| if message := detectErrorMessageFromJSONBytes(b); message != "" { |
| return fmt.Errorf("upstream error: %s", message) |
| } |
|
|
| for _, line := range bytes.Split(b, []byte{'\n'}) { |
| line = bytes.TrimSpace(line) |
| if len(line) == 0 { |
| continue |
| } |
| if !bytes.HasPrefix(line, []byte("data:")) { |
| continue |
| } |
| payload := bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data:"))) |
| if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) { |
| continue |
| } |
| if message := detectErrorMessageFromJSONBytes(payload); message != "" { |
| return fmt.Errorf("upstream error: %s", message) |
| } |
| } |
|
|
| return nil |
| } |
|
|
| func detectErrorMessageFromJSONBytes(jsonBytes []byte) string { |
| if len(jsonBytes) == 0 { |
| return "" |
| } |
| if jsonBytes[0] != '{' && jsonBytes[0] != '[' { |
| return "" |
| } |
| errVal := gjson.GetBytes(jsonBytes, "error") |
| if !errVal.Exists() || errVal.Type == gjson.Null { |
| return "" |
| } |
|
|
| message := gjson.GetBytes(jsonBytes, "error.message").String() |
| if message == "" { |
| message = gjson.GetBytes(jsonBytes, "error.error.message").String() |
| } |
| if message == "" && errVal.Type == gjson.String { |
| message = errVal.String() |
| } |
| if message == "" { |
| message = errVal.Raw |
| } |
| message = strings.TrimSpace(message) |
| if message == "" { |
| return "upstream returned error payload" |
| } |
| return message |
| } |
|
|
| func buildTestRequest(model string, endpointType string, channel *model.Channel, isStream bool) dto.Request { |
| testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`) |
|
|
| |
| if endpointType != "" { |
| switch constant.EndpointType(endpointType) { |
| case constant.EndpointTypeEmbeddings: |
| |
| return &dto.EmbeddingRequest{ |
| Model: model, |
| Input: []any{"hello world"}, |
| } |
| case constant.EndpointTypeImageGeneration: |
| |
| return &dto.ImageRequest{ |
| Model: model, |
| Prompt: "a cute cat", |
| N: lo.ToPtr(uint(1)), |
| Size: "1024x1024", |
| } |
| case constant.EndpointTypeJinaRerank: |
| |
| return &dto.RerankRequest{ |
| Model: model, |
| Query: "What is Deep Learning?", |
| Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."}, |
| TopN: lo.ToPtr(2), |
| } |
| case constant.EndpointTypeOpenAIResponse: |
| |
| return &dto.OpenAIResponsesRequest{ |
| Model: model, |
| Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), |
| Stream: lo.ToPtr(isStream), |
| } |
| case constant.EndpointTypeOpenAIResponseCompact: |
| |
| return &dto.OpenAIResponsesCompactionRequest{ |
| Model: model, |
| Input: testResponsesInput, |
| } |
| case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI: |
| |
| maxTokens := uint(16) |
| if constant.EndpointType(endpointType) == constant.EndpointTypeGemini { |
| maxTokens = 3000 |
| } |
| req := &dto.GeneralOpenAIRequest{ |
| Model: model, |
| Stream: lo.ToPtr(isStream), |
| Messages: []dto.Message{ |
| { |
| Role: "user", |
| Content: "hi", |
| }, |
| }, |
| MaxTokens: lo.ToPtr(maxTokens), |
| } |
| if isStream { |
| req.StreamOptions = &dto.StreamOptions{IncludeUsage: true} |
| } |
| return req |
| } |
| } |
|
|
| |
| if strings.Contains(strings.ToLower(model), "rerank") { |
| return &dto.RerankRequest{ |
| Model: model, |
| Query: "What is Deep Learning?", |
| Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."}, |
| TopN: lo.ToPtr(2), |
| } |
| } |
|
|
| |
| if strings.Contains(strings.ToLower(model), "embedding") || |
| strings.HasPrefix(model, "m3e") || |
| strings.Contains(model, "bge-") { |
| |
| return &dto.EmbeddingRequest{ |
| Model: model, |
| Input: []any{"hello world"}, |
| } |
| } |
|
|
| |
| if strings.HasSuffix(model, ratio_setting.CompactModelSuffix) { |
| return &dto.OpenAIResponsesCompactionRequest{ |
| Model: model, |
| Input: testResponsesInput, |
| } |
| } |
|
|
| |
| if strings.Contains(strings.ToLower(model), "codex") { |
| return &dto.OpenAIResponsesRequest{ |
| Model: model, |
| Input: json.RawMessage(`[{"role":"user","content":"hi"}]`), |
| Stream: lo.ToPtr(isStream), |
| } |
| } |
|
|
| |
| testRequest := &dto.GeneralOpenAIRequest{ |
| Model: model, |
| Stream: lo.ToPtr(isStream), |
| Messages: []dto.Message{ |
| { |
| Role: "user", |
| Content: "hi", |
| }, |
| }, |
| } |
| if isStream { |
| testRequest.StreamOptions = &dto.StreamOptions{IncludeUsage: true} |
| } |
|
|
| if strings.HasPrefix(model, "o") { |
| testRequest.MaxCompletionTokens = lo.ToPtr(uint(16)) |
| } else if strings.Contains(model, "thinking") { |
| if !strings.Contains(model, "claude") { |
| testRequest.MaxTokens = lo.ToPtr(uint(50)) |
| } |
| } else if strings.Contains(model, "gemini") { |
| testRequest.MaxTokens = lo.ToPtr(uint(3000)) |
| } else { |
| testRequest.MaxTokens = lo.ToPtr(uint(16)) |
| } |
|
|
| return testRequest |
| } |
|
|
| func TestChannel(c *gin.Context) { |
| channelId, err := strconv.Atoi(c.Param("id")) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| channel, err := model.CacheGetChannel(channelId) |
| if err != nil { |
| channel, err = model.GetChannelById(channelId, true) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| } |
| |
| |
| |
| |
| |
| testModel := c.Query("model") |
| endpointType := c.Query("endpoint_type") |
| isStream, _ := strconv.ParseBool(c.Query("stream")) |
| tik := time.Now() |
| result := testChannel(channel, testModel, endpointType, isStream) |
| if result.localErr != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": result.localErr.Error(), |
| "time": 0.0, |
| }) |
| return |
| } |
| tok := time.Now() |
| milliseconds := tok.Sub(tik).Milliseconds() |
| go channel.UpdateResponseTime(milliseconds) |
| consumedTime := float64(milliseconds) / 1000.0 |
| if result.newAPIError != nil { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": result.newAPIError.Error(), |
| "time": consumedTime, |
| }) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "time": consumedTime, |
| }) |
| } |
|
|
| var testAllChannelsLock sync.Mutex |
| var testAllChannelsRunning bool = false |
|
|
| func testAllChannels(notify bool) error { |
|
|
| testAllChannelsLock.Lock() |
| if testAllChannelsRunning { |
| testAllChannelsLock.Unlock() |
| return errors.New("测试已在运行中") |
| } |
| testAllChannelsRunning = true |
| testAllChannelsLock.Unlock() |
| channels, getChannelErr := model.GetAllChannels(0, 0, true, false) |
| if getChannelErr != nil { |
| return getChannelErr |
| } |
| var disableThreshold = int64(common.ChannelDisableThreshold * 1000) |
| if disableThreshold == 0 { |
| disableThreshold = 10000000 |
| } |
| gopool.Go(func() { |
| |
| defer func() { |
| testAllChannelsLock.Lock() |
| testAllChannelsRunning = false |
| testAllChannelsLock.Unlock() |
| }() |
|
|
| for _, channel := range channels { |
| if channel.Status == common.ChannelStatusManuallyDisabled { |
| continue |
| } |
| isChannelEnabled := channel.Status == common.ChannelStatusEnabled |
| tik := time.Now() |
| result := testChannel(channel, "", "", false) |
| tok := time.Now() |
| milliseconds := tok.Sub(tik).Milliseconds() |
|
|
| shouldBanChannel := false |
| newAPIError := result.newAPIError |
| |
| if newAPIError != nil { |
| shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError) |
| } |
|
|
| |
| if common.AutomaticDisableChannelEnabled && !shouldBanChannel { |
| if milliseconds > disableThreshold { |
| err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0) |
| newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout) |
| shouldBanChannel = true |
| } |
| } |
|
|
| |
| if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { |
| processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError) |
| } |
|
|
| |
| if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) { |
| service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name) |
| } |
|
|
| channel.UpdateResponseTime(milliseconds) |
| time.Sleep(common.RequestInterval) |
| } |
|
|
| if notify { |
| service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") |
| } |
| }) |
| return nil |
| } |
|
|
| func TestAllChannels(c *gin.Context) { |
| err := testAllChannels(true) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| }) |
| } |
|
|
| var autoTestChannelsOnce sync.Once |
|
|
| func AutomaticallyTestChannels() { |
| |
| if !common.IsMasterNode { |
| return |
| } |
| autoTestChannelsOnce.Do(func() { |
| for { |
| if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled { |
| time.Sleep(1 * time.Minute) |
| continue |
| } |
| for { |
| frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes |
| time.Sleep(time.Duration(int(math.Round(frequency))) * time.Minute) |
| common.SysLog(fmt.Sprintf("automatically test channels with interval %f minutes", frequency)) |
| common.SysLog("automatically testing all channels") |
| _ = testAllChannels(false) |
| common.SysLog("automatically channel test finished") |
| if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled { |
| break |
| } |
| } |
| } |
| }) |
| } |
|
|