| from email import message |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import HTMLResponse, FileResponse |
| import uvicorn |
| import json |
| import asyncio |
| import os |
| from pathlib import Path |
| from datetime import datetime |
| from bw_utils import get_grandchild_folders, is_image, load_json_file |
| from BookWorld import BookWorld |
| os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
| app = FastAPI() |
| default_icon_path = './frontend/assets/images/default-icon.jpg' |
| config = load_json_file('config.json') |
| experiment_name = config["preset_path"].split("/")[-1].split(".")[0] |
| |
| |
|
|
| for key in config: |
| if "API_KEY" in key and config[key]: |
| os.environ[key] = config[key] |
|
|
| static_file_abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), 'frontend')) |
| app.mount("/frontend", StaticFiles(directory=static_file_abspath), name="frontend") |
|
|
| |
| PRESETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'experiment_presets') |
|
|
| class ConnectionManager: |
| def __init__(self): |
| self.active_connections: dict[str, WebSocket] = {} |
| self.story_tasks: dict[str, asyncio.Task] = {} |
| if True: |
| if "preset_path" in config and config["preset_path"]: |
| if os.path.exists(config["preset_path"]): |
| preset_path = config["preset_path"] |
| else: |
| raise ValueError(f"The preset path {config['preset_path']} does not exist.") |
| elif "genre" in config and config["genre"]: |
| genre = config["genre"] |
| preset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),f"./config/experiment_{genre}.json") |
| else: |
| raise ValueError("Please set the preset_path in `config.json`.") |
| self.bw = BookWorld(preset_path = preset_path, |
| world_llm_name = config["world_llm_name"], |
| role_llm_name = config["role_llm_name"], |
| embedding_name = config["embedding_model_name"]) |
| self.bw.set_generator(rounds = config["rounds"], |
| save_dir = config["save_dir"], |
| if_save = config["if_save"], |
| mode = config["mode"], |
| scene_mode = config["scene_mode"],) |
| else: |
| from BookWorld_test import BookWorld_test |
| self.bw = BookWorld_test() |
| |
| async def connect(self, websocket: WebSocket, client_id: str): |
| await websocket.accept() |
| self.active_connections[client_id] = websocket |
| |
| def disconnect(self, client_id: str): |
| if client_id in self.active_connections: |
| del self.active_connections[client_id] |
| self.stop_story(client_id) |
| |
| def stop_story(self, client_id: str): |
| if client_id in self.story_tasks: |
| self.story_tasks[client_id].cancel() |
| del self.story_tasks[client_id] |
|
|
| async def start_story(self, client_id: str): |
| if client_id in self.story_tasks: |
| |
| self.stop_story(client_id) |
| |
| |
| self.story_tasks[client_id] = asyncio.create_task( |
| self.generate_story(client_id) |
| ) |
|
|
| async def generate_story(self, client_id: str): |
| """持续生成故事的协程""" |
| try: |
| while True: |
| if client_id in self.active_connections: |
| message,status = await self.get_next_message() |
| await self.active_connections[client_id].send_json({ |
| 'type': 'message', |
| 'data': message |
| }) |
| await self.active_connections[client_id].send_json({ |
| 'type': 'status_update', |
| 'data': status |
| }) |
| |
| await asyncio.sleep(2) |
| else: |
| break |
| except asyncio.CancelledError: |
| |
| print(f"Story generation cancelled for client {client_id}") |
| except Exception as e: |
| print(f"Error in generate_story: {e}") |
|
|
| async def get_initial_data(self): |
| """获取初始化数据""" |
| data = { |
| 'characters': self.bw.get_characters_info(), |
| 'map': self.bw.get_map_info(), |
| 'settings': self.bw.get_settings_info(), |
| 'status': self.bw.get_current_status(), |
| |
| 'history_messages':[], |
| } |
| |
| return data |
| |
| async def get_next_message(self): |
| """从BookWorld获取下一条消息""" |
| message = self.bw.generate_next_message() |
| if not os.path.exists(message["icon"]) or not is_image(message["icon"]): |
| message["icon"] = default_icon_path |
| status = self.bw.get_current_status() |
| |
| return message,status |
|
|
| manager = ConnectionManager() |
|
|
| @app.get("/") |
| async def get(): |
| html_file = Path("index.html") |
| return HTMLResponse(html_file.read_text(encoding="utf-8")) |
|
|
| @app.get("/data/{full_path:path}") |
| async def get_file(full_path: str): |
| |
| base_paths = [ |
| Path("/data/") |
| ] |
| |
| for base_path in base_paths: |
| file_path = base_path / full_path |
| if file_path.exists() and file_path.is_file(): |
| return FileResponse(file_path) |
| else: |
| return FileResponse(default_icon_path) |
| |
| raise HTTPException(status_code=404, detail="File not found") |
|
|
| @app.get("/api/list-presets") |
| async def list_presets(): |
| try: |
| |
| presets = [f for f in os.listdir(PRESETS_DIR) if f.endswith('.json')] |
| return {"presets": presets} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/api/load-preset") |
| async def load_preset(request: Request): |
| try: |
| data = await request.json() |
| preset_name = data.get('preset') |
| |
| if not preset_name: |
| raise HTTPException(status_code=400, detail="No preset specified") |
| |
| preset_path = os.path.join(PRESETS_DIR, preset_name) |
| print(f"Loading preset from: {preset_path}") |
| |
| if not os.path.exists(preset_path): |
| raise HTTPException(status_code=404, detail=f"Preset not found: {preset_path}") |
| |
| try: |
| |
| manager.bw = BookWorld( |
| preset_path=preset_path, |
| world_llm_name=config["world_llm_name"], |
| role_llm_name=config["role_llm_name"], |
| embedding_name=config["embedding_model_name"] |
| ) |
| config["preset_path"] = preset_path |
| experiment_name = preset_path.split("/")[-1].split(".")[0] |
| |
| |
| manager.bw.set_generator( |
| rounds=config["rounds"], |
| save_dir=config["save_dir"], |
| if_save=config["if_save"], |
| mode=config["mode"], |
| scene_mode=config["scene_mode"] |
| ) |
| |
| |
| initial_data = await manager.get_initial_data() |
| |
| return { |
| "success": True, |
| "data": initial_data |
| } |
| except Exception as e: |
| print(f"Error initializing BookWorld: {str(e)}") |
| raise HTTPException(status_code=500, detail=f"Error initializing BookWorld: {str(e)}") |
| |
| except Exception as e: |
| print(f"Error in load_preset: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.websocket("/ws/{client_id}") |
| async def websocket_endpoint(websocket: WebSocket, client_id: str): |
| await manager.connect(websocket, client_id) |
| try: |
| initial_data = await manager.get_initial_data() |
| await websocket.send_json({ |
| 'type': 'initial_data', |
| 'data': initial_data |
| }) |
| |
| while True: |
| data = await websocket.receive_text() |
| message = json.loads(data) |
| |
| if message['type'] == 'user_message': |
| |
| await websocket.send_json({ |
| 'type': 'message', |
| 'data': { |
| 'username': 'User', |
| 'timestamp': message['timestamp'], |
| 'text': message['text'], |
| 'icon': default_icon_path, |
| } |
| }) |
| |
| elif message['type'] == 'control': |
| |
| if message['action'] == 'start': |
| await manager.start_story(client_id) |
| elif message['action'] == 'pause': |
| manager.stop_story(client_id) |
| elif message['action'] == 'stop': |
| manager.stop_story(client_id) |
| |
| |
| elif message['type'] == 'edit_message': |
| |
| edit_data = message['data'] |
| |
| manager.bw.handle_message_edit( |
| record_id=edit_data['uuid'], |
| new_text=edit_data['text'] |
| ) |
| |
| elif message['type'] == 'request_scene_characters': |
| manager.bw.select_scene(message['scene']) |
| scene_characters = manager.bw.get_characters_info() |
| await websocket.send_json({ |
| 'type': 'scene_characters', |
| 'data': scene_characters |
| }) |
| |
| elif message['type'] == 'generate_story': |
| |
| story_text = manager.bw.generate_story() |
| |
| await websocket.send_json({ |
| 'type': 'message', |
| 'data': { |
| 'username': 'System', |
| 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| 'text': story_text, |
| 'icon': default_icon_path, |
| 'type': 'story' |
| } |
| }) |
| |
| except Exception as e: |
| print(f"WebSocket error: {e}") |
| finally: |
| manager.disconnect(client_id) |
|
|
| @app.post("/api/save-config") |
| async def save_config(request: Request): |
| global config |
| global manager |
| try: |
| config_data = await request.json() |
| |
| if 'provider' not in config_data or 'model' not in config_data or 'apiKey' not in config_data: |
| raise HTTPException(status_code=400, detail="缺少必要的字段") |
|
|
| llm_provider = config_data['provider'] |
| model = config_data['model'] |
| api_key = config_data['apiKey'] |
| config['role_llm_name'] = model |
| config['world_llm_name'] = model |
| if 'openai' in llm_provider.lower(): |
| os.environ['OPENAI_API_KEY'] = api_key |
| elif 'anthropic' in llm_provider.lower(): |
| os.environ['ANTHROPIC_API_KEY'] = api_key |
| elif 'alibaba' in llm_provider.lower(): |
| os.environ['DASHSCOPE_API_KEY'] = api_key |
| elif 'openrouter' in llm_provider.lower(): |
| os.environ['OPENROUTER_API_KEY'] = api_key |
| |
| manager.bw.server.reset_llm(model,model) |
| return {"status": "success", "message": llm_provider + " 配置已保存"} |
| |
| except Exception as e: |
| print(f"保存配置失败: {e}") |
| raise HTTPException(status_code=500, detail="保存配置失败") |
|
|
| if __name__ == "__main__": |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |
|
|