• 需求场景:在用户登陆后,与前端保持websocket的连接,进行前后交互,比如:消息通知
  • 参考官方文档:https://websockets.readthedocs.io/en/stable/intro.html

使用版本

python 3.8 
websockets==8.1
redis==3.5.3

websocket 安装

pip install websockets

主要设计方案

  • 用户:
  • 1 websocket 对象存储格式,存在内存中, 以用户id为key, value是websocket对象的列表,使用列表是因为我们的服务支持多点登陆,且需要在每个页面上都建立连接
{ user_id: [websocket object1, websocket object2]}
  • 2 每次登录,或刷新页面,需要重建连接(由前端控制,若连接断开,也由前端负责重连)
  • 3 每次重连时,都需要发送 login 给后端,后端通过cookie判断用户,并将此次的websocket对象存储下来,即第1步。
  • 单点部署服务端产生消息的逻辑:
  • 1 将所有消息存储在Redis中,使用格式为zset。
  • 2 producer_handler 每秒通过score值小于当前时间来获取所需推送的消息。
  • 需支持多点部署的服务端产生消息的逻辑:
  • 1 采用 Redis 的发布与订阅功能,每个websocket服务都订阅产生消息的频道。
  • 2 实时推送的消息直接publish
  • 3 将定时消息存储在Redis中,使用格式为zset , score值为后端推送的时间。此部分消息通过定时任务/脚本触发,通过score值小于当前时间来获取所需推送的消息并publish。
  • 4 producer_handler直接从订阅频道获取消息即可。
  • 前后端交互,数据格式约定参考:
1. 前端推送
通用字段: msg_type 消息类型
比如:前端发送消息,后端可通过header中的cookie来判断是哪个用户
登陆成功,前端发起ws事件通知  {"msg_type": "login"}
2. 后端推送
通用字段: 
msg_type 消息类型 
to_user_id 代表推送给哪个用户 (此字段可只用来给后端做判断,不传给前端)
data 消息内容

主要逻辑代码

import asyncio
import json
import websockets
from websockets import WebSocketServerProtocol
from websockets.exceptions import ConnectionClosed, ConnectionClosedOK
from typing import Callable, Optional, List


# 建立redis的连接实例,可从项目中导入
redis_store = redis.StrictRedis() 
redis_pubsub = redis_store.pubsub()


user_websockets_map = {}


def do_login(websocket: WebSocketServerProtocol, message: dict) -> None:
    user_id = message['user_id']
    websockets = user_websockets_map.get(user_id)
    if not websockets:
        user_websockets_map[user_id] = [websocket]
    elif websocket not in websockets:
        websockets.append(websocket)


async def consumer(websocket: WebSocketServerProtocol, message: str) -> None:
    # 处理前端发过来的消息
    # 如果消息被判定为登陆,则可以调用 do_login ,存入用户 webdocket 对象
    pass


async def producer() -> List[dict]:
    """
    每秒把订阅的消息全部取出
    message 中约定了固定字段:to_user_id,代表需要发送给的用户id
    单点部署,使用zset获取消息
    代码如下:
    await asyncio.sleep(1)
    current = get_current_timestamp()
    messages = redis_store.zrangebyscore(WebsocketUnit.TIMING_MSG_CACHE_KEY, 1, current)
    if messages:
        redis_store.zremrangebyscore(WebsocketUnit.TIMING_MSG_CACHE_KEY, 1, current)
    return [json.loads(message.dumps('utf-8')) for message in messages]
    """
    # 以下示例为 使用 Redis 发布订阅模式获取消息 
    await asyncio.sleep(1)
    result = []
    while True:
        message = redis_pubsub.get_message()
        if message:
            # 此处省略 初步处理订阅频道中的消息的逻辑,比如做消息格式的校验等等 
            message = json.loads(message.dumps('utf-8'))
            result.append(message)
        else:
            break
    return result



async def consumer_handler(websocket: WebSocketServerProtocol) -> None:
    async for message in websocket:
        await consumer(websocket, message)


async def producer_handler() -> None:
    while True:
        messages = await producer()
        for message in messages:
            # 定点发送给用户
            to_user_id = message.get('to_user_id')
            websockets = to_user_id and user_websockets_map.get(to_user_id) or []
            for ws in websockets:
                try:
                    await ws.send(json.dumps(message))
                except (ConnectionClosed, ConnectionClosedOK):
                    # 移除已关闭的websocket
                    websockets.remove(ws)


async def handler(websocket: WebSocketServerProtocol, path):
    
    # 作为前端发送消息的消费者
    consumer_task = asyncio.ensure_future(
        consumer_handler(websocket))

    # 作为推送给前端消息的生产者
    producer_task = asyncio.ensure_future(
        producer_handler())

    done, pending = await asyncio.wait(
        [consumer_task, producer_task],
        return_when=asyncio.ALL_COMPLETED,
    )
    for task in pending:
        task.cancel()


def main():
    # 订阅producer消息
    redis_pubsub.subscribe('websocket:producer:channel')

    # 开启websocket服务,端口号为9992
    start_server = websockets.serve(handler, host="0.0.0.0", port=9992)

    # 创建协程事件循环
    asyncio.get_event_loop().run_until_complete(start_server)
    asyncio.get_event_loop().run_forever()


if __name__ == '__main__':
    main()