File size: 2,854 Bytes
daa8246 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | package service
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/model"
)
type CodexCredentialRefreshOptions struct {
ResetCaches bool
}
type CodexOAuthKey struct {
IDToken string `json:"id_token,omitempty"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
AccountID string `json:"account_id,omitempty"`
LastRefresh string `json:"last_refresh,omitempty"`
Email string `json:"email,omitempty"`
Type string `json:"type,omitempty"`
Expired string `json:"expired,omitempty"`
}
func parseCodexOAuthKey(raw string) (*CodexOAuthKey, error) {
if strings.TrimSpace(raw) == "" {
return nil, errors.New("codex channel: empty oauth key")
}
var key CodexOAuthKey
if err := common.Unmarshal([]byte(raw), &key); err != nil {
return nil, errors.New("codex channel: invalid oauth key json")
}
return &key, nil
}
func RefreshCodexChannelCredential(ctx context.Context, channelID int, opts CodexCredentialRefreshOptions) (*CodexOAuthKey, *model.Channel, error) {
ch, err := model.GetChannelById(channelID, true)
if err != nil {
return nil, nil, err
}
if ch == nil {
return nil, nil, fmt.Errorf("channel not found")
}
if ch.Type != constant.ChannelTypeCodex {
return nil, nil, fmt.Errorf("channel type is not Codex")
}
oauthKey, err := parseCodexOAuthKey(strings.TrimSpace(ch.Key))
if err != nil {
return nil, nil, err
}
if strings.TrimSpace(oauthKey.RefreshToken) == "" {
return nil, nil, fmt.Errorf("codex channel: refresh_token is required to refresh credential")
}
refreshCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
res, err := RefreshCodexOAuthTokenWithProxy(refreshCtx, oauthKey.RefreshToken, ch.GetSetting().Proxy)
if err != nil {
return nil, nil, err
}
oauthKey.AccessToken = res.AccessToken
oauthKey.RefreshToken = res.RefreshToken
oauthKey.LastRefresh = time.Now().Format(time.RFC3339)
oauthKey.Expired = res.ExpiresAt.Format(time.RFC3339)
if strings.TrimSpace(oauthKey.Type) == "" {
oauthKey.Type = "codex"
}
if strings.TrimSpace(oauthKey.AccountID) == "" {
if accountID, ok := ExtractCodexAccountIDFromJWT(oauthKey.AccessToken); ok {
oauthKey.AccountID = accountID
}
}
if strings.TrimSpace(oauthKey.Email) == "" {
if email, ok := ExtractEmailFromJWT(oauthKey.AccessToken); ok {
oauthKey.Email = email
}
}
encoded, err := common.Marshal(oauthKey)
if err != nil {
return nil, nil, err
}
if err := model.DB.Model(&model.Channel{}).Where("id = ?", ch.Id).Update("key", string(encoded)).Error; err != nil {
return nil, nil, err
}
if opts.ResetCaches {
model.InitChannelCache()
ResetProxyClientCache()
}
return oauthKey, ch, nil
}
|