| package handler |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "errors" |
| "net/http" |
| "net/http/httptest" |
| "strings" |
| "testing" |
| "time" |
|
|
| pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" |
| "github.com/Wei-Shaw/sub2api/internal/server/middleware" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
| coderws "github.com/coder/websocket" |
| "github.com/gin-gonic/gin" |
| "github.com/stretchr/testify/assert" |
| "github.com/stretchr/testify/require" |
| "github.com/tidwall/gjson" |
| "github.com/tidwall/sjson" |
| ) |
|
|
| func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) { |
| tests := []struct { |
| name string |
| errType string |
| message string |
| }{ |
| { |
| name: "包含双引号的消息", |
| errType: "server_error", |
| message: `upstream returned "invalid" response`, |
| }, |
| { |
| name: "包含反斜杠的消息", |
| errType: "server_error", |
| message: `path C:\Users\test\file.txt not found`, |
| }, |
| { |
| name: "包含双引号和反斜杠的消息", |
| errType: "upstream_error", |
| message: `error parsing "key\value": unexpected token`, |
| }, |
| { |
| name: "包含换行符的消息", |
| errType: "server_error", |
| message: "line1\nline2\ttab", |
| }, |
| { |
| name: "普通消息", |
| errType: "upstream_error", |
| message: "Upstream service temporarily unavailable", |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
|
|
| h := &OpenAIGatewayHandler{} |
| h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true) |
|
|
| body := w.Body.String() |
|
|
| |
| assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头") |
| assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾") |
|
|
| |
| lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") |
| require.Len(t, lines, 2, "应有 event 行和 data 行") |
| dataLine := lines[1] |
| require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头") |
| jsonStr := strings.TrimPrefix(dataLine, "data: ") |
|
|
| |
| var parsed map[string]any |
| err := json.Unmarshal([]byte(jsonStr), &parsed) |
| require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr) |
|
|
| |
| errorObj, ok := parsed["error"].(map[string]any) |
| require.True(t, ok, "应包含 error 对象") |
| assert.Equal(t, tt.errType, errorObj["type"]) |
| assert.Equal(t, tt.message, errorObj["message"]) |
| }) |
| } |
| } |
|
|
| func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
|
|
| h := &OpenAIGatewayHandler{} |
| h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false) |
|
|
| |
| assert.Equal(t, http.StatusBadGateway, w.Code) |
|
|
| var parsed map[string]any |
| err := json.Unmarshal(w.Body.Bytes(), &parsed) |
| require.NoError(t, err) |
| errorObj, ok := parsed["error"].(map[string]any) |
| require.True(t, ok) |
| assert.Equal(t, "upstream_error", errorObj["type"]) |
| assert.Equal(t, "test error", errorObj["message"]) |
| } |
|
|
| func TestReadRequestBodyWithPrealloc(t *testing.T) { |
| payload := `{"model":"gpt-5","input":"hello"}` |
| req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload)) |
| req.ContentLength = int64(len(payload)) |
|
|
| body, err := pkghttputil.ReadRequestBodyWithPrealloc(req) |
| require.NoError(t, err) |
| require.Equal(t, payload, string(body)) |
| } |
|
|
| func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) { |
| rec := httptest.NewRecorder() |
| req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8))) |
| req.Body = http.MaxBytesReader(rec, req.Body, 4) |
|
|
| _, err := pkghttputil.ReadRequestBodyWithPrealloc(req) |
| require.Error(t, err) |
| var maxErr *http.MaxBytesError |
| require.ErrorAs(t, err, &maxErr) |
| } |
|
|
| func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
|
|
| h := &OpenAIGatewayHandler{} |
| wrote := h.ensureForwardErrorResponse(c, false) |
|
|
| require.True(t, wrote) |
| require.Equal(t, http.StatusBadGateway, w.Code) |
|
|
| var parsed map[string]any |
| err := json.Unmarshal(w.Body.Bytes(), &parsed) |
| require.NoError(t, err) |
| errorObj, ok := parsed["error"].(map[string]any) |
| require.True(t, ok) |
| assert.Equal(t, "upstream_error", errorObj["type"]) |
| assert.Equal(t, "Upstream request failed", errorObj["message"]) |
| } |
|
|
| func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.String(http.StatusTeapot, "already written") |
|
|
| h := &OpenAIGatewayHandler{} |
| wrote := h.ensureForwardErrorResponse(c, false) |
|
|
| require.False(t, wrote) |
| require.Equal(t, http.StatusTeapot, w.Code) |
| assert.Equal(t, "already written", w.Body.String()) |
| } |
|
|
| func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| t.Run("fallback_written_should_not_downgrade", func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true)) |
| }) |
|
|
| t.Run("context_nil_should_not_downgrade", func(t *testing.T) { |
| require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false)) |
| }) |
|
|
| t.Run("response_not_written_should_not_downgrade", func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) |
| }) |
|
|
| t.Run("response_already_written_should_downgrade", func(t *testing.T) { |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/", nil) |
| c.String(http.StatusForbidden, "already written") |
| require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) |
| }) |
| } |
|
|
| func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) |
|
|
| h := &OpenAIGatewayHandler{} |
| streamStarted := false |
| require.NotPanics(t, func() { |
| func() { |
| defer h.recoverResponsesPanic(c, &streamStarted) |
| panic("test panic") |
| }() |
| }) |
|
|
| require.Equal(t, http.StatusBadGateway, w.Code) |
|
|
| var parsed map[string]any |
| err := json.Unmarshal(w.Body.Bytes(), &parsed) |
| require.NoError(t, err) |
|
|
| errorObj, ok := parsed["error"].(map[string]any) |
| require.True(t, ok) |
| assert.Equal(t, "upstream_error", errorObj["type"]) |
| assert.Equal(t, "Upstream request failed", errorObj["message"]) |
| } |
|
|
| func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) |
|
|
| h := &OpenAIGatewayHandler{} |
| streamStarted := false |
| require.NotPanics(t, func() { |
| func() { |
| defer h.recoverResponsesPanic(c, &streamStarted) |
| }() |
| }) |
|
|
| require.False(t, c.Writer.Written()) |
| assert.Equal(t, "", w.Body.String()) |
| } |
|
|
| func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) |
| c.String(http.StatusTeapot, "already written") |
|
|
| h := &OpenAIGatewayHandler{} |
| streamStarted := false |
| require.NotPanics(t, func() { |
| func() { |
| defer h.recoverResponsesPanic(c, &streamStarted) |
| panic("test panic") |
| }() |
| }) |
|
|
| require.Equal(t, http.StatusTeapot, w.Code) |
| assert.Equal(t, "already written", w.Body.String()) |
| } |
|
|
| func TestOpenAIMissingResponsesDependencies(t *testing.T) { |
| t.Run("nil_handler", func(t *testing.T) { |
| var h *OpenAIGatewayHandler |
| require.Equal(t, []string{"handler"}, h.missingResponsesDependencies()) |
| }) |
|
|
| t.Run("all_dependencies_missing", func(t *testing.T) { |
| h := &OpenAIGatewayHandler{} |
| require.Equal(t, |
| []string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"}, |
| h.missingResponsesDependencies(), |
| ) |
| }) |
|
|
| t.Run("all_dependencies_present", func(t *testing.T) { |
| h := &OpenAIGatewayHandler{ |
| gatewayService: &service.OpenAIGatewayService{}, |
| billingCacheService: &service.BillingCacheService{}, |
| apiKeyService: &service.APIKeyService{}, |
| concurrencyHelper: &ConcurrencyHelper{ |
| concurrencyService: &service.ConcurrencyService{}, |
| }, |
| } |
| require.Empty(t, h.missingResponsesDependencies()) |
| }) |
| } |
|
|
| func TestOpenAIEnsureResponsesDependencies(t *testing.T) { |
| t.Run("missing_dependencies_returns_503", func(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) |
|
|
| h := &OpenAIGatewayHandler{} |
| ok := h.ensureResponsesDependencies(c, nil) |
|
|
| require.False(t, ok) |
| require.Equal(t, http.StatusServiceUnavailable, w.Code) |
| var parsed map[string]any |
| err := json.Unmarshal(w.Body.Bytes(), &parsed) |
| require.NoError(t, err) |
| errorObj, exists := parsed["error"].(map[string]any) |
| require.True(t, exists) |
| assert.Equal(t, "api_error", errorObj["type"]) |
| assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) |
| }) |
|
|
| t.Run("already_written_response_not_overridden", func(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) |
| c.String(http.StatusTeapot, "already written") |
|
|
| h := &OpenAIGatewayHandler{} |
| ok := h.ensureResponsesDependencies(c, nil) |
|
|
| require.False(t, ok) |
| require.Equal(t, http.StatusTeapot, w.Code) |
| assert.Equal(t, "already written", w.Body.String()) |
| }) |
|
|
| t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) |
|
|
| h := &OpenAIGatewayHandler{ |
| gatewayService: &service.OpenAIGatewayService{}, |
| billingCacheService: &service.BillingCacheService{}, |
| apiKeyService: &service.APIKeyService{}, |
| concurrencyHelper: &ConcurrencyHelper{ |
| concurrencyService: &service.ConcurrencyService{}, |
| }, |
| } |
| ok := h.ensureResponsesDependencies(c, nil) |
|
|
| require.True(t, ok) |
| require.False(t, c.Writer.Written()) |
| assert.Equal(t, "", w.Body.String()) |
| }) |
| } |
|
|
| func TestResolveOpenAIForwardDefaultMappedModel(t *testing.T) { |
| t.Run("prefers_explicit_fallback_model", func(t *testing.T) { |
| apiKey := &service.APIKey{ |
| Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, |
| } |
| require.Equal(t, "gpt-5.2", resolveOpenAIForwardDefaultMappedModel(apiKey, " gpt-5.2 ")) |
| }) |
|
|
| t.Run("uses_group_default_on_normal_path", func(t *testing.T) { |
| apiKey := &service.APIKey{ |
| Group: &service.Group{DefaultMappedModel: "gpt-5.4"}, |
| } |
| require.Equal(t, "gpt-5.4", resolveOpenAIForwardDefaultMappedModel(apiKey, "")) |
| }) |
|
|
| t.Run("returns_empty_without_group_default", func(t *testing.T) { |
| require.Empty(t, resolveOpenAIForwardDefaultMappedModel(nil, "")) |
| require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{}, "")) |
| require.Empty(t, resolveOpenAIForwardDefaultMappedModel(&service.APIKey{ |
| Group: &service.Group{}, |
| }, "")) |
| }) |
| } |
|
|
| func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`)) |
| c.Request.Header.Set("Content-Type", "application/json") |
|
|
| groupID := int64(2) |
| c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ |
| ID: 10, |
| GroupID: &groupID, |
| }) |
| c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ |
| UserID: 1, |
| Concurrency: 1, |
| }) |
|
|
| |
| h := &OpenAIGatewayHandler{} |
| require.NotPanics(t, func() { |
| h.Responses(c) |
| }) |
|
|
| require.Equal(t, http.StatusServiceUnavailable, w.Code) |
|
|
| var parsed map[string]any |
| err := json.Unmarshal(w.Body.Bytes(), &parsed) |
| require.NoError(t, err) |
|
|
| errorObj, ok := parsed["error"].(map[string]any) |
| require.True(t, ok) |
| assert.Equal(t, "api_error", errorObj["type"]) |
| assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) |
| } |
|
|
| func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`)) |
| c.Request.Header.Set("Content-Type", "application/json") |
|
|
| h := &OpenAIGatewayHandler{} |
| h.Responses(c) |
|
|
| require.Equal(t, http.StatusUnauthorized, w.Code) |
| require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) |
| } |
|
|
| func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( |
| `{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`, |
| )) |
| c.Request.Header.Set("Content-Type", "application/json") |
|
|
| groupID := int64(2) |
| c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ |
| ID: 101, |
| GroupID: &groupID, |
| User: &service.User{ID: 1}, |
| }) |
| c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ |
| UserID: 1, |
| Concurrency: 1, |
| }) |
|
|
| h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) |
| h.Responses(c) |
|
|
| require.Equal(t, http.StatusBadRequest, w.Code) |
| require.Contains(t, w.Body.String(), "previous_response_id must be a response.id") |
| } |
|
|
| func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) |
| c.Request.Header.Set("Upgrade", "websocket") |
| c.Request.Header.Set("Connection", "Upgrade") |
|
|
| h := &OpenAIGatewayHandler{} |
| h.ResponsesWebSocket(c) |
|
|
| require.Equal(t, http.StatusUnauthorized, w.Code) |
| require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) |
| } |
|
|
| func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) |
|
|
| h := &OpenAIGatewayHandler{} |
| h.ResponsesWebSocket(c) |
|
|
| require.Equal(t, http.StatusUpgradeRequired, w.Code) |
| require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c)) |
| } |
|
|
| func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) |
| wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) |
| defer wsServer.Close() |
|
|
| dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) |
| clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) |
| cancelDial() |
| require.NoError(t, err) |
| defer func() { |
| _ = clientConn.CloseNow() |
| }() |
|
|
| writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) |
| err = clientConn.Write(writeCtx, coderws.MessageText, []byte( |
| `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`, |
| )) |
| cancelWrite() |
| require.NoError(t, err) |
|
|
| readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) |
| _, _, err = clientConn.Read(readCtx) |
| cancelRead() |
| require.Error(t, err) |
| var closeErr coderws.CloseError |
| require.ErrorAs(t, err, &closeErr) |
| require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) |
| require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id") |
| } |
|
|
| func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| cache := &concurrencyCacheMock{ |
| acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { |
| return false, errors.New("user slot unavailable") |
| }, |
| } |
| h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache) |
| wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) |
| defer wsServer.Close() |
|
|
| dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) |
| clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) |
| cancelDial() |
| require.NoError(t, err) |
| defer func() { |
| _ = clientConn.CloseNow() |
| }() |
|
|
| writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) |
| err = clientConn.Write(writeCtx, coderws.MessageText, []byte( |
| `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`, |
| )) |
| cancelWrite() |
| require.NoError(t, err) |
|
|
| readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) |
| _, _, err = clientConn.Read(readCtx) |
| cancelRead() |
| require.Error(t, err) |
| var closeErr coderws.CloseError |
| require.ErrorAs(t, err, &closeErr) |
| require.Equal(t, coderws.StatusInternalError, closeErr.Code) |
| require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") |
| } |
|
|
| func TestSetOpenAIClientTransportHTTP(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| setOpenAIClientTransportHTTP(c) |
| require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) |
| } |
|
|
| func TestSetOpenAIClientTransportWS(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
|
|
| setOpenAIClientTransportWS(c) |
| require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) |
| } |
|
|
| |
| func TestOpenAIHandler_GjsonExtraction(t *testing.T) { |
| tests := []struct { |
| name string |
| body string |
| wantModel string |
| wantStream bool |
| }{ |
| {"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true}, |
| {"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false}, |
| {"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false}, |
| {"model 缺失", `{"stream":true}`, "", true}, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| body := []byte(tt.body) |
| modelResult := gjson.GetBytes(body, "model") |
| model := "" |
| if modelResult.Type == gjson.String { |
| model = modelResult.String() |
| } |
| stream := gjson.GetBytes(body, "stream").Bool() |
| require.Equal(t, tt.wantModel, model) |
| require.Equal(t, tt.wantStream, stream) |
| }) |
| } |
| } |
|
|
| |
| func TestOpenAIHandler_GjsonValidation(t *testing.T) { |
| |
| require.False(t, gjson.ValidBytes([]byte(`{invalid json`))) |
|
|
| |
| body := []byte(`{"model":123}`) |
| modelResult := gjson.GetBytes(body, "model") |
| require.True(t, modelResult.Exists()) |
| require.NotEqual(t, gjson.String, modelResult.Type) |
|
|
| |
| body2 := []byte(`{"model":null}`) |
| modelResult2 := gjson.GetBytes(body2, "model") |
| require.True(t, modelResult2.Exists()) |
| require.NotEqual(t, gjson.String, modelResult2.Type) |
|
|
| |
| body3 := []byte(`{"model":"gpt-4","stream":"true"}`) |
| streamResult := gjson.GetBytes(body3, "stream") |
| require.True(t, streamResult.Exists()) |
| require.NotEqual(t, gjson.True, streamResult.Type) |
| require.NotEqual(t, gjson.False, streamResult.Type) |
|
|
| |
| body4 := []byte(`{"model":"gpt-4","stream":1}`) |
| streamResult2 := gjson.GetBytes(body4, "stream") |
| require.True(t, streamResult2.Exists()) |
| require.NotEqual(t, gjson.True, streamResult2.Type) |
| require.NotEqual(t, gjson.False, streamResult2.Type) |
| } |
|
|
| |
| func TestOpenAIHandler_InstructionsInjection(t *testing.T) { |
| |
| body := []byte(`{"model":"gpt-4"}`) |
| existing := gjson.GetBytes(body, "instructions").String() |
| require.Empty(t, existing) |
| newBody, err := sjson.SetBytes(body, "instructions", "test instruction") |
| require.NoError(t, err) |
| require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String()) |
|
|
| |
| body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`) |
| existing2 := gjson.GetBytes(body2, "instructions").String() |
| require.Equal(t, "existing", existing2) |
|
|
| |
| body3 := []byte(`{"model":"gpt-4","instructions":" "}`) |
| existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String()) |
| require.Empty(t, existing3) |
|
|
| |
| |
| validBody := []byte(`{"model":"gpt-4"}`) |
| result, setErr := sjson.SetBytes(validBody, "instructions", "hello") |
| require.NoError(t, setErr) |
| require.True(t, gjson.ValidBytes(result)) |
| } |
|
|
| func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler { |
| t.Helper() |
| if cache == nil { |
| cache = &concurrencyCacheMock{ |
| acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { |
| return true, nil |
| }, |
| acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { |
| return true, nil |
| }, |
| } |
| } |
| return &OpenAIGatewayHandler{ |
| gatewayService: &service.OpenAIGatewayService{}, |
| billingCacheService: &service.BillingCacheService{}, |
| apiKeyService: &service.APIKeyService{}, |
| concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), |
| } |
| } |
|
|
| func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server { |
| t.Helper() |
| groupID := int64(2) |
| apiKey := &service.APIKey{ |
| ID: 101, |
| GroupID: &groupID, |
| User: &service.User{ID: subject.UserID}, |
| } |
| router := gin.New() |
| router.Use(func(c *gin.Context) { |
| c.Set(string(middleware.ContextKeyAPIKey), apiKey) |
| c.Set(string(middleware.ContextKeyUser), subject) |
| c.Next() |
| }) |
| router.GET("/openai/v1/responses", h.ResponsesWebSocket) |
| return httptest.NewServer(router) |
| } |
|
|