package integration import ( "encoding/json" "fmt" "net" "testing" "time" "github.com/noahlann/nnet/pkg/nnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestWebSocketServerClient 测试WebSocket服务器-客户端通信 func TestWebSocketServerClient(t *testing.T) { // 获取随机端口 listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("ws://127.0.0.1:%d", port), Codec: &nnet.CodecConfig{ DefaultCodec: "json", EnableProtocolEncode: false, }, } // 创建WebSocket服务器 server, err := nnet.NewWebSocketServer(cfg) require.NoError(t, err) // 注册路由 server.Router().RegisterString("hello", func(ctx nnet.Context) error { return ctx.Response().Write(map[string]any{ "message": "hello from websocket server", }) }) // 启动服务器 ts := &TestServer{ Server: server, Addr: cfg.Addr, stopCh: make(chan struct{}), } ts.wg.Add(1) go func() { defer ts.wg.Done() if err := server.Start(); err != nil { t.Logf("Server error: %v", err) } }() require.Eventually(t, func() bool { return server.Started() }, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds") time.Sleep(200 * time.Millisecond) defer CleanupTestServer(t, ts) // 创建WebSocket客户端(使用ws://地址格式) wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port) client := NewTestClient(t, wsAddr, &nnet.ClientConfig{ TransportProtocol: "websocket", }) defer CleanupTestClient(t, client) ConnectTestClient(t, client) // 等待服务器准备好 time.Sleep(100 * time.Millisecond) // 发送请求 resp := RequestWithTimeout(t, client, []byte("hello"), 3*time.Second) t.Logf("Response: %q", string(resp)) var result map[string]any err = json.Unmarshal(resp, &result) assert.NoError(t, err, "Response should be valid JSON") assert.Equal(t, "hello from websocket server", result["message"], "Message should match") } // TestWebSocketMultipleClients 测试WebSocket多个客户端 func TestWebSocketMultipleClients(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("ws://127.0.0.1:%d", port), Codec: &nnet.CodecConfig{ DefaultCodec: "json", }, } server, err := nnet.NewWebSocketServer(cfg) require.NoError(t, err) server.Router().RegisterString("test", func(ctx nnet.Context) error { return ctx.Response().Write(map[string]any{ "status": "ok", }) }) ts := &TestServer{ Server: server, Addr: fmt.Sprintf("ws://127.0.0.1:%d", port), stopCh: make(chan struct{}), } ts.wg.Add(1) go func() { defer ts.wg.Done() if err := server.Start(); err != nil { t.Logf("Server error: %v", err) } }() require.Eventually(t, func() bool { return server.Started() }, 3*time.Second, 50*time.Millisecond, "Server should start") time.Sleep(200 * time.Millisecond) defer CleanupTestServer(t, ts) // 创建多个WebSocket客户端(使用ws://地址格式) wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port) clients := make([]nnet.Client, 3) for i := 0; i < 3; i++ { client := NewTestClient(t, wsAddr, &nnet.ClientConfig{ TransportProtocol: "websocket", }) ConnectTestClient(t, client) clients[i] = client } time.Sleep(100 * time.Millisecond) // 所有客户端都应该能够发送请求 for i, client := range clients { resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second) t.Logf("Response for client %d: %q", i, string(resp)) var result map[string]any err := json.Unmarshal(resp, &result) assert.NoError(t, err, "Response should be valid JSON") assert.Equal(t, "ok", result["status"], "Status should be ok") } // 清理所有客户端 for _, client := range clients { CleanupTestClient(t, client) } } // TestWebSocketEcho 测试WebSocket Echo功能 func TestWebSocketEcho(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("ws://127.0.0.1:%d", port), Codec: &nnet.CodecConfig{ DefaultCodec: "json", }, } server, err := nnet.NewWebSocketServer(cfg) require.NoError(t, err) server.Router().RegisterString("echo", func(ctx nnet.Context) error { data := ctx.Request().Raw() return ctx.Response().Write(map[string]any{ "echo": string(data), }) }) ts := &TestServer{ Server: server, Addr: fmt.Sprintf("ws://127.0.0.1:%d", port), stopCh: make(chan struct{}), } ts.wg.Add(1) go func() { defer ts.wg.Done() if err := server.Start(); err != nil { t.Logf("Server error: %v", err) } }() require.Eventually(t, func() bool { return server.Started() }, 3*time.Second, 50*time.Millisecond, "Server should start") time.Sleep(200 * time.Millisecond) defer CleanupTestServer(t, ts) // 创建WebSocket客户端(使用ws://地址格式) wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port) client := NewTestClient(t, wsAddr, &nnet.ClientConfig{ TransportProtocol: "websocket", }) defer CleanupTestClient(t, client) ConnectTestClient(t, client) time.Sleep(100 * time.Millisecond) // 发送echo请求 testData := "echo test message" resp := RequestWithTimeout(t, client, []byte(testData), 3*time.Second) t.Logf("Response: %q", string(resp)) var result map[string]any err = json.Unmarshal(resp, &result) assert.NoError(t, err, "Response should be valid JSON") assert.Contains(t, result["echo"].(string), "test message", "Echo should contain the message") }