package integration import ( "encoding/json" "fmt" "net" "testing" "time" internalinterceptor "github.com/noahlann/nnet/internal/interceptor" "github.com/noahlann/nnet/internal/interceptor/builtin" "github.com/noahlann/nnet/pkg/interceptor" "github.com/noahlann/nnet/pkg/nnet" routerpkg "github.com/noahlann/nnet/pkg/router" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestInterceptorValidation 测试拦截器验证 func TestInterceptorValidation(t *testing.T) { // 获取随机端口 listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), Codec: &nnet.CodecConfig{ DefaultCodec: "json", EnableProtocolEncode: false, }, } // 创建服务器并在启动前注册路由 server, err := nnet.NewServer(cfg) require.NoError(t, err) // 创建拦截器 minLenInterceptor := builtin.MinLengthInterceptor(5) maxLenInterceptor := builtin.MaxLengthInterceptor(100) customInterceptor := interceptor.HandlerFunc(func(data []byte, ctx nnet.Context, next interceptor.Chain) ([]byte, bool, error) { if string(data) == "forbidden" { return nil, false, interceptor.New("forbidden data") } ctx.Set("intercepted", true) return next.Next(data, ctx) }) // 注册路由,使用拦截器链 server.Router().RegisterCustom( func(input routerpkg.MatchInput, ctx nnet.Context) bool { data := input.Raw return len(data) >= 8 && string(data[:8]) == "validate" }, func(ctx nnet.Context) error { rawData := ctx.Request().Raw() var actualData []byte if len(rawData) > 8 { actualData = rawData[8:] for len(actualData) > 0 && (actualData[0] == ' ' || actualData[0] == '\n') { actualData = actualData[1:] } } if len(actualData) == 0 { actualData = []byte("test data for validation") } // 执行拦截器链 interceptors := []interceptor.Interceptor{ minLenInterceptor, maxLenInterceptor, customInterceptor, } data, continueProcessing, err := internalinterceptor.Execute(actualData, ctx, interceptors...) if err != nil { return ctx.Response().Write(map[string]any{ "error": err.Error(), "valid": false, }) } if !continueProcessing { return ctx.Response().Write(map[string]any{ "error": "processing stopped", "valid": false, }) } intercepted := ctx.GetBool("intercepted") return ctx.Response().Write(map[string]any{ "valid": true, "data": string(data), "intercepted": intercepted, }) }, ) // 启动服务器 ts := &TestServer{ Server: server, Addr: cfg.Addr, stopCh: make(chan struct{}), } ts.wg.Add(1) go func() { defer ts.wg.Done() if err := server.Start(); err != nil { t.Logf("Server error: %v", err) } }() require.Eventually(t, func() bool { return server.Started() }, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds") // 等待服务器准备好 require.Eventually(t, func() bool { testConn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond) if err != nil { return false } testConn.Close() return true }, 5*time.Second, 100*time.Millisecond, "Server should be ready") defer CleanupTestServer(t, ts) client := NewTestClient(t, ts.Addr, nil) defer CleanupTestClient(t, client) ConnectTestClient(t, client) // 等待服务器准备好 time.Sleep(100 * time.Millisecond) // 测试有效数据 validData := "validate hello world" resp := RequestWithTimeout(t, client, []byte(validData), 3*time.Second) t.Logf("Response for valid data: %q", string(resp)) var result map[string]any err = json.Unmarshal(resp, &result) assert.NoError(t, err, "Response should be valid JSON") assert.True(t, result["valid"].(bool), "Data should be valid") // 测试数据太短(会被拦截器拒绝) shortData := "validate hi" // "hi" 只有2字节,小于最小长度5 resp = RequestWithTimeout(t, client, []byte(shortData), 3*time.Second) t.Logf("Response for short data: %q", string(resp)) err = json.Unmarshal(resp, &result) assert.NoError(t, err, "Response should be valid JSON") assert.False(t, result["valid"].(bool), "Data should be invalid") assert.Contains(t, result["error"].(string), "too short", "Error should indicate data is too short") }