|
|
package nnet
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"testing"
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
"github.com/stretchr/testify/require"
|
|
|
)
|
|
|
|
|
|
func TestNNetVersionIdentifier_Identify(t *testing.T) {
|
|
|
// 创建版本识别器
|
|
|
identifier := NewNNetVersionIdentifier(map[byte]string{
|
|
|
1: "1.0",
|
|
|
2: "2.0",
|
|
|
3: "3.0",
|
|
|
})
|
|
|
|
|
|
// 测试版本1识别
|
|
|
t.Run("Identify version 1.0", func(t *testing.T) {
|
|
|
// 创建版本1的数据包(Magic + Version(1) + Length(4) + Data + Checksum)
|
|
|
data := []byte("NNET")
|
|
|
data = append(data, byte(1)) // version
|
|
|
data = append(data, []byte{0, 0, 0, 5}...) // length
|
|
|
data = append(data, []byte("hello")...) // data
|
|
|
data = append(data, []byte{0, 0}...) // checksum (简化)
|
|
|
|
|
|
version, err := identifier.Identify(data, context.Background())
|
|
|
require.NoError(t, err)
|
|
|
assert.Equal(t, "1.0", version)
|
|
|
})
|
|
|
|
|
|
// 测试版本2识别
|
|
|
t.Run("Identify version 2.0", func(t *testing.T) {
|
|
|
data := []byte("NNET")
|
|
|
data = append(data, byte(2)) // version
|
|
|
data = append(data, []byte{0, 0, 0, 5}...) // length
|
|
|
data = append(data, []byte("hello")...) // data
|
|
|
data = append(data, []byte{0, 0}...) // checksum (简化)
|
|
|
|
|
|
version, err := identifier.Identify(data, context.Background())
|
|
|
require.NoError(t, err)
|
|
|
assert.Equal(t, "2.0", version)
|
|
|
})
|
|
|
|
|
|
// 测试数据不足
|
|
|
t.Run("Data too short", func(t *testing.T) {
|
|
|
data := []byte("NNE") // 只有3字节,不足5字节
|
|
|
|
|
|
version, err := identifier.Identify(data, context.Background())
|
|
|
require.Error(t, err)
|
|
|
assert.Empty(t, version)
|
|
|
})
|
|
|
|
|
|
// 测试无效Magic
|
|
|
t.Run("Invalid magic", func(t *testing.T) {
|
|
|
data := []byte("BAD!")
|
|
|
data = append(data, byte(1)) // version
|
|
|
|
|
|
version, err := identifier.Identify(data, context.Background())
|
|
|
require.Error(t, err)
|
|
|
assert.Empty(t, version)
|
|
|
})
|
|
|
|
|
|
// 测试未知版本
|
|
|
t.Run("Unknown version", func(t *testing.T) {
|
|
|
data := []byte("NNET")
|
|
|
data = append(data, byte(99)) // unknown version
|
|
|
data = append(data, []byte{0, 0, 0, 5}...) // length
|
|
|
|
|
|
version, err := identifier.Identify(data, context.Background())
|
|
|
require.Error(t, err)
|
|
|
assert.Empty(t, version)
|
|
|
})
|
|
|
}
|
|
|
|
|
|
func TestNNetVersionIdentifier_DefaultMapping(t *testing.T) {
|
|
|
// 测试默认映射
|
|
|
identifier := NewNNetVersionIdentifier(nil)
|
|
|
|
|
|
data := []byte("NNET")
|
|
|
data = append(data, byte(1)) // version
|
|
|
data = append(data, []byte{0, 0, 0, 5}...) // length
|
|
|
data = append(data, []byte("hello")...) // data
|
|
|
data = append(data, []byte{0, 0}...) // checksum (简化)
|
|
|
|
|
|
version, err := identifier.Identify(data, context.Background())
|
|
|
require.NoError(t, err)
|
|
|
assert.Equal(t, "1.0", version)
|
|
|
}
|
|
|
|