You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

304 lines
8.3 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package server
import (
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/noahlann/nnet/internal/connection"
"github.com/noahlann/nnet/internal/logger"
codecpkg "github.com/noahlann/nnet/pkg/codec"
"github.com/noahlann/nnet/pkg/config"
"github.com/noahlann/nnet/pkg/errors"
"github.com/noahlann/nnet/pkg/health"
"github.com/noahlann/nnet/pkg/lifecycle"
metricspkg "github.com/noahlann/nnet/pkg/metrics"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
routerpkg "github.com/noahlann/nnet/pkg/router"
)
// Server TCP服务器使用统一的gnetServer
type Server struct {
*gnetServer
metrics metricspkg.Metrics
healthChecker health.Checker
serverLifecycleHooks []lifecycle.ServerLifecycleHook
connLifecycleHooks []lifecycle.ConnectionLifecycleHook
}
// NewServer 创建新服务器基于统一的gnetServerTCP模式
func NewServer(cfg *config.Config) (*Server, error) {
// 创建统一的gnet服务器TCP模式
gnetSrv, err := newGnetServer(cfg, ProtocolTCP)
if err != nil {
return nil, err
}
return newServerWithGnet(gnetSrv, cfg), nil
}
func newServerWithGnet(gnetSrv *gnetServer, cfg *config.Config) *Server {
metrics := metricspkg.NewMetrics()
healthChecker := health.NewChecker()
if cfg != nil {
_ = healthChecker.Register("connections", &connectionHealthCheck{
connManager: gnetSrv.connManager,
maxConns: cfg.MaxConnections,
})
}
s := &Server{
gnetServer: gnetSrv,
metrics: metrics,
healthChecker: healthChecker,
serverLifecycleHooks: make([]lifecycle.ServerLifecycleHook, 0),
connLifecycleHooks: make([]lifecycle.ConnectionLifecycleHook, 0),
}
if gnetSrv.eventHandler != nil {
gnetSrv.eventHandler.connLifecycleHooks = s.connLifecycleHooks
}
return s
}
// Start 启动服务器(阻塞)
func (s *Server) Start() error {
return s.StartAsync()
}
// StartAsync 异步启动服务器
func (s *Server) StartAsync() error {
// 执行服务器生命周期钩子
for _, hook := range s.serverLifecycleHooks {
if err := hook.OnInit(); err != nil {
return errors.New("failed to execute OnInit hook").WithCause(err)
}
}
// 执行OnStart钩子
for _, hook := range s.serverLifecycleHooks {
if err := hook.OnStart(); err != nil {
s.gnetServer.logger.Error("OnStart hook error: %v", err)
}
}
// 使用gnetServer的Start方法
return s.gnetServer.Start()
}
// Stop 停止服务器(优雅关闭)
func (s *Server) Stop() error {
s.mu.Lock()
if !s.started {
s.mu.Unlock()
return errors.ErrServerNotStarted
}
s.mu.Unlock()
s.gnetServer.logger.Info("Starting graceful shutdown...")
// 1. 停止接受新连接通过停止gnet引擎
if s.gnetServer.eventHandler != nil {
engine, engineSet := s.gnetServer.eventHandler.getEngine()
if engineSet {
// gnet v2中Engine接口可能没有Stop方法
// 我们需要通过其他方式停止服务器
// 暂时使用context取消来停止服务器
s.gnetServer.logger.Info("Stopping gnet engine...")
_ = engine // 暂时不使用等待gnet API支持
}
}
// 2. 获取关闭超时时间
shutdownTimeout := s.gnetServer.config.ShutdownTimeout
if shutdownTimeout <= 0 {
shutdownTimeout = 30 * time.Second
}
// 3. 等待现有连接完成处理(设置超时)
s.gnetServer.logger.Info("Waiting for connections to close (timeout: %v)...", shutdownTimeout)
startTime := time.Now()
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
timeout := time.NewTimer(shutdownTimeout)
defer timeout.Stop()
for {
select {
case <-timeout.C:
// 超时,强制关闭所有连接
s.gnetServer.logger.Warn("Shutdown timeout reached, forcing connection closure")
s.forceCloseAllConnections()
goto done
case <-ticker.C:
// 检查连接数
connCount := s.gnetServer.connManager.Count()
if connCount == 0 {
s.gnetServer.logger.Info("All connections closed (took %v)", time.Since(startTime))
goto done
}
// 继续等待
}
}
done:
// 4. 执行OnStop钩子
for _, hook := range s.serverLifecycleHooks {
if err := hook.OnStop(); err != nil {
s.gnetServer.logger.Error("OnStop hook error: %v", err)
}
}
// 5. 使用gnetServer的Stop方法
if err := s.gnetServer.Stop(); err != nil {
s.gnetServer.logger.Error("Failed to stop gnet server: %v", err)
}
s.gnetServer.logger.Info("Server stopped gracefully")
return nil
}
// forceCloseAllConnections 强制关闭所有连接
func (s *Server) forceCloseAllConnections() {
conns := s.gnetServer.connManager.GetAll()
s.gnetServer.logger.Warn("Force closing %d connections", len(conns))
for _, conn := range conns {
if err := conn.Close(); err != nil {
s.gnetServer.logger.Error("Failed to close connection %s: %v", conn.ID(), err)
}
}
}
// Router 获取路由器
func (s *Server) Router() routerpkg.Router {
return s.gnetServer.router
}
// ConnectionManager 获取连接管理器
func (s *Server) ConnectionManager() connection.ManagerInterface {
return s.gnetServer.connManager
}
// Config 获取配置
func (s *Server) Config() *config.Config {
return s.gnetServer.config
}
// Logger 获取日志器
func (s *Server) Logger() logger.Logger {
return s.gnetServer.logger
}
// Started 检查服务器是否已启动
func (s *Server) Started() bool {
return s.gnetServer.Started()
}
// Metrics 获取Metrics
func (s *Server) Metrics() metricspkg.Metrics {
return s.metrics
}
// HealthChecker 获取健康检查器
func (s *Server) HealthChecker() health.Checker {
return s.healthChecker
}
// ExportMetrics 以Prometheus文本格式导出当前指标。
func (s *Server) ExportMetrics(w io.Writer) error {
exporter := metricspkg.NewPrometheusExporter(s.metrics)
return exporter.Export(w)
}
// MetricsHandler 返回默认的Metrics HTTP处理器。
func (s *Server) MetricsHandler() http.Handler {
return http.HandlerFunc(s.handleMetrics)
}
// HealthHandler 返回默认的健康检查HTTP处理器。
func (s *Server) HealthHandler() http.Handler {
return http.HandlerFunc(s.handleHealth)
}
// RegisterServerLifecycleHook 注册服务器生命周期钩子
func (s *Server) RegisterServerLifecycleHook(hook lifecycle.ServerLifecycleHook) {
s.mu.Lock()
defer s.mu.Unlock()
s.serverLifecycleHooks = append(s.serverLifecycleHooks, hook)
}
// RegisterConnectionLifecycleHook 注册连接生命周期钩子
func (s *Server) RegisterConnectionLifecycleHook(hook lifecycle.ConnectionLifecycleHook) {
s.mu.Lock()
defer s.mu.Unlock()
s.connLifecycleHooks = append(s.connLifecycleHooks, hook)
if s.gnetServer.eventHandler != nil {
s.gnetServer.eventHandler.connLifecycleHooks = s.connLifecycleHooks
}
}
// CodecRegistry 获取编解码器注册表
func (s *Server) CodecRegistry() codecpkg.Registry {
return s.gnetServer.CodecRegistry()
}
// ProtocolManager 获取协议管理器
func (s *Server) ProtocolManager() protocolpkg.Manager {
return s.gnetServer.ProtocolManager()
}
// handleMetrics 处理Metrics请求
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
exporter := metricspkg.NewPrometheusExporter(s.metrics)
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
if err := exporter.Export(w); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
// handleHealth 处理健康检查请求
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
status, results := s.healthChecker.Check()
w.Header().Set("Content-Type", "application/json")
response := map[string]interface{}{
"status": status,
"results": results,
}
switch status {
case health.StatusHealthy, health.StatusDegraded:
w.WriteHeader(http.StatusOK)
default:
w.WriteHeader(http.StatusServiceUnavailable)
}
json.NewEncoder(w).Encode(response)
}
// connectionHealthCheck 连接数健康检查
type connectionHealthCheck struct {
connManager connection.ManagerInterface
maxConns int
}
func (c *connectionHealthCheck) Check() (health.Status, string) {
count := c.connManager.Count()
if c.maxConns <= 0 {
return health.StatusHealthy, fmt.Sprintf("connections: %d", count)
}
percentage := float64(count) / float64(c.maxConns) * 100
if percentage >= 90 {
return health.StatusUnhealthy, fmt.Sprintf("connections at %d%% capacity", int(percentage))
} else if percentage >= 70 {
return health.StatusDegraded, fmt.Sprintf("connections at %d%% capacity", int(percentage))
}
return health.StatusHealthy, fmt.Sprintf("connections: %d/%d", count, c.maxConns)
}