package server import ( "context" "fmt" "net" "sync" "time" "github.com/noahlann/nnet/internal/connection" "github.com/noahlann/nnet/internal/logger" codecpkg "github.com/noahlann/nnet/pkg/codec" "github.com/noahlann/nnet/pkg/lifecycle" protocolpkg "github.com/noahlann/nnet/pkg/protocol" routerpkg "github.com/noahlann/nnet/pkg/router" unpackerpkg "github.com/noahlann/nnet/pkg/unpacker" "github.com/panjf2000/gnet/v2" ) // unifiedEventHandler 统一的事件处理器(支持TCP和UDP) type unifiedEventHandler struct { connManager connection.ManagerInterface router routerpkg.Router logger logger.Logger codecRegistry codecpkg.Registry codecResolverChain *codecpkg.ResolverChain defaultCodec string protocolManager protocolpkg.Manager protocolName string enableProtocolEncode bool cloneHeader bool handlerTimeout time.Duration protocol TransportProtocol unpackerManager *unpackerManager connLifecycleHooks []lifecycle.ConnectionLifecycleHook engine gnet.Engine engineMu sync.RWMutex engineSet bool bootCh chan error bootOnce *sync.Once msgHandler *messageHandler } // newUnifiedEventHandler 创建统一的事件处理器 func newUnifiedEventHandler(connManager connection.ManagerInterface, r routerpkg.Router, logger logger.Logger, codecRegistry codecpkg.Registry, codecResolverChain *codecpkg.ResolverChain, defaultCodec string, protocolManager protocolpkg.Manager, protocolName string, enableProtocolEncode bool, cloneHeader bool, handlerTimeout time.Duration, protocol TransportProtocol, bootCh chan error, bootOnce *sync.Once) *unifiedEventHandler { return &unifiedEventHandler{ connManager: connManager, router: r, logger: logger, codecRegistry: codecRegistry, codecResolverChain: codecResolverChain, defaultCodec: defaultCodec, protocolManager: protocolManager, protocolName: protocolName, enableProtocolEncode: enableProtocolEncode, cloneHeader: cloneHeader, handlerTimeout: handlerTimeout, protocol: protocol, unpackerManager: newUnpackerManager(), connLifecycleHooks: make([]lifecycle.ConnectionLifecycleHook, 0), bootCh: bootCh, bootOnce: bootOnce, msgHandler: newMessageHandler(logger, codecRegistry, codecResolverChain, defaultCodec, r, cloneHeader), } } // OnBoot 服务器启动时调用 func (h *unifiedEventHandler) OnBoot(eng gnet.Engine) (action gnet.Action) { h.engineMu.Lock() h.engine = eng h.engineSet = true h.engineMu.Unlock() if h.bootCh != nil && h.bootOnce != nil { h.bootOnce.Do(func() { select { case h.bootCh <- nil: default: } }) } h.logger.Info("%s server booted", h.protocol.String()) return gnet.None } // setBootSignal 设置启动信号通道 func (h *unifiedEventHandler) setBootSignal(ch chan error, once *sync.Once) { h.engineMu.Lock() h.bootCh = ch h.bootOnce = once h.engineMu.Unlock() } // OnShutdown 服务器关闭时调用 func (h *unifiedEventHandler) OnShutdown(eng gnet.Engine) { h.logger.Info("%s server shutdown", h.protocol.String()) } // OnOpen 连接打开时调用(TCP和Unix) func (h *unifiedEventHandler) OnOpen(c gnet.Conn) ([]byte, gnet.Action) { if !h.protocol.IsConnectionOriented() { // UDP下不会调用此方法 return nil, gnet.None } // TCP/Unix连接处理逻辑(gnet的Unix连接和TCP连接使用相同的接口) remoteAddr := c.RemoteAddr().String() conn := connection.NewConnection("", c) connID := conn.ID() if err := h.connManager.Add(conn); err != nil { h.logger.Error("Failed to add connection: %v", err) return nil, gnet.Close } // 执行连接生命周期钩子 for _, hook := range h.connLifecycleHooks { if err := hook.OnOpen(connID, remoteAddr); err != nil { h.logger.Error("OnOpen hook error: %v", err) } } var protocol protocolpkg.Protocol var protocolVersion string if h.enableProtocolEncode { protocol, _ = h.protocolManager.Get(h.protocolName, "") } var unpacker unpackerpkg.Unpacker if protocol != nil { unpacker = h.unpackerManager.getOrCreateUnpacker(connID, protocol) } connData := &connectionData{ conn: conn, unpacker: unpacker, protocol: protocol, protocolVersion: protocolVersion, } setConnectionData(c, connData) h.logger.Debug("Connection opened: %s from %s", connID, remoteAddr) return nil, gnet.None } // OnClose 连接关闭时调用(TCP和Unix) func (h *unifiedEventHandler) OnClose(c gnet.Conn, err error) (action gnet.Action) { if !h.protocol.IsConnectionOriented() { // UDP下不会调用此方法 return gnet.None } connData := getConnectionData(c) if connData == nil { return gnet.None } connID := connData.conn.ID() // 执行连接生命周期钩子 for _, hook := range h.connLifecycleHooks { if hookErr := hook.OnClose(connID, err); hookErr != nil { h.logger.Error("OnClose hook error: %v", hookErr) } } h.connManager.Remove(connID) h.unpackerManager.removeUnpacker(connID) h.logger.Debug("Connection closed: %s", connID) return gnet.None } // OnTraffic 数据到达时调用(TCP和Unix) func (h *unifiedEventHandler) OnTraffic(c gnet.Conn) (action gnet.Action) { if !h.protocol.IsConnectionOriented() { // UDP下不会调用此方法,使用React代替 return gnet.None } return h.handleTraffic(c, nil) } // React UDP数据到达时调用(仅UDP) func (h *unifiedEventHandler) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) { if !h.protocol.IsDatagram() { // TCP/Unix下不会调用此方法 return nil, gnet.None } if len(packet) == 0 { return nil, gnet.None } // UDP处理:直接调用handleTraffic,由handleTraffic统一处理UDP连接创建 return nil, h.handleTraffic(c, packet) } // handleTraffic 统一的数据处理逻辑(TCP和UDP共享) func (h *unifiedEventHandler) handleTraffic(c gnet.Conn, udpPacket []byte) gnet.Action { var data []byte var connData *connectionData var connID string var conn connection.ConnectionInterface if h.protocol.IsDatagram() { // UDP:直接使用传入的数据包 data = udpPacket if len(data) == 0 { return gnet.None } // 创建UDP虚拟连接 remoteAddr := c.RemoteAddr() if remoteAddr == nil { h.logger.Debug("UDP packet with nil remote address, ignoring") return gnet.None } udpAddr, ok := remoteAddr.(*net.UDPAddr) if !ok { addr, err := net.ResolveUDPAddr("udp", remoteAddr.String()) if err != nil { h.logger.Debug("Failed to resolve UDP address: %v", err) return gnet.None } udpAddr = addr } udpConn := connection.NewUDPConnection("", udpAddr, nil) udpConn.SetAttribute("gnet_conn", c) conn = udpConn connID = udpConn.ID() // UDP下创建临时连接数据 var protocol protocolpkg.Protocol if h.enableProtocolEncode { protocol, _ = h.protocolManager.Get(h.protocolName, "") } connData = &connectionData{ conn: conn, unpacker: nil, // UDP通常不需要unpacker protocol: protocol, protocolVersion: "", } } else { // TCP:从gnet buffer读取数据 connData = getConnectionData(c) if connData == nil { h.logger.Error("Connection data not found in context") return gnet.Close } connID = connData.conn.ID() conn = connData.conn inBuffer := c.InboundBuffered() if inBuffer == 0 { return gnet.None } peekData, _ := c.Peek(-1) if len(peekData) == 0 { return gnet.None } data = peekData } // 获取协议和unpacker protocol := connData.protocol unpacker := connData.unpacker // 版本识别(仅面向连接的协议,UDP通常不需要) if h.protocol.IsConnectionOriented() && h.enableProtocolEncode && connData.protocolVersion == "" { if identifiedVersion, err := h.identifyProtocolVersion(connData, data, c); err == nil && identifiedVersion != "" { newProtocol, err := h.protocolManager.Get(h.protocolName, identifiedVersion) if err == nil { connData.protocol = newProtocol connData.protocolVersion = identifiedVersion if newProtocol != nil { connData.unpacker = h.unpackerManager.getOrCreateUnpacker(connID, newProtocol) } protocol = newProtocol unpacker = connData.unpacker } } } // 处理数据:拆包(仅面向连接的协议需要,UDP数据包已经是完整的) var messages [][]byte var totalProcessed int if h.protocol.IsConnectionOriented() { // 使用统一的拆包处理函数 var err error messages, totalProcessed, _, err = processDataWithUnpacker(data, unpacker) if err != nil { c.Discard(c.InboundBuffered()) return gnet.Close } if len(messages) == 0 { // 没有完整消息,等待更多数据(gnet会自动保留数据在缓冲区中) return gnet.None } } else { // UDP数据包已经是完整的,不需要拆包 messages = [][]byte{data} totalProcessed = len(data) } // 处理每个完整的消息 ctxConn := toContextConnection(conn) for _, message := range messages { h.msgHandler.handleMessageWithContext( context.Background(), ctxConn, message, protocol, h.codecRegistry, h.handlerTimeout, ) } // 丢弃已处理的数据(仅面向连接的协议) if h.protocol.IsConnectionOriented() && totalProcessed > 0 { c.Discard(totalProcessed) } return gnet.None } // OnTick 定时器触发时调用 func (h *unifiedEventHandler) OnTick() (delay time.Duration, action gnet.Action) { h.connManager.CleanupInactive(30 * time.Second) return 10 * time.Second, gnet.None } // getEngine 获取engine引用(用于优雅关闭) func (h *unifiedEventHandler) getEngine() (gnet.Engine, bool) { h.engineMu.RLock() defer h.engineMu.RUnlock() return h.engine, h.engineSet } // identifyProtocolVersion 识别协议版本(仅TCP) func (h *unifiedEventHandler) identifyProtocolVersion(connData *connectionData, data []byte, c gnet.Conn) (string, error) { protocol := connData.protocol if protocol == nil && h.enableProtocolEncode { var err error protocol, err = h.protocolManager.Get(h.protocolName, "") if err != nil { return "", fmt.Errorf("failed to get default protocol: %w", err) } if protocol == nil { return "", nil } } if protocol != nil { header, _, err := protocol.Decode(data) if err != nil { return "", nil } if header != nil { versionVal := header.Get("version") if versionVal != nil { var version string switch v := versionVal.(type) { case string: version = v case byte: version = fmt.Sprintf("%d.0", v) default: version = fmt.Sprintf("%v", v) } if version != "" { _, err := h.protocolManager.Get(h.protocolName, version) if err == nil { return version, nil } } } } } identifier := h.protocolManager.GetVersionIdentifier(h.protocolName) if identifier == nil { return "", nil } ctx := context.Background() version, err := h.protocolManager.IdentifyVersion(h.protocolName, data, ctx) if err != nil { return "", nil } return version, nil }