package context import ( "context" "testing" "time" protocolpkg "github.com/noahlann/nnet/pkg/protocol" responsepkg "github.com/noahlann/nnet/pkg/response" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestContext(t *testing.T) { parentCtx := context.Background() conn := &mockConnection{id: "test-conn"} // 创建 mock Request 和 Response req := &stubRequest{raw: []byte("test request")} resp := stubResponse{} ctx := New(parentCtx, conn, req, resp) // 测试 Set 和 Get ctx.Set("key1", "value1") val := ctx.Get("key1") require.NotNil(t, val, "Expected key1 to be found") assert.Equal(t, "value1", val.(string), "Expected 'value1'") // 测试 GetString strVal := ctx.GetString("key1") assert.Equal(t, "value1", strVal, "Expected 'value1'") // 测试 GetInt ctx.Set("int1", 42) intVal := ctx.GetInt("int1") assert.Equal(t, 42, intVal, "Expected 42") // 测试 GetBool ctx.Set("bool1", true) boolVal := ctx.GetBool("bool1") assert.True(t, boolVal, "Expected true") // 测试 Request assert.Equal(t, "test request", string(ctx.Request().Raw()), "Expected 'test request'") // 测试 Response assert.NotNil(t, ctx.Response(), "Expected Response to be non-nil") // 测试 Connection assert.Equal(t, "test-conn", ctx.Connection().ID(), "Expected 'test-conn'") // 测试 MustGet val = ctx.MustGet("key1") assert.Equal(t, "value1", val.(string), "Expected 'value1'") // 测试 MustGet panic assert.Panics(t, func() { ctx.MustGet("nonexistent") }, "Expected panic for missing key") // 测试 Value (context.Context interface) val = ctx.Value("key1") require.NotNil(t, val, "Expected key1 to be found via Value") assert.Equal(t, "value1", val.(string), "Expected 'value1'") } func TestContextWithTimeout(t *testing.T) { parentCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() conn := &mockConnection{id: "test-conn"} req := &stubRequest{raw: []byte("test")} resp := stubResponse{} ctx := New(parentCtx, conn, req, resp) // 测试 Deadline deadline, ok := ctx.Deadline() assert.True(t, ok, "Expected deadline to be set") assert.False(t, deadline.IsZero(), "Expected deadline to be non-zero") // 测试 Done select { case <-ctx.Done(): t.Error("Context should not be done yet") case <-time.After(50 * time.Millisecond): // OK } // 等待超时 time.Sleep(150 * time.Millisecond) select { case <-ctx.Done(): // OK default: t.Error("Context should be done after timeout") } // 测试 Err assert.NotNil(t, ctx.Err(), "Expected error after timeout") } func TestContextGetIntTypes(t *testing.T) { parentCtx := context.Background() conn := &mockConnection{id: "test-conn"} req := &stubRequest{raw: []byte("test")} resp := stubResponse{} ctx := New(parentCtx, conn, req, resp) // 测试 int32 ctx.Set("int32", int32(32)) assert.Equal(t, 32, ctx.GetInt("int32"), "Expected 32 for int32") // 测试 int64 ctx.Set("int64", int64(64)) assert.Equal(t, 64, ctx.GetInt("int64"), "Expected 64 for int64") // 测试无效类型 ctx.Set("invalid", "not an int") assert.Equal(t, 0, ctx.GetInt("invalid"), "Expected 0 for invalid type") } func TestContextGetStringInvalidType(t *testing.T) { parentCtx := context.Background() conn := &mockConnection{id: "test-conn"} req := &stubRequest{raw: []byte("test")} resp := stubResponse{} ctx := New(parentCtx, conn, req, resp) // 测试无效类型 ctx.Set("invalid", 123) assert.Equal(t, "", ctx.GetString("invalid"), "Expected empty string for invalid type") } func TestContextGetBoolInvalidType(t *testing.T) { parentCtx := context.Background() conn := &mockConnection{id: "test-conn"} req := &stubRequest{raw: []byte("test")} resp := stubResponse{} ctx := New(parentCtx, conn, req, resp) // 测试无效类型 ctx.Set("invalid", "not a bool") assert.False(t, ctx.GetBool("invalid"), "Expected false for invalid type") } func TestContextValue(t *testing.T) { // 使用非字符串 key 测试父 context parentKey := struct{ name string }{"parent_key"} parentCtx := context.WithValue(context.Background(), parentKey, "parent_value") conn := &mockConnection{id: "test-conn"} req := &stubRequest{raw: []byte("test")} resp := stubResponse{} ctx := New(parentCtx, conn, req, resp) // 测试字符串 key(从内部 values map 获取) ctx.Set("ctx_key", "ctx_value") val := ctx.Value("ctx_key") require.NotNil(t, val, "Expected ctx_key to be found") assert.Equal(t, "ctx_value", val.(string), "Expected 'ctx_value'") // 测试非字符串 key(应该从父 context 获取) val = ctx.Value(parentKey) require.NotNil(t, val, "Expected parent_key to be found from parent context") assert.Equal(t, "parent_value", val.(string), "Expected 'parent_value'") } type stubRequest struct { raw []byte } func (s *stubRequest) Raw() []byte { return s.raw } func (s *stubRequest) Data() interface{} { return nil } func (s *stubRequest) Bind(interface{}) error { return nil } func (s *stubRequest) Header() protocolpkg.FrameHeader { return nil } func (s *stubRequest) DataBytes() []byte { return nil } func (s *stubRequest) Protocol() protocolpkg.Protocol { return nil } func (s *stubRequest) ProtocolVersion() string { return "" } type stubResponse struct{} func (stubResponse) Write(interface{}) error { return nil } func (stubResponse) WriteString(string) error { return nil } func (stubResponse) WriteWithCodec(interface{}, string) error { return nil } func (stubResponse) WriteBytes([]byte) error { return nil } func (stubResponse) Header() responsepkg.FrameHeader { return nil } func (stubResponse) SetHeader(responsepkg.FrameHeader) {} func (stubResponse) Protocol() protocolpkg.Protocol { return nil } // mockConnection 模拟连接 type mockConnection struct { id string } func (m *mockConnection) ID() string { return m.id } func (m *mockConnection) RemoteAddr() string { return "127.0.0.1:6995" } func (m *mockConnection) LocalAddr() string { return "127.0.0.1:6995" } func (m *mockConnection) Write(data []byte) error { return nil } func (m *mockConnection) Close() error { return nil }