| package service |
|
|
| import ( |
| "context" |
| "fmt" |
| "strings" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/logger" |
| "github.com/QuantumNous/new-api/model" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| "github.com/QuantumNous/new-api/setting/ratio_setting" |
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| |
| func LogTaskConsumption(c *gin.Context, info *relaycommon.RelayInfo) { |
| tokenName := c.GetString("token_name") |
| logContent := fmt.Sprintf("操作 %s", info.Action) |
| |
| if common.StringsContains(constant.TaskPricePatches, info.OriginModelName) { |
| logContent = fmt.Sprintf("%s,按次计费", logContent) |
| } else { |
| if len(info.PriceData.OtherRatios) > 0 { |
| var contents []string |
| for key, ra := range info.PriceData.OtherRatios { |
| if 1.0 != ra { |
| contents = append(contents, fmt.Sprintf("%s: %.2f", key, ra)) |
| } |
| } |
| if len(contents) > 0 { |
| logContent = fmt.Sprintf("%s, 计算参数:%s", logContent, strings.Join(contents, ", ")) |
| } |
| } |
| } |
| other := make(map[string]interface{}) |
| other["request_path"] = c.Request.URL.Path |
| other["model_price"] = info.PriceData.ModelPrice |
| other["group_ratio"] = info.PriceData.GroupRatioInfo.GroupRatio |
| if info.PriceData.GroupRatioInfo.HasSpecialRatio { |
| other["user_group_ratio"] = info.PriceData.GroupRatioInfo.GroupSpecialRatio |
| } |
| if info.IsModelMapped { |
| other["is_model_mapped"] = true |
| other["upstream_model_name"] = info.UpstreamModelName |
| } |
| model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{ |
| ChannelId: info.ChannelId, |
| ModelName: info.OriginModelName, |
| TokenName: tokenName, |
| Quota: info.PriceData.Quota, |
| Content: logContent, |
| TokenId: info.TokenId, |
| Group: info.UsingGroup, |
| Other: other, |
| }) |
| model.UpdateUserUsedQuotaAndRequestCount(info.UserId, info.PriceData.Quota) |
| model.UpdateChannelUsedQuota(info.ChannelId, info.PriceData.Quota) |
| } |
|
|
| |
| |
| |
|
|
| |
| |
| func resolveTokenKey(ctx context.Context, tokenId int, taskID string) string { |
| token, err := model.GetTokenById(tokenId) |
| if err != nil { |
| logger.LogWarn(ctx, fmt.Sprintf("获取令牌 key 失败 (tokenId=%d, task=%s): %s", tokenId, taskID, err.Error())) |
| return "" |
| } |
| return token.Key |
| } |
|
|
| |
| func taskIsSubscription(task *model.Task) bool { |
| return task.PrivateData.BillingSource == BillingSourceSubscription && task.PrivateData.SubscriptionId > 0 |
| } |
|
|
| |
| func taskAdjustFunding(task *model.Task, delta int) error { |
| if taskIsSubscription(task) { |
| return model.PostConsumeUserSubscriptionDelta(task.PrivateData.SubscriptionId, int64(delta)) |
| } |
| if delta > 0 { |
| return model.DecreaseUserQuota(task.UserId, delta) |
| } |
| return model.IncreaseUserQuota(task.UserId, -delta, false) |
| } |
|
|
| |
| |
| func taskAdjustTokenQuota(ctx context.Context, task *model.Task, delta int) { |
| if task.PrivateData.TokenId <= 0 || delta == 0 { |
| return |
| } |
| tokenKey := resolveTokenKey(ctx, task.PrivateData.TokenId, task.TaskID) |
| if tokenKey == "" { |
| return |
| } |
| var err error |
| if delta > 0 { |
| err = model.DecreaseTokenQuota(task.PrivateData.TokenId, tokenKey, delta) |
| } else { |
| err = model.IncreaseTokenQuota(task.PrivateData.TokenId, tokenKey, -delta) |
| } |
| if err != nil { |
| logger.LogWarn(ctx, fmt.Sprintf("调整令牌额度失败 (delta=%d, task=%s): %s", delta, task.TaskID, err.Error())) |
| } |
| } |
|
|
| |
| func taskBillingOther(task *model.Task) map[string]interface{} { |
| other := make(map[string]interface{}) |
| if bc := task.PrivateData.BillingContext; bc != nil { |
| other["model_price"] = bc.ModelPrice |
| other["group_ratio"] = bc.GroupRatio |
| if len(bc.OtherRatios) > 0 { |
| for k, v := range bc.OtherRatios { |
| other[k] = v |
| } |
| } |
| } |
| props := task.Properties |
| if props.UpstreamModelName != "" && props.UpstreamModelName != props.OriginModelName { |
| other["is_model_mapped"] = true |
| other["upstream_model_name"] = props.UpstreamModelName |
| } |
| return other |
| } |
|
|
| |
| func taskModelName(task *model.Task) string { |
| if bc := task.PrivateData.BillingContext; bc != nil && bc.OriginModelName != "" { |
| return bc.OriginModelName |
| } |
| return task.Properties.OriginModelName |
| } |
|
|
| |
| |
| func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { |
| quota := task.Quota |
| if quota == 0 { |
| return |
| } |
|
|
| |
| if err := taskAdjustFunding(task, -quota); err != nil { |
| logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error())) |
| return |
| } |
|
|
| |
| taskAdjustTokenQuota(ctx, task, -quota) |
|
|
| |
| other := taskBillingOther(task) |
| other["task_id"] = task.TaskID |
| other["reason"] = reason |
| model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ |
| UserId: task.UserId, |
| LogType: model.LogTypeRefund, |
| Content: "", |
| ChannelId: task.ChannelId, |
| ModelName: taskModelName(task), |
| Quota: quota, |
| TokenId: task.PrivateData.TokenId, |
| Group: task.Group, |
| Other: other, |
| }) |
| } |
|
|
| |
| |
| |
| func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) { |
| if actualQuota <= 0 { |
| return |
| } |
| preConsumedQuota := task.Quota |
| quotaDelta := actualQuota - preConsumedQuota |
|
|
| if quotaDelta == 0 { |
| logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)", |
| task.TaskID, logger.LogQuota(actualQuota), reason)) |
| return |
| } |
|
|
| logger.LogInfo(ctx, fmt.Sprintf("任务 %s 差额结算:delta=%s(实际:%s,预扣:%s,%s)", |
| task.TaskID, |
| logger.LogQuota(quotaDelta), |
| logger.LogQuota(actualQuota), |
| logger.LogQuota(preConsumedQuota), |
| reason, |
| )) |
|
|
| |
| if err := taskAdjustFunding(task, quotaDelta); err != nil { |
| logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) |
| return |
| } |
|
|
| |
| taskAdjustTokenQuota(ctx, task, quotaDelta) |
|
|
| task.Quota = actualQuota |
|
|
| var logType int |
| var logQuota int |
| if quotaDelta > 0 { |
| logType = model.LogTypeConsume |
| logQuota = quotaDelta |
| model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) |
| model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) |
| } else { |
| logType = model.LogTypeRefund |
| logQuota = -quotaDelta |
| } |
| other := taskBillingOther(task) |
| other["task_id"] = task.TaskID |
| |
| other["pre_consumed_quota"] = preConsumedQuota |
| other["actual_quota"] = actualQuota |
| model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ |
| UserId: task.UserId, |
| LogType: logType, |
| Content: reason, |
| ChannelId: task.ChannelId, |
| ModelName: taskModelName(task), |
| Quota: logQuota, |
| TokenId: task.PrivateData.TokenId, |
| Group: task.Group, |
| Other: other, |
| }) |
| } |
|
|
| |
| |
| |
| func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTokens int) { |
| if totalTokens <= 0 { |
| return |
| } |
|
|
| modelName := taskModelName(task) |
|
|
| |
| modelRatio, hasRatioSetting, _ := ratio_setting.GetModelRatio(modelName) |
| |
| if !hasRatioSetting || modelRatio <= 0 { |
| return |
| } |
|
|
| |
| group := task.Group |
| if group == "" { |
| user, err := model.GetUserById(task.UserId, false) |
| if err == nil { |
| group = user.Group |
| } |
| } |
| if group == "" { |
| return |
| } |
|
|
| groupRatio := ratio_setting.GetGroupRatio(group) |
| userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(group, group) |
|
|
| var finalGroupRatio float64 |
| if hasUserGroupRatio { |
| finalGroupRatio = userGroupRatio |
| } else { |
| finalGroupRatio = groupRatio |
| } |
|
|
| |
| actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio) |
|
|
| reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f", totalTokens, modelRatio, finalGroupRatio) |
| RecalculateTaskQuota(ctx, task, actualQuota, reason) |
| } |
|
|