You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
5.7 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package server
import (
"context"
"crypto/tls"
"time"
"github.com/noahlann/nnet/internal/request"
"github.com/noahlann/nnet/internal/response"
codecpkg "github.com/noahlann/nnet/pkg/codec"
"github.com/noahlann/nnet/pkg/config"
"github.com/noahlann/nnet/pkg/errors"
ctxpkg "github.com/noahlann/nnet/pkg/context"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
)
// ============================================================================
// 配置辅助函数
// ============================================================================
// serverConfigHelper 服务器配置辅助结构
// 用于集中管理和访问服务器配置,减少重复的配置读取逻辑
type serverConfigHelper struct {
cfg *config.Config
}
// newServerConfigHelper 创建服务器配置辅助结构
func newServerConfigHelper(cfg *config.Config) *serverConfigHelper {
return &serverConfigHelper{
cfg: cfg,
}
}
// DefaultCodec 获取默认编解码器名称
func (h *serverConfigHelper) DefaultCodec() string {
defaultCodec := "binary"
if h.cfg != nil && h.cfg.Codec != nil && h.cfg.Codec.DefaultCodec != "" {
defaultCodec = h.cfg.Codec.DefaultCodec
}
return defaultCodec
}
// IsProtocolEncodeEnabled 检查是否启用协议层编码
func (h *serverConfigHelper) IsProtocolEncodeEnabled() bool {
if h.cfg != nil && h.cfg.Codec != nil {
return h.cfg.Codec.EnableProtocolEncode
}
return false
}
// CloneHeader 获取是否拷贝header配置
func (h *serverConfigHelper) CloneHeader() bool {
if h.cfg != nil && h.cfg.Codec != nil {
return h.cfg.Codec.CloneHeader
}
return false // 默认false共享引用
}
// HandlerTimeout 获取处理器超时时间
func (h *serverConfigHelper) HandlerTimeout() time.Duration {
if h.cfg != nil && h.cfg.HandlerTimeout > 0 {
return h.cfg.HandlerTimeout
}
if h.cfg != nil && h.cfg.ReadTimeout > 0 {
return h.cfg.ReadTimeout
}
// 默认30秒如果为0表示不设置超时
return 30 * time.Second
}
// Config 获取原始配置(用于需要访问其他配置的场景)
func (h *serverConfigHelper) Config() *config.Config {
return h.cfg
}
// ============================================================================
// 上下文辅助函数
// ============================================================================
// createContext 创建新的上下文
func createContext(
parentCtx context.Context,
conn ctxpkg.Connection,
rawData []byte,
protocol protocolpkg.Protocol,
codecRegistry codecpkg.Registry,
resolverChain *codecpkg.ResolverChain,
defaultCodec string,
cloneHeader bool,
handlerTimeout time.Duration,
) (ctxpkg.Context, context.CancelFunc) {
// 创建Request对象
req := request.New(rawData, protocol)
// 创建Response对象需要connectionWriter接口和resolver chain
resp := response.NewWithResolverChain(newConnectionWriterAdapter(conn), protocol, codecRegistry, resolverChain, defaultCodec)
// 注意response的header会在parseProtocolHeader之后设置
// 因为此时request的header已经解析完成
// 如果有超时设置创建带超时的context
// 注意cancel函数需要被调用以避免context泄漏
// 但由于handler执行时间可能很长cancel会在context被GC时自动调用
var ctx context.Context
var cancel context.CancelFunc
if handlerTimeout > 0 {
ctx, cancel = context.WithTimeout(parentCtx, handlerTimeout)
} else {
ctx = parentCtx
}
// 创建Context
ctxObj := ctxpkg.New(ctx, conn, req, resp)
// 设置context到response供resolver使用
if respImpl, ok := resp.(interface{ SetContext(ctxpkg.Context) }); ok {
respImpl.SetContext(ctxObj)
}
return ctxObj, cancel
}
// connectionWriterAdapter 连接写入适配器
type connectionWriterAdapter struct {
conn ctxpkg.Connection
}
// newConnectionWriterAdapter 创建连接写入适配器
func newConnectionWriterAdapter(conn ctxpkg.Connection) *connectionWriterAdapter {
return &connectionWriterAdapter{
conn: conn,
}
}
// Write 写入数据
func (a *connectionWriterAdapter) Write(data []byte) error {
return a.conn.Write(data)
}
// ============================================================================
// 拆包器辅助函数
// ============================================================================
// processDataWithUnpacker 使用unpacker处理数据通用函数
// 返回:完整消息列表、已处理的数据量、是否有剩余数据
func processDataWithUnpacker(data []byte, unpacker unpackerpkg.Unpacker) ([][]byte, int, bool, error) {
if unpacker == nil {
// 没有unpacker直接返回原始数据
return [][]byte{data}, len(data), false, nil
}
// 使用unpacker拆包
messages, remaining, consumed, err := unpacker.Unpack(data)
if err != nil {
return nil, 0, false, err
}
// 如果没有完整消息,等待更多数据
if len(messages) == 0 {
return nil, 0, true, nil
}
// 使用unpacker返回的consumed值作为已处理的数据量100%准确)
totalProcessed := consumed
// 是否有剩余数据
hasRemaining := len(remaining) > 0
return messages, totalProcessed, hasRemaining, nil
}
// ============================================================================
// TLS辅助函数
// ============================================================================
// loadTLSConfig 加载TLS配置返回可供gnet使用的 tls.Config。
func loadTLSConfig(cfg *config.TLSConfig) (*tls.Config, error) {
if cfg == nil {
return nil, errors.New("TLS config is nil")
}
if cfg.CertFile == "" || cfg.KeyFile == "" {
return nil, errors.New("TLS cert file or key file is empty")
}
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
}, nil
}