package integration import ( "encoding/json" "fmt" "net" "testing" "time" "github.com/noahlann/nnet/pkg/nnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestMiddlewareChain 测试中间件链 func TestMiddlewareChain(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), Codec: &nnet.CodecConfig{ DefaultCodec: "json", }, } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { // 创建路由分组并添加多个中间件 group := srv.Router().Group() // 第一个中间件:设置值 group.Use(func(ctx nnet.Context) error { ctx.Set("middleware1", "executed") return nil }) // 第二个中间件:修改值 group.Use(func(ctx nnet.Context) error { ctx.Set("middleware2", "executed") if val := ctx.GetString("middleware1"); val != "" { ctx.Set("middleware1", val+"-modified") } return nil }) // 注册路由 group.RegisterString("test", func(ctx nnet.Context) error { mw1 := ctx.GetString("middleware1") mw2 := ctx.GetString("middleware2") return ctx.Response().Write(map[string]any{ "middleware1": mw1, "middleware2": mw2, }) }) }) defer CleanupTestServer(t, ts) client := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client) ConnectTestClient(t, client) time.Sleep(100 * time.Millisecond) resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second) t.Logf("Response: %q", string(resp)) var result map[string]any err = json.Unmarshal(resp, &result) assert.NoError(t, err, "Response should be valid JSON") assert.Equal(t, "executed-modified", result["middleware1"], "Middleware1 should be executed and modified") assert.Equal(t, "executed", result["middleware2"], "Middleware2 should be executed") } // TestGlobalMiddleware 测试全局中间件 func TestGlobalMiddleware(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), Codec: &nnet.CodecConfig{ DefaultCodec: "json", }, } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { // 创建路由分组并添加全局中间件 group := srv.Router().Group() group.Use(func(ctx nnet.Context) error { ctx.Set("global", "executed") return nil }) // 注册路由到分组 group.RegisterString("test", func(ctx nnet.Context) error { global := ctx.GetString("global") return ctx.Response().Write(map[string]any{ "global": global, }) }) }) defer CleanupTestServer(t, ts) client := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client) ConnectTestClient(t, client) time.Sleep(100 * time.Millisecond) resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second) t.Logf("Response: %q", string(resp)) var result map[string]any err = json.Unmarshal(resp, &result) assert.NoError(t, err, "Response should be valid JSON") assert.Equal(t, "executed", result["global"], "Global middleware should be executed") } // TestMiddlewareError 测试中间件错误处理 func TestMiddlewareError(t *testing.T) { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), } ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) { group := srv.Router().Group() // 中间件返回错误 group.Use(func(ctx nnet.Context) error { return fmt.Errorf("middleware error") }) // 注册路由(不应该被执行) group.RegisterString("test", func(ctx nnet.Context) error { return ctx.Response().WriteBytes([]byte("ok\n")) }) }) defer CleanupTestServer(t, ts) client := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client) ConnectTestClient(t, client) time.Sleep(100 * time.Millisecond) // 发送请求 resp, err := client.Request([]byte("test"), 3*time.Second) if err != nil { // 如果请求失败,这是预期的行为 t.Logf("Request failed as expected: %v", err) assert.Error(t, err, "Request should fail due to middleware error") } else { // 如果收到响应,应该是错误响应 t.Logf("Response: %q", string(resp)) assert.NotEmpty(t, resp, "Response should not be empty") } }