File size: 3,527 Bytes
daa8246 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | package oauth
import (
"fmt"
"sync"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/model"
)
var (
providers = make(map[string]Provider)
mu sync.RWMutex
// customProviderSlugs tracks which providers are custom (can be unregistered)
customProviderSlugs = make(map[string]bool)
)
// Register registers an OAuth provider with the given name
func Register(name string, provider Provider) {
mu.Lock()
defer mu.Unlock()
providers[name] = provider
}
// RegisterCustom registers a custom OAuth provider (can be unregistered later)
func RegisterCustom(name string, provider Provider) {
mu.Lock()
defer mu.Unlock()
providers[name] = provider
customProviderSlugs[name] = true
}
// Unregister removes a provider from the registry
func Unregister(name string) {
mu.Lock()
defer mu.Unlock()
delete(providers, name)
delete(customProviderSlugs, name)
}
// GetProvider returns the OAuth provider for the given name
func GetProvider(name string) Provider {
mu.RLock()
defer mu.RUnlock()
return providers[name]
}
// GetAllProviders returns all registered OAuth providers
func GetAllProviders() map[string]Provider {
mu.RLock()
defer mu.RUnlock()
result := make(map[string]Provider, len(providers))
for k, v := range providers {
result[k] = v
}
return result
}
// GetEnabledCustomProviders returns all enabled custom OAuth providers
func GetEnabledCustomProviders() []*GenericOAuthProvider {
mu.RLock()
defer mu.RUnlock()
var result []*GenericOAuthProvider
for name, provider := range providers {
if customProviderSlugs[name] {
if gp, ok := provider.(*GenericOAuthProvider); ok && gp.IsEnabled() {
result = append(result, gp)
}
}
}
return result
}
// IsProviderRegistered checks if a provider is registered
func IsProviderRegistered(name string) bool {
mu.RLock()
defer mu.RUnlock()
_, ok := providers[name]
return ok
}
// IsCustomProvider checks if a provider is a custom provider
func IsCustomProvider(name string) bool {
mu.RLock()
defer mu.RUnlock()
return customProviderSlugs[name]
}
// LoadCustomProviders loads all custom OAuth providers from the database
func LoadCustomProviders() error {
// First, unregister all existing custom providers
mu.Lock()
for name := range customProviderSlugs {
delete(providers, name)
}
customProviderSlugs = make(map[string]bool)
mu.Unlock()
// Load all custom providers from database
customProviders, err := model.GetAllCustomOAuthProviders()
if err != nil {
common.SysError("Failed to load custom OAuth providers: " + err.Error())
return err
}
// Register each custom provider
for _, config := range customProviders {
provider := NewGenericOAuthProvider(config)
RegisterCustom(config.Slug, provider)
common.SysLog("Loaded custom OAuth provider: " + config.Name + " (" + config.Slug + ")")
}
common.SysLog(fmt.Sprintf("Loaded %d custom OAuth providers", len(customProviders)))
return nil
}
// ReloadCustomProviders reloads all custom OAuth providers from the database
func ReloadCustomProviders() error {
return LoadCustomProviders()
}
// RegisterOrUpdateCustomProvider registers or updates a single custom provider
func RegisterOrUpdateCustomProvider(config *model.CustomOAuthProvider) {
provider := NewGenericOAuthProvider(config)
mu.Lock()
defer mu.Unlock()
providers[config.Slug] = provider
customProviderSlugs[config.Slug] = true
}
// UnregisterCustomProvider unregisters a custom provider by slug
func UnregisterCustomProvider(slug string) {
Unregister(slug)
}
|