You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

226 lines
5.6 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package integration
import (
"encoding/json"
"fmt"
"net"
"testing"
"time"
"github.com/noahlann/nnet/pkg/nnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestWebSocketServerClient 测试WebSocket服务器-客户端通信
func TestWebSocketServerClient(t *testing.T) {
// 获取随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
Codec: &nnet.CodecConfig{
DefaultCodec: "json",
EnableProtocolEncode: false,
},
}
// 创建WebSocket服务器
server, err := nnet.NewWebSocketServer(cfg)
require.NoError(t, err)
// 注册路由
server.Router().RegisterString("hello", func(ctx nnet.Context) error {
return ctx.Response().Write(map[string]any{
"message": "hello from websocket server",
})
})
// 启动服务器
ts := &TestServer{
Server: server,
Addr: cfg.Addr,
stopCh: make(chan struct{}),
}
ts.wg.Add(1)
go func() {
defer ts.wg.Done()
if err := server.Start(); err != nil {
t.Logf("Server error: %v", err)
}
}()
require.Eventually(t, func() bool {
return server.Started()
}, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds")
time.Sleep(200 * time.Millisecond)
defer CleanupTestServer(t, ts)
// 创建WebSocket客户端使用ws://地址格式)
wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port)
client := NewTestClient(t, wsAddr, &nnet.ClientConfig{
TransportProtocol: "websocket",
})
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
// 等待服务器准备好
time.Sleep(100 * time.Millisecond)
// 发送请求
resp := RequestWithTimeout(t, client, []byte("hello"), 3*time.Second)
t.Logf("Response: %q", string(resp))
var result map[string]any
err = json.Unmarshal(resp, &result)
assert.NoError(t, err, "Response should be valid JSON")
assert.Equal(t, "hello from websocket server", result["message"], "Message should match")
}
// TestWebSocketMultipleClients 测试WebSocket多个客户端
func TestWebSocketMultipleClients(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
Codec: &nnet.CodecConfig{
DefaultCodec: "json",
},
}
server, err := nnet.NewWebSocketServer(cfg)
require.NoError(t, err)
server.Router().RegisterString("test", func(ctx nnet.Context) error {
return ctx.Response().Write(map[string]any{
"status": "ok",
})
})
ts := &TestServer{
Server: server,
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
stopCh: make(chan struct{}),
}
ts.wg.Add(1)
go func() {
defer ts.wg.Done()
if err := server.Start(); err != nil {
t.Logf("Server error: %v", err)
}
}()
require.Eventually(t, func() bool {
return server.Started()
}, 3*time.Second, 50*time.Millisecond, "Server should start")
time.Sleep(200 * time.Millisecond)
defer CleanupTestServer(t, ts)
// 创建多个WebSocket客户端使用ws://地址格式)
wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port)
clients := make([]nnet.Client, 3)
for i := 0; i < 3; i++ {
client := NewTestClient(t, wsAddr, &nnet.ClientConfig{
TransportProtocol: "websocket",
})
ConnectTestClient(t, client)
clients[i] = client
}
time.Sleep(100 * time.Millisecond)
// 所有客户端都应该能够发送请求
for i, client := range clients {
resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second)
t.Logf("Response for client %d: %q", i, string(resp))
var result map[string]any
err := json.Unmarshal(resp, &result)
assert.NoError(t, err, "Response should be valid JSON")
assert.Equal(t, "ok", result["status"], "Status should be ok")
}
// 清理所有客户端
for _, client := range clients {
CleanupTestClient(t, client)
}
}
// TestWebSocketEcho 测试WebSocket Echo功能
func TestWebSocketEcho(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
Codec: &nnet.CodecConfig{
DefaultCodec: "json",
},
}
server, err := nnet.NewWebSocketServer(cfg)
require.NoError(t, err)
server.Router().RegisterString("echo", func(ctx nnet.Context) error {
data := ctx.Request().Raw()
return ctx.Response().Write(map[string]any{
"echo": string(data),
})
})
ts := &TestServer{
Server: server,
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
stopCh: make(chan struct{}),
}
ts.wg.Add(1)
go func() {
defer ts.wg.Done()
if err := server.Start(); err != nil {
t.Logf("Server error: %v", err)
}
}()
require.Eventually(t, func() bool {
return server.Started()
}, 3*time.Second, 50*time.Millisecond, "Server should start")
time.Sleep(200 * time.Millisecond)
defer CleanupTestServer(t, ts)
// 创建WebSocket客户端使用ws://地址格式)
wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port)
client := NewTestClient(t, wsAddr, &nnet.ClientConfig{
TransportProtocol: "websocket",
})
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 发送echo请求
testData := "echo test message"
resp := RequestWithTimeout(t, client, []byte(testData), 3*time.Second)
t.Logf("Response: %q", string(resp))
var result map[string]any
err = json.Unmarshal(resp, &result)
assert.NoError(t, err, "Response should be valid JSON")
assert.Contains(t, result["echo"].(string), "test message", "Echo should contain the message")
}