package server import ( "encoding/json" "fmt" "io" "net/http" "time" "github.com/noahlann/nnet/internal/connection" "github.com/noahlann/nnet/internal/logger" codecpkg "github.com/noahlann/nnet/pkg/codec" "github.com/noahlann/nnet/pkg/config" "github.com/noahlann/nnet/pkg/errors" "github.com/noahlann/nnet/pkg/health" "github.com/noahlann/nnet/pkg/lifecycle" metricspkg "github.com/noahlann/nnet/pkg/metrics" protocolpkg "github.com/noahlann/nnet/pkg/protocol" routerpkg "github.com/noahlann/nnet/pkg/router" ) // Server TCP服务器(使用统一的gnetServer) type Server struct { *gnetServer metrics metricspkg.Metrics healthChecker health.Checker serverLifecycleHooks []lifecycle.ServerLifecycleHook connLifecycleHooks []lifecycle.ConnectionLifecycleHook } // NewServer 创建新服务器(基于统一的gnetServer,TCP模式) func NewServer(cfg *config.Config) (*Server, error) { // 创建统一的gnet服务器(TCP模式) gnetSrv, err := newGnetServer(cfg, ProtocolTCP) if err != nil { return nil, err } return newServerWithGnet(gnetSrv, cfg), nil } func newServerWithGnet(gnetSrv *gnetServer, cfg *config.Config) *Server { metrics := metricspkg.NewMetrics() healthChecker := health.NewChecker() if cfg != nil { _ = healthChecker.Register("connections", &connectionHealthCheck{ connManager: gnetSrv.connManager, maxConns: cfg.MaxConnections, }) } s := &Server{ gnetServer: gnetSrv, metrics: metrics, healthChecker: healthChecker, serverLifecycleHooks: make([]lifecycle.ServerLifecycleHook, 0), connLifecycleHooks: make([]lifecycle.ConnectionLifecycleHook, 0), } if gnetSrv.eventHandler != nil { gnetSrv.eventHandler.connLifecycleHooks = s.connLifecycleHooks } return s } // Start 启动服务器(阻塞) func (s *Server) Start() error { return s.StartAsync() } // StartAsync 异步启动服务器 func (s *Server) StartAsync() error { // 执行服务器生命周期钩子 for _, hook := range s.serverLifecycleHooks { if err := hook.OnInit(); err != nil { return errors.New("failed to execute OnInit hook").WithCause(err) } } // 执行OnStart钩子 for _, hook := range s.serverLifecycleHooks { if err := hook.OnStart(); err != nil { s.gnetServer.logger.Error("OnStart hook error: %v", err) } } // 使用gnetServer的Start方法 return s.gnetServer.Start() } // Stop 停止服务器(优雅关闭) func (s *Server) Stop() error { s.mu.Lock() if !s.started { s.mu.Unlock() return errors.ErrServerNotStarted } s.mu.Unlock() s.gnetServer.logger.Info("Starting graceful shutdown...") // 1. 停止接受新连接(通过停止gnet引擎) if s.gnetServer.eventHandler != nil { engine, engineSet := s.gnetServer.eventHandler.getEngine() if engineSet { // gnet v2中,Engine接口可能没有Stop方法 // 我们需要通过其他方式停止服务器 // 暂时使用context取消来停止服务器 s.gnetServer.logger.Info("Stopping gnet engine...") _ = engine // 暂时不使用,等待gnet API支持 } } // 2. 获取关闭超时时间 shutdownTimeout := s.gnetServer.config.ShutdownTimeout if shutdownTimeout <= 0 { shutdownTimeout = 30 * time.Second } // 3. 等待现有连接完成处理(设置超时) s.gnetServer.logger.Info("Waiting for connections to close (timeout: %v)...", shutdownTimeout) startTime := time.Now() ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() timeout := time.NewTimer(shutdownTimeout) defer timeout.Stop() for { select { case <-timeout.C: // 超时,强制关闭所有连接 s.gnetServer.logger.Warn("Shutdown timeout reached, forcing connection closure") s.forceCloseAllConnections() goto done case <-ticker.C: // 检查连接数 connCount := s.gnetServer.connManager.Count() if connCount == 0 { s.gnetServer.logger.Info("All connections closed (took %v)", time.Since(startTime)) goto done } // 继续等待 } } done: // 4. 执行OnStop钩子 for _, hook := range s.serverLifecycleHooks { if err := hook.OnStop(); err != nil { s.gnetServer.logger.Error("OnStop hook error: %v", err) } } // 5. 使用gnetServer的Stop方法 if err := s.gnetServer.Stop(); err != nil { s.gnetServer.logger.Error("Failed to stop gnet server: %v", err) } s.gnetServer.logger.Info("Server stopped gracefully") return nil } // forceCloseAllConnections 强制关闭所有连接 func (s *Server) forceCloseAllConnections() { conns := s.gnetServer.connManager.GetAll() s.gnetServer.logger.Warn("Force closing %d connections", len(conns)) for _, conn := range conns { if err := conn.Close(); err != nil { s.gnetServer.logger.Error("Failed to close connection %s: %v", conn.ID(), err) } } } // Router 获取路由器 func (s *Server) Router() routerpkg.Router { return s.gnetServer.router } // ConnectionManager 获取连接管理器 func (s *Server) ConnectionManager() connection.ManagerInterface { return s.gnetServer.connManager } // Config 获取配置 func (s *Server) Config() *config.Config { return s.gnetServer.config } // Logger 获取日志器 func (s *Server) Logger() logger.Logger { return s.gnetServer.logger } // Started 检查服务器是否已启动 func (s *Server) Started() bool { return s.gnetServer.Started() } // Metrics 获取Metrics func (s *Server) Metrics() metricspkg.Metrics { return s.metrics } // HealthChecker 获取健康检查器 func (s *Server) HealthChecker() health.Checker { return s.healthChecker } // ExportMetrics 以Prometheus文本格式导出当前指标。 func (s *Server) ExportMetrics(w io.Writer) error { exporter := metricspkg.NewPrometheusExporter(s.metrics) return exporter.Export(w) } // MetricsHandler 返回默认的Metrics HTTP处理器。 func (s *Server) MetricsHandler() http.Handler { return http.HandlerFunc(s.handleMetrics) } // HealthHandler 返回默认的健康检查HTTP处理器。 func (s *Server) HealthHandler() http.Handler { return http.HandlerFunc(s.handleHealth) } // RegisterServerLifecycleHook 注册服务器生命周期钩子 func (s *Server) RegisterServerLifecycleHook(hook lifecycle.ServerLifecycleHook) { s.mu.Lock() defer s.mu.Unlock() s.serverLifecycleHooks = append(s.serverLifecycleHooks, hook) } // RegisterConnectionLifecycleHook 注册连接生命周期钩子 func (s *Server) RegisterConnectionLifecycleHook(hook lifecycle.ConnectionLifecycleHook) { s.mu.Lock() defer s.mu.Unlock() s.connLifecycleHooks = append(s.connLifecycleHooks, hook) if s.gnetServer.eventHandler != nil { s.gnetServer.eventHandler.connLifecycleHooks = s.connLifecycleHooks } } // CodecRegistry 获取编解码器注册表 func (s *Server) CodecRegistry() codecpkg.Registry { return s.gnetServer.CodecRegistry() } // ProtocolManager 获取协议管理器 func (s *Server) ProtocolManager() protocolpkg.Manager { return s.gnetServer.ProtocolManager() } // handleMetrics 处理Metrics请求 func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) { exporter := metricspkg.NewPrometheusExporter(s.metrics) w.Header().Set("Content-Type", "text/plain; version=0.0.4") if err := exporter.Export(w); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } // handleHealth 处理健康检查请求 func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { status, results := s.healthChecker.Check() w.Header().Set("Content-Type", "application/json") response := map[string]interface{}{ "status": status, "results": results, } switch status { case health.StatusHealthy, health.StatusDegraded: w.WriteHeader(http.StatusOK) default: w.WriteHeader(http.StatusServiceUnavailable) } json.NewEncoder(w).Encode(response) } // connectionHealthCheck 连接数健康检查 type connectionHealthCheck struct { connManager connection.ManagerInterface maxConns int } func (c *connectionHealthCheck) Check() (health.Status, string) { count := c.connManager.Count() if c.maxConns <= 0 { return health.StatusHealthy, fmt.Sprintf("connections: %d", count) } percentage := float64(count) / float64(c.maxConns) * 100 if percentage >= 90 { return health.StatusUnhealthy, fmt.Sprintf("connections at %d%% capacity", int(percentage)) } else if percentage >= 70 { return health.StatusDegraded, fmt.Sprintf("connections at %d%% capacity", int(percentage)) } return health.StatusHealthy, fmt.Sprintf("connections: %d/%d", count, c.maxConns) }