|
|
package server
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"io"
|
|
|
"net/http"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/internal/connection"
|
|
|
"github.com/noahlann/nnet/internal/logger"
|
|
|
codecpkg "github.com/noahlann/nnet/pkg/codec"
|
|
|
"github.com/noahlann/nnet/pkg/config"
|
|
|
"github.com/noahlann/nnet/pkg/errors"
|
|
|
"github.com/noahlann/nnet/pkg/health"
|
|
|
"github.com/noahlann/nnet/pkg/lifecycle"
|
|
|
metricspkg "github.com/noahlann/nnet/pkg/metrics"
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
routerpkg "github.com/noahlann/nnet/pkg/router"
|
|
|
)
|
|
|
|
|
|
// Server TCP服务器(使用统一的gnetServer)
|
|
|
type Server struct {
|
|
|
*gnetServer
|
|
|
metrics metricspkg.Metrics
|
|
|
healthChecker health.Checker
|
|
|
serverLifecycleHooks []lifecycle.ServerLifecycleHook
|
|
|
connLifecycleHooks []lifecycle.ConnectionLifecycleHook
|
|
|
}
|
|
|
|
|
|
// NewServer 创建新服务器(基于统一的gnetServer,TCP模式)
|
|
|
func NewServer(cfg *config.Config) (*Server, error) {
|
|
|
// 创建统一的gnet服务器(TCP模式)
|
|
|
gnetSrv, err := newGnetServer(cfg, ProtocolTCP)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
return newServerWithGnet(gnetSrv, cfg), nil
|
|
|
}
|
|
|
|
|
|
func newServerWithGnet(gnetSrv *gnetServer, cfg *config.Config) *Server {
|
|
|
metrics := metricspkg.NewMetrics()
|
|
|
healthChecker := health.NewChecker()
|
|
|
|
|
|
if cfg != nil {
|
|
|
_ = healthChecker.Register("connections", &connectionHealthCheck{
|
|
|
connManager: gnetSrv.connManager,
|
|
|
maxConns: cfg.MaxConnections,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
s := &Server{
|
|
|
gnetServer: gnetSrv,
|
|
|
metrics: metrics,
|
|
|
healthChecker: healthChecker,
|
|
|
serverLifecycleHooks: make([]lifecycle.ServerLifecycleHook, 0),
|
|
|
connLifecycleHooks: make([]lifecycle.ConnectionLifecycleHook, 0),
|
|
|
}
|
|
|
|
|
|
if gnetSrv.eventHandler != nil {
|
|
|
gnetSrv.eventHandler.connLifecycleHooks = s.connLifecycleHooks
|
|
|
}
|
|
|
|
|
|
return s
|
|
|
}
|
|
|
|
|
|
// Start 启动服务器(阻塞)
|
|
|
func (s *Server) Start() error {
|
|
|
return s.StartAsync()
|
|
|
}
|
|
|
|
|
|
// StartAsync 异步启动服务器
|
|
|
func (s *Server) StartAsync() error {
|
|
|
// 执行服务器生命周期钩子
|
|
|
for _, hook := range s.serverLifecycleHooks {
|
|
|
if err := hook.OnInit(); err != nil {
|
|
|
return errors.New("failed to execute OnInit hook").WithCause(err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 执行OnStart钩子
|
|
|
for _, hook := range s.serverLifecycleHooks {
|
|
|
if err := hook.OnStart(); err != nil {
|
|
|
s.gnetServer.logger.Error("OnStart hook error: %v", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 使用gnetServer的Start方法
|
|
|
return s.gnetServer.Start()
|
|
|
}
|
|
|
|
|
|
// Stop 停止服务器(优雅关闭)
|
|
|
func (s *Server) Stop() error {
|
|
|
s.mu.Lock()
|
|
|
if !s.started {
|
|
|
s.mu.Unlock()
|
|
|
return errors.ErrServerNotStarted
|
|
|
}
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
s.gnetServer.logger.Info("Starting graceful shutdown...")
|
|
|
|
|
|
// 1. 停止接受新连接(通过停止gnet引擎)
|
|
|
if s.gnetServer.eventHandler != nil {
|
|
|
engine, engineSet := s.gnetServer.eventHandler.getEngine()
|
|
|
if engineSet {
|
|
|
// gnet v2中,Engine接口可能没有Stop方法
|
|
|
// 我们需要通过其他方式停止服务器
|
|
|
// 暂时使用context取消来停止服务器
|
|
|
s.gnetServer.logger.Info("Stopping gnet engine...")
|
|
|
_ = engine // 暂时不使用,等待gnet API支持
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 2. 获取关闭超时时间
|
|
|
shutdownTimeout := s.gnetServer.config.ShutdownTimeout
|
|
|
if shutdownTimeout <= 0 {
|
|
|
shutdownTimeout = 30 * time.Second
|
|
|
}
|
|
|
|
|
|
// 3. 等待现有连接完成处理(设置超时)
|
|
|
s.gnetServer.logger.Info("Waiting for connections to close (timeout: %v)...", shutdownTimeout)
|
|
|
startTime := time.Now()
|
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
timeout := time.NewTimer(shutdownTimeout)
|
|
|
defer timeout.Stop()
|
|
|
|
|
|
for {
|
|
|
select {
|
|
|
case <-timeout.C:
|
|
|
// 超时,强制关闭所有连接
|
|
|
s.gnetServer.logger.Warn("Shutdown timeout reached, forcing connection closure")
|
|
|
s.forceCloseAllConnections()
|
|
|
goto done
|
|
|
case <-ticker.C:
|
|
|
// 检查连接数
|
|
|
connCount := s.gnetServer.connManager.Count()
|
|
|
if connCount == 0 {
|
|
|
s.gnetServer.logger.Info("All connections closed (took %v)", time.Since(startTime))
|
|
|
goto done
|
|
|
}
|
|
|
// 继续等待
|
|
|
}
|
|
|
}
|
|
|
|
|
|
done:
|
|
|
|
|
|
// 4. 执行OnStop钩子
|
|
|
for _, hook := range s.serverLifecycleHooks {
|
|
|
if err := hook.OnStop(); err != nil {
|
|
|
s.gnetServer.logger.Error("OnStop hook error: %v", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 5. 使用gnetServer的Stop方法
|
|
|
if err := s.gnetServer.Stop(); err != nil {
|
|
|
s.gnetServer.logger.Error("Failed to stop gnet server: %v", err)
|
|
|
}
|
|
|
|
|
|
s.gnetServer.logger.Info("Server stopped gracefully")
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// forceCloseAllConnections 强制关闭所有连接
|
|
|
func (s *Server) forceCloseAllConnections() {
|
|
|
conns := s.gnetServer.connManager.GetAll()
|
|
|
s.gnetServer.logger.Warn("Force closing %d connections", len(conns))
|
|
|
for _, conn := range conns {
|
|
|
if err := conn.Close(); err != nil {
|
|
|
s.gnetServer.logger.Error("Failed to close connection %s: %v", conn.ID(), err)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Router 获取路由器
|
|
|
func (s *Server) Router() routerpkg.Router {
|
|
|
return s.gnetServer.router
|
|
|
}
|
|
|
|
|
|
// ConnectionManager 获取连接管理器
|
|
|
func (s *Server) ConnectionManager() connection.ManagerInterface {
|
|
|
return s.gnetServer.connManager
|
|
|
}
|
|
|
|
|
|
// Config 获取配置
|
|
|
func (s *Server) Config() *config.Config {
|
|
|
return s.gnetServer.config
|
|
|
}
|
|
|
|
|
|
// Logger 获取日志器
|
|
|
func (s *Server) Logger() logger.Logger {
|
|
|
return s.gnetServer.logger
|
|
|
}
|
|
|
|
|
|
// Started 检查服务器是否已启动
|
|
|
func (s *Server) Started() bool {
|
|
|
return s.gnetServer.Started()
|
|
|
}
|
|
|
|
|
|
// Metrics 获取Metrics
|
|
|
func (s *Server) Metrics() metricspkg.Metrics {
|
|
|
return s.metrics
|
|
|
}
|
|
|
|
|
|
// HealthChecker 获取健康检查器
|
|
|
func (s *Server) HealthChecker() health.Checker {
|
|
|
return s.healthChecker
|
|
|
}
|
|
|
|
|
|
// ExportMetrics 以Prometheus文本格式导出当前指标。
|
|
|
func (s *Server) ExportMetrics(w io.Writer) error {
|
|
|
exporter := metricspkg.NewPrometheusExporter(s.metrics)
|
|
|
return exporter.Export(w)
|
|
|
}
|
|
|
|
|
|
// MetricsHandler 返回默认的Metrics HTTP处理器。
|
|
|
func (s *Server) MetricsHandler() http.Handler {
|
|
|
return http.HandlerFunc(s.handleMetrics)
|
|
|
}
|
|
|
|
|
|
// HealthHandler 返回默认的健康检查HTTP处理器。
|
|
|
func (s *Server) HealthHandler() http.Handler {
|
|
|
return http.HandlerFunc(s.handleHealth)
|
|
|
}
|
|
|
|
|
|
// RegisterServerLifecycleHook 注册服务器生命周期钩子
|
|
|
func (s *Server) RegisterServerLifecycleHook(hook lifecycle.ServerLifecycleHook) {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
s.serverLifecycleHooks = append(s.serverLifecycleHooks, hook)
|
|
|
}
|
|
|
|
|
|
// RegisterConnectionLifecycleHook 注册连接生命周期钩子
|
|
|
func (s *Server) RegisterConnectionLifecycleHook(hook lifecycle.ConnectionLifecycleHook) {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
s.connLifecycleHooks = append(s.connLifecycleHooks, hook)
|
|
|
if s.gnetServer.eventHandler != nil {
|
|
|
s.gnetServer.eventHandler.connLifecycleHooks = s.connLifecycleHooks
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// CodecRegistry 获取编解码器注册表
|
|
|
func (s *Server) CodecRegistry() codecpkg.Registry {
|
|
|
return s.gnetServer.CodecRegistry()
|
|
|
}
|
|
|
|
|
|
// ProtocolManager 获取协议管理器
|
|
|
func (s *Server) ProtocolManager() protocolpkg.Manager {
|
|
|
return s.gnetServer.ProtocolManager()
|
|
|
}
|
|
|
|
|
|
// handleMetrics 处理Metrics请求
|
|
|
func (s *Server) handleMetrics(w http.ResponseWriter, r *http.Request) {
|
|
|
exporter := metricspkg.NewPrometheusExporter(s.metrics)
|
|
|
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
|
|
|
if err := exporter.Export(w); err != nil {
|
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// handleHealth 处理健康检查请求
|
|
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
|
|
status, results := s.healthChecker.Check()
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
|
|
response := map[string]interface{}{
|
|
|
"status": status,
|
|
|
"results": results,
|
|
|
}
|
|
|
|
|
|
switch status {
|
|
|
case health.StatusHealthy, health.StatusDegraded:
|
|
|
w.WriteHeader(http.StatusOK)
|
|
|
default:
|
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
|
}
|
|
|
|
|
|
json.NewEncoder(w).Encode(response)
|
|
|
}
|
|
|
|
|
|
// connectionHealthCheck 连接数健康检查
|
|
|
type connectionHealthCheck struct {
|
|
|
connManager connection.ManagerInterface
|
|
|
maxConns int
|
|
|
}
|
|
|
|
|
|
func (c *connectionHealthCheck) Check() (health.Status, string) {
|
|
|
count := c.connManager.Count()
|
|
|
if c.maxConns <= 0 {
|
|
|
return health.StatusHealthy, fmt.Sprintf("connections: %d", count)
|
|
|
}
|
|
|
percentage := float64(count) / float64(c.maxConns) * 100
|
|
|
|
|
|
if percentage >= 90 {
|
|
|
return health.StatusUnhealthy, fmt.Sprintf("connections at %d%% capacity", int(percentage))
|
|
|
} else if percentage >= 70 {
|
|
|
return health.StatusDegraded, fmt.Sprintf("connections at %d%% capacity", int(percentage))
|
|
|
}
|
|
|
return health.StatusHealthy, fmt.Sprintf("connections: %d/%d", count, c.maxConns)
|
|
|
}
|