You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
5.8 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package server
import (
"context"
"testing"
"time"
"github.com/noahlann/nnet/internal/codec"
codecpkg "github.com/noahlann/nnet/pkg/codec"
ctxpkg "github.com/noahlann/nnet/pkg/context"
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
requestpkg "github.com/noahlann/nnet/pkg/request"
responsepkg "github.com/noahlann/nnet/pkg/response"
)
// --- minimal fakes for Context/Request/Header
type fakeHeader struct{ m map[string]interface{} }
func (h *fakeHeader) Get(key string) interface{} { return h.m[key] }
func (h *fakeHeader) Set(key string, value interface{}) {
if h.m == nil {
h.m = map[string]interface{}{}
}
h.m[key] = value
}
func (h *fakeHeader) Encode() ([]byte, error) { return nil, nil }
func (h *fakeHeader) Decode([]byte) error { return nil }
func (h *fakeHeader) Clone() requestpkg.FrameHeader {
c := &fakeHeader{m: make(map[string]interface{}, len(h.m))}
for k, v := range h.m {
c.m[k] = v
}
return c
}
type fakeRequest struct{ header requestpkg.FrameHeader }
func (r *fakeRequest) Raw() []byte { return nil }
func (r *fakeRequest) Data() interface{} { return nil }
func (r *fakeRequest) Bind(v interface{}) error { return nil }
func (r *fakeRequest) Header() requestpkg.FrameHeader { return r.header }
func (r *fakeRequest) DataBytes() []byte { return nil }
func (r *fakeRequest) Protocol() protocolpkg.Protocol { return nil }
func (r *fakeRequest) ProtocolVersion() string { return "" }
type fakeConn struct{}
func (f *fakeConn) ID() string { return "" }
func (f *fakeConn) RemoteAddr() string { return "" }
func (f *fakeConn) LocalAddr() string { return "" }
func (f *fakeConn) Write([]byte) error { return nil }
func (f *fakeConn) Close() error { return nil }
type fakeCtx struct {
ctx context.Context
req requestpkg.Request
vals map[string]interface{}
conn ctxpkg.Connection
}
func (f *fakeCtx) Deadline() (time.Time, bool) { return time.Time{}, false }
func (f *fakeCtx) Done() <-chan struct{} { return nil }
func (f *fakeCtx) Err() error { return nil }
func (f *fakeCtx) Value(key interface{}) interface{} { return nil }
func (f *fakeCtx) Request() requestpkg.Request { return f.req }
func (f *fakeCtx) Response() responsepkg.Response { return nil }
func (f *fakeCtx) Connection() ctxpkg.Connection { return f.conn }
func (f *fakeCtx) Set(k string, v interface{}) {
if f.vals == nil {
f.vals = map[string]interface{}{}
}
f.vals[k] = v
}
func (f *fakeCtx) Get(k string) interface{} { return f.vals[k] }
func (f *fakeCtx) MustGet(k string) interface{} { return f.vals[k] }
func (f *fakeCtx) GetString(k string) string {
if v, ok := f.vals[k].(string); ok {
return v
}
return ""
}
func (f *fakeCtx) GetInt(k string) int {
if v, ok := f.vals[k].(int); ok {
return v
}
return 0
}
func (f *fakeCtx) GetBool(k string) bool {
if v, ok := f.vals[k].(bool); ok {
return v
}
return false
}
func newCtxWithHeaderCodec(name string) ctxpkg.Context {
req := &fakeRequest{header: &fakeHeader{m: map[string]interface{}{"codec": name}}}
return &fakeCtx{ctx: context.Background(), req: req, conn: &fakeConn{}}
}
// 测试用的resolver实现
type testHeaderResolver struct{}
func (r *testHeaderResolver) ResolveForDecode(registry codecpkg.Registry, ctx ctxpkg.Context, raw []byte, header protocolpkg.FrameHeader) (codecpkg.Codec, error) {
if ctx != nil && ctx.Request() != nil && ctx.Request().Header() != nil {
if v := ctx.Request().Header().Get("codec"); v != nil {
if name, ok := v.(string); ok && name != "" {
return registry.Get(name)
}
}
}
return nil, nil
}
func (r *testHeaderResolver) ResolveForEncode(registry codecpkg.Registry, ctx ctxpkg.Context, data interface{}, header protocolpkg.FrameHeader) (codecpkg.Codec, error) {
// 编码时也尝试从header获取
if header != nil {
if v := header.Get("codec"); v != nil {
if name, ok := v.(string); ok && name != "" {
return registry.Get(name)
}
}
}
return nil, nil
}
type testRawDataResolver struct{}
func (r *testRawDataResolver) ResolveForDecode(registry codecpkg.Registry, ctx ctxpkg.Context, raw []byte, header protocolpkg.FrameHeader) (codecpkg.Codec, error) {
if len(raw) > 0 {
b0 := raw[0]
if len(raw) > 1 {
bl := raw[len(raw)-1]
if (b0 == '{' && bl == '}') || (b0 == '[' && bl == ']') {
return registry.Get("json")
}
}
}
return nil, nil
}
func (r *testRawDataResolver) ResolveForEncode(registry codecpkg.Registry, ctx ctxpkg.Context, data interface{}, header protocolpkg.FrameHeader) (codecpkg.Codec, error) {
// 编码时无法从原始数据判断返回nil
return nil, nil
}
func TestResolveCodec_HeaderPreferred(t *testing.T) {
reg := codec.NewRegistry()
_ = reg.Register("json", codec.NewJSONCodec())
chain := codecpkg.NewResolverChain(reg, "")
// 添加一个测试resolver
chain.Add(&testHeaderResolver{})
ctx := newCtxWithHeaderCodec("json")
c, err := chain.ResolveForDecode(ctx, []byte(nil), ctx.Request().Header())
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if c.Name() != "json" {
t.Fatalf("want json, got %s", c.Name())
}
}
func TestResolveCodec_DefaultFallback(t *testing.T) {
reg := codec.NewRegistry()
chain := codecpkg.NewResolverChain(reg, "binary")
c, err := chain.ResolveForDecode(nil, []byte(nil), nil)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if c.Name() != "binary" {
t.Fatalf("want binary, got %s", c.Name())
}
}
func TestResolveCodec_JSONHeuristic(t *testing.T) {
reg := codec.NewRegistry()
_ = reg.Register("json", codec.NewJSONCodec())
chain := codecpkg.NewResolverChain(reg, "")
// 添加一个测试resolver
chain.Add(&testRawDataResolver{})
raw := []byte(`{"a":1}`)
c, err := chain.ResolveForDecode(nil, raw, nil)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if c.Name() != "json" {
t.Fatalf("want json by heuristic, got %s", c.Name())
}
}