You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
5.1 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package protocol
import (
"context"
"fmt"
"sync"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
)
// manager 协议管理器实现
type manager struct {
protocols map[string]map[string]protocolpkg.Protocol
defaultProto protocolpkg.Protocol
versionIdentifiers map[string]protocolpkg.VersionIdentifier
mu sync.RWMutex
}
// NewManager 创建协议管理器
func NewManager() protocolpkg.Manager {
return &manager{
protocols: make(map[string]map[string]protocolpkg.Protocol),
versionIdentifiers: make(map[string]protocolpkg.VersionIdentifier),
}
}
// Register 注册协议
func (m *manager) Register(protocol protocolpkg.Protocol) error {
m.mu.Lock()
defer m.mu.Unlock()
if protocol == nil {
return protocolpkg.NewError("protocol cannot be nil")
}
name := protocol.Name()
version := protocol.Version()
if name == "" {
return protocolpkg.NewError("protocol name cannot be empty")
}
if version == "" {
version = "default"
}
if m.protocols[name] == nil {
m.protocols[name] = make(map[string]protocolpkg.Protocol)
}
m.protocols[name][version] = protocol
// 如果是第一个协议,设置为默认协议
if m.defaultProto == nil {
m.defaultProto = protocol
}
return nil
}
// Get 获取协议
func (m *manager) Get(name, version string) (protocolpkg.Protocol, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if name == "" {
return m.defaultProto, nil
}
versions, ok := m.protocols[name]
if !ok {
return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name))
}
if version == "" {
version = "default"
if protocol, ok := versions[version]; ok {
return protocol, nil
}
// 如果未显式标记default版本返回第一个可用版本
for _, protocol := range versions {
return protocol, nil
}
return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s has no registered versions", name))
}
protocol, ok := versions[version]
if !ok {
return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version))
}
return protocol, nil
}
// GetDefault 获取默认协议
func (m *manager) GetDefault() protocolpkg.Protocol {
m.mu.RLock()
defer m.mu.RUnlock()
return m.defaultProto
}
// SetDefault 设置默认协议
func (m *manager) SetDefault(name, version string) error {
m.mu.Lock()
defer m.mu.Unlock()
// 直接在这里获取协议避免调用Get方法导致死锁Get也需要锁
if name == "" {
// 如果名称为空,使用第一个协议作为默认协议
if len(m.protocols) == 0 {
return protocolpkg.NewError("no protocols registered")
}
// 获取第一个协议
for _, versions := range m.protocols {
for _, protocol := range versions {
m.defaultProto = protocol
return nil
}
}
}
versions, ok := m.protocols[name]
if !ok {
return protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name))
}
if version == "" {
version = "default"
}
protocol, ok := versions[version]
if !ok {
return protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version))
}
m.defaultProto = protocol
return nil
}
// List 列出所有协议
func (m *manager) List() []protocolpkg.Protocol {
m.mu.RLock()
defer m.mu.RUnlock()
var protocols []protocolpkg.Protocol
for _, versions := range m.protocols {
for _, protocol := range versions {
protocols = append(protocols, protocol)
}
}
return protocols
}
// RegisterVersionIdentifier 注册版本识别器
func (m *manager) RegisterVersionIdentifier(name string, identifier protocolpkg.VersionIdentifier) error {
m.mu.Lock()
defer m.mu.Unlock()
if name == "" {
return protocolpkg.NewError("protocol name cannot be empty")
}
if identifier == nil {
return protocolpkg.NewError("version identifier cannot be nil")
}
m.versionIdentifiers[name] = identifier
return nil
}
// GetVersionIdentifier 获取版本识别器
func (m *manager) GetVersionIdentifier(name string) protocolpkg.VersionIdentifier {
m.mu.RLock()
defer m.mu.RUnlock()
return m.versionIdentifiers[name]
}
// IdentifyVersion 识别协议版本
func (m *manager) IdentifyVersion(name string, data []byte, ctx context.Context) (string, error) {
// 获取版本识别器和协议版本映射(一次性获取,避免多次加锁)
m.mu.RLock()
identifier := m.versionIdentifiers[name]
versions, protocolExists := m.protocols[name]
m.mu.RUnlock()
if identifier == nil {
// 如果没有版本识别器,返回空字符串(表示无法识别,使用默认版本)
return "", nil
}
if !protocolExists {
// 协议不存在,返回错误
return "", protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name))
}
// 使用版本识别器识别版本
version, err := identifier.Identify(data, ctx)
if err != nil {
// 识别失败,返回错误
return "", err
}
if version != "" {
// 检查版本是否存在不需要再次加锁因为我们已经有了versions的引用
_, exists := versions[version]
if !exists {
// 版本不存在,返回错误
return "", protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version))
}
}
return version, nil
}