package protocol import ( "context" "fmt" "sync" protocolpkg "github.com/noahlann/nnet/pkg/protocol" ) // manager 协议管理器实现 type manager struct { protocols map[string]map[string]protocolpkg.Protocol defaultProto protocolpkg.Protocol versionIdentifiers map[string]protocolpkg.VersionIdentifier mu sync.RWMutex } // NewManager 创建协议管理器 func NewManager() protocolpkg.Manager { return &manager{ protocols: make(map[string]map[string]protocolpkg.Protocol), versionIdentifiers: make(map[string]protocolpkg.VersionIdentifier), } } // Register 注册协议 func (m *manager) Register(protocol protocolpkg.Protocol) error { m.mu.Lock() defer m.mu.Unlock() if protocol == nil { return protocolpkg.NewError("protocol cannot be nil") } name := protocol.Name() version := protocol.Version() if name == "" { return protocolpkg.NewError("protocol name cannot be empty") } if version == "" { version = "default" } if m.protocols[name] == nil { m.protocols[name] = make(map[string]protocolpkg.Protocol) } m.protocols[name][version] = protocol // 如果是第一个协议,设置为默认协议 if m.defaultProto == nil { m.defaultProto = protocol } return nil } // Get 获取协议 func (m *manager) Get(name, version string) (protocolpkg.Protocol, error) { m.mu.RLock() defer m.mu.RUnlock() if name == "" { return m.defaultProto, nil } versions, ok := m.protocols[name] if !ok { return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name)) } if version == "" { version = "default" if protocol, ok := versions[version]; ok { return protocol, nil } // 如果未显式标记default版本,返回第一个可用版本 for _, protocol := range versions { return protocol, nil } return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s has no registered versions", name)) } protocol, ok := versions[version] if !ok { return nil, protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version)) } return protocol, nil } // GetDefault 获取默认协议 func (m *manager) GetDefault() protocolpkg.Protocol { m.mu.RLock() defer m.mu.RUnlock() return m.defaultProto } // SetDefault 设置默认协议 func (m *manager) SetDefault(name, version string) error { m.mu.Lock() defer m.mu.Unlock() // 直接在这里获取协议,避免调用Get方法导致死锁(Get也需要锁) if name == "" { // 如果名称为空,使用第一个协议作为默认协议 if len(m.protocols) == 0 { return protocolpkg.NewError("no protocols registered") } // 获取第一个协议 for _, versions := range m.protocols { for _, protocol := range versions { m.defaultProto = protocol return nil } } } versions, ok := m.protocols[name] if !ok { return protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name)) } if version == "" { version = "default" } protocol, ok := versions[version] if !ok { return protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version)) } m.defaultProto = protocol return nil } // List 列出所有协议 func (m *manager) List() []protocolpkg.Protocol { m.mu.RLock() defer m.mu.RUnlock() var protocols []protocolpkg.Protocol for _, versions := range m.protocols { for _, protocol := range versions { protocols = append(protocols, protocol) } } return protocols } // RegisterVersionIdentifier 注册版本识别器 func (m *manager) RegisterVersionIdentifier(name string, identifier protocolpkg.VersionIdentifier) error { m.mu.Lock() defer m.mu.Unlock() if name == "" { return protocolpkg.NewError("protocol name cannot be empty") } if identifier == nil { return protocolpkg.NewError("version identifier cannot be nil") } m.versionIdentifiers[name] = identifier return nil } // GetVersionIdentifier 获取版本识别器 func (m *manager) GetVersionIdentifier(name string) protocolpkg.VersionIdentifier { m.mu.RLock() defer m.mu.RUnlock() return m.versionIdentifiers[name] } // IdentifyVersion 识别协议版本 func (m *manager) IdentifyVersion(name string, data []byte, ctx context.Context) (string, error) { // 获取版本识别器和协议版本映射(一次性获取,避免多次加锁) m.mu.RLock() identifier := m.versionIdentifiers[name] versions, protocolExists := m.protocols[name] m.mu.RUnlock() if identifier == nil { // 如果没有版本识别器,返回空字符串(表示无法识别,使用默认版本) return "", nil } if !protocolExists { // 协议不存在,返回错误 return "", protocolpkg.NewError(fmt.Sprintf("protocol %s not found", name)) } // 使用版本识别器识别版本 version, err := identifier.Identify(data, ctx) if err != nil { // 识别失败,返回错误 return "", err } if version != "" { // 检查版本是否存在(不需要再次加锁,因为我们已经有了versions的引用) _, exists := versions[version] if !exists { // 版本不存在,返回错误 return "", protocolpkg.NewError(fmt.Sprintf("protocol %s version %s not found", name, version)) } } return version, nil }