You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

285 lines
8.3 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package connection
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/noahlann/nnet/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShardedManager(t *testing.T) {
manager := NewShardedManager(100, 4)
// 测试添加连接
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err, "Expected no error when adding connection")
// 测试获取连接
retrievedConn, err := manager.Get("conn1")
require.NoError(t, err, "Expected no error when getting connection")
assert.Equal(t, "conn1", retrievedConn.ID(), "Expected conn1")
// 测试连接数
assert.Equal(t, 1, manager.Count(), "Expected 1 connection")
// 测试移除连接
err = manager.Remove("conn1")
require.NoError(t, err, "Expected no error when removing connection")
assert.Equal(t, 0, manager.Count(), "Expected 0 connections")
}
func TestShardedManagerMaxConnections(t *testing.T) {
manager := NewShardedManager(2, 4)
// 添加第一个连接
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err, "Expected no error when adding first connection")
// 添加第二个连接
conn2 := &mockConnectionInterface{id: "conn2"}
err = manager.Add(conn2)
require.NoError(t, err, "Expected no error when adding second connection")
// 尝试添加第三个连接(应该失败)
conn3 := &mockConnectionInterface{id: "conn3"}
err = manager.Add(conn3)
assert.Error(t, err, "Expected error for exceeding max connections")
}
func TestShardedManagerMaxConnectionsConcurrent(t *testing.T) {
const maxConns = 10
manager := NewShardedManager(maxConns, 4)
var wg sync.WaitGroup
var success atomic.Int64
var unexpectedError atomic.Value
totalAttempts := 100
for i := 0; i < totalAttempts; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
conn := &mockConnectionInterface{id: fmt.Sprintf("conn-%d", idx)}
err := manager.Add(conn)
if err != nil {
if !strings.Contains(err.Error(), "max connections exceeded") {
unexpectedError.Store(err)
}
return
}
success.Add(1)
}(i)
}
wg.Wait()
if v := unexpectedError.Load(); v != nil {
t.Fatalf("unexpected error: %v", v)
}
assert.Equal(t, maxConns, manager.Count(), "expected count to respect max connections under concurrency")
assert.Equal(t, int64(maxConns), success.Load(), "expected only maxConns successful additions")
}
func TestShardedManagerGroups(t *testing.T) {
manager := NewShardedManager(100, 4)
conn1 := &mockConnectionInterface{id: "conn1"}
conn2 := &mockConnectionInterface{id: "conn2"}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.Add(conn2)
require.NoError(t, err)
// 添加到分组
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err, "Expected no error when adding conn1 to group")
err = manager.AddToGroup("group1", "conn2")
require.NoError(t, err, "Expected no error when adding conn2 to group")
// 获取分组
group := manager.GetGroup("group1")
assert.Equal(t, 2, len(group), "Expected 2 connections in group")
}
func TestShardedManagerCleanupInactive(t *testing.T) {
manager := NewShardedManager(100, 4)
conn1 := &mockConnectionInterfaceWithLastActive{
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
lastActive: time.Now().Add(-40 * time.Second),
}
conn2 := &mockConnectionInterfaceWithLastActive{
mockConnectionInterface: mockConnectionInterface{id: "conn2"},
lastActive: time.Now(),
}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.Add(conn2)
require.NoError(t, err)
// 将 conn1 添加到分组
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err)
// 清理非活动连接
manager.CleanupInactive(30 * time.Second)
// 验证 conn1 被移除
assert.Equal(t, 1, manager.Count(), "Expected 1 connection after cleanup")
_, err = manager.Get("conn1")
assert.Error(t, err, "Expected conn1 to be removed")
_, err = manager.Get("conn2")
assert.NoError(t, err, "Expected conn2 to be retained")
}
func TestShardedManagerDefaultShardCount(t *testing.T) {
// 测试默认分片数 - 需要通过反射或添加getter方法
// 这里我们通过行为测试创建manager并添加连接验证它正常工作
manager := NewShardedManager(100, 0)
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err)
assert.Equal(t, 1, manager.Count())
}
func TestShardedManagerMaxShardCount(t *testing.T) {
// 测试最大分片数 - 通过行为测试
manager := NewShardedManager(100, 500)
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err)
assert.Equal(t, 1, manager.Count())
}
func TestShardedManagerGetAll(t *testing.T) {
manager := NewShardedManager(100, 4)
conn1 := &mockConnectionInterface{id: "conn1"}
conn2 := &mockConnectionInterface{id: "conn2"}
conn3 := &mockConnectionInterface{id: "conn3"}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.Add(conn2)
require.NoError(t, err)
err = manager.Add(conn3)
require.NoError(t, err)
allConns := manager.GetAll()
assert.Equal(t, 3, len(allConns), "Expected 3 connections")
}
func TestShardedManagerRemoveFromGroup(t *testing.T) {
manager := NewShardedManager(100, 4)
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err)
err = manager.RemoveFromGroup("group1", "conn1")
require.NoError(t, err, "Expected no error when removing from group")
group := manager.GetGroup("group1")
assert.True(t, group == nil || len(group) == 0, "Expected empty group after removal")
}
func TestShardedManagerGetNotFound(t *testing.T) {
manager := NewShardedManager(100, 4)
_, err := manager.Get("nonexistent")
assert.Error(t, err, "Expected error for nonexistent connection")
assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound")
}
func TestShardedManagerRemoveNotFound(t *testing.T) {
manager := NewShardedManager(100, 4)
err := manager.Remove("nonexistent")
assert.Error(t, err, "Expected error for nonexistent connection")
assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound")
}
func TestShardedManagerAddToGroupNotFound(t *testing.T) {
manager := NewShardedManager(100, 4)
err := manager.AddToGroup("group1", "nonexistent")
assert.Error(t, err, "Expected error for nonexistent connection")
assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound")
}
func TestShardedManagerAddNil(t *testing.T) {
manager := NewShardedManager(100, 4)
err := manager.Add(nil)
assert.Error(t, err, "Expected error for nil connection")
}
func TestShardedManagerCleanupInactiveInvalidTimeout(t *testing.T) {
manager := NewShardedManager(100, 4)
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err)
// 测试无效的超时类型 - 这个测试可能需要根据实际实现调整
manager.CleanupInactive("invalid")
// 验证连接没有被移除(如果实现支持类型检查)
assert.Equal(t, 1, manager.Count(), "Expected 1 connection")
}
func TestShardedManagerBroadcastError(t *testing.T) {
manager := NewShardedManager(100, 4)
conn1 := &mockConnectionInterfaceWithError{
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
writeError: true,
}
conn2 := &mockConnectionInterface{id: "conn2"}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.Add(conn2)
require.NoError(t, err)
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err)
err = manager.AddToGroup("group1", "conn2")
require.NoError(t, err)
// 广播消息conn1 会失败)
err = manager.BroadcastToGroup("group1", []byte("test"))
assert.Error(t, err, "Expected error for broadcast with write failure")
}
func TestShardedManagerConcurrent(t *testing.T) {
t.Skip("Skipping concurrent test to avoid potential hangs")
// 这个测试可能在某些情况下会卡住用户手动skip
// manager := NewShardedManager(1000, 16)
// ... (test code)
}
func TestShardedManagerGetGroupEmpty(t *testing.T) {
manager := NewShardedManager(100, 4)
// 测试获取不存在的分组
group := manager.GetGroup("nonexistent")
assert.True(t, group == nil || len(group) == 0, "Expected empty group for nonexistent group")
}