You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

268 lines
7.3 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package connection
import (
"testing"
"time"
"github.com/noahlann/nnet/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConnectionManager(t *testing.T) {
manager := NewManager(100)
// 测试添加连接
conn := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn)
require.NoError(t, err, "Expected no error when adding connection")
// 测试获取连接
retrievedConn, err := manager.Get("conn1")
require.NoError(t, err, "Expected no error when getting connection")
assert.Equal(t, "conn1", retrievedConn.ID(), "Expected conn1")
// 测试连接数
assert.Equal(t, 1, manager.Count(), "Expected 1 connection")
// 测试移除连接
err = manager.Remove("conn1")
require.NoError(t, err, "Expected no error when removing connection")
assert.Equal(t, 0, manager.Count(), "Expected 0 connections")
}
func TestConnectionManagerAddNil(t *testing.T) {
manager := NewManager(100)
// 测试添加 nil 连接
err := manager.Add(nil)
assert.Error(t, err, "Expected error for nil connection")
}
func TestConnectionManagerMaxConnections(t *testing.T) {
manager := NewManager(2)
// 添加第一个连接
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err, "Expected no error when adding first connection")
// 添加第二个连接
conn2 := &mockConnectionInterface{id: "conn2"}
err = manager.Add(conn2)
require.NoError(t, err, "Expected no error when adding second connection")
// 尝试添加第三个连接(应该失败)
conn3 := &mockConnectionInterface{id: "conn3"}
err = manager.Add(conn3)
assert.Error(t, err, "Expected error for exceeding max connections")
}
func TestConnectionManagerGetNotFound(t *testing.T) {
manager := NewManager(100)
// 测试获取不存在的连接
_, err := manager.Get("nonexistent")
assert.Error(t, err, "Expected error for nonexistent connection")
assert.Equal(t, errors.ErrConnectionNotFound, err, "Expected ErrConnectionNotFound")
}
func TestConnectionManagerGetAll(t *testing.T) {
manager := NewManager(100)
// 添加多个连接
conn1 := &mockConnectionInterface{id: "conn1"}
conn2 := &mockConnectionInterface{id: "conn2"}
conn3 := &mockConnectionInterface{id: "conn3"}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.Add(conn2)
require.NoError(t, err)
err = manager.Add(conn3)
require.NoError(t, err)
// 获取所有连接
allConns := manager.GetAll()
assert.Equal(t, 3, len(allConns), "Expected 3 connections")
}
func TestConnectionGroup(t *testing.T) {
manager := NewManager(100)
conn1 := &mockConnectionInterface{id: "conn1"}
conn2 := &mockConnectionInterface{id: "conn2"}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.Add(conn2)
require.NoError(t, err)
// 添加到分组
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err, "Expected no error when adding conn1 to group")
err = manager.AddToGroup("group1", "conn2")
require.NoError(t, err, "Expected no error when adding conn2 to group")
// 获取分组
group := manager.GetGroup("group1")
assert.Equal(t, 2, len(group), "Expected 2 connections in group")
}
func TestConnectionGroupRemoveFromGroup(t *testing.T) {
manager := NewManager(100)
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err)
// 添加到分组
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err)
// 从分组中移除
err = manager.RemoveFromGroup("group1", "conn1")
require.NoError(t, err, "Expected no error when removing from group")
// 验证分组为空
group := manager.GetGroup("group1")
assert.True(t, group == nil || len(group) == 0, "Expected empty group after removal")
}
func TestConnectionManagerCleanupInactive(t *testing.T) {
manager := NewManager(100)
conn1 := &mockConnectionInterfaceWithLastActive{
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
lastActive: time.Now().Add(-40 * time.Second),
}
conn2 := &mockConnectionInterfaceWithLastActive{
mockConnectionInterface: mockConnectionInterface{id: "conn2"},
lastActive: time.Now(),
}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.Add(conn2)
require.NoError(t, err)
// 将 conn1 添加到分组
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err)
// 清理非活动连接30秒超时
manager.CleanupInactive(30 * time.Second)
// 验证 conn1 被移除conn2 保留
assert.Equal(t, 1, manager.Count(), "Expected 1 connection after cleanup")
_, err = manager.Get("conn1")
assert.Error(t, err, "Expected conn1 to be removed")
_, err = manager.Get("conn2")
assert.NoError(t, err, "Expected conn2 to be retained")
// 验证 conn1 从分组中移除
group := manager.GetGroup("group1")
assert.True(t, group == nil || len(group) == 0, "Expected empty group after cleanup")
}
func TestConnectionManagerRemoveFromGroups(t *testing.T) {
manager := NewManager(100)
conn1 := &mockConnectionInterface{id: "conn1"}
err := manager.Add(conn1)
require.NoError(t, err)
// 添加到多个分组
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err)
err = manager.AddToGroup("group2", "conn1")
require.NoError(t, err)
// 移除连接
err = manager.Remove("conn1")
require.NoError(t, err)
// 验证从所有分组中移除
group1 := manager.GetGroup("group1")
assert.True(t, group1 == nil || len(group1) == 0, "Expected empty group1 after removal")
group2 := manager.GetGroup("group2")
assert.True(t, group2 == nil || len(group2) == 0, "Expected empty group2 after removal")
}
func TestConnectionManagerBroadcastError(t *testing.T) {
manager := NewManager(100)
conn1 := &mockConnectionInterfaceWithError{
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
writeError: true,
}
conn2 := &mockConnectionInterface{id: "conn2"}
err := manager.Add(conn1)
require.NoError(t, err)
err = manager.Add(conn2)
require.NoError(t, err)
err = manager.AddToGroup("group1", "conn1")
require.NoError(t, err)
err = manager.AddToGroup("group1", "conn2")
require.NoError(t, err)
// 广播消息conn1 会失败)
err = manager.BroadcastToGroup("group1", []byte("test"))
assert.Error(t, err, "Expected error for broadcast with write failure")
}
// All ShardedManager tests moved to sharded_manager_test.go
// mockConnectionInterface 模拟连接接口
type mockConnectionInterface struct {
id string
}
func (m *mockConnectionInterface) ID() string {
return m.id
}
func (m *mockConnectionInterface) RemoteAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnectionInterface) LocalAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnectionInterface) Write(data []byte) error {
return nil
}
func (m *mockConnectionInterface) Close() error {
return nil
}
// mockConnectionInterfaceWithLastActive 带最后活动时间的连接
type mockConnectionInterfaceWithLastActive struct {
mockConnectionInterface
lastActive time.Time
}
func (m *mockConnectionInterfaceWithLastActive) LastActive() time.Time {
return m.lastActive
}
// mockConnectionInterfaceWithError 带错误的连接
type mockConnectionInterfaceWithError struct {
mockConnectionInterface
writeError bool
}
func (m *mockConnectionInterfaceWithError) Write(data []byte) error {
if m.writeError {
return errors.New("write error")
}
return nil
}