|
|
package unpacker
|
|
|
|
|
|
import (
|
|
|
"testing"
|
|
|
|
|
|
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
func TestFrameHeaderUnpacker(t *testing.T) {
|
|
|
config := unpackerpkg.FrameHeaderUnpacker{
|
|
|
HeaderLength: 4,
|
|
|
GetLength: nil, // 使用默认实现
|
|
|
MaxBufferSize: 1024,
|
|
|
}
|
|
|
u := NewFrameHeaderUnpacker(config)
|
|
|
|
|
|
// 创建测试数据:4字节头部(前4字节是长度)+ 数据
|
|
|
data := make([]byte, 8)
|
|
|
data[0] = 0
|
|
|
data[1] = 0
|
|
|
data[2] = 0
|
|
|
data[3] = 4 // 长度 = 4
|
|
|
copy(data[4:], []byte("test"))
|
|
|
|
|
|
messages, _, consumed, err := u.Unpack(data)
|
|
|
require.NoError(t, err, "Expected no error")
|
|
|
assert.Equal(t, 1, len(messages), "Expected 1 message")
|
|
|
assert.Equal(t, 8, len(messages[0]), "Expected 8 bytes message")
|
|
|
assert.Equal(t, len(data), consumed, "Expected consumed equals input data length")
|
|
|
}
|
|
|
|
|
|
func TestFrameHeaderUnpackerCustomGetLength(t *testing.T) {
|
|
|
config := unpackerpkg.FrameHeaderUnpacker{
|
|
|
HeaderLength: 2,
|
|
|
GetLength: func(header []byte) int {
|
|
|
if len(header) < 2 {
|
|
|
return 0
|
|
|
}
|
|
|
return int(header[0])<<8 | int(header[1])
|
|
|
},
|
|
|
MaxBufferSize: 1024,
|
|
|
}
|
|
|
u := NewFrameHeaderUnpacker(config)
|
|
|
|
|
|
// 创建测试数据:2字节头部(大端序)+ 数据
|
|
|
data := make([]byte, 6)
|
|
|
data[0] = 0
|
|
|
data[1] = 4 // 长度 = 4
|
|
|
copy(data[2:], []byte("test"))
|
|
|
|
|
|
messages, remaining, consumed, err := u.Unpack(data)
|
|
|
require.NoError(t, err, "Expected no error")
|
|
|
assert.Equal(t, 1, len(messages), "Expected 1 message")
|
|
|
assert.Equal(t, 6, len(messages[0]), "Expected 6 bytes message")
|
|
|
assert.Equal(t, 0, len(remaining), "Expected no remaining data")
|
|
|
assert.Equal(t, len(data), consumed, "Expected consumed equals input data length")
|
|
|
}
|
|
|
|
|
|
func TestFrameHeaderUnpackerPack(t *testing.T) {
|
|
|
config := unpackerpkg.FrameHeaderUnpacker{
|
|
|
HeaderLength: 4,
|
|
|
GetLength: nil,
|
|
|
MaxBufferSize: 1024,
|
|
|
}
|
|
|
u := NewFrameHeaderUnpacker(config)
|
|
|
|
|
|
// 测试打包
|
|
|
data := []byte("test")
|
|
|
packed, err := u.Pack(data)
|
|
|
require.NoError(t, err, "Expected no error")
|
|
|
assert.Equal(t, 8, len(packed), "Expected 8 bytes")
|
|
|
assert.Equal(t, byte(4), packed[3], "Expected length 4")
|
|
|
assert.Equal(t, "test", string(packed[4:]), "Expected 'test'")
|
|
|
}
|
|
|
|
|
|
func TestFrameHeaderUnpackerIncompleteHeader(t *testing.T) {
|
|
|
config := unpackerpkg.FrameHeaderUnpacker{
|
|
|
HeaderLength: 4,
|
|
|
GetLength: nil,
|
|
|
MaxBufferSize: 1024,
|
|
|
}
|
|
|
u := NewFrameHeaderUnpacker(config)
|
|
|
|
|
|
// 测试不完整的头部
|
|
|
data := []byte{0, 0, 0} // 只有3字节,不够4字节头部
|
|
|
messages, remaining, consumed, err := u.Unpack(data)
|
|
|
require.NoError(t, err, "Expected no error")
|
|
|
assert.Equal(t, 0, len(messages), "Expected 0 messages")
|
|
|
assert.Equal(t, 3, len(remaining), "Expected 3 bytes remaining")
|
|
|
assert.Equal(t, len(data), consumed, "Expected all data consumed when no complete message")
|
|
|
}
|
|
|
|
|
|
func TestFrameHeaderUnpackerBufferSizeLimit(t *testing.T) {
|
|
|
config := unpackerpkg.FrameHeaderUnpacker{
|
|
|
HeaderLength: 4,
|
|
|
GetLength: nil,
|
|
|
MaxBufferSize: 10, // 小缓冲区
|
|
|
}
|
|
|
u := NewFrameHeaderUnpacker(config)
|
|
|
|
|
|
// 创建超过缓冲区大小的数据
|
|
|
data := make([]byte, 20)
|
|
|
data[0] = 0
|
|
|
data[1] = 0
|
|
|
data[2] = 0
|
|
|
data[3] = 20 // 长度 = 20
|
|
|
|
|
|
_, _, _, err := u.Unpack(data)
|
|
|
assert.Error(t, err, "Expected error for buffer size exceeded")
|
|
|
}
|
|
|
|