package storage import ( "encoding/json" "fmt" "os" "path/filepath" "sync" "time" sessionimpl "github.com/noahlann/nnet/internal/session" sessionpkg "github.com/noahlann/nnet/pkg/session" ) // FileStorage 文件存储 type FileStorage struct { path string sessions map[string]*sessionimpl.SessionImpl mu sync.RWMutex expiration time.Duration } // NewFileStorage 创建文件存储 func NewFileStorage(path string, expiration time.Duration) (sessionpkg.Storage, error) { // 确保目录存在 if err := os.MkdirAll(path, 0755); err != nil { return nil, fmt.Errorf("failed to create storage directory: %w", err) } storage := &FileStorage{ path: path, sessions: make(map[string]*sessionimpl.SessionImpl), expiration: expiration, } // 加载已存在的session文件 if err := storage.loadSessions(); err != nil { return nil, fmt.Errorf("failed to load sessions: %w", err) } // 启动清理goroutine go storage.cleanup() return storage, nil } // Get 获取Session func (s *FileStorage) Get(sessionID string) (sessionpkg.Session, error) { s.mu.RLock() sess, ok := s.sessions[sessionID] s.mu.RUnlock() if !ok { // 尝试从文件加载 return s.loadSessionFromFile(sessionID) } // 检查是否过期 if s.expiration > 0 && time.Since(sess.AccessedAt()) > s.expiration { s.mu.Lock() delete(s.sessions, sessionID) s.mu.Unlock() s.deleteSessionFile(sessionID) return nil, nil } sess.SetNew(false) return sess, nil } // Create 创建Session func (s *FileStorage) Create(sessionID string) (sessionpkg.Session, error) { s.mu.Lock() defer s.mu.Unlock() sess := sessionimpl.NewSession(sessionID).(*sessionimpl.SessionImpl) s.sessions[sessionID] = sess // 保存到文件 if err := s.saveSessionToFile(sess); err != nil { return nil, err } return sess, nil } // Delete 删除Session func (s *FileStorage) Delete(sessionID string) error { s.mu.Lock() defer s.mu.Unlock() delete(s.sessions, sessionID) s.deleteSessionFile(sessionID) return nil } // Save 保存Session func (s *FileStorage) Save(sess sessionpkg.Session) error { s.mu.Lock() defer s.mu.Unlock() if impl, ok := sess.(*sessionimpl.SessionImpl); ok { s.sessions[sess.ID()] = impl return s.saveSessionToFile(impl) } return fmt.Errorf("invalid session type") } // Cleanup 清理过期Session func (s *FileStorage) Cleanup() error { s.mu.Lock() defer s.mu.Unlock() now := time.Now() for id, sess := range s.sessions { if s.expiration > 0 && now.Sub(sess.AccessedAt()) > s.expiration { delete(s.sessions, id) s.deleteSessionFile(id) } } return nil } // loadSessions 加载所有Session func (s *FileStorage) loadSessions() error { entries, err := os.ReadDir(s.path) if err != nil { return err } for _, entry := range entries { if entry.IsDir() { continue } if filepath.Ext(entry.Name()) != ".json" { continue } sessionID := entry.Name()[:len(entry.Name())-5] // 去掉 .json 扩展名 sess, err := s.loadSessionFromFile(sessionID) if err != nil { continue // 跳过损坏的文件 } if sess != nil { s.sessions[sessionID] = sess.(*sessionimpl.SessionImpl) } } return nil } // loadSessionFromFile 从文件加载Session func (s *FileStorage) loadSessionFromFile(sessionID string) (sessionpkg.Session, error) { filePath := s.getSessionFilePath(sessionID) data, err := os.ReadFile(filePath) if err != nil { if os.IsNotExist(err) { return nil, nil } return nil, err } var sessionData struct { ID string `json:"id"` Data map[string]interface{} `json:"data"` CreatedAt time.Time `json:"created_at"` AccessedAt time.Time `json:"accessed_at"` IsNew bool `json:"is_new"` } if err := json.Unmarshal(data, &sessionData); err != nil { return nil, err } // 检查是否过期 if s.expiration > 0 && time.Since(sessionData.AccessedAt) > s.expiration { s.deleteSessionFile(sessionID) return nil, nil } sess := sessionimpl.NewSession(sessionData.ID).(*sessionimpl.SessionImpl) sess.SetData(sessionData.Data) sess.SetNew(sessionData.IsNew) sess.SetCreatedAt(sessionData.CreatedAt) s.mu.Lock() s.sessions[sessionID] = sess s.mu.Unlock() return sess, nil } // saveSessionToFile 保存Session到文件 func (s *FileStorage) saveSessionToFile(sess *sessionimpl.SessionImpl) error { filePath := s.getSessionFilePath(sess.ID()) sessionData := struct { ID string `json:"id"` Data map[string]interface{} `json:"data"` CreatedAt time.Time `json:"created_at"` AccessedAt time.Time `json:"accessed_at"` IsNew bool `json:"is_new"` }{ ID: sess.ID(), Data: sess.Data(), CreatedAt: sess.CreatedAt(), AccessedAt: sess.AccessedAt(), IsNew: sess.IsNew(), } data, err := json.Marshal(sessionData) if err != nil { return err } return os.WriteFile(filePath, data, 0644) } // deleteSessionFile 删除Session文件 func (s *FileStorage) deleteSessionFile(sessionID string) { filePath := s.getSessionFilePath(sessionID) os.Remove(filePath) } // getSessionFilePath 获取Session文件路径 func (s *FileStorage) getSessionFilePath(sessionID string) string { return filepath.Join(s.path, sessionID+".json") } // cleanup 定期清理(后台goroutine) func (s *FileStorage) cleanup() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for range ticker.C { s.Cleanup() } }