You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
3.3 KiB
Go

package middleware
import (
"errors"
"testing"
"github.com/noahlann/nnet/internal/codec"
"github.com/noahlann/nnet/internal/request"
"github.com/noahlann/nnet/internal/response"
ctxpkg "github.com/noahlann/nnet/pkg/context"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMiddlewareChain(t *testing.T) {
// 创建测试Context
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(nil, conn, req, resp)
// 创建中间件链
callOrder := make([]int, 0)
middleware1 := func(ctx ctxpkg.Context) error {
callOrder = append(callOrder, 1)
return nil
}
middleware2 := func(ctx ctxpkg.Context) error {
callOrder = append(callOrder, 2)
return nil
}
chain := NewChain(middleware1, middleware2)
handler := func(ctx ctxpkg.Context) error {
callOrder = append(callOrder, 3)
return nil
}
// 执行中间件链
finalHandler := chain.Then(handler)
err := finalHandler(ctx)
require.NoError(t, err, "Expected no error")
// 验证执行顺序(中间件从前往后执行)
assert.Equal(t, []int{1, 2, 3}, callOrder, "Expected call order to be 1, 2, 3")
}
func TestMiddlewareChainEmpty(t *testing.T) {
chain := NewChain()
handler := func(ctx ctxpkg.Context) error {
return nil
}
// 空的中间件链应该直接返回handler
finalHandler := chain.Then(handler)
assert.NotNil(t, finalHandler, "Expected handler to be returned")
}
func TestMiddlewareChainError(t *testing.T) {
// 创建测试Context
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(nil, conn, req, resp)
// 创建会返回错误的中间件
testError := errors.New("middleware error")
middleware1 := func(ctx ctxpkg.Context) error {
return testError
}
chain := NewChain(middleware1)
handler := func(ctx ctxpkg.Context) error {
return nil
}
// 执行中间件链(应该返回错误)
finalHandler := chain.Then(handler)
err := finalHandler(ctx)
assert.Error(t, err, "Expected error from middleware")
assert.Equal(t, testError, err, "Expected middleware error")
}
func TestMiddlewareChainAppend(t *testing.T) {
chain := NewChain()
middleware1 := func(ctx ctxpkg.Context) error {
return nil
}
middleware2 := func(ctx ctxpkg.Context) error {
return nil
}
chain.Append(middleware1, middleware2)
assert.Equal(t, 2, len(chain.middlewares), "Expected 2 middlewares")
}
func TestMiddlewareChainPrepend(t *testing.T) {
chain := NewChain()
middleware1 := func(ctx ctxpkg.Context) error {
return nil
}
middleware2 := func(ctx ctxpkg.Context) error {
return nil
}
chain.Append(middleware1)
chain.Prepend(middleware2)
// Prepend应该在前面
assert.Equal(t, 2, len(chain.middlewares), "Expected 2 middlewares")
}
// mockConnection 模拟连接
type mockConnection struct {
id string
}
func (m *mockConnection) ID() string {
return m.id
}
func (m *mockConnection) RemoteAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnection) LocalAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnection) Write(data []byte) error {
return nil
}
func (m *mockConnection) Close() error {
return nil
}