|
|
package context
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
responsepkg "github.com/noahlann/nnet/pkg/response"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
func TestContext(t *testing.T) {
|
|
|
parentCtx := context.Background()
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
|
|
|
// 创建 mock Request 和 Response
|
|
|
req := &stubRequest{raw: []byte("test request")}
|
|
|
resp := stubResponse{}
|
|
|
|
|
|
ctx := New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试 Set 和 Get
|
|
|
ctx.Set("key1", "value1")
|
|
|
val := ctx.Get("key1")
|
|
|
require.NotNil(t, val, "Expected key1 to be found")
|
|
|
assert.Equal(t, "value1", val.(string), "Expected 'value1'")
|
|
|
|
|
|
// 测试 GetString
|
|
|
strVal := ctx.GetString("key1")
|
|
|
assert.Equal(t, "value1", strVal, "Expected 'value1'")
|
|
|
|
|
|
// 测试 GetInt
|
|
|
ctx.Set("int1", 42)
|
|
|
intVal := ctx.GetInt("int1")
|
|
|
assert.Equal(t, 42, intVal, "Expected 42")
|
|
|
|
|
|
// 测试 GetBool
|
|
|
ctx.Set("bool1", true)
|
|
|
boolVal := ctx.GetBool("bool1")
|
|
|
assert.True(t, boolVal, "Expected true")
|
|
|
|
|
|
// 测试 Request
|
|
|
assert.Equal(t, "test request", string(ctx.Request().Raw()), "Expected 'test request'")
|
|
|
|
|
|
// 测试 Response
|
|
|
assert.NotNil(t, ctx.Response(), "Expected Response to be non-nil")
|
|
|
|
|
|
// 测试 Connection
|
|
|
assert.Equal(t, "test-conn", ctx.Connection().ID(), "Expected 'test-conn'")
|
|
|
|
|
|
// 测试 MustGet
|
|
|
val = ctx.MustGet("key1")
|
|
|
assert.Equal(t, "value1", val.(string), "Expected 'value1'")
|
|
|
|
|
|
// 测试 MustGet panic
|
|
|
assert.Panics(t, func() {
|
|
|
ctx.MustGet("nonexistent")
|
|
|
}, "Expected panic for missing key")
|
|
|
|
|
|
// 测试 Value (context.Context interface)
|
|
|
val = ctx.Value("key1")
|
|
|
require.NotNil(t, val, "Expected key1 to be found via Value")
|
|
|
assert.Equal(t, "value1", val.(string), "Expected 'value1'")
|
|
|
}
|
|
|
|
|
|
func TestContextWithTimeout(t *testing.T) {
|
|
|
parentCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
|
defer cancel()
|
|
|
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := &stubRequest{raw: []byte("test")}
|
|
|
resp := stubResponse{}
|
|
|
|
|
|
ctx := New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试 Deadline
|
|
|
deadline, ok := ctx.Deadline()
|
|
|
assert.True(t, ok, "Expected deadline to be set")
|
|
|
assert.False(t, deadline.IsZero(), "Expected deadline to be non-zero")
|
|
|
|
|
|
// 测试 Done
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
t.Error("Context should not be done yet")
|
|
|
case <-time.After(50 * time.Millisecond):
|
|
|
// OK
|
|
|
}
|
|
|
|
|
|
// 等待超时
|
|
|
time.Sleep(150 * time.Millisecond)
|
|
|
select {
|
|
|
case <-ctx.Done():
|
|
|
// OK
|
|
|
default:
|
|
|
t.Error("Context should be done after timeout")
|
|
|
}
|
|
|
|
|
|
// 测试 Err
|
|
|
assert.NotNil(t, ctx.Err(), "Expected error after timeout")
|
|
|
}
|
|
|
|
|
|
func TestContextGetIntTypes(t *testing.T) {
|
|
|
parentCtx := context.Background()
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := &stubRequest{raw: []byte("test")}
|
|
|
resp := stubResponse{}
|
|
|
|
|
|
ctx := New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试 int32
|
|
|
ctx.Set("int32", int32(32))
|
|
|
assert.Equal(t, 32, ctx.GetInt("int32"), "Expected 32 for int32")
|
|
|
|
|
|
// 测试 int64
|
|
|
ctx.Set("int64", int64(64))
|
|
|
assert.Equal(t, 64, ctx.GetInt("int64"), "Expected 64 for int64")
|
|
|
|
|
|
// 测试无效类型
|
|
|
ctx.Set("invalid", "not an int")
|
|
|
assert.Equal(t, 0, ctx.GetInt("invalid"), "Expected 0 for invalid type")
|
|
|
}
|
|
|
|
|
|
func TestContextGetStringInvalidType(t *testing.T) {
|
|
|
parentCtx := context.Background()
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := &stubRequest{raw: []byte("test")}
|
|
|
resp := stubResponse{}
|
|
|
|
|
|
ctx := New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试无效类型
|
|
|
ctx.Set("invalid", 123)
|
|
|
assert.Equal(t, "", ctx.GetString("invalid"), "Expected empty string for invalid type")
|
|
|
}
|
|
|
|
|
|
func TestContextGetBoolInvalidType(t *testing.T) {
|
|
|
parentCtx := context.Background()
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := &stubRequest{raw: []byte("test")}
|
|
|
resp := stubResponse{}
|
|
|
|
|
|
ctx := New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试无效类型
|
|
|
ctx.Set("invalid", "not a bool")
|
|
|
assert.False(t, ctx.GetBool("invalid"), "Expected false for invalid type")
|
|
|
}
|
|
|
|
|
|
func TestContextValue(t *testing.T) {
|
|
|
// 使用非字符串 key 测试父 context
|
|
|
parentKey := struct{ name string }{"parent_key"}
|
|
|
parentCtx := context.WithValue(context.Background(), parentKey, "parent_value")
|
|
|
conn := &mockConnection{id: "test-conn"}
|
|
|
req := &stubRequest{raw: []byte("test")}
|
|
|
resp := stubResponse{}
|
|
|
|
|
|
ctx := New(parentCtx, conn, req, resp)
|
|
|
|
|
|
// 测试字符串 key(从内部 values map 获取)
|
|
|
ctx.Set("ctx_key", "ctx_value")
|
|
|
val := ctx.Value("ctx_key")
|
|
|
require.NotNil(t, val, "Expected ctx_key to be found")
|
|
|
assert.Equal(t, "ctx_value", val.(string), "Expected 'ctx_value'")
|
|
|
|
|
|
// 测试非字符串 key(应该从父 context 获取)
|
|
|
val = ctx.Value(parentKey)
|
|
|
require.NotNil(t, val, "Expected parent_key to be found from parent context")
|
|
|
assert.Equal(t, "parent_value", val.(string), "Expected 'parent_value'")
|
|
|
}
|
|
|
|
|
|
type stubRequest struct {
|
|
|
raw []byte
|
|
|
}
|
|
|
|
|
|
func (s *stubRequest) Raw() []byte { return s.raw }
|
|
|
func (s *stubRequest) Data() interface{} { return nil }
|
|
|
func (s *stubRequest) Bind(interface{}) error { return nil }
|
|
|
func (s *stubRequest) Header() protocolpkg.FrameHeader { return nil }
|
|
|
func (s *stubRequest) DataBytes() []byte { return nil }
|
|
|
func (s *stubRequest) Protocol() protocolpkg.Protocol { return nil }
|
|
|
func (s *stubRequest) ProtocolVersion() string { return "" }
|
|
|
|
|
|
type stubResponse struct{}
|
|
|
|
|
|
func (stubResponse) Write(interface{}) error { return nil }
|
|
|
func (stubResponse) WriteString(string) error { return nil }
|
|
|
func (stubResponse) WriteWithCodec(interface{}, string) error { return nil }
|
|
|
func (stubResponse) WriteBytes([]byte) error { return nil }
|
|
|
func (stubResponse) Header() responsepkg.FrameHeader { return nil }
|
|
|
func (stubResponse) SetHeader(responsepkg.FrameHeader) {}
|
|
|
func (stubResponse) Protocol() protocolpkg.Protocol { return nil }
|
|
|
|
|
|
// mockConnection 模拟连接
|
|
|
type mockConnection struct {
|
|
|
id string
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) ID() string {
|
|
|
return m.id
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) RemoteAddr() string {
|
|
|
return "127.0.0.1:6995"
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) LocalAddr() string {
|
|
|
return "127.0.0.1:6995"
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) Write(data []byte) error {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
func (m *mockConnection) Close() error {
|
|
|
return nil
|
|
|
}
|