| package main |
|
|
| import ( |
| "bytes" |
| "context" |
| "encoding/json" |
| "fmt" |
| "io" |
| "log" |
| "net/http" |
| "os" |
| "strings" |
| "sync" |
| "time" |
|
|
| "github.com/gin-gonic/gin" |
| "github.com/google/generative-ai-go/genai" |
| "google.golang.org/api/option" |
| ) |
|
|
| |
| type Config struct { |
| AnthropicKey string |
| GoogleKey string |
| ServiceURL string |
| DeepseekURL string |
| OpenAIURL string |
| } |
|
|
| var ( |
| config Config |
| configOnce sync.Once |
| ) |
|
|
| |
| type TokenCountRequest struct { |
| Model string `json:"model" binding:"required"` |
| Messages []Message `json:"messages" binding:"required"` |
| System *string `json:"system,omitempty"` |
| } |
|
|
| type Message struct { |
| Role string `json:"role" binding:"required"` |
| Content string `json:"content" binding:"required"` |
| } |
|
|
| |
| type TokenCountResponse struct { |
| InputTokens int `json:"input_tokens"` |
| } |
|
|
| |
| type ErrorResponse struct { |
| Error string `json:"error"` |
| } |
|
|
| |
| type ModelRule struct { |
| Keywords []string |
| Target string |
| } |
|
|
| var modelRules = []ModelRule{ |
| { |
| Keywords: []string{"deepseek"}, |
| Target: "deepseek-v3", |
| }, |
| |
| { |
| Keywords: []string{"claude", "3", "5", "sonnet"}, |
| Target: "claude-3-5-sonnet-latest", |
| }, |
| { |
| Keywords: []string{"claude", "3", "5", "haiku"}, |
| Target: "claude-3-5-haiku-latest", |
| }, |
| { |
| Keywords: []string{"claude", "3", "7"}, |
| Target: "claude-3-7-sonnet-latest", |
| }, |
| { |
| Keywords: []string{"claude", "3", "opus"}, |
| Target: "claude-3-opus-latest", |
| }, |
| { |
| Keywords: []string{"claude", "3", "haiku"}, |
| Target: "claude-3-haiku-20240307", |
| }, |
| |
| { |
| Keywords: []string{"claude", "3", "sonnet"}, |
| Target: "claude-3-sonnet-20240229", |
| }, |
| { |
| Keywords: []string{"gemini", "2.0"}, |
| Target: "gemini-2.0-flash", |
| }, |
| { |
| Keywords: []string{"gemini", "2.5"}, |
| Target: "gemini-2.0-flash", |
| }, |
| { |
| Keywords: []string{"gemini", "1.5"}, |
| Target: "gemini-1.5-flash", |
| }, |
| } |
|
|
| |
| func matchModelName(input string) string { |
| |
| input = strings.ToLower(input) |
| log.Printf("正在匹配模型名称: %s", input) |
|
|
| |
| if strings.Contains(input, "claude") && strings.Contains(input, "3.5") || |
| strings.Contains(input, "claude") && strings.Contains(input, "3") && strings.Contains(input, "5") { |
| if strings.Contains(input, "sonnet") { |
| log.Printf("匹配到Claude 3.5 Sonnet") |
| return "claude-3-5-sonnet-latest" |
| } else if strings.Contains(input, "haiku") { |
| log.Printf("匹配到Claude 3.5 Haiku") |
| return "claude-3-5-haiku-latest" |
| } else { |
| |
| log.Printf("匹配到Claude 3.5 (默认使用Sonnet)") |
| return "claude-3-5-sonnet-latest" |
| } |
| } |
|
|
| |
| if strings.Contains(input, "claude") && strings.Contains(input, "3.7") || |
| strings.Contains(input, "claude") && strings.Contains(input, "3") && strings.Contains(input, "7") { |
| log.Printf("匹配到Claude 3.7") |
| return "claude-3-7-sonnet-latest" |
| } |
|
|
| |
| if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) || |
| strings.Contains(input, "o1") || |
| strings.Contains(input, "o3") { |
| log.Printf("匹配到GPT-4o") |
| return "gpt-4o" |
| } |
|
|
| |
| if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) || |
| (strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) { |
| log.Printf("匹配到GPT-4") |
| return "gpt-4" |
| } |
|
|
| |
| for _, rule := range modelRules { |
| matches := true |
| for _, keyword := range rule.Keywords { |
| if !strings.Contains(input, strings.ToLower(keyword)) { |
| matches = false |
| break |
| } |
| } |
| if matches { |
| log.Printf("通过规则匹配到: %s", rule.Target) |
| return rule.Target |
| } |
| } |
|
|
| |
| log.Printf("没有匹配到任何规则,使用原始输入: %s", input) |
| return input |
| } |
|
|
| |
| func loadConfig() Config { |
| configOnce.Do(func() { |
| |
| log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile) |
| log.Println("开始加载配置...") |
|
|
| config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY") |
| if config.AnthropicKey == "" { |
| log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用") |
| } else { |
| log.Println("Anthropic API Key已配置") |
| } |
|
|
| config.GoogleKey = os.Getenv("GOOGLE_API_KEY") |
| if config.GoogleKey == "" { |
| log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用") |
| } else { |
| log.Println("Google API Key已配置") |
| } |
|
|
| |
| config.DeepseekURL = os.Getenv("DEEPSEEK_URL") |
| if config.DeepseekURL == "" { |
| config.DeepseekURL = "http://127.0.0.1:7861" |
| log.Println("使用默认Deepseek服务地址:", config.DeepseekURL) |
| } else { |
| log.Println("使用配置的Deepseek服务地址:", config.DeepseekURL) |
| } |
|
|
| |
| config.OpenAIURL = os.Getenv("OPENAI_URL") |
| if config.OpenAIURL == "" { |
| config.OpenAIURL = "http://127.0.0.1:7862" |
| log.Println("使用默认OpenAI服务地址:", config.OpenAIURL) |
| } else { |
| log.Println("使用配置的OpenAI服务地址:", config.OpenAIURL) |
| } |
|
|
| |
| config.ServiceURL = os.Getenv("SERVICE_URL") |
| if config.ServiceURL == "" { |
| log.Println("SERVICE_URL 未设置,防休眠功能将被禁用") |
| } else { |
| log.Println("防休眠URL已配置:", config.ServiceURL) |
| } |
|
|
| log.Println("配置加载完成") |
| }) |
| return config |
| } |
|
|
| |
| func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) { |
| |
| log.Printf("开始Claude API请求: 模型=%s, 消息数量=%d", req.Model, len(req.Messages)) |
|
|
| |
| var filteredMessages []Message |
| for i, msg := range req.Messages { |
| if msg.Content == "" { |
| log.Printf("警告: 消息 #%d 内容为空,将被过滤掉", i) |
| continue |
| } |
| if msg.Role != "user" && msg.Role != "assistant" { |
| log.Printf("警告: 消息 #%d 角色'%s'不是标准角色(user/assistant),可能导致请求失败", i, msg.Role) |
| } |
| filteredMessages = append(filteredMessages, msg) |
| } |
|
|
| if len(filteredMessages) == 0 { |
| log.Printf("错误: 过滤后没有有效消息") |
| return TokenCountResponse{}, fmt.Errorf("没有有效消息:所有消息内容都为空") |
| } |
|
|
| |
| filteredReq := TokenCountRequest{ |
| Model: req.Model, |
| Messages: filteredMessages, |
| System: req.System, |
| } |
|
|
| |
| if len(filteredMessages) != len(req.Messages) { |
| log.Printf("消息过滤: 原始消息数=%d, 过滤后消息数=%d", len(req.Messages), len(filteredMessages)) |
| } |
|
|
| client := &http.Client{} |
| data, err := json.Marshal(filteredReq) |
| if err != nil { |
| log.Printf("错误: 序列化Claude请求失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) |
| } |
|
|
| |
| if len(data) < 1000 { |
| log.Printf("Claude请求内容: %s", string(data)) |
| } else { |
| log.Printf("Claude请求内容较大,长度=%d字节", len(data)) |
| } |
|
|
| |
| request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data)) |
| if err != nil { |
| log.Printf("错误: 创建Claude请求失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) |
| } |
|
|
| |
| request.Header.Set("x-api-key", config.AnthropicKey) |
| request.Header.Set("anthropic-version", "2023-06-01") |
| request.Header.Set("content-type", "application/json") |
|
|
| |
| response, err := client.Do(request) |
| if err != nil { |
| log.Printf("错误: 发送请求到Anthropic API失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err) |
| } |
| defer response.Body.Close() |
|
|
| |
| if response.StatusCode != http.StatusOK { |
| |
| var errorBody []byte |
| errorBody, _ = io.ReadAll(response.Body) |
| log.Printf("错误: Claude API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) |
|
|
| |
| errorStr := string(errorBody) |
| if response.StatusCode == http.StatusUnauthorized || strings.Contains(errorStr, "invalid_api_key") { |
| log.Printf("错误: Claude API密钥无效或过期") |
| return TokenCountResponse{}, fmt.Errorf("Claude API验证失败,请检查API Key是否有效: %s", string(errorBody)) |
| } else if response.StatusCode == http.StatusBadRequest { |
| if strings.Contains(errorStr, "empty content") { |
| log.Printf("错误: 请求包含空内容的消息") |
| return TokenCountResponse{}, fmt.Errorf("请求格式错误: 消息不能有空内容: %s", string(errorBody)) |
| } else if strings.Contains(errorStr, "invalid_request_error") { |
| log.Printf("错误: 无效的请求格式") |
| return TokenCountResponse{}, fmt.Errorf("无效的请求格式: %s", string(errorBody)) |
| } |
| } |
|
|
| return TokenCountResponse{}, fmt.Errorf("Claude API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) |
| } |
|
|
| |
| var result TokenCountResponse |
| if err := json.NewDecoder(response.Body).Decode(&result); err != nil { |
| log.Printf("错误: 解码Claude响应失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) |
| } |
|
|
| log.Printf("Claude API请求成功: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) |
| return result, nil |
| } |
|
|
| |
| func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) { |
| |
| log.Printf("开始Gemini API请求: 模型=%s, 消息数量=%d", req.Model, len(req.Messages)) |
|
|
| if config.GoogleKey == "" { |
| log.Printf("错误: Gemini API密钥未设置") |
| return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置") |
| } |
|
|
| |
| ctx := context.Background() |
| client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey)) |
| if err != nil { |
| log.Printf("错误: 创建Gemini客户端失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err) |
| } |
| defer client.Close() |
|
|
| |
| modelName := req.Model |
| log.Printf("使用Gemini模型: %s", modelName) |
|
|
| |
| model := client.GenerativeModel(modelName) |
|
|
| |
| var content string |
| if req.System != nil && *req.System != "" { |
| content += *req.System + "\n\n" |
| log.Printf("Gemini请求包含系统提示: %s", *req.System) |
| } |
|
|
| for _, msg := range req.Messages { |
| if msg.Role == "user" { |
| content += "用户: " + msg.Content + "\n" |
| } else if msg.Role == "assistant" { |
| content += "助手: " + msg.Content + "\n" |
| } else { |
| content += msg.Role + ": " + msg.Content + "\n" |
| } |
| } |
|
|
| |
| log.Printf("开始计算Gemini tokens...") |
| tokResp, err := model.CountTokens(ctx, genai.Text(content)) |
| if err != nil { |
| log.Printf("错误: 计算Gemini token失败: %v", err) |
| if strings.Contains(err.Error(), "invalid_api_key") || strings.Contains(err.Error(), "permission_denied") { |
| log.Printf("错误: Gemini API密钥可能无效或过期") |
| } |
| return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err) |
| } |
|
|
| log.Printf("Gemini API请求成功: 模型=%s, 输入tokens=%d", req.Model, tokResp.TotalTokens) |
| return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil |
| } |
|
|
| |
| func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) { |
| log.Printf("开始Deepseek API请求: 模型=%s, 消息数量=%d, 服务地址=%s", req.Model, len(req.Messages), config.DeepseekURL) |
|
|
| |
| client := &http.Client{} |
| data, err := json.Marshal(req) |
| if err != nil { |
| log.Printf("错误: 序列化Deepseek请求失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) |
| } |
|
|
| |
| requestURL := config.DeepseekURL + "/count_tokens" |
| log.Printf("发送请求到Deepseek服务: %s", requestURL) |
| request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data)) |
| if err != nil { |
| log.Printf("错误: 创建Deepseek请求失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) |
| } |
|
|
| |
| request.Header.Set("Content-Type", "application/json") |
|
|
| |
| response, err := client.Do(request) |
| if err != nil { |
| log.Printf("错误: 发送请求到Deepseek服务失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err) |
| } |
| defer response.Body.Close() |
|
|
| |
| if response.StatusCode != http.StatusOK { |
| |
| var errorBody []byte |
| errorBody, _ = io.ReadAll(response.Body) |
| log.Printf("错误: Deepseek API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) |
| return TokenCountResponse{}, fmt.Errorf("Deepseek API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) |
| } |
|
|
| |
| var result TokenCountResponse |
| if err := json.NewDecoder(response.Body).Decode(&result); err != nil { |
| log.Printf("错误: 解码Deepseek响应失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) |
| } |
|
|
| log.Printf("Deepseek API请求成功: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) |
| return result, nil |
| } |
|
|
| |
| func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) { |
| log.Printf("开始OpenAI API请求: 模型=%s, 消息数量=%d, 服务地址=%s", req.Model, len(req.Messages), config.OpenAIURL) |
|
|
| |
| client := &http.Client{} |
| data, err := json.Marshal(req) |
| if err != nil { |
| log.Printf("错误: 序列化OpenAI请求失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) |
| } |
|
|
| |
| requestURL := config.OpenAIURL + "/count_tokens" |
| log.Printf("发送请求到OpenAI服务: %s", requestURL) |
| request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data)) |
| if err != nil { |
| log.Printf("错误: 创建OpenAI请求失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) |
| } |
|
|
| |
| request.Header.Set("Content-Type", "application/json") |
|
|
| |
| response, err := client.Do(request) |
| if err != nil { |
| log.Printf("错误: 发送请求到OpenAI服务失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err) |
| } |
| defer response.Body.Close() |
|
|
| |
| if response.StatusCode != http.StatusOK { |
| |
| var errorBody []byte |
| errorBody, _ = io.ReadAll(response.Body) |
| log.Printf("错误: OpenAI API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) |
| return TokenCountResponse{}, fmt.Errorf("OpenAI API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) |
| } |
|
|
| |
| var result struct { |
| InputTokens int `json:"input_tokens"` |
| Model string `json:"model"` |
| Encoding string `json:"encoding"` |
| } |
| if err := json.NewDecoder(response.Body).Decode(&result); err != nil { |
| log.Printf("错误: 解码OpenAI响应失败: %v", err) |
| return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) |
| } |
|
|
| log.Printf("OpenAI API请求成功: 模型=%s(实际使用=%s), 编码=%s, 输入tokens=%d", |
| req.Model, result.Model, result.Encoding, result.InputTokens) |
| return TokenCountResponse{InputTokens: result.InputTokens}, nil |
| } |
|
|
| |
| func countTokens(c *gin.Context) { |
| var req TokenCountRequest |
| if err := c.ShouldBindJSON(&req); err != nil { |
| log.Printf("错误: 无效的请求格式: %v", err) |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()}) |
| return |
| } |
|
|
| |
| systemPrompt := "无" |
| if req.System != nil && *req.System != "" { |
| systemPrompt = *req.System |
| } |
| log.Printf("收到token计算请求: 原始模型=%s, 消息数量=%d, 系统提示=%s", |
| req.Model, len(req.Messages), systemPrompt) |
|
|
| |
| originalModel := req.Model |
|
|
| |
| isUnsupportedModel := true |
|
|
| |
| modelLower := strings.ToLower(req.Model) |
| if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") || |
| strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") || |
| strings.HasPrefix(modelLower, "claude") || |
| strings.Contains(modelLower, "gemini") || |
| strings.Contains(modelLower, "deepseek") { |
| isUnsupportedModel = false |
| } |
|
|
| |
| req.Model = matchModelName(req.Model) |
| log.Printf("模型名称匹配结果: 原始=%s -> 匹配=%s", originalModel, req.Model) |
|
|
| var result TokenCountResponse |
| var err error |
|
|
| |
| if strings.Contains(strings.ToLower(req.Model), "deepseek") { |
| log.Printf("使用Deepseek API计算token") |
| |
| result, err = countTokensWithDeepseek(req) |
| } else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") { |
| log.Printf("使用OpenAI API计算token") |
| |
| result, err = countTokensWithOpenAI(req) |
| } else if strings.HasPrefix(strings.ToLower(req.Model), "claude") { |
| log.Printf("使用Claude API计算token") |
| |
| if config.AnthropicKey == "" { |
| log.Printf("错误: ANTHROPIC_API_KEY未设置") |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"}) |
| return |
| } |
| result, err = countTokensWithClaude(req) |
| } else if strings.Contains(strings.ToLower(req.Model), "gemini") { |
| log.Printf("使用Gemini API计算token") |
| |
| if config.GoogleKey == "" { |
| log.Printf("错误: GOOGLE_API_KEY未设置") |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"}) |
| return |
| } |
| result, err = countTokensWithGemini(req) |
| } else if isUnsupportedModel { |
| log.Printf("不支持的模型: %s, 将使用GPT-4o估算", originalModel) |
| |
| |
| gptReq := req |
| gptReq.Model = "gpt-4o" |
|
|
| |
| estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) |
|
|
| if estimateErr == nil { |
| log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) |
| |
| c.JSON(http.StatusBadRequest, gin.H{ |
| "input_tokens": estimatedResult.InputTokens, |
| "warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), |
| "estimated_with": "gpt-4o", |
| "error": fmt.Sprintf("Unsupported model: %s", originalModel), |
| }) |
| return |
| } else { |
| log.Printf("使用GPT-4o估算失败: %v", estimateErr) |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("Failed to estimate tokens for unsupported model: %s", originalModel)}) |
| return |
| } |
| } else { |
| log.Printf("完全不支持的模型: %s, 将尝试使用GPT-4o估算", originalModel) |
| |
| |
| gptReq := req |
| gptReq.Model = "gpt-4o" |
|
|
| estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) |
| if estimateErr == nil { |
| log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) |
| c.JSON(http.StatusBadRequest, gin.H{ |
| "input_tokens": estimatedResult.InputTokens, |
| "warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), |
| "estimated_with": "gpt-4o", |
| "error": fmt.Sprintf("Unsupported model: %s", originalModel), |
| }) |
| } else { |
| log.Printf("使用GPT-4o估算失败: %v", estimateErr) |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)}) |
| } |
| return |
| } |
|
|
| if err != nil { |
| log.Printf("计算token失败: %v", err) |
|
|
| |
| log.Printf("API调用失败,尝试使用GPT-4o估算: 原始模型=%s, 错误=%v", req.Model, err) |
|
|
| |
| gptReq := req |
| gptReq.Model = "gpt-4o" |
|
|
| |
| estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) |
|
|
| if estimateErr == nil { |
| log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) |
|
|
| |
| c.JSON(http.StatusBadRequest, gin.H{ |
| "input_tokens": estimatedResult.InputTokens, |
| "warning": fmt.Sprintf("Token calculation for model '%s' failed. This is an estimation based on gpt-4o and may not be accurate.", originalModel), |
| "estimated_with": "gpt-4o", |
| "error": err.Error(), |
| }) |
| return |
| } else { |
| log.Printf("使用GPT-4o估算也失败: %v", estimateErr) |
| |
| c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()}) |
| return |
| } |
| } |
|
|
| |
| log.Printf("成功计算token: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) |
| c.JSON(http.StatusOK, result) |
| } |
|
|
| |
| func healthCheck(c *gin.Context) { |
| c.JSON(http.StatusOK, gin.H{ |
| "status": "healthy", |
| "time": time.Now().Format(time.RFC3339), |
| }) |
| } |
|
|
| |
| func startKeepAlive() { |
| if config.ServiceURL == "" { |
| return |
| } |
|
|
| healthURL := fmt.Sprintf("%s/health", config.ServiceURL) |
| ticker := time.NewTicker(10 * time.Hour) |
|
|
| |
| go func() { |
| log.Printf("Starting keep-alive checks to %s", healthURL) |
| for { |
| resp, err := http.Get(healthURL) |
| if err != nil { |
| log.Printf("Keep-alive check failed: %v", err) |
| } else { |
| resp.Body.Close() |
| log.Printf("Keep-alive check successful") |
| } |
|
|
| |
| <-ticker.C |
| } |
| }() |
| } |
|
|
| func main() { |
| |
| loadConfig() |
| log.Println("=== Token计算服务启动 ===") |
|
|
| |
| gin.SetMode(gin.ReleaseMode) |
| log.Println("设置Gin为发布模式") |
|
|
| |
| r := gin.Default() |
| log.Println("创建Gin路由") |
|
|
| |
| r.Use(gin.Recovery()) |
| r.Use(func(c *gin.Context) { |
| |
| startTime := time.Now() |
|
|
| |
| log.Printf("收到请求: %s %s 来自 %s", c.Request.Method, c.Request.URL.Path, c.ClientIP()) |
|
|
| c.Writer.Header().Set("Access-Control-Allow-Origin", "*") |
| c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") |
| c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type") |
| if c.Request.Method == "OPTIONS" { |
| c.AbortWithStatus(204) |
| return |
| } |
|
|
| |
| c.Next() |
|
|
| |
| endTime := time.Now() |
| latency := endTime.Sub(startTime) |
|
|
| |
| log.Printf("请求完成: %s %s 状态=%d 耗时=%v", |
| c.Request.Method, c.Request.URL.Path, c.Writer.Status(), latency) |
| }) |
|
|
| |
| r.GET("/health", healthCheck) |
| r.POST("/count_tokens", countTokens) |
| log.Println("配置路由: GET /health, POST /count_tokens") |
|
|
| |
| port := os.Getenv("PORT") |
| if port == "" { |
| port = "7860" |
| log.Println("使用默认端口: 7860") |
| } else { |
| log.Println("使用配置端口:", port) |
| } |
|
|
| |
| startKeepAlive() |
|
|
| |
| log.Printf("=== 服务器启动在端口 %s ===", port) |
| if err := r.Run(":" + port); err != nil { |
| log.Fatalf("服务器启动失败: %v", err) |
| } |
| } |
|
|