package nnet import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNNetVersionIdentifier_Identify(t *testing.T) { // 创建版本识别器 identifier := NewNNetVersionIdentifier(map[byte]string{ 1: "1.0", 2: "2.0", 3: "3.0", }) // 测试版本1识别 t.Run("Identify version 1.0", func(t *testing.T) { // 创建版本1的数据包(Magic + Version(1) + Length(4) + Data + Checksum) data := []byte("NNET") data = append(data, byte(1)) // version data = append(data, []byte{0, 0, 0, 5}...) // length data = append(data, []byte("hello")...) // data data = append(data, []byte{0, 0}...) // checksum (简化) version, err := identifier.Identify(data, context.Background()) require.NoError(t, err) assert.Equal(t, "1.0", version) }) // 测试版本2识别 t.Run("Identify version 2.0", func(t *testing.T) { data := []byte("NNET") data = append(data, byte(2)) // version data = append(data, []byte{0, 0, 0, 5}...) // length data = append(data, []byte("hello")...) // data data = append(data, []byte{0, 0}...) // checksum (简化) version, err := identifier.Identify(data, context.Background()) require.NoError(t, err) assert.Equal(t, "2.0", version) }) // 测试数据不足 t.Run("Data too short", func(t *testing.T) { data := []byte("NNE") // 只有3字节,不足5字节 version, err := identifier.Identify(data, context.Background()) require.Error(t, err) assert.Empty(t, version) }) // 测试无效Magic t.Run("Invalid magic", func(t *testing.T) { data := []byte("BAD!") data = append(data, byte(1)) // version version, err := identifier.Identify(data, context.Background()) require.Error(t, err) assert.Empty(t, version) }) // 测试未知版本 t.Run("Unknown version", func(t *testing.T) { data := []byte("NNET") data = append(data, byte(99)) // unknown version data = append(data, []byte{0, 0, 0, 5}...) // length version, err := identifier.Identify(data, context.Background()) require.Error(t, err) assert.Empty(t, version) }) } func TestNNetVersionIdentifier_DefaultMapping(t *testing.T) { // 测试默认映射 identifier := NewNNetVersionIdentifier(nil) data := []byte("NNET") data = append(data, byte(1)) // version data = append(data, []byte{0, 0, 0, 5}...) // length data = append(data, []byte("hello")...) // data data = append(data, []byte{0, 0}...) // checksum (简化) version, err := identifier.Identify(data, context.Background()) require.NoError(t, err) assert.Equal(t, "1.0", version) }