File size: 4,686 Bytes
daa8246 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | package relay
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/logger"
relaycommon "github.com/QuantumNous/new-api/relay/common"
"github.com/QuantumNous/new-api/relay/helper"
"github.com/QuantumNous/new-api/service"
"github.com/QuantumNous/new-api/setting/model_setting"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
imageReq, ok := info.Request.(*dto.ImageRequest)
if !ok {
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
request, err := common.DeepCopy(imageReq)
if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
}
err = helper.ModelMappedHelper(c, info, request)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(info)
var requestBody io.Reader
if model_setting.GetGlobalSettings().PassThroughRequestEnabled || info.ChannelSetting.PassThroughBodyEnabled {
storage, err := common.GetBodyStorage(c)
if err != nil {
return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
}
requestBody = common.ReaderOnly(storage)
} else {
convertedRequest, err := adaptor.ConvertImageRequest(c, info, *request)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
relaycommon.AppendRequestConversionFromRequest(info, convertedRequest)
switch convertedRequest.(type) {
case *bytes.Buffer:
requestBody = convertedRequest.(io.Reader)
default:
jsonData, err := common.Marshal(convertedRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
}
// apply param override
if len(info.ParamOverride) > 0 {
jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info)
if err != nil {
return newAPIErrorFromParamOverride(err)
}
}
if common.DebugEnabled {
logger.LogDebug(c, fmt.Sprintf("image request body: %s", string(jsonData)))
}
requestBody = bytes.NewBuffer(jsonData)
}
}
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, info, requestBody)
if err != nil {
return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
}
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream")
if httpResp.StatusCode != http.StatusOK {
if httpResp.StatusCode == http.StatusCreated && info.ApiType == constant.APITypeReplicate {
// replicate channel returns 201 Created when using Prefer: wait, treat it as success.
httpResp.StatusCode = http.StatusOK
} else {
newAPIError = service.RelayErrorHandler(c.Request.Context(), httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
imageN := uint(1)
if request.N != nil {
imageN = *request.N
}
if usage.(*dto.Usage).TotalTokens == 0 {
usage.(*dto.Usage).TotalTokens = int(imageN)
}
if usage.(*dto.Usage).PromptTokens == 0 {
usage.(*dto.Usage).PromptTokens = int(imageN)
}
quality := "standard"
if request.Quality == "hd" {
quality = "hd"
}
var logContent []string
if len(request.Size) > 0 {
logContent = append(logContent, fmt.Sprintf("大小 %s", request.Size))
}
if len(quality) > 0 {
logContent = append(logContent, fmt.Sprintf("品质 %s", quality))
}
if imageN > 0 {
logContent = append(logContent, fmt.Sprintf("生成数量 %d", imageN))
}
postConsumeQuota(c, info, usage.(*dto.Usage), logContent...)
return nil
}
|