| package handler |
|
|
| import ( |
| "net/http" |
| "net/http/httptest" |
| "strings" |
| "testing" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/service" |
| "github.com/gin-gonic/gin" |
| "github.com/stretchr/testify/assert" |
| "github.com/stretchr/testify/require" |
| ) |
|
|
| |
| const partialMessageStartSSE = "event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_01\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-5\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"output_tokens\":1}}}\n\n" + |
| "event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n" |
|
|
| |
| |
| |
| |
| |
| |
| |
| func TestStreamWrittenGuard_MessagesPath_AbortFailoverOnSSEContentWritten(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) |
|
|
| |
| sizeBeforeForward := c.Writer.Size() |
| require.Equal(t, -1, sizeBeforeForward, "gin writer 初始 Size 应为 -1(未写入任何字节)") |
|
|
| |
| _, err := c.Writer.Write([]byte(partialMessageStartSSE)) |
| require.NoError(t, err) |
|
|
| |
| require.NotEqual(t, sizeBeforeForward, c.Writer.Size(), |
| "写入 SSE 内容后 writer size 必须增加,守卫条件应为 true") |
|
|
| |
| failoverErr := &service.UpstreamFailoverError{ |
| StatusCode: http.StatusForbidden, |
| ResponseBody: []byte(`{"error":{"type":"permission_error","message":"forbidden"}}`), |
| } |
|
|
| |
| h := &GatewayHandler{} |
| h.handleFailoverExhausted(c, failoverErr, service.PlatformAnthropic, true) |
|
|
| body := w.Body.String() |
|
|
| |
| require.Contains(t, body, "event: message_start", "响应体应包含已写入的 message_start SSE 事件") |
|
|
| |
| require.True(t, strings.HasSuffix(strings.TrimRight(body, "\n"), "}"), |
| "响应体应以 JSON 对象结尾(SSE error event 的 data 字段)") |
| require.Contains(t, body, `"type":"error"`, "响应体末尾必须包含 SSE 错误事件") |
|
|
| |
| firstIdx := strings.Index(body, "event: message_start") |
| lastIdx := strings.LastIndex(body, "event: message_start") |
| assert.Equal(t, firstIdx, lastIdx, |
| "响应体中 'event: message_start' 必须只出现一次,不得因 failover 拼接导致两次") |
| } |
|
|
| |
| |
| func TestStreamWrittenGuard_GeminiPath_AbortFailoverOnSSEContentWritten(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.0-flash:streamGenerateContent", nil) |
|
|
| sizeBeforeForward := c.Writer.Size() |
|
|
| _, err := c.Writer.Write([]byte(partialMessageStartSSE)) |
| require.NoError(t, err) |
|
|
| require.NotEqual(t, sizeBeforeForward, c.Writer.Size()) |
|
|
| failoverErr := &service.UpstreamFailoverError{ |
| StatusCode: http.StatusForbidden, |
| } |
|
|
| h := &GatewayHandler{} |
| h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, true) |
|
|
| body := w.Body.String() |
|
|
| require.Contains(t, body, "event: message_start") |
| require.Contains(t, body, `"type":"error"`) |
|
|
| firstIdx := strings.Index(body, "event: message_start") |
| lastIdx := strings.LastIndex(body, "event: message_start") |
| assert.Equal(t, firstIdx, lastIdx, "Gemini 路径不得出现双 message_start") |
| } |
|
|
| |
| |
| |
| func TestStreamWrittenGuard_NoByteWritten_GuardNotTriggered(t *testing.T) { |
| gin.SetMode(gin.TestMode) |
| w := httptest.NewRecorder() |
| c, _ := gin.CreateTestContext(w) |
| c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) |
|
|
| |
| sizeBeforeForward := c.Writer.Size() |
|
|
| |
| |
|
|
| |
| guardTriggered := c.Writer.Size() != sizeBeforeForward |
| require.False(t, guardTriggered, |
| "未写入任何字节时,守卫条件必须为 false,应允许正常 failover 继续") |
| } |
|
|