package nnet import ( "fmt" "git.noahlan.cn/noahlan/nnet/config" "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/ntool/nlog" "github.com/gorilla/websocket" "net/http" "os" "strings" ) type WsConfOption func(conf config.WSServerFullConf) func WithWSCheckOrigin(fn func(*http.Request) bool) WsConfOption { return func(conf config.WSServerFullConf) { conf.CheckOrigin = fn } } // ListenWebsocket 开始监听Websocket func (ngin *Engine) ListenWebsocket(conf config.WSServerFullConf, opts ...WsConfOption) error { for _, opt := range opts { opt(conf) } err := ngin.setup() if err != nil { nlog.Errorf("%s failed to setup server, err:%v", ngin.LogPrefix(), err) return err } nlog.Infof("%s now listening websocket at %s.", ngin.LogPrefix(), conf.Addr) ngin.upgradeWebsocket(conf) if conf.IsTLS() { if err := http.ListenAndServeTLS(conf.Addr, conf.TLSCertificate, conf.TLSKey, nil); err != nil { nlog.Errorf("%s failed to listening websocket with TLS at %s %v", ngin.LogPrefix(), conf.Addr, err) return err } } else { if err := http.ListenAndServe(conf.Addr, nil); err != nil { nlog.Errorf("%s failed to listening websocket at %s %v", ngin.LogPrefix(), conf.Addr, err) return err } } return nil } func (ngin *Engine) handleWS(conn *websocket.Conn, conf config.WSServerFullConf) { wsConn := connection.NewWSConn(conn) //defaultCloseHandler := conn.CloseHandler() //conn.SetCloseHandler(func(code int, text string) error { // result := defaultCloseHandler(code, text) // //wsConn.Close() // return result //}) // ping defaultPingHandler := wsConn.PingHandler() wsConn.SetPingHandler(func(appData string) error { if conf.PingHandler != nil { conf.PingHandler(appData) } return defaultPingHandler(appData) }) // pong defaultPongHandler := wsConn.PongHandler() wsConn.SetPongHandler(func(appData string) error { if conf.PongHandler != nil { conf.PongHandler(appData) } return defaultPongHandler(appData) }) ngin.handle(wsConn) } func (ngin *Engine) upgradeWebsocket(conf config.WSServerFullConf) { upgrade := websocket.Upgrader{ HandshakeTimeout: conf.HandshakeTimeout, ReadBufferSize: conf.ReadBufferSize, WriteBufferSize: conf.WriteBufferSize, CheckOrigin: conf.CheckOrigin, EnableCompression: conf.Compression, } path := fmt.Sprintf("/%s", strings.TrimPrefix(conf.Path, "/")) http.HandleFunc(path, func(writer http.ResponseWriter, request *http.Request) { conn, err := upgrade.Upgrade(writer, request, nil) if err != nil { nlog.Errorf("%s Upgrade failure, URI=%ng, Error=%ng", ngin.LogPrefix(), request.RequestURI, err.Error()) return } err = ngin.goPool.Submit(func() { ngin.handleWS(conn, conf) }) if err != nil { nlog.Errorf("%s submit conn pool err: %v", ngin.LogPrefix(), err.Error()) os.Exit(1) } }) }