You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

185 lines
4.7 KiB
Go

package integration
import (
"encoding/json"
"fmt"
"net"
"testing"
"time"
internalprotocol "github.com/noahlann/nnet/internal/protocol/nnet"
"github.com/noahlann/nnet/pkg/nnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestProtocolFrameHeader 测试协议帧头匹配
func TestProtocolFrameHeader(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
ApplicationProtocol: "nnet",
Codec: &nnet.CodecConfig{
DefaultCodec: "json",
EnableProtocolEncode: true,
},
}
server, err := nnet.NewServer(cfg)
require.NoError(t, err)
// 注册协议
pm := server.ProtocolManager()
proto := internalprotocol.NewNNetProtocol("1.0")
require.NoError(t, pm.Register(proto), "Should register protocol")
// 注册帧头匹配路由
server.Router().RegisterFrameHeader("op", "==", "ping", func(ctx nnet.Context) error {
return ctx.Response().Write(map[string]any{
"op": "pong",
"status": "ok",
})
})
server.Router().RegisterFrameHeader("op", "==", "echo", func(ctx nnet.Context) error {
data := ctx.Request().Data()
return ctx.Response().Write(map[string]any{
"op": "echo",
"data": data,
})
})
ts := &TestServer{
Server: server,
Addr: cfg.Addr,
stopCh: make(chan struct{}),
}
ts.wg.Add(1)
go func() {
defer ts.wg.Done()
if err := server.Start(); err != nil {
t.Logf("Server error: %v", err)
}
}()
require.Eventually(t, func() bool {
return server.Started()
}, 3*time.Second, 50*time.Millisecond, "Server should start")
time.Sleep(200 * time.Millisecond)
defer CleanupTestServer(t, ts)
// 创建客户端
client := NewTestClient(t, ts.Addr, &nnet.ClientConfig{
ApplicationProtocol: "nnet",
})
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 使用nnet协议编码请求
pingData := []byte(`{"op":"ping"}`)
pingPacket, err := proto.Encode(pingData, nil)
require.NoError(t, err, "Failed to encode ping request")
resp := RequestWithTimeout(t, client, pingPacket, 3*time.Second)
t.Logf("Response: %q", string(resp))
// 解码响应
_, respPayload, err := proto.Decode(resp)
require.NoError(t, err, "Failed to decode response")
var result map[string]any
err = json.Unmarshal(respPayload, &result)
assert.NoError(t, err, "Response should be valid JSON")
assert.Equal(t, "pong", result["op"], "Op should be pong")
assert.Equal(t, "ok", result["status"], "Status should be ok")
}
// TestProtocolFrameData 测试协议帧数据匹配
func TestProtocolFrameData(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
ApplicationProtocol: "nnet",
Codec: &nnet.CodecConfig{
DefaultCodec: "json",
EnableProtocolEncode: true,
},
}
server, err := nnet.NewServer(cfg)
require.NoError(t, err)
// 注册协议
pm := server.ProtocolManager()
proto := internalprotocol.NewNNetProtocol("1.0")
require.NoError(t, pm.Register(proto), "Should register protocol")
// 注册帧数据匹配路由
server.Router().RegisterFrameData("action", "==", "test", func(ctx nnet.Context) error {
return ctx.Response().Write(map[string]any{
"action": "test",
"result": "success",
})
})
ts := &TestServer{
Server: server,
Addr: cfg.Addr,
stopCh: make(chan struct{}),
}
ts.wg.Add(1)
go func() {
defer ts.wg.Done()
if err := server.Start(); err != nil {
t.Logf("Server error: %v", err)
}
}()
require.Eventually(t, func() bool {
return server.Started()
}, 3*time.Second, 50*time.Millisecond, "Server should start")
time.Sleep(200 * time.Millisecond)
defer CleanupTestServer(t, ts)
client := NewTestClient(t, ts.Addr, &nnet.ClientConfig{
ApplicationProtocol: "nnet",
})
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 使用nnet协议编码请求
testData := []byte(`{"action":"test"}`)
testPacket, err := proto.Encode(testData, nil)
require.NoError(t, err, "Failed to encode test request")
resp := RequestWithTimeout(t, client, testPacket, 3*time.Second)
t.Logf("Response: %q", string(resp))
// 解码响应
_, respPayload, err := proto.Decode(resp)
require.NoError(t, err, "Failed to decode response")
var result map[string]any
err = json.Unmarshal(respPayload, &result)
assert.NoError(t, err, "Response should be valid JSON")
assert.Equal(t, "test", result["action"], "Action should be test")
assert.Equal(t, "success", result["result"], "Result should be success")
}