|
|
package integration
|
|
|
|
|
|
import (
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
"net"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
internalsession "github.com/noahlann/nnet/internal/session"
|
|
|
"github.com/noahlann/nnet/internal/session/storage"
|
|
|
"github.com/noahlann/nnet/pkg/nnet"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
// TestSessionMemoryStorage 测试Session内存存储
|
|
|
func TestSessionMemoryStorage(t *testing.T) {
|
|
|
// 创建内存存储
|
|
|
storage := storage.NewMemoryStorage(30 * time.Minute)
|
|
|
|
|
|
// 创建Session
|
|
|
sessionID := "test-session-001"
|
|
|
session, err := storage.Create(sessionID)
|
|
|
require.NoError(t, err, "Should create session")
|
|
|
require.NotNil(t, session, "Session should not be nil")
|
|
|
assert.True(t, session.IsNew(), "Session should be new")
|
|
|
|
|
|
// 设置值
|
|
|
err = session.Set("key1", "value1")
|
|
|
require.NoError(t, err, "Should set value")
|
|
|
|
|
|
err = session.Set("key2", 42)
|
|
|
require.NoError(t, err, "Should set value")
|
|
|
|
|
|
// 保存Session
|
|
|
err = storage.Save(session)
|
|
|
require.NoError(t, err, "Should save session")
|
|
|
|
|
|
// 获取Session
|
|
|
retrievedSession, err := storage.Get(sessionID)
|
|
|
require.NoError(t, err, "Should get session")
|
|
|
require.NotNil(t, retrievedSession, "Retrieved session should not be nil")
|
|
|
assert.False(t, retrievedSession.IsNew(), "Retrieved session should not be new")
|
|
|
|
|
|
// 验证值
|
|
|
val1, err := retrievedSession.Get("key1")
|
|
|
require.NoError(t, err, "Should get key1")
|
|
|
assert.Equal(t, "value1", val1, "Value1 should match")
|
|
|
|
|
|
val2, err := retrievedSession.Get("key2")
|
|
|
require.NoError(t, err, "Should get key2")
|
|
|
assert.Equal(t, 42, val2, "Value2 should match")
|
|
|
|
|
|
// 删除值
|
|
|
err = retrievedSession.Delete("key1")
|
|
|
require.NoError(t, err, "Should delete key1")
|
|
|
|
|
|
val1, err = retrievedSession.Get("key1")
|
|
|
require.NoError(t, err, "Should get key1 after delete")
|
|
|
assert.Nil(t, val1, "Value1 should be nil after delete")
|
|
|
|
|
|
// 清空Session
|
|
|
err = retrievedSession.Clear()
|
|
|
require.NoError(t, err, "Should clear session")
|
|
|
|
|
|
val2, err = retrievedSession.Get("key2")
|
|
|
require.NoError(t, err, "Should get key2 after clear")
|
|
|
assert.Nil(t, val2, "Value2 should be nil after clear")
|
|
|
|
|
|
// 删除Session
|
|
|
err = storage.Delete(sessionID)
|
|
|
require.NoError(t, err, "Should delete session")
|
|
|
|
|
|
retrievedSession, err = storage.Get(sessionID)
|
|
|
require.NoError(t, err, "Should get session after delete")
|
|
|
assert.Nil(t, retrievedSession, "Session should be nil after delete")
|
|
|
}
|
|
|
|
|
|
// TestSessionFileStorage 测试Session文件存储
|
|
|
func TestSessionFileStorage(t *testing.T) {
|
|
|
// 创建临时目录
|
|
|
tempDir := t.TempDir()
|
|
|
|
|
|
// 创建文件存储
|
|
|
fileStorage, err := storage.NewFileStorage(tempDir, 30*time.Minute)
|
|
|
require.NoError(t, err, "Should create file storage")
|
|
|
|
|
|
// 创建Session
|
|
|
sessionID := "test-session-002"
|
|
|
session, err := fileStorage.Create(sessionID)
|
|
|
require.NoError(t, err, "Should create session")
|
|
|
require.NotNil(t, session, "Session should not be nil")
|
|
|
|
|
|
// 设置值
|
|
|
err = session.Set("key1", "value1")
|
|
|
require.NoError(t, err, "Should set value")
|
|
|
|
|
|
// 保存Session
|
|
|
err = fileStorage.Save(session)
|
|
|
require.NoError(t, err, "Should save session")
|
|
|
|
|
|
// 获取Session
|
|
|
retrievedSession, err := fileStorage.Get(sessionID)
|
|
|
require.NoError(t, err, "Should get session")
|
|
|
require.NotNil(t, retrievedSession, "Retrieved session should not be nil")
|
|
|
|
|
|
// 验证值
|
|
|
val1, err := retrievedSession.Get("key1")
|
|
|
require.NoError(t, err, "Should get key1")
|
|
|
assert.Equal(t, "value1", val1, "Value1 should match")
|
|
|
|
|
|
// 删除Session
|
|
|
err = fileStorage.Delete(sessionID)
|
|
|
require.NoError(t, err, "Should delete session")
|
|
|
}
|
|
|
|
|
|
// TestSessionInContext 测试在Context中使用Session
|
|
|
func TestSessionInContext(t *testing.T) {
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
require.NoError(t, err)
|
|
|
port := listener.Addr().(*net.TCPAddr).Port
|
|
|
listener.Close()
|
|
|
|
|
|
// 创建Session存储
|
|
|
sessionStorage := storage.NewMemoryStorage(30 * time.Minute)
|
|
|
|
|
|
cfg := &nnet.Config{
|
|
|
Addr: fmt.Sprintf("tcp://127.0.0.1:%d", port),
|
|
|
Codec: &nnet.CodecConfig{
|
|
|
DefaultCodec: "json",
|
|
|
},
|
|
|
}
|
|
|
|
|
|
ts := StartTestServerWithRoutes(t, cfg, func(srv nnet.Server) {
|
|
|
// 注册路由,使用Session
|
|
|
srv.Router().RegisterString("session", func(ctx nnet.Context) error {
|
|
|
// 从连接ID生成Session ID
|
|
|
connID := ctx.Connection().ID()
|
|
|
sessionID := "session-" + connID
|
|
|
|
|
|
// 获取或创建Session
|
|
|
session, err := sessionStorage.Get(sessionID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
if session == nil {
|
|
|
session, err = sessionStorage.Create(sessionID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 读取Session值
|
|
|
count, _ := session.Get("count")
|
|
|
countInt := 0
|
|
|
if count != nil {
|
|
|
if c, ok := count.(int); ok {
|
|
|
countInt = c
|
|
|
} else if c, ok := count.(float64); ok {
|
|
|
countInt = int(c)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 增加计数
|
|
|
countInt++
|
|
|
err = session.Set("count", countInt)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// 保存Session
|
|
|
err = sessionStorage.Save(session)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// 返回响应
|
|
|
return ctx.Response().Write(map[string]any{
|
|
|
"session_id": sessionID,
|
|
|
"count": countInt,
|
|
|
})
|
|
|
})
|
|
|
})
|
|
|
defer CleanupTestServer(t, ts)
|
|
|
|
|
|
client := NewTestClient(t, ts.Addr, nil)
|
|
|
defer CleanupTestClient(t, client)
|
|
|
|
|
|
ConnectTestClient(t, client)
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
|
|
// 发送多次请求,验证Session持久化
|
|
|
for i := 1; i <= 3; i++ {
|
|
|
resp := RequestWithTimeout(t, client, []byte("session"), 3*time.Second)
|
|
|
t.Logf("Response %d: %q", i, string(resp))
|
|
|
|
|
|
var result map[string]any
|
|
|
err = json.Unmarshal(resp, &result)
|
|
|
assert.NoError(t, err, "Response should be valid JSON")
|
|
|
assert.Equal(t, float64(i), result["count"], "Count should be %d", i)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// TestSessionExpiration 测试Session过期
|
|
|
func TestSessionExpiration(t *testing.T) {
|
|
|
// 创建内存存储,设置很短的过期时间
|
|
|
storage := storage.NewMemoryStorage(100 * time.Millisecond)
|
|
|
|
|
|
// 创建Session
|
|
|
sessionID := "test-session-expire"
|
|
|
session, err := storage.Create(sessionID)
|
|
|
require.NoError(t, err, "Should create session")
|
|
|
|
|
|
// 设置值
|
|
|
err = session.Set("key1", "value1")
|
|
|
require.NoError(t, err, "Should set value")
|
|
|
|
|
|
// 保存Session
|
|
|
err = storage.Save(session)
|
|
|
require.NoError(t, err, "Should save session")
|
|
|
|
|
|
// 立即获取Session(应该存在)
|
|
|
retrievedSession, err := storage.Get(sessionID)
|
|
|
require.NoError(t, err, "Should get session")
|
|
|
require.NotNil(t, retrievedSession, "Session should exist")
|
|
|
|
|
|
// 等待过期
|
|
|
time.Sleep(150 * time.Millisecond)
|
|
|
|
|
|
// 再次获取Session(应该过期)
|
|
|
retrievedSession, err = storage.Get(sessionID)
|
|
|
require.NoError(t, err, "Should get session after expiration")
|
|
|
assert.Nil(t, retrievedSession, "Session should be nil after expiration")
|
|
|
}
|
|
|
|
|
|
// TestSessionManager 测试Session管理器
|
|
|
func TestSessionManager(t *testing.T) {
|
|
|
// 创建Session实现
|
|
|
sessionID := "test-session-manager"
|
|
|
session := internalsession.NewSession(sessionID)
|
|
|
|
|
|
// 测试基本操作
|
|
|
assert.Equal(t, sessionID, session.ID(), "Session ID should match")
|
|
|
assert.True(t, session.IsNew(), "Session should be new")
|
|
|
|
|
|
// 设置值
|
|
|
err := session.Set("key1", "value1")
|
|
|
require.NoError(t, err, "Should set value")
|
|
|
|
|
|
// 获取值
|
|
|
val1, err := session.Get("key1")
|
|
|
require.NoError(t, err, "Should get value")
|
|
|
assert.Equal(t, "value1", val1, "Value should match")
|
|
|
|
|
|
// 删除值
|
|
|
err = session.Delete("key1")
|
|
|
require.NoError(t, err, "Should delete value")
|
|
|
|
|
|
val1, err = session.Get("key1")
|
|
|
require.NoError(t, err, "Should get value after delete")
|
|
|
assert.Nil(t, val1, "Value should be nil after delete")
|
|
|
|
|
|
// 清空
|
|
|
err = session.Clear()
|
|
|
require.NoError(t, err, "Should clear session")
|
|
|
}
|
|
|
|