You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

184 lines
4.4 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package request
import (
"context"
"testing"
"github.com/noahlann/nnet/pkg/protocol"
unpackerpkg "github.com/noahlann/nnet/pkg/unpacker"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRequest(t *testing.T) {
raw := []byte("test data")
req := New(raw, nil)
// 测试Raw
assert.Equal(t, raw, req.Raw(), "Expected raw data to match")
}
func TestRequestBody(t *testing.T) {
req := New([]byte("test"), nil).(*requestImpl)
// 测试Data初始为nil
assert.Nil(t, req.Data(), "Expected data to be nil initially")
// 设置Data
req.SetData("test data")
assert.Equal(t, "test data", req.Data(), "Expected data to be set")
}
func TestRequestBind(t *testing.T) {
type TestStruct struct {
Name string `json:"name"`
Value int `json:"value"`
}
// 创建测试数据
data := TestStruct{
Name: "test",
Value: 42,
}
req := New([]byte("test"), nil).(*requestImpl)
// 设置Data以便Bind可以工作
req.SetData(data)
// 测试Bind
var result TestStruct
err := req.Bind(&result)
require.NoError(t, err, "Expected no error when binding")
assert.Equal(t, data.Name, result.Name, "Expected name to match")
assert.Equal(t, data.Value, result.Value, "Expected value to match")
}
func TestRequestBindNoData(t *testing.T) {
req := New([]byte("test"), nil)
// 测试Bind without data
var result struct {
Name string `json:"name"`
}
err := req.Bind(&result)
assert.Error(t, err, "Expected error when no data available")
}
func TestRequestBindWithBody(t *testing.T) {
type TestStruct struct {
Name string `json:"name"`
Value int `json:"value"`
}
req := New([]byte("{}"), nil).(*requestImpl)
// 设置Data
data := TestStruct{
Name: "test",
Value: 42,
}
req.SetData(data)
// 测试Bind应该使用Data
var result TestStruct
err := req.Bind(&result)
require.NoError(t, err, "Expected no error when binding")
assert.Equal(t, data.Name, result.Name, "Expected name to match")
assert.Equal(t, data.Value, result.Value, "Expected value to match")
}
func TestRequestHeader(t *testing.T) {
req := New([]byte("test"), nil)
// 测试Header初始可能为nil
header := req.Header()
// Header可能为nil这是正常的无帧头协议
if header != nil {
// 如果header存在测试Get方法
_ = header.Get("test")
}
}
func TestRequestBodyBytes(t *testing.T) {
req := New([]byte("test"), nil).(*requestImpl)
// 测试DataBytes初始为空
dataBytes := req.DataBytes()
assert.Nil(t, dataBytes, "Expected dataBytes to be nil initially")
// 设置DataBytes
testData := []byte("test data bytes")
req.SetDataBytes(testData)
assert.Equal(t, testData, req.DataBytes(), "Expected dataBytes to match")
}
func TestRequestProtocolVersion(t *testing.T) {
mockProtocol := &mockProtocol{}
req := New([]byte("test"), mockProtocol)
// 测试ProtocolVersion
version := req.ProtocolVersion()
assert.Equal(t, "1.0", version, "Expected version to match")
// 测试nil protocol
req2 := New([]byte("test"), nil)
version2 := req2.ProtocolVersion()
assert.Equal(t, "", version2, "Expected empty version for nil protocol")
}
func TestRequestSetHeader(t *testing.T) {
req := New([]byte("test"), nil).(*requestImpl)
// 设置Header
newHeader := NewFrameHeader()
newHeader.Set("message_id", uint64(123))
newHeader.Set("type", "request")
req.SetHeader(newHeader)
// 验证Header
header := req.Header()
assert.NotNil(t, header, "Expected header to be non-nil")
assert.Equal(t, uint64(123), header.Get("message_id"), "Expected message ID to match")
assert.Equal(t, "request", header.Get("type"), "Expected type to match")
}
func TestRequestProtocol(t *testing.T) {
mockProtocol := &mockProtocol{}
req := New([]byte("test"), mockProtocol).(*requestImpl)
// 测试Protocol
assert.Equal(t, mockProtocol, req.Protocol(), "Expected protocol to match")
}
// mockProtocol 模拟协议
type mockProtocol struct{}
func (m *mockProtocol) Name() string {
return "test"
}
func (m *mockProtocol) Version() string {
return "1.0"
}
func (m *mockProtocol) HasHeader() bool {
return false
}
func (m *mockProtocol) Encode(data []byte, header protocol.FrameHeader) ([]byte, error) {
return data, nil
}
func (m *mockProtocol) Decode(data []byte) (protocol.FrameHeader, []byte, error) {
return nil, data, nil
}
func (m *mockProtocol) Handle(ctx context.Context, data []byte) ([]byte, error) {
return data, nil
}
func (m *mockProtocol) Unpacker() unpackerpkg.Unpacker {
return nil
}