You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

508 lines
12 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package connection
import (
"hash/fnv"
"sync"
"sync/atomic"
"time"
"github.com/noahlann/nnet/pkg/errors"
)
// ShardedManager 分段连接管理器(减少锁竞争)
type ShardedManager struct {
shards []*shard
shardCount int
maxConns int
totalConns atomic.Int64
// 分组策略注册表(共享,不需要分片)
strategies map[string]GroupStrategy
strategyMu sync.RWMutex
}
// shard 分片
type shard struct {
connections map[string]ConnectionInterface
groups map[string]map[string]ConnectionInterface
// 索引系统indexKey -> indexValue -> connID
indexes map[string]map[string]string
// 反向索引connID -> indexKey -> indexValue用于快速删除
reverseIndexes map[string]map[string]string
mu sync.RWMutex
}
// NewShardedManager 创建分段连接管理器
func NewShardedManager(maxConns int, shardCount int) *ShardedManager {
if shardCount <= 0 {
shardCount = 16 // 默认16个分片
}
if shardCount > 256 {
shardCount = 256 // 最大256个分片
}
shards := make([]*shard, shardCount)
for i := range shards {
shards[i] = &shard{
connections: make(map[string]ConnectionInterface),
groups: make(map[string]map[string]ConnectionInterface),
indexes: make(map[string]map[string]string),
reverseIndexes: make(map[string]map[string]string),
}
}
return &ShardedManager{
shards: shards,
shardCount: shardCount,
maxConns: maxConns,
strategies: make(map[string]GroupStrategy),
}
}
// getShard 获取连接对应的分片
func (m *ShardedManager) getShard(connID string) *shard {
h := fnv.New32a()
h.Write([]byte(connID))
idx := h.Sum32() % uint32(m.shardCount)
return m.shards[idx]
}
// Add 添加连接
func (m *ShardedManager) Add(conn ConnectionInterface) error {
if conn == nil {
return errors.New("connection cannot be nil")
}
shard := m.getShard(conn.ID())
shard.mu.Lock()
defer shard.mu.Unlock()
// 检查是否已存在(避免重复添加)
if _, exists := shard.connections[conn.ID()]; exists {
return errors.New("connection already exists")
}
if m.maxConns > 0 {
for {
current := m.totalConns.Load()
if int(current) >= m.maxConns {
return errors.New("max connections exceeded")
}
if m.totalConns.CompareAndSwap(current, current+1) {
break
}
}
} else {
m.totalConns.Add(1)
}
shard.connections[conn.ID()] = conn
return nil
}
// Remove 移除连接
func (m *ShardedManager) Remove(connID string) error {
shard := m.getShard(connID)
shard.mu.Lock()
defer shard.mu.Unlock()
conn, ok := shard.connections[connID]
if !ok {
return errors.ErrConnectionNotFound
}
delete(shard.connections, connID)
m.totalConns.Add(-1)
// 从所有分组中移除
for groupID := range shard.groups {
delete(shard.groups[groupID], connID)
if len(shard.groups[groupID]) == 0 {
delete(shard.groups, groupID)
}
}
// 从所有索引中移除
if indexKeys, exists := shard.reverseIndexes[connID]; exists {
for indexKey, indexValue := range indexKeys {
if indexMap, exists := shard.indexes[indexKey]; exists {
delete(indexMap, indexValue)
if len(indexMap) == 0 {
delete(shard.indexes, indexKey)
}
}
}
delete(shard.reverseIndexes, connID)
}
// 关闭连接
return conn.Close()
}
// Get 获取连接
func (m *ShardedManager) Get(connID string) (ConnectionInterface, error) {
shard := m.getShard(connID)
shard.mu.RLock()
defer shard.mu.RUnlock()
conn, ok := shard.connections[connID]
if !ok {
return nil, errors.ErrConnectionNotFound
}
return conn, nil
}
// Count 获取连接数
func (m *ShardedManager) Count() int {
return int(m.totalConns.Load())
}
// GetAll 获取所有连接
func (m *ShardedManager) GetAll() []ConnectionInterface {
total := m.Count()
if total < 0 {
total = 0
}
allConns := make([]ConnectionInterface, 0, total)
for _, shard := range m.shards {
shard.mu.RLock()
for _, conn := range shard.connections {
allConns = append(allConns, conn)
}
shard.mu.RUnlock()
}
return allConns
}
// AddToGroup 将连接添加到分组
func (m *ShardedManager) AddToGroup(groupID string, connID string) error {
shard := m.getShard(connID)
shard.mu.Lock()
defer shard.mu.Unlock()
conn, ok := shard.connections[connID]
if !ok {
return errors.ErrConnectionNotFound
}
if shard.groups[groupID] == nil {
shard.groups[groupID] = make(map[string]ConnectionInterface)
}
shard.groups[groupID][connID] = conn
return nil
}
// RemoveFromGroup 从分组中移除连接
func (m *ShardedManager) RemoveFromGroup(groupID string, connID string) error {
shard := m.getShard(connID)
shard.mu.Lock()
defer shard.mu.Unlock()
if shard.groups[groupID] == nil {
return nil
}
delete(shard.groups[groupID], connID)
if len(shard.groups[groupID]) == 0 {
delete(shard.groups, groupID)
}
return nil
}
// GetGroup 获取分组
func (m *ShardedManager) GetGroup(groupID string) []ConnectionInterface {
var conns []ConnectionInterface
// 需要在所有分片中查找(因为分组可能跨分片)
for _, shard := range m.shards {
shard.mu.RLock()
if shard.groups[groupID] != nil {
for _, conn := range shard.groups[groupID] {
conns = append(conns, conn)
}
}
shard.mu.RUnlock()
}
return conns
}
// BroadcastToGroup 向分组广播消息
func (m *ShardedManager) BroadcastToGroup(groupID string, data []byte) error {
if groupID == "" {
return nil
}
var wg sync.WaitGroup
var mu sync.Mutex
var errs []error
for _, shard := range m.shards {
shard.mu.RLock()
group := shard.groups[groupID]
if len(group) == 0 {
shard.mu.RUnlock()
continue
}
conns := make([]ConnectionInterface, 0, len(group))
for _, conn := range group {
conns = append(conns, conn)
}
shard.mu.RUnlock()
if len(conns) == 0 {
continue
}
wg.Add(1)
go func(connections []ConnectionInterface) {
defer wg.Done()
var localErrs []error
for _, conn := range connections {
if err := conn.Write(data); err != nil {
localErrs = append(localErrs, err)
}
}
if len(localErrs) > 0 {
mu.Lock()
errs = append(errs, localErrs...)
mu.Unlock()
}
}(conns)
}
wg.Wait()
if len(errs) > 0 {
return errors.Newf("broadcast failed: %d errors", len(errs))
}
return nil
}
// CleanupInactive 清理非活动连接
func (m *ShardedManager) CleanupInactive(timeout interface{}) {
timeoutDuration, ok := timeout.(time.Duration)
if !ok {
return
}
now := time.Now()
for _, shard := range m.shards {
shard.mu.Lock()
for connID, conn := range shard.connections {
// 检查连接是否有LastActive方法
if activeConn, ok := conn.(interface{ LastActive() time.Time }); ok {
if now.Sub(activeConn.LastActive()) > timeoutDuration {
delete(shard.connections, connID)
conn.Close()
m.totalConns.Add(-1)
// 从所有分组中移除
for groupID := range shard.groups {
delete(shard.groups[groupID], connID)
if len(shard.groups[groupID]) == 0 {
delete(shard.groups, groupID)
}
}
// 从所有索引中移除
if indexKeys, exists := shard.reverseIndexes[connID]; exists {
for indexKey, indexValue := range indexKeys {
if indexMap, exists := shard.indexes[indexKey]; exists {
delete(indexMap, indexValue)
if len(indexMap) == 0 {
delete(shard.indexes, indexKey)
}
}
}
delete(shard.reverseIndexes, connID)
}
}
}
}
shard.mu.Unlock()
}
}
// AddIndex 添加索引(通过业务数据查询连接)
func (m *ShardedManager) AddIndex(key, value string, connID string) error {
shard := m.getShard(connID)
shard.mu.Lock()
defer shard.mu.Unlock()
// 验证连接存在
if _, exists := shard.connections[connID]; !exists {
return errors.ErrConnectionNotFound
}
// 创建索引映射(如果不存在)
if shard.indexes[key] == nil {
shard.indexes[key] = make(map[string]string)
}
// 添加索引(如果值已存在,会覆盖)
shard.indexes[key][value] = connID
// 维护反向索引
if shard.reverseIndexes[connID] == nil {
shard.reverseIndexes[connID] = make(map[string]string)
}
shard.reverseIndexes[connID][key] = value
return nil
}
// RemoveIndex 移除索引
func (m *ShardedManager) RemoveIndex(key, value string, connID string) error {
shard := m.getShard(connID)
shard.mu.Lock()
defer shard.mu.Unlock()
// 从索引中移除
if indexMap, exists := shard.indexes[key]; exists {
delete(indexMap, value)
if len(indexMap) == 0 {
delete(shard.indexes, key)
}
}
// 从反向索引中移除
if reverseIndex, exists := shard.reverseIndexes[connID]; exists {
delete(reverseIndex, key)
if len(reverseIndex) == 0 {
delete(shard.reverseIndexes, connID)
}
}
return nil
}
// FindByIndex 通过索引查找连接(需要遍历所有分片)
func (m *ShardedManager) FindByIndex(key, value string) (ConnectionInterface, error) {
// 需要在所有分片中查找
for _, shard := range m.shards {
shard.mu.RLock()
indexMap, exists := shard.indexes[key]
if exists {
if connID, exists := indexMap[value]; exists {
if conn, exists := shard.connections[connID]; exists {
shard.mu.RUnlock()
return conn, nil
}
}
}
shard.mu.RUnlock()
}
return nil, errors.ErrConnectionNotFound
}
// FindByIndexKey 通过索引键查找所有连接(需要遍历所有分片)
func (m *ShardedManager) FindByIndexKey(key string) ([]ConnectionInterface, error) {
var conns []ConnectionInterface
seen := make(map[string]bool) // 去重
// 需要在所有分片中查找
for _, shard := range m.shards {
shard.mu.RLock()
indexMap, exists := shard.indexes[key]
if exists {
for _, connID := range indexMap {
if seen[connID] {
continue
}
if conn, exists := shard.connections[connID]; exists {
conns = append(conns, conn)
seen[connID] = true
}
}
}
shard.mu.RUnlock()
}
return conns, nil
}
// AddToGroupByStrategy 使用分组策略将连接添加到分组
func (m *ShardedManager) AddToGroupByStrategy(groupID string, connID string, strategy GroupStrategy) error {
shard := m.getShard(connID)
shard.mu.Lock()
defer shard.mu.Unlock()
conn, ok := shard.connections[connID]
if !ok {
return errors.ErrConnectionNotFound
}
// 如果策略支持自动分组使用策略获取分组ID
if strategy != nil {
if autoGroupID := strategy.GetGroupID(conn); autoGroupID != "" {
groupID = autoGroupID
}
}
if shard.groups[groupID] == nil {
shard.groups[groupID] = make(map[string]ConnectionInterface)
}
shard.groups[groupID][connID] = conn
return nil
}
// GetGroupByStrategy 使用分组策略获取分组(需要遍历所有分片)
func (m *ShardedManager) GetGroupByStrategy(groupID string, strategy GroupStrategy) []ConnectionInterface {
var conns []ConnectionInterface
// 如果使用策略,通过策略匹配所有连接(需要遍历所有分片)
if strategy != nil {
for _, shard := range m.shards {
shard.mu.RLock()
for _, conn := range shard.connections {
if strategy.Match(conn, groupID) {
conns = append(conns, conn)
}
}
shard.mu.RUnlock()
}
return conns
}
// 否则使用传统的字符串分组(需要遍历所有分片)
for _, shard := range m.shards {
shard.mu.RLock()
if shard.groups[groupID] != nil {
for _, conn := range shard.groups[groupID] {
conns = append(conns, conn)
}
}
shard.mu.RUnlock()
}
return conns
}
// RegisterGroupStrategy 注册分组策略
func (m *ShardedManager) RegisterGroupStrategy(name string, strategy GroupStrategy) {
m.strategyMu.Lock()
defer m.strategyMu.Unlock()
if m.strategies == nil {
m.strategies = make(map[string]GroupStrategy)
}
m.strategies[name] = strategy
}
// GetGroupStrategy 获取已注册的分组策略
func (m *ShardedManager) GetGroupStrategy(name string) GroupStrategy {
m.strategyMu.RLock()
defer m.strategyMu.RUnlock()
return m.strategies[name]
}