|
|
package connection
|
|
|
|
|
|
import (
|
|
|
"hash/fnv"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/pkg/errors"
|
|
|
)
|
|
|
|
|
|
// ShardedManager 分段连接管理器(减少锁竞争)
|
|
|
type ShardedManager struct {
|
|
|
shards []*shard
|
|
|
shardCount int
|
|
|
maxConns int
|
|
|
totalConns atomic.Int64
|
|
|
// 分组策略注册表(共享,不需要分片)
|
|
|
strategies map[string]GroupStrategy
|
|
|
strategyMu sync.RWMutex
|
|
|
}
|
|
|
|
|
|
// shard 分片
|
|
|
type shard struct {
|
|
|
connections map[string]ConnectionInterface
|
|
|
groups map[string]map[string]ConnectionInterface
|
|
|
// 索引系统:indexKey -> indexValue -> connID
|
|
|
indexes map[string]map[string]string
|
|
|
// 反向索引:connID -> indexKey -> indexValue(用于快速删除)
|
|
|
reverseIndexes map[string]map[string]string
|
|
|
mu sync.RWMutex
|
|
|
}
|
|
|
|
|
|
// NewShardedManager 创建分段连接管理器
|
|
|
func NewShardedManager(maxConns int, shardCount int) *ShardedManager {
|
|
|
if shardCount <= 0 {
|
|
|
shardCount = 16 // 默认16个分片
|
|
|
}
|
|
|
if shardCount > 256 {
|
|
|
shardCount = 256 // 最大256个分片
|
|
|
}
|
|
|
|
|
|
shards := make([]*shard, shardCount)
|
|
|
for i := range shards {
|
|
|
shards[i] = &shard{
|
|
|
connections: make(map[string]ConnectionInterface),
|
|
|
groups: make(map[string]map[string]ConnectionInterface),
|
|
|
indexes: make(map[string]map[string]string),
|
|
|
reverseIndexes: make(map[string]map[string]string),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return &ShardedManager{
|
|
|
shards: shards,
|
|
|
shardCount: shardCount,
|
|
|
maxConns: maxConns,
|
|
|
strategies: make(map[string]GroupStrategy),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// getShard 获取连接对应的分片
|
|
|
func (m *ShardedManager) getShard(connID string) *shard {
|
|
|
h := fnv.New32a()
|
|
|
h.Write([]byte(connID))
|
|
|
idx := h.Sum32() % uint32(m.shardCount)
|
|
|
return m.shards[idx]
|
|
|
}
|
|
|
|
|
|
// Add 添加连接
|
|
|
func (m *ShardedManager) Add(conn ConnectionInterface) error {
|
|
|
if conn == nil {
|
|
|
return errors.New("connection cannot be nil")
|
|
|
}
|
|
|
|
|
|
shard := m.getShard(conn.ID())
|
|
|
shard.mu.Lock()
|
|
|
defer shard.mu.Unlock()
|
|
|
|
|
|
// 检查是否已存在(避免重复添加)
|
|
|
if _, exists := shard.connections[conn.ID()]; exists {
|
|
|
return errors.New("connection already exists")
|
|
|
}
|
|
|
|
|
|
if m.maxConns > 0 {
|
|
|
for {
|
|
|
current := m.totalConns.Load()
|
|
|
if int(current) >= m.maxConns {
|
|
|
return errors.New("max connections exceeded")
|
|
|
}
|
|
|
if m.totalConns.CompareAndSwap(current, current+1) {
|
|
|
break
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
m.totalConns.Add(1)
|
|
|
}
|
|
|
|
|
|
shard.connections[conn.ID()] = conn
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Remove 移除连接
|
|
|
func (m *ShardedManager) Remove(connID string) error {
|
|
|
shard := m.getShard(connID)
|
|
|
shard.mu.Lock()
|
|
|
defer shard.mu.Unlock()
|
|
|
|
|
|
conn, ok := shard.connections[connID]
|
|
|
if !ok {
|
|
|
return errors.ErrConnectionNotFound
|
|
|
}
|
|
|
|
|
|
delete(shard.connections, connID)
|
|
|
m.totalConns.Add(-1)
|
|
|
|
|
|
// 从所有分组中移除
|
|
|
for groupID := range shard.groups {
|
|
|
delete(shard.groups[groupID], connID)
|
|
|
if len(shard.groups[groupID]) == 0 {
|
|
|
delete(shard.groups, groupID)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 从所有索引中移除
|
|
|
if indexKeys, exists := shard.reverseIndexes[connID]; exists {
|
|
|
for indexKey, indexValue := range indexKeys {
|
|
|
if indexMap, exists := shard.indexes[indexKey]; exists {
|
|
|
delete(indexMap, indexValue)
|
|
|
if len(indexMap) == 0 {
|
|
|
delete(shard.indexes, indexKey)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
delete(shard.reverseIndexes, connID)
|
|
|
}
|
|
|
|
|
|
// 关闭连接
|
|
|
return conn.Close()
|
|
|
}
|
|
|
|
|
|
// Get 获取连接
|
|
|
func (m *ShardedManager) Get(connID string) (ConnectionInterface, error) {
|
|
|
shard := m.getShard(connID)
|
|
|
shard.mu.RLock()
|
|
|
defer shard.mu.RUnlock()
|
|
|
|
|
|
conn, ok := shard.connections[connID]
|
|
|
if !ok {
|
|
|
return nil, errors.ErrConnectionNotFound
|
|
|
}
|
|
|
|
|
|
return conn, nil
|
|
|
}
|
|
|
|
|
|
// Count 获取连接数
|
|
|
func (m *ShardedManager) Count() int {
|
|
|
return int(m.totalConns.Load())
|
|
|
}
|
|
|
|
|
|
// GetAll 获取所有连接
|
|
|
func (m *ShardedManager) GetAll() []ConnectionInterface {
|
|
|
total := m.Count()
|
|
|
if total < 0 {
|
|
|
total = 0
|
|
|
}
|
|
|
allConns := make([]ConnectionInterface, 0, total)
|
|
|
|
|
|
for _, shard := range m.shards {
|
|
|
shard.mu.RLock()
|
|
|
for _, conn := range shard.connections {
|
|
|
allConns = append(allConns, conn)
|
|
|
}
|
|
|
shard.mu.RUnlock()
|
|
|
}
|
|
|
|
|
|
return allConns
|
|
|
}
|
|
|
|
|
|
// AddToGroup 将连接添加到分组
|
|
|
func (m *ShardedManager) AddToGroup(groupID string, connID string) error {
|
|
|
shard := m.getShard(connID)
|
|
|
shard.mu.Lock()
|
|
|
defer shard.mu.Unlock()
|
|
|
|
|
|
conn, ok := shard.connections[connID]
|
|
|
if !ok {
|
|
|
return errors.ErrConnectionNotFound
|
|
|
}
|
|
|
|
|
|
if shard.groups[groupID] == nil {
|
|
|
shard.groups[groupID] = make(map[string]ConnectionInterface)
|
|
|
}
|
|
|
|
|
|
shard.groups[groupID][connID] = conn
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// RemoveFromGroup 从分组中移除连接
|
|
|
func (m *ShardedManager) RemoveFromGroup(groupID string, connID string) error {
|
|
|
shard := m.getShard(connID)
|
|
|
shard.mu.Lock()
|
|
|
defer shard.mu.Unlock()
|
|
|
|
|
|
if shard.groups[groupID] == nil {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
delete(shard.groups[groupID], connID)
|
|
|
if len(shard.groups[groupID]) == 0 {
|
|
|
delete(shard.groups, groupID)
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// GetGroup 获取分组
|
|
|
func (m *ShardedManager) GetGroup(groupID string) []ConnectionInterface {
|
|
|
var conns []ConnectionInterface
|
|
|
|
|
|
// 需要在所有分片中查找(因为分组可能跨分片)
|
|
|
for _, shard := range m.shards {
|
|
|
shard.mu.RLock()
|
|
|
if shard.groups[groupID] != nil {
|
|
|
for _, conn := range shard.groups[groupID] {
|
|
|
conns = append(conns, conn)
|
|
|
}
|
|
|
}
|
|
|
shard.mu.RUnlock()
|
|
|
}
|
|
|
|
|
|
return conns
|
|
|
}
|
|
|
|
|
|
// BroadcastToGroup 向分组广播消息
|
|
|
func (m *ShardedManager) BroadcastToGroup(groupID string, data []byte) error {
|
|
|
if groupID == "" {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
var mu sync.Mutex
|
|
|
var errs []error
|
|
|
|
|
|
for _, shard := range m.shards {
|
|
|
shard.mu.RLock()
|
|
|
group := shard.groups[groupID]
|
|
|
if len(group) == 0 {
|
|
|
shard.mu.RUnlock()
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
conns := make([]ConnectionInterface, 0, len(group))
|
|
|
for _, conn := range group {
|
|
|
conns = append(conns, conn)
|
|
|
}
|
|
|
shard.mu.RUnlock()
|
|
|
|
|
|
if len(conns) == 0 {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
wg.Add(1)
|
|
|
go func(connections []ConnectionInterface) {
|
|
|
defer wg.Done()
|
|
|
var localErrs []error
|
|
|
for _, conn := range connections {
|
|
|
if err := conn.Write(data); err != nil {
|
|
|
localErrs = append(localErrs, err)
|
|
|
}
|
|
|
}
|
|
|
if len(localErrs) > 0 {
|
|
|
mu.Lock()
|
|
|
errs = append(errs, localErrs...)
|
|
|
mu.Unlock()
|
|
|
}
|
|
|
}(conns)
|
|
|
}
|
|
|
|
|
|
wg.Wait()
|
|
|
|
|
|
if len(errs) > 0 {
|
|
|
return errors.Newf("broadcast failed: %d errors", len(errs))
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// CleanupInactive 清理非活动连接
|
|
|
func (m *ShardedManager) CleanupInactive(timeout interface{}) {
|
|
|
timeoutDuration, ok := timeout.(time.Duration)
|
|
|
if !ok {
|
|
|
return
|
|
|
}
|
|
|
now := time.Now()
|
|
|
|
|
|
for _, shard := range m.shards {
|
|
|
shard.mu.Lock()
|
|
|
for connID, conn := range shard.connections {
|
|
|
// 检查连接是否有LastActive方法
|
|
|
if activeConn, ok := conn.(interface{ LastActive() time.Time }); ok {
|
|
|
if now.Sub(activeConn.LastActive()) > timeoutDuration {
|
|
|
delete(shard.connections, connID)
|
|
|
conn.Close()
|
|
|
m.totalConns.Add(-1)
|
|
|
|
|
|
// 从所有分组中移除
|
|
|
for groupID := range shard.groups {
|
|
|
delete(shard.groups[groupID], connID)
|
|
|
if len(shard.groups[groupID]) == 0 {
|
|
|
delete(shard.groups, groupID)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 从所有索引中移除
|
|
|
if indexKeys, exists := shard.reverseIndexes[connID]; exists {
|
|
|
for indexKey, indexValue := range indexKeys {
|
|
|
if indexMap, exists := shard.indexes[indexKey]; exists {
|
|
|
delete(indexMap, indexValue)
|
|
|
if len(indexMap) == 0 {
|
|
|
delete(shard.indexes, indexKey)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
delete(shard.reverseIndexes, connID)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
shard.mu.Unlock()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// AddIndex 添加索引(通过业务数据查询连接)
|
|
|
func (m *ShardedManager) AddIndex(key, value string, connID string) error {
|
|
|
shard := m.getShard(connID)
|
|
|
shard.mu.Lock()
|
|
|
defer shard.mu.Unlock()
|
|
|
|
|
|
// 验证连接存在
|
|
|
if _, exists := shard.connections[connID]; !exists {
|
|
|
return errors.ErrConnectionNotFound
|
|
|
}
|
|
|
|
|
|
// 创建索引映射(如果不存在)
|
|
|
if shard.indexes[key] == nil {
|
|
|
shard.indexes[key] = make(map[string]string)
|
|
|
}
|
|
|
|
|
|
// 添加索引(如果值已存在,会覆盖)
|
|
|
shard.indexes[key][value] = connID
|
|
|
|
|
|
// 维护反向索引
|
|
|
if shard.reverseIndexes[connID] == nil {
|
|
|
shard.reverseIndexes[connID] = make(map[string]string)
|
|
|
}
|
|
|
shard.reverseIndexes[connID][key] = value
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// RemoveIndex 移除索引
|
|
|
func (m *ShardedManager) RemoveIndex(key, value string, connID string) error {
|
|
|
shard := m.getShard(connID)
|
|
|
shard.mu.Lock()
|
|
|
defer shard.mu.Unlock()
|
|
|
|
|
|
// 从索引中移除
|
|
|
if indexMap, exists := shard.indexes[key]; exists {
|
|
|
delete(indexMap, value)
|
|
|
if len(indexMap) == 0 {
|
|
|
delete(shard.indexes, key)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 从反向索引中移除
|
|
|
if reverseIndex, exists := shard.reverseIndexes[connID]; exists {
|
|
|
delete(reverseIndex, key)
|
|
|
if len(reverseIndex) == 0 {
|
|
|
delete(shard.reverseIndexes, connID)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// FindByIndex 通过索引查找连接(需要遍历所有分片)
|
|
|
func (m *ShardedManager) FindByIndex(key, value string) (ConnectionInterface, error) {
|
|
|
// 需要在所有分片中查找
|
|
|
for _, shard := range m.shards {
|
|
|
shard.mu.RLock()
|
|
|
indexMap, exists := shard.indexes[key]
|
|
|
if exists {
|
|
|
if connID, exists := indexMap[value]; exists {
|
|
|
if conn, exists := shard.connections[connID]; exists {
|
|
|
shard.mu.RUnlock()
|
|
|
return conn, nil
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
shard.mu.RUnlock()
|
|
|
}
|
|
|
|
|
|
return nil, errors.ErrConnectionNotFound
|
|
|
}
|
|
|
|
|
|
// FindByIndexKey 通过索引键查找所有连接(需要遍历所有分片)
|
|
|
func (m *ShardedManager) FindByIndexKey(key string) ([]ConnectionInterface, error) {
|
|
|
var conns []ConnectionInterface
|
|
|
seen := make(map[string]bool) // 去重
|
|
|
|
|
|
// 需要在所有分片中查找
|
|
|
for _, shard := range m.shards {
|
|
|
shard.mu.RLock()
|
|
|
indexMap, exists := shard.indexes[key]
|
|
|
if exists {
|
|
|
for _, connID := range indexMap {
|
|
|
if seen[connID] {
|
|
|
continue
|
|
|
}
|
|
|
if conn, exists := shard.connections[connID]; exists {
|
|
|
conns = append(conns, conn)
|
|
|
seen[connID] = true
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
shard.mu.RUnlock()
|
|
|
}
|
|
|
|
|
|
return conns, nil
|
|
|
}
|
|
|
|
|
|
// AddToGroupByStrategy 使用分组策略将连接添加到分组
|
|
|
func (m *ShardedManager) AddToGroupByStrategy(groupID string, connID string, strategy GroupStrategy) error {
|
|
|
shard := m.getShard(connID)
|
|
|
shard.mu.Lock()
|
|
|
defer shard.mu.Unlock()
|
|
|
|
|
|
conn, ok := shard.connections[connID]
|
|
|
if !ok {
|
|
|
return errors.ErrConnectionNotFound
|
|
|
}
|
|
|
|
|
|
// 如果策略支持自动分组,使用策略获取分组ID
|
|
|
if strategy != nil {
|
|
|
if autoGroupID := strategy.GetGroupID(conn); autoGroupID != "" {
|
|
|
groupID = autoGroupID
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if shard.groups[groupID] == nil {
|
|
|
shard.groups[groupID] = make(map[string]ConnectionInterface)
|
|
|
}
|
|
|
|
|
|
shard.groups[groupID][connID] = conn
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// GetGroupByStrategy 使用分组策略获取分组(需要遍历所有分片)
|
|
|
func (m *ShardedManager) GetGroupByStrategy(groupID string, strategy GroupStrategy) []ConnectionInterface {
|
|
|
var conns []ConnectionInterface
|
|
|
|
|
|
// 如果使用策略,通过策略匹配所有连接(需要遍历所有分片)
|
|
|
if strategy != nil {
|
|
|
for _, shard := range m.shards {
|
|
|
shard.mu.RLock()
|
|
|
for _, conn := range shard.connections {
|
|
|
if strategy.Match(conn, groupID) {
|
|
|
conns = append(conns, conn)
|
|
|
}
|
|
|
}
|
|
|
shard.mu.RUnlock()
|
|
|
}
|
|
|
return conns
|
|
|
}
|
|
|
|
|
|
// 否则使用传统的字符串分组(需要遍历所有分片)
|
|
|
for _, shard := range m.shards {
|
|
|
shard.mu.RLock()
|
|
|
if shard.groups[groupID] != nil {
|
|
|
for _, conn := range shard.groups[groupID] {
|
|
|
conns = append(conns, conn)
|
|
|
}
|
|
|
}
|
|
|
shard.mu.RUnlock()
|
|
|
}
|
|
|
|
|
|
return conns
|
|
|
}
|
|
|
|
|
|
// RegisterGroupStrategy 注册分组策略
|
|
|
func (m *ShardedManager) RegisterGroupStrategy(name string, strategy GroupStrategy) {
|
|
|
m.strategyMu.Lock()
|
|
|
defer m.strategyMu.Unlock()
|
|
|
|
|
|
if m.strategies == nil {
|
|
|
m.strategies = make(map[string]GroupStrategy)
|
|
|
}
|
|
|
m.strategies[name] = strategy
|
|
|
}
|
|
|
|
|
|
// GetGroupStrategy 获取已注册的分组策略
|
|
|
func (m *ShardedManager) GetGroupStrategy(name string) GroupStrategy {
|
|
|
m.strategyMu.RLock()
|
|
|
defer m.strategyMu.RUnlock()
|
|
|
|
|
|
return m.strategies[name]
|
|
|
}
|