package integration import ( "fmt" "net" "testing" "time" "github.com/noahlann/nnet/pkg/nnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestConnectionManagerFind 测试连接查找 func TestConnectionManagerFind(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { srv.Router().RegisterString("getid", func(ctx nnet.Context) error { connID := ctx.Connection().ID() return ctx.Response().WriteBytes([]byte(connID + "\n")) }) }) defer CleanupTestServer(t, ts) client := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client) ConnectTestClient(t, client) time.Sleep(100 * time.Millisecond) // 获取连接ID resp := RequestWithTimeout(t, client, []byte("getid"), 3*time.Second) connID := string(resp) connID = connID[:len(connID)-1] // 去掉换行符 t.Logf("Connection ID: %s", connID) // 查找连接 connMgr := ts.Server.ConnectionManager() conn, err := connMgr.Get(connID) require.NoError(t, err, "Connection should be found") require.NotNil(t, conn, "Connection should not be nil") assert.Equal(t, connID, conn.ID(), "Connection ID should match") } // TestConnectionManagerGroup 测试连接分组 func TestConnectionManagerGroup(t *testing.T) { const groupID = "test-group" listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { connMgr := srv.ConnectionManager() srv.Router().RegisterString("join", func(ctx nnet.Context) error { connID := ctx.Connection().ID() err := connMgr.AddToGroup(groupID, connID) if err != nil { return ctx.Response().WriteBytes([]byte("error: " + err.Error() + "\n")) } return ctx.Response().WriteBytes([]byte("joined\n")) }) srv.Router().RegisterString("leave", func(ctx nnet.Context) error { connID := ctx.Connection().ID() err := connMgr.RemoveFromGroup(groupID, connID) if err != nil { return ctx.Response().WriteBytes([]byte("error: " + err.Error() + "\n")) } return ctx.Response().WriteBytes([]byte("left\n")) }) srv.Router().RegisterString("count", func(ctx nnet.Context) error { group := connMgr.GetGroup(groupID) count := len(group) return ctx.Response().WriteBytes([]byte(fmt.Sprintf("count: %d\n", count))) }) }) defer CleanupTestServer(t, ts) // 创建两个客户端 client1 := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client1) ConnectTestClient(t, client1) client2 := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client2) ConnectTestClient(t, client2) time.Sleep(100 * time.Millisecond) // 两个客户端都加入分组 resp := RequestWithTimeout(t, client1, []byte("join"), 3*time.Second) assert.Contains(t, string(resp), "joined", "Client1 should join group") resp = RequestWithTimeout(t, client2, []byte("join"), 3*time.Second) assert.Contains(t, string(resp), "joined", "Client2 should join group") // 检查分组大小(通过服务器端验证) connMgr := ts.Server.ConnectionManager() group := connMgr.GetGroup(groupID) assert.Len(t, group, 2, "Group should have 2 connections") // 一个客户端离开分组 resp = RequestWithTimeout(t, client1, []byte("leave"), 3*time.Second) assert.Contains(t, string(resp), "left", "Client1 should leave group") // 检查分组大小 group = connMgr.GetGroup(groupID) assert.Len(t, group, 1, "Group should have 1 connection") } // TestConnectionManagerBroadcast 测试分组广播 func TestConnectionManagerBroadcast(t *testing.T) { const groupID = "broadcast-group" listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { connMgr := srv.ConnectionManager() srv.Router().RegisterString("join", func(ctx nnet.Context) error { connID := ctx.Connection().ID() return connMgr.AddToGroup(groupID, connID) }) srv.Router().RegisterString("broadcast", func(ctx nnet.Context) error { message := []byte("broadcast message\n") return connMgr.BroadcastToGroup(groupID, message) }) }) defer CleanupTestServer(t, ts) // 创建两个客户端 client1 := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client1) ConnectTestClient(t, client1) client2 := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client2) ConnectTestClient(t, client2) time.Sleep(100 * time.Millisecond) // 两个客户端都加入分组 RequestWithTimeout(t, client1, []byte("join"), 3*time.Second) RequestWithTimeout(t, client2, []byte("join"), 3*time.Second) time.Sleep(100 * time.Millisecond) // 发送广播 RequestWithTimeout(t, client1, []byte("broadcast"), 3*time.Second) // 两个客户端都应该收到广播消息 time.Sleep(200 * time.Millisecond) // 注意:这里需要检查客户端是否支持接收异步消息 // 如果客户端不支持,可能需要使用其他方式验证广播 t.Logf("Broadcast sent, clients should receive message") }