|
|
package nnet
|
|
|
|
|
|
import (
|
|
|
"testing"
|
|
|
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
func TestNNetProtocol_DecodeHeader_IncompleteData(t *testing.T) {
|
|
|
p := NewNNetProtocol("1.0").(*NNetProtocol)
|
|
|
|
|
|
// 测试数据不完整的情况(只有Magic)
|
|
|
data := []byte("NNE")
|
|
|
header, minBytes, err := p.DecodeHeader(data)
|
|
|
assert.Error(t, err, "Expected error for incomplete data")
|
|
|
assert.Nil(t, header)
|
|
|
assert.Equal(t, 9, minBytes, "Expected min bytes to be 9")
|
|
|
|
|
|
// 测试数据不完整的情况(只有Magic)
|
|
|
data = []byte("NNET")
|
|
|
header, minBytes, err = p.DecodeHeader(data)
|
|
|
assert.NoError(t, err, "Should not error for partial data")
|
|
|
assert.Nil(t, header, "Should return nil header when data is incomplete")
|
|
|
assert.Equal(t, 9, minBytes, "Expected min bytes to be 9")
|
|
|
|
|
|
// 测试数据不完整的情况(Magic + Version)
|
|
|
data = []byte("NNET\x01")
|
|
|
header, minBytes, err = p.DecodeHeader(data)
|
|
|
assert.NoError(t, err, "Should not error for partial data")
|
|
|
assert.Nil(t, header, "Should return nil header when data is incomplete")
|
|
|
assert.Equal(t, 9, minBytes, "Expected min bytes to be 9")
|
|
|
}
|
|
|
|
|
|
func TestNNetProtocol_DecodeHeader_CompleteHeader(t *testing.T) {
|
|
|
p := NewNNetProtocol("1.0").(*NNetProtocol)
|
|
|
|
|
|
// 创建完整的帧头数据:Magic(4) + Version(1) + Length(4) = 9字节
|
|
|
data := make([]byte, 9)
|
|
|
copy(data[0:4], []byte("NNET"))
|
|
|
data[4] = 1 // Version
|
|
|
// Length = 5 (数据部分长度)
|
|
|
data[5] = 0
|
|
|
data[6] = 0
|
|
|
data[7] = 0
|
|
|
data[8] = 5
|
|
|
|
|
|
header, totalLength, err := p.DecodeHeader(data)
|
|
|
require.NoError(t, err, "Expected no error")
|
|
|
require.NotNil(t, header, "Expected header")
|
|
|
|
|
|
assert.Equal(t, "NNET", header.Get("magic"))
|
|
|
assert.Equal(t, byte(1), header.Get("version"))
|
|
|
assert.Equal(t, uint32(5), header.Get("length"))
|
|
|
|
|
|
// 完整消息长度 = 9(帧头) + 5(数据) + 2(校验和) = 16
|
|
|
assert.Equal(t, 16, totalLength, "Expected total length to be 16")
|
|
|
}
|
|
|
|
|
|
func TestNNetProtocol_DecodeBody_WithHeader(t *testing.T) {
|
|
|
p := NewNNetProtocol("1.0").(*NNetProtocol)
|
|
|
|
|
|
// 创建完整的消息数据
|
|
|
data := []byte("hello")
|
|
|
packet, err := p.Encode(data, nil)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 先解析帧头
|
|
|
header, _, err := p.DecodeHeader(packet)
|
|
|
require.NoError(t, err)
|
|
|
require.NotNil(t, header)
|
|
|
|
|
|
// 使用DecodeBody解析数据体(重用已解析的帧头)
|
|
|
bodyBytes, err := p.DecodeBody(packet, header)
|
|
|
require.NoError(t, err)
|
|
|
assert.Equal(t, data, bodyBytes, "Expected body to match original data")
|
|
|
}
|
|
|
|
|
|
func TestNNetProtocol_DecodeBody_WithoutHeader(t *testing.T) {
|
|
|
p := NewNNetProtocol("1.0").(*NNetProtocol)
|
|
|
|
|
|
// 创建完整的消息数据
|
|
|
data := []byte("hello")
|
|
|
packet, err := p.Encode(data, nil)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 使用DecodeBody解析数据体(header为nil,需要从data中解析)
|
|
|
bodyBytes, err := p.DecodeBody(packet, nil)
|
|
|
require.NoError(t, err)
|
|
|
assert.Equal(t, data, bodyBytes, "Expected body to match original data")
|
|
|
}
|
|
|
|
|
|
type incrementalDecoder interface {
|
|
|
DecodeHeader([]byte) (protocolpkg.FrameHeader, int, error)
|
|
|
DecodeBody([]byte, protocolpkg.FrameHeader) ([]byte, error)
|
|
|
}
|
|
|
|
|
|
func TestNNetProtocol_IncrementalDecoder_Interface(t *testing.T) {
|
|
|
p := NewNNetProtocol("1.0")
|
|
|
|
|
|
// 验证NNetProtocol实现了IncrementalDecoder接口
|
|
|
var _ incrementalDecoder = p.(*NNetProtocol)
|
|
|
|
|
|
// 测试增量解码的完整流程
|
|
|
data := []byte("hello")
|
|
|
packet, err := p.Encode(data, nil)
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
// 转换为IncrementalDecoder类型
|
|
|
incDecoder := p.(*NNetProtocol)
|
|
|
|
|
|
// 1. 先解析帧头(增量解析)
|
|
|
header, totalLength, err := incDecoder.DecodeHeader(packet)
|
|
|
require.NoError(t, err)
|
|
|
require.NotNil(t, header)
|
|
|
assert.Equal(t, len(packet), totalLength, "Expected total length to match packet length")
|
|
|
|
|
|
// 2. 使用已解析的帧头解析数据体(避免重复解析帧头)
|
|
|
bodyBytes, err := incDecoder.DecodeBody(packet, header)
|
|
|
require.NoError(t, err)
|
|
|
assert.Equal(t, data, bodyBytes, "Expected body to match original data")
|
|
|
}
|