diff --git a/server/websocket/internal/logic/datatransferlogic.go b/server/websocket/internal/logic/datatransferlogic.go index 21ee7614..2b0f36d0 100644 --- a/server/websocket/internal/logic/datatransferlogic.go +++ b/server/websocket/internal/logic/datatransferlogic.go @@ -8,6 +8,7 @@ import ( "fusenapi/constants" "fusenapi/server/websocket/internal/websocket_data" "fusenapi/utils/auth" + "fusenapi/utils/encryption_decryption" "net/http" "sync" "time" @@ -158,13 +159,21 @@ func (l *DataTransferLogic) setConnPool(conn *websocket.Conn, userInfo auth.User } // 获取唯一id -func (l *DataTransferLogic) getUniqueId(userInfo auth.UserInfo) string { +func (l *DataTransferLogic) getUniqueId(userInfo auth.UserInfo) (uniqueId string, err error) { //后面拼接上用户id - uniqueId := hex.EncodeToString([]byte(uuid.New().String())) + getUserJoinPart(userInfo.UserId, userInfo.GuestId) + uniqueId = hex.EncodeToString([]byte(uuid.New().String())) + getUserJoinPart(userInfo.UserId, userInfo.GuestId) if _, ok := mapConnPool.Load(uniqueId); ok { - uniqueId = l.getUniqueId(userInfo) + uniqueId, err = l.getUniqueId(userInfo) + if err != nil { + return "", err + } } - return uniqueId + //加密 + uniqueId, err = encryption_decryption.CBCEncrypt(uniqueId) + if err != nil { + return "", err + } + return uniqueId, nil } // 鉴权 diff --git a/server/websocket/internal/logic/ws_reuse_last_connect.go b/server/websocket/internal/logic/ws_reuse_last_connect.go index dcf1b352..24fd174e 100644 --- a/server/websocket/internal/logic/ws_reuse_last_connect.go +++ b/server/websocket/internal/logic/ws_reuse_last_connect.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "fusenapi/constants" + "fusenapi/utils/encryption_decryption" "github.com/zeromicro/go-zero/core/logx" ) @@ -17,21 +18,22 @@ func (w *wsConnectItem) reuseLastConnect(data []byte) { w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR, "invalid format of client id")) return } - lenClientId := len(clientId) - //id长度太长 - if lenClientId > 500 { - w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR, "length of client id is to long")) + //解密 + decryptionClientId, err := encryption_decryption.CBCDecrypt(clientId) + if err != nil { + w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR, "invalid client id")) return } + lendecryptionClientId := len(decryptionClientId) //合成client后缀,不是同个后缀的不能复用 userPart := getUserJoinPart(w.userId, w.guestId) lenUserPart := len(userPart) - if lenClientId <= lenUserPart { + if lendecryptionClientId <= lenUserPart { w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR, "length of client id is to short")) return } //尾部不同不能复用 - if clientId[lenClientId-lenUserPart:] != userPart { + if decryptionClientId[lendecryptionClientId-lenUserPart:] != userPart { w.sendToOutChan(w.respondDataFormat(constants.WEBSOCKET_REQUEST_RESUME_LAST_CONNECT_ERR, "the client id is not belong you before")) return } @@ -65,5 +67,5 @@ func (w *wsConnectItem) reuseLastConnect(data []byte) { // 获取用户拼接部分(复用标识用到) func getUserJoinPart(userId, guestId int64) string { - return fmt.Sprintf("i%di%d", userId, guestId) + return fmt.Sprintf("|%d_%d", userId, guestId) }