| package channel |
|
|
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "io" |
| "net/http" |
| "regexp" |
| "strings" |
| "sync" |
| "time" |
|
|
| common2 "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/logger" |
| "github.com/QuantumNous/new-api/relay/common" |
| "github.com/QuantumNous/new-api/relay/constant" |
| "github.com/QuantumNous/new-api/relay/helper" |
| "github.com/QuantumNous/new-api/service" |
| "github.com/QuantumNous/new-api/setting/operation_setting" |
| "github.com/QuantumNous/new-api/types" |
|
|
| "github.com/bytedance/gopkg/util/gopool" |
| "github.com/gin-gonic/gin" |
| "github.com/gorilla/websocket" |
| ) |
|
|
| func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { |
| if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { |
| |
| } else if info.RelayMode == constant.RelayModeRealtime { |
| |
| } else { |
| req.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
| req.Set("Accept", c.Request.Header.Get("Accept")) |
| if info.IsStream && c.Request.Header.Get("Accept") == "" { |
| req.Set("Accept", "text/event-stream") |
| } |
| } |
| } |
|
|
| const clientHeaderPlaceholderPrefix = "{client_header:" |
|
|
| const ( |
| headerPassthroughAllKey = "*" |
| headerPassthroughRegexPrefix = "re:" |
| headerPassthroughRegexPrefixV2 = "regex:" |
| ) |
|
|
| var passthroughSkipHeaderNamesLower = map[string]struct{}{ |
| |
| "connection": {}, |
| "keep-alive": {}, |
| "proxy-authenticate": {}, |
| "proxy-authorization": {}, |
| "te": {}, |
| "trailer": {}, |
| "transfer-encoding": {}, |
| "upgrade": {}, |
|
|
| "cookie": {}, |
|
|
| |
| "host": {}, |
| "content-length": {}, |
| "accept-encoding": {}, |
|
|
| |
| "authorization": {}, |
| "x-api-key": {}, |
| "x-goog-api-key": {}, |
|
|
| |
| "sec-websocket-key": {}, |
| "sec-websocket-version": {}, |
| "sec-websocket-extensions": {}, |
| } |
|
|
| var headerPassthroughRegexCache sync.Map |
|
|
| func getHeaderPassthroughRegex(pattern string) (*regexp.Regexp, error) { |
| pattern = strings.TrimSpace(pattern) |
| if pattern == "" { |
| return nil, errors.New("empty regex pattern") |
| } |
| if v, ok := headerPassthroughRegexCache.Load(pattern); ok { |
| if re, ok := v.(*regexp.Regexp); ok { |
| return re, nil |
| } |
| headerPassthroughRegexCache.Delete(pattern) |
| } |
| compiled, err := regexp.Compile(pattern) |
| if err != nil { |
| return nil, err |
| } |
| actual, _ := headerPassthroughRegexCache.LoadOrStore(pattern, compiled) |
| if re, ok := actual.(*regexp.Regexp); ok { |
| return re, nil |
| } |
| return compiled, nil |
| } |
|
|
| func IsHeaderPassthroughRuleKey(key string) bool { |
| return isHeaderPassthroughRuleKey(key) |
| } |
| func isHeaderPassthroughRuleKey(key string) bool { |
| key = strings.TrimSpace(key) |
| if key == "" { |
| return false |
| } |
| if key == headerPassthroughAllKey { |
| return true |
| } |
| lower := strings.ToLower(key) |
| return strings.HasPrefix(lower, headerPassthroughRegexPrefix) || strings.HasPrefix(lower, headerPassthroughRegexPrefixV2) |
| } |
|
|
| func shouldSkipPassthroughHeader(name string) bool { |
| name = strings.TrimSpace(name) |
| if name == "" { |
| return true |
| } |
| lower := strings.ToLower(name) |
| if _, ok := passthroughSkipHeaderNamesLower[lower]; ok { |
| return true |
| } |
| return false |
| } |
|
|
| func applyHeaderOverridePlaceholders(template string, c *gin.Context, apiKey string) (string, bool, error) { |
| trimmed := strings.TrimSpace(template) |
| if strings.HasPrefix(trimmed, clientHeaderPlaceholderPrefix) { |
| afterPrefix := trimmed[len(clientHeaderPlaceholderPrefix):] |
| end := strings.Index(afterPrefix, "}") |
| if end < 0 || end != len(afterPrefix)-1 { |
| return "", false, fmt.Errorf("client_header placeholder must be the full value: %q", template) |
| } |
|
|
| name := strings.TrimSpace(afterPrefix[:end]) |
| if name == "" { |
| return "", false, fmt.Errorf("client_header placeholder name is empty: %q", template) |
| } |
| if c == nil || c.Request == nil { |
| return "", false, fmt.Errorf("missing request context for client_header placeholder") |
| } |
| clientHeaderValue := c.Request.Header.Get(name) |
| if strings.TrimSpace(clientHeaderValue) == "" { |
| return "", false, nil |
| } |
| |
| return clientHeaderValue, true, nil |
| } |
|
|
| if strings.Contains(template, "{api_key}") { |
| template = strings.ReplaceAll(template, "{api_key}", apiKey) |
| } |
| if strings.TrimSpace(template) == "" { |
| return "", false, nil |
| } |
| return template, true, nil |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| func processHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) { |
| headerOverride := make(map[string]string) |
| if info == nil { |
| return headerOverride, nil |
| } |
|
|
| headerOverrideSource := common.GetEffectiveHeaderOverride(info) |
|
|
| passAll := false |
| var passthroughRegex []*regexp.Regexp |
| if !info.IsChannelTest { |
| for k := range headerOverrideSource { |
| key := strings.TrimSpace(strings.ToLower(k)) |
| if key == "" { |
| continue |
| } |
| if key == headerPassthroughAllKey { |
| passAll = true |
| continue |
| } |
|
|
| var pattern string |
| switch { |
| case strings.HasPrefix(key, headerPassthroughRegexPrefix): |
| pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefix):]) |
| case strings.HasPrefix(key, headerPassthroughRegexPrefixV2): |
| pattern = strings.TrimSpace(key[len(headerPassthroughRegexPrefixV2):]) |
| default: |
| continue |
| } |
|
|
| if pattern == "" { |
| return nil, types.NewError(fmt.Errorf("header passthrough regex pattern is empty: %q", k), types.ErrorCodeChannelHeaderOverrideInvalid) |
| } |
| compiled, err := getHeaderPassthroughRegex(pattern) |
| if err != nil { |
| return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) |
| } |
| passthroughRegex = append(passthroughRegex, compiled) |
| } |
| } |
|
|
| if passAll || len(passthroughRegex) > 0 { |
| if c == nil || c.Request == nil { |
| return nil, types.NewError(fmt.Errorf("missing request context for header passthrough"), types.ErrorCodeChannelHeaderOverrideInvalid) |
| } |
| for name := range c.Request.Header { |
| if shouldSkipPassthroughHeader(name) { |
| continue |
| } |
| if !passAll { |
| matched := false |
| for _, re := range passthroughRegex { |
| if re.MatchString(name) { |
| matched = true |
| break |
| } |
| } |
| if !matched { |
| continue |
| } |
| } |
| value := strings.TrimSpace(c.Request.Header.Get(name)) |
| if value == "" { |
| continue |
| } |
| headerOverride[strings.ToLower(strings.TrimSpace(name))] = value |
| } |
| } |
|
|
| for k, v := range headerOverrideSource { |
| if isHeaderPassthroughRuleKey(k) { |
| continue |
| } |
| key := strings.TrimSpace(strings.ToLower(k)) |
| if key == "" { |
| continue |
| } |
|
|
| str, ok := v.(string) |
| if !ok { |
| return nil, types.NewError(nil, types.ErrorCodeChannelHeaderOverrideInvalid) |
| } |
| if info.IsChannelTest && strings.HasPrefix(strings.TrimSpace(str), clientHeaderPlaceholderPrefix) { |
| continue |
| } |
|
|
| value, include, err := applyHeaderOverridePlaceholders(str, c, info.ApiKey) |
| if err != nil { |
| return nil, types.NewError(err, types.ErrorCodeChannelHeaderOverrideInvalid) |
| } |
| if !include { |
| continue |
| } |
|
|
| headerOverride[key] = value |
| } |
| return headerOverride, nil |
| } |
|
|
| func ResolveHeaderOverride(info *common.RelayInfo, c *gin.Context) (map[string]string, error) { |
| return processHeaderOverride(info, c) |
| } |
|
|
| func applyHeaderOverrideToRequest(req *http.Request, headerOverride map[string]string) { |
| if req == nil { |
| return |
| } |
| for key, value := range headerOverride { |
| req.Header.Set(key, value) |
| |
| if strings.EqualFold(key, "Host") { |
| req.Host = value |
| } |
| } |
| } |
|
|
| func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
| fullRequestURL, err := a.GetRequestURL(info) |
| if err != nil { |
| return nil, fmt.Errorf("get request url failed: %w", err) |
| } |
| if common2.DebugEnabled { |
| println("fullRequestURL:", fullRequestURL) |
| } |
| req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
| if err != nil { |
| return nil, fmt.Errorf("new request failed: %w", err) |
| } |
| headers := req.Header |
| err = a.SetupRequestHeader(c, &headers, info) |
| if err != nil { |
| return nil, fmt.Errorf("setup request header failed: %w", err) |
| } |
| |
| |
| headerOverride, err := processHeaderOverride(info, c) |
| if err != nil { |
| return nil, err |
| } |
| applyHeaderOverrideToRequest(req, headerOverride) |
| resp, err := doRequest(c, req, info) |
| if err != nil { |
| return nil, fmt.Errorf("do request failed: %w", err) |
| } |
| return resp, nil |
| } |
|
|
| func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
| fullRequestURL, err := a.GetRequestURL(info) |
| if err != nil { |
| return nil, fmt.Errorf("get request url failed: %w", err) |
| } |
| if common2.DebugEnabled { |
| println("fullRequestURL:", fullRequestURL) |
| } |
| req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
| if err != nil { |
| return nil, fmt.Errorf("new request failed: %w", err) |
| } |
| |
| req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
| headers := req.Header |
| err = a.SetupRequestHeader(c, &headers, info) |
| if err != nil { |
| return nil, fmt.Errorf("setup request header failed: %w", err) |
| } |
| |
| |
| headerOverride, err := processHeaderOverride(info, c) |
| if err != nil { |
| return nil, err |
| } |
| applyHeaderOverrideToRequest(req, headerOverride) |
| resp, err := doRequest(c, req, info) |
| if err != nil { |
| return nil, fmt.Errorf("do request failed: %w", err) |
| } |
| return resp, nil |
| } |
|
|
| func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) { |
| fullRequestURL, err := a.GetRequestURL(info) |
| if err != nil { |
| return nil, fmt.Errorf("get request url failed: %w", err) |
| } |
| targetHeader := http.Header{} |
| err = a.SetupRequestHeader(c, &targetHeader, info) |
| if err != nil { |
| return nil, fmt.Errorf("setup request header failed: %w", err) |
| } |
| |
| |
| headerOverride, err := processHeaderOverride(info, c) |
| if err != nil { |
| return nil, err |
| } |
| for key, value := range headerOverride { |
| targetHeader.Set(key, value) |
| } |
| targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type")) |
| targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader) |
| if err != nil { |
| return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err) |
| } |
| |
| |
| |
| return targetConn, nil |
| } |
|
|
| func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc { |
| pingerCtx, stopPinger := context.WithCancel(context.Background()) |
|
|
| gopool.Go(func() { |
| defer func() { |
| |
| if r := recover(); r != nil { |
| if common2.DebugEnabled { |
| println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r)) |
| } |
| } |
| if common2.DebugEnabled { |
| println("SSE ping goroutine stopped.") |
| } |
| }() |
|
|
| if pingInterval <= 0 { |
| pingInterval = helper.DefaultPingInterval |
| } |
|
|
| ticker := time.NewTicker(pingInterval) |
| |
| defer func() { |
| ticker.Stop() |
| if common2.DebugEnabled { |
| println("SSE ping ticker stopped") |
| } |
| }() |
|
|
| var pingMutex sync.Mutex |
| if common2.DebugEnabled { |
| println("SSE ping goroutine started") |
| } |
|
|
| |
| maxPingDuration := 120 * time.Minute |
| pingTimeout := time.NewTimer(maxPingDuration) |
| defer pingTimeout.Stop() |
|
|
| for { |
| select { |
| |
| case <-ticker.C: |
| if err := sendPingData(c, &pingMutex); err != nil { |
| if common2.DebugEnabled { |
| println("SSE ping error, stopping goroutine:", err.Error()) |
| } |
| return |
| } |
| |
| case <-pingerCtx.Done(): |
| return |
| |
| case <-c.Request.Context().Done(): |
| return |
| |
| case <-pingTimeout.C: |
| if common2.DebugEnabled { |
| println("SSE ping goroutine timeout, stopping") |
| } |
| return |
| } |
| } |
| }) |
|
|
| return stopPinger |
| } |
|
|
| func sendPingData(c *gin.Context, mutex *sync.Mutex) error { |
| |
| done := make(chan error, 1) |
| go func() { |
| mutex.Lock() |
| defer mutex.Unlock() |
|
|
| err := helper.PingData(c) |
| if err != nil { |
| logger.LogError(c, "SSE ping error: "+err.Error()) |
| done <- err |
| return |
| } |
|
|
| if common2.DebugEnabled { |
| println("SSE ping data sent.") |
| } |
| done <- nil |
| }() |
|
|
| |
| select { |
| case err := <-done: |
| return err |
| case <-time.After(10 * time.Second): |
| return errors.New("SSE ping data send timeout") |
| case <-c.Request.Context().Done(): |
| return errors.New("request context cancelled during ping") |
| } |
| } |
|
|
| func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { |
| return doRequest(c, req, info) |
| } |
| func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) { |
| var client *http.Client |
| var err error |
| if info.ChannelSetting.Proxy != "" { |
| client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy) |
| if err != nil { |
| return nil, fmt.Errorf("new proxy http client failed: %w", err) |
| } |
| } else { |
| client = service.GetHttpClient() |
| } |
|
|
| var stopPinger context.CancelFunc |
| if info.IsStream { |
| helper.SetEventStreamHeaders(c) |
| |
| generalSettings := operation_setting.GetGeneralSetting() |
| if generalSettings.PingIntervalEnabled && !info.DisablePing { |
| pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second |
| stopPinger = startPingKeepAlive(c, pingInterval) |
| |
| defer func() { |
| if stopPinger != nil { |
| stopPinger() |
| if common2.DebugEnabled { |
| println("SSE ping goroutine stopped by defer") |
| } |
| } |
| }() |
| } |
| } |
|
|
| resp, err := client.Do(req) |
| if err != nil { |
| logger.LogError(c, "do request failed: "+err.Error()) |
| return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed")) |
| } |
| if resp == nil { |
| return nil, errors.New("resp is nil") |
| } |
|
|
| _ = req.Body.Close() |
| _ = c.Request.Body.Close() |
| return resp, nil |
| } |
|
|
| func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { |
| fullRequestURL, err := a.BuildRequestURL(info) |
| if err != nil { |
| return nil, err |
| } |
| req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) |
| if err != nil { |
| return nil, fmt.Errorf("new request failed: %w", err) |
| } |
| req.GetBody = func() (io.ReadCloser, error) { |
| return io.NopCloser(requestBody), nil |
| } |
|
|
| err = a.BuildRequestHeader(c, req, info) |
| if err != nil { |
| return nil, fmt.Errorf("setup request header failed: %w", err) |
| } |
| resp, err := doRequest(c, req, info) |
| if err != nil { |
| return nil, fmt.Errorf("do request failed: %w", err) |
| } |
| return resp, nil |
| } |
|
|