You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

181 lines
4.3 KiB
Go

package integration
import (
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/noahlann/nnet/pkg/nnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConcurrentRequests 测试并发请求
func TestConcurrentRequests(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
srv.Router().RegisterString("test", func(ctx nnet.Context) error {
return ctx.Response().WriteBytes([]byte("ok\n"))
})
})
defer CleanupTestServer(t, ts)
client := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 并发发送多个请求
const numRequests = 10
var wg sync.WaitGroup
wg.Add(numRequests)
results := make([]bool, numRequests)
for i := 0; i < numRequests; i++ {
go func(idx int) {
defer wg.Done()
resp, err := client.Request([]byte("test"), 3*time.Second)
if err != nil {
t.Logf("Request %d failed: %v", idx, err)
results[idx] = false
} else {
results[idx] = string(resp) == "ok\n"
}
}(i)
}
wg.Wait()
// 验证所有请求都成功
successCount := 0
for i, success := range results {
if success {
successCount++
} else {
t.Logf("Request %d failed", i)
}
}
assert.Greater(t, successCount, numRequests/2, "At least half of requests should succeed")
t.Logf("Success rate: %d/%d", successCount, numRequests)
}
// TestConcurrentClients 测试并发客户端
func TestConcurrentClients(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
srv.Router().RegisterString("test", func(ctx nnet.Context) error {
return ctx.Response().WriteBytes([]byte("ok\n"))
})
})
defer CleanupTestServer(t, ts)
// 并发创建多个客户端
const numClients = 5
var wg sync.WaitGroup
wg.Add(numClients)
clients := make([]nnet.Client, numClients)
for i := 0; i < numClients; i++ {
go func(idx int) {
defer wg.Done()
client := NewTestClient(t, ts.Addr, nil)
err := client.Connect()
if err != nil {
t.Logf("Client %d failed to connect: %v", idx, err)
} else {
clients[idx] = client
}
}(i)
}
wg.Wait()
time.Sleep(100 * time.Millisecond)
// 所有客户端都应该能够发送请求
successCount := 0
for i, client := range clients {
if client != nil && client.IsConnected() {
resp, err := client.Request([]byte("test"), 3*time.Second)
if err == nil && string(resp) == "ok\n" {
successCount++
} else {
t.Logf("Client %d request failed: %v", i, err)
}
}
}
// 清理所有客户端
for _, client := range clients {
if client != nil {
CleanupTestClient(t, client)
}
}
assert.Greater(t, successCount, 0, "At least one client should succeed")
t.Logf("Success rate: %d/%d", successCount, numClients)
}
// TestServerStop 测试服务器停止
func TestServerStop(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
srv.Router().RegisterString("test", func(ctx nnet.Context) error {
return ctx.Response().WriteBytes([]byte("ok\n"))
})
})
// 创建客户端并发送请求
client := NewTestClient(t, ts.Addr, nil)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 发送请求
resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second)
assert.Contains(t, string(resp), "ok", "Request should succeed")
// 停止服务器
CleanupTestServer(t, ts)
// 等待服务器停止
require.Eventually(t, func() bool {
return !ts.Server.Started()
}, 3*time.Second, 50*time.Millisecond, "Server should stop")
// 客户端应该断开连接
require.Eventually(t, func() bool {
return !client.IsConnected()
}, 2*time.Second, 50*time.Millisecond, "Client should disconnect")
CleanupTestClient(t, client)
}