|
|
package server
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/internal/codec"
|
|
|
codecpkg "github.com/noahlann/nnet/pkg/codec"
|
|
|
ctxpkg "github.com/noahlann/nnet/pkg/context"
|
|
|
protocolpkg "github.com/noahlann/nnet/pkg/protocol"
|
|
|
requestpkg "github.com/noahlann/nnet/pkg/request"
|
|
|
responsepkg "github.com/noahlann/nnet/pkg/response"
|
|
|
)
|
|
|
|
|
|
// --- minimal fakes for Context/Request/Header
|
|
|
|
|
|
type fakeHeader struct{ m map[string]interface{} }
|
|
|
|
|
|
func (h *fakeHeader) Get(key string) interface{} { return h.m[key] }
|
|
|
func (h *fakeHeader) Set(key string, value interface{}) {
|
|
|
if h.m == nil {
|
|
|
h.m = map[string]interface{}{}
|
|
|
}
|
|
|
h.m[key] = value
|
|
|
}
|
|
|
func (h *fakeHeader) Encode() ([]byte, error) { return nil, nil }
|
|
|
func (h *fakeHeader) Decode([]byte) error { return nil }
|
|
|
func (h *fakeHeader) Clone() requestpkg.FrameHeader {
|
|
|
c := &fakeHeader{m: make(map[string]interface{}, len(h.m))}
|
|
|
for k, v := range h.m {
|
|
|
c.m[k] = v
|
|
|
}
|
|
|
return c
|
|
|
}
|
|
|
|
|
|
type fakeRequest struct{ header requestpkg.FrameHeader }
|
|
|
|
|
|
func (r *fakeRequest) Raw() []byte { return nil }
|
|
|
func (r *fakeRequest) Data() interface{} { return nil }
|
|
|
func (r *fakeRequest) Bind(v interface{}) error { return nil }
|
|
|
func (r *fakeRequest) Header() requestpkg.FrameHeader { return r.header }
|
|
|
func (r *fakeRequest) DataBytes() []byte { return nil }
|
|
|
func (r *fakeRequest) Protocol() protocolpkg.Protocol { return nil }
|
|
|
func (r *fakeRequest) ProtocolVersion() string { return "" }
|
|
|
|
|
|
type fakeConn struct{}
|
|
|
|
|
|
func (f *fakeConn) ID() string { return "" }
|
|
|
func (f *fakeConn) RemoteAddr() string { return "" }
|
|
|
func (f *fakeConn) LocalAddr() string { return "" }
|
|
|
func (f *fakeConn) Write([]byte) error { return nil }
|
|
|
func (f *fakeConn) Close() error { return nil }
|
|
|
|
|
|
type fakeCtx struct {
|
|
|
ctx context.Context
|
|
|
req requestpkg.Request
|
|
|
vals map[string]interface{}
|
|
|
conn ctxpkg.Connection
|
|
|
}
|
|
|
|
|
|
func (f *fakeCtx) Deadline() (time.Time, bool) { return time.Time{}, false }
|
|
|
func (f *fakeCtx) Done() <-chan struct{} { return nil }
|
|
|
func (f *fakeCtx) Err() error { return nil }
|
|
|
func (f *fakeCtx) Value(key interface{}) interface{} { return nil }
|
|
|
func (f *fakeCtx) Request() requestpkg.Request { return f.req }
|
|
|
func (f *fakeCtx) Response() responsepkg.Response { return nil }
|
|
|
func (f *fakeCtx) Connection() ctxpkg.Connection { return f.conn }
|
|
|
func (f *fakeCtx) Set(k string, v interface{}) {
|
|
|
if f.vals == nil {
|
|
|
f.vals = map[string]interface{}{}
|
|
|
}
|
|
|
f.vals[k] = v
|
|
|
}
|
|
|
func (f *fakeCtx) Get(k string) interface{} { return f.vals[k] }
|
|
|
func (f *fakeCtx) MustGet(k string) interface{} { return f.vals[k] }
|
|
|
func (f *fakeCtx) GetString(k string) string {
|
|
|
if v, ok := f.vals[k].(string); ok {
|
|
|
return v
|
|
|
}
|
|
|
return ""
|
|
|
}
|
|
|
func (f *fakeCtx) GetInt(k string) int {
|
|
|
if v, ok := f.vals[k].(int); ok {
|
|
|
return v
|
|
|
}
|
|
|
return 0
|
|
|
}
|
|
|
func (f *fakeCtx) GetBool(k string) bool {
|
|
|
if v, ok := f.vals[k].(bool); ok {
|
|
|
return v
|
|
|
}
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
func newCtxWithHeaderCodec(name string) ctxpkg.Context {
|
|
|
req := &fakeRequest{header: &fakeHeader{m: map[string]interface{}{"codec": name}}}
|
|
|
return &fakeCtx{ctx: context.Background(), req: req, conn: &fakeConn{}}
|
|
|
}
|
|
|
|
|
|
// 测试用的resolver实现
|
|
|
type testHeaderResolver struct{}
|
|
|
|
|
|
func (r *testHeaderResolver) ResolveForDecode(registry codecpkg.Registry, ctx ctxpkg.Context, raw []byte, header protocolpkg.FrameHeader) (codecpkg.Codec, error) {
|
|
|
if ctx != nil && ctx.Request() != nil && ctx.Request().Header() != nil {
|
|
|
if v := ctx.Request().Header().Get("codec"); v != nil {
|
|
|
if name, ok := v.(string); ok && name != "" {
|
|
|
return registry.Get(name)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
func (r *testHeaderResolver) ResolveForEncode(registry codecpkg.Registry, ctx ctxpkg.Context, data interface{}, header protocolpkg.FrameHeader) (codecpkg.Codec, error) {
|
|
|
// 编码时也尝试从header获取
|
|
|
if header != nil {
|
|
|
if v := header.Get("codec"); v != nil {
|
|
|
if name, ok := v.(string); ok && name != "" {
|
|
|
return registry.Get(name)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
type testRawDataResolver struct{}
|
|
|
|
|
|
func (r *testRawDataResolver) ResolveForDecode(registry codecpkg.Registry, ctx ctxpkg.Context, raw []byte, header protocolpkg.FrameHeader) (codecpkg.Codec, error) {
|
|
|
if len(raw) > 0 {
|
|
|
b0 := raw[0]
|
|
|
if len(raw) > 1 {
|
|
|
bl := raw[len(raw)-1]
|
|
|
if (b0 == '{' && bl == '}') || (b0 == '[' && bl == ']') {
|
|
|
return registry.Get("json")
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
func (r *testRawDataResolver) ResolveForEncode(registry codecpkg.Registry, ctx ctxpkg.Context, data interface{}, header protocolpkg.FrameHeader) (codecpkg.Codec, error) {
|
|
|
// 编码时无法从原始数据判断,返回nil
|
|
|
return nil, nil
|
|
|
}
|
|
|
|
|
|
func TestResolveCodec_HeaderPreferred(t *testing.T) {
|
|
|
reg := codec.NewRegistry()
|
|
|
_ = reg.Register("json", codec.NewJSONCodec())
|
|
|
chain := codecpkg.NewResolverChain(reg, "")
|
|
|
// 添加一个测试resolver
|
|
|
chain.Add(&testHeaderResolver{})
|
|
|
ctx := newCtxWithHeaderCodec("json")
|
|
|
c, err := chain.ResolveForDecode(ctx, []byte(nil), ctx.Request().Header())
|
|
|
if err != nil {
|
|
|
t.Fatalf("unexpected err: %v", err)
|
|
|
}
|
|
|
if c.Name() != "json" {
|
|
|
t.Fatalf("want json, got %s", c.Name())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestResolveCodec_DefaultFallback(t *testing.T) {
|
|
|
reg := codec.NewRegistry()
|
|
|
chain := codecpkg.NewResolverChain(reg, "binary")
|
|
|
c, err := chain.ResolveForDecode(nil, []byte(nil), nil)
|
|
|
if err != nil {
|
|
|
t.Fatalf("unexpected err: %v", err)
|
|
|
}
|
|
|
if c.Name() != "binary" {
|
|
|
t.Fatalf("want binary, got %s", c.Name())
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func TestResolveCodec_JSONHeuristic(t *testing.T) {
|
|
|
reg := codec.NewRegistry()
|
|
|
_ = reg.Register("json", codec.NewJSONCodec())
|
|
|
chain := codecpkg.NewResolverChain(reg, "")
|
|
|
// 添加一个测试resolver
|
|
|
chain.Add(&testRawDataResolver{})
|
|
|
raw := []byte(`{"a":1}`)
|
|
|
c, err := chain.ResolveForDecode(nil, raw, nil)
|
|
|
if err != nil {
|
|
|
t.Fatalf("unexpected err: %v", err)
|
|
|
}
|
|
|
if c.Name() != "json" {
|
|
|
t.Fatalf("want json by heuristic, got %s", c.Name())
|
|
|
}
|
|
|
}
|