|
|
package storage
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"time"
|
|
|
|
|
|
sessionimpl "github.com/noahlann/nnet/internal/session"
|
|
|
sessionpkg "github.com/noahlann/nnet/pkg/session"
|
|
|
)
|
|
|
|
|
|
// RedisStorage Redis存储接口(需要用户提供Redis客户端实现)
|
|
|
type RedisStorage struct {
|
|
|
client RedisClient
|
|
|
prefix string
|
|
|
expiration time.Duration
|
|
|
ctx context.Context
|
|
|
}
|
|
|
|
|
|
// RedisClient Redis客户端接口(避免直接依赖Redis库)
|
|
|
type RedisClient interface {
|
|
|
Get(ctx context.Context, key string) (string, error)
|
|
|
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error
|
|
|
Delete(ctx context.Context, key string) error
|
|
|
Keys(ctx context.Context, pattern string) ([]string, error)
|
|
|
}
|
|
|
|
|
|
// NewRedisStorage 创建Redis存储
|
|
|
func NewRedisStorage(client RedisClient, prefix string, expiration time.Duration) sessionpkg.Storage {
|
|
|
if prefix == "" {
|
|
|
prefix = "nnet:session:"
|
|
|
}
|
|
|
return &RedisStorage{
|
|
|
client: client,
|
|
|
prefix: prefix,
|
|
|
expiration: expiration,
|
|
|
ctx: context.Background(),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Get 获取Session
|
|
|
func (s *RedisStorage) Get(sessionID string) (sessionpkg.Session, error) {
|
|
|
key := s.getKey(sessionID)
|
|
|
data, err := s.client.Get(s.ctx, key)
|
|
|
if err != nil {
|
|
|
// Redis返回nil表示key不存在
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
var sessionData struct {
|
|
|
ID string `json:"id"`
|
|
|
Data map[string]interface{} `json:"data"`
|
|
|
CreatedAt time.Time `json:"created_at"`
|
|
|
AccessedAt time.Time `json:"accessed_at"`
|
|
|
IsNew bool `json:"is_new"`
|
|
|
}
|
|
|
|
|
|
if err := json.Unmarshal([]byte(data), &sessionData); err != nil {
|
|
|
return nil, fmt.Errorf("failed to unmarshal session data: %w", err)
|
|
|
}
|
|
|
|
|
|
// 检查是否过期(Redis的TTL应该已经处理了,但这里再做一次检查)
|
|
|
if s.expiration > 0 && time.Since(sessionData.AccessedAt) > s.expiration {
|
|
|
s.client.Delete(s.ctx, key)
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
sess := sessionimpl.NewSession(sessionData.ID).(*sessionimpl.SessionImpl)
|
|
|
sess.SetData(sessionData.Data)
|
|
|
sess.SetNew(sessionData.IsNew)
|
|
|
sess.SetCreatedAt(sessionData.CreatedAt)
|
|
|
|
|
|
// 更新访问时间并保存
|
|
|
sess.SetData(sessionData.Data) // 这会更新AccessedAt
|
|
|
if err := s.Save(sess); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
return sess, nil
|
|
|
}
|
|
|
|
|
|
// Create 创建Session
|
|
|
func (s *RedisStorage) Create(sessionID string) (sessionpkg.Session, error) {
|
|
|
sess := sessionimpl.NewSession(sessionID).(*sessionimpl.SessionImpl)
|
|
|
if err := s.Save(sess); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
return sess, nil
|
|
|
}
|
|
|
|
|
|
// Delete 删除Session
|
|
|
func (s *RedisStorage) Delete(sessionID string) error {
|
|
|
key := s.getKey(sessionID)
|
|
|
return s.client.Delete(s.ctx, key)
|
|
|
}
|
|
|
|
|
|
// Save 保存Session
|
|
|
func (s *RedisStorage) Save(sess sessionpkg.Session) error {
|
|
|
impl, ok := sess.(*sessionimpl.SessionImpl)
|
|
|
if !ok {
|
|
|
return fmt.Errorf("invalid session type")
|
|
|
}
|
|
|
|
|
|
key := s.getKey(sess.ID())
|
|
|
|
|
|
sessionData := struct {
|
|
|
ID string `json:"id"`
|
|
|
Data map[string]interface{} `json:"data"`
|
|
|
CreatedAt time.Time `json:"created_at"`
|
|
|
AccessedAt time.Time `json:"accessed_at"`
|
|
|
IsNew bool `json:"is_new"`
|
|
|
}{
|
|
|
ID: sess.ID(),
|
|
|
Data: impl.Data(),
|
|
|
CreatedAt: impl.CreatedAt(),
|
|
|
AccessedAt: impl.AccessedAt(),
|
|
|
IsNew: sess.IsNew(),
|
|
|
}
|
|
|
|
|
|
data, err := json.Marshal(sessionData)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to marshal session data: %w", err)
|
|
|
}
|
|
|
|
|
|
return s.client.Set(s.ctx, key, string(data), s.expiration)
|
|
|
}
|
|
|
|
|
|
// Cleanup 清理过期Session(Redis会自动清理,这里可以做一些额外的清理工作)
|
|
|
func (s *RedisStorage) Cleanup() error {
|
|
|
// Redis使用TTL自动清理过期key,这里可以做一些额外的清理工作
|
|
|
// 例如清理无效的session key等
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// getKey 获取Redis key
|
|
|
func (s *RedisStorage) getKey(sessionID string) string {
|
|
|
return s.prefix + sessionID
|
|
|
}
|
|
|
|