| package channel |
|
|
| import ( |
| "net/http" |
| "net/http/httptest" |
| "testing" |
|
|
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| "github.com/gin-gonic/gin" |
| "github.com/stretchr/testify/require" |
| ) |
|
|
| func TestProcessHeaderOverride_ChannelTestSkipsPassthroughRules(t *testing.T) { |
| t.Parallel() |
|
|
| gin.SetMode(gin.TestMode) |
| recorder := httptest.NewRecorder() |
| ctx, _ := gin.CreateTestContext(recorder) |
| ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) |
| ctx.Request.Header.Set("X-Trace-Id", "trace-123") |
|
|
| info := &relaycommon.RelayInfo{ |
| IsChannelTest: true, |
| ChannelMeta: &relaycommon.ChannelMeta{ |
| HeadersOverride: map[string]any{ |
| "*": "", |
| }, |
| }, |
| } |
|
|
| headers, err := processHeaderOverride(info, ctx) |
| require.NoError(t, err) |
| require.Empty(t, headers) |
| } |
|
|
| func TestProcessHeaderOverride_ChannelTestSkipsClientHeaderPlaceholder(t *testing.T) { |
| t.Parallel() |
|
|
| gin.SetMode(gin.TestMode) |
| recorder := httptest.NewRecorder() |
| ctx, _ := gin.CreateTestContext(recorder) |
| ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) |
| ctx.Request.Header.Set("X-Trace-Id", "trace-123") |
|
|
| info := &relaycommon.RelayInfo{ |
| IsChannelTest: true, |
| ChannelMeta: &relaycommon.ChannelMeta{ |
| HeadersOverride: map[string]any{ |
| "X-Upstream-Trace": "{client_header:X-Trace-Id}", |
| }, |
| }, |
| } |
|
|
| headers, err := processHeaderOverride(info, ctx) |
| require.NoError(t, err) |
| _, ok := headers["x-upstream-trace"] |
| require.False(t, ok) |
| } |
|
|
| func TestProcessHeaderOverride_NonTestKeepsClientHeaderPlaceholder(t *testing.T) { |
| t.Parallel() |
|
|
| gin.SetMode(gin.TestMode) |
| recorder := httptest.NewRecorder() |
| ctx, _ := gin.CreateTestContext(recorder) |
| ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) |
| ctx.Request.Header.Set("X-Trace-Id", "trace-123") |
|
|
| info := &relaycommon.RelayInfo{ |
| IsChannelTest: false, |
| ChannelMeta: &relaycommon.ChannelMeta{ |
| HeadersOverride: map[string]any{ |
| "X-Upstream-Trace": "{client_header:X-Trace-Id}", |
| }, |
| }, |
| } |
|
|
| headers, err := processHeaderOverride(info, ctx) |
| require.NoError(t, err) |
| require.Equal(t, "trace-123", headers["x-upstream-trace"]) |
| } |
|
|
| func TestProcessHeaderOverride_RuntimeOverrideIsFinalHeaderMap(t *testing.T) { |
| t.Parallel() |
|
|
| gin.SetMode(gin.TestMode) |
| recorder := httptest.NewRecorder() |
| ctx, _ := gin.CreateTestContext(recorder) |
| ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) |
|
|
| info := &relaycommon.RelayInfo{ |
| IsChannelTest: false, |
| UseRuntimeHeadersOverride: true, |
| RuntimeHeadersOverride: map[string]any{ |
| "x-static": "runtime-value", |
| "x-runtime": "runtime-only", |
| }, |
| ChannelMeta: &relaycommon.ChannelMeta{ |
| HeadersOverride: map[string]any{ |
| "X-Static": "legacy-value", |
| "X-Legacy": "legacy-only", |
| }, |
| }, |
| } |
|
|
| headers, err := processHeaderOverride(info, ctx) |
| require.NoError(t, err) |
| require.Equal(t, "runtime-value", headers["x-static"]) |
| require.Equal(t, "runtime-only", headers["x-runtime"]) |
| _, exists := headers["x-legacy"] |
| require.False(t, exists) |
| } |
|
|
| func TestProcessHeaderOverride_PassthroughSkipsAcceptEncoding(t *testing.T) { |
| t.Parallel() |
|
|
| gin.SetMode(gin.TestMode) |
| recorder := httptest.NewRecorder() |
| ctx, _ := gin.CreateTestContext(recorder) |
| ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) |
| ctx.Request.Header.Set("X-Trace-Id", "trace-123") |
| ctx.Request.Header.Set("Accept-Encoding", "gzip") |
|
|
| info := &relaycommon.RelayInfo{ |
| IsChannelTest: false, |
| ChannelMeta: &relaycommon.ChannelMeta{ |
| HeadersOverride: map[string]any{ |
| "*": "", |
| }, |
| }, |
| } |
|
|
| headers, err := processHeaderOverride(info, ctx) |
| require.NoError(t, err) |
| require.Equal(t, "trace-123", headers["x-trace-id"]) |
|
|
| _, hasAcceptEncoding := headers["accept-encoding"] |
| require.False(t, hasAcceptEncoding) |
| } |
|
|
| func TestProcessHeaderOverride_PassHeadersTemplateSetsRuntimeHeaders(t *testing.T) { |
| t.Parallel() |
|
|
| gin.SetMode(gin.TestMode) |
| recorder := httptest.NewRecorder() |
| ctx, _ := gin.CreateTestContext(recorder) |
| ctx.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) |
| ctx.Request.Header.Set("Originator", "Codex CLI") |
| ctx.Request.Header.Set("Session_id", "sess-123") |
|
|
| info := &relaycommon.RelayInfo{ |
| IsChannelTest: false, |
| RequestHeaders: map[string]string{ |
| "Originator": "Codex CLI", |
| "Session_id": "sess-123", |
| }, |
| ChannelMeta: &relaycommon.ChannelMeta{ |
| ParamOverride: map[string]any{ |
| "operations": []any{ |
| map[string]any{ |
| "mode": "pass_headers", |
| "value": []any{"Originator", "Session_id", "X-Codex-Beta-Features"}, |
| }, |
| }, |
| }, |
| HeadersOverride: map[string]any{ |
| "X-Static": "legacy-value", |
| }, |
| }, |
| } |
|
|
| _, err := relaycommon.ApplyParamOverrideWithRelayInfo([]byte(`{"model":"gpt-4.1"}`), info) |
| require.NoError(t, err) |
| require.True(t, info.UseRuntimeHeadersOverride) |
| require.Equal(t, "Codex CLI", info.RuntimeHeadersOverride["originator"]) |
| require.Equal(t, "sess-123", info.RuntimeHeadersOverride["session_id"]) |
| _, exists := info.RuntimeHeadersOverride["x-codex-beta-features"] |
| require.False(t, exists) |
| require.Equal(t, "legacy-value", info.RuntimeHeadersOverride["x-static"]) |
|
|
| headers, err := processHeaderOverride(info, ctx) |
| require.NoError(t, err) |
| require.Equal(t, "Codex CLI", headers["originator"]) |
| require.Equal(t, "sess-123", headers["session_id"]) |
| _, exists = headers["x-codex-beta-features"] |
| require.False(t, exists) |
|
|
| upstreamReq := httptest.NewRequest(http.MethodPost, "https://example.com/v1/responses", nil) |
| applyHeaderOverrideToRequest(upstreamReq, headers) |
| require.Equal(t, "Codex CLI", upstreamReq.Header.Get("Originator")) |
| require.Equal(t, "sess-123", upstreamReq.Header.Get("Session_id")) |
| require.Empty(t, upstreamReq.Header.Get("X-Codex-Beta-Features")) |
| } |
|
|