diff --git a/server/websocket/internal/logic/datatransferlogic.go b/server/websocket/internal/logic/datatransferlogic.go index d3725a5b..6c724f6a 100644 --- a/server/websocket/internal/logic/datatransferlogic.go +++ b/server/websocket/internal/logic/datatransferlogic.go @@ -3,7 +3,6 @@ package logic import ( "bytes" "encoding/json" - "fmt" "fusenapi/constants" "fusenapi/utils/auth" "fusenapi/utils/id_generator" @@ -99,16 +98,8 @@ func (l *DataTransferLogic) DataTransfer(w http.ResponseWriter, r *http.Request) ) isAuth, userInfo = l.checkAuth(r) if !isAuth { - time.Sleep(time.Second * 1) //兼容下火狐(直接发回去收不到第一条消息:有待研究) - rsp := websocket_data.DataTransferData{ - T: constants.WEBSOCKET_UNAUTH, - D: nil, - } - b, _ := json.Marshal(rsp) - //先发一条正常信息 - _ = conn.WriteMessage(websocket.TextMessage, b) - //发送关闭信息 - _ = conn.WriteMessage(websocket.CloseMessage, nil) + //未授权响应消息 + l.unAuthResponse(conn) return } //设置连接 @@ -162,18 +153,13 @@ func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo auth.User // 获取唯一id func (l *DataTransferLogic) getUniqueId(userInfo auth.UserInfo) string { //后面拼接上用户id - uniqueId := uuid.New().String() + getUserPart(userInfo.UserId, userInfo.GuestId) + uniqueId := uuid.New().String() + getUserJoinPart(userInfo.UserId, userInfo.GuestId) if _, ok := mapConnPool.Load(uniqueId); ok { uniqueId = l.getUniqueId(userInfo) } return uniqueId } -// 获取用户拼接部分 -func getUserPart(userId, guestId int64) string { - return fmt.Sprintf("_%d_%d", userId, guestId) -} - // 鉴权 func (l *DataTransferLogic) checkAuth(r *http.Request) (isAuth bool, userInfo *auth.UserInfo) { // 解析JWT token,并对空用户进行判断 @@ -200,6 +186,22 @@ func (l *DataTransferLogic) checkAuth(r *http.Request) (isAuth bool, userInfo *a return false, nil } +// 鉴权失败通知 +func (l *DataTransferLogic) unAuthResponse(conn *websocket.Conn) { + time.Sleep(time.Second * 1) //兼容下火狐(直接发回去收不到第一条消息:有待研究) + rsp := websocket_data.DataTransferData{ + T: constants.WEBSOCKET_UNAUTH, + D: nil, + } + b, _ := json.Marshal(rsp) + //先发一条正常信息 + _ = conn.WriteMessage(websocket.TextMessage, b) + //发送关闭信息 + _ = conn.WriteMessage(websocket.CloseMessage, nil) + //关闭连接 + conn.Close() +} + // 心跳 func (w *wsConnectItem) heartbeat() { tick := time.Tick(time.Second * 5) diff --git a/server/websocket/internal/logic/ws_reuse_last_connect.go b/server/websocket/internal/logic/ws_reuse_last_connect.go index 8d591a98..b31ee320 100644 --- a/server/websocket/internal/logic/ws_reuse_last_connect.go +++ b/server/websocket/internal/logic/ws_reuse_last_connect.go @@ -2,6 +2,7 @@ package logic import ( "encoding/json" + "fmt" "fusenapi/constants" "github.com/zeromicro/go-zero/core/logx" ) @@ -22,7 +23,7 @@ func (w *wsConnectItem) reuseLastConnect(data []byte) { return } //合成client后缀,不是同个后缀的不能复用 - userPart := getUserPart(w.userId, w.guestId) + userPart := getUserJoinPart(w.userId, w.guestId) lenUserPart := len(userPart) if lenClientId <= lenUserPart { w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR, "length of client id is to short")) @@ -60,3 +61,8 @@ func (w *wsConnectItem) reuseLastConnect(data []byte) { w.sendToOutChan(rsp) return } + +// 获取用户拼接部分(复用标识用到) +func getUserJoinPart(userId, guestId int64) string { + return fmt.Sprintf("_%d_%d", userId, guestId) +}