package server import ( "context" "crypto/tls" "sync" "time" "github.com/noahlann/nnet/internal/connection" "github.com/noahlann/nnet/internal/logger" internalrouter "github.com/noahlann/nnet/internal/router" codecpkg "github.com/noahlann/nnet/pkg/codec" "github.com/noahlann/nnet/pkg/config" "github.com/noahlann/nnet/pkg/errors" protocolpkg "github.com/noahlann/nnet/pkg/protocol" routerpkg "github.com/noahlann/nnet/pkg/router" "github.com/panjf2000/gnet/v2" ) // TransportProtocol 传输层协议类型 type TransportProtocol int const ( ProtocolTCP TransportProtocol = iota ProtocolUDP ProtocolUnix ) // String 返回协议名称 func (p TransportProtocol) String() string { switch p { case ProtocolTCP: return "TCP" case ProtocolUDP: return "UDP" case ProtocolUnix: return "Unix" default: return "Unknown" } } // NetworkString 返回gnet网络协议字符串 func (p TransportProtocol) NetworkString() string { switch p { case ProtocolTCP: return "tcp" case ProtocolUDP: return "udp" case ProtocolUnix: return "unix" default: return "tcp" } } // IsConnectionOriented 是否为面向连接的协议(TCP、Unix) func (p TransportProtocol) IsConnectionOriented() bool { return p == ProtocolTCP || p == ProtocolUnix } // IsDatagram 是否为数据报协议(UDP) func (p TransportProtocol) IsDatagram() bool { return p == ProtocolUDP } // gnetServer 统一的gnet服务器基类(支持TCP、UDP、Unix) const startupWaitTimeout = 5 * time.Second type gnetServer struct { config *config.Config logger logger.Logger connManager connection.ManagerInterface router routerpkg.Router codecRegistry codecpkg.Registry codecResolverChain *codecpkg.ResolverChain protocolManager protocolpkg.Manager eventHandler *unifiedEventHandler mu sync.RWMutex ctx context.Context cancel context.CancelFunc started bool stopCh chan struct{} stopped bool stopMu sync.Mutex tlsConfig *tls.Config protocol TransportProtocol // 传输层协议类型 bootCh chan error bootOnce *sync.Once } // newGnetServer 创建统一的gnet服务器 func newGnetServer(cfg *config.Config, protocol TransportProtocol) (*gnetServer, error) { if cfg == nil { cfg = config.DefaultConfig() } if err := cfg.Validate(); err != nil { return nil, err } // 创建日志器 var logConfig *logger.Config if cfg.Logger != nil { logConfig = &logger.Config{ Level: cfg.Logger.Level, Format: cfg.Logger.Format, Output: cfg.Logger.Output, } } else { logConfig = &logger.Config{ Level: "info", Format: "text", Output: "stdout", } } log := logger.New(logConfig) // 创建连接管理器 shardCount := cfg.ConnectionManagerShards if shardCount <= 0 { shardCount = 16 } connManager := connection.NewShardedManager(cfg.MaxConnections, shardCount) // 创建路由器 r := internalrouter.NewRouter() // 创建编解码器注册表 codecRegistry := initCodecRegistry(cfg) // 创建协议管理器 protocolManager := initProtocolManager() // 创建Codec解析器链 configHelper := newServerConfigHelper(cfg) defaultCodec := configHelper.DefaultCodec() codecResolverChain := codecpkg.NewResolverChain(codecRegistry, defaultCodec) // 加载TLS配置(如果启用) var tlsConfig *tls.Config if cfg.TLSEnabled && cfg.TLS != nil { var err error tlsConfig, err = loadTLSConfig(cfg.TLS) if err != nil { return nil, errors.New("failed to load TLS config").WithCause(err) } // 对于UDP,需要DTLS配置(这里暂时使用TLS配置,后续可以扩展) if protocol == ProtocolUDP { // DTLS需要特殊处理,这里先使用TLS配置 log.Info("DTLS support will be added in future version") } } // 创建事件处理器 handlerTimeout := configHelper.HandlerTimeout() cloneHeader := configHelper.CloneHeader() enableProtocolEncode := configHelper.IsProtocolEncodeEnabled() protocolName := cfg.ApplicationProtocol bootCh := make(chan error, 1) bootOnce := &sync.Once{} eventHandler := newUnifiedEventHandler(connManager, r, log, codecRegistry, codecResolverChain, defaultCodec, protocolManager, protocolName, enableProtocolEncode, cloneHeader, handlerTimeout, protocol, bootCh, bootOnce) ctx, cancel := context.WithCancel(context.Background()) return &gnetServer{ config: cfg, logger: log, connManager: connManager, router: r, codecRegistry: codecRegistry, codecResolverChain: codecResolverChain, protocolManager: protocolManager, eventHandler: eventHandler, ctx: ctx, cancel: cancel, stopCh: make(chan struct{}), tlsConfig: tlsConfig, protocol: protocol, bootCh: bootCh, bootOnce: bootOnce, }, nil } // Start 启动服务器 func (s *gnetServer) Start() error { s.mu.Lock() defer s.mu.Unlock() if s.started { return errors.ErrServerAlreadyStarted } // 解析地址 addr := s.config.Addr protocolStr := s.protocol.NetworkString() // 格式化地址 addr = formatGnetAddr(addr, protocolStr) s.logger.Info("Starting %s server on %s", s.protocol.String(), addr) // 重置启动同步器 s.bootCh = make(chan error, 1) s.bootOnce = &sync.Once{} s.eventHandler.setBootSignal(s.bootCh, s.bootOnce) // 构建gnet选项 options := []gnet.Option{ gnet.WithMulticore(s.config.Multicore), gnet.WithReadBufferCap(s.config.ReadBufferSize), gnet.WithWriteBufferCap(s.config.WriteBufferSize), } // 添加TLS支持(如果启用) // 注意:gnet v2的TLS支持可能需要通过其他方式实现 // 这里先记录TLS配置,后续可以在事件处理器中处理TLS握手 if s.tlsConfig != nil { if s.protocol.IsConnectionOriented() { // TCP/Unix TLS支持(gnet v2可能需要通过其他方式配置) // TODO: 实现gnet v2的TLS支持 s.logger.Info("TLS enabled, but gnet v2 TLS integration needs to be implemented") } else { // UDP DTLS支持(需要第三方库如pion/dtls) s.logger.Warn("DTLS support is not yet implemented, TLS config will be ignored for UDP") } } // 在goroutine中启动服务器 go func() { err := gnet.Run(s.eventHandler, addr, options...) if err != nil { s.mu.Lock() s.started = false s.mu.Unlock() if s.bootOnce != nil && s.bootCh != nil { s.bootOnce.Do(func() { select { case s.bootCh <- err: default: } }) } s.logger.Error("Server error: %v", err) } s.stopMu.Lock() if !s.stopped { close(s.stopCh) s.stopped = true } s.stopMu.Unlock() }() // 等待服务器启动 select { case bootErr := <-s.bootCh: if bootErr != nil { return errors.New("failed to start server").WithCause(bootErr) } case <-time.After(startupWaitTimeout): return errors.New("server start timeout") } s.started = true s.logger.Info("%s server started on %s", s.protocol.String(), addr) return nil } // Stop 停止服务器 func (s *gnetServer) Stop() error { s.mu.Lock() if !s.started { s.mu.Unlock() return errors.ErrServerNotStarted } s.mu.Unlock() s.logger.Info("Stopping %s server...", s.protocol.String()) // 取消context s.cancel() // 等待服务器停止 select { case <-s.stopCh: s.logger.Info("%s server stopped", s.protocol.String()) case <-time.After(5 * time.Second): s.logger.Warn("%s server stop timeout", s.protocol.String()) } s.mu.Lock() s.started = false s.mu.Unlock() return nil } // formatGnetAddr 格式化gnet地址 func formatGnetAddr(addr, protocolStr string) string { prefix := protocolStr + "://" if len(addr) > len(prefix) && addr[:len(prefix)] == prefix { return addr } // Unix Domain Socket 特殊处理 if protocolStr == "unix" { // 移除 unix:// 前缀(如果存在) if len(addr) > 7 && addr[:7] == "unix://" { return addr } // 如果没有前缀,直接返回(Unix socket 是文件路径) return prefix + addr } // TCP/UDP 处理 if len(addr) > 0 && addr[0] == ':' { return prefix + addr } if addr == "" || addr == ":" { return prefix + ":6995" } // 尝试提取端口 lastColon := -1 for i := len(addr) - 1; i >= 0; i-- { if addr[i] == ':' { lastColon = i break } } if lastColon >= 0 { port := addr[lastColon:] return prefix + port } return prefix + ":" + addr } // Router 获取路由器 func (s *gnetServer) Router() routerpkg.Router { return s.router } // ConnectionManager 获取连接管理器 func (s *gnetServer) ConnectionManager() connection.ManagerInterface { return s.connManager } // Config 获取配置 func (s *gnetServer) Config() *config.Config { return s.config } // Started 检查服务器是否已启动 func (s *gnetServer) Started() bool { s.mu.RLock() defer s.mu.RUnlock() return s.started } // CodecRegistry 获取编解码器注册表 func (s *gnetServer) CodecRegistry() codecpkg.Registry { return s.codecRegistry } // ProtocolManager 获取协议管理器 func (s *gnetServer) ProtocolManager() protocolpkg.Manager { return s.protocolManager }