package integration import ( "fmt" "net" "sync" "testing" "time" "github.com/noahlann/nnet/pkg/nnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestConcurrentRequests 测试并发请求 func TestConcurrentRequests(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { srv.Router().RegisterString("test", func(ctx nnet.Context) error { return ctx.Response().WriteBytes([]byte("ok\n")) }) }) defer CleanupTestServer(t, ts) client := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client) ConnectTestClient(t, client) time.Sleep(100 * time.Millisecond) // 并发发送多个请求 const numRequests = 10 var wg sync.WaitGroup wg.Add(numRequests) results := make([]bool, numRequests) for i := 0; i < numRequests; i++ { go func(idx int) { defer wg.Done() resp, err := client.Request([]byte("test"), 3*time.Second) if err != nil { t.Logf("Request %d failed: %v", idx, err) results[idx] = false } else { results[idx] = string(resp) == "ok\n" } }(i) } wg.Wait() // 验证所有请求都成功 successCount := 0 for i, success := range results { if success { successCount++ } else { t.Logf("Request %d failed", i) } } assert.Greater(t, successCount, numRequests/2, "At least half of requests should succeed") t.Logf("Success rate: %d/%d", successCount, numRequests) } // TestConcurrentClients 测试并发客户端 func TestConcurrentClients(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { srv.Router().RegisterString("test", func(ctx nnet.Context) error { return ctx.Response().WriteBytes([]byte("ok\n")) }) }) defer CleanupTestServer(t, ts) // 并发创建多个客户端 const numClients = 5 var wg sync.WaitGroup wg.Add(numClients) clients := make([]nnet.Client, numClients) for i := 0; i < numClients; i++ { go func(idx int) { defer wg.Done() client := NewTestClient(t, ts.Addr, nil) err := client.Connect() if err != nil { t.Logf("Client %d failed to connect: %v", idx, err) } else { clients[idx] = client } }(i) } wg.Wait() time.Sleep(100 * time.Millisecond) // 所有客户端都应该能够发送请求 successCount := 0 for i, client := range clients { if client != nil && client.IsConnected() { resp, err := client.Request([]byte("test"), 3*time.Second) if err == nil && string(resp) == "ok\n" { successCount++ } else { t.Logf("Client %d request failed: %v", i, err) } } } // 清理所有客户端 for _, client := range clients { if client != nil { CleanupTestClient(t, client) } } assert.Greater(t, successCount, 0, "At least one client should succeed") t.Logf("Success rate: %d/%d", successCount, numClients) } // TestServerStop 测试服务器停止 func TestServerStop(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { srv.Router().RegisterString("test", func(ctx nnet.Context) error { return ctx.Response().WriteBytes([]byte("ok\n")) }) }) // 创建客户端并发送请求 client := NewTestClient(t, ts.Addr, nil) ConnectTestClient(t, client) time.Sleep(100 * time.Millisecond) // 发送请求 resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second) assert.Contains(t, string(resp), "ok", "Request should succeed") // 停止服务器 CleanupTestServer(t, ts) // 等待服务器停止 require.Eventually(t, func() bool { return !ts.Server.Started() }, 3*time.Second, 50*time.Millisecond, "Server should stop") // 客户端应该断开连接 require.Eventually(t, func() bool { return !client.IsConnected() }, 2*time.Second, 50*time.Millisecond, "Client should disconnect") CleanupTestClient(t, client) }