diff --git a/core/config.go b/config/config.go similarity index 66% rename from core/config.go rename to config/config.go index 4b5547d..c5600f1 100644 --- a/core/config.go +++ b/config/config.go @@ -1,4 +1,6 @@ -package core +package config + +import "git.noahlan.cn/noahlan/ntools-go/core/pool" const ( // DevMode means development mode. @@ -11,16 +13,21 @@ const ( type ( EngineConf struct { + ServerConf + Pool pool.Config + } + + ServerConf struct { // Protocol 协议名 // "tcp", "tcp4", "tcp6", "unix" or "unixpacket" // 若只想开启IPv4, 使用tcp4即可 - Protocol string + Protocol string `json:",default=tcp4"` // Addr 服务地址 // 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP // 如果端口号不填或端口号为0,例如:"127.0.0.1:" 或 ":0",服务端将选择随机可用端口 - Addr string + Addr string `json:",default=0.0.0.0"` // Name 服务端名称,默认为n-net - Name string - Mode string + Name string `json:",default=n-net"` + Mode string `json:",default=dev,options=[dev,test,prod]"` } ) diff --git a/conn/conn_mgr.go b/conn/conn_mgr.go new file mode 100644 index 0000000..8f0b131 --- /dev/null +++ b/conn/conn_mgr.go @@ -0,0 +1,98 @@ +package conn + +import ( + "git.noahlan.cn/noahlan/nnet/entity" + "sync" +) + +var ConnManager *Manager + +type Manager struct { + sync.RWMutex + + // 分组 + groups map[string]*Group + // 所有 Connection + conns map[int64]entity.NetworkEntity +} + +func init() { + ConnManager = &Manager{ + RWMutex: sync.RWMutex{}, + groups: make(map[string]*Group), + conns: make(map[int64]entity.NetworkEntity), + } +} + +// Store 保存连接,同时加入到指定分组,若给定分组名为空,则不进行分组操作 +func (m *Manager) Store(groupName string, s entity.NetworkEntity) { + m.Lock() + m.conns[s.Session().ID()] = s + m.Unlock() + + group, ok := m.FindGroup(groupName) + if !ok { + group = m.NewGroup(groupName) + } + _ = group.Add(s) +} + +// NewGroup 新增分组,若分组已存在,则返回现有分组 +func (m *Manager) NewGroup(name string) *Group { + m.Lock() + defer m.Unlock() + + group, ok := m.groups[name] + if ok { + return group + } + + group = NewGroup(name) + m.groups[name] = group + + return group +} + +// FindGroup 查找分组 +func (m *Manager) FindGroup(name string) (*Group, bool) { + m.RLock() + defer m.RUnlock() + + g, ok := m.groups[name] + return g, ok +} + +// FindConn 根据连接ID找到连接 +func (m *Manager) FindConn(id int64) (entity.NetworkEntity, bool) { + m.RLock() + defer m.RUnlock() + + e, ok := m.conns[id] + return e, ok +} + +// FindConnByUID 根据连接绑定的UID找到连接 +func (m *Manager) FindConnByUID(uid string) (entity.NetworkEntity, bool) { + m.RLock() + defer m.RUnlock() + + for _, e := range m.conns { + if e.Session().UID() == uid { + return e, true + } + } + return nil, false +} + +// PeekConn 循环所有连接 +// fn 返回true跳过循环,反之一直循环 +func (m *Manager) PeekConn(fn func(id int64, e entity.NetworkEntity) bool) { + m.RLock() + defer m.RUnlock() + + for id, e := range m.conns { + if fn(id, e) { + break + } + } +} diff --git a/conn/group.go b/conn/group.go new file mode 100644 index 0000000..3ad3aa1 --- /dev/null +++ b/conn/group.go @@ -0,0 +1,246 @@ +package conn + +import ( + "errors" + "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "sync" + "sync/atomic" +) + +const groupKey = "NNET_GROUP#" +const DefaultGroupName = "DEFAULT_GROUP" + +const ( + groupStatusWorking = 0 + groupStatusClosed = 1 +) + +var ( + ErrCloseClosedGroup = errors.New("close closed group") + ErrClosedGroup = errors.New("group closed") + DeleteDefaultGroupNotAllow = errors.New("delete default group not allow") +) + +type Group struct { + mu sync.RWMutex + + status int32 // group current status + name string // group name + conns map[int64]entity.NetworkEntity +} + +func NewGroup(name string) *Group { + return &Group{ + mu: sync.RWMutex{}, + status: groupStatusWorking, + name: name, + conns: make(map[int64]entity.NetworkEntity), + } +} + +// Member returns connection by specified uid +func (c *Group) Member(uid string) (entity.NetworkEntity, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + for _, e := range c.conns { + if e.Session().UID() == uid { + return e, true + } + } + + return nil, false +} + +// MemberBySID returns specified sId's connection +func (c *Group) MemberBySID(id int64) (entity.NetworkEntity, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + e, ok := c.conns[id] + return e, ok +} + +func (c *Group) Members() []entity.NetworkEntity { + var resp []entity.NetworkEntity + c.PeekMembers(func(_ int64, e entity.NetworkEntity) bool { + resp = append(resp, e) + return false + }) + return resp +} + +// PeekMembers returns all members in current group +// fn 返回true跳过循环,反之一直循环 +func (c *Group) PeekMembers(fn func(sId int64, e entity.NetworkEntity) bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + for sId, e := range c.conns { + if fn(sId, e) { + break + } + } +} + +// Contains check whether a UID is contained in current group or not +func (c *Group) Contains(uid string) bool { + _, ok := c.Member(uid) + return ok +} + +// Add session to group +func (c *Group) Add(e entity.NetworkEntity) error { + if c.isClosed() { + return ErrClosedGroup + } + + c.mu.Lock() + defer c.mu.Unlock() + + sess := e.Session() + id := sess.ID() + + // group attribute + if sess.Exists(groupKey) { + groups, ok := sess.Attribute(groupKey).([]string) + if !ok { + groups = make([]string, 0) + sess.SetAttribute(groupKey, groups) + } + contains := false + for _, g := range groups { + if g == c.name { + contains = true + break + } + } + if !contains { + groups = append(groups, c.name) + sess.SetAttribute(groupKey, groups) + } + } else { + sess.SetAttribute(groupKey, []string{c.name}) + } + + if _, ok := c.conns[id]; !ok { + c.conns[id] = e + } + + nlog.Debugf("Add connection to group %s, ID=%d, UID=%d", c.name, sess.ID(), sess.UID()) + return nil +} + +// Leave remove specified UID related session from group +func (c *Group) Leave(e entity.NetworkEntity) error { + if c.isClosed() { + return ErrClosedGroup + } + if e == nil { + return nil + } + sess := e.Session() + nlog.Debugf("Remove connection from group %s, UID=%d", c.name, sess.ID()) + + c.mu.Lock() + defer c.mu.Unlock() + + if sess.Exists(groupKey) { + groups, ok := sess.Attribute(groupKey).([]string) + if !ok { + groups = make([]string, 0) + sess.SetAttribute(groupKey, groups) + } + groups = c.removeGroupAttr(groups) + + if len(groups) == 0 { + sess.RemoveAttribute(groupKey) + } else { + sess.SetAttribute(groupKey, groups) + } + } + + delete(c.conns, sess.ID()) + return nil +} + +func (c *Group) LeaveByUID(uid string) error { + if c.isClosed() { + return ErrClosedGroup + } + member, _ := c.Member(uid) + return c.Leave(member) +} + +// LeaveAll clear all sessions in the group +func (c *Group) LeaveAll() error { + if c.isClosed() { + return ErrClosedGroup + } + + c.mu.Lock() + defer c.mu.Unlock() + + for _, e := range c.conns { + sess := e.Session() + + groups, ok := sess.Attribute(groupKey).([]string) + if !ok { + groups = make([]string, 0) + sess.SetAttribute(groupKey, groups) + } + groups = c.removeGroupAttr(groups) + + if len(groups) == 0 { + sess.RemoveAttribute(groupKey) + } else { + sess.SetAttribute(groupKey, groups) + } + } + c.conns = make(map[int64]entity.NetworkEntity) + return nil +} + +// 使用移位法移除group中与name匹配的元素 +func (c *Group) removeGroupAttr(group []string) []string { + j := 0 + for _, v := range group { + if v != c.name { + group[j] = v + j++ + } + } + return group[:j] +} + +// Count get current member amount in the group +func (c *Group) Count() int { + c.mu.RLock() + defer c.mu.RUnlock() + + return len(c.conns) +} + +func (c *Group) isClosed() bool { + if atomic.LoadInt32(&c.status) == groupStatusClosed { + return true + } + return false +} + +// Close destroy group, which will release all resource in the group +func (c *Group) Close() error { + if c.isClosed() { + return ErrCloseClosedGroup + } + if c.name == DefaultGroupName { + // 默认分组不允许删除 + return DeleteDefaultGroupNotAllow + } + + _ = c.LeaveAll() + + atomic.StoreInt32(&c.status, groupStatusClosed) + return nil +} diff --git a/core/connection.go b/core/connection.go index 7ea3a21..f029a65 100644 --- a/core/connection.go +++ b/core/connection.go @@ -3,23 +3,19 @@ package core import ( "errors" "fmt" - "git.noahlan.cn/noahlan/nnet/internal/pool" + "git.noahlan.cn/noahlan/nnet/entity" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/nnet/scheduler" - "git.noahlan.cn/noahlan/nnet/session" "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "git.noahlan.cn/noahlan/ntools-go/core/pool" "net" "sync/atomic" - "time" ) var ( ErrCloseClosedSession = errors.New("close closed session") // ErrBrokenPipe represents the low-level connection has broken. ErrBrokenPipe = errors.New("broken low-level pipe") - // ErrBufferExceed indicates that the current session buffer is full and - // can not receive more data. - ErrBufferExceed = errors.New("session send buffer exceed") ) const ( @@ -36,63 +32,62 @@ const ( ) type ( - Connection struct { - session *session.Session // Session - ngin *engine // engine + connection struct { + session *session // Session + ngin *engine // engine + + status int32 // 连接状态 + conn net.Conn // low-level conn fd + isWS bool // 是否为websocket - status int32 // 连接状态 - conn net.Conn // low-level conn fd packer packet.Packer // 封包、拆包器 lastMid uint64 // 最近一次消息ID - // TODO 考虑独立出去作为一个中间件 - lastHeartbeatAt int64 // 最近一次心跳时间 chDie chan struct{} // 停止通道 - chSend chan pendingMessage // 消息发送通道(结构化消息) + chSend chan PendingMessage // 消息发送通道(结构化消息) chWrite chan []byte // 消息发送通道(二进制消息) } - pendingMessage struct { + PendingMessage struct { header interface{} payload interface{} } ) -func newConn(server *engine, conn net.Conn) *Connection { - r := &Connection{ +func newConnection(server *engine, conn net.Conn) *connection { + r := &connection{ ngin: server, status: StatusStart, conn: conn, packer: server.packerFn(), - lastHeartbeatAt: time.Now().Unix(), - chDie: make(chan struct{}), - chSend: make(chan pendingMessage, 128), + chSend: make(chan PendingMessage, 128), chWrite: make(chan []byte, 128), } + _, r.isWS = conn.(*WSConn) // binding session - r.session = session.NewSession() + r.session = newSession(r) return r } -func (r *Connection) Send(header, payload interface{}) (err error) { +func (r *connection) Send(header, payload interface{}) (err error) { defer func() { if e := recover(); e != nil { err = ErrBrokenPipe } }() - r.chSend <- pendingMessage{ + r.chSend <- PendingMessage{ header: header, payload: payload, } return err } -func (r *Connection) SendBytes(data []byte) (err error) { +func (r *connection) SendBytes(data []byte) (err error) { defer func() { if e := recover(); e != nil { err = ErrBrokenPipe @@ -102,71 +97,55 @@ func (r *Connection) SendBytes(data []byte) (err error) { return err } -func (r *Connection) Status() int32 { +func (r *connection) Status() int32 { return atomic.LoadInt32(&r.status) } -func (r *Connection) SetStatus(s int32) { +func (r *connection) SetStatus(s int32) { atomic.StoreInt32(&r.status, s) } -func (r *Connection) Conn() net.Conn { - return r.conn +func (r *connection) Conn() (net.Conn, bool) { + return r.conn, r.isWS } -func (r *Connection) ID() int64 { +func (r *connection) ID() int64 { return r.session.ID() } -func (r *Connection) SetLastHeartbeatAt(t int64) { - atomic.StoreInt64(&r.lastHeartbeatAt, t) -} - -func (r *Connection) Session() *session.Session { +func (r *connection) Session() entity.Session { return r.session } -func (r *Connection) LastMID() uint64 { +func (r *connection) LastMID() uint64 { return r.lastMid } -func (r *Connection) SetLastMID(mid uint64) { +func (r *connection) SetLastMID(mid uint64) { atomic.StoreUint64(&r.lastMid, mid) } -func (r *Connection) serve() { - _ = pool.SubmitConn(func() { +func (r *connection) serve() { + _ = pool.Submit(func() { r.write() }) - _ = pool.SubmitWorker(func() { + _ = pool.Submit(func() { r.read() }) } -func (r *Connection) write() { - ticker := time.NewTicker(r.ngin.heartbeatInterval) - +func (r *connection) write() { defer func() { - ticker.Stop() close(r.chSend) close(r.chWrite) _ = r.Close() - nlog.Debugf("Connection write goroutine exit, ConnID=%d, SessionUID=%s", r.ID(), r.session.UID()) + nlog.Debugf("connection write goroutine exit, ConnID=%d, SessionUID=%s", r.ID(), r.session.UID()) }() for { select { - case <-ticker.C: - // TODO heartbeat enable control - deadline := time.Now().Add(-2 * r.ngin.heartbeatInterval).Unix() - if atomic.LoadInt64(&r.lastHeartbeatAt) < deadline { - nlog.Debugf("Session heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&r.lastHeartbeatAt), deadline) - return - } - // TODO heartbeat data - r.chWrite <- []byte{} case data := <-r.chSend: // marshal packet body (data) if r.ngin.serializer == nil { @@ -187,8 +166,7 @@ func (r *Connection) write() { if pipe := r.ngin.pipeline; pipe != nil { err := pipe.Outbound().Process(r, data) if err != nil { - nlog.Errorf("broken pipeline err: %s", err.Error()) - break + nlog.Errorf("pipeline err: %s", err.Error()) } } @@ -203,7 +181,7 @@ func (r *Connection) write() { // 回写数据 if _, err := r.conn.Write(data); err != nil { nlog.Error(err.Error()) - return + break } case <-r.chDie: // connection close signal return @@ -213,9 +191,9 @@ func (r *Connection) write() { } } -func (r *Connection) read() { +func (r *connection) read() { defer func() { - r.Close() + _ = r.Close() }() buf := make([]byte, 4096) @@ -246,7 +224,7 @@ func (r *Connection) read() { } } -func (r *Connection) processPacket(packet packet.IPacket) error { +func (r *connection) processPacket(packet packet.IPacket) error { if pipe := r.ngin.pipeline; pipe != nil { err := pipe.Inbound().Process(r, packet) if err != nil { @@ -254,22 +232,16 @@ func (r *Connection) processPacket(packet packet.IPacket) error { } } - // packet processor - err := r.ngin.processor.Process(r, packet) - if err != nil { - return err - } - if r.Status() == StatusWorking { // HandleFunc - _ = pool.SubmitWorker(func() { + _ = pool.Submit(func() { r.ngin.handler.Handle(r, packet) }) } - return err + return nil } -func (r *Connection) Close() error { +func (r *connection) Close() error { if r.Status() == StatusClosed { return ErrCloseClosedSession } @@ -283,5 +255,7 @@ func (r *Connection) Close() error { close(r.chDie) scheduler.PushTask(func() { Lifetime.Close(r) }) } + r.session.Close() + return r.conn.Close() } diff --git a/core/engine.go b/core/engine.go index b21a7b5..90c5a0c 100644 --- a/core/engine.go +++ b/core/engine.go @@ -2,12 +2,15 @@ package core import ( "errors" - "git.noahlan.cn/noahlan/nnet/internal/pool" + "git.noahlan.cn/noahlan/nnet/config" + conn2 "git.noahlan.cn/noahlan/nnet/conn" + "git.noahlan.cn/noahlan/nnet/entity" "git.noahlan.cn/noahlan/nnet/packet" + "git.noahlan.cn/noahlan/nnet/pipeline" "git.noahlan.cn/noahlan/nnet/scheduler" "git.noahlan.cn/noahlan/nnet/serialize" - "git.noahlan.cn/noahlan/nnet/session" "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "git.noahlan.cn/noahlan/ntools-go/core/pool" "github.com/gorilla/websocket" "log" "net" @@ -19,9 +22,9 @@ import ( "time" ) -func NotFound(conn *Connection, packet packet.IPacket) { +func NotFound(conn entity.NetworkEntity, _ packet.IPacket) { nlog.Error("handler not found") - conn.SendBytes([]byte("handler not found")) + _ = conn.SendBytes([]byte("handler not found")) } func NotFoundHandler() Handler { @@ -31,25 +34,21 @@ func NotFoundHandler() Handler { type ( // engine TCP-engine engine struct { - conf EngineConf // conf 配置 - middlewares []Middleware // 中间件 - routes []Route // 路由 + conf config.EngineConf // conf 配置 + middlewares []Middleware // 中间件 + routes []Route // 路由 // handler 消息处理器 handler Handler // dieChan 应用程序退出信号 dieChan chan struct{} - // sessionMgr session管理器 - sessionMgr *session.Manager - pipeline Pipeline // 消息管道 + pipeline pipeline.Pipeline // 消息管道 packerFn packet.NewPackerFunc // 封包、拆包器 - processor Processor // 数据包处理器 serializer serialize.Serializer // 消息 序列化/反序列化 - retryInterval time.Duration // 消息重试间隔时长 - heartbeatInterval time.Duration // 心跳间隔,0表示不进行心跳 - wsOpt wsOptions // websocket + retryInterval time.Duration // 消息重试间隔时长 + wsOpt wsOptions // websocket } wsOptions struct { @@ -61,20 +60,22 @@ type ( } ) -func newEngine(conf EngineConf) *engine { +func newEngine(conf config.EngineConf) *engine { s := &engine{ - conf: conf, - dieChan: make(chan struct{}), - sessionMgr: session.NewSessionMgr(), + conf: conf, + dieChan: make(chan struct{}), + pipeline: pipeline.New(), + middlewares: make([]Middleware, 0), + routes: make([]Route, 0), } - pool.InitPool(10000) + pool.InitPool(conf.Pool) return s } -func (ng *engine) use(middleware Middleware) { - ng.middlewares = append(ng.middlewares, middleware) +func (ng *engine) use(middleware ...Middleware) { + ng.middlewares = append(ng.middlewares, middleware...) } func (ng *engine) addRoutes(route ...Route) { @@ -112,20 +113,23 @@ func (ng *engine) serve(router Router) error { if err := ng.bindRoutes(router); err != nil { return err } + go scheduler.Schedule() + defer func() { + nlog.Info("Server is stopping...") - go func() { - if ng.wsOpt.IsWebsocket { - if len(ng.wsOpt.TLSCertificate) != 0 && len(ng.wsOpt.TLSKey) != 0 { - ng.listenAndServeWSTLS() - } else { - ng.listenAndServeWS() - } - } else { - ng.listenAndServe() - } + ng.shutdown() + scheduler.Close() }() - go scheduler.Schedule() + if ng.wsOpt.IsWebsocket { + if len(ng.wsOpt.TLSCertificate) != 0 && len(ng.wsOpt.TLSKey) != 0 { + ng.listenAndServeWSTLS() + } else { + ng.listenAndServeWS() + } + } else { + ng.listenAndServe() + } sg := make(chan os.Signal) signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM) @@ -137,11 +141,6 @@ func (ng *engine) serve(router Router) error { nlog.Infof("Server got signal: %s", s) } - nlog.Info("Server is stopping...") - - ng.shutdown() - scheduler.Close() - return nil } @@ -161,7 +160,8 @@ func (ng *engine) listenAndServe() { // 监听成功,服务已启动 nlog.Infof("now listening %s on %s.", ng.conf.Protocol, ng.conf.Addr) defer func() { - listener.Close() + _ = listener.Close() + ng.shutdown() ng.close() }() @@ -176,7 +176,7 @@ func (ng *engine) listenAndServe() { continue } - err = pool.SubmitConn(func() { + err = pool.Submit(func() { ng.handle(conn) }) if err != nil { @@ -212,7 +212,7 @@ func (ng *engine) setupWS() { nlog.Errorf("Upgrade failure, URI=%ng, Error=%ng", r.RequestURI, err.Error()) return } - err = pool.SubmitConn(func() { + err = pool.Submit(func() { ng.handleWS(conn) }) if err != nil { @@ -222,29 +222,26 @@ func (ng *engine) setupWS() { } func (ng *engine) handleWS(conn *websocket.Conn) { - c, err := newWSConn(conn) - if err != nil { - nlog.Error(err) - return - } + c := newWSConn(conn) ng.handle(c) } func (ng *engine) handle(conn net.Conn) { - connection := newConn(ng, conn) - ng.sessionMgr.StoreSession(connection.Session()) + c := newConnection(ng, conn) + conn2.ConnManager.Store(conn2.DefaultGroupName, c) - connection.serve() + c.serve() // hook + Lifetime.Open(c) } func (ng *engine) notFoundHandler(next Handler) Handler { - return HandlerFunc(func(conn *Connection, packet packet.IPacket) { + return HandlerFunc(func(entity entity.NetworkEntity, packet packet.IPacket) { h := next if next == nil { h = NotFoundHandler() } // TODO write to client - h.Handle(conn, packet) + h.Handle(entity, packet) }) } diff --git a/core/lifetime.go b/core/lifetime.go index 964abba..97c2ebe 100644 --- a/core/lifetime.go +++ b/core/lifetime.go @@ -1,7 +1,9 @@ package core +import "git.noahlan.cn/noahlan/nnet/entity" + type ( - LifetimeHandler func(conn *Connection) + LifetimeHandler func(entity entity.NetworkEntity) lifetime struct { onOpen []LifetimeHandler @@ -19,22 +21,22 @@ func (lt *lifetime) OnOpen(h LifetimeHandler) { lt.onOpen = append(lt.onOpen, h) } -func (lt *lifetime) Open(conn *Connection) { +func (lt *lifetime) Open(entity entity.NetworkEntity) { if len(lt.onOpen) <= 0 { return } for _, handler := range lt.onOpen { - handler(conn) + handler(entity) } } -func (lt *lifetime) Close(conn *Connection) { +func (lt *lifetime) Close(entity entity.NetworkEntity) { if len(lt.onClosed) <= 0 { return } for _, handler := range lt.onClosed { - handler(conn) + handler(entity) } } diff --git a/core/nnet_router.go b/core/nnet_router.go deleted file mode 100644 index 3763f7e..0000000 --- a/core/nnet_router.go +++ /dev/null @@ -1,49 +0,0 @@ -package core - -import ( - "errors" - "git.noahlan.cn/noahlan/nnet/packet" - "git.noahlan.cn/noahlan/ntools-go/core/nlog" -) - -type nNetRouter struct { - handlers map[string]Handler - notFound Handler -} - -func NewRouter() Router { - return &nNetRouter{ - handlers: make(map[string]Handler), - } -} - -func (r *nNetRouter) Handle(conn *Connection, p packet.IPacket) { - pkg, ok := p.(*packet.Packet) - if !ok { - nlog.Error(packet.ErrWrongMessage) - return - } - handler, ok := r.handlers[pkg.Header.Route] - if !ok { - if r.notFound == nil { - nlog.Error("message handler not found") - return - } - r.notFound.Handle(conn, p) - return - } - handler.Handle(conn, p) -} - -func (r *nNetRouter) Register(matches interface{}, handler Handler) error { - route, ok := matches.(string) - if !ok { - return errors.New("the type of matches must be string") - } - r.handlers[route] = handler - return nil -} - -func (r *nNetRouter) SetNotFoundHandler(handler Handler) { - r.notFound = handler -} diff --git a/core/processor.go b/core/processor.go deleted file mode 100644 index dd57f83..0000000 --- a/core/processor.go +++ /dev/null @@ -1,9 +0,0 @@ -package core - -import "git.noahlan.cn/noahlan/nnet/packet" - -type ( - Processor interface { - Process(conn *Connection, packet packet.IPacket) error - } -) diff --git a/core/processor_nnet.go b/core/processor_nnet.go deleted file mode 100644 index 48a1465..0000000 --- a/core/processor_nnet.go +++ /dev/null @@ -1,72 +0,0 @@ -package core - -import ( - "encoding/json" - "errors" - "fmt" - "git.noahlan.cn/noahlan/nnet/packet" - "git.noahlan.cn/noahlan/ntools-go/core/nlog" - "time" -) - -var ( - hrd []byte // handshake response data - hbd []byte // heartbeat packet data -) - -type NNetProcessor struct { -} - -func NewNNetProcessor() *NNetProcessor { - // TODO custom hrd hbd - data, _ := json.Marshal(map[string]interface{}{ - "code": 200, - "sys": map[string]float64{"heartbeat": time.Second.Seconds()}, - }) - packer := packet.NewNNetPacker() - hrd, _ = packer.Pack(packet.Handshake, data) - - return &NNetProcessor{} -} - -func (n *NNetProcessor) Process(conn *Connection, p packet.IPacket) error { - h, ok := p.(*packet.Packet) - if !ok { - return packet.ErrWrongPacketType - } - - switch h.PacketType { - case packet.Handshake: - // TODO validate handshake - if err := conn.SendBytes(hrd); err != nil { - return err - } - - conn.SetStatus(StatusPrepare) - nlog.Debugf("connection handshake Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr()) - case packet.HandshakeAck: - conn.SetStatus(StatusPending) - nlog.Debugf("receive handshake ACK Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr()) - case packet.Heartbeat: - // Expected - case packet.Data: - if conn.Status() < StatusPending { - return errors.New(fmt.Sprintf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s", - conn.Conn().RemoteAddr())) - } - conn.SetStatus(StatusWorking) - - var lastMid uint64 - switch h.MsgType { - case packet.Request: - lastMid = h.ID - case packet.Notify: - lastMid = 0 - default: - return fmt.Errorf("Invalid message type: %s ", h.MsgType.String()) - } - conn.SetLastMID(lastMid) - } - conn.SetLastHeartbeatAt(time.Now().Unix()) - return nil -} diff --git a/core/router.go b/core/router.go index fc4218e..39c23a2 100644 --- a/core/router.go +++ b/core/router.go @@ -1,8 +1,11 @@ package core -type ( - Middleware func(next HandlerFunc) HandlerFunc +import ( + "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/packet" +) +type ( Route struct { Matches interface{} // 用于匹配的关键字段 Handler HandlerFunc // 处理方法 @@ -52,3 +55,28 @@ func (c Chain) Append(constructors ...Constructor) Chain { func (c Chain) Extend(chain Chain) Chain { return c.Append(chain.constructors...) } + +type plainRouter struct { + handler Handler + notFound Handler +} + +func NewDefaultRouter() Router { + return &plainRouter{} +} + +func (p *plainRouter) Handle(entity entity.NetworkEntity, pkg packet.IPacket) { + if p.handler == nil { + return + } + p.handler.Handle(entity, pkg) +} + +func (p *plainRouter) Register(_ interface{}, handler Handler) error { + p.handler = handler + return nil +} + +func (p *plainRouter) SetNotFoundHandler(handler Handler) { + p.notFound = handler +} diff --git a/core/server.go b/core/server.go index 7e5fc32..b37e0f0 100644 --- a/core/server.go +++ b/core/server.go @@ -1,8 +1,10 @@ package core import ( + "git.noahlan.cn/noahlan/nnet/config" "git.noahlan.cn/noahlan/nnet/internal/env" "git.noahlan.cn/noahlan/nnet/packet" + "git.noahlan.cn/noahlan/nnet/pipeline" "git.noahlan.cn/noahlan/nnet/serialize" "git.noahlan.cn/noahlan/ntools-go/core/nlog" "net/http" @@ -21,10 +23,10 @@ type ( // NewServer returns a server with given config of c and options defined in opts. // Be aware that later RunOption might overwrite previous one that write the same option. -func NewServer(c EngineConf, opts ...RunOption) *Server { +func NewServer(c config.EngineConf, opts ...RunOption) *Server { s := &Server{ ngin: newEngine(c), - router: NewRouter(), + router: NewDefaultRouter(), } opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...) @@ -60,8 +62,13 @@ func (s *Server) Stop() { } // Use adds the given middleware in the Server. -func (s *Server) Use(middleware Middleware) { - s.ngin.use(middleware) +func (s *Server) Use(middleware ...Middleware) { + s.ngin.use(middleware...) +} + +// Pipeline returns inner pipeline +func (s *Server) Pipeline() pipeline.Pipeline { + return s.ngin.pipeline } // ToMiddleware converts the given handler to a Middleware. @@ -93,6 +100,12 @@ func WithMiddleware(middleware Middleware, rs ...Route) []Route { return routes } +func UseMiddleware(middleware ...Middleware) RunOption { + return func(server *Server) { + server.Use(middleware...) + } +} + // WithNotFoundHandler returns a RunOption with not found handler set to given handler. func WithNotFoundHandler(handler Handler) RunOption { return func(server *Server) { @@ -101,24 +114,21 @@ func WithNotFoundHandler(handler Handler) RunOption { } } +// WithRouter 设置消息路由 func WithRouter(router Router) RunOption { return func(server *Server) { server.router = router } } +// WithPacker 设置消息的 封包/解包 方式 func WithPacker(fn packet.NewPackerFunc) RunOption { return func(server *Server) { server.ngin.packerFn = fn } } -func WithProcessor(p Processor) RunOption { - return func(server *Server) { - server.ngin.processor = p - } -} - +// WithSerializer 设置消息的 序列化/反序列化 方式 func WithSerializer(s serialize.Serializer) RunOption { return func(server *Server) { server.ngin.serializer = s @@ -137,15 +147,19 @@ func WithTimerPrecision(precision time.Duration) RunOption { } } -func WithPipeline(pipeline Pipeline) RunOption { +func WithPipeline(pipeline pipeline.Pipeline) RunOption { return func(server *Server) { server.ngin.pipeline = pipeline } } -func WithHeartbeatInterval(d time.Duration) RunOption { +type PipelineOption func(opts pipeline.Pipeline) + +func WithPipelineOpt(opts ...func(pipeline.Pipeline)) RunOption { return func(server *Server) { - server.ngin.heartbeatInterval = d + for _, opt := range opts { + opt(server.ngin.pipeline) + } } } diff --git a/core/server_test.go b/core/server_test.go index 05cb694..c3e3573 100644 --- a/core/server_test.go +++ b/core/server_test.go @@ -25,7 +25,7 @@ func TestServer(t *testing.T) { server.AddRoute(Route{ Matches: "test", - Handler: func(conn *Connection, pkg packet.IPacket) { + Handler: func(conn *connection, pkg packet.IPacket) { fmt.Println(pkg) p, ok := pkg.(*packet.Packet) if !ok { diff --git a/session/session.go b/core/session.go similarity index 58% rename from session/session.go rename to core/session.go index ec1171e..9293fd0 100644 --- a/session/session.go +++ b/core/session.go @@ -1,46 +1,56 @@ -package session +package core import ( + "git.noahlan.cn/noahlan/nnet/entity" "sync" "sync/atomic" ) -type Session struct { +type session struct { sync.RWMutex // 数据锁 + // 网络单元 + entity entity.NetworkEntity + id int64 // Session全局唯一ID uid string // 用户ID,不绑定的情况下与sid一致 data map[string]interface{} // session数据存储(内存) } -func NewSession() *Session { - return &Session{ - id: sessionIDMgrInstance.SessionID(), - uid: "", - data: make(map[string]interface{}), +func newSession(entity entity.NetworkEntity) *session { + return &session{ + id: sessionIDMgrInstance.SessionID(), + entity: entity, + uid: "", + data: make(map[string]interface{}), } } -func (s *Session) ID() int64 { +// ID 获取 session ID +func (s *session) ID() int64 { return s.id } -func (s *Session) UID() string { +// UID 获取UID +func (s *session) UID() string { return s.uid } -func (s *Session) Bind(uid string) { +// Bind 绑定uid +func (s *session) Bind(uid string) { s.uid = uid } -func (s *Session) Attribute(key string) interface{} { +// Attribute 获取指定key对应参数 +func (s *session) Attribute(key string) interface{} { s.RLock() defer s.RUnlock() return s.data[key] } -func (s *Session) Keys() []string { +// Keys 获取所有参数key +func (s *session) Keys() []string { s.RLock() defer s.RUnlock() @@ -51,7 +61,8 @@ func (s *Session) Keys() []string { return keys } -func (s *Session) Exists(key string) bool { +// Exists 指定key是否存在 +func (s *session) Exists(key string) bool { s.RLock() defer s.RUnlock() @@ -59,28 +70,32 @@ func (s *Session) Exists(key string) bool { return has } -func (s *Session) Attributes() map[string]interface{} { +// Attributes 获取所有参数 +func (s *session) Attributes() map[string]interface{} { s.RLock() defer s.RUnlock() return s.data } -func (s *Session) RemoveAttribute(key string) { +// RemoveAttribute 移除指定key对应参数 +func (s *session) RemoveAttribute(key string) { s.Lock() defer s.Unlock() delete(s.data, key) } -func (s *Session) SetAttribute(key string, value interface{}) { +// SetAttribute 设置参数 +func (s *session) SetAttribute(key string, value interface{}) { s.Lock() defer s.Unlock() s.data[key] = value } -func (s *Session) Invalidate() { +// Invalidate 清理 +func (s *session) Invalidate() { s.Lock() defer s.Unlock() @@ -89,6 +104,12 @@ func (s *Session) Invalidate() { s.data = make(map[string]interface{}) } +// Close 关闭 +func (s *session) Close() { + //s.entity.Close() + s.Invalidate() +} + var sessionIDMgrInstance = newSessionIDMgr() type sessionIDMgr struct { diff --git a/core/types.go b/core/types.go index 9523c99..8a47ad3 100644 --- a/core/types.go +++ b/core/types.go @@ -1,15 +1,20 @@ package core -import "git.noahlan.cn/noahlan/nnet/packet" +import ( + "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/packet" +) type ( Handler interface { - Handle(conn *Connection, pkg packet.IPacket) + Handle(entity entity.NetworkEntity, pkg packet.IPacket) } - HandlerFunc func(conn *Connection, pkg packet.IPacket) + HandlerFunc func(entity entity.NetworkEntity, pkg packet.IPacket) + + Middleware func(next HandlerFunc) HandlerFunc ) -func (f HandlerFunc) Handle(conn *Connection, pkg packet.IPacket) { - f(conn, pkg) +func (f HandlerFunc) Handle(entity entity.NetworkEntity, pkg packet.IPacket) { + f(entity, pkg) } diff --git a/core/ws.go b/core/ws.go index 8116ad0..d78179b 100644 --- a/core/ws.go +++ b/core/ws.go @@ -3,54 +3,40 @@ package core import ( "github.com/gorilla/websocket" "io" - "net" "time" ) -// wsConn 封装 websocket.Conn 并实现所有 net.Conn 接口 +// WSConn 封装 websocket.Conn 并实现所有 net.Conn 接口 // 兼容所有使用 net.Conn 的方法 -type wsConn struct { - conn *websocket.Conn - typ int // message type - reader io.Reader +type WSConn struct { + *websocket.Conn } // newWSConn 新建wsConn -func newWSConn(conn *websocket.Conn) (*wsConn, error) { - c := &wsConn{conn: conn} - - t, r, err := conn.NextReader() - if err != nil { - return nil, err - } - c.typ = t - c.reader = r - return c, nil +func newWSConn(conn *websocket.Conn) *WSConn { + return &WSConn{Conn: conn} } // Read reads data from the connection. // Read can be made to time out and return an Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. -func (c *wsConn) Read(b []byte) (int, error) { - n, err := c.reader.Read(b) +func (c *WSConn) Read(b []byte) (int, error) { + _, r, err := c.NextReader() + if err != nil { + return 0, err + } + n, err := r.Read(b) if err != nil && err != io.EOF { return n, err - } else if err == io.EOF { - _, r, err := c.conn.NextReader() - if err != nil { - return 0, err - } - c.reader = r } - return n, nil } // Write writes data to the connection. // Write can be made to time out and return an Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetWriteDeadline. -func (c *wsConn) Write(b []byte) (int, error) { - err := c.conn.WriteMessage(websocket.BinaryMessage, b) +func (c *WSConn) Write(b []byte) (int, error) { + err := c.WriteMessage(websocket.BinaryMessage, b) if err != nil { return 0, err } @@ -58,22 +44,6 @@ func (c *wsConn) Write(b []byte) (int, error) { return len(b), nil } -// Close closes the connection. -// Any blocked Read or Write operations will be unblocked and return errors. -func (c *wsConn) Close() error { - return c.conn.Close() -} - -// LocalAddr returns the local network Addr. -func (c *wsConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -// RemoteAddr returns the remote network Addr. -func (c *wsConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - // SetDeadline sets the read and write deadlines associated // with the connection. It is equivalent to calling both // SetReadDeadline and SetWriteDeadline. @@ -89,26 +59,10 @@ func (c *wsConn) RemoteAddr() net.Addr { // the deadline after successful Read or Write calls. // // A zero value for t means I/O operations will not time out. -func (c *wsConn) SetDeadline(t time.Time) error { - if err := c.conn.SetReadDeadline(t); err != nil { +func (c *WSConn) SetDeadline(t time.Time) error { + if err := c.SetReadDeadline(t); err != nil { return err } - return c.conn.SetWriteDeadline(t) -} - -// SetReadDeadline sets the deadline for future Read calls -// and any currently-blocked Read call. -// A zero value for t means Read will not time out. -func (c *wsConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -// SetWriteDeadline sets the deadline for future Write calls -// and any currently-blocked Write call. -// Even if write times out, it may return n > 0, indicating that -// some data was successfully written. -// A zero value for t means Write will not time out. -func (c *wsConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) + return c.SetWriteDeadline(t) } diff --git a/entity/network.go b/entity/network.go new file mode 100644 index 0000000..f686329 --- /dev/null +++ b/entity/network.go @@ -0,0 +1,24 @@ +package entity + +import "net" + +type NetworkEntity interface { + // Send 主动发送消息,支持自定义header,payload + Send(header, payload interface{}) error + // SendBytes 主动发送消息,消息需提前编码 + SendBytes(data []byte) error + // Status 获取当前连接状态 + Status() int32 + // SetStatus 设置当前连接状态 + SetStatus(s int32) + // Conn 获取当前底层连接(还需根据返回参数2决定是否转换为WSConn) + Conn() (net.Conn, bool) + // Session 获取当前连接 Session + Session() Session + // LastMID 最新消息ID + LastMID() uint64 + // SetLastMID 设置消息ID + SetLastMID(mid uint64) + // Close 关闭连接 + Close() error +} diff --git a/entity/session.go b/entity/session.go new file mode 100644 index 0000000..598a5bc --- /dev/null +++ b/entity/session.go @@ -0,0 +1,26 @@ +package entity + +type Session interface { + // ID 获取 Session ID + ID() int64 + // UID 获取UID + UID() string + // Bind 绑定uid + Bind(uid string) + // Attribute 获取指定key对应参数 + Attribute(key string) interface{} + // Keys 获取所有参数key + Keys() []string + // Exists 指定key是否存在 + Exists(key string) bool + // Attributes 获取所有参数 + Attributes() map[string]interface{} + // RemoveAttribute 移除指定key对应参数 + RemoveAttribute(key string) + // SetAttribute 设置参数 + SetAttribute(key string, value interface{}) + // Invalidate 清理 + Invalidate() + // Close 关闭 Session + Close() +} diff --git a/go.mod b/go.mod index 61fecb2..1467554 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,13 @@ go 1.20 require ( github.com/gorilla/websocket v1.5.0 - github.com/panjf2000/ants/v2 v2.6.0 + github.com/panjf2000/ants/v2 v2.7.3 ) require google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8 require ( - git.noahlan.cn/noahlan/ntools-go/core v1.1.1 + git.noahlan.cn/noahlan/ntools-go/core v1.1.3 github.com/fatih/color v1.15.0 // indirect github.com/gofrs/uuid/v5 v5.0.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index f5324f8..c5b02db 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,10 @@ git.noahlan.cn/noahlan/ntools-go/core v1.1.1 h1:icFPOTTpVYPa8NpNJteAwFBARPOuHE3695xZWNcAM2c= git.noahlan.cn/noahlan/ntools-go/core v1.1.1/go.mod h1:UN8UVL5WoyMgqNcxKoAu0/J9d+1hH2Yco64MUtPdjFk= +git.noahlan.cn/noahlan/ntools-go/core v1.1.3 h1:n4z0KaXmX/fmobavxCMc2vGJDoStbhNbm8AZugPEPGg= +git.noahlan.cn/noahlan/ntools-go/core v1.1.3/go.mod h1:pmwee9V76Cyp6nVr3dPj5TpePLvRpc8C0ZgAzFIFAKU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= @@ -19,7 +23,15 @@ github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPn github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/panjf2000/ants/v2 v2.6.0 h1:xOSpw42m+BMiJ2I33we7h6fYzG4DAlpE1xyI7VS2gxU= github.com/panjf2000/ants/v2 v2.6.0/go.mod h1:cU93usDlihJZ5CfRGNDYsiBYvoilLvBF5Qp/BT2GNRE= +github.com/panjf2000/ants/v2 v2.7.3/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= go.opentelemetry.io/otel v1.14.0 h1:/79Huy8wbf5DnIPhemGB+zEPVwnN6fuQybr/SRXa6hM= go.opentelemetry.io/otel v1.14.0/go.mod h1:o4buv+dJzx8rohcUeRmWUZhqupFvzWis188WlggnNeU= @@ -28,6 +40,7 @@ go.opentelemetry.io/otel/trace v1.14.0 h1:wp2Mmvj41tDsyAJXiWDWpfNsOiIyd38fy85pyK go.opentelemetry.io/otel/trace v1.14.0/go.mod h1:8avnQLK+CG77yNLUae4ea2JDQ6iT+gozhnZjy/rw9G8= golang.org/x/crypto v0.7.0 h1:AvwMYaRytfdeVt3u6mLaxYtErKYjxA2OXjJ1HHq6t3A= golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -37,4 +50,7 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8 h1:KR8+MyP7/qOlV+8Af01LtjL04bu7on42eVsxT4EyBQk= google.golang.org/protobuf v1.28.2-0.20220831092852-f930b1dc76e8/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/pool/pool.go b/internal/pool/pool.go deleted file mode 100644 index 2ec6e8a..0000000 --- a/internal/pool/pool.go +++ /dev/null @@ -1,27 +0,0 @@ -package pool - -import "github.com/panjf2000/ants/v2" - -var _pool *pool - -type pool struct { - connPool *ants.Pool - workerPool *ants.Pool -} - -func InitPool(size int) { - p := &pool{} - - p.connPool, _ = ants.NewPool(size, ants.WithNonblocking(true)) - p.workerPool, _ = ants.NewPool(size*2, ants.WithNonblocking(true)) - - _pool = p -} - -func SubmitConn(h func()) error { - return _pool.connPool.Submit(h) -} - -func SubmitWorker(h func()) error { - return _pool.workerPool.Submit(h) -} diff --git a/middleware/heartbeat.go b/middleware/heartbeat.go new file mode 100644 index 0000000..fa205e0 --- /dev/null +++ b/middleware/heartbeat.go @@ -0,0 +1,67 @@ +package middleware + +import ( + "git.noahlan.cn/noahlan/nnet/core" + "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/packet" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "sync/atomic" + "time" +) + +type HeartbeatMiddleware struct { + lastAt int64 + interval time.Duration + hbdFn func(entity entity.NetworkEntity) []byte +} + +func WithHeartbeat(interval time.Duration, dataFn func(entity entity.NetworkEntity) []byte) core.RunOption { + m := &HeartbeatMiddleware{ + lastAt: time.Now().Unix(), + interval: interval, + hbdFn: dataFn, + } + if dataFn == nil { + nlog.Error("dataFn must not be nil") + panic("dataFn must not be nil") + } + core.Lifetime.OnOpen(m.start) + + return func(server *core.Server) { + server.Use(func(next core.HandlerFunc) core.HandlerFunc { + return func(entity entity.NetworkEntity, pkg packet.IPacket) { + m.handle(entity, pkg) + + next(entity, pkg) + } + }) + } +} + +func (m *HeartbeatMiddleware) start(entity entity.NetworkEntity) { + ticker := time.NewTicker(m.interval) + + defer func() { + ticker.Stop() + }() + + for { + select { + case <-ticker.C: + deadline := time.Now().Add(-2 * m.interval).Unix() + if atomic.LoadInt64(&m.lastAt) < deadline { + nlog.Errorf("Heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&m.lastAt), deadline) + return + } + err := entity.SendBytes(m.hbdFn(entity)) + if err != nil { + nlog.Errorf("Heartbeat err: %v", err) + return + } + } + } +} + +func (m *HeartbeatMiddleware) handle(_ entity.NetworkEntity, _ packet.IPacket) { + atomic.StoreInt64(&m.lastAt, time.Now().Unix()) +} diff --git a/packet/packet.go b/packet/packet.go index cd90812..47889a9 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -4,7 +4,7 @@ type ( // IPacket 数据帧 IPacket interface { GetHeader() interface{} // 数据帧头部 Header - GetLen() uint32 // 数据帧长度 4bytes - 32bit 占位,根据实际情况进行转换 + GetLen() uint64 // 数据帧长度 8bytes,根据实际情况进行转换 GetBody() []byte // 数据 Body } diff --git a/core/pipeline.go b/pipeline/pipeline.go similarity index 80% rename from core/pipeline.go rename to pipeline/pipeline.go index 994ec50..7a7d8d7 100644 --- a/core/pipeline.go +++ b/pipeline/pipeline.go @@ -1,11 +1,12 @@ -package core +package pipeline import ( + "git.noahlan.cn/noahlan/nnet/entity" "sync" ) type ( - Func func(conn *Connection, v interface{}) error + Func func(entity entity.NetworkEntity, v interface{}) error // Pipeline 消息管道 Pipeline interface { @@ -20,7 +21,7 @@ type ( Channel interface { PushFront(h Func) PushBack(h Func) - Process(conn *Connection, v interface{}) error + Process(entity entity.NetworkEntity, v interface{}) error } pipelineChannel struct { @@ -65,16 +66,16 @@ func (p *pipelineChannel) PushBack(h Func) { } // Process 处理所有的pipeline方法 -func (p *pipelineChannel) Process(conn *Connection, v interface{}) error { - p.mu.RLock() - defer p.mu.RUnlock() - +func (p *pipelineChannel) Process(entity entity.NetworkEntity, v interface{}) error { if len(p.handlers) < 1 { return nil } + p.mu.RLock() + defer p.mu.RUnlock() + for _, handler := range p.handlers { - err := handler(conn, v) + err := handler(entity, v) if err != nil { return err } diff --git a/protocol/nnet.go b/protocol/nnet.go new file mode 100644 index 0000000..bcc6daa --- /dev/null +++ b/protocol/nnet.go @@ -0,0 +1,34 @@ +package protocol + +import ( + "git.noahlan.cn/noahlan/nnet/core" + "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/middleware" + "git.noahlan.cn/noahlan/nnet/packet" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "time" +) + +type NNetConfig struct { +} + +func WithNNetProtocol( + handshakeValidator func([]byte) error, + heartbeatInterval time.Duration, +) []core.RunOption { + if handshakeValidator == nil { + handshakeValidator = func(bytes []byte) error { return nil } + } + packer := NewNNetPacker() + hbd, err := packer.Pack(Handshake, nil) + nlog.Must(err) + + return []core.RunOption{ + WithNNetPipeline(handshakeValidator), + core.WithRouter(NewNNetRouter()), + core.WithPacker(func() packet.Packer { return NewNNetPacker() }), + middleware.WithHeartbeat(heartbeatInterval, func(_ entity.NetworkEntity) []byte { + return hbd + }), + } +} diff --git a/packet/packer_nnet.go b/protocol/packer_nnet.go similarity index 97% rename from packet/packer_nnet.go rename to protocol/packer_nnet.go index f5f4357..4cbb5e4 100644 --- a/packet/packer_nnet.go +++ b/protocol/packer_nnet.go @@ -1,9 +1,10 @@ -package packet +package protocol import ( "bytes" "encoding/binary" "errors" + "git.noahlan.cn/noahlan/nnet/packet" ) type NNetPacker struct { @@ -133,11 +134,11 @@ func (d *NNetPacker) intToBytes(n uint32) []byte { return buf } -func (d *NNetPacker) Unpack(data []byte) ([]IPacket, error) { +func (d *NNetPacker) Unpack(data []byte) ([]packet.IPacket, error) { d.buf.Write(data) // copy var ( - packets []IPacket + packets []packet.IPacket err error ) diff --git a/packet/packer_nnet_test.go b/protocol/packer_nnet_test.go similarity index 99% rename from packet/packer_nnet_test.go rename to protocol/packer_nnet_test.go index 6f4072c..2c3a04d 100644 --- a/packet/packer_nnet_test.go +++ b/protocol/packer_nnet_test.go @@ -1,4 +1,4 @@ -package packet +package protocol import ( "encoding/hex" diff --git a/packet/packet_nnet.go b/protocol/packet_nnet.go similarity index 77% rename from packet/packet_nnet.go rename to protocol/packet_nnet.go index 9311e92..339eefe 100644 --- a/packet/packet_nnet.go +++ b/protocol/packet_nnet.go @@ -1,4 +1,4 @@ -package packet +package protocol import ( "encoding/hex" @@ -61,14 +61,14 @@ type ( Route string // route for locating service compressed bool // if message compressed } - Packet struct { + NNetPacket struct { Header Data []byte // 原始数据 } ) -func newPacket(typ Type) *Packet { - return &Packet{ +func newPacket(typ Type) *NNetPacket { + return &NNetPacket{ Header: Header{ PacketType: typ, MessageHeader: MessageHeader{}, @@ -76,19 +76,19 @@ func newPacket(typ Type) *Packet { } } -func (p *Packet) GetHeader() interface{} { +func (p *NNetPacket) GetHeader() interface{} { return p.Header } -func (p *Packet) GetLen() uint32 { - return p.Length +func (p *NNetPacket) GetLen() uint64 { + return uint64(p.Length) } -func (p *Packet) GetBody() []byte { +func (p *NNetPacket) GetBody() []byte { return p.Data } -func (p *Packet) String() string { - return fmt.Sprintf("Packet[Type: %d, Len: %d] Message[Type: %s, ID: %d, Route: %s, Compressed: %v] BodyStr: [%s], BodyHex: [%s]", +func (p *NNetPacket) String() string { + return fmt.Sprintf("NNetPacket[Type: %d, Len: %d] Message[Type: %s, ID: %d, Route: %s, Compressed: %v] BodyStr: [%s], BodyHex: [%s]", p.PacketType, p.Length, p.MsgType, p.ID, p.Route, p.compressed, string(p.Data), hex.EncodeToString(p.Data)) } diff --git a/protocol/pipeline_nnet.go b/protocol/pipeline_nnet.go new file mode 100644 index 0000000..5879b50 --- /dev/null +++ b/protocol/pipeline_nnet.go @@ -0,0 +1,86 @@ +package protocol + +import ( + "encoding/json" + "fmt" + "git.noahlan.cn/noahlan/nnet/core" + "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "time" +) + +type ( + handshakeData struct { + Version string `json:"version"` // 客户端版本号,服务器以此判断是否合适与客户端通信 + Type string `json:"type"` // 客户端类型,与客户端版本号一起来确定客户端是否合适 + + // 透传信息 + Payload interface{} `json:"payload,optional,omitempty"` + } + handshakeAckData struct { + Heartbeat int64 `json:"heartbeat"` // 心跳间隔,单位秒 0表示不需要心跳 + // 路由 + Routes map[string]uint16 `json:"routes"` // route map to code + Codes map[uint16]string `json:"codes"` // code map to route + + // 服务端支持的body部分消息传输协议 + //Protocol string `json:"protocol,options=[plain,json,protobuf]"` // plain/json/protobuf + + // 透传信息 + Payload interface{} `json:"payload,optional,omitempty"` + } +) + +func WithNNetPipeline(heartbeatInterval time.Duration, handshakeValidator func([]byte) error) core.RunOption { + handshakeAck := &handshakeAckData{} + data, err := json.Marshal(handshakeAck) + nlog.Must(err) + + packer := NewNNetPacker() + hrd, _ := packer.Pack(Handshake, data) + + return func(server *core.Server) { + server.Pipeline().Inbound().PushFront(func(entity entity.NetworkEntity, v interface{}) error { + pkg, ok := v.(*NNetPacket) + if !ok { + return ErrWrongPacketType + } + conn, _ := entity.Conn() + + switch pkg.PacketType { + case Handshake: + if err := handshakeValidator(pkg.Data); err != nil { + return err + } + if err := entity.SendBytes(hrd); err != nil { + return err + } + entity.SetStatus(core.StatusPrepare) + nlog.Debugf("connection handshake Id=%d, Remote=%s", entity.Session().ID(), conn.RemoteAddr()) + case HandshakeAck: + entity.SetStatus(core.StatusPending) + nlog.Debugf("receive handshake ACK Id=%d, Remote=%s", entity.Session().ID(), conn.RemoteAddr()) + case Heartbeat: + // Expected + case Data: + if entity.Status() < core.StatusPending { + return errors.New(fmt.Sprintf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s", + conn.RemoteAddr())) + } + entity.SetStatus(core.StatusWorking) + + var lastMid uint64 + switch pkg.MsgType { + case Request: + lastMid = pkg.ID + case Notify: + lastMid = 0 + default: + return fmt.Errorf("Invalid message type: %s ", pkg.MsgType.String()) + } + entity.SetLastMID(lastMid) + } + return nil + }) + } +} diff --git a/protocol/router_nnet.go b/protocol/router_nnet.go new file mode 100644 index 0000000..a692621 --- /dev/null +++ b/protocol/router_nnet.go @@ -0,0 +1,51 @@ +package protocol + +import ( + "errors" + "git.noahlan.cn/noahlan/nnet/core" + "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/packet" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" +) + +type nNetRouter struct { + handlers map[string]core.Handler + notFound core.Handler +} + +func NewNNetRouter() core.Router { + return &nNetRouter{ + handlers: make(map[string]core.Handler), + } +} + +func (r *nNetRouter) Handle(entity entity.NetworkEntity, p packet.IPacket) { + pkg, ok := p.(*NNetPacket) + if !ok { + nlog.Error(ErrWrongMessage) + return + } + handler, ok := r.handlers[pkg.Header.Route] + if !ok { + if r.notFound == nil { + nlog.Error("message handler not found") + return + } + r.notFound.Handle(entity, p) + return + } + handler.Handle(entity, p) +} + +func (r *nNetRouter) Register(matches interface{}, handler core.Handler) error { + route, ok := matches.(string) + if !ok { + return errors.New("the type of matches must be string") + } + r.handlers[route] = handler + return nil +} + +func (r *nNetRouter) SetNotFoundHandler(handler core.Handler) { + r.notFound = handler +} diff --git a/session/session_mgr.go b/session/session_mgr.go deleted file mode 100644 index 45d3858..0000000 --- a/session/session_mgr.go +++ /dev/null @@ -1,46 +0,0 @@ -package session - -import ( - "sync" -) - -type Manager struct { - sync.RWMutex - sessions map[int64]*Session -} - -func NewSessionMgr() *Manager { - return &Manager{ - RWMutex: sync.RWMutex{}, - sessions: make(map[int64]*Session), - } -} - -func (m *Manager) StoreSession(s *Session) { - m.Lock() - defer m.Unlock() - - m.sessions[s.ID()] = s -} - -func (m *Manager) FindSession(sid int64) *Session { - m.RLock() - defer m.RUnlock() - - return m.sessions[sid] -} - -func (m *Manager) FindOrCreateSession(sid int64) *Session { - m.RLock() - s, ok := m.sessions[sid] - m.RUnlock() - - if !ok { - s = NewSession() - - m.Lock() - m.sessions[s.ID()] = s - m.Unlock() - } - return s -}