| package service |
|
|
| import ( |
| "fmt" |
| "strings" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/logger" |
| "github.com/QuantumNous/new-api/model" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| "github.com/QuantumNous/new-api/setting/model_setting" |
| "github.com/QuantumNous/new-api/types" |
|
|
| "github.com/shopspring/decimal" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| const ( |
| ViolationFeeCodePrefix = "violation_fee." |
| CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE" |
| ContentViolatesUsageMarker = "Content violates usage guidelines" |
| ) |
|
|
| func IsViolationFeeCode(code types.ErrorCode) bool { |
| return strings.HasPrefix(string(code), ViolationFeeCodePrefix) |
| } |
|
|
| func HasCSAMViolationMarker(err *types.NewAPIError) bool { |
| if err == nil { |
| return false |
| } |
| if strings.Contains(err.Error(), CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) { |
| return true |
| } |
| msg := err.ToOpenAIError().Message |
| return strings.Contains(msg, CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) |
| } |
|
|
| func WrapAsViolationFeeGrokCSAM(err *types.NewAPIError) *types.NewAPIError { |
| if err == nil { |
| return nil |
| } |
| oai := err.ToOpenAIError() |
| oai.Type = string(types.ErrorCodeViolationFeeGrokCSAM) |
| oai.Code = string(types.ErrorCodeViolationFeeGrokCSAM) |
| return types.WithOpenAIError(oai, err.StatusCode, types.ErrOptionWithSkipRetry()) |
| } |
|
|
| |
| |
| |
| |
| |
| func NormalizeViolationFeeError(err *types.NewAPIError) *types.NewAPIError { |
| if err == nil { |
| return nil |
| } |
|
|
| if HasCSAMViolationMarker(err) { |
| return WrapAsViolationFeeGrokCSAM(err) |
| } |
|
|
| if IsViolationFeeCode(err.GetErrorCode()) { |
| oai := err.ToOpenAIError() |
| return types.WithOpenAIError(oai, err.StatusCode, types.ErrOptionWithSkipRetry()) |
| } |
|
|
| return err |
| } |
|
|
| func shouldChargeViolationFee(err *types.NewAPIError) bool { |
| if err == nil { |
| return false |
| } |
| if err.GetErrorCode() == types.ErrorCodeViolationFeeGrokCSAM { |
| return true |
| } |
| |
| return HasCSAMViolationMarker(err) |
| } |
|
|
| func calcViolationFeeQuota(amount, groupRatio float64) int { |
| if amount <= 0 { |
| return 0 |
| } |
| if groupRatio <= 0 { |
| return 0 |
| } |
| quota := decimal.NewFromFloat(amount). |
| Mul(decimal.NewFromFloat(common.QuotaPerUnit)). |
| Mul(decimal.NewFromFloat(groupRatio)). |
| Round(0). |
| IntPart() |
| if quota <= 0 { |
| return 0 |
| } |
| return int(quota) |
| } |
|
|
| |
| |
| func ChargeViolationFeeIfNeeded(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, apiErr *types.NewAPIError) bool { |
| if ctx == nil || relayInfo == nil || apiErr == nil { |
| return false |
| } |
| |
| |
| |
| if !shouldChargeViolationFee(apiErr) { |
| return false |
| } |
|
|
| settings := model_setting.GetGrokSettings() |
| if settings == nil || !settings.ViolationDeductionEnabled { |
| return false |
| } |
|
|
| groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio |
| feeQuota := calcViolationFeeQuota(settings.ViolationDeductionAmount, groupRatio) |
| if feeQuota <= 0 { |
| return false |
| } |
|
|
| if err := PostConsumeQuota(relayInfo, feeQuota, 0, true); err != nil { |
| logger.LogError(ctx, fmt.Sprintf("failed to charge violation fee: %s", err.Error())) |
| return false |
| } |
|
|
| model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, feeQuota) |
| model.UpdateChannelUsedQuota(relayInfo.ChannelId, feeQuota) |
|
|
| useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() |
| tokenName := ctx.GetString("token_name") |
| oai := apiErr.ToOpenAIError() |
|
|
| other := map[string]any{ |
| "violation_fee": true, |
| "violation_fee_code": string(types.ErrorCodeViolationFeeGrokCSAM), |
| "fee_quota": feeQuota, |
| "base_amount": settings.ViolationDeductionAmount, |
| "group_ratio": groupRatio, |
| "status_code": apiErr.StatusCode, |
| "upstream_error_type": oai.Type, |
| "upstream_error_code": fmt.Sprintf("%v", oai.Code), |
| "violation_fee_marker": CSAMViolationMarker, |
| } |
|
|
| model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{ |
| ChannelId: relayInfo.ChannelId, |
| ModelName: relayInfo.OriginModelName, |
| TokenName: tokenName, |
| Quota: feeQuota, |
| Content: "Violation fee charged", |
| TokenId: relayInfo.TokenId, |
| UseTimeSeconds: int(useTimeSeconds), |
| IsStream: relayInfo.IsStream, |
| Group: relayInfo.UsingGroup, |
| Other: other, |
| }) |
|
|
| return true |
| } |
|
|