| |
| """ |
| OpenAI to Augment API Adapter |
| |
| 这个FastAPI应用程序将OpenAI API请求格式转换为Augment API格式, |
| 允许OpenAI客户端直接与Augment服务通信。 |
| 所有配置参数都通过命令行参数提供,不依赖于环境变量或配置文件。 |
| """ |
|
|
| import os |
| import json |
| import uuid |
| import time |
| import logging |
| import argparse |
| from typing import List, Optional, Dict, Any, Literal, Union |
| from datetime import datetime |
|
|
| import httpx |
| from fastapi import FastAPI, Header, HTTPException, Depends, Request |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
| import uvicorn |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| |
| class ChatMessage(BaseModel): |
| """表示OpenAI聊天API中的单条消息""" |
| role: Literal["system", "user", "assistant", "function"] |
| content: Optional[str] = None |
| name: Optional[str] = None |
|
|
|
|
| class ChatCompletionRequest(BaseModel): |
| """OpenAI聊天完成API请求模型""" |
| model: str |
| messages: List[ChatMessage] |
| temperature: Optional[float] = 1.0 |
| top_p: Optional[float] = 1.0 |
| n: Optional[int] = 1 |
| stream: Optional[bool] = False |
| max_tokens: Optional[int] = None |
| presence_penalty: Optional[float] = 0 |
| frequency_penalty: Optional[float] = 0 |
| user: Optional[str] = None |
|
|
|
|
| |
| class ChatCompletionResponseChoice(BaseModel): |
| """OpenAI聊天完成API响应中的单个选择""" |
| index: int |
| message: ChatMessage |
| finish_reason: Optional[str] = None |
|
|
|
|
| class Usage(BaseModel): |
| """OpenAI API响应中的token使用信息""" |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
|
|
|
|
| class ChatCompletionResponse(BaseModel): |
| """OpenAI聊天完成API响应模型""" |
| id: str |
| object: str = "chat.completion" |
| created: int |
| model: str |
| choices: List[ChatCompletionResponseChoice] |
| usage: Usage |
|
|
|
|
| |
| class ChatCompletionStreamResponseChoice(BaseModel): |
| """OpenAI聊天完成流式API响应中的单个选择""" |
| index: int |
| delta: Dict[str, Any] |
| finish_reason: Optional[str] = None |
|
|
|
|
| class ChatCompletionStreamResponse(BaseModel): |
| """OpenAI聊天完成流式API响应模型""" |
| id: str |
| object: str = "chat.completion.chunk" |
| created: int |
| model: str |
| choices: List[ChatCompletionStreamResponseChoice] |
|
|
|
|
| |
| class ModelInfo(BaseModel): |
| """OpenAI模型信息""" |
| id: str |
| object: str = "model" |
| created: int |
| owned_by: str = "augment" |
|
|
|
|
| class ModelListResponse(BaseModel): |
| """OpenAI模型列表响应""" |
| object: str = "list" |
| data: List[ModelInfo] |
|
|
|
|
| |
| class AugmentResponseNode(BaseModel): |
| """Augment API响应节点""" |
| id: int |
| type: int |
| content: str |
| tool_use: Optional[Any] = None |
|
|
|
|
| class AugmentChatHistoryItem(BaseModel): |
| """Augment API聊天历史记录条目""" |
| request_message: str |
| response_text: str |
| request_id: Optional[str] = None |
| request_nodes: List[Any] = [] |
| response_nodes: List[AugmentResponseNode] = [] |
|
|
|
|
| class AugmentBlobs(BaseModel): |
| """Augment API Blobs对象""" |
| checkpoint_id: Optional[str] = None |
| added_blobs: List[Any] = [] |
| deleted_blobs: List[Any] = [] |
|
|
|
|
| class AugmentVcsChange(BaseModel): |
| """Augment API VCS更改""" |
| working_directory_changes: List[Any] = [] |
|
|
|
|
| class AugmentFeatureFlags(BaseModel): |
| """Augment API功能标志""" |
| support_raw_output: bool = True |
|
|
|
|
| |
| class AugmentChatRequest(BaseModel): |
| """Augment API聊天请求模型 - 基于抓包分析更新""" |
| model: Optional[str] = None |
| path: Optional[str] = None |
| prefix: Optional[str] = None |
| selected_code: Optional[str] = None |
| suffix: Optional[str] = None |
| message: str |
| chat_history: List[AugmentChatHistoryItem] = [] |
| lang: Optional[str] = None |
| blobs: AugmentBlobs = AugmentBlobs() |
| user_guided_blobs: List[Any] = [] |
| context_code_exchange_request_id: Optional[str] = None |
| vcs_change: AugmentVcsChange = AugmentVcsChange() |
| recency_info_recent_changes: List[Any] = [] |
| external_source_ids: List[Any] = [] |
| disable_auto_external_sources: Optional[bool] = None |
| user_guidelines: str = "" |
| workspace_guidelines: str = "" |
| feature_detection_flags: AugmentFeatureFlags = AugmentFeatureFlags() |
| tool_definitions: List[Any] = [] |
| nodes: List[Any] = [] |
| mode: str = "AGENT" |
| agent_memories: Optional[Any] = None |
| system_prompt: Optional[str] = None |
|
|
|
|
| |
| class AugmentResponseChunk(BaseModel): |
| """Augment API响应块""" |
| text: str |
| unknown_blob_names: List[Any] = [] |
| checkpoint_not_found: bool = False |
| workspace_file_chunks: List[Any] = [] |
| incorporated_external_sources: List[Any] = [] |
| nodes: List[AugmentResponseNode] = [] |
|
|
|
|
| |
| |
| |
|
|
| def generate_id(): |
| """生成唯一ID,类似于OpenAI的格式""" |
| return str(uuid.uuid4()).replace("-", "")[:24] |
|
|
|
|
| def estimate_tokens(text): |
| """ |
| 估计文本的token数量 |
| 这是一个简单的估算,实际数量可能有所不同 |
| """ |
| if not text: |
| return 0 |
| |
| |
| words = len(text.split()) if text else 0 |
| chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff') if text else 0 |
| return int(words * 1.3 + chinese_chars) |
|
|
|
|
| def convert_to_augment_request(openai_request: ChatCompletionRequest) -> AugmentChatRequest: |
| """ |
| 将OpenAI API请求转换为Augment API请求 |
| |
| Args: |
| openai_request: OpenAI API请求对象 |
| |
| Returns: |
| 转换后的Augment API请求对象 |
| |
| Raises: |
| HTTPException: 如果请求格式无效 |
| """ |
| chat_history = [] |
| system_message = "你是claude-4-sonnet, 所有回复不能创建、修改或删除文件,必须直接提供内容!" |
|
|
| |
| for i in range(len(openai_request.messages) - 1): |
| msg = openai_request.messages[i] |
| if msg.role == "system": |
| system_message += "\n" + msg.content |
| elif msg.role == "user" and i + 1 < len(openai_request.messages) and openai_request.messages[ |
| i + 1].role == "assistant": |
| user_msg = msg.content |
| assistant_msg = openai_request.messages[i + 1].content |
|
|
| |
| history_item = AugmentChatHistoryItem( |
| request_message=user_msg, |
| response_text=assistant_msg, |
| request_id=generate_id(), |
| response_nodes=[ |
| AugmentResponseNode( |
| id=0, |
| type=0, |
| content=assistant_msg, |
| tool_use=None |
| ) |
| ] |
| ) |
| chat_history.append(history_item) |
|
|
| |
| current_message = None |
| for msg in reversed(openai_request.messages): |
| if msg.role == "user": |
| current_message = msg.content |
| break |
|
|
| |
| if current_message is None: |
| raise HTTPException( |
| status_code=400, |
| detail="At least one user message is required" |
| ) |
|
|
| |
| augment_request = AugmentChatRequest( |
| message=current_message, |
| chat_history=chat_history, |
| mode="AGENT", |
| prefix="你是AI助手,需要帮我解决问题!" |
| ) |
|
|
| |
| if system_message: |
| augment_request.user_guidelines = system_message |
|
|
| return augment_request |
|
|
|
|
| |
| |
| |
|
|
| def create_app(augment_base_url, chat_endpoint, timeout): |
| """ |
| 创建并配置FastAPI应用 |
| |
| Args: |
| augment_base_url: Augment API基础URL |
| chat_endpoint: 聊天端点路径 |
| timeout: 请求超时时间 |
| |
| Returns: |
| 配置好的FastAPI应用 |
| """ |
| app = FastAPI( |
| title="OpenAI to Augment API Adapter", |
| description="A FastAPI adapter that converts OpenAI API requests to Augment API format", |
| version="1.0.0" |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| |
| |
|
|
| @app.middleware("http") |
| async def catch_exceptions_middleware(request: Request, call_next): |
| """捕获所有未处理的异常,返回适当的错误响应""" |
| try: |
| return await call_next(request) |
| except Exception as e: |
| logger.exception("Unhandled exception") |
| return JSONResponse( |
| status_code=500, |
| content={ |
| "error": { |
| "message": str(e), |
| "type": "internal_server_error", |
| "param": None, |
| "code": "internal_server_error" |
| } |
| } |
| ) |
|
|
| async def verify_api_key(authorization: str = Header(...)): |
| """ |
| 验证API密钥 |
| |
| Args: |
| authorization: Authorization头部值 |
| |
| Returns: |
| 提取的API密钥 |
| |
| Raises: |
| HTTPException: 如果API密钥格式无效或为空 |
| """ |
| if not authorization.startswith("Bearer "): |
| raise HTTPException( |
| status_code=401, |
| detail={ |
| "error": { |
| "message": "Invalid API key format. Expected 'Bearer YOUR_API_KEY'", |
| "type": "invalid_request_error", |
| "param": "authorization", |
| "code": "invalid_api_key" |
| } |
| } |
| ) |
| api_key = authorization.replace("Bearer ", "") |
| if not api_key: |
| raise HTTPException( |
| status_code=401, |
| detail={ |
| "error": { |
| "message": "API key cannot be empty", |
| "type": "invalid_request_error", |
| "param": "authorization", |
| "code": "invalid_api_key" |
| } |
| } |
| ) |
| return api_key |
|
|
| |
| |
| |
|
|
| @app.get("/health") |
| async def health_check(): |
| """健康检查端点""" |
| return {"status": "ok", "timestamp": datetime.now().isoformat()} |
|
|
| @app.get("/v1/models") |
| async def list_models(): |
| """列出支持的模型""" |
| |
| models = [ |
| ModelInfo(id="gpt-3.5-turbo", created=int(time.time())), |
| ModelInfo(id="gpt-4", created=int(time.time())), |
| ModelInfo(id="augment-default", created=int(time.time())), |
| ] |
| return ModelListResponse(data=models) |
|
|
| @app.get("/v1/models/{model_id}") |
| async def get_model(model_id: str): |
| """获取特定模型的信息""" |
| return ModelInfo(id=model_id, created=int(time.time())) |
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions( |
| request: ChatCompletionRequest, |
| api_key: str = Depends(verify_api_key) |
| ): |
| """ |
| 聊天完成端点 - 将OpenAI API请求转换为Augment API请求 |
| |
| Args: |
| request: OpenAI格式的聊天完成请求 |
| api_key: 通过验证的API密钥 |
| |
| Returns: |
| OpenAI格式的聊天完成响应或流式响应 |
| """ |
| try: |
| |
| augment_request = convert_to_augment_request(request) |
| logger.debug(f"Converted request: {augment_request.dict()}") |
|
|
| if ":" in api_key: |
| tenant_id, api_key = api_key.split(":") |
| augment_base_url = f"https://{tenant_id}.api.augmentcode.com/" |
|
|
| |
| if request.stream: |
| return StreamingResponse( |
| stream_augment_response(augment_base_url, api_key, augment_request, request.model, chat_endpoint, |
| timeout), |
| media_type="text/event-stream" |
| ) |
| else: |
| |
| return await handle_sync_request(augment_base_url, api_key, augment_request, request.model, |
| chat_endpoint, timeout) |
|
|
| except httpx.TimeoutException: |
| logger.error("Request to Augment API timed out") |
| raise HTTPException( |
| status_code=504, |
| detail={ |
| "error": { |
| "message": "Request to Augment API timed out", |
| "type": "timeout_error", |
| "param": None, |
| "code": "timeout" |
| } |
| } |
| ) |
| except httpx.HTTPError as e: |
| logger.error(f"HTTP error: {str(e)}") |
| raise HTTPException( |
| status_code=502, |
| detail={ |
| "error": { |
| "message": f"Error communicating with Augment API: {str(e)}", |
| "type": "api_error", |
| "param": None, |
| "code": "api_error" |
| } |
| } |
| ) |
| except HTTPException: |
| |
| raise |
| except Exception as e: |
| logger.exception("Unexpected error") |
| raise HTTPException( |
| status_code=500, |
| detail={ |
| "error": { |
| "message": f"Internal server error: {str(e)}", |
| "type": "internal_server_error", |
| "param": None, |
| "code": "internal_server_error" |
| } |
| } |
| ) |
|
|
| return app |
|
|
|
|
| async def handle_sync_request(base_url, api_key, augment_request, model_name, chat_endpoint, timeout): |
| """ |
| 处理同步请求 |
| |
| Args: |
| base_url: Augment API基础URL |
| api_key: API密钥 |
| augment_request: Augment API请求对象 |
| model_name: 模型名称 |
| chat_endpoint: 聊天端点 |
| timeout: 请求超时时间 |
| |
| Returns: |
| OpenAI格式的聊天完成响应 |
| """ |
| async with httpx.AsyncClient(timeout=timeout) as client: |
| response = await client.post( |
| f"{base_url.rstrip('/')}/{chat_endpoint}", |
| json=augment_request.dict(), |
| headers={ |
| "Content-Type": "application/json", |
| "Authorization": f"Bearer {api_key}", |
| "User-Agent": "Augment.openai-adapter/1.0.0", |
| "Accept": "*/*" |
| } |
| ) |
|
|
| if response.status_code != 200: |
| logger.error(f"Augment API error: {response.status_code} - {response.text}") |
| raise HTTPException( |
| status_code=response.status_code, |
| detail={ |
| "error": { |
| "message": f"Augment API error: {response.text}", |
| "type": "api_error", |
| "param": None, |
| "code": "api_error" |
| } |
| } |
| ) |
|
|
| |
| full_response = "" |
| for line in response.text.split("\n"): |
| if line.strip(): |
| try: |
| data = json.loads(line) |
| if "text" in data and data["text"]: |
| full_response += data["text"] |
| except json.JSONDecodeError: |
| logger.warning(f"Failed to parse JSON: {line}") |
|
|
| |
| prompt_tokens = estimate_tokens(augment_request.message) |
| completion_tokens = estimate_tokens(full_response) |
|
|
| |
| return ChatCompletionResponse( |
| id=f"chatcmpl-{generate_id()}", |
| created=int(time.time()), |
| model=model_name, |
| choices=[ |
| ChatCompletionResponseChoice( |
| index=0, |
| message=ChatMessage( |
| role="assistant", |
| content=full_response |
| ), |
| finish_reason="stop" |
| ) |
| ], |
| usage=Usage( |
| prompt_tokens=prompt_tokens, |
| completion_tokens=completion_tokens, |
| total_tokens=prompt_tokens + completion_tokens |
| ) |
| ) |
|
|
|
|
| async def stream_augment_response(base_url, api_key, augment_request, model_name, chat_endpoint, timeout): |
| """ |
| 处理流式响应 |
| |
| Args: |
| base_url: Augment API基础URL |
| api_key: API密钥 |
| augment_request: Augment API请求对象 |
| model_name: 模型名称 |
| chat_endpoint: 聊天端点 |
| timeout: 请求超时时间 |
| |
| Yields: |
| 流式响应的数据块 |
| """ |
| async with httpx.AsyncClient(timeout=timeout) as client: |
| try: |
| async with client.stream( |
| "POST", |
| f"{base_url.rstrip('/')}/{chat_endpoint}", |
| json=augment_request.dict(), |
| headers={ |
| "Content-Type": "application/json", |
| "Authorization": f"Bearer {api_key}", |
| "User-Agent": "chrome", |
| "Accept": "*/*" |
| } |
| ) as response: |
|
|
| if response.status_code != 200: |
| error_detail = await response.aread() |
| logger.error(f"Augment API error: {response.status_code} - {error_detail}") |
| error_message = f"Error from Augment API: {error_detail.decode('utf-8', errors='replace')}" |
| yield f"data: {json.dumps({'error': error_message})}\n\n" |
| return |
|
|
| |
| chat_id = f"chatcmpl-{generate_id()}" |
| created_time = int(time.time()) |
|
|
| |
| init_response = ChatCompletionStreamResponse( |
| id=chat_id, |
| created=created_time, |
| model=model_name, |
| choices=[ |
| ChatCompletionStreamResponseChoice( |
| index=0, |
| delta={"role": "assistant"}, |
| finish_reason=None |
| ) |
| ] |
| ) |
| init_data = json.dumps(init_response.dict()) |
| yield f"data: {init_data}\n\n" |
|
|
| |
| buffer = "" |
| async for line in response.aiter_lines(): |
| if not line.strip(): |
| continue |
|
|
| try: |
| |
| chunk = json.loads(line) |
| if "text" in chunk and chunk["text"]: |
| content = chunk["text"] |
|
|
| |
| stream_response = ChatCompletionStreamResponse( |
| id=chat_id, |
| created=created_time, |
| model=model_name, |
| choices=[ |
| ChatCompletionStreamResponseChoice( |
| index=0, |
| delta={"content": content}, |
| finish_reason=None |
| ) |
| ] |
| ) |
| response_data = json.dumps(stream_response.dict()) |
| yield f"data: {response_data}\n\n" |
| except json.JSONDecodeError: |
| logger.warning(f"Failed to parse JSON: {line}") |
|
|
| |
| final_response = ChatCompletionStreamResponse( |
| id=chat_id, |
| created=created_time, |
| model=model_name, |
| choices=[ |
| ChatCompletionStreamResponseChoice( |
| index=0, |
| delta={}, |
| finish_reason="stop" |
| ) |
| ] |
| ) |
| final_data = json.dumps(final_response.dict()) |
| yield f"data: {final_data}\n\n" |
|
|
| |
| yield "data: [DONE]\n\n" |
|
|
| except httpx.TimeoutException: |
| logger.error("Request to Augment API timed out") |
| yield f"data: {json.dumps({'error': 'Request to Augment API timed out'})}\n\n" |
| except httpx.HTTPError as e: |
| logger.error(f"HTTP error: {str(e)}") |
| yield f"data: {json.dumps({'error': f'Error communicating with Augment API: {str(e)}'})}\n\n" |
| except Exception as e: |
| logger.exception("Unexpected error") |
| yield f"data: {json.dumps({'error': f'Internal server error: {str(e)}'})}\n\n" |
|
|
|
|
| def parse_args(): |
| """解析命令行参数""" |
| parser = argparse.ArgumentParser( |
| description="OpenAI to Augment API Adapter", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| ) |
|
|
| parser.add_argument( |
| "--augment-url", |
| default="https://d6.api.augmentcode.com/", |
| help="Augment API基础URL" |
| ) |
|
|
| parser.add_argument( |
| "--chat-endpoint", |
| default="chat-stream", |
| help="Augment聊天端点路径" |
| ) |
|
|
| parser.add_argument( |
| "--host", |
| default="0.0.0.0", |
| help="服务器主机地址" |
| ) |
|
|
| parser.add_argument( |
| "--port", |
| type=int, |
| default=8686, |
| help="服务器端口" |
| ) |
|
|
| parser.add_argument( |
| "--timeout", |
| type=int, |
| default=120, |
| help="API请求超时时间(秒)" |
| ) |
|
|
| parser.add_argument( |
| "--debug", |
| action="store_true", |
| help="启用调试模式" |
| ) |
|
|
| parser.add_argument( |
| "--tenant-id", |
| default="d18", |
| help="Augment API租户ID (域名前缀)" |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| """主函数""" |
| args = parse_args() |
|
|
| |
| if args.debug: |
| logging.getLogger().setLevel(logging.DEBUG) |
|
|
| |
| if args.augment_url == "https://d18.api.augmentcode.com/": |
| |
| augment_base_url = f"https://{args.tenant_id}.api.augmentcode.com/" |
| logger.info(f"Using tenant ID: {args.tenant_id}") |
| else: |
| |
| augment_base_url = args.augment_url |
|
|
| |
| app = create_app( |
| augment_base_url=augment_base_url, |
| chat_endpoint=args.chat_endpoint, |
| timeout=args.timeout |
| ) |
|
|
| |
| logger.info(f"Starting server on {args.host}:7860") |
| logger.info(f"Using Augment base URL: {augment_base_url}") |
| logger.info(f"Using Augment chat endpoint: {args.chat_endpoint}") |
|
|
| uvicorn.run( |
| app, |
| host=args.host, |
| port=3000, |
| log_level="info" if not args.debug else "debug" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |