You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

369 lines
9.0 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package server
import (
"context"
"crypto/tls"
"sync"
"time"
"github.com/noahlann/nnet/internal/connection"
"github.com/noahlann/nnet/internal/logger"
internalrouter "github.com/noahlann/nnet/internal/router"
codecpkg "github.com/noahlann/nnet/pkg/codec"
"github.com/noahlann/nnet/pkg/config"
"github.com/noahlann/nnet/pkg/errors"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
routerpkg "github.com/noahlann/nnet/pkg/router"
"github.com/panjf2000/gnet/v2"
)
// TransportProtocol 传输层协议类型
type TransportProtocol int
const (
ProtocolTCP TransportProtocol = iota
ProtocolUDP
ProtocolUnix
)
// String 返回协议名称
func (p TransportProtocol) String() string {
switch p {
case ProtocolTCP:
return "TCP"
case ProtocolUDP:
return "UDP"
case ProtocolUnix:
return "Unix"
default:
return "Unknown"
}
}
// NetworkString 返回gnet网络协议字符串
func (p TransportProtocol) NetworkString() string {
switch p {
case ProtocolTCP:
return "tcp"
case ProtocolUDP:
return "udp"
case ProtocolUnix:
return "unix"
default:
return "tcp"
}
}
// IsConnectionOriented 是否为面向连接的协议TCP、Unix
func (p TransportProtocol) IsConnectionOriented() bool {
return p == ProtocolTCP || p == ProtocolUnix
}
// IsDatagram 是否为数据报协议UDP
func (p TransportProtocol) IsDatagram() bool {
return p == ProtocolUDP
}
// gnetServer 统一的gnet服务器基类支持TCP、UDP、Unix
const startupWaitTimeout = 5 * time.Second
type gnetServer struct {
config *config.Config
logger logger.Logger
connManager connection.ManagerInterface
router routerpkg.Router
codecRegistry codecpkg.Registry
codecResolverChain *codecpkg.ResolverChain
protocolManager protocolpkg.Manager
eventHandler *unifiedEventHandler
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
started bool
stopCh chan struct{}
stopped bool
stopMu sync.Mutex
tlsConfig *tls.Config
protocol TransportProtocol // 传输层协议类型
bootCh chan error
bootOnce *sync.Once
}
// newGnetServer 创建统一的gnet服务器
func newGnetServer(cfg *config.Config, protocol TransportProtocol) (*gnetServer, error) {
if cfg == nil {
cfg = config.DefaultConfig()
}
if err := cfg.Validate(); err != nil {
return nil, err
}
// 创建日志器
var logConfig *logger.Config
if cfg.Logger != nil {
logConfig = &logger.Config{
Level: cfg.Logger.Level,
Format: cfg.Logger.Format,
Output: cfg.Logger.Output,
}
} else {
logConfig = &logger.Config{
Level: "info",
Format: "text",
Output: "stdout",
}
}
log := logger.New(logConfig)
// 创建连接管理器
shardCount := cfg.ConnectionManagerShards
if shardCount <= 0 {
shardCount = 16
}
connManager := connection.NewShardedManager(cfg.MaxConnections, shardCount)
// 创建路由器
r := internalrouter.NewRouter()
// 创建编解码器注册表
codecRegistry := initCodecRegistry(cfg)
// 创建协议管理器
protocolManager := initProtocolManager()
// 创建Codec解析器链
configHelper := newServerConfigHelper(cfg)
defaultCodec := configHelper.DefaultCodec()
codecResolverChain := codecpkg.NewResolverChain(codecRegistry, defaultCodec)
// 加载TLS配置如果启用
var tlsConfig *tls.Config
if cfg.TLSEnabled && cfg.TLS != nil {
var err error
tlsConfig, err = loadTLSConfig(cfg.TLS)
if err != nil {
return nil, errors.New("failed to load TLS config").WithCause(err)
}
// 对于UDP需要DTLS配置这里暂时使用TLS配置后续可以扩展
if protocol == ProtocolUDP {
// DTLS需要特殊处理这里先使用TLS配置
log.Info("DTLS support will be added in future version")
}
}
// 创建事件处理器
handlerTimeout := configHelper.HandlerTimeout()
cloneHeader := configHelper.CloneHeader()
enableProtocolEncode := configHelper.IsProtocolEncodeEnabled()
protocolName := cfg.ApplicationProtocol
bootCh := make(chan error, 1)
bootOnce := &sync.Once{}
eventHandler := newUnifiedEventHandler(connManager, r, log, codecRegistry, codecResolverChain, defaultCodec, protocolManager, protocolName, enableProtocolEncode, cloneHeader, handlerTimeout, protocol, bootCh, bootOnce)
ctx, cancel := context.WithCancel(context.Background())
return &gnetServer{
config: cfg,
logger: log,
connManager: connManager,
router: r,
codecRegistry: codecRegistry,
codecResolverChain: codecResolverChain,
protocolManager: protocolManager,
eventHandler: eventHandler,
ctx: ctx,
cancel: cancel,
stopCh: make(chan struct{}),
tlsConfig: tlsConfig,
protocol: protocol,
bootCh: bootCh,
bootOnce: bootOnce,
}, nil
}
// Start 启动服务器
func (s *gnetServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.started {
return errors.ErrServerAlreadyStarted
}
// 解析地址
addr := s.config.Addr
protocolStr := s.protocol.NetworkString()
// 格式化地址
addr = formatGnetAddr(addr, protocolStr)
s.logger.Info("Starting %s server on %s", s.protocol.String(), addr)
// 重置启动同步器
s.bootCh = make(chan error, 1)
s.bootOnce = &sync.Once{}
s.eventHandler.setBootSignal(s.bootCh, s.bootOnce)
// 构建gnet选项
options := []gnet.Option{
gnet.WithMulticore(s.config.Multicore),
gnet.WithReadBufferCap(s.config.ReadBufferSize),
gnet.WithWriteBufferCap(s.config.WriteBufferSize),
}
// 添加TLS支持如果启用
// 注意gnet v2的TLS支持可能需要通过其他方式实现
// 这里先记录TLS配置后续可以在事件处理器中处理TLS握手
if s.tlsConfig != nil {
if s.protocol.IsConnectionOriented() {
// TCP/Unix TLS支持gnet v2可能需要通过其他方式配置
// TODO: 实现gnet v2的TLS支持
s.logger.Info("TLS enabled, but gnet v2 TLS integration needs to be implemented")
} else {
// UDP DTLS支持需要第三方库如pion/dtls
s.logger.Warn("DTLS support is not yet implemented, TLS config will be ignored for UDP")
}
}
// 在goroutine中启动服务器
go func() {
err := gnet.Run(s.eventHandler, addr, options...)
if err != nil {
s.mu.Lock()
s.started = false
s.mu.Unlock()
if s.bootOnce != nil && s.bootCh != nil {
s.bootOnce.Do(func() {
select {
case s.bootCh <- err:
default:
}
})
}
s.logger.Error("Server error: %v", err)
}
s.stopMu.Lock()
if !s.stopped {
close(s.stopCh)
s.stopped = true
}
s.stopMu.Unlock()
}()
// 等待服务器启动
select {
case bootErr := <-s.bootCh:
if bootErr != nil {
return errors.New("failed to start server").WithCause(bootErr)
}
case <-time.After(startupWaitTimeout):
return errors.New("server start timeout")
}
s.started = true
s.logger.Info("%s server started on %s", s.protocol.String(), addr)
return nil
}
// Stop 停止服务器
func (s *gnetServer) Stop() error {
s.mu.Lock()
if !s.started {
s.mu.Unlock()
return errors.ErrServerNotStarted
}
s.mu.Unlock()
s.logger.Info("Stopping %s server...", s.protocol.String())
// 取消context
s.cancel()
// 等待服务器停止
select {
case <-s.stopCh:
s.logger.Info("%s server stopped", s.protocol.String())
case <-time.After(5 * time.Second):
s.logger.Warn("%s server stop timeout", s.protocol.String())
}
s.mu.Lock()
s.started = false
s.mu.Unlock()
return nil
}
// formatGnetAddr 格式化gnet地址
func formatGnetAddr(addr, protocolStr string) string {
prefix := protocolStr + "://"
if len(addr) > len(prefix) && addr[:len(prefix)] == prefix {
return addr
}
// Unix Domain Socket 特殊处理
if protocolStr == "unix" {
// 移除 unix:// 前缀(如果存在)
if len(addr) > 7 && addr[:7] == "unix://" {
return addr
}
// 如果没有前缀直接返回Unix socket 是文件路径)
return prefix + addr
}
// TCP/UDP 处理
if len(addr) > 0 && addr[0] == ':' {
return prefix + addr
}
if addr == "" || addr == ":" {
return prefix + ":6995"
}
// 尝试提取端口
lastColon := -1
for i := len(addr) - 1; i >= 0; i-- {
if addr[i] == ':' {
lastColon = i
break
}
}
if lastColon >= 0 {
port := addr[lastColon:]
return prefix + port
}
return prefix + ":" + addr
}
// Router 获取路由器
func (s *gnetServer) Router() routerpkg.Router {
return s.router
}
// ConnectionManager 获取连接管理器
func (s *gnetServer) ConnectionManager() connection.ManagerInterface {
return s.connManager
}
// Config 获取配置
func (s *gnetServer) Config() *config.Config {
return s.config
}
// Started 检查服务器是否已启动
func (s *gnetServer) Started() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.started
}
// CodecRegistry 获取编解码器注册表
func (s *gnetServer) CodecRegistry() codecpkg.Registry {
return s.codecRegistry
}
// ProtocolManager 获取协议管理器
func (s *gnetServer) ProtocolManager() protocolpkg.Manager {
return s.protocolManager
}