package test import ( "context" "testing" "time" "github.com/noahlann/nnet/internal/codec" "github.com/noahlann/nnet/internal/request" "github.com/noahlann/nnet/internal/response" ctxpkg "github.com/noahlann/nnet/pkg/context" ) func TestContext(t *testing.T) { parentCtx := context.Background() conn := &mockConnection{id: "test-conn"} // 创建 mock Request 和 Response req := request.New([]byte("test request"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(parentCtx, conn, req, resp) // 测试 Set 和 Get ctx.Set("key1", "value1") val := ctx.Get("key1") if val == nil { t.Error("Expected key1 to be found") } if val.(string) != "value1" { t.Error("Expected 'value1', got:", val) } // 测试 GetString strVal := ctx.GetString("key1") if strVal != "value1" { t.Error("Expected 'value1', got:", strVal) } // 测试 GetInt ctx.Set("int1", 42) intVal := ctx.GetInt("int1") if intVal != 42 { t.Error("Expected 42, got:", intVal) } // 测试 GetBool ctx.Set("bool1", true) boolVal := ctx.GetBool("bool1") if !boolVal { t.Error("Expected true, got:", boolVal) } // 测试 Request if string(ctx.Request().Raw()) != "test request" { t.Error("Expected 'test request'") } // 测试 Response if ctx.Response() == nil { t.Error("Expected Response to be non-nil") } // 测试 Connection if ctx.Connection().ID() != "test-conn" { t.Error("Expected 'test-conn'") } // 测试 MustGet val = ctx.MustGet("key1") if val.(string) != "value1" { t.Error("Expected 'value1'") } // 测试 MustGet panic func() { defer func() { if r := recover(); r == nil { t.Error("Expected panic for missing key") } }() ctx.MustGet("nonexistent") }() // 测试 Value (context.Context interface) val = ctx.Value("key1") if val == nil { t.Error("Expected key1 to be found via Value") } } func TestContextWithTimeout(t *testing.T) { parentCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(parentCtx, conn, req, resp) // 测试 Deadline deadline, ok := ctx.Deadline() if !ok { t.Error("Expected deadline to be set") } if deadline.IsZero() { t.Error("Expected deadline to be non-zero") } // 测试 Done select { case <-ctx.Done(): t.Error("Context should not be done yet") case <-time.After(50 * time.Millisecond): // OK } // 等待超时 time.Sleep(150 * time.Millisecond) select { case <-ctx.Done(): // OK default: t.Error("Context should be done after timeout") } // 测试 Err if ctx.Err() == nil { t.Error("Expected error after timeout") } } func TestContextGetIntTypes(t *testing.T) { parentCtx := context.Background() conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(parentCtx, conn, req, resp) // 测试 int32 ctx.Set("int32", int32(32)) if ctx.GetInt("int32") != 32 { t.Error("Expected 32 for int32") } // 测试 int64 ctx.Set("int64", int64(64)) if ctx.GetInt("int64") != 64 { t.Error("Expected 64 for int64") } // 测试无效类型 ctx.Set("invalid", "not an int") if ctx.GetInt("invalid") != 0 { t.Error("Expected 0 for invalid type") } } func TestContextGetStringInvalidType(t *testing.T) { parentCtx := context.Background() conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(parentCtx, conn, req, resp) // 测试无效类型 ctx.Set("invalid", 123) if ctx.GetString("invalid") != "" { t.Error("Expected empty string for invalid type") } } func TestContextGetBoolInvalidType(t *testing.T) { parentCtx := context.Background() conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(parentCtx, conn, req, resp) // 测试无效类型 ctx.Set("invalid", "not a bool") if ctx.GetBool("invalid") { t.Error("Expected false for invalid type") } } func TestContextValue(t *testing.T) { // 使用非字符串 key 测试父 context parentKey := struct{ name string }{"parent_key"} parentCtx := context.WithValue(context.Background(), parentKey, "parent_value") conn := &mockConnection{id: "test-conn"} req := request.New([]byte("test"), nil) codecRegistry := codec.NewRegistry() resp := response.New(conn, nil, codecRegistry, "json") ctx := ctxpkg.New(parentCtx, conn, req, resp) // 测试字符串 key(从内部 values map 获取) ctx.Set("ctx_key", "ctx_value") val := ctx.Value("ctx_key") if val == nil { t.Error("Expected ctx_key to be found") } if val.(string) != "ctx_value" { t.Errorf("Expected 'ctx_value', got: %v", val) } // 测试非字符串 key(应该从父 context 获取) val = ctx.Value(parentKey) if val == nil { t.Error("Expected parent_key to be found from parent context") } if val.(string) != "parent_value" { t.Errorf("Expected 'parent_value', got: %v", val) } } // mockConnection 模拟连接 type mockConnection struct { id string } func (m *mockConnection) ID() string { return m.id } func (m *mockConnection) RemoteAddr() string { return "127.0.0.1:6995" } func (m *mockConnection) LocalAddr() string { return "127.0.0.1:6995" } func (m *mockConnection) Write(data []byte) error { return nil } func (m *mockConnection) Close() error { return nil }