You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

251 lines
5.7 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package test
import (
"context"
"testing"
"time"
"github.com/noahlann/nnet/internal/codec"
"github.com/noahlann/nnet/internal/request"
"github.com/noahlann/nnet/internal/response"
ctxpkg "github.com/noahlann/nnet/pkg/context"
)
func TestContext(t *testing.T) {
parentCtx := context.Background()
conn := &mockConnection{id: "test-conn"}
// 创建 mock Request 和 Response
req := request.New([]byte("test request"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(parentCtx, conn, req, resp)
// 测试 Set 和 Get
ctx.Set("key1", "value1")
val := ctx.Get("key1")
if val == nil {
t.Error("Expected key1 to be found")
}
if val.(string) != "value1" {
t.Error("Expected 'value1', got:", val)
}
// 测试 GetString
strVal := ctx.GetString("key1")
if strVal != "value1" {
t.Error("Expected 'value1', got:", strVal)
}
// 测试 GetInt
ctx.Set("int1", 42)
intVal := ctx.GetInt("int1")
if intVal != 42 {
t.Error("Expected 42, got:", intVal)
}
// 测试 GetBool
ctx.Set("bool1", true)
boolVal := ctx.GetBool("bool1")
if !boolVal {
t.Error("Expected true, got:", boolVal)
}
// 测试 Request
if string(ctx.Request().Raw()) != "test request" {
t.Error("Expected 'test request'")
}
// 测试 Response
if ctx.Response() == nil {
t.Error("Expected Response to be non-nil")
}
// 测试 Connection
if ctx.Connection().ID() != "test-conn" {
t.Error("Expected 'test-conn'")
}
// 测试 MustGet
val = ctx.MustGet("key1")
if val.(string) != "value1" {
t.Error("Expected 'value1'")
}
// 测试 MustGet panic
func() {
defer func() {
if r := recover(); r == nil {
t.Error("Expected panic for missing key")
}
}()
ctx.MustGet("nonexistent")
}()
// 测试 Value (context.Context interface)
val = ctx.Value("key1")
if val == nil {
t.Error("Expected key1 to be found via Value")
}
}
func TestContextWithTimeout(t *testing.T) {
parentCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(parentCtx, conn, req, resp)
// 测试 Deadline
deadline, ok := ctx.Deadline()
if !ok {
t.Error("Expected deadline to be set")
}
if deadline.IsZero() {
t.Error("Expected deadline to be non-zero")
}
// 测试 Done
select {
case <-ctx.Done():
t.Error("Context should not be done yet")
case <-time.After(50 * time.Millisecond):
// OK
}
// 等待超时
time.Sleep(150 * time.Millisecond)
select {
case <-ctx.Done():
// OK
default:
t.Error("Context should be done after timeout")
}
// 测试 Err
if ctx.Err() == nil {
t.Error("Expected error after timeout")
}
}
func TestContextGetIntTypes(t *testing.T) {
parentCtx := context.Background()
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(parentCtx, conn, req, resp)
// 测试 int32
ctx.Set("int32", int32(32))
if ctx.GetInt("int32") != 32 {
t.Error("Expected 32 for int32")
}
// 测试 int64
ctx.Set("int64", int64(64))
if ctx.GetInt("int64") != 64 {
t.Error("Expected 64 for int64")
}
// 测试无效类型
ctx.Set("invalid", "not an int")
if ctx.GetInt("invalid") != 0 {
t.Error("Expected 0 for invalid type")
}
}
func TestContextGetStringInvalidType(t *testing.T) {
parentCtx := context.Background()
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(parentCtx, conn, req, resp)
// 测试无效类型
ctx.Set("invalid", 123)
if ctx.GetString("invalid") != "" {
t.Error("Expected empty string for invalid type")
}
}
func TestContextGetBoolInvalidType(t *testing.T) {
parentCtx := context.Background()
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(parentCtx, conn, req, resp)
// 测试无效类型
ctx.Set("invalid", "not a bool")
if ctx.GetBool("invalid") {
t.Error("Expected false for invalid type")
}
}
func TestContextValue(t *testing.T) {
// 使用非字符串 key 测试父 context
parentKey := struct{ name string }{"parent_key"}
parentCtx := context.WithValue(context.Background(), parentKey, "parent_value")
conn := &mockConnection{id: "test-conn"}
req := request.New([]byte("test"), nil)
codecRegistry := codec.NewRegistry()
resp := response.New(conn, nil, codecRegistry, "json")
ctx := ctxpkg.New(parentCtx, conn, req, resp)
// 测试字符串 key从内部 values map 获取)
ctx.Set("ctx_key", "ctx_value")
val := ctx.Value("ctx_key")
if val == nil {
t.Error("Expected ctx_key to be found")
}
if val.(string) != "ctx_value" {
t.Errorf("Expected 'ctx_value', got: %v", val)
}
// 测试非字符串 key应该从父 context 获取)
val = ctx.Value(parentKey)
if val == nil {
t.Error("Expected parent_key to be found from parent context")
}
if val.(string) != "parent_value" {
t.Errorf("Expected 'parent_value', got: %v", val)
}
}
// mockConnection 模拟连接
type mockConnection struct {
id string
}
func (m *mockConnection) ID() string {
return m.id
}
func (m *mockConnection) RemoteAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnection) LocalAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnection) Write(data []byte) error {
return nil
}
func (m *mockConnection) Close() error {
return nil
}