feat: 暂且使用。

main v0.1.0
NorthLan 2 years ago
parent a2ed3090e7
commit dfbc5cbd63

@ -0,0 +1,5 @@
# NNet 轻量级 TCP/WS/UDP 网络库
===
> 封装了

@ -1,16 +1,16 @@
package component package component
import ( import (
"git.noahlan.cn/northlan/nnet/nface"
"reflect" "reflect"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
) )
var ( var (
typeOfError = reflect.TypeOf((*error)(nil)).Elem() typeOfError = reflect.TypeOf((*error)(nil)).Elem()
typeOfBytes = reflect.TypeOf(([]byte)(nil)) typeOfBytes = reflect.TypeOf(([]byte)(nil))
typeOfConnection = reflect.TypeOf((nface.IConnection)(nil)) // TODO cycle not allow IConnection
typeOfConnection = reflect.TypeOf(([]byte)(nil))
) )
func isExported(name string) bool { func isExported(name string) bool {

@ -0,0 +1,26 @@
package core
const (
// DevMode means development mode.
DevMode = "dev"
// TestMode means test mode.
TestMode = "test"
// ProductionMode means production mode.
ProductionMode = "prod"
)
type (
EngineConf struct {
// Protocol 协议名
// "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
// 若只想开启IPv4, 使用tcp4即可
Protocol string
// Addr 服务地址
// 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP
// 如果端口号不填或端口号为0例如"127.0.0.1:" 或 ":0",服务端将选择随机可用端口
Addr string
// Name 服务端名称默认为n-net
Name string
Mode string
}
)

@ -0,0 +1,283 @@
package core
import (
"errors"
"fmt"
"git.noahlan.cn/northlan/nnet/internal/log"
"git.noahlan.cn/northlan/nnet/internal/packet"
"git.noahlan.cn/northlan/nnet/internal/pool"
"git.noahlan.cn/northlan/nnet/scheduler"
"git.noahlan.cn/northlan/nnet/session"
"net"
"sync/atomic"
"time"
)
var (
ErrCloseClosedSession = errors.New("close closed session")
// ErrBrokenPipe represents the low-level connection has broken.
ErrBrokenPipe = errors.New("broken low-level pipe")
// ErrBufferExceed indicates that the current session buffer is full and
// can not receive more data.
ErrBufferExceed = errors.New("session send buffer exceed")
)
const (
// StatusStart 开始阶段
StatusStart int32 = iota + 1
// StatusPrepare 准备阶段
StatusPrepare
// StatusPending 等待工作阶段
StatusPending
// StatusWorking 工作阶段
StatusWorking
// StatusClosed 连接关闭
StatusClosed
)
type (
Connection struct {
session *session.Session // Session
ngin *engine // engine
conn net.Conn // low-level conn fd
status int32 // 连接状态
lastMid uint64 // 最近一次消息ID
// TODO 考虑独立出去作为一个中间件
lastHeartbeatAt int64 // 最近一次心跳时间
chDie chan struct{} // 停止通道
chSend chan pendingMessage // 消息发送通道(结构化消息)
chWrite chan []byte // 消息发送通道(二进制消息)
}
pendingMessage struct {
header interface{}
payload interface{}
}
)
func newConn(server *engine, conn net.Conn) *Connection {
r := &Connection{
conn: conn,
ngin: server,
status: StatusStart,
lastHeartbeatAt: time.Now().Unix(),
chDie: make(chan struct{}),
chSend: make(chan pendingMessage, 128),
chWrite: make(chan []byte, 128),
}
// binding session
r.session = session.NewSession()
return r
}
func (r *Connection) Send(header, payload interface{}) (err error) {
defer func() {
if e := recover(); e != nil {
err = ErrBrokenPipe
}
}()
r.chSend <- pendingMessage{
header: header,
payload: payload,
}
return err
}
func (r *Connection) SendBytes(data []byte) (err error) {
defer func() {
if e := recover(); e != nil {
err = ErrBrokenPipe
}
}()
r.chWrite <- data
return err
}
func (r *Connection) Status() int32 {
return atomic.LoadInt32(&r.status)
}
func (r *Connection) SetStatus(s int32) {
atomic.StoreInt32(&r.status, s)
}
func (r *Connection) Conn() net.Conn {
return r.conn
}
func (r *Connection) ID() int64 {
return r.session.ID()
}
func (r *Connection) SetLastHeartbeatAt(t int64) {
atomic.StoreInt64(&r.lastHeartbeatAt, t)
}
func (r *Connection) Session() *session.Session {
return r.session
}
func (r *Connection) LastMID() uint64 {
return r.lastMid
}
func (r *Connection) SetLastMID(mid uint64) {
atomic.StoreUint64(&r.lastMid, mid)
}
func (r *Connection) serve() {
_ = pool.SubmitConn(func() {
r.write()
})
_ = pool.SubmitWorker(func() {
r.read()
})
}
func (r *Connection) write() {
ticker := time.NewTicker(r.ngin.heartbeatInterval)
defer func() {
ticker.Stop()
close(r.chSend)
close(r.chWrite)
_ = r.Close()
log.Debugf("Connection write goroutine exit, ConnID=%d, SessionUID=%s", r.ID(), r.session.UID())
}()
for {
select {
case <-ticker.C:
// TODO heartbeat enable control
deadline := time.Now().Add(-2 * r.ngin.heartbeatInterval).Unix()
if atomic.LoadInt64(&r.lastHeartbeatAt) < deadline {
log.Debugf("Session heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&r.lastHeartbeatAt), deadline)
return
}
// TODO heartbeat data
r.chWrite <- []byte{}
case data := <-r.chSend:
// marshal packet body (data)
if r.ngin.serializer == nil {
if _, ok := data.payload.([]byte); !ok {
log.Errorf("serializer is nil, but payload type not []byte")
break
}
} else {
payload, err := r.ngin.serializer.Marshal(data.payload)
if err != nil {
log.Errorf("message body marshal err: %v", err)
break
}
data.payload = payload
}
// invoke pipeline
if pipe := r.ngin.pipeline; pipe != nil {
err := pipe.Outbound().Process(r, data)
if err != nil {
log.Errorf("broken pipeline err: %s", err.Error())
break
}
}
// packet pack data
p, err := r.ngin.packer.Pack(data.header, data.payload.([]byte))
if err != nil {
log.Error(err.Error())
break
}
r.chWrite <- p
case data := <-r.chWrite:
// 回写数据
if _, err := r.conn.Write(data); err != nil {
log.Error(err.Error())
return
}
case <-r.chDie: // connection close signal
return
case <-r.ngin.dieChan: // application quit signal
return
}
}
}
func (r *Connection) read() {
defer func() {
r.Close()
}()
buf := make([]byte, 4096)
for {
n, err := r.conn.Read(buf)
if err != nil {
log.Errorf("Read message error: %s, session will be closed immediately", err.Error())
return
}
if r.ngin.packer == nil {
log.Errorf("unexpected error: packer is nil")
return
}
// warning: 为性能考虑复用slice处理数据buf传入后必须要copy再处理
packets, err := r.ngin.packer.Unpack(buf[:n])
if err != nil {
log.Error(err.Error())
}
// packets 处理
for _, p := range packets {
if err := r.processPacket(p); err != nil {
log.Error(err.Error())
continue
}
}
}
}
func (r *Connection) processPacket(packet packet.IPacket) error {
if pipe := r.ngin.pipeline; pipe != nil {
err := pipe.Inbound().Process(r, packet)
if err != nil {
return errors.New(fmt.Sprintf("pipeline process failed: %v", err.Error()))
}
}
// packet processor
err := r.ngin.processor.Process(r, packet)
if err != nil {
return err
}
if r.Status() == StatusWorking {
// HandleFunc
_ = pool.SubmitWorker(func() {
r.ngin.handler.Handle(r, packet)
})
}
return err
}
func (r *Connection) Close() error {
if r.Status() == StatusClosed {
return ErrCloseClosedSession
}
r.SetStatus(StatusClosed)
log.Debugf("close connection, ID: %d", r.ID())
select {
case <-r.chDie:
default:
close(r.chDie)
scheduler.PushTask(func() { Lifetime.Close(r) })
}
return r.conn.Close()
}

@ -0,0 +1,249 @@
package core
import (
"errors"
"git.noahlan.cn/northlan/nnet/internal/log"
"git.noahlan.cn/northlan/nnet/internal/packet"
"git.noahlan.cn/northlan/nnet/internal/pool"
"git.noahlan.cn/northlan/nnet/scheduler"
"git.noahlan.cn/northlan/nnet/serialize"
"git.noahlan.cn/northlan/nnet/session"
"github.com/gorilla/websocket"
"net"
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"
)
func NotFound(conn *Connection, packet packet.IPacket) {
log.Error("handler not found")
conn.SendBytes([]byte("handler not found"))
}
func NotFoundHandler() Handler {
return HandlerFunc(NotFound)
}
type (
// engine TCP-engine
engine struct {
conf EngineConf // conf 配置
middlewares []Middleware // 中间件
routes []Route // 路由
// handler 消息处理器
handler Handler
// dieChan 应用程序退出信号
dieChan chan struct{}
// sessionMgr session管理器
sessionMgr *session.Manager
pipeline Pipeline // 消息管道
packer packet.Packer // 封包、拆包器
processor Processor // 数据包处理器
serializer serialize.Serializer // 消息 序列化/反序列化
retryInterval time.Duration // 消息重试间隔时长
heartbeatInterval time.Duration // 心跳间隔0表示不进行心跳
wsOpt wsOptions // websocket
}
wsOptions struct {
IsWebsocket bool // 是否为websocket服务端
WebsocketPath string // ws地址(ws://127.0.0.1/WebsocketPath)
TLSCertificate string // TLS 证书地址 (websocket)
TLSKey string // TLS 证书key地址 (websocket)
CheckOrigin func(*http.Request) bool // check origin
}
)
func newEngine(conf EngineConf) *engine {
s := &engine{
conf: conf,
dieChan: make(chan struct{}),
sessionMgr: session.NewSessionMgr(),
}
pool.InitPool(10000)
return s
}
func (ng *engine) use(middleware Middleware) {
ng.middlewares = append(ng.middlewares, middleware)
}
func (ng *engine) addRoutes(route ...Route) {
ng.routes = append(ng.routes, route...)
}
func (ng *engine) bindRoutes(router Router) error {
for _, fr := range ng.routes {
if err := ng.bindRoute(router, fr); err != nil {
return err
}
}
return nil
}
func (ng *engine) bindRoute(router Router, route Route) error {
// TODO default middleware
chain := newChain()
// build chain
for _, middleware := range ng.middlewares {
chain.Append(convertMiddleware(middleware))
}
return router.Register(route.Matches, route.Handler)
}
func convertMiddleware(ware Middleware) func(Handler) Handler {
return func(next Handler) Handler {
return ware(next.Handle)
}
}
func (ng *engine) serve(router Router) error {
ng.handler = router
if err := ng.bindRoutes(router); err != nil {
return err
}
go func() {
if ng.wsOpt.IsWebsocket {
if len(ng.wsOpt.TLSCertificate) != 0 && len(ng.wsOpt.TLSKey) != 0 {
ng.listenAndServeWSTLS()
} else {
ng.listenAndServeWS()
}
} else {
ng.listenAndServe()
}
}()
go scheduler.Schedule()
sg := make(chan os.Signal)
signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM)
select {
case <-ng.dieChan:
log.Info("Server will shutdown in a few seconds")
case s := <-sg:
log.Infof("Server got signal: %ng", s)
}
log.Info("Server is stopping...")
ng.shutdown()
scheduler.Close()
return nil
}
func (ng *engine) close() {
close(ng.dieChan)
}
func (ng *engine) shutdown() {
}
func (ng *engine) listenAndServe() {
listener, err := net.Listen(ng.conf.Protocol, ng.conf.Addr)
if err != nil {
panic(err)
}
// 监听成功,服务已启动
log.Infof("now listening %s on %s.", ng.conf.Protocol, ng.conf.Addr)
defer func() {
listener.Close()
ng.close()
}()
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Errorf("服务器网络错误 %+v", err)
return
}
log.Errorf("监听错误 %v", err)
continue
}
err = pool.SubmitConn(func() {
ng.handle(conn)
})
if err != nil {
log.Errorf("submit conn pool err: %ng", err.Error())
continue
}
}
}
func (ng *engine) listenAndServeWS() {
ng.setupWS()
if err := http.ListenAndServe(ng.conf.Addr, nil); err != nil {
log.Fatal(err.Error())
}
}
func (ng *engine) listenAndServeWSTLS() {
ng.setupWS()
if err := http.ListenAndServeTLS(ng.conf.Addr, ng.wsOpt.TLSCertificate, ng.wsOpt.TLSKey, nil); err != nil {
log.Fatal(err.Error())
}
}
func (ng *engine) setupWS() {
upgrade := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: ng.wsOpt.CheckOrigin,
}
http.HandleFunc("/"+strings.TrimPrefix(ng.wsOpt.WebsocketPath, "/"), func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrade.Upgrade(w, r, nil)
if err != nil {
log.Errorf("Upgrade failure, URI=%ng, Error=%ng", r.RequestURI, err.Error())
return
}
err = pool.SubmitConn(func() {
ng.handleWS(conn)
})
if err != nil {
log.Fatalf("submit conn pool err: %ng", err.Error())
}
})
}
func (ng *engine) handleWS(conn *websocket.Conn) {
c, err := newWSConn(conn)
if err != nil {
log.Error(err)
return
}
ng.handle(c)
}
func (ng *engine) handle(conn net.Conn) {
connection := newConn(ng, conn)
ng.sessionMgr.StoreSession(connection.Session())
connection.serve()
// hook
}
func (ng *engine) notFoundHandler(next Handler) Handler {
return HandlerFunc(func(conn *Connection, packet packet.IPacket) {
h := next
if next == nil {
h = NotFoundHandler()
}
// TODO write to client
h.Handle(conn, packet)
})
}

@ -0,0 +1,40 @@
package core
type (
LifetimeHandler func(conn *Connection)
lifetime struct {
onOpen []LifetimeHandler
onClosed []LifetimeHandler
}
)
var Lifetime = &lifetime{}
func (lt *lifetime) OnClosed(h LifetimeHandler) {
lt.onClosed = append(lt.onClosed, h)
}
func (lt *lifetime) OnOpen(h LifetimeHandler) {
lt.onOpen = append(lt.onOpen, h)
}
func (lt *lifetime) Open(conn *Connection) {
if len(lt.onOpen) <= 0 {
return
}
for _, handler := range lt.onOpen {
handler(conn)
}
}
func (lt *lifetime) Close(conn *Connection) {
if len(lt.onClosed) <= 0 {
return
}
for _, handler := range lt.onClosed {
handler(conn)
}
}

@ -0,0 +1,49 @@
package core
import (
"errors"
"git.noahlan.cn/northlan/nnet/internal/log"
"git.noahlan.cn/northlan/nnet/internal/packet"
)
type nNetRouter struct {
handlers map[string]Handler
notFound Handler
}
func NewRouter() Router {
return &nNetRouter{
handlers: make(map[string]Handler),
}
}
func (r *nNetRouter) Handle(conn *Connection, p packet.IPacket) {
pkg, ok := p.(*packet.Packet)
if !ok {
log.Error(packet.ErrWrongMessage)
return
}
handler, ok := r.handlers[pkg.Header.Route]
if !ok {
if r.notFound == nil {
log.Error("message handler not found")
return
}
r.notFound.Handle(conn, p)
return
}
handler.Handle(conn, p)
}
func (r *nNetRouter) Register(matches interface{}, handler Handler) error {
route, ok := matches.(string)
if !ok {
return errors.New("the type of matches must be string")
}
r.handlers[route] = handler
return nil
}
func (r *nNetRouter) SetNotFoundHandler(handler Handler) {
r.notFound = handler
}

@ -1,12 +1,11 @@
package pipeline package core
import ( import (
"git.noahlan.cn/northlan/nnet/nface"
"sync" "sync"
) )
type ( type (
Func func(nface.IConnection) error Func func(conn *Connection, v interface{}) error
// Pipeline 消息管道 // Pipeline 消息管道
Pipeline interface { Pipeline interface {
@ -21,7 +20,7 @@ type (
Channel interface { Channel interface {
PushFront(h Func) PushFront(h Func)
PushBack(h Func) PushBack(h Func)
Process(nface.IConnection) error Process(conn *Connection, v interface{}) error
} }
pipelineChannel struct { pipelineChannel struct {
@ -66,7 +65,7 @@ func (p *pipelineChannel) PushBack(h Func) {
} }
// Process 处理所有的pipeline方法 // Process 处理所有的pipeline方法
func (p *pipelineChannel) Process(conn nface.IConnection) error { func (p *pipelineChannel) Process(conn *Connection, v interface{}) error {
p.mu.RLock() p.mu.RLock()
defer p.mu.RUnlock() defer p.mu.RUnlock()
@ -75,7 +74,7 @@ func (p *pipelineChannel) Process(conn nface.IConnection) error {
} }
for _, handler := range p.handlers { for _, handler := range p.handlers {
err := handler(conn) err := handler(conn, v)
if err != nil { if err != nil {
return err return err
} }

@ -0,0 +1,9 @@
package core
import "git.noahlan.cn/northlan/nnet/internal/packet"
type (
Processor interface {
Process(conn *Connection, packet packet.IPacket) error
}
)

@ -0,0 +1,72 @@
package core
import (
"encoding/json"
"errors"
"fmt"
"git.noahlan.cn/northlan/nnet/internal/log"
"git.noahlan.cn/northlan/nnet/internal/packet"
"time"
)
var (
hrd []byte // handshake response data
hbd []byte // heartbeat packet data
)
type NNetProcessor struct {
}
func NewNNetProcessor() *NNetProcessor {
// TODO custom hrd hbd
data, _ := json.Marshal(map[string]interface{}{
"code": 200,
"sys": map[string]float64{"heartbeat": time.Second.Seconds()},
})
packer := packet.NewNNetPacker()
hrd, _ = packer.Pack(packet.Handshake, data)
return &NNetProcessor{}
}
func (n *NNetProcessor) Process(conn *Connection, p packet.IPacket) error {
h, ok := p.(*packet.Packet)
if !ok {
return packet.ErrWrongPacketType
}
switch h.PacketType {
case packet.Handshake:
// TODO validate handshake
if err := conn.SendBytes(hrd); err != nil {
return err
}
conn.SetStatus(StatusPrepare)
log.Debugf("connection handshake Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr())
case packet.HandshakeAck:
conn.SetStatus(StatusPending)
log.Debugf("receive handshake ACK Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr())
case packet.Heartbeat:
// Expected
case packet.Data:
if conn.Status() < StatusPending {
return errors.New(fmt.Sprintf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s",
conn.Conn().RemoteAddr()))
}
conn.SetStatus(StatusWorking)
var lastMid uint64
switch h.MsgType {
case packet.Request:
lastMid = h.ID
case packet.Notify:
lastMid = 0
default:
return fmt.Errorf("Invalid message type: %s ", h.MsgType.String())
}
conn.SetLastMID(lastMid)
}
conn.SetLastHeartbeatAt(time.Now().Unix())
return nil
}

@ -0,0 +1,54 @@
package core
type (
Middleware func(next HandlerFunc) HandlerFunc
Route struct {
Matches interface{} // 用于匹配的关键字段
Handler HandlerFunc // 处理方法
}
Router interface {
Handler
Register(matches interface{}, handler Handler) error
SetNotFoundHandler(handler Handler)
}
Constructor func(Handler) Handler
)
type Chain struct {
constructors []Constructor
}
func newChain(constructors ...Constructor) Chain {
return Chain{append(([]Constructor)(nil), constructors...)}
}
func (c Chain) Then(h Handler) Handler {
// TODO nil
for i := range c.constructors {
h = c.constructors[len(c.constructors)-1-i](h)
}
return h
}
func (c Chain) ThenFunc(fn HandlerFunc) Handler {
if fn == nil {
return c.Then(nil)
}
return c.Then(fn)
}
func (c Chain) Append(constructors ...Constructor) Chain {
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
newCons = append(newCons, c.constructors...)
newCons = append(newCons, constructors...)
return Chain{newCons}
}
func (c Chain) Extend(chain Chain) Chain {
return c.Append(chain.constructors...)
}

@ -0,0 +1,191 @@
package core
import (
"git.noahlan.cn/northlan/nnet/internal/env"
"git.noahlan.cn/northlan/nnet/internal/log"
"git.noahlan.cn/northlan/nnet/internal/packet"
"git.noahlan.cn/northlan/nnet/serialize"
"net/http"
"time"
)
type (
// RunOption defines the method to customize a Server.
RunOption func(*Server)
Server struct {
ngin *engine
router Router
}
)
// NewServer returns a server with given config of c and options defined in opts.
// Be aware that later RunOption might overwrite previous one that write the same option.
func NewServer(c EngineConf, opts ...RunOption) *Server {
s := &Server{
ngin: newEngine(c),
router: NewRouter(),
}
opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...)
for _, opt := range opts {
opt(s)
}
return s
}
// AddRoutes add given routes into the Server.
func (s *Server) AddRoutes(rs []Route) {
s.ngin.addRoutes(rs...)
}
// AddRoute adds given route into the Server.
func (s *Server) AddRoute(r Route) {
s.AddRoutes([]Route{r})
}
// Start starts the Server.
// Graceful shutdown is enabled by default.
func (s *Server) Start() {
if err := s.ngin.serve(s.router); err != nil {
log.Error(err)
panic(err)
}
}
// Stop stops the Server.
func (s *Server) Stop() {
s.ngin.close()
}
// Use adds the given middleware in the Server.
func (s *Server) Use(middleware Middleware) {
s.ngin.use(middleware)
}
// ToMiddleware converts the given handler to a Middleware.
func ToMiddleware(handler func(next Handler) Handler) Middleware {
return func(next HandlerFunc) HandlerFunc {
return handler(next).Handle
}
}
// WithMiddlewares adds given middlewares to given routes.
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
for i := len(ms) - 1; i >= 0; i-- {
rs = WithMiddleware(ms[i], rs...)
}
return rs
}
// WithMiddleware adds given middleware to given route.
func WithMiddleware(middleware Middleware, rs ...Route) []Route {
routes := make([]Route, len(rs))
for i := range rs {
route := rs[i]
routes[i] = Route{
Matches: route.Matches,
Handler: middleware(route.Handler),
}
}
return routes
}
// WithNotFoundHandler returns a RunOption with not found handler set to given handler.
func WithNotFoundHandler(handler Handler) RunOption {
return func(server *Server) {
notFoundHandler := server.ngin.notFoundHandler(handler)
server.router.SetNotFoundHandler(notFoundHandler)
}
}
func WithRouter(router Router) RunOption {
return func(server *Server) {
server.router = router
}
}
func WithPacker(packer packet.Packer) RunOption {
return func(server *Server) {
server.ngin.packer = packer
}
}
func WithProcessor(p Processor) RunOption {
return func(server *Server) {
server.ngin.processor = p
}
}
func WithSerializer(s serialize.Serializer) RunOption {
return func(server *Server) {
server.ngin.serializer = s
}
}
func WithLogger(logger log.Logger) RunOption {
return func(_ *Server) {
log.SetLogger(logger)
}
}
// WithTimerPrecision 设置Timer精度
// 注精度需大于1ms, 并且不能在运行时更改
// 默认精度是 time.Second
func WithTimerPrecision(precision time.Duration) RunOption {
if precision < time.Millisecond {
panic("time precision can not less than a Millisecond")
}
return func(_ *Server) {
env.TimerPrecision = precision
}
}
func WithPipeline(pipeline Pipeline) RunOption {
return func(server *Server) {
server.ngin.pipeline = pipeline
}
}
func WithHeartbeatInterval(d time.Duration) RunOption {
return func(server *Server) {
server.ngin.heartbeatInterval = d
}
}
type WSOption func(opts *wsOptions)
// WithWebsocket 开启Websocket, 参数是websocket的相关参数 nnet.WSOption
func WithWebsocket(wsOpts ...WSOption) RunOption {
return func(server *Server) {
for _, opt := range wsOpts {
opt(&server.ngin.wsOpt)
}
server.ngin.wsOpt.IsWebsocket = true
}
}
// WithWSPath 设置websocket的path
func WithWSPath(path string) WSOption {
return func(opts *wsOptions) {
opts.WebsocketPath = path
}
}
// WithWSTLSConfig 设置websocket的证书和密钥
func WithWSTLSConfig(certificate, key string) WSOption {
return func(opts *wsOptions) {
opts.TLSCertificate = certificate
opts.TLSKey = key
}
}
func WithWSCheckOriginFunc(fn func(*http.Request) bool) WSOption {
return func(opts *wsOptions) {
if fn != nil {
opts.CheckOrigin = fn
}
}
}

@ -0,0 +1,51 @@
package core
import (
"fmt"
"git.noahlan.cn/northlan/nnet/internal/log"
"git.noahlan.cn/northlan/nnet/internal/packet"
"testing"
"time"
)
func TestServer(t *testing.T) {
server := NewServer(EngineConf{
Protocol: "tcp",
Addr: ":12345",
Name: "N-Net",
Mode: DevMode,
},
WithPacker(packet.NewNNetPacker()),
WithSerializer(nil),
WithHeartbeatInterval(time.Hour),
WithProcessor(NewNNetProcessor()),
)
server.AddRoute(Route{
Matches: "test",
Handler: func(conn *Connection, pkg packet.IPacket) {
fmt.Println(pkg)
p, ok := pkg.(*packet.Packet)
if !ok {
log.Error("wrong packet type")
return
}
bd := []byte("服务器接收到数据为: " + string(p.GetBody()))
// 注Response类型数据不需要Route原地返回客户端需等待
conn.Send(packet.Header{
PacketType: packet.Data,
Length: uint32(len(bd)),
MessageHeader: packet.MessageHeader{
MsgType: packet.Response,
ID: p.ID,
Route: p.Route,
},
}, bd)
},
})
defer server.Stop()
server.Start()
}

@ -0,0 +1,15 @@
package core
import "git.noahlan.cn/northlan/nnet/internal/packet"
type (
Handler interface {
Handle(conn *Connection, pkg packet.IPacket)
}
HandlerFunc func(conn *Connection, pkg packet.IPacket)
)
func (f HandlerFunc) Handle(conn *Connection, pkg packet.IPacket) {
f(conn, pkg)
}

@ -1,4 +1,4 @@
package nnet package core
import ( import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -64,12 +64,12 @@ func (c *wsConn) Close() error {
return c.conn.Close() return c.conn.Close()
} }
// LocalAddr returns the local network address. // LocalAddr returns the local network Addr.
func (c *wsConn) LocalAddr() net.Addr { func (c *wsConn) LocalAddr() net.Addr {
return c.conn.LocalAddr() return c.conn.LocalAddr()
} }
// RemoteAddr returns the remote network address. // RemoteAddr returns the remote network Addr.
func (c *wsConn) RemoteAddr() net.Addr { func (c *wsConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr() return c.conn.RemoteAddr()
} }

@ -56,33 +56,33 @@ func newInnerLogger() Logger {
} }
func (i *innerLogger) Debugf(format string, v ...interface{}) { func (i *innerLogger) Debugf(format string, v ...interface{}) {
i.log.Printf(format+"\n", v) i.log.Printf(format+"\n", v...)
} }
func (i *innerLogger) Debug(v ...interface{}) { func (i *innerLogger) Debug(v ...interface{}) {
i.log.Println(v) i.log.Println(v...)
} }
func (i *innerLogger) Info(v ...interface{}) { func (i *innerLogger) Info(v ...interface{}) {
i.log.Println(v) i.log.Println(v...)
} }
func (i *innerLogger) Infof(format string, v ...interface{}) { func (i *innerLogger) Infof(format string, v ...interface{}) {
i.log.Printf(format+"\n", v) i.log.Printf(format+"\n", v...)
} }
func (i *innerLogger) Error(v ...interface{}) { func (i *innerLogger) Error(v ...interface{}) {
i.log.Println(v) i.log.Println(v...)
} }
func (i *innerLogger) Errorf(format string, v ...interface{}) { func (i *innerLogger) Errorf(format string, v ...interface{}) {
i.log.Printf(format+"\n", v) i.log.Printf(format+"\n", v...)
} }
func (i *innerLogger) Fatal(v ...interface{}) { func (i *innerLogger) Fatal(v ...interface{}) {
i.log.Fatal(v) i.log.Fatal(v...)
} }
func (i *innerLogger) Fatalf(format string, v ...interface{}) { func (i *innerLogger) Fatalf(format string, v ...interface{}) {
i.log.Fatalf(format+"\n", v) i.log.Fatalf(format+"\n", v...)
} }

@ -0,0 +1,247 @@
package packet
import (
"bytes"
"encoding/binary"
"errors"
)
type NNetPacker struct {
buf *bytes.Buffer
size int // 最近一次 length
typ byte // 最近一次 packet type
flag byte // 最近一次 flag
}
// packer constants.
const (
headLength = 5
maxPacketSize = 64 * 1024
msgRouteCompressMask = 0x01 // 0000 0001 last bit
msgTypeMask = 0x07 // 0000 0111 1-3 bit (需要>>)
msgRouteLengthMask = 0xFF // 1111 1111 last 8 bit
msgHeadLength = 0x02 // 0000 0010 2 bit
)
var (
ErrPacketSizeExceed = errors.New("packer: packet size exceed")
ErrWrongMessageType = errors.New("wrong message type")
ErrRouteInfoNotFound = errors.New("route info not found in dictionary")
ErrWrongMessage = errors.New("wrong message")
// ErrWrongPacketType represents a wrong packet type.
ErrWrongPacketType = errors.New("wrong packet type")
)
var (
routes = make(map[string]uint16) // route map to code
codes = make(map[uint16]string) // code map to route
)
func NewNNetPacker() *NNetPacker {
p := &NNetPacker{
buf: bytes.NewBuffer(nil),
}
p.resetFlags()
return p
}
func (d *NNetPacker) resetFlags() {
d.size = -1
d.typ = byte(Unknown)
d.flag = 0x00
}
func (d *NNetPacker) routable(t MsgType) bool {
return t == Request || t == Notify || t == Push
}
func (d *NNetPacker) invalidType(t MsgType) bool {
return t < Request || t > Push
}
func (d *NNetPacker) Pack(header interface{}, data []byte) ([]byte, error) {
h, ok := header.(Header)
if !ok {
return nil, ErrWrongPacketType
}
typ := h.PacketType
if typ < Handshake || typ > Kick {
return nil, ErrWrongPacketType
}
if d.invalidType(h.MsgType) {
return nil, ErrWrongMessageType
}
buf := make([]byte, 0)
// packet type
buf = append(buf, byte(h.PacketType))
// length
buf = append(buf, d.intToBytes(uint32(len(data)))...)
// flag
flag := byte(h.MsgType << 1) // 编译器提示,此处 byte 转换不能删
code, compressed := routes[h.Route]
if compressed {
flag |= msgRouteCompressMask
}
buf = append(buf, flag)
// msg id
if h.MsgType == Request || h.MsgType == Response {
n := h.ID
// variant length encode
for {
b := byte(n % 128)
n >>= 7
if n != 0 {
buf = append(buf, b+128)
} else {
buf = append(buf, b)
break
}
}
}
// route
if d.routable(h.MsgType) {
if compressed {
buf = append(buf, byte((code>>8)&0xFF))
buf = append(buf, byte(code&0xFF))
} else {
buf = append(buf, byte(len(h.Route)))
buf = append(buf, []byte(h.Route)...)
}
}
// body
buf = append(buf, data...)
return buf, nil
}
// Encode packet data length to bytes(Big end)
func (d *NNetPacker) intToBytes(n uint32) []byte {
buf := make([]byte, 3)
buf[0] = byte((n >> 16) & 0xFF)
buf[1] = byte((n >> 8) & 0xFF)
buf[2] = byte(n & 0xFF)
return buf
}
func (d *NNetPacker) Unpack(data []byte) ([]IPacket, error) {
d.buf.Write(data) // copy
var (
packets []IPacket
err error
)
// 检查包长度
if d.buf.Len() < headLength {
return nil, err
}
// 第一次拆包
if d.size < 0 {
if err = d.readHeader(); err != nil {
return nil, err
}
}
for d.size <= d.buf.Len() {
// 读取
p := newPacket(Type(d.typ))
p.MsgType = MsgType((d.flag >> 1) & msgTypeMask)
if d.invalidType(p.MsgType) {
return nil, ErrWrongMessageType
}
if p.MsgType == Request || p.MsgType == Response {
id := uint64(0)
// little end byte order
// WARNING: must be stored in 64 bits integer
// variant length encode
c := 0
for {
b, err := d.buf.ReadByte()
if err != nil {
break
}
id += uint64(b&0x7F) << uint64(7*c)
if b < 128 {
break
}
c++
}
p.ID = id
}
if d.routable(p.MsgType) {
if d.flag&msgRouteCompressMask == 1 {
p.compressed = true
code := binary.BigEndian.Uint16(d.buf.Next(2))
route, ok := codes[code]
if !ok {
return nil, ErrRouteInfoNotFound
}
p.Route = route
} else {
p.compressed = false
rl, _ := d.buf.ReadByte()
if int(rl) > d.buf.Len() {
return nil, ErrWrongMessage
}
p.Route = string(d.buf.Next(int(rl)))
}
}
p.Length = uint32(d.size)
p.Data = d.buf.Next(d.size)
packets = append(packets, p)
// 剩余数据不满足至少一个数据帧,重置数据帧长度
// 数据缓存内存 保留至 下一次进入本方法以继续拆包
if d.buf.Len() < headLength {
d.resetFlags()
break
}
// 读取下一个包 next
if err = d.readHeader(); err != nil {
return packets, err
}
}
if packets == nil || len(packets) <= 0 {
d.resetFlags()
d.buf.Reset()
}
return packets, nil
}
func (d *NNetPacker) readHeader() error {
header := d.buf.Next(headLength)
d.typ = header[0]
if d.typ < Handshake || d.typ > Kick {
return ErrWrongPacketType
}
d.size = d.bytesToInt(header[1 : len(header)-1])
d.flag = header[len(header)-1]
// 最大包限定
if d.size > maxPacketSize {
return ErrPacketSizeExceed
}
return nil
}
// Decode packet data length byte to int(Big end)
func (d *NNetPacker) bytesToInt(b []byte) int {
result := 0
for _, v := range b {
result = result<<8 + int(v)
}
return result
}

@ -0,0 +1,98 @@
package packet
import (
"encoding/hex"
"fmt"
"testing"
)
func TestPacker(t *testing.T) {
p := NewNNetPacker()
body := []byte("")
header := Header{
PacketType: Handshake,
Length: uint32(len(body)),
MessageHeader: MessageHeader{
MsgType: Request,
ID: 1,
Route: "",
compressed: false,
},
}
pack, err := p.Pack(header, body)
if err != nil {
return
}
fmt.Println(hex.EncodeToString(pack))
// handshake ack
body = []byte("")
header = Header{
PacketType: HandshakeAck,
Length: uint32(len(body)),
MessageHeader: MessageHeader{
MsgType: Response,
ID: 1,
Route: "",
compressed: false,
},
}
pack, err = p.Pack(header, body)
if err != nil {
return
}
fmt.Println(hex.EncodeToString(pack))
// data
body = []byte("123")
header = Header{
PacketType: Data,
Length: uint32(len(body)),
MessageHeader: MessageHeader{
MsgType: Request,
ID: 2,
Route: "",
compressed: false,
},
}
pack, err = p.Pack(header, body)
if err != nil {
return
}
fmt.Println(hex.EncodeToString(pack))
// data -> route: test -> handler
// data
body = []byte("ni hao")
header = Header{
PacketType: Data,
Length: uint32(len(body)),
MessageHeader: MessageHeader{
MsgType: Request,
ID: 3,
Route: "test",
compressed: false,
},
}
pack, err = p.Pack(header, body)
if err != nil {
return
}
fmt.Println(hex.EncodeToString(pack))
}
func TestUnPack(t *testing.T) {
data := []byte{0x04, 0x00, 0x00, 0x23, 0x04, 0x03, 0xE6, 0x9C, 0x8D, 0xE5, 0x8A, 0xA1, 0xE5, 0x99, 0xA8, 0xE6, 0x8E, 0xA5, 0xE6, 0x94, 0xB6, 0xE5, 0x88, 0xB0, 0xE6, 0x95, 0xB0, 0xE6, 0x8D, 0xAE, 0xE4, 0xB8, 0xBA, 0x3A, 0x20, 0x6E, 0x69, 0x20, 0x68, 0x61, 0x6F}
p := NewNNetPacker()
unPacked, err := p.Unpack(data)
if err != nil {
panic(err)
}
fmt.Println(unPacked)
}

@ -0,0 +1,19 @@
package packet
type (
// IPacket 数据帧
IPacket interface {
GetHeader() interface{} // 数据帧头部 Header
GetLen() uint32 // 数据帧长度 4bytes - 32bit 占位,根据实际情况进行转换
GetBody() []byte // 数据 Body
}
// Packer 数据帧 封包/解包
Packer interface {
// Pack 封包,将原始数据构造为二进制流数据帧
Pack(header interface{}, data []byte) ([]byte, error)
// Unpack 解包
Unpack(data []byte) ([]IPacket, error)
}
)

@ -0,0 +1,94 @@
package packet
import (
"encoding/hex"
"fmt"
)
// Type 数据帧类型,如:握手,心跳,数据 等
type Type byte
const (
// Unknown 未知包类型,无意义
Unknown Type = iota
// Handshake 握手数据(服务端主动发起)
Handshake = 0x01
// HandshakeAck 握手回复(客户端回复)
HandshakeAck = 0x02
// Heartbeat 心跳(服务端发起)
Heartbeat = 0x03
// Data 数据传输
Data = 0x04
// Kick 服务端主动断开连接
Kick = 0x05
)
type MsgType byte
// Message types
const (
Request MsgType = 0x00
Notify = 0x01
Response = 0x02
Push = 0x03
)
var msgTypes = map[MsgType]string{
Request: "Request",
Notify: "Notify",
Response: "Response",
Push: "Push",
}
func (t MsgType) String() string {
return msgTypes[t]
}
type (
Header struct {
PacketType Type // 数据帧 类型
Length uint32 // 数据长度
MessageHeader // 消息头
}
MessageHeader struct {
MsgType MsgType // message type (flag)
ID uint64 // unique id, zero while notify mode
Route string // route for locating service
compressed bool // if message compressed
}
Packet struct {
Header
Data []byte // 原始数据
}
)
func newPacket(typ Type) *Packet {
return &Packet{
Header: Header{
PacketType: typ,
MessageHeader: MessageHeader{},
},
}
}
func (p *Packet) GetHeader() interface{} {
return p.Header
}
func (p *Packet) GetLen() uint32 {
return p.Length
}
func (p *Packet) GetBody() []byte {
return p.Data
}
func (p *Packet) String() string {
return fmt.Sprintf("Packet[Type: %d, Len: %d] Message[Type: %s, ID: %d, Route: %s, Compressed: %v] BodyStr: [%s], BodyHex: [%s]",
p.PacketType, p.Length, p.MsgType, p.ID, p.Route, p.compressed, string(p.Data), hex.EncodeToString(p.Data))
}

@ -0,0 +1,27 @@
package pool
import "github.com/panjf2000/ants/v2"
var _pool *pool
type pool struct {
connPool *ants.Pool
workerPool *ants.Pool
}
func InitPool(size int) {
p := &pool{}
p.connPool, _ = ants.NewPool(size, ants.WithNonblocking(true))
p.workerPool, _ = ants.NewPool(size*2, ants.WithNonblocking(true))
_pool = p
}
func SubmitConn(h func()) error {
return _pool.connPool.Submit(h)
}
func SubmitWorker(h func()) error {
return _pool.workerPool.Submit(h)
}

@ -1,146 +0,0 @@
package message
import (
"encoding/binary"
"errors"
)
var _ Codec = (*NNetCodec)(nil)
const (
msgRouteCompressMask = 0x01 // 0000 0001 last bit
msgTypeMask = 0x07 // 0000 0111 1-3 bit (需要>>)
msgRouteLengthMask = 0xFF // 1111 1111 last 8 bit
msgHeadLength = 0x02 // 0000 0010 2 bit
)
// Errors that could be occurred in message codec
var (
ErrWrongMessageType = errors.New("wrong message type")
ErrInvalidMessage = errors.New("invalid message")
ErrRouteInfoNotFound = errors.New("route info not found in dictionary")
ErrWrongMessage = errors.New("wrong message")
)
var (
routes = make(map[string]uint16) // route map to code
codes = make(map[uint16]string) // code map to route
)
type NNetCodec struct{}
func (n *NNetCodec) routable(t Type) bool {
return t == Request || t == Notify || t == Push
}
func (n *NNetCodec) invalidType(t Type) bool {
return t < Request || t > Push
}
func (n *NNetCodec) Encode(v interface{}) ([]byte, error) {
m, ok := v.(*Message)
if !ok {
return nil, ErrWrongMessageType
}
if n.invalidType(m.Type) {
return nil, ErrWrongMessageType
}
buf := make([]byte, 0)
flag := byte(m.Type << 1) // 编译器提示,此处 byte 转换不能删
code, compressed := routes[m.Route]
if compressed {
flag |= msgRouteCompressMask
}
buf = append(buf, flag)
if m.Type == Request || m.Type == Response {
n := m.ID
// variant length encode
for {
b := byte(n % 128)
n >>= 7
if n != 0 {
buf = append(buf, b+128)
} else {
buf = append(buf, b)
break
}
}
}
if n.routable(m.Type) {
if compressed {
buf = append(buf, byte((code>>8)&0xFF))
buf = append(buf, byte(code&0xFF))
} else {
buf = append(buf, byte(len(m.Route)))
buf = append(buf, []byte(m.Route)...)
}
}
buf = append(buf, m.Data...)
return buf, nil
}
func (n *NNetCodec) Decode(data []byte) (interface{}, error) {
if len(data) < msgHeadLength {
return nil, ErrInvalidMessage
}
m := New()
flag := data[0]
offset := 1
m.Type = Type((flag >> 1) & msgTypeMask) // 编译器提示,此处Type转换不能删
if n.invalidType(m.Type) {
return nil, ErrWrongMessageType
}
if m.Type == Request || m.Type == Response {
id := uint64(0)
// little end byte order
// WARNING: must can be stored in 64 bits integer
// variant length encode
for i := offset; i < len(data); i++ {
b := data[i]
id += uint64(b&0x7F) << uint64(7*(i-offset))
if b < 128 {
offset = i + 1
break
}
}
m.ID = id
}
if offset >= len(data) {
return nil, ErrWrongMessage
}
if n.routable(m.Type) {
if flag&msgRouteCompressMask == 1 {
m.compressed = true
code := binary.BigEndian.Uint16(data[offset:(offset + 2)])
route, ok := codes[code]
if !ok {
return nil, ErrRouteInfoNotFound
}
m.Route = route
offset += 2
} else {
m.compressed = false
rl := data[offset]
offset++
if offset+int(rl) > len(data) {
return nil, ErrWrongMessage
}
m.Route = string(data[offset:(offset + int(rl))])
offset += int(rl)
}
}
if offset > len(data) {
return nil, ErrWrongMessage
}
m.Data = data[offset:]
return m, nil
}

@ -1,11 +0,0 @@
package message
type (
// Codec 消息编解码器
Codec interface {
// Encode 编码
Encode(v interface{}) ([]byte, error)
// Decode 解码
Decode(data []byte) (interface{}, error)
}
)

@ -1,46 +0,0 @@
package message
import (
"fmt"
)
// Type represents the type of message, which could be Request/Notify/Response/Push
type Type byte
// Message types
const (
Request Type = 0x00
Notify = 0x01
Response = 0x02
Push = 0x03
)
var types = map[Type]string{
Request: "Request",
Notify: "Notify",
Response: "Response",
Push: "Push",
}
func (t Type) String() string {
return types[t]
}
// Message represents an unmarshaler message or a message which to be marshaled
type Message struct {
Type Type // message type (flag)
ID uint64 // unique id, zero while notify mode
Route string // route for locating service
Data []byte // payload
compressed bool // if message compressed
}
// New returns a new message instance
func New() *Message {
return &Message{}
}
// String, implementation of fmt.Stringer interface
func (m *Message) String() string {
return fmt.Sprintf("%s %s (%dbytes)", types[m.Type], m.Route, len(m.Data))
}

@ -1,31 +0,0 @@
package nface
import "net"
const (
// StatusStart 开始阶段
StatusStart int32 = iota + 1
// StatusPrepare 准备阶段
StatusPrepare
// StatusWorking 工作阶段
StatusWorking
// StatusClosed 连接关闭
StatusClosed
)
type IConnection interface {
// Server 获取Server实例
Server() IServer
// Status 获取连接状态
Status() int32
// SetStatus 设置连接状态
SetStatus(s int32)
// Conn 获取底层网络连接
Conn() net.Conn
// ID 获取连接ID
ID() int64
// Session 获取当前连接绑定的Session
Session() ISession
// Close 关闭连接
Close() error
}

@ -1,5 +0,0 @@
package nface
// IRouter 路由接口
type IRouter interface {
}

@ -1,4 +0,0 @@
package nface
type IServer interface {
}

@ -1,31 +0,0 @@
package nface
// ISessionAttribute Session数据接口
type ISessionAttribute interface {
// Attribute 获取指定key的值
Attribute(key string) interface{}
// Keys 获取所有Key
Keys() []string
// Exists key是否存在
Exists(key string) bool
// Attributes 获取所有数据
Attributes() map[string]interface{}
// RemoveAttribute 通过key移除数据
RemoveAttribute(key string)
// SetAttribute 设置数据
SetAttribute(key string, value interface{})
// Invalidate 使当前Session无效并且解除所有与之绑定的对象
Invalidate()
}
// ISession session接口
type ISession interface {
// ID Session ID
ID() int64
// UID 用户自行绑定UID,默认与SessionID一致
UID() string
// Bind 绑定用户ID
Bind(uid string)
// ISessionAttribute Session数据抽象方法
ISessionAttribute
}

@ -1,187 +0,0 @@
package nnet
import (
"errors"
"git.noahlan.cn/northlan/nnet/log"
"git.noahlan.cn/northlan/nnet/nface"
"git.noahlan.cn/northlan/nnet/packet"
"git.noahlan.cn/northlan/nnet/pipeline"
"git.noahlan.cn/northlan/nnet/session"
"net"
"sync/atomic"
"time"
)
var (
_ nface.IConnection = (*Connection)(nil)
ErrCloseClosedSession = errors.New("close closed session")
// ErrBrokenPipe represents the low-level connection has broken.
ErrBrokenPipe = errors.New("broken low-level pipe")
// ErrBufferExceed indicates that the current session buffer is full and
// can not receive more data.
ErrBufferExceed = errors.New("session send buffer exceed")
)
type (
Connection struct {
session nface.ISession // Session
server *Server // Server 引用
conn net.Conn // low-level conn fd
status int32 // 连接状态
lastMid uint64 // 最近一次消息ID
lastHeartbeatAt int64 // 最近一次心跳时间
chDie chan struct{} // 停止通道
chSend chan pendingMessage // 消息发送通道
pipeline pipeline.Pipeline // 消息管道
}
pendingMessage struct {
typ interface{} // message type
route string // message route
mid uint64 // response message id
payload interface{} // payload
}
)
func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) *Connection {
r := &Connection{
conn: conn,
server: server,
status: nface.StatusStart,
lastHeartbeatAt: time.Now().Unix(),
chDie: make(chan struct{}),
chSend: make(chan pendingMessage, 512),
pipeline: pipeline,
}
// binding session
r.session = session.New()
return r
}
func (r *Connection) send(m pendingMessage) (err error) {
defer func() {
if e := recover(); e != nil {
err = ErrBrokenPipe
}
}()
r.chSend <- m
return err
}
func (r *Connection) Server() nface.IServer {
return r.server
}
func (r *Connection) Status() int32 {
return atomic.LoadInt32(&r.status)
}
func (r *Connection) SetStatus(s int32) {
atomic.StoreInt32(&r.status, s)
}
func (r *Connection) Conn() net.Conn {
return r.conn
}
func (r *Connection) ID() int64 {
return r.session.ID()
}
func (r *Connection) setLastHeartbeatAt(t int64) {
atomic.StoreInt64(&r.lastHeartbeatAt, t)
}
func (r *Connection) Session() nface.ISession {
return r.session
}
func (r *Connection) write() {
ticker := time.NewTicker(r.server.HeartbeatInterval)
chWrite := make(chan []byte, 1024)
defer func() {
ticker.Stop()
close(r.chSend)
close(chWrite)
_ = r.Close()
log.Debugf("Connection write goroutine exit, ConnID=%d, SessionUID=%d", r.ID(), r.session.UID())
}()
for {
select {
case <-ticker.C:
// TODO heartbeat enable control
deadline := time.Now().Add(-2 * r.server.HeartbeatInterval).Unix()
if atomic.LoadInt64(&r.lastHeartbeatAt) < deadline {
log.Debugf("Session heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&r.lastHeartbeatAt), deadline)
return
}
// TODO heartbeat data
chWrite <- []byte{}
case data := <-r.chSend:
// marshal packet body (data)
payload, err := r.server.Serializer.Marshal(data.payload)
if err != nil {
log.Errorf("message body marshal err: %v", err)
break
}
// TODO new message and pipeline
if pipe := r.pipeline; pipe != nil {
err := pipe.Outbound().Process(r)
if err != nil {
log.Errorf("broken pipeline err: %s", err.Error())
break
}
}
// TODO encode message ? message processor ?
// packet pack data
p, err := r.server.Packer.Pack(packet.Data, payload)
if err != nil {
log.Error(err.Error())
break
}
chWrite <- p
case data := <-chWrite:
// 回写数据
if _, err := r.conn.Write(data); err != nil {
log.Error(err.Error())
return
}
case <-r.chDie: // connection close signal
return
case <-r.server.DieChan: // application quit signal
return
}
}
}
func (r *Connection) Close() error {
if r.Status() == nface.StatusClosed {
return ErrCloseClosedSession
}
r.SetStatus(nface.StatusClosed)
log.Debugf("close connection, ID: %d", r.ID())
select {
case <-r.chDie:
default:
close(r.chDie)
// TODO lifetime
}
return r.conn.Close()
}

@ -1,113 +0,0 @@
package nnet
import (
"fmt"
"git.noahlan.cn/northlan/nnet/component"
"git.noahlan.cn/northlan/nnet/log"
"git.noahlan.cn/northlan/nnet/packet"
"git.noahlan.cn/northlan/nnet/pipeline"
"github.com/gorilla/websocket"
"net"
"time"
)
type Handler struct {
server *Server // Server 引用
pipeline pipeline.Pipeline // 通道
processor packet.Processor // 数据包处理器
allServices map[string]*component.Service // 所有注册的Service
allHandlers map[string]*component.Handler // 所有注册的Handler
}
func NewHandler(server *Server, pipeline pipeline.Pipeline, processor packet.Processor) *Handler {
return &Handler{
server: server,
pipeline: pipeline,
processor: processor,
allServices: make(map[string]*component.Service),
allHandlers: make(map[string]*component.Handler),
}
}
func (h *Handler) register(comp component.Component, opts []component.Option) error {
s := component.NewService(comp, opts)
if _, ok := h.allServices[s.Name]; ok {
return fmt.Errorf("handler: service already defined: %s", s.Name)
}
if err := s.ExtractHandler(); err != nil {
return err
}
h.allServices[s.Name] = s
// 拷贝一份所有handlers
for name, handler := range s.Handlers {
handleName := fmt.Sprintf("%s.%s", s.Name, name)
log.Debugf("register handler %s", handleName)
h.allHandlers[handleName] = handler
}
return nil
}
func (h *Handler) handleWS(conn *websocket.Conn) {
c, err := newWSConn(conn)
if err != nil {
log.Error(err)
return
}
h.handle(c)
}
func (h *Handler) handle(conn net.Conn) {
connection := newConnection(h.server, conn, h.pipeline)
h.server.sessionMgr.StoreSession(connection.Session())
_ = pool.SubmitConn(func() {
h.writeLoop(connection)
})
_ = pool.SubmitWorker(func() {
h.readLoop(connection)
})
// hook
}
func (h *Handler) writeLoop(conn *Connection) {
conn.write()
}
func (h *Handler) readLoop(conn *Connection) {
buf := make([]byte, 4096)
for {
n, err := conn.Conn().Read(buf)
if err != nil {
log.Errorf("Read message error: %s, session will be closed immediately", err.Error())
return
}
// warning: 为性能考虑复用slice处理数据buf传入后必须要copy再处理
packets, err := h.server.Packer.Unpack(buf[:n])
if err != nil {
log.Error(err.Error())
}
// packets 处理
for _, p := range packets {
if err := h.processPackets(conn, p); err != nil {
log.Error(err.Error())
return
}
}
}
}
func (h *Handler) processPackets(conn *Connection, packets interface{}) error {
err := h.processor.ProcessPacket(conn, packets)
conn.setLastHeartbeatAt(time.Now().Unix())
return err
}

@ -1,82 +0,0 @@
package nnet
import (
"git.noahlan.cn/northlan/nnet/component"
"git.noahlan.cn/northlan/nnet/env"
"git.noahlan.cn/northlan/nnet/log"
"git.noahlan.cn/northlan/nnet/pipeline"
"net/http"
"time"
)
type Option func(options *Options)
type WSOption func(opts *WSOptions)
func WithLogger(logger log.Logger) Option {
return func(_ *Options) {
log.SetLogger(logger)
}
}
func WithPipeline(pipeline pipeline.Pipeline) Option {
return func(options *Options) {
options.Pipeline = pipeline
}
}
func WithComponents(components *component.Components) Option {
return func(options *Options) {
options.Components = components
}
}
func WithHeartbeatInterval(d time.Duration) Option {
return func(options *Options) {
options.HeartbeatInterval = d
}
}
// WithTimerPrecision 设置Timer精度
// 注精度需大于1ms, 并且不能在运行时更改
// 默认精度是 time.Second
func WithTimerPrecision(precision time.Duration) Option {
if precision < time.Millisecond {
panic("time precision can not less than a Millisecond")
}
return func(_ *Options) {
env.TimerPrecision = precision
}
}
// WithWebsocket 开启Websocket, 参数是websocket的相关参数 nnet.WSOption
func WithWebsocket(wsOpts ...WSOption) Option {
return func(options *Options) {
for _, opt := range wsOpts {
opt(&options.WS)
}
options.WS.IsWebsocket = true
}
}
// WithWSPath 设置websocket的path
func WithWSPath(path string) WSOption {
return func(opts *WSOptions) {
opts.WebsocketPath = path
}
}
// WithWSTLSConfig 设置websocket的证书和密钥
func WithWSTLSConfig(certificate, key string) WSOption {
return func(opts *WSOptions) {
opts.TLSCertificate = certificate
opts.TLSKey = key
}
}
func WithWSCheckOriginFunc(fn func(*http.Request) bool) WSOption {
return func(opts *WSOptions) {
if fn != nil {
opts.CheckOrigin = fn
}
}
}

@ -1,27 +0,0 @@
package nnet
import "github.com/panjf2000/ants/v2"
var pool *Pool
type Pool struct {
connPool *ants.Pool
workerPool *ants.Pool
}
func initPool(size int) {
p := &Pool{}
p.connPool, _ = ants.NewPool(size, ants.WithNonblocking(true))
p.workerPool, _ = ants.NewPool(size*2, ants.WithNonblocking(true))
pool = p
}
func (p *Pool) SubmitConn(h func()) error {
return p.connPool.Submit(h)
}
func (p *Pool) SubmitWorker(h func()) error {
return p.workerPool.Submit(h)
}

@ -1,230 +0,0 @@
package nnet
import (
"errors"
"fmt"
"git.noahlan.cn/northlan/nnet/component"
"git.noahlan.cn/northlan/nnet/log"
"git.noahlan.cn/northlan/nnet/packet"
"git.noahlan.cn/northlan/nnet/pipeline"
"git.noahlan.cn/northlan/nnet/scheduler"
"git.noahlan.cn/northlan/nnet/serialize"
"git.noahlan.cn/northlan/nnet/session"
"github.com/gorilla/websocket"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
type (
Options struct {
Name string // 服务端名默认为n-net
Pipeline pipeline.Pipeline // 消息管道
RetryInterval time.Duration // 消息重试间隔时长
Components *component.Components // 组件库
Packer packet.Packer // 封包、拆包器
PacketProcessor packet.Processor // 数据包处理器
Serializer serialize.Serializer // 消息 序列化/反序列化
HeartbeatInterval time.Duration // 心跳间隔0表示不进行心跳
WS WSOptions // websocket
}
WSOptions struct {
IsWebsocket bool // 是否为websocket服务端
WebsocketPath string // ws地址(ws://127.0.0.1/WebsocketPath)
TLSCertificate string // TLS 证书地址 (websocket)
TLSKey string // TLS 证书key地址 (websocket)
CheckOrigin func(*http.Request) bool // check origin
}
)
// Server TCP-Server
type Server struct {
// Options 参数列表
Options
// protocol 协议名
// "tcp", "tcp4", "tcp6", "unix" or "unixpacket"
// 若只想开启IPv4, 使用tcp4即可
protocol string
// address 服务地址
// 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP
// 如果端口号不填或端口号为0例如"127.0.0.1:" 或 ":0",服务端将选择随机可用端口
address string
// DieChan 应用程序退出信号
DieChan chan struct{}
// handler 消息处理器
handler *Handler
// sessionMgr session管理器
sessionMgr *session.Manager
}
func NewServer(protocol, addr string, opts ...Option) *Server {
options := Options{
Components: &component.Components{},
WS: WSOptions{
CheckOrigin: func(_ *http.Request) bool { return true },
},
Packer: packet.NewNNetPacker(),
PacketProcessor: packet.NewNNetProcessor(),
}
s := &Server{
Options: options,
protocol: protocol,
address: addr,
DieChan: make(chan struct{}),
}
for _, opt := range opts {
opt(&s.Options)
}
s.handler = NewHandler(s, s.Options.Pipeline, s.Options.PacketProcessor)
s.sessionMgr = session.NewManager()
initPool(0)
return s
}
func (s *Server) Serve() {
components := s.Components.List()
for _, c := range components {
err := s.handler.register(c.Comp, c.Opts)
if err != nil {
// TODO Log and return
return
}
}
// Initialize components
for _, c := range components {
c.Comp.OnInit()
}
go func() {
if s.WS.IsWebsocket {
if len(s.WS.TLSCertificate) != 0 && len(s.WS.TLSKey) != 0 {
s.listenAndServeWSTLS()
} else {
s.listenAndServeWS()
}
} else {
s.listenAndServe()
}
}()
go scheduler.Schedule()
sg := make(chan os.Signal)
signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM)
select {
case <-s.DieChan:
log.Info("The server will shutdown in a few seconds")
case s := <-sg:
log.Infof("server got signal: %s", s)
}
log.Info("server is stopping...")
s.shutdown()
scheduler.Close()
}
func (s *Server) Close() {
close(s.DieChan)
}
func (s *Server) shutdown() {
components := s.Components.List()
compLen := len(components)
for i := compLen - 1; i >= 0; i-- {
components[i].Comp.OnShutdown()
}
}
func (s *Server) listenAndServe() {
listener, err := net.Listen(s.protocol, s.address)
if err != nil {
panic(err)
}
// 监听成功,服务已启动
log.Info("listening...")
defer func() {
listener.Close()
s.Close()
}()
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Errorf("服务器网络错误 %+v", err)
return
}
log.Errorf("监听错误 %v", err)
continue
}
err = pool.SubmitConn(func() {
s.handler.handle(conn)
})
if err != nil {
log.Errorf("submit conn pool err: %s", err.Error())
continue
}
}
}
func (s *Server) listenAndServeWS() {
upgrade := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: s.WS.CheckOrigin,
}
http.HandleFunc(fmt.Sprintf("/%s/", s.WS.WebsocketPath), func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrade.Upgrade(w, r, nil)
if err != nil {
log.Errorf("Upgrade failure, URI=%s, Error=%s", r.RequestURI, err.Error())
return
}
err = pool.SubmitConn(func() {
s.handler.handleWS(conn)
})
if err != nil {
log.Fatalf("submit conn pool err: %s", err.Error())
}
})
if err := http.ListenAndServe(s.address, nil); err != nil {
log.Fatal(err.Error())
}
}
func (s *Server) listenAndServeWSTLS() {
upgrade := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: s.WS.CheckOrigin,
}
http.HandleFunc(fmt.Sprintf("/%s/", s.WS.WebsocketPath), func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrade.Upgrade(w, r, nil)
if err != nil {
log.Errorf("Upgrade failure, URI=%s, Error=%s", r.RequestURI, err.Error())
return
}
err = pool.SubmitConn(func() {
s.handler.handleWS(conn)
})
if err != nil {
log.Fatalf("submit conn pool err: %s", err.Error())
}
})
if err := http.ListenAndServeTLS(s.address, s.WS.TLSCertificate, s.WS.TLSKey, nil); err != nil {
log.Fatal(err.Error())
}
}

@ -1,9 +0,0 @@
package nnet
import "testing"
func TestServer(t *testing.T) {
server := NewServer("tcp4", ":22112")
server.Serve()
}

@ -1,25 +0,0 @@
package packet
import (
"git.noahlan.cn/northlan/nnet/nface"
)
// Type 数据帧类型,如:握手,心跳,数据 等
type Type byte
type (
Packer interface {
// Pack 从原始raw bytes创建一个用于网络传输的 数据帧结构
Pack(typ Type, data []byte) ([]byte, error)
// Unpack 解包
Unpack(data []byte) ([]interface{}, error)
}
// Processor 数据帧处理器,拆包之后的处理
Processor interface {
// ProcessPacket 单个数据包处理方法
// packet 为实际数据包,是 packet.Packer 的Unpack方法拆包出来的数据指针
ProcessPacket(conn nface.IConnection, packet interface{}) error
}
)

@ -1,123 +0,0 @@
package packet
import (
"bytes"
"errors"
)
var _ Packer = (*NNetPacker)(nil)
type NNetPacker struct {
buf *bytes.Buffer
size int // 最近一次 长度
typ byte // 最近一次 数据帧类型
}
// Codec constants.
const (
headLength = 4
maxPacketSize = 64 * 1024
)
var ErrPacketSizeExceed = errors.New("codec: packet size exceed")
func NewNNetPacker() Packer {
return &NNetPacker{
buf: bytes.NewBuffer(nil),
size: -1,
}
}
func (d *NNetPacker) Pack(typ Type, data []byte) ([]byte, error) {
if typ < Handshake || typ > Kick {
return nil, ErrWrongPacketType
}
p := &Packet{Type: typ, Length: uint32(len(data))}
buf := make([]byte, p.Length+headLength)
// header
buf[0] = byte(p.Type)
copy(buf[1:headLength], d.intToBytes(p.Length))
// body
copy(buf[headLength:], data)
return buf, nil
}
// Encode packet data length to bytes(Big end)
func (d *NNetPacker) intToBytes(n uint32) []byte {
buf := make([]byte, 3)
buf[0] = byte((n >> 16) & 0xFF)
buf[1] = byte((n >> 8) & 0xFF)
buf[2] = byte(n & 0xFF)
return buf
}
func (d *NNetPacker) Unpack(data []byte) ([]interface{}, error) {
d.buf.Write(data) // copy
var (
packets []interface{}
err error
)
// 检查包长度
if d.buf.Len() < headLength {
return nil, err
}
// 第一次拆包
if d.size < 0 {
if err = d.readHeader(); err != nil {
return nil, err
}
}
for d.size <= d.buf.Len() {
// 读取
p := &Packet{
Type: Type(d.typ),
Length: uint32(d.size),
Data: d.buf.Next(d.size),
}
packets = append(packets, p)
// 剩余数据不满足至少一个数据帧,重置数据帧长度
// 数据缓存内存 保留至 下一次进入本方法以继续拆包
if d.buf.Len() < headLength {
d.size = -1
break
}
// 读取下一个包 next
if err = d.readHeader(); err != nil {
return packets, err
}
}
return packets, nil
}
func (d *NNetPacker) readHeader() error {
header := d.buf.Next(headLength)
d.typ = header[0]
if d.typ < Handshake || d.typ > Kick {
return ErrWrongPacketType
}
d.size = d.bytesToInt(header[1:])
// 最大包限定
if d.size > maxPacketSize {
return ErrPacketSizeExceed
}
return nil
}
// Decode packet data length byte to int(Big end)
func (d *NNetPacker) bytesToInt(b []byte) int {
result := 0
for _, v := range b {
result = result<<8 + int(v)
}
return result
}

@ -1,40 +0,0 @@
package packet
import (
"errors"
)
const (
// Default 默认,暂无意义
Default Type = iota
// Handshake 握手数据(服务端主动发起)
Handshake = 0x01
// HandshakeAck 握手回复(客户端回复)
HandshakeAck = 0x02
// Heartbeat 心跳(服务端发起)
Heartbeat = 0x03
// Data 数据传输
Data = 0x04
// Kick 服务端主动断开连接
Kick = 0x05
)
// ErrWrongPacketType represents a wrong packet type.
var ErrWrongPacketType = errors.New("wrong packet type")
type Packet struct {
Type Type // 数据帧 类型
Length uint32 // 数据长度
Data []byte // 原始数据
}
func New() *Packet {
return &Packet{
Type: Default,
}
}

@ -1,40 +0,0 @@
package packet
import (
"fmt"
"git.noahlan.cn/northlan/nnet/log"
"git.noahlan.cn/northlan/nnet/nface"
)
type NNetProcessor struct{}
func NewNNetProcessor() Processor {
return &NNetProcessor{}
}
func (d *NNetProcessor) ProcessPacket(conn nface.IConnection, packet interface{}) error {
p := packet.(*Packet)
switch p.Type {
case Handshake:
// TODO validate handshake
if _, err := conn.Conn().Write([]byte{}); err != nil {
return err
}
conn.SetStatus(nface.StatusPrepare)
log.Debugf("Connection handshake Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr())
case HandshakeAck:
conn.SetStatus(nface.StatusWorking)
log.Debugf("Receive handshake ACK Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr())
case Data:
if conn.Status() < nface.StatusWorking {
return fmt.Errorf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s",
conn.Conn().RemoteAddr())
}
// TODO message data 处理
case Heartbeat:
// expected
}
return nil
}

@ -1,8 +1,8 @@
package scheduler package scheduler
import ( import (
"git.noahlan.cn/northlan/nnet/env" "git.noahlan.cn/northlan/nnet/internal/env"
"git.noahlan.cn/northlan/nnet/log" "git.noahlan.cn/northlan/nnet/internal/log"
"runtime/debug" "runtime/debug"
"sync/atomic" "sync/atomic"
"time" "time"

@ -1,7 +1,7 @@
package scheduler package scheduler
import ( import (
"git.noahlan.cn/northlan/nnet/log" "git.noahlan.cn/northlan/nnet/internal/log"
"math" "math"
"runtime/debug" "runtime/debug"
"sync" "sync"

@ -2,7 +2,7 @@ package json
import ( import (
"encoding/json" "encoding/json"
"git.noahlan.cn/northlan/nnet/serialize" "git.noahlan.cn/northlan/nnet/core/serialize"
) )
type Serializer struct{} type Serializer struct{}

@ -2,7 +2,7 @@ package protobuf
import ( import (
"errors" "errors"
"git.noahlan.cn/northlan/nnet/serialize" "git.noahlan.cn/northlan/nnet/core/serialize"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )

@ -0,0 +1,203 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc v3.19.4
// source: test.proto
package testdata
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type Ping struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Content string `protobuf:"bytes,1,opt,name=Content,proto3" json:"Content,omitempty"`
}
func (x *Ping) Reset() {
*x = Ping{}
if protoimpl.UnsafeEnabled {
mi := &file_test_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Ping) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Ping) ProtoMessage() {}
func (x *Ping) ProtoReflect() protoreflect.Message {
mi := &file_test_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Ping.ProtoReflect.Descriptor instead.
func (*Ping) Descriptor() ([]byte, []int) {
return file_test_proto_rawDescGZIP(), []int{0}
}
func (x *Ping) GetContent() string {
if x != nil {
return x.Content
}
return ""
}
type Pong struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Content string `protobuf:"bytes,1,opt,name=Content,proto3" json:"Content,omitempty"`
}
func (x *Pong) Reset() {
*x = Pong{}
if protoimpl.UnsafeEnabled {
mi := &file_test_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Pong) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Pong) ProtoMessage() {}
func (x *Pong) ProtoReflect() protoreflect.Message {
mi := &file_test_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Pong.ProtoReflect.Descriptor instead.
func (*Pong) Descriptor() ([]byte, []int) {
return file_test_proto_rawDescGZIP(), []int{1}
}
func (x *Pong) GetContent() string {
if x != nil {
return x.Content
}
return ""
}
var File_test_proto protoreflect.FileDescriptor
var file_test_proto_rawDesc = []byte{
0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x74, 0x65,
0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x22, 0x20, 0x0a, 0x04, 0x50, 0x69, 0x6e, 0x67, 0x12, 0x18,
0x0a, 0x07, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x07, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x22, 0x20, 0x0a, 0x04, 0x50, 0x6f, 0x6e, 0x67,
0x12, 0x18, 0x0a, 0x07, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28,
0x09, 0x52, 0x07, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x42, 0x0b, 0x5a, 0x09, 0x2f, 0x74,
0x65, 0x73, 0x74, 0x64, 0x61, 0x74, 0x61, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_test_proto_rawDescOnce sync.Once
file_test_proto_rawDescData = file_test_proto_rawDesc
)
func file_test_proto_rawDescGZIP() []byte {
file_test_proto_rawDescOnce.Do(func() {
file_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_proto_rawDescData)
})
return file_test_proto_rawDescData
}
var file_test_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_test_proto_goTypes = []interface{}{
(*Ping)(nil), // 0: testdata.Ping
(*Pong)(nil), // 1: testdata.Pong
}
var file_test_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_test_proto_init() }
func file_test_proto_init() {
if File_test_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Ping); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_test_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Pong); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_test_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_test_proto_goTypes,
DependencyIndexes: file_test_proto_depIdxs,
MessageInfos: file_test_proto_msgTypes,
}.Build()
File_test_proto = out.File
file_test_proto_rawDesc = nil
file_test_proto_goTypes = nil
file_test_proto_depIdxs = nil
}

@ -1,13 +1,10 @@
package session package session
import ( import (
"git.noahlan.cn/northlan/nnet/nface"
"sync" "sync"
"sync/atomic" "sync/atomic"
) )
var _ nface.ISession = (*Session)(nil)
type Session struct { type Session struct {
sync.RWMutex // 数据锁 sync.RWMutex // 数据锁
@ -16,7 +13,7 @@ type Session struct {
data map[string]interface{} // session数据存储内存 data map[string]interface{} // session数据存储内存
} }
func New() nface.ISession { func NewSession() *Session {
return &Session{ return &Session{
id: sessionIDMgrInstance.SessionID(), id: sessionIDMgrInstance.SessionID(),
uid: "", uid: "",

@ -1,43 +1,42 @@
package session package session
import ( import (
"git.noahlan.cn/northlan/nnet/nface"
"sync" "sync"
) )
type Manager struct { type Manager struct {
sync.RWMutex sync.RWMutex
sessions map[int64]nface.ISession sessions map[int64]*Session
} }
func NewManager() *Manager { func NewSessionMgr() *Manager {
return &Manager{ return &Manager{
RWMutex: sync.RWMutex{}, RWMutex: sync.RWMutex{},
sessions: make(map[int64]nface.ISession), sessions: make(map[int64]*Session),
} }
} }
func (m *Manager) StoreSession(s nface.ISession) { func (m *Manager) StoreSession(s *Session) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
m.sessions[s.ID()] = s m.sessions[s.ID()] = s
} }
func (m *Manager) FindSession(sid int64) nface.ISession { func (m *Manager) FindSession(sid int64) *Session {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()
return m.sessions[sid] return m.sessions[sid]
} }
func (m *Manager) FindOrCreateSession(sid int64) nface.ISession { func (m *Manager) FindOrCreateSession(sid int64) *Session {
m.RLock() m.RLock()
s, ok := m.sessions[sid] s, ok := m.sessions[sid]
m.RUnlock() m.RUnlock()
if !ok { if !ok {
s = New() s = NewSession()
m.Lock() m.Lock()
m.sessions[s.ID()] = s m.sessions[s.ID()] = s

Loading…
Cancel
Save