File size: 5,483 Bytes
daa8246 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | package model
import (
"errors"
"time"
"gorm.io/gorm"
)
// UserOAuthBinding stores the binding relationship between users and custom OAuth providers
type UserOAuthBinding struct {
Id int `json:"id" gorm:"primaryKey"`
UserId int `json:"user_id" gorm:"not null;uniqueIndex:ux_user_provider"` // User ID - one binding per user per provider
ProviderId int `json:"provider_id" gorm:"not null;uniqueIndex:ux_user_provider;uniqueIndex:ux_provider_userid"` // Custom OAuth provider ID
ProviderUserId string `json:"provider_user_id" gorm:"type:varchar(256);not null;uniqueIndex:ux_provider_userid"` // User ID from OAuth provider - one OAuth account per provider
CreatedAt time.Time `json:"created_at"`
}
func (UserOAuthBinding) TableName() string {
return "user_oauth_bindings"
}
// GetUserOAuthBindingsByUserId returns all OAuth bindings for a user
func GetUserOAuthBindingsByUserId(userId int) ([]*UserOAuthBinding, error) {
var bindings []*UserOAuthBinding
err := DB.Where("user_id = ?", userId).Find(&bindings).Error
return bindings, err
}
// GetUserOAuthBinding returns a specific binding for a user and provider
func GetUserOAuthBinding(userId, providerId int) (*UserOAuthBinding, error) {
var binding UserOAuthBinding
err := DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
if err != nil {
return nil, err
}
return &binding, nil
}
// GetUserByOAuthBinding finds a user by provider ID and provider user ID
func GetUserByOAuthBinding(providerId int, providerUserId string) (*User, error) {
var binding UserOAuthBinding
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).First(&binding).Error
if err != nil {
return nil, err
}
var user User
err = DB.First(&user, binding.UserId).Error
if err != nil {
return nil, err
}
return &user, nil
}
// IsProviderUserIdTaken checks if a provider user ID is already bound to any user
func IsProviderUserIdTaken(providerId int, providerUserId string) bool {
var count int64
DB.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", providerId, providerUserId).Count(&count)
return count > 0
}
// CreateUserOAuthBinding creates a new OAuth binding
func CreateUserOAuthBinding(binding *UserOAuthBinding) error {
if binding.UserId == 0 {
return errors.New("user ID is required")
}
if binding.ProviderId == 0 {
return errors.New("provider ID is required")
}
if binding.ProviderUserId == "" {
return errors.New("provider user ID is required")
}
// Check if this provider user ID is already taken
if IsProviderUserIdTaken(binding.ProviderId, binding.ProviderUserId) {
return errors.New("this OAuth account is already bound to another user")
}
binding.CreatedAt = time.Now()
return DB.Create(binding).Error
}
// CreateUserOAuthBindingWithTx creates a new OAuth binding within a transaction
func CreateUserOAuthBindingWithTx(tx *gorm.DB, binding *UserOAuthBinding) error {
if binding.UserId == 0 {
return errors.New("user ID is required")
}
if binding.ProviderId == 0 {
return errors.New("provider ID is required")
}
if binding.ProviderUserId == "" {
return errors.New("provider user ID is required")
}
// Check if this provider user ID is already taken (use tx to check within the same transaction)
var count int64
tx.Model(&UserOAuthBinding{}).Where("provider_id = ? AND provider_user_id = ?", binding.ProviderId, binding.ProviderUserId).Count(&count)
if count > 0 {
return errors.New("this OAuth account is already bound to another user")
}
binding.CreatedAt = time.Now()
return tx.Create(binding).Error
}
// UpdateUserOAuthBinding updates an existing OAuth binding (e.g., rebind to different OAuth account)
func UpdateUserOAuthBinding(userId, providerId int, newProviderUserId string) error {
// Check if the new provider user ID is already taken by another user
var existingBinding UserOAuthBinding
err := DB.Where("provider_id = ? AND provider_user_id = ?", providerId, newProviderUserId).First(&existingBinding).Error
if err == nil && existingBinding.UserId != userId {
return errors.New("this OAuth account is already bound to another user")
}
// Check if user already has a binding for this provider
var binding UserOAuthBinding
err = DB.Where("user_id = ? AND provider_id = ?", userId, providerId).First(&binding).Error
if err != nil {
// No existing binding, create new one
return CreateUserOAuthBinding(&UserOAuthBinding{
UserId: userId,
ProviderId: providerId,
ProviderUserId: newProviderUserId,
})
}
// Update existing binding
return DB.Model(&binding).Update("provider_user_id", newProviderUserId).Error
}
// DeleteUserOAuthBinding deletes an OAuth binding
func DeleteUserOAuthBinding(userId, providerId int) error {
return DB.Where("user_id = ? AND provider_id = ?", userId, providerId).Delete(&UserOAuthBinding{}).Error
}
// DeleteUserOAuthBindingsByUserId deletes all OAuth bindings for a user
func DeleteUserOAuthBindingsByUserId(userId int) error {
return DB.Where("user_id = ?", userId).Delete(&UserOAuthBinding{}).Error
}
// GetBindingCountByProviderId returns the number of bindings for a provider
func GetBindingCountByProviderId(providerId int) (int64, error) {
var count int64
err := DB.Model(&UserOAuthBinding{}).Where("provider_id = ?", providerId).Count(&count).Error
return count, err
}
|