package server import ( "context" "fmt" "time" "github.com/noahlann/nnet/internal/logger" internalrequest "github.com/noahlann/nnet/internal/request" codecpkg "github.com/noahlann/nnet/pkg/codec" ctxpkg "github.com/noahlann/nnet/pkg/context" protocolpkg "github.com/noahlann/nnet/pkg/protocol" requestpkg "github.com/noahlann/nnet/pkg/request" routerpkg "github.com/noahlann/nnet/pkg/router" ) // messageHandler 消息处理器配置 type messageHandler struct { logger logger.Logger codecRegistry codecpkg.Registry codecResolverChain *codecpkg.ResolverChain defaultCodec string router routerpkg.Router cloneHeader bool } // newMessageHandler 创建消息处理器 func newMessageHandler( logger logger.Logger, codecRegistry codecpkg.Registry, codecResolverChain *codecpkg.ResolverChain, defaultCodec string, router routerpkg.Router, cloneHeader bool, ) *messageHandler { return &messageHandler{ logger: logger, codecRegistry: codecRegistry, codecResolverChain: codecResolverChain, defaultCodec: defaultCodec, router: router, cloneHeader: cloneHeader, } } // handleMessage 处理单个消息(通用函数,TCP/UDP/串口共享) func (mh *messageHandler) handleMessage( ctx ctxpkg.Context, message []byte, protocol protocolpkg.Protocol, ) error { // 1. 协议解码 if protocol != nil { if err := parseProtocolHeader(ctx.Request(), message, protocol); err != nil { mh.logger.Debug("Failed to parse protocol header: %v", err) } else { // 设置响应header(共享或拷贝) if reqHeader := ctx.Request().Header(); reqHeader != nil { if mh.cloneHeader { ctx.Response().SetHeader(reqHeader.Clone()) } else { ctx.Response().SetHeader(reqHeader) } } } } // 2. Codec解析和预解码 var header protocolpkg.FrameHeader if ctx.Request() != nil { header = ctx.Request().Header() } resolvedCodec, err := mh.codecResolverChain.ResolveForDecode(ctx, message, header) if err != nil { mh.logger.Error("Failed to resolve codec: %v", err) return fmt.Errorf("codec resolve error: %w", err) } // 预解码请求体 if reqObj := ctx.Request(); reqObj != nil { bodyBytes := reqObj.DataBytes() if len(bodyBytes) == 0 { bodyBytes = message } if err := mh.preDecodeRequestBody(reqObj, bodyBytes, resolvedCodec); err != nil { mh.logger.Debug("Failed to pre-decode request body: %v", err) // 预解码失败不影响继续处理 } } // 3. 路由匹配 matchInput := routerpkg.MatchInput{Raw: message} if reqObj := ctx.Request(); reqObj != nil { matchInput.Header = reqObj.Header() matchInput.DataBytes = reqObj.DataBytes() matchInput.Data = reqObj.Data() } route, handler, err := mh.router.Match(matchInput, ctx) if err != nil { mh.logger.Warn("Route not found for: %s", string(matchInput.Raw)) return fmt.Errorf("route not found: %w", err) } // 4. 强类型解码(基于路由) if route != nil { if err := parseRequestBodyWithCodec(ctx.Request(), route, protocol, resolvedCodec); err != nil { mh.logger.Error("Failed to parse request body: %v", err) return fmt.Errorf("failed to parse request body: %w", err) } } // 5. 执行handler if err := handler(ctx); err != nil { mh.logger.Error("Handler error: %v", err) return fmt.Errorf("handler error: %w", err) } return nil } // preDecodeRequestBody 预解码请求体(用于路由匹配) func (mh *messageHandler) preDecodeRequestBody(req requestpkg.Request, bodyBytes []byte, codec codecpkg.Codec) error { reqImpl := internalrequest.AsRequestSetter(req) if reqImpl == nil { return nil } name := codec.Name() if name == "binary" || name == "plain" { reqImpl.SetData(bodyBytes) return nil } // 尝试解码为通用类型 var generic interface{} if err := codec.Decode(bodyBytes, &generic); err != nil { return err } reqImpl.SetData(generic) return nil } // handleMessageWithContext 处理消息并创建上下文(带错误响应) func (mh *messageHandler) handleMessageWithContext( parentCtx context.Context, conn ctxpkg.Connection, message []byte, protocol protocolpkg.Protocol, codecRegistry codecpkg.Registry, handlerTimeout time.Duration, ) { ctx, cancel := createContext(parentCtx, conn, message, protocol, codecRegistry, mh.codecResolverChain, mh.defaultCodec, mh.cloneHeader, handlerTimeout) if cancel != nil { defer cancel() } // 处理消息 if err := mh.handleMessage(ctx, message, protocol); err != nil { // 根据错误类型决定是否发送错误响应 errorMsg := fmt.Sprintf("Error: %v\n", err) _ = ctx.Response().WriteBytes([]byte(errorMsg)) } }