You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
4.4 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package integration
import (
"encoding/json"
"fmt"
"net"
"testing"
"time"
internalinterceptor "github.com/noahlann/nnet/internal/interceptor"
"github.com/noahlann/nnet/internal/interceptor/builtin"
"github.com/noahlann/nnet/pkg/interceptor"
"github.com/noahlann/nnet/pkg/nnet"
routerpkg "github.com/noahlann/nnet/pkg/router"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestInterceptorValidation 测试拦截器验证
func TestInterceptorValidation(t *testing.T) {
// 获取随机端口
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
Codec: &nnet.CodecConfig{
DefaultCodec: "json",
EnableProtocolEncode: false,
},
}
// 创建服务器并在启动前注册路由
server, err := nnet.NewServer(cfg)
require.NoError(t, err)
// 创建拦截器
minLenInterceptor := builtin.MinLengthInterceptor(5)
maxLenInterceptor := builtin.MaxLengthInterceptor(100)
customInterceptor := interceptor.HandlerFunc(func(data []byte, ctx nnet.Context, next interceptor.Chain) ([]byte, bool, error) {
if string(data) == "forbidden" {
return nil, false, interceptor.New("forbidden data")
}
ctx.Set("intercepted", true)
return next.Next(data, ctx)
})
// 注册路由,使用拦截器链
server.Router().RegisterCustom(
func(input routerpkg.MatchInput, ctx nnet.Context) bool {
data := input.Raw
return len(data) >= 8 && string(data[:8]) == "validate"
},
func(ctx nnet.Context) error {
rawData := ctx.Request().Raw()
var actualData []byte
if len(rawData) > 8 {
actualData = rawData[8:]
for len(actualData) > 0 && (actualData[0] == ' ' || actualData[0] == '\n') {
actualData = actualData[1:]
}
}
if len(actualData) == 0 {
actualData = []byte("test data for validation")
}
// 执行拦截器链
interceptors := []interceptor.Interceptor{
minLenInterceptor,
maxLenInterceptor,
customInterceptor,
}
data, continueProcessing, err := internalinterceptor.Execute(actualData, ctx, interceptors...)
if err != nil {
return ctx.Response().Write(map[string]any{
"error": err.Error(),
"valid": false,
})
}
if !continueProcessing {
return ctx.Response().Write(map[string]any{
"error": "processing stopped",
"valid": false,
})
}
intercepted := ctx.GetBool("intercepted")
return ctx.Response().Write(map[string]any{
"valid": true,
"data": string(data),
"intercepted": intercepted,
})
},
)
// 启动服务器
ts := &TestServer{
Server: server,
Addr: cfg.Addr,
stopCh: make(chan struct{}),
}
ts.wg.Add(1)
go func() {
defer ts.wg.Done()
if err := server.Start(); err != nil {
t.Logf("Server error: %v", err)
}
}()
require.Eventually(t, func() bool {
return server.Started()
}, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds")
// 等待服务器准备好
require.Eventually(t, func() bool {
testConn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond)
if err != nil {
return false
}
testConn.Close()
return true
}, 5*time.Second, 100*time.Millisecond, "Server should be ready")
defer CleanupTestServer(t, ts)
client := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
// 等待服务器准备好
time.Sleep(100 * time.Millisecond)
// 测试有效数据
validData := "validate hello world"
resp := RequestWithTimeout(t, client, []byte(validData), 3*time.Second)
t.Logf("Response for valid data: %q", string(resp))
var result map[string]any
err = json.Unmarshal(resp, &result)
assert.NoError(t, err, "Response should be valid JSON")
assert.True(t, result["valid"].(bool), "Data should be valid")
// 测试数据太短(会被拦截器拒绝)
shortData := "validate hi" // "hi" 只有2字节小于最小长度5
resp = RequestWithTimeout(t, client, []byte(shortData), 3*time.Second)
t.Logf("Response for short data: %q", string(resp))
err = json.Unmarshal(resp, &result)
assert.NoError(t, err, "Response should be valid JSON")
assert.False(t, result["valid"].(bool), "Data should be invalid")
assert.Contains(t, result["error"].(string), "too short", "Error should indicate data is too short")
}