|
|
package protocol
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
func TestProtocolManager(t *testing.T) {
|
|
|
manager := NewManager()
|
|
|
|
|
|
// 测试注册协议
|
|
|
mockProtocol := &mockProtocol{
|
|
|
name: "test",
|
|
|
version: "1.0",
|
|
|
}
|
|
|
err := manager.Register(mockProtocol)
|
|
|
require.NoError(t, err, "Expected no error when registering protocol")
|
|
|
|
|
|
// 测试获取协议
|
|
|
proto, err := manager.Get("test", "1.0")
|
|
|
require.NoError(t, err, "Expected no error when getting protocol")
|
|
|
assert.Equal(t, "test", proto.Name(), "Expected protocol name to match")
|
|
|
assert.Equal(t, "1.0", proto.Version(), "Expected protocol version to match")
|
|
|
|
|
|
// 测试省略版本时自动回退到首个注册版本
|
|
|
proto, err = manager.Get("test", "")
|
|
|
require.NoError(t, err, "Expected no error when getting protocol without version")
|
|
|
assert.Equal(t, "test", proto.Name())
|
|
|
assert.Equal(t, "1.0", proto.Version())
|
|
|
}
|
|
|
|
|
|
func TestProtocolManagerGetDefault(t *testing.T) {
|
|
|
manager := NewManager()
|
|
|
|
|
|
// 测试获取默认协议(第一个注册的协议)
|
|
|
mockProtocol := &mockProtocol{
|
|
|
name: "test",
|
|
|
version: "1.0",
|
|
|
}
|
|
|
err := manager.Register(mockProtocol)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
defaultProto := manager.GetDefault()
|
|
|
require.NotNil(t, defaultProto, "Expected default protocol to be non-nil")
|
|
|
assert.Equal(t, "test", defaultProto.Name(), "Expected default protocol name to match")
|
|
|
}
|
|
|
|
|
|
func TestProtocolManagerRegisterNil(t *testing.T) {
|
|
|
manager := NewManager()
|
|
|
|
|
|
// 测试注册nil协议
|
|
|
err := manager.Register(nil)
|
|
|
assert.Error(t, err, "Expected error for nil protocol")
|
|
|
}
|
|
|
|
|
|
func TestProtocolManagerGetNotFound(t *testing.T) {
|
|
|
manager := NewManager()
|
|
|
|
|
|
// 测试获取不存在的协议
|
|
|
_, err := manager.Get("nonexistent", "1.0")
|
|
|
assert.Error(t, err, "Expected error for nonexistent protocol")
|
|
|
}
|
|
|
|
|
|
func TestProtocolManagerSetDefault(t *testing.T) {
|
|
|
manager := NewManager()
|
|
|
|
|
|
// 注册多个协议
|
|
|
proto1 := &mockProtocol{name: "test1", version: "1.0"}
|
|
|
proto2 := &mockProtocol{name: "test2", version: "1.0"}
|
|
|
|
|
|
err := manager.Register(proto1)
|
|
|
require.NoError(t, err)
|
|
|
err = manager.Register(proto2)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 设置默认协议
|
|
|
err = manager.SetDefault("test2", "1.0")
|
|
|
require.NoError(t, err, "Expected no error when setting default protocol")
|
|
|
|
|
|
// 验证默认协议
|
|
|
defaultProto := manager.GetDefault()
|
|
|
assert.Equal(t, "test2", defaultProto.Name(), "Expected default protocol to be test2")
|
|
|
}
|
|
|
|
|
|
func TestProtocolManagerList(t *testing.T) {
|
|
|
manager := NewManager()
|
|
|
|
|
|
// 注册多个协议
|
|
|
proto1 := &mockProtocol{name: "test1", version: "1.0"}
|
|
|
proto2 := &mockProtocol{name: "test2", version: "1.0"}
|
|
|
|
|
|
manager.Register(proto1)
|
|
|
manager.Register(proto2)
|
|
|
|
|
|
// 列出所有协议
|
|
|
protocols := manager.List()
|
|
|
assert.GreaterOrEqual(t, len(protocols), 2, "Expected at least 2 protocols")
|
|
|
}
|
|
|
|
|
|
func TestProtocolManagerConcurrent(t *testing.T) {
|
|
|
// 使用较短的超时,避免卡住
|
|
|
t.Parallel()
|
|
|
manager := NewManager()
|
|
|
|
|
|
// 并发注册协议(减少并发数,避免卡住)
|
|
|
done := make(chan bool, 10)
|
|
|
for i := 0; i < 10; i++ {
|
|
|
go func(id int) {
|
|
|
proto := &mockProtocol{
|
|
|
name: fmt.Sprintf("test%d", id),
|
|
|
version: "1.0",
|
|
|
}
|
|
|
manager.Register(proto)
|
|
|
done <- true
|
|
|
}(i)
|
|
|
}
|
|
|
|
|
|
// 等待所有goroutine完成,使用超时
|
|
|
timeout := time.NewTimer(5 * time.Second)
|
|
|
defer timeout.Stop()
|
|
|
|
|
|
for i := 0; i < 10; i++ {
|
|
|
select {
|
|
|
case <-done:
|
|
|
// OK
|
|
|
case <-timeout.C:
|
|
|
t.Fatal("Test timeout: concurrent registration took too long")
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// mockProtocol 模拟协议
|
|
|
type mockProtocol struct {
|
|
|
name string
|
|
|
version string
|
|
|
}
|
|
|
|
|
|
func (m *mockProtocol) Name() string {
|
|
|
return m.name
|
|
|
}
|
|
|
|
|
|
func (m *mockProtocol) Version() string {
|
|
|
return m.version
|
|
|
}
|
|
|
|
|
|
func (m *mockProtocol) HasHeader() bool {
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
func (m *mockProtocol) Encode(data []byte, header protocolpkg.FrameHeader) ([]byte, error) {
|
|
|
return data, nil
|
|
|
}
|
|
|
|
|
|
func (m *mockProtocol) Decode(data []byte) (protocolpkg.FrameHeader, []byte, error) {
|
|
|
return nil, data, nil
|
|
|
}
|
|
|
|
|
|
func (m *mockProtocol) Handle(ctx context.Context, data []byte) ([]byte, error) {
|
|
|
return data, nil
|
|
|
}
|
|
|
|
|
|
func (m *mockProtocol) Unpacker() unpackerpkg.Unpacker {
|
|
|
return nil
|
|
|
}
|