You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

338 lines
8.3 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package test
import (
"testing"
"time"
"github.com/noahlann/nnet/internal/connection"
"github.com/noahlann/nnet/pkg/errors"
)
func TestConnectionManager(t *testing.T) {
manager := connection.NewManager(100)
// 测试添加连接
conn := &mockConnectionInterface{id: "conn1"}
if err := manager.Add(conn); err != nil {
t.Errorf("Expected no error, got: %v", err)
}
// 测试获取连接
retrievedConn, err := manager.Get("conn1")
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if retrievedConn.ID() != "conn1" {
t.Error("Expected conn1")
}
// 测试连接数
if manager.Count() != 1 {
t.Errorf("Expected 1 connection, got: %d", manager.Count())
}
// 测试移除连接
if err := manager.Remove("conn1"); err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if manager.Count() != 0 {
t.Errorf("Expected 0 connections, got: %d", manager.Count())
}
}
func TestConnectionManagerAddNil(t *testing.T) {
manager := connection.NewManager(100)
// 测试添加 nil 连接
err := manager.Add(nil)
if err == nil {
t.Error("Expected error for nil connection")
}
}
func TestConnectionManagerMaxConnections(t *testing.T) {
manager := connection.NewManager(2)
// 添加第一个连接
conn1 := &mockConnectionInterface{id: "conn1"}
if err := manager.Add(conn1); err != nil {
t.Errorf("Expected no error, got: %v", err)
}
// 添加第二个连接
conn2 := &mockConnectionInterface{id: "conn2"}
if err := manager.Add(conn2); err != nil {
t.Errorf("Expected no error, got: %v", err)
}
// 尝试添加第三个连接(应该失败)
conn3 := &mockConnectionInterface{id: "conn3"}
err := manager.Add(conn3)
if err == nil {
t.Error("Expected error for exceeding max connections")
}
}
func TestConnectionManagerGetNotFound(t *testing.T) {
manager := connection.NewManager(100)
// 测试获取不存在的连接
_, err := manager.Get("nonexistent")
if err == nil {
t.Error("Expected error for nonexistent connection")
}
if err != errors.ErrConnectionNotFound {
t.Errorf("Expected ErrConnectionNotFound, got: %v", err)
}
}
func TestConnectionManagerRemoveNotFound(t *testing.T) {
manager := connection.NewManager(100)
// 测试移除不存在的连接
err := manager.Remove("nonexistent")
if err == nil {
t.Error("Expected error for nonexistent connection")
}
if err != errors.ErrConnectionNotFound {
t.Errorf("Expected ErrConnectionNotFound, got: %v", err)
}
}
func TestConnectionManagerGetAll(t *testing.T) {
manager := connection.NewManager(100)
// 添加多个连接
conn1 := &mockConnectionInterface{id: "conn1"}
conn2 := &mockConnectionInterface{id: "conn2"}
conn3 := &mockConnectionInterface{id: "conn3"}
manager.Add(conn1)
manager.Add(conn2)
manager.Add(conn3)
// 获取所有连接
allConns := manager.GetAll()
if len(allConns) != 3 {
t.Errorf("Expected 3 connections, got: %d", len(allConns))
}
}
func TestConnectionGroup(t *testing.T) {
manager := connection.NewManager(100)
conn1 := &mockConnectionInterface{id: "conn1"}
conn2 := &mockConnectionInterface{id: "conn2"}
manager.Add(conn1)
manager.Add(conn2)
// 添加到分组
if err := manager.AddToGroup("group1", "conn1"); err != nil {
t.Errorf("Expected no error, got: %v", err)
}
if err := manager.AddToGroup("group1", "conn2"); err != nil {
t.Errorf("Expected no error, got: %v", err)
}
// 获取分组
group := manager.GetGroup("group1")
if len(group) != 2 {
t.Errorf("Expected 2 connections in group, got: %d", len(group))
}
}
func TestConnectionGroupNotFound(t *testing.T) {
manager := connection.NewManager(100)
// 测试获取不存在的分组
group := manager.GetGroup("nonexistent")
if group != nil && len(group) != 0 {
t.Error("Expected empty group for nonexistent group")
}
}
func TestConnectionGroupAddToGroupNotFound(t *testing.T) {
manager := connection.NewManager(100)
// 测试将不存在的连接添加到分组
err := manager.AddToGroup("group1", "nonexistent")
if err == nil {
t.Error("Expected error for nonexistent connection")
}
if err != errors.ErrConnectionNotFound {
t.Errorf("Expected ErrConnectionNotFound, got: %v", err)
}
}
func TestConnectionGroupRemoveFromGroup(t *testing.T) {
manager := connection.NewManager(100)
conn1 := &mockConnectionInterface{id: "conn1"}
manager.Add(conn1)
// 添加到分组
manager.AddToGroup("group1", "conn1")
// 从分组中移除
if err := manager.RemoveFromGroup("group1", "conn1"); err != nil {
t.Errorf("Expected no error, got: %v", err)
}
// 验证分组为空
group := manager.GetGroup("group1")
if group != nil && len(group) != 0 {
t.Error("Expected empty group after removal")
}
}
func TestConnectionGroupRemoveFromGroupNotFound(t *testing.T) {
manager := connection.NewManager(100)
// 测试从不存在的分组中移除连接
err := manager.RemoveFromGroup("nonexistent", "conn1")
if err != nil {
t.Errorf("Expected no error for nonexistent group, got: %v", err)
}
}
func TestConnectionGroupBroadcast(t *testing.T) {
manager := connection.NewManager(100)
conn1 := &mockConnectionInterface{id: "conn1", writeData: make(chan []byte, 1)}
conn2 := &mockConnectionInterface{id: "conn2", writeData: make(chan []byte, 1)}
manager.Add(conn1)
manager.Add(conn2)
// 添加到分组
manager.AddToGroup("group1", "conn1")
manager.AddToGroup("group1", "conn2")
// 广播消息
data := []byte("broadcast message")
if err := manager.BroadcastToGroup("group1", data); err != nil {
t.Errorf("Expected no error, got: %v", err)
}
// 验证两个连接都收到了消息
select {
case msg := <-conn1.writeData:
if string(msg) != string(data) {
t.Errorf("Expected %s, got %s", string(data), string(msg))
}
case <-time.After(100 * time.Millisecond):
t.Error("Timeout waiting for message on conn1")
}
select {
case msg := <-conn2.writeData:
if string(msg) != string(data) {
t.Errorf("Expected %s, got %s", string(data), string(msg))
}
case <-time.After(100 * time.Millisecond):
t.Error("Timeout waiting for message on conn2")
}
}
func TestConnectionGroupBroadcastEmptyGroup(t *testing.T) {
manager := connection.NewManager(100)
// 测试向空分组广播
err := manager.BroadcastToGroup("empty_group", []byte("test"))
if err != nil {
t.Errorf("Expected no error for empty group, got: %v", err)
}
}
func TestConnectionManagerCleanupInactive(t *testing.T) {
manager := connection.NewManager(100)
conn1 := &mockConnectionInterfaceWithLastActive{
mockConnectionInterface: mockConnectionInterface{id: "conn1"},
lastActive: time.Now().Add(-40 * time.Second),
}
conn2 := &mockConnectionInterfaceWithLastActive{
mockConnectionInterface: mockConnectionInterface{id: "conn2"},
lastActive: time.Now(),
}
manager.Add(conn1)
manager.Add(conn2)
// 清理非活动连接30秒超时
manager.CleanupInactive(30 * time.Second)
// 验证 conn1 被移除conn2 保留
if manager.Count() != 1 {
t.Errorf("Expected 1 connection after cleanup, got: %d", manager.Count())
}
_, err := manager.Get("conn1")
if err == nil {
t.Error("Expected conn1 to be removed")
}
_, err = manager.Get("conn2")
if err != nil {
t.Error("Expected conn2 to be retained")
}
}
func TestConnectionManagerCleanupInactiveInvalidTimeout(t *testing.T) {
manager := connection.NewManager(100)
conn1 := &mockConnectionInterface{id: "conn1"}
manager.Add(conn1)
// 测试无效的超时类型
manager.CleanupInactive("invalid")
// 验证连接没有被移除
if manager.Count() != 1 {
t.Errorf("Expected 1 connection, got: %d", manager.Count())
}
}
// mockConnectionInterface 模拟连接接口
type mockConnectionInterface struct {
id string
writeData chan []byte
}
func (m *mockConnectionInterface) ID() string {
return m.id
}
func (m *mockConnectionInterface) RemoteAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnectionInterface) LocalAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnectionInterface) Write(data []byte) error {
if m.writeData != nil {
m.writeData <- data
}
return nil
}
func (m *mockConnectionInterface) Close() error {
return nil
}
// mockConnectionInterfaceWithLastActive 带最后活动时间的连接
type mockConnectionInterfaceWithLastActive struct {
mockConnectionInterface
lastActive time.Time
}
func (m *mockConnectionInterfaceWithLastActive) LastActive() time.Time {
return m.lastActive
}