package server import ( "context" "encoding/json" "io" "net/http" "sync" "github.com/noahlann/nnet/internal/connection" "github.com/noahlann/nnet/internal/logger" internalrouter "github.com/noahlann/nnet/internal/router" codecpkg "github.com/noahlann/nnet/pkg/codec" "github.com/noahlann/nnet/pkg/config" "github.com/noahlann/nnet/pkg/errors" "github.com/noahlann/nnet/pkg/health" metricspkg "github.com/noahlann/nnet/pkg/metrics" protocolpkg "github.com/noahlann/nnet/pkg/protocol" routerpkg "github.com/noahlann/nnet/pkg/router" "go.bug.st/serial" ) // SerialServer 串口服务器 type SerialServer struct { config *config.Config logger logger.Logger connManager connection.ManagerInterface router routerpkg.Router port serial.Port mu sync.RWMutex ctx context.Context cancel context.CancelFunc started bool codecRegistry codecpkg.Registry codecResolverChain *codecpkg.ResolverChain defaultCodec string protocolManager protocolpkg.Manager unpackerManager *unpackerManager metrics metricspkg.Metrics healthChecker health.Checker msgHandler *messageHandler } // NewSerialServer 创建串口服务器 func NewSerialServer(cfg *config.Config) (*SerialServer, error) { if cfg == nil { cfg = config.DefaultConfig() } if err := cfg.Validate(); err != nil { return nil, err } // 创建日志器 var logConfig *logger.Config if cfg.Logger != nil { logConfig = &logger.Config{ Level: cfg.Logger.Level, Format: cfg.Logger.Format, Output: cfg.Logger.Output, } } else { logConfig = &logger.Config{ Level: "info", Format: "text", Output: "stdout", } } log := logger.New(logConfig) // 创建连接管理器(使用分段锁优化) shardCount := cfg.ConnectionManagerShards if shardCount <= 0 { shardCount = 16 // 默认16个分片 } connManager := connection.NewShardedManager(cfg.MaxConnections, shardCount) // 创建路由器 r := internalrouter.NewRouter() // 创建编解码器注册表 codecRegistry := initCodecRegistry(cfg) // 创建协议管理器 protocolManager := initProtocolManager() // 创建Codec解析器链(用户需要自己添加resolver) configHelper := newServerConfigHelper(cfg) defaultCodec := configHelper.DefaultCodec() codecResolverChain := codecpkg.NewResolverChain(codecRegistry, defaultCodec) metrics := metricspkg.NewMetrics() healthChecker := health.NewChecker() _ = healthChecker.Register("connections", &connectionHealthCheck{ connManager: connManager, maxConns: cfg.MaxConnections, }) ctx, cancel := context.WithCancel(context.Background()) cloneHeader := configHelper.CloneHeader() return &SerialServer{ config: cfg, logger: log, connManager: connManager, router: r, codecRegistry: codecRegistry, codecResolverChain: codecResolverChain, defaultCodec: defaultCodec, protocolManager: protocolManager, unpackerManager: newUnpackerManager(), metrics: metrics, healthChecker: healthChecker, ctx: ctx, cancel: cancel, started: false, msgHandler: newMessageHandler(log, codecRegistry, codecResolverChain, defaultCodec, r, cloneHeader), }, nil } // Start 启动串口服务器 func (s *SerialServer) Start() error { s.mu.Lock() defer s.mu.Unlock() if s.started { return errors.ErrServerAlreadyStarted } // 解析串口地址 addr := s.config.Addr if len(addr) >= 7 && addr[:7] == "serial:" { addr = addr[7:] } // 获取串口配置 serialConfig := s.config.Serial if serialConfig == nil { serialConfig = &config.SerialConfig{ BaudRate: 9600, DataBits: 8, StopBits: 1, Parity: "None", ReadTimeout: 1000, WriteTimeout: 1000, } } // 验证和设置数据位(范围:5-8) dataBits := serialConfig.DataBits if dataBits < 5 || dataBits > 8 { dataBits = 8 } // 验证和设置停止位(范围:1-2) stopBits := serialConfig.StopBits if stopBits < 1 || stopBits > 2 { stopBits = 1 } // 设置校验位 var parity serial.Parity switch serialConfig.Parity { case "Odd": parity = serial.OddParity case "Even": parity = serial.EvenParity case "Mark": parity = serial.MarkParity case "Space": parity = serial.SpaceParity case "None": parity = serial.NoParity default: parity = serial.NoParity } // 配置串口模式 mode := &serial.Mode{ BaudRate: serialConfig.BaudRate, DataBits: dataBits, StopBits: serial.StopBits(stopBits), Parity: parity, } // 打开串口 port, err := serial.Open(addr, mode) if err != nil { return errors.New("failed to open serial port").WithCause(err) } s.port = port s.started = true s.logger.Info("Serial server started on %s", addr) // 启动读取协程 go s.handleConnection() return nil } // Stop 停止串口服务器 func (s *SerialServer) Stop() error { s.mu.Lock() defer s.mu.Unlock() if !s.started { return errors.ErrServerNotStarted } s.logger.Info("Stopping Serial server...") s.cancel() if s.port != nil { if err := s.port.Close(); err != nil { s.logger.Error("Failed to close serial port: %v", err) } } s.started = false s.logger.Info("Serial server stopped") return nil } // handleConnection 处理串口连接 func (s *SerialServer) handleConnection() { // 解析串口地址 addr := s.config.Addr if len(addr) >= 7 && addr[:7] == "serial:" { addr = addr[7:] } if len(addr) == 0 { addr = "/dev/ttyUSB0" // 默认串口设备 } // 创建串口连接包装器 // 使用地址作为连接ID // 连接ID由NewSerialConnection自动生成xid serialConn := connection.NewSerialConnection("", s.port) connID := serialConn.ID() // 添加到连接管理器 if err := s.connManager.Add(serialConn); err != nil { s.logger.Error("Failed to add connection: %v", err) return } s.metrics.IncConnections() defer func() { s.connManager.Remove(connID) s.unpackerManager.removeUnpacker(connID) s.metrics.DecConnections() }() // 获取协议和编解码器配置 configHelper := newServerConfigHelper(s.config) enableProtocolEncode := configHelper.IsProtocolEncodeEnabled() var protocol protocolpkg.Protocol if enableProtocolEncode { protocol, _ = s.protocolManager.Get(s.config.ApplicationProtocol, "") } // 获取或创建连接的unpacker unpacker := s.unpackerManager.getOrCreateUnpacker(connID, protocol) // 转换为Context需要的Connection接口 ctxConn := toContextConnection(serialConn) buffer := make([]byte, s.config.ReadBufferSize) for { select { case <-s.ctx.Done(): return default: n, err := s.port.Read(buffer) if err != nil { if err == io.EOF { continue } s.logger.Debug("Serial read error: %v", err) return } if n == 0 { continue } s.metrics.AddBytesReceived(int64(n)) data := make([]byte, n) copy(data, buffer[:n]) // 使用unpacker处理数据 messages, _, hasRemaining, err := processDataWithUnpacker(data, unpacker) if err != nil { s.logger.Error("Unpack error: %v", err) s.metrics.IncErrors() continue } // 如果没有完整消息,等待更多数据 if len(messages) == 0 { if !hasRemaining { // 没有剩余数据,可能是错误 s.logger.Debug("No messages and no remaining data") } continue } // 处理每个完整的消息 handlerTimeout := newServerConfigHelper(s.config).HandlerTimeout() for _, message := range messages { s.metrics.IncRequests() s.msgHandler.handleMessageWithContext( context.Background(), ctxConn, message, protocol, s.codecRegistry, handlerTimeout, ) // 注意:metrics错误计数在handleMessageWithContext内部处理 // 如果需要更细粒度的错误处理,可以在handleMessageWithContext中回调 } } } } // Router 获取路由器 func (s *SerialServer) Router() routerpkg.Router { return s.router } // ConnectionManager 获取连接管理器 func (s *SerialServer) ConnectionManager() connection.ManagerInterface { return s.connManager } // Config 获取配置 func (s *SerialServer) Config() *config.Config { return s.config } // Started 检查服务器是否已启动 func (s *SerialServer) Started() bool { s.mu.RLock() defer s.mu.RUnlock() return s.started } // CodecRegistry 获取编解码器注册表 func (s *SerialServer) CodecRegistry() codecpkg.Registry { return s.codecRegistry } // ProtocolManager 获取协议管理器 func (s *SerialServer) ProtocolManager() protocolpkg.Manager { return s.protocolManager } // Metrics 返回指标收集器 func (s *SerialServer) Metrics() metricspkg.Metrics { return s.metrics } // HealthChecker 返回健康检查器 func (s *SerialServer) HealthChecker() health.Checker { return s.healthChecker } // ExportMetrics 以Prometheus格式导出指标 func (s *SerialServer) ExportMetrics(w io.Writer) error { exporter := metricspkg.NewPrometheusExporter(s.metrics) return exporter.Export(w) } // MetricsHandler 返回默认的Metrics HTTP处理器 func (s *SerialServer) MetricsHandler() http.Handler { return http.HandlerFunc(s.handleMetrics) } // HealthHandler 返回默认的健康检查HTTP处理器 func (s *SerialServer) HealthHandler() http.Handler { return http.HandlerFunc(s.handleHealth) } func (s *SerialServer) handleMetrics(w http.ResponseWriter, r *http.Request) { exporter := metricspkg.NewPrometheusExporter(s.metrics) w.Header().Set("Content-Type", "text/plain; version=0.0.4") if err := exporter.Export(w); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } func (s *SerialServer) handleHealth(w http.ResponseWriter, r *http.Request) { status, results := s.healthChecker.Check() w.Header().Set("Content-Type", "application/json") response := map[string]interface{}{ "status": status, "results": results, } switch status { case health.StatusHealthy, health.StatusDegraded: w.WriteHeader(http.StatusOK) default: w.WriteHeader(http.StatusServiceUnavailable) } _ = json.NewEncoder(w).Encode(response) }