| package service |
|
|
| import ( |
| "bytes" |
| "encoding/base64" |
| "fmt" |
| "image" |
| _ "image/gif" |
| _ "image/jpeg" |
| _ "image/png" |
| "io" |
| "net/http" |
| "strings" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/logger" |
| "github.com/QuantumNous/new-api/types" |
|
|
| "github.com/gin-gonic/gin" |
| "golang.org/x/image/webp" |
| ) |
|
|
| |
| |
|
|
| |
| func getContextCacheKey(url string) string { |
| return fmt.Sprintf("file_cache_%s", common.GenerateHMAC(url)) |
| } |
|
|
| |
| |
| func LoadFileSource(c *gin.Context, source *types.FileSource, reason ...string) (*types.CachedFileData, error) { |
| if source == nil { |
| return nil, fmt.Errorf("file source is nil") |
| } |
|
|
| if common.DebugEnabled { |
| logger.LogDebug(c, fmt.Sprintf("LoadFileSource starting for: %s", source.GetIdentifier())) |
| } |
|
|
| |
| if source.HasCache() { |
| |
| if c != nil { |
| registerSourceForCleanup(c, source) |
| } |
| return source.GetCache(), nil |
| } |
|
|
| |
| source.Mu().Lock() |
| defer source.Mu().Unlock() |
|
|
| |
| if source.HasCache() { |
| if c != nil { |
| registerSourceForCleanup(c, source) |
| } |
| return source.GetCache(), nil |
| } |
|
|
| |
| var contextKey string |
| if source.IsURL() && c != nil { |
| contextKey = getContextCacheKey(source.URL) |
| if cachedData, exists := c.Get(contextKey); exists { |
| data := cachedData.(*types.CachedFileData) |
| source.SetCache(data) |
| registerSourceForCleanup(c, source) |
| return data, nil |
| } |
| } |
|
|
| |
| var cachedData *types.CachedFileData |
| var err error |
|
|
| if source.IsURL() { |
| cachedData, err = loadFromURL(c, source.URL, reason...) |
| } else { |
| cachedData, err = loadFromBase64(source.Base64Data, source.MimeType) |
| } |
|
|
| if err != nil { |
| return nil, err |
| } |
|
|
| |
| source.SetCache(cachedData) |
| if contextKey != "" && c != nil { |
| c.Set(contextKey, cachedData) |
| } |
|
|
| |
| if c != nil { |
| registerSourceForCleanup(c, source) |
| } |
|
|
| return cachedData, nil |
| } |
|
|
| |
| func registerSourceForCleanup(c *gin.Context, source *types.FileSource) { |
| if source.IsRegistered() { |
| return |
| } |
|
|
| key := string(constant.ContextKeyFileSourcesToCleanup) |
| var sources []*types.FileSource |
| if existing, exists := c.Get(key); exists { |
| sources = existing.([]*types.FileSource) |
| } |
| sources = append(sources, source) |
| c.Set(key, sources) |
| source.SetRegistered(true) |
| } |
|
|
| |
| |
| func CleanupFileSources(c *gin.Context) { |
| key := string(constant.ContextKeyFileSourcesToCleanup) |
| if sources, exists := c.Get(key); exists { |
| for _, source := range sources.([]*types.FileSource) { |
| if cache := source.GetCache(); cache != nil { |
| cache.Close() |
| } |
| } |
| c.Set(key, nil) |
| } |
| } |
|
|
| |
| func loadFromURL(c *gin.Context, url string, reason ...string) (*types.CachedFileData, error) { |
| |
| var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024 |
|
|
| if common.DebugEnabled { |
| logger.LogDebug(c, "loadFromURL: initiating download") |
| } |
| resp, err := DoDownloadRequest(url, reason...) |
| if err != nil { |
| return nil, fmt.Errorf("failed to download file from %s: %w", url, err) |
| } |
| defer resp.Body.Close() |
|
|
| if resp.StatusCode != 200 { |
| return nil, fmt.Errorf("failed to download file, status code: %d", resp.StatusCode) |
| } |
|
|
| |
| if common.DebugEnabled { |
| logger.LogDebug(c, "loadFromURL: reading response body") |
| } |
| fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1))) |
| if err != nil { |
| return nil, fmt.Errorf("failed to read file content: %w", err) |
| } |
| if len(fileBytes) > maxFileSize { |
| return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB) |
| } |
|
|
| |
| base64Data := base64.StdEncoding.EncodeToString(fileBytes) |
|
|
| |
| mimeType := smartDetectMimeType(resp, url, fileBytes) |
|
|
| |
| base64Size := int64(len(base64Data)) |
| var cachedData *types.CachedFileData |
|
|
| if shouldUseDiskCache(base64Size) { |
| |
| diskPath, err := writeToDiskCache(base64Data) |
| if err != nil { |
| |
| logger.LogWarn(c, fmt.Sprintf("Failed to write to disk cache, falling back to memory: %v", err)) |
| cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes))) |
| } else { |
| cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(fileBytes))) |
| cachedData.DiskSize = base64Size |
| cachedData.OnClose = func(size int64) { |
| common.DecrementDiskFiles(size) |
| } |
| common.IncrementDiskFiles(base64Size) |
| if common.DebugEnabled { |
| logger.LogDebug(c, fmt.Sprintf("File cached to disk: %s, size: %d bytes", diskPath, base64Size)) |
| } |
| } |
| } else { |
| |
| cachedData = types.NewMemoryCachedData(base64Data, mimeType, int64(len(fileBytes))) |
| } |
|
|
| |
| if strings.HasPrefix(mimeType, "image/") { |
| if common.DebugEnabled { |
| logger.LogDebug(c, "loadFromURL: decoding image config") |
| } |
| config, format, err := decodeImageConfig(fileBytes) |
| if err == nil { |
| cachedData.ImageConfig = &config |
| cachedData.ImageFormat = format |
| |
| if mimeType == "application/octet-stream" || mimeType == "" { |
| cachedData.MimeType = "image/" + format |
| } |
| } |
| } |
|
|
| return cachedData, nil |
| } |
|
|
| |
| func shouldUseDiskCache(dataSize int64) bool { |
| return common.ShouldUseDiskCache(dataSize) |
| } |
|
|
| |
| func writeToDiskCache(base64Data string) (string, error) { |
| return common.WriteDiskCacheFileString(common.DiskCacheTypeFile, base64Data) |
| } |
|
|
| |
| func smartDetectMimeType(resp *http.Response, url string, fileBytes []byte) string { |
| |
| mimeType := resp.Header.Get("Content-Type") |
| if idx := strings.Index(mimeType, ";"); idx != -1 { |
| mimeType = strings.TrimSpace(mimeType[:idx]) |
| } |
| if mimeType != "" && mimeType != "application/octet-stream" { |
| return mimeType |
| } |
|
|
| |
| if cd := resp.Header.Get("Content-Disposition"); cd != "" { |
| parts := strings.Split(cd, ";") |
| for _, part := range parts { |
| part = strings.TrimSpace(part) |
| if strings.HasPrefix(strings.ToLower(part), "filename=") { |
| name := strings.TrimSpace(strings.TrimPrefix(part, "filename=")) |
| |
| if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' { |
| name = name[1 : len(name)-1] |
| } |
| if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) { |
| ext := strings.ToLower(name[dot+1:]) |
| if ext != "" { |
| mt := GetMimeTypeByExtension(ext) |
| if mt != "application/octet-stream" { |
| return mt |
| } |
| } |
| } |
| break |
| } |
| } |
| } |
|
|
| |
| mt := guessMimeTypeFromURL(url) |
| if mt != "application/octet-stream" { |
| return mt |
| } |
|
|
| |
| if len(fileBytes) > 0 { |
| sniffed := http.DetectContentType(fileBytes) |
| if sniffed != "" && sniffed != "application/octet-stream" { |
| |
| if idx := strings.Index(sniffed, ";"); idx != -1 { |
| sniffed = strings.TrimSpace(sniffed[:idx]) |
| } |
| return sniffed |
| } |
| } |
|
|
| |
| if len(fileBytes) > 0 { |
| if _, format, err := decodeImageConfig(fileBytes); err == nil && format != "" { |
| return "image/" + strings.ToLower(format) |
| } |
| } |
|
|
| |
| return "application/octet-stream" |
| } |
|
|
| |
| func loadFromBase64(base64String string, providedMimeType string) (*types.CachedFileData, error) { |
| var mimeType string |
| var cleanBase64 string |
|
|
| |
| if strings.HasPrefix(base64String, "data:") { |
| idx := strings.Index(base64String, ",") |
| if idx != -1 { |
| header := base64String[:idx] |
| cleanBase64 = base64String[idx+1:] |
|
|
| if strings.Contains(header, ":") && strings.Contains(header, ";") { |
| mimeStart := strings.Index(header, ":") + 1 |
| mimeEnd := strings.Index(header, ";") |
| if mimeStart < mimeEnd { |
| mimeType = header[mimeStart:mimeEnd] |
| } |
| } |
| } else { |
| cleanBase64 = base64String |
| } |
| } else { |
| cleanBase64 = base64String |
| } |
|
|
| if providedMimeType != "" { |
| mimeType = providedMimeType |
| } |
|
|
| decodedData, err := base64.StdEncoding.DecodeString(cleanBase64) |
| if err != nil { |
| return nil, fmt.Errorf("failed to decode base64 data: %w", err) |
| } |
|
|
| base64Size := int64(len(cleanBase64)) |
| var cachedData *types.CachedFileData |
|
|
| if shouldUseDiskCache(base64Size) { |
| diskPath, err := writeToDiskCache(cleanBase64) |
| if err != nil { |
| cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData))) |
| } else { |
| cachedData = types.NewDiskCachedData(diskPath, mimeType, int64(len(decodedData))) |
| cachedData.DiskSize = base64Size |
| cachedData.OnClose = func(size int64) { |
| common.DecrementDiskFiles(size) |
| } |
| common.IncrementDiskFiles(base64Size) |
| } |
| } else { |
| cachedData = types.NewMemoryCachedData(cleanBase64, mimeType, int64(len(decodedData))) |
| } |
|
|
| if mimeType == "" || strings.HasPrefix(mimeType, "image/") { |
| config, format, err := decodeImageConfig(decodedData) |
| if err == nil { |
| cachedData.ImageConfig = &config |
| cachedData.ImageFormat = format |
| if mimeType == "" { |
| cachedData.MimeType = "image/" + format |
| } |
| } |
| } |
|
|
| return cachedData, nil |
| } |
|
|
| |
| func GetImageConfig(c *gin.Context, source *types.FileSource) (image.Config, string, error) { |
| cachedData, err := LoadFileSource(c, source, "get_image_config") |
| if err != nil { |
| return image.Config{}, "", err |
| } |
|
|
| if cachedData.ImageConfig != nil { |
| return *cachedData.ImageConfig, cachedData.ImageFormat, nil |
| } |
|
|
| base64Str, err := cachedData.GetBase64Data() |
| if err != nil { |
| return image.Config{}, "", fmt.Errorf("failed to get base64 data: %w", err) |
| } |
| decodedData, err := base64.StdEncoding.DecodeString(base64Str) |
| if err != nil { |
| return image.Config{}, "", fmt.Errorf("failed to decode base64 for image config: %w", err) |
| } |
|
|
| config, format, err := decodeImageConfig(decodedData) |
| if err != nil { |
| return image.Config{}, "", err |
| } |
|
|
| cachedData.ImageConfig = &config |
| cachedData.ImageFormat = format |
|
|
| return config, format, nil |
| } |
|
|
| |
| func GetBase64Data(c *gin.Context, source *types.FileSource, reason ...string) (string, string, error) { |
| cachedData, err := LoadFileSource(c, source, reason...) |
| if err != nil { |
| return "", "", err |
| } |
| base64Str, err := cachedData.GetBase64Data() |
| if err != nil { |
| return "", "", fmt.Errorf("failed to get base64 data: %w", err) |
| } |
| return base64Str, cachedData.MimeType, nil |
| } |
|
|
| |
| func GetMimeType(c *gin.Context, source *types.FileSource) (string, error) { |
| if source.HasCache() { |
| return source.GetCache().MimeType, nil |
| } |
|
|
| if source.IsURL() { |
| mimeType, err := GetFileTypeFromUrl(c, source.URL, "get_mime_type") |
| if err == nil && mimeType != "" && mimeType != "application/octet-stream" { |
| return mimeType, nil |
| } |
| } |
|
|
| cachedData, err := LoadFileSource(c, source, "get_mime_type") |
| if err != nil { |
| return "", err |
| } |
| return cachedData.MimeType, nil |
| } |
|
|
| |
| func DetectFileType(mimeType string) types.FileType { |
| if strings.HasPrefix(mimeType, "image/") { |
| return types.FileTypeImage |
| } |
| if strings.HasPrefix(mimeType, "audio/") { |
| return types.FileTypeAudio |
| } |
| if strings.HasPrefix(mimeType, "video/") { |
| return types.FileTypeVideo |
| } |
| return types.FileTypeFile |
| } |
|
|
| |
| func decodeImageConfig(data []byte) (image.Config, string, error) { |
| reader := bytes.NewReader(data) |
|
|
| config, format, err := image.DecodeConfig(reader) |
| if err == nil { |
| return config, format, nil |
| } |
|
|
| reader.Seek(0, io.SeekStart) |
| config, err = webp.DecodeConfig(reader) |
| if err == nil { |
| return config, "webp", nil |
| } |
|
|
| return image.Config{}, "", fmt.Errorf("failed to decode image config: unsupported format") |
| } |
|
|
| |
| func guessMimeTypeFromURL(url string) string { |
| cleanedURL := url |
| if q := strings.Index(cleanedURL, "?"); q != -1 { |
| cleanedURL = cleanedURL[:q] |
| } |
|
|
| if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) { |
| last := cleanedURL[slash+1:] |
| if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) { |
| ext := strings.ToLower(last[dot+1:]) |
| return GetMimeTypeByExtension(ext) |
| } |
| } |
|
|
| return "application/octet-stream" |
| } |
|
|