package middleware import ( "errors" "testing" "github.com/noahlann/nnet/internal/codec" "github.com/noahlann/nnet/internal/request" "github.com/noahlann/nnet/internal/response" ctxpkg "github.com/noahlann/nnet/pkg/context" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMiddlewareChain(t *testing.T) { // 创建测试Context conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(nil, conn, req, resp) // 创建中间件链 callOrder := make([]int, 0) middleware1 := func(ctx ctxpkg.Context) error { callOrder = append(callOrder, 1) return nil } middleware2 := func(ctx ctxpkg.Context) error { callOrder = append(callOrder, 2) return nil } chain := NewChain(middleware1, middleware2) handler := func(ctx ctxpkg.Context) error { callOrder = append(callOrder, 3) return nil } // 执行中间件链 finalHandler := chain.Then(handler) err := finalHandler(ctx) require.NoError(t, err, "Expected no error") // 验证执行顺序(中间件从前往后执行) assert.Equal(t, []int{1, 2, 3}, callOrder, "Expected call order to be 1, 2, 3") } func TestMiddlewareChainEmpty(t *testing.T) { chain := NewChain() handler := func(ctx ctxpkg.Context) error { return nil } // 空的中间件链应该直接返回handler finalHandler := chain.Then(handler) assert.NotNil(t, finalHandler, "Expected handler to be returned") } func TestMiddlewareChainError(t *testing.T) { // 创建测试Context conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(nil, conn, req, resp) // 创建会返回错误的中间件 testError := errors.New("middleware error") middleware1 := func(ctx ctxpkg.Context) error { return testError } chain := NewChain(middleware1) handler := func(ctx ctxpkg.Context) error { return nil } // 执行中间件链(应该返回错误) finalHandler := chain.Then(handler) err := finalHandler(ctx) assert.Error(t, err, "Expected error from middleware") assert.Equal(t, testError, err, "Expected middleware error") } func TestMiddlewareChainAppend(t *testing.T) { chain := NewChain() middleware1 := func(ctx ctxpkg.Context) error { return nil } middleware2 := func(ctx ctxpkg.Context) error { return nil } chain.Append(middleware1, middleware2) assert.Equal(t, 2, len(chain.middlewares), "Expected 2 middlewares") } func TestMiddlewareChainPrepend(t *testing.T) { chain := NewChain() middleware1 := func(ctx ctxpkg.Context) error { return nil } middleware2 := func(ctx ctxpkg.Context) error { return nil } chain.Append(middleware1) chain.Prepend(middleware2) // Prepend应该在前面 assert.Equal(t, 2, len(chain.middlewares), "Expected 2 middlewares") } // mockConnection 模拟连接 type mockConnection struct { id string } func (m *mockConnection) ID() string { return m.id } func (m *mockConnection) RemoteAddr() string { return "127.0.0.1:6995" } func (m *mockConnection) LocalAddr() string { return "127.0.0.1:6995" } func (m *mockConnection) Write(data []byte) error { return nil } func (m *mockConnection) Close() error { return nil }