|
|
package response
|
|
|
|
|
|
import (
|
|
|
"fmt"
|
|
|
|
|
|
"github.com/noahlann/nnet/pkg/codec"
|
|
|
ctxpkg "github.com/noahlann/nnet/pkg/context"
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
responsepkg "github.com/noahlann/nnet/pkg/response"
|
|
|
)
|
|
|
|
|
|
// connectionWriter 连接写入接口(避免循环依赖)
|
|
|
type connectionWriter interface {
|
|
|
Write(data []byte) error
|
|
|
}
|
|
|
|
|
|
// responseImpl 响应实现
|
|
|
type responseImpl struct {
|
|
|
conn connectionWriter
|
|
|
protocol protocolpkg.Protocol
|
|
|
codecRegistry codec.Registry
|
|
|
codecResolverChain *codec.ResolverChain
|
|
|
header responsepkg.FrameHeader
|
|
|
ctx ctxpkg.Context // 用于resolver访问context(通过SetContext设置)
|
|
|
defaultCodecName string
|
|
|
}
|
|
|
|
|
|
// New 创建新的响应对象,使用默认编码器名称配置resolver链
|
|
|
func New(conn connectionWriter, protocol protocolpkg.Protocol, codecRegistry codec.Registry, defaultCodec string) responsepkg.Response {
|
|
|
var resolverChain *codec.ResolverChain
|
|
|
if codecRegistry != nil {
|
|
|
resolverChain = codec.NewResolverChain(codecRegistry, defaultCodec)
|
|
|
}
|
|
|
return newResponseImpl(conn, protocol, codecRegistry, resolverChain, defaultCodec)
|
|
|
}
|
|
|
|
|
|
// NewWithResolverChain 使用自定义resolver链创建响应对象
|
|
|
func NewWithResolverChain(conn connectionWriter, protocol protocolpkg.Protocol, codecRegistry codec.Registry, resolverChain *codec.ResolverChain, defaultCodec string) responsepkg.Response {
|
|
|
return newResponseImpl(conn, protocol, codecRegistry, resolverChain, defaultCodec)
|
|
|
}
|
|
|
|
|
|
func newResponseImpl(conn connectionWriter, protocol protocolpkg.Protocol, codecRegistry codec.Registry, resolverChain *codec.ResolverChain, defaultCodec string) *responseImpl {
|
|
|
return &responseImpl{
|
|
|
conn: conn,
|
|
|
protocol: protocol,
|
|
|
codecRegistry: codecRegistry,
|
|
|
codecResolverChain: resolverChain,
|
|
|
defaultCodecName: defaultCodec,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// SetContext 设置上下文(供resolver使用)
|
|
|
func (r *responseImpl) SetContext(ctx ctxpkg.Context) {
|
|
|
r.ctx = ctx
|
|
|
}
|
|
|
|
|
|
// Write 写入数据(自动resolve codec进行编码和协议封装)
|
|
|
func (r *responseImpl) Write(data interface{}) error {
|
|
|
// 使用resolver解析codec用于编码
|
|
|
var codec codec.Codec
|
|
|
var err error
|
|
|
if r.codecResolverChain != nil && r.ctx != nil {
|
|
|
// 获取header
|
|
|
var header protocolpkg.FrameHeader
|
|
|
if r.ctx.Request() != nil {
|
|
|
header = r.ctx.Request().Header()
|
|
|
}
|
|
|
codec, err = r.codecResolverChain.ResolveForEncode(r.ctx, data, header)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to resolve codec for encode: %w", err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 如果resolver无法解析,使用默认
|
|
|
if codec == nil {
|
|
|
if r.codecRegistry != nil && r.defaultCodecName != "" {
|
|
|
if namedCodec, err := r.codecRegistry.Get(r.defaultCodecName); err == nil {
|
|
|
codec = namedCodec
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
if codec == nil && r.codecRegistry != nil {
|
|
|
codec = r.codecRegistry.Default()
|
|
|
}
|
|
|
|
|
|
return r.writeWithCodec(data, codec)
|
|
|
}
|
|
|
|
|
|
// WriteWithCodec 使用指定的编解码器写入数据(手动选择codec)
|
|
|
func (r *responseImpl) WriteWithCodec(data interface{}, codecName string) error {
|
|
|
// 获取编解码器
|
|
|
var codec codec.Codec
|
|
|
var err error
|
|
|
if codecName == "" {
|
|
|
codec = r.codecRegistry.Default()
|
|
|
} else {
|
|
|
codec, err = r.codecRegistry.Get(codecName)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to get codec %s: %w", codecName, err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return r.writeWithCodec(data, codec)
|
|
|
}
|
|
|
|
|
|
// writeWithCodec 使用指定的codec写入数据(内部方法)
|
|
|
func (r *responseImpl) writeWithCodec(data interface{}, codec codec.Codec) error {
|
|
|
// 编解码器编码(Go对象 → 字节)
|
|
|
bodyBytes, err := codec.Encode(data)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to encode data: %w", err)
|
|
|
}
|
|
|
|
|
|
// 协议编码(添加帧头)
|
|
|
var frameData []byte
|
|
|
var protocolHeader protocolpkg.FrameHeader
|
|
|
|
|
|
// 如果有响应帧头,转换为协议帧头
|
|
|
if r.header != nil {
|
|
|
protocolHeader = convertResponseHeaderToProtocolHeader(r.header)
|
|
|
}
|
|
|
|
|
|
if r.protocol != nil && r.protocol.HasHeader() {
|
|
|
// 有帧头协议,使用协议编码
|
|
|
frameData, err = r.protocol.Encode(bodyBytes, protocolHeader)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to encode with protocol: %w", err)
|
|
|
}
|
|
|
} else {
|
|
|
// 无帧头协议,直接使用数据(可能添加分隔符)
|
|
|
if r.protocol != nil {
|
|
|
frameData, err = r.protocol.Encode(bodyBytes, nil)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to encode with protocol: %w", err)
|
|
|
}
|
|
|
} else {
|
|
|
frameData = bodyBytes
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 写入连接
|
|
|
return r.conn.Write(frameData)
|
|
|
}
|
|
|
|
|
|
// WriteString 写入字符串数据(便捷方法)
|
|
|
func (r *responseImpl) WriteString(data string) error {
|
|
|
return r.WriteBytes([]byte(data))
|
|
|
}
|
|
|
|
|
|
// WriteBytes 写入原始字节(绕过自动编码,但会进行协议封装)
|
|
|
func (r *responseImpl) WriteBytes(data []byte) error {
|
|
|
// 协议编码(添加帧头)
|
|
|
var frameData []byte
|
|
|
var err error
|
|
|
var protocolHeader protocolpkg.FrameHeader
|
|
|
|
|
|
// 如果有响应帧头,转换为协议帧头
|
|
|
if r.header != nil {
|
|
|
protocolHeader = convertResponseHeaderToProtocolHeader(r.header)
|
|
|
}
|
|
|
|
|
|
if r.protocol != nil && r.protocol.HasHeader() {
|
|
|
frameData, err = r.protocol.Encode(data, protocolHeader)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to encode with protocol: %w", err)
|
|
|
}
|
|
|
} else {
|
|
|
if r.protocol != nil {
|
|
|
frameData, err = r.protocol.Encode(data, nil)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to encode with protocol: %w", err)
|
|
|
}
|
|
|
} else {
|
|
|
frameData = data
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return r.conn.Write(frameData)
|
|
|
}
|
|
|
|
|
|
// Header 获取协议帧头
|
|
|
func (r *responseImpl) Header() responsepkg.FrameHeader {
|
|
|
if r.header == nil && r.protocol != nil && r.protocol.HasHeader() {
|
|
|
// 创建默认帧头
|
|
|
r.header = NewFrameHeader()
|
|
|
}
|
|
|
return r.header
|
|
|
}
|
|
|
|
|
|
// SetHeader 设置协议帧头
|
|
|
func (r *responseImpl) SetHeader(header responsepkg.FrameHeader) {
|
|
|
r.header = header
|
|
|
}
|
|
|
|
|
|
// Protocol 获取协议信息
|
|
|
func (r *responseImpl) Protocol() protocolpkg.Protocol {
|
|
|
return r.protocol
|
|
|
}
|
|
|
|
|
|
// Reset 重置Response状态(用于对象池)
|
|
|
func (r *responseImpl) Reset() {
|
|
|
r.conn = nil
|
|
|
r.protocol = nil
|
|
|
r.codecRegistry = nil
|
|
|
r.codecResolverChain = nil
|
|
|
r.header = nil
|
|
|
r.ctx = nil
|
|
|
}
|