You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

397 lines
11 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package server
import (
"context"
"fmt"
"net"
"sync"
"time"
"github.com/noahlann/nnet/internal/connection"
"github.com/noahlann/nnet/internal/logger"
codecpkg "github.com/noahlann/nnet/pkg/codec"
"github.com/noahlann/nnet/pkg/lifecycle"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
routerpkg "github.com/noahlann/nnet/pkg/router"
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
"github.com/panjf2000/gnet/v2"
)
// unifiedEventHandler 统一的事件处理器支持TCP和UDP
type unifiedEventHandler struct {
connManager connection.ManagerInterface
router routerpkg.Router
logger logger.Logger
codecRegistry codecpkg.Registry
codecResolverChain *codecpkg.ResolverChain
defaultCodec string
protocolManager protocolpkg.Manager
protocolName string
enableProtocolEncode bool
cloneHeader bool
handlerTimeout time.Duration
protocol TransportProtocol
unpackerManager *unpackerManager
connLifecycleHooks []lifecycle.ConnectionLifecycleHook
engine gnet.Engine
engineMu sync.RWMutex
engineSet bool
bootCh chan error
bootOnce *sync.Once
msgHandler *messageHandler
}
// newUnifiedEventHandler 创建统一的事件处理器
func newUnifiedEventHandler(connManager connection.ManagerInterface, r routerpkg.Router, logger logger.Logger, codecRegistry codecpkg.Registry, codecResolverChain *codecpkg.ResolverChain, defaultCodec string, protocolManager protocolpkg.Manager, protocolName string, enableProtocolEncode bool, cloneHeader bool, handlerTimeout time.Duration, protocol TransportProtocol, bootCh chan error, bootOnce *sync.Once) *unifiedEventHandler {
return &unifiedEventHandler{
connManager: connManager,
router: r,
logger: logger,
codecRegistry: codecRegistry,
codecResolverChain: codecResolverChain,
defaultCodec: defaultCodec,
protocolManager: protocolManager,
protocolName: protocolName,
enableProtocolEncode: enableProtocolEncode,
cloneHeader: cloneHeader,
handlerTimeout: handlerTimeout,
protocol: protocol,
unpackerManager: newUnpackerManager(),
connLifecycleHooks: make([]lifecycle.ConnectionLifecycleHook, 0),
bootCh: bootCh,
bootOnce: bootOnce,
msgHandler: newMessageHandler(logger, codecRegistry, codecResolverChain, defaultCodec, r, cloneHeader),
}
}
// OnBoot 服务器启动时调用
func (h *unifiedEventHandler) OnBoot(eng gnet.Engine) (action gnet.Action) {
h.engineMu.Lock()
h.engine = eng
h.engineSet = true
h.engineMu.Unlock()
if h.bootCh != nil && h.bootOnce != nil {
h.bootOnce.Do(func() {
select {
case h.bootCh <- nil:
default:
}
})
}
h.logger.Info("%s server booted", h.protocol.String())
return gnet.None
}
// setBootSignal 设置启动信号通道
func (h *unifiedEventHandler) setBootSignal(ch chan error, once *sync.Once) {
h.engineMu.Lock()
h.bootCh = ch
h.bootOnce = once
h.engineMu.Unlock()
}
// OnShutdown 服务器关闭时调用
func (h *unifiedEventHandler) OnShutdown(eng gnet.Engine) {
h.logger.Info("%s server shutdown", h.protocol.String())
}
// OnOpen 连接打开时调用TCP和Unix
func (h *unifiedEventHandler) OnOpen(c gnet.Conn) ([]byte, gnet.Action) {
if !h.protocol.IsConnectionOriented() {
// UDP下不会调用此方法
return nil, gnet.None
}
// TCP/Unix连接处理逻辑gnet的Unix连接和TCP连接使用相同的接口
remoteAddr := c.RemoteAddr().String()
conn := connection.NewConnection("", c)
connID := conn.ID()
if err := h.connManager.Add(conn); err != nil {
h.logger.Error("Failed to add connection: %v", err)
return nil, gnet.Close
}
// 执行连接生命周期钩子
for _, hook := range h.connLifecycleHooks {
if err := hook.OnOpen(connID, remoteAddr); err != nil {
h.logger.Error("OnOpen hook error: %v", err)
}
}
var protocol protocolpkg.Protocol
var protocolVersion string
if h.enableProtocolEncode {
protocol, _ = h.protocolManager.Get(h.protocolName, "")
}
var unpacker unpackerpkg.Unpacker
if protocol != nil {
unpacker = h.unpackerManager.getOrCreateUnpacker(connID, protocol)
}
connData := &connectionData{
conn: conn,
unpacker: unpacker,
protocol: protocol,
protocolVersion: protocolVersion,
}
setConnectionData(c, connData)
h.logger.Debug("Connection opened: %s from %s", connID, remoteAddr)
return nil, gnet.None
}
// OnClose 连接关闭时调用TCP和Unix
func (h *unifiedEventHandler) OnClose(c gnet.Conn, err error) (action gnet.Action) {
if !h.protocol.IsConnectionOriented() {
// UDP下不会调用此方法
return gnet.None
}
connData := getConnectionData(c)
if connData == nil {
return gnet.None
}
connID := connData.conn.ID()
// 执行连接生命周期钩子
for _, hook := range h.connLifecycleHooks {
if hookErr := hook.OnClose(connID, err); hookErr != nil {
h.logger.Error("OnClose hook error: %v", hookErr)
}
}
h.connManager.Remove(connID)
h.unpackerManager.removeUnpacker(connID)
h.logger.Debug("Connection closed: %s", connID)
return gnet.None
}
// OnTraffic 数据到达时调用TCP和Unix
func (h *unifiedEventHandler) OnTraffic(c gnet.Conn) (action gnet.Action) {
if !h.protocol.IsConnectionOriented() {
// UDP下不会调用此方法使用React代替
return gnet.None
}
return h.handleTraffic(c, nil)
}
// React UDP数据到达时调用仅UDP
func (h *unifiedEventHandler) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) {
if !h.protocol.IsDatagram() {
// TCP/Unix下不会调用此方法
return nil, gnet.None
}
if len(packet) == 0 {
return nil, gnet.None
}
// UDP处理直接调用handleTraffic由handleTraffic统一处理UDP连接创建
return nil, h.handleTraffic(c, packet)
}
// handleTraffic 统一的数据处理逻辑TCP和UDP共享
func (h *unifiedEventHandler) handleTraffic(c gnet.Conn, udpPacket []byte) gnet.Action {
var data []byte
var connData *connectionData
var connID string
var conn connection.ConnectionInterface
if h.protocol.IsDatagram() {
// UDP直接使用传入的数据包
data = udpPacket
if len(data) == 0 {
return gnet.None
}
// 创建UDP虚拟连接
remoteAddr := c.RemoteAddr()
if remoteAddr == nil {
h.logger.Debug("UDP packet with nil remote address, ignoring")
return gnet.None
}
udpAddr, ok := remoteAddr.(*net.UDPAddr)
if !ok {
addr, err := net.ResolveUDPAddr("udp", remoteAddr.String())
if err != nil {
h.logger.Debug("Failed to resolve UDP address: %v", err)
return gnet.None
}
udpAddr = addr
}
udpConn := connection.NewUDPConnection("", udpAddr, nil)
udpConn.SetAttribute("gnet_conn", c)
conn = udpConn
connID = udpConn.ID()
// UDP下创建临时连接数据
var protocol protocolpkg.Protocol
if h.enableProtocolEncode {
protocol, _ = h.protocolManager.Get(h.protocolName, "")
}
connData = &connectionData{
conn: conn,
unpacker: nil, // UDP通常不需要unpacker
protocol: protocol,
protocolVersion: "",
}
} else {
// TCP从gnet buffer读取数据
connData = getConnectionData(c)
if connData == nil {
h.logger.Error("Connection data not found in context")
return gnet.Close
}
connID = connData.conn.ID()
conn = connData.conn
inBuffer := c.InboundBuffered()
if inBuffer == 0 {
return gnet.None
}
peekData, _ := c.Peek(-1)
if len(peekData) == 0 {
return gnet.None
}
data = peekData
}
// 获取协议和unpacker
protocol := connData.protocol
unpacker := connData.unpacker
// 版本识别仅面向连接的协议UDP通常不需要
if h.protocol.IsConnectionOriented() && h.enableProtocolEncode && connData.protocolVersion == "" {
if identifiedVersion, err := h.identifyProtocolVersion(connData, data, c); err == nil && identifiedVersion != "" {
newProtocol, err := h.protocolManager.Get(h.protocolName, identifiedVersion)
if err == nil {
connData.protocol = newProtocol
connData.protocolVersion = identifiedVersion
if newProtocol != nil {
connData.unpacker = h.unpackerManager.getOrCreateUnpacker(connID, newProtocol)
}
protocol = newProtocol
unpacker = connData.unpacker
}
}
}
// 处理数据拆包仅面向连接的协议需要UDP数据包已经是完整的
var messages [][]byte
var totalProcessed int
if h.protocol.IsConnectionOriented() {
// 使用统一的拆包处理函数
var err error
messages, totalProcessed, _, err = processDataWithUnpacker(data, unpacker)
if err != nil {
c.Discard(c.InboundBuffered())
return gnet.Close
}
if len(messages) == 0 {
// 没有完整消息等待更多数据gnet会自动保留数据在缓冲区中
return gnet.None
}
} else {
// UDP数据包已经是完整的不需要拆包
messages = [][]byte{data}
totalProcessed = len(data)
}
// 处理每个完整的消息
ctxConn := toContextConnection(conn)
for _, message := range messages {
h.msgHandler.handleMessageWithContext(
context.Background(),
ctxConn,
message,
protocol,
h.codecRegistry,
h.handlerTimeout,
)
}
// 丢弃已处理的数据(仅面向连接的协议)
if h.protocol.IsConnectionOriented() && totalProcessed > 0 {
c.Discard(totalProcessed)
}
return gnet.None
}
// OnTick 定时器触发时调用
func (h *unifiedEventHandler) OnTick() (delay time.Duration, action gnet.Action) {
h.connManager.CleanupInactive(30 * time.Second)
return 10 * time.Second, gnet.None
}
// getEngine 获取engine引用用于优雅关闭
func (h *unifiedEventHandler) getEngine() (gnet.Engine, bool) {
h.engineMu.RLock()
defer h.engineMu.RUnlock()
return h.engine, h.engineSet
}
// identifyProtocolVersion 识别协议版本仅TCP
func (h *unifiedEventHandler) identifyProtocolVersion(connData *connectionData, data []byte, c gnet.Conn) (string, error) {
protocol := connData.protocol
if protocol == nil && h.enableProtocolEncode {
var err error
protocol, err = h.protocolManager.Get(h.protocolName, "")
if err != nil {
return "", fmt.Errorf("failed to get default protocol: %w", err)
}
if protocol == nil {
return "", nil
}
}
if protocol != nil {
header, _, err := protocol.Decode(data)
if err != nil {
return "", nil
}
if header != nil {
versionVal := header.Get("version")
if versionVal != nil {
var version string
switch v := versionVal.(type) {
case string:
version = v
case byte:
version = fmt.Sprintf("%d.0", v)
default:
version = fmt.Sprintf("%v", v)
}
if version != "" {
_, err := h.protocolManager.Get(h.protocolName, version)
if err == nil {
return version, nil
}
}
}
}
}
identifier := h.protocolManager.GetVersionIdentifier(h.protocolName)
if identifier == nil {
return "", nil
}
ctx := context.Background()
version, err := h.protocolManager.IdentifyVersion(h.protocolName, data, ctx)
if err != nil {
return "", nil
}
return version, nil
}