|
|
package storage
|
|
|
|
|
|
import (
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
sessionimpl "github.com/noahlann/nnet/internal/session"
|
|
|
sessionpkg "github.com/noahlann/nnet/pkg/session"
|
|
|
)
|
|
|
|
|
|
// MemoryStorage 内存存储
|
|
|
type MemoryStorage struct {
|
|
|
sessions map[string]*sessionimpl.SessionImpl
|
|
|
mu sync.RWMutex
|
|
|
expiration time.Duration
|
|
|
}
|
|
|
|
|
|
// NewMemoryStorage 创建内存存储
|
|
|
func NewMemoryStorage(expiration time.Duration) sessionpkg.Storage {
|
|
|
storage := &MemoryStorage{
|
|
|
sessions: make(map[string]*sessionimpl.SessionImpl),
|
|
|
expiration: expiration,
|
|
|
}
|
|
|
|
|
|
// 启动清理goroutine
|
|
|
go storage.cleanup()
|
|
|
|
|
|
return storage
|
|
|
}
|
|
|
|
|
|
// Get 获取Session
|
|
|
func (s *MemoryStorage) Get(sessionID string) (sessionpkg.Session, error) {
|
|
|
s.mu.RLock()
|
|
|
defer s.mu.RUnlock()
|
|
|
|
|
|
sess, ok := s.sessions[sessionID]
|
|
|
if !ok {
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
// 检查是否过期
|
|
|
if s.expiration > 0 && time.Since(sess.AccessedAt()) > s.expiration {
|
|
|
delete(s.sessions, sessionID)
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
sess.SetNew(false)
|
|
|
return sess, nil
|
|
|
}
|
|
|
|
|
|
// Create 创建Session
|
|
|
func (s *MemoryStorage) Create(sessionID string) (sessionpkg.Session, error) {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
sess := sessionimpl.NewSession(sessionID).(*sessionimpl.SessionImpl)
|
|
|
s.sessions[sessionID] = sess
|
|
|
return sess, nil
|
|
|
}
|
|
|
|
|
|
// Delete 删除Session
|
|
|
func (s *MemoryStorage) Delete(sessionID string) error {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
delete(s.sessions, sessionID)
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Save 保存Session
|
|
|
func (s *MemoryStorage) Save(sess sessionpkg.Session) error {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
if impl, ok := sess.(*sessionimpl.SessionImpl); ok {
|
|
|
s.sessions[sess.ID()] = impl
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Cleanup 清理过期Session
|
|
|
func (s *MemoryStorage) Cleanup() error {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
now := time.Now()
|
|
|
for id, sess := range s.sessions {
|
|
|
if s.expiration > 0 && now.Sub(sess.AccessedAt()) > s.expiration {
|
|
|
delete(s.sessions, id)
|
|
|
}
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// cleanup 定期清理(后台goroutine)
|
|
|
func (s *MemoryStorage) cleanup() {
|
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
for range ticker.C {
|
|
|
s.Cleanup()
|
|
|
}
|
|
|
}
|