diff --git a/proxyserver/main.go b/proxyserver/main.go index 066a16d1..05631968 100644 --- a/proxyserver/main.go +++ b/proxyserver/main.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/gorilla/websocket" "gopkg.in/yaml.v2" ) @@ -115,6 +116,7 @@ type Backend struct { HttpAddress string Client *http.Client Handler http.HandlerFunc + Dialer *websocket.Dialer } func NewBackend(mux *http.ServeMux, httpAddress string, muxPaths ...string) *Backend { @@ -142,14 +144,63 @@ func NewBackend(mux *http.ServeMux, httpAddress string, muxPaths ...string) *Bac }, } + dialer := &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + NetDial: func(network, addr string) (net.Conn, error) { + return net.Dial(network, addr) + }, + } + // 创建后端服务对象,包含地址和客户端 backend := &Backend{ HttpAddress: httpAddress, Client: client, + Dialer: dialer, } // 创建处理请求的函数 handleRequest := func(w http.ResponseWriter, r *http.Request) { + + if websocket.IsWebSocketUpgrade(r) { + //todo: 建立websocket的代理 + + target := url.URL{Scheme: "ws", Host: backend.HttpAddress, Path: r.URL.Path} + + var transfer = func(src, dest *websocket.Conn) { + for { + mType, msg, err := src.ReadMessage() + if err != nil { + break + } + + err = dest.WriteMessage(mType, msg) + if err != nil { + break + } + } + + src.Close() + dest.Close() + } + + proxyConn, _, err := backend.Dialer.Dial(target.String(), nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer proxyConn.Close() + + upgrader := websocket.Upgrader{} + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + + go transfer(proxyConn, conn) + go transfer(conn, proxyConn) + } + // 解析目标URL,包含了查询参数 targetURL, err := url.Parse(httpAddress + r.URL.String()) if err != nil {