You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
479 lines
11 KiB
Go
479 lines
11 KiB
Go
package ws
|
|
|
|
import (
|
|
"github.com/gorilla/websocket"
|
|
"github.com/jpillora/backoff"
|
|
"github.com/pkg/errors"
|
|
"math/rand"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type event struct {
|
|
// 连接成功回调
|
|
onConnected func()
|
|
// 连接异常回调,在准备进行连接的过程中发生异常时触发
|
|
onConnectError func(err error)
|
|
// 连接断开回调,网络异常,服务端掉线等情况时触发
|
|
onDisconnected func(err error)
|
|
// 连接关闭回调,服务端发起关闭信号或客户端主动关闭时触发
|
|
onClose func(code int, text string)
|
|
|
|
// 发送Text消息成功回调
|
|
onTextMessageSent func(message string)
|
|
// 发送Binary消息成功回调
|
|
onBinaryMessageSent func(data []byte)
|
|
// 发送消息异常回调
|
|
onSentError func(err error)
|
|
|
|
// 接受到Ping消息回调
|
|
onPingReceived func(appData string)
|
|
// 接受到Pong消息回调
|
|
onPongReceived func(appData string)
|
|
|
|
// 接受到Text消息回调
|
|
onTextMessageReceived func(message string)
|
|
// 接受到Binary消息回调
|
|
onBinaryMessageReceived func(v interface{})
|
|
|
|
// 接收消息异常回调
|
|
onReceiveError func(err error)
|
|
}
|
|
|
|
type NWebsocket struct {
|
|
Config *Config // 配置
|
|
webSocket *WebSocket // 底层 webSocket
|
|
*event // evt
|
|
}
|
|
|
|
type WebSocket struct {
|
|
wsConn *websocket.Conn // 底层 webSocket 连接
|
|
url string // 连接地址,用于重连
|
|
dialer *websocket.Dialer // dialer
|
|
requestHeader http.Header // 连接请求header
|
|
httpResponse *http.Response // 连接失败的返回消息
|
|
|
|
packer Packer // Packer 打包解包器
|
|
codec Codec // 编解码
|
|
|
|
readChan chan *Entry // 读channel(队列)
|
|
sendChan chan *Entry // 写channel(队列)
|
|
sendLock sync.Mutex // 写消息锁
|
|
|
|
isConnected bool // 是否已连接
|
|
connLock sync.RWMutex // 连接锁,保证只被执行一次
|
|
}
|
|
|
|
type Config struct {
|
|
*BackoffOptions
|
|
ReadBufferSize int // 最大读缓冲区大小, 默认1024条消息
|
|
SendBufferSize int // 最大写缓冲区大小, 默认512条消息
|
|
|
|
ReadLimit int64 `json:",default=8192"` // 单条消息支持的最大消息长度,默认 8MB
|
|
WriteDeadline time.Duration // 写超时,默认 5s
|
|
ReadDeadline time.Duration // 读超时,控制断线检测
|
|
}
|
|
|
|
type BackoffOptions struct {
|
|
MinRecTime time.Duration // 最小重连时间间隔
|
|
MaxRecTime time.Duration // 最大重连时间间隔
|
|
RecFactor float64 // 每次重连失败继续重连的时间间隔递增的乘数因子,递增到最大重连时间间隔为止
|
|
}
|
|
|
|
type ConnectionOptions struct {
|
|
*Config
|
|
Packer Packer // 打包解包器
|
|
Codec Codec // 编解码器
|
|
}
|
|
|
|
type ConnectionOption func(*ConnectionOptions)
|
|
|
|
func applyOpts(opts ...ConnectionOption) *ConnectionOptions {
|
|
result := &ConnectionOptions{
|
|
Config: &Config{
|
|
BackoffOptions: &BackoffOptions{
|
|
MinRecTime: 2 * time.Second,
|
|
MaxRecTime: 60 * time.Second,
|
|
RecFactor: 1.5,
|
|
},
|
|
ReadBufferSize: 1024,
|
|
SendBufferSize: 512,
|
|
ReadLimit: 8 << 10 << 10,
|
|
WriteDeadline: 5 * time.Second,
|
|
ReadDeadline: 5 * time.Second,
|
|
},
|
|
}
|
|
for _, opt := range opts {
|
|
opt(result)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func WithPacker(packer Packer) ConnectionOption {
|
|
return func(options *ConnectionOptions) {
|
|
options.Packer = packer
|
|
}
|
|
}
|
|
|
|
func WithBackoff(b *BackoffOptions) ConnectionOption {
|
|
return func(options *ConnectionOptions) {
|
|
options.BackoffOptions = b
|
|
}
|
|
}
|
|
|
|
var (
|
|
CloseErr = errors.New("connection closed")
|
|
BufferErr = errors.New("message buffer is full")
|
|
PackerNotFoundErr = errors.New("packer not found")
|
|
//PackErr = errors.New("packs message err")
|
|
)
|
|
|
|
func NewWsConnection(opts ...ConnectionOption) *NWebsocket {
|
|
opt := applyOpts(opts...)
|
|
return &NWebsocket{
|
|
Config: opt.Config,
|
|
webSocket: &WebSocket{
|
|
dialer: websocket.DefaultDialer,
|
|
requestHeader: http.Header{},
|
|
|
|
packer: opt.Packer,
|
|
codec: opt.Codec,
|
|
isConnected: false,
|
|
},
|
|
event: &event{},
|
|
}
|
|
}
|
|
|
|
func (c *NWebsocket) OnConnected(f func()) {
|
|
c.onConnected = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnConnectError(f func(err error)) {
|
|
c.onConnectError = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnDisconnected(f func(err error)) {
|
|
c.onDisconnected = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnClose(f func(code int, text string)) {
|
|
c.onClose = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnTextMessageSent(f func(message string)) {
|
|
c.onTextMessageSent = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnBinaryMessageSent(f func(data []byte)) {
|
|
c.onBinaryMessageSent = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnSentError(f func(err error)) {
|
|
c.onSentError = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnPingReceived(f func(appData string)) {
|
|
c.onPingReceived = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnPongReceived(f func(appData string)) {
|
|
c.onPongReceived = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnTextMessageReceived(f func(message string)) {
|
|
c.onTextMessageReceived = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnBinaryMessageReceived(f func(v interface{})) {
|
|
c.onBinaryMessageReceived = f
|
|
}
|
|
|
|
func (c *NWebsocket) OnReceiveError(f func(err error)) {
|
|
c.onReceiveError = f
|
|
}
|
|
|
|
func (c *NWebsocket) Connect(url string) {
|
|
c.webSocket.sendChan = make(chan *Entry, c.Config.SendBufferSize)
|
|
c.webSocket.readChan = make(chan *Entry, c.Config.ReadBufferSize)
|
|
c.webSocket.url = url
|
|
|
|
b := &backoff.Backoff{
|
|
Min: c.Config.MinRecTime,
|
|
Max: c.Config.MaxRecTime,
|
|
Factor: c.Config.RecFactor,
|
|
Jitter: true,
|
|
}
|
|
rand.Seed(time.Now().UTC().UnixNano())
|
|
|
|
for {
|
|
var err error
|
|
nextRec := b.Duration()
|
|
c.webSocket.wsConn, c.webSocket.httpResponse, err = c.webSocket.dialer.Dial(c.webSocket.url, c.webSocket.requestHeader)
|
|
if err != nil {
|
|
if c.onConnectError != nil {
|
|
c.onConnectError(err)
|
|
}
|
|
// 连接重试
|
|
time.Sleep(nextRec)
|
|
continue
|
|
}
|
|
|
|
// 设置连接状态
|
|
c.webSocket.connLock.Lock()
|
|
c.webSocket.isConnected = true
|
|
c.webSocket.connLock.Unlock()
|
|
|
|
// 连接成功回调
|
|
if c.onConnected != nil {
|
|
go c.onConnected()
|
|
}
|
|
// 设置支持接受的消息最大长度
|
|
c.webSocket.wsConn.SetReadLimit(c.Config.ReadLimit)
|
|
// 连接关闭回调
|
|
defaultCloseHandler := c.webSocket.wsConn.CloseHandler()
|
|
c.webSocket.wsConn.SetCloseHandler(func(code int, text string) error {
|
|
result := defaultCloseHandler(code, text)
|
|
c.clean()
|
|
if c.onClose != nil {
|
|
c.onClose(code, text)
|
|
}
|
|
return result
|
|
})
|
|
|
|
// 收到ping回调
|
|
defaultPingHandler := c.webSocket.wsConn.PingHandler()
|
|
c.webSocket.wsConn.SetPingHandler(func(appData string) error {
|
|
if c.onPingReceived != nil {
|
|
c.onPingReceived(appData)
|
|
}
|
|
return defaultPingHandler(appData)
|
|
})
|
|
// 收到pong回调
|
|
defaultPongHandler := c.webSocket.wsConn.PongHandler()
|
|
c.webSocket.wsConn.SetPongHandler(func(appData string) error {
|
|
if c.onPongReceived != nil {
|
|
c.onPongReceived(appData)
|
|
}
|
|
return defaultPongHandler(appData)
|
|
})
|
|
|
|
// 读写
|
|
go c.readInternalLoop()
|
|
go c.readLoop()
|
|
go c.sendLoop()
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
func (c *NWebsocket) SendTextMessage(message string) (err error) {
|
|
return c.write(&Entry{
|
|
MessageType: websocket.TextMessage,
|
|
Raw: []byte(message),
|
|
})
|
|
}
|
|
|
|
func (c *NWebsocket) SendBinaryMessage(v interface{}) (err error) {
|
|
// wrap message
|
|
msg, err := c.wrapSendMessage(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.write(msg)
|
|
}
|
|
|
|
func (c *NWebsocket) readLoop() {
|
|
if c.webSocket.packer == nil {
|
|
if c.onReceiveError != nil {
|
|
c.onReceiveError(PackerNotFoundErr)
|
|
}
|
|
return
|
|
}
|
|
for {
|
|
// pack: ws.Entry -> CustomEntry
|
|
// decode: CustomEntry -> AnyType
|
|
select {
|
|
case entry, ok := <-c.webSocket.readChan:
|
|
if !ok {
|
|
break
|
|
}
|
|
if entry.MessageType == websocket.TextMessage {
|
|
if c.onTextMessageReceived != nil {
|
|
go c.onTextMessageReceived(string(entry.Raw))
|
|
}
|
|
break
|
|
}
|
|
packed, err := c.webSocket.packer.Pack(entry)
|
|
if err != nil {
|
|
if c.onReceiveError != nil {
|
|
c.onReceiveError(errors.Wrapf(err, "pack msg err: %+v", err))
|
|
}
|
|
break
|
|
}
|
|
if e, ok := packed.([]*Entry); ok {
|
|
// 解包之后是 Entry Slice,全部丢回队列里继续处理
|
|
c.pushIn(e...)
|
|
break
|
|
}
|
|
msg := packed
|
|
if c.webSocket.codec != nil {
|
|
msg, err = c.webSocket.codec.Decode(packed)
|
|
if err != nil {
|
|
if c.onReceiveError != nil {
|
|
c.onReceiveError(errors.Wrapf(err, "decode msg err: %+v", err))
|
|
}
|
|
}
|
|
break
|
|
}
|
|
if c.onBinaryMessageReceived != nil {
|
|
go c.onBinaryMessageReceived(msg)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *NWebsocket) write(entry *Entry) (err error) {
|
|
if c.Closed() {
|
|
return CloseErr
|
|
}
|
|
// 写缓冲区
|
|
select {
|
|
case c.webSocket.sendChan <- entry:
|
|
default:
|
|
err = BufferErr
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *NWebsocket) Closed() bool {
|
|
c.webSocket.connLock.RLock()
|
|
defer c.webSocket.connLock.RUnlock()
|
|
return !c.webSocket.isConnected
|
|
}
|
|
|
|
func (c *NWebsocket) Close() {
|
|
c.CloseWithMsg("")
|
|
}
|
|
|
|
func (c *NWebsocket) CloseWithMsg(msg string) {
|
|
if c.Closed() {
|
|
return
|
|
}
|
|
_ = c.send(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, msg))
|
|
c.clean()
|
|
if c.onClose != nil {
|
|
c.onClose(websocket.CloseNormalClosure, msg)
|
|
}
|
|
}
|
|
|
|
func (c *NWebsocket) wrapSendMessage(v interface{}) (*Entry, error) {
|
|
if c.webSocket.packer == nil {
|
|
return nil, PackerNotFoundErr
|
|
}
|
|
customEntry := v
|
|
if c.webSocket.codec != nil {
|
|
encode, err := c.webSocket.codec.Encode(v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
customEntry = encode
|
|
}
|
|
|
|
msg, err := c.webSocket.packer.Unpack(customEntry)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
// closeAndReConnect 断线重连
|
|
func (c *NWebsocket) closeAndReConnect() {
|
|
if c.Closed() {
|
|
return
|
|
}
|
|
c.clean()
|
|
go c.Connect(c.webSocket.url)
|
|
}
|
|
|
|
func (c *NWebsocket) readInternalLoop() {
|
|
for {
|
|
select {
|
|
default:
|
|
_ = c.webSocket.wsConn.SetReadDeadline(time.Now().Add(c.Config.ReadDeadline))
|
|
msgType, data, err := c.webSocket.wsConn.ReadMessage()
|
|
if err != nil {
|
|
if c.onDisconnected != nil {
|
|
c.onDisconnected(err)
|
|
}
|
|
// 断线重连
|
|
c.closeAndReConnect()
|
|
return
|
|
}
|
|
entry := &Entry{
|
|
MessageType: msgType,
|
|
Raw: data,
|
|
}
|
|
c.pushIn(entry)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *NWebsocket) pushIn(entries ...*Entry) {
|
|
pushTimeout := func(entry *Entry) {
|
|
// 5s timeout
|
|
after := time.NewTicker(5 * time.Second)
|
|
defer after.Stop()
|
|
|
|
select {
|
|
case <-after.C:
|
|
return
|
|
case c.webSocket.readChan <- entry:
|
|
return
|
|
}
|
|
}
|
|
|
|
// 批量 push, 顺序无关
|
|
for _, entry := range entries {
|
|
go pushTimeout(entry)
|
|
}
|
|
}
|
|
|
|
func (c *NWebsocket) sendLoop() {
|
|
for {
|
|
select {
|
|
case data, ok := <-c.webSocket.sendChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
if err := c.send(data.MessageType, data.Raw); err != nil {
|
|
if c.onSentError != nil {
|
|
c.onSentError(err)
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *NWebsocket) send(messageType int, data []byte) error {
|
|
c.webSocket.sendLock.Lock()
|
|
defer c.webSocket.sendLock.Unlock()
|
|
if c.Closed() {
|
|
return CloseErr
|
|
}
|
|
|
|
_ = c.webSocket.wsConn.SetWriteDeadline(time.Now().Add(c.Config.WriteDeadline))
|
|
return c.webSocket.wsConn.WriteMessage(messageType, data)
|
|
}
|
|
|
|
// clean 关闭连接 并清理所有能清理的channel
|
|
func (c *NWebsocket) clean() {
|
|
c.webSocket.connLock.Lock()
|
|
_ = c.webSocket.wsConn.Close()
|
|
close(c.webSocket.readChan)
|
|
close(c.webSocket.sendChan)
|
|
c.webSocket.isConnected = false
|
|
c.webSocket.connLock.Unlock()
|
|
}
|