|
|
package integration
|
|
|
|
|
|
import (
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"sync"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/pkg/nnet"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
// TestConnectionLifecycleHooks 测试连接生命周期钩子
|
|
|
func TestConnectionLifecycleHooks(t *testing.T) {
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
require.NoError(t, err)
|
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
|
listener.Close()
|
|
|
|
|
|
cfg := &nnet.Config{
|
|
|
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
|
|
|
}
|
|
|
|
|
|
// 连接生命周期钩子状态
|
|
|
var mu sync.Mutex
|
|
|
var onOpenCalled bool
|
|
|
var onCloseCalled bool
|
|
|
var onOpenConnID string
|
|
|
var onCloseConnID string
|
|
|
var onOpenAddr string
|
|
|
|
|
|
// 创建服务器
|
|
|
server, err := nnet.NewServer(cfg)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 注册连接生命周期钩子
|
|
|
hook := nnet.NewConnectionHookFunc(
|
|
|
func(connID string, remoteAddr string) error {
|
|
|
mu.Lock()
|
|
|
onOpenCalled = true
|
|
|
onOpenConnID = connID
|
|
|
onOpenAddr = remoteAddr
|
|
|
mu.Unlock()
|
|
|
t.Logf("OnOpen called: connID=%s, remoteAddr=%s", connID, remoteAddr)
|
|
|
return nil
|
|
|
},
|
|
|
nil, // OnTraffic
|
|
|
func(connID string, err error) error {
|
|
|
mu.Lock()
|
|
|
onCloseCalled = true
|
|
|
onCloseConnID = connID
|
|
|
mu.Unlock()
|
|
|
t.Logf("OnClose called: connID=%s, err=%v", connID, err)
|
|
|
return nil
|
|
|
},
|
|
|
nil, // OnError
|
|
|
)
|
|
|
|
|
|
// 注意:这里需要检查服务器是否支持注册连接生命周期钩子
|
|
|
// 如果Server接口没有这个方法,我们需要使用内部实现
|
|
|
// 检查服务器是否支持RegisterConnectionLifecycleHook方法
|
|
|
if serverWithHooks, ok := server.(interface {
|
|
|
RegisterConnectionLifecycleHook(hook nnet.ConnectionLifecycleHook)
|
|
|
}); ok {
|
|
|
serverWithHooks.RegisterConnectionLifecycleHook(hook)
|
|
|
} else {
|
|
|
// 如果接口不支持,跳过这个测试
|
|
|
t.Skip("Server does not support connection lifecycle hooks")
|
|
|
}
|
|
|
|
|
|
// 注册路由
|
|
|
server.Router().RegisterString("test", func(ctx nnet.Context) error {
|
|
|
return ctx.Response().WriteBytes([]byte("ok\n"))
|
|
|
})
|
|
|
|
|
|
// 启动服务器
|
|
|
ts := &TestServer{
|
|
|
Server: server,
|
|
|
Addr: cfg.Addr,
|
|
|
stopCh: make(chan struct{}),
|
|
|
}
|
|
|
|
|
|
ts.wg.Add(1)
|
|
|
go func() {
|
|
|
defer ts.wg.Done()
|
|
|
if err := server.Start(); err != nil {
|
|
|
t.Logf("Server error: %v", err)
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
require.Eventually(t, func() bool {
|
|
|
return server.Started()
|
|
|
}, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds")
|
|
|
time.Sleep(200 * time.Millisecond)
|
|
|
|
|
|
defer CleanupTestServer(t, ts)
|
|
|
|
|
|
// 创建客户端并连接
|
|
|
client := NewTestClient(t, ts.Addr, nil)
|
|
|
defer CleanupTestClient(t, client)
|
|
|
|
|
|
ConnectTestClient(t, client)
|
|
|
|
|
|
// 等待OnOpen被调用
|
|
|
require.Eventually(t, func() bool {
|
|
|
mu.Lock()
|
|
|
defer mu.Unlock()
|
|
|
return onOpenCalled
|
|
|
}, 2*time.Second, 50*time.Millisecond, "OnOpen should be called")
|
|
|
|
|
|
// 验证OnOpen被调用
|
|
|
mu.Lock()
|
|
|
assert.True(t, onOpenCalled, "OnOpen should be called")
|
|
|
assert.NotEmpty(t, onOpenConnID, "OnOpen connID should not be empty")
|
|
|
assert.NotEmpty(t, onOpenAddr, "OnOpen remoteAddr should not be empty")
|
|
|
mu.Unlock()
|
|
|
|
|
|
// 发送请求
|
|
|
resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second)
|
|
|
assert.Contains(t, string(resp), "ok", "Response should contain 'ok'")
|
|
|
|
|
|
// 断开连接
|
|
|
client.Disconnect()
|
|
|
|
|
|
// 等待OnClose被调用
|
|
|
require.Eventually(t, func() bool {
|
|
|
mu.Lock()
|
|
|
defer mu.Unlock()
|
|
|
return onCloseCalled
|
|
|
}, 2*time.Second, 50*time.Millisecond, "OnClose should be called")
|
|
|
|
|
|
// 验证OnClose被调用
|
|
|
mu.Lock()
|
|
|
assert.True(t, onCloseCalled, "OnClose should be called")
|
|
|
assert.Equal(t, onOpenConnID, onCloseConnID, "OnClose connID should match OnOpen connID")
|
|
|
mu.Unlock()
|
|
|
}
|
|
|
|