You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
3.9 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package nnet
import (
"testing"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNNetProtocol_DecodeHeader_IncompleteData(t *testing.T) {
p := NewNNetProtocol("1.0").(*NNetProtocol)
// 测试数据不完整的情况只有Magic
data := []byte("NNE")
header, minBytes, err := p.DecodeHeader(data)
assert.Error(t, err, "Expected error for incomplete data")
assert.Nil(t, header)
assert.Equal(t, 9, minBytes, "Expected min bytes to be 9")
// 测试数据不完整的情况只有Magic
data = []byte("NNET")
header, minBytes, err = p.DecodeHeader(data)
assert.NoError(t, err, "Should not error for partial data")
assert.Nil(t, header, "Should return nil header when data is incomplete")
assert.Equal(t, 9, minBytes, "Expected min bytes to be 9")
// 测试数据不完整的情况Magic + Version
data = []byte("NNET\x01")
header, minBytes, err = p.DecodeHeader(data)
assert.NoError(t, err, "Should not error for partial data")
assert.Nil(t, header, "Should return nil header when data is incomplete")
assert.Equal(t, 9, minBytes, "Expected min bytes to be 9")
}
func TestNNetProtocol_DecodeHeader_CompleteHeader(t *testing.T) {
p := NewNNetProtocol("1.0").(*NNetProtocol)
// 创建完整的帧头数据Magic(4) + Version(1) + Length(4) = 9字节
data := make([]byte, 9)
copy(data[0:4], []byte("NNET"))
data[4] = 1 // Version
// Length = 5 (数据部分长度)
data[5] = 0
data[6] = 0
data[7] = 0
data[8] = 5
header, totalLength, err := p.DecodeHeader(data)
require.NoError(t, err, "Expected no error")
require.NotNil(t, header, "Expected header")
assert.Equal(t, "NNET", header.Get("magic"))
assert.Equal(t, byte(1), header.Get("version"))
assert.Equal(t, uint32(5), header.Get("length"))
// 完整消息长度 = 9(帧头) + 5(数据) + 2(校验和) = 16
assert.Equal(t, 16, totalLength, "Expected total length to be 16")
}
func TestNNetProtocol_DecodeBody_WithHeader(t *testing.T) {
p := NewNNetProtocol("1.0").(*NNetProtocol)
// 创建完整的消息数据
data := []byte("hello")
packet, err := p.Encode(data, nil)
require.NoError(t, err)
// 先解析帧头
header, _, err := p.DecodeHeader(packet)
require.NoError(t, err)
require.NotNil(t, header)
// 使用DecodeBody解析数据体重用已解析的帧头
bodyBytes, err := p.DecodeBody(packet, header)
require.NoError(t, err)
assert.Equal(t, data, bodyBytes, "Expected body to match original data")
}
func TestNNetProtocol_DecodeBody_WithoutHeader(t *testing.T) {
p := NewNNetProtocol("1.0").(*NNetProtocol)
// 创建完整的消息数据
data := []byte("hello")
packet, err := p.Encode(data, nil)
require.NoError(t, err)
// 使用DecodeBody解析数据体header为nil需要从data中解析
bodyBytes, err := p.DecodeBody(packet, nil)
require.NoError(t, err)
assert.Equal(t, data, bodyBytes, "Expected body to match original data")
}
type incrementalDecoder interface {
DecodeHeader([]byte) (protocolpkg.FrameHeader, int, error)
DecodeBody([]byte, protocolpkg.FrameHeader) ([]byte, error)
}
func TestNNetProtocol_IncrementalDecoder_Interface(t *testing.T) {
p := NewNNetProtocol("1.0")
// 验证NNetProtocol实现了IncrementalDecoder接口
var _ incrementalDecoder = p.(*NNetProtocol)
// 测试增量解码的完整流程
data := []byte("hello")
packet, err := p.Encode(data, nil)
require.NoError(t, err)
// 转换为IncrementalDecoder类型
incDecoder := p.(*NNetProtocol)
// 1. 先解析帧头(增量解析)
header, totalLength, err := incDecoder.DecodeHeader(packet)
require.NoError(t, err)
require.NotNil(t, header)
assert.Equal(t, len(packet), totalLength, "Expected total length to match packet length")
// 2. 使用已解析的帧头解析数据体(避免重复解析帧头)
bodyBytes, err := incDecoder.DecodeBody(packet, header)
require.NoError(t, err)
assert.Equal(t, data, bodyBytes, "Expected body to match original data")
}