| package handler |
|
|
| import ( |
| "context" |
| "crypto/rand" |
| "encoding/json" |
| "errors" |
| "fmt" |
| "net/http" |
| "strconv" |
| "strings" |
| "sync/atomic" |
| "time" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/config" |
| "github.com/Wei-Shaw/sub2api/internal/domain" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/claude" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" |
| pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" |
| pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/ip" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/logger" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/openai" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" |
| middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
|
|
| "github.com/gin-gonic/gin" |
| "go.uber.org/zap" |
| ) |
|
|
| const gatewayCompatibilityMetricsLogInterval = 1024 |
|
|
| var gatewayCompatibilityMetricsLogCounter atomic.Uint64 |
|
|
| |
| type GatewayHandler struct { |
| gatewayService *service.GatewayService |
| geminiCompatService *service.GeminiMessagesCompatService |
| antigravityGatewayService *service.AntigravityGatewayService |
| userService *service.UserService |
| billingCacheService *service.BillingCacheService |
| usageService *service.UsageService |
| apiKeyService *service.APIKeyService |
| usageRecordWorkerPool *service.UsageRecordWorkerPool |
| errorPassthroughService *service.ErrorPassthroughService |
| concurrencyHelper *ConcurrencyHelper |
| userMsgQueueHelper *UserMsgQueueHelper |
| maxAccountSwitches int |
| maxAccountSwitchesGemini int |
| cfg *config.Config |
| settingService *service.SettingService |
| } |
|
|
| |
| func NewGatewayHandler( |
| gatewayService *service.GatewayService, |
| geminiCompatService *service.GeminiMessagesCompatService, |
| antigravityGatewayService *service.AntigravityGatewayService, |
| userService *service.UserService, |
| concurrencyService *service.ConcurrencyService, |
| billingCacheService *service.BillingCacheService, |
| usageService *service.UsageService, |
| apiKeyService *service.APIKeyService, |
| usageRecordWorkerPool *service.UsageRecordWorkerPool, |
| errorPassthroughService *service.ErrorPassthroughService, |
| userMsgQueueService *service.UserMessageQueueService, |
| cfg *config.Config, |
| settingService *service.SettingService, |
| ) *GatewayHandler { |
| pingInterval := time.Duration(0) |
| maxAccountSwitches := 10 |
| maxAccountSwitchesGemini := 3 |
| if cfg != nil { |
| pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second |
| if cfg.Gateway.MaxAccountSwitches > 0 { |
| maxAccountSwitches = cfg.Gateway.MaxAccountSwitches |
| } |
| if cfg.Gateway.MaxAccountSwitchesGemini > 0 { |
| maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini |
| } |
| } |
|
|
| |
| var umqHelper *UserMsgQueueHelper |
| if userMsgQueueService != nil && cfg != nil { |
| umqHelper = NewUserMsgQueueHelper(userMsgQueueService, SSEPingFormatClaude, pingInterval) |
| } |
|
|
| return &GatewayHandler{ |
| gatewayService: gatewayService, |
| geminiCompatService: geminiCompatService, |
| antigravityGatewayService: antigravityGatewayService, |
| userService: userService, |
| billingCacheService: billingCacheService, |
| usageService: usageService, |
| apiKeyService: apiKeyService, |
| usageRecordWorkerPool: usageRecordWorkerPool, |
| errorPassthroughService: errorPassthroughService, |
| concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), |
| userMsgQueueHelper: umqHelper, |
| maxAccountSwitches: maxAccountSwitches, |
| maxAccountSwitchesGemini: maxAccountSwitchesGemini, |
| cfg: cfg, |
| settingService: settingService, |
| } |
| } |
|
|
| |
| |
| func (h *GatewayHandler) Messages(c *gin.Context) { |
| |
| apiKey, ok := middleware2.GetAPIKeyFromContext(c) |
| if !ok { |
| h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") |
| return |
| } |
|
|
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") |
| return |
| } |
| reqLog := requestLogger( |
| c, |
| "handler.gateway.messages", |
| zap.Int64("user_id", subject.UserID), |
| zap.Int64("api_key_id", apiKey.ID), |
| zap.Any("group_id", apiKey.GroupID), |
| ) |
| defer h.maybeLogCompatibilityFallbackMetrics(reqLog) |
|
|
| |
| body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) |
| if err != nil { |
| if maxErr, ok := extractMaxBytesError(err); ok { |
| h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) |
| return |
| } |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") |
| return |
| } |
|
|
| if len(body) == 0 { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") |
| return |
| } |
|
|
| setOpsRequestContext(c, "", false, body) |
|
|
| parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) |
| if err != nil { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") |
| return |
| } |
| reqModel := parsedReq.Model |
| reqStream := parsedReq.Stream |
| reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) |
|
|
| |
| |
| if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { |
| ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled()) |
| c.Request = c.Request.WithContext(ctx) |
| } |
|
|
| |
| SetClaudeCodeClientContext(c, body, parsedReq) |
| isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context()) |
|
|
| |
| if !h.checkClaudeCodeVersion(c) { |
| return |
| } |
|
|
| |
| c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) |
|
|
| setOpsRequestContext(c, reqModel, reqStream, body) |
|
|
| |
| if reqModel == "" { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") |
| return |
| } |
|
|
| |
| streamStarted := false |
|
|
| |
| if h.errorPassthroughService != nil { |
| service.BindErrorPassthroughService(c, h.errorPassthroughService) |
| } |
|
|
| |
| subscription, _ := middleware2.GetSubscriptionFromContext(c) |
|
|
| |
| maxWait := service.CalculateMaxWait(subject.Concurrency) |
| canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) |
| waitCounted := false |
| if err != nil { |
| reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err)) |
| |
| } else if !canWait { |
| reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait)) |
| h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") |
| return |
| } |
| if err == nil && canWait { |
| waitCounted = true |
| } |
| |
| defer func() { |
| if waitCounted { |
| h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) |
| } |
| }() |
|
|
| |
| userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) |
| if err != nil { |
| reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err)) |
| h.handleConcurrencyError(c, err, "user", streamStarted) |
| return |
| } |
| |
| if waitCounted { |
| h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) |
| waitCounted = false |
| } |
| |
| userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) |
| if userReleaseFunc != nil { |
| defer userReleaseFunc() |
| } |
|
|
| |
| if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { |
| reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) |
| status, code, message := billingErrorDetails(err) |
| h.handleStreamingAwareError(c, status, code, message, streamStarted) |
| return |
| } |
|
|
| |
| parsedReq.SessionContext = &service.SessionContext{ |
| ClientIP: ip.GetClientIP(c), |
| UserAgent: c.GetHeader("User-Agent"), |
| APIKeyID: apiKey.ID, |
| } |
| sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) |
|
|
| |
| platform := "" |
| if forcePlatform, ok := middleware2.GetForcePlatformFromContext(c); ok { |
| platform = forcePlatform |
| } else if apiKey.Group != nil { |
| platform = apiKey.Group.Platform |
| } |
| sessionKey := sessionHash |
| if platform == service.PlatformGemini && sessionHash != "" { |
| sessionKey = "gemini:" + sessionHash |
| } |
|
|
| |
| var sessionBoundAccountID int64 |
| if sessionKey != "" { |
| sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) |
| if sessionBoundAccountID > 0 { |
| prefetchedGroupID := int64(0) |
| if apiKey.GroupID != nil { |
| prefetchedGroupID = *apiKey.GroupID |
| } |
| ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) |
| c.Request = c.Request.WithContext(ctx) |
| } |
| } |
| |
| hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 |
|
|
| if platform == service.PlatformGemini { |
| fs := NewFailoverState(h.maxAccountSwitchesGemini, hasBoundSession) |
|
|
| |
| |
| if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { |
| ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) |
| c.Request = c.Request.WithContext(ctx) |
| } |
|
|
| for { |
| selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "") |
| if err != nil { |
| if len(fs.FailedAccountIDs) == 0 { |
| h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) |
| return |
| } |
| action := fs.HandleSelectionExhausted(c.Request.Context()) |
| switch action { |
| case FailoverContinue: |
| ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) |
| c.Request = c.Request.WithContext(ctx) |
| continue |
| case FailoverCanceled: |
| return |
| default: |
| if fs.LastFailoverErr != nil { |
| h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) |
| } else { |
| h.handleFailoverExhaustedSimple(c, 502, streamStarted) |
| } |
| return |
| } |
| } |
| account := selection.Account |
| setOpsSelectedAccount(c, account.ID, account.Platform) |
|
|
| |
| if account.IsInterceptWarmupEnabled() { |
| interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) |
| if interceptType != InterceptTypeNone { |
| if selection.Acquired && selection.ReleaseFunc != nil { |
| selection.ReleaseFunc() |
| } |
| if reqStream { |
| sendMockInterceptStream(c, reqModel, interceptType) |
| } else { |
| sendMockInterceptResponse(c, reqModel, interceptType) |
| } |
| return |
| } |
| } |
|
|
| |
| accountReleaseFunc := selection.ReleaseFunc |
| if !selection.Acquired { |
| if selection.WaitPlan == nil { |
| h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) |
| return |
| } |
| accountWaitCounted := false |
| canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) |
| if err != nil { |
| reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| } else if !canWait { |
| reqLog.Info("gateway.account_wait_queue_full", |
| zap.Int64("account_id", account.ID), |
| zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), |
| ) |
| h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) |
| return |
| } |
| if err == nil && canWait { |
| accountWaitCounted = true |
| } |
| releaseWait := func() { |
| if accountWaitCounted { |
| h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) |
| accountWaitCounted = false |
| } |
| } |
|
|
| accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( |
| c, |
| account.ID, |
| selection.WaitPlan.MaxConcurrency, |
| selection.WaitPlan.Timeout, |
| reqStream, |
| &streamStarted, |
| ) |
| if err != nil { |
| reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| releaseWait() |
| h.handleConcurrencyError(c, err, "account", streamStarted) |
| return |
| } |
| |
| releaseWait() |
| if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { |
| reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| } |
| } |
| |
| accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) |
|
|
| |
| var result *service.ForwardResult |
| requestCtx := c.Request.Context() |
| if fs.SwitchCount > 0 { |
| requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) |
| } |
| |
| writerSizeBeforeForward := c.Writer.Size() |
| if account.Platform == service.PlatformAntigravity { |
| result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) |
| } else { |
| result, err = h.geminiCompatService.Forward(requestCtx, c, account, body) |
| } |
| if accountReleaseFunc != nil { |
| accountReleaseFunc() |
| } |
| if err != nil { |
| var failoverErr *service.UpstreamFailoverError |
| if errors.As(err, &failoverErr) { |
| |
| if c.Writer.Size() != writerSizeBeforeForward { |
| h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true) |
| return |
| } |
| action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) |
| switch action { |
| case FailoverContinue: |
| continue |
| case FailoverExhausted: |
| h.handleFailoverExhausted(c, fs.LastFailoverErr, service.PlatformGemini, streamStarted) |
| return |
| case FailoverCanceled: |
| return |
| } |
| } |
| wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) |
| reqLog.Error("gateway.forward_failed", |
| zap.Int64("account_id", account.ID), |
| zap.Bool("fallback_error_response_written", wroteFallback), |
| zap.Error(err), |
| ) |
| return |
| } |
|
|
| |
| |
| |
| if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { |
| if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { |
| reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| } |
| } |
|
|
| |
| userAgent := c.GetHeader("User-Agent") |
| clientIP := ip.GetClientIP(c) |
| requestPayloadHash := service.HashUsageRequestPayload(body) |
| inboundEndpoint := GetInboundEndpoint(c) |
| upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) |
|
|
| if result.ReasoningEffort == nil { |
| result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) |
| } |
|
|
| |
| h.submitUsageRecordTask(func(ctx context.Context) { |
| if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ |
| Result: result, |
| APIKey: apiKey, |
| User: apiKey.User, |
| Account: account, |
| Subscription: subscription, |
| InboundEndpoint: inboundEndpoint, |
| UpstreamEndpoint: upstreamEndpoint, |
| UserAgent: userAgent, |
| IPAddress: clientIP, |
| RequestPayloadHash: requestPayloadHash, |
| ForceCacheBilling: fs.ForceCacheBilling, |
| APIKeyService: h.apiKeyService, |
| }); err != nil { |
| logger.L().With( |
| zap.String("component", "handler.gateway.messages"), |
| zap.Int64("user_id", subject.UserID), |
| zap.Int64("api_key_id", apiKey.ID), |
| zap.Any("group_id", apiKey.GroupID), |
| zap.String("model", reqModel), |
| zap.Int64("account_id", account.ID), |
| ).Error("gateway.record_usage_failed", zap.Error(err)) |
| } |
| }) |
| return |
| } |
| } |
|
|
| currentAPIKey := apiKey |
| currentSubscription := subscription |
| var fallbackGroupID *int64 |
| if apiKey.Group != nil { |
| fallbackGroupID = apiKey.Group.FallbackGroupIDOnInvalidRequest |
| } |
| fallbackUsed := false |
|
|
| |
| |
| if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) { |
| ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) |
| c.Request = c.Request.WithContext(ctx) |
| } |
|
|
| for { |
| fs := NewFailoverState(h.maxAccountSwitches, hasBoundSession) |
| retryWithFallback := false |
|
|
| for { |
| |
| selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID) |
| if err != nil { |
| if len(fs.FailedAccountIDs) == 0 { |
| h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) |
| return |
| } |
| action := fs.HandleSelectionExhausted(c.Request.Context()) |
| switch action { |
| case FailoverContinue: |
| ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) |
| c.Request = c.Request.WithContext(ctx) |
| continue |
| case FailoverCanceled: |
| return |
| default: |
| if fs.LastFailoverErr != nil { |
| h.handleFailoverExhausted(c, fs.LastFailoverErr, platform, streamStarted) |
| } else { |
| h.handleFailoverExhaustedSimple(c, 502, streamStarted) |
| } |
| return |
| } |
| } |
| account := selection.Account |
| setOpsSelectedAccount(c, account.ID, account.Platform) |
|
|
| |
| if account.IsInterceptWarmupEnabled() { |
| interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient) |
| if interceptType != InterceptTypeNone { |
| if selection.Acquired && selection.ReleaseFunc != nil { |
| selection.ReleaseFunc() |
| } |
| if reqStream { |
| sendMockInterceptStream(c, reqModel, interceptType) |
| } else { |
| sendMockInterceptResponse(c, reqModel, interceptType) |
| } |
| return |
| } |
| } |
|
|
| |
| accountReleaseFunc := selection.ReleaseFunc |
| if !selection.Acquired { |
| if selection.WaitPlan == nil { |
| h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) |
| return |
| } |
| accountWaitCounted := false |
| canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) |
| if err != nil { |
| reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| } else if !canWait { |
| reqLog.Info("gateway.account_wait_queue_full", |
| zap.Int64("account_id", account.ID), |
| zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), |
| ) |
| h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) |
| return |
| } |
| if err == nil && canWait { |
| accountWaitCounted = true |
| } |
| releaseWait := func() { |
| if accountWaitCounted { |
| h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) |
| accountWaitCounted = false |
| } |
| } |
|
|
| accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( |
| c, |
| account.ID, |
| selection.WaitPlan.MaxConcurrency, |
| selection.WaitPlan.Timeout, |
| reqStream, |
| &streamStarted, |
| ) |
| if err != nil { |
| reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| releaseWait() |
| h.handleConcurrencyError(c, err, "account", streamStarted) |
| return |
| } |
| |
| releaseWait() |
| if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { |
| reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| } |
| } |
| |
| accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) |
|
|
| |
| var queueRelease func() |
| umqMode := h.getUserMsgQueueMode(account, parsedReq) |
|
|
| switch umqMode { |
| case config.UMQModeSerialize: |
| |
| baseRPM := account.GetBaseRPM() |
| release, qErr := h.userMsgQueueHelper.AcquireWithWait( |
| c, account.ID, baseRPM, reqStream, &streamStarted, |
| h.cfg.Gateway.UserMessageQueue.WaitTimeout(), |
| reqLog, |
| ) |
| if qErr != nil { |
| |
| reqLog.Warn("gateway.umq_acquire_failed", |
| zap.Int64("account_id", account.ID), |
| zap.Error(qErr), |
| ) |
| } else { |
| queueRelease = release |
| } |
|
|
| case config.UMQModeThrottle: |
| |
| baseRPM := account.GetBaseRPM() |
| if tErr := h.userMsgQueueHelper.ThrottleWithPing( |
| c, account.ID, baseRPM, reqStream, &streamStarted, |
| h.cfg.Gateway.UserMessageQueue.WaitTimeout(), |
| reqLog, |
| ); tErr != nil { |
| reqLog.Warn("gateway.umq_throttle_failed", |
| zap.Int64("account_id", account.ID), |
| zap.Error(tErr), |
| ) |
| } |
|
|
| default: |
| if umqMode != "" { |
| reqLog.Warn("gateway.umq_unknown_mode", |
| zap.String("mode", umqMode), |
| zap.Int64("account_id", account.ID), |
| ) |
| } |
| } |
|
|
| |
| queueRelease = wrapReleaseOnDone(c.Request.Context(), queueRelease) |
| |
| parsedReq.OnUpstreamAccepted = queueRelease |
| |
|
|
| |
| var result *service.ForwardResult |
| requestCtx := c.Request.Context() |
| if fs.SwitchCount > 0 { |
| requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) |
| } |
| |
| writerSizeBeforeForward := c.Writer.Size() |
| if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { |
| result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) |
| } else { |
| result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) |
| } |
|
|
| |
| if queueRelease != nil { |
| queueRelease() |
| } |
| |
| parsedReq.OnUpstreamAccepted = nil |
|
|
| if accountReleaseFunc != nil { |
| accountReleaseFunc() |
| } |
| if err != nil { |
| |
| var betaBlockedErr *service.BetaBlockedError |
| if errors.As(err, &betaBlockedErr) { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", betaBlockedErr.Message) |
| return |
| } |
|
|
| var promptTooLongErr *service.PromptTooLongError |
| if errors.As(err, &promptTooLongErr) { |
| reqLog.Warn("gateway.prompt_too_long_from_antigravity", |
| zap.Any("current_group_id", currentAPIKey.GroupID), |
| zap.Any("fallback_group_id", fallbackGroupID), |
| zap.Bool("fallback_used", fallbackUsed), |
| ) |
| if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { |
| fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) |
| if err != nil { |
| reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err)) |
| _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) |
| return |
| } |
| if fallbackGroup.Platform != service.PlatformAnthropic || |
| fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || |
| fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { |
| reqLog.Warn("gateway.fallback_group_invalid", |
| zap.Int64("fallback_group_id", fallbackGroup.ID), |
| zap.String("fallback_platform", fallbackGroup.Platform), |
| zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType), |
| ) |
| _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) |
| return |
| } |
| fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) |
| if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { |
| status, code, message := billingErrorDetails(err) |
| h.handleStreamingAwareError(c, status, code, message, streamStarted) |
| return |
| } |
| |
| ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") |
| c.Request = c.Request.WithContext(ctx) |
| currentAPIKey = fallbackAPIKey |
| currentSubscription = nil |
| fallbackUsed = true |
| retryWithFallback = true |
| break |
| } |
| _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) |
| return |
| } |
| var failoverErr *service.UpstreamFailoverError |
| if errors.As(err, &failoverErr) { |
| |
| if c.Writer.Size() != writerSizeBeforeForward { |
| h.handleFailoverExhausted(c, failoverErr, account.Platform, true) |
| return |
| } |
| action := fs.HandleFailoverError(c.Request.Context(), h.gatewayService, account.ID, account.Platform, failoverErr) |
| switch action { |
| case FailoverContinue: |
| continue |
| case FailoverExhausted: |
| h.handleFailoverExhausted(c, fs.LastFailoverErr, account.Platform, streamStarted) |
| return |
| case FailoverCanceled: |
| return |
| } |
| } |
| wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) |
| reqLog.Error("gateway.forward_failed", |
| zap.Int64("account_id", account.ID), |
| zap.Bool("fallback_error_response_written", wroteFallback), |
| zap.Error(err), |
| ) |
| return |
| } |
|
|
| |
| |
| |
| if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { |
| if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { |
| reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| } |
| } |
|
|
| |
| userAgent := c.GetHeader("User-Agent") |
| clientIP := ip.GetClientIP(c) |
| requestPayloadHash := service.HashUsageRequestPayload(body) |
| inboundEndpoint := GetInboundEndpoint(c) |
| upstreamEndpoint := GetUpstreamEndpoint(c, account.Platform) |
|
|
| if result.ReasoningEffort == nil { |
| result.ReasoningEffort = service.NormalizeClaudeOutputEffort(parsedReq.OutputEffort) |
| } |
|
|
| |
| h.submitUsageRecordTask(func(ctx context.Context) { |
| if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ |
| Result: result, |
| APIKey: currentAPIKey, |
| User: currentAPIKey.User, |
| Account: account, |
| Subscription: currentSubscription, |
| InboundEndpoint: inboundEndpoint, |
| UpstreamEndpoint: upstreamEndpoint, |
| UserAgent: userAgent, |
| IPAddress: clientIP, |
| RequestPayloadHash: requestPayloadHash, |
| ForceCacheBilling: fs.ForceCacheBilling, |
| APIKeyService: h.apiKeyService, |
| }); err != nil { |
| logger.L().With( |
| zap.String("component", "handler.gateway.messages"), |
| zap.Int64("user_id", subject.UserID), |
| zap.Int64("api_key_id", currentAPIKey.ID), |
| zap.Any("group_id", currentAPIKey.GroupID), |
| zap.String("model", reqModel), |
| zap.Int64("account_id", account.ID), |
| ).Error("gateway.record_usage_failed", zap.Error(err)) |
| } |
| }) |
| return |
| } |
| if !retryWithFallback { |
| return |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| func (h *GatewayHandler) Models(c *gin.Context) { |
| apiKey, _ := middleware2.GetAPIKeyFromContext(c) |
|
|
| var groupID *int64 |
| var platform string |
|
|
| if apiKey != nil && apiKey.Group != nil { |
| groupID = &apiKey.Group.ID |
| platform = apiKey.Group.Platform |
| } |
| if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" { |
| platform = forcedPlatform |
| } |
|
|
| if platform == service.PlatformSora { |
| c.JSON(http.StatusOK, gin.H{ |
| "object": "list", |
| "data": service.DefaultSoraModels(h.cfg), |
| }) |
| return |
| } |
|
|
| |
| availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") |
|
|
| if len(availableModels) > 0 { |
| |
| models := make([]claude.Model, 0, len(availableModels)) |
| for _, modelID := range availableModels { |
| models = append(models, claude.Model{ |
| ID: modelID, |
| Type: "model", |
| DisplayName: modelID, |
| CreatedAt: "2024-01-01T00:00:00Z", |
| }) |
| } |
| c.JSON(http.StatusOK, gin.H{ |
| "object": "list", |
| "data": models, |
| }) |
| return |
| } |
|
|
| |
| if platform == "openai" { |
| c.JSON(http.StatusOK, gin.H{ |
| "object": "list", |
| "data": openai.DefaultModels, |
| }) |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "object": "list", |
| "data": claude.DefaultModels, |
| }) |
| } |
|
|
| |
| |
| func (h *GatewayHandler) AntigravityModels(c *gin.Context) { |
| c.JSON(http.StatusOK, gin.H{ |
| "object": "list", |
| "data": antigravity.DefaultModels(), |
| }) |
| } |
|
|
| func cloneAPIKeyWithGroup(apiKey *service.APIKey, group *service.Group) *service.APIKey { |
| if apiKey == nil || group == nil { |
| return apiKey |
| } |
| cloned := *apiKey |
| groupID := group.ID |
| cloned.GroupID = &groupID |
| cloned.Group = group |
| return &cloned |
| } |
|
|
| |
| |
| |
| |
| |
| |
| func (h *GatewayHandler) Usage(c *gin.Context) { |
| apiKey, ok := middleware2.GetAPIKeyFromContext(c) |
| if !ok { |
| h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") |
| return |
| } |
|
|
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") |
| return |
| } |
|
|
| ctx := c.Request.Context() |
|
|
| |
| startTime, endTime := h.parseUsageDateRange(c) |
|
|
| |
| usageData := h.buildUsageData(ctx, apiKey.ID) |
|
|
| |
| var modelStats any |
| if h.usageService != nil { |
| if stats, err := h.usageService.GetAPIKeyModelStats(ctx, apiKey.ID, startTime, endTime); err == nil && len(stats) > 0 { |
| modelStats = stats |
| } |
| } |
|
|
| |
| isQuotaLimited := apiKey.Quota > 0 || apiKey.HasRateLimits() |
|
|
| if isQuotaLimited { |
| h.usageQuotaLimited(c, ctx, apiKey, usageData, modelStats) |
| return |
| } |
|
|
| h.usageUnrestricted(c, ctx, apiKey, subject, usageData, modelStats) |
| } |
|
|
| |
| func (h *GatewayHandler) parseUsageDateRange(c *gin.Context) (time.Time, time.Time) { |
| now := timezone.Now() |
| endTime := now |
| startTime := now.AddDate(0, 0, -30) |
|
|
| if s := c.Query("start_date"); s != "" { |
| if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { |
| startTime = t |
| } |
| } |
| if s := c.Query("end_date"); s != "" { |
| if t, err := timezone.ParseInLocation("2006-01-02", s); err == nil { |
| endTime = t.AddDate(0, 0, 1) |
| } |
| } |
| return startTime, endTime |
| } |
|
|
| |
| func (h *GatewayHandler) buildUsageData(ctx context.Context, apiKeyID int64) gin.H { |
| if h.usageService == nil { |
| return nil |
| } |
| dashStats, err := h.usageService.GetAPIKeyDashboardStats(ctx, apiKeyID) |
| if err != nil || dashStats == nil { |
| return nil |
| } |
| return gin.H{ |
| "today": gin.H{ |
| "requests": dashStats.TodayRequests, |
| "input_tokens": dashStats.TodayInputTokens, |
| "output_tokens": dashStats.TodayOutputTokens, |
| "cache_creation_tokens": dashStats.TodayCacheCreationTokens, |
| "cache_read_tokens": dashStats.TodayCacheReadTokens, |
| "total_tokens": dashStats.TodayTokens, |
| "cost": dashStats.TodayCost, |
| "actual_cost": dashStats.TodayActualCost, |
| }, |
| "total": gin.H{ |
| "requests": dashStats.TotalRequests, |
| "input_tokens": dashStats.TotalInputTokens, |
| "output_tokens": dashStats.TotalOutputTokens, |
| "cache_creation_tokens": dashStats.TotalCacheCreationTokens, |
| "cache_read_tokens": dashStats.TotalCacheReadTokens, |
| "total_tokens": dashStats.TotalTokens, |
| "cost": dashStats.TotalCost, |
| "actual_cost": dashStats.TotalActualCost, |
| }, |
| "average_duration_ms": dashStats.AverageDurationMs, |
| "rpm": dashStats.Rpm, |
| "tpm": dashStats.Tpm, |
| } |
| } |
|
|
| |
| func (h *GatewayHandler) usageQuotaLimited(c *gin.Context, ctx context.Context, apiKey *service.APIKey, usageData gin.H, modelStats any) { |
| resp := gin.H{ |
| "mode": "quota_limited", |
| "isValid": apiKey.Status == service.StatusAPIKeyActive || apiKey.Status == service.StatusAPIKeyQuotaExhausted || apiKey.Status == service.StatusAPIKeyExpired, |
| "status": apiKey.Status, |
| } |
|
|
| |
| if apiKey.Quota > 0 { |
| remaining := apiKey.GetQuotaRemaining() |
| resp["quota"] = gin.H{ |
| "limit": apiKey.Quota, |
| "used": apiKey.QuotaUsed, |
| "remaining": remaining, |
| "unit": "USD", |
| } |
| resp["remaining"] = remaining |
| resp["unit"] = "USD" |
| } |
|
|
| |
| if apiKey.HasRateLimits() && h.apiKeyService != nil { |
| rateLimitData, err := h.apiKeyService.GetRateLimitData(ctx, apiKey.ID) |
| if err == nil && rateLimitData != nil { |
| var rateLimits []gin.H |
| if apiKey.RateLimit5h > 0 { |
| used := rateLimitData.EffectiveUsage5h() |
| entry := gin.H{ |
| "window": "5h", |
| "limit": apiKey.RateLimit5h, |
| "used": used, |
| "remaining": max(0, apiKey.RateLimit5h-used), |
| "window_start": rateLimitData.Window5hStart, |
| } |
| if rateLimitData.Window5hStart != nil && !service.IsWindowExpired(rateLimitData.Window5hStart, service.RateLimitWindow5h) { |
| entry["reset_at"] = rateLimitData.Window5hStart.Add(service.RateLimitWindow5h) |
| } |
| rateLimits = append(rateLimits, entry) |
| } |
| if apiKey.RateLimit1d > 0 { |
| used := rateLimitData.EffectiveUsage1d() |
| entry := gin.H{ |
| "window": "1d", |
| "limit": apiKey.RateLimit1d, |
| "used": used, |
| "remaining": max(0, apiKey.RateLimit1d-used), |
| "window_start": rateLimitData.Window1dStart, |
| } |
| if rateLimitData.Window1dStart != nil && !service.IsWindowExpired(rateLimitData.Window1dStart, service.RateLimitWindow1d) { |
| entry["reset_at"] = rateLimitData.Window1dStart.Add(service.RateLimitWindow1d) |
| } |
| rateLimits = append(rateLimits, entry) |
| } |
| if apiKey.RateLimit7d > 0 { |
| used := rateLimitData.EffectiveUsage7d() |
| entry := gin.H{ |
| "window": "7d", |
| "limit": apiKey.RateLimit7d, |
| "used": used, |
| "remaining": max(0, apiKey.RateLimit7d-used), |
| "window_start": rateLimitData.Window7dStart, |
| } |
| if rateLimitData.Window7dStart != nil && !service.IsWindowExpired(rateLimitData.Window7dStart, service.RateLimitWindow7d) { |
| entry["reset_at"] = rateLimitData.Window7dStart.Add(service.RateLimitWindow7d) |
| } |
| rateLimits = append(rateLimits, entry) |
| } |
| if len(rateLimits) > 0 { |
| resp["rate_limits"] = rateLimits |
| } |
| } |
| } |
|
|
| |
| if apiKey.ExpiresAt != nil { |
| resp["expires_at"] = apiKey.ExpiresAt |
| resp["days_until_expiry"] = apiKey.GetDaysUntilExpiry() |
| } |
|
|
| if usageData != nil { |
| resp["usage"] = usageData |
| } |
| if modelStats != nil { |
| resp["model_stats"] = modelStats |
| } |
|
|
| c.JSON(http.StatusOK, resp) |
| } |
|
|
| |
| func (h *GatewayHandler) usageUnrestricted(c *gin.Context, ctx context.Context, apiKey *service.APIKey, subject middleware2.AuthSubject, usageData gin.H, modelStats any) { |
| |
| if apiKey.Group != nil && apiKey.Group.IsSubscriptionType() { |
| resp := gin.H{ |
| "mode": "unrestricted", |
| "isValid": true, |
| "planName": apiKey.Group.Name, |
| "unit": "USD", |
| } |
|
|
| |
| subscription, ok := middleware2.GetSubscriptionFromContext(c) |
| if ok { |
| remaining := h.calculateSubscriptionRemaining(apiKey.Group, subscription) |
| resp["remaining"] = remaining |
| resp["subscription"] = gin.H{ |
| "daily_usage_usd": subscription.DailyUsageUSD, |
| "weekly_usage_usd": subscription.WeeklyUsageUSD, |
| "monthly_usage_usd": subscription.MonthlyUsageUSD, |
| "daily_limit_usd": apiKey.Group.DailyLimitUSD, |
| "weekly_limit_usd": apiKey.Group.WeeklyLimitUSD, |
| "monthly_limit_usd": apiKey.Group.MonthlyLimitUSD, |
| "expires_at": subscription.ExpiresAt, |
| } |
| } |
|
|
| if usageData != nil { |
| resp["usage"] = usageData |
| } |
| if modelStats != nil { |
| resp["model_stats"] = modelStats |
| } |
| c.JSON(http.StatusOK, resp) |
| return |
| } |
|
|
| |
| latestUser, err := h.userService.GetByID(ctx, subject.UserID) |
| if err != nil { |
| h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to get user info") |
| return |
| } |
|
|
| resp := gin.H{ |
| "mode": "unrestricted", |
| "isValid": true, |
| "planName": "钱包余额", |
| "remaining": latestUser.Balance, |
| "unit": "USD", |
| "balance": latestUser.Balance, |
| } |
| if usageData != nil { |
| resp["usage"] = usageData |
| } |
| if modelStats != nil { |
| resp["model_stats"] = modelStats |
| } |
| c.JSON(http.StatusOK, resp) |
| } |
|
|
| |
| |
| |
| |
| func (h *GatewayHandler) calculateSubscriptionRemaining(group *service.Group, sub *service.UserSubscription) float64 { |
| var remainingValues []float64 |
|
|
| |
| if group.HasDailyLimit() { |
| remaining := *group.DailyLimitUSD - sub.DailyUsageUSD |
| if remaining <= 0 { |
| return 0 |
| } |
| remainingValues = append(remainingValues, remaining) |
| } |
|
|
| |
| if group.HasWeeklyLimit() { |
| remaining := *group.WeeklyLimitUSD - sub.WeeklyUsageUSD |
| if remaining <= 0 { |
| return 0 |
| } |
| remainingValues = append(remainingValues, remaining) |
| } |
|
|
| |
| if group.HasMonthlyLimit() { |
| remaining := *group.MonthlyLimitUSD - sub.MonthlyUsageUSD |
| if remaining <= 0 { |
| return 0 |
| } |
| remainingValues = append(remainingValues, remaining) |
| } |
|
|
| |
| if len(remainingValues) == 0 { |
| return -1 |
| } |
|
|
| |
| min := remainingValues[0] |
| for _, v := range remainingValues[1:] { |
| if v < min { |
| min = v |
| } |
| } |
| return min |
| } |
|
|
| |
| func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { |
| h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", |
| fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) |
| } |
|
|
| func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { |
| statusCode := failoverErr.StatusCode |
| responseBody := failoverErr.ResponseBody |
|
|
| |
| if h.errorPassthroughService != nil && len(responseBody) > 0 { |
| if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil { |
| |
| respCode := statusCode |
| if !rule.PassthroughCode && rule.ResponseCode != nil { |
| respCode = *rule.ResponseCode |
| } |
|
|
| |
| msg := service.ExtractUpstreamErrorMessage(responseBody) |
| if !rule.PassthroughBody && rule.CustomMessage != nil { |
| msg = *rule.CustomMessage |
| } |
|
|
| if rule.SkipMonitoring { |
| c.Set(service.OpsSkipPassthroughKey, true) |
| } |
|
|
| h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) |
| return |
| } |
| } |
|
|
| |
| upstreamMsg := service.ExtractUpstreamErrorMessage(responseBody) |
| service.SetOpsUpstreamError(c, statusCode, upstreamMsg, "") |
|
|
| |
| status, errType, errMsg := h.mapUpstreamError(statusCode) |
| h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) |
| } |
|
|
| |
| func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { |
| status, errType, errMsg := h.mapUpstreamError(statusCode) |
| service.SetOpsUpstreamError(c, statusCode, errMsg, "") |
| h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) |
| } |
|
|
| func (h *GatewayHandler) mapUpstreamError(statusCode int) (int, string, string) { |
| switch statusCode { |
| case 401: |
| return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" |
| case 403: |
| return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" |
| case 429: |
| return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" |
| case 529: |
| return http.StatusServiceUnavailable, "overloaded_error", "Upstream service overloaded, please retry later" |
| case 500, 502, 503, 504: |
| return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" |
| default: |
| return http.StatusBadGateway, "upstream_error", "Upstream request failed" |
| } |
| } |
|
|
| |
| func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) { |
| if streamStarted { |
| |
| flusher, ok := c.Writer.(http.Flusher) |
| if ok { |
| |
| errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" |
| if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { |
| _ = c.Error(err) |
| } |
| flusher.Flush() |
| } |
| return |
| } |
|
|
| |
| h.errorResponse(c, status, errType, message) |
| } |
|
|
| |
| func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { |
| if c == nil || c.Writer == nil || c.Writer.Written() { |
| return false |
| } |
| h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) |
| return true |
| } |
|
|
| |
| |
| func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool { |
| ctx := c.Request.Context() |
| if !service.IsClaudeCodeClient(ctx) { |
| return true |
| } |
|
|
| |
| if strings.HasSuffix(c.Request.URL.Path, "/count_tokens") { |
| return true |
| } |
|
|
| minVersion, maxVersion := h.settingService.GetClaudeCodeVersionBounds(ctx) |
| if minVersion == "" && maxVersion == "" { |
| return true |
| } |
|
|
| clientVersion := service.GetClaudeCodeVersion(ctx) |
| if clientVersion == "" { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", |
| "Unable to determine Claude Code version. Please update Claude Code: npm update -g @anthropic-ai/claude-code") |
| return false |
| } |
|
|
| if minVersion != "" && service.CompareVersions(clientVersion, minVersion) < 0 { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", |
| fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code", |
| clientVersion, minVersion)) |
| return false |
| } |
|
|
| if maxVersion != "" && service.CompareVersions(clientVersion, maxVersion) > 0 { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", |
| fmt.Sprintf("Your Claude Code version (%s) exceeds the maximum allowed version (%s). "+ |
| "Please downgrade: npm install -g @anthropic-ai/claude-code@%s && "+ |
| "set CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC=1 to prevent auto-upgrade", |
| clientVersion, maxVersion, maxVersion)) |
| return false |
| } |
|
|
| return true |
| } |
|
|
| |
| func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { |
| c.JSON(status, gin.H{ |
| "type": "error", |
| "error": gin.H{ |
| "type": errType, |
| "message": message, |
| }, |
| }) |
| } |
|
|
| |
| |
| |
| func (h *GatewayHandler) CountTokens(c *gin.Context) { |
| |
| apiKey, ok := middleware2.GetAPIKeyFromContext(c) |
| if !ok { |
| h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") |
| return |
| } |
|
|
| _, ok = middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") |
| return |
| } |
| reqLog := requestLogger( |
| c, |
| "handler.gateway.count_tokens", |
| zap.Int64("api_key_id", apiKey.ID), |
| zap.Any("group_id", apiKey.GroupID), |
| ) |
| defer h.maybeLogCompatibilityFallbackMetrics(reqLog) |
|
|
| |
| body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) |
| if err != nil { |
| if maxErr, ok := extractMaxBytesError(err); ok { |
| h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) |
| return |
| } |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") |
| return |
| } |
|
|
| if len(body) == 0 { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") |
| return |
| } |
|
|
| setOpsRequestContext(c, "", false, body) |
|
|
| parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) |
| if err != nil { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") |
| return |
| } |
| |
| SetClaudeCodeClientContext(c, body, parsedReq) |
| reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream)) |
| |
| c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) |
|
|
| |
| if parsedReq.Model == "" { |
| h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") |
| return |
| } |
|
|
| setOpsRequestContext(c, parsedReq.Model, parsedReq.Stream, body) |
|
|
| |
| subscription, _ := middleware2.GetSubscriptionFromContext(c) |
|
|
| |
| |
| if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { |
| status, code, message := billingErrorDetails(err) |
| h.errorResponse(c, status, code, message) |
| return |
| } |
|
|
| |
| parsedReq.SessionContext = &service.SessionContext{ |
| ClientIP: ip.GetClientIP(c), |
| UserAgent: c.GetHeader("User-Agent"), |
| APIKeyID: apiKey.ID, |
| } |
| sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) |
|
|
| |
| account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) |
| if err != nil { |
| reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err)) |
| h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") |
| return |
| } |
| setOpsSelectedAccount(c, account.ID, account.Platform) |
|
|
| |
| if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { |
| reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) |
| |
| return |
| } |
| } |
|
|
| |
| type InterceptType int |
|
|
| const ( |
| InterceptTypeNone InterceptType = iota |
| InterceptTypeWarmup |
| InterceptTypeSuggestionMode |
| InterceptTypeMaxTokensOneHaiku |
| ) |
|
|
| |
| func isHaikuModel(model string) bool { |
| return strings.Contains(strings.ToLower(model), "haiku") |
| } |
|
|
| |
| |
| |
| func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool { |
| return maxTokens == 1 && isHaikuModel(model) && !isStream |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType { |
| |
| if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) { |
| return InterceptTypeMaxTokensOneHaiku |
| } |
|
|
| |
| bodyStr := string(body) |
| hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:") |
| hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup") |
|
|
| if !hasSuggestionMode && !hasWarmupKeyword { |
| return InterceptTypeNone |
| } |
|
|
| |
| var req struct { |
| Messages []struct { |
| Role string `json:"role"` |
| Content []struct { |
| Type string `json:"type"` |
| Text string `json:"text"` |
| } `json:"content"` |
| } `json:"messages"` |
| System []struct { |
| Text string `json:"text"` |
| } `json:"system"` |
| } |
| if err := json.Unmarshal(body, &req); err != nil { |
| return InterceptTypeNone |
| } |
|
|
| |
| if hasSuggestionMode && len(req.Messages) > 0 { |
| lastMsg := req.Messages[len(req.Messages)-1] |
| if lastMsg.Role == "user" && len(lastMsg.Content) > 0 && |
| lastMsg.Content[0].Type == "text" && |
| strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") { |
| return InterceptTypeSuggestionMode |
| } |
| } |
|
|
| |
| if hasWarmupKeyword { |
| |
| for _, msg := range req.Messages { |
| for _, content := range msg.Content { |
| if content.Type == "text" { |
| if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") || |
| content.Text == "Warmup" { |
| return InterceptTypeWarmup |
| } |
| } |
| } |
| } |
| |
| for _, sys := range req.System { |
| if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") { |
| return InterceptTypeWarmup |
| } |
| } |
| } |
|
|
| return InterceptTypeNone |
| } |
|
|
| |
| func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) { |
| c.Header("Content-Type", "text/event-stream") |
| c.Header("Cache-Control", "no-cache") |
| c.Header("Connection", "keep-alive") |
| c.Header("X-Accel-Buffering", "no") |
|
|
| |
| var msgID string |
| var outputTokens int |
| var textDeltas []string |
|
|
| switch interceptType { |
| case InterceptTypeSuggestionMode: |
| msgID = "msg_mock_suggestion" |
| outputTokens = 1 |
| textDeltas = []string{""} |
| default: |
| msgID = "msg_mock_warmup" |
| outputTokens = 2 |
| textDeltas = []string{"New", " Conversation"} |
| } |
|
|
| |
| messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}` |
|
|
| |
| events := []string{ |
| `event: message_start` + "\n" + `data: ` + string(messageStartJSON), |
| `event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`, |
| } |
|
|
| |
| for _, text := range textDeltas { |
| deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}` |
| events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON)) |
| } |
|
|
| |
| messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}` |
|
|
| events = append(events, |
| `event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`, |
| `event: message_delta`+"\n"+`data: `+string(messageDeltaJSON), |
| `event: message_stop`+"\n"+`data: {"type":"message_stop"}`, |
| ) |
|
|
| for _, event := range events { |
| _, _ = c.Writer.WriteString(event + "\n\n") |
| c.Writer.Flush() |
| time.Sleep(20 * time.Millisecond) |
| } |
| } |
|
|
| |
| |
| func generateRealisticMsgID() string { |
| const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" |
| const idLen = 24 |
| randomBytes := make([]byte, idLen) |
| if _, err := rand.Read(randomBytes); err != nil { |
| return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano()) |
| } |
| b := make([]byte, idLen) |
| for i := range b { |
| b[i] = charset[int(randomBytes[i])%len(charset)] |
| } |
| return "msg_bdrk_" + string(b) |
| } |
|
|
| |
| func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) { |
| var msgID, text, stopReason string |
| var outputTokens int |
|
|
| switch interceptType { |
| case InterceptTypeSuggestionMode: |
| msgID = "msg_mock_suggestion" |
| text = "" |
| outputTokens = 1 |
| stopReason = "end_turn" |
| case InterceptTypeMaxTokensOneHaiku: |
| msgID = generateRealisticMsgID() |
| text = "#" |
| outputTokens = 1 |
| stopReason = "max_tokens" |
| default: |
| msgID = "msg_mock_warmup" |
| text = "New Conversation" |
| outputTokens = 2 |
| stopReason = "end_turn" |
| } |
|
|
| |
| response := gin.H{ |
| "model": model, |
| "id": msgID, |
| "type": "message", |
| "role": "assistant", |
| "content": []gin.H{{"type": "text", "text": text}}, |
| "stop_reason": stopReason, |
| "stop_sequence": nil, |
| "usage": gin.H{ |
| "input_tokens": 10, |
| "cache_creation_input_tokens": 0, |
| "cache_read_input_tokens": 0, |
| "cache_creation": gin.H{ |
| "ephemeral_5m_input_tokens": 0, |
| "ephemeral_1h_input_tokens": 0, |
| }, |
| "output_tokens": outputTokens, |
| "total_tokens": 10 + outputTokens, |
| }, |
| } |
|
|
| c.JSON(http.StatusOK, response) |
| } |
|
|
| func billingErrorDetails(err error) (status int, code, message string) { |
| if errors.Is(err, service.ErrBillingServiceUnavailable) { |
| msg := pkgerrors.Message(err) |
| if msg == "" { |
| msg = "Billing service temporarily unavailable. Please retry later." |
| } |
| return http.StatusServiceUnavailable, "billing_service_error", msg |
| } |
| if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) { |
| msg := pkgerrors.Message(err) |
| return http.StatusTooManyRequests, "rate_limit_exceeded", msg |
| } |
| if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) { |
| msg := pkgerrors.Message(err) |
| return http.StatusTooManyRequests, "rate_limit_exceeded", msg |
| } |
| if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) { |
| msg := pkgerrors.Message(err) |
| return http.StatusTooManyRequests, "rate_limit_exceeded", msg |
| } |
| msg := pkgerrors.Message(err) |
| if msg == "" { |
| logger.L().With( |
| zap.String("component", "handler.gateway.billing"), |
| zap.Error(err), |
| ).Warn("gateway.billing_error_missing_message") |
| msg = "Billing error" |
| } |
| return http.StatusForbidden, "billing_error", msg |
| } |
|
|
| func (h *GatewayHandler) metadataBridgeEnabled() bool { |
| if h == nil || h.cfg == nil { |
| return true |
| } |
| return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled |
| } |
|
|
| func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) { |
| if reqLog == nil { |
| return |
| } |
| if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 { |
| return |
| } |
| metrics := service.SnapshotOpenAICompatibilityFallbackMetrics() |
| reqLog.Info("gateway.compatibility_fallback_metrics", |
| zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal), |
| zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit), |
| zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal), |
| zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate), |
| zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal), |
| ) |
| } |
|
|
| func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { |
| if task == nil { |
| return |
| } |
| if h.usageRecordWorkerPool != nil { |
| h.usageRecordWorkerPool.Submit(task) |
| return |
| } |
| |
| ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
| defer cancel() |
| defer func() { |
| if recovered := recover(); recovered != nil { |
| logger.L().With( |
| zap.String("component", "handler.gateway.messages"), |
| zap.Any("panic", recovered), |
| ).Error("gateway.usage_record_task_panic_recovered") |
| } |
| }() |
| task(ctx) |
| } |
|
|
| |
| |
| func (h *GatewayHandler) getUserMsgQueueMode(account *service.Account, parsed *service.ParsedRequest) string { |
| if h.userMsgQueueHelper == nil { |
| return "" |
| } |
| |
| if !account.IsAnthropicOAuthOrSetupToken() { |
| return "" |
| } |
| if !service.IsRealUserMessage(parsed) { |
| return "" |
| } |
| |
| mode := account.GetUserMsgQueueMode() |
| if mode == "" { |
| mode = h.cfg.Gateway.UserMessageQueue.GetEffectiveMode() |
| } |
| return mode |
| } |
|
|