| package controller |
|
|
| import ( |
| "context" |
| "io" |
| "net/http" |
| "net/url" |
| "strconv" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/model" |
| "github.com/QuantumNous/new-api/oauth" |
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| |
| type CustomOAuthProviderResponse struct { |
| Id int `json:"id"` |
| Name string `json:"name"` |
| Slug string `json:"slug"` |
| Icon string `json:"icon"` |
| Enabled bool `json:"enabled"` |
| ClientId string `json:"client_id"` |
| AuthorizationEndpoint string `json:"authorization_endpoint"` |
| TokenEndpoint string `json:"token_endpoint"` |
| UserInfoEndpoint string `json:"user_info_endpoint"` |
| Scopes string `json:"scopes"` |
| UserIdField string `json:"user_id_field"` |
| UsernameField string `json:"username_field"` |
| DisplayNameField string `json:"display_name_field"` |
| EmailField string `json:"email_field"` |
| WellKnown string `json:"well_known"` |
| AuthStyle int `json:"auth_style"` |
| AccessPolicy string `json:"access_policy"` |
| AccessDeniedMessage string `json:"access_denied_message"` |
| } |
|
|
| type UserOAuthBindingResponse struct { |
| ProviderId int `json:"provider_id"` |
| ProviderName string `json:"provider_name"` |
| ProviderSlug string `json:"provider_slug"` |
| ProviderIcon string `json:"provider_icon"` |
| ProviderUserId string `json:"provider_user_id"` |
| } |
|
|
| func toCustomOAuthProviderResponse(p *model.CustomOAuthProvider) *CustomOAuthProviderResponse { |
| return &CustomOAuthProviderResponse{ |
| Id: p.Id, |
| Name: p.Name, |
| Slug: p.Slug, |
| Icon: p.Icon, |
| Enabled: p.Enabled, |
| ClientId: p.ClientId, |
| AuthorizationEndpoint: p.AuthorizationEndpoint, |
| TokenEndpoint: p.TokenEndpoint, |
| UserInfoEndpoint: p.UserInfoEndpoint, |
| Scopes: p.Scopes, |
| UserIdField: p.UserIdField, |
| UsernameField: p.UsernameField, |
| DisplayNameField: p.DisplayNameField, |
| EmailField: p.EmailField, |
| WellKnown: p.WellKnown, |
| AuthStyle: p.AuthStyle, |
| AccessPolicy: p.AccessPolicy, |
| AccessDeniedMessage: p.AccessDeniedMessage, |
| } |
| } |
|
|
| |
| func GetCustomOAuthProviders(c *gin.Context) { |
| providers, err := model.GetAllCustomOAuthProviders() |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| response := make([]*CustomOAuthProviderResponse, len(providers)) |
| for i, p := range providers { |
| response[i] = toCustomOAuthProviderResponse(p) |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": response, |
| }) |
| } |
|
|
| |
| func GetCustomOAuthProvider(c *gin.Context) { |
| idStr := c.Param("id") |
| id, err := strconv.Atoi(idStr) |
| if err != nil { |
| common.ApiErrorMsg(c, "无效的 ID") |
| return |
| } |
|
|
| provider, err := model.GetCustomOAuthProviderById(id) |
| if err != nil { |
| common.ApiErrorMsg(c, "未找到该 OAuth 提供商") |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": toCustomOAuthProviderResponse(provider), |
| }) |
| } |
|
|
| |
| type CreateCustomOAuthProviderRequest struct { |
| Name string `json:"name" binding:"required"` |
| Slug string `json:"slug" binding:"required"` |
| Icon string `json:"icon"` |
| Enabled bool `json:"enabled"` |
| ClientId string `json:"client_id" binding:"required"` |
| ClientSecret string `json:"client_secret" binding:"required"` |
| AuthorizationEndpoint string `json:"authorization_endpoint" binding:"required"` |
| TokenEndpoint string `json:"token_endpoint" binding:"required"` |
| UserInfoEndpoint string `json:"user_info_endpoint" binding:"required"` |
| Scopes string `json:"scopes"` |
| UserIdField string `json:"user_id_field"` |
| UsernameField string `json:"username_field"` |
| DisplayNameField string `json:"display_name_field"` |
| EmailField string `json:"email_field"` |
| WellKnown string `json:"well_known"` |
| AuthStyle int `json:"auth_style"` |
| AccessPolicy string `json:"access_policy"` |
| AccessDeniedMessage string `json:"access_denied_message"` |
| } |
|
|
| type FetchCustomOAuthDiscoveryRequest struct { |
| WellKnownURL string `json:"well_known_url"` |
| IssuerURL string `json:"issuer_url"` |
| } |
|
|
| |
| func FetchCustomOAuthDiscovery(c *gin.Context) { |
| var req FetchCustomOAuthDiscoveryRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| common.ApiErrorMsg(c, "无效的请求参数: "+err.Error()) |
| return |
| } |
|
|
| wellKnownURL := strings.TrimSpace(req.WellKnownURL) |
| issuerURL := strings.TrimSpace(req.IssuerURL) |
|
|
| if wellKnownURL == "" && issuerURL == "" { |
| common.ApiErrorMsg(c, "请先填写 Discovery URL 或 Issuer URL") |
| return |
| } |
|
|
| targetURL := wellKnownURL |
| if targetURL == "" { |
| targetURL = strings.TrimRight(issuerURL, "/") + "/.well-known/openid-configuration" |
| } |
| targetURL = strings.TrimSpace(targetURL) |
|
|
| parsedURL, err := url.Parse(targetURL) |
| if err != nil || parsedURL.Host == "" || (parsedURL.Scheme != "http" && parsedURL.Scheme != "https") { |
| common.ApiErrorMsg(c, "Discovery URL 无效,仅支持 http/https") |
| return |
| } |
|
|
| ctx, cancel := context.WithTimeout(c.Request.Context(), 20*time.Second) |
| defer cancel() |
|
|
| httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) |
| if err != nil { |
| common.ApiErrorMsg(c, "创建 Discovery 请求失败: "+err.Error()) |
| return |
| } |
| httpReq.Header.Set("Accept", "application/json") |
|
|
| client := &http.Client{Timeout: 20 * time.Second} |
| resp, err := client.Do(httpReq) |
| if err != nil { |
| common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+err.Error()) |
| return |
| } |
| defer resp.Body.Close() |
|
|
| if resp.StatusCode != http.StatusOK { |
| body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) |
| message := strings.TrimSpace(string(body)) |
| if message == "" { |
| message = resp.Status |
| } |
| common.ApiErrorMsg(c, "获取 Discovery 配置失败: "+message) |
| return |
| } |
|
|
| var discovery map[string]any |
| if err = common.DecodeJson(resp.Body, &discovery); err != nil { |
| common.ApiErrorMsg(c, "解析 Discovery 配置失败: "+err.Error()) |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": gin.H{ |
| "well_known_url": targetURL, |
| "discovery": discovery, |
| }, |
| }) |
| } |
|
|
| |
| func CreateCustomOAuthProvider(c *gin.Context) { |
| var req CreateCustomOAuthProviderRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| common.ApiErrorMsg(c, "无效的请求参数: "+err.Error()) |
| return |
| } |
|
|
| |
| if model.IsSlugTaken(req.Slug, 0) { |
| common.ApiErrorMsg(c, "该 Slug 已被使用") |
| return |
| } |
|
|
| |
| if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) { |
| common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突") |
| return |
| } |
|
|
| provider := &model.CustomOAuthProvider{ |
| Name: req.Name, |
| Slug: req.Slug, |
| Icon: req.Icon, |
| Enabled: req.Enabled, |
| ClientId: req.ClientId, |
| ClientSecret: req.ClientSecret, |
| AuthorizationEndpoint: req.AuthorizationEndpoint, |
| TokenEndpoint: req.TokenEndpoint, |
| UserInfoEndpoint: req.UserInfoEndpoint, |
| Scopes: req.Scopes, |
| UserIdField: req.UserIdField, |
| UsernameField: req.UsernameField, |
| DisplayNameField: req.DisplayNameField, |
| EmailField: req.EmailField, |
| WellKnown: req.WellKnown, |
| AuthStyle: req.AuthStyle, |
| AccessPolicy: req.AccessPolicy, |
| AccessDeniedMessage: req.AccessDeniedMessage, |
| } |
|
|
| if err := model.CreateCustomOAuthProvider(provider); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| |
| oauth.RegisterOrUpdateCustomProvider(provider) |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "创建成功", |
| "data": toCustomOAuthProviderResponse(provider), |
| }) |
| } |
|
|
| |
| type UpdateCustomOAuthProviderRequest struct { |
| Name string `json:"name"` |
| Slug string `json:"slug"` |
| Icon *string `json:"icon"` |
| Enabled *bool `json:"enabled"` |
| ClientId string `json:"client_id"` |
| ClientSecret string `json:"client_secret"` |
| AuthorizationEndpoint string `json:"authorization_endpoint"` |
| TokenEndpoint string `json:"token_endpoint"` |
| UserInfoEndpoint string `json:"user_info_endpoint"` |
| Scopes string `json:"scopes"` |
| UserIdField string `json:"user_id_field"` |
| UsernameField string `json:"username_field"` |
| DisplayNameField string `json:"display_name_field"` |
| EmailField string `json:"email_field"` |
| WellKnown *string `json:"well_known"` |
| AuthStyle *int `json:"auth_style"` |
| AccessPolicy *string `json:"access_policy"` |
| AccessDeniedMessage *string `json:"access_denied_message"` |
| } |
|
|
| |
| func UpdateCustomOAuthProvider(c *gin.Context) { |
| idStr := c.Param("id") |
| id, err := strconv.Atoi(idStr) |
| if err != nil { |
| common.ApiErrorMsg(c, "无效的 ID") |
| return |
| } |
|
|
| var req UpdateCustomOAuthProviderRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| common.ApiErrorMsg(c, "无效的请求参数: "+err.Error()) |
| return |
| } |
|
|
| |
| provider, err := model.GetCustomOAuthProviderById(id) |
| if err != nil { |
| common.ApiErrorMsg(c, "未找到该 OAuth 提供商") |
| return |
| } |
|
|
| oldSlug := provider.Slug |
|
|
| |
| if req.Slug != "" && req.Slug != provider.Slug { |
| if model.IsSlugTaken(req.Slug, id) { |
| common.ApiErrorMsg(c, "该 Slug 已被使用") |
| return |
| } |
| |
| if oauth.IsProviderRegistered(req.Slug) && !oauth.IsCustomProvider(req.Slug) { |
| common.ApiErrorMsg(c, "该 Slug 与内置 OAuth 提供商冲突") |
| return |
| } |
| } |
|
|
| |
| if req.Name != "" { |
| provider.Name = req.Name |
| } |
| if req.Slug != "" { |
| provider.Slug = req.Slug |
| } |
| if req.Icon != nil { |
| provider.Icon = *req.Icon |
| } |
| if req.Enabled != nil { |
| provider.Enabled = *req.Enabled |
| } |
| if req.ClientId != "" { |
| provider.ClientId = req.ClientId |
| } |
| if req.ClientSecret != "" { |
| provider.ClientSecret = req.ClientSecret |
| } |
| if req.AuthorizationEndpoint != "" { |
| provider.AuthorizationEndpoint = req.AuthorizationEndpoint |
| } |
| if req.TokenEndpoint != "" { |
| provider.TokenEndpoint = req.TokenEndpoint |
| } |
| if req.UserInfoEndpoint != "" { |
| provider.UserInfoEndpoint = req.UserInfoEndpoint |
| } |
| if req.Scopes != "" { |
| provider.Scopes = req.Scopes |
| } |
| if req.UserIdField != "" { |
| provider.UserIdField = req.UserIdField |
| } |
| if req.UsernameField != "" { |
| provider.UsernameField = req.UsernameField |
| } |
| if req.DisplayNameField != "" { |
| provider.DisplayNameField = req.DisplayNameField |
| } |
| if req.EmailField != "" { |
| provider.EmailField = req.EmailField |
| } |
| if req.WellKnown != nil { |
| provider.WellKnown = *req.WellKnown |
| } |
| if req.AuthStyle != nil { |
| provider.AuthStyle = *req.AuthStyle |
| } |
| if req.AccessPolicy != nil { |
| provider.AccessPolicy = *req.AccessPolicy |
| } |
| if req.AccessDeniedMessage != nil { |
| provider.AccessDeniedMessage = *req.AccessDeniedMessage |
| } |
|
|
| if err := model.UpdateCustomOAuthProvider(provider); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| |
| if oldSlug != provider.Slug { |
| oauth.UnregisterCustomProvider(oldSlug) |
| } |
| oauth.RegisterOrUpdateCustomProvider(provider) |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "更新成功", |
| "data": toCustomOAuthProviderResponse(provider), |
| }) |
| } |
|
|
| |
| func DeleteCustomOAuthProvider(c *gin.Context) { |
| idStr := c.Param("id") |
| id, err := strconv.Atoi(idStr) |
| if err != nil { |
| common.ApiErrorMsg(c, "无效的 ID") |
| return |
| } |
|
|
| |
| provider, err := model.GetCustomOAuthProviderById(id) |
| if err != nil { |
| common.ApiErrorMsg(c, "未找到该 OAuth 提供商") |
| return |
| } |
|
|
| |
| count, err := model.GetBindingCountByProviderId(id) |
| if err != nil { |
| common.SysError("Failed to get binding count for provider " + strconv.Itoa(id) + ": " + err.Error()) |
| common.ApiErrorMsg(c, "检查用户绑定时发生错误,请稍后重试") |
| return |
| } |
| if count > 0 { |
| common.ApiErrorMsg(c, "该 OAuth 提供商还有用户绑定,无法删除。请先解除所有用户绑定。") |
| return |
| } |
|
|
| if err := model.DeleteCustomOAuthProvider(id); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| |
| oauth.UnregisterCustomProvider(provider.Slug) |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "删除成功", |
| }) |
| } |
|
|
| func buildUserOAuthBindingsResponse(userId int) ([]UserOAuthBindingResponse, error) { |
| bindings, err := model.GetUserOAuthBindingsByUserId(userId) |
| if err != nil { |
| return nil, err |
| } |
|
|
| response := make([]UserOAuthBindingResponse, 0, len(bindings)) |
| for _, binding := range bindings { |
| provider, err := model.GetCustomOAuthProviderById(binding.ProviderId) |
| if err != nil { |
| continue |
| } |
| response = append(response, UserOAuthBindingResponse{ |
| ProviderId: binding.ProviderId, |
| ProviderName: provider.Name, |
| ProviderSlug: provider.Slug, |
| ProviderIcon: provider.Icon, |
| ProviderUserId: binding.ProviderUserId, |
| }) |
| } |
|
|
| return response, nil |
| } |
|
|
| |
| func GetUserOAuthBindings(c *gin.Context) { |
| userId := c.GetInt("id") |
| if userId == 0 { |
| common.ApiErrorMsg(c, "未登录") |
| return |
| } |
|
|
| response, err := buildUserOAuthBindingsResponse(userId) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": response, |
| }) |
| } |
|
|
| func GetUserOAuthBindingsByAdmin(c *gin.Context) { |
| userIdStr := c.Param("id") |
| userId, err := strconv.Atoi(userIdStr) |
| if err != nil { |
| common.ApiErrorMsg(c, "invalid user id") |
| return |
| } |
|
|
| targetUser, err := model.GetUserById(userId, false) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| myRole := c.GetInt("role") |
| if myRole <= targetUser.Role && myRole != common.RoleRootUser { |
| common.ApiErrorMsg(c, "no permission") |
| return |
| } |
|
|
| response, err := buildUserOAuthBindingsResponse(userId) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": response, |
| }) |
| } |
|
|
| |
| func UnbindCustomOAuth(c *gin.Context) { |
| userId := c.GetInt("id") |
| if userId == 0 { |
| common.ApiErrorMsg(c, "未登录") |
| return |
| } |
|
|
| providerIdStr := c.Param("provider_id") |
| providerId, err := strconv.Atoi(providerIdStr) |
| if err != nil { |
| common.ApiErrorMsg(c, "无效的提供商 ID") |
| return |
| } |
|
|
| if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "解绑成功", |
| }) |
| } |
|
|
| func UnbindCustomOAuthByAdmin(c *gin.Context) { |
| userIdStr := c.Param("id") |
| userId, err := strconv.Atoi(userIdStr) |
| if err != nil { |
| common.ApiErrorMsg(c, "invalid user id") |
| return |
| } |
|
|
| targetUser, err := model.GetUserById(userId, false) |
| if err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| myRole := c.GetInt("role") |
| if myRole <= targetUser.Role && myRole != common.RoleRootUser { |
| common.ApiErrorMsg(c, "no permission") |
| return |
| } |
|
|
| providerIdStr := c.Param("provider_id") |
| providerId, err := strconv.Atoi(providerIdStr) |
| if err != nil { |
| common.ApiErrorMsg(c, "invalid provider id") |
| return |
| } |
|
|
| if err := model.DeleteUserOAuthBinding(userId, providerId); err != nil { |
| common.ApiError(c, err) |
| return |
| } |
|
|
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "success", |
| }) |
| } |
|
|