package integration import ( "fmt" "net" "sync" "testing" "time" "github.com/noahlann/nnet/pkg/nnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestConnectionLifecycleHooks 测试连接生命周期钩子 func TestConnectionLifecycleHooks(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), } // 连接生命周期钩子状态 var mu sync.Mutex var onOpenCalled bool var onCloseCalled bool var onOpenConnID string var onCloseConnID string var onOpenAddr string // 创建服务器 server, err := nnet.NewServer(cfg) require.NoError(t, err) // 注册连接生命周期钩子 hook := nnet.NewConnectionHookFunc( func(connID string, remoteAddr string) error { mu.Lock() onOpenCalled = true onOpenConnID = connID onOpenAddr = remoteAddr mu.Unlock() t.Logf("OnOpen called: connID=%s, remoteAddr=%s", connID, remoteAddr) return nil }, nil, // OnTraffic func(connID string, err error) error { mu.Lock() onCloseCalled = true onCloseConnID = connID mu.Unlock() t.Logf("OnClose called: connID=%s, err=%v", connID, err) return nil }, nil, // OnError ) // 注意:这里需要检查服务器是否支持注册连接生命周期钩子 // 如果Server接口没有这个方法,我们需要使用内部实现 // 检查服务器是否支持RegisterConnectionLifecycleHook方法 if serverWithHooks, ok := server.(interface { RegisterConnectionLifecycleHook(hook nnet.ConnectionLifecycleHook) }); ok { serverWithHooks.RegisterConnectionLifecycleHook(hook) } else { // 如果接口不支持,跳过这个测试 t.Skip("Server does not support connection lifecycle hooks") } // 注册路由 server.Router().RegisterString("test", func(ctx nnet.Context) error { return ctx.Response().WriteBytes([]byte("ok\n")) }) // 启动服务器 ts := &TestServer{ Server: server, Addr: cfg.Addr, stopCh: make(chan struct{}), } ts.wg.Add(1) go func() { defer ts.wg.Done() if err := server.Start(); err != nil { t.Logf("Server error: %v", err) } }() require.Eventually(t, func() bool { return server.Started() }, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds") time.Sleep(200 * time.Millisecond) defer CleanupTestServer(t, ts) // 创建客户端并连接 client := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client) ConnectTestClient(t, client) // 等待OnOpen被调用 require.Eventually(t, func() bool { mu.Lock() defer mu.Unlock() return onOpenCalled }, 2*time.Second, 50*time.Millisecond, "OnOpen should be called") // 验证OnOpen被调用 mu.Lock() assert.True(t, onOpenCalled, "OnOpen should be called") assert.NotEmpty(t, onOpenConnID, "OnOpen connID should not be empty") assert.NotEmpty(t, onOpenAddr, "OnOpen remoteAddr should not be empty") mu.Unlock() // 发送请求 resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second) assert.Contains(t, string(resp), "ok", "Response should contain 'ok'") // 断开连接 client.Disconnect() // 等待OnClose被调用 require.Eventually(t, func() bool { mu.Lock() defer mu.Unlock() return onCloseCalled }, 2*time.Second, 50*time.Millisecond, "OnClose should be called") // 验证OnClose被调用 mu.Lock() assert.True(t, onCloseCalled, "OnClose should be called") assert.Equal(t, onOpenConnID, onCloseConnID, "OnClose connID should match OnOpen connID") mu.Unlock() }