package connection import ( "hash/fnv" "sync" "sync/atomic" "time" "github.com/noahlann/nnet/pkg/errors" ) // ShardedManager 分段连接管理器(减少锁竞争) type ShardedManager struct { shards []*shard shardCount int maxConns int totalConns atomic.Int64 // 分组策略注册表(共享,不需要分片) strategies map[string]GroupStrategy strategyMu sync.RWMutex } // shard 分片 type shard struct { connections map[string]ConnectionInterface groups map[string]map[string]ConnectionInterface // 索引系统:indexKey -> indexValue -> connID indexes map[string]map[string]string // 反向索引:connID -> indexKey -> indexValue(用于快速删除) reverseIndexes map[string]map[string]string mu sync.RWMutex } // NewShardedManager 创建分段连接管理器 func NewShardedManager(maxConns int, shardCount int) *ShardedManager { if shardCount <= 0 { shardCount = 16 // 默认16个分片 } if shardCount > 256 { shardCount = 256 // 最大256个分片 } shards := make([]*shard, shardCount) for i := range shards { shards[i] = &shard{ connections: make(map[string]ConnectionInterface), groups: make(map[string]map[string]ConnectionInterface), indexes: make(map[string]map[string]string), reverseIndexes: make(map[string]map[string]string), } } return &ShardedManager{ shards: shards, shardCount: shardCount, maxConns: maxConns, strategies: make(map[string]GroupStrategy), } } // getShard 获取连接对应的分片 func (m *ShardedManager) getShard(connID string) *shard { h := fnv.New32a() h.Write([]byte(connID)) idx := h.Sum32() % uint32(m.shardCount) return m.shards[idx] } // Add 添加连接 func (m *ShardedManager) Add(conn ConnectionInterface) error { if conn == nil { return errors.New("connection cannot be nil") } shard := m.getShard(conn.ID()) shard.mu.Lock() defer shard.mu.Unlock() // 检查是否已存在(避免重复添加) if _, exists := shard.connections[conn.ID()]; exists { return errors.New("connection already exists") } if m.maxConns > 0 { for { current := m.totalConns.Load() if int(current) >= m.maxConns { return errors.New("max connections exceeded") } if m.totalConns.CompareAndSwap(current, current+1) { break } } } else { m.totalConns.Add(1) } shard.connections[conn.ID()] = conn return nil } // Remove 移除连接 func (m *ShardedManager) Remove(connID string) error { shard := m.getShard(connID) shard.mu.Lock() defer shard.mu.Unlock() conn, ok := shard.connections[connID] if !ok { return errors.ErrConnectionNotFound } delete(shard.connections, connID) m.totalConns.Add(-1) // 从所有分组中移除 for groupID := range shard.groups { delete(shard.groups[groupID], connID) if len(shard.groups[groupID]) == 0 { delete(shard.groups, groupID) } } // 从所有索引中移除 if indexKeys, exists := shard.reverseIndexes[connID]; exists { for indexKey, indexValue := range indexKeys { if indexMap, exists := shard.indexes[indexKey]; exists { delete(indexMap, indexValue) if len(indexMap) == 0 { delete(shard.indexes, indexKey) } } } delete(shard.reverseIndexes, connID) } // 关闭连接 return conn.Close() } // Get 获取连接 func (m *ShardedManager) Get(connID string) (ConnectionInterface, error) { shard := m.getShard(connID) shard.mu.RLock() defer shard.mu.RUnlock() conn, ok := shard.connections[connID] if !ok { return nil, errors.ErrConnectionNotFound } return conn, nil } // Count 获取连接数 func (m *ShardedManager) Count() int { return int(m.totalConns.Load()) } // GetAll 获取所有连接 func (m *ShardedManager) GetAll() []ConnectionInterface { total := m.Count() if total < 0 { total = 0 } allConns := make([]ConnectionInterface, 0, total) for _, shard := range m.shards { shard.mu.RLock() for _, conn := range shard.connections { allConns = append(allConns, conn) } shard.mu.RUnlock() } return allConns } // AddToGroup 将连接添加到分组 func (m *ShardedManager) AddToGroup(groupID string, connID string) error { shard := m.getShard(connID) shard.mu.Lock() defer shard.mu.Unlock() conn, ok := shard.connections[connID] if !ok { return errors.ErrConnectionNotFound } if shard.groups[groupID] == nil { shard.groups[groupID] = make(map[string]ConnectionInterface) } shard.groups[groupID][connID] = conn return nil } // RemoveFromGroup 从分组中移除连接 func (m *ShardedManager) RemoveFromGroup(groupID string, connID string) error { shard := m.getShard(connID) shard.mu.Lock() defer shard.mu.Unlock() if shard.groups[groupID] == nil { return nil } delete(shard.groups[groupID], connID) if len(shard.groups[groupID]) == 0 { delete(shard.groups, groupID) } return nil } // GetGroup 获取分组 func (m *ShardedManager) GetGroup(groupID string) []ConnectionInterface { var conns []ConnectionInterface // 需要在所有分片中查找(因为分组可能跨分片) for _, shard := range m.shards { shard.mu.RLock() if shard.groups[groupID] != nil { for _, conn := range shard.groups[groupID] { conns = append(conns, conn) } } shard.mu.RUnlock() } return conns } // BroadcastToGroup 向分组广播消息 func (m *ShardedManager) BroadcastToGroup(groupID string, data []byte) error { if groupID == "" { return nil } var wg sync.WaitGroup var mu sync.Mutex var errs []error for _, shard := range m.shards { shard.mu.RLock() group := shard.groups[groupID] if len(group) == 0 { shard.mu.RUnlock() continue } conns := make([]ConnectionInterface, 0, len(group)) for _, conn := range group { conns = append(conns, conn) } shard.mu.RUnlock() if len(conns) == 0 { continue } wg.Add(1) go func(connections []ConnectionInterface) { defer wg.Done() var localErrs []error for _, conn := range connections { if err := conn.Write(data); err != nil { localErrs = append(localErrs, err) } } if len(localErrs) > 0 { mu.Lock() errs = append(errs, localErrs...) mu.Unlock() } }(conns) } wg.Wait() if len(errs) > 0 { return errors.Newf("broadcast failed: %d errors", len(errs)) } return nil } // CleanupInactive 清理非活动连接 func (m *ShardedManager) CleanupInactive(timeout interface{}) { timeoutDuration, ok := timeout.(time.Duration) if !ok { return } now := time.Now() for _, shard := range m.shards { shard.mu.Lock() for connID, conn := range shard.connections { // 检查连接是否有LastActive方法 if activeConn, ok := conn.(interface{ LastActive() time.Time }); ok { if now.Sub(activeConn.LastActive()) > timeoutDuration { delete(shard.connections, connID) conn.Close() m.totalConns.Add(-1) // 从所有分组中移除 for groupID := range shard.groups { delete(shard.groups[groupID], connID) if len(shard.groups[groupID]) == 0 { delete(shard.groups, groupID) } } // 从所有索引中移除 if indexKeys, exists := shard.reverseIndexes[connID]; exists { for indexKey, indexValue := range indexKeys { if indexMap, exists := shard.indexes[indexKey]; exists { delete(indexMap, indexValue) if len(indexMap) == 0 { delete(shard.indexes, indexKey) } } } delete(shard.reverseIndexes, connID) } } } } shard.mu.Unlock() } } // AddIndex 添加索引(通过业务数据查询连接) func (m *ShardedManager) AddIndex(key, value string, connID string) error { shard := m.getShard(connID) shard.mu.Lock() defer shard.mu.Unlock() // 验证连接存在 if _, exists := shard.connections[connID]; !exists { return errors.ErrConnectionNotFound } // 创建索引映射(如果不存在) if shard.indexes[key] == nil { shard.indexes[key] = make(map[string]string) } // 添加索引(如果值已存在,会覆盖) shard.indexes[key][value] = connID // 维护反向索引 if shard.reverseIndexes[connID] == nil { shard.reverseIndexes[connID] = make(map[string]string) } shard.reverseIndexes[connID][key] = value return nil } // RemoveIndex 移除索引 func (m *ShardedManager) RemoveIndex(key, value string, connID string) error { shard := m.getShard(connID) shard.mu.Lock() defer shard.mu.Unlock() // 从索引中移除 if indexMap, exists := shard.indexes[key]; exists { delete(indexMap, value) if len(indexMap) == 0 { delete(shard.indexes, key) } } // 从反向索引中移除 if reverseIndex, exists := shard.reverseIndexes[connID]; exists { delete(reverseIndex, key) if len(reverseIndex) == 0 { delete(shard.reverseIndexes, connID) } } return nil } // FindByIndex 通过索引查找连接(需要遍历所有分片) func (m *ShardedManager) FindByIndex(key, value string) (ConnectionInterface, error) { // 需要在所有分片中查找 for _, shard := range m.shards { shard.mu.RLock() indexMap, exists := shard.indexes[key] if exists { if connID, exists := indexMap[value]; exists { if conn, exists := shard.connections[connID]; exists { shard.mu.RUnlock() return conn, nil } } } shard.mu.RUnlock() } return nil, errors.ErrConnectionNotFound } // FindByIndexKey 通过索引键查找所有连接(需要遍历所有分片) func (m *ShardedManager) FindByIndexKey(key string) ([]ConnectionInterface, error) { var conns []ConnectionInterface seen := make(map[string]bool) // 去重 // 需要在所有分片中查找 for _, shard := range m.shards { shard.mu.RLock() indexMap, exists := shard.indexes[key] if exists { for _, connID := range indexMap { if seen[connID] { continue } if conn, exists := shard.connections[connID]; exists { conns = append(conns, conn) seen[connID] = true } } } shard.mu.RUnlock() } return conns, nil } // AddToGroupByStrategy 使用分组策略将连接添加到分组 func (m *ShardedManager) AddToGroupByStrategy(groupID string, connID string, strategy GroupStrategy) error { shard := m.getShard(connID) shard.mu.Lock() defer shard.mu.Unlock() conn, ok := shard.connections[connID] if !ok { return errors.ErrConnectionNotFound } // 如果策略支持自动分组,使用策略获取分组ID if strategy != nil { if autoGroupID := strategy.GetGroupID(conn); autoGroupID != "" { groupID = autoGroupID } } if shard.groups[groupID] == nil { shard.groups[groupID] = make(map[string]ConnectionInterface) } shard.groups[groupID][connID] = conn return nil } // GetGroupByStrategy 使用分组策略获取分组(需要遍历所有分片) func (m *ShardedManager) GetGroupByStrategy(groupID string, strategy GroupStrategy) []ConnectionInterface { var conns []ConnectionInterface // 如果使用策略,通过策略匹配所有连接(需要遍历所有分片) if strategy != nil { for _, shard := range m.shards { shard.mu.RLock() for _, conn := range shard.connections { if strategy.Match(conn, groupID) { conns = append(conns, conn) } } shard.mu.RUnlock() } return conns } // 否则使用传统的字符串分组(需要遍历所有分片) for _, shard := range m.shards { shard.mu.RLock() if shard.groups[groupID] != nil { for _, conn := range shard.groups[groupID] { conns = append(conns, conn) } } shard.mu.RUnlock() } return conns } // RegisterGroupStrategy 注册分组策略 func (m *ShardedManager) RegisterGroupStrategy(name string, strategy GroupStrategy) { m.strategyMu.Lock() defer m.strategyMu.Unlock() if m.strategies == nil { m.strategies = make(map[string]GroupStrategy) } m.strategies[name] = strategy } // GetGroupStrategy 获取已注册的分组策略 func (m *ShardedManager) GetGroupStrategy(name string) GroupStrategy { m.strategyMu.RLock() defer m.strategyMu.RUnlock() return m.strategies[name] }