|
|
package integration
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
internalinterceptor "github.com/noahlann/nnet/internal/interceptor"
|
|
|
"github.com/noahlann/nnet/internal/interceptor/builtin"
|
|
|
"github.com/noahlann/nnet/pkg/interceptor"
|
|
|
"github.com/noahlann/nnet/pkg/nnet"
|
|
|
routerpkg "github.com/noahlann/nnet/pkg/router"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
// TestInterceptorValidation 测试拦截器验证
|
|
|
func TestInterceptorValidation(t *testing.T) {
|
|
|
// 获取随机端口
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
require.NoError(t, err)
|
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
|
listener.Close()
|
|
|
|
|
|
cfg := &nnet.Config{
|
|
|
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
|
|
|
Codec: &nnet.CodecConfig{
|
|
|
DefaultCodec: "json",
|
|
|
EnableProtocolEncode: false,
|
|
|
},
|
|
|
}
|
|
|
|
|
|
// 创建服务器并在启动前注册路由
|
|
|
server, err := nnet.NewServer(cfg)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 创建拦截器
|
|
|
minLenInterceptor := builtin.MinLengthInterceptor(5)
|
|
|
maxLenInterceptor := builtin.MaxLengthInterceptor(100)
|
|
|
|
|
|
customInterceptor := interceptor.HandlerFunc(func(data []byte, ctx nnet.Context, next interceptor.Chain) ([]byte, bool, error) {
|
|
|
if string(data) == "forbidden" {
|
|
|
return nil, false, interceptor.New("forbidden data")
|
|
|
}
|
|
|
ctx.Set("intercepted", true)
|
|
|
return next.Next(data, ctx)
|
|
|
})
|
|
|
|
|
|
// 注册路由,使用拦截器链
|
|
|
server.Router().RegisterCustom(
|
|
|
func(input routerpkg.MatchInput, ctx nnet.Context) bool {
|
|
|
data := input.Raw
|
|
|
return len(data) >= 8 && string(data[:8]) == "validate"
|
|
|
},
|
|
|
func(ctx nnet.Context) error {
|
|
|
rawData := ctx.Request().Raw()
|
|
|
var actualData []byte
|
|
|
if len(rawData) > 8 {
|
|
|
actualData = rawData[8:]
|
|
|
for len(actualData) > 0 && (actualData[0] == ' ' || actualData[0] == '\n') {
|
|
|
actualData = actualData[1:]
|
|
|
}
|
|
|
}
|
|
|
if len(actualData) == 0 {
|
|
|
actualData = []byte("test data for validation")
|
|
|
}
|
|
|
|
|
|
// 执行拦截器链
|
|
|
interceptors := []interceptor.Interceptor{
|
|
|
minLenInterceptor,
|
|
|
maxLenInterceptor,
|
|
|
customInterceptor,
|
|
|
}
|
|
|
|
|
|
data, continueProcessing, err := internalinterceptor.Execute(actualData, ctx, interceptors...)
|
|
|
if err != nil {
|
|
|
return ctx.Response().Write(map[string]any{
|
|
|
"error": err.Error(),
|
|
|
"valid": false,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
if !continueProcessing {
|
|
|
return ctx.Response().Write(map[string]any{
|
|
|
"error": "processing stopped",
|
|
|
"valid": false,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
intercepted := ctx.GetBool("intercepted")
|
|
|
|
|
|
return ctx.Response().Write(map[string]any{
|
|
|
"valid": true,
|
|
|
"data": string(data),
|
|
|
"intercepted": intercepted,
|
|
|
})
|
|
|
},
|
|
|
)
|
|
|
|
|
|
// 启动服务器
|
|
|
ts := &TestServer{
|
|
|
Server: server,
|
|
|
Addr: cfg.Addr,
|
|
|
stopCh: make(chan struct{}),
|
|
|
}
|
|
|
|
|
|
ts.wg.Add(1)
|
|
|
go func() {
|
|
|
defer ts.wg.Done()
|
|
|
if err := server.Start(); err != nil {
|
|
|
t.Logf("Server error: %v", err)
|
|
|
}
|
|
|
}()
|
|
|
|
|
|
require.Eventually(t, func() bool {
|
|
|
return server.Started()
|
|
|
}, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds")
|
|
|
|
|
|
// 等待服务器准备好
|
|
|
require.Eventually(t, func() bool {
|
|
|
testConn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond)
|
|
|
if err != nil {
|
|
|
return false
|
|
|
}
|
|
|
testConn.Close()
|
|
|
return true
|
|
|
}, 5*time.Second, 100*time.Millisecond, "Server should be ready")
|
|
|
|
|
|
defer CleanupTestServer(t, ts)
|
|
|
|
|
|
client := NewTestClient(t, ts.Addr, nil)
|
|
|
defer CleanupTestClient(t, client)
|
|
|
|
|
|
ConnectTestClient(t, client)
|
|
|
|
|
|
// 等待服务器准备好
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
|
|
// 测试有效数据
|
|
|
validData := "validate hello world"
|
|
|
resp := RequestWithTimeout(t, client, []byte(validData), 3*time.Second)
|
|
|
t.Logf("Response for valid data: %q", string(resp))
|
|
|
|
|
|
var result map[string]any
|
|
|
err = json.Unmarshal(resp, &result)
|
|
|
assert.NoError(t, err, "Response should be valid JSON")
|
|
|
assert.True(t, result["valid"].(bool), "Data should be valid")
|
|
|
|
|
|
// 测试数据太短(会被拦截器拒绝)
|
|
|
shortData := "validate hi" // "hi" 只有2字节,小于最小长度5
|
|
|
resp = RequestWithTimeout(t, client, []byte(shortData), 3*time.Second)
|
|
|
t.Logf("Response for short data: %q", string(resp))
|
|
|
|
|
|
err = json.Unmarshal(resp, &result)
|
|
|
assert.NoError(t, err, "Response should be valid JSON")
|
|
|
assert.False(t, result["valid"].(bool), "Data should be invalid")
|
|
|
assert.Contains(t, result["error"].(string), "too short", "Error should indicate data is too short")
|
|
|
}
|