| package handler |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "fmt" |
| "math/rand/v2" |
| "net/http" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/service" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| var claudeCodeValidator = service.NewClaudeCodeValidator() |
|
|
| const claudeCodeParsedRequestContextKey = "claude_code_parsed_request" |
|
|
| |
| |
| func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) { |
| if c == nil || c.Request == nil { |
| return |
| } |
| if parsedReq != nil { |
| c.Set(claudeCodeParsedRequestContextKey, parsedReq) |
| } |
|
|
| ua := c.GetHeader("User-Agent") |
| |
| if !claudeCodeValidator.ValidateUserAgent(ua) { |
| ctx := service.SetClaudeCodeClient(c.Request.Context(), false) |
| c.Request = c.Request.WithContext(ctx) |
| return |
| } |
|
|
| isClaudeCode := false |
| if !strings.Contains(c.Request.URL.Path, "messages") { |
| |
| isClaudeCode = true |
| } else { |
| |
| bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq) |
| if bodyMap == nil { |
| bodyMap = claudeCodeBodyMapFromContextCache(c) |
| } |
| if bodyMap == nil && len(body) > 0 { |
| _ = json.Unmarshal(body, &bodyMap) |
| } |
| isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap) |
| } |
|
|
| |
| ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) |
|
|
| |
| if isClaudeCode { |
| if version := claudeCodeValidator.ExtractVersion(ua); version != "" { |
| ctx = service.SetClaudeCodeVersion(ctx, version) |
| } |
| } |
|
|
| c.Request = c.Request.WithContext(ctx) |
| } |
|
|
| func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any { |
| if parsedReq == nil { |
| return nil |
| } |
| bodyMap := map[string]any{ |
| "model": parsedReq.Model, |
| } |
| if parsedReq.System != nil || parsedReq.HasSystem { |
| bodyMap["system"] = parsedReq.System |
| } |
| if parsedReq.MetadataUserID != "" { |
| bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID} |
| } |
| return bodyMap |
| } |
|
|
| func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any { |
| if c == nil { |
| return nil |
| } |
| if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok { |
| if bodyMap, ok := cached.(map[string]any); ok { |
| return bodyMap |
| } |
| } |
| if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok { |
| switch v := cached.(type) { |
| case *service.ParsedRequest: |
| return claudeCodeBodyMapFromParsedRequest(v) |
| case service.ParsedRequest: |
| return claudeCodeBodyMapFromParsedRequest(&v) |
| } |
| } |
| return nil |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| const ( |
| |
| maxConcurrencyWait = 30 * time.Second |
| |
| defaultPingInterval = 10 * time.Second |
| |
| initialBackoff = 100 * time.Millisecond |
| |
| backoffMultiplier = 1.5 |
| |
| maxBackoff = 2 * time.Second |
| ) |
|
|
| |
| type SSEPingFormat string |
|
|
| const ( |
| |
| SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n" |
| |
| SSEPingFormatNone SSEPingFormat = "" |
| |
| SSEPingFormatComment SSEPingFormat = ":\n\n" |
| ) |
|
|
| |
| type ConcurrencyError struct { |
| SlotType string |
| IsTimeout bool |
| } |
|
|
| func (e *ConcurrencyError) Error() string { |
| if e.IsTimeout { |
| return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType) |
| } |
| return fmt.Sprintf("%s concurrency limit reached", e.SlotType) |
| } |
|
|
| |
| type ConcurrencyHelper struct { |
| concurrencyService *service.ConcurrencyService |
| pingFormat SSEPingFormat |
| pingInterval time.Duration |
| } |
|
|
| |
| func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat, pingInterval time.Duration) *ConcurrencyHelper { |
| if pingInterval <= 0 { |
| pingInterval = defaultPingInterval |
| } |
| return &ConcurrencyHelper{ |
| concurrencyService: concurrencyService, |
| pingFormat: pingFormat, |
| pingInterval: pingInterval, |
| } |
| } |
|
|
| |
| |
| |
| func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { |
| if releaseFunc == nil { |
| return nil |
| } |
| var once sync.Once |
| var stop func() bool |
|
|
| release := func() { |
| once.Do(func() { |
| if stop != nil { |
| _ = stop() |
| } |
| releaseFunc() |
| }) |
| } |
|
|
| stop = context.AfterFunc(ctx, release) |
|
|
| return release |
| } |
|
|
| |
| func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { |
| return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait) |
| } |
|
|
| |
| func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) { |
| h.concurrencyService.DecrementWaitCount(ctx, userID) |
| } |
|
|
| |
| func (h *ConcurrencyHelper) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { |
| return h.concurrencyService.IncrementAccountWaitCount(ctx, accountID, maxWait) |
| } |
|
|
| |
| func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accountID int64) { |
| h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) |
| } |
|
|
| |
| |
| func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) { |
| result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) |
| if err != nil { |
| return nil, false, err |
| } |
| if !result.Acquired { |
| return nil, false, nil |
| } |
| return result.ReleaseFunc, true, nil |
| } |
|
|
| |
| |
| func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) { |
| result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) |
| if err != nil { |
| return nil, false, err |
| } |
| if !result.Acquired { |
| return nil, false, nil |
| } |
| return result.ReleaseFunc, true, nil |
| } |
|
|
| |
| |
| |
| func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { |
| ctx := c.Request.Context() |
|
|
| |
| releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency) |
| if err != nil { |
| return nil, err |
| } |
|
|
| if acquired { |
| return releaseFunc, nil |
| } |
|
|
| |
| return h.waitForSlotWithPing(c, "user", userID, maxConcurrency, isStream, streamStarted) |
| } |
|
|
| |
| |
| |
| func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { |
| ctx := c.Request.Context() |
|
|
| |
| releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency) |
| if err != nil { |
| return nil, err |
| } |
|
|
| if acquired { |
| return releaseFunc, nil |
| } |
|
|
| |
| return h.waitForSlotWithPing(c, "account", accountID, maxConcurrency, isStream, streamStarted) |
| } |
|
|
| |
| |
| func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { |
| return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false) |
| } |
|
|
| |
| func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) { |
| ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) |
| defer cancel() |
|
|
| acquireSlot := func() (*service.AcquireResult, error) { |
| if slotType == "user" { |
| return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) |
| } |
| return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) |
| } |
|
|
| if tryImmediate { |
| result, err := acquireSlot() |
| if err != nil { |
| return nil, err |
| } |
| if result.Acquired { |
| return result.ReleaseFunc, nil |
| } |
| } |
|
|
| |
| needPing := isStream && h.pingFormat != "" |
|
|
| var flusher http.Flusher |
| if needPing { |
| var ok bool |
| flusher, ok = c.Writer.(http.Flusher) |
| if !ok { |
| return nil, fmt.Errorf("streaming not supported") |
| } |
| } |
|
|
| |
| var pingCh <-chan time.Time |
| if needPing { |
| pingTicker := time.NewTicker(h.pingInterval) |
| defer pingTicker.Stop() |
| pingCh = pingTicker.C |
| } |
|
|
| backoff := initialBackoff |
| timer := time.NewTimer(backoff) |
| defer timer.Stop() |
|
|
| for { |
| select { |
| case <-ctx.Done(): |
| return nil, &ConcurrencyError{ |
| SlotType: slotType, |
| IsTimeout: true, |
| } |
|
|
| case <-pingCh: |
| |
| if !*streamStarted { |
| c.Header("Content-Type", "text/event-stream") |
| c.Header("Cache-Control", "no-cache") |
| c.Header("Connection", "keep-alive") |
| c.Header("X-Accel-Buffering", "no") |
| *streamStarted = true |
| } |
| if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { |
| return nil, err |
| } |
| flusher.Flush() |
|
|
| case <-timer.C: |
| |
| result, err := acquireSlot() |
| if err != nil { |
| return nil, err |
| } |
|
|
| if result.Acquired { |
| return result.ReleaseFunc, nil |
| } |
| backoff = nextBackoff(backoff) |
| timer.Reset(backoff) |
| } |
| } |
| } |
|
|
| |
| func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { |
| return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true) |
| } |
|
|
| |
| |
| |
| |
| func nextBackoff(current time.Duration) time.Duration { |
| |
| next := time.Duration(float64(current) * backoffMultiplier) |
| if next > maxBackoff { |
| next = maxBackoff |
| } |
| |
| |
| jitter := 0.8 + rand.Float64()*0.4 |
| jittered := time.Duration(float64(next) * jitter) |
| if jittered < initialBackoff { |
| return initialBackoff |
| } |
| if jittered > maxBackoff { |
| return maxBackoff |
| } |
| return jittered |
| } |
|
|