|
|
package storage
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
sessionimpl "github.com/noahlann/nnet/internal/session"
|
|
|
sessionpkg "github.com/noahlann/nnet/pkg/session"
|
|
|
)
|
|
|
|
|
|
// FileStorage 文件存储
|
|
|
type FileStorage struct {
|
|
|
path string
|
|
|
sessions map[string]*sessionimpl.SessionImpl
|
|
|
mu sync.RWMutex
|
|
|
expiration time.Duration
|
|
|
}
|
|
|
|
|
|
// NewFileStorage 创建文件存储
|
|
|
func NewFileStorage(path string, expiration time.Duration) (sessionpkg.Storage, error) {
|
|
|
// 确保目录存在
|
|
|
if err := os.MkdirAll(path, 0755); err != nil {
|
|
|
return nil, fmt.Errorf("failed to create storage directory: %w", err)
|
|
|
}
|
|
|
|
|
|
storage := &FileStorage{
|
|
|
path: path,
|
|
|
sessions: make(map[string]*sessionimpl.SessionImpl),
|
|
|
expiration: expiration,
|
|
|
}
|
|
|
|
|
|
// 加载已存在的session文件
|
|
|
if err := storage.loadSessions(); err != nil {
|
|
|
return nil, fmt.Errorf("failed to load sessions: %w", err)
|
|
|
}
|
|
|
|
|
|
// 启动清理goroutine
|
|
|
go storage.cleanup()
|
|
|
|
|
|
return storage, nil
|
|
|
}
|
|
|
|
|
|
// Get 获取Session
|
|
|
func (s *FileStorage) Get(sessionID string) (sessionpkg.Session, error) {
|
|
|
s.mu.RLock()
|
|
|
sess, ok := s.sessions[sessionID]
|
|
|
s.mu.RUnlock()
|
|
|
|
|
|
if !ok {
|
|
|
// 尝试从文件加载
|
|
|
return s.loadSessionFromFile(sessionID)
|
|
|
}
|
|
|
|
|
|
// 检查是否过期
|
|
|
if s.expiration > 0 && time.Since(sess.AccessedAt()) > s.expiration {
|
|
|
s.mu.Lock()
|
|
|
delete(s.sessions, sessionID)
|
|
|
s.mu.Unlock()
|
|
|
s.deleteSessionFile(sessionID)
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
sess.SetNew(false)
|
|
|
return sess, nil
|
|
|
}
|
|
|
|
|
|
// Create 创建Session
|
|
|
func (s *FileStorage) Create(sessionID string) (sessionpkg.Session, error) {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
sess := sessionimpl.NewSession(sessionID).(*sessionimpl.SessionImpl)
|
|
|
s.sessions[sessionID] = sess
|
|
|
|
|
|
// 保存到文件
|
|
|
if err := s.saveSessionToFile(sess); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return sess, nil
|
|
|
}
|
|
|
|
|
|
// Delete 删除Session
|
|
|
func (s *FileStorage) Delete(sessionID string) error {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
delete(s.sessions, sessionID)
|
|
|
s.deleteSessionFile(sessionID)
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Save 保存Session
|
|
|
func (s *FileStorage) Save(sess sessionpkg.Session) error {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
if impl, ok := sess.(*sessionimpl.SessionImpl); ok {
|
|
|
s.sessions[sess.ID()] = impl
|
|
|
return s.saveSessionToFile(impl)
|
|
|
}
|
|
|
return fmt.Errorf("invalid session type")
|
|
|
}
|
|
|
|
|
|
// Cleanup 清理过期Session
|
|
|
func (s *FileStorage) Cleanup() error {
|
|
|
s.mu.Lock()
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
now := time.Now()
|
|
|
for id, sess := range s.sessions {
|
|
|
if s.expiration > 0 && now.Sub(sess.AccessedAt()) > s.expiration {
|
|
|
delete(s.sessions, id)
|
|
|
s.deleteSessionFile(id)
|
|
|
}
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// loadSessions 加载所有Session
|
|
|
func (s *FileStorage) loadSessions() error {
|
|
|
entries, err := os.ReadDir(s.path)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
for _, entry := range entries {
|
|
|
if entry.IsDir() {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
if filepath.Ext(entry.Name()) != ".json" {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
sessionID := entry.Name()[:len(entry.Name())-5] // 去掉 .json 扩展名
|
|
|
sess, err := s.loadSessionFromFile(sessionID)
|
|
|
if err != nil {
|
|
|
continue // 跳过损坏的文件
|
|
|
}
|
|
|
|
|
|
if sess != nil {
|
|
|
s.sessions[sessionID] = sess.(*sessionimpl.SessionImpl)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// loadSessionFromFile 从文件加载Session
|
|
|
func (s *FileStorage) loadSessionFromFile(sessionID string) (sessionpkg.Session, error) {
|
|
|
filePath := s.getSessionFilePath(sessionID)
|
|
|
data, err := os.ReadFile(filePath)
|
|
|
if err != nil {
|
|
|
if os.IsNotExist(err) {
|
|
|
return nil, nil
|
|
|
}
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
var sessionData struct {
|
|
|
ID string `json:"id"`
|
|
|
Data map[string]interface{} `json:"data"`
|
|
|
CreatedAt time.Time `json:"created_at"`
|
|
|
AccessedAt time.Time `json:"accessed_at"`
|
|
|
IsNew bool `json:"is_new"`
|
|
|
}
|
|
|
|
|
|
if err := json.Unmarshal(data, &sessionData); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 检查是否过期
|
|
|
if s.expiration > 0 && time.Since(sessionData.AccessedAt) > s.expiration {
|
|
|
s.deleteSessionFile(sessionID)
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
sess := sessionimpl.NewSession(sessionData.ID).(*sessionimpl.SessionImpl)
|
|
|
sess.SetData(sessionData.Data)
|
|
|
sess.SetNew(sessionData.IsNew)
|
|
|
sess.SetCreatedAt(sessionData.CreatedAt)
|
|
|
|
|
|
s.mu.Lock()
|
|
|
s.sessions[sessionID] = sess
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
return sess, nil
|
|
|
}
|
|
|
|
|
|
// saveSessionToFile 保存Session到文件
|
|
|
func (s *FileStorage) saveSessionToFile(sess *sessionimpl.SessionImpl) error {
|
|
|
filePath := s.getSessionFilePath(sess.ID())
|
|
|
|
|
|
sessionData := struct {
|
|
|
ID string `json:"id"`
|
|
|
Data map[string]interface{} `json:"data"`
|
|
|
CreatedAt time.Time `json:"created_at"`
|
|
|
AccessedAt time.Time `json:"accessed_at"`
|
|
|
IsNew bool `json:"is_new"`
|
|
|
}{
|
|
|
ID: sess.ID(),
|
|
|
Data: sess.Data(),
|
|
|
CreatedAt: sess.CreatedAt(),
|
|
|
AccessedAt: sess.AccessedAt(),
|
|
|
IsNew: sess.IsNew(),
|
|
|
}
|
|
|
|
|
|
data, err := json.Marshal(sessionData)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
return os.WriteFile(filePath, data, 0644)
|
|
|
}
|
|
|
|
|
|
// deleteSessionFile 删除Session文件
|
|
|
func (s *FileStorage) deleteSessionFile(sessionID string) {
|
|
|
filePath := s.getSessionFilePath(sessionID)
|
|
|
os.Remove(filePath)
|
|
|
}
|
|
|
|
|
|
// getSessionFilePath 获取Session文件路径
|
|
|
func (s *FileStorage) getSessionFilePath(sessionID string) string {
|
|
|
return filepath.Join(s.path, sessionID+".json")
|
|
|
}
|
|
|
|
|
|
// cleanup 定期清理(后台goroutine)
|
|
|
func (s *FileStorage) cleanup() {
|
|
|
ticker := time.NewTicker(1 * time.Minute)
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
for range ticker.C {
|
|
|
s.Cleanup()
|
|
|
}
|
|
|
}
|
|
|
|