File size: 4,811 Bytes
daa8246 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | package service
import (
"fmt"
"strings"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/model"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/types"
"github.com/shopspring/decimal"
"github.com/gin-gonic/gin"
)
const (
ViolationFeeCodePrefix = "violation_fee."
CSAMViolationMarker = "Failed check: SAFETY_CHECK_TYPE"
ContentViolatesUsageMarker = "Content violates usage guidelines"
)
func IsViolationFeeCode(code types.ErrorCode) bool {
return strings.HasPrefix(string(code), ViolationFeeCodePrefix)
}
func HasCSAMViolationMarker(err *types.NewAPIError) bool {
if err == nil {
return false
}
if strings.Contains(err.Error(), CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker) {
return true
}
msg := err.ToOpenAIError().Message
return strings.Contains(msg, CSAMViolationMarker) || strings.Contains(err.Error(), ContentViolatesUsageMarker)
}
func WrapAsViolationFeeGrokCSAM(err *types.NewAPIError) *types.NewAPIError {
if err == nil {
return nil
}
oai := err.ToOpenAIError()
oai.Type = string(types.ErrorCodeViolationFeeGrokCSAM)
oai.Code = string(types.ErrorCodeViolationFeeGrokCSAM)
return types.WithOpenAIError(oai, err.StatusCode, types.ErrOptionWithSkipRetry())
}
// NormalizeViolationFeeError ensures:
// - if the CSAM marker is present, error.code is set to a stable violation-fee code and skip-retry is enabled.
// - if error.code already has the violation-fee prefix, skip-retry is enabled.
//
// It must be called before retry decision logic.
func NormalizeViolationFeeError(err *types.NewAPIError) *types.NewAPIError {
if err == nil {
return nil
}
if HasCSAMViolationMarker(err) {
return WrapAsViolationFeeGrokCSAM(err)
}
if IsViolationFeeCode(err.GetErrorCode()) {
oai := err.ToOpenAIError()
return types.WithOpenAIError(oai, err.StatusCode, types.ErrOptionWithSkipRetry())
}
return err
}
func shouldChargeViolationFee(err *types.NewAPIError) bool {
if err == nil {
return false
}
if err.GetErrorCode() == types.ErrorCodeViolationFeeGrokCSAM {
return true
}
// In case some callers didn't normalize, keep a safety net.
return HasCSAMViolationMarker(err)
}
func calcViolationFeeQuota(amount, groupRatio float64) int {
if amount <= 0 {
return 0
}
if groupRatio <= 0 {
return 0
}
quota := decimal.NewFromFloat(amount).
Mul(decimal.NewFromFloat(common.QuotaPerUnit)).
Mul(decimal.NewFromFloat(groupRatio)).
Round(0).
IntPart()
if quota <= 0 {
return 0
}
return int(quota)
}
// ChargeViolationFeeIfNeeded charges an additional fee after the normal flow finishes (including refund).
// It uses Grok fee settings as the fee policy.
func ChargeViolationFeeIfNeeded(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, apiErr *types.NewAPIError) bool {
if ctx == nil || relayInfo == nil || apiErr == nil {
return false
}
//if relayInfo.IsPlayground {
// return false
//}
if !shouldChargeViolationFee(apiErr) {
return false
}
settings := model_setting.GetGrokSettings()
if settings == nil || !settings.ViolationDeductionEnabled {
return false
}
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
feeQuota := calcViolationFeeQuota(settings.ViolationDeductionAmount, groupRatio)
if feeQuota <= 0 {
return false
}
if err := PostConsumeQuota(relayInfo, feeQuota, 0, true); err != nil {
logger.LogError(ctx, fmt.Sprintf("failed to charge violation fee: %s", err.Error()))
return false
}
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, feeQuota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, feeQuota)
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
tokenName := ctx.GetString("token_name")
oai := apiErr.ToOpenAIError()
other := map[string]any{
"violation_fee": true,
"violation_fee_code": string(types.ErrorCodeViolationFeeGrokCSAM),
"fee_quota": feeQuota,
"base_amount": settings.ViolationDeductionAmount,
"group_ratio": groupRatio,
"status_code": apiErr.StatusCode,
"upstream_error_type": oai.Type,
"upstream_error_code": fmt.Sprintf("%v", oai.Code),
"violation_fee_marker": CSAMViolationMarker,
}
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
ModelName: relayInfo.OriginModelName,
TokenName: tokenName,
Quota: feeQuota,
Content: "Violation fee charged",
TokenId: relayInfo.TokenId,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
Other: other,
})
return true
}
|