package protocol import ( "context" "fmt" "testing" "time" protocolpkg "github.com/noahlann/nnet/pkg/protocol" unpackerpkg "github.com/noahlann/nnet/pkg/unpacker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestProtocolManager(t *testing.T) { manager := NewManager() // 测试注册协议 mockProtocol := &mockProtocol{ name: "test", version: "1.0", } err := manager.Register(mockProtocol) require.NoError(t, err, "Expected no error when registering protocol") // 测试获取协议 proto, err := manager.Get("test", "1.0") require.NoError(t, err, "Expected no error when getting protocol") assert.Equal(t, "test", proto.Name(), "Expected protocol name to match") assert.Equal(t, "1.0", proto.Version(), "Expected protocol version to match") // 测试省略版本时自动回退到首个注册版本 proto, err = manager.Get("test", "") require.NoError(t, err, "Expected no error when getting protocol without version") assert.Equal(t, "test", proto.Name()) assert.Equal(t, "1.0", proto.Version()) } func TestProtocolManagerGetDefault(t *testing.T) { manager := NewManager() // 测试获取默认协议(第一个注册的协议) mockProtocol := &mockProtocol{ name: "test", version: "1.0", } err := manager.Register(mockProtocol) require.NoError(t, err) defaultProto := manager.GetDefault() require.NotNil(t, defaultProto, "Expected default protocol to be non-nil") assert.Equal(t, "test", defaultProto.Name(), "Expected default protocol name to match") } func TestProtocolManagerRegisterNil(t *testing.T) { manager := NewManager() // 测试注册nil协议 err := manager.Register(nil) assert.Error(t, err, "Expected error for nil protocol") } func TestProtocolManagerGetNotFound(t *testing.T) { manager := NewManager() // 测试获取不存在的协议 _, err := manager.Get("nonexistent", "1.0") assert.Error(t, err, "Expected error for nonexistent protocol") } func TestProtocolManagerSetDefault(t *testing.T) { manager := NewManager() // 注册多个协议 proto1 := &mockProtocol{name: "test1", version: "1.0"} proto2 := &mockProtocol{name: "test2", version: "1.0"} err := manager.Register(proto1) require.NoError(t, err) err = manager.Register(proto2) require.NoError(t, err) // 设置默认协议 err = manager.SetDefault("test2", "1.0") require.NoError(t, err, "Expected no error when setting default protocol") // 验证默认协议 defaultProto := manager.GetDefault() assert.Equal(t, "test2", defaultProto.Name(), "Expected default protocol to be test2") } func TestProtocolManagerList(t *testing.T) { manager := NewManager() // 注册多个协议 proto1 := &mockProtocol{name: "test1", version: "1.0"} proto2 := &mockProtocol{name: "test2", version: "1.0"} manager.Register(proto1) manager.Register(proto2) // 列出所有协议 protocols := manager.List() assert.GreaterOrEqual(t, len(protocols), 2, "Expected at least 2 protocols") } func TestProtocolManagerConcurrent(t *testing.T) { // 使用较短的超时,避免卡住 t.Parallel() manager := NewManager() // 并发注册协议(减少并发数,避免卡住) done := make(chan bool, 10) for i := 0; i < 10; i++ { go func(id int) { proto := &mockProtocol{ name: fmt.Sprintf("test%d", id), version: "1.0", } manager.Register(proto) done <- true }(i) } // 等待所有goroutine完成,使用超时 timeout := time.NewTimer(5 * time.Second) defer timeout.Stop() for i := 0; i < 10; i++ { select { case <-done: // OK case <-timeout.C: t.Fatal("Test timeout: concurrent registration took too long") } } } // mockProtocol 模拟协议 type mockProtocol struct { name string version string } func (m *mockProtocol) Name() string { return m.name } func (m *mockProtocol) Version() string { return m.version } func (m *mockProtocol) HasHeader() bool { return false } func (m *mockProtocol) Encode(data []byte, header protocolpkg.FrameHeader) ([]byte, error) { return data, nil } func (m *mockProtocol) Decode(data []byte) (protocolpkg.FrameHeader, []byte, error) { return nil, data, nil } func (m *mockProtocol) Handle(ctx context.Context, data []byte) ([]byte, error) { return data, nil } func (m *mockProtocol) Unpacker() unpackerpkg.Unpacker { return nil }