|
|
package integration
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/pkg/nnet"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
// TestWebSocketServerClient 测试WebSocket服务器-客户端通信
|
|
|
func TestWebSocketServerClient(t *testing.T) {
|
|
|
// 获取随机端口
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
require.NoError(t, err)
|
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
|
listener.Close()
|
|
|
|
|
|
cfg := &nnet.Config{
|
|
|
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
|
|
|
Codec: &nnet.CodecConfig{
|
|
|
DefaultCodec: "json",
|
|
|
EnableProtocolEncode: false,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
// 创建WebSocket服务器
|
|
|
server, err := nnet.NewWebSocketServer(cfg)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 注册路由
|
|
|
server.Router().RegisterString("hello", func(ctx nnet.Context) error {
|
|
|
return ctx.Response().Write(map[string]any{
|
|
|
"message": "hello from websocket server",
|
|
|
})
|
|
|
})
|
|
|
|
|
|
// 启动服务器
|
|
|
ts := &TestServer{
|
|
|
Server: server,
|
|
|
Addr: cfg.Addr,
|
|
|
stopCh: make(chan struct{}),
|
|
|
}
|
|
|
|
|
|
ts.wg.Add(1)
|
|
|
go func() {
|
|
|
defer ts.wg.Done()
|
|
|
if err := server.Start(); err != nil {
|
|
|
t.Logf("Server error: %v", err)
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
require.Eventually(t, func() bool {
|
|
|
return server.Started()
|
|
|
}, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds")
|
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
|
|
defer CleanupTestServer(t, ts)
|
|
|
|
|
|
// 创建WebSocket客户端(使用ws://地址格式)
|
|
|
wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port)
|
|
|
client := NewTestClient(t, wsAddr, &nnet.ClientConfig{
|
|
|
TransportProtocol: "websocket",
|
|
|
})
|
|
|
defer CleanupTestClient(t, client)
|
|
|
|
|
|
ConnectTestClient(t, client)
|
|
|
|
|
|
// 等待服务器准备好
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
|
|
// 发送请求
|
|
|
resp := RequestWithTimeout(t, client, []byte("hello"), 3*time.Second)
|
|
|
t.Logf("Response: %q", string(resp))
|
|
|
|
|
|
var result map[string]any
|
|
|
err = json.Unmarshal(resp, &result)
|
|
|
assert.NoError(t, err, "Response should be valid JSON")
|
|
|
assert.Equal(t, "hello from websocket server", result["message"], "Message should match")
|
|
|
}
|
|
|
|
|
|
// TestWebSocketMultipleClients 测试WebSocket多个客户端
|
|
|
func TestWebSocketMultipleClients(t *testing.T) {
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
require.NoError(t, err)
|
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
|
listener.Close()
|
|
|
|
|
|
cfg := &nnet.Config{
|
|
|
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
|
|
|
Codec: &nnet.CodecConfig{
|
|
|
DefaultCodec: "json",
|
|
|
},
|
|
|
}
|
|
|
|
|
|
server, err := nnet.NewWebSocketServer(cfg)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
server.Router().RegisterString("test", func(ctx nnet.Context) error {
|
|
|
return ctx.Response().Write(map[string]any{
|
|
|
"status": "ok",
|
|
|
})
|
|
|
})
|
|
|
|
|
|
ts := &TestServer{
|
|
|
Server: server,
|
|
|
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
|
|
|
stopCh: make(chan struct{}),
|
|
|
}
|
|
|
|
|
|
ts.wg.Add(1)
|
|
|
go func() {
|
|
|
defer ts.wg.Done()
|
|
|
if err := server.Start(); err != nil {
|
|
|
t.Logf("Server error: %v", err)
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
require.Eventually(t, func() bool {
|
|
|
return server.Started()
|
|
|
}, 3*time.Second, 50*time.Millisecond, "Server should start")
|
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
|
|
defer CleanupTestServer(t, ts)
|
|
|
|
|
|
// 创建多个WebSocket客户端(使用ws://地址格式)
|
|
|
wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port)
|
|
|
clients := make([]nnet.Client, 3)
|
|
|
for i := 0; i < 3; i++ {
|
|
|
client := NewTestClient(t, wsAddr, &nnet.ClientConfig{
|
|
|
TransportProtocol: "websocket",
|
|
|
})
|
|
|
ConnectTestClient(t, client)
|
|
|
clients[i] = client
|
|
|
}
|
|
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
|
|
// 所有客户端都应该能够发送请求
|
|
|
for i, client := range clients {
|
|
|
resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second)
|
|
|
t.Logf("Response for client %d: %q", i, string(resp))
|
|
|
|
|
|
var result map[string]any
|
|
|
err := json.Unmarshal(resp, &result)
|
|
|
assert.NoError(t, err, "Response should be valid JSON")
|
|
|
assert.Equal(t, "ok", result["status"], "Status should be ok")
|
|
|
}
|
|
|
|
|
|
// 清理所有客户端
|
|
|
for _, client := range clients {
|
|
|
CleanupTestClient(t, client)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// TestWebSocketEcho 测试WebSocket Echo功能
|
|
|
func TestWebSocketEcho(t *testing.T) {
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
require.NoError(t, err)
|
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
|
listener.Close()
|
|
|
|
|
|
cfg := &nnet.Config{
|
|
|
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
|
|
|
Codec: &nnet.CodecConfig{
|
|
|
DefaultCodec: "json",
|
|
|
},
|
|
|
}
|
|
|
|
|
|
server, err := nnet.NewWebSocketServer(cfg)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
server.Router().RegisterString("echo", func(ctx nnet.Context) error {
|
|
|
data := ctx.Request().Raw()
|
|
|
return ctx.Response().Write(map[string]any{
|
|
|
"echo": string(data),
|
|
|
})
|
|
|
})
|
|
|
|
|
|
ts := &TestServer{
|
|
|
Server: server,
|
|
|
Addr: fmt.Sprintf("ws://127.0.0.1:%d", port),
|
|
|
stopCh: make(chan struct{}),
|
|
|
}
|
|
|
|
|
|
ts.wg.Add(1)
|
|
|
go func() {
|
|
|
defer ts.wg.Done()
|
|
|
if err := server.Start(); err != nil {
|
|
|
t.Logf("Server error: %v", err)
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
require.Eventually(t, func() bool {
|
|
|
return server.Started()
|
|
|
}, 3*time.Second, 50*time.Millisecond, "Server should start")
|
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
|
|
defer CleanupTestServer(t, ts)
|
|
|
|
|
|
// 创建WebSocket客户端(使用ws://地址格式)
|
|
|
wsAddr := fmt.Sprintf("ws://127.0.0.1:%d", port)
|
|
|
client := NewTestClient(t, wsAddr, &nnet.ClientConfig{
|
|
|
TransportProtocol: "websocket",
|
|
|
})
|
|
|
defer CleanupTestClient(t, client)
|
|
|
|
|
|
ConnectTestClient(t, client)
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
|
|
// 发送echo请求
|
|
|
testData := "echo test message"
|
|
|
resp := RequestWithTimeout(t, client, []byte(testData), 3*time.Second)
|
|
|
t.Logf("Response: %q", string(resp))
|
|
|
|
|
|
var result map[string]any
|
|
|
err = json.Unmarshal(resp, &result)
|
|
|
assert.NoError(t, err, "Response should be valid JSON")
|
|
|
assert.Contains(t, result["echo"].(string), "test message", "Echo should contain the message")
|
|
|
}
|
|
|
|