|
|
package nnet
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"encoding/binary"
|
|
|
"fmt"
|
|
|
"sync"
|
|
|
|
|
|
internalprotocol "github.com/noahlann/nnet/internal/protocol"
|
|
|
internalunpacker "github.com/noahlann/nnet/internal/unpacker"
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
|
|
|
)
|
|
|
|
|
|
// NNetProtocol nnet协议实现
|
|
|
type NNetProtocol struct {
|
|
|
version string
|
|
|
unpacker unpackerpkg.Unpacker
|
|
|
once sync.Once
|
|
|
}
|
|
|
|
|
|
// NewNNetProtocol 创建nnet协议
|
|
|
func NewNNetProtocol(version string) protocolpkg.Protocol {
|
|
|
return &NNetProtocol{
|
|
|
version: version,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// Name 获取协议名称
|
|
|
func (p *NNetProtocol) Name() string {
|
|
|
return "nnet"
|
|
|
}
|
|
|
|
|
|
// Version 获取协议版本
|
|
|
func (p *NNetProtocol) Version() string {
|
|
|
return p.version
|
|
|
}
|
|
|
|
|
|
// HasHeader 是否有帧头(nnet协议有帧头)
|
|
|
func (p *NNetProtocol) HasHeader() bool {
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
// Encode 编码数据
|
|
|
func (p *NNetProtocol) Encode(data []byte, header protocolpkg.FrameHeader) ([]byte, error) {
|
|
|
// nnet协议格式:
|
|
|
// [Magic(4 bytes)][Version(1 byte)][Length(4 bytes)][Data(N bytes)][Checksum(2 bytes)]
|
|
|
|
|
|
magic := []byte("NNET")
|
|
|
versionByte := byte(1) // 版本1
|
|
|
if header != nil {
|
|
|
if versionVal := header.Get("version"); versionVal != nil {
|
|
|
if v, ok := versionVal.(byte); ok {
|
|
|
versionByte = v
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
dataLength := uint32(len(data))
|
|
|
|
|
|
// 计算校验和(简单实现:数据字节和)
|
|
|
checksum := uint16(0)
|
|
|
for _, b := range data {
|
|
|
checksum += uint16(b)
|
|
|
}
|
|
|
|
|
|
// 构建消息
|
|
|
packet := make([]byte, 0, 11+len(data))
|
|
|
packet = append(packet, magic...)
|
|
|
packet = append(packet, versionByte)
|
|
|
|
|
|
// 长度(4字节,大端序)
|
|
|
lengthBytes := make([]byte, 4)
|
|
|
binary.BigEndian.PutUint32(lengthBytes, dataLength)
|
|
|
packet = append(packet, lengthBytes...)
|
|
|
|
|
|
// 数据
|
|
|
packet = append(packet, data...)
|
|
|
|
|
|
// 校验和(2字节,大端序)
|
|
|
checksumBytes := make([]byte, 2)
|
|
|
binary.BigEndian.PutUint16(checksumBytes, checksum)
|
|
|
packet = append(packet, checksumBytes...)
|
|
|
|
|
|
return packet, nil
|
|
|
}
|
|
|
|
|
|
// Decode 解码数据
|
|
|
// 优化:当数据来自unpacker时(数据已经完整),可以跳过长度验证,因为unpacker已经验证过
|
|
|
// 但为了保持接口的通用性,我们仍然进行基本的验证
|
|
|
func (p *NNetProtocol) Decode(data []byte) (protocolpkg.FrameHeader, []byte, error) {
|
|
|
// 检查最小长度(Magic 4 + Version 1 + Length 4 + 最小数据 0 + Checksum 2 = 11)
|
|
|
if len(data) < 11 {
|
|
|
return nil, nil, fmt.Errorf("invalid packet length: %d < 11", len(data))
|
|
|
}
|
|
|
|
|
|
// 检查Magic
|
|
|
if len(data) < 4 {
|
|
|
return nil, nil, fmt.Errorf("invalid packet: too short for magic")
|
|
|
}
|
|
|
magic := data[0:4]
|
|
|
if string(magic) != "NNET" {
|
|
|
return nil, nil, fmt.Errorf("invalid magic: %s", string(magic))
|
|
|
}
|
|
|
|
|
|
// 创建帧头
|
|
|
header := internalprotocol.NewFrameHeader()
|
|
|
header.Set("magic", string(magic))
|
|
|
|
|
|
// 读取版本
|
|
|
if len(data) < 5 {
|
|
|
return nil, nil, fmt.Errorf("invalid packet: too short for version")
|
|
|
}
|
|
|
version := data[4]
|
|
|
header.Set("version", version)
|
|
|
// 注意:这里不检查版本值,因为版本识别由版本识别器或服务器逻辑处理
|
|
|
// 这样可以支持多版本协议
|
|
|
|
|
|
// 读取长度字段(偏移5-8)
|
|
|
if len(data) < 9 {
|
|
|
return nil, nil, fmt.Errorf("invalid packet: too short for length field")
|
|
|
}
|
|
|
dataLength := binary.BigEndian.Uint32(data[5:9])
|
|
|
header.Set("length", dataLength)
|
|
|
|
|
|
// 计算预期的总长度:Magic(4) + Version(1) + Length(4) + Data(dataLength) + Checksum(2)
|
|
|
expectedTotalLength := 9 + int(dataLength) + 2
|
|
|
|
|
|
// 优化:如果数据长度正好等于预期长度,说明数据来自unpacker(已经完整),可以跳过长度验证
|
|
|
// 否则,需要进行长度验证(数据可能不完整)
|
|
|
if len(data) != expectedTotalLength {
|
|
|
// 数据长度不匹配,可能是数据不完整或数据错误
|
|
|
if len(data) < expectedTotalLength {
|
|
|
return nil, nil, fmt.Errorf("invalid data length: expected %d, got %d", expectedTotalLength, len(data))
|
|
|
}
|
|
|
// 如果数据长度大于预期,可能是多个包,但我们只处理第一个包
|
|
|
// 这种情况应该由unpacker处理,不应该到达这里
|
|
|
}
|
|
|
|
|
|
// 读取数据部分(偏移9到9+dataLength)
|
|
|
messageDataStart := 9
|
|
|
messageDataEnd := 9 + int(dataLength)
|
|
|
if len(data) < messageDataEnd {
|
|
|
return nil, nil, fmt.Errorf("invalid packet: data section incomplete")
|
|
|
}
|
|
|
messageData := data[messageDataStart:messageDataEnd]
|
|
|
|
|
|
// 读取校验和(偏移9+dataLength到11+dataLength)
|
|
|
checksumStart := messageDataEnd
|
|
|
checksumEnd := checksumStart + 2
|
|
|
if len(data) < checksumEnd {
|
|
|
return nil, nil, fmt.Errorf("invalid packet: checksum incomplete")
|
|
|
}
|
|
|
checksum := binary.BigEndian.Uint16(data[checksumStart:checksumEnd])
|
|
|
header.Set("checksum", checksum)
|
|
|
|
|
|
// 验证校验和
|
|
|
calculatedChecksum := uint16(0)
|
|
|
for _, b := range messageData {
|
|
|
calculatedChecksum += uint16(b)
|
|
|
}
|
|
|
if checksum != calculatedChecksum {
|
|
|
return nil, nil, fmt.Errorf("checksum mismatch: expected %d, got %d", calculatedChecksum, checksum)
|
|
|
}
|
|
|
|
|
|
return header, messageData, nil
|
|
|
}
|
|
|
|
|
|
// Handle 处理消息
|
|
|
func (p *NNetProtocol) Handle(ctx context.Context, data []byte) ([]byte, error) {
|
|
|
// nnet协议的处理逻辑
|
|
|
// 这里可以添加协议特定的处理逻辑
|
|
|
return data, nil
|
|
|
}
|
|
|
|
|
|
// Unpacker 获取协议的拆包器
|
|
|
// nnet协议使用LengthFieldUnpacker来处理粘包拆包
|
|
|
func (p *NNetProtocol) Unpacker() unpackerpkg.Unpacker {
|
|
|
p.once.Do(func() {
|
|
|
// nnet协议格式:[Magic(4)][Version(1)][Length(4)][Data(N)][Checksum(2)]
|
|
|
// 长度字段在偏移5的位置(Magic 4字节 + Version 1字节)
|
|
|
// 长度字段是4字节,表示Data部分的长度
|
|
|
// 总长度 = 5(Magic+Version) + 4(Length字段) + Length(数据长度) + 2(Checksum)
|
|
|
config := unpackerpkg.LengthFieldUnpacker{
|
|
|
LengthFieldOffset: 5, // Magic(4) + Version(1) = 5
|
|
|
LengthFieldLength: 4, // Length字段是4字节
|
|
|
LengthAdjustment: 2, // 需要加上Checksum(2字节)
|
|
|
InitialBytesToStrip: 0, // 不跳过任何字节,保留完整包
|
|
|
}
|
|
|
p.unpacker = internalunpacker.NewLengthFieldUnpacker(config)
|
|
|
})
|
|
|
return p.unpacker
|
|
|
}
|
|
|
|
|
|
// DecodeHeader 解码帧头(增量解析,即使数据不完整)
|
|
|
// 实现IncrementalDecoder接口,支持在数据不完整时解析帧头
|
|
|
func (p *NNetProtocol) DecodeHeader(data []byte) (protocolpkg.FrameHeader, int, error) {
|
|
|
// nnet协议帧头格式:[Magic(4)][Version(1)][Length(4)]
|
|
|
// 完整的帧头需要9字节:Magic(4) + Version(1) + Length(4)
|
|
|
minHeaderLength := 9
|
|
|
|
|
|
// 检查最小长度
|
|
|
if len(data) < 4 {
|
|
|
return nil, minHeaderLength, fmt.Errorf("invalid packet: too short for magic, need at least 4 bytes")
|
|
|
}
|
|
|
|
|
|
// 检查Magic
|
|
|
magic := data[0:4]
|
|
|
if string(magic) != "NNET" {
|
|
|
return nil, minHeaderLength, fmt.Errorf("invalid magic: %s", string(magic))
|
|
|
}
|
|
|
|
|
|
// 如果数据不足9字节,返回需要的字节数
|
|
|
if len(data) < minHeaderLength {
|
|
|
return nil, minHeaderLength, nil
|
|
|
}
|
|
|
|
|
|
// 创建帧头
|
|
|
header := internalprotocol.NewFrameHeader()
|
|
|
header.Set("magic", string(magic))
|
|
|
|
|
|
// 读取版本
|
|
|
version := data[4]
|
|
|
header.Set("version", version)
|
|
|
// 注意:这里不检查版本值,因为版本识别由版本识别器或服务器逻辑处理
|
|
|
// 这样可以支持多版本协议
|
|
|
|
|
|
// 读取长度字段
|
|
|
dataLength := binary.BigEndian.Uint32(data[5:9])
|
|
|
header.Set("length", dataLength)
|
|
|
|
|
|
// 计算完整消息需要的总字节数
|
|
|
// 总长度 = 9(帧头) + dataLength(数据) + 2(Checksum)
|
|
|
totalLength := minHeaderLength + int(dataLength) + 2
|
|
|
|
|
|
// 返回解析的帧头和完整消息需要的总字节数
|
|
|
return header, totalLength, nil
|
|
|
}
|
|
|
|
|
|
// DecodeBody 解码消息体(假设帧头已经解析)
|
|
|
// 实现IncrementalDecoder接口,支持重用已解析的帧头
|
|
|
// 优化:如果header不为nil,可以跳过帧头解析,直接解析数据体和校验和,避免重复解析帧头
|
|
|
func (p *NNetProtocol) DecodeBody(data []byte, header protocolpkg.FrameHeader) ([]byte, error) {
|
|
|
// 如果header为nil,需要从data中解析帧头(回退到标准Decode方法)
|
|
|
if header == nil {
|
|
|
_, body, err := p.Decode(data)
|
|
|
return body, err
|
|
|
}
|
|
|
|
|
|
// 优化:从header中获取长度(避免重复读取长度字段)
|
|
|
lengthVal := header.Get("length")
|
|
|
if lengthVal == nil {
|
|
|
return nil, fmt.Errorf("header missing length field")
|
|
|
}
|
|
|
dataLength, ok := lengthVal.(uint32)
|
|
|
if !ok {
|
|
|
return nil, fmt.Errorf("invalid length field type in header")
|
|
|
}
|
|
|
|
|
|
// 计算预期的总长度:Magic(4) + Version(1) + Length(4) + Data(dataLength) + Checksum(2)
|
|
|
expectedTotalLength := 9 + int(dataLength) + 2
|
|
|
|
|
|
// 验证数据长度(数据应该来自unpacker,已经完整)
|
|
|
if len(data) != expectedTotalLength {
|
|
|
if len(data) < expectedTotalLength {
|
|
|
return nil, fmt.Errorf("invalid data length: expected %d, got %d", expectedTotalLength, len(data))
|
|
|
}
|
|
|
// 如果数据长度大于预期,可能是多个包,但我们只处理第一个包
|
|
|
// 这种情况应该由unpacker处理,不应该到达这里
|
|
|
}
|
|
|
|
|
|
// 读取数据部分(偏移9到9+dataLength)
|
|
|
// 优化:直接使用header中的长度信息,避免重复读取长度字段
|
|
|
messageDataStart := 9
|
|
|
messageDataEnd := 9 + int(dataLength)
|
|
|
messageData := data[messageDataStart:messageDataEnd]
|
|
|
|
|
|
// 读取校验和(偏移9+dataLength到11+dataLength)
|
|
|
checksumStart := messageDataEnd
|
|
|
checksumEnd := checksumStart + 2
|
|
|
checksum := binary.BigEndian.Uint16(data[checksumStart:checksumEnd])
|
|
|
|
|
|
// 将校验和设置到header中(保持与Decode方法的一致性)
|
|
|
header.Set("checksum", checksum)
|
|
|
|
|
|
// 验证校验和
|
|
|
calculatedChecksum := uint16(0)
|
|
|
for _, b := range messageData {
|
|
|
calculatedChecksum += uint16(b)
|
|
|
}
|
|
|
if checksum != calculatedChecksum {
|
|
|
return nil, fmt.Errorf("checksum mismatch: expected %d, got %d", calculatedChecksum, checksum)
|
|
|
}
|
|
|
|
|
|
return messageData, nil
|
|
|
}
|