You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
3.5 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package integration
import (
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/noahlann/nnet/pkg/nnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestConnectionLifecycleHooks 测试连接生命周期钩子
func TestConnectionLifecycleHooks(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
// 连接生命周期钩子状态
var mu sync.Mutex
var onOpenCalled bool
var onCloseCalled bool
var onOpenConnID string
var onCloseConnID string
var onOpenAddr string
// 创建服务器
server, err := nnet.NewServer(cfg)
require.NoError(t, err)
// 注册连接生命周期钩子
hook := nnet.NewConnectionHookFunc(
func(connID string, remoteAddr string) error {
mu.Lock()
onOpenCalled = true
onOpenConnID = connID
onOpenAddr = remoteAddr
mu.Unlock()
t.Logf("OnOpen called: connID=%s, remoteAddr=%s", connID, remoteAddr)
return nil
},
nil, // OnTraffic
func(connID string, err error) error {
mu.Lock()
onCloseCalled = true
onCloseConnID = connID
mu.Unlock()
t.Logf("OnClose called: connID=%s, err=%v", connID, err)
return nil
},
nil, // OnError
)
// 注意:这里需要检查服务器是否支持注册连接生命周期钩子
// 如果Server接口没有这个方法我们需要使用内部实现
// 检查服务器是否支持RegisterConnectionLifecycleHook方法
if serverWithHooks, ok := server.(interface {
RegisterConnectionLifecycleHook(hook nnet.ConnectionLifecycleHook)
}); ok {
serverWithHooks.RegisterConnectionLifecycleHook(hook)
} else {
// 如果接口不支持,跳过这个测试
t.Skip("Server does not support connection lifecycle hooks")
}
// 注册路由
server.Router().RegisterString("test", func(ctx nnet.Context) error {
return ctx.Response().WriteBytes([]byte("ok\n"))
})
// 启动服务器
ts := &TestServer{
Server: server,
Addr: cfg.Addr,
stopCh: make(chan struct{}),
}
ts.wg.Add(1)
go func() {
defer ts.wg.Done()
if err := server.Start(); err != nil {
t.Logf("Server error: %v", err)
}
}()
require.Eventually(t, func() bool {
return server.Started()
}, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds")
time.Sleep(200 * time.Millisecond)
defer CleanupTestServer(t, ts)
// 创建客户端并连接
client := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
// 等待OnOpen被调用
require.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
return onOpenCalled
}, 2*time.Second, 50*time.Millisecond, "OnOpen should be called")
// 验证OnOpen被调用
mu.Lock()
assert.True(t, onOpenCalled, "OnOpen should be called")
assert.NotEmpty(t, onOpenConnID, "OnOpen connID should not be empty")
assert.NotEmpty(t, onOpenAddr, "OnOpen remoteAddr should not be empty")
mu.Unlock()
// 发送请求
resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second)
assert.Contains(t, string(resp), "ok", "Response should contain 'ok'")
// 断开连接
client.Disconnect()
// 等待OnClose被调用
require.Eventually(t, func() bool {
mu.Lock()
defer mu.Unlock()
return onCloseCalled
}, 2*time.Second, 50*time.Millisecond, "OnClose should be called")
// 验证OnClose被调用
mu.Lock()
assert.True(t, onCloseCalled, "OnClose should be called")
assert.Equal(t, onOpenConnID, onCloseConnID, "OnClose connID should match OnOpen connID")
mu.Unlock()
}