package integration import ( "testing" "time" "github.com/noahlann/nnet/pkg/nnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // StartTestServer 启动测试服务器(简化版本,使用helper.go中的辅助函数) func StartTestServer(t *testing.T, cfg *nnet.Config) *TestServer { return StartTestServerWithRoutes(t, cfg, nil) } // NewTestClient 创建测试客户端 func NewTestClient(t *testing.T, addr string, cfg *nnet.ClientConfig) nnet.Client { if cfg == nil { cfg = &nnet.ClientConfig{} } cfg.Addr = addr if cfg.ConnectTimeout == 0 { cfg.ConnectTimeout = 3 * time.Second } if cfg.ReadTimeout == 0 { cfg.ReadTimeout = 3 * time.Second } if cfg.WriteTimeout == 0 { cfg.WriteTimeout = 3 * time.Second } client := nnet.NewClient(cfg) require.NotNil(t, client, "Failed to create client") return client } // ConnectTestClient 连接测试客户端 func ConnectTestClient(t *testing.T, client nnet.Client) { err := client.Connect() require.NoError(t, err, "Failed to connect client") assert.True(t, client.IsConnected(), "Client should be connected") } // RequestWithTimeout 发送请求并等待响应 func RequestWithTimeout(t *testing.T, client nnet.Client, data []byte, timeout time.Duration) []byte { if timeout == 0 { timeout = 5 * time.Second // 默认超时时间 } resp, err := client.Request(data, timeout) if err != nil { t.Logf("Request failed: %v, data: %q", err, string(data)) } require.NoError(t, err, "Request should succeed") return resp } // CleanupTestServer 清理测试服务器 func CleanupTestServer(t *testing.T, ts *TestServer) { if ts != nil { ts.Stop() } } // CleanupTestClient 清理测试客户端 func CleanupTestClient(t *testing.T, client nnet.Client) { if client != nil { if client.IsConnected() { client.Disconnect() } client.Close() } } // WaitForServer 等待服务器启动 func WaitForServer(t *testing.T, server nnet.Server, timeout time.Duration) { require.Eventually(t, func() bool { return server.Started() }, timeout, 100*time.Millisecond, "Server should start within timeout") } // ExtractPort 从地址中提取端口 func ExtractPort(addr string) string { // 从 "tcp://:6995" 提取 ":6995" 或从 "tcp://127.0.0.1:6995" 提取 ":6995" // 简化实现:假设地址格式正确 if len(addr) > 6 && addr[:6] == "tcp://" { hostPort := addr[6:] // 查找最后一个 ":" lastColon := -1 for i := len(hostPort) - 1; i >= 0; i-- { if hostPort[i] == ':' { lastColon = i break } } if lastColon >= 0 { return hostPort[lastColon:] } } return "" }