package integration import ( "fmt" "net" "sync" "testing" "time" "github.com/noahlann/nnet/pkg/nnet" "github.com/stretchr/testify/require" ) // StartTestServerWithRoutes 启动测试服务器并注册路由 // 这是一个辅助函数,简化测试代码 func StartTestServerWithRoutes(t *testing.T, cfg *nnet.Config, setupRoutes func(nnet.Server)) *TestServer { // 如果地址为空或使用端口0,使用随机端口 if cfg.Addr == "" || cfg.Addr == "tcp://:0" || cfg.Addr == "tcp://127.0.0.1:0" || cfg.Addr == "tcp://:6995" { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err, "Failed to get random port") port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg.Addr = fmt.Sprintf("tcp://127.0.0.1:%d", port) } // 创建服务器 server, err := nnet.NewServer(cfg) require.NoError(t, err, "Failed to create server") // 设置路由(在启动前) if setupRoutes != nil { setupRoutes(server) } ts := &TestServer{ Server: server, Addr: cfg.Addr, stopCh: make(chan struct{}), } // 异步启动服务器 ts.wg.Add(1) go func() { defer ts.wg.Done() if err := server.Start(); err != nil { t.Logf("Server error: %v", err) } }() // 等待服务器启动 require.Eventually(t, func() bool { return server.Started() }, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds") // 等待服务器准备好接收连接 require.Eventually(t, func() bool { addr := cfg.Addr if len(addr) > 6 && addr[:6] == "tcp://" { addr = addr[6:] } testConn, err := net.DialTimeout("tcp", addr, 500*time.Millisecond) if err != nil { return false } testConn.Close() return true }, 5*time.Second, 100*time.Millisecond, "Server should be ready to accept connections") // 额外等待,确保服务器完全准备好 time.Sleep(100 * time.Millisecond) return ts } // TestServer 测试服务器包装 type TestServer struct { Server nnet.Server Addr string stopCh chan struct{} stopOnce sync.Once wg sync.WaitGroup } // Stop 停止测试服务器 func (ts *TestServer) Stop() { ts.stopOnce.Do(func() { close(ts.stopCh) if ts.Server != nil { ts.Server.Stop() } // 等待服务器停止 ts.wg.Wait() }) }