parent
							
								
									c65fd5961b
								
							
						
					
					
						commit
						ebcbd0f88f
					
				@ -0,0 +1,23 @@
 | 
			
		||||
package nnet
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/connection"
 | 
			
		||||
	"git.noahlan.cn/noahlan/ntools-go/core/nlog"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Dial 连接服务器
 | 
			
		||||
func (ngin *Engine) Dial(addr string) (*connection.Connection, error) {
 | 
			
		||||
	err := ngin.setup()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		nlog.Errorf("%s failed to setup server, err:%v", ngin.LogPrefix(), err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn, err := net.Dial("tcp", addr)
 | 
			
		||||
	nlog.Must(err)
 | 
			
		||||
 | 
			
		||||
	nlog.Infof("%s now connect to %s...", ngin.LogPrefix(), addr)
 | 
			
		||||
 | 
			
		||||
	return ngin.handle(conn), nil
 | 
			
		||||
}
 | 
			
		||||
@ -1,286 +0,0 @@
 | 
			
		||||
package core
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/config"
 | 
			
		||||
	conn2 "git.noahlan.cn/noahlan/nnet/conn"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/entity"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/lifetime"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/packet"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/pipeline"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/scheduler"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/serialize"
 | 
			
		||||
	"git.noahlan.cn/noahlan/ntools-go/core/nlog"
 | 
			
		||||
	"git.noahlan.cn/noahlan/ntools-go/core/pool"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NotFound(conn entity.NetworkEntity, _ packet.IPacket) {
 | 
			
		||||
	nlog.Error("handler not found")
 | 
			
		||||
	_ = conn.SendBytes([]byte("handler not found"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NotFoundHandler() Handler {
 | 
			
		||||
	return HandlerFunc(NotFound)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type (
 | 
			
		||||
	// engine TCP-engine
 | 
			
		||||
	engine struct {
 | 
			
		||||
		conf               config.EngineConf // conf 配置
 | 
			
		||||
		taskTimerPrecision time.Duration
 | 
			
		||||
 | 
			
		||||
		middlewares []Middleware // 中间件
 | 
			
		||||
		routes      []Route      // 路由
 | 
			
		||||
		// handler 消息处理器
 | 
			
		||||
		handler Handler
 | 
			
		||||
		// dieChan 应用程序退出信号
 | 
			
		||||
		dieChan chan struct{}
 | 
			
		||||
 | 
			
		||||
		pipeline pipeline.Pipeline // 消息管道
 | 
			
		||||
		lifetime *lifetime.Mgr     // 连接的生命周期管理器
 | 
			
		||||
 | 
			
		||||
		packerFn   packet.NewPackerFunc // 封包、拆包器
 | 
			
		||||
		serializer serialize.Serializer // 消息 序列化/反序列化
 | 
			
		||||
 | 
			
		||||
		wsOpt wsOptions // websocket
 | 
			
		||||
 | 
			
		||||
		connManager *conn2.Manager
 | 
			
		||||
		sessIdMgr   *sessionIDMgr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wsOptions struct {
 | 
			
		||||
		IsWebsocket    bool                     // 是否为websocket服务端
 | 
			
		||||
		WebsocketPath  string                   // ws地址(ws://127.0.0.1/WebsocketPath)
 | 
			
		||||
		TLSCertificate string                   // TLS 证书地址 (websocket)
 | 
			
		||||
		TLSKey         string                   // TLS 证书key地址 (websocket)
 | 
			
		||||
		CheckOrigin    func(*http.Request) bool // check origin
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func newEngine(conf config.EngineConf) *engine {
 | 
			
		||||
	s := &engine{
 | 
			
		||||
		conf:               conf,
 | 
			
		||||
		dieChan:            make(chan struct{}),
 | 
			
		||||
		pipeline:           pipeline.New(),
 | 
			
		||||
		middlewares:        make([]Middleware, 0),
 | 
			
		||||
		routes:             make([]Route, 0),
 | 
			
		||||
		taskTimerPrecision: conf.TaskTimerPrecision,
 | 
			
		||||
		connManager:        conn2.NewManager(),
 | 
			
		||||
		sessIdMgr:          newSessionIDMgr(),
 | 
			
		||||
		lifetime:           lifetime.NewLifetime(),
 | 
			
		||||
	}
 | 
			
		||||
	pool.InitPool(conf.Pool)
 | 
			
		||||
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) shallLogDebug() bool {
 | 
			
		||||
	return config.ShallLogDebug(ng.conf.Mode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) logPrefix() string {
 | 
			
		||||
	return fmt.Sprintf("[NNet-%s]", ng.conf.Name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) use(middleware ...Middleware) {
 | 
			
		||||
	ng.middlewares = append(ng.middlewares, middleware...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) addRoutes(route ...Route) {
 | 
			
		||||
	ng.routes = append(ng.routes, route...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) bindRoutes(router Router) error {
 | 
			
		||||
	for _, fr := range ng.routes {
 | 
			
		||||
		if err := ng.bindRoute(router, fr); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) bindRoute(router Router, route Route) error {
 | 
			
		||||
	// TODO default middleware
 | 
			
		||||
	chain := newChain()
 | 
			
		||||
	// build chain
 | 
			
		||||
	for _, middleware := range ng.middlewares {
 | 
			
		||||
		chain.Append(convertMiddleware(middleware))
 | 
			
		||||
	}
 | 
			
		||||
	return router.Register(route.Matches, route.Handler)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func convertMiddleware(ware Middleware) func(Handler) Handler {
 | 
			
		||||
	return func(next Handler) Handler {
 | 
			
		||||
		return ware(next.Handle)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) dial(addr string, router Router) (entity.NetworkEntity, error) {
 | 
			
		||||
	ng.handler = router
 | 
			
		||||
 | 
			
		||||
	if err := ng.bindRoutes(router); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	go scheduler.Schedule(ng.taskTimerPrecision)
 | 
			
		||||
 | 
			
		||||
	// connection
 | 
			
		||||
	conn, err := net.Dial("tcp", addr)
 | 
			
		||||
	nlog.Must(err)
 | 
			
		||||
 | 
			
		||||
	c := newConnection(ng, conn)
 | 
			
		||||
	c.serve()
 | 
			
		||||
	// hook
 | 
			
		||||
	ng.lifetime.Open(c)
 | 
			
		||||
	// connection manager
 | 
			
		||||
	err = ng.connManager.Store(conn2.DefaultGroupName, c)
 | 
			
		||||
	nlog.Must(err)
 | 
			
		||||
 | 
			
		||||
	// 连接成功,客户端已启动
 | 
			
		||||
	if ng.shallLogDebug() {
 | 
			
		||||
		nlog.Debugf("now connect to %s.", addr)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) serve(router Router) error {
 | 
			
		||||
	ng.handler = router
 | 
			
		||||
 | 
			
		||||
	if err := ng.bindRoutes(router); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	go scheduler.Schedule(ng.taskTimerPrecision)
 | 
			
		||||
	defer func() {
 | 
			
		||||
		nlog.Infof("%s is stopping...", ng.logPrefix())
 | 
			
		||||
 | 
			
		||||
		ng.shutdown()
 | 
			
		||||
		scheduler.Close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	if ng.wsOpt.IsWebsocket {
 | 
			
		||||
		if len(ng.wsOpt.TLSCertificate) != 0 && len(ng.wsOpt.TLSKey) != 0 {
 | 
			
		||||
			ng.listenAndServeWSTLS()
 | 
			
		||||
		} else {
 | 
			
		||||
			ng.listenAndServeWS()
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		ng.listenAndServe()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) close() {
 | 
			
		||||
	close(ng.dieChan)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) shutdown() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) listenAndServe() {
 | 
			
		||||
	listener, err := net.Listen(ng.conf.Protocol, ng.conf.Addr)
 | 
			
		||||
	nlog.Must(err)
 | 
			
		||||
 | 
			
		||||
	// 监听成功,服务已启动
 | 
			
		||||
	if ng.shallLogDebug() {
 | 
			
		||||
		nlog.Debugf("%s now listening %s at %s.", ng.logPrefix(), ng.conf.Protocol, ng.conf.Addr)
 | 
			
		||||
	}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = listener.Close()
 | 
			
		||||
		ng.close()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		conn, err := listener.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if errors.Is(err, net.ErrClosed) {
 | 
			
		||||
				nlog.Errorf("%s 服务器网络错误 %+v", ng.logPrefix(), err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			nlog.Errorf("%s 监听错误 %v", ng.logPrefix(), err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = pool.Submit(func() {
 | 
			
		||||
			ng.handle(conn)
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			nlog.Errorf("%s submit conn pool err: %ng", ng.logPrefix(), err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) listenAndServeWS() {
 | 
			
		||||
	ng.setupWS()
 | 
			
		||||
	if ng.shallLogDebug() {
 | 
			
		||||
		nlog.Debugf("%s now listening websocket at %s.", ng.logPrefix(), ng.conf.Addr)
 | 
			
		||||
	}
 | 
			
		||||
	if err := http.ListenAndServe(ng.conf.Addr, nil); err != nil {
 | 
			
		||||
		log.Fatal(err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) listenAndServeWSTLS() {
 | 
			
		||||
	ng.setupWS()
 | 
			
		||||
	if ng.shallLogDebug() {
 | 
			
		||||
		nlog.Debugf("%s now listening websocket with tls at %s.", ng.logPrefix(), ng.conf.Addr)
 | 
			
		||||
	}
 | 
			
		||||
	if err := http.ListenAndServeTLS(ng.conf.Addr, ng.wsOpt.TLSCertificate, ng.wsOpt.TLSKey, nil); err != nil {
 | 
			
		||||
		log.Fatal(err.Error())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) setupWS() {
 | 
			
		||||
	upgrade := websocket.Upgrader{
 | 
			
		||||
		ReadBufferSize:  1024,
 | 
			
		||||
		WriteBufferSize: 1024,
 | 
			
		||||
		CheckOrigin:     ng.wsOpt.CheckOrigin,
 | 
			
		||||
	}
 | 
			
		||||
	http.HandleFunc("/"+strings.TrimPrefix(ng.wsOpt.WebsocketPath, "/"), func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		conn, err := upgrade.Upgrade(w, r, nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			nlog.Errorf("%s Upgrade failure, URI=%ng, Error=%ng", ng.logPrefix(), r.RequestURI, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = pool.Submit(func() {
 | 
			
		||||
			ng.handleWS(conn)
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatalf("%s submit conn pool err: %v", ng.logPrefix(), err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) handleWS(conn *websocket.Conn) {
 | 
			
		||||
	c := newWSConn(conn)
 | 
			
		||||
	ng.handle(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) handle(conn net.Conn) {
 | 
			
		||||
	c := newConnection(ng, conn)
 | 
			
		||||
	err := ng.connManager.Store(conn2.DefaultGroupName, c)
 | 
			
		||||
	nlog.Must(err)
 | 
			
		||||
 | 
			
		||||
	c.serve()
 | 
			
		||||
	// hook
 | 
			
		||||
	ng.lifetime.Open(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ng *engine) notFoundHandler(next Handler) Handler {
 | 
			
		||||
	return HandlerFunc(func(entity entity.NetworkEntity, packet packet.IPacket) {
 | 
			
		||||
		h := next
 | 
			
		||||
		if next == nil {
 | 
			
		||||
			h = NotFoundHandler()
 | 
			
		||||
		}
 | 
			
		||||
		// TODO write to client
 | 
			
		||||
		h.Handle(entity, packet)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@ -1,20 +0,0 @@
 | 
			
		||||
package core
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/entity"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/packet"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type (
 | 
			
		||||
	Handler interface {
 | 
			
		||||
		Handle(entity entity.NetworkEntity, pkg packet.IPacket)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	HandlerFunc func(entity entity.NetworkEntity, pkg packet.IPacket)
 | 
			
		||||
 | 
			
		||||
	Middleware func(next HandlerFunc) HandlerFunc
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (f HandlerFunc) Handle(entity entity.NetworkEntity, pkg packet.IPacket) {
 | 
			
		||||
	f(entity, pkg)
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,135 @@
 | 
			
		||||
package nnet
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/config"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/connection"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/lifetime"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/packet"
 | 
			
		||||
	rt "git.noahlan.cn/noahlan/nnet/router"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/scheduler"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/serialize"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/session"
 | 
			
		||||
	"git.noahlan.cn/noahlan/ntools-go/core/nlog"
 | 
			
		||||
	"github.com/panjf2000/ants/v2"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Engine 引擎
 | 
			
		||||
type Engine struct {
 | 
			
		||||
	config.EngineConf                      // 引擎配置
 | 
			
		||||
	middlewares       []rt.Middleware      // 中间件
 | 
			
		||||
	routes            []rt.Route           // 路由
 | 
			
		||||
	router            rt.Router            // 消息处理器
 | 
			
		||||
	dieChan           chan struct{}        // 应用程序退出信号
 | 
			
		||||
	pipeline          connection.Pipeline  // 消息管道
 | 
			
		||||
	packerBuilder     packet.PackerBuilder // 封包、拆包器
 | 
			
		||||
	serializer        serialize.Serializer // 消息 序列化/反序列化
 | 
			
		||||
	goPool            *ants.Pool           // goroutine池
 | 
			
		||||
	connManager       *connection.Manager  // 连接管理器
 | 
			
		||||
	lifetime          *lifetime.Mgr        // 生命周期
 | 
			
		||||
	sessIdMgr         *session.IDMgr       // SessionId管理器
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewEngine(conf config.EngineConf, opts ...RunOption) *Engine {
 | 
			
		||||
	ngin := &Engine{
 | 
			
		||||
		EngineConf:    conf,
 | 
			
		||||
		middlewares:   make([]rt.Middleware, 0),
 | 
			
		||||
		routes:        make([]rt.Route, 0),
 | 
			
		||||
		router:        rt.NewDefaultRouter(),
 | 
			
		||||
		packerBuilder: nil,
 | 
			
		||||
		serializer:    nil,
 | 
			
		||||
		dieChan:       make(chan struct{}),
 | 
			
		||||
		pipeline:      connection.NewPipeline(),
 | 
			
		||||
		connManager:   connection.NewManager(),
 | 
			
		||||
		lifetime:      lifetime.NewLifetime(),
 | 
			
		||||
		sessIdMgr:     session.NewSessionIDMgr(),
 | 
			
		||||
		goPool:        nil,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, opt := range opts {
 | 
			
		||||
		opt(ngin)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ngin.goPool == nil {
 | 
			
		||||
		ngin.goPool, _ = ants.NewPool(math.MaxInt32)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ngin
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) Use(middleware ...rt.Middleware) {
 | 
			
		||||
	ngin.middlewares = append(ngin.middlewares, middleware...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) AddRoutes(rs ...rt.Route) {
 | 
			
		||||
	ngin.routes = append(ngin.routes, rs...)
 | 
			
		||||
	err := ngin.bindRoutes()
 | 
			
		||||
	nlog.Must(err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) bindRoutes() error {
 | 
			
		||||
	for _, fr := range ngin.routes {
 | 
			
		||||
		if err := ngin.bindRoute(fr); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) bindRoute(route rt.Route) error {
 | 
			
		||||
	// TODO default middleware
 | 
			
		||||
	chain := rt.NewChain()
 | 
			
		||||
	// build chain
 | 
			
		||||
	for _, middleware := range ngin.middlewares {
 | 
			
		||||
		chain.Append(rt.ConvertMiddleware(middleware))
 | 
			
		||||
	}
 | 
			
		||||
	return ngin.router.Register(route.Matches, route.Handler)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) setup() error {
 | 
			
		||||
	if err := ngin.bindRoutes(); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if err := ngin.goPool.Submit(func() {
 | 
			
		||||
		scheduler.Schedule(ngin.TaskTimerPrecision)
 | 
			
		||||
	}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) Stop() {
 | 
			
		||||
	nlog.Infof("%s is stopping...", ngin.LogPrefix())
 | 
			
		||||
	close(ngin.dieChan)
 | 
			
		||||
	scheduler.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) handle(conn net.Conn) *connection.Connection {
 | 
			
		||||
	nc := connection.NewConnection(
 | 
			
		||||
		ngin.sessIdMgr.SessionID(),
 | 
			
		||||
		conn,
 | 
			
		||||
		connection.Config{LogDebug: ngin.ShallLogDebug(), LogPrefix: ngin.LogPrefix()},
 | 
			
		||||
		ngin.packerBuilder, ngin.serializer, ngin.pipeline,
 | 
			
		||||
		ngin.router.Handle,
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	nc.Serve()
 | 
			
		||||
 | 
			
		||||
	err := ngin.connManager.Store(connection.DefaultGroupName, nc)
 | 
			
		||||
	nlog.Must(err)
 | 
			
		||||
 | 
			
		||||
	// dieChan
 | 
			
		||||
	go func() {
 | 
			
		||||
		// lifetime
 | 
			
		||||
		ngin.lifetime.Open(nc)
 | 
			
		||||
 | 
			
		||||
		select {
 | 
			
		||||
		case <-nc.DieChan():
 | 
			
		||||
			scheduler.PushTask(func() { ngin.lifetime.Close(nc) })
 | 
			
		||||
			_ = ngin.connManager.Remove(nc)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	return nc
 | 
			
		||||
}
 | 
			
		||||
@ -1,26 +0,0 @@
 | 
			
		||||
package entity
 | 
			
		||||
 | 
			
		||||
type Session interface {
 | 
			
		||||
	// ID 获取 Session ID
 | 
			
		||||
	ID() int64
 | 
			
		||||
	// UID 获取UID
 | 
			
		||||
	UID() string
 | 
			
		||||
	// Bind 绑定uid
 | 
			
		||||
	Bind(uid string)
 | 
			
		||||
	// Attribute 获取指定key对应参数
 | 
			
		||||
	Attribute(key string) interface{}
 | 
			
		||||
	// Keys 获取所有参数key
 | 
			
		||||
	Keys() []string
 | 
			
		||||
	// Exists 指定key是否存在
 | 
			
		||||
	Exists(key string) bool
 | 
			
		||||
	// Attributes 获取所有参数
 | 
			
		||||
	Attributes() map[string]interface{}
 | 
			
		||||
	// RemoveAttribute 移除指定key对应参数
 | 
			
		||||
	RemoveAttribute(key string)
 | 
			
		||||
	// SetAttribute 设置参数
 | 
			
		||||
	SetAttribute(key string, value interface{})
 | 
			
		||||
	// Invalidate 清理
 | 
			
		||||
	Invalidate()
 | 
			
		||||
	// Close 关闭 Session
 | 
			
		||||
	Close()
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,43 @@
 | 
			
		||||
package modbus
 | 
			
		||||
 | 
			
		||||
import "sync"
 | 
			
		||||
 | 
			
		||||
var crcTable []uint16
 | 
			
		||||
var mu sync.Mutex
 | 
			
		||||
 | 
			
		||||
// crcModbus 计算modbus的crc
 | 
			
		||||
func crcModbus(data []byte) (crc uint16) {
 | 
			
		||||
	if crcTable == nil {
 | 
			
		||||
		mu.Lock()
 | 
			
		||||
		if crcTable == nil {
 | 
			
		||||
			initCrcTable()
 | 
			
		||||
		}
 | 
			
		||||
		mu.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
	crc = 0xffff
 | 
			
		||||
	for _, v := range data {
 | 
			
		||||
		crc = (crc >> 8) ^ crcTable[(crc^uint16(v))&0x00FF]
 | 
			
		||||
	}
 | 
			
		||||
	return crc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// initCrcTable 初始化crcTable
 | 
			
		||||
func initCrcTable() {
 | 
			
		||||
	crc16IBM := uint16(0xA001)
 | 
			
		||||
	crcTable = make([]uint16, 256)
 | 
			
		||||
 | 
			
		||||
	for i := uint16(0); i < 256; i++ {
 | 
			
		||||
		crc := uint16(0)
 | 
			
		||||
		c := i
 | 
			
		||||
 | 
			
		||||
		for j := uint16(0); j < 8; j++ {
 | 
			
		||||
			if ((crc ^ c) & 0x0001) > 0 {
 | 
			
		||||
				crc = (crc >> 1) ^ crc16IBM
 | 
			
		||||
			} else {
 | 
			
		||||
				crc = crc >> 1
 | 
			
		||||
			}
 | 
			
		||||
			c = c >> 1
 | 
			
		||||
		}
 | 
			
		||||
		crcTable[i] = crc
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,14 @@
 | 
			
		||||
package modbus
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/protocol/modbus/internal"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestCRC(t *testing.T) {
 | 
			
		||||
	got := crcModbus([]byte{0x01, 0x04, 0x02, 0xFF, 0xFF})
 | 
			
		||||
	expect := uint16(0x80B8)
 | 
			
		||||
 | 
			
		||||
	assert := internal.NewAssert(t, "TestCRC")
 | 
			
		||||
	assert.Equal(expect, got)
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,167 @@
 | 
			
		||||
package internal
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"runtime"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	compareNotEqual int = iota - 2
 | 
			
		||||
	compareLess
 | 
			
		||||
	compareEqual
 | 
			
		||||
	compareGreater
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Assert is a simple implementation of assertion, only for internal usage
 | 
			
		||||
type Assert struct {
 | 
			
		||||
	T        *testing.T
 | 
			
		||||
	CaseName string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewAssert return instance of Assert
 | 
			
		||||
func NewAssert(t *testing.T, caseName string) *Assert {
 | 
			
		||||
	return &Assert{T: t, CaseName: caseName}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Equal check if expected is equal with actual
 | 
			
		||||
func (a *Assert) Equal(expected, actual any) {
 | 
			
		||||
	if compare(expected, actual) != compareEqual {
 | 
			
		||||
		makeTestFailed(a.T, a.CaseName, expected, actual)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NotEqual check if expected is not equal with actual
 | 
			
		||||
func (a *Assert) NotEqual(expected, actual any) {
 | 
			
		||||
	if compare(expected, actual) == compareEqual {
 | 
			
		||||
		expectedInfo := fmt.Sprintf("not %v", expected)
 | 
			
		||||
		makeTestFailed(a.T, a.CaseName, expectedInfo, actual)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Greater check if expected is greate than actual
 | 
			
		||||
func (a *Assert) Greater(expected, actual any) {
 | 
			
		||||
	if compare(expected, actual) != compareGreater {
 | 
			
		||||
		expectedInfo := fmt.Sprintf("> %v", expected)
 | 
			
		||||
		makeTestFailed(a.T, a.CaseName, expectedInfo, actual)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GreaterOrEqual check if expected is greate than or equal with actual
 | 
			
		||||
func (a *Assert) GreaterOrEqual(expected, actual any) {
 | 
			
		||||
	isGreatOrEqual := compare(expected, actual) == compareGreater || compare(expected, actual) == compareEqual
 | 
			
		||||
	if !isGreatOrEqual {
 | 
			
		||||
		expectedInfo := fmt.Sprintf(">= %v", expected)
 | 
			
		||||
		makeTestFailed(a.T, a.CaseName, expectedInfo, actual)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Less check if expected is less than actual
 | 
			
		||||
func (a *Assert) Less(expected, actual any) {
 | 
			
		||||
	if compare(expected, actual) != compareLess {
 | 
			
		||||
		expectedInfo := fmt.Sprintf("< %v", expected)
 | 
			
		||||
		makeTestFailed(a.T, a.CaseName, expectedInfo, actual)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LessOrEqual check if expected is less than or equal with actual
 | 
			
		||||
func (a *Assert) LessOrEqual(expected, actual any) {
 | 
			
		||||
	isLessOrEqual := compare(expected, actual) == compareLess || compare(expected, actual) == compareEqual
 | 
			
		||||
	if !isLessOrEqual {
 | 
			
		||||
		expectedInfo := fmt.Sprintf("<= %v", expected)
 | 
			
		||||
		makeTestFailed(a.T, a.CaseName, expectedInfo, actual)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsNil check if value is nil
 | 
			
		||||
func (a *Assert) IsNil(value any) {
 | 
			
		||||
	if value != nil {
 | 
			
		||||
		makeTestFailed(a.T, a.CaseName, nil, value)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsNotNil check if value is not nil
 | 
			
		||||
func (a *Assert) IsNotNil(value any) {
 | 
			
		||||
	if value == nil {
 | 
			
		||||
		makeTestFailed(a.T, a.CaseName, "not nil", value)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// compare x and y return :
 | 
			
		||||
// x > y -> 1, x < y -> -1, x == y -> 0, x != y -> -2
 | 
			
		||||
func compare(x, y any) int {
 | 
			
		||||
	vx := reflect.ValueOf(x)
 | 
			
		||||
	vy := reflect.ValueOf(y)
 | 
			
		||||
 | 
			
		||||
	if vx.Type() != vy.Type() {
 | 
			
		||||
		return compareNotEqual
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch vx.Kind() {
 | 
			
		||||
	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 | 
			
		||||
		xInt := vx.Int()
 | 
			
		||||
		yInt := vy.Int()
 | 
			
		||||
		if xInt > yInt {
 | 
			
		||||
			return compareGreater
 | 
			
		||||
		}
 | 
			
		||||
		if xInt == yInt {
 | 
			
		||||
			return compareEqual
 | 
			
		||||
		}
 | 
			
		||||
		if xInt < yInt {
 | 
			
		||||
			return compareLess
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 | 
			
		||||
		xUint := vx.Uint()
 | 
			
		||||
		yUint := vy.Uint()
 | 
			
		||||
		if xUint > yUint {
 | 
			
		||||
			return compareGreater
 | 
			
		||||
		}
 | 
			
		||||
		if xUint == yUint {
 | 
			
		||||
			return compareEqual
 | 
			
		||||
		}
 | 
			
		||||
		if xUint < yUint {
 | 
			
		||||
			return compareLess
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.Float32, reflect.Float64:
 | 
			
		||||
		xFloat := vx.Float()
 | 
			
		||||
		yFloat := vy.Float()
 | 
			
		||||
		if xFloat > yFloat {
 | 
			
		||||
			return compareGreater
 | 
			
		||||
		}
 | 
			
		||||
		if xFloat == yFloat {
 | 
			
		||||
			return compareEqual
 | 
			
		||||
		}
 | 
			
		||||
		if xFloat < yFloat {
 | 
			
		||||
			return compareLess
 | 
			
		||||
		}
 | 
			
		||||
	case reflect.String:
 | 
			
		||||
		xString := vx.String()
 | 
			
		||||
		yString := vy.String()
 | 
			
		||||
		if xString > yString {
 | 
			
		||||
			return compareGreater
 | 
			
		||||
		}
 | 
			
		||||
		if xString == yString {
 | 
			
		||||
			return compareEqual
 | 
			
		||||
		}
 | 
			
		||||
		if xString < yString {
 | 
			
		||||
			return compareLess
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		if reflect.DeepEqual(x, y) {
 | 
			
		||||
			return compareEqual
 | 
			
		||||
		}
 | 
			
		||||
		return compareNotEqual
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return compareNotEqual
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// logFailedInfo make test failed and log error info
 | 
			
		||||
func makeTestFailed(t *testing.T, caseName string, expected, actual any) {
 | 
			
		||||
	_, file, line, _ := runtime.Caller(2)
 | 
			
		||||
	errInfo := fmt.Sprintf("Case %v failed. file: %v, line: %v, expected: %v, actual: %v.", caseName, file, line, expected, actual)
 | 
			
		||||
	t.Error(errInfo)
 | 
			
		||||
	t.FailNow()
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,50 @@
 | 
			
		||||
package internal
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestAssert(t *testing.T) {
 | 
			
		||||
	assert := NewAssert(t, "TestAssert")
 | 
			
		||||
	assert.Equal(0, 0)
 | 
			
		||||
	assert.NotEqual(1, 0)
 | 
			
		||||
 | 
			
		||||
	assert.NotEqual("1", 1)
 | 
			
		||||
	var uInt1 uint
 | 
			
		||||
	var uInt2 uint
 | 
			
		||||
	var uInt8 uint8
 | 
			
		||||
	var uInt16 uint16
 | 
			
		||||
	var uInt32 uint32
 | 
			
		||||
	var uInt64 uint64
 | 
			
		||||
	assert.NotEqual(uInt1, uInt8)
 | 
			
		||||
	assert.NotEqual(uInt8, uInt16)
 | 
			
		||||
	assert.NotEqual(uInt16, uInt32)
 | 
			
		||||
	assert.NotEqual(uInt32, uInt64)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(uInt1, uInt2)
 | 
			
		||||
 | 
			
		||||
	uInt1 = 1
 | 
			
		||||
	uInt2 = 2
 | 
			
		||||
	assert.Less(uInt1, uInt2)
 | 
			
		||||
 | 
			
		||||
	assert.Greater(1, 0)
 | 
			
		||||
	assert.GreaterOrEqual(1, 1)
 | 
			
		||||
	assert.Less(0, 1)
 | 
			
		||||
	assert.LessOrEqual(0, 0)
 | 
			
		||||
 | 
			
		||||
	assert.Equal(0.1, 0.1)
 | 
			
		||||
	assert.Greater(1.1, 0.1)
 | 
			
		||||
	assert.Less(0.1, 1.1)
 | 
			
		||||
 | 
			
		||||
	assert.Equal("abc", "abc")
 | 
			
		||||
	assert.NotEqual("abc", "abd")
 | 
			
		||||
	assert.Less("abc", "abd")
 | 
			
		||||
	assert.Greater("abd", "abc")
 | 
			
		||||
 | 
			
		||||
	assert.Equal([]int{1, 2, 3}, []int{1, 2, 3})
 | 
			
		||||
	assert.NotEqual([]int{1, 2, 3}, []int{1, 2})
 | 
			
		||||
 | 
			
		||||
	assert.IsNil(nil)
 | 
			
		||||
	assert.IsNotNil("abc")
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
@ -1,15 +1,15 @@
 | 
			
		||||
package plain
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/core"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/packet"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func WithPlainProtocol() []core.RunOption {
 | 
			
		||||
	opts := []core.RunOption{
 | 
			
		||||
func WithPlainProtocol() []nnet.RunOption {
 | 
			
		||||
	opts := []nnet.RunOption{
 | 
			
		||||
		withPipeline(),
 | 
			
		||||
		core.WithRouter(NewRouter()),
 | 
			
		||||
		core.WithPacker(func() packet.Packer { return NewPacker() }),
 | 
			
		||||
		nnet.WithRouter(NewRouter()),
 | 
			
		||||
		nnet.WithPackerBuilder(func() packet.Packer { return NewPacker() }),
 | 
			
		||||
	}
 | 
			
		||||
	return opts
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,30 @@
 | 
			
		||||
package router
 | 
			
		||||
 | 
			
		||||
// ToMiddleware converts the given handler to a Middleware.
 | 
			
		||||
func ToMiddleware(handler func(next Handler) Handler) Middleware {
 | 
			
		||||
	return func(next HandlerFunc) HandlerFunc {
 | 
			
		||||
		return handler(next).Handle
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithMiddlewares adds given middlewares to given routes.
 | 
			
		||||
func WithMiddlewares(ms []Middleware, rs ...Route) []Route {
 | 
			
		||||
	for i := len(ms) - 1; i >= 0; i-- {
 | 
			
		||||
		rs = WithMiddleware(ms[i], rs...)
 | 
			
		||||
	}
 | 
			
		||||
	return rs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithMiddleware adds given middleware to given route.
 | 
			
		||||
func WithMiddleware(middleware Middleware, rs ...Route) []Route {
 | 
			
		||||
	routes := make([]Route, len(rs))
 | 
			
		||||
 | 
			
		||||
	for i := range rs {
 | 
			
		||||
		route := rs[i]
 | 
			
		||||
		routes[i] = Route{
 | 
			
		||||
			Matches: route.Matches,
 | 
			
		||||
			Handler: middleware(route.Handler),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return routes
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,7 @@
 | 
			
		||||
package router
 | 
			
		||||
 | 
			
		||||
func ConvertMiddleware(ware Middleware) func(Handler) Handler {
 | 
			
		||||
	return func(next Handler) Handler {
 | 
			
		||||
		return ware(next.Handle)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,46 @@
 | 
			
		||||
package nnet
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/config"
 | 
			
		||||
	"git.noahlan.cn/noahlan/ntools-go/core/nlog"
 | 
			
		||||
	"net"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) ListenTCP(conf config.TCPServerConf) error {
 | 
			
		||||
	err := ngin.setup()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		nlog.Errorf("%s failed to setup server, err:%v", ngin.LogPrefix(), err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	listener, err := net.Listen(conf.Protocol, conf.Addr)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		nlog.Errorf("%s failed to listening at [%s %s] %v", ngin.LogPrefix(), conf.Protocol, conf.Addr, err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	nlog.Infof("%s now listening %s at %s...", ngin.LogPrefix(), conf.Protocol, conf.Addr)
 | 
			
		||||
	defer func() {
 | 
			
		||||
		_ = listener.Close()
 | 
			
		||||
		ngin.Stop()
 | 
			
		||||
	}()
 | 
			
		||||
	for {
 | 
			
		||||
		conn, err := listener.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if errors.Is(err, net.ErrClosed) {
 | 
			
		||||
				nlog.Errorf("%s connection closed, err:%v", ngin.LogPrefix(), err)
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			nlog.Errorf("%s accept connection failed, err:%v", ngin.LogPrefix(), err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		err = ngin.goPool.Submit(func() {
 | 
			
		||||
			ngin.handle(conn)
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			nlog.Errorf("%s submit conn pool err: %ng", ngin.LogPrefix(), err.Error())
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,66 @@
 | 
			
		||||
package nnet
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/config"
 | 
			
		||||
	"git.noahlan.cn/noahlan/nnet/connection"
 | 
			
		||||
	"git.noahlan.cn/noahlan/ntools-go/core/nlog"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ListenWebsocket 开始监听Websocket
 | 
			
		||||
func (ngin *Engine) ListenWebsocket(conf config.WSServerConf) error {
 | 
			
		||||
	err := ngin.setup()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		nlog.Errorf("%s failed to setup server, err:%v", ngin.LogPrefix(), err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	nlog.Infof("%s now listening websocket at %s.", ngin.LogPrefix(), conf.Addr)
 | 
			
		||||
	ngin.upgradeWebsocket(conf)
 | 
			
		||||
	if conf.IsTLS() {
 | 
			
		||||
		if err := http.ListenAndServeTLS(conf.Addr, conf.TLSCertificate, conf.TLSKey, nil); err != nil {
 | 
			
		||||
			nlog.Errorf("%s failed to listening websocket with TLS at %s %v", ngin.LogPrefix(), conf.Addr, err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if err := http.ListenAndServe(conf.Addr, nil); err != nil {
 | 
			
		||||
			nlog.Errorf("%s failed to listening websocket at %s %v", ngin.LogPrefix(), conf.Addr, err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) handleWS(conn *websocket.Conn) {
 | 
			
		||||
	wsConn := connection.NewWSConn(conn)
 | 
			
		||||
	ngin.handle(wsConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ngin *Engine) upgradeWebsocket(conf config.WSServerConf) {
 | 
			
		||||
	upgrade := websocket.Upgrader{
 | 
			
		||||
		HandshakeTimeout:  conf.HandshakeTimeout,
 | 
			
		||||
		ReadBufferSize:    conf.ReadBufferSize,
 | 
			
		||||
		WriteBufferSize:   conf.WriteBufferSize,
 | 
			
		||||
		CheckOrigin:       conf.CheckOrigin,
 | 
			
		||||
		EnableCompression: conf.Compression,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	path := fmt.Sprintf("/%s", strings.TrimPrefix(conf.Path, "/"))
 | 
			
		||||
	http.HandleFunc(path, func(writer http.ResponseWriter, request *http.Request) {
 | 
			
		||||
		conn, err := upgrade.Upgrade(writer, request, nil)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			nlog.Errorf("%s Upgrade failure, URI=%ng, Error=%ng", ngin.LogPrefix(), request.RequestURI, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = ngin.goPool.Submit(func() {
 | 
			
		||||
			ngin.handleWS(conn)
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			nlog.Errorf("%s submit conn pool err: %v", ngin.LogPrefix(), err.Error())
 | 
			
		||||
			os.Exit(1)
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,38 @@
 | 
			
		||||
package session
 | 
			
		||||
 | 
			
		||||
import "sync/atomic"
 | 
			
		||||
 | 
			
		||||
type IDMgr struct {
 | 
			
		||||
	count int64
 | 
			
		||||
	sid   int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSessionIDMgr() *IDMgr {
 | 
			
		||||
	return &IDMgr{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Increment the connection count
 | 
			
		||||
func (c *IDMgr) Increment() {
 | 
			
		||||
	atomic.AddInt64(&c.count, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Decrement the connection count
 | 
			
		||||
func (c *IDMgr) Decrement() {
 | 
			
		||||
	atomic.AddInt64(&c.count, -1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Count returns the connection numbers in current
 | 
			
		||||
func (c *IDMgr) Count() int64 {
 | 
			
		||||
	return atomic.LoadInt64(&c.count)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Reset the connection service status
 | 
			
		||||
func (c *IDMgr) Reset() {
 | 
			
		||||
	atomic.StoreInt64(&c.count, 0)
 | 
			
		||||
	atomic.StoreInt64(&c.sid, 0)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SessionID returns the session id
 | 
			
		||||
func (c *IDMgr) SessionID() int64 {
 | 
			
		||||
	return atomic.AddInt64(&c.sid, 1)
 | 
			
		||||
}
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue