You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

485 lines
11 KiB
Go

package ws
import (
"github.com/gorilla/websocket"
"github.com/jpillora/backoff"
"github.com/pkg/errors"
"math/rand"
"net/http"
"sync"
"time"
)
type event struct {
// 连接成功回调
onConnected func()
// 连接异常回调,在准备进行连接的过程中发生异常时触发
onConnectError func(err error)
// 连接断开回调,网络异常,服务端掉线等情况时触发
onDisconnected func(err error)
// 连接关闭回调,服务端发起关闭信号或客户端主动关闭时触发
onClose func(code int, text string)
// 发送Text消息成功回调
onTextMessageSent func(message string)
// 发送Binary消息成功回调
onBinaryMessageSent func(data []byte)
// 发送消息异常回调
onSentError func(err error)
// 接受到Ping消息回调
onPingReceived func(appData string)
// 接受到Pong消息回调
onPongReceived func(appData string)
// 接受到Text消息回调
onTextMessageReceived func(message string)
// 接受到Binary消息回调
onBinaryMessageReceived func(v interface{})
// 接收消息异常回调
onReceiveError func(err error)
}
type NWebsocket struct {
Config *Config // 配置
webSocket *WebSocket // 底层 webSocket
*event // evt
}
type WebSocket struct {
wsConn *websocket.Conn // 底层 webSocket 连接
url string // 连接地址,用于重连
dialer *websocket.Dialer // dialer
requestHeader http.Header // 连接请求header
httpResponse *http.Response // 连接失败的返回消息
packer Packer // Packer 打包解包器
codec Codec // 编解码
readChan chan *Entry // 读channel(队列)
sendChan chan *Entry // 写channel(队列)
sendLock sync.Mutex // 写消息锁
isConnected bool // 是否已连接
connLock sync.RWMutex // 连接锁,保证只被执行一次
}
type Config struct {
*BackoffOptions
ReadBufferSize int // 最大读缓冲区大小, 默认1024条消息
SendBufferSize int // 最大写缓冲区大小, 默认512条消息
ReadLimit int64 `json:",default=8192"` // 单条消息支持的最大消息长度,默认 8MB
WriteDeadline time.Duration // 写超时,默认 5s
ReadDeadline time.Duration // 读超时,控制断线检测
}
type BackoffOptions struct {
MinRecTime time.Duration // 最小重连时间间隔
MaxRecTime time.Duration // 最大重连时间间隔
RecFactor float64 // 每次重连失败继续重连的时间间隔递增的乘数因子,递增到最大重连时间间隔为止
}
type ConnectionOptions struct {
*Config
Packer Packer // 打包解包器
Codec Codec // 编解码器
}
type ConnectionOption func(*ConnectionOptions)
func applyOpts(opts ...ConnectionOption) *ConnectionOptions {
result := &ConnectionOptions{
Config: &Config{
BackoffOptions: &BackoffOptions{
MinRecTime: 2 * time.Second,
MaxRecTime: 60 * time.Second,
RecFactor: 1.5,
},
ReadBufferSize: 1024,
SendBufferSize: 512,
ReadLimit: 8 << 10 << 10,
WriteDeadline: 5 * time.Second,
ReadDeadline: 60 * time.Second,
},
}
for _, opt := range opts {
opt(result)
}
return result
}
func WithPacker(packer Packer) ConnectionOption {
return func(options *ConnectionOptions) {
options.Packer = packer
}
}
func WithCodec(codec Codec) ConnectionOption {
return func(options *ConnectionOptions) {
options.Codec = codec
}
}
func WithBackoff(b *BackoffOptions) ConnectionOption {
return func(options *ConnectionOptions) {
options.BackoffOptions = b
}
}
var (
CloseErr = errors.New("connection closed")
BufferErr = errors.New("message buffer is full")
PackerNotFoundErr = errors.New("packer not found")
//PackErr = errors.New("packs message err")
)
func NewWsConnection(opts ...ConnectionOption) *NWebsocket {
opt := applyOpts(opts...)
return &NWebsocket{
Config: opt.Config,
webSocket: &WebSocket{
dialer: websocket.DefaultDialer,
requestHeader: http.Header{},
packer: opt.Packer,
codec: opt.Codec,
isConnected: false,
},
event: &event{},
}
}
func (c *NWebsocket) OnConnected(f func()) {
c.onConnected = f
}
func (c *NWebsocket) OnConnectError(f func(err error)) {
c.onConnectError = f
}
func (c *NWebsocket) OnDisconnected(f func(err error)) {
c.onDisconnected = f
}
func (c *NWebsocket) OnClose(f func(code int, text string)) {
c.onClose = f
}
func (c *NWebsocket) OnTextMessageSent(f func(message string)) {
c.onTextMessageSent = f
}
func (c *NWebsocket) OnBinaryMessageSent(f func(data []byte)) {
c.onBinaryMessageSent = f
}
func (c *NWebsocket) OnSentError(f func(err error)) {
c.onSentError = f
}
func (c *NWebsocket) OnPingReceived(f func(appData string)) {
c.onPingReceived = f
}
func (c *NWebsocket) OnPongReceived(f func(appData string)) {
c.onPongReceived = f
}
func (c *NWebsocket) OnTextMessageReceived(f func(message string)) {
c.onTextMessageReceived = f
}
func (c *NWebsocket) OnBinaryMessageReceived(f func(v interface{})) {
c.onBinaryMessageReceived = f
}
func (c *NWebsocket) OnReceiveError(f func(err error)) {
c.onReceiveError = f
}
func (c *NWebsocket) Connect(url string) {
c.webSocket.sendChan = make(chan *Entry, c.Config.SendBufferSize)
c.webSocket.readChan = make(chan *Entry, c.Config.ReadBufferSize)
c.webSocket.url = url
b := &backoff.Backoff{
Min: c.Config.MinRecTime,
Max: c.Config.MaxRecTime,
Factor: c.Config.RecFactor,
Jitter: true,
}
rand.Seed(time.Now().UTC().UnixNano())
for {
var err error
nextRec := b.Duration()
c.webSocket.wsConn, c.webSocket.httpResponse, err = c.webSocket.dialer.Dial(c.webSocket.url, c.webSocket.requestHeader)
if err != nil {
if c.onConnectError != nil {
c.onConnectError(err)
}
// 连接重试
time.Sleep(nextRec)
continue
}
// 设置连接状态
c.webSocket.connLock.Lock()
c.webSocket.isConnected = true
c.webSocket.connLock.Unlock()
// 连接成功回调
if c.onConnected != nil {
go c.onConnected()
}
// 设置支持接受的消息最大长度
c.webSocket.wsConn.SetReadLimit(c.Config.ReadLimit)
// 连接关闭回调
defaultCloseHandler := c.webSocket.wsConn.CloseHandler()
c.webSocket.wsConn.SetCloseHandler(func(code int, text string) error {
result := defaultCloseHandler(code, text)
c.clean()
if c.onClose != nil {
c.onClose(code, text)
}
return result
})
// 收到ping回调
defaultPingHandler := c.webSocket.wsConn.PingHandler()
c.webSocket.wsConn.SetPingHandler(func(appData string) error {
if c.onPingReceived != nil {
c.onPingReceived(appData)
}
return defaultPingHandler(appData)
})
// 收到pong回调
defaultPongHandler := c.webSocket.wsConn.PongHandler()
c.webSocket.wsConn.SetPongHandler(func(appData string) error {
if c.onPongReceived != nil {
c.onPongReceived(appData)
}
return defaultPongHandler(appData)
})
// 读写
go c.readInternalLoop()
go c.readLoop()
go c.sendLoop()
return
}
}
func (c *NWebsocket) SendTextMessage(message string) (err error) {
return c.write(&Entry{
MessageType: websocket.TextMessage,
Raw: []byte(message),
})
}
func (c *NWebsocket) SendBinaryMessage(v interface{}) (err error) {
// wrap message
msg, err := c.wrapSendMessage(v)
if err != nil {
return err
}
return c.write(msg)
}
func (c *NWebsocket) readLoop() {
if c.webSocket.packer == nil {
if c.onReceiveError != nil {
c.onReceiveError(PackerNotFoundErr)
}
return
}
for {
// pack: ws.Entry -> CustomEntry
// decode: CustomEntry -> AnyType
select {
case entry, ok := <-c.webSocket.readChan:
if !ok {
break
}
if entry.MessageType == websocket.TextMessage {
if c.onTextMessageReceived != nil {
go c.onTextMessageReceived(string(entry.Raw))
}
break
}
packed, err := c.webSocket.packer.Pack(entry)
if err != nil {
if c.onReceiveError != nil {
c.onReceiveError(errors.Wrapf(err, "pack msg err: %+v", err))
}
break
}
if e, ok := packed.([]*Entry); ok {
// 解包之后是 Entry Slice,全部丢回队列里继续处理
c.pushIn(e...)
break
}
msg := packed
if c.webSocket.codec != nil {
msg, err = c.webSocket.codec.Decode(packed)
if err != nil {
if c.onReceiveError != nil {
c.onReceiveError(errors.Wrapf(err, "decode msg err: %+v", err))
}
break
}
}
if c.onBinaryMessageReceived != nil {
go c.onBinaryMessageReceived(msg)
}
}
}
}
func (c *NWebsocket) write(entry *Entry) (err error) {
if c.Closed() {
return CloseErr
}
// 写缓冲区
select {
case c.webSocket.sendChan <- entry:
default:
err = BufferErr
}
return nil
}
func (c *NWebsocket) Closed() bool {
c.webSocket.connLock.RLock()
defer c.webSocket.connLock.RUnlock()
return !c.webSocket.isConnected
}
func (c *NWebsocket) Close() {
c.CloseWithMsg("")
}
func (c *NWebsocket) CloseWithMsg(msg string) {
if c.Closed() {
return
}
_ = c.send(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, msg))
c.clean()
if c.onClose != nil {
c.onClose(websocket.CloseNormalClosure, msg)
}
}
func (c *NWebsocket) wrapSendMessage(v interface{}) (*Entry, error) {
if c.webSocket.packer == nil {
return nil, PackerNotFoundErr
}
customEntry := v
if c.webSocket.codec != nil {
encode, err := c.webSocket.codec.Encode(v)
if err != nil {
return nil, err
}
customEntry = encode
}
msg, err := c.webSocket.packer.Unpack(customEntry)
if err != nil {
return nil, err
}
return msg, nil
}
// CloseAndReConnect 断线重连
func (c *NWebsocket) CloseAndReConnect() {
if c.Closed() {
return
}
c.clean()
go c.Connect(c.webSocket.url)
}
func (c *NWebsocket) readInternalLoop() {
for {
select {
default:
//_ = c.webSocket.wsConn.SetReadDeadline(time.Now().Add(c.Config.ReadDeadline))
msgType, data, err := c.webSocket.wsConn.ReadMessage()
if err != nil {
if c.onDisconnected != nil {
c.onDisconnected(err)
}
// 断线重连
c.CloseAndReConnect()
return
}
entry := &Entry{
MessageType: msgType,
Raw: data,
}
c.pushIn(entry)
}
}
}
func (c *NWebsocket) pushIn(entries ...*Entry) {
pushTimeout := func(entry *Entry) {
// 5s timeout
after := time.NewTicker(5 * time.Second)
defer after.Stop()
select {
case <-after.C:
return
case c.webSocket.readChan <- entry:
return
}
}
// 批量 push, 顺序无关
for _, entry := range entries {
go pushTimeout(entry)
}
}
func (c *NWebsocket) sendLoop() {
for {
select {
case data, ok := <-c.webSocket.sendChan:
if !ok {
return
}
if err := c.send(data.MessageType, data.Raw); err != nil {
if c.onSentError != nil {
c.onSentError(err)
}
continue
}
}
}
}
func (c *NWebsocket) send(messageType int, data []byte) error {
c.webSocket.sendLock.Lock()
defer c.webSocket.sendLock.Unlock()
if c.Closed() {
return CloseErr
}
_ = c.webSocket.wsConn.SetWriteDeadline(time.Now().Add(c.Config.WriteDeadline))
return c.webSocket.wsConn.WriteMessage(messageType, data)
}
// clean 关闭连接 并清理所有能清理的channel
func (c *NWebsocket) clean() {
c.webSocket.connLock.Lock()
_ = c.webSocket.wsConn.Close()
close(c.webSocket.readChan)
close(c.webSocket.sendChan)
c.webSocket.isConnected = false
c.webSocket.connLock.Unlock()
}