diff --git a/config/server_ws.go b/config/server_ws.go index 15947cf..bc098fa 100644 --- a/config/server_ws.go +++ b/config/server_ws.go @@ -25,6 +25,9 @@ type ( TLSCertificate string `json:",optional"` // TLS 证书key地址 TLSKey string `json:",optional"` + } + WSServerFullConf struct { + WSServerConf // check origin CheckOrigin func(*http.Request) bool `json:",optional"` } diff --git a/server_ws.go b/server_ws.go index eda6bab..fbba3e4 100644 --- a/server_ws.go +++ b/server_ws.go @@ -11,16 +11,16 @@ import ( "strings" ) -type WsConfOption func(conf config.WSServerConf) +type WsConfOption func(conf config.WSServerFullConf) func WithWSCheckOrigin(fn func(*http.Request) bool) WsConfOption { - return func(conf config.WSServerConf) { + return func(conf config.WSServerFullConf) { conf.CheckOrigin = fn } } // ListenWebsocket 开始监听Websocket -func (ngin *Engine) ListenWebsocket(conf config.WSServerConf, opts ...WsConfOption) error { +func (ngin *Engine) ListenWebsocket(conf config.WSServerFullConf, opts ...WsConfOption) error { for _, opt := range opts { opt(conf) } @@ -51,7 +51,7 @@ func (ngin *Engine) handleWS(conn *websocket.Conn) { ngin.handle(wsConn) } -func (ngin *Engine) upgradeWebsocket(conf config.WSServerConf) { +func (ngin *Engine) upgradeWebsocket(conf config.WSServerFullConf) { upgrade := websocket.Upgrader{ HandshakeTimeout: conf.HandshakeTimeout, ReadBufferSize: conf.ReadBufferSize,