File size: 8,981 Bytes
daa8246 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 | package model
import (
"errors"
"fmt"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
)
type accessPolicyPayload struct {
Logic string `json:"logic"`
Conditions []accessConditionItem `json:"conditions"`
Groups []accessPolicyPayload `json:"groups"`
}
type accessConditionItem struct {
Field string `json:"field"`
Op string `json:"op"`
Value any `json:"value"`
}
var supportedAccessPolicyOps = map[string]struct{}{
"eq": {},
"ne": {},
"gt": {},
"gte": {},
"lt": {},
"lte": {},
"in": {},
"not_in": {},
"contains": {},
"not_contains": {},
"exists": {},
"not_exists": {},
}
// CustomOAuthProvider stores configuration for custom OAuth providers
type CustomOAuthProvider struct {
Id int `json:"id" gorm:"primaryKey"`
Name string `json:"name" gorm:"type:varchar(64);not null"` // Display name, e.g., "GitHub Enterprise"
Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` // URL identifier, e.g., "github-enterprise"
Icon string `json:"icon" gorm:"type:varchar(128);default:''"` // Icon name from @lobehub/icons
Enabled bool `json:"enabled" gorm:"default:false"` // Whether this provider is enabled
ClientId string `json:"client_id" gorm:"type:varchar(256)"` // OAuth client ID
ClientSecret string `json:"-" gorm:"type:varchar(512)"` // OAuth client secret (not returned to frontend)
AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` // Authorization URL
TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` // Token exchange URL
UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` // User info URL
Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` // OAuth scopes
// Field mapping configuration (supports JSONPath via gjson)
UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` // User ID field path, e.g., "sub", "id", "data.user.id"
UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` // Username field path
DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` // Display name field path
EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` // Email field path
// Advanced options
WellKnown string `json:"well_known" gorm:"type:varchar(512)"` // OIDC discovery endpoint (optional)
AuthStyle int `json:"auth_style" gorm:"default:0"` // 0=auto, 1=params, 2=header (Basic Auth)
AccessPolicy string `json:"access_policy" gorm:"type:text"` // JSON policy for access control based on user info
AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` // Custom error message template when access is denied
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
func (CustomOAuthProvider) TableName() string {
return "custom_oauth_providers"
}
// GetAllCustomOAuthProviders returns all custom OAuth providers
func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
var providers []*CustomOAuthProvider
err := DB.Order("id asc").Find(&providers).Error
return providers, err
}
// GetEnabledCustomOAuthProviders returns all enabled custom OAuth providers
func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) {
var providers []*CustomOAuthProvider
err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error
return providers, err
}
// GetCustomOAuthProviderById returns a custom OAuth provider by ID
func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) {
var provider CustomOAuthProvider
err := DB.First(&provider, id).Error
if err != nil {
return nil, err
}
return &provider, nil
}
// GetCustomOAuthProviderBySlug returns a custom OAuth provider by slug
func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) {
var provider CustomOAuthProvider
err := DB.Where("slug = ?", slug).First(&provider).Error
if err != nil {
return nil, err
}
return &provider, nil
}
// CreateCustomOAuthProvider creates a new custom OAuth provider
func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if err := validateCustomOAuthProvider(provider); err != nil {
return err
}
return DB.Create(provider).Error
}
// UpdateCustomOAuthProvider updates an existing custom OAuth provider
func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if err := validateCustomOAuthProvider(provider); err != nil {
return err
}
return DB.Save(provider).Error
}
// DeleteCustomOAuthProvider deletes a custom OAuth provider by ID
func DeleteCustomOAuthProvider(id int) error {
// First, delete all user bindings for this provider
if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil {
return err
}
return DB.Delete(&CustomOAuthProvider{}, id).Error
}
// IsSlugTaken checks if a slug is already taken by another provider
// Returns true on DB errors (fail-closed) to prevent slug conflicts
func IsSlugTaken(slug string, excludeId int) bool {
var count int64
query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug)
if excludeId > 0 {
query = query.Where("id != ?", excludeId)
}
res := query.Count(&count)
if res.Error != nil {
// Fail-closed: treat DB errors as slug being taken to prevent conflicts
return true
}
return count > 0
}
// validateCustomOAuthProvider validates a custom OAuth provider configuration
func validateCustomOAuthProvider(provider *CustomOAuthProvider) error {
if provider.Name == "" {
return errors.New("provider name is required")
}
if provider.Slug == "" {
return errors.New("provider slug is required")
}
// Slug must be lowercase and contain only alphanumeric characters and hyphens
slug := strings.ToLower(provider.Slug)
for _, c := range slug {
if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') {
return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens")
}
}
provider.Slug = slug
if provider.ClientId == "" {
return errors.New("client ID is required")
}
if provider.AuthorizationEndpoint == "" {
return errors.New("authorization endpoint is required")
}
if provider.TokenEndpoint == "" {
return errors.New("token endpoint is required")
}
if provider.UserInfoEndpoint == "" {
return errors.New("user info endpoint is required")
}
// Set defaults for field mappings if empty
if provider.UserIdField == "" {
provider.UserIdField = "sub"
}
if provider.UsernameField == "" {
provider.UsernameField = "preferred_username"
}
if provider.DisplayNameField == "" {
provider.DisplayNameField = "name"
}
if provider.EmailField == "" {
provider.EmailField = "email"
}
if provider.Scopes == "" {
provider.Scopes = "openid profile email"
}
if strings.TrimSpace(provider.AccessPolicy) != "" {
var policy accessPolicyPayload
if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil {
return errors.New("access_policy must be valid JSON")
}
if err := validateAccessPolicyPayload(&policy); err != nil {
return fmt.Errorf("access_policy is invalid: %w", err)
}
}
return nil
}
func validateAccessPolicyPayload(policy *accessPolicyPayload) error {
if policy == nil {
return errors.New("policy is nil")
}
logic := strings.ToLower(strings.TrimSpace(policy.Logic))
if logic == "" {
logic = "and"
}
if logic != "and" && logic != "or" {
return fmt.Errorf("unsupported logic: %s", logic)
}
if len(policy.Conditions) == 0 && len(policy.Groups) == 0 {
return errors.New("policy requires at least one condition or group")
}
for index, condition := range policy.Conditions {
field := strings.TrimSpace(condition.Field)
if field == "" {
return fmt.Errorf("condition[%d].field is required", index)
}
op := strings.ToLower(strings.TrimSpace(condition.Op))
if _, ok := supportedAccessPolicyOps[op]; !ok {
return fmt.Errorf("condition[%d].op is unsupported: %s", index, op)
}
if op == "in" || op == "not_in" {
if _, ok := condition.Value.([]any); !ok {
return fmt.Errorf("condition[%d].value must be an array for op %s", index, op)
}
}
}
for index := range policy.Groups {
if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil {
return fmt.Errorf("group[%d]: %w", index, err)
}
}
return nil
}
|