| package common |
|
|
| import ( |
| "bytes" |
| "fmt" |
| "io" |
| "mime" |
| "mime/multipart" |
| "net/http" |
| "net/url" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/constant" |
| "github.com/pkg/errors" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| const KeyRequestBody = "key_request_body" |
| const KeyBodyStorage = "key_body_storage" |
|
|
| var ErrRequestBodyTooLarge = errors.New("request body too large") |
|
|
| func IsRequestBodyTooLargeError(err error) bool { |
| if err == nil { |
| return false |
| } |
| if errors.Is(err, ErrRequestBodyTooLarge) { |
| return true |
| } |
| var mbe *http.MaxBytesError |
| return errors.As(err, &mbe) |
| } |
|
|
| func GetRequestBody(c *gin.Context) (io.Seeker, error) { |
| |
| if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil { |
| if bs, ok := storage.(BodyStorage); ok { |
| if _, err := bs.Seek(0, io.SeekStart); err != nil { |
| return nil, fmt.Errorf("failed to seek body storage: %w", err) |
| } |
| return bs, nil |
| } |
| } |
|
|
| |
| cached, exists := c.Get(KeyRequestBody) |
| if exists && cached != nil { |
| if b, ok := cached.([]byte); ok { |
| bs, err := CreateBodyStorage(b) |
| if err != nil { |
| return nil, err |
| } |
| c.Set(KeyBodyStorage, bs) |
| return bs, nil |
| } |
| } |
|
|
| maxMB := constant.MaxRequestBodyMB |
| if maxMB <= 0 { |
| maxMB = 128 |
| } |
| maxBytes := int64(maxMB) << 20 |
|
|
| contentLength := c.Request.ContentLength |
|
|
| |
| storage, err := CreateBodyStorageFromReader(c.Request.Body, contentLength, maxBytes) |
| _ = c.Request.Body.Close() |
|
|
| if err != nil { |
| if IsRequestBodyTooLargeError(err) { |
| return nil, errors.Wrap(ErrRequestBodyTooLarge, fmt.Sprintf("request body exceeds %d MB", maxMB)) |
| } |
| return nil, err |
| } |
|
|
| |
| c.Set(KeyBodyStorage, storage) |
|
|
| return storage, nil |
| } |
|
|
| |
| func GetBodyStorage(c *gin.Context) (BodyStorage, error) { |
| seeker, err := GetRequestBody(c) |
| if err != nil { |
| return nil, err |
| } |
| bs, ok := seeker.(BodyStorage) |
| if !ok { |
| return nil, errors.New("unexpected body storage type") |
| } |
| return bs, nil |
| } |
|
|
| |
| func CleanupBodyStorage(c *gin.Context) { |
| if storage, exists := c.Get(KeyBodyStorage); exists && storage != nil { |
| if bs, ok := storage.(BodyStorage); ok { |
| bs.Close() |
| } |
| c.Set(KeyBodyStorage, nil) |
| } |
| } |
|
|
| func UnmarshalBodyReusable(c *gin.Context, v any) error { |
| storage, err := GetBodyStorage(c) |
| if err != nil { |
| return err |
| } |
| requestBody, err := storage.Bytes() |
| if err != nil { |
| return err |
| } |
| contentType := c.Request.Header.Get("Content-Type") |
| if strings.HasPrefix(contentType, "application/json") { |
| err = Unmarshal(requestBody, v) |
| } else if strings.Contains(contentType, gin.MIMEPOSTForm) { |
| err = parseFormData(requestBody, v) |
| } else if strings.Contains(contentType, gin.MIMEMultipartPOSTForm) { |
| err = parseMultipartFormData(c, requestBody, v) |
| } else { |
| |
| |
| } |
| if err != nil { |
| return err |
| } |
| |
| if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil { |
| return seekErr |
| } |
| c.Request.Body = io.NopCloser(storage) |
| return nil |
| } |
|
|
| func SetContextKey(c *gin.Context, key constant.ContextKey, value any) { |
| c.Set(string(key), value) |
| } |
|
|
| func GetContextKey(c *gin.Context, key constant.ContextKey) (any, bool) { |
| return c.Get(string(key)) |
| } |
|
|
| func GetContextKeyString(c *gin.Context, key constant.ContextKey) string { |
| return c.GetString(string(key)) |
| } |
|
|
| func GetContextKeyInt(c *gin.Context, key constant.ContextKey) int { |
| return c.GetInt(string(key)) |
| } |
|
|
| func GetContextKeyBool(c *gin.Context, key constant.ContextKey) bool { |
| return c.GetBool(string(key)) |
| } |
|
|
| func GetContextKeyStringSlice(c *gin.Context, key constant.ContextKey) []string { |
| return c.GetStringSlice(string(key)) |
| } |
|
|
| func GetContextKeyStringMap(c *gin.Context, key constant.ContextKey) map[string]any { |
| return c.GetStringMap(string(key)) |
| } |
|
|
| func GetContextKeyTime(c *gin.Context, key constant.ContextKey) time.Time { |
| return c.GetTime(string(key)) |
| } |
|
|
| func GetContextKeyType[T any](c *gin.Context, key constant.ContextKey) (T, bool) { |
| if value, ok := c.Get(string(key)); ok { |
| if v, ok := value.(T); ok { |
| return v, true |
| } |
| } |
| var t T |
| return t, false |
| } |
|
|
| func ApiError(c *gin.Context, err error) { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": err.Error(), |
| }) |
| } |
|
|
| func ApiErrorMsg(c *gin.Context, msg string) { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": msg, |
| }) |
| } |
|
|
| func ApiSuccess(c *gin.Context, data any) { |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": "", |
| "data": data, |
| }) |
| } |
|
|
| |
| |
| func ApiErrorI18n(c *gin.Context, key string, args ...map[string]any) { |
| msg := TranslateMessage(c, key, args...) |
| c.JSON(http.StatusOK, gin.H{ |
| "success": false, |
| "message": msg, |
| }) |
| } |
|
|
| |
| func ApiSuccessI18n(c *gin.Context, key string, data any, args ...map[string]any) { |
| msg := TranslateMessage(c, key, args...) |
| c.JSON(http.StatusOK, gin.H{ |
| "success": true, |
| "message": msg, |
| "data": data, |
| }) |
| } |
|
|
| |
| |
| |
| var TranslateMessage func(c *gin.Context, key string, args ...map[string]any) string |
|
|
| func init() { |
| |
| |
| TranslateMessage = func(c *gin.Context, key string, args ...map[string]any) string { |
| return key |
| } |
| } |
|
|
| func ParseMultipartFormReusable(c *gin.Context) (*multipart.Form, error) { |
| storage, err := GetBodyStorage(c) |
| if err != nil { |
| return nil, err |
| } |
| requestBody, err := storage.Bytes() |
| if err != nil { |
| return nil, err |
| } |
|
|
| |
| |
| var contentType string |
| if saved, ok := c.Get("_original_multipart_ct"); ok { |
| contentType = saved.(string) |
| } else { |
| contentType = c.Request.Header.Get("Content-Type") |
| c.Set("_original_multipart_ct", contentType) |
| } |
| boundary, err := parseBoundary(contentType) |
| if err != nil { |
| return nil, err |
| } |
|
|
| reader := multipart.NewReader(bytes.NewReader(requestBody), boundary) |
| form, err := reader.ReadForm(multipartMemoryLimit()) |
| if err != nil { |
| return nil, err |
| } |
|
|
| |
| if _, seekErr := storage.Seek(0, io.SeekStart); seekErr != nil { |
| return nil, seekErr |
| } |
| c.Request.Body = io.NopCloser(storage) |
| return form, nil |
| } |
|
|
| func processFormMap(formMap map[string]any, v any) error { |
| jsonData, err := Marshal(formMap) |
| if err != nil { |
| return err |
| } |
|
|
| err = Unmarshal(jsonData, v) |
| if err != nil { |
| return err |
| } |
|
|
| return nil |
| } |
|
|
| func parseFormData(data []byte, v any) error { |
| values, err := url.ParseQuery(string(data)) |
| if err != nil { |
| return err |
| } |
| formMap := make(map[string]any) |
| for key, vals := range values { |
| if len(vals) == 1 { |
| formMap[key] = vals[0] |
| } else { |
| formMap[key] = vals |
| } |
| } |
|
|
| return processFormMap(formMap, v) |
| } |
|
|
| func parseMultipartFormData(c *gin.Context, data []byte, v any) error { |
| var contentType string |
| if saved, ok := c.Get("_original_multipart_ct"); ok { |
| contentType = saved.(string) |
| } else { |
| contentType = c.Request.Header.Get("Content-Type") |
| c.Set("_original_multipart_ct", contentType) |
| } |
| boundary, err := parseBoundary(contentType) |
| if err != nil { |
| if errors.Is(err, errBoundaryNotFound) { |
| return Unmarshal(data, v) |
| } |
| return err |
| } |
|
|
| reader := multipart.NewReader(bytes.NewReader(data), boundary) |
| form, err := reader.ReadForm(multipartMemoryLimit()) |
| if err != nil { |
| return err |
| } |
| defer form.RemoveAll() |
| formMap := make(map[string]any) |
| for key, vals := range form.Value { |
| if len(vals) == 1 { |
| formMap[key] = vals[0] |
| } else { |
| formMap[key] = vals |
| } |
| } |
|
|
| return processFormMap(formMap, v) |
| } |
|
|
| var errBoundaryNotFound = errors.New("multipart boundary not found") |
|
|
| |
| func parseBoundary(contentType string) (string, error) { |
| if contentType == "" { |
| return "", errBoundaryNotFound |
| } |
| |
| _, params, err := mime.ParseMediaType(contentType) |
| if err != nil { |
| return "", err |
| } |
| boundary, ok := params["boundary"] |
| if !ok || boundary == "" { |
| return "", errBoundaryNotFound |
| } |
| return boundary, nil |
| } |
|
|
| |
| func multipartMemoryLimit() int64 { |
| limitMB := constant.MaxFileDownloadMB |
| if limitMB <= 0 { |
| limitMB = 32 |
| } |
| return int64(limitMB) << 20 |
| } |
|
|