diff --git a/conn/conn_mgr.go b/conn/conn_mgr.go index 8629a56..368bcfb 100644 --- a/conn/conn_mgr.go +++ b/conn/conn_mgr.go @@ -35,7 +35,22 @@ func (m *Manager) Store(groupName string, s entity.NetworkEntity) error { return group.Add(s) } -func (m *Manager) Remove(groupName string, s entity.NetworkEntity) error { +func (m *Manager) Remove(s entity.NetworkEntity) error { + m.Lock() + defer m.Unlock() + delete(m.conns, s.Session().ID()) + + // 从所有group中移除 + for _, group := range m.groups { + err := group.Leave(s) + if err != nil { + return err + } + } + return nil +} + +func (m *Manager) RemoveFromGroup(groupName string, s entity.NetworkEntity) error { m.Lock() delete(m.conns, s.Session().ID()) m.Unlock() diff --git a/conn/group.go b/conn/group.go index b9d5b59..2df6a89 100644 --- a/conn/group.go +++ b/conn/group.go @@ -203,15 +203,15 @@ func (c *Group) LeaveAll() error { } // 使用移位法移除group中与name匹配的元素 -func (c *Group) removeGroupAttr(group []string) []string { +func (c *Group) removeGroupAttr(groups []string) []string { j := 0 - for _, v := range group { + for _, v := range groups { if v != c.name { - group[j] = v + groups[j] = v j++ } } - return group[:j] + return groups[:j] } // Count get current member amount in the group diff --git a/core/connection.go b/core/connection.go index d8e7b90..5df9fc4 100644 --- a/core/connection.go +++ b/core/connection.go @@ -3,7 +3,6 @@ package core import ( "errors" "fmt" - "git.noahlan.cn/noahlan/nnet/conn" "git.noahlan.cn/noahlan/nnet/entity" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/nnet/scheduler" @@ -267,9 +266,9 @@ func (r *connection) Close() error { case <-r.chDie: default: close(r.chDie) - scheduler.PushTask(func() { Lifetime.Close(r) }) + scheduler.PushTask(func() { r.ngin.lifetime.Close(r) }) } - _ = r.ngin.connManager.Remove(conn.DefaultGroupName, r) + _ = r.ngin.connManager.Remove(r) r.session.Close() return r.conn.Close() diff --git a/core/engine.go b/core/engine.go index aadf1c0..99f3c45 100644 --- a/core/engine.go +++ b/core/engine.go @@ -6,6 +6,7 @@ import ( "git.noahlan.cn/noahlan/nnet/config" conn2 "git.noahlan.cn/noahlan/nnet/conn" "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/lifetime" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/nnet/pipeline" "git.noahlan.cn/noahlan/nnet/scheduler" @@ -43,6 +44,7 @@ type ( dieChan chan struct{} pipeline pipeline.Pipeline // 消息管道 + lifetime *lifetime.Mgr // 连接的生命周期管理器 packerFn packet.NewPackerFunc // 封包、拆包器 serializer serialize.Serializer // 消息 序列化/反序列化 @@ -72,6 +74,7 @@ func newEngine(conf config.EngineConf) *engine { taskTimerPrecision: conf.TaskTimerPrecision, connManager: conn2.NewManager(), sessIdMgr: newSessionIDMgr(), + lifetime: lifetime.NewLifetime(), } pool.InitPool(conf.Pool) @@ -134,7 +137,7 @@ func (ng *engine) dial(addr string, router Router) (entity.NetworkEntity, error) c := newConnection(ng, conn) c.serve() // hook - Lifetime.Open(c) + ng.lifetime.Open(c) // connection manager err = ng.connManager.Store(conn2.DefaultGroupName, c) nlog.Must(err) @@ -268,7 +271,7 @@ func (ng *engine) handle(conn net.Conn) { c.serve() // hook - Lifetime.Open(c) + ng.lifetime.Open(c) } func (ng *engine) notFoundHandler(next Handler) Handler { diff --git a/core/lifetime.go b/core/lifetime.go deleted file mode 100644 index 97c2ebe..0000000 --- a/core/lifetime.go +++ /dev/null @@ -1,42 +0,0 @@ -package core - -import "git.noahlan.cn/noahlan/nnet/entity" - -type ( - LifetimeHandler func(entity entity.NetworkEntity) - - lifetime struct { - onOpen []LifetimeHandler - onClosed []LifetimeHandler - } -) - -var Lifetime = &lifetime{} - -func (lt *lifetime) OnClosed(h LifetimeHandler) { - lt.onClosed = append(lt.onClosed, h) -} - -func (lt *lifetime) OnOpen(h LifetimeHandler) { - lt.onOpen = append(lt.onOpen, h) -} - -func (lt *lifetime) Open(entity entity.NetworkEntity) { - if len(lt.onOpen) <= 0 { - return - } - - for _, handler := range lt.onOpen { - handler(entity) - } -} - -func (lt *lifetime) Close(entity entity.NetworkEntity) { - if len(lt.onClosed) <= 0 { - return - } - - for _, handler := range lt.onClosed { - handler(entity) - } -} diff --git a/core/nnet.go b/core/nnet.go index c8c1ee6..deb34ba 100644 --- a/core/nnet.go +++ b/core/nnet.go @@ -4,6 +4,7 @@ import ( "git.noahlan.cn/noahlan/nnet/config" "git.noahlan.cn/noahlan/nnet/conn" "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/lifetime" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/nnet/pipeline" "git.noahlan.cn/noahlan/nnet/serialize" @@ -110,6 +111,11 @@ func (s *NNet) Pipeline() pipeline.Pipeline { return s.ngin.pipeline } +// Lifetime returns lifetime interface. +func (s *NNet) Lifetime() lifetime.Lifetime { + return s.ngin.lifetime +} + // ConnManager returns connection manager func (s *NNet) ConnManager() *conn.Manager { return s.ngin.connManager diff --git a/lifetime/lifetime.go b/lifetime/lifetime.go new file mode 100644 index 0000000..aef5ddb --- /dev/null +++ b/lifetime/lifetime.go @@ -0,0 +1,52 @@ +package lifetime + +import "git.noahlan.cn/noahlan/nnet/entity" + +type ( + Handler func(entity entity.NetworkEntity) + + Lifetime interface { + OnClosed(h Handler) + OnOpen(h Handler) + } + + Mgr struct { + onOpen []Handler + onClosed []Handler + } +) + +func NewLifetime() *Mgr { + return &Mgr{ + onOpen: make([]Handler, 0), + onClosed: make([]Handler, 0), + } +} + +func (lt *Mgr) OnClosed(h Handler) { + lt.onClosed = append(lt.onClosed, h) +} + +func (lt *Mgr) OnOpen(h Handler) { + lt.onOpen = append(lt.onOpen, h) +} + +func (lt *Mgr) Open(entity entity.NetworkEntity) { + if len(lt.onOpen) <= 0 { + return + } + + for _, handler := range lt.onOpen { + handler(entity) + } +} + +func (lt *Mgr) Close(entity entity.NetworkEntity) { + if len(lt.onClosed) <= 0 { + return + } + + for _, handler := range lt.onClosed { + handler(entity) + } +} diff --git a/middleware/heartbeat.go b/middleware/heartbeat.go index 5c2e4e7..d06e98b 100644 --- a/middleware/heartbeat.go +++ b/middleware/heartbeat.go @@ -25,9 +25,10 @@ func WithHeartbeat(interval time.Duration, hbdFn func(entity entity.NetworkEntit nlog.Error("dataFn must not be nil") panic("dataFn must not be nil") } - core.Lifetime.OnOpen(m.start) return func(server *core.NNet) { + server.Lifetime().OnOpen(m.start) + server.Use(func(next core.HandlerFunc) core.HandlerFunc { return func(entity entity.NetworkEntity, pkg packet.IPacket) { m.handle(entity, pkg)