You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
6.1 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package response
import (
"fmt"
"github.com/noahlann/nnet/pkg/codec"
ctxpkg "github.com/noahlann/nnet/pkg/context"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
responsepkg "github.com/noahlann/nnet/pkg/response"
)
// connectionWriter 连接写入接口(避免循环依赖)
type connectionWriter interface {
Write(data []byte) error
}
// responseImpl 响应实现
type responseImpl struct {
conn connectionWriter
protocol protocolpkg.Protocol
codecRegistry codec.Registry
codecResolverChain *codec.ResolverChain
header responsepkg.FrameHeader
ctx ctxpkg.Context // 用于resolver访问context通过SetContext设置
defaultCodecName string
}
// New 创建新的响应对象使用默认编码器名称配置resolver链
func New(conn connectionWriter, protocol protocolpkg.Protocol, codecRegistry codec.Registry, defaultCodec string) responsepkg.Response {
var resolverChain *codec.ResolverChain
if codecRegistry != nil {
resolverChain = codec.NewResolverChain(codecRegistry, defaultCodec)
}
return newResponseImpl(conn, protocol, codecRegistry, resolverChain, defaultCodec)
}
// NewWithResolverChain 使用自定义resolver链创建响应对象
func NewWithResolverChain(conn connectionWriter, protocol protocolpkg.Protocol, codecRegistry codec.Registry, resolverChain *codec.ResolverChain, defaultCodec string) responsepkg.Response {
return newResponseImpl(conn, protocol, codecRegistry, resolverChain, defaultCodec)
}
func newResponseImpl(conn connectionWriter, protocol protocolpkg.Protocol, codecRegistry codec.Registry, resolverChain *codec.ResolverChain, defaultCodec string) *responseImpl {
return &responseImpl{
conn: conn,
protocol: protocol,
codecRegistry: codecRegistry,
codecResolverChain: resolverChain,
defaultCodecName: defaultCodec,
}
}
// SetContext 设置上下文供resolver使用
func (r *responseImpl) SetContext(ctx ctxpkg.Context) {
r.ctx = ctx
}
// Write 写入数据自动resolve codec进行编码和协议封装
func (r *responseImpl) Write(data interface{}) error {
// 使用resolver解析codec用于编码
var codec codec.Codec
var err error
if r.codecResolverChain != nil && r.ctx != nil {
// 获取header
var header protocolpkg.FrameHeader
if r.ctx.Request() != nil {
header = r.ctx.Request().Header()
}
codec, err = r.codecResolverChain.ResolveForEncode(r.ctx, data, header)
if err != nil {
return fmt.Errorf("failed to resolve codec for encode: %w", err)
}
}
// 如果resolver无法解析使用默认
if codec == nil {
if r.codecRegistry != nil && r.defaultCodecName != "" {
if namedCodec, err := r.codecRegistry.Get(r.defaultCodecName); err == nil {
codec = namedCodec
}
}
}
if codec == nil && r.codecRegistry != nil {
codec = r.codecRegistry.Default()
}
return r.writeWithCodec(data, codec)
}
// WriteWithCodec 使用指定的编解码器写入数据手动选择codec
func (r *responseImpl) WriteWithCodec(data interface{}, codecName string) error {
// 获取编解码器
var codec codec.Codec
var err error
if codecName == "" {
codec = r.codecRegistry.Default()
} else {
codec, err = r.codecRegistry.Get(codecName)
if err != nil {
return fmt.Errorf("failed to get codec %s: %w", codecName, err)
}
}
return r.writeWithCodec(data, codec)
}
// writeWithCodec 使用指定的codec写入数据内部方法
func (r *responseImpl) writeWithCodec(data interface{}, codec codec.Codec) error {
// 编解码器编码Go对象 → 字节)
bodyBytes, err := codec.Encode(data)
if err != nil {
return fmt.Errorf("failed to encode data: %w", err)
}
// 协议编码(添加帧头)
var frameData []byte
var protocolHeader protocolpkg.FrameHeader
// 如果有响应帧头,转换为协议帧头
if r.header != nil {
protocolHeader = convertResponseHeaderToProtocolHeader(r.header)
}
if r.protocol != nil && r.protocol.HasHeader() {
// 有帧头协议,使用协议编码
frameData, err = r.protocol.Encode(bodyBytes, protocolHeader)
if err != nil {
return fmt.Errorf("failed to encode with protocol: %w", err)
}
} else {
// 无帧头协议,直接使用数据(可能添加分隔符)
if r.protocol != nil {
frameData, err = r.protocol.Encode(bodyBytes, nil)
if err != nil {
return fmt.Errorf("failed to encode with protocol: %w", err)
}
} else {
frameData = bodyBytes
}
}
// 写入连接
return r.conn.Write(frameData)
}
// WriteString 写入字符串数据(便捷方法)
func (r *responseImpl) WriteString(data string) error {
return r.WriteBytes([]byte(data))
}
// WriteBytes 写入原始字节(绕过自动编码,但会进行协议封装)
func (r *responseImpl) WriteBytes(data []byte) error {
// 协议编码(添加帧头)
var frameData []byte
var err error
var protocolHeader protocolpkg.FrameHeader
// 如果有响应帧头,转换为协议帧头
if r.header != nil {
protocolHeader = convertResponseHeaderToProtocolHeader(r.header)
}
if r.protocol != nil && r.protocol.HasHeader() {
frameData, err = r.protocol.Encode(data, protocolHeader)
if err != nil {
return fmt.Errorf("failed to encode with protocol: %w", err)
}
} else {
if r.protocol != nil {
frameData, err = r.protocol.Encode(data, nil)
if err != nil {
return fmt.Errorf("failed to encode with protocol: %w", err)
}
} else {
frameData = data
}
}
return r.conn.Write(frameData)
}
// Header 获取协议帧头
func (r *responseImpl) Header() responsepkg.FrameHeader {
if r.header == nil && r.protocol != nil && r.protocol.HasHeader() {
// 创建默认帧头
r.header = NewFrameHeader()
}
return r.header
}
// SetHeader 设置协议帧头
func (r *responseImpl) SetHeader(header responsepkg.FrameHeader) {
r.header = header
}
// Protocol 获取协议信息
func (r *responseImpl) Protocol() protocolpkg.Protocol {
return r.protocol
}
// Reset 重置Response状态用于对象池
func (r *responseImpl) Reset() {
r.conn = nil
r.protocol = nil
r.codecRegistry = nil
r.codecResolverChain = nil
r.header = nil
r.ctx = nil
}