diff --git a/middleware/heartbeat_ws.go b/middleware/heartbeat_ws.go new file mode 100644 index 0000000..da2a489 --- /dev/null +++ b/middleware/heartbeat_ws.go @@ -0,0 +1,78 @@ +package middleware + +import ( + "git.noahlan.cn/noahlan/nnet" + "git.noahlan.cn/noahlan/nnet/conn" + "git.noahlan.cn/noahlan/nnet/event" + "git.noahlan.cn/noahlan/nnet/packet" + rt "git.noahlan.cn/noahlan/nnet/router" + "git.noahlan.cn/noahlan/ntool/nlog" + "sync/atomic" + "time" +) + +type ( + HeartbeatWsMiddleware struct { + lastAt int64 + interval time.Duration + hbdFn WsHeartbeatFn + } + WsHeartbeatFn func(conn *conn.WSConn) error +) + +func WithHeartbeatWS(interval time.Duration, hbdFn WsHeartbeatFn) nnet.RunOption { + m := &HeartbeatWsMiddleware{ + lastAt: time.Now().Unix(), + interval: interval, + hbdFn: hbdFn, + } + if hbdFn == nil { + nlog.Error("dataFn must not be nil") + panic("dataFn must not be nil") + } + + return func(ngin *nnet.Engine) { + ngin.EventManager().RegisterEvent(event.EvtOnConnected, m.start) + + ngin.Use(func(next rt.HandlerFunc) rt.HandlerFunc { + return func(conn *conn.Connection, pkg packet.IPacket) { + m.handle(conn, pkg) + + next(conn, pkg) + } + }) + } +} + +func (m *HeartbeatWsMiddleware) start(nc *conn.Connection) { + ticker := time.NewTicker(m.interval) + + defer func() { + ticker.Stop() + }() + + for { + select { + case <-ticker.C: + if nc.Type() != conn.ConnTypeWS { + break + } + + deadline := time.Now().Add(-2 * m.interval).Unix() + if atomic.LoadInt64(&m.lastAt) < deadline { + nlog.Errorf("Heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&m.lastAt), deadline) + return + } + + err := m.hbdFn(nc.WsConn()) + if err != nil { + nlog.Errorf("Heartbeat err: %v", err) + return + } + } + } +} + +func (m *HeartbeatWsMiddleware) handle(_ *conn.Connection, _ packet.IPacket) { + atomic.StoreInt64(&m.lastAt, time.Now().Unix()) +}