You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

296 lines
9.6 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package nnet
import (
"context"
"encoding/binary"
"fmt"
"sync"
internalprotocol "github.com/noahlann/nnet/internal/protocol"
internalunpacker "github.com/noahlann/nnet/internal/unpacker"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
)
// NNetProtocol nnet协议实现
type NNetProtocol struct {
version string
unpacker unpackerpkg.Unpacker
once sync.Once
}
// NewNNetProtocol 创建nnet协议
func NewNNetProtocol(version string) protocolpkg.Protocol {
return &NNetProtocol{
version: version,
}
}
// Name 获取协议名称
func (p *NNetProtocol) Name() string {
return "nnet"
}
// Version 获取协议版本
func (p *NNetProtocol) Version() string {
return p.version
}
// HasHeader 是否有帧头nnet协议有帧头
func (p *NNetProtocol) HasHeader() bool {
return true
}
// Encode 编码数据
func (p *NNetProtocol) Encode(data []byte, header protocolpkg.FrameHeader) ([]byte, error) {
// nnet协议格式
// [Magic(4 bytes)][Version(1 byte)][Length(4 bytes)][Data(N bytes)][Checksum(2 bytes)]
magic := []byte("NNET")
versionByte := byte(1) // 版本1
if header != nil {
if versionVal := header.Get("version"); versionVal != nil {
if v, ok := versionVal.(byte); ok {
versionByte = v
}
}
}
dataLength := uint32(len(data))
// 计算校验和(简单实现:数据字节和)
checksum := uint16(0)
for _, b := range data {
checksum += uint16(b)
}
// 构建消息
packet := make([]byte, 0, 11+len(data))
packet = append(packet, magic...)
packet = append(packet, versionByte)
// 长度4字节大端序
lengthBytes := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBytes, dataLength)
packet = append(packet, lengthBytes...)
// 数据
packet = append(packet, data...)
// 校验和2字节大端序
checksumBytes := make([]byte, 2)
binary.BigEndian.PutUint16(checksumBytes, checksum)
packet = append(packet, checksumBytes...)
return packet, nil
}
// Decode 解码数据
// 优化当数据来自unpacker时数据已经完整可以跳过长度验证因为unpacker已经验证过
// 但为了保持接口的通用性,我们仍然进行基本的验证
func (p *NNetProtocol) Decode(data []byte) (protocolpkg.FrameHeader, []byte, error) {
// 检查最小长度Magic 4 + Version 1 + Length 4 + 最小数据 0 + Checksum 2 = 11
if len(data) < 11 {
return nil, nil, fmt.Errorf("invalid packet length: %d < 11", len(data))
}
// 检查Magic
if len(data) < 4 {
return nil, nil, fmt.Errorf("invalid packet: too short for magic")
}
magic := data[0:4]
if string(magic) != "NNET" {
return nil, nil, fmt.Errorf("invalid magic: %s", string(magic))
}
// 创建帧头
header := internalprotocol.NewFrameHeader()
header.Set("magic", string(magic))
// 读取版本
if len(data) < 5 {
return nil, nil, fmt.Errorf("invalid packet: too short for version")
}
version := data[4]
header.Set("version", version)
// 注意:这里不检查版本值,因为版本识别由版本识别器或服务器逻辑处理
// 这样可以支持多版本协议
// 读取长度字段偏移5-8
if len(data) < 9 {
return nil, nil, fmt.Errorf("invalid packet: too short for length field")
}
dataLength := binary.BigEndian.Uint32(data[5:9])
header.Set("length", dataLength)
// 计算预期的总长度Magic(4) + Version(1) + Length(4) + Data(dataLength) + Checksum(2)
expectedTotalLength := 9 + int(dataLength) + 2
// 优化如果数据长度正好等于预期长度说明数据来自unpacker已经完整可以跳过长度验证
// 否则,需要进行长度验证(数据可能不完整)
if len(data) != expectedTotalLength {
// 数据长度不匹配,可能是数据不完整或数据错误
if len(data) < expectedTotalLength {
return nil, nil, fmt.Errorf("invalid data length: expected %d, got %d", expectedTotalLength, len(data))
}
// 如果数据长度大于预期,可能是多个包,但我们只处理第一个包
// 这种情况应该由unpacker处理不应该到达这里
}
// 读取数据部分偏移9到9+dataLength
messageDataStart := 9
messageDataEnd := 9 + int(dataLength)
if len(data) < messageDataEnd {
return nil, nil, fmt.Errorf("invalid packet: data section incomplete")
}
messageData := data[messageDataStart:messageDataEnd]
// 读取校验和偏移9+dataLength到11+dataLength
checksumStart := messageDataEnd
checksumEnd := checksumStart + 2
if len(data) < checksumEnd {
return nil, nil, fmt.Errorf("invalid packet: checksum incomplete")
}
checksum := binary.BigEndian.Uint16(data[checksumStart:checksumEnd])
header.Set("checksum", checksum)
// 验证校验和
calculatedChecksum := uint16(0)
for _, b := range messageData {
calculatedChecksum += uint16(b)
}
if checksum != calculatedChecksum {
return nil, nil, fmt.Errorf("checksum mismatch: expected %d, got %d", calculatedChecksum, checksum)
}
return header, messageData, nil
}
// Handle 处理消息
func (p *NNetProtocol) Handle(ctx context.Context, data []byte) ([]byte, error) {
// nnet协议的处理逻辑
// 这里可以添加协议特定的处理逻辑
return data, nil
}
// Unpacker 获取协议的拆包器
// nnet协议使用LengthFieldUnpacker来处理粘包拆包
func (p *NNetProtocol) Unpacker() unpackerpkg.Unpacker {
p.once.Do(func() {
// nnet协议格式[Magic(4)][Version(1)][Length(4)][Data(N)][Checksum(2)]
// 长度字段在偏移5的位置Magic 4字节 + Version 1字节
// 长度字段是4字节表示Data部分的长度
// 总长度 = 5Magic+Version + 4Length字段 + Length数据长度 + 2Checksum
config := unpackerpkg.LengthFieldUnpacker{
LengthFieldOffset: 5, // Magic(4) + Version(1) = 5
LengthFieldLength: 4, // Length字段是4字节
LengthAdjustment: 2, // 需要加上Checksum(2字节)
InitialBytesToStrip: 0, // 不跳过任何字节,保留完整包
}
p.unpacker = internalunpacker.NewLengthFieldUnpacker(config)
})
return p.unpacker
}
// DecodeHeader 解码帧头(增量解析,即使数据不完整)
// 实现IncrementalDecoder接口支持在数据不完整时解析帧头
func (p *NNetProtocol) DecodeHeader(data []byte) (protocolpkg.FrameHeader, int, error) {
// nnet协议帧头格式[Magic(4)][Version(1)][Length(4)]
// 完整的帧头需要9字节Magic(4) + Version(1) + Length(4)
minHeaderLength := 9
// 检查最小长度
if len(data) < 4 {
return nil, minHeaderLength, fmt.Errorf("invalid packet: too short for magic, need at least 4 bytes")
}
// 检查Magic
magic := data[0:4]
if string(magic) != "NNET" {
return nil, minHeaderLength, fmt.Errorf("invalid magic: %s", string(magic))
}
// 如果数据不足9字节返回需要的字节数
if len(data) < minHeaderLength {
return nil, minHeaderLength, nil
}
// 创建帧头
header := internalprotocol.NewFrameHeader()
header.Set("magic", string(magic))
// 读取版本
version := data[4]
header.Set("version", version)
// 注意:这里不检查版本值,因为版本识别由版本识别器或服务器逻辑处理
// 这样可以支持多版本协议
// 读取长度字段
dataLength := binary.BigEndian.Uint32(data[5:9])
header.Set("length", dataLength)
// 计算完整消息需要的总字节数
// 总长度 = 9帧头 + dataLength数据 + 2Checksum
totalLength := minHeaderLength + int(dataLength) + 2
// 返回解析的帧头和完整消息需要的总字节数
return header, totalLength, nil
}
// DecodeBody 解码消息体(假设帧头已经解析)
// 实现IncrementalDecoder接口支持重用已解析的帧头
// 优化如果header不为nil可以跳过帧头解析直接解析数据体和校验和避免重复解析帧头
func (p *NNetProtocol) DecodeBody(data []byte, header protocolpkg.FrameHeader) ([]byte, error) {
// 如果header为nil需要从data中解析帧头回退到标准Decode方法
if header == nil {
_, body, err := p.Decode(data)
return body, err
}
// 优化从header中获取长度避免重复读取长度字段
lengthVal := header.Get("length")
if lengthVal == nil {
return nil, fmt.Errorf("header missing length field")
}
dataLength, ok := lengthVal.(uint32)
if !ok {
return nil, fmt.Errorf("invalid length field type in header")
}
// 计算预期的总长度Magic(4) + Version(1) + Length(4) + Data(dataLength) + Checksum(2)
expectedTotalLength := 9 + int(dataLength) + 2
// 验证数据长度数据应该来自unpacker已经完整
if len(data) != expectedTotalLength {
if len(data) < expectedTotalLength {
return nil, fmt.Errorf("invalid data length: expected %d, got %d", expectedTotalLength, len(data))
}
// 如果数据长度大于预期,可能是多个包,但我们只处理第一个包
// 这种情况应该由unpacker处理不应该到达这里
}
// 读取数据部分偏移9到9+dataLength
// 优化直接使用header中的长度信息避免重复读取长度字段
messageDataStart := 9
messageDataEnd := 9 + int(dataLength)
messageData := data[messageDataStart:messageDataEnd]
// 读取校验和偏移9+dataLength到11+dataLength
checksumStart := messageDataEnd
checksumEnd := checksumStart + 2
checksum := binary.BigEndian.Uint16(data[checksumStart:checksumEnd])
// 将校验和设置到header中保持与Decode方法的一致性
header.Set("checksum", checksum)
// 验证校验和
calculatedChecksum := uint16(0)
for _, b := range messageData {
calculatedChecksum += uint16(b)
}
if checksum != calculatedChecksum {
return nil, fmt.Errorf("checksum mismatch: expected %d, got %d", calculatedChecksum, checksum)
}
return messageData, nil
}