diff --git a/client_tcp.go b/client_tcp.go new file mode 100644 index 0000000..fa43f2f --- /dev/null +++ b/client_tcp.go @@ -0,0 +1,23 @@ +package nnet + +import ( + "git.noahlan.cn/noahlan/nnet/connection" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "net" +) + +// Dial 连接服务器 +func (ngin *Engine) Dial(addr string) (*connection.Connection, error) { + err := ngin.setup() + if err != nil { + nlog.Errorf("%s failed to setup server, err:%v", ngin.LogPrefix(), err) + return nil, err + } + + conn, err := net.Dial("tcp", addr) + nlog.Must(err) + + nlog.Infof("%s now connect to %s...", ngin.LogPrefix(), addr) + + return ngin.handle(conn), nil +} diff --git a/config/config.go b/config/config.go index 925fbaa..8788f1b 100644 --- a/config/config.go +++ b/config/config.go @@ -1,7 +1,7 @@ package config import ( - "git.noahlan.cn/noahlan/ntools-go/core/pool" + "fmt" "time" ) @@ -16,29 +16,20 @@ const ( type ( EngineConf struct { - ServerConf - Pool pool.Config - } - - ServerConf struct { - // Protocol 协议名 - // "tcp", "tcp4", "tcp6", "unix" or "unixpacket" - // 若只想开启IPv4, 使用tcp4即可 - Protocol string `json:",default=tcp4"` - // Addr 服务地址 - // 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP - // 如果端口号不填或端口号为0,例如:"127.0.0.1:" 或 ":0",服务端将选择随机可用端口 - Addr string `json:",default=0.0.0.0"` - // Name 服务端名称,默认为n-net - Name string `json:",default=n-net"` // TaskTimerPrecision // 全局任务的timer间隔 TaskTimerPrecision time.Duration `json:",default=1s"` // Mode 运行模式 Mode string `json:",default=dev,options=[dev,test,prod]"` + // Name 引擎名称 + Name string `json:",default=NL,env=ENGINE_NAME"` } ) // ShallLogDebug 是否应该打印 Debug 级别的日志,打印的首要条件是 nlog 的打印级别为 debug -func ShallLogDebug(mode string) bool { - return mode == DevMode || mode == TestMode +func (c EngineConf) ShallLogDebug() bool { + return c.Mode == DevMode || c.Mode == TestMode +} + +func (c EngineConf) LogPrefix() string { + return fmt.Sprintf("[NNet-%s]", c.Name) } diff --git a/config/server_tcp.go b/config/server_tcp.go new file mode 100644 index 0000000..52b3e3b --- /dev/null +++ b/config/server_tcp.go @@ -0,0 +1,14 @@ +package config + +type ( + TCPServerConf struct { + // Protocol 协议名 + // "tcp", "tcp4", "tcp6", "unix" or "unixpacket" + // 若只想开启IPv4, 使用tcp4即可 + Protocol string `json:",default=tcp4,env=TCP_PROTOCOL"` + // Addr 服务地址 + // 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP + // 如果端口号不填或端口号为0,例如:"127.0.0.1:" 或 ":0",服务端将选择随机可用端口 + Addr string `json:",default=0.0.0.0:9876,env=TCP_ADDR"` + } +) diff --git a/config/server_ws.go b/config/server_ws.go new file mode 100644 index 0000000..85ccc63 --- /dev/null +++ b/config/server_ws.go @@ -0,0 +1,35 @@ +package config + +import ( + "net/http" + "time" +) + +type ( + WSServerConf struct { + // Addr 服务地址 + // 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP + // 如果端口号不填或端口号为0,例如:"127.0.0.1:" 或 ":0",服务端将选择随机可用端口 + Addr string `json:",default=0.0.0.0:9876,env=WS_ADDR"` + // Path 监听路径 /WebsocketPath + Path string `json:",default=/,env=WS_PATH"` + // HandshakeTimeout 握手超时时间,默认0 + HandshakeTimeout time.Duration `json:",default=0"` + // ReadBufferSize 读缓冲区大小 + ReadBufferSize int `json:",default=2048"` + // WriteBufferSize 写缓冲区大小 + WriteBufferSize int `json:",default=2048"` + // Compression 是否使用压缩协议 + Compression bool `json:",default=false"` + // TLSCertificate 证书地址 + TLSCertificate string `json:",optional"` + // TLS 证书key地址 + TLSKey string `json:",optional"` + // check origin + CheckOrigin func(*http.Request) bool + } +) + +func (c WSServerConf) IsTLS() bool { + return len(c.TLSCertificate) > 0 && len(c.TLSKey) > 0 +} diff --git a/core/connection.go b/connection/connection.go similarity index 56% rename from core/connection.go rename to connection/connection.go index 5df9fc4..a9b1f1c 100644 --- a/core/connection.go +++ b/connection/connection.go @@ -1,11 +1,11 @@ -package core +package connection import ( "errors" "fmt" - "git.noahlan.cn/noahlan/nnet/entity" "git.noahlan.cn/noahlan/nnet/packet" - "git.noahlan.cn/noahlan/nnet/scheduler" + "git.noahlan.cn/noahlan/nnet/serialize" + "git.noahlan.cn/noahlan/nnet/session" "git.noahlan.cn/noahlan/ntools-go/core/nlog" "git.noahlan.cn/noahlan/ntools-go/core/pool" "net" @@ -31,16 +31,26 @@ const ( StatusClosed ) +type ConnType int + +const ( + ConnTypeTCP ConnType = iota // TCP connection + ConnTypeWS // Websocket connection +) + type ( - connection struct { - session *session // Session - ngin *engine // engine + Connection struct { + conf Config // 配置 + session *session.Session // Session status int32 // 连接状态 conn net.Conn // low-level conn fd - isWS bool // 是否为websocket + typ ConnType // 连接类型 - packer packet.Packer // 封包、拆包器 + packer packet.Packer // 封包、拆包器 + serializer serialize.Serializer // 消息序列化/反序列化器 + pipeline Pipeline // 连接生命周期管理 + handleFn func(conn *Connection, pkg packet.IPacket) // 消息处理方法 lastMid uint64 // 最近一次消息ID @@ -49,32 +59,54 @@ type ( chWrite chan []byte // 消息发送通道(二进制消息) } + packetFn func(conn *Connection, pkg packet.IPacket) + + Config struct { + LogDebug bool + LogPrefix string + } + PendingMessage struct { header interface{} payload interface{} } ) -func newConnection(server *engine, conn net.Conn) *connection { - r := &connection{ - ngin: server, +func NewConnection( + id int64, + conn net.Conn, + conf Config, + packerBuilder packet.PackerBuilder, + serializer serialize.Serializer, + pipeline Pipeline, + handleFn packetFn) *Connection { + r := &Connection{ + conf: conf, + session: session.NewSession(id), status: StatusStart, conn: conn, - packer: server.packerFn(), + typ: ConnTypeTCP, + + packer: packerBuilder(), + serializer: serializer, + pipeline: pipeline, + handleFn: handleFn, + + lastMid: 0, chDie: make(chan struct{}), chSend: make(chan PendingMessage, 128), chWrite: make(chan []byte, 128), } - _, r.isWS = conn.(*WSConn) - - // binding session - r.session = newSession(r, server.sessIdMgr.SessionID()) + _, ok := conn.(*WSConn) + if ok { + r.typ = ConnTypeWS + } return r } -func (r *connection) Send(header, payload interface{}) (err error) { +func (r *Connection) Send(header, payload interface{}) (err error) { defer func() { if e := recover(); e != nil { err = ErrBrokenPipe @@ -87,7 +119,7 @@ func (r *connection) Send(header, payload interface{}) (err error) { return err } -func (r *connection) SendBytes(data []byte) (err error) { +func (r *Connection) SendBytes(data []byte) (err error) { defer func() { if e := recover(); e != nil { err = ErrBrokenPipe @@ -97,35 +129,35 @@ func (r *connection) SendBytes(data []byte) (err error) { return err } -func (r *connection) Status() int32 { +func (r *Connection) Status() int32 { return atomic.LoadInt32(&r.status) } -func (r *connection) SetStatus(s int32) { +func (r *Connection) SetStatus(s int32) { atomic.StoreInt32(&r.status, s) } -func (r *connection) Conn() (net.Conn, bool) { - return r.conn, r.isWS +func (r *Connection) Conn() (net.Conn, ConnType) { + return r.conn, r.typ } -func (r *connection) ID() int64 { +func (r *Connection) ID() int64 { return r.session.ID() } -func (r *connection) Session() entity.Session { +func (r *Connection) Session() *session.Session { return r.session } -func (r *connection) LastMID() uint64 { +func (r *Connection) LastMID() uint64 { return r.lastMid } -func (r *connection) SetLastMID(mid uint64) { +func (r *Connection) SetLastMID(mid uint64) { atomic.StoreUint64(&r.lastMid, mid) } -func (r *connection) serve() { +func (r *Connection) Serve() { _ = pool.Submit(func() { r.write() }) @@ -135,15 +167,15 @@ func (r *connection) serve() { }) } -func (r *connection) write() { +func (r *Connection) write() { defer func() { close(r.chSend) close(r.chWrite) _ = r.Close() - if r.ngin.shallLogDebug() { + if r.conf.LogDebug { nlog.Debugf("%s [writeLoop] connection write goroutine exit, ConnID=%d, SessionUID=%s", - r.ngin.logPrefix(), r.ID(), r.session.UID()) + r.conf.LogPrefix, r.ID(), r.session.UID()) } }() @@ -151,51 +183,52 @@ func (r *connection) write() { select { case data := <-r.chSend: // marshal packet body (data) - if r.ngin.serializer == nil { + if r.serializer == nil { if _, ok := data.payload.([]byte); !ok { - nlog.Errorf("%s serializer is nil, but payload type not []byte", r.ngin.logPrefix()) + nlog.Errorf("%s serializer is nil, but payload type not []byte", r.conf.LogPrefix) break } } else { - payload, err := r.ngin.serializer.Marshal(data.payload) + payload, err := r.serializer.Marshal(data.payload) if err != nil { - nlog.Errorf("%s message body marshal err: %v", r.ngin.logPrefix(), err) + nlog.Errorf("%s message body marshal err: %v", r.conf.LogPrefix, err) break } data.payload = payload } // invoke pipeline - if pipe := r.ngin.pipeline; pipe != nil { + if pipe := r.pipeline; pipe != nil { err := pipe.Outbound().Process(r, data) if err != nil { - nlog.Errorf("%s pipeline err: %s", r.ngin.logPrefix(), err.Error()) + nlog.Errorf("%s pipeline err: %s", r.conf.LogPrefix, err.Error()) } } // packet pack data p, err := r.packer.Pack(data.header, data.payload.([]byte)) if err != nil { - nlog.Errorf("%s pack err: %s", r.ngin.logPrefix(), err.Error()) + nlog.Errorf("%s pack err: %s", r.conf.LogPrefix, err.Error()) break } r.chWrite <- p case data := <-r.chWrite: // 回写数据 if _, err := r.conn.Write(data); err != nil { - nlog.Errorf("%s write data err: %s", r.ngin.logPrefix(), err.Error()) + nlog.Errorf("%s write data err: %s", r.conf.LogPrefix, err.Error()) break } //nlog.Debugf("write data %v", data) case <-r.chDie: // connection close signal return - case <-r.ngin.dieChan: // application quit signal - return + // TODO + //case <-r.ngin.dieChan: // application quit signal + // return } } } -func (r *connection) read() { +func (r *Connection) read() { defer func() { _ = r.Close() }() @@ -205,38 +238,38 @@ func (r *connection) read() { //nlog.Debugf("receive data %v", buf[:n]) if err != nil { nlog.Errorf("%s [readLoop] Read message error: %s, session will be closed immediately", - r.ngin.logPrefix(), err.Error()) + r.conf.LogPrefix, err.Error()) return } if n == 0 { nlog.Errorf("%s [readLoop] Read empty message, session will be closed immediately", - r.ngin.logPrefix()) + r.conf.LogPrefix) return } if r.packer == nil { - nlog.Errorf("%s [readLoop] unexpected error: packer is nil", r.ngin.logPrefix()) + nlog.Errorf("%s [readLoop] unexpected error: packer is nil", r.conf.LogPrefix) return } // warning: 为性能考虑,复用slice处理数据,buf传入后必须要copy再处理 packets, err := r.packer.Unpack(buf[:n]) if err != nil { - nlog.Errorf("%s unpack err: %s", r.ngin.logPrefix(), err.Error()) + nlog.Errorf("%s unpack err: %s", r.conf.LogPrefix, err.Error()) } // packets 处理 for _, p := range packets { if err := r.processPacket(p); err != nil { - nlog.Errorf("%s process packet err: %s", r.ngin.logPrefix(), err.Error()) + nlog.Errorf("%s process packet err: %s", r.conf.LogPrefix, err.Error()) continue } } } } -func (r *connection) processPacket(packet packet.IPacket) error { - if pipe := r.ngin.pipeline; pipe != nil { +func (r *Connection) processPacket(packet packet.IPacket) error { + if pipe := r.pipeline; pipe != nil { err := pipe.Inbound().Process(r, packet) if err != nil { return errors.New(fmt.Sprintf("pipeline process failed: %v", err.Error())) @@ -244,32 +277,34 @@ func (r *connection) processPacket(packet packet.IPacket) error { } if r.Status() == StatusWorking { - // HandleFunc + // 处理包消息 _ = pool.Submit(func() { - r.ngin.handler.Handle(r, packet) + r.handleFn(r, packet) }) } return nil } -func (r *connection) Close() error { +func (r *Connection) DieChan() chan struct{} { + return r.chDie +} + +func (r *Connection) Close() error { if r.Status() == StatusClosed { return ErrCloseClosedSession } r.SetStatus(StatusClosed) - if r.ngin.shallLogDebug() { - nlog.Debugf("%s close connection, ID: %d", r.ngin.logPrefix(), r.ID()) + if r.conf.LogDebug { + nlog.Debugf("%s close connection, ID: %d", r.conf.LogPrefix, r.ID()) } select { case <-r.chDie: default: close(r.chDie) - scheduler.PushTask(func() { r.ngin.lifetime.Close(r) }) } - _ = r.ngin.connManager.Remove(r) - r.session.Close() + r.session.Close() return r.conn.Close() } diff --git a/conn/group.go b/connection/group.go similarity index 55% rename from conn/group.go rename to connection/group.go index 2df6a89..0228e6a 100644 --- a/conn/group.go +++ b/connection/group.go @@ -1,8 +1,7 @@ -package conn +package connection import ( "errors" - "git.noahlan.cn/noahlan/nnet/entity" "git.noahlan.cn/noahlan/ntools-go/core/nlog" "sync" "sync/atomic" @@ -27,7 +26,7 @@ type Group struct { status int32 // group current status name string // group name - conns map[int64]entity.NetworkEntity + conns map[int64]*Connection } func NewGroup(name string) *Group { @@ -35,16 +34,16 @@ func NewGroup(name string) *Group { mu: sync.RWMutex{}, status: groupStatusWorking, name: name, - conns: make(map[int64]entity.NetworkEntity), + conns: make(map[int64]*Connection), } } // Member returns connection by specified uid -func (c *Group) Member(uid string) (entity.NetworkEntity, bool) { - c.mu.RLock() - defer c.mu.RUnlock() +func (g *Group) Member(uid string) (*Connection, bool) { + g.mu.RLock() + defer g.mu.RUnlock() - for _, e := range c.conns { + for _, e := range g.conns { if e.Session().UID() == uid { return e, true } @@ -54,18 +53,18 @@ func (c *Group) Member(uid string) (entity.NetworkEntity, bool) { } // MemberBySID returns specified sId's connection -func (c *Group) MemberBySID(id int64) (entity.NetworkEntity, bool) { - c.mu.RLock() - defer c.mu.RUnlock() +func (g *Group) MemberBySID(id int64) (*Connection, bool) { + g.mu.RLock() + defer g.mu.RUnlock() - e, ok := c.conns[id] + e, ok := g.conns[id] return e, ok } -func (c *Group) Members() []entity.NetworkEntity { - var resp []entity.NetworkEntity - c.PeekMembers(func(_ int64, e entity.NetworkEntity) bool { - resp = append(resp, e) +func (g *Group) Members() []*Connection { + var resp []*Connection + g.PeekMembers(func(_ int64, c *Connection) bool { + resp = append(resp, c) return false }) return resp @@ -73,33 +72,33 @@ func (c *Group) Members() []entity.NetworkEntity { // PeekMembers returns all members in current group // fn 返回true跳过循环,反之一直循环 -func (c *Group) PeekMembers(fn func(sId int64, e entity.NetworkEntity) bool) { - c.mu.RLock() - defer c.mu.RUnlock() +func (g *Group) PeekMembers(fn func(sId int64, c *Connection) bool) { + g.mu.RLock() + defer g.mu.RUnlock() - for sId, e := range c.conns { - if fn(sId, e) { + for sId, c := range g.conns { + if fn(sId, c) { break } } } // Contains check whether a UID is contained in current group or not -func (c *Group) Contains(uid string) bool { - _, ok := c.Member(uid) +func (g *Group) Contains(uid string) bool { + _, ok := g.Member(uid) return ok } // Add session to group -func (c *Group) Add(e entity.NetworkEntity) error { - if c.isClosed() { +func (g *Group) Add(c *Connection) error { + if g.isClosed() { return ErrClosedGroup } - c.mu.Lock() - defer c.mu.Unlock() + g.mu.Lock() + defer g.mu.Unlock() - sess := e.Session() + sess := c.Session() id := sess.ID() // group attribute @@ -110,41 +109,41 @@ func (c *Group) Add(e entity.NetworkEntity) error { sess.SetAttribute(groupKey, groups) } contains := false - for _, g := range groups { - if g == c.name { + for _, group := range groups { + if group == g.name { contains = true break } } if !contains { - groups = append(groups, c.name) + groups = append(groups, g.name) sess.SetAttribute(groupKey, groups) } } else { - sess.SetAttribute(groupKey, []string{c.name}) + sess.SetAttribute(groupKey, []string{g.name}) } - if _, ok := c.conns[id]; !ok { - c.conns[id] = e + if _, ok := g.conns[id]; !ok { + g.conns[id] = c } - nlog.Debugf("Add connection to group %s, ID=%d, UID=%s", c.name, sess.ID(), sess.UID()) + nlog.Debugf("Add connection to group %s, ID=%d, UID=%s", g.name, sess.ID(), sess.UID()) return nil } // Leave remove specified UID related session from group -func (c *Group) Leave(e entity.NetworkEntity) error { - if c.isClosed() { +func (g *Group) Leave(c *Connection) error { + if g.isClosed() { return ErrClosedGroup } - if e == nil { + if c == nil { return nil } - sess := e.Session() - nlog.Debugf("Remove connection from group %s, UID=%s", c.name, sess.UID()) + sess := c.Session() + nlog.Debugf("Remove connection from group %s, UID=%s", g.name, sess.UID()) - c.mu.Lock() - defer c.mu.Unlock() + g.mu.Lock() + defer g.mu.Unlock() if sess.Exists(groupKey) { groups, ok := sess.Attribute(groupKey).([]string) @@ -152,7 +151,7 @@ func (c *Group) Leave(e entity.NetworkEntity) error { groups = make([]string, 0) sess.SetAttribute(groupKey, groups) } - groups = c.removeGroupAttr(groups) + groups = g.removeGroupAttr(groups) if len(groups) == 0 { sess.RemoveAttribute(groupKey) @@ -161,28 +160,28 @@ func (c *Group) Leave(e entity.NetworkEntity) error { } } - delete(c.conns, sess.ID()) + delete(g.conns, sess.ID()) return nil } -func (c *Group) LeaveByUID(uid string) error { - if c.isClosed() { +func (g *Group) LeaveByUID(uid string) error { + if g.isClosed() { return ErrClosedGroup } - member, _ := c.Member(uid) - return c.Leave(member) + member, _ := g.Member(uid) + return g.Leave(member) } // LeaveAll clear all sessions in the group -func (c *Group) LeaveAll() error { - if c.isClosed() { +func (g *Group) LeaveAll() error { + if g.isClosed() { return ErrClosedGroup } - c.mu.Lock() - defer c.mu.Unlock() + g.mu.Lock() + defer g.mu.Unlock() - for _, e := range c.conns { + for _, e := range g.conns { sess := e.Session() groups, ok := sess.Attribute(groupKey).([]string) @@ -190,7 +189,7 @@ func (c *Group) LeaveAll() error { groups = make([]string, 0) sess.SetAttribute(groupKey, groups) } - groups = c.removeGroupAttr(groups) + groups = g.removeGroupAttr(groups) if len(groups) == 0 { sess.RemoveAttribute(groupKey) @@ -198,15 +197,15 @@ func (c *Group) LeaveAll() error { sess.SetAttribute(groupKey, groups) } } - c.conns = make(map[int64]entity.NetworkEntity) + g.conns = make(map[int64]*Connection) return nil } // 使用移位法移除group中与name匹配的元素 -func (c *Group) removeGroupAttr(groups []string) []string { +func (g *Group) removeGroupAttr(groups []string) []string { j := 0 for _, v := range groups { - if v != c.name { + if v != g.name { groups[j] = v j++ } @@ -215,32 +214,32 @@ func (c *Group) removeGroupAttr(groups []string) []string { } // Count get current member amount in the group -func (c *Group) Count() int { - c.mu.RLock() - defer c.mu.RUnlock() +func (g *Group) Count() int { + g.mu.RLock() + defer g.mu.RUnlock() - return len(c.conns) + return len(g.conns) } -func (c *Group) isClosed() bool { - if atomic.LoadInt32(&c.status) == groupStatusClosed { +func (g *Group) isClosed() bool { + if atomic.LoadInt32(&g.status) == groupStatusClosed { return true } return false } // Close destroy group, which will release all resource in the group -func (c *Group) Close() error { - if c.isClosed() { +func (g *Group) Close() error { + if g.isClosed() { return ErrCloseClosedGroup } - if c.name == DefaultGroupName { + if g.name == DefaultGroupName { // 默认分组不允许删除 return DeleteDefaultGroupNotAllow } - _ = c.LeaveAll() + _ = g.LeaveAll() - atomic.StoreInt32(&c.status, groupStatusClosed) + atomic.StoreInt32(&g.status, groupStatusClosed) return nil } diff --git a/conn/conn_mgr.go b/connection/manager.go similarity index 66% rename from conn/conn_mgr.go rename to connection/manager.go index 368bcfb..9aebbef 100644 --- a/conn/conn_mgr.go +++ b/connection/manager.go @@ -1,7 +1,6 @@ -package conn +package connection import ( - "git.noahlan.cn/noahlan/nnet/entity" "sync" ) @@ -11,38 +10,38 @@ type Manager struct { // 分组 groups map[string]*Group // 所有 Connection - conns map[int64]entity.NetworkEntity + conns map[int64]*Connection } func NewManager() *Manager { return &Manager{ RWMutex: sync.RWMutex{}, groups: make(map[string]*Group), - conns: make(map[int64]entity.NetworkEntity), + conns: make(map[int64]*Connection), } } // Store 保存连接,同时加入到指定分组,若给定分组名为空,则不进行分组操作 -func (m *Manager) Store(groupName string, s entity.NetworkEntity) error { +func (m *Manager) Store(groupName string, c *Connection) error { m.Lock() - m.conns[s.Session().ID()] = s + m.conns[c.Session().ID()] = c m.Unlock() group, ok := m.FindGroup(groupName) if !ok { group = m.NewGroup(groupName) } - return group.Add(s) + return group.Add(c) } -func (m *Manager) Remove(s entity.NetworkEntity) error { +func (m *Manager) Remove(c *Connection) error { m.Lock() defer m.Unlock() - delete(m.conns, s.Session().ID()) + delete(m.conns, c.Session().ID()) // 从所有group中移除 for _, group := range m.groups { - err := group.Leave(s) + err := group.Leave(c) if err != nil { return err } @@ -50,9 +49,9 @@ func (m *Manager) Remove(s entity.NetworkEntity) error { return nil } -func (m *Manager) RemoveFromGroup(groupName string, s entity.NetworkEntity) error { +func (m *Manager) RemoveFromGroup(groupName string, c *Connection) error { m.Lock() - delete(m.conns, s.Session().ID()) + delete(m.conns, c.Session().ID()) m.Unlock() group, ok := m.FindGroup(groupName) @@ -60,7 +59,7 @@ func (m *Manager) RemoveFromGroup(groupName string, s entity.NetworkEntity) erro return nil } - return group.Leave(s) + return group.Leave(c) } // NewGroup 新增分组,若分组已存在,则返回现有分组 @@ -89,7 +88,7 @@ func (m *Manager) FindGroup(name string) (*Group, bool) { } // FindConn 根据连接ID找到连接 -func (m *Manager) FindConn(id int64) (entity.NetworkEntity, bool) { +func (m *Manager) FindConn(id int64) (*Connection, bool) { m.RLock() defer m.RUnlock() @@ -98,7 +97,7 @@ func (m *Manager) FindConn(id int64) (entity.NetworkEntity, bool) { } // FindConnByUID 根据连接绑定的UID找到连接 -func (m *Manager) FindConnByUID(uid string) (entity.NetworkEntity, bool) { +func (m *Manager) FindConnByUID(uid string) (*Connection, bool) { m.RLock() defer m.RUnlock() @@ -112,12 +111,12 @@ func (m *Manager) FindConnByUID(uid string) (entity.NetworkEntity, bool) { // PeekConn 循环所有连接 // fn 返回true跳过循环,反之一直循环 -func (m *Manager) PeekConn(fn func(id int64, e entity.NetworkEntity) bool) { +func (m *Manager) PeekConn(fn func(id int64, c *Connection) bool) { m.RLock() defer m.RUnlock() - for id, e := range m.conns { - if fn(id, e) { + for id, c := range m.conns { + if fn(id, c) { break } } diff --git a/pipeline/pipeline.go b/connection/pipeline.go similarity index 78% rename from pipeline/pipeline.go rename to connection/pipeline.go index 7a7d8d7..17de2d9 100644 --- a/pipeline/pipeline.go +++ b/connection/pipeline.go @@ -1,12 +1,11 @@ -package pipeline +package connection import ( - "git.noahlan.cn/noahlan/nnet/entity" "sync" ) type ( - Func func(entity entity.NetworkEntity, v interface{}) error + Func func(c *Connection, v interface{}) error // Pipeline 消息管道 Pipeline interface { @@ -21,7 +20,7 @@ type ( Channel interface { PushFront(h Func) PushBack(h Func) - Process(entity entity.NetworkEntity, v interface{}) error + Process(c *Connection, v interface{}) error } pipelineChannel struct { @@ -30,7 +29,7 @@ type ( } ) -func New() Pipeline { +func NewPipeline() Pipeline { return &pipeline{ outbound: &pipelineChannel{}, inbound: &pipelineChannel{}, @@ -66,7 +65,7 @@ func (p *pipelineChannel) PushBack(h Func) { } // Process 处理所有的pipeline方法 -func (p *pipelineChannel) Process(entity entity.NetworkEntity, v interface{}) error { +func (p *pipelineChannel) Process(c *Connection, v interface{}) error { if len(p.handlers) < 1 { return nil } @@ -75,7 +74,7 @@ func (p *pipelineChannel) Process(entity entity.NetworkEntity, v interface{}) er defer p.mu.RUnlock() for _, handler := range p.handlers { - err := handler(entity, v) + err := handler(c, v) if err != nil { return err } diff --git a/core/ws.go b/connection/ws.go similarity index 95% rename from core/ws.go rename to connection/ws.go index d78179b..3d2a514 100644 --- a/core/ws.go +++ b/connection/ws.go @@ -1,4 +1,4 @@ -package core +package connection import ( "github.com/gorilla/websocket" @@ -12,8 +12,8 @@ type WSConn struct { *websocket.Conn } -// newWSConn 新建wsConn -func newWSConn(conn *websocket.Conn) *WSConn { +// NewWSConn 新建wsConn +func NewWSConn(conn *websocket.Conn) *WSConn { return &WSConn{Conn: conn} } diff --git a/core/engine.go b/core/engine.go deleted file mode 100644 index 99f3c45..0000000 --- a/core/engine.go +++ /dev/null @@ -1,286 +0,0 @@ -package core - -import ( - "errors" - "fmt" - "git.noahlan.cn/noahlan/nnet/config" - conn2 "git.noahlan.cn/noahlan/nnet/conn" - "git.noahlan.cn/noahlan/nnet/entity" - "git.noahlan.cn/noahlan/nnet/lifetime" - "git.noahlan.cn/noahlan/nnet/packet" - "git.noahlan.cn/noahlan/nnet/pipeline" - "git.noahlan.cn/noahlan/nnet/scheduler" - "git.noahlan.cn/noahlan/nnet/serialize" - "git.noahlan.cn/noahlan/ntools-go/core/nlog" - "git.noahlan.cn/noahlan/ntools-go/core/pool" - "github.com/gorilla/websocket" - "log" - "net" - "net/http" - "strings" - "time" -) - -func NotFound(conn entity.NetworkEntity, _ packet.IPacket) { - nlog.Error("handler not found") - _ = conn.SendBytes([]byte("handler not found")) -} - -func NotFoundHandler() Handler { - return HandlerFunc(NotFound) -} - -type ( - // engine TCP-engine - engine struct { - conf config.EngineConf // conf 配置 - taskTimerPrecision time.Duration - - middlewares []Middleware // 中间件 - routes []Route // 路由 - // handler 消息处理器 - handler Handler - // dieChan 应用程序退出信号 - dieChan chan struct{} - - pipeline pipeline.Pipeline // 消息管道 - lifetime *lifetime.Mgr // 连接的生命周期管理器 - - packerFn packet.NewPackerFunc // 封包、拆包器 - serializer serialize.Serializer // 消息 序列化/反序列化 - - wsOpt wsOptions // websocket - - connManager *conn2.Manager - sessIdMgr *sessionIDMgr - } - - wsOptions struct { - IsWebsocket bool // 是否为websocket服务端 - WebsocketPath string // ws地址(ws://127.0.0.1/WebsocketPath) - TLSCertificate string // TLS 证书地址 (websocket) - TLSKey string // TLS 证书key地址 (websocket) - CheckOrigin func(*http.Request) bool // check origin - } -) - -func newEngine(conf config.EngineConf) *engine { - s := &engine{ - conf: conf, - dieChan: make(chan struct{}), - pipeline: pipeline.New(), - middlewares: make([]Middleware, 0), - routes: make([]Route, 0), - taskTimerPrecision: conf.TaskTimerPrecision, - connManager: conn2.NewManager(), - sessIdMgr: newSessionIDMgr(), - lifetime: lifetime.NewLifetime(), - } - pool.InitPool(conf.Pool) - - return s -} - -func (ng *engine) shallLogDebug() bool { - return config.ShallLogDebug(ng.conf.Mode) -} - -func (ng *engine) logPrefix() string { - return fmt.Sprintf("[NNet-%s]", ng.conf.Name) -} - -func (ng *engine) use(middleware ...Middleware) { - ng.middlewares = append(ng.middlewares, middleware...) -} - -func (ng *engine) addRoutes(route ...Route) { - ng.routes = append(ng.routes, route...) -} - -func (ng *engine) bindRoutes(router Router) error { - for _, fr := range ng.routes { - if err := ng.bindRoute(router, fr); err != nil { - return err - } - } - return nil -} - -func (ng *engine) bindRoute(router Router, route Route) error { - // TODO default middleware - chain := newChain() - // build chain - for _, middleware := range ng.middlewares { - chain.Append(convertMiddleware(middleware)) - } - return router.Register(route.Matches, route.Handler) -} - -func convertMiddleware(ware Middleware) func(Handler) Handler { - return func(next Handler) Handler { - return ware(next.Handle) - } -} - -func (ng *engine) dial(addr string, router Router) (entity.NetworkEntity, error) { - ng.handler = router - - if err := ng.bindRoutes(router); err != nil { - return nil, err - } - go scheduler.Schedule(ng.taskTimerPrecision) - - // connection - conn, err := net.Dial("tcp", addr) - nlog.Must(err) - - c := newConnection(ng, conn) - c.serve() - // hook - ng.lifetime.Open(c) - // connection manager - err = ng.connManager.Store(conn2.DefaultGroupName, c) - nlog.Must(err) - - // 连接成功,客户端已启动 - if ng.shallLogDebug() { - nlog.Debugf("now connect to %s.", addr) - } - - return c, nil -} - -func (ng *engine) serve(router Router) error { - ng.handler = router - - if err := ng.bindRoutes(router); err != nil { - return err - } - go scheduler.Schedule(ng.taskTimerPrecision) - defer func() { - nlog.Infof("%s is stopping...", ng.logPrefix()) - - ng.shutdown() - scheduler.Close() - }() - - if ng.wsOpt.IsWebsocket { - if len(ng.wsOpt.TLSCertificate) != 0 && len(ng.wsOpt.TLSKey) != 0 { - ng.listenAndServeWSTLS() - } else { - ng.listenAndServeWS() - } - } else { - ng.listenAndServe() - } - - return nil -} - -func (ng *engine) close() { - close(ng.dieChan) -} - -func (ng *engine) shutdown() { -} - -func (ng *engine) listenAndServe() { - listener, err := net.Listen(ng.conf.Protocol, ng.conf.Addr) - nlog.Must(err) - - // 监听成功,服务已启动 - if ng.shallLogDebug() { - nlog.Debugf("%s now listening %s at %s.", ng.logPrefix(), ng.conf.Protocol, ng.conf.Addr) - } - defer func() { - _ = listener.Close() - ng.close() - }() - - for { - conn, err := listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - nlog.Errorf("%s 服务器网络错误 %+v", ng.logPrefix(), err) - return - } - nlog.Errorf("%s 监听错误 %v", ng.logPrefix(), err) - continue - } - - err = pool.Submit(func() { - ng.handle(conn) - }) - if err != nil { - nlog.Errorf("%s submit conn pool err: %ng", ng.logPrefix(), err.Error()) - continue - } - } -} - -func (ng *engine) listenAndServeWS() { - ng.setupWS() - if ng.shallLogDebug() { - nlog.Debugf("%s now listening websocket at %s.", ng.logPrefix(), ng.conf.Addr) - } - if err := http.ListenAndServe(ng.conf.Addr, nil); err != nil { - log.Fatal(err.Error()) - } -} - -func (ng *engine) listenAndServeWSTLS() { - ng.setupWS() - if ng.shallLogDebug() { - nlog.Debugf("%s now listening websocket with tls at %s.", ng.logPrefix(), ng.conf.Addr) - } - if err := http.ListenAndServeTLS(ng.conf.Addr, ng.wsOpt.TLSCertificate, ng.wsOpt.TLSKey, nil); err != nil { - log.Fatal(err.Error()) - } -} - -func (ng *engine) setupWS() { - upgrade := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: ng.wsOpt.CheckOrigin, - } - http.HandleFunc("/"+strings.TrimPrefix(ng.wsOpt.WebsocketPath, "/"), func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrade.Upgrade(w, r, nil) - if err != nil { - nlog.Errorf("%s Upgrade failure, URI=%ng, Error=%ng", ng.logPrefix(), r.RequestURI, err.Error()) - return - } - err = pool.Submit(func() { - ng.handleWS(conn) - }) - if err != nil { - log.Fatalf("%s submit conn pool err: %v", ng.logPrefix(), err.Error()) - } - }) -} - -func (ng *engine) handleWS(conn *websocket.Conn) { - c := newWSConn(conn) - ng.handle(c) -} - -func (ng *engine) handle(conn net.Conn) { - c := newConnection(ng, conn) - err := ng.connManager.Store(conn2.DefaultGroupName, c) - nlog.Must(err) - - c.serve() - // hook - ng.lifetime.Open(c) -} - -func (ng *engine) notFoundHandler(next Handler) Handler { - return HandlerFunc(func(entity entity.NetworkEntity, packet packet.IPacket) { - h := next - if next == nil { - h = NotFoundHandler() - } - // TODO write to client - h.Handle(entity, packet) - }) -} diff --git a/core/nnet.go b/core/nnet.go deleted file mode 100644 index deb34ba..0000000 --- a/core/nnet.go +++ /dev/null @@ -1,249 +0,0 @@ -package core - -import ( - "git.noahlan.cn/noahlan/nnet/config" - "git.noahlan.cn/noahlan/nnet/conn" - "git.noahlan.cn/noahlan/nnet/entity" - "git.noahlan.cn/noahlan/nnet/lifetime" - "git.noahlan.cn/noahlan/nnet/packet" - "git.noahlan.cn/noahlan/nnet/pipeline" - "git.noahlan.cn/noahlan/nnet/serialize" - "git.noahlan.cn/noahlan/ntools-go/core/nlog" - "net/http" - "time" -) - -type ( - // RunOption defines the method to customize a NNet. - RunOption func(*NNet) - - Server struct { - *NNet - } - - Client struct { - *NNet - } - - NNet struct { - ngin *engine - router Router - } -) - -// NewServer returns a server with given config of c and options defined in opts. -// Be aware that later RunOption might overwrite previous one that write the same option. -func NewServer(c config.EngineConf, opts ...RunOption) *Server { - s := &Server{ - NNet: &NNet{ - ngin: newEngine(c), - router: NewDefaultRouter(), - }, - } - - opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...) - for _, opt := range opts { - opt(s.NNet) - } - - return s -} - -// NewClient returns a client with given config of c and options defined in opts. -// Be aware that later RunOption might overwrite previous one that write the same option. -func NewClient(c config.EngineConf, opts ...RunOption) *Client { - s := &Client{ - NNet: &NNet{ - ngin: newEngine(c), - router: NewDefaultRouter(), - }, - } - - opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...) - for _, opt := range opts { - opt(s.NNet) - } - - return s -} - -// Start starts the NNet. -// Graceful shutdown is enabled by default. -func (s *Server) Start() { - if err := s.ngin.serve(s.router); err != nil { - nlog.Error(err) - panic(err) - } -} - -// Dial start the NNet client. -// Graceful shutdown is enabled by default. -func (c *Client) Dial(addr string) entity.NetworkEntity { - e, err := c.ngin.dial(addr, c.router) - nlog.Must(err) - return e -} - -// AddRoutes add given routes into the NNet. -func (s *NNet) AddRoutes(rs []Route) { - s.ngin.addRoutes(rs...) - err := s.ngin.bindRoutes(s.router) - nlog.Must(err) -} - -// AddRoute adds given route into the NNet. -func (s *NNet) AddRoute(r Route) { - s.AddRoutes([]Route{r}) -} - -// Stop stops the NNet. -func (s *NNet) Stop() { - s.ngin.close() -} - -// Use adds the given middleware in the NNet. -func (s *NNet) Use(middleware ...Middleware) { - s.ngin.use(middleware...) -} - -// Pipeline returns inner pipeline -func (s *NNet) Pipeline() pipeline.Pipeline { - return s.ngin.pipeline -} - -// Lifetime returns lifetime interface. -func (s *NNet) Lifetime() lifetime.Lifetime { - return s.ngin.lifetime -} - -// ConnManager returns connection manager -func (s *NNet) ConnManager() *conn.Manager { - return s.ngin.connManager -} - -// ToMiddleware converts the given handler to a Middleware. -func ToMiddleware(handler func(next Handler) Handler) Middleware { - return func(next HandlerFunc) HandlerFunc { - return handler(next).Handle - } -} - -// WithMiddlewares adds given middlewares to given routes. -func WithMiddlewares(ms []Middleware, rs ...Route) []Route { - for i := len(ms) - 1; i >= 0; i-- { - rs = WithMiddleware(ms[i], rs...) - } - return rs -} - -// WithMiddleware adds given middleware to given route. -func WithMiddleware(middleware Middleware, rs ...Route) []Route { - routes := make([]Route, len(rs)) - - for i := range rs { - route := rs[i] - routes[i] = Route{ - Matches: route.Matches, - Handler: middleware(route.Handler), - } - } - return routes -} - -func UseMiddleware(middleware ...Middleware) RunOption { - return func(server *NNet) { - server.Use(middleware...) - } -} - -// WithNotFoundHandler returns a RunOption with not found handler set to given handler. -func WithNotFoundHandler(handler Handler) RunOption { - return func(server *NNet) { - notFoundHandler := server.ngin.notFoundHandler(handler) - server.router.SetNotFoundHandler(notFoundHandler) - } -} - -// WithRouter 设置消息路由 -func WithRouter(router Router) RunOption { - return func(server *NNet) { - server.router = router - } -} - -// WithPacker 设置消息的 封包/解包 方式 -func WithPacker(fn packet.NewPackerFunc) RunOption { - return func(server *NNet) { - server.ngin.packerFn = fn - } -} - -// WithSerializer 设置消息的 序列化/反序列化 方式 -func WithSerializer(s serialize.Serializer) RunOption { - return func(server *NNet) { - server.ngin.serializer = s - } -} - -// WithTimerPrecision 设置Timer精度,需在 Start 或 Dial 之前执行 -// 注:精度需大于1ms, 并且不能在运行时更改 -// 默认精度是 time.Second -func WithTimerPrecision(precision time.Duration) RunOption { - if precision < time.Millisecond { - panic("time precision can not less than a Millisecond") - } - return func(s *NNet) { - s.ngin.taskTimerPrecision = precision - } -} - -func WithPipeline(pipeline pipeline.Pipeline) RunOption { - return func(server *NNet) { - server.ngin.pipeline = pipeline - } -} - -type PipelineOption func(opts pipeline.Pipeline) - -func WithPipelineOpt(opts ...func(pipeline.Pipeline)) RunOption { - return func(server *NNet) { - for _, opt := range opts { - opt(server.ngin.pipeline) - } - } -} - -type WSOption func(opts *wsOptions) - -// WithWebsocket 开启Websocket, 参数是websocket的相关参数 nnet.WSOption -func WithWebsocket(wsOpts ...WSOption) RunOption { - return func(server *NNet) { - for _, opt := range wsOpts { - opt(&server.ngin.wsOpt) - } - server.ngin.wsOpt.IsWebsocket = true - } -} - -// WithWSPath 设置websocket的path -func WithWSPath(path string) WSOption { - return func(opts *wsOptions) { - opts.WebsocketPath = path - } -} - -// WithWSTLSConfig 设置websocket的证书和密钥 -func WithWSTLSConfig(certificate, key string) WSOption { - return func(opts *wsOptions) { - opts.TLSCertificate = certificate - opts.TLSKey = key - } -} - -func WithWSCheckOriginFunc(fn func(*http.Request) bool) WSOption { - return func(opts *wsOptions) { - if fn != nil { - opts.CheckOrigin = fn - } - } -} diff --git a/core/nnet_test.go b/core/nnet_test.go deleted file mode 100644 index 37d629d..0000000 --- a/core/nnet_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package core - -import ( - "fmt" - "git.noahlan.cn/noahlan/nnet/config" - "git.noahlan.cn/noahlan/nnet/entity" - "git.noahlan.cn/noahlan/nnet/packet" - "git.noahlan.cn/noahlan/nnet/protocol/nnet" - "git.noahlan.cn/noahlan/ntools-go/core/nlog" - "git.noahlan.cn/noahlan/ntools-go/core/pool" - "math" - "testing" - "time" -) - -func TestServer(t *testing.T) { - server := NewServer(config.EngineConf{ - ServerConf: config.ServerConf{ - Protocol: "tcp", - Addr: "0.0.0.0:6666", - Name: "testServer", - Mode: "dev", - }, - Pool: pool.Config{ - PoolSize: math.MaxInt32, - ExpiryDuration: time.Second, - PreAlloc: false, - MaxBlockingTasks: 0, - Nonblocking: false, - DisablePurge: false, - }, - }, nnet.WithNNetProtocol(nnet.Config{ - HeartbeatInterval: 0, - HandshakeValidator: nil, - })...) - - server.AddRoute(Route{ - Matches: nnet.Match{ - Route: "test", - Code: 1, - }, - Handler: func(entity entity.NetworkEntity, pkg packet.IPacket) { - fmt.Println(pkg) - p, ok := pkg.(*nnet.Packet) - if !ok { - nlog.Error("wrong packet type") - return - } - - bd := []byte("服务器接收到数据为: " + string(p.GetBody())) - // 注:Response类型数据不需要Route(原地返回,客户端需等待) - _ = entity.Send(nnet.Header{ - PacketType: nnet.Data, - Length: uint32(len(bd)), - MessageHeader: nnet.MessageHeader{ - MsgType: nnet.Response, - ID: p.ID, - Route: p.Route, - }, - }, bd) - }, - }) - - defer server.Stop() - server.Start() -} diff --git a/core/session.go b/core/session.go deleted file mode 100644 index c6c84bf..0000000 --- a/core/session.go +++ /dev/null @@ -1,146 +0,0 @@ -package core - -import ( - "git.noahlan.cn/noahlan/nnet/entity" - "sync" - "sync/atomic" -) - -type session struct { - sync.RWMutex // 数据锁 - - // 网络单元 - entity entity.NetworkEntity - - id int64 // Session全局唯一ID - uid string // 用户ID,不绑定的情况下与sid一致 - data map[string]interface{} // session数据存储(内存) -} - -func newSession(entity entity.NetworkEntity, id int64) *session { - return &session{ - id: id, - entity: entity, - uid: "", - data: make(map[string]interface{}), - } -} - -// ID 获取 session ID -func (s *session) ID() int64 { - return s.id -} - -// UID 获取UID -func (s *session) UID() string { - return s.uid -} - -// Bind 绑定uid -func (s *session) Bind(uid string) { - s.uid = uid -} - -// Attribute 获取指定key对应参数 -func (s *session) Attribute(key string) interface{} { - s.RLock() - defer s.RUnlock() - - return s.data[key] -} - -// Keys 获取所有参数key -func (s *session) Keys() []string { - s.RLock() - defer s.RUnlock() - - keys := make([]string, 0, len(s.data)) - for k := range s.data { - keys = append(keys, k) - } - return keys -} - -// Exists 指定key是否存在 -func (s *session) Exists(key string) bool { - s.RLock() - defer s.RUnlock() - - _, has := s.data[key] - return has -} - -// Attributes 获取所有参数 -func (s *session) Attributes() map[string]interface{} { - s.RLock() - defer s.RUnlock() - - return s.data -} - -// RemoveAttribute 移除指定key对应参数 -func (s *session) RemoveAttribute(key string) { - s.Lock() - defer s.Unlock() - - delete(s.data, key) -} - -// SetAttribute 设置参数 -func (s *session) SetAttribute(key string, value interface{}) { - s.Lock() - defer s.Unlock() - - s.data[key] = value -} - -// Invalidate 清理 -func (s *session) Invalidate() { - s.Lock() - defer s.Unlock() - - s.id = 0 - s.uid = "" - s.data = make(map[string]interface{}) -} - -// Close 关闭 -func (s *session) Close() { - //s.entity.Close() - s.Invalidate() -} - -type sessionIDMgr struct { - count int64 - sid int64 -} - -func newSessionIDMgr() *sessionIDMgr { - return &sessionIDMgr{} -} - -// Increment the connection count -func (c *sessionIDMgr) Increment() { - atomic.AddInt64(&c.count, 1) -} - -// Decrement the connection count -func (c *sessionIDMgr) Decrement() { - atomic.AddInt64(&c.count, -1) -} - -// Count returns the connection numbers in current -func (c *sessionIDMgr) Count() int64 { - return atomic.LoadInt64(&c.count) -} - -// Reset the connection service status -func (c *sessionIDMgr) Reset() { - atomic.StoreInt64(&c.count, 0) - atomic.StoreInt64(&c.sid, 0) -} - -// SessionID returns the session id -func (c *sessionIDMgr) SessionID() int64 { - return atomic.AddInt64(&c.sid, 1) -} diff --git a/core/types.go b/core/types.go deleted file mode 100644 index 8a47ad3..0000000 --- a/core/types.go +++ /dev/null @@ -1,20 +0,0 @@ -package core - -import ( - "git.noahlan.cn/noahlan/nnet/entity" - "git.noahlan.cn/noahlan/nnet/packet" -) - -type ( - Handler interface { - Handle(entity entity.NetworkEntity, pkg packet.IPacket) - } - - HandlerFunc func(entity entity.NetworkEntity, pkg packet.IPacket) - - Middleware func(next HandlerFunc) HandlerFunc -) - -func (f HandlerFunc) Handle(entity entity.NetworkEntity, pkg packet.IPacket) { - f(entity, pkg) -} diff --git a/engine.go b/engine.go new file mode 100644 index 0000000..6e1205c --- /dev/null +++ b/engine.go @@ -0,0 +1,135 @@ +package nnet + +import ( + "git.noahlan.cn/noahlan/nnet/config" + "git.noahlan.cn/noahlan/nnet/connection" + "git.noahlan.cn/noahlan/nnet/lifetime" + "git.noahlan.cn/noahlan/nnet/packet" + rt "git.noahlan.cn/noahlan/nnet/router" + "git.noahlan.cn/noahlan/nnet/scheduler" + "git.noahlan.cn/noahlan/nnet/serialize" + "git.noahlan.cn/noahlan/nnet/session" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "github.com/panjf2000/ants/v2" + "math" + "net" +) + +// Engine 引擎 +type Engine struct { + config.EngineConf // 引擎配置 + middlewares []rt.Middleware // 中间件 + routes []rt.Route // 路由 + router rt.Router // 消息处理器 + dieChan chan struct{} // 应用程序退出信号 + pipeline connection.Pipeline // 消息管道 + packerBuilder packet.PackerBuilder // 封包、拆包器 + serializer serialize.Serializer // 消息 序列化/反序列化 + goPool *ants.Pool // goroutine池 + connManager *connection.Manager // 连接管理器 + lifetime *lifetime.Mgr // 生命周期 + sessIdMgr *session.IDMgr // SessionId管理器 +} + +func NewEngine(conf config.EngineConf, opts ...RunOption) *Engine { + ngin := &Engine{ + EngineConf: conf, + middlewares: make([]rt.Middleware, 0), + routes: make([]rt.Route, 0), + router: rt.NewDefaultRouter(), + packerBuilder: nil, + serializer: nil, + dieChan: make(chan struct{}), + pipeline: connection.NewPipeline(), + connManager: connection.NewManager(), + lifetime: lifetime.NewLifetime(), + sessIdMgr: session.NewSessionIDMgr(), + goPool: nil, + } + + for _, opt := range opts { + opt(ngin) + } + + if ngin.goPool == nil { + ngin.goPool, _ = ants.NewPool(math.MaxInt32) + } + + return ngin +} + +func (ngin *Engine) Use(middleware ...rt.Middleware) { + ngin.middlewares = append(ngin.middlewares, middleware...) +} + +func (ngin *Engine) AddRoutes(rs ...rt.Route) { + ngin.routes = append(ngin.routes, rs...) + err := ngin.bindRoutes() + nlog.Must(err) +} + +func (ngin *Engine) bindRoutes() error { + for _, fr := range ngin.routes { + if err := ngin.bindRoute(fr); err != nil { + return err + } + } + return nil +} + +func (ngin *Engine) bindRoute(route rt.Route) error { + // TODO default middleware + chain := rt.NewChain() + // build chain + for _, middleware := range ngin.middlewares { + chain.Append(rt.ConvertMiddleware(middleware)) + } + return ngin.router.Register(route.Matches, route.Handler) +} + +func (ngin *Engine) setup() error { + if err := ngin.bindRoutes(); err != nil { + return err + } + if err := ngin.goPool.Submit(func() { + scheduler.Schedule(ngin.TaskTimerPrecision) + }); err != nil { + return err + } + return nil +} + +func (ngin *Engine) Stop() { + nlog.Infof("%s is stopping...", ngin.LogPrefix()) + close(ngin.dieChan) + scheduler.Close() +} + +func (ngin *Engine) handle(conn net.Conn) *connection.Connection { + nc := connection.NewConnection( + ngin.sessIdMgr.SessionID(), + conn, + connection.Config{LogDebug: ngin.ShallLogDebug(), LogPrefix: ngin.LogPrefix()}, + ngin.packerBuilder, ngin.serializer, ngin.pipeline, + ngin.router.Handle, + ) + + nc.Serve() + + err := ngin.connManager.Store(connection.DefaultGroupName, nc) + nlog.Must(err) + + // dieChan + go func() { + // lifetime + ngin.lifetime.Open(nc) + + select { + case <-nc.DieChan(): + scheduler.PushTask(func() { ngin.lifetime.Close(nc) }) + _ = ngin.connManager.Remove(nc) + } + }() + + return nc +} diff --git a/entity/network.go b/entity/network.go deleted file mode 100644 index f686329..0000000 --- a/entity/network.go +++ /dev/null @@ -1,24 +0,0 @@ -package entity - -import "net" - -type NetworkEntity interface { - // Send 主动发送消息,支持自定义header,payload - Send(header, payload interface{}) error - // SendBytes 主动发送消息,消息需提前编码 - SendBytes(data []byte) error - // Status 获取当前连接状态 - Status() int32 - // SetStatus 设置当前连接状态 - SetStatus(s int32) - // Conn 获取当前底层连接(还需根据返回参数2决定是否转换为WSConn) - Conn() (net.Conn, bool) - // Session 获取当前连接 Session - Session() Session - // LastMID 最新消息ID - LastMID() uint64 - // SetLastMID 设置消息ID - SetLastMID(mid uint64) - // Close 关闭连接 - Close() error -} diff --git a/entity/session.go b/entity/session.go deleted file mode 100644 index 598a5bc..0000000 --- a/entity/session.go +++ /dev/null @@ -1,26 +0,0 @@ -package entity - -type Session interface { - // ID 获取 Session ID - ID() int64 - // UID 获取UID - UID() string - // Bind 绑定uid - Bind(uid string) - // Attribute 获取指定key对应参数 - Attribute(key string) interface{} - // Keys 获取所有参数key - Keys() []string - // Exists 指定key是否存在 - Exists(key string) bool - // Attributes 获取所有参数 - Attributes() map[string]interface{} - // RemoveAttribute 移除指定key对应参数 - RemoveAttribute(key string) - // SetAttribute 设置参数 - SetAttribute(key string, value interface{}) - // Invalidate 清理 - Invalidate() - // Close 关闭 Session - Close() -} diff --git a/lifetime/lifetime.go b/lifetime/lifetime.go index aef5ddb..e47b654 100644 --- a/lifetime/lifetime.go +++ b/lifetime/lifetime.go @@ -1,9 +1,11 @@ package lifetime -import "git.noahlan.cn/noahlan/nnet/entity" +import ( + "git.noahlan.cn/noahlan/nnet/connection" +) type ( - Handler func(entity entity.NetworkEntity) + Handler func(conn *connection.Connection) Lifetime interface { OnClosed(h Handler) @@ -31,22 +33,22 @@ func (lt *Mgr) OnOpen(h Handler) { lt.onOpen = append(lt.onOpen, h) } -func (lt *Mgr) Open(entity entity.NetworkEntity) { +func (lt *Mgr) Open(conn *connection.Connection) { if len(lt.onOpen) <= 0 { return } for _, handler := range lt.onOpen { - handler(entity) + handler(conn) } } -func (lt *Mgr) Close(entity entity.NetworkEntity) { +func (lt *Mgr) Close(conn *connection.Connection) { if len(lt.onClosed) <= 0 { return } for _, handler := range lt.onClosed { - handler(entity) + handler(conn) } } diff --git a/middleware/heartbeat.go b/middleware/heartbeat.go index d06e98b..17a87cf 100644 --- a/middleware/heartbeat.go +++ b/middleware/heartbeat.go @@ -1,9 +1,10 @@ package middleware import ( - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" + rt "git.noahlan.cn/noahlan/nnet/router" "git.noahlan.cn/noahlan/ntools-go/core/nlog" "sync/atomic" "time" @@ -12,10 +13,10 @@ import ( type HeartbeatMiddleware struct { lastAt int64 interval time.Duration - hbdFn func(entity entity.NetworkEntity) []byte + hbdFn func(conn *connection.Connection) []byte } -func WithHeartbeat(interval time.Duration, hbdFn func(entity entity.NetworkEntity) []byte) core.RunOption { +func WithHeartbeat(interval time.Duration, hbdFn func(conn *connection.Connection) []byte) nnet.RunOption { m := &HeartbeatMiddleware{ lastAt: time.Now().Unix(), interval: interval, @@ -26,20 +27,20 @@ func WithHeartbeat(interval time.Duration, hbdFn func(entity entity.NetworkEntit panic("dataFn must not be nil") } - return func(server *core.NNet) { - server.Lifetime().OnOpen(m.start) + return func(ngin *nnet.Engine) { + ngin.Lifetime().OnOpen(m.start) - server.Use(func(next core.HandlerFunc) core.HandlerFunc { - return func(entity entity.NetworkEntity, pkg packet.IPacket) { - m.handle(entity, pkg) + ngin.Use(func(next rt.HandlerFunc) rt.HandlerFunc { + return func(conn *connection.Connection, pkg packet.IPacket) { + m.handle(conn, pkg) - next(entity, pkg) + next(conn, pkg) } }) } } -func (m *HeartbeatMiddleware) start(entity entity.NetworkEntity) { +func (m *HeartbeatMiddleware) start(conn *connection.Connection) { ticker := time.NewTicker(m.interval) defer func() { @@ -54,7 +55,7 @@ func (m *HeartbeatMiddleware) start(entity entity.NetworkEntity) { nlog.Errorf("Heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&m.lastAt), deadline) return } - err := entity.SendBytes(m.hbdFn(entity)) + err := conn.SendBytes(m.hbdFn(conn)) if err != nil { nlog.Errorf("Heartbeat err: %v", err) return @@ -63,6 +64,6 @@ func (m *HeartbeatMiddleware) start(entity entity.NetworkEntity) { } } -func (m *HeartbeatMiddleware) handle(_ entity.NetworkEntity, _ packet.IPacket) { +func (m *HeartbeatMiddleware) handle(_ *connection.Connection, _ packet.IPacket) { atomic.StoreInt64(&m.lastAt, time.Now().Unix()) } diff --git a/options.go b/options.go new file mode 100644 index 0000000..b9ddca1 --- /dev/null +++ b/options.go @@ -0,0 +1,114 @@ +package nnet + +import ( + "git.noahlan.cn/noahlan/nnet/connection" + "git.noahlan.cn/noahlan/nnet/lifetime" + "git.noahlan.cn/noahlan/nnet/packet" + rt "git.noahlan.cn/noahlan/nnet/router" + "git.noahlan.cn/noahlan/nnet/serialize" + "git.noahlan.cn/noahlan/ntools-go/core/pool" + "github.com/panjf2000/ants/v2" + "time" +) + +type ( + // RunOption defines the method to customize an Engine. + RunOption func(ngin *Engine) +) + +// Pipeline returns inner pipeline +func (ngin *Engine) Pipeline() connection.Pipeline { + return ngin.pipeline +} + +// Lifetime returns lifetime interface. +func (ngin *Engine) Lifetime() lifetime.Lifetime { + return ngin.lifetime +} + +// ConnManager returns connection manager +func (ngin *Engine) ConnManager() *connection.Manager { + return ngin.connManager +} + +//////////////////////// Options + +func WithMiddleware(middleware ...rt.Middleware) RunOption { + return func(ngin *Engine) { + ngin.Use(middleware...) + } +} + +// WithRouter 设置消息路由 +func WithRouter(router rt.Router) RunOption { + return func(ngin *Engine) { + ngin.router = router + } +} + +// WithNotFoundHandler returns a RunOption with not found handler set to given handler. +func WithNotFoundHandler(handler rt.Handler) RunOption { + return func(ngin *Engine) { + ngin.router.SetNotFoundHandler(rt.NotFoundHandler(handler)) + } +} + +// WithTimerPrecision 设置Timer精度,需在 Start 或 Dial 之前执行 +// 注:精度需大于1ms, 并且不能在运行时更改 +// 默认精度是 time.Second +func WithTimerPrecision(precision time.Duration) RunOption { + if precision < time.Millisecond { + panic("time precision can not less than a Millisecond") + } + return func(ngin *Engine) { + ngin.TaskTimerPrecision = precision + } +} + +// WithPackerBuilder 设置 消息的封包/解包构造器 +func WithPackerBuilder(fn packet.PackerBuilder) RunOption { + return func(ngin *Engine) { + ngin.packerBuilder = fn + } +} + +// WithSerializer 设置消息的 序列化/反序列化 方式 +func WithSerializer(s serialize.Serializer) RunOption { + return func(ngin *Engine) { + ngin.serializer = s + } +} + +// WithPool 设置使用自定义的工作池 +func WithPool(pl *ants.Pool) RunOption { + return func(ngin *Engine) { + ngin.goPool = pl + } +} + +// WithPoolCfg 设置工作池配置 +func WithPoolCfg(cfg pool.Config) RunOption { + return func(ngin *Engine) { + ngin.goPool, _ = ants.NewPool(cfg.PoolSize, ants.WithOptions(cfg.Options())) + } +} + +//////////////////// Pipeline + +// WithPipeline 使用自定义 pipeline +func WithPipeline(pipeline connection.Pipeline) RunOption { + return func(ngin *Engine) { + ngin.pipeline = pipeline + } +} + +type PipelineOption func(opts connection.Pipeline) + +// WithPipelineOpt 使用默认Pipeline并设置其配置 +func WithPipelineOpt(opts ...func(connection.Pipeline)) RunOption { + return func(ngin *Engine) { + for _, opt := range opts { + opt(ngin.pipeline) + } + } +} diff --git a/packet/packet.go b/packet/packet.go index 5707124..b24fcf0 100644 --- a/packet/packet.go +++ b/packet/packet.go @@ -24,5 +24,6 @@ type ( Unpack(data []byte) ([]IPacket, error) } - NewPackerFunc func() Packer + // PackerBuilder Packer构建器 + PackerBuilder func() Packer ) diff --git a/protocol/modbus/crc.go b/protocol/modbus/crc.go new file mode 100644 index 0000000..1ef332b --- /dev/null +++ b/protocol/modbus/crc.go @@ -0,0 +1,43 @@ +package modbus + +import "sync" + +var crcTable []uint16 +var mu sync.Mutex + +// crcModbus 计算modbus的crc +func crcModbus(data []byte) (crc uint16) { + if crcTable == nil { + mu.Lock() + if crcTable == nil { + initCrcTable() + } + mu.Unlock() + } + crc = 0xffff + for _, v := range data { + crc = (crc >> 8) ^ crcTable[(crc^uint16(v))&0x00FF] + } + return crc +} + +// initCrcTable 初始化crcTable +func initCrcTable() { + crc16IBM := uint16(0xA001) + crcTable = make([]uint16, 256) + + for i := uint16(0); i < 256; i++ { + crc := uint16(0) + c := i + + for j := uint16(0); j < 8; j++ { + if ((crc ^ c) & 0x0001) > 0 { + crc = (crc >> 1) ^ crc16IBM + } else { + crc = crc >> 1 + } + c = c >> 1 + } + crcTable[i] = crc + } +} diff --git a/protocol/modbus/crc_test.go b/protocol/modbus/crc_test.go new file mode 100644 index 0000000..b91bdf9 --- /dev/null +++ b/protocol/modbus/crc_test.go @@ -0,0 +1,14 @@ +package modbus + +import ( + "git.noahlan.cn/noahlan/nnet/protocol/modbus/internal" + "testing" +) + +func TestCRC(t *testing.T) { + got := crcModbus([]byte{0x01, 0x04, 0x02, 0xFF, 0xFF}) + expect := uint16(0x80B8) + + assert := internal.NewAssert(t, "TestCRC") + assert.Equal(expect, got) +} diff --git a/protocol/modbus/internal/assert.go b/protocol/modbus/internal/assert.go new file mode 100644 index 0000000..e6a490e --- /dev/null +++ b/protocol/modbus/internal/assert.go @@ -0,0 +1,167 @@ +package internal + +import ( + "fmt" + "reflect" + "runtime" + "testing" +) + +const ( + compareNotEqual int = iota - 2 + compareLess + compareEqual + compareGreater +) + +// Assert is a simple implementation of assertion, only for internal usage +type Assert struct { + T *testing.T + CaseName string +} + +// NewAssert return instance of Assert +func NewAssert(t *testing.T, caseName string) *Assert { + return &Assert{T: t, CaseName: caseName} +} + +// Equal check if expected is equal with actual +func (a *Assert) Equal(expected, actual any) { + if compare(expected, actual) != compareEqual { + makeTestFailed(a.T, a.CaseName, expected, actual) + } +} + +// NotEqual check if expected is not equal with actual +func (a *Assert) NotEqual(expected, actual any) { + if compare(expected, actual) == compareEqual { + expectedInfo := fmt.Sprintf("not %v", expected) + makeTestFailed(a.T, a.CaseName, expectedInfo, actual) + } +} + +// Greater check if expected is greate than actual +func (a *Assert) Greater(expected, actual any) { + if compare(expected, actual) != compareGreater { + expectedInfo := fmt.Sprintf("> %v", expected) + makeTestFailed(a.T, a.CaseName, expectedInfo, actual) + } +} + +// GreaterOrEqual check if expected is greate than or equal with actual +func (a *Assert) GreaterOrEqual(expected, actual any) { + isGreatOrEqual := compare(expected, actual) == compareGreater || compare(expected, actual) == compareEqual + if !isGreatOrEqual { + expectedInfo := fmt.Sprintf(">= %v", expected) + makeTestFailed(a.T, a.CaseName, expectedInfo, actual) + } +} + +// Less check if expected is less than actual +func (a *Assert) Less(expected, actual any) { + if compare(expected, actual) != compareLess { + expectedInfo := fmt.Sprintf("< %v", expected) + makeTestFailed(a.T, a.CaseName, expectedInfo, actual) + } +} + +// LessOrEqual check if expected is less than or equal with actual +func (a *Assert) LessOrEqual(expected, actual any) { + isLessOrEqual := compare(expected, actual) == compareLess || compare(expected, actual) == compareEqual + if !isLessOrEqual { + expectedInfo := fmt.Sprintf("<= %v", expected) + makeTestFailed(a.T, a.CaseName, expectedInfo, actual) + } +} + +// IsNil check if value is nil +func (a *Assert) IsNil(value any) { + if value != nil { + makeTestFailed(a.T, a.CaseName, nil, value) + } +} + +// IsNotNil check if value is not nil +func (a *Assert) IsNotNil(value any) { + if value == nil { + makeTestFailed(a.T, a.CaseName, "not nil", value) + } +} + +// compare x and y return : +// x > y -> 1, x < y -> -1, x == y -> 0, x != y -> -2 +func compare(x, y any) int { + vx := reflect.ValueOf(x) + vy := reflect.ValueOf(y) + + if vx.Type() != vy.Type() { + return compareNotEqual + } + + switch vx.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + xInt := vx.Int() + yInt := vy.Int() + if xInt > yInt { + return compareGreater + } + if xInt == yInt { + return compareEqual + } + if xInt < yInt { + return compareLess + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + xUint := vx.Uint() + yUint := vy.Uint() + if xUint > yUint { + return compareGreater + } + if xUint == yUint { + return compareEqual + } + if xUint < yUint { + return compareLess + } + case reflect.Float32, reflect.Float64: + xFloat := vx.Float() + yFloat := vy.Float() + if xFloat > yFloat { + return compareGreater + } + if xFloat == yFloat { + return compareEqual + } + if xFloat < yFloat { + return compareLess + } + case reflect.String: + xString := vx.String() + yString := vy.String() + if xString > yString { + return compareGreater + } + if xString == yString { + return compareEqual + } + if xString < yString { + return compareLess + } + default: + if reflect.DeepEqual(x, y) { + return compareEqual + } + return compareNotEqual + } + + return compareNotEqual + +} + +// logFailedInfo make test failed and log error info +func makeTestFailed(t *testing.T, caseName string, expected, actual any) { + _, file, line, _ := runtime.Caller(2) + errInfo := fmt.Sprintf("Case %v failed. file: %v, line: %v, expected: %v, actual: %v.", caseName, file, line, expected, actual) + t.Error(errInfo) + t.FailNow() +} diff --git a/protocol/modbus/internal/assert_test.go b/protocol/modbus/internal/assert_test.go new file mode 100644 index 0000000..ae0bcfa --- /dev/null +++ b/protocol/modbus/internal/assert_test.go @@ -0,0 +1,50 @@ +package internal + +import ( + "testing" +) + +func TestAssert(t *testing.T) { + assert := NewAssert(t, "TestAssert") + assert.Equal(0, 0) + assert.NotEqual(1, 0) + + assert.NotEqual("1", 1) + var uInt1 uint + var uInt2 uint + var uInt8 uint8 + var uInt16 uint16 + var uInt32 uint32 + var uInt64 uint64 + assert.NotEqual(uInt1, uInt8) + assert.NotEqual(uInt8, uInt16) + assert.NotEqual(uInt16, uInt32) + assert.NotEqual(uInt32, uInt64) + + assert.Equal(uInt1, uInt2) + + uInt1 = 1 + uInt2 = 2 + assert.Less(uInt1, uInt2) + + assert.Greater(1, 0) + assert.GreaterOrEqual(1, 1) + assert.Less(0, 1) + assert.LessOrEqual(0, 0) + + assert.Equal(0.1, 0.1) + assert.Greater(1.1, 0.1) + assert.Less(0.1, 1.1) + + assert.Equal("abc", "abc") + assert.NotEqual("abc", "abd") + assert.Less("abc", "abd") + assert.Greater("abd", "abc") + + assert.Equal([]int{1, 2, 3}, []int{1, 2, 3}) + assert.NotEqual([]int{1, 2, 3}, []int{1, 2}) + + assert.IsNil(nil) + assert.IsNotNil("abc") + +} diff --git a/protocol/nnet/client_pipeline_nnet.go b/protocol/nnet/client_pipeline_nnet.go index f33402d..ab83d81 100644 --- a/protocol/nnet/client_pipeline_nnet.go +++ b/protocol/nnet/client_pipeline_nnet.go @@ -4,22 +4,22 @@ import ( "encoding/json" "errors" "fmt" - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/ntools-go/core/nlog" ) type OnReadyFunc func() -func WithNNetClientPipeline(onReady OnReadyFunc, packer packet.Packer) core.RunOption { - return func(server *core.NNet) { - server.Pipeline().Inbound().PushFront(func(entity entity.NetworkEntity, v interface{}) error { +func WithNNetClientPipeline(onReady OnReadyFunc, packer packet.Packer) nnet.RunOption { + return func(ngin *nnet.Engine) { + ngin.Pipeline().Inbound().PushFront(func(conn *connection.Connection, v interface{}) error { pkg, ok := v.(*Packet) if !ok { return packet.ErrWrongPacketType } - conn, _ := entity.Conn() + nc, _ := conn.Conn() // Server to client switch pkg.PacketType { @@ -32,22 +32,22 @@ func WithNNetClientPipeline(onReady OnReadyFunc, packer packet.Packer) core.RunO PacketType: HandshakeAck, MessageHeader: MessageHeader{}, }, nil) - if err := entity.SendBytes(hrd); err != nil { + if err := conn.SendBytes(hrd); err != nil { return err } - entity.SetStatus(core.StatusWorking) + conn.SetStatus(connection.StatusWorking) // onReady if onReady != nil { onReady() } - nlog.Debugf("connection handshake Id=%d, Remote=%s", entity.Session().ID(), conn.RemoteAddr()) + nlog.Debugf("connection handshake Id=%d, Remote=%s", conn.Session().ID(), nc.RemoteAddr()) case Kick: - _ = entity.Close() + _ = conn.Close() case Data: - status := entity.Status() - if status != core.StatusWorking { + status := conn.Status() + if status != connection.StatusWorking { return errors.New(fmt.Sprintf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s", - conn.RemoteAddr())) + nc.RemoteAddr())) } var lastMid uint64 @@ -57,7 +57,7 @@ func WithNNetClientPipeline(onReady OnReadyFunc, packer packet.Packer) core.RunO case Notify: lastMid = 0 } - entity.SetLastMID(lastMid) + conn.SetLastMID(lastMid) } return nil }) diff --git a/protocol/nnet/nnet.go b/protocol/nnet/nnet.go index 48fa2ec..de8c224 100644 --- a/protocol/nnet/nnet.go +++ b/protocol/nnet/nnet.go @@ -1,8 +1,8 @@ package nnet import ( - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/middleware" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/ntools-go/core/nlog" @@ -36,19 +36,19 @@ type ( } ) -func WithNNetClientProtocol(onReady OnReadyFunc) []core.RunOption { +func WithNNetClientProtocol(onReady OnReadyFunc) []nnet.RunOption { router := NewRouter().(*nRouter) packer := NewPacker(router.routeMap) - opts := []core.RunOption{ + opts := []nnet.RunOption{ WithNNetClientPipeline(onReady, packer), - core.WithRouter(router), - core.WithPacker(func() packet.Packer { return NewPacker(router.routeMap) }), + nnet.WithRouter(router), + nnet.WithPackerBuilder(func() packet.Packer { return NewPacker(router.routeMap) }), } return opts } -func WithNNetProtocol(config Config) []core.RunOption { +func WithNNetProtocol(config Config) []nnet.RunOption { if config.HandshakeValidator == nil { config.HandshakeValidator = func(data *HandshakeReq) error { return nil @@ -61,17 +61,17 @@ func WithNNetProtocol(config Config) []core.RunOption { } packer := NewPacker(router.routeMap) - opts := []core.RunOption{ + opts := []nnet.RunOption{ withNNetPipeline(handshakeAckData, config.HandshakeValidator, packer), - core.WithRouter(router), - core.WithPacker(func() packet.Packer { return NewPacker(router.routeMap) }), + nnet.WithRouter(router), + nnet.WithPackerBuilder(func() packet.Packer { return NewPacker(router.routeMap) }), } if config.HeartbeatInterval.Seconds() > 0 { hbd, err := packer.Pack(Heartbeat, nil) nlog.Must(err) - opts = append(opts, middleware.WithHeartbeat(config.HeartbeatInterval, func(_ entity.NetworkEntity) []byte { + opts = append(opts, middleware.WithHeartbeat(config.HeartbeatInterval, func(_ *connection.Connection) []byte { return hbd })) } diff --git a/protocol/nnet/pipeline_nnet.go b/protocol/nnet/pipeline_nnet.go index 577af55..67d3d0c 100644 --- a/protocol/nnet/pipeline_nnet.go +++ b/protocol/nnet/pipeline_nnet.go @@ -4,8 +4,8 @@ import ( "encoding/json" "errors" "fmt" - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/ntools-go/core/nlog" ) @@ -19,14 +19,14 @@ func withNNetPipeline( handshakeResp *HandshakeResp, validator HandshakeValidatorFunc, packer packet.Packer, -) core.RunOption { - return func(server *core.NNet) { - server.Pipeline().Inbound().PushFront(func(entity entity.NetworkEntity, v interface{}) error { +) nnet.RunOption { + return func(ngin *nnet.Engine) { + ngin.Pipeline().Inbound().PushFront(func(conn *connection.Connection, v interface{}) error { pkg, ok := v.(*Packet) if !ok { return packet.ErrWrongPacketType } - conn, _ := entity.Conn() + nc, _ := conn.Conn() switch pkg.PacketType { case Handshake: @@ -46,20 +46,20 @@ func withNNetPipeline( PacketType: Handshake, MessageHeader: MessageHeader{}, }, data) - if err := entity.SendBytes(hrd); err != nil { + if err := conn.SendBytes(hrd); err != nil { return err } - entity.SetStatus(core.StatusPrepare) - nlog.Debugf("connection handshake Id=%d, Remote=%s", entity.Session().ID(), conn.RemoteAddr()) + conn.SetStatus(connection.StatusPrepare) + nlog.Debugf("connection handshake Id=%d, Remote=%s", conn.Session().ID(), nc.RemoteAddr()) case HandshakeAck: - entity.SetStatus(core.StatusPending) - nlog.Debugf("receive handshake ACK Id=%d, Remote=%s", entity.Session().ID(), conn.RemoteAddr()) + conn.SetStatus(connection.StatusPending) + nlog.Debugf("receive handshake ACK Id=%d, Remote=%s", conn.Session().ID(), nc.RemoteAddr()) case Data: - if entity.Status() < core.StatusPending { + if conn.Status() < connection.StatusPending { return errors.New(fmt.Sprintf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s", - conn.RemoteAddr())) + nc.RemoteAddr())) } - entity.SetStatus(core.StatusWorking) + conn.SetStatus(connection.StatusWorking) var lastMid uint64 switch pkg.MsgType { @@ -70,7 +70,7 @@ func withNNetPipeline( default: return fmt.Errorf("Invalid message type: %s ", pkg.MsgType.String()) } - entity.SetLastMID(lastMid) + conn.SetLastMID(lastMid) } return nil }) diff --git a/protocol/nnet/router_nnet.go b/protocol/nnet/router_nnet.go index cc14bc0..5c2a1a2 100644 --- a/protocol/nnet/router_nnet.go +++ b/protocol/nnet/router_nnet.go @@ -3,9 +3,9 @@ package nnet import ( "errors" "fmt" - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" + rt "git.noahlan.cn/noahlan/nnet/router" "git.noahlan.cn/noahlan/ntools-go/core/nlog" ) @@ -23,8 +23,8 @@ type ( nRouter struct { routeMap *RouteMap - handlers map[string]core.Handler - notFound core.Handler + handlers map[string]rt.Handler + notFound rt.Handler } ) @@ -35,14 +35,14 @@ func NewRouteMap() *RouteMap { } } -func NewRouter() core.Router { +func NewRouter() rt.Router { return &nRouter{ routeMap: NewRouteMap(), - handlers: make(map[string]core.Handler), + handlers: make(map[string]rt.Handler), } } -func (r *nRouter) Handle(entity entity.NetworkEntity, p packet.IPacket) { +func (r *nRouter) Handle(conn *connection.Connection, p packet.IPacket) { pkg, ok := p.(*Packet) if !ok { nlog.Error(packet.ErrWrongPacketType) @@ -54,13 +54,13 @@ func (r *nRouter) Handle(entity entity.NetworkEntity, p packet.IPacket) { nlog.Error("message handler not found") return } - r.notFound.Handle(entity, p) + r.notFound.Handle(conn, p) return } - handler.Handle(entity, p) + handler.Handle(conn, p) } -func (r *nRouter) Register(matches interface{}, handler core.Handler) error { +func (r *nRouter) Register(matches interface{}, handler rt.Handler) error { match, ok := matches.(Match) if !ok { return errors.New(fmt.Sprintf("the type of matches must be %T", Match{})) @@ -72,6 +72,6 @@ func (r *nRouter) Register(matches interface{}, handler core.Handler) error { return nil } -func (r *nRouter) SetNotFoundHandler(handler core.Handler) { +func (r *nRouter) SetNotFoundHandler(handler rt.Handler) { r.notFound = handler } diff --git a/protocol/plain/pipeline_plain.go b/protocol/plain/pipeline_plain.go index 015a951..f27ede3 100644 --- a/protocol/plain/pipeline_plain.go +++ b/protocol/plain/pipeline_plain.go @@ -1,20 +1,20 @@ package plain import ( - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" ) -func withPipeline() core.RunOption { - return func(net *core.NNet) { - net.Pipeline().Inbound().PushFront(func(et entity.NetworkEntity, v interface{}) error { +func withPipeline() nnet.RunOption { + return func(ngin *nnet.Engine) { + ngin.Pipeline().Inbound().PushFront(func(conn *connection.Connection, v interface{}) error { _, ok := v.(*Packet) if !ok { return packet.ErrWrongPacketType } - if et.Status() != core.StatusWorking { - et.SetStatus(core.StatusWorking) + if conn.Status() != connection.StatusWorking { + conn.SetStatus(connection.StatusWorking) } return nil }) diff --git a/protocol/plain/plain.go b/protocol/plain/plain.go index f851d84..b238bbf 100644 --- a/protocol/plain/plain.go +++ b/protocol/plain/plain.go @@ -1,15 +1,15 @@ package plain import ( - "git.noahlan.cn/noahlan/nnet/core" + "git.noahlan.cn/noahlan/nnet" "git.noahlan.cn/noahlan/nnet/packet" ) -func WithPlainProtocol() []core.RunOption { - opts := []core.RunOption{ +func WithPlainProtocol() []nnet.RunOption { + opts := []nnet.RunOption{ withPipeline(), - core.WithRouter(NewRouter()), - core.WithPacker(func() packet.Packer { return NewPacker() }), + nnet.WithRouter(NewRouter()), + nnet.WithPackerBuilder(func() packet.Packer { return NewPacker() }), } return opts } diff --git a/protocol/plain/router_plain.go b/protocol/plain/router_plain.go index 503304b..3416a0c 100644 --- a/protocol/plain/router_plain.go +++ b/protocol/plain/router_plain.go @@ -1,22 +1,22 @@ package plain import ( - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" + "git.noahlan.cn/noahlan/nnet/router" "git.noahlan.cn/noahlan/ntools-go/core/nlog" ) type Router struct { - plainHandler core.Handler - notFound core.Handler + plainHandler router.Handler + notFound router.Handler } -func NewRouter() core.Router { +func NewRouter() router.Router { return &Router{} } -func (r *Router) Handle(entity entity.NetworkEntity, pkg packet.IPacket) { +func (r *Router) Handle(conn *connection.Connection, pkg packet.IPacket) { p, ok := pkg.(*Packet) if !ok { nlog.Error(packet.ErrWrongPacketType) @@ -27,17 +27,17 @@ func (r *Router) Handle(entity entity.NetworkEntity, pkg packet.IPacket) { nlog.Error("message handler not found") return } - r.notFound.Handle(entity, p) + r.notFound.Handle(conn, p) return } - r.plainHandler.Handle(entity, p) + r.plainHandler.Handle(conn, p) } -func (r *Router) Register(_ interface{}, handler core.Handler) error { +func (r *Router) Register(_ interface{}, handler router.Handler) error { r.plainHandler = handler return nil } -func (r *Router) SetNotFoundHandler(handler core.Handler) { +func (r *Router) SetNotFoundHandler(handler router.Handler) { r.notFound = handler } diff --git a/router/options.go b/router/options.go new file mode 100644 index 0000000..7cbb289 --- /dev/null +++ b/router/options.go @@ -0,0 +1,30 @@ +package router + +// ToMiddleware converts the given handler to a Middleware. +func ToMiddleware(handler func(next Handler) Handler) Middleware { + return func(next HandlerFunc) HandlerFunc { + return handler(next).Handle + } +} + +// WithMiddlewares adds given middlewares to given routes. +func WithMiddlewares(ms []Middleware, rs ...Route) []Route { + for i := len(ms) - 1; i >= 0; i-- { + rs = WithMiddleware(ms[i], rs...) + } + return rs +} + +// WithMiddleware adds given middleware to given route. +func WithMiddleware(middleware Middleware, rs ...Route) []Route { + routes := make([]Route, len(rs)) + + for i := range rs { + route := rs[i] + routes[i] = Route{ + Matches: route.Matches, + Handler: middleware(route.Handler), + } + } + return routes +} diff --git a/core/router.go b/router/router.go similarity index 59% rename from core/router.go rename to router/router.go index 39c23a2..d5746c9 100644 --- a/core/router.go +++ b/router/router.go @@ -1,11 +1,20 @@ -package core +package router import ( - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" ) type ( + Handler interface { + Handle(c *connection.Connection, pkg packet.IPacket) + } + // HandlerFunc 消息处理方法 + HandlerFunc func(conn *connection.Connection, pkg packet.IPacket) + + Middleware func(next HandlerFunc) HandlerFunc + Route struct { Matches interface{} // 用于匹配的关键字段 Handler HandlerFunc // 处理方法 @@ -20,11 +29,31 @@ type ( Constructor func(Handler) Handler ) +func notFound(conn *connection.Connection, _ packet.IPacket) { + nlog.Error("handler not found") + _ = conn.SendBytes([]byte("404")) +} + +func NotFoundHandler(next Handler) Handler { + return HandlerFunc(func(c *connection.Connection, packet packet.IPacket) { + h := next + if next == nil { + h = HandlerFunc(notFound) + } + // TODO write to client + h.Handle(c, packet) + }) +} + +func (f HandlerFunc) Handle(c *connection.Connection, pkg packet.IPacket) { + f(c, pkg) +} + type Chain struct { constructors []Constructor } -func newChain(constructors ...Constructor) Chain { +func NewChain(constructors ...Constructor) Chain { return Chain{append(([]Constructor)(nil), constructors...)} } @@ -65,11 +94,11 @@ func NewDefaultRouter() Router { return &plainRouter{} } -func (p *plainRouter) Handle(entity entity.NetworkEntity, pkg packet.IPacket) { +func (p *plainRouter) Handle(c *connection.Connection, pkg packet.IPacket) { if p.handler == nil { return } - p.handler.Handle(entity, pkg) + p.handler.Handle(c, pkg) } func (p *plainRouter) Register(_ interface{}, handler Handler) error { diff --git a/router/util.go b/router/util.go new file mode 100644 index 0000000..fc5663f --- /dev/null +++ b/router/util.go @@ -0,0 +1,7 @@ +package router + +func ConvertMiddleware(ware Middleware) func(Handler) Handler { + return func(next Handler) Handler { + return ware(next.Handle) + } +} diff --git a/server_tcp.go b/server_tcp.go new file mode 100644 index 0000000..a889f30 --- /dev/null +++ b/server_tcp.go @@ -0,0 +1,46 @@ +package nnet + +import ( + "errors" + "git.noahlan.cn/noahlan/nnet/config" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "net" +) + +func (ngin *Engine) ListenTCP(conf config.TCPServerConf) error { + err := ngin.setup() + if err != nil { + nlog.Errorf("%s failed to setup server, err:%v", ngin.LogPrefix(), err) + return err + } + + listener, err := net.Listen(conf.Protocol, conf.Addr) + if err != nil { + nlog.Errorf("%s failed to listening at [%s %s] %v", ngin.LogPrefix(), conf.Protocol, conf.Addr, err) + return err + } + nlog.Infof("%s now listening %s at %s...", ngin.LogPrefix(), conf.Protocol, conf.Addr) + defer func() { + _ = listener.Close() + ngin.Stop() + }() + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + nlog.Errorf("%s connection closed, err:%v", ngin.LogPrefix(), err) + return err + } + nlog.Errorf("%s accept connection failed, err:%v", ngin.LogPrefix(), err) + continue + } + + err = ngin.goPool.Submit(func() { + ngin.handle(conn) + }) + if err != nil { + nlog.Errorf("%s submit conn pool err: %ng", ngin.LogPrefix(), err.Error()) + continue + } + } +} diff --git a/server_ws.go b/server_ws.go new file mode 100644 index 0000000..584721b --- /dev/null +++ b/server_ws.go @@ -0,0 +1,66 @@ +package nnet + +import ( + "fmt" + "git.noahlan.cn/noahlan/nnet/config" + "git.noahlan.cn/noahlan/nnet/connection" + "git.noahlan.cn/noahlan/ntools-go/core/nlog" + "github.com/gorilla/websocket" + "net/http" + "os" + "strings" +) + +// ListenWebsocket 开始监听Websocket +func (ngin *Engine) ListenWebsocket(conf config.WSServerConf) error { + err := ngin.setup() + if err != nil { + nlog.Errorf("%s failed to setup server, err:%v", ngin.LogPrefix(), err) + return err + } + nlog.Infof("%s now listening websocket at %s.", ngin.LogPrefix(), conf.Addr) + ngin.upgradeWebsocket(conf) + if conf.IsTLS() { + if err := http.ListenAndServeTLS(conf.Addr, conf.TLSCertificate, conf.TLSKey, nil); err != nil { + nlog.Errorf("%s failed to listening websocket with TLS at %s %v", ngin.LogPrefix(), conf.Addr, err) + return err + } + } else { + if err := http.ListenAndServe(conf.Addr, nil); err != nil { + nlog.Errorf("%s failed to listening websocket at %s %v", ngin.LogPrefix(), conf.Addr, err) + return err + } + } + return nil +} + +func (ngin *Engine) handleWS(conn *websocket.Conn) { + wsConn := connection.NewWSConn(conn) + ngin.handle(wsConn) +} + +func (ngin *Engine) upgradeWebsocket(conf config.WSServerConf) { + upgrade := websocket.Upgrader{ + HandshakeTimeout: conf.HandshakeTimeout, + ReadBufferSize: conf.ReadBufferSize, + WriteBufferSize: conf.WriteBufferSize, + CheckOrigin: conf.CheckOrigin, + EnableCompression: conf.Compression, + } + + path := fmt.Sprintf("/%s", strings.TrimPrefix(conf.Path, "/")) + http.HandleFunc(path, func(writer http.ResponseWriter, request *http.Request) { + conn, err := upgrade.Upgrade(writer, request, nil) + if err != nil { + nlog.Errorf("%s Upgrade failure, URI=%ng, Error=%ng", ngin.LogPrefix(), request.RequestURI, err.Error()) + return + } + err = ngin.goPool.Submit(func() { + ngin.handleWS(conn) + }) + if err != nil { + nlog.Errorf("%s submit conn pool err: %v", ngin.LogPrefix(), err.Error()) + os.Exit(1) + } + }) +} diff --git a/session/session.go b/session/session.go new file mode 100644 index 0000000..8745820 --- /dev/null +++ b/session/session.go @@ -0,0 +1,104 @@ +package session + +import ( + "sync" +) + +type Session struct { + sync.RWMutex // 数据锁 + + id int64 // Session全局唯一ID + uid string // 用户ID,不绑定的情况下与sid一致 + data map[string]interface{} // session数据存储(内存) +} + +func NewSession(id int64) *Session { + return &Session{ + id: id, + uid: "", + data: make(map[string]interface{}), + } +} + +// ID 获取 session ID +func (s *Session) ID() int64 { + return s.id +} + +// UID 获取UID +func (s *Session) UID() string { + return s.uid +} + +// Bind 绑定uid +func (s *Session) Bind(uid string) { + s.uid = uid +} + +// Attribute 获取指定key对应参数 +func (s *Session) Attribute(key string) interface{} { + s.RLock() + defer s.RUnlock() + + return s.data[key] +} + +// Keys 获取所有参数key +func (s *Session) Keys() []string { + s.RLock() + defer s.RUnlock() + + keys := make([]string, 0, len(s.data)) + for k := range s.data { + keys = append(keys, k) + } + return keys +} + +// Exists 指定key是否存在 +func (s *Session) Exists(key string) bool { + s.RLock() + defer s.RUnlock() + + _, has := s.data[key] + return has +} + +// Attributes 获取所有参数 +func (s *Session) Attributes() map[string]interface{} { + s.RLock() + defer s.RUnlock() + + return s.data +} + +// RemoveAttribute 移除指定key对应参数 +func (s *Session) RemoveAttribute(key string) { + s.Lock() + defer s.Unlock() + + delete(s.data, key) +} + +// SetAttribute 设置参数 +func (s *Session) SetAttribute(key string, value interface{}) { + s.Lock() + defer s.Unlock() + + s.data[key] = value +} + +// Invalidate 清理 +func (s *Session) Invalidate() { + s.Lock() + defer s.Unlock() + + s.id = 0 + s.uid = "" + s.data = make(map[string]interface{}) +} + +// Close 关闭 +func (s *Session) Close() { + s.Invalidate() +} diff --git a/session/session_id.go b/session/session_id.go new file mode 100644 index 0000000..90a8f3f --- /dev/null +++ b/session/session_id.go @@ -0,0 +1,38 @@ +package session + +import "sync/atomic" + +type IDMgr struct { + count int64 + sid int64 +} + +func NewSessionIDMgr() *IDMgr { + return &IDMgr{} +} + +// Increment the connection count +func (c *IDMgr) Increment() { + atomic.AddInt64(&c.count, 1) +} + +// Decrement the connection count +func (c *IDMgr) Decrement() { + atomic.AddInt64(&c.count, -1) +} + +// Count returns the connection numbers in current +func (c *IDMgr) Count() int64 { + return atomic.LoadInt64(&c.count) +} + +// Reset the connection service status +func (c *IDMgr) Reset() { + atomic.StoreInt64(&c.count, 0) + atomic.StoreInt64(&c.sid, 0) +} + +// SessionID returns the session id +func (c *IDMgr) SessionID() int64 { + return atomic.AddInt64(&c.sid, 1) +} diff --git a/test/test_nnet.go b/test/test_nnet.go index f2e1dc9..177a112 100644 --- a/test/test_nnet.go +++ b/test/test_nnet.go @@ -2,11 +2,12 @@ package main import ( "encoding/json" + "git.noahlan.cn/noahlan/nnet" "git.noahlan.cn/noahlan/nnet/config" - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" - "git.noahlan.cn/noahlan/nnet/protocol/nnet" + protocol_nnet "git.noahlan.cn/noahlan/nnet/protocol/nnet" + rt "git.noahlan.cn/noahlan/nnet/router" "git.noahlan.cn/noahlan/ntools-go/core/nlog" "git.noahlan.cn/noahlan/ntools-go/core/pool" "math" @@ -14,80 +15,89 @@ import ( ) func runServer(addr string) { - server := core.NewServer(config.EngineConf{ - ServerConf: config.ServerConf{ - Protocol: "tcp", - Addr: addr, - Name: "testServer", - Mode: "dev", - }, - Pool: pool.Config{ - PoolSize: math.MaxInt32, - ExpiryDuration: time.Second, - PreAlloc: false, - MaxBlockingTasks: 0, - Nonblocking: false, - DisablePurge: false, - }, - }, nnet.WithNNetProtocol(nnet.Config{ + nginOpts := make([]nnet.RunOption, 0) + nginOpts = append(nginOpts, nnet.WithPoolCfg(pool.Config{ + PoolSize: math.MaxInt32, + ExpiryDuration: time.Second, + PreAlloc: false, + MaxBlockingTasks: 0, + Nonblocking: false, + DisablePurge: false, + })) + nginOpts = append(nginOpts, protocol_nnet.WithNNetProtocol(protocol_nnet.Config{ HeartbeatInterval: 0, HandshakeValidator: nil, })...) - - server.AddRoutes([]core.Route{ - { - Matches: nnet.Match{ - Route: "ping", - Code: 1, - }, - Handler: func(et entity.NetworkEntity, pkg packet.IPacket) { - nlog.Info("client ping, server pong -> ") - err := et.Send(nnet.Header{ - PacketType: nnet.Data, - MessageHeader: nnet.MessageHeader{ - MsgType: nnet.Request, - ID: 1, - Route: "pong", - }, - }, []byte("1")) - nlog.Must(err) - }, + ngin := nnet.NewEngine(config.EngineConf{ + TaskTimerPrecision: 0, + Mode: "dev", + Name: "NL", + }, nginOpts...) + ngin.AddRoutes(rt.Route{ + Matches: protocol_nnet.Match{ + Route: "ping", + Code: 1, + }, + Handler: func(conn *connection.Connection, pkg packet.IPacket) { + nlog.Info("client ping, server pong -> ") + err := conn.Send(protocol_nnet.Header{ + PacketType: protocol_nnet.Data, + MessageHeader: protocol_nnet.MessageHeader{ + MsgType: protocol_nnet.Request, + ID: 1, + Route: "pong", + }, + }, []byte("1")) + nlog.Must(err) }, }) - defer server.Stop() - server.Start() + defer ngin.Stop() + + err := ngin.ListenTCP(config.TCPServerConf{ + Protocol: "tcp", + Addr: addr, + }) + if err != nil { + return + } + } -func runClient(addr string) (client *core.Client, et entity.NetworkEntity) { +func runClient(addr string) (*nnet.Engine, *connection.Connection) { chReady := make(chan struct{}) - client = core.NewClient(config.EngineConf{ - Pool: pool.Config{ - PoolSize: math.MaxInt32, - ExpiryDuration: time.Second, - PreAlloc: false, - MaxBlockingTasks: 0, - Nonblocking: false, - DisablePurge: false, - }, - }, nnet.WithNNetClientProtocol(func() { + + nginOpts := make([]nnet.RunOption, 0) + nginOpts = append(nginOpts, nnet.WithPoolCfg(pool.Config{ + PoolSize: math.MaxInt32, + ExpiryDuration: time.Second, + PreAlloc: false, + MaxBlockingTasks: 0, + Nonblocking: false, + DisablePurge: false, + })) + nginOpts = append(nginOpts, protocol_nnet.WithNNetClientProtocol(func() { chReady <- struct{}{} })...) - client.AddRoutes([]core.Route{ - { - Matches: nnet.Match{ - Route: "test.client", - Code: 1, - }, - Handler: func(et entity.NetworkEntity, pkg packet.IPacket) { - nlog.Info("client hahaha") - }, + ngin := nnet.NewEngine(config.EngineConf{ + TaskTimerPrecision: 0, + Mode: "dev", + Name: "NL", + }, nginOpts...) + ngin.AddRoutes(rt.Route{ + Matches: protocol_nnet.Match{ + Route: "test.client", + Code: 1, + }, + Handler: func(conn *connection.Connection, pkg packet.IPacket) { + nlog.Info("client hahaha") }, }) - et = client.Dial(addr) + conn, err := ngin.Dial(addr) + nlog.Must(err) - handshake, err := json.Marshal(&nnet.HandshakeReq{ + handshake, err := json.Marshal(&protocol_nnet.HandshakeReq{ Version: "1.0.0", Type: "test", ClientId: "a", @@ -98,10 +108,10 @@ func runClient(addr string) (client *core.Client, et entity.NetworkEntity) { }) nlog.Must(err) - packer := nnet.NewPacker(nnet.NewRouteMap()) - hsd, err := packer.Pack(nnet.Header{ - PacketType: nnet.Handshake, - MessageHeader: nnet.MessageHeader{ + packer := protocol_nnet.NewPacker(protocol_nnet.NewRouteMap()) + hsd, err := packer.Pack(protocol_nnet.Header{ + PacketType: protocol_nnet.Handshake, + MessageHeader: protocol_nnet.MessageHeader{ MsgType: 0, ID: 0, Route: "", @@ -109,9 +119,9 @@ func runClient(addr string) (client *core.Client, et entity.NetworkEntity) { }, handshake) nlog.Must(err) - err = et.SendBytes(hsd) + err = conn.SendBytes(hsd) nlog.Must(err) <-chReady - return + return ngin, conn } diff --git a/test/test_nnet_test.go b/test/test_nnet_test.go index bc8222d..001b1a3 100644 --- a/test/test_nnet_test.go +++ b/test/test_nnet_test.go @@ -1,10 +1,10 @@ package main import ( - "git.noahlan.cn/noahlan/nnet/core" - "git.noahlan.cn/noahlan/nnet/entity" + "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/nnet/protocol/nnet" + rt "git.noahlan.cn/noahlan/nnet/router" "git.noahlan.cn/noahlan/ntools-go/core/nlog" "sync" "testing" @@ -15,13 +15,13 @@ func TestServer(t *testing.T) { } func TestClient(t *testing.T) { - client, et := runClient("127.0.0.1:6666") - client.AddRoute(core.Route{ + ngin, et := runClient("127.0.0.1:6666") + ngin.AddRoutes(rt.Route{ Matches: nnet.Match{ Route: "pong", Code: 2, }, - Handler: func(et entity.NetworkEntity, pkg packet.IPacket) { + Handler: func(conn *connection.Connection, pkg packet.IPacket) { nlog.Info("server pong, client ping ->") _ = et.Send(nnet.Header{ PacketType: nnet.Data,