|
|
package integration
|
|
|
|
|
|
import (
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/pkg/nnet"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
// StartTestServer 启动测试服务器(简化版本,使用helper.go中的辅助函数)
|
|
|
func StartTestServer(t *testing.T, cfg *nnet.Config) *TestServer {
|
|
|
return StartTestServerWithRoutes(t, cfg, nil)
|
|
|
}
|
|
|
|
|
|
// NewTestClient 创建测试客户端
|
|
|
func NewTestClient(t *testing.T, addr string, cfg *nnet.ClientConfig) nnet.Client {
|
|
|
if cfg == nil {
|
|
|
cfg = &nnet.ClientConfig{}
|
|
|
}
|
|
|
cfg.Addr = addr
|
|
|
if cfg.ConnectTimeout == 0 {
|
|
|
cfg.ConnectTimeout = 3 * time.Second
|
|
|
}
|
|
|
if cfg.ReadTimeout == 0 {
|
|
|
cfg.ReadTimeout = 3 * time.Second
|
|
|
}
|
|
|
if cfg.WriteTimeout == 0 {
|
|
|
cfg.WriteTimeout = 3 * time.Second
|
|
|
}
|
|
|
|
|
|
client := nnet.NewClient(cfg)
|
|
|
require.NotNil(t, client, "Failed to create client")
|
|
|
|
|
|
return client
|
|
|
}
|
|
|
|
|
|
// ConnectTestClient 连接测试客户端
|
|
|
func ConnectTestClient(t *testing.T, client nnet.Client) {
|
|
|
err := client.Connect()
|
|
|
require.NoError(t, err, "Failed to connect client")
|
|
|
assert.True(t, client.IsConnected(), "Client should be connected")
|
|
|
}
|
|
|
|
|
|
// RequestWithTimeout 发送请求并等待响应
|
|
|
func RequestWithTimeout(t *testing.T, client nnet.Client, data []byte, timeout time.Duration) []byte {
|
|
|
if timeout == 0 {
|
|
|
timeout = 5 * time.Second // 默认超时时间
|
|
|
}
|
|
|
|
|
|
resp, err := client.Request(data, timeout)
|
|
|
if err != nil {
|
|
|
t.Logf("Request failed: %v, data: %q", err, string(data))
|
|
|
}
|
|
|
require.NoError(t, err, "Request should succeed")
|
|
|
return resp
|
|
|
}
|
|
|
|
|
|
// CleanupTestServer 清理测试服务器
|
|
|
func CleanupTestServer(t *testing.T, ts *TestServer) {
|
|
|
if ts != nil {
|
|
|
ts.Stop()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// CleanupTestClient 清理测试客户端
|
|
|
func CleanupTestClient(t *testing.T, client nnet.Client) {
|
|
|
if client != nil {
|
|
|
if client.IsConnected() {
|
|
|
client.Disconnect()
|
|
|
}
|
|
|
client.Close()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// WaitForServer 等待服务器启动
|
|
|
func WaitForServer(t *testing.T, server nnet.Server, timeout time.Duration) {
|
|
|
require.Eventually(t, func() bool {
|
|
|
return server.Started()
|
|
|
}, timeout, 100*time.Millisecond, "Server should start within timeout")
|
|
|
}
|
|
|
|
|
|
// ExtractPort 从地址中提取端口
|
|
|
func ExtractPort(addr string) string {
|
|
|
// 从 "tcp://:6995" 提取 ":6995" 或从 "tcp://127.0.0.1:6995" 提取 ":6995"
|
|
|
// 简化实现:假设地址格式正确
|
|
|
if len(addr) > 6 && addr[:6] == "tcp://" {
|
|
|
hostPort := addr[6:]
|
|
|
// 查找最后一个 ":"
|
|
|
lastColon := -1
|
|
|
for i := len(hostPort) - 1; i >= 0; i-- {
|
|
|
if hostPort[i] == ':' {
|
|
|
lastColon = i
|
|
|
break
|
|
|
}
|
|
|
}
|
|
|
if lastColon >= 0 {
|
|
|
return hostPort[lastColon:]
|
|
|
}
|
|
|
}
|
|
|
return ""
|
|
|
}
|