File size: 10,459 Bytes
3a5cf48 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 | import json
import socket
import struct
import sys
import threading
import os
import atexit
from typing import Callable, Any, Union, Tuple, Optional, List
class StatusServer:
def __init__(self, get_status_func: Callable[[bool], Any], server_address: Union[str, Tuple[str, int]]):
"""
初始化服务端
:param get_status_func: 获取状态的函数,返回当前进程状态字典, 支持一个参数 init,
当init为True时,表示获取初始化状态,否则为更新状态
:param server_address: 本地套接字文件路径(Unix域)或 (host, port)(TCP)
"""
self.get_status_func = get_status_func
self.server_address = server_address
self.clients: List[socket.socket] = []
self.lock = threading.Lock() # 线程锁
self.running = False
self.server_socket = None
def handle_client(self, client_socket):
"""处理客户端连接"""
try:
# 发送初始状态
new_status = self.get_status_func(True)
status_bytes = json.dumps(new_status).encode() # 使用 JSON 更安全
packed_data = len(status_bytes).to_bytes(4, 'little') + status_bytes # 添加长度头
# 添加到客户端列表
try:
# 分块发送
client_socket.sendall(packed_data) # 发送结束标志
except Exception as e:
print(f"Failed to send update to client: {e}")
client_socket.close()
return
with self.lock:
self.clients.append(client_socket)
# 保持连接以支持后续更新
while self.running:
try:
# 可选:接收客户端心跳或命令
data = client_socket.recv(1024)
if not data:
break
except:
break
finally:
# 关闭连接并从列表中移除
client_socket.close()
with self.lock:
if client_socket in self.clients:
self.clients.remove(client_socket)
def start_server(self):
"""启动本地套接字服务端"""
self.running = True
if isinstance(self.server_address, str):
# Unix 域套接字
self.server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.unlink(self.server_address)
except OSError:
if os.path.exists(self.server_address):
raise
self.server_socket.bind(self.server_address)
else:
# TCP 套接字
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server_socket.bind(self.server_address)
self.server_socket.listen(5)
print(f"Server is listening on {self.server_address}...")
try:
self.running = True
while self.running:
client_socket, _ = self.server_socket.accept()
print("Client connected")
# 启动新线程处理客户端
client_thread = threading.Thread(target=self.handle_client, args=(client_socket,))
client_thread.start()
except KeyboardInterrupt:
print("Shutting down server...")
finally:
self.stop()
def stop(self):
"""停止服务端并清理资源"""
if not self.running:
return
self.running = False
with self.lock:
for client in self.clients:
client.close()
self.clients.clear()
if self.server_socket:
self.server_socket.close()
self.server_socket = None
# 清理 Unix 套接字文件
if isinstance(self.server_address, str) and os.path.exists(self.server_address):
os.remove(self.server_address)
print(f"Socket file removed: {self.server_address}")
def update_status(self, update_data: Optional[dict]=None):
"""获取最新的状态并推送给所有客户端"""
if not update_data:
new_status = self.get_status_func(False)
else:
new_status = update_data
status_bytes = json.dumps(new_status).encode() # 使用 JSON 更安全
packed_data = len(status_bytes).to_bytes(4, 'little') + status_bytes # 添加长度头
with self.lock:
for client in self.clients:
print("Sending update to client...")
print(len(status_bytes), status_bytes, packed_data)
try:
client.sendall(packed_data) # 直接发送完整数据
except Exception as e:
print(f"Failed to send update to client: {e}")
client.close()
if client in self.clients:
self.clients.remove(client)
class StatusClient:
def __init__(self, server_address, callback=None):
"""
初始化客户端
:param server_address: Unix 域路径(字符串) 或 TCP 地址元组 (host, port)
:param callback: 接收到状态更新时的回调函数,接受一个 dict 参数
"""
self.server_address = server_address
self.callback = callback
self.sock: Optional[socket.socket] = None
self.running = False
self.receive_thread = None
def connect(self):
"""连接到服务端"""
if isinstance(self.server_address, str):
print("Connecting to Unix socket...", self.server_address)
# Unix 域套接字
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self.server_address)
else:
# TCP 套接字
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.connect(self.server_address)
print("Connected to server.")
# 启动接收线程
self.running = True
self.receive_thread = threading.Thread(target=self._receive_loop, daemon=True)
self.receive_thread.start()
def _receive_loop(self):
buffer = b''
while self.running:
try:
# 读取长度头
while len(buffer) < 4:
data = self.sock.recv(4)
if not data:
raise ConnectionResetError("Server closed the connection")
buffer += data
length = int.from_bytes(buffer[:4], 'little')
buffer = buffer[4:]
# 读取完整数据
while len(buffer) < length:
data = self.sock.recv(length - len(buffer))
if not data:
raise ConnectionResetError("Server closed the connection")
buffer += data
message = buffer[:length]
buffer = buffer[length:]
# 解析JSON
status = json.loads(message.decode())
if self.callback:
self.callback(status)
except ConnectionResetError as e:
print("连接中断:", e)
self.disconnect()
break
except json.JSONDecodeError as e:
print("JSON解析失败:", e)
continue
except Exception as e:
print("接收错误:", e)
self.disconnect()
break
def disconnect(self):
"""断开连接"""
self.running = False
if self.sock:
self.sock.close()
self.sock = None
print("Disconnected from server.")
def stop(self):
"""停止客户端"""
self.disconnect()
if self.receive_thread and self.receive_thread.is_alive():
self.receive_thread.join()
print("Client stopped.")
def wait_receive(self):
if self.receive_thread and self.receive_thread.is_alive():
self.receive_thread.join()
# 注册退出清理钩子
def register_cleanup(server_instance):
def cleanup():
server_instance.stop()
atexit.register(cleanup)
# # 示例使用
# if __name__ == '__main__' and "server" in sys.argv:
#
# import time
#
# # 模拟的状态存储
# process_status = {
# 'process1': 'running',
# 'process2': 'stopped',
# "big_data": "<AAAAAAAAAAAAAAAAAFFFFFFFFFFFFFFFFFFAAAAFFFFFFFFFFFFFFFAAAAAAAAAAAAAAAAAAAAAAAAA>"
# }
#
# def get_status():
# return process_status
#
# # 设置 Unix 域套接字地址
# server_address = './socket_filetransfer.sock'
#
# # 创建服务端实例
# server = StatusServer(get_status, server_address)
# register_cleanup(server) # 注册退出时清理
#
# # 启动服务端线程
# server_thread = threading.Thread(target=server.start_server)
# server_thread.daemon = True
# server_thread.start()
#
# # 模拟状态更新
# try:
# while True:
# print(">>>>>>>change<<<<<<<<<<<<<<<")
# time.sleep(5)
# process_status['process1'] = 'stopped' if process_status['process1'] == 'running' else 'running'
# server.update_status()
#
# time.sleep(5)
# process_status['process2'] = 'running' if process_status['process2'] == 'stopped' else 'stopped'
# server.update_status()
# except KeyboardInterrupt:
# pass
#
# # 示例使用
# if __name__ == '__main__' and "client" in sys.argv:
# # Unix 域套接字示例:
# server_address = './socket_filetransfer.sock'
#
# # 示例回调函数
# def on_status_update(status_dict):
# print("[Callback] New status received:")
# for k, v in status_dict.items():
# print(f" - {k}: {v}")
#
#
# client = StatusClient(server_address, callback=on_status_update)
# try:
# client.connect()
#
# # 主线程保持运行,防止程序退出
# while client.running:
# pass
# except KeyboardInterrupt:
# print("Client shutting down...")
# client.stop()
|