package context import ( "context" "time" "github.com/noahlann/nnet/pkg/request" "github.com/noahlann/nnet/pkg/response" ) // Context 请求上下文接口 // Context只负责状态管理和流程控制,数据读写由Request和Response处理 type Context interface { context.Context // Request 获取请求对象(数据已自动解析) Request() request.Request // Response 获取响应对象(自动编码和协议封装) Response() response.Response // Connection 获取关联的连接 Connection() Connection // Set 设置上下文值 Set(key string, value interface{}) // Get 获取上下文值 Get(key string) interface{} // MustGet 获取上下文值,如果不存在则panic MustGet(key string) interface{} // GetString 获取字符串值 GetString(key string) string // GetInt 获取整数值 GetInt(key string) int // GetBool 获取布尔值 GetBool(key string) bool // Deadline 返回上下文截止时间 Deadline() (time.Time, bool) // Done 返回一个channel,当上下文取消时关闭 Done() <-chan struct{} // Err 返回上下文错误 Err() error // Value 获取上下文值(实现context.Context接口) Value(key interface{}) interface{} } // Connection 连接接口(避免循环依赖,在此定义简化版本) type Connection interface { ID() string RemoteAddr() string LocalAddr() string Write(data []byte) error Close() error } // contextImpl 上下文实现 type contextImpl struct { context.Context conn Connection req request.Request resp response.Response values map[string]interface{} } // New 创建新的上下文 func New(parent context.Context, conn Connection, req request.Request, resp response.Response) Context { return &contextImpl{ Context: parent, conn: conn, req: req, resp: resp, values: make(map[string]interface{}), } } // Connection 获取关联的连接 func (c *contextImpl) Connection() Connection { return c.conn } // Request 获取请求对象 func (c *contextImpl) Request() request.Request { return c.req } // Response 获取响应对象 func (c *contextImpl) Response() response.Response { return c.resp } // Set 设置上下文值 func (c *contextImpl) Set(key string, value interface{}) { if c.values == nil { c.values = make(map[string]interface{}) } c.values[key] = value } // Get 获取上下文值 func (c *contextImpl) Get(key string) interface{} { if c.values == nil { return nil } return c.values[key] } // MustGet 获取上下文值,如果不存在则panic func (c *contextImpl) MustGet(key string) interface{} { value := c.Get(key) if value == nil { panic("key " + key + " does not exist") } return value } // GetString 获取字符串值 func (c *contextImpl) GetString(key string) string { value := c.Get(key) if value == nil { return "" } if s, ok := value.(string); ok { return s } return "" } // GetInt 获取整数值 func (c *contextImpl) GetInt(key string) int { value := c.Get(key) if value == nil { return 0 } switch v := value.(type) { case int: return v case int32: return int(v) case int64: return int(v) } return 0 } // GetBool 获取布尔值 func (c *contextImpl) GetBool(key string) bool { value := c.Get(key) if value == nil { return false } if b, ok := value.(bool); ok { return b } return false } // Deadline 返回上下文截止时间 func (c *contextImpl) Deadline() (time.Time, bool) { return c.Context.Deadline() } // Done 返回一个channel,当上下文取消时关闭 func (c *contextImpl) Done() <-chan struct{} { return c.Context.Done() } // Err 返回上下文错误 func (c *contextImpl) Err() error { return c.Context.Err() } // Value 获取上下文值(实现context.Context接口) func (c *contextImpl) Value(key interface{}) interface{} { if keyStr, ok := key.(string); ok { return c.Get(keyStr) } return c.Context.Value(key) }