package client import ( "net" "testing" "time" clientpkg "github.com/noahlann/nnet/pkg/client" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestPool_Basic 测试连接池基本功能 func TestPool_Basic(t *testing.T) { // 启动一个测试服务器 listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer listener.Close() addr := listener.Addr().String() // 接受连接 go func() { for { conn, err := listener.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() buf := make([]byte, 1024) for { n, err := c.Read(buf) if err != nil { return } c.Write(buf[:n]) } }(conn) } }() // 创建连接池配置 config := &PoolConfig{ MaxSize: 5, MinSize: 2, ClientConfig: &clientpkg.Config{ Addr: "tcp://" + addr, ConnectTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, }, AcquireTimeout: 5 * time.Second, IdleTimeout: 1 * time.Minute, } // 创建连接池 pool, err := NewPool(config) require.NoError(t, err) defer pool.Close() // 等待连接建立 time.Sleep(100 * time.Millisecond) // 测试获取和释放连接 client, err := pool.Acquire() require.NoError(t, err) assert.NotNil(t, client) assert.True(t, client.IsConnected()) // 测试发送和接收 err = client.Send([]byte("hello")) require.NoError(t, err) resp, err := client.Receive() require.NoError(t, err) assert.Equal(t, "hello", string(resp)) // 释放连接 err = pool.Release(client) require.NoError(t, err) // 测试连接池大小 assert.GreaterOrEqual(t, pool.Size(), config.MinSize) assert.LessOrEqual(t, pool.Size(), config.MaxSize) } // TestPool_MaxSize 测试最大连接数限制 func TestPool_MaxSize(t *testing.T) { // 启动一个测试服务器 listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer listener.Close() addr := listener.Addr().String() // 接受连接 go func() { for { conn, err := listener.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() buf := make([]byte, 1024) for { n, err := c.Read(buf) if err != nil { return } c.Write(buf[:n]) } }(conn) } }() // 创建连接池配置(最大连接数为2) config := &PoolConfig{ MaxSize: 2, MinSize: 1, ClientConfig: &clientpkg.Config{ Addr: "tcp://" + addr, ConnectTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, }, AcquireTimeout: 5 * time.Second, IdleTimeout: 1 * time.Minute, } // 创建连接池 pool, err := NewPool(config) require.NoError(t, err) defer pool.Close() // 等待连接建立 time.Sleep(100 * time.Millisecond) // 获取所有可用连接 clients := make([]clientpkg.Client, 0, config.MaxSize) for i := 0; i < config.MaxSize; i++ { client, err := pool.Acquire() require.NoError(t, err) clients = append(clients, client) } // 尝试获取更多连接(应该失败或超时) done := make(chan bool, 1) go func() { _, err := pool.Acquire() if err != nil { done <- true } }() select { case <-done: // 获取连接失败,符合预期 case <-time.After(1 * time.Second): // 超时,也符合预期(因为连接池已满) } // 释放所有连接 for _, client := range clients { err := pool.Release(client) require.NoError(t, err) } } // TestPool_Close 测试关闭连接池 func TestPool_Close(t *testing.T) { // 启动一个测试服务器 listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer listener.Close() addr := listener.Addr().String() // 接受连接 go func() { for { conn, err := listener.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() buf := make([]byte, 1024) for { n, err := c.Read(buf) if err != nil { return } c.Write(buf[:n]) } }(conn) } }() // 创建连接池配置 config := &PoolConfig{ MaxSize: 5, MinSize: 2, ClientConfig: &clientpkg.Config{ Addr: "tcp://" + addr, ConnectTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, }, AcquireTimeout: 5 * time.Second, IdleTimeout: 1 * time.Minute, } // 创建连接池 pool, err := NewPool(config) require.NoError(t, err) // 获取一个连接 client, err := pool.Acquire() require.NoError(t, err) // 释放连接(在关闭连接池之前) err = pool.Release(client) require.NoError(t, err) // 关闭连接池 err = pool.Close() require.NoError(t, err) // 检查连接池是否已关闭 assert.True(t, pool.IsClosed()) // 尝试获取连接(应该失败) _, err = pool.Acquire() assert.Error(t, err) assert.Contains(t, err.Error(), "closed") } // TestPool_InvalidConnection 测试无效连接处理 func TestPool_InvalidConnection(t *testing.T) { // 创建连接池配置(连接到不存在的服务器) config := &PoolConfig{ MaxSize: 5, MinSize: 0, // 不预创建连接 ClientConfig: &clientpkg.Config{ Addr: "tcp://127.0.0.1:99999", // 无效地址 ConnectTimeout: 1 * time.Second, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, }, AcquireTimeout: 1 * time.Second, IdleTimeout: 1 * time.Minute, } // 创建连接池(应该成功,因为MinSize为0) pool, err := NewPool(config) require.NoError(t, err) defer pool.Close() // 尝试获取连接(应该失败) _, err = pool.Acquire() assert.Error(t, err) } // TestPool_Concurrent 测试并发获取和释放连接 func TestPool_Concurrent(t *testing.T) { // 启动一个测试服务器 listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer listener.Close() addr := listener.Addr().String() // 接受连接 go func() { for { conn, err := listener.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() buf := make([]byte, 1024) for { n, err := c.Read(buf) if err != nil { return } c.Write(buf[:n]) } }(conn) } }() // 创建连接池配置 config := &PoolConfig{ MaxSize: 10, MinSize: 2, ClientConfig: &clientpkg.Config{ Addr: "tcp://" + addr, ConnectTimeout: 5 * time.Second, ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, }, AcquireTimeout: 5 * time.Second, IdleTimeout: 1 * time.Minute, } // 创建连接池 pool, err := NewPool(config) require.NoError(t, err) defer pool.Close() // 等待连接建立 time.Sleep(100 * time.Millisecond) // 并发获取和释放连接 const numGoroutines = 20 const numOps = 10 done := make(chan bool, numGoroutines) for i := 0; i < numGoroutines; i++ { go func(id int) { defer func() { done <- true }() for j := 0; j < numOps; j++ { client, err := pool.Acquire() if err != nil { t.Logf("Goroutine %d: Failed to acquire connection: %v", id, err) continue } // 使用连接 err = client.Send([]byte("test")) if err != nil { t.Logf("Goroutine %d: Failed to send: %v", id, err) } // 释放连接 err = pool.Release(client) if err != nil { t.Logf("Goroutine %d: Failed to release connection: %v", id, err) } time.Sleep(10 * time.Millisecond) } }(i) } // 等待所有goroutine完成 for i := 0; i < numGoroutines; i++ { <-done } // 检查连接池状态 assert.GreaterOrEqual(t, pool.Size(), config.MinSize) assert.LessOrEqual(t, pool.Size(), config.MaxSize) } // TestDefaultPoolConfig 测试默认连接池配置 func TestDefaultPoolConfig(t *testing.T) { config := DefaultPoolConfig() assert.NotNil(t, config) assert.Equal(t, 10, config.MaxSize) assert.Equal(t, 2, config.MinSize) assert.Equal(t, 10*time.Second, config.AcquireTimeout) assert.Equal(t, 5*time.Minute, config.IdleTimeout) assert.NotNil(t, config.ClientConfig) }