You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

181 lines
5.4 KiB
Go

package integration
import (
"fmt"
"net"
"testing"
"time"
"github.com/noahlann/nnet/pkg/nnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConnectionManagerFind 测试连接查找
func TestConnectionManagerFind(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
srv.Router().RegisterString("getid", func(ctx nnet.Context) error {
connID := ctx.Connection().ID()
return ctx.Response().WriteBytes([]byte(connID + "\n"))
})
})
defer CleanupTestServer(t, ts)
client := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 获取连接ID
resp := RequestWithTimeout(t, client, []byte("getid"), 3*time.Second)
connID := string(resp)
connID = connID[:len(connID)-1] // 去掉换行符
t.Logf("Connection ID: %s", connID)
// 查找连接
connMgr := ts.Server.ConnectionManager()
conn, err := connMgr.Get(connID)
require.NoError(t, err, "Connection should be found")
require.NotNil(t, conn, "Connection should not be nil")
assert.Equal(t, connID, conn.ID(), "Connection ID should match")
}
// TestConnectionManagerGroup 测试连接分组
func TestConnectionManagerGroup(t *testing.T) {
const groupID = "test-group"
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
connMgr := srv.ConnectionManager()
srv.Router().RegisterString("join", func(ctx nnet.Context) error {
connID := ctx.Connection().ID()
err := connMgr.AddToGroup(groupID, connID)
if err != nil {
return ctx.Response().WriteBytes([]byte("error: " + err.Error() + "\n"))
}
return ctx.Response().WriteBytes([]byte("joined\n"))
})
srv.Router().RegisterString("leave", func(ctx nnet.Context) error {
connID := ctx.Connection().ID()
err := connMgr.RemoveFromGroup(groupID, connID)
if err != nil {
return ctx.Response().WriteBytes([]byte("error: " + err.Error() + "\n"))
}
return ctx.Response().WriteBytes([]byte("left\n"))
})
srv.Router().RegisterString("count", func(ctx nnet.Context) error {
group := connMgr.GetGroup(groupID)
count := len(group)
return ctx.Response().WriteBytes([]byte(fmt.Sprintf("count: %d\n", count)))
})
})
defer CleanupTestServer(t, ts)
// 创建两个客户端
client1 := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client1)
ConnectTestClient(t, client1)
client2 := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client2)
ConnectTestClient(t, client2)
time.Sleep(100 * time.Millisecond)
// 两个客户端都加入分组
resp := RequestWithTimeout(t, client1, []byte("join"), 3*time.Second)
assert.Contains(t, string(resp), "joined", "Client1 should join group")
resp = RequestWithTimeout(t, client2, []byte("join"), 3*time.Second)
assert.Contains(t, string(resp), "joined", "Client2 should join group")
// 检查分组大小(通过服务器端验证)
connMgr := ts.Server.ConnectionManager()
group := connMgr.GetGroup(groupID)
assert.Len(t, group, 2, "Group should have 2 connections")
// 一个客户端离开分组
resp = RequestWithTimeout(t, client1, []byte("leave"), 3*time.Second)
assert.Contains(t, string(resp), "left", "Client1 should leave group")
// 检查分组大小
group = connMgr.GetGroup(groupID)
assert.Len(t, group, 1, "Group should have 1 connection")
}
// TestConnectionManagerBroadcast 测试分组广播
func TestConnectionManagerBroadcast(t *testing.T) {
const groupID = "broadcast-group"
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
connMgr := srv.ConnectionManager()
srv.Router().RegisterString("join", func(ctx nnet.Context) error {
connID := ctx.Connection().ID()
return connMgr.AddToGroup(groupID, connID)
})
srv.Router().RegisterString("broadcast", func(ctx nnet.Context) error {
message := []byte("broadcast message\n")
return connMgr.BroadcastToGroup(groupID, message)
})
})
defer CleanupTestServer(t, ts)
// 创建两个客户端
client1 := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client1)
ConnectTestClient(t, client1)
client2 := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client2)
ConnectTestClient(t, client2)
time.Sleep(100 * time.Millisecond)
// 两个客户端都加入分组
RequestWithTimeout(t, client1, []byte("join"), 3*time.Second)
RequestWithTimeout(t, client2, []byte("join"), 3*time.Second)
time.Sleep(100 * time.Millisecond)
// 发送广播
RequestWithTimeout(t, client1, []byte("broadcast"), 3*time.Second)
// 两个客户端都应该收到广播消息
time.Sleep(200 * time.Millisecond)
// 注意:这里需要检查客户端是否支持接收异步消息
// 如果客户端不支持,可能需要使用其他方式验证广播
t.Logf("Broadcast sent, clients should receive message")
}