package nnet import ( "context" "fmt" protocolpkg "github.com/noahlann/nnet/pkg/protocol" ) // NNetVersionIdentifier NNet协议版本识别器 // 从协议帧头中读取版本字段(偏移4的位置,1字节) type NNetVersionIdentifier struct { // versionMapping 版本映射:byte值 -> 版本字符串 // 例如:{1: "1.0", 2: "2.0"} versionMapping map[byte]string } // NewNNetVersionIdentifier 创建NNet协议版本识别器 func NewNNetVersionIdentifier(versionMapping map[byte]string) protocolpkg.VersionIdentifier { if versionMapping == nil { // 默认映射 versionMapping = map[byte]string{ 1: "1.0", 2: "2.0", } } return &NNetVersionIdentifier{ versionMapping: versionMapping, } } // Identify 识别协议版本 func (i *NNetVersionIdentifier) Identify(data []byte, ctx context.Context) (string, error) { // NNet协议格式:[Magic(4)][Version(1)][Length(4)][Data(N)][Checksum(2)] // 版本字段在偏移4的位置(Magic 4字节之后) versionOffset := 4 minRequiredBytes := versionOffset + 1 // 检查数据长度 if len(data) < minRequiredBytes { return "", protocolpkg.NewError(fmt.Sprintf("data too short: need at least %d bytes, got %d", minRequiredBytes, len(data))) } // 检查Magic magic := string(data[0:4]) if magic != "NNET" { return "", protocolpkg.NewError(fmt.Sprintf("invalid magic: %s", magic)) } // 读取版本字段 versionByte := data[versionOffset] // 查找版本映射 version, ok := i.versionMapping[versionByte] if !ok { return "", protocolpkg.NewError(fmt.Sprintf("unknown version byte: %d", versionByte)) } return version, nil }