|
|
package server
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/internal/connection"
|
|
|
"github.com/noahlann/nnet/internal/logger"
|
|
|
codecpkg "github.com/noahlann/nnet/pkg/codec"
|
|
|
"github.com/noahlann/nnet/pkg/lifecycle"
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
routerpkg "github.com/noahlann/nnet/pkg/router"
|
|
|
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
|
|
|
"github.com/panjf2000/gnet/v2"
|
|
|
)
|
|
|
|
|
|
// unifiedEventHandler 统一的事件处理器(支持TCP和UDP)
|
|
|
type unifiedEventHandler struct {
|
|
|
connManager connection.ManagerInterface
|
|
|
router routerpkg.Router
|
|
|
logger logger.Logger
|
|
|
codecRegistry codecpkg.Registry
|
|
|
codecResolverChain *codecpkg.ResolverChain
|
|
|
defaultCodec string
|
|
|
protocolManager protocolpkg.Manager
|
|
|
protocolName string
|
|
|
enableProtocolEncode bool
|
|
|
cloneHeader bool
|
|
|
handlerTimeout time.Duration
|
|
|
protocol TransportProtocol
|
|
|
unpackerManager *unpackerManager
|
|
|
connLifecycleHooks []lifecycle.ConnectionLifecycleHook
|
|
|
engine gnet.Engine
|
|
|
engineMu sync.RWMutex
|
|
|
engineSet bool
|
|
|
bootCh chan error
|
|
|
bootOnce *sync.Once
|
|
|
msgHandler *messageHandler
|
|
|
}
|
|
|
|
|
|
// newUnifiedEventHandler 创建统一的事件处理器
|
|
|
func newUnifiedEventHandler(connManager connection.ManagerInterface, r routerpkg.Router, logger logger.Logger, codecRegistry codecpkg.Registry, codecResolverChain *codecpkg.ResolverChain, defaultCodec string, protocolManager protocolpkg.Manager, protocolName string, enableProtocolEncode bool, cloneHeader bool, handlerTimeout time.Duration, protocol TransportProtocol, bootCh chan error, bootOnce *sync.Once) *unifiedEventHandler {
|
|
|
return &unifiedEventHandler{
|
|
|
connManager: connManager,
|
|
|
router: r,
|
|
|
logger: logger,
|
|
|
codecRegistry: codecRegistry,
|
|
|
codecResolverChain: codecResolverChain,
|
|
|
defaultCodec: defaultCodec,
|
|
|
protocolManager: protocolManager,
|
|
|
protocolName: protocolName,
|
|
|
enableProtocolEncode: enableProtocolEncode,
|
|
|
cloneHeader: cloneHeader,
|
|
|
handlerTimeout: handlerTimeout,
|
|
|
protocol: protocol,
|
|
|
unpackerManager: newUnpackerManager(),
|
|
|
connLifecycleHooks: make([]lifecycle.ConnectionLifecycleHook, 0),
|
|
|
bootCh: bootCh,
|
|
|
bootOnce: bootOnce,
|
|
|
msgHandler: newMessageHandler(logger, codecRegistry, codecResolverChain, defaultCodec, r, cloneHeader),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// OnBoot 服务器启动时调用
|
|
|
func (h *unifiedEventHandler) OnBoot(eng gnet.Engine) (action gnet.Action) {
|
|
|
h.engineMu.Lock()
|
|
|
h.engine = eng
|
|
|
h.engineSet = true
|
|
|
h.engineMu.Unlock()
|
|
|
if h.bootCh != nil && h.bootOnce != nil {
|
|
|
h.bootOnce.Do(func() {
|
|
|
select {
|
|
|
case h.bootCh <- nil:
|
|
|
default:
|
|
|
}
|
|
|
})
|
|
|
}
|
|
|
h.logger.Info("%s server booted", h.protocol.String())
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
// setBootSignal 设置启动信号通道
|
|
|
func (h *unifiedEventHandler) setBootSignal(ch chan error, once *sync.Once) {
|
|
|
h.engineMu.Lock()
|
|
|
h.bootCh = ch
|
|
|
h.bootOnce = once
|
|
|
h.engineMu.Unlock()
|
|
|
}
|
|
|
|
|
|
// OnShutdown 服务器关闭时调用
|
|
|
func (h *unifiedEventHandler) OnShutdown(eng gnet.Engine) {
|
|
|
h.logger.Info("%s server shutdown", h.protocol.String())
|
|
|
}
|
|
|
|
|
|
// OnOpen 连接打开时调用(TCP和Unix)
|
|
|
func (h *unifiedEventHandler) OnOpen(c gnet.Conn) ([]byte, gnet.Action) {
|
|
|
if !h.protocol.IsConnectionOriented() {
|
|
|
// UDP下不会调用此方法
|
|
|
return nil, gnet.None
|
|
|
}
|
|
|
|
|
|
// TCP/Unix连接处理逻辑(gnet的Unix连接和TCP连接使用相同的接口)
|
|
|
remoteAddr := c.RemoteAddr().String()
|
|
|
conn := connection.NewConnection("", c)
|
|
|
connID := conn.ID()
|
|
|
|
|
|
if err := h.connManager.Add(conn); err != nil {
|
|
|
h.logger.Error("Failed to add connection: %v", err)
|
|
|
return nil, gnet.Close
|
|
|
}
|
|
|
|
|
|
// 执行连接生命周期钩子
|
|
|
for _, hook := range h.connLifecycleHooks {
|
|
|
if err := hook.OnOpen(connID, remoteAddr); err != nil {
|
|
|
h.logger.Error("OnOpen hook error: %v", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
var protocol protocolpkg.Protocol
|
|
|
var protocolVersion string
|
|
|
if h.enableProtocolEncode {
|
|
|
protocol, _ = h.protocolManager.Get(h.protocolName, "")
|
|
|
}
|
|
|
|
|
|
var unpacker unpackerpkg.Unpacker
|
|
|
if protocol != nil {
|
|
|
unpacker = h.unpackerManager.getOrCreateUnpacker(connID, protocol)
|
|
|
}
|
|
|
|
|
|
connData := &connectionData{
|
|
|
conn: conn,
|
|
|
unpacker: unpacker,
|
|
|
protocol: protocol,
|
|
|
protocolVersion: protocolVersion,
|
|
|
}
|
|
|
setConnectionData(c, connData)
|
|
|
|
|
|
h.logger.Debug("Connection opened: %s from %s", connID, remoteAddr)
|
|
|
return nil, gnet.None
|
|
|
}
|
|
|
|
|
|
// OnClose 连接关闭时调用(TCP和Unix)
|
|
|
func (h *unifiedEventHandler) OnClose(c gnet.Conn, err error) (action gnet.Action) {
|
|
|
if !h.protocol.IsConnectionOriented() {
|
|
|
// UDP下不会调用此方法
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
connData := getConnectionData(c)
|
|
|
if connData == nil {
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
connID := connData.conn.ID()
|
|
|
|
|
|
// 执行连接生命周期钩子
|
|
|
for _, hook := range h.connLifecycleHooks {
|
|
|
if hookErr := hook.OnClose(connID, err); hookErr != nil {
|
|
|
h.logger.Error("OnClose hook error: %v", hookErr)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
h.connManager.Remove(connID)
|
|
|
h.unpackerManager.removeUnpacker(connID)
|
|
|
|
|
|
h.logger.Debug("Connection closed: %s", connID)
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
// OnTraffic 数据到达时调用(TCP和Unix)
|
|
|
func (h *unifiedEventHandler) OnTraffic(c gnet.Conn) (action gnet.Action) {
|
|
|
if !h.protocol.IsConnectionOriented() {
|
|
|
// UDP下不会调用此方法,使用React代替
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
return h.handleTraffic(c, nil)
|
|
|
}
|
|
|
|
|
|
// React UDP数据到达时调用(仅UDP)
|
|
|
func (h *unifiedEventHandler) React(packet []byte, c gnet.Conn) (out []byte, action gnet.Action) {
|
|
|
if !h.protocol.IsDatagram() {
|
|
|
// TCP/Unix下不会调用此方法
|
|
|
return nil, gnet.None
|
|
|
}
|
|
|
|
|
|
if len(packet) == 0 {
|
|
|
return nil, gnet.None
|
|
|
}
|
|
|
|
|
|
// UDP处理:直接调用handleTraffic,由handleTraffic统一处理UDP连接创建
|
|
|
return nil, h.handleTraffic(c, packet)
|
|
|
}
|
|
|
|
|
|
// handleTraffic 统一的数据处理逻辑(TCP和UDP共享)
|
|
|
func (h *unifiedEventHandler) handleTraffic(c gnet.Conn, udpPacket []byte) gnet.Action {
|
|
|
var data []byte
|
|
|
var connData *connectionData
|
|
|
var connID string
|
|
|
var conn connection.ConnectionInterface
|
|
|
|
|
|
if h.protocol.IsDatagram() {
|
|
|
// UDP:直接使用传入的数据包
|
|
|
data = udpPacket
|
|
|
if len(data) == 0 {
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
// 创建UDP虚拟连接
|
|
|
remoteAddr := c.RemoteAddr()
|
|
|
if remoteAddr == nil {
|
|
|
h.logger.Debug("UDP packet with nil remote address, ignoring")
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
udpAddr, ok := remoteAddr.(*net.UDPAddr)
|
|
|
if !ok {
|
|
|
addr, err := net.ResolveUDPAddr("udp", remoteAddr.String())
|
|
|
if err != nil {
|
|
|
h.logger.Debug("Failed to resolve UDP address: %v", err)
|
|
|
return gnet.None
|
|
|
}
|
|
|
udpAddr = addr
|
|
|
}
|
|
|
|
|
|
udpConn := connection.NewUDPConnection("", udpAddr, nil)
|
|
|
udpConn.SetAttribute("gnet_conn", c)
|
|
|
conn = udpConn
|
|
|
connID = udpConn.ID()
|
|
|
|
|
|
// UDP下创建临时连接数据
|
|
|
var protocol protocolpkg.Protocol
|
|
|
if h.enableProtocolEncode {
|
|
|
protocol, _ = h.protocolManager.Get(h.protocolName, "")
|
|
|
}
|
|
|
connData = &connectionData{
|
|
|
conn: conn,
|
|
|
unpacker: nil, // UDP通常不需要unpacker
|
|
|
protocol: protocol,
|
|
|
protocolVersion: "",
|
|
|
}
|
|
|
} else {
|
|
|
// TCP:从gnet buffer读取数据
|
|
|
connData = getConnectionData(c)
|
|
|
if connData == nil {
|
|
|
h.logger.Error("Connection data not found in context")
|
|
|
return gnet.Close
|
|
|
}
|
|
|
|
|
|
connID = connData.conn.ID()
|
|
|
conn = connData.conn
|
|
|
|
|
|
inBuffer := c.InboundBuffered()
|
|
|
if inBuffer == 0 {
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
peekData, _ := c.Peek(-1)
|
|
|
if len(peekData) == 0 {
|
|
|
return gnet.None
|
|
|
}
|
|
|
data = peekData
|
|
|
}
|
|
|
|
|
|
// 获取协议和unpacker
|
|
|
protocol := connData.protocol
|
|
|
unpacker := connData.unpacker
|
|
|
|
|
|
// 版本识别(仅面向连接的协议,UDP通常不需要)
|
|
|
if h.protocol.IsConnectionOriented() && h.enableProtocolEncode && connData.protocolVersion == "" {
|
|
|
if identifiedVersion, err := h.identifyProtocolVersion(connData, data, c); err == nil && identifiedVersion != "" {
|
|
|
newProtocol, err := h.protocolManager.Get(h.protocolName, identifiedVersion)
|
|
|
if err == nil {
|
|
|
connData.protocol = newProtocol
|
|
|
connData.protocolVersion = identifiedVersion
|
|
|
if newProtocol != nil {
|
|
|
connData.unpacker = h.unpackerManager.getOrCreateUnpacker(connID, newProtocol)
|
|
|
}
|
|
|
protocol = newProtocol
|
|
|
unpacker = connData.unpacker
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 处理数据:拆包(仅面向连接的协议需要,UDP数据包已经是完整的)
|
|
|
var messages [][]byte
|
|
|
var totalProcessed int
|
|
|
if h.protocol.IsConnectionOriented() {
|
|
|
// 使用统一的拆包处理函数
|
|
|
var err error
|
|
|
messages, totalProcessed, _, err = processDataWithUnpacker(data, unpacker)
|
|
|
if err != nil {
|
|
|
c.Discard(c.InboundBuffered())
|
|
|
return gnet.Close
|
|
|
}
|
|
|
if len(messages) == 0 {
|
|
|
// 没有完整消息,等待更多数据(gnet会自动保留数据在缓冲区中)
|
|
|
return gnet.None
|
|
|
}
|
|
|
} else {
|
|
|
// UDP数据包已经是完整的,不需要拆包
|
|
|
messages = [][]byte{data}
|
|
|
totalProcessed = len(data)
|
|
|
}
|
|
|
|
|
|
// 处理每个完整的消息
|
|
|
ctxConn := toContextConnection(conn)
|
|
|
for _, message := range messages {
|
|
|
h.msgHandler.handleMessageWithContext(
|
|
|
context.Background(),
|
|
|
ctxConn,
|
|
|
message,
|
|
|
protocol,
|
|
|
h.codecRegistry,
|
|
|
h.handlerTimeout,
|
|
|
)
|
|
|
}
|
|
|
|
|
|
// 丢弃已处理的数据(仅面向连接的协议)
|
|
|
if h.protocol.IsConnectionOriented() && totalProcessed > 0 {
|
|
|
c.Discard(totalProcessed)
|
|
|
}
|
|
|
|
|
|
return gnet.None
|
|
|
}
|
|
|
|
|
|
// OnTick 定时器触发时调用
|
|
|
func (h *unifiedEventHandler) OnTick() (delay time.Duration, action gnet.Action) {
|
|
|
h.connManager.CleanupInactive(30 * time.Second)
|
|
|
return 10 * time.Second, gnet.None
|
|
|
}
|
|
|
|
|
|
// getEngine 获取engine引用(用于优雅关闭)
|
|
|
func (h *unifiedEventHandler) getEngine() (gnet.Engine, bool) {
|
|
|
h.engineMu.RLock()
|
|
|
defer h.engineMu.RUnlock()
|
|
|
return h.engine, h.engineSet
|
|
|
}
|
|
|
|
|
|
// identifyProtocolVersion 识别协议版本(仅TCP)
|
|
|
func (h *unifiedEventHandler) identifyProtocolVersion(connData *connectionData, data []byte, c gnet.Conn) (string, error) {
|
|
|
protocol := connData.protocol
|
|
|
if protocol == nil && h.enableProtocolEncode {
|
|
|
var err error
|
|
|
protocol, err = h.protocolManager.Get(h.protocolName, "")
|
|
|
if err != nil {
|
|
|
return "", fmt.Errorf("failed to get default protocol: %w", err)
|
|
|
}
|
|
|
if protocol == nil {
|
|
|
return "", nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if protocol != nil {
|
|
|
header, _, err := protocol.Decode(data)
|
|
|
if err != nil {
|
|
|
return "", nil
|
|
|
}
|
|
|
if header != nil {
|
|
|
versionVal := header.Get("version")
|
|
|
if versionVal != nil {
|
|
|
var version string
|
|
|
switch v := versionVal.(type) {
|
|
|
case string:
|
|
|
version = v
|
|
|
case byte:
|
|
|
version = fmt.Sprintf("%d.0", v)
|
|
|
default:
|
|
|
version = fmt.Sprintf("%v", v)
|
|
|
}
|
|
|
if version != "" {
|
|
|
_, err := h.protocolManager.Get(h.protocolName, version)
|
|
|
if err == nil {
|
|
|
return version, nil
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
identifier := h.protocolManager.GetVersionIdentifier(h.protocolName)
|
|
|
if identifier == nil {
|
|
|
return "", nil
|
|
|
}
|
|
|
|
|
|
ctx := context.Background()
|
|
|
version, err := h.protocolManager.IdentifyVersion(h.protocolName, data, ctx)
|
|
|
if err != nil {
|
|
|
return "", nil
|
|
|
}
|
|
|
|
|
|
return version, nil
|
|
|
}
|