You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

219 lines
6.2 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package context
import (
"context"
"testing"
"time"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
responsepkg "github.com/noahlann/nnet/pkg/response"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestContext(t *testing.T) {
parentCtx := context.Background()
conn := &mockConnection{id: "test-conn"}
// 创建 mock Request 和 Response
req := &stubRequest{raw: []byte("test request")}
resp := stubResponse{}
ctx := New(parentCtx, conn, req, resp)
// 测试 Set 和 Get
ctx.Set("key1", "value1")
val := ctx.Get("key1")
require.NotNil(t, val, "Expected key1 to be found")
assert.Equal(t, "value1", val.(string), "Expected 'value1'")
// 测试 GetString
strVal := ctx.GetString("key1")
assert.Equal(t, "value1", strVal, "Expected 'value1'")
// 测试 GetInt
ctx.Set("int1", 42)
intVal := ctx.GetInt("int1")
assert.Equal(t, 42, intVal, "Expected 42")
// 测试 GetBool
ctx.Set("bool1", true)
boolVal := ctx.GetBool("bool1")
assert.True(t, boolVal, "Expected true")
// 测试 Request
assert.Equal(t, "test request", string(ctx.Request().Raw()), "Expected 'test request'")
// 测试 Response
assert.NotNil(t, ctx.Response(), "Expected Response to be non-nil")
// 测试 Connection
assert.Equal(t, "test-conn", ctx.Connection().ID(), "Expected 'test-conn'")
// 测试 MustGet
val = ctx.MustGet("key1")
assert.Equal(t, "value1", val.(string), "Expected 'value1'")
// 测试 MustGet panic
assert.Panics(t, func() {
ctx.MustGet("nonexistent")
}, "Expected panic for missing key")
// 测试 Value (context.Context interface)
val = ctx.Value("key1")
require.NotNil(t, val, "Expected key1 to be found via Value")
assert.Equal(t, "value1", val.(string), "Expected 'value1'")
}
func TestContextWithTimeout(t *testing.T) {
parentCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
conn := &mockConnection{id: "test-conn"}
req := &stubRequest{raw: []byte("test")}
resp := stubResponse{}
ctx := New(parentCtx, conn, req, resp)
// 测试 Deadline
deadline, ok := ctx.Deadline()
assert.True(t, ok, "Expected deadline to be set")
assert.False(t, deadline.IsZero(), "Expected deadline to be non-zero")
// 测试 Done
select {
case <-ctx.Done():
t.Error("Context should not be done yet")
case <-time.After(50 * time.Millisecond):
// OK
}
// 等待超时
time.Sleep(150 * time.Millisecond)
select {
case <-ctx.Done():
// OK
default:
t.Error("Context should be done after timeout")
}
// 测试 Err
assert.NotNil(t, ctx.Err(), "Expected error after timeout")
}
func TestContextGetIntTypes(t *testing.T) {
parentCtx := context.Background()
conn := &mockConnection{id: "test-conn"}
req := &stubRequest{raw: []byte("test")}
resp := stubResponse{}
ctx := New(parentCtx, conn, req, resp)
// 测试 int32
ctx.Set("int32", int32(32))
assert.Equal(t, 32, ctx.GetInt("int32"), "Expected 32 for int32")
// 测试 int64
ctx.Set("int64", int64(64))
assert.Equal(t, 64, ctx.GetInt("int64"), "Expected 64 for int64")
// 测试无效类型
ctx.Set("invalid", "not an int")
assert.Equal(t, 0, ctx.GetInt("invalid"), "Expected 0 for invalid type")
}
func TestContextGetStringInvalidType(t *testing.T) {
parentCtx := context.Background()
conn := &mockConnection{id: "test-conn"}
req := &stubRequest{raw: []byte("test")}
resp := stubResponse{}
ctx := New(parentCtx, conn, req, resp)
// 测试无效类型
ctx.Set("invalid", 123)
assert.Equal(t, "", ctx.GetString("invalid"), "Expected empty string for invalid type")
}
func TestContextGetBoolInvalidType(t *testing.T) {
parentCtx := context.Background()
conn := &mockConnection{id: "test-conn"}
req := &stubRequest{raw: []byte("test")}
resp := stubResponse{}
ctx := New(parentCtx, conn, req, resp)
// 测试无效类型
ctx.Set("invalid", "not a bool")
assert.False(t, ctx.GetBool("invalid"), "Expected false for invalid type")
}
func TestContextValue(t *testing.T) {
// 使用非字符串 key 测试父 context
parentKey := struct{ name string }{"parent_key"}
parentCtx := context.WithValue(context.Background(), parentKey, "parent_value")
conn := &mockConnection{id: "test-conn"}
req := &stubRequest{raw: []byte("test")}
resp := stubResponse{}
ctx := New(parentCtx, conn, req, resp)
// 测试字符串 key从内部 values map 获取)
ctx.Set("ctx_key", "ctx_value")
val := ctx.Value("ctx_key")
require.NotNil(t, val, "Expected ctx_key to be found")
assert.Equal(t, "ctx_value", val.(string), "Expected 'ctx_value'")
// 测试非字符串 key应该从父 context 获取)
val = ctx.Value(parentKey)
require.NotNil(t, val, "Expected parent_key to be found from parent context")
assert.Equal(t, "parent_value", val.(string), "Expected 'parent_value'")
}
type stubRequest struct {
raw []byte
}
func (s *stubRequest) Raw() []byte { return s.raw }
func (s *stubRequest) Data() interface{} { return nil }
func (s *stubRequest) Bind(interface{}) error { return nil }
func (s *stubRequest) Header() protocolpkg.FrameHeader { return nil }
func (s *stubRequest) DataBytes() []byte { return nil }
func (s *stubRequest) Protocol() protocolpkg.Protocol { return nil }
func (s *stubRequest) ProtocolVersion() string { return "" }
type stubResponse struct{}
func (stubResponse) Write(interface{}) error { return nil }
func (stubResponse) WriteString(string) error { return nil }
func (stubResponse) WriteWithCodec(interface{}, string) error { return nil }
func (stubResponse) WriteBytes([]byte) error { return nil }
func (stubResponse) Header() responsepkg.FrameHeader { return nil }
func (stubResponse) SetHeader(responsepkg.FrameHeader) {}
func (stubResponse) Protocol() protocolpkg.Protocol { return nil }
// mockConnection 模拟连接
type mockConnection struct {
id string
}
func (m *mockConnection) ID() string {
return m.id
}
func (m *mockConnection) RemoteAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnection) LocalAddr() string {
return "127.0.0.1:6995"
}
func (m *mockConnection) Write(data []byte) error {
return nil
}
func (m *mockConnection) Close() error {
return nil
}