| package oauth |
|
|
| import ( |
| "fmt" |
| "sync" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/model" |
| ) |
|
|
| var ( |
| providers = make(map[string]Provider) |
| mu sync.RWMutex |
| |
| customProviderSlugs = make(map[string]bool) |
| ) |
|
|
| |
| func Register(name string, provider Provider) { |
| mu.Lock() |
| defer mu.Unlock() |
| providers[name] = provider |
| } |
|
|
| |
| func RegisterCustom(name string, provider Provider) { |
| mu.Lock() |
| defer mu.Unlock() |
| providers[name] = provider |
| customProviderSlugs[name] = true |
| } |
|
|
| |
| func Unregister(name string) { |
| mu.Lock() |
| defer mu.Unlock() |
| delete(providers, name) |
| delete(customProviderSlugs, name) |
| } |
|
|
| |
| func GetProvider(name string) Provider { |
| mu.RLock() |
| defer mu.RUnlock() |
| return providers[name] |
| } |
|
|
| |
| func GetAllProviders() map[string]Provider { |
| mu.RLock() |
| defer mu.RUnlock() |
| result := make(map[string]Provider, len(providers)) |
| for k, v := range providers { |
| result[k] = v |
| } |
| return result |
| } |
|
|
| |
| func GetEnabledCustomProviders() []*GenericOAuthProvider { |
| mu.RLock() |
| defer mu.RUnlock() |
| var result []*GenericOAuthProvider |
| for name, provider := range providers { |
| if customProviderSlugs[name] { |
| if gp, ok := provider.(*GenericOAuthProvider); ok && gp.IsEnabled() { |
| result = append(result, gp) |
| } |
| } |
| } |
| return result |
| } |
|
|
| |
| func IsProviderRegistered(name string) bool { |
| mu.RLock() |
| defer mu.RUnlock() |
| _, ok := providers[name] |
| return ok |
| } |
|
|
| |
| func IsCustomProvider(name string) bool { |
| mu.RLock() |
| defer mu.RUnlock() |
| return customProviderSlugs[name] |
| } |
|
|
| |
| func LoadCustomProviders() error { |
| |
| mu.Lock() |
| for name := range customProviderSlugs { |
| delete(providers, name) |
| } |
| customProviderSlugs = make(map[string]bool) |
| mu.Unlock() |
|
|
| |
| customProviders, err := model.GetAllCustomOAuthProviders() |
| if err != nil { |
| common.SysError("Failed to load custom OAuth providers: " + err.Error()) |
| return err |
| } |
|
|
| |
| for _, config := range customProviders { |
| provider := NewGenericOAuthProvider(config) |
| RegisterCustom(config.Slug, provider) |
| common.SysLog("Loaded custom OAuth provider: " + config.Name + " (" + config.Slug + ")") |
| } |
|
|
| common.SysLog(fmt.Sprintf("Loaded %d custom OAuth providers", len(customProviders))) |
| return nil |
| } |
|
|
| |
| func ReloadCustomProviders() error { |
| return LoadCustomProviders() |
| } |
|
|
| |
| func RegisterOrUpdateCustomProvider(config *model.CustomOAuthProvider) { |
| provider := NewGenericOAuthProvider(config) |
| mu.Lock() |
| defer mu.Unlock() |
| providers[config.Slug] = provider |
| customProviderSlugs[config.Slug] = true |
| } |
|
|
| |
| func UnregisterCustomProvider(slug string) { |
| Unregister(slug) |
| } |
|
|