发布日期:2024年1月 | 作者:Python高级工程师
异步编程与WebSocket技术背景
在现代Web应用中,实时数据推送已成为基本需求。传统的HTTP轮询方式存在延迟高、服务器压力大等问题,而WebSocket提供了真正的全双工通信能力。
技术栈选择
- Asyncio: Python原生异步IO框架
- WebSockets: 轻量级WebSocket库
- Redis Pub/Sub: 消息广播机制
- JSON Web Tokens: 连接认证
系统架构设计
组件架构图
客户端 → WebSocket连接管理器 → 消息路由 → Redis发布订阅 → 数据处理器 ↑ ↓ 认证中心 ←── 连接状态监控 ←── 异常处理 ←── 日志记录
目录结构
websocket_system/
├── app/
│ ├── __init__.py
│ ├── main.py # 应用入口
│ ├── websocket/
│ │ ├── manager.py # 连接管理器
│ │ ├── handler.py # 消息处理器
│ │ └── auth.py # 认证模块
│ ├── services/
│ │ ├── redis_service.py # Redis服务
│ │ └── message_service.py # 消息服务
│ └── models/
│ ├── connection.py # 连接模型
│ └── message.py # 消息模型
├── config/
│ ├── settings.py # 配置文件
│ └── constants.py # 常量定义
└── requirements.txt
核心功能实现
1. WebSocket服务器主程序
import asyncio
import websockets
import json
import jwt
from datetime import datetime
from typing import Dict, Set
class WebSocketServer:
def __init__(self, host: str = "localhost", port: int = 8765):
self.host = host
self.port = port
self.connections: Dict[str, Set[websockets.WebSocketServerProtocol]] = {}
self.authenticated_clients: Dict[str, dict] = {}
async def handle_connection(self, websocket, path):
"""处理WebSocket连接"""
client_id = id(websocket)
print(f"新连接建立: {client_id}")
try:
# 等待客户端认证
auth_message = await asyncio.wait_for(
websocket.recv(), timeout=30.0
)
auth_data = json.loads(auth_message)
if await self.authenticate_client(auth_data, websocket):
await self.register_client(websocket, auth_data)
await self.handle_messages(websocket)
else:
await websocket.close(1008, "认证失败")
except asyncio.TimeoutError:
await websocket.close(1008, "认证超时")
except Exception as e:
print(f"连接处理错误: {e}")
await websocket.close(1011, "服务器内部错误")
finally:
await self.unregister_client(websocket)
async def authenticate_client(self, auth_data: dict, websocket) -> bool:
"""客户端认证"""
try:
token = auth_data.get('token')
if not token:
return False
# JWT令牌验证
payload = jwt.decode(
token,
"your-secret-key",
algorithms=["HS256"]
)
user_id = payload.get('user_id')
if not user_id:
return False
self.authenticated_clients[id(websocket)] = {
'user_id': user_id,
'connected_at': datetime.now(),
'channels': set(auth_data.get('channels', []))
}
return True
except jwt.InvalidTokenError:
return False
async def register_client(self, websocket, auth_data):
"""注册客户端到频道"""
client_info = self.authenticated_clients[id(websocket)]
for channel in client_info['channels']:
if channel not in self.connections:
self.connections[channel] = set()
self.connections[channel].add(websocket)
await websocket.send(json.dumps({
'type': 'auth_success',
'message': '认证成功',
'user_id': client_info['user_id']
}))
async def handle_messages(self, websocket):
"""处理客户端消息"""
async for message in websocket:
try:
data = json.loads(message)
await self.process_message(data, websocket)
except json.JSONDecodeError:
await websocket.send(json.dumps({
'type': 'error',
'message': '消息格式错误'
}))
async def process_message(self, data: dict, websocket):
"""处理不同类型的消息"""
message_type = data.get('type')
if message_type == 'subscribe':
await self.handle_subscribe(data, websocket)
elif message_type == 'unsubscribe':
await self.handle_unsubscribe(data, websocket)
elif message_type == 'broadcast':
await self.handle_broadcast(data, websocket)
else:
await websocket.send(json.dumps({
'type': 'error',
'message': '不支持的消息类型'
}))
async def handle_broadcast(self, data: dict, websocket):
"""处理广播消息"""
channel = data.get('channel')
message = data.get('message')
if channel in self.connections:
await self.broadcast_to_channel(channel, {
'type': 'message',
'from_user': self.authenticated_clients[id(websocket)]['user_id'],
'channel': channel,
'message': message,
'timestamp': datetime.now().isoformat()
})
async def broadcast_to_channel(self, channel: str, message: dict):
"""向频道广播消息"""
if channel not in self.connections:
return
disconnected = set()
message_json = json.dumps(message)
for websocket in self.connections[channel]:
try:
await websocket.send(message_json)
except websockets.exceptions.ConnectionClosed:
disconnected.add(websocket)
# 清理断开连接的客户端
for websocket in disconnected:
self.connections[channel].discard(websocket)
async def unregister_client(self, websocket):
"""取消注册客户端"""
client_id = id(websocket)
if client_id in self.authenticated_clients:
client_info = self.authenticated_clients[client_id]
# 从所有频道中移除
for channel in client_info['channels']:
if channel in self.connections:
self.connections[channel].discard(websocket)
del self.authenticated_clients[client_id]
print(f"连接关闭: {client_id}")
async def start_server(self):
"""启动WebSocket服务器"""
print(f"启动WebSocket服务器在 {self.host}:{self.port}")
async with websockets.serve(
self.handle_connection,
self.host,
self.port
):
await asyncio.Future() # 永久运行
2. Redis消息广播服务
import asyncio
import redis.asyncio as redis
import json
from typing import Callable
class RedisBroadcastService:
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis_url = redis_url
self.redis_client = None
self.pubsub = None
self.message_handlers = {}
async def connect(self):
"""连接Redis服务器"""
self.redis_client = await redis.from_url(self.redis_url)
self.pubsub = self.redis_client.pubsub()
async def subscribe(self, channel: str, handler: Callable):
"""订阅频道并设置消息处理器"""
await self.pubsub.subscribe(channel)
self.message_handlers[channel] = handler
async def publish(self, channel: str, message: dict):
"""发布消息到频道"""
if self.redis_client:
await self.redis_client.publish(
channel,
json.dumps(message)
)
async def listen_messages(self):
"""监听并处理消息"""
async for message in self.pubsub.listen():
if message['type'] == 'message':
channel = message['channel'].decode()
data = json.loads(message['data'])
if channel in self.message_handlers:
await self.message_handlers[channel](data)
async def start_broadcast_service(self, websocket_server):
"""启动广播服务"""
await self.connect()
# 订阅系统频道
await self.subscribe('system_notifications',
lambda msg: self.handle_system_notification(msg, websocket_server))
await self.subscribe('user_messages',
lambda msg: self.handle_user_message(msg, websocket_server))
# 开始监听消息
await self.listen_messages()
async def handle_system_notification(self, message: dict, server):
"""处理系统通知"""
await server.broadcast_to_channel('system', {
'type': 'system_notification',
'message': message.get('content'),
'timestamp': message.get('timestamp')
})
async def handle_user_message(self, message: dict, server):
"""处理用户消息"""
target_channel = message.get('target_channel')
if target_channel:
await server.broadcast_to_channel(target_channel, {
'type': 'user_message',
'user_id': message.get('user_id'),
'content': message.get('content'),
'timestamp': message.get('timestamp')
})
3. 连接健康监控
import asyncio
from datetime import datetime, timedelta
class ConnectionMonitor:
def __init__(self, websocket_server, check_interval: int = 60):
self.server = websocket_server
self.check_interval = check_interval
self.is_monitoring = False
async def start_monitoring(self):
"""开始连接监控"""
self.is_monitoring = True
print("启动连接健康监控...")
while self.is_monitoring:
await self.check_connections_health()
await asyncio.sleep(self.check_interval)
async def check_connections_health(self):
"""检查连接健康状态"""
current_time = datetime.now()
disconnected_clients = []
for client_id, client_info in self.server.authenticated_clients.items():
# 检查连接时间(模拟实际的心跳检查)
connected_time = client_info.get('connected_at')
if (current_time - connected_time) > timedelta(hours=24):
disconnected_clients.append(client_id)
# 记录监控数据
total_connections = len(self.server.authenticated_clients)
total_channels = len(self.server.connections)
print(f"连接监控 - 总连接数: {total_connections}, 总频道数: {total_channels}")
if disconnected_clients:
print(f"发现 {len(disconnected_clients)} 个异常连接")
async def send_heartbeat(self):
"""发送心跳包到所有连接"""
heartbeat_message = json.dumps({
'type': 'heartbeat',
'timestamp': datetime.now().isoformat()
})
for channel in self.server.connections:
await self.server.broadcast_to_channel(channel, {
'type': 'heartbeat',
'timestamp': datetime.now().isoformat()
})
def stop_monitoring(self):
"""停止监控"""
self.is_monitoring = False
高级特性实现
1. 消息队列集成
import asyncio
from concurrent.futures import ThreadPoolExecutor
import pickle
class MessageQueueProcessor:
def __init__(self, websocket_server):
self.server = websocket_server
self.executor = ThreadPoolExecutor(max_workers=4)
self.processing_tasks = set()
async def process_in_background(self, message_type: str, data: dict):
"""在后台线程中处理耗时任务"""
loop = asyncio.get_event_loop()
# 在线程池中执行阻塞操作
result = await loop.run_in_executor(
self.executor,
self._process_message,
message_type,
data
)
# 处理完成后推送结果
await self._push_result_to_clients(result)
def _process_message(self, message_type: str, data: dict):
"""实际的消息处理逻辑(在后台线程中运行)"""
# 模拟耗时操作
import time
time.sleep(2)
if message_type == 'data_processing':
return {
'type': 'processing_result',
'status': 'completed',
'input_data': data,
'result': {'processed': True, 'timestamp': datetime.now().isoformat()}
}
elif message_type == 'file_operation':
return {
'type': 'file_result',
'status': 'success',
'operation': data.get('operation')
}
return {'status': 'unknown_operation'}
async def _push_result_to_clients(self, result: dict):
"""将处理结果推送到相关客户端"""
if 'target_channel' in result:
await self.server.broadcast_to_channel(
result['target_channel'],
result
)
2. 连接限流与防护
from collections import defaultdict
from datetime import datetime, timedelta
class RateLimiter:
def __init__(self, max_requests: int = 100, window_seconds: int = 60):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = defaultdict(list)
async def is_rate_limited(self, client_id: str) -> bool:
"""检查客户端是否被限流"""
now = datetime.now()
window_start = now - timedelta(seconds=self.window_seconds)
# 清理过期的请求记录
self.requests[client_id] = [
req_time for req_time in self.requests[client_id]
if req_time > window_start
]
# 检查请求数量
if len(self.requests[client_id]) >= self.max_requests:
return True
# 记录新请求
self.requests[client_id].append(now)
return False
def get_remaining_requests(self, client_id: str) -> int:
"""获取剩余请求次数"""
now = datetime.now()
window_start = now - timedelta(seconds=self.window_seconds)
recent_requests = [
req_time for req_time in self.requests[client_id]
if req_time > window_start
]
return max(0, self.max_requests - len(recent_requests))
class ConnectionGuard:
def __init__(self, max_connections_per_ip: int = 10):
self.max_connections_per_ip = max_connections_per_ip
self.connections_per_ip = defaultdict(set)
self.rate_limiter = RateLimiter()
async def can_accept_connection(self, websocket, ip_address: str) -> bool:
"""检查是否可以接受新连接"""
# 检查IP连接数限制
if len(self.connections_per_ip[ip_address]) >= self.max_connections_per_ip:
return False
# 检查速率限制
if await self.rate_limiter.is_rate_limited(ip_address):
return False
return True
def register_connection(self, websocket, ip_address: str):
"""注册新连接"""
self.connections_per_ip[ip_address].add(id(websocket))
def unregister_connection(self, websocket, ip_address: str):
"""取消注册连接"""
self.connections_per_ip[ip_address].discard(id(websocket))
性能优化与部署
1. 异步性能优化
import asyncio
import time
from functools import wraps
def async_timing_decorator(func):
"""异步函数执行时间装饰器"""
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
try:
result = await func(*args, **kwargs)
return result
finally:
end_time = time.time()
print(f"{func.__name__} 执行时间: {end_time - start_time:.4f}秒")
return wrapper
class PerformanceOptimizer:
def __init__(self):
self.message_batch = []
self.batch_size = 10
self.batch_timeout = 0.1 # 100毫秒
async def batch_broadcast(self, channel: str, message: dict):
"""批量广播消息以减少IO操作"""
self.message_batch.append((channel, message))
if (len(self.message_batch) >= self.batch_size or
hasattr(self, '_batch_task') and self._batch_task.done()):
await self._flush_batch()
else:
# 设置定时刷新任务
if not hasattr(self, '_batch_task') or self._batch_task.done():
self._batch_task = asyncio.create_task(self._schedule_flush())
async def _schedule_flush(self):
"""定时刷新批量消息"""
await asyncio.sleep(self.batch_timeout)
await self._flush_batch()
async def _flush_batch(self):
"""刷新批量消息到客户端"""
if not self.message_batch:
return
# 按频道分组消息
channel_messages = defaultdict(list)
for channel, message in self.message_batch:
channel_messages[channel].append(message)
# 批量发送
for channel, messages in channel_messages.items():
if channel in self.server.connections:
message_batch = {
'type': 'batch_messages',
'messages': messages,
'count': len(messages)
}
await self.server.broadcast_to_channel(
channel,
message_batch
)
self.message_batch.clear()
2. 系统启动与配置
import asyncio
import signal
import sys
from app.websocket.manager import WebSocketServer
from app.services.redis_service import RedisBroadcastService
from app.monitoring.connection_monitor import ConnectionMonitor
class Application:
def __init__(self):
self.websocket_server = WebSocketServer()
self.redis_service = RedisBroadcastService()
self.connection_monitor = ConnectionMonitor(self.websocket_server)
self.is_running = False
async def startup(self):
"""应用启动"""
print("启动WebSocket实时数据推送系统...")
# 启动Redis广播服务
redis_task = asyncio.create_task(
self.redis_service.start_broadcast_service(self.websocket_server)
)
# 启动连接监控
monitor_task = asyncio.create_task(
self.connection_monitor.start_monitoring()
)
# 启动WebSocket服务器
server_task = asyncio.create_task(
self.websocket_server.start_server()
)
self.is_running = True
print("系统启动完成")
# 等待所有任务
await asyncio.gather(redis_task, monitor_task, server_task)
async def shutdown(self):
"""应用关闭"""
print("正在关闭系统...")
self.is_running = False
self.connection_monitor.stop_monitoring()
print("系统已关闭")
def handle_shutdown(signum, frame):
"""处理关闭信号"""
print("接收到关闭信号,正在优雅关闭...")
asyncio.create_task(app.shutdown())
async def main():
"""主函数"""
global app
app = Application()
# 注册信号处理器
signal.signal(signal.SIGINT, handle_shutdown)
signal.signal(signal.SIGTERM, handle_shutdown)
try:
await app.startup()
except KeyboardInterrupt:
await app.shutdown()
except Exception as e:
print(f"系统运行错误: {e}")
await app.shutdown()
if __name__ == "__main__":
asyncio.run(main())
document.addEventListener(‘DOMContentLoaded’, function() {
const codeBlocks = document.querySelectorAll(‘pre code’);
codeBlocks.forEach(block => {
block.addEventListener(‘click’, function() {
const textArea = document.createElement(‘textarea’);
textArea.value = this.textContent;
document.body.appendChild(textArea);
textArea.select();
try {
document.execCommand(‘copy’);
console.log(‘Python代码已复制到剪贴板’);
} catch (err) {
console.error(‘复制失败:’, err);
}
document.body.removeChild(textArea);
});
block.title = ‘点击复制Python代码’;
});
});