package ws import ( "github.com/gorilla/websocket" "github.com/jpillora/backoff" "github.com/pkg/errors" "math/rand" "net/http" "sync" "time" ) type event struct { // 连接成功回调 onConnected func() // 连接异常回调,在准备进行连接的过程中发生异常时触发 onConnectError func(err error) // 连接断开回调,网络异常,服务端掉线等情况时触发 onDisconnected func(err error) // 连接关闭回调,服务端发起关闭信号或客户端主动关闭时触发 onClose func(code int, text string) // 发送Text消息成功回调 onTextMessageSent func(message string) // 发送Binary消息成功回调 onBinaryMessageSent func(data []byte) // 发送消息异常回调 onSentError func(err error) // 接受到Ping消息回调 onPingReceived func(appData string) // 接受到Pong消息回调 onPongReceived func(appData string) // 接受到Text消息回调 onTextMessageReceived func(message string) // 接受到Binary消息回调 onBinaryMessageReceived func(v interface{}) // 接收消息异常回调 onReceiveError func(err error) } type NWebsocket struct { Config *Config // 配置 webSocket *WebSocket // 底层 webSocket *event // evt } type WebSocket struct { wsConn *websocket.Conn // 底层 webSocket 连接 url string // 连接地址,用于重连 dialer *websocket.Dialer // dialer requestHeader http.Header // 连接请求header httpResponse *http.Response // 连接失败的返回消息 packer Packer // Packer 打包解包器 codec Codec // 编解码 readChan chan *Entry // 读channel(队列) sendChan chan *Entry // 写channel(队列) sendLock sync.Mutex // 写消息锁 isConnected bool // 是否已连接 connLock sync.RWMutex // 连接锁,保证只被执行一次 } type Config struct { *BackoffOptions ReadBufferSize int // 最大读缓冲区大小, 默认1024条消息 SendBufferSize int // 最大写缓冲区大小, 默认512条消息 ReadLimit int64 `json:",default=8192"` // 单条消息支持的最大消息长度,默认 8MB WriteDeadline time.Duration // 写超时,默认 5s ReadDeadline time.Duration // 读超时,控制断线检测 } type BackoffOptions struct { MinRecTime time.Duration // 最小重连时间间隔 MaxRecTime time.Duration // 最大重连时间间隔 RecFactor float64 // 每次重连失败继续重连的时间间隔递增的乘数因子,递增到最大重连时间间隔为止 } type ConnectionOptions struct { *Config Packer Packer // 打包解包器 Codec Codec // 编解码器 } type ConnectionOption func(*ConnectionOptions) func applyOpts(opts ...ConnectionOption) *ConnectionOptions { result := &ConnectionOptions{ Config: &Config{ BackoffOptions: &BackoffOptions{ MinRecTime: 2 * time.Second, MaxRecTime: 60 * time.Second, RecFactor: 1.5, }, ReadBufferSize: 1024, SendBufferSize: 512, ReadLimit: 8 << 10 << 10, WriteDeadline: 5 * time.Second, ReadDeadline: 60 * time.Second, }, } for _, opt := range opts { opt(result) } return result } func WithPacker(packer Packer) ConnectionOption { return func(options *ConnectionOptions) { options.Packer = packer } } func WithCodec(codec Codec) ConnectionOption { return func(options *ConnectionOptions) { options.Codec = codec } } func WithBackoff(b *BackoffOptions) ConnectionOption { return func(options *ConnectionOptions) { options.BackoffOptions = b } } var ( CloseErr = errors.New("connection closed") BufferErr = errors.New("message buffer is full") PackerNotFoundErr = errors.New("packer not found") //PackErr = errors.New("packs message err") ) func NewWsConnection(opts ...ConnectionOption) *NWebsocket { opt := applyOpts(opts...) return &NWebsocket{ Config: opt.Config, webSocket: &WebSocket{ dialer: websocket.DefaultDialer, requestHeader: http.Header{}, packer: opt.Packer, codec: opt.Codec, isConnected: false, }, event: &event{}, } } func (c *NWebsocket) OnConnected(f func()) { c.onConnected = f } func (c *NWebsocket) OnConnectError(f func(err error)) { c.onConnectError = f } func (c *NWebsocket) OnDisconnected(f func(err error)) { c.onDisconnected = f } func (c *NWebsocket) OnClose(f func(code int, text string)) { c.onClose = f } func (c *NWebsocket) OnTextMessageSent(f func(message string)) { c.onTextMessageSent = f } func (c *NWebsocket) OnBinaryMessageSent(f func(data []byte)) { c.onBinaryMessageSent = f } func (c *NWebsocket) OnSentError(f func(err error)) { c.onSentError = f } func (c *NWebsocket) OnPingReceived(f func(appData string)) { c.onPingReceived = f } func (c *NWebsocket) OnPongReceived(f func(appData string)) { c.onPongReceived = f } func (c *NWebsocket) OnTextMessageReceived(f func(message string)) { c.onTextMessageReceived = f } func (c *NWebsocket) OnBinaryMessageReceived(f func(v interface{})) { c.onBinaryMessageReceived = f } func (c *NWebsocket) OnReceiveError(f func(err error)) { c.onReceiveError = f } func (c *NWebsocket) Connect(url string) { c.webSocket.sendChan = make(chan *Entry, c.Config.SendBufferSize) c.webSocket.readChan = make(chan *Entry, c.Config.ReadBufferSize) c.webSocket.url = url b := &backoff.Backoff{ Min: c.Config.MinRecTime, Max: c.Config.MaxRecTime, Factor: c.Config.RecFactor, Jitter: true, } rand.Seed(time.Now().UTC().UnixNano()) for { var err error nextRec := b.Duration() c.webSocket.wsConn, c.webSocket.httpResponse, err = c.webSocket.dialer.Dial(c.webSocket.url, c.webSocket.requestHeader) if err != nil { if c.onConnectError != nil { c.onConnectError(err) } // 连接重试 time.Sleep(nextRec) continue } // 设置连接状态 c.webSocket.connLock.Lock() c.webSocket.isConnected = true c.webSocket.connLock.Unlock() // 连接成功回调 if c.onConnected != nil { go c.onConnected() } // 设置支持接受的消息最大长度 c.webSocket.wsConn.SetReadLimit(c.Config.ReadLimit) // 连接关闭回调 defaultCloseHandler := c.webSocket.wsConn.CloseHandler() c.webSocket.wsConn.SetCloseHandler(func(code int, text string) error { result := defaultCloseHandler(code, text) c.clean() if c.onClose != nil { c.onClose(code, text) } return result }) // 收到ping回调 defaultPingHandler := c.webSocket.wsConn.PingHandler() c.webSocket.wsConn.SetPingHandler(func(appData string) error { if c.onPingReceived != nil { c.onPingReceived(appData) } return defaultPingHandler(appData) }) // 收到pong回调 defaultPongHandler := c.webSocket.wsConn.PongHandler() c.webSocket.wsConn.SetPongHandler(func(appData string) error { if c.onPongReceived != nil { c.onPongReceived(appData) } return defaultPongHandler(appData) }) // 读写 go c.readInternalLoop() go c.readLoop() go c.sendLoop() return } } func (c *NWebsocket) SendTextMessage(message string) (err error) { return c.write(&Entry{ MessageType: websocket.TextMessage, Raw: []byte(message), }) } func (c *NWebsocket) SendBinaryMessage(v interface{}) (err error) { // wrap message msg, err := c.wrapSendMessage(v) if err != nil { return err } return c.write(msg) } func (c *NWebsocket) readLoop() { if c.webSocket.packer == nil { if c.onReceiveError != nil { c.onReceiveError(PackerNotFoundErr) } return } for { // pack: ws.Entry -> CustomEntry // decode: CustomEntry -> AnyType select { case entry, ok := <-c.webSocket.readChan: if !ok { break } if entry.MessageType == websocket.TextMessage { if c.onTextMessageReceived != nil { go c.onTextMessageReceived(string(entry.Raw)) } break } packed, err := c.webSocket.packer.Pack(entry) if err != nil { if c.onReceiveError != nil { c.onReceiveError(errors.Wrapf(err, "pack msg err: %+v", err)) } break } if e, ok := packed.([]*Entry); ok { // 解包之后是 Entry Slice,全部丢回队列里继续处理 c.pushIn(e...) break } msg := packed if c.webSocket.codec != nil { msg, err = c.webSocket.codec.Decode(packed) if err != nil { if c.onReceiveError != nil { c.onReceiveError(errors.Wrapf(err, "decode msg err: %+v", err)) } break } } if c.onBinaryMessageReceived != nil { go c.onBinaryMessageReceived(msg) } } } } func (c *NWebsocket) write(entry *Entry) (err error) { if c.Closed() { return CloseErr } // 写缓冲区 select { case c.webSocket.sendChan <- entry: default: err = BufferErr } return nil } func (c *NWebsocket) Closed() bool { c.webSocket.connLock.RLock() defer c.webSocket.connLock.RUnlock() return !c.webSocket.isConnected } func (c *NWebsocket) Close() { c.CloseWithMsg("") } func (c *NWebsocket) CloseWithMsg(msg string) { if c.Closed() { return } _ = c.send(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, msg)) c.clean() if c.onClose != nil { c.onClose(websocket.CloseNormalClosure, msg) } } func (c *NWebsocket) wrapSendMessage(v interface{}) (*Entry, error) { if c.webSocket.packer == nil { return nil, PackerNotFoundErr } customEntry := v if c.webSocket.codec != nil { encode, err := c.webSocket.codec.Encode(v) if err != nil { return nil, err } customEntry = encode } msg, err := c.webSocket.packer.Unpack(customEntry) if err != nil { return nil, err } return msg, nil } // CloseAndReConnect 断线重连 func (c *NWebsocket) CloseAndReConnect() { if c.Closed() { return } c.clean() go c.Connect(c.webSocket.url) } func (c *NWebsocket) readInternalLoop() { for { select { default: //_ = c.webSocket.wsConn.SetReadDeadline(time.Now().Add(c.Config.ReadDeadline)) msgType, data, err := c.webSocket.wsConn.ReadMessage() if err != nil { if c.onDisconnected != nil { c.onDisconnected(err) } // 断线重连 c.CloseAndReConnect() return } entry := &Entry{ MessageType: msgType, Raw: data, } c.pushIn(entry) } } } func (c *NWebsocket) pushIn(entries ...*Entry) { pushTimeout := func(entry *Entry) { // 5s timeout after := time.NewTicker(5 * time.Second) defer after.Stop() select { case <-after.C: return case c.webSocket.readChan <- entry: return } } // 批量 push, 顺序无关 for _, entry := range entries { go pushTimeout(entry) } } func (c *NWebsocket) sendLoop() { for { select { case data, ok := <-c.webSocket.sendChan: if !ok { return } if err := c.send(data.MessageType, data.Raw); err != nil { if c.onSentError != nil { c.onSentError(err) } continue } } } } func (c *NWebsocket) send(messageType int, data []byte) error { c.webSocket.sendLock.Lock() defer c.webSocket.sendLock.Unlock() if c.Closed() { return CloseErr } _ = c.webSocket.wsConn.SetWriteDeadline(time.Now().Add(c.Config.WriteDeadline)) return c.webSocket.wsConn.WriteMessage(messageType, data) } // clean 关闭连接 并清理所有能清理的channel func (c *NWebsocket) clean() { c.webSocket.connLock.Lock() _ = c.webSocket.wsConn.Close() close(c.webSocket.readChan) close(c.webSocket.sendChan) c.webSocket.isConnected = false c.webSocket.connLock.Unlock() }