package storage import ( "context" "encoding/json" "fmt" "time" sessionimpl "github.com/noahlann/nnet/internal/session" sessionpkg "github.com/noahlann/nnet/pkg/session" ) // RedisStorage Redis存储接口(需要用户提供Redis客户端实现) type RedisStorage struct { client RedisClient prefix string expiration time.Duration ctx context.Context } // RedisClient Redis客户端接口(避免直接依赖Redis库) type RedisClient interface { Get(ctx context.Context, key string) (string, error) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error Delete(ctx context.Context, key string) error Keys(ctx context.Context, pattern string) ([]string, error) } // NewRedisStorage 创建Redis存储 func NewRedisStorage(client RedisClient, prefix string, expiration time.Duration) sessionpkg.Storage { if prefix == "" { prefix = "nnet:session:" } return &RedisStorage{ client: client, prefix: prefix, expiration: expiration, ctx: context.Background(), } } // Get 获取Session func (s *RedisStorage) Get(sessionID string) (sessionpkg.Session, error) { key := s.getKey(sessionID) data, err := s.client.Get(s.ctx, key) if err != nil { // Redis返回nil表示key不存在 return nil, nil } var sessionData struct { ID string `json:"id"` Data map[string]interface{} `json:"data"` CreatedAt time.Time `json:"created_at"` AccessedAt time.Time `json:"accessed_at"` IsNew bool `json:"is_new"` } if err := json.Unmarshal([]byte(data), &sessionData); err != nil { return nil, fmt.Errorf("failed to unmarshal session data: %w", err) } // 检查是否过期(Redis的TTL应该已经处理了,但这里再做一次检查) if s.expiration > 0 && time.Since(sessionData.AccessedAt) > s.expiration { s.client.Delete(s.ctx, key) return nil, nil } sess := sessionimpl.NewSession(sessionData.ID).(*sessionimpl.SessionImpl) sess.SetData(sessionData.Data) sess.SetNew(sessionData.IsNew) sess.SetCreatedAt(sessionData.CreatedAt) // 更新访问时间并保存 sess.SetData(sessionData.Data) // 这会更新AccessedAt if err := s.Save(sess); err != nil { return nil, err } return sess, nil } // Create 创建Session func (s *RedisStorage) Create(sessionID string) (sessionpkg.Session, error) { sess := sessionimpl.NewSession(sessionID).(*sessionimpl.SessionImpl) if err := s.Save(sess); err != nil { return nil, err } return sess, nil } // Delete 删除Session func (s *RedisStorage) Delete(sessionID string) error { key := s.getKey(sessionID) return s.client.Delete(s.ctx, key) } // Save 保存Session func (s *RedisStorage) Save(sess sessionpkg.Session) error { impl, ok := sess.(*sessionimpl.SessionImpl) if !ok { return fmt.Errorf("invalid session type") } key := s.getKey(sess.ID()) sessionData := struct { ID string `json:"id"` Data map[string]interface{} `json:"data"` CreatedAt time.Time `json:"created_at"` AccessedAt time.Time `json:"accessed_at"` IsNew bool `json:"is_new"` }{ ID: sess.ID(), Data: impl.Data(), CreatedAt: impl.CreatedAt(), AccessedAt: impl.AccessedAt(), IsNew: sess.IsNew(), } data, err := json.Marshal(sessionData) if err != nil { return fmt.Errorf("failed to marshal session data: %w", err) } return s.client.Set(s.ctx, key, string(data), s.expiration) } // Cleanup 清理过期Session(Redis会自动清理,这里可以做一些额外的清理工作) func (s *RedisStorage) Cleanup() error { // Redis使用TTL自动清理过期key,这里可以做一些额外的清理工作 // 例如清理无效的session key等 return nil } // getKey 获取Redis key func (s *RedisStorage) getKey(sessionID string) string { return s.prefix + sessionID }