package test import ( "testing" "time" "github.com/noahlann/nnet/internal/connection" "github.com/noahlann/nnet/pkg/errors" ) func TestConnectionManager(t *testing.T) { manager := connection.NewManager(100) // 测试添加连接 conn := &mockConnectionInterface{id: "conn1"} if err := manager.Add(conn); err != nil { t.Errorf("Expected no error, got: %v", err) } // 测试获取连接 retrievedConn, err := manager.Get("conn1") if err != nil { t.Errorf("Expected no error, got: %v", err) } if retrievedConn.ID() != "conn1" { t.Error("Expected conn1") } // 测试连接数 if manager.Count() != 1 { t.Errorf("Expected 1 connection, got: %d", manager.Count()) } // 测试移除连接 if err := manager.Remove("conn1"); err != nil { t.Errorf("Expected no error, got: %v", err) } if manager.Count() != 0 { t.Errorf("Expected 0 connections, got: %d", manager.Count()) } } func TestConnectionManagerAddNil(t *testing.T) { manager := connection.NewManager(100) // 测试添加 nil 连接 err := manager.Add(nil) if err == nil { t.Error("Expected error for nil connection") } } func TestConnectionManagerMaxConnections(t *testing.T) { manager := connection.NewManager(2) // 添加第一个连接 conn1 := &mockConnectionInterface{id: "conn1"} if err := manager.Add(conn1); err != nil { t.Errorf("Expected no error, got: %v", err) } // 添加第二个连接 conn2 := &mockConnectionInterface{id: "conn2"} if err := manager.Add(conn2); err != nil { t.Errorf("Expected no error, got: %v", err) } // 尝试添加第三个连接(应该失败) conn3 := &mockConnectionInterface{id: "conn3"} err := manager.Add(conn3) if err == nil { t.Error("Expected error for exceeding max connections") } } func TestConnectionManagerGetNotFound(t *testing.T) { manager := connection.NewManager(100) // 测试获取不存在的连接 _, err := manager.Get("nonexistent") if err == nil { t.Error("Expected error for nonexistent connection") } if err != errors.ErrConnectionNotFound { t.Errorf("Expected ErrConnectionNotFound, got: %v", err) } } func TestConnectionManagerRemoveNotFound(t *testing.T) { manager := connection.NewManager(100) // 测试移除不存在的连接 err := manager.Remove("nonexistent") if err == nil { t.Error("Expected error for nonexistent connection") } if err != errors.ErrConnectionNotFound { t.Errorf("Expected ErrConnectionNotFound, got: %v", err) } } func TestConnectionManagerGetAll(t *testing.T) { manager := connection.NewManager(100) // 添加多个连接 conn1 := &mockConnectionInterface{id: "conn1"} conn2 := &mockConnectionInterface{id: "conn2"} conn3 := &mockConnectionInterface{id: "conn3"} manager.Add(conn1) manager.Add(conn2) manager.Add(conn3) // 获取所有连接 allConns := manager.GetAll() if len(allConns) != 3 { t.Errorf("Expected 3 connections, got: %d", len(allConns)) } } func TestConnectionGroup(t *testing.T) { manager := connection.NewManager(100) conn1 := &mockConnectionInterface{id: "conn1"} conn2 := &mockConnectionInterface{id: "conn2"} manager.Add(conn1) manager.Add(conn2) // 添加到分组 if err := manager.AddToGroup("group1", "conn1"); err != nil { t.Errorf("Expected no error, got: %v", err) } if err := manager.AddToGroup("group1", "conn2"); err != nil { t.Errorf("Expected no error, got: %v", err) } // 获取分组 group := manager.GetGroup("group1") if len(group) != 2 { t.Errorf("Expected 2 connections in group, got: %d", len(group)) } } func TestConnectionGroupNotFound(t *testing.T) { manager := connection.NewManager(100) // 测试获取不存在的分组 group := manager.GetGroup("nonexistent") if group != nil && len(group) != 0 { t.Error("Expected empty group for nonexistent group") } } func TestConnectionGroupAddToGroupNotFound(t *testing.T) { manager := connection.NewManager(100) // 测试将不存在的连接添加到分组 err := manager.AddToGroup("group1", "nonexistent") if err == nil { t.Error("Expected error for nonexistent connection") } if err != errors.ErrConnectionNotFound { t.Errorf("Expected ErrConnectionNotFound, got: %v", err) } } func TestConnectionGroupRemoveFromGroup(t *testing.T) { manager := connection.NewManager(100) conn1 := &mockConnectionInterface{id: "conn1"} manager.Add(conn1) // 添加到分组 manager.AddToGroup("group1", "conn1") // 从分组中移除 if err := manager.RemoveFromGroup("group1", "conn1"); err != nil { t.Errorf("Expected no error, got: %v", err) } // 验证分组为空 group := manager.GetGroup("group1") if group != nil && len(group) != 0 { t.Error("Expected empty group after removal") } } func TestConnectionGroupRemoveFromGroupNotFound(t *testing.T) { manager := connection.NewManager(100) // 测试从不存在的分组中移除连接 err := manager.RemoveFromGroup("nonexistent", "conn1") if err != nil { t.Errorf("Expected no error for nonexistent group, got: %v", err) } } func TestConnectionGroupBroadcast(t *testing.T) { manager := connection.NewManager(100) conn1 := &mockConnectionInterface{id: "conn1", writeData: make(chan []byte, 1)} conn2 := &mockConnectionInterface{id: "conn2", writeData: make(chan []byte, 1)} manager.Add(conn1) manager.Add(conn2) // 添加到分组 manager.AddToGroup("group1", "conn1") manager.AddToGroup("group1", "conn2") // 广播消息 data := []byte("broadcast message") if err := manager.BroadcastToGroup("group1", data); err != nil { t.Errorf("Expected no error, got: %v", err) } // 验证两个连接都收到了消息 select { case msg := <-conn1.writeData: if string(msg) != string(data) { t.Errorf("Expected %s, got %s", string(data), string(msg)) } case <-time.After(100 * time.Millisecond): t.Error("Timeout waiting for message on conn1") } select { case msg := <-conn2.writeData: if string(msg) != string(data) { t.Errorf("Expected %s, got %s", string(data), string(msg)) } case <-time.After(100 * time.Millisecond): t.Error("Timeout waiting for message on conn2") } } func TestConnectionGroupBroadcastEmptyGroup(t *testing.T) { manager := connection.NewManager(100) // 测试向空分组广播 err := manager.BroadcastToGroup("empty_group", []byte("test")) if err != nil { t.Errorf("Expected no error for empty group, got: %v", err) } } func TestConnectionManagerCleanupInactive(t *testing.T) { manager := connection.NewManager(100) conn1 := &mockConnectionInterfaceWithLastActive{ mockConnectionInterface: mockConnectionInterface{id: "conn1"}, lastActive: time.Now().Add(-40 * time.Second), } conn2 := &mockConnectionInterfaceWithLastActive{ mockConnectionInterface: mockConnectionInterface{id: "conn2"}, lastActive: time.Now(), } manager.Add(conn1) manager.Add(conn2) // 清理非活动连接(30秒超时) manager.CleanupInactive(30 * time.Second) // 验证 conn1 被移除,conn2 保留 if manager.Count() != 1 { t.Errorf("Expected 1 connection after cleanup, got: %d", manager.Count()) } _, err := manager.Get("conn1") if err == nil { t.Error("Expected conn1 to be removed") } _, err = manager.Get("conn2") if err != nil { t.Error("Expected conn2 to be retained") } } func TestConnectionManagerCleanupInactiveInvalidTimeout(t *testing.T) { manager := connection.NewManager(100) conn1 := &mockConnectionInterface{id: "conn1"} manager.Add(conn1) // 测试无效的超时类型 manager.CleanupInactive("invalid") // 验证连接没有被移除 if manager.Count() != 1 { t.Errorf("Expected 1 connection, got: %d", manager.Count()) } } // mockConnectionInterface 模拟连接接口 type mockConnectionInterface struct { id string writeData chan []byte } func (m *mockConnectionInterface) ID() string { return m.id } func (m *mockConnectionInterface) RemoteAddr() string { return "127.0.0.1:6995" } func (m *mockConnectionInterface) LocalAddr() string { return "127.0.0.1:6995" } func (m *mockConnectionInterface) Write(data []byte) error { if m.writeData != nil { m.writeData <- data } return nil } func (m *mockConnectionInterface) Close() error { return nil } // mockConnectionInterfaceWithLastActive 带最后活动时间的连接 type mockConnectionInterfaceWithLastActive struct { mockConnectionInterface lastActive time.Time } func (m *mockConnectionInterfaceWithLastActive) LastActive() time.Time { return m.lastActive }