| package handler |
|
|
| import ( |
| "net/http" |
| "net/http/httptest" |
| "sync" |
| "testing" |
|
|
| middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
| "github.com/gin-gonic/gin" |
| "github.com/stretchr/testify/require" |
| ) |
|
|
| func resetOpsErrorLoggerStateForTest(t *testing.T) { |
| t.Helper() |
|
|
| opsErrorLogMu.Lock() |
| ch := opsErrorLogQueue |
| opsErrorLogQueue = nil |
| opsErrorLogStopping = true |
| opsErrorLogMu.Unlock() |
|
|
| if ch != nil { |
| close(ch) |
| } |
| opsErrorLogWorkersWg.Wait() |
|
|
| opsErrorLogOnce = sync.Once{} |
| opsErrorLogStopOnce = sync.Once{} |
| opsErrorLogWorkersWg = sync.WaitGroup{} |
| opsErrorLogMu = sync.RWMutex{} |
| opsErrorLogStopping = false |
|
|
| opsErrorLogQueueLen.Store(0) |
| opsErrorLogEnqueued.Store(0) |
| opsErrorLogDropped.Store(0) |
| opsErrorLogProcessed.Store(0) |
| opsErrorLogSanitized.Store(0) |
| opsErrorLogLastDropLogAt.Store(0) |
|
|
| opsErrorLogShutdownCh = make(chan struct{}) |
| opsErrorLogShutdownOnce = sync.Once{} |
| opsErrorLogDrained.Store(false) |
| } |
|
|
| func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) { |
| resetOpsErrorLoggerStateForTest(t) |
| gin.SetMode(gin.TestMode) |
|
|
| rec := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(rec) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) |
|
|
| raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`) |
| setOpsRequestContext(c, "claude-3", false, raw) |
|
|
| entry := &service.OpsInsertErrorLogInput{} |
| attachOpsRequestBodyToEntry(c, entry) |
|
|
| require.NotNil(t, entry.RequestBodyBytes) |
| require.Equal(t, len(raw), *entry.RequestBodyBytes) |
| require.NotNil(t, entry.RequestBodyJSON) |
| require.NotContains(t, *entry.RequestBodyJSON, "secret-token") |
| require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]") |
| require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) |
| } |
|
|
| func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) { |
| resetOpsErrorLoggerStateForTest(t) |
| gin.SetMode(gin.TestMode) |
|
|
| rec := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(rec) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) |
|
|
| raw := []byte("not-json") |
| setOpsRequestContext(c, "claude-3", false, raw) |
|
|
| entry := &service.OpsInsertErrorLogInput{} |
| attachOpsRequestBodyToEntry(c, entry) |
|
|
| require.Nil(t, entry.RequestBodyJSON) |
| require.NotNil(t, entry.RequestBodyBytes) |
| require.Equal(t, len(raw), *entry.RequestBodyBytes) |
| require.False(t, entry.RequestBodyTruncated) |
| require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) |
| } |
|
|
| func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) { |
| resetOpsErrorLoggerStateForTest(t) |
|
|
| |
| opsErrorLogOnce.Do(func() {}) |
|
|
| opsErrorLogMu.Lock() |
| opsErrorLogQueue = make(chan opsErrorLogJob, 1) |
| opsErrorLogMu.Unlock() |
|
|
| ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) |
| entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} |
|
|
| enqueueOpsErrorLog(ops, entry) |
| enqueueOpsErrorLog(ops, entry) |
|
|
| require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal()) |
| require.Equal(t, int64(1), OpsErrorLogDroppedTotal()) |
| require.Equal(t, int64(1), OpsErrorLogQueueLength()) |
| } |
|
|
| func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) { |
| resetOpsErrorLoggerStateForTest(t) |
| gin.SetMode(gin.TestMode) |
|
|
| entry := &service.OpsInsertErrorLogInput{} |
| attachOpsRequestBodyToEntry(nil, entry) |
| attachOpsRequestBodyToEntry(&gin.Context{}, nil) |
|
|
| rec := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(rec) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) |
|
|
| |
| attachOpsRequestBodyToEntry(c, entry) |
| require.Nil(t, entry.RequestBodyJSON) |
| require.Nil(t, entry.RequestBodyBytes) |
| require.False(t, entry.RequestBodyTruncated) |
|
|
| |
| c.Set(opsRequestBodyKey, "not-bytes") |
| attachOpsRequestBodyToEntry(c, entry) |
| require.Nil(t, entry.RequestBodyJSON) |
| require.Nil(t, entry.RequestBodyBytes) |
|
|
| |
| c.Set(opsRequestBodyKey, []byte{}) |
| attachOpsRequestBodyToEntry(c, entry) |
| require.Nil(t, entry.RequestBodyJSON) |
| require.Nil(t, entry.RequestBodyBytes) |
|
|
| require.Equal(t, int64(0), OpsErrorLogSanitizedTotal()) |
| } |
|
|
| func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) { |
| resetOpsErrorLoggerStateForTest(t) |
|
|
| ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) |
| entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} |
|
|
| |
| enqueueOpsErrorLog(nil, entry) |
| enqueueOpsErrorLog(ops, nil) |
| require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) |
|
|
| |
| close(opsErrorLogShutdownCh) |
| enqueueOpsErrorLog(ops, entry) |
| require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) |
|
|
| |
| resetOpsErrorLoggerStateForTest(t) |
| opsErrorLogMu.Lock() |
| opsErrorLogStopping = true |
| opsErrorLogMu.Unlock() |
| enqueueOpsErrorLog(ops, entry) |
| require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) |
|
|
| |
| resetOpsErrorLoggerStateForTest(t) |
| opsErrorLogOnce.Do(func() {}) |
| opsErrorLogMu.Lock() |
| opsErrorLogQueue = nil |
| opsErrorLogMu.Unlock() |
| enqueueOpsErrorLog(ops, entry) |
| require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) |
| } |
|
|
| func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| rec := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(rec) |
| c.Request = httptest.NewRequest(http.MethodGet, "/test", nil) |
|
|
| writer := acquireOpsCaptureWriter(c.Writer) |
| require.NotNil(t, writer) |
| _, err := writer.buf.WriteString("temp-error-body") |
| require.NoError(t, err) |
|
|
| releaseOpsCaptureWriter(writer) |
|
|
| reused := acquireOpsCaptureWriter(c.Writer) |
| defer releaseOpsCaptureWriter(reused) |
|
|
| require.Zero(t, reused.buf.Len(), "writer should be reset before reuse") |
| } |
|
|
| func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| r := gin.New() |
| r.Use(middleware2.Recovery()) |
| r.Use(middleware2.RequestLogger()) |
| r.Use(middleware2.Logger()) |
| r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) { |
| c.Status(http.StatusNoContent) |
| }) |
|
|
| rec := httptest.NewRecorder() |
| req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil) |
|
|
| require.NotPanics(t, func() { |
| r.ServeHTTP(rec, req) |
| }) |
| require.Equal(t, http.StatusNoContent, rec.Code) |
| } |
|
|
| func TestIsKnownOpsErrorType(t *testing.T) { |
| known := []string{ |
| "invalid_request_error", |
| "authentication_error", |
| "rate_limit_error", |
| "billing_error", |
| "subscription_error", |
| "upstream_error", |
| "overloaded_error", |
| "api_error", |
| "not_found_error", |
| "forbidden_error", |
| } |
| for _, k := range known { |
| require.True(t, isKnownOpsErrorType(k), "expected known: %s", k) |
| } |
|
|
| unknown := []string{"<nil>", "null", "", "random_error", "some_new_type", "<nil>\u003e"} |
| for _, u := range unknown { |
| require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u) |
| } |
| } |
|
|
| func TestNormalizeOpsErrorType(t *testing.T) { |
| tests := []struct { |
| name string |
| errType string |
| code string |
| want string |
| }{ |
| |
| {"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"}, |
| {"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"}, |
| {"known upstream_error", "upstream_error", "", "upstream_error"}, |
|
|
| |
| {"nil literal from upstream", "<nil>", "", "api_error"}, |
| {"null string", "null", "", "api_error"}, |
| {"random string", "something_weird", "", "api_error"}, |
|
|
| |
| {"nil with INSUFFICIENT_BALANCE code", "<nil>", "INSUFFICIENT_BALANCE", "billing_error"}, |
| {"nil with USAGE_LIMIT_EXCEEDED code", "<nil>", "USAGE_LIMIT_EXCEEDED", "subscription_error"}, |
|
|
| |
| {"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"}, |
| {"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"}, |
| {"empty type no code", "", "", "api_error"}, |
|
|
| |
| {"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"}, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| got := normalizeOpsErrorType(tt.errType, tt.code) |
| require.Equal(t, tt.want, got) |
| }) |
| } |
| } |
|
|