package integration import ( "encoding/json" "fmt" "net" "testing" "time" internalprotocol "github.com/noahlann/nnet/internal/protocol/nnet" "github.com/noahlann/nnet/pkg/nnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // TestProtocolVersion 测试协议版本管理 func TestProtocolVersion(t *testing.T) { // 获取随机端口 listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port listener.Close() cfg := &nnet.Config{ Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port), ApplicationProtocol: "nnet", Codec: &nnet.CodecConfig{ DefaultCodec: "json", EnableProtocolEncode: true, }, } // 创建服务器并在启动前注册协议和路由 server, err := nnet.NewServer(cfg) require.NoError(t, err) // 注册多个版本的协议 pm := server.ProtocolManager() // 注册 v1.0 版本 protoV1 := internalprotocol.NewNNetProtocol("1.0") require.NoError(t, pm.Register(protoV1), "Should register protocol v1.0") // 注册 v2.0 版本 protoV2 := internalprotocol.NewNNetProtocol("2.0") require.NoError(t, pm.Register(protoV2), "Should register protocol v2.0") // 设置默认协议版本为 v1.0 require.NoError(t, pm.SetDefault("nnet", "1.0"), "Should set default protocol") // 注册路由 server.Router().RegisterString("version", func(ctx nnet.Context) error { proto := pm.GetDefault() if proto != nil { return ctx.Response().Write(map[string]any{ "protocol": proto.Name(), "version": proto.Version(), }) } return ctx.Response().Write(map[string]any{ "protocol": "unknown", "version": "unknown", }) }) // 启动服务器 ts := &TestServer{ Server: server, Addr: cfg.Addr, stopCh: make(chan struct{}), } ts.wg.Add(1) go func() { defer ts.wg.Done() if err := server.Start(); err != nil { t.Logf("Server error: %v", err) } }() require.Eventually(t, func() bool { return server.Started() }, 3*time.Second, 50*time.Millisecond, "Server should start within 3 seconds") // 等待服务器准备好 require.Eventually(t, func() bool { testConn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 500*time.Millisecond) if err != nil { return false } testConn.Close() return true }, 5*time.Second, 100*time.Millisecond, "Server should be ready") defer CleanupTestServer(t, ts) client := NewTestClient(t, ts.Addr, &nnet.ClientConfig{ ApplicationProtocol: "nnet", }) defer CleanupTestClient(t, client) ConnectTestClient(t, client) // 等待服务器准备好 time.Sleep(100 * time.Millisecond) // 使用nnet协议编码请求 requestPacket, err := protoV1.Encode([]byte("version"), nil) require.NoError(t, err, "Failed to encode request with nnet protocol") respPacket := RequestWithTimeout(t, client, requestPacket, 3*time.Second) t.Logf("Response for version (raw): %q", string(respPacket)) // 解码响应 _, respPayload, err := protoV1.Decode(respPacket) require.NoError(t, err, "Failed to decode response packet") var result map[string]any err = json.Unmarshal(respPayload, &result) assert.NoError(t, err, "Response should be valid JSON") assert.Equal(t, "nnet", result["protocol"], "Protocol should be nnet") assert.Equal(t, "1.0", result["version"], "Version should be 1.0") }