package unpacker import ( "encoding/binary" unpackerpkg "github.com/noahlann/nnet/pkg/unpacker" ) // lengthFieldUnpacker 长度字段拆包器实现 type lengthFieldUnpacker struct { lengthFieldOffset int lengthFieldLength int lengthAdjustment int initialBytesToStrip int buffer []byte byteOrder binary.ByteOrder maxBufferSize int } // NewLengthFieldUnpacker 创建长度字段拆包器 func NewLengthFieldUnpacker(config unpackerpkg.LengthFieldUnpacker) unpackerpkg.Unpacker { lengthFieldLength := config.LengthFieldLength if lengthFieldLength <= 0 { lengthFieldLength = 4 // 默认4字节 } maxBufferSize := config.MaxBufferSize if maxBufferSize <= 0 { maxBufferSize = unpackerpkg.DefaultMaxBufferSize } return &lengthFieldUnpacker{ lengthFieldOffset: config.LengthFieldOffset, lengthFieldLength: lengthFieldLength, lengthAdjustment: config.LengthAdjustment, initialBytesToStrip: config.InitialBytesToStrip, buffer: nil, // 延迟分配,使用buffer pool byteOrder: binary.BigEndian, maxBufferSize: maxBufferSize, } } // Unpack 拆包 func (u *lengthFieldUnpacker) Unpack(data []byte) ([][]byte, []byte, int, error) { // 记录调用前的buffer长度(用于计算消耗的数据量) prevBufferLen := len(u.buffer) // 检查buffer大小限制 newSize := prevBufferLen + len(data) if newSize > u.maxBufferSize { return nil, nil, 0, unpackerpkg.NewErrorf("unpacker buffer size exceeded: %d > %d", newSize, u.maxBufferSize) } // 优化buffer管理:只在需要时分配和扩容 if u.buffer == nil { // 首次使用,从buffer pool获取或分配 if newSize <= 4096 { // 小数据,使用固定大小的buffer u.buffer = make([]byte, 0, 4096) } else { // 大数据,直接分配所需容量(但不超过maxBufferSize) cap := newSize if cap > u.maxBufferSize { cap = u.maxBufferSize } u.buffer = make([]byte, 0, cap) } } else if cap(u.buffer) < newSize { // 需要扩容:计算新容量(至少是当前大小的2倍,但不超过maxBufferSize) newCap := cap(u.buffer) * 2 if newCap < newSize { newCap = newSize } if newCap > u.maxBufferSize { newCap = u.maxBufferSize } // 扩容:创建新buffer并复制数据 newBuffer := make([]byte, len(u.buffer), newCap) copy(newBuffer, u.buffer) u.buffer = newBuffer } // 追加新数据 u.buffer = append(u.buffer, data...) // 记录追加数据后的buffer长度(用于计算consumed) afterAppendLen := len(u.buffer) var messages [][]byte for { if len(u.buffer) < u.lengthFieldOffset+u.lengthFieldLength { // 数据不足,等待更多数据 break } // 读取长度字段 lengthBytes := u.buffer[u.lengthFieldOffset : u.lengthFieldOffset+u.lengthFieldLength] var length int switch u.lengthFieldLength { case 1: length = int(lengthBytes[0]) case 2: length = int(u.byteOrder.Uint16(lengthBytes)) case 4: length = int(u.byteOrder.Uint32(lengthBytes)) case 8: length = int(u.byteOrder.Uint64(lengthBytes)) default: return nil, u.buffer, 0, unpackerpkg.NewError("unsupported length field length") } // 验证长度字段的合理性(防止恶意数据) if length < 0 || length > u.maxBufferSize { return nil, nil, 0, unpackerpkg.NewErrorf("invalid length field: %d", length) } // 计算实际消息长度 actualLength := length + u.lengthAdjustment if actualLength < 0 { return nil, nil, 0, unpackerpkg.NewErrorf("invalid actual length: %d", actualLength) } totalLength := u.lengthFieldOffset + u.lengthFieldLength + actualLength if totalLength > u.maxBufferSize { return nil, nil, 0, unpackerpkg.NewErrorf("message too large: %d > %d", totalLength, u.maxBufferSize) } if len(u.buffer) < totalLength { // 数据不足,等待更多数据 break } // 提取消息(必须复制数据,因为buffer可能会被修改) // 优化:预分配messages slice的容量以减少重新分配 start := u.initialBytesToStrip end := totalLength messageLen := end - start if messages == nil { // 预分配messages slice(假设至少有一个消息) messages = make([][]byte, 0, 1) } message := make([]byte, messageLen) copy(message, u.buffer[start:end]) messages = append(messages, message) // 移除已处理的数据(使用切片操作,避免复制) u.buffer = u.buffer[totalLength:] // 优化:如果buffer使用率太低,压缩buffer以减少内存占用 // 使用率阈值:25%(即剩余数据 < 容量的25%) // 注意:压缩不会改变buffer的长度,只改变容量 if len(u.buffer) < cap(u.buffer)/4 && cap(u.buffer) > 4096 { // 压缩:新容量为当前容量的50%,但至少能容纳当前数据 newCap := cap(u.buffer) / 2 if newCap < len(u.buffer) { newCap = len(u.buffer) } // 如果新容量仍然合理,执行压缩 if newCap >= 1024 { compressed := make([]byte, len(u.buffer), newCap) copy(compressed, u.buffer) u.buffer = compressed } } } // 计算本次从输入data中消耗的数据量(100%准确,无误差) // 使用与delimiterUnpacker相同的逻辑 currentBufferLen := len(u.buffer) processedTotal := afterAppendLen - currentBufferLen bufferIncrease := currentBufferLen - prevBufferLen var consumed int if len(messages) == 0 { // 没有完整消息,所有数据都被保存到buffer consumed = len(data) } else if processedTotal <= prevBufferLen { // 所有被处理的数据都来自之前的buffer if bufferIncrease > 0 { consumed = bufferIncrease if consumed > len(data) { consumed = len(data) } } else { consumed = 0 } } else { // 至少部分被处理的数据来自输入data consumedFromData := processedTotal - prevBufferLen remainingFromData := len(data) - consumedFromData if bufferIncrease <= remainingFromData { consumed = len(data) } else if bufferIncrease > 0 { consumed = consumedFromData + remainingFromData } else { consumed = consumedFromData if consumed > len(data) { consumed = len(data) } } } // 确保consumed在合理范围内 if consumed < 0 { consumed = 0 } else if consumed > len(data) { consumed = len(data) } return messages, u.buffer, consumed, nil } // Pack 打包 func (u *lengthFieldUnpacker) Pack(data []byte) ([]byte, error) { dataLength := len(data) totalLength := u.lengthFieldOffset + u.lengthFieldLength + dataLength - u.lengthAdjustment result := make([]byte, totalLength) // 填充长度字段之前的字节(如果有) if u.lengthFieldOffset > 0 { // 通常为0,如果需要可以扩展 } // 写入长度字段 lengthValue := dataLength - u.lengthAdjustment lengthBytes := make([]byte, u.lengthFieldLength) switch u.lengthFieldLength { case 1: lengthBytes[0] = byte(lengthValue) case 2: u.byteOrder.PutUint16(lengthBytes, uint16(lengthValue)) case 4: u.byteOrder.PutUint32(lengthBytes, uint32(lengthValue)) case 8: u.byteOrder.PutUint64(lengthBytes, uint64(lengthValue)) } copy(result[u.lengthFieldOffset:], lengthBytes) // 写入数据 copy(result[u.lengthFieldOffset+u.lengthFieldLength:], data[u.initialBytesToStrip:]) return result, nil }