You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
3.9 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package storage
import (
"context"
"encoding/json"
"fmt"
"time"
sessionimpl "github.com/noahlann/nnet/internal/session"
sessionpkg "github.com/noahlann/nnet/pkg/session"
)
// RedisStorage Redis存储接口需要用户提供Redis客户端实现
type RedisStorage struct {
client RedisClient
prefix string
expiration time.Duration
ctx context.Context
}
// RedisClient Redis客户端接口避免直接依赖Redis库
type RedisClient interface {
Get(ctx context.Context, key string) (string, error)
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) error
Delete(ctx context.Context, key string) error
Keys(ctx context.Context, pattern string) ([]string, error)
}
// NewRedisStorage 创建Redis存储
func NewRedisStorage(client RedisClient, prefix string, expiration time.Duration) sessionpkg.Storage {
if prefix == "" {
prefix = "nnet:session:"
}
return &RedisStorage{
client: client,
prefix: prefix,
expiration: expiration,
ctx: context.Background(),
}
}
// Get 获取Session
func (s *RedisStorage) Get(sessionID string) (sessionpkg.Session, error) {
key := s.getKey(sessionID)
data, err := s.client.Get(s.ctx, key)
if err != nil {
// Redis返回nil表示key不存在
return nil, nil
}
var sessionData struct {
ID string `json:"id"`
Data map[string]interface{} `json:"data"`
CreatedAt time.Time `json:"created_at"`
AccessedAt time.Time `json:"accessed_at"`
IsNew bool `json:"is_new"`
}
if err := json.Unmarshal([]byte(data), &sessionData); err != nil {
return nil, fmt.Errorf("failed to unmarshal session data: %w", err)
}
// 检查是否过期Redis的TTL应该已经处理了但这里再做一次检查
if s.expiration > 0 && time.Since(sessionData.AccessedAt) > s.expiration {
s.client.Delete(s.ctx, key)
return nil, nil
}
sess := sessionimpl.NewSession(sessionData.ID).(*sessionimpl.SessionImpl)
sess.SetData(sessionData.Data)
sess.SetNew(sessionData.IsNew)
sess.SetCreatedAt(sessionData.CreatedAt)
// 更新访问时间并保存
sess.SetData(sessionData.Data) // 这会更新AccessedAt
if err := s.Save(sess); err != nil {
return nil, err
}
return sess, nil
}
// Create 创建Session
func (s *RedisStorage) Create(sessionID string) (sessionpkg.Session, error) {
sess := sessionimpl.NewSession(sessionID).(*sessionimpl.SessionImpl)
if err := s.Save(sess); err != nil {
return nil, err
}
return sess, nil
}
// Delete 删除Session
func (s *RedisStorage) Delete(sessionID string) error {
key := s.getKey(sessionID)
return s.client.Delete(s.ctx, key)
}
// Save 保存Session
func (s *RedisStorage) Save(sess sessionpkg.Session) error {
impl, ok := sess.(*sessionimpl.SessionImpl)
if !ok {
return fmt.Errorf("invalid session type")
}
key := s.getKey(sess.ID())
sessionData := struct {
ID string `json:"id"`
Data map[string]interface{} `json:"data"`
CreatedAt time.Time `json:"created_at"`
AccessedAt time.Time `json:"accessed_at"`
IsNew bool `json:"is_new"`
}{
ID: sess.ID(),
Data: impl.Data(),
CreatedAt: impl.CreatedAt(),
AccessedAt: impl.AccessedAt(),
IsNew: sess.IsNew(),
}
data, err := json.Marshal(sessionData)
if err != nil {
return fmt.Errorf("failed to marshal session data: %w", err)
}
return s.client.Set(s.ctx, key, string(data), s.expiration)
}
// Cleanup 清理过期SessionRedis会自动清理这里可以做一些额外的清理工作
func (s *RedisStorage) Cleanup() error {
// Redis使用TTL自动清理过期key这里可以做一些额外的清理工作
// 例如清理无效的session key等
return nil
}
// getKey 获取Redis key
func (s *RedisStorage) getKey(sessionID string) string {
return s.prefix + sessionID
}