|
|
package protocol
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
"sync"
|
|
|
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
)
|
|
|
|
|
|
// manager 协议管理器实现
|
|
|
type manager struct {
|
|
|
protocols map[string]map[string]protocolpkg.Protocol
|
|
|
defaultProto protocolpkg.Protocol
|
|
|
versionIdentifiers map[string]protocolpkg.VersionIdentifier
|
|
|
mu sync.RWMutex
|
|
|
}
|
|
|
|
|
|
// NewManager 创建协议管理器
|
|
|
func NewManager() protocolpkg.Manager {
|
|
|
return &manager{
|
|
|
protocols: make(map[string]map[string]protocolpkg.Protocol),
|
|
|
versionIdentifiers: make(map[string]protocolpkg.VersionIdentifier),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Register 注册协议
|
|
|
func (m *manager) Register(protocol protocolpkg.Protocol) error {
|
|
|
m.mu.Lock()
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
|
if protocol == nil {
|
|
|
return protocolpkg.NewError("protocol cannot be nil")
|
|
|
}
|
|
|
|
|
|
name := protocol.Name()
|
|
|
version := protocol.Version()
|
|
|
|
|
|
if name == "" {
|
|
|
return protocolpkg.NewError("protocol name cannot be empty")
|
|
|
}
|
|
|
|
|
|
if version == "" {
|
|
|
version = "default"
|
|
|
}
|
|
|
|
|
|
if m.protocols[name] == nil {
|
|
|
m.protocols[name] = make(map[string]protocolpkg.Protocol)
|
|
|
}
|
|
|
|
|
|
m.protocols[name][version] = protocol
|
|
|
|
|
|
// 如果是第一个协议,设置为默认协议
|
|
|
if m.defaultProto == nil {
|
|
|
m.defaultProto = protocol
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Get 获取协议
|
|
|
func (m *manager) Get(name, version string) (protocolpkg.Protocol, error) {
|
|
|
m.mu.RLock()
|
|
|
defer m.mu.RUnlock()
|
|
|
|
|
|
if name == "" {
|
|
|
return m.defaultProto, nil
|
|
|
}
|
|
|
|
|
|
versions, ok := m.protocols[name]
|
|
|
if !ok {
|
|
|
return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name))
|
|
|
}
|
|
|
|
|
|
if version == "" {
|
|
|
version = "default"
|
|
|
if protocol, ok := versions[version]; ok {
|
|
|
return protocol, nil
|
|
|
}
|
|
|
// 如果未显式标记default版本,返回第一个可用版本
|
|
|
for _, protocol := range versions {
|
|
|
return protocol, nil
|
|
|
}
|
|
|
return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s has no registered versions", name))
|
|
|
}
|
|
|
|
|
|
protocol, ok := versions[version]
|
|
|
if !ok {
|
|
|
return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version))
|
|
|
}
|
|
|
|
|
|
return protocol, nil
|
|
|
}
|
|
|
|
|
|
// GetDefault 获取默认协议
|
|
|
func (m *manager) GetDefault() protocolpkg.Protocol {
|
|
|
m.mu.RLock()
|
|
|
defer m.mu.RUnlock()
|
|
|
return m.defaultProto
|
|
|
}
|
|
|
|
|
|
// SetDefault 设置默认协议
|
|
|
func (m *manager) SetDefault(name, version string) error {
|
|
|
m.mu.Lock()
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
|
// 直接在这里获取协议,避免调用Get方法导致死锁(Get也需要锁)
|
|
|
if name == "" {
|
|
|
// 如果名称为空,使用第一个协议作为默认协议
|
|
|
if len(m.protocols) == 0 {
|
|
|
return protocolpkg.NewError("no protocols registered")
|
|
|
}
|
|
|
// 获取第一个协议
|
|
|
for _, versions := range m.protocols {
|
|
|
for _, protocol := range versions {
|
|
|
m.defaultProto = protocol
|
|
|
return nil
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
versions, ok := m.protocols[name]
|
|
|
if !ok {
|
|
|
return protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name))
|
|
|
}
|
|
|
|
|
|
if version == "" {
|
|
|
version = "default"
|
|
|
}
|
|
|
|
|
|
protocol, ok := versions[version]
|
|
|
if !ok {
|
|
|
return protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version))
|
|
|
}
|
|
|
|
|
|
m.defaultProto = protocol
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// List 列出所有协议
|
|
|
func (m *manager) List() []protocolpkg.Protocol {
|
|
|
m.mu.RLock()
|
|
|
defer m.mu.RUnlock()
|
|
|
|
|
|
var protocols []protocolpkg.Protocol
|
|
|
for _, versions := range m.protocols {
|
|
|
for _, protocol := range versions {
|
|
|
protocols = append(protocols, protocol)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return protocols
|
|
|
}
|
|
|
|
|
|
// RegisterVersionIdentifier 注册版本识别器
|
|
|
func (m *manager) RegisterVersionIdentifier(name string, identifier protocolpkg.VersionIdentifier) error {
|
|
|
m.mu.Lock()
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
|
if name == "" {
|
|
|
return protocolpkg.NewError("protocol name cannot be empty")
|
|
|
}
|
|
|
|
|
|
if identifier == nil {
|
|
|
return protocolpkg.NewError("version identifier cannot be nil")
|
|
|
}
|
|
|
|
|
|
m.versionIdentifiers[name] = identifier
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// GetVersionIdentifier 获取版本识别器
|
|
|
func (m *manager) GetVersionIdentifier(name string) protocolpkg.VersionIdentifier {
|
|
|
m.mu.RLock()
|
|
|
defer m.mu.RUnlock()
|
|
|
|
|
|
return m.versionIdentifiers[name]
|
|
|
}
|
|
|
|
|
|
// IdentifyVersion 识别协议版本
|
|
|
func (m *manager) IdentifyVersion(name string, data []byte, ctx context.Context) (string, error) {
|
|
|
// 获取版本识别器和协议版本映射(一次性获取,避免多次加锁)
|
|
|
m.mu.RLock()
|
|
|
identifier := m.versionIdentifiers[name]
|
|
|
versions, protocolExists := m.protocols[name]
|
|
|
m.mu.RUnlock()
|
|
|
|
|
|
if identifier == nil {
|
|
|
// 如果没有版本识别器,返回空字符串(表示无法识别,使用默认版本)
|
|
|
return "", nil
|
|
|
}
|
|
|
|
|
|
if !protocolExists {
|
|
|
// 协议不存在,返回错误
|
|
|
return "", protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name))
|
|
|
}
|
|
|
|
|
|
// 使用版本识别器识别版本
|
|
|
version, err := identifier.Identify(data, ctx)
|
|
|
if err != nil {
|
|
|
// 识别失败,返回错误
|
|
|
return "", err
|
|
|
}
|
|
|
|
|
|
if version != "" {
|
|
|
// 检查版本是否存在(不需要再次加锁,因为我们已经有了versions的引用)
|
|
|
_, exists := versions[version]
|
|
|
if !exists {
|
|
|
// 版本不存在,返回错误
|
|
|
return "", protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version))
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return version, nil
|
|
|
}
|
|
|
|