You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
4.7 KiB
Go

package interceptor
import (
"errors"
"testing"
"github.com/noahlann/nnet/internal/codec"
"github.com/noahlann/nnet/internal/request"
"github.com/noahlann/nnet/internal/response"
ctxpkg "github.com/noahlann/nnet/pkg/context"
interceptorpkg "github.com/noahlann/nnet/pkg/interceptor"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInterceptorChain(t *testing.T) {
// 创建测试Context
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(nil, conn, req, resp)
// 创建拦截器链
callOrder := make([]int, 0)
interceptor1 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) {
callOrder = append(callOrder, 1)
return next.Next(data, ctx)
})
interceptor2 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) {
callOrder = append(callOrder, 2)
return next.Next(data, ctx)
})
chain := NewChain(interceptor1, interceptor2)
// 执行拦截器链
data := []byte("test")
result, continue_, err := chain.Next(data, ctx)
require.NoError(t, err, "Expected no error")
assert.True(t, continue_, "Expected continue to be true")
assert.Equal(t, data, result, "Expected data to be unchanged")
// 验证执行顺序
assert.Equal(t, []int{1, 2}, callOrder, "Expected call order to be 1, 2")
}
func TestInterceptorChainEmpty(t *testing.T) {
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(nil, conn, req, resp)
chain := NewChain()
// 空的拦截器链应该直接返回数据
data := []byte("test")
result, continue_, err := chain.Next(data, ctx)
require.NoError(t, err, "Expected no error")
assert.True(t, continue_, "Expected continue to be true")
assert.Equal(t, data, result, "Expected data to be unchanged")
}
func TestInterceptorChainError(t *testing.T) {
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(nil, conn, req, resp)
// 创建会返回错误的拦截器
testError := errors.New("interceptor error")
interceptor1 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) {
return nil, false, testError
})
chain := NewChain(interceptor1)
// 执行拦截器链(应该返回错误)
data := []byte("test")
_, continue_, err := chain.Next(data, ctx)
assert.Error(t, err, "Expected error from interceptor")
assert.False(t, continue_, "Expected continue to be false")
assert.Equal(t, testError, err, "Expected interceptor error")
}
func TestInterceptorChainStop(t *testing.T) {
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(nil, conn, req, resp)
// 创建会停止链的拦截器
interceptor1 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) {
return []byte("modified"), false, nil // 停止链,返回修改后的数据
})
chain := NewChain(interceptor1)
// 执行拦截器链
data := []byte("test")
result, continue_, err := chain.Next(data, ctx)
require.NoError(t, err, "Expected no error")
assert.False(t, continue_, "Expected continue to be false")
assert.Equal(t, []byte("modified"), result, "Expected modified data")
}
func TestExecute(t *testing.T) {
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(nil, conn, req, resp)
interceptor1 := interceptorpkg.HandlerFunc(func(data []byte, ctx ctxpkg.Context, next interceptorpkg.Chain) ([]byte, bool, error) {
return next.Next(data, ctx)
})
data := []byte("test")
result, continue_, err := Execute(data, ctx, interceptor1)
require.NoError(t, err)
assert.True(t, continue_)
assert.Equal(t, data, result)
}
// mockConnection 模拟连接
type mockConnection struct {
id string
}
func (m *mockConnection) ID() string {
return m.id
}
func (m *mockConnection) RemoteAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnection) LocalAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnection) Write(data []byte) error {
return nil
}
func (m *mockConnection) Close() error {
return nil
}