You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

412 lines
10 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package server
import (
"context"
"encoding/json"
"io"
"net/http"
"sync"
"github.com/noahlann/nnet/internal/connection"
"github.com/noahlann/nnet/internal/logger"
internalrouter "github.com/noahlann/nnet/internal/router"
codecpkg "github.com/noahlann/nnet/pkg/codec"
"github.com/noahlann/nnet/pkg/config"
"github.com/noahlann/nnet/pkg/errors"
"github.com/noahlann/nnet/pkg/health"
metricspkg "github.com/noahlann/nnet/pkg/metrics"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
routerpkg "github.com/noahlann/nnet/pkg/router"
"go.bug.st/serial"
)
// SerialServer 串口服务器
type SerialServer struct {
config *config.Config
logger logger.Logger
connManager connection.ManagerInterface
router routerpkg.Router
port serial.Port
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
started bool
codecRegistry codecpkg.Registry
codecResolverChain *codecpkg.ResolverChain
defaultCodec string
protocolManager protocolpkg.Manager
unpackerManager *unpackerManager
metrics metricspkg.Metrics
healthChecker health.Checker
msgHandler *messageHandler
}
// NewSerialServer 创建串口服务器
func NewSerialServer(cfg *config.Config) (*SerialServer, error) {
if cfg == nil {
cfg = config.DefaultConfig()
}
if err := cfg.Validate(); err != nil {
return nil, err
}
// 创建日志器
var logConfig *logger.Config
if cfg.Logger != nil {
logConfig = &logger.Config{
Level: cfg.Logger.Level,
Format: cfg.Logger.Format,
Output: cfg.Logger.Output,
}
} else {
logConfig = &logger.Config{
Level: "info",
Format: "text",
Output: "stdout",
}
}
log := logger.New(logConfig)
// 创建连接管理器(使用分段锁优化)
shardCount := cfg.ConnectionManagerShards
if shardCount <= 0 {
shardCount = 16 // 默认16个分片
}
connManager := connection.NewShardedManager(cfg.MaxConnections, shardCount)
// 创建路由器
r := internalrouter.NewRouter()
// 创建编解码器注册表
codecRegistry := initCodecRegistry(cfg)
// 创建协议管理器
protocolManager := initProtocolManager()
// 创建Codec解析器链用户需要自己添加resolver
configHelper := newServerConfigHelper(cfg)
defaultCodec := configHelper.DefaultCodec()
codecResolverChain := codecpkg.NewResolverChain(codecRegistry, defaultCodec)
metrics := metricspkg.NewMetrics()
healthChecker := health.NewChecker()
_ = healthChecker.Register("connections", &connectionHealthCheck{
connManager: connManager,
maxConns: cfg.MaxConnections,
})
ctx, cancel := context.WithCancel(context.Background())
cloneHeader := configHelper.CloneHeader()
return &SerialServer{
config: cfg,
logger: log,
connManager: connManager,
router: r,
codecRegistry: codecRegistry,
codecResolverChain: codecResolverChain,
defaultCodec: defaultCodec,
protocolManager: protocolManager,
unpackerManager: newUnpackerManager(),
metrics: metrics,
healthChecker: healthChecker,
ctx: ctx,
cancel: cancel,
started: false,
msgHandler: newMessageHandler(log, codecRegistry, codecResolverChain, defaultCodec, r, cloneHeader),
}, nil
}
// Start 启动串口服务器
func (s *SerialServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.started {
return errors.ErrServerAlreadyStarted
}
// 解析串口地址
addr := s.config.Addr
if len(addr) >= 7 && addr[:7] == "serial:" {
addr = addr[7:]
}
// 获取串口配置
serialConfig := s.config.Serial
if serialConfig == nil {
serialConfig = &config.SerialConfig{
BaudRate: 9600,
DataBits: 8,
StopBits: 1,
Parity: "None",
ReadTimeout: 1000,
WriteTimeout: 1000,
}
}
// 验证和设置数据位范围5-8
dataBits := serialConfig.DataBits
if dataBits < 5 || dataBits > 8 {
dataBits = 8
}
// 验证和设置停止位范围1-2
stopBits := serialConfig.StopBits
if stopBits < 1 || stopBits > 2 {
stopBits = 1
}
// 设置校验位
var parity serial.Parity
switch serialConfig.Parity {
case "Odd":
parity = serial.OddParity
case "Even":
parity = serial.EvenParity
case "Mark":
parity = serial.MarkParity
case "Space":
parity = serial.SpaceParity
case "None":
parity = serial.NoParity
default:
parity = serial.NoParity
}
// 配置串口模式
mode := &serial.Mode{
BaudRate: serialConfig.BaudRate,
DataBits: dataBits,
StopBits: serial.StopBits(stopBits),
Parity: parity,
}
// 打开串口
port, err := serial.Open(addr, mode)
if err != nil {
return errors.New("failed to open serial port").WithCause(err)
}
s.port = port
s.started = true
s.logger.Info("Serial server started on %s", addr)
// 启动读取协程
go s.handleConnection()
return nil
}
// Stop 停止串口服务器
func (s *SerialServer) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.started {
return errors.ErrServerNotStarted
}
s.logger.Info("Stopping Serial server...")
s.cancel()
if s.port != nil {
if err := s.port.Close(); err != nil {
s.logger.Error("Failed to close serial port: %v", err)
}
}
s.started = false
s.logger.Info("Serial server stopped")
return nil
}
// handleConnection 处理串口连接
func (s *SerialServer) handleConnection() {
// 解析串口地址
addr := s.config.Addr
if len(addr) >= 7 && addr[:7] == "serial:" {
addr = addr[7:]
}
if len(addr) == 0 {
addr = "/dev/ttyUSB0" // 默认串口设备
}
// 创建串口连接包装器
// 使用地址作为连接ID
// 连接ID由NewSerialConnection自动生成xid
serialConn := connection.NewSerialConnection("", s.port)
connID := serialConn.ID()
// 添加到连接管理器
if err := s.connManager.Add(serialConn); err != nil {
s.logger.Error("Failed to add connection: %v", err)
return
}
s.metrics.IncConnections()
defer func() {
s.connManager.Remove(connID)
s.unpackerManager.removeUnpacker(connID)
s.metrics.DecConnections()
}()
// 获取协议和编解码器配置
configHelper := newServerConfigHelper(s.config)
enableProtocolEncode := configHelper.IsProtocolEncodeEnabled()
var protocol protocolpkg.Protocol
if enableProtocolEncode {
protocol, _ = s.protocolManager.Get(s.config.ApplicationProtocol, "")
}
// 获取或创建连接的unpacker
unpacker := s.unpackerManager.getOrCreateUnpacker(connID, protocol)
// 转换为Context需要的Connection接口
ctxConn := toContextConnection(serialConn)
buffer := make([]byte, s.config.ReadBufferSize)
for {
select {
case <-s.ctx.Done():
return
default:
n, err := s.port.Read(buffer)
if err != nil {
if err == io.EOF {
continue
}
s.logger.Debug("Serial read error: %v", err)
return
}
if n == 0 {
continue
}
s.metrics.AddBytesReceived(int64(n))
data := make([]byte, n)
copy(data, buffer[:n])
// 使用unpacker处理数据
messages, _, hasRemaining, err := processDataWithUnpacker(data, unpacker)
if err != nil {
s.logger.Error("Unpack error: %v", err)
s.metrics.IncErrors()
continue
}
// 如果没有完整消息,等待更多数据
if len(messages) == 0 {
if !hasRemaining {
// 没有剩余数据,可能是错误
s.logger.Debug("No messages and no remaining data")
}
continue
}
// 处理每个完整的消息
handlerTimeout := newServerConfigHelper(s.config).HandlerTimeout()
for _, message := range messages {
s.metrics.IncRequests()
s.msgHandler.handleMessageWithContext(
context.Background(),
ctxConn,
message,
protocol,
s.codecRegistry,
handlerTimeout,
)
// 注意metrics错误计数在handleMessageWithContext内部处理
// 如果需要更细粒度的错误处理可以在handleMessageWithContext中回调
}
}
}
}
// Router 获取路由器
func (s *SerialServer) Router() routerpkg.Router {
return s.router
}
// ConnectionManager 获取连接管理器
func (s *SerialServer) ConnectionManager() connection.ManagerInterface {
return s.connManager
}
// Config 获取配置
func (s *SerialServer) Config() *config.Config {
return s.config
}
// Started 检查服务器是否已启动
func (s *SerialServer) Started() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.started
}
// CodecRegistry 获取编解码器注册表
func (s *SerialServer) CodecRegistry() codecpkg.Registry {
return s.codecRegistry
}
// ProtocolManager 获取协议管理器
func (s *SerialServer) ProtocolManager() protocolpkg.Manager {
return s.protocolManager
}
// Metrics 返回指标收集器
func (s *SerialServer) Metrics() metricspkg.Metrics {
return s.metrics
}
// HealthChecker 返回健康检查器
func (s *SerialServer) HealthChecker() health.Checker {
return s.healthChecker
}
// ExportMetrics 以Prometheus格式导出指标
func (s *SerialServer) ExportMetrics(w io.Writer) error {
exporter := metricspkg.NewPrometheusExporter(s.metrics)
return exporter.Export(w)
}
// MetricsHandler 返回默认的Metrics HTTP处理器
func (s *SerialServer) MetricsHandler() http.Handler {
return http.HandlerFunc(s.handleMetrics)
}
// HealthHandler 返回默认的健康检查HTTP处理器
func (s *SerialServer) HealthHandler() http.Handler {
return http.HandlerFunc(s.handleHealth)
}
func (s *SerialServer) handleMetrics(w http.ResponseWriter, r *http.Request) {
exporter := metricspkg.NewPrometheusExporter(s.metrics)
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
if err := exporter.Export(w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func (s *SerialServer) handleHealth(w http.ResponseWriter, r *http.Request) {
status, results := s.healthChecker.Check()
w.Header().Set("Content-Type", "application/json")
response := map[string]interface{}{
"status": status,
"results": results,
}
switch status {
case health.StatusHealthy, health.StatusDegraded:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusServiceUnavailable)
}
_ = json.NewEncoder(w).Encode(response)
}