You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
4.6 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package server
import (
"context"
"fmt"
"time"
"github.com/noahlann/nnet/internal/logger"
internalrequest "github.com/noahlann/nnet/internal/request"
codecpkg "github.com/noahlann/nnet/pkg/codec"
ctxpkg "github.com/noahlann/nnet/pkg/context"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
requestpkg "github.com/noahlann/nnet/pkg/request"
routerpkg "github.com/noahlann/nnet/pkg/router"
)
// messageHandler 消息处理器配置
type messageHandler struct {
logger logger.Logger
codecRegistry codecpkg.Registry
codecResolverChain *codecpkg.ResolverChain
defaultCodec string
router routerpkg.Router
cloneHeader bool
}
// newMessageHandler 创建消息处理器
func newMessageHandler(
logger logger.Logger,
codecRegistry codecpkg.Registry,
codecResolverChain *codecpkg.ResolverChain,
defaultCodec string,
router routerpkg.Router,
cloneHeader bool,
) *messageHandler {
return &messageHandler{
logger: logger,
codecRegistry: codecRegistry,
codecResolverChain: codecResolverChain,
defaultCodec: defaultCodec,
router: router,
cloneHeader: cloneHeader,
}
}
// handleMessage 处理单个消息通用函数TCP/UDP/串口共享)
func (mh *messageHandler) handleMessage(
ctx ctxpkg.Context,
message []byte,
protocol protocolpkg.Protocol,
) error {
// 1. 协议解码
if protocol != nil {
if err := parseProtocolHeader(ctx.Request(), message, protocol); err != nil {
mh.logger.Debug("Failed to parse protocol header: %v", err)
} else {
// 设置响应header共享或拷贝
if reqHeader := ctx.Request().Header(); reqHeader != nil {
if mh.cloneHeader {
ctx.Response().SetHeader(reqHeader.Clone())
} else {
ctx.Response().SetHeader(reqHeader)
}
}
}
}
// 2. Codec解析和预解码
var header protocolpkg.FrameHeader
if ctx.Request() != nil {
header = ctx.Request().Header()
}
resolvedCodec, err := mh.codecResolverChain.ResolveForDecode(ctx, message, header)
if err != nil {
mh.logger.Error("Failed to resolve codec: %v", err)
return fmt.Errorf("codec resolve error: %w", err)
}
// 预解码请求体
if reqObj := ctx.Request(); reqObj != nil {
bodyBytes := reqObj.DataBytes()
if len(bodyBytes) == 0 {
bodyBytes = message
}
if err := mh.preDecodeRequestBody(reqObj, bodyBytes, resolvedCodec); err != nil {
mh.logger.Debug("Failed to pre-decode request body: %v", err)
// 预解码失败不影响继续处理
}
}
// 3. 路由匹配
matchInput := routerpkg.MatchInput{Raw: message}
if reqObj := ctx.Request(); reqObj != nil {
matchInput.Header = reqObj.Header()
matchInput.DataBytes = reqObj.DataBytes()
matchInput.Data = reqObj.Data()
}
route, handler, err := mh.router.Match(matchInput, ctx)
if err != nil {
mh.logger.Warn("Route not found for: %s", string(matchInput.Raw))
return fmt.Errorf("route not found: %w", err)
}
// 4. 强类型解码(基于路由)
if route != nil {
if err := parseRequestBodyWithCodec(ctx.Request(), route, protocol, resolvedCodec); err != nil {
mh.logger.Error("Failed to parse request body: %v", err)
return fmt.Errorf("failed to parse request body: %w", err)
}
}
// 5. 执行handler
if err := handler(ctx); err != nil {
mh.logger.Error("Handler error: %v", err)
return fmt.Errorf("handler error: %w", err)
}
return nil
}
// preDecodeRequestBody 预解码请求体(用于路由匹配)
func (mh *messageHandler) preDecodeRequestBody(req requestpkg.Request, bodyBytes []byte, codec codecpkg.Codec) error {
reqImpl := internalrequest.AsRequestSetter(req)
if reqImpl == nil {
return nil
}
name := codec.Name()
if name == "binary" || name == "plain" {
reqImpl.SetData(bodyBytes)
return nil
}
// 尝试解码为通用类型
var generic interface{}
if err := codec.Decode(bodyBytes, &generic); err != nil {
return err
}
reqImpl.SetData(generic)
return nil
}
// handleMessageWithContext 处理消息并创建上下文(带错误响应)
func (mh *messageHandler) handleMessageWithContext(
parentCtx context.Context,
conn ctxpkg.Connection,
message []byte,
protocol protocolpkg.Protocol,
codecRegistry codecpkg.Registry,
handlerTimeout time.Duration,
) {
ctx, cancel := createContext(parentCtx, conn, message, protocol, codecRegistry, mh.codecResolverChain, mh.defaultCodec, mh.cloneHeader, handlerTimeout)
if cancel != nil {
defer cancel()
}
// 处理消息
if err := mh.handleMessage(ctx, message, protocol); err != nil {
// 根据错误类型决定是否发送错误响应
errorMsg := fmt.Sprintf("Error: %v\n", err)
_ = ctx.Response().WriteBytes([]byte(errorMsg))
}
}