| package helper |
|
|
| import ( |
| "bufio" |
| "context" |
| "fmt" |
| "io" |
| "net/http" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/logger" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| "github.com/QuantumNous/new-api/setting/operation_setting" |
|
|
| "github.com/bytedance/gopkg/util/gopool" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| const ( |
| InitialScannerBufferSize = 64 << 10 |
| DefaultMaxScannerBufferSize = 64 << 20 |
| DefaultPingInterval = 10 * time.Second |
| ) |
|
|
| func getScannerBufferSize() int { |
| if constant.StreamScannerMaxBufferMB > 0 { |
| return constant.StreamScannerMaxBufferMB << 20 |
| } |
| return DefaultMaxScannerBufferSize |
| } |
|
|
| func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, dataHandler func(data string) bool) { |
|
|
| if resp == nil || dataHandler == nil { |
| return |
| } |
|
|
| |
| defer func() { |
| if resp.Body != nil { |
| resp.Body.Close() |
| } |
| }() |
|
|
| streamingTimeout := time.Duration(constant.StreamingTimeout) * time.Second |
|
|
| var ( |
| stopChan = make(chan bool, 3) |
| scanner = bufio.NewScanner(resp.Body) |
| ticker = time.NewTicker(streamingTimeout) |
| pingTicker *time.Ticker |
| writeMutex sync.Mutex |
| wg sync.WaitGroup |
| ) |
|
|
| generalSettings := operation_setting.GetGeneralSetting() |
| pingEnabled := generalSettings.PingIntervalEnabled && !info.DisablePing |
| pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second |
| if pingInterval <= 0 { |
| pingInterval = DefaultPingInterval |
| } |
|
|
| if pingEnabled { |
| pingTicker = time.NewTicker(pingInterval) |
| } |
|
|
| if common.DebugEnabled { |
| |
| println("relay timeout seconds:", common.RelayTimeout) |
| println("relay max idle conns:", common.RelayMaxIdleConns) |
| println("relay max idle conns per host:", common.RelayMaxIdleConnsPerHost) |
| println("streaming timeout seconds:", int64(streamingTimeout.Seconds())) |
| println("ping interval seconds:", int64(pingInterval.Seconds())) |
| } |
|
|
| |
| defer func() { |
| |
| common.SafeSendBool(stopChan, true) |
|
|
| ticker.Stop() |
| if pingTicker != nil { |
| pingTicker.Stop() |
| } |
|
|
| |
| done := make(chan struct{}) |
| gopool.Go(func() { |
| wg.Wait() |
| close(done) |
| }) |
|
|
| select { |
| case <-done: |
| case <-time.After(5 * time.Second): |
| logger.LogError(c, "timeout waiting for goroutines to exit") |
| } |
|
|
| close(stopChan) |
| }() |
|
|
| scanner.Buffer(make([]byte, InitialScannerBufferSize), getScannerBufferSize()) |
| scanner.Split(bufio.ScanLines) |
| SetEventStreamHeaders(c) |
|
|
| ctx, cancel := context.WithCancel(context.Background()) |
| defer cancel() |
|
|
| ctx = context.WithValue(ctx, "stop_chan", stopChan) |
|
|
| |
| if pingEnabled && pingTicker != nil { |
| wg.Add(1) |
| gopool.Go(func() { |
| defer func() { |
| wg.Done() |
| if r := recover(); r != nil { |
| logger.LogError(c, fmt.Sprintf("ping goroutine panic: %v", r)) |
| common.SafeSendBool(stopChan, true) |
| } |
| if common.DebugEnabled { |
| println("ping goroutine exited") |
| } |
| }() |
|
|
| |
| maxPingDuration := 30 * time.Minute |
| pingTimeout := time.NewTimer(maxPingDuration) |
| defer pingTimeout.Stop() |
|
|
| for { |
| select { |
| case <-pingTicker.C: |
| |
| done := make(chan error, 1) |
| gopool.Go(func() { |
| writeMutex.Lock() |
| defer writeMutex.Unlock() |
| done <- PingData(c) |
| }) |
|
|
| select { |
| case err := <-done: |
| if err != nil { |
| logger.LogError(c, "ping data error: "+err.Error()) |
| return |
| } |
| if common.DebugEnabled { |
| println("ping data sent") |
| } |
| case <-time.After(10 * time.Second): |
| logger.LogError(c, "ping data send timeout") |
| return |
| case <-ctx.Done(): |
| return |
| case <-stopChan: |
| return |
| } |
| case <-ctx.Done(): |
| return |
| case <-stopChan: |
| return |
| case <-c.Request.Context().Done(): |
| |
| return |
| case <-pingTimeout.C: |
| logger.LogError(c, "ping goroutine max duration reached") |
| return |
| } |
| } |
| }) |
| } |
|
|
| dataChan := make(chan string, 10) |
|
|
| wg.Add(1) |
| gopool.Go(func() { |
| defer func() { |
| wg.Done() |
| if r := recover(); r != nil { |
| logger.LogError(c, fmt.Sprintf("data handler goroutine panic: %v", r)) |
| } |
| common.SafeSendBool(stopChan, true) |
| }() |
| for data := range dataChan { |
| writeMutex.Lock() |
| success := dataHandler(data) |
| writeMutex.Unlock() |
| if !success { |
| return |
| } |
| } |
| }) |
|
|
| |
| wg.Add(1) |
| common.RelayCtxGo(ctx, func() { |
| defer func() { |
| close(dataChan) |
| wg.Done() |
| if r := recover(); r != nil { |
| logger.LogError(c, fmt.Sprintf("scanner goroutine panic: %v", r)) |
| } |
| common.SafeSendBool(stopChan, true) |
| if common.DebugEnabled { |
| println("scanner goroutine exited") |
| } |
| }() |
|
|
| for scanner.Scan() { |
| |
| select { |
| case <-stopChan: |
| return |
| case <-ctx.Done(): |
| return |
| case <-c.Request.Context().Done(): |
| return |
| default: |
| } |
|
|
| ticker.Reset(streamingTimeout) |
| data := scanner.Text() |
| if common.DebugEnabled { |
| println(data) |
| } |
|
|
| if len(data) < 6 { |
| continue |
| } |
| if data[:5] != "data:" && data[:6] != "[DONE]" { |
| continue |
| } |
| data = data[5:] |
| data = strings.TrimSpace(data) |
| if data == "" { |
| continue |
| } |
| if !strings.HasPrefix(data, "[DONE]") { |
| info.SetFirstResponseTime() |
| info.ReceivedResponseCount++ |
|
|
| select { |
| case dataChan <- data: |
| case <-ctx.Done(): |
| return |
| case <-stopChan: |
| return |
| } |
| } else { |
| |
| if common.DebugEnabled { |
| println("received [DONE], stopping scanner") |
| } |
| return |
| } |
| } |
|
|
| if err := scanner.Err(); err != nil { |
| if err != io.EOF { |
| logger.LogError(c, "scanner error: "+err.Error()) |
| } |
| } |
| }) |
|
|
| |
| select { |
| case <-ticker.C: |
| |
| logger.LogError(c, "streaming timeout") |
| case <-stopChan: |
| |
| logger.LogInfo(c, "streaming finished") |
| case <-c.Request.Context().Done(): |
| |
| logger.LogInfo(c, "client disconnected") |
| } |
| } |
|
|