package connection import ( "testing" "time" "github.com/noahlann/nnet/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConnectionManager(t *testing.T) { manager := NewManager(100) // 测试添加连接 conn := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn) require.NoError(t, err, "Expected no error when adding connection") // 测试获取连接 retrievedConn, err := manager.Get("conn1") require.NoError(t, err, "Expected no error when getting connection") assert.Equal(t, "conn1", retrievedConn.ID(), "Expected conn1") // 测试连接数 assert.Equal(t, 1, manager.Count(), "Expected 1 connection") // 测试移除连接 err = manager.Remove("conn1") require.NoError(t, err, "Expected no error when removing connection") assert.Equal(t, 0, manager.Count(), "Expected 0 connections") } func TestConnectionManagerAddNil(t *testing.T) { manager := NewManager(100) // 测试添加 nil 连接 err := manager.Add(nil) assert.Error(t, err, "Expected error for nil connection") } func TestConnectionManagerMaxConnections(t *testing.T) { manager := NewManager(2) // 添加第一个连接 conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err, "Expected no error when adding first connection") // 添加第二个连接 conn2 := &mockConnectionInterface{id: "conn2"} err = manager.Add(conn2) require.NoError(t, err, "Expected no error when adding second connection") // 尝试添加第三个连接(应该失败) conn3 := &mockConnectionInterface{id: "conn3"} err = manager.Add(conn3) assert.Error(t, err, "Expected error for exceeding max connections") } func TestConnectionManagerGetNotFound(t *testing.T) { manager := NewManager(100) // 测试获取不存在的连接 _, err := manager.Get("nonexistent") assert.Error(t, err, "Expected error for nonexistent connection") assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound") } func TestConnectionManagerGetAll(t *testing.T) { manager := NewManager(100) // 添加多个连接 conn1 := &mockConnectionInterface{id: "conn1"} conn2 := &mockConnectionInterface{id: "conn2"} conn3 := &mockConnectionInterface{id: "conn3"} err := manager.Add(conn1) require.NoError(t, err) err = manager.Add(conn2) require.NoError(t, err) err = manager.Add(conn3) require.NoError(t, err) // 获取所有连接 allConns := manager.GetAll() assert.Equal(t, 3, len(allConns), "Expected 3 connections") } func TestConnectionGroup(t *testing.T) { manager := NewManager(100) conn1 := &mockConnectionInterface{id: "conn1"} conn2 := &mockConnectionInterface{id: "conn2"} err := manager.Add(conn1) require.NoError(t, err) err = manager.Add(conn2) require.NoError(t, err) // 添加到分组 err = manager.AddToGroup("group1", "conn1") require.NoError(t, err, "Expected no error when adding conn1 to group") err = manager.AddToGroup("group1", "conn2") require.NoError(t, err, "Expected no error when adding conn2 to group") // 获取分组 group := manager.GetGroup("group1") assert.Equal(t, 2, len(group), "Expected 2 connections in group") } func TestConnectionGroupRemoveFromGroup(t *testing.T) { manager := NewManager(100) conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err) // 添加到分组 err = manager.AddToGroup("group1", "conn1") require.NoError(t, err) // 从分组中移除 err = manager.RemoveFromGroup("group1", "conn1") require.NoError(t, err, "Expected no error when removing from group") // 验证分组为空 group := manager.GetGroup("group1") assert.True(t, group == nil || len(group) == 0, "Expected empty group after removal") } func TestConnectionManagerCleanupInactive(t *testing.T) { manager := NewManager(100) conn1 := &mockConnectionInterfaceWithLastActive{ mockConnectionInterface: mockConnectionInterface{id: "conn1"}, lastActive: time.Now().Add(-40 * time.Second), } conn2 := &mockConnectionInterfaceWithLastActive{ mockConnectionInterface: mockConnectionInterface{id: "conn2"}, lastActive: time.Now(), } err := manager.Add(conn1) require.NoError(t, err) err = manager.Add(conn2) require.NoError(t, err) // 将 conn1 添加到分组 err = manager.AddToGroup("group1", "conn1") require.NoError(t, err) // 清理非活动连接(30秒超时) manager.CleanupInactive(30 * time.Second) // 验证 conn1 被移除,conn2 保留 assert.Equal(t, 1, manager.Count(), "Expected 1 connection after cleanup") _, err = manager.Get("conn1") assert.Error(t, err, "Expected conn1 to be removed") _, err = manager.Get("conn2") assert.NoError(t, err, "Expected conn2 to be retained") // 验证 conn1 从分组中移除 group := manager.GetGroup("group1") assert.True(t, group == nil || len(group) == 0, "Expected empty group after cleanup") } func TestConnectionManagerRemoveFromGroups(t *testing.T) { manager := NewManager(100) conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err) // 添加到多个分组 err = manager.AddToGroup("group1", "conn1") require.NoError(t, err) err = manager.AddToGroup("group2", "conn1") require.NoError(t, err) // 移除连接 err = manager.Remove("conn1") require.NoError(t, err) // 验证从所有分组中移除 group1 := manager.GetGroup("group1") assert.True(t, group1 == nil || len(group1) == 0, "Expected empty group1 after removal") group2 := manager.GetGroup("group2") assert.True(t, group2 == nil || len(group2) == 0, "Expected empty group2 after removal") } func TestConnectionManagerBroadcastError(t *testing.T) { manager := NewManager(100) conn1 := &mockConnectionInterfaceWithError{ mockConnectionInterface: mockConnectionInterface{id: "conn1"}, writeError: true, } conn2 := &mockConnectionInterface{id: "conn2"} err := manager.Add(conn1) require.NoError(t, err) err = manager.Add(conn2) require.NoError(t, err) err = manager.AddToGroup("group1", "conn1") require.NoError(t, err) err = manager.AddToGroup("group1", "conn2") require.NoError(t, err) // 广播消息(conn1 会失败) err = manager.BroadcastToGroup("group1", []byte("test")) assert.Error(t, err, "Expected error for broadcast with write failure") } // All ShardedManager tests moved to sharded_manager_test.go // mockConnectionInterface 模拟连接接口 type mockConnectionInterface struct { id string } func (m *mockConnectionInterface) ID() string { return m.id } func (m *mockConnectionInterface) RemoteAddr() string { return "127.0.0.1:6995" } func (m *mockConnectionInterface) LocalAddr() string { return "127.0.0.1:6995" } func (m *mockConnectionInterface) Write(data []byte) error { return nil } func (m *mockConnectionInterface) Close() error { return nil } // mockConnectionInterfaceWithLastActive 带最后活动时间的连接 type mockConnectionInterfaceWithLastActive struct { mockConnectionInterface lastActive time.Time } func (m *mockConnectionInterfaceWithLastActive) LastActive() time.Time { return m.lastActive } // mockConnectionInterfaceWithError 带错误的连接 type mockConnectionInterfaceWithError struct { mockConnectionInterface writeError bool } func (m *mockConnectionInterfaceWithError) Write(data []byte) error { if m.writeError { return errors.New("write error") } return nil }