|
|
package connection
|
|
|
|
|
|
import (
|
|
|
"fmt"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/pkg/errors"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
func TestShardedManager(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
// 测试添加连接
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err, "Expected no error when adding connection")
|
|
|
|
|
|
// 测试获取连接
|
|
|
retrievedConn, err := manager.Get("conn1")
|
|
|
require.NoError(t, err, "Expected no error when getting connection")
|
|
|
assert.Equal(t, "conn1", retrievedConn.ID(), "Expected conn1")
|
|
|
|
|
|
// 测试连接数
|
|
|
assert.Equal(t, 1, manager.Count(), "Expected 1 connection")
|
|
|
|
|
|
// 测试移除连接
|
|
|
err = manager.Remove("conn1")
|
|
|
require.NoError(t, err, "Expected no error when removing connection")
|
|
|
assert.Equal(t, 0, manager.Count(), "Expected 0 connections")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerMaxConnections(t *testing.T) {
|
|
|
manager := NewShardedManager(2, 4)
|
|
|
|
|
|
// 添加第一个连接
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err, "Expected no error when adding first connection")
|
|
|
|
|
|
// 添加第二个连接
|
|
|
conn2 := &mockConnectionInterface{id: "conn2"}
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err, "Expected no error when adding second connection")
|
|
|
|
|
|
// 尝试添加第三个连接(应该失败)
|
|
|
conn3 := &mockConnectionInterface{id: "conn3"}
|
|
|
err = manager.Add(conn3)
|
|
|
assert.Error(t, err, "Expected error for exceeding max connections")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerMaxConnectionsConcurrent(t *testing.T) {
|
|
|
const maxConns = 10
|
|
|
manager := NewShardedManager(maxConns, 4)
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
var success atomic.Int64
|
|
|
var unexpectedError atomic.Value
|
|
|
|
|
|
totalAttempts := 100
|
|
|
for i := 0; i < totalAttempts; i++ {
|
|
|
wg.Add(1)
|
|
|
go func(idx int) {
|
|
|
defer wg.Done()
|
|
|
conn := &mockConnectionInterface{id: fmt.Sprintf("conn-%d", idx)}
|
|
|
err := manager.Add(conn)
|
|
|
if err != nil {
|
|
|
if !strings.Contains(err.Error(), "max connections exceeded") {
|
|
|
unexpectedError.Store(err)
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
success.Add(1)
|
|
|
}(i)
|
|
|
}
|
|
|
|
|
|
wg.Wait()
|
|
|
if v := unexpectedError.Load(); v != nil {
|
|
|
t.Fatalf("unexpected error: %v", v)
|
|
|
}
|
|
|
assert.Equal(t, maxConns, manager.Count(), "expected count to respect max connections under concurrency")
|
|
|
assert.Equal(t, int64(maxConns), success.Load(), "expected only maxConns successful additions")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerGroups(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
conn2 := &mockConnectionInterface{id: "conn2"}
|
|
|
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 添加到分组
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err, "Expected no error when adding conn1 to group")
|
|
|
err = manager.AddToGroup("group1", "conn2")
|
|
|
require.NoError(t, err, "Expected no error when adding conn2 to group")
|
|
|
|
|
|
// 获取分组
|
|
|
group := manager.GetGroup("group1")
|
|
|
assert.Equal(t, 2, len(group), "Expected 2 connections in group")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerCleanupInactive(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
conn1 := &mockConnectionInterfaceWithLastActive{
|
|
|
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
|
|
|
lastActive: time.Now().Add(-40 * time.Second),
|
|
|
}
|
|
|
conn2 := &mockConnectionInterfaceWithLastActive{
|
|
|
mockConnectionInterface: mockConnectionInterface{id: "conn2"},
|
|
|
lastActive: time.Now(),
|
|
|
}
|
|
|
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 将 conn1 添加到分组
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 清理非活动连接
|
|
|
manager.CleanupInactive(30 * time.Second)
|
|
|
|
|
|
// 验证 conn1 被移除
|
|
|
assert.Equal(t, 1, manager.Count(), "Expected 1 connection after cleanup")
|
|
|
|
|
|
_, err = manager.Get("conn1")
|
|
|
assert.Error(t, err, "Expected conn1 to be removed")
|
|
|
|
|
|
_, err = manager.Get("conn2")
|
|
|
assert.NoError(t, err, "Expected conn2 to be retained")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerDefaultShardCount(t *testing.T) {
|
|
|
// 测试默认分片数 - 需要通过反射或添加getter方法
|
|
|
// 这里我们通过行为测试:创建manager并添加连接,验证它正常工作
|
|
|
manager := NewShardedManager(100, 0)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
assert.Equal(t, 1, manager.Count())
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerMaxShardCount(t *testing.T) {
|
|
|
// 测试最大分片数 - 通过行为测试
|
|
|
manager := NewShardedManager(100, 500)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
assert.Equal(t, 1, manager.Count())
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerGetAll(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
conn2 := &mockConnectionInterface{id: "conn2"}
|
|
|
conn3 := &mockConnectionInterface{id: "conn3"}
|
|
|
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn3)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
allConns := manager.GetAll()
|
|
|
assert.Equal(t, 3, len(allConns), "Expected 3 connections")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerRemoveFromGroup(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
err = manager.RemoveFromGroup("group1", "conn1")
|
|
|
require.NoError(t, err, "Expected no error when removing from group")
|
|
|
|
|
|
group := manager.GetGroup("group1")
|
|
|
assert.True(t, group == nil || len(group) == 0, "Expected empty group after removal")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerGetNotFound(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
_, err := manager.Get("nonexistent")
|
|
|
assert.Error(t, err, "Expected error for nonexistent connection")
|
|
|
assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerRemoveNotFound(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
err := manager.Remove("nonexistent")
|
|
|
assert.Error(t, err, "Expected error for nonexistent connection")
|
|
|
assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerAddToGroupNotFound(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
err := manager.AddToGroup("group1", "nonexistent")
|
|
|
assert.Error(t, err, "Expected error for nonexistent connection")
|
|
|
assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerAddNil(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
err := manager.Add(nil)
|
|
|
assert.Error(t, err, "Expected error for nil connection")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerCleanupInactiveInvalidTimeout(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 测试无效的超时类型 - 这个测试可能需要根据实际实现调整
|
|
|
manager.CleanupInactive("invalid")
|
|
|
|
|
|
// 验证连接没有被移除(如果实现支持类型检查)
|
|
|
assert.Equal(t, 1, manager.Count(), "Expected 1 connection")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerBroadcastError(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
conn1 := &mockConnectionInterfaceWithError{
|
|
|
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
|
|
|
writeError: true,
|
|
|
}
|
|
|
conn2 := &mockConnectionInterface{id: "conn2"}
|
|
|
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err)
|
|
|
err = manager.AddToGroup("group1", "conn2")
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 广播消息(conn1 会失败)
|
|
|
err = manager.BroadcastToGroup("group1", []byte("test"))
|
|
|
assert.Error(t, err, "Expected error for broadcast with write failure")
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerConcurrent(t *testing.T) {
|
|
|
t.Skip("Skipping concurrent test to avoid potential hangs")
|
|
|
// 这个测试可能在某些情况下会卡住,用户手动skip
|
|
|
// manager := NewShardedManager(1000, 16)
|
|
|
// ... (test code)
|
|
|
}
|
|
|
|
|
|
func TestShardedManagerGetGroupEmpty(t *testing.T) {
|
|
|
manager := NewShardedManager(100, 4)
|
|
|
|
|
|
// 测试获取不存在的分组
|
|
|
group := manager.GetGroup("nonexistent")
|
|
|
assert.True(t, group == nil || len(group) == 0, "Expected empty group for nonexistent group")
|
|
|
}
|