You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

172 lines
4.2 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package protocol
import (
"context"
"fmt"
"testing"
"time"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestProtocolManager(t *testing.T) {
manager := NewManager()
// 测试注册协议
mockProtocol := &mockProtocol{
name: "test",
version: "1.0",
}
err := manager.Register(mockProtocol)
require.NoError(t, err, "Expected no error when registering protocol")
// 测试获取协议
proto, err := manager.Get("test", "1.0")
require.NoError(t, err, "Expected no error when getting protocol")
assert.Equal(t, "test", proto.Name(), "Expected protocol name to match")
assert.Equal(t, "1.0", proto.Version(), "Expected protocol version to match")
// 测试省略版本时自动回退到首个注册版本
proto, err = manager.Get("test", "")
require.NoError(t, err, "Expected no error when getting protocol without version")
assert.Equal(t, "test", proto.Name())
assert.Equal(t, "1.0", proto.Version())
}
func TestProtocolManagerGetDefault(t *testing.T) {
manager := NewManager()
// 测试获取默认协议(第一个注册的协议)
mockProtocol := &mockProtocol{
name: "test",
version: "1.0",
}
err := manager.Register(mockProtocol)
require.NoError(t, err)
defaultProto := manager.GetDefault()
require.NotNil(t, defaultProto, "Expected default protocol to be non-nil")
assert.Equal(t, "test", defaultProto.Name(), "Expected default protocol name to match")
}
func TestProtocolManagerRegisterNil(t *testing.T) {
manager := NewManager()
// 测试注册nil协议
err := manager.Register(nil)
assert.Error(t, err, "Expected error for nil protocol")
}
func TestProtocolManagerGetNotFound(t *testing.T) {
manager := NewManager()
// 测试获取不存在的协议
_, err := manager.Get("nonexistent", "1.0")
assert.Error(t, err, "Expected error for nonexistent protocol")
}
func TestProtocolManagerSetDefault(t *testing.T) {
manager := NewManager()
// 注册多个协议
proto1 := &mockProtocol{name: "test1", version: "1.0"}
proto2 := &mockProtocol{name: "test2", version: "1.0"}
err := manager.Register(proto1)
require.NoError(t, err)
err = manager.Register(proto2)
require.NoError(t, err)
// 设置默认协议
err = manager.SetDefault("test2", "1.0")
require.NoError(t, err, "Expected no error when setting default protocol")
// 验证默认协议
defaultProto := manager.GetDefault()
assert.Equal(t, "test2", defaultProto.Name(), "Expected default protocol to be test2")
}
func TestProtocolManagerList(t *testing.T) {
manager := NewManager()
// 注册多个协议
proto1 := &mockProtocol{name: "test1", version: "1.0"}
proto2 := &mockProtocol{name: "test2", version: "1.0"}
manager.Register(proto1)
manager.Register(proto2)
// 列出所有协议
protocols := manager.List()
assert.GreaterOrEqual(t, len(protocols), 2, "Expected at least 2 protocols")
}
func TestProtocolManagerConcurrent(t *testing.T) {
// 使用较短的超时,避免卡住
t.Parallel()
manager := NewManager()
// 并发注册协议(减少并发数,避免卡住)
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(id int) {
proto := &mockProtocol{
name: fmt.Sprintf("test%d", id),
version: "1.0",
}
manager.Register(proto)
done <- true
}(i)
}
// 等待所有goroutine完成使用超时
timeout := time.NewTimer(5 * time.Second)
defer timeout.Stop()
for i := 0; i < 10; i++ {
select {
case <-done:
// OK
case <-timeout.C:
t.Fatal("Test timeout: concurrent registration took too long")
}
}
}
// mockProtocol 模拟协议
type mockProtocol struct {
name string
version string
}
func (m *mockProtocol) Name() string {
return m.name
}
func (m *mockProtocol) Version() string {
return m.version
}
func (m *mockProtocol) HasHeader() bool {
return false
}
func (m *mockProtocol) Encode(data []byte, header protocolpkg.FrameHeader) ([]byte, error) {
return data, nil
}
func (m *mockProtocol) Decode(data []byte) (protocolpkg.FrameHeader, []byte, error) {
return nil, data, nil
}
func (m *mockProtocol) Handle(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}
func (m *mockProtocol) Unpacker() unpackerpkg.Unpacker {
return nil
}