package interceptor import ( "errors" "testing" "github.com/noahlann/nnet/internal/codec" "github.com/noahlann/nnet/internal/request" "github.com/noahlann/nnet/internal/response" ctxpkg "github.com/noahlann/nnet/pkg/context" interceptorpkg "github.com/noahlann/nnet/pkg/interceptor" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestInterceptorChain(t *testing.T) { // 创建测试Context conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(nil, conn, req, resp) // 创建拦截器链 callOrder := make([]int, 0) interceptor1 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) { callOrder = append(callOrder, 1) return next.Next(data, ctx) }) interceptor2 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) { callOrder = append(callOrder, 2) return next.Next(data, ctx) }) chain := NewChain(interceptor1, interceptor2) // 执行拦截器链 data := []byte("test") result, continue_, err := chain.Next(data, ctx) require.NoError(t, err, "Expected no error") assert.True(t, continue_, "Expected continue to be true") assert.Equal(t, data, result, "Expected data to be unchanged") // 验证执行顺序 assert.Equal(t, []int{1, 2}, callOrder, "Expected call order to be 1, 2") } func TestInterceptorChainEmpty(t *testing.T) { conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(nil, conn, req, resp) chain := NewChain() // 空的拦截器链应该直接返回数据 data := []byte("test") result, continue_, err := chain.Next(data, ctx) require.NoError(t, err, "Expected no error") assert.True(t, continue_, "Expected continue to be true") assert.Equal(t, data, result, "Expected data to be unchanged") } func TestInterceptorChainError(t *testing.T) { conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(nil, conn, req, resp) // 创建会返回错误的拦截器 testError := errors.New("interceptor error") interceptor1 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) { return nil, false, testError }) chain := NewChain(interceptor1) // 执行拦截器链(应该返回错误) data := []byte("test") _, continue_, err := chain.Next(data, ctx) assert.Error(t, err, "Expected error from interceptor") assert.False(t, continue_, "Expected continue to be false") assert.Equal(t, testError, err, "Expected interceptor error") } func TestInterceptorChainStop(t *testing.T) { conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(nil, conn, req, resp) // 创建会停止链的拦截器 interceptor1 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) { return []byte("modified"), false, nil // 停止链,返回修改后的数据 }) chain := NewChain(interceptor1) // 执行拦截器链 data := []byte("test") result, continue_, err := chain.Next(data, ctx) require.NoError(t, err, "Expected no error") assert.False(t, continue_, "Expected continue to be false") assert.Equal(t, []byte("modified"), result, "Expected modified data") } func TestExecute(t *testing.T) { conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(nil, conn, req, resp) interceptor1 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) { return next.Next(data, ctx) }) data := []byte("test") result, continue_, err := Execute(data, ctx, interceptor1) require.NoError(t, err) assert.True(t, continue_) assert.Equal(t, data, result) } // mockConnection 模拟连接 type mockConnection struct { id string } func (m *mockConnection) ID() string { return m.id } func (m *mockConnection) RemoteAddr() string { return "127.0.0.1:6995" } func (m *mockConnection) LocalAddr() string { return "127.0.0.1:6995" } func (m *mockConnection) Write(data []byte) error { return nil } func (m *mockConnection) Close() error { return nil }