You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
2.6 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package integration
import (
"testing"
"time"
"github.com/noahlann/nnet/pkg/nnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// StartTestServer 启动测试服务器简化版本使用helper.go中的辅助函数
func StartTestServer(t *testing.T, cfg *nnet.Config) *TestServer {
return StartTestServerWithRoutes(t, cfg, nil)
}
// NewTestClient 创建测试客户端
func NewTestClient(t *testing.T, addr string, cfg *nnet.ClientConfig) nnet.Client {
if cfg == nil {
cfg = &nnet.ClientConfig{}
}
cfg.Addr = addr
if cfg.ConnectTimeout == 0 {
cfg.ConnectTimeout = 3 * time.Second
}
if cfg.ReadTimeout == 0 {
cfg.ReadTimeout = 3 * time.Second
}
if cfg.WriteTimeout == 0 {
cfg.WriteTimeout = 3 * time.Second
}
client := nnet.NewClient(cfg)
require.NotNil(t, client, "Failed to create client")
return client
}
// ConnectTestClient 连接测试客户端
func ConnectTestClient(t *testing.T, client nnet.Client) {
err := client.Connect()
require.NoError(t, err, "Failed to connect client")
assert.True(t, client.IsConnected(), "Client should be connected")
}
// RequestWithTimeout 发送请求并等待响应
func RequestWithTimeout(t *testing.T, client nnet.Client, data []byte, timeout time.Duration) []byte {
if timeout == 0 {
timeout = 5 * time.Second // 默认超时时间
}
resp, err := client.Request(data, timeout)
if err != nil {
t.Logf("Request failed: %v, data: %q", err, string(data))
}
require.NoError(t, err, "Request should succeed")
return resp
}
// CleanupTestServer 清理测试服务器
func CleanupTestServer(t *testing.T, ts *TestServer) {
if ts != nil {
ts.Stop()
}
}
// CleanupTestClient 清理测试客户端
func CleanupTestClient(t *testing.T, client nnet.Client) {
if client != nil {
if client.IsConnected() {
client.Disconnect()
}
client.Close()
}
}
// WaitForServer 等待服务器启动
func WaitForServer(t *testing.T, server nnet.Server, timeout time.Duration) {
require.Eventually(t, func() bool {
return server.Started()
}, timeout, 100*time.Millisecond, "Server should start within timeout")
}
// ExtractPort 从地址中提取端口
func ExtractPort(addr string) string {
// 从 "tcp://:6995" 提取 ":6995" 或从 "tcp://127.0.0.1:6995" 提取 ":6995"
// 简化实现:假设地址格式正确
if len(addr) > 6 && addr[:6] == "tcp://" {
hostPort := addr[6:]
// 查找最后一个 ":"
lastColon := -1
for i := len(hostPort) - 1; i >= 0; i-- {
if hostPort[i] == ':' {
lastColon = i
break
}
}
if lastColon >= 0 {
return hostPort[lastColon:]
}
}
return ""
}