|
|
package connection
|
|
|
|
|
|
import (
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/pkg/errors"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
func TestConnectionManager(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
// 测试添加连接
|
|
|
conn := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn)
|
|
|
require.NoError(t, err, "Expected no error when adding connection")
|
|
|
|
|
|
// 测试获取连接
|
|
|
retrievedConn, err := manager.Get("conn1")
|
|
|
require.NoError(t, err, "Expected no error when getting connection")
|
|
|
assert.Equal(t, "conn1", retrievedConn.ID(), "Expected conn1")
|
|
|
|
|
|
// 测试连接数
|
|
|
assert.Equal(t, 1, manager.Count(), "Expected 1 connection")
|
|
|
|
|
|
// 测试移除连接
|
|
|
err = manager.Remove("conn1")
|
|
|
require.NoError(t, err, "Expected no error when removing connection")
|
|
|
assert.Equal(t, 0, manager.Count(), "Expected 0 connections")
|
|
|
}
|
|
|
|
|
|
func TestConnectionManagerAddNil(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
// 测试添加 nil 连接
|
|
|
err := manager.Add(nil)
|
|
|
assert.Error(t, err, "Expected error for nil connection")
|
|
|
}
|
|
|
|
|
|
func TestConnectionManagerMaxConnections(t *testing.T) {
|
|
|
manager := NewManager(2)
|
|
|
|
|
|
// 添加第一个连接
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err, "Expected no error when adding first connection")
|
|
|
|
|
|
// 添加第二个连接
|
|
|
conn2 := &mockConnectionInterface{id: "conn2"}
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err, "Expected no error when adding second connection")
|
|
|
|
|
|
// 尝试添加第三个连接(应该失败)
|
|
|
conn3 := &mockConnectionInterface{id: "conn3"}
|
|
|
err = manager.Add(conn3)
|
|
|
assert.Error(t, err, "Expected error for exceeding max connections")
|
|
|
}
|
|
|
|
|
|
func TestConnectionManagerGetNotFound(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
// 测试获取不存在的连接
|
|
|
_, err := manager.Get("nonexistent")
|
|
|
assert.Error(t, err, "Expected error for nonexistent connection")
|
|
|
assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound")
|
|
|
}
|
|
|
|
|
|
func TestConnectionManagerGetAll(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
// 添加多个连接
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
conn2 := &mockConnectionInterface{id: "conn2"}
|
|
|
conn3 := &mockConnectionInterface{id: "conn3"}
|
|
|
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn3)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 获取所有连接
|
|
|
allConns := manager.GetAll()
|
|
|
assert.Equal(t, 3, len(allConns), "Expected 3 connections")
|
|
|
}
|
|
|
|
|
|
func TestConnectionGroup(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
conn2 := &mockConnectionInterface{id: "conn2"}
|
|
|
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 添加到分组
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err, "Expected no error when adding conn1 to group")
|
|
|
err = manager.AddToGroup("group1", "conn2")
|
|
|
require.NoError(t, err, "Expected no error when adding conn2 to group")
|
|
|
|
|
|
// 获取分组
|
|
|
group := manager.GetGroup("group1")
|
|
|
assert.Equal(t, 2, len(group), "Expected 2 connections in group")
|
|
|
}
|
|
|
|
|
|
func TestConnectionGroupRemoveFromGroup(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 添加到分组
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 从分组中移除
|
|
|
err = manager.RemoveFromGroup("group1", "conn1")
|
|
|
require.NoError(t, err, "Expected no error when removing from group")
|
|
|
|
|
|
// 验证分组为空
|
|
|
group := manager.GetGroup("group1")
|
|
|
assert.True(t, group == nil || len(group) == 0, "Expected empty group after removal")
|
|
|
}
|
|
|
|
|
|
func TestConnectionManagerCleanupInactive(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
conn1 := &mockConnectionInterfaceWithLastActive{
|
|
|
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
|
|
|
lastActive: time.Now().Add(-40 * time.Second),
|
|
|
}
|
|
|
conn2 := &mockConnectionInterfaceWithLastActive{
|
|
|
mockConnectionInterface: mockConnectionInterface{id: "conn2"},
|
|
|
lastActive: time.Now(),
|
|
|
}
|
|
|
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 将 conn1 添加到分组
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 清理非活动连接(30秒超时)
|
|
|
manager.CleanupInactive(30 * time.Second)
|
|
|
|
|
|
// 验证 conn1 被移除,conn2 保留
|
|
|
assert.Equal(t, 1, manager.Count(), "Expected 1 connection after cleanup")
|
|
|
|
|
|
_, err = manager.Get("conn1")
|
|
|
assert.Error(t, err, "Expected conn1 to be removed")
|
|
|
|
|
|
_, err = manager.Get("conn2")
|
|
|
assert.NoError(t, err, "Expected conn2 to be retained")
|
|
|
|
|
|
// 验证 conn1 从分组中移除
|
|
|
group := manager.GetGroup("group1")
|
|
|
assert.True(t, group == nil || len(group) == 0, "Expected empty group after cleanup")
|
|
|
}
|
|
|
|
|
|
func TestConnectionManagerRemoveFromGroups(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
conn1 := &mockConnectionInterface{id: "conn1"}
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 添加到多个分组
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err)
|
|
|
err = manager.AddToGroup("group2", "conn1")
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 移除连接
|
|
|
err = manager.Remove("conn1")
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 验证从所有分组中移除
|
|
|
group1 := manager.GetGroup("group1")
|
|
|
assert.True(t, group1 == nil || len(group1) == 0, "Expected empty group1 after removal")
|
|
|
|
|
|
group2 := manager.GetGroup("group2")
|
|
|
assert.True(t, group2 == nil || len(group2) == 0, "Expected empty group2 after removal")
|
|
|
}
|
|
|
|
|
|
func TestConnectionManagerBroadcastError(t *testing.T) {
|
|
|
manager := NewManager(100)
|
|
|
|
|
|
conn1 := &mockConnectionInterfaceWithError{
|
|
|
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
|
|
|
writeError: true,
|
|
|
}
|
|
|
conn2 := &mockConnectionInterface{id: "conn2"}
|
|
|
|
|
|
err := manager.Add(conn1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Add(conn2)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
err = manager.AddToGroup("group1", "conn1")
|
|
|
require.NoError(t, err)
|
|
|
err = manager.AddToGroup("group1", "conn2")
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 广播消息(conn1 会失败)
|
|
|
err = manager.BroadcastToGroup("group1", []byte("test"))
|
|
|
assert.Error(t, err, "Expected error for broadcast with write failure")
|
|
|
}
|
|
|
|
|
|
// All ShardedManager tests moved to sharded_manager_test.go
|
|
|
|
|
|
// mockConnectionInterface 模拟连接接口
|
|
|
type mockConnectionInterface struct {
|
|
|
id string
|
|
|
}
|
|
|
|
|
|
func (m *mockConnectionInterface) ID() string {
|
|
|
return m.id
|
|
|
}
|
|
|
|
|
|
func (m *mockConnectionInterface) RemoteAddr() string {
|
|
|
return "127.0.0.1:6995"
|
|
|
}
|
|
|
|
|
|
func (m *mockConnectionInterface) LocalAddr() string {
|
|
|
return "127.0.0.1:6995"
|
|
|
}
|
|
|
|
|
|
func (m *mockConnectionInterface) Write(data []byte) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func (m *mockConnectionInterface) Close() error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// mockConnectionInterfaceWithLastActive 带最后活动时间的连接
|
|
|
type mockConnectionInterfaceWithLastActive struct {
|
|
|
mockConnectionInterface
|
|
|
lastActive time.Time
|
|
|
}
|
|
|
|
|
|
func (m *mockConnectionInterfaceWithLastActive) LastActive() time.Time {
|
|
|
return m.lastActive
|
|
|
}
|
|
|
|
|
|
// mockConnectionInterfaceWithError 带错误的连接
|
|
|
type mockConnectionInterfaceWithError struct {
|
|
|
mockConnectionInterface
|
|
|
writeError bool
|
|
|
}
|
|
|
|
|
|
func (m *mockConnectionInterfaceWithError) Write(data []byte) error {
|
|
|
if m.writeError {
|
|
|
return errors.New("write error")
|
|
|
}
|
|
|
return nil
|
|
|
}
|