|
|
package test
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/internal/codec"
|
|
|
"github.com/noahlann/nnet/internal/request"
|
|
|
"github.com/noahlann/nnet/internal/response"
|
|
|
ctxpkg "github.com/noahlann/nnet/pkg/context"
|
|
|
)
|
|
|
|
|
|
func TestContext(t *testing.T) {
|
|
|
parentCtx := context.Background()
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
|
|
|
// 创建 mock Request 和 Response
|
|
|
req := request.New([]byte("test request"), nil)
|
|
|
codecRegistry := codec.NewRegistry()
|
|
|
resp := response.New(conn, nil, codecRegistry, "json")
|
|
|
|
|
|
ctx := ctxpkg.New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试 Set 和 Get
|
|
|
ctx.Set("key1", "value1")
|
|
|
val := ctx.Get("key1")
|
|
|
if val == nil {
|
|
|
t.Error("Expected key1 to be found")
|
|
|
}
|
|
|
if val.(string) != "value1" {
|
|
|
t.Error("Expected 'value1', got:", val)
|
|
|
}
|
|
|
|
|
|
// 测试 GetString
|
|
|
strVal := ctx.GetString("key1")
|
|
|
if strVal != "value1" {
|
|
|
t.Error("Expected 'value1', got:", strVal)
|
|
|
}
|
|
|
|
|
|
// 测试 GetInt
|
|
|
ctx.Set("int1", 42)
|
|
|
intVal := ctx.GetInt("int1")
|
|
|
if intVal != 42 {
|
|
|
t.Error("Expected 42, got:", intVal)
|
|
|
}
|
|
|
|
|
|
// 测试 GetBool
|
|
|
ctx.Set("bool1", true)
|
|
|
boolVal := ctx.GetBool("bool1")
|
|
|
if !boolVal {
|
|
|
t.Error("Expected true, got:", boolVal)
|
|
|
}
|
|
|
|
|
|
// 测试 Request
|
|
|
if string(ctx.Request().Raw()) != "test request" {
|
|
|
t.Error("Expected 'test request'")
|
|
|
}
|
|
|
|
|
|
// 测试 Response
|
|
|
if ctx.Response() == nil {
|
|
|
t.Error("Expected Response to be non-nil")
|
|
|
}
|
|
|
|
|
|
// 测试 Connection
|
|
|
if ctx.Connection().ID() != "test-conn" {
|
|
|
t.Error("Expected 'test-conn'")
|
|
|
}
|
|
|
|
|
|
// 测试 MustGet
|
|
|
val = ctx.MustGet("key1")
|
|
|
if val.(string) != "value1" {
|
|
|
t.Error("Expected 'value1'")
|
|
|
}
|
|
|
|
|
|
// 测试 MustGet panic
|
|
|
func() {
|
|
|
defer func() {
|
|
|
if r := recover(); r == nil {
|
|
|
t.Error("Expected panic for missing key")
|
|
|
}
|
|
|
}()
|
|
|
ctx.MustGet("nonexistent")
|
|
|
}()
|
|
|
|
|
|
// 测试 Value (context.Context interface)
|
|
|
val = ctx.Value("key1")
|
|
|
if val == nil {
|
|
|
t.Error("Expected key1 to be found via Value")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestContextWithTimeout(t *testing.T) {
|
|
|
parentCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
|
defer cancel()
|
|
|
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := request.New([]byte("test"), nil)
|
|
|
codecRegistry := codec.NewRegistry()
|
|
|
resp := response.New(conn, nil, codecRegistry, "json")
|
|
|
|
|
|
ctx := ctxpkg.New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试 Deadline
|
|
|
deadline, ok := ctx.Deadline()
|
|
|
if !ok {
|
|
|
t.Error("Expected deadline to be set")
|
|
|
}
|
|
|
if deadline.IsZero() {
|
|
|
t.Error("Expected deadline to be non-zero")
|
|
|
}
|
|
|
|
|
|
// 测试 Done
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
t.Error("Context should not be done yet")
|
|
|
case <-time.After(50 * time.Millisecond):
|
|
|
// OK
|
|
|
}
|
|
|
|
|
|
// 等待超时
|
|
|
time.Sleep(150 * time.Millisecond)
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
// OK
|
|
|
default:
|
|
|
t.Error("Context should be done after timeout")
|
|
|
}
|
|
|
|
|
|
// 测试 Err
|
|
|
if ctx.Err() == nil {
|
|
|
t.Error("Expected error after timeout")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestContextGetIntTypes(t *testing.T) {
|
|
|
parentCtx := context.Background()
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := request.New([]byte("test"), nil)
|
|
|
codecRegistry := codec.NewRegistry()
|
|
|
resp := response.New(conn, nil, codecRegistry, "json")
|
|
|
|
|
|
ctx := ctxpkg.New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试 int32
|
|
|
ctx.Set("int32", int32(32))
|
|
|
if ctx.GetInt("int32") != 32 {
|
|
|
t.Error("Expected 32 for int32")
|
|
|
}
|
|
|
|
|
|
// 测试 int64
|
|
|
ctx.Set("int64", int64(64))
|
|
|
if ctx.GetInt("int64") != 64 {
|
|
|
t.Error("Expected 64 for int64")
|
|
|
}
|
|
|
|
|
|
// 测试无效类型
|
|
|
ctx.Set("invalid", "not an int")
|
|
|
if ctx.GetInt("invalid") != 0 {
|
|
|
t.Error("Expected 0 for invalid type")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestContextGetStringInvalidType(t *testing.T) {
|
|
|
parentCtx := context.Background()
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := request.New([]byte("test"), nil)
|
|
|
codecRegistry := codec.NewRegistry()
|
|
|
resp := response.New(conn, nil, codecRegistry, "json")
|
|
|
|
|
|
ctx := ctxpkg.New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试无效类型
|
|
|
ctx.Set("invalid", 123)
|
|
|
if ctx.GetString("invalid") != "" {
|
|
|
t.Error("Expected empty string for invalid type")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestContextGetBoolInvalidType(t *testing.T) {
|
|
|
parentCtx := context.Background()
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := request.New([]byte("test"), nil)
|
|
|
codecRegistry := codec.NewRegistry()
|
|
|
resp := response.New(conn, nil, codecRegistry, "json")
|
|
|
|
|
|
ctx := ctxpkg.New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试无效类型
|
|
|
ctx.Set("invalid", "not a bool")
|
|
|
if ctx.GetBool("invalid") {
|
|
|
t.Error("Expected false for invalid type")
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestContextValue(t *testing.T) {
|
|
|
// 使用非字符串 key 测试父 context
|
|
|
parentKey := struct{ name string }{"parent_key"}
|
|
|
parentCtx := context.WithValue(context.Background(), parentKey, "parent_value")
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := request.New([]byte("test"), nil)
|
|
|
codecRegistry := codec.NewRegistry()
|
|
|
resp := response.New(conn, nil, codecRegistry, "json")
|
|
|
|
|
|
ctx := ctxpkg.New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试字符串 key(从内部 values map 获取)
|
|
|
ctx.Set("ctx_key", "ctx_value")
|
|
|
val := ctx.Value("ctx_key")
|
|
|
if val == nil {
|
|
|
t.Error("Expected ctx_key to be found")
|
|
|
}
|
|
|
if val.(string) != "ctx_value" {
|
|
|
t.Errorf("Expected 'ctx_value', got: %v", val)
|
|
|
}
|
|
|
|
|
|
// 测试非字符串 key(应该从父 context 获取)
|
|
|
val = ctx.Value(parentKey)
|
|
|
if val == nil {
|
|
|
t.Error("Expected parent_key to be found from parent context")
|
|
|
}
|
|
|
if val.(string) != "parent_value" {
|
|
|
t.Errorf("Expected 'parent_value', got: %v", val)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// mockConnection 模拟连接
|
|
|
type mockConnection struct {
|
|
|
id string
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) ID() string {
|
|
|
return m.id
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) RemoteAddr() string {
|
|
|
return "127.0.0.1:6995"
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) LocalAddr() string {
|
|
|
return "127.0.0.1:6995"
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) Write(data []byte) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) Close() error {
|
|
|
return nil
|
|
|
}
|