You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

269 lines
7.4 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package integration
import (
"encoding/json"
"fmt"
"net"
"testing"
"time"
internalsession "github.com/noahlann/nnet/internal/session"
"github.com/noahlann/nnet/internal/session/storage"
"github.com/noahlann/nnet/pkg/nnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestSessionMemoryStorage 测试Session内存存储
func TestSessionMemoryStorage(t *testing.T) {
// 创建内存存储
storage := storage.NewMemoryStorage(30 * time.Minute)
// 创建Session
sessionID := "test-session-001"
session, err := storage.Create(sessionID)
require.NoError(t, err, "Should create session")
require.NotNil(t, session, "Session should not be nil")
assert.True(t, session.IsNew(), "Session should be new")
// 设置值
err = session.Set("key1", "value1")
require.NoError(t, err, "Should set value")
err = session.Set("key2", 42)
require.NoError(t, err, "Should set value")
// 保存Session
err = storage.Save(session)
require.NoError(t, err, "Should save session")
// 获取Session
retrievedSession, err := storage.Get(sessionID)
require.NoError(t, err, "Should get session")
require.NotNil(t, retrievedSession, "Retrieved session should not be nil")
assert.False(t, retrievedSession.IsNew(), "Retrieved session should not be new")
// 验证值
val1, err := retrievedSession.Get("key1")
require.NoError(t, err, "Should get key1")
assert.Equal(t, "value1", val1, "Value1 should match")
val2, err := retrievedSession.Get("key2")
require.NoError(t, err, "Should get key2")
assert.Equal(t, 42, val2, "Value2 should match")
// 删除值
err = retrievedSession.Delete("key1")
require.NoError(t, err, "Should delete key1")
val1, err = retrievedSession.Get("key1")
require.NoError(t, err, "Should get key1 after delete")
assert.Nil(t, val1, "Value1 should be nil after delete")
// 清空Session
err = retrievedSession.Clear()
require.NoError(t, err, "Should clear session")
val2, err = retrievedSession.Get("key2")
require.NoError(t, err, "Should get key2 after clear")
assert.Nil(t, val2, "Value2 should be nil after clear")
// 删除Session
err = storage.Delete(sessionID)
require.NoError(t, err, "Should delete session")
retrievedSession, err = storage.Get(sessionID)
require.NoError(t, err, "Should get session after delete")
assert.Nil(t, retrievedSession, "Session should be nil after delete")
}
// TestSessionFileStorage 测试Session文件存储
func TestSessionFileStorage(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()
// 创建文件存储
fileStorage, err := storage.NewFileStorage(tempDir, 30*time.Minute)
require.NoError(t, err, "Should create file storage")
// 创建Session
sessionID := "test-session-002"
session, err := fileStorage.Create(sessionID)
require.NoError(t, err, "Should create session")
require.NotNil(t, session, "Session should not be nil")
// 设置值
err = session.Set("key1", "value1")
require.NoError(t, err, "Should set value")
// 保存Session
err = fileStorage.Save(session)
require.NoError(t, err, "Should save session")
// 获取Session
retrievedSession, err := fileStorage.Get(sessionID)
require.NoError(t, err, "Should get session")
require.NotNil(t, retrievedSession, "Retrieved session should not be nil")
// 验证值
val1, err := retrievedSession.Get("key1")
require.NoError(t, err, "Should get key1")
assert.Equal(t, "value1", val1, "Value1 should match")
// 删除Session
err = fileStorage.Delete(sessionID)
require.NoError(t, err, "Should delete session")
}
// TestSessionInContext 测试在Context中使用Session
func TestSessionInContext(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
listener.Close()
// 创建Session存储
sessionStorage := storage.NewMemoryStorage(30 * time.Minute)
cfg := &nnet.Config{
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
Codec: &nnet.CodecConfig{
DefaultCodec: "json",
},
}
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
// 注册路由使用Session
srv.Router().RegisterString("session", func(ctx nnet.Context) error {
// 从连接ID生成Session ID
connID := ctx.Connection().ID()
sessionID := "session-" + connID
// 获取或创建Session
session, err := sessionStorage.Get(sessionID)
if err != nil {
return err
}
if session == nil {
session, err = sessionStorage.Create(sessionID)
if err != nil {
return err
}
}
// 读取Session值
count, _ := session.Get("count")
countInt := 0
if count != nil {
if c, ok := count.(int); ok {
countInt = c
} else if c, ok := count.(float64); ok {
countInt = int(c)
}
}
// 增加计数
countInt++
err = session.Set("count", countInt)
if err != nil {
return err
}
// 保存Session
err = sessionStorage.Save(session)
if err != nil {
return err
}
// 返回响应
return ctx.Response().Write(map[string]any{
"session_id": sessionID,
"count": countInt,
})
})
})
defer CleanupTestServer(t, ts)
client := NewTestClient(t, ts.Addr, nil)
defer CleanupTestClient(t, client)
ConnectTestClient(t, client)
time.Sleep(100 * time.Millisecond)
// 发送多次请求验证Session持久化
for i := 1; i <= 3; i++ {
resp := RequestWithTimeout(t, client, []byte("session"), 3*time.Second)
t.Logf("Response %d: %q", i, string(resp))
var result map[string]any
err = json.Unmarshal(resp, &result)
assert.NoError(t, err, "Response should be valid JSON")
assert.Equal(t, float64(i), result["count"], "Count should be %d", i)
}
}
// TestSessionExpiration 测试Session过期
func TestSessionExpiration(t *testing.T) {
// 创建内存存储,设置很短的过期时间
storage := storage.NewMemoryStorage(100 * time.Millisecond)
// 创建Session
sessionID := "test-session-expire"
session, err := storage.Create(sessionID)
require.NoError(t, err, "Should create session")
// 设置值
err = session.Set("key1", "value1")
require.NoError(t, err, "Should set value")
// 保存Session
err = storage.Save(session)
require.NoError(t, err, "Should save session")
// 立即获取Session应该存在
retrievedSession, err := storage.Get(sessionID)
require.NoError(t, err, "Should get session")
require.NotNil(t, retrievedSession, "Session should exist")
// 等待过期
time.Sleep(150 * time.Millisecond)
// 再次获取Session应该过期
retrievedSession, err = storage.Get(sessionID)
require.NoError(t, err, "Should get session after expiration")
assert.Nil(t, retrievedSession, "Session should be nil after expiration")
}
// TestSessionManager 测试Session管理器
func TestSessionManager(t *testing.T) {
// 创建Session实现
sessionID := "test-session-manager"
session := internalsession.NewSession(sessionID)
// 测试基本操作
assert.Equal(t, sessionID, session.ID(), "Session ID should match")
assert.True(t, session.IsNew(), "Session should be new")
// 设置值
err := session.Set("key1", "value1")
require.NoError(t, err, "Should set value")
// 获取值
val1, err := session.Get("key1")
require.NoError(t, err, "Should get value")
assert.Equal(t, "value1", val1, "Value should match")
// 删除值
err = session.Delete("key1")
require.NoError(t, err, "Should delete value")
val1, err = session.Get("key1")
require.NoError(t, err, "Should get value after delete")
assert.Nil(t, val1, "Value should be nil after delete")
// 清空
err = session.Clear()
require.NoError(t, err, "Should clear session")
}