You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
3.3 KiB
Go
125 lines
3.3 KiB
Go
package integration
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
internalprotocol "github.com/noahlann/nnet/internal/protocol/nnet"
|
|
"github.com/noahlann/nnet/pkg/nnet"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TestProtocolVersion 测试协议版本管理
|
|
func TestProtocolVersion(t *testing.T) {
|
|
// 获取随机端口
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
listener.Close()
|
|
|
|
cfg := &nnet.Config{
|
|
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
|
|
ApplicationProtocol: "nnet",
|
|
Codec: &nnet.CodecConfig{
|
|
DefaultCodec: "json",
|
|
EnableProtocolEncode: true,
|
|
},
|
|
}
|
|
|
|
// 创建服务器并在启动前注册协议和路由
|
|
server, err := nnet.NewServer(cfg)
|
|
require.NoError(t, err)
|
|
|
|
// 注册多个版本的协议
|
|
pm := server.ProtocolManager()
|
|
|
|
// 注册 v1.0 版本
|
|
protoV1 := internalprotocol.NewNNetProtocol("1.0")
|
|
require.NoError(t, pm.Register(protoV1), "Should register protocol v1.0")
|
|
|
|
// 注册 v2.0 版本
|
|
protoV2 := internalprotocol.NewNNetProtocol("2.0")
|
|
require.NoError(t, pm.Register(protoV2), "Should register protocol v2.0")
|
|
|
|
// 设置默认协议版本为 v1.0
|
|
require.NoError(t, pm.SetDefault("nnet", "1.0"), "Should set default protocol")
|
|
|
|
// 注册路由
|
|
server.Router().RegisterString("version", func(ctx nnet.Context) error {
|
|
proto := pm.GetDefault()
|
|
if proto != nil {
|
|
return ctx.Response().Write(map[string]any{
|
|
"protocol": proto.Name(),
|
|
"version": proto.Version(),
|
|
})
|
|
}
|
|
return ctx.Response().Write(map[string]any{
|
|
"protocol": "unknown",
|
|
"version": "unknown",
|
|
})
|
|
})
|
|
|
|
// 启动服务器
|
|
ts := &TestServer{
|
|
Server: server,
|
|
Addr: cfg.Addr,
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
|
|
ts.wg.Add(1)
|
|
go func() {
|
|
defer ts.wg.Done()
|
|
if err := server.Start(); err != nil {
|
|
t.Logf("Server error: %v", err)
|
|
}
|
|
}()
|
|
|
|
require.Eventually(t, func() bool {
|
|
return server.Started()
|
|
}, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds")
|
|
|
|
// 等待服务器准备好
|
|
require.Eventually(t, func() bool {
|
|
testConn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
testConn.Close()
|
|
return true
|
|
}, 5*time.Second, 100*time.Millisecond, "Server should be ready")
|
|
|
|
defer CleanupTestServer(t, ts)
|
|
|
|
client := NewTestClient(t, ts.Addr, &nnet.ClientConfig{
|
|
ApplicationProtocol: "nnet",
|
|
})
|
|
defer CleanupTestClient(t, client)
|
|
|
|
ConnectTestClient(t, client)
|
|
|
|
// 等待服务器准备好
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// 使用nnet协议编码请求
|
|
requestPacket, err := protoV1.Encode([]byte("version"), nil)
|
|
require.NoError(t, err, "Failed to encode request with nnet protocol")
|
|
|
|
respPacket := RequestWithTimeout(t, client, requestPacket, 3*time.Second)
|
|
t.Logf("Response for version (raw): %q", string(respPacket))
|
|
|
|
// 解码响应
|
|
_, respPayload, err := protoV1.Decode(respPacket)
|
|
require.NoError(t, err, "Failed to decode response packet")
|
|
|
|
var result map[string]any
|
|
err = json.Unmarshal(respPayload, &result)
|
|
assert.NoError(t, err, "Response should be valid JSON")
|
|
assert.Equal(t, "nnet", result["protocol"], "Protocol should be nnet")
|
|
assert.Equal(t, "1.0", result["version"], "Version should be 1.0")
|
|
}
|
|
|