diff --git a/server/websocket/internal/logic/datatransferlogic.go b/server/websocket/internal/logic/datatransferlogic.go index 53a2d2c6..b8520c5a 100644 --- a/server/websocket/internal/logic/datatransferlogic.go +++ b/server/websocket/internal/logic/datatransferlogic.go @@ -79,6 +79,7 @@ var ( // 每个连接的连接基本属性 type wsConnectItem struct { conn *websocket.Conn //websocket的连接(基本属性) + userAgent string //用户代理头信息(基本属性,用于重连标识验证因素之一) logic *DataTransferLogic //logic(基本属性,用于获取上下文,配置或者操作数据库) closeChan chan struct{} //ws连接关闭chan(基本属性) isClose bool //是否已经关闭(基本属性) @@ -127,29 +128,29 @@ func (l *DataTransferLogic) DataTransfer(w http.ResponseWriter, r *http.Request) return } //设置连接 - ws, err := l.setConnPool(conn, userInfo, isFirefoxBrowser) + ws, err := l.setConnPool(conn, userInfo, isFirefoxBrowser, userAgent) if err != nil { conn.Close() return } //循环读客户端信息 - go ws.readLoop() - //循环把数据发送给客户端 - go ws.writeLoop() - //推消息到云渲染 - go ws.sendLoop() + go ws.acceptBrowserMessage() + //消费出口数据并发送浏览器端 + go ws.consumeOutChanData() + //消费入口数据 + go ws.consumeInChanData() //操作连接中渲染任务的增加/删除 go ws.operationRenderTask() //消费渲染缓冲队列 - go ws.renderImage() + go ws.consumeRenderImageData() //心跳 ws.heartbeat() } // 设置连接 -func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo *auth.UserInfo, isFirefoxBrowser bool) (wsConnectItem, error) { +func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo *auth.UserInfo, isFirefoxBrowser bool, userAgent string) (wsConnectItem, error) { //生成连接唯一标识(失败重试10次) - uniqueId, err := l.getUniqueId(userInfo, 10) + uniqueId, err := l.getUniqueId(userInfo, userAgent, 10) if err != nil { //发送获取唯一标识失败的消息 l.sendGetUniqueIdErrResponse(conn) @@ -157,11 +158,14 @@ func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo *auth.Use } ws := wsConnectItem{ conn: conn, + userAgent: userAgent, logic: l, - uniqueId: uniqueId, closeChan: make(chan struct{}, 1), + isClose: false, + uniqueId: uniqueId, inChan: make(chan []byte, websocketInChanLen), outChan: make(chan []byte, websocketOutChanLen), + mutex: sync.Mutex{}, userId: userInfo.UserId, guestId: userInfo.GuestId, extendRenderProperty: extendRenderProperty{ @@ -180,15 +184,15 @@ func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo *auth.Use } // 获取唯一id -func (l *DataTransferLogic) getUniqueId(userInfo *auth.UserInfo, retryTimes int) (uniqueId string, err error) { +func (l *DataTransferLogic) getUniqueId(userInfo *auth.UserInfo, userAgent string, retryTimes int) (uniqueId string, err error) { if retryTimes < 0 { return "", errors.New("failed to get unique id") } //后面拼接上用户id - uniqueId = hex.EncodeToString([]byte(uuid.New().String())) + getUserJoinPart(userInfo.UserId, userInfo.GuestId) + uniqueId = hex.EncodeToString([]byte(uuid.New().String())) + getUserJoinPart(userInfo.UserId, userInfo.GuestId, userAgent) //存在则从新获取 if _, ok := mapConnPool.Load(uniqueId); ok { - uniqueId, err = l.getUniqueId(userInfo, retryTimes-1) + uniqueId, err = l.getUniqueId(userInfo, userAgent, retryTimes-1) if err != nil { return "", err } @@ -277,7 +281,7 @@ func (w *wsConnectItem) close() { } // 读取出口缓冲队列数据输出返回给浏览器端 -func (w *wsConnectItem) writeLoop() { +func (w *wsConnectItem) consumeOutChanData() { defer func() { if err := recover(); err != nil { logx.Error("write loop panic:", err) @@ -297,8 +301,25 @@ func (w *wsConnectItem) writeLoop() { } } -// 接受客户端发来的消息并写入入口缓冲队列 -func (w *wsConnectItem) readLoop() { +// 消费websocket入口数据池中的数据 +func (w *wsConnectItem) consumeInChanData() { + defer func() { + if err := recover(); err != nil { + logx.Error("send loop panic:", err) + } + }() + for { + select { + case <-w.closeChan: + return + case data := <-w.inChan: + w.dealwithReciveData(data) + } + } +} + +// 接受浏览器端发来的消息并写入入口缓冲队列 +func (w *wsConnectItem) acceptBrowserMessage() { defer func() { if err := recover(); err != nil { logx.Error("read loop panic:", err) @@ -324,23 +345,6 @@ func (w *wsConnectItem) readLoop() { } } -// 消费websocket入口数据池中的数据 -func (w *wsConnectItem) sendLoop() { - defer func() { - if err := recover(); err != nil { - logx.Error("send loop panic:", err) - } - }() - for { - select { - case <-w.closeChan: - return - case data := <-w.inChan: - w.dealwithReciveData(data) - } - } -} - // 把要传递给客户端的数据放入出口缓冲队列 func (w *wsConnectItem) sendToOutChan(data []byte) { select { diff --git a/server/websocket/internal/logic/ws_render_image.go b/server/websocket/internal/logic/ws_render_image.go index d9a513f8..bab88207 100644 --- a/server/websocket/internal/logic/ws_render_image.go +++ b/server/websocket/internal/logic/ws_render_image.go @@ -56,8 +56,8 @@ func (w *wsConnectItem) sendToRenderChan(data []byte) { } } -// 渲染发送到组装数据组装数据(缓冲队列) -func (w *wsConnectItem) renderImage() { +// 消费渲染缓冲队列数据 +func (w *wsConnectItem) consumeRenderImageData() { defer func() { if err := recover(); err != nil { logx.Error("func renderImage err:", err) diff --git a/server/websocket/internal/logic/ws_reuse_last_connect.go b/server/websocket/internal/logic/ws_reuse_last_connect.go index 0790024e..7ef64ec7 100644 --- a/server/websocket/internal/logic/ws_reuse_last_connect.go +++ b/server/websocket/internal/logic/ws_reuse_last_connect.go @@ -26,7 +26,7 @@ func (w *wsConnectItem) reuseLastConnect(data []byte) { } lendecryptionWid := len(decryptionWid) //合成client后缀,不是同个后缀的不能复用 - userPart := getUserJoinPart(w.userId, w.guestId) + userPart := getUserJoinPart(w.userId, w.guestId, w.userAgent) lenUserPart := len(userPart) if lendecryptionWid <= lenUserPart { w.reuseLastConnErrResponse("length of client id is to short") @@ -34,7 +34,7 @@ func (w *wsConnectItem) reuseLastConnect(data []byte) { } //尾部不同不能复用 if decryptionWid[lendecryptionWid-lenUserPart:] != userPart { - w.reuseLastConnErrResponse("the client id is not belong you before") + w.reuseLastConnErrResponse("the client id is not belong to you before") return } //存在是不能给他申请重新绑定 @@ -62,6 +62,6 @@ func (w *wsConnectItem) reuseLastConnect(data []byte) { } // 获取用户拼接部分(复用标识用到) -func getUserJoinPart(userId, guestId int64) string { - return fmt.Sprintf("|_%d_%d_|", userId, guestId) +func getUserJoinPart(userId, guestId int64, userAgent string) string { + return fmt.Sprintf("|_%d_%d_|_%s_|", userId, guestId, userAgent) }