| package service |
|
|
| import ( |
| "context" |
| "encoding/json" |
| "net/http" |
| "os" |
| "testing" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/model" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| "github.com/glebarez/sqlite" |
| "github.com/stretchr/testify/assert" |
| "github.com/stretchr/testify/require" |
| "gorm.io/gorm" |
| ) |
|
|
| func TestMain(m *testing.M) { |
| db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) |
| if err != nil { |
| panic("failed to open test db: " + err.Error()) |
| } |
| sqlDB, err := db.DB() |
| if err != nil { |
| panic("failed to get sql.DB: " + err.Error()) |
| } |
| sqlDB.SetMaxOpenConns(1) |
|
|
| model.DB = db |
| model.LOG_DB = db |
|
|
| common.UsingSQLite = true |
| common.RedisEnabled = false |
| common.BatchUpdateEnabled = false |
| common.LogConsumeEnabled = true |
|
|
| if err := db.AutoMigrate( |
| &model.Task{}, |
| &model.User{}, |
| &model.Token{}, |
| &model.Log{}, |
| &model.Channel{}, |
| &model.UserSubscription{}, |
| ); err != nil { |
| panic("failed to migrate: " + err.Error()) |
| } |
|
|
| os.Exit(m.Run()) |
| } |
|
|
| |
| |
| |
|
|
| func truncate(t *testing.T) { |
| t.Helper() |
| t.Cleanup(func() { |
| model.DB.Exec("DELETE FROM tasks") |
| model.DB.Exec("DELETE FROM users") |
| model.DB.Exec("DELETE FROM tokens") |
| model.DB.Exec("DELETE FROM logs") |
| model.DB.Exec("DELETE FROM channels") |
| model.DB.Exec("DELETE FROM user_subscriptions") |
| }) |
| } |
|
|
| func seedUser(t *testing.T, id int, quota int) { |
| t.Helper() |
| user := &model.User{Id: id, Username: "test_user", Quota: quota, Status: common.UserStatusEnabled} |
| require.NoError(t, model.DB.Create(user).Error) |
| } |
|
|
| func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) { |
| t.Helper() |
| token := &model.Token{ |
| Id: id, |
| UserId: userId, |
| Key: key, |
| Name: "test_token", |
| Status: common.TokenStatusEnabled, |
| RemainQuota: remainQuota, |
| UsedQuota: 0, |
| } |
| require.NoError(t, model.DB.Create(token).Error) |
| } |
|
|
| func seedSubscription(t *testing.T, id int, userId int, amountTotal int64, amountUsed int64) { |
| t.Helper() |
| sub := &model.UserSubscription{ |
| Id: id, |
| UserId: userId, |
| AmountTotal: amountTotal, |
| AmountUsed: amountUsed, |
| Status: "active", |
| StartTime: time.Now().Unix(), |
| EndTime: time.Now().Add(30 * 24 * time.Hour).Unix(), |
| } |
| require.NoError(t, model.DB.Create(sub).Error) |
| } |
|
|
| func seedChannel(t *testing.T, id int) { |
| t.Helper() |
| ch := &model.Channel{Id: id, Name: "test_channel", Key: "sk-test", Status: common.ChannelStatusEnabled} |
| require.NoError(t, model.DB.Create(ch).Error) |
| } |
|
|
| func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task { |
| return &model.Task{ |
| TaskID: "task_" + time.Now().Format("150405.000"), |
| UserId: userId, |
| ChannelId: channelId, |
| Quota: quota, |
| Status: model.TaskStatus(model.TaskStatusInProgress), |
| Group: "default", |
| Data: json.RawMessage(`{}`), |
| CreatedAt: time.Now().Unix(), |
| UpdatedAt: time.Now().Unix(), |
| Properties: model.Properties{ |
| OriginModelName: "test-model", |
| }, |
| PrivateData: model.TaskPrivateData{ |
| BillingSource: billingSource, |
| SubscriptionId: subscriptionId, |
| TokenId: tokenId, |
| BillingContext: &model.TaskBillingContext{ |
| ModelPrice: 0.02, |
| GroupRatio: 1.0, |
| OriginModelName: "test-model", |
| }, |
| }, |
| } |
| } |
|
|
| |
| |
| |
|
|
| func getUserQuota(t *testing.T, id int) int { |
| t.Helper() |
| var user model.User |
| require.NoError(t, model.DB.Select("quota").Where("id = ?", id).First(&user).Error) |
| return user.Quota |
| } |
|
|
| func getTokenRemainQuota(t *testing.T, id int) int { |
| t.Helper() |
| var token model.Token |
| require.NoError(t, model.DB.Select("remain_quota").Where("id = ?", id).First(&token).Error) |
| return token.RemainQuota |
| } |
|
|
| func getTokenUsedQuota(t *testing.T, id int) int { |
| t.Helper() |
| var token model.Token |
| require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&token).Error) |
| return token.UsedQuota |
| } |
|
|
| func getSubscriptionUsed(t *testing.T, id int) int64 { |
| t.Helper() |
| var sub model.UserSubscription |
| require.NoError(t, model.DB.Select("amount_used").Where("id = ?", id).First(&sub).Error) |
| return sub.AmountUsed |
| } |
|
|
| func getLastLog(t *testing.T) *model.Log { |
| t.Helper() |
| var log model.Log |
| err := model.LOG_DB.Order("id desc").First(&log).Error |
| if err != nil { |
| return nil |
| } |
| return &log |
| } |
|
|
| func countLogs(t *testing.T) int64 { |
| t.Helper() |
| var count int64 |
| model.LOG_DB.Model(&model.Log{}).Count(&count) |
| return count |
| } |
|
|
| |
| |
| |
|
|
| func TestRefundTaskQuota_Wallet(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 1, 1, 1 |
| const initQuota, preConsumed = 10000, 3000 |
| const tokenRemain = 5000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-test-key", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
|
|
| RefundTaskQuota(ctx, task, "task failed: upstream error") |
|
|
| |
| assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) |
|
|
| |
| assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) |
| assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID)) |
|
|
| |
| log := getLastLog(t) |
| require.NotNil(t, log) |
| assert.Equal(t, model.LogTypeRefund, log.Type) |
| assert.Equal(t, preConsumed, log.Quota) |
| assert.Equal(t, "test-model", log.ModelName) |
| } |
|
|
| func TestRefundTaskQuota_Subscription(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID, subID = 2, 2, 2, 1 |
| const preConsumed = 2000 |
| const subTotal, subUsed int64 = 100000, 50000 |
| const tokenRemain = 8000 |
|
|
| seedUser(t, userID, 0) |
| seedToken(t, tokenID, userID, "sk-sub-key", tokenRemain) |
| seedChannel(t, channelID) |
| seedSubscription(t, subID, userID, subTotal, subUsed) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) |
|
|
| RefundTaskQuota(ctx, task, "subscription task failed") |
|
|
| |
| assert.Equal(t, subUsed-int64(preConsumed), getSubscriptionUsed(t, subID)) |
|
|
| |
| assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) |
|
|
| log := getLastLog(t) |
| require.NotNil(t, log) |
| assert.Equal(t, model.LogTypeRefund, log.Type) |
| } |
|
|
| func TestRefundTaskQuota_ZeroQuota(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID = 3 |
| seedUser(t, userID, 5000) |
|
|
| task := makeTask(userID, 0, 0, 0, BillingSourceWallet, 0) |
|
|
| RefundTaskQuota(ctx, task, "zero quota task") |
|
|
| |
| assert.Equal(t, 5000, getUserQuota(t, userID)) |
|
|
| |
| assert.Equal(t, int64(0), countLogs(t)) |
| } |
|
|
| func TestRefundTaskQuota_NoToken(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, channelID = 4, 4 |
| const initQuota, preConsumed = 10000, 1500 |
|
|
| seedUser(t, userID, initQuota) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) |
|
|
| RefundTaskQuota(ctx, task, "no token task failed") |
|
|
| |
| assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) |
|
|
| |
| log := getLastLog(t) |
| require.NotNil(t, log) |
| assert.Equal(t, model.LogTypeRefund, log.Type) |
| } |
|
|
| |
| |
| |
|
|
| func TestRecalculate_PositiveDelta(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 10, 10, 10 |
| const initQuota, preConsumed = 10000, 2000 |
| const actualQuota = 3000 |
| const tokenRemain = 5000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
|
|
| RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") |
|
|
| |
| assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID)) |
|
|
| |
| assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID)) |
|
|
| |
| assert.Equal(t, actualQuota, task.Quota) |
|
|
| |
| log := getLastLog(t) |
| require.NotNil(t, log) |
| assert.Equal(t, model.LogTypeConsume, log.Type) |
| assert.Equal(t, actualQuota-preConsumed, log.Quota) |
| } |
|
|
| func TestRecalculate_NegativeDelta(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 11, 11, 11 |
| const initQuota, preConsumed = 10000, 5000 |
| const actualQuota = 3000 |
| const tokenRemain = 5000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
|
|
| RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") |
|
|
| |
| assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) |
|
|
| |
| assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) |
|
|
| |
| assert.Equal(t, actualQuota, task.Quota) |
|
|
| |
| log := getLastLog(t) |
| require.NotNil(t, log) |
| assert.Equal(t, model.LogTypeRefund, log.Type) |
| assert.Equal(t, preConsumed-actualQuota, log.Quota) |
| } |
|
|
| func TestRecalculate_ZeroDelta(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID = 12 |
| const initQuota, preConsumed = 10000, 3000 |
|
|
| seedUser(t, userID, initQuota) |
|
|
| task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0) |
|
|
| RecalculateTaskQuota(ctx, task, preConsumed, "exact match") |
|
|
| |
| assert.Equal(t, initQuota, getUserQuota(t, userID)) |
|
|
| |
| assert.Equal(t, int64(0), countLogs(t)) |
| } |
|
|
| func TestRecalculate_ActualQuotaZero(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID = 13 |
| const initQuota = 10000 |
|
|
| seedUser(t, userID, initQuota) |
|
|
| task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0) |
|
|
| RecalculateTaskQuota(ctx, task, 0, "zero actual") |
|
|
| |
| assert.Equal(t, initQuota, getUserQuota(t, userID)) |
| assert.Equal(t, int64(0), countLogs(t)) |
| } |
|
|
| func TestRecalculate_Subscription_NegativeDelta(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID, subID = 14, 14, 14, 2 |
| const preConsumed = 5000 |
| const actualQuota = 2000 |
| const subTotal, subUsed int64 = 100000, 50000 |
| const tokenRemain = 8000 |
|
|
| seedUser(t, userID, 0) |
| seedToken(t, tokenID, userID, "sk-sub-recalc", tokenRemain) |
| seedChannel(t, channelID) |
| seedSubscription(t, subID, userID, subTotal, subUsed) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) |
|
|
| RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge") |
|
|
| |
| assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID)) |
|
|
| |
| assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) |
|
|
| assert.Equal(t, actualQuota, task.Quota) |
|
|
| log := getLastLog(t) |
| require.NotNil(t, log) |
| assert.Equal(t, model.LogTypeRefund, log.Type) |
| } |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model.TaskStatus, actualQuota int) { |
| snap := task.Snapshot() |
|
|
| shouldRefund := false |
| shouldSettle := false |
| quota := task.Quota |
|
|
| task.Status = newStatus |
| switch string(newStatus) { |
| case model.TaskStatusSuccess: |
| task.Progress = "100%" |
| task.FinishTime = 9999 |
| shouldSettle = true |
| case model.TaskStatusFailure: |
| task.Progress = "100%" |
| task.FinishTime = 9999 |
| task.FailReason = "upstream error" |
| if quota != 0 { |
| shouldRefund = true |
| } |
| default: |
| task.Progress = "50%" |
| } |
|
|
| isDone := task.Status == model.TaskStatus(model.TaskStatusSuccess) || task.Status == model.TaskStatus(model.TaskStatusFailure) |
| if isDone && snap.Status != task.Status { |
| won, err := task.UpdateWithStatus(snap.Status) |
| if err != nil { |
| shouldRefund = false |
| shouldSettle = false |
| } else if !won { |
| shouldRefund = false |
| shouldSettle = false |
| } |
| } else if !snap.Equal(task.Snapshot()) { |
| _, _ = task.UpdateWithStatus(snap.Status) |
| } |
|
|
| if shouldSettle && actualQuota > 0 { |
| RecalculateTaskQuota(ctx, task, actualQuota, "test settle") |
| } |
| if shouldRefund { |
| RefundTaskQuota(ctx, task, task.FailReason) |
| } |
| } |
|
|
| func TestCASGuardedRefund_Win(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 20, 20, 20 |
| const initQuota, preConsumed = 10000, 4000 |
| const tokenRemain = 6000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-cas-refund-win", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
| task.Status = model.TaskStatus(model.TaskStatusInProgress) |
| require.NoError(t, model.DB.Create(task).Error) |
|
|
| simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) |
|
|
| |
| var reloaded model.Task |
| require.NoError(t, model.DB.First(&reloaded, task.ID).Error) |
| assert.EqualValues(t, model.TaskStatusFailure, reloaded.Status) |
|
|
| |
| assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) |
| assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) |
|
|
| log := getLastLog(t) |
| require.NotNil(t, log) |
| assert.Equal(t, model.LogTypeRefund, log.Type) |
| } |
|
|
| func TestCASGuardedRefund_Lose(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 21, 21, 21 |
| const initQuota, preConsumed = 10000, 4000 |
| const tokenRemain = 6000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-cas-refund-lose", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| |
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
| task.Status = model.TaskStatus(model.TaskStatusInProgress) |
| require.NoError(t, model.DB.Create(task).Error) |
|
|
| |
| model.DB.Model(&model.Task{}).Where("id = ?", task.ID).Update("status", model.TaskStatusFailure) |
|
|
| |
| |
| simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusFailure), 0) |
|
|
| |
| assert.Equal(t, initQuota, getUserQuota(t, userID)) |
| assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) |
|
|
| |
| assert.Equal(t, int64(0), countLogs(t)) |
| } |
|
|
| func TestCASGuardedSettle_Win(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 22, 22, 22 |
| const initQuota, preConsumed = 10000, 5000 |
| const actualQuota = 3000 |
| const tokenRemain = 8000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-cas-settle-win", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
| task.Status = model.TaskStatus(model.TaskStatusInProgress) |
| require.NoError(t, model.DB.Create(task).Error) |
|
|
| simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusSuccess), actualQuota) |
|
|
| |
| var reloaded model.Task |
| require.NoError(t, model.DB.First(&reloaded, task.ID).Error) |
| assert.EqualValues(t, model.TaskStatusSuccess, reloaded.Status) |
|
|
| |
| assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) |
| assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) |
|
|
| |
| assert.Equal(t, actualQuota, task.Quota) |
| } |
|
|
| func TestNonTerminalUpdate_NoBilling(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, channelID = 23, 23 |
| const initQuota, preConsumed = 10000, 3000 |
|
|
| seedUser(t, userID, initQuota) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, 0, BillingSourceWallet, 0) |
| task.Status = model.TaskStatus(model.TaskStatusInProgress) |
| task.Progress = "20%" |
| require.NoError(t, model.DB.Create(task).Error) |
|
|
| |
| simulatePollBilling(ctx, task, model.TaskStatus(model.TaskStatusInProgress), 0) |
|
|
| |
| assert.Equal(t, initQuota, getUserQuota(t, userID)) |
|
|
| |
| assert.Equal(t, int64(0), countLogs(t)) |
|
|
| |
| var reloaded model.Task |
| require.NoError(t, model.DB.First(&reloaded, task.ID).Error) |
| assert.Equal(t, "50%", reloaded.Progress) |
| } |
|
|
| |
| |
| |
|
|
| type mockAdaptor struct { |
| adjustReturn int |
| } |
|
|
| func (m *mockAdaptor) Init(_ *relaycommon.RelayInfo) {} |
| func (m *mockAdaptor) FetchTask(string, string, map[string]any, string) (*http.Response, error) { |
| return nil, nil |
| } |
| func (m *mockAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) { return nil, nil } |
| func (m *mockAdaptor) AdjustBillingOnComplete(_ *model.Task, _ *relaycommon.TaskInfo) int { |
| return m.adjustReturn |
| } |
|
|
| |
| |
| |
|
|
| func TestSettle_PerCallBilling_SkipsAdaptorAdjust(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 30, 30, 30 |
| const initQuota, preConsumed = 10000, 5000 |
| const tokenRemain = 8000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-percall-adaptor", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
| task.PrivateData.BillingContext.PerCallBilling = true |
|
|
| adaptor := &mockAdaptor{adjustReturn: 2000} |
| taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} |
|
|
| settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) |
|
|
| |
| assert.Equal(t, initQuota, getUserQuota(t, userID)) |
| assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) |
| assert.Equal(t, preConsumed, task.Quota) |
| assert.Equal(t, int64(0), countLogs(t)) |
| } |
|
|
| func TestSettle_PerCallBilling_SkipsTotalTokens(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 31, 31, 31 |
| const initQuota, preConsumed = 10000, 4000 |
| const tokenRemain = 7000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-percall-tokens", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
| task.PrivateData.BillingContext.PerCallBilling = true |
|
|
| adaptor := &mockAdaptor{adjustReturn: 0} |
| taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess, TotalTokens: 9999} |
|
|
| settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) |
|
|
| |
| assert.Equal(t, initQuota, getUserQuota(t, userID)) |
| assert.Equal(t, tokenRemain, getTokenRemainQuota(t, tokenID)) |
| assert.Equal(t, preConsumed, task.Quota) |
| assert.Equal(t, int64(0), countLogs(t)) |
| } |
|
|
| func TestSettle_NonPerCall_AdaptorAdjustWorks(t *testing.T) { |
| truncate(t) |
| ctx := context.Background() |
|
|
| const userID, tokenID, channelID = 32, 32, 32 |
| const initQuota, preConsumed = 10000, 5000 |
| const adaptorQuota = 3000 |
| const tokenRemain = 8000 |
|
|
| seedUser(t, userID, initQuota) |
| seedToken(t, tokenID, userID, "sk-nonpercall-adj", tokenRemain) |
| seedChannel(t, channelID) |
|
|
| task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) |
| |
|
|
| adaptor := &mockAdaptor{adjustReturn: adaptorQuota} |
| taskResult := &relaycommon.TaskInfo{Status: model.TaskStatusSuccess} |
|
|
| settleTaskBillingOnComplete(ctx, adaptor, task, taskResult) |
|
|
| |
| assert.Equal(t, initQuota+(preConsumed-adaptorQuota), getUserQuota(t, userID)) |
| assert.Equal(t, tokenRemain+(preConsumed-adaptorQuota), getTokenRemainQuota(t, tokenID)) |
| assert.Equal(t, adaptorQuota, task.Quota) |
|
|
| log := getLastLog(t) |
| require.NotNil(t, log) |
| assert.Equal(t, model.LogTypeRefund, log.Type) |
| } |
|
|