| |
|
|
| package handler |
|
|
| import ( |
| "bytes" |
| "context" |
| "encoding/json" |
| "net/http/httptest" |
| "testing" |
| "time" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/config" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" |
| middleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
|
|
| "github.com/gin-gonic/gin" |
| "github.com/stretchr/testify/require" |
| ) |
|
|
| |
| |
| |
|
|
| type fakeSchedulerCache struct { |
| accounts []*service.Account |
| } |
|
|
| func (f *fakeSchedulerCache) GetSnapshot(_ context.Context, _ service.SchedulerBucket) ([]*service.Account, bool, error) { |
| return f.accounts, true, nil |
| } |
| func (f *fakeSchedulerCache) SetSnapshot(_ context.Context, _ service.SchedulerBucket, _ []service.Account) error { |
| return nil |
| } |
| func (f *fakeSchedulerCache) GetAccount(_ context.Context, _ int64) (*service.Account, error) { |
| return nil, nil |
| } |
| func (f *fakeSchedulerCache) SetAccount(_ context.Context, _ *service.Account) error { return nil } |
| func (f *fakeSchedulerCache) DeleteAccount(_ context.Context, _ int64) error { return nil } |
| func (f *fakeSchedulerCache) UpdateLastUsed(_ context.Context, _ map[int64]time.Time) error { |
| return nil |
| } |
| func (f *fakeSchedulerCache) TryLockBucket(_ context.Context, _ service.SchedulerBucket, _ time.Duration) (bool, error) { |
| return true, nil |
| } |
| func (f *fakeSchedulerCache) ListBuckets(_ context.Context) ([]service.SchedulerBucket, error) { |
| return nil, nil |
| } |
| func (f *fakeSchedulerCache) GetOutboxWatermark(_ context.Context) (int64, error) { return 0, nil } |
| func (f *fakeSchedulerCache) SetOutboxWatermark(_ context.Context, _ int64) error { return nil } |
|
|
| type fakeGroupRepo struct { |
| group *service.Group |
| } |
|
|
| func (f *fakeGroupRepo) Create(context.Context, *service.Group) error { return nil } |
| func (f *fakeGroupRepo) GetByID(context.Context, int64) (*service.Group, error) { |
| return f.group, nil |
| } |
| func (f *fakeGroupRepo) GetByIDLite(context.Context, int64) (*service.Group, error) { |
| return f.group, nil |
| } |
| func (f *fakeGroupRepo) Update(context.Context, *service.Group) error { return nil } |
| func (f *fakeGroupRepo) Delete(context.Context, int64) error { return nil } |
| func (f *fakeGroupRepo) DeleteCascade(context.Context, int64) ([]int64, error) { return nil, nil } |
| func (f *fakeGroupRepo) List(context.Context, pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { |
| return nil, nil, nil |
| } |
| func (f *fakeGroupRepo) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]service.Group, *pagination.PaginationResult, error) { |
| return nil, nil, nil |
| } |
| func (f *fakeGroupRepo) ListActive(context.Context) ([]service.Group, error) { return nil, nil } |
| func (f *fakeGroupRepo) ListActiveByPlatform(context.Context, string) ([]service.Group, error) { |
| return nil, nil |
| } |
| func (f *fakeGroupRepo) ExistsByName(context.Context, string) (bool, error) { return false, nil } |
| func (f *fakeGroupRepo) GetAccountCount(context.Context, int64) (int64, int64, error) { return 0, 0, nil } |
| func (f *fakeGroupRepo) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { |
| return 0, nil |
| } |
| func (f *fakeGroupRepo) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { |
| return nil, nil |
| } |
| func (f *fakeGroupRepo) BindAccountsToGroup(context.Context, int64, []int64) error { return nil } |
| func (f *fakeGroupRepo) UpdateSortOrders(context.Context, []service.GroupSortOrderUpdate) error { |
| return nil |
| } |
|
|
| type fakeConcurrencyCache struct{} |
|
|
| func (f *fakeConcurrencyCache) AcquireAccountSlot(context.Context, int64, int, string) (bool, error) { |
| return true, nil |
| } |
| func (f *fakeConcurrencyCache) ReleaseAccountSlot(context.Context, int64, string) error { return nil } |
| func (f *fakeConcurrencyCache) GetAccountConcurrency(context.Context, int64) (int, error) { |
| return 0, nil |
| } |
| func (f *fakeConcurrencyCache) IncrementAccountWaitCount(context.Context, int64, int) (bool, error) { |
| return true, nil |
| } |
| func (f *fakeConcurrencyCache) DecrementAccountWaitCount(context.Context, int64) error { return nil } |
| func (f *fakeConcurrencyCache) GetAccountWaitingCount(context.Context, int64) (int, error) { |
| return 0, nil |
| } |
| func (f *fakeConcurrencyCache) AcquireUserSlot(context.Context, int64, int, string) (bool, error) { |
| return true, nil |
| } |
| func (f *fakeConcurrencyCache) ReleaseUserSlot(context.Context, int64, string) error { return nil } |
| func (f *fakeConcurrencyCache) GetUserConcurrency(context.Context, int64) (int, error) { return 0, nil } |
| func (f *fakeConcurrencyCache) IncrementWaitCount(context.Context, int64, int) (bool, error) { |
| return true, nil |
| } |
| func (f *fakeConcurrencyCache) DecrementWaitCount(context.Context, int64) error { return nil } |
| func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { |
| return map[int64]*service.AccountLoadInfo{}, nil |
| } |
| func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { |
| return map[int64]*service.UserLoadInfo{}, nil |
| } |
| func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { |
| result := make(map[int64]int, len(accountIDs)) |
| for _, id := range accountIDs { |
| result[id] = 0 |
| } |
| return result, nil |
| } |
| func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } |
| func (f *fakeConcurrencyCache) CleanupStaleProcessSlots(context.Context, string) error { return nil } |
|
|
| func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { |
| t.Helper() |
|
|
| schedulerCache := &fakeSchedulerCache{accounts: accounts} |
| schedulerSnapshot := service.NewSchedulerSnapshotService(schedulerCache, nil, nil, nil, nil) |
|
|
| gwSvc := service.NewGatewayService( |
| nil, |
| &fakeGroupRepo{group: group}, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| schedulerSnapshot, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| nil, |
| ) |
|
|
| |
| cfg := &config.Config{RunMode: config.RunModeSimple} |
| billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) |
|
|
| concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) |
| concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) |
|
|
| h := &GatewayHandler{ |
| gatewayService: gwSvc, |
| billingCacheService: billingCacheSvc, |
| concurrencyHelper: concurrencyHelper, |
| |
| maxAccountSwitches: 1, |
| maxAccountSwitchesGemini: 1, |
| } |
|
|
| cleanup := func() { |
| billingCacheSvc.Stop() |
| } |
| return h, cleanup |
| } |
|
|
| func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_MixedSchedulingV1(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| groupID := int64(2001) |
| accountID := int64(1001) |
|
|
| group := &service.Group{ |
| ID: groupID, |
| Hydrated: true, |
| Platform: service.PlatformAnthropic, |
| Status: service.StatusActive, |
| } |
|
|
| account := &service.Account{ |
| ID: accountID, |
| Name: "ag-1", |
| Platform: service.PlatformAntigravity, |
| Type: service.AccountTypeOAuth, |
| Credentials: map[string]any{ |
| "access_token": "tok_xxx", |
| "intercept_warmup_requests": true, |
| }, |
| Extra: map[string]any{ |
| "mixed_scheduling": true, |
| }, |
| Concurrency: 1, |
| Priority: 1, |
| Status: service.StatusActive, |
| Schedulable: true, |
| AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}}, |
| } |
|
|
| h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account}) |
| defer cleanup() |
|
|
| rec := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(rec) |
|
|
| body := []byte(`{ |
| "model": "claude-sonnet-4-5", |
| "max_tokens": 256, |
| "messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}] |
| }`) |
| req := httptest.NewRequest("POST", "/v1/messages", bytes.NewReader(body)) |
| req.Header.Set("Content-Type", "application/json") |
| req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, group)) |
| c.Request = req |
|
|
| apiKey := &service.APIKey{ |
| ID: 3001, |
| UserID: 4001, |
| GroupID: &groupID, |
| Status: service.StatusActive, |
| User: &service.User{ |
| ID: 4001, |
| Concurrency: 10, |
| Balance: 100, |
| }, |
| Group: group, |
| } |
|
|
| c.Set(string(middleware.ContextKeyAPIKey), apiKey) |
| c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10}) |
|
|
| h.Messages(c) |
|
|
| require.Equal(t, 200, rec.Code) |
|
|
| |
| selected, ok := c.Get(opsAccountIDKey) |
| require.True(t, ok) |
| require.Equal(t, accountID, selected) |
|
|
| var resp map[string]any |
| require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) |
| require.Equal(t, "msg_mock_warmup", resp["id"]) |
| require.Equal(t, "claude-sonnet-4-5", resp["model"]) |
|
|
| content, ok := resp["content"].([]any) |
| require.True(t, ok) |
| require.Len(t, content, 1) |
| first, ok := content[0].(map[string]any) |
| require.True(t, ok) |
| require.Equal(t, "New Conversation", first["text"]) |
| } |
|
|
| func TestGatewayHandlerMessages_InterceptWarmup_AntigravityAccount_ForcePlatform(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
|
|
| groupID := int64(2002) |
| accountID := int64(1002) |
|
|
| group := &service.Group{ |
| ID: groupID, |
| Hydrated: true, |
| Platform: service.PlatformAntigravity, |
| Status: service.StatusActive, |
| } |
|
|
| account := &service.Account{ |
| ID: accountID, |
| Name: "ag-2", |
| Platform: service.PlatformAntigravity, |
| Type: service.AccountTypeOAuth, |
| Credentials: map[string]any{ |
| "access_token": "tok_xxx", |
| "intercept_warmup_requests": true, |
| }, |
| Concurrency: 1, |
| Priority: 1, |
| Status: service.StatusActive, |
| Schedulable: true, |
| AccountGroups: []service.AccountGroup{{AccountID: accountID, GroupID: groupID}}, |
| } |
|
|
| h, cleanup := newTestGatewayHandler(t, group, []*service.Account{account}) |
| defer cleanup() |
|
|
| rec := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(rec) |
|
|
| body := []byte(`{ |
| "model": "claude-sonnet-4-5", |
| "max_tokens": 256, |
| "messages": [{"role":"user","content":[{"type":"text","text":"Warmup"}]}] |
| }`) |
| req := httptest.NewRequest("POST", "/antigravity/v1/messages", bytes.NewReader(body)) |
| req.Header.Set("Content-Type", "application/json") |
|
|
| |
| |
| |
| ctx := context.WithValue(req.Context(), ctxkey.Group, group) |
| ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformAntigravity) |
| req = req.WithContext(ctx) |
| c.Request = req |
| c.Set(string(middleware.ContextKeyForcePlatform), service.PlatformAntigravity) |
|
|
| apiKey := &service.APIKey{ |
| ID: 3002, |
| UserID: 4002, |
| GroupID: &groupID, |
| Status: service.StatusActive, |
| User: &service.User{ |
| ID: 4002, |
| Concurrency: 10, |
| Balance: 100, |
| }, |
| Group: group, |
| } |
|
|
| c.Set(string(middleware.ContextKeyAPIKey), apiKey) |
| c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: apiKey.UserID, Concurrency: 10}) |
|
|
| h.Messages(c) |
|
|
| require.Equal(t, 200, rec.Code) |
|
|
| selected, ok := c.Get(opsAccountIDKey) |
| require.True(t, ok) |
| require.Equal(t, accountID, selected) |
|
|
| var resp map[string]any |
| require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) |
| require.Equal(t, "msg_mock_warmup", resp["id"]) |
| require.Equal(t, "claude-sonnet-4-5", resp["model"]) |
| } |
|
|