|
|
package unpacker
|
|
|
|
|
|
import (
|
|
|
"encoding/binary"
|
|
|
|
|
|
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
|
|
|
)
|
|
|
|
|
|
// lengthFieldUnpacker 长度字段拆包器实现
|
|
|
type lengthFieldUnpacker struct {
|
|
|
lengthFieldOffset int
|
|
|
lengthFieldLength int
|
|
|
lengthAdjustment int
|
|
|
initialBytesToStrip int
|
|
|
buffer []byte
|
|
|
byteOrder binary.ByteOrder
|
|
|
maxBufferSize int
|
|
|
}
|
|
|
|
|
|
// NewLengthFieldUnpacker 创建长度字段拆包器
|
|
|
func NewLengthFieldUnpacker(config unpackerpkg.LengthFieldUnpacker) unpackerpkg.Unpacker {
|
|
|
lengthFieldLength := config.LengthFieldLength
|
|
|
if lengthFieldLength <= 0 {
|
|
|
lengthFieldLength = 4 // 默认4字节
|
|
|
}
|
|
|
|
|
|
maxBufferSize := config.MaxBufferSize
|
|
|
if maxBufferSize <= 0 {
|
|
|
maxBufferSize = unpackerpkg.DefaultMaxBufferSize
|
|
|
}
|
|
|
|
|
|
return &lengthFieldUnpacker{
|
|
|
lengthFieldOffset: config.LengthFieldOffset,
|
|
|
lengthFieldLength: lengthFieldLength,
|
|
|
lengthAdjustment: config.LengthAdjustment,
|
|
|
initialBytesToStrip: config.InitialBytesToStrip,
|
|
|
buffer: nil, // 延迟分配,使用buffer pool
|
|
|
byteOrder: binary.BigEndian,
|
|
|
maxBufferSize: maxBufferSize,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Unpack 拆包
|
|
|
func (u *lengthFieldUnpacker) Unpack(data []byte) ([][]byte, []byte, int, error) {
|
|
|
// 记录调用前的buffer长度(用于计算消耗的数据量)
|
|
|
prevBufferLen := len(u.buffer)
|
|
|
|
|
|
// 检查buffer大小限制
|
|
|
newSize := prevBufferLen + len(data)
|
|
|
if newSize > u.maxBufferSize {
|
|
|
return nil, nil, 0, unpackerpkg.NewErrorf("unpacker buffer size exceeded: %d > %d", newSize, u.maxBufferSize)
|
|
|
}
|
|
|
|
|
|
// 优化buffer管理:只在需要时分配和扩容
|
|
|
if u.buffer == nil {
|
|
|
// 首次使用,从buffer pool获取或分配
|
|
|
if newSize <= 4096 {
|
|
|
// 小数据,使用固定大小的buffer
|
|
|
u.buffer = make([]byte, 0, 4096)
|
|
|
} else {
|
|
|
// 大数据,直接分配所需容量(但不超过maxBufferSize)
|
|
|
cap := newSize
|
|
|
if cap > u.maxBufferSize {
|
|
|
cap = u.maxBufferSize
|
|
|
}
|
|
|
u.buffer = make([]byte, 0, cap)
|
|
|
}
|
|
|
} else if cap(u.buffer) < newSize {
|
|
|
// 需要扩容:计算新容量(至少是当前大小的2倍,但不超过maxBufferSize)
|
|
|
newCap := cap(u.buffer) * 2
|
|
|
if newCap < newSize {
|
|
|
newCap = newSize
|
|
|
}
|
|
|
if newCap > u.maxBufferSize {
|
|
|
newCap = u.maxBufferSize
|
|
|
}
|
|
|
|
|
|
// 扩容:创建新buffer并复制数据
|
|
|
newBuffer := make([]byte, len(u.buffer), newCap)
|
|
|
copy(newBuffer, u.buffer)
|
|
|
u.buffer = newBuffer
|
|
|
}
|
|
|
|
|
|
// 追加新数据
|
|
|
u.buffer = append(u.buffer, data...)
|
|
|
|
|
|
// 记录追加数据后的buffer长度(用于计算consumed)
|
|
|
afterAppendLen := len(u.buffer)
|
|
|
|
|
|
var messages [][]byte
|
|
|
|
|
|
for {
|
|
|
if len(u.buffer) < u.lengthFieldOffset+u.lengthFieldLength {
|
|
|
// 数据不足,等待更多数据
|
|
|
break
|
|
|
}
|
|
|
|
|
|
// 读取长度字段
|
|
|
lengthBytes := u.buffer[u.lengthFieldOffset : u.lengthFieldOffset+u.lengthFieldLength]
|
|
|
var length int
|
|
|
switch u.lengthFieldLength {
|
|
|
case 1:
|
|
|
length = int(lengthBytes[0])
|
|
|
case 2:
|
|
|
length = int(u.byteOrder.Uint16(lengthBytes))
|
|
|
case 4:
|
|
|
length = int(u.byteOrder.Uint32(lengthBytes))
|
|
|
case 8:
|
|
|
length = int(u.byteOrder.Uint64(lengthBytes))
|
|
|
default:
|
|
|
return nil, u.buffer, 0, unpackerpkg.NewError("unsupported length field length")
|
|
|
}
|
|
|
|
|
|
// 验证长度字段的合理性(防止恶意数据)
|
|
|
if length < 0 || length > u.maxBufferSize {
|
|
|
return nil, nil, 0, unpackerpkg.NewErrorf("invalid length field: %d", length)
|
|
|
}
|
|
|
|
|
|
// 计算实际消息长度
|
|
|
actualLength := length + u.lengthAdjustment
|
|
|
if actualLength < 0 {
|
|
|
return nil, nil, 0, unpackerpkg.NewErrorf("invalid actual length: %d", actualLength)
|
|
|
}
|
|
|
totalLength := u.lengthFieldOffset + u.lengthFieldLength + actualLength
|
|
|
|
|
|
if totalLength > u.maxBufferSize {
|
|
|
return nil, nil, 0, unpackerpkg.NewErrorf("message too large: %d > %d", totalLength, u.maxBufferSize)
|
|
|
}
|
|
|
|
|
|
if len(u.buffer) < totalLength {
|
|
|
// 数据不足,等待更多数据
|
|
|
break
|
|
|
}
|
|
|
|
|
|
// 提取消息(必须复制数据,因为buffer可能会被修改)
|
|
|
// 优化:预分配messages slice的容量以减少重新分配
|
|
|
start := u.initialBytesToStrip
|
|
|
end := totalLength
|
|
|
messageLen := end - start
|
|
|
if messages == nil {
|
|
|
// 预分配messages slice(假设至少有一个消息)
|
|
|
messages = make([][]byte, 0, 1)
|
|
|
}
|
|
|
message := make([]byte, messageLen)
|
|
|
copy(message, u.buffer[start:end])
|
|
|
messages = append(messages, message)
|
|
|
|
|
|
// 移除已处理的数据(使用切片操作,避免复制)
|
|
|
u.buffer = u.buffer[totalLength:]
|
|
|
|
|
|
// 优化:如果buffer使用率太低,压缩buffer以减少内存占用
|
|
|
// 使用率阈值:25%(即剩余数据 < 容量的25%)
|
|
|
// 注意:压缩不会改变buffer的长度,只改变容量
|
|
|
if len(u.buffer) < cap(u.buffer)/4 && cap(u.buffer) > 4096 {
|
|
|
// 压缩:新容量为当前容量的50%,但至少能容纳当前数据
|
|
|
newCap := cap(u.buffer) / 2
|
|
|
if newCap < len(u.buffer) {
|
|
|
newCap = len(u.buffer)
|
|
|
}
|
|
|
// 如果新容量仍然合理,执行压缩
|
|
|
if newCap >= 1024 {
|
|
|
compressed := make([]byte, len(u.buffer), newCap)
|
|
|
copy(compressed, u.buffer)
|
|
|
u.buffer = compressed
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 计算本次从输入data中消耗的数据量(100%准确,无误差)
|
|
|
// 使用与delimiterUnpacker相同的逻辑
|
|
|
currentBufferLen := len(u.buffer)
|
|
|
processedTotal := afterAppendLen - currentBufferLen
|
|
|
bufferIncrease := currentBufferLen - prevBufferLen
|
|
|
|
|
|
var consumed int
|
|
|
if len(messages) == 0 {
|
|
|
// 没有完整消息,所有数据都被保存到buffer
|
|
|
consumed = len(data)
|
|
|
} else if processedTotal <= prevBufferLen {
|
|
|
// 所有被处理的数据都来自之前的buffer
|
|
|
if bufferIncrease > 0 {
|
|
|
consumed = bufferIncrease
|
|
|
if consumed > len(data) {
|
|
|
consumed = len(data)
|
|
|
}
|
|
|
} else {
|
|
|
consumed = 0
|
|
|
}
|
|
|
} else {
|
|
|
// 至少部分被处理的数据来自输入data
|
|
|
consumedFromData := processedTotal - prevBufferLen
|
|
|
remainingFromData := len(data) - consumedFromData
|
|
|
if bufferIncrease <= remainingFromData {
|
|
|
consumed = len(data)
|
|
|
} else if bufferIncrease > 0 {
|
|
|
consumed = consumedFromData + remainingFromData
|
|
|
} else {
|
|
|
consumed = consumedFromData
|
|
|
if consumed > len(data) {
|
|
|
consumed = len(data)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// 确保consumed在合理范围内
|
|
|
if consumed < 0 {
|
|
|
consumed = 0
|
|
|
} else if consumed > len(data) {
|
|
|
consumed = len(data)
|
|
|
}
|
|
|
|
|
|
return messages, u.buffer, consumed, nil
|
|
|
}
|
|
|
|
|
|
// Pack 打包
|
|
|
func (u *lengthFieldUnpacker) Pack(data []byte) ([]byte, error) {
|
|
|
dataLength := len(data)
|
|
|
totalLength := u.lengthFieldOffset + u.lengthFieldLength + dataLength - u.lengthAdjustment
|
|
|
|
|
|
result := make([]byte, totalLength)
|
|
|
|
|
|
// 填充长度字段之前的字节(如果有)
|
|
|
if u.lengthFieldOffset > 0 {
|
|
|
// 通常为0,如果需要可以扩展
|
|
|
}
|
|
|
|
|
|
// 写入长度字段
|
|
|
lengthValue := dataLength - u.lengthAdjustment
|
|
|
lengthBytes := make([]byte, u.lengthFieldLength)
|
|
|
switch u.lengthFieldLength {
|
|
|
case 1:
|
|
|
lengthBytes[0] = byte(lengthValue)
|
|
|
case 2:
|
|
|
u.byteOrder.PutUint16(lengthBytes, uint16(lengthValue))
|
|
|
case 4:
|
|
|
u.byteOrder.PutUint32(lengthBytes, uint32(lengthValue))
|
|
|
case 8:
|
|
|
u.byteOrder.PutUint64(lengthBytes, uint64(lengthValue))
|
|
|
}
|
|
|
copy(result[u.lengthFieldOffset:], lengthBytes)
|
|
|
|
|
|
// 写入数据
|
|
|
copy(result[u.lengthFieldOffset+u.lengthFieldLength:], data[u.initialBytesToStrip:])
|
|
|
|
|
|
return result, nil
|
|
|
}
|