| package handler |
|
|
| import ( |
| "strconv" |
| "strings" |
| "time" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/handler/dto" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/response" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" |
| middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| type UsageHandler struct { |
| usageService *service.UsageService |
| apiKeyService *service.APIKeyService |
| } |
|
|
| |
| func NewUsageHandler(usageService *service.UsageService, apiKeyService *service.APIKeyService) *UsageHandler { |
| return &UsageHandler{ |
| usageService: usageService, |
| apiKeyService: apiKeyService, |
| } |
| } |
|
|
| |
| |
| func (h *UsageHandler) List(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| page, pageSize := response.ParsePagination(c) |
|
|
| var apiKeyID int64 |
| if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { |
| id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) |
| if err != nil { |
| response.BadRequest(c, "Invalid api_key_id") |
| return |
| } |
|
|
| |
| apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
| if apiKey.UserID != subject.UserID { |
| response.Forbidden(c, "Not authorized to access this API key's usage records") |
| return |
| } |
|
|
| apiKeyID = id |
| } |
|
|
| |
| model := c.Query("model") |
|
|
| var requestType *int16 |
| var stream *bool |
| if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { |
| parsed, err := service.ParseUsageRequestType(requestTypeStr) |
| if err != nil { |
| response.BadRequest(c, err.Error()) |
| return |
| } |
| value := int16(parsed) |
| requestType = &value |
| } else if streamStr := c.Query("stream"); streamStr != "" { |
| val, err := strconv.ParseBool(streamStr) |
| if err != nil { |
| response.BadRequest(c, "Invalid stream value, use true or false") |
| return |
| } |
| stream = &val |
| } |
|
|
| var billingType *int8 |
| if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { |
| val, err := strconv.ParseInt(billingTypeStr, 10, 8) |
| if err != nil { |
| response.BadRequest(c, "Invalid billing_type") |
| return |
| } |
| bt := int8(val) |
| billingType = &bt |
| } |
|
|
| |
| var startTime, endTime *time.Time |
| userTZ := c.Query("timezone") |
| if startDateStr := c.Query("start_date"); startDateStr != "" { |
| t, err := timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ) |
| if err != nil { |
| response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") |
| return |
| } |
| startTime = &t |
| } |
|
|
| if endDateStr := c.Query("end_date"); endDateStr != "" { |
| t, err := timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ) |
| if err != nil { |
| response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") |
| return |
| } |
| |
| t = t.AddDate(0, 0, 1) |
| endTime = &t |
| } |
|
|
| params := pagination.PaginationParams{Page: page, PageSize: pageSize} |
| filters := usagestats.UsageLogFilters{ |
| UserID: subject.UserID, |
| APIKeyID: apiKeyID, |
| Model: model, |
| RequestType: requestType, |
| Stream: stream, |
| BillingType: billingType, |
| StartTime: startTime, |
| EndTime: endTime, |
| } |
|
|
| records, result, err := h.usageService.ListWithFilters(c.Request.Context(), params, filters) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| out := make([]dto.UsageLog, 0, len(records)) |
| for i := range records { |
| out = append(out, *dto.UsageLogFromService(&records[i])) |
| } |
| response.Paginated(c, out, result.Total, page, pageSize) |
| } |
|
|
| |
| |
| func (h *UsageHandler) GetByID(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| usageID, err := strconv.ParseInt(c.Param("id"), 10, 64) |
| if err != nil { |
| response.BadRequest(c, "Invalid usage ID") |
| return |
| } |
|
|
| record, err := h.usageService.GetByID(c.Request.Context(), usageID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| |
| if record.UserID != subject.UserID { |
| response.Forbidden(c, "Not authorized to access this record") |
| return |
| } |
|
|
| response.Success(c, dto.UsageLogFromService(record)) |
| } |
|
|
| |
| |
| func (h *UsageHandler) Stats(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| var apiKeyID int64 |
| if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { |
| id, err := strconv.ParseInt(apiKeyIDStr, 10, 64) |
| if err != nil { |
| response.BadRequest(c, "Invalid api_key_id") |
| return |
| } |
|
|
| |
| apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), id) |
| if err != nil { |
| response.NotFound(c, "API key not found") |
| return |
| } |
| if apiKey.UserID != subject.UserID { |
| response.Forbidden(c, "Not authorized to access this API key's statistics") |
| return |
| } |
|
|
| apiKeyID = id |
| } |
|
|
| |
| userTZ := c.Query("timezone") |
| now := timezone.NowInUserLocation(userTZ) |
| var startTime, endTime time.Time |
|
|
| |
| startDateStr := c.Query("start_date") |
| endDateStr := c.Query("end_date") |
|
|
| if startDateStr != "" && endDateStr != "" { |
| |
| var err error |
| startTime, err = timezone.ParseInUserLocation("2006-01-02", startDateStr, userTZ) |
| if err != nil { |
| response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") |
| return |
| } |
| endTime, err = timezone.ParseInUserLocation("2006-01-02", endDateStr, userTZ) |
| if err != nil { |
| response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") |
| return |
| } |
| |
| endTime = endTime.AddDate(0, 0, 1) |
| } else { |
| |
| period := c.DefaultQuery("period", "today") |
| switch period { |
| case "today": |
| startTime = timezone.StartOfDayInUserLocation(now, userTZ) |
| case "week": |
| startTime = now.AddDate(0, 0, -7) |
| case "month": |
| startTime = now.AddDate(0, -1, 0) |
| default: |
| startTime = timezone.StartOfDayInUserLocation(now, userTZ) |
| } |
| endTime = now |
| } |
|
|
| var stats *service.UsageStats |
| var err error |
| if apiKeyID > 0 { |
| stats, err = h.usageService.GetStatsByAPIKey(c.Request.Context(), apiKeyID, startTime, endTime) |
| } else { |
| stats, err = h.usageService.GetStatsByUser(c.Request.Context(), subject.UserID, startTime, endTime) |
| } |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, stats) |
| } |
|
|
| |
| |
| func parseUserTimeRange(c *gin.Context) (time.Time, time.Time) { |
| userTZ := c.Query("timezone") |
| now := timezone.NowInUserLocation(userTZ) |
| startDate := c.Query("start_date") |
| endDate := c.Query("end_date") |
|
|
| var startTime, endTime time.Time |
|
|
| if startDate != "" { |
| if t, err := timezone.ParseInUserLocation("2006-01-02", startDate, userTZ); err == nil { |
| startTime = t |
| } else { |
| startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ) |
| } |
| } else { |
| startTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, -7), userTZ) |
| } |
|
|
| if endDate != "" { |
| if t, err := timezone.ParseInUserLocation("2006-01-02", endDate, userTZ); err == nil { |
| endTime = t.Add(24 * time.Hour) |
| } else { |
| endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ) |
| } |
| } else { |
| endTime = timezone.StartOfDayInUserLocation(now.AddDate(0, 0, 1), userTZ) |
| } |
|
|
| return startTime, endTime |
| } |
|
|
| |
| |
| func (h *UsageHandler) DashboardStats(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| stats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, stats) |
| } |
|
|
| |
| |
| func (h *UsageHandler) DashboardTrend(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| startTime, endTime := parseUserTimeRange(c) |
| granularity := c.DefaultQuery("granularity", "day") |
|
|
| trend, err := h.usageService.GetUserUsageTrendByUserID(c.Request.Context(), subject.UserID, startTime, endTime, granularity) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, gin.H{ |
| "trend": trend, |
| "start_date": startTime.Format("2006-01-02"), |
| "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), |
| "granularity": granularity, |
| }) |
| } |
|
|
| |
| |
| func (h *UsageHandler) DashboardModels(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| startTime, endTime := parseUserTimeRange(c) |
|
|
| stats, err := h.usageService.GetUserModelStats(c.Request.Context(), subject.UserID, startTime, endTime) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, gin.H{ |
| "models": stats, |
| "start_date": startTime.Format("2006-01-02"), |
| "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), |
| }) |
| } |
|
|
| |
| type BatchAPIKeysUsageRequest struct { |
| APIKeyIDs []int64 `json:"api_key_ids" binding:"required"` |
| } |
|
|
| |
| |
| func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { |
| subject, ok := middleware2.GetAuthSubjectFromContext(c) |
| if !ok { |
| response.Unauthorized(c, "User not authenticated") |
| return |
| } |
|
|
| var req BatchAPIKeysUsageRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| response.BadRequest(c, "Invalid request: "+err.Error()) |
| return |
| } |
|
|
| if len(req.APIKeyIDs) == 0 { |
| response.Success(c, gin.H{"stats": map[string]any{}}) |
| return |
| } |
|
|
| |
| if len(req.APIKeyIDs) > 100 { |
| response.BadRequest(c, "Too many API key IDs (maximum 100 allowed)") |
| return |
| } |
|
|
| validAPIKeyIDs, err := h.apiKeyService.VerifyOwnership(c.Request.Context(), subject.UserID, req.APIKeyIDs) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| if len(validAPIKeyIDs) == 0 { |
| response.Success(c, gin.H{"stats": map[string]any{}}) |
| return |
| } |
|
|
| stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{}) |
| if err != nil { |
| response.ErrorFrom(c, err) |
| return |
| } |
|
|
| response.Success(c, gin.H{"stats": stats}) |
| } |
|
|