You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

140 lines
3.8 KiB
Go

package integration
import (
"fmt"
"net"
"testing"
"time"
"github.com/noahlann/nnet/pkg/nnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestErrorHandling 测试错误处理
func TestErrorHandling(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
// 注册一个会返回错误的路由
srv.Router().RegisterString("error", func(ctx nnet.Context) error {
return fmt.Errorf("handler error")
})
// 注册一个正常的路由
srv.Router().RegisterString("ok", func(ctx nnet.Context) error {
return ctx.Response().WriteBytes([]byte("ok\n"))
})
})
defer CleanupTestServer(t, ts)
client := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 测试正常路由
resp := RequestWithTimeout(t, client, []byte("ok"), 3*time.Second)
assert.Contains(t, string(resp), "ok", "Normal route should work")
// 测试错误路由
resp, err = client.Request([]byte("error"), 3*time.Second)
if err != nil {
// 如果请求失败,这是预期的行为
t.Logf("Error route failed as expected: %v", err)
} else {
// 如果收到响应,应该是错误响应
t.Logf("Error route response: %q", string(resp))
assert.NotEmpty(t, resp, "Error route should return error response")
}
}
// TestConnectionError 测试连接错误处理
func TestConnectionError(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
srv.Router().RegisterString("test", func(ctx nnet.Context) error {
return ctx.Response().WriteBytes([]byte("ok\n"))
})
})
defer CleanupTestServer(t, ts)
// 创建客户端
client := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 正常请求
resp := RequestWithTimeout(t, client, []byte("test"), 3*time.Second)
assert.Contains(t, string(resp), "ok", "Request should succeed")
// 断开连接
client.Disconnect()
// 等待连接关闭
time.Sleep(100 * time.Millisecond)
// 尝试发送请求(应该失败)
_, err = client.Request([]byte("test"), 1*time.Second)
assert.Error(t, err, "Request should fail after disconnect")
}
// TestInvalidData 测试无效数据处理
func TestInvalidData(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
Codec: &nnet.CodecConfig{
DefaultCodec: "json",
},
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
srv.Router().RegisterString("test", func(ctx nnet.Context) error {
return ctx.Response().WriteBytes([]byte("ok\n"))
})
})
defer CleanupTestServer(t, ts)
client := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 发送无效的JSON数据
invalidData := []byte("invalid json{")
resp, err := client.Request(invalidData, 3*time.Second)
if err != nil {
// 如果请求失败,这是预期的行为
t.Logf("Invalid data request failed as expected: %v", err)
} else {
// 如果收到响应,服务器应该能够处理
t.Logf("Invalid data response: %q", string(resp))
assert.NotEmpty(t, resp, "Server should handle invalid data")
}
}