| package model |
|
|
| import ( |
| "errors" |
| "fmt" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| ) |
|
|
| type accessPolicyPayload struct { |
| Logic string `json:"logic"` |
| Conditions []accessConditionItem `json:"conditions"` |
| Groups []accessPolicyPayload `json:"groups"` |
| } |
|
|
| type accessConditionItem struct { |
| Field string `json:"field"` |
| Op string `json:"op"` |
| Value any `json:"value"` |
| } |
|
|
| var supportedAccessPolicyOps = map[string]struct{}{ |
| "eq": {}, |
| "ne": {}, |
| "gt": {}, |
| "gte": {}, |
| "lt": {}, |
| "lte": {}, |
| "in": {}, |
| "not_in": {}, |
| "contains": {}, |
| "not_contains": {}, |
| "exists": {}, |
| "not_exists": {}, |
| } |
|
|
| |
| type CustomOAuthProvider struct { |
| Id int `json:"id" gorm:"primaryKey"` |
| Name string `json:"name" gorm:"type:varchar(64);not null"` |
| Slug string `json:"slug" gorm:"type:varchar(64);uniqueIndex;not null"` |
| Icon string `json:"icon" gorm:"type:varchar(128);default:''"` |
| Enabled bool `json:"enabled" gorm:"default:false"` |
| ClientId string `json:"client_id" gorm:"type:varchar(256)"` |
| ClientSecret string `json:"-" gorm:"type:varchar(512)"` |
| AuthorizationEndpoint string `json:"authorization_endpoint" gorm:"type:varchar(512)"` |
| TokenEndpoint string `json:"token_endpoint" gorm:"type:varchar(512)"` |
| UserInfoEndpoint string `json:"user_info_endpoint" gorm:"type:varchar(512)"` |
| Scopes string `json:"scopes" gorm:"type:varchar(256);default:'openid profile email'"` |
|
|
| |
| UserIdField string `json:"user_id_field" gorm:"type:varchar(128);default:'sub'"` |
| UsernameField string `json:"username_field" gorm:"type:varchar(128);default:'preferred_username'"` |
| DisplayNameField string `json:"display_name_field" gorm:"type:varchar(128);default:'name'"` |
| EmailField string `json:"email_field" gorm:"type:varchar(128);default:'email'"` |
|
|
| |
| WellKnown string `json:"well_known" gorm:"type:varchar(512)"` |
| AuthStyle int `json:"auth_style" gorm:"default:0"` |
| AccessPolicy string `json:"access_policy" gorm:"type:text"` |
| AccessDeniedMessage string `json:"access_denied_message" gorm:"type:varchar(512)"` |
|
|
| CreatedAt time.Time `json:"created_at"` |
| UpdatedAt time.Time `json:"updated_at"` |
| } |
|
|
| func (CustomOAuthProvider) TableName() string { |
| return "custom_oauth_providers" |
| } |
|
|
| |
| func GetAllCustomOAuthProviders() ([]*CustomOAuthProvider, error) { |
| var providers []*CustomOAuthProvider |
| err := DB.Order("id asc").Find(&providers).Error |
| return providers, err |
| } |
|
|
| |
| func GetEnabledCustomOAuthProviders() ([]*CustomOAuthProvider, error) { |
| var providers []*CustomOAuthProvider |
| err := DB.Where("enabled = ?", true).Order("id asc").Find(&providers).Error |
| return providers, err |
| } |
|
|
| |
| func GetCustomOAuthProviderById(id int) (*CustomOAuthProvider, error) { |
| var provider CustomOAuthProvider |
| err := DB.First(&provider, id).Error |
| if err != nil { |
| return nil, err |
| } |
| return &provider, nil |
| } |
|
|
| |
| func GetCustomOAuthProviderBySlug(slug string) (*CustomOAuthProvider, error) { |
| var provider CustomOAuthProvider |
| err := DB.Where("slug = ?", slug).First(&provider).Error |
| if err != nil { |
| return nil, err |
| } |
| return &provider, nil |
| } |
|
|
| |
| func CreateCustomOAuthProvider(provider *CustomOAuthProvider) error { |
| if err := validateCustomOAuthProvider(provider); err != nil { |
| return err |
| } |
| return DB.Create(provider).Error |
| } |
|
|
| |
| func UpdateCustomOAuthProvider(provider *CustomOAuthProvider) error { |
| if err := validateCustomOAuthProvider(provider); err != nil { |
| return err |
| } |
| return DB.Save(provider).Error |
| } |
|
|
| |
| func DeleteCustomOAuthProvider(id int) error { |
| |
| if err := DB.Where("provider_id = ?", id).Delete(&UserOAuthBinding{}).Error; err != nil { |
| return err |
| } |
| return DB.Delete(&CustomOAuthProvider{}, id).Error |
| } |
|
|
| |
| |
| func IsSlugTaken(slug string, excludeId int) bool { |
| var count int64 |
| query := DB.Model(&CustomOAuthProvider{}).Where("slug = ?", slug) |
| if excludeId > 0 { |
| query = query.Where("id != ?", excludeId) |
| } |
| res := query.Count(&count) |
| if res.Error != nil { |
| |
| return true |
| } |
| return count > 0 |
| } |
|
|
| |
| func validateCustomOAuthProvider(provider *CustomOAuthProvider) error { |
| if provider.Name == "" { |
| return errors.New("provider name is required") |
| } |
| if provider.Slug == "" { |
| return errors.New("provider slug is required") |
| } |
| |
| slug := strings.ToLower(provider.Slug) |
| for _, c := range slug { |
| if !((c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-') { |
| return errors.New("provider slug must contain only lowercase letters, numbers, and hyphens") |
| } |
| } |
| provider.Slug = slug |
|
|
| if provider.ClientId == "" { |
| return errors.New("client ID is required") |
| } |
| if provider.AuthorizationEndpoint == "" { |
| return errors.New("authorization endpoint is required") |
| } |
| if provider.TokenEndpoint == "" { |
| return errors.New("token endpoint is required") |
| } |
| if provider.UserInfoEndpoint == "" { |
| return errors.New("user info endpoint is required") |
| } |
|
|
| |
| if provider.UserIdField == "" { |
| provider.UserIdField = "sub" |
| } |
| if provider.UsernameField == "" { |
| provider.UsernameField = "preferred_username" |
| } |
| if provider.DisplayNameField == "" { |
| provider.DisplayNameField = "name" |
| } |
| if provider.EmailField == "" { |
| provider.EmailField = "email" |
| } |
| if provider.Scopes == "" { |
| provider.Scopes = "openid profile email" |
| } |
| if strings.TrimSpace(provider.AccessPolicy) != "" { |
| var policy accessPolicyPayload |
| if err := common.UnmarshalJsonStr(provider.AccessPolicy, &policy); err != nil { |
| return errors.New("access_policy must be valid JSON") |
| } |
| if err := validateAccessPolicyPayload(&policy); err != nil { |
| return fmt.Errorf("access_policy is invalid: %w", err) |
| } |
| } |
|
|
| return nil |
| } |
|
|
| func validateAccessPolicyPayload(policy *accessPolicyPayload) error { |
| if policy == nil { |
| return errors.New("policy is nil") |
| } |
|
|
| logic := strings.ToLower(strings.TrimSpace(policy.Logic)) |
| if logic == "" { |
| logic = "and" |
| } |
| if logic != "and" && logic != "or" { |
| return fmt.Errorf("unsupported logic: %s", logic) |
| } |
|
|
| if len(policy.Conditions) == 0 && len(policy.Groups) == 0 { |
| return errors.New("policy requires at least one condition or group") |
| } |
|
|
| for index, condition := range policy.Conditions { |
| field := strings.TrimSpace(condition.Field) |
| if field == "" { |
| return fmt.Errorf("condition[%d].field is required", index) |
| } |
| op := strings.ToLower(strings.TrimSpace(condition.Op)) |
| if _, ok := supportedAccessPolicyOps[op]; !ok { |
| return fmt.Errorf("condition[%d].op is unsupported: %s", index, op) |
| } |
| if op == "in" || op == "not_in" { |
| if _, ok := condition.Value.([]any); !ok { |
| return fmt.Errorf("condition[%d].value must be an array for op %s", index, op) |
| } |
| } |
| } |
|
|
| for index := range policy.Groups { |
| if err := validateAccessPolicyPayload(&policy.Groups[index]); err != nil { |
| return fmt.Errorf("group[%d]: %w", index, err) |
| } |
| } |
|
|
| return nil |
| } |
|
|