You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

370 lines
7.9 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package client
import (
"net"
"testing"
"time"
clientpkg "github.com/noahlann/nnet/pkg/client"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestPool_Basic 测试连接池基本功能
func TestPool_Basic(t *testing.T) {
// 启动一个测试服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
addr := listener.Addr().String()
// 接受连接
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
for {
n, err := c.Read(buf)
if err != nil {
return
}
c.Write(buf[:n])
}
}(conn)
}
}()
// 创建连接池配置
config := &PoolConfig{
MaxSize: 5,
MinSize: 2,
ClientConfig: &clientpkg.Config{
Addr: "tcp://" + addr,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
},
AcquireTimeout: 5 * time.Second,
IdleTimeout: 1 * time.Minute,
}
// 创建连接池
pool, err := NewPool(config)
require.NoError(t, err)
defer pool.Close()
// 等待连接建立
time.Sleep(100 * time.Millisecond)
// 测试获取和释放连接
client, err := pool.Acquire()
require.NoError(t, err)
assert.NotNil(t, client)
assert.True(t, client.IsConnected())
// 测试发送和接收
err = client.Send([]byte("hello"))
require.NoError(t, err)
resp, err := client.Receive()
require.NoError(t, err)
assert.Equal(t, "hello", string(resp))
// 释放连接
err = pool.Release(client)
require.NoError(t, err)
// 测试连接池大小
assert.GreaterOrEqual(t, pool.Size(), config.MinSize)
assert.LessOrEqual(t, pool.Size(), config.MaxSize)
}
// TestPool_MaxSize 测试最大连接数限制
func TestPool_MaxSize(t *testing.T) {
// 启动一个测试服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
addr := listener.Addr().String()
// 接受连接
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
for {
n, err := c.Read(buf)
if err != nil {
return
}
c.Write(buf[:n])
}
}(conn)
}
}()
// 创建连接池配置最大连接数为2
config := &PoolConfig{
MaxSize: 2,
MinSize: 1,
ClientConfig: &clientpkg.Config{
Addr: "tcp://" + addr,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
},
AcquireTimeout: 5 * time.Second,
IdleTimeout: 1 * time.Minute,
}
// 创建连接池
pool, err := NewPool(config)
require.NoError(t, err)
defer pool.Close()
// 等待连接建立
time.Sleep(100 * time.Millisecond)
// 获取所有可用连接
clients := make([]clientpkg.Client, 0, config.MaxSize)
for i := 0; i < config.MaxSize; i++ {
client, err := pool.Acquire()
require.NoError(t, err)
clients = append(clients, client)
}
// 尝试获取更多连接(应该失败或超时)
done := make(chan bool, 1)
go func() {
_, err := pool.Acquire()
if err != nil {
done <- true
}
}()
select {
case <-done:
// 获取连接失败,符合预期
case <-time.After(1 * time.Second):
// 超时,也符合预期(因为连接池已满)
}
// 释放所有连接
for _, client := range clients {
err := pool.Release(client)
require.NoError(t, err)
}
}
// TestPool_Close 测试关闭连接池
func TestPool_Close(t *testing.T) {
// 启动一个测试服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
addr := listener.Addr().String()
// 接受连接
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
for {
n, err := c.Read(buf)
if err != nil {
return
}
c.Write(buf[:n])
}
}(conn)
}
}()
// 创建连接池配置
config := &PoolConfig{
MaxSize: 5,
MinSize: 2,
ClientConfig: &clientpkg.Config{
Addr: "tcp://" + addr,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
},
AcquireTimeout: 5 * time.Second,
IdleTimeout: 1 * time.Minute,
}
// 创建连接池
pool, err := NewPool(config)
require.NoError(t, err)
// 获取一个连接
client, err := pool.Acquire()
require.NoError(t, err)
// 释放连接(在关闭连接池之前)
err = pool.Release(client)
require.NoError(t, err)
// 关闭连接池
err = pool.Close()
require.NoError(t, err)
// 检查连接池是否已关闭
assert.True(t, pool.IsClosed())
// 尝试获取连接(应该失败)
_, err = pool.Acquire()
assert.Error(t, err)
assert.Contains(t, err.Error(), "closed")
}
// TestPool_InvalidConnection 测试无效连接处理
func TestPool_InvalidConnection(t *testing.T) {
// 创建连接池配置(连接到不存在的服务器)
config := &PoolConfig{
MaxSize: 5,
MinSize: 0, // 不预创建连接
ClientConfig: &clientpkg.Config{
Addr: "tcp://127.0.0.1:99999", // 无效地址
ConnectTimeout: 1 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
},
AcquireTimeout: 1 * time.Second,
IdleTimeout: 1 * time.Minute,
}
// 创建连接池应该成功因为MinSize为0
pool, err := NewPool(config)
require.NoError(t, err)
defer pool.Close()
// 尝试获取连接(应该失败)
_, err = pool.Acquire()
assert.Error(t, err)
}
// TestPool_Concurrent 测试并发获取和释放连接
func TestPool_Concurrent(t *testing.T) {
// 启动一个测试服务器
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()
addr := listener.Addr().String()
// 接受连接
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
for {
n, err := c.Read(buf)
if err != nil {
return
}
c.Write(buf[:n])
}
}(conn)
}
}()
// 创建连接池配置
config := &PoolConfig{
MaxSize: 10,
MinSize: 2,
ClientConfig: &clientpkg.Config{
Addr: "tcp://" + addr,
ConnectTimeout: 5 * time.Second,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
},
AcquireTimeout: 5 * time.Second,
IdleTimeout: 1 * time.Minute,
}
// 创建连接池
pool, err := NewPool(config)
require.NoError(t, err)
defer pool.Close()
// 等待连接建立
time.Sleep(100 * time.Millisecond)
// 并发获取和释放连接
const numGoroutines = 20
const numOps = 10
done := make(chan bool, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer func() { done <- true }()
for j := 0; j < numOps; j++ {
client, err := pool.Acquire()
if err != nil {
t.Logf("Goroutine %d: Failed to acquire connection: %v", id, err)
continue
}
// 使用连接
err = client.Send([]byte("test"))
if err != nil {
t.Logf("Goroutine %d: Failed to send: %v", id, err)
}
// 释放连接
err = pool.Release(client)
if err != nil {
t.Logf("Goroutine %d: Failed to release connection: %v", id, err)
}
time.Sleep(10 * time.Millisecond)
}
}(i)
}
// 等待所有goroutine完成
for i := 0; i < numGoroutines; i++ {
<-done
}
// 检查连接池状态
assert.GreaterOrEqual(t, pool.Size(), config.MinSize)
assert.LessOrEqual(t, pool.Size(), config.MaxSize)
}
// TestDefaultPoolConfig 测试默认连接池配置
func TestDefaultPoolConfig(t *testing.T) {
config := DefaultPoolConfig()
assert.NotNil(t, config)
assert.Equal(t, 10, config.MaxSize)
assert.Equal(t, 2, config.MinSize)
assert.Equal(t, 10*time.Second, config.AcquireTimeout)
assert.Equal(t, 5*time.Minute, config.IdleTimeout)
assert.NotNil(t, config.ClientConfig)
}