package connection import ( "fmt" "strings" "sync" "sync/atomic" "testing" "time" "github.com/noahlann/nnet/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestShardedManager(t *testing.T) { manager := NewShardedManager(100, 4) // 测试添加连接 conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err, "Expected no error when adding connection") // 测试获取连接 retrievedConn, err := manager.Get("conn1") require.NoError(t, err, "Expected no error when getting connection") assert.Equal(t, "conn1", retrievedConn.ID(), "Expected conn1") // 测试连接数 assert.Equal(t, 1, manager.Count(), "Expected 1 connection") // 测试移除连接 err = manager.Remove("conn1") require.NoError(t, err, "Expected no error when removing connection") assert.Equal(t, 0, manager.Count(), "Expected 0 connections") } func TestShardedManagerMaxConnections(t *testing.T) { manager := NewShardedManager(2, 4) // 添加第一个连接 conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err, "Expected no error when adding first connection") // 添加第二个连接 conn2 := &mockConnectionInterface{id: "conn2"} err = manager.Add(conn2) require.NoError(t, err, "Expected no error when adding second connection") // 尝试添加第三个连接(应该失败) conn3 := &mockConnectionInterface{id: "conn3"} err = manager.Add(conn3) assert.Error(t, err, "Expected error for exceeding max connections") } func TestShardedManagerMaxConnectionsConcurrent(t *testing.T) { const maxConns = 10 manager := NewShardedManager(maxConns, 4) var wg sync.WaitGroup var success atomic.Int64 var unexpectedError atomic.Value totalAttempts := 100 for i := 0; i < totalAttempts; i++ { wg.Add(1) go func(idx int) { defer wg.Done() conn := &mockConnectionInterface{id: fmt.Sprintf("conn-%d", idx)} err := manager.Add(conn) if err != nil { if !strings.Contains(err.Error(), "max connections exceeded") { unexpectedError.Store(err) } return } success.Add(1) }(i) } wg.Wait() if v := unexpectedError.Load(); v != nil { t.Fatalf("unexpected error: %v", v) } assert.Equal(t, maxConns, manager.Count(), "expected count to respect max connections under concurrency") assert.Equal(t, int64(maxConns), success.Load(), "expected only maxConns successful additions") } func TestShardedManagerGroups(t *testing.T) { manager := NewShardedManager(100, 4) conn1 := &mockConnectionInterface{id: "conn1"} conn2 := &mockConnectionInterface{id: "conn2"} err := manager.Add(conn1) require.NoError(t, err) err = manager.Add(conn2) require.NoError(t, err) // 添加到分组 err = manager.AddToGroup("group1", "conn1") require.NoError(t, err, "Expected no error when adding conn1 to group") err = manager.AddToGroup("group1", "conn2") require.NoError(t, err, "Expected no error when adding conn2 to group") // 获取分组 group := manager.GetGroup("group1") assert.Equal(t, 2, len(group), "Expected 2 connections in group") } func TestShardedManagerCleanupInactive(t *testing.T) { manager := NewShardedManager(100, 4) conn1 := &mockConnectionInterfaceWithLastActive{ mockConnectionInterface: mockConnectionInterface{id: "conn1"}, lastActive: time.Now().Add(-40 * time.Second), } conn2 := &mockConnectionInterfaceWithLastActive{ mockConnectionInterface: mockConnectionInterface{id: "conn2"}, lastActive: time.Now(), } err := manager.Add(conn1) require.NoError(t, err) err = manager.Add(conn2) require.NoError(t, err) // 将 conn1 添加到分组 err = manager.AddToGroup("group1", "conn1") require.NoError(t, err) // 清理非活动连接 manager.CleanupInactive(30 * time.Second) // 验证 conn1 被移除 assert.Equal(t, 1, manager.Count(), "Expected 1 connection after cleanup") _, err = manager.Get("conn1") assert.Error(t, err, "Expected conn1 to be removed") _, err = manager.Get("conn2") assert.NoError(t, err, "Expected conn2 to be retained") } func TestShardedManagerDefaultShardCount(t *testing.T) { // 测试默认分片数 - 需要通过反射或添加getter方法 // 这里我们通过行为测试:创建manager并添加连接,验证它正常工作 manager := NewShardedManager(100, 0) conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err) assert.Equal(t, 1, manager.Count()) } func TestShardedManagerMaxShardCount(t *testing.T) { // 测试最大分片数 - 通过行为测试 manager := NewShardedManager(100, 500) conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err) assert.Equal(t, 1, manager.Count()) } func TestShardedManagerGetAll(t *testing.T) { manager := NewShardedManager(100, 4) conn1 := &mockConnectionInterface{id: "conn1"} conn2 := &mockConnectionInterface{id: "conn2"} conn3 := &mockConnectionInterface{id: "conn3"} err := manager.Add(conn1) require.NoError(t, err) err = manager.Add(conn2) require.NoError(t, err) err = manager.Add(conn3) require.NoError(t, err) allConns := manager.GetAll() assert.Equal(t, 3, len(allConns), "Expected 3 connections") } func TestShardedManagerRemoveFromGroup(t *testing.T) { manager := NewShardedManager(100, 4) conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err) err = manager.AddToGroup("group1", "conn1") require.NoError(t, err) err = manager.RemoveFromGroup("group1", "conn1") require.NoError(t, err, "Expected no error when removing from group") group := manager.GetGroup("group1") assert.True(t, group == nil || len(group) == 0, "Expected empty group after removal") } func TestShardedManagerGetNotFound(t *testing.T) { manager := NewShardedManager(100, 4) _, err := manager.Get("nonexistent") assert.Error(t, err, "Expected error for nonexistent connection") assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound") } func TestShardedManagerRemoveNotFound(t *testing.T) { manager := NewShardedManager(100, 4) err := manager.Remove("nonexistent") assert.Error(t, err, "Expected error for nonexistent connection") assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound") } func TestShardedManagerAddToGroupNotFound(t *testing.T) { manager := NewShardedManager(100, 4) err := manager.AddToGroup("group1", "nonexistent") assert.Error(t, err, "Expected error for nonexistent connection") assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound") } func TestShardedManagerAddNil(t *testing.T) { manager := NewShardedManager(100, 4) err := manager.Add(nil) assert.Error(t, err, "Expected error for nil connection") } func TestShardedManagerCleanupInactiveInvalidTimeout(t *testing.T) { manager := NewShardedManager(100, 4) conn1 := &mockConnectionInterface{id: "conn1"} err := manager.Add(conn1) require.NoError(t, err) // 测试无效的超时类型 - 这个测试可能需要根据实际实现调整 manager.CleanupInactive("invalid") // 验证连接没有被移除(如果实现支持类型检查) assert.Equal(t, 1, manager.Count(), "Expected 1 connection") } func TestShardedManagerBroadcastError(t *testing.T) { manager := NewShardedManager(100, 4) conn1 := &mockConnectionInterfaceWithError{ mockConnectionInterface: mockConnectionInterface{id: "conn1"}, writeError: true, } conn2 := &mockConnectionInterface{id: "conn2"} err := manager.Add(conn1) require.NoError(t, err) err = manager.Add(conn2) require.NoError(t, err) err = manager.AddToGroup("group1", "conn1") require.NoError(t, err) err = manager.AddToGroup("group1", "conn2") require.NoError(t, err) // 广播消息(conn1 会失败) err = manager.BroadcastToGroup("group1", []byte("test")) assert.Error(t, err, "Expected error for broadcast with write failure") } func TestShardedManagerConcurrent(t *testing.T) { t.Skip("Skipping concurrent test to avoid potential hangs") // 这个测试可能在某些情况下会卡住,用户手动skip // manager := NewShardedManager(1000, 16) // ... (test code) } func TestShardedManagerGetGroupEmpty(t *testing.T) { manager := NewShardedManager(100, 4) // 测试获取不存在的分组 group := manager.GetGroup("nonexistent") assert.True(t, group == nil || len(group) == 0, "Expected empty group for nonexistent group") }