package unpacker import ( unpackerpkg "github.com/noahlann/nnet/pkg/unpacker" ) // frameHeaderUnpacker 帧头拆包器实现 type frameHeaderUnpacker struct { headerLength int getLength func(header []byte) int buffer []byte maxBufferSize int } // NewFrameHeaderUnpacker 创建帧头拆包器 func NewFrameHeaderUnpacker(config unpackerpkg.FrameHeaderUnpacker) unpackerpkg.Unpacker { headerLength := config.HeaderLength if headerLength <= 0 { headerLength = 4 // 默认4字节 } getLength := config.GetLength if getLength == nil { // 默认实现:假设前4字节是长度 getLength = func(header []byte) int { if len(header) < 4 { return 0 } return int(header[0])<<24 | int(header[1])<<16 | int(header[2])<<8 | int(header[3]) } } maxBufferSize := config.MaxBufferSize if maxBufferSize <= 0 { maxBufferSize = unpackerpkg.DefaultMaxBufferSize } return &frameHeaderUnpacker{ headerLength: headerLength, getLength: getLength, buffer: make([]byte, 0, 4096), // 预分配初始容量 maxBufferSize: maxBufferSize, } } // Unpack 拆包 func (u *frameHeaderUnpacker) Unpack(data []byte) ([][]byte, []byte, int, error) { // 记录调用前的buffer长度(用于计算消耗的数据量) prevBufferLen := len(u.buffer) // 检查buffer大小限制 newSize := prevBufferLen + len(data) if newSize > u.maxBufferSize { return nil, nil, 0, unpackerpkg.NewErrorf("unpacker buffer size exceeded: %d > %d", newSize, u.maxBufferSize) } // 优化:如果容量不足,预分配更大的容量(零拷贝优化) if cap(u.buffer) < newSize { newCap := cap(u.buffer) * 2 if newCap < newSize { newCap = newSize } if newCap > u.maxBufferSize { newCap = u.maxBufferSize } // 如果现有 buffer 为空,直接分配新 buffer(避免不必要的复制) if len(u.buffer) == 0 { u.buffer = make([]byte, 0, newCap) } else { newBuffer := make([]byte, len(u.buffer), newCap) copy(newBuffer, u.buffer) u.buffer = newBuffer } } u.buffer = append(u.buffer, data...) // 记录追加数据后的buffer长度(用于计算consumed) afterAppendLen := len(u.buffer) var messages [][]byte for { if len(u.buffer) < u.headerLength { // 数据不足,等待更多数据 break } // 读取帧头 header := u.buffer[:u.headerLength] messageLength := u.getLength(header) if messageLength <= 0 { // 无效长度,跳过 u.buffer = u.buffer[1:] continue } // 验证长度字段的合理性(防止恶意数据) if messageLength > u.maxBufferSize { return nil, nil, 0, unpackerpkg.NewErrorf("invalid message length: %d > %d", messageLength, u.maxBufferSize) } totalLength := u.headerLength + messageLength if totalLength > u.maxBufferSize { return nil, nil, 0, unpackerpkg.NewErrorf("message too large: %d > %d", totalLength, u.maxBufferSize) } if len(u.buffer) < totalLength { // 数据不足,等待更多数据 break } // 提取消息 message := make([]byte, totalLength) copy(message, u.buffer[:totalLength]) messages = append(messages, message) // 移除已处理的数据(优化:使用切片操作,避免复制) u.buffer = u.buffer[totalLength:] // 如果 buffer 太大但剩余数据很少,压缩 buffer(减少内存占用) // 注意:压缩不会改变buffer的长度,只改变容量 if len(u.buffer) < cap(u.buffer)/4 && cap(u.buffer) > 4096 { compressed := make([]byte, len(u.buffer), cap(u.buffer)/2) copy(compressed, u.buffer) u.buffer = compressed } } // 计算本次从输入data中消耗的数据量(100%准确,无误差) // 使用与delimiterUnpacker相同的逻辑 currentBufferLen := len(u.buffer) processedTotal := afterAppendLen - currentBufferLen bufferIncrease := currentBufferLen - prevBufferLen var consumed int if len(messages) == 0 { // 没有完整消息,所有数据都被保存到buffer consumed = len(data) } else if processedTotal <= prevBufferLen { // 所有被处理的数据都来自之前的buffer if bufferIncrease > 0 { consumed = bufferIncrease if consumed > len(data) { consumed = len(data) } } else { consumed = 0 } } else { // 至少部分被处理的数据来自输入data consumedFromData := processedTotal - prevBufferLen remainingFromData := len(data) - consumedFromData if bufferIncrease <= remainingFromData { consumed = len(data) } else if bufferIncrease > 0 { consumed = consumedFromData + remainingFromData } else { consumed = consumedFromData if consumed > len(data) { consumed = len(data) } } } // 确保consumed在合理范围内 if consumed < 0 { consumed = 0 } else if consumed > len(data) { consumed = len(data) } return messages, u.buffer, consumed, nil } // Pack 打包 func (u *frameHeaderUnpacker) Pack(data []byte) ([]byte, error) { messageLength := len(data) header := make([]byte, u.headerLength) // 写入长度到帧头 if u.headerLength >= 4 { header[0] = byte(messageLength >> 24) header[1] = byte(messageLength >> 16) header[2] = byte(messageLength >> 8) header[3] = byte(messageLength) } result := make([]byte, u.headerLength+messageLength) copy(result, header) copy(result[u.headerLength:], data) return result, nil }