| package controller |
|
|
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "net/http" |
| "net/url" |
| "strconv" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/model" |
| "github.com/QuantumNous/new-api/relay/channel/codex" |
| "github.com/QuantumNous/new-api/service" |
|
|
| "github.com/gin-contrib/sessions" |
| "github.com/gin-gonic/gin" |
| ) |
|
|
| type codexOAuthCompleteRequest struct { |
| Input string `json:"input"` |
| } |
|
|
| func codexOAuthSessionKey(channelID int, field string) string { |
| return fmt.Sprintf("codex_oauth_%s_%d", field, channelID) |
| } |
|
|
| func parseCodexAuthorizationInput(input string) (code string, state string, err error) { |
| v := strings.TrimSpace(input) |
| if v == "" { |
| return "", "", errors.New("empty input") |
| } |
| if strings.Contains(v, "#") { |
| parts := strings.SplitN(v, "#", 2) |
| code = strings.TrimSpace(parts[0]) |
| state = strings.TrimSpace(parts[1]) |
| return code, state, nil |
| } |
| if strings.Contains(v, "code=") { |
| u, parseErr := url.Parse(v) |
| if parseErr == nil { |
| q := u.Query() |
| code = strings.TrimSpace(q.Get("code")) |
| state = strings.TrimSpace(q.Get("state")) |
| return code, state, nil |
| } |
| q, parseErr := url.ParseQuery(v) |
| if parseErr == nil { |
| code = strings.TrimSpace(q.Get("code")) |
| state = strings.TrimSpace(q.Get("state")) |
| return code, state, nil |
| } |
| } |
|
|
| code = v |
| return code, "", nil |
| } |
|
|
| func StartCodexOAuth(c *gin.Context) { |
| startCodexOAuthWithChannelID(c, 0) |
| } |
|
|
| func StartCodexOAuthForChannel(c *gin.Context) { |
| channelID, err := strconv.Atoi(c.Param("id")) |
| if err != nil { |
| common.ApiError(c, fmt.Errorf("invalid channel id: %w", err)) |
| return |
| } |
| startCodexOAuthWithChannelID(c, channelID) |
| } |
|
|
| func startCodexOAuthWithChannelID(c *gin.Context, channelID int) { |
| if channelID > 0 { |
| ch, err := model.GetChannelById(channelID, false) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| if ch == nil { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"}) |
| return |
| } |
| if ch.Type != constant.ChannelTypeCodex { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"}) |
| return |
| } |
| } |
|
|
| flow, err := service.CreateCodexOAuthAuthorizationFlow() |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| session := sessions.Default(c) |
| session.Set(codexOAuthSessionKey(channelID, "state"), flow.State) |
| session.Set(codexOAuthSessionKey(channelID, "verifier"), flow.Verifier) |
| session.Set(codexOAuthSessionKey(channelID, "created_at"), time.Now().Unix()) |
| _ = session.Save() |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": gin.H{ |
| "authorize_url": flow.AuthorizeURL, |
| }, |
| }) |
| } |
|
|
| func CompleteCodexOAuth(c *gin.Context) { |
| completeCodexOAuthWithChannelID(c, 0) |
| } |
|
|
| func CompleteCodexOAuthForChannel(c *gin.Context) { |
| channelID, err := strconv.Atoi(c.Param("id")) |
| if err != nil { |
| common.ApiError(c, fmt.Errorf("invalid channel id: %w", err)) |
| return |
| } |
| completeCodexOAuthWithChannelID(c, channelID) |
| } |
|
|
| func completeCodexOAuthWithChannelID(c *gin.Context, channelID int) { |
| req := codexOAuthCompleteRequest{} |
| if err := c.ShouldBindJSON(&req); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| code, state, err := parseCodexAuthorizationInput(req.Input) |
| if err != nil { |
| common.SysError("failed to parse codex authorization input: " + err.Error()) |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "解析授权信息失败,请检查输入格式"}) |
| return |
| } |
| if strings.TrimSpace(code) == "" { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing authorization code"}) |
| return |
| } |
| if strings.TrimSpace(state) == "" { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "missing state in input"}) |
| return |
| } |
|
|
| channelProxy := "" |
| if channelID > 0 { |
| ch, err := model.GetChannelById(channelID, false) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| if ch == nil { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel not found"}) |
| return |
| } |
| if ch.Type != constant.ChannelTypeCodex { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "channel type is not Codex"}) |
| return |
| } |
| channelProxy = ch.GetSetting().Proxy |
| } |
|
|
| session := sessions.Default(c) |
| expectedState, _ := session.Get(codexOAuthSessionKey(channelID, "state")).(string) |
| verifier, _ := session.Get(codexOAuthSessionKey(channelID, "verifier")).(string) |
| if strings.TrimSpace(expectedState) == "" || strings.TrimSpace(verifier) == "" { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "oauth flow not started or session expired"}) |
| return |
| } |
| if state != expectedState { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "state mismatch"}) |
| return |
| } |
|
|
| ctx, cancel := context.WithTimeout(c.Request.Context(), 15*time.Second) |
| defer cancel() |
|
|
| tokenRes, err := service.ExchangeCodexAuthorizationCodeWithProxy(ctx, code, verifier, channelProxy) |
| if err != nil { |
| common.SysError("failed to exchange codex authorization code: " + err.Error()) |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "授权码交换失败,请重试"}) |
| return |
| } |
|
|
| accountID, ok := service.ExtractCodexAccountIDFromJWT(tokenRes.AccessToken) |
| if !ok { |
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "failed to extract account_id from access_token"}) |
| return |
| } |
| email, _ := service.ExtractEmailFromJWT(tokenRes.AccessToken) |
|
|
| key := codex.OAuthKey{ |
| AccessToken: tokenRes.AccessToken, |
| RefreshToken: tokenRes.RefreshToken, |
| AccountID: accountID, |
| LastRefresh: time.Now().Format(time.RFC3339), |
| Expired: tokenRes.ExpiresAt.Format(time.RFC3339), |
| Email: email, |
| Type: "codex", |
| } |
| encoded, err := common.Marshal(key) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| session.Delete(codexOAuthSessionKey(channelID, "state")) |
| session.Delete(codexOAuthSessionKey(channelID, "verifier")) |
| session.Delete(codexOAuthSessionKey(channelID, "created_at")) |
| _ = session.Save() |
|
|
| if channelID > 0 { |
| if err := model.DB.Model(&model.Channel{}).Where("id = ?", channelID).Update("key", string(encoded)).Error; err != nil { |
| common.ApiError(c, err) |
| return |
| } |
| model.InitChannelCache() |
| service.ResetProxyClientCache() |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "saved", |
| "data": gin.H{ |
| "channel_id": channelID, |
| "account_id": accountID, |
| "email": email, |
| "expires_at": key.Expired, |
| "last_refresh": key.LastRefresh, |
| }, |
| }) |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "generated", |
| "data": gin.H{ |
| "key": string(encoded), |
| "account_id": accountID, |
| "email": email, |
| "expires_at": key.Expired, |
| "last_refresh": key.LastRefresh, |
| }, |
| }) |
| } |
|
|