package nnet import ( "context" "encoding/binary" "fmt" "sync" internalprotocol "github.com/noahlann/nnet/internal/protocol" internalunpacker "github.com/noahlann/nnet/internal/unpacker" protocolpkg "github.com/noahlann/nnet/pkg/protocol" unpackerpkg "github.com/noahlann/nnet/pkg/unpacker" ) // NNetProtocol nnet协议实现 type NNetProtocol struct { version string unpacker unpackerpkg.Unpacker once sync.Once } // NewNNetProtocol 创建nnet协议 func NewNNetProtocol(version string) protocolpkg.Protocol { return &NNetProtocol{ version: version, } } // Name 获取协议名称 func (p *NNetProtocol) Name() string { return "nnet" } // Version 获取协议版本 func (p *NNetProtocol) Version() string { return p.version } // HasHeader 是否有帧头(nnet协议有帧头) func (p *NNetProtocol) HasHeader() bool { return true } // Encode 编码数据 func (p *NNetProtocol) Encode(data []byte, header protocolpkg.FrameHeader) ([]byte, error) { // nnet协议格式: // [Magic(4 bytes)][Version(1 byte)][Length(4 bytes)][Data(N bytes)][Checksum(2 bytes)] magic := []byte("NNET") versionByte := byte(1) // 版本1 if header != nil { if versionVal := header.Get("version"); versionVal != nil { if v, ok := versionVal.(byte); ok { versionByte = v } } } dataLength := uint32(len(data)) // 计算校验和(简单实现:数据字节和) checksum := uint16(0) for _, b := range data { checksum += uint16(b) } // 构建消息 packet := make([]byte, 0, 11+len(data)) packet = append(packet, magic...) packet = append(packet, versionByte) // 长度(4字节,大端序) lengthBytes := make([]byte, 4) binary.BigEndian.PutUint32(lengthBytes, dataLength) packet = append(packet, lengthBytes...) // 数据 packet = append(packet, data...) // 校验和(2字节,大端序) checksumBytes := make([]byte, 2) binary.BigEndian.PutUint16(checksumBytes, checksum) packet = append(packet, checksumBytes...) return packet, nil } // Decode 解码数据 // 优化:当数据来自unpacker时(数据已经完整),可以跳过长度验证,因为unpacker已经验证过 // 但为了保持接口的通用性,我们仍然进行基本的验证 func (p *NNetProtocol) Decode(data []byte) (protocolpkg.FrameHeader, []byte, error) { // 检查最小长度(Magic 4 + Version 1 + Length 4 + 最小数据 0 + Checksum 2 = 11) if len(data) < 11 { return nil, nil, fmt.Errorf("invalid packet length: %d < 11", len(data)) } // 检查Magic if len(data) < 4 { return nil, nil, fmt.Errorf("invalid packet: too short for magic") } magic := data[0:4] if string(magic) != "NNET" { return nil, nil, fmt.Errorf("invalid magic: %s", string(magic)) } // 创建帧头 header := internalprotocol.NewFrameHeader() header.Set("magic", string(magic)) // 读取版本 if len(data) < 5 { return nil, nil, fmt.Errorf("invalid packet: too short for version") } version := data[4] header.Set("version", version) // 注意:这里不检查版本值,因为版本识别由版本识别器或服务器逻辑处理 // 这样可以支持多版本协议 // 读取长度字段(偏移5-8) if len(data) < 9 { return nil, nil, fmt.Errorf("invalid packet: too short for length field") } dataLength := binary.BigEndian.Uint32(data[5:9]) header.Set("length", dataLength) // 计算预期的总长度:Magic(4) + Version(1) + Length(4) + Data(dataLength) + Checksum(2) expectedTotalLength := 9 + int(dataLength) + 2 // 优化:如果数据长度正好等于预期长度,说明数据来自unpacker(已经完整),可以跳过长度验证 // 否则,需要进行长度验证(数据可能不完整) if len(data) != expectedTotalLength { // 数据长度不匹配,可能是数据不完整或数据错误 if len(data) < expectedTotalLength { return nil, nil, fmt.Errorf("invalid data length: expected %d, got %d", expectedTotalLength, len(data)) } // 如果数据长度大于预期,可能是多个包,但我们只处理第一个包 // 这种情况应该由unpacker处理,不应该到达这里 } // 读取数据部分(偏移9到9+dataLength) messageDataStart := 9 messageDataEnd := 9 + int(dataLength) if len(data) < messageDataEnd { return nil, nil, fmt.Errorf("invalid packet: data section incomplete") } messageData := data[messageDataStart:messageDataEnd] // 读取校验和(偏移9+dataLength到11+dataLength) checksumStart := messageDataEnd checksumEnd := checksumStart + 2 if len(data) < checksumEnd { return nil, nil, fmt.Errorf("invalid packet: checksum incomplete") } checksum := binary.BigEndian.Uint16(data[checksumStart:checksumEnd]) header.Set("checksum", checksum) // 验证校验和 calculatedChecksum := uint16(0) for _, b := range messageData { calculatedChecksum += uint16(b) } if checksum != calculatedChecksum { return nil, nil, fmt.Errorf("checksum mismatch: expected %d, got %d", calculatedChecksum, checksum) } return header, messageData, nil } // Handle 处理消息 func (p *NNetProtocol) Handle(ctx context.Context, data []byte) ([]byte, error) { // nnet协议的处理逻辑 // 这里可以添加协议特定的处理逻辑 return data, nil } // Unpacker 获取协议的拆包器 // nnet协议使用LengthFieldUnpacker来处理粘包拆包 func (p *NNetProtocol) Unpacker() unpackerpkg.Unpacker { p.once.Do(func() { // nnet协议格式:[Magic(4)][Version(1)][Length(4)][Data(N)][Checksum(2)] // 长度字段在偏移5的位置(Magic 4字节 + Version 1字节) // 长度字段是4字节,表示Data部分的长度 // 总长度 = 5(Magic+Version) + 4(Length字段) + Length(数据长度) + 2(Checksum) config := unpackerpkg.LengthFieldUnpacker{ LengthFieldOffset: 5, // Magic(4) + Version(1) = 5 LengthFieldLength: 4, // Length字段是4字节 LengthAdjustment: 2, // 需要加上Checksum(2字节) InitialBytesToStrip: 0, // 不跳过任何字节,保留完整包 } p.unpacker = internalunpacker.NewLengthFieldUnpacker(config) }) return p.unpacker } // DecodeHeader 解码帧头(增量解析,即使数据不完整) // 实现IncrementalDecoder接口,支持在数据不完整时解析帧头 func (p *NNetProtocol) DecodeHeader(data []byte) (protocolpkg.FrameHeader, int, error) { // nnet协议帧头格式:[Magic(4)][Version(1)][Length(4)] // 完整的帧头需要9字节:Magic(4) + Version(1) + Length(4) minHeaderLength := 9 // 检查最小长度 if len(data) < 4 { return nil, minHeaderLength, fmt.Errorf("invalid packet: too short for magic, need at least 4 bytes") } // 检查Magic magic := data[0:4] if string(magic) != "NNET" { return nil, minHeaderLength, fmt.Errorf("invalid magic: %s", string(magic)) } // 如果数据不足9字节,返回需要的字节数 if len(data) < minHeaderLength { return nil, minHeaderLength, nil } // 创建帧头 header := internalprotocol.NewFrameHeader() header.Set("magic", string(magic)) // 读取版本 version := data[4] header.Set("version", version) // 注意:这里不检查版本值,因为版本识别由版本识别器或服务器逻辑处理 // 这样可以支持多版本协议 // 读取长度字段 dataLength := binary.BigEndian.Uint32(data[5:9]) header.Set("length", dataLength) // 计算完整消息需要的总字节数 // 总长度 = 9(帧头) + dataLength(数据) + 2(Checksum) totalLength := minHeaderLength + int(dataLength) + 2 // 返回解析的帧头和完整消息需要的总字节数 return header, totalLength, nil } // DecodeBody 解码消息体(假设帧头已经解析) // 实现IncrementalDecoder接口,支持重用已解析的帧头 // 优化:如果header不为nil,可以跳过帧头解析,直接解析数据体和校验和,避免重复解析帧头 func (p *NNetProtocol) DecodeBody(data []byte, header protocolpkg.FrameHeader) ([]byte, error) { // 如果header为nil,需要从data中解析帧头(回退到标准Decode方法) if header == nil { _, body, err := p.Decode(data) return body, err } // 优化:从header中获取长度(避免重复读取长度字段) lengthVal := header.Get("length") if lengthVal == nil { return nil, fmt.Errorf("header missing length field") } dataLength, ok := lengthVal.(uint32) if !ok { return nil, fmt.Errorf("invalid length field type in header") } // 计算预期的总长度:Magic(4) + Version(1) + Length(4) + Data(dataLength) + Checksum(2) expectedTotalLength := 9 + int(dataLength) + 2 // 验证数据长度(数据应该来自unpacker,已经完整) if len(data) != expectedTotalLength { if len(data) < expectedTotalLength { return nil, fmt.Errorf("invalid data length: expected %d, got %d", expectedTotalLength, len(data)) } // 如果数据长度大于预期,可能是多个包,但我们只处理第一个包 // 这种情况应该由unpacker处理,不应该到达这里 } // 读取数据部分(偏移9到9+dataLength) // 优化:直接使用header中的长度信息,避免重复读取长度字段 messageDataStart := 9 messageDataEnd := 9 + int(dataLength) messageData := data[messageDataStart:messageDataEnd] // 读取校验和(偏移9+dataLength到11+dataLength) checksumStart := messageDataEnd checksumEnd := checksumStart + 2 checksum := binary.BigEndian.Uint16(data[checksumStart:checksumEnd]) // 将校验和设置到header中(保持与Decode方法的一致性) header.Set("checksum", checksum) // 验证校验和 calculatedChecksum := uint16(0) for _, b := range messageData { calculatedChecksum += uint16(b) } if checksum != calculatedChecksum { return nil, fmt.Errorf("checksum mismatch: expected %d, got %d", calculatedChecksum, checksum) } return messageData, nil }