package nnet import ( "testing" protocolpkg "github.com/noahlann/nnet/pkg/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNNetProtocol_DecodeHeader_IncompleteData(t *testing.T) { p := NewNNetProtocol("1.0").(*NNetProtocol) // 测试数据不完整的情况(只有Magic) data := []byte("NNE") header, minBytes, err := p.DecodeHeader(data) assert.Error(t, err, "Expected error for incomplete data") assert.Nil(t, header) assert.Equal(t, 9, minBytes, "Expected min bytes to be 9") // 测试数据不完整的情况(只有Magic) data = []byte("NNET") header, minBytes, err = p.DecodeHeader(data) assert.NoError(t, err, "Should not error for partial data") assert.Nil(t, header, "Should return nil header when data is incomplete") assert.Equal(t, 9, minBytes, "Expected min bytes to be 9") // 测试数据不完整的情况(Magic + Version) data = []byte("NNET\x01") header, minBytes, err = p.DecodeHeader(data) assert.NoError(t, err, "Should not error for partial data") assert.Nil(t, header, "Should return nil header when data is incomplete") assert.Equal(t, 9, minBytes, "Expected min bytes to be 9") } func TestNNetProtocol_DecodeHeader_CompleteHeader(t *testing.T) { p := NewNNetProtocol("1.0").(*NNetProtocol) // 创建完整的帧头数据:Magic(4) + Version(1) + Length(4) = 9字节 data := make([]byte, 9) copy(data[0:4], []byte("NNET")) data[4] = 1 // Version // Length = 5 (数据部分长度) data[5] = 0 data[6] = 0 data[7] = 0 data[8] = 5 header, totalLength, err := p.DecodeHeader(data) require.NoError(t, err, "Expected no error") require.NotNil(t, header, "Expected header") assert.Equal(t, "NNET", header.Get("magic")) assert.Equal(t, byte(1), header.Get("version")) assert.Equal(t, uint32(5), header.Get("length")) // 完整消息长度 = 9(帧头) + 5(数据) + 2(校验和) = 16 assert.Equal(t, 16, totalLength, "Expected total length to be 16") } func TestNNetProtocol_DecodeBody_WithHeader(t *testing.T) { p := NewNNetProtocol("1.0").(*NNetProtocol) // 创建完整的消息数据 data := []byte("hello") packet, err := p.Encode(data, nil) require.NoError(t, err) // 先解析帧头 header, _, err := p.DecodeHeader(packet) require.NoError(t, err) require.NotNil(t, header) // 使用DecodeBody解析数据体(重用已解析的帧头) bodyBytes, err := p.DecodeBody(packet, header) require.NoError(t, err) assert.Equal(t, data, bodyBytes, "Expected body to match original data") } func TestNNetProtocol_DecodeBody_WithoutHeader(t *testing.T) { p := NewNNetProtocol("1.0").(*NNetProtocol) // 创建完整的消息数据 data := []byte("hello") packet, err := p.Encode(data, nil) require.NoError(t, err) // 使用DecodeBody解析数据体(header为nil,需要从data中解析) bodyBytes, err := p.DecodeBody(packet, nil) require.NoError(t, err) assert.Equal(t, data, bodyBytes, "Expected body to match original data") } type incrementalDecoder interface { DecodeHeader([]byte) (protocolpkg.FrameHeader, int, error) DecodeBody([]byte, protocolpkg.FrameHeader) ([]byte, error) } func TestNNetProtocol_IncrementalDecoder_Interface(t *testing.T) { p := NewNNetProtocol("1.0") // 验证NNetProtocol实现了IncrementalDecoder接口 var _ incrementalDecoder = p.(*NNetProtocol) // 测试增量解码的完整流程 data := []byte("hello") packet, err := p.Encode(data, nil) require.NoError(t, err) // 转换为IncrementalDecoder类型 incDecoder := p.(*NNetProtocol) // 1. 先解析帧头(增量解析) header, totalLength, err := incDecoder.DecodeHeader(packet) require.NoError(t, err) require.NotNil(t, header) assert.Equal(t, len(packet), totalLength, "Expected total length to match packet length") // 2. 使用已解析的帧头解析数据体(避免重复解析帧头) bodyBytes, err := incDecoder.DecodeBody(packet, header) require.NoError(t, err) assert.Equal(t, data, bodyBytes, "Expected body to match original data") }