| package service |
|
|
| import ( |
| "fmt" |
| "net/http/httptest" |
| "testing" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/types" |
| "github.com/gin-gonic/gin" |
| "github.com/stretchr/testify/require" |
| ) |
|
|
| func buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP string) *gin.Context { |
| rec := httptest.NewRecorder() |
| ctx, _ := gin.CreateTestContext(rec) |
| setChannelAffinityContext(ctx, channelAffinityMeta{ |
| CacheKey: fmt.Sprintf("test:%s:%s:%s", ruleName, usingGroup, keyFP), |
| TTLSeconds: 600, |
| RuleName: ruleName, |
| UsingGroup: usingGroup, |
| KeyFingerprint: keyFP, |
| }) |
| return ctx |
| } |
|
|
| func TestObserveChannelAffinityUsageCacheByRelayFormat_ClaudeMode(t *testing.T) { |
| ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) |
| usingGroup := "default" |
| keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) |
| ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) |
|
|
| usage := &dto.Usage{ |
| PromptTokens: 100, |
| CompletionTokens: 40, |
| TotalTokens: 140, |
| PromptTokensDetails: dto.InputTokenDetails{ |
| CachedTokens: 30, |
| }, |
| } |
|
|
| ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatClaude) |
| stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) |
|
|
| require.EqualValues(t, 1, stats.Total) |
| require.EqualValues(t, 1, stats.Hit) |
| require.EqualValues(t, 100, stats.PromptTokens) |
| require.EqualValues(t, 40, stats.CompletionTokens) |
| require.EqualValues(t, 140, stats.TotalTokens) |
| require.EqualValues(t, 30, stats.CachedTokens) |
| require.Equal(t, cacheTokenRateModeCachedOverPromptPlusCached, stats.CachedTokenRateMode) |
| } |
|
|
| func TestObserveChannelAffinityUsageCacheByRelayFormat_MixedMode(t *testing.T) { |
| ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) |
| usingGroup := "default" |
| keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) |
| ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) |
|
|
| openAIUsage := &dto.Usage{ |
| PromptTokens: 100, |
| PromptTokensDetails: dto.InputTokenDetails{ |
| CachedTokens: 10, |
| }, |
| } |
| claudeUsage := &dto.Usage{ |
| PromptTokens: 80, |
| PromptTokensDetails: dto.InputTokenDetails{ |
| CachedTokens: 20, |
| }, |
| } |
|
|
| ObserveChannelAffinityUsageCacheByRelayFormat(ctx, openAIUsage, types.RelayFormatOpenAI) |
| ObserveChannelAffinityUsageCacheByRelayFormat(ctx, claudeUsage, types.RelayFormatClaude) |
| stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) |
|
|
| require.EqualValues(t, 2, stats.Total) |
| require.EqualValues(t, 2, stats.Hit) |
| require.EqualValues(t, 180, stats.PromptTokens) |
| require.EqualValues(t, 30, stats.CachedTokens) |
| require.Equal(t, cacheTokenRateModeMixed, stats.CachedTokenRateMode) |
| } |
|
|
| func TestObserveChannelAffinityUsageCacheByRelayFormat_UnsupportedModeKeepsEmpty(t *testing.T) { |
| ruleName := fmt.Sprintf("rule_%d", time.Now().UnixNano()) |
| usingGroup := "default" |
| keyFP := fmt.Sprintf("fp_%d", time.Now().UnixNano()) |
| ctx := buildChannelAffinityStatsContextForTest(ruleName, usingGroup, keyFP) |
|
|
| usage := &dto.Usage{ |
| PromptTokens: 100, |
| PromptTokensDetails: dto.InputTokenDetails{ |
| CachedTokens: 25, |
| }, |
| } |
|
|
| ObserveChannelAffinityUsageCacheByRelayFormat(ctx, usage, types.RelayFormatGemini) |
| stats := GetChannelAffinityUsageCacheStats(ruleName, usingGroup, keyFP) |
|
|
| require.EqualValues(t, 1, stats.Total) |
| require.EqualValues(t, 1, stats.Hit) |
| require.EqualValues(t, 25, stats.CachedTokens) |
| require.Equal(t, "", stats.CachedTokenRateMode) |
| } |
|
|