package server import ( "context" "testing" "time" "github.com/noahlann/nnet/internal/codec" codecpkg "github.com/noahlann/nnet/pkg/codec" ctxpkg "github.com/noahlann/nnet/pkg/context" protocolpkg "github.com/noahlann/nnet/pkg/protocol" requestpkg "github.com/noahlann/nnet/pkg/request" responsepkg "github.com/noahlann/nnet/pkg/response" ) // --- minimal fakes for Context/Request/Header type fakeHeader struct{ m map[string]interface{} } func (h *fakeHeader) Get(key string) interface{} { return h.m[key] } func (h *fakeHeader) Set(key string, value interface{}) { if h.m == nil { h.m = map[string]interface{}{} } h.m[key] = value } func (h *fakeHeader) Encode() ([]byte, error) { return nil, nil } func (h *fakeHeader) Decode([]byte) error { return nil } func (h *fakeHeader) Clone() requestpkg.FrameHeader { c := &fakeHeader{m: make(map[string]interface{}, len(h.m))} for k, v := range h.m { c.m[k] = v } return c } type fakeRequest struct{ header requestpkg.FrameHeader } func (r *fakeRequest) Raw() []byte { return nil } func (r *fakeRequest) Data() interface{} { return nil } func (r *fakeRequest) Bind(v interface{}) error { return nil } func (r *fakeRequest) Header() requestpkg.FrameHeader { return r.header } func (r *fakeRequest) DataBytes() []byte { return nil } func (r *fakeRequest) Protocol() protocolpkg.Protocol { return nil } func (r *fakeRequest) ProtocolVersion() string { return "" } type fakeConn struct{} func (f *fakeConn) ID() string { return "" } func (f *fakeConn) RemoteAddr() string { return "" } func (f *fakeConn) LocalAddr() string { return "" } func (f *fakeConn) Write([]byte) error { return nil } func (f *fakeConn) Close() error { return nil } type fakeCtx struct { ctx context.Context req requestpkg.Request vals map[string]interface{} conn ctxpkg.Connection } func (f *fakeCtx) Deadline() (time.Time, bool) { return time.Time{}, false } func (f *fakeCtx) Done() <-chan struct{} { return nil } func (f *fakeCtx) Err() error { return nil } func (f *fakeCtx) Value(key interface{}) interface{} { return nil } func (f *fakeCtx) Request() requestpkg.Request { return f.req } func (f *fakeCtx) Response() responsepkg.Response { return nil } func (f *fakeCtx) Connection() ctxpkg.Connection { return f.conn } func (f *fakeCtx) Set(k string, v interface{}) { if f.vals == nil { f.vals = map[string]interface{}{} } f.vals[k] = v } func (f *fakeCtx) Get(k string) interface{} { return f.vals[k] } func (f *fakeCtx) MustGet(k string) interface{} { return f.vals[k] } func (f *fakeCtx) GetString(k string) string { if v, ok := f.vals[k].(string); ok { return v } return "" } func (f *fakeCtx) GetInt(k string) int { if v, ok := f.vals[k].(int); ok { return v } return 0 } func (f *fakeCtx) GetBool(k string) bool { if v, ok := f.vals[k].(bool); ok { return v } return false } func newCtxWithHeaderCodec(name string) ctxpkg.Context { req := &fakeRequest{header: &fakeHeader{m: map[string]interface{}{"codec": name}}} return &fakeCtx{ctx: context.Background(), req: req, conn: &fakeConn{}} } // 测试用的resolver实现 type testHeaderResolver struct{} func (r *testHeaderResolver) ResolveForDecode(registry codecpkg.Registry, ctx ctxpkg.Context, raw []byte, header protocolpkg.FrameHeader) (codecpkg.Codec, error) { if ctx != nil && ctx.Request() != nil && ctx.Request().Header() != nil { if v := ctx.Request().Header().Get("codec"); v != nil { if name, ok := v.(string); ok && name != "" { return registry.Get(name) } } } return nil, nil } func (r *testHeaderResolver) ResolveForEncode(registry codecpkg.Registry, ctx ctxpkg.Context, data interface{}, header protocolpkg.FrameHeader) (codecpkg.Codec, error) { // 编码时也尝试从header获取 if header != nil { if v := header.Get("codec"); v != nil { if name, ok := v.(string); ok && name != "" { return registry.Get(name) } } } return nil, nil } type testRawDataResolver struct{} func (r *testRawDataResolver) ResolveForDecode(registry codecpkg.Registry, ctx ctxpkg.Context, raw []byte, header protocolpkg.FrameHeader) (codecpkg.Codec, error) { if len(raw) > 0 { b0 := raw[0] if len(raw) > 1 { bl := raw[len(raw)-1] if (b0 == '{' && bl == '}') || (b0 == '[' && bl == ']') { return registry.Get("json") } } } return nil, nil } func (r *testRawDataResolver) ResolveForEncode(registry codecpkg.Registry, ctx ctxpkg.Context, data interface{}, header protocolpkg.FrameHeader) (codecpkg.Codec, error) { // 编码时无法从原始数据判断,返回nil return nil, nil } func TestResolveCodec_HeaderPreferred(t *testing.T) { reg := codec.NewRegistry() _ = reg.Register("json", codec.NewJSONCodec()) chain := codecpkg.NewResolverChain(reg, "") // 添加一个测试resolver chain.Add(&testHeaderResolver{}) ctx := newCtxWithHeaderCodec("json") c, err := chain.ResolveForDecode(ctx, []byte(nil), ctx.Request().Header()) if err != nil { t.Fatalf("unexpected err: %v", err) } if c.Name() != "json" { t.Fatalf("want json, got %s", c.Name()) } } func TestResolveCodec_DefaultFallback(t *testing.T) { reg := codec.NewRegistry() chain := codecpkg.NewResolverChain(reg, "binary") c, err := chain.ResolveForDecode(nil, []byte(nil), nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if c.Name() != "binary" { t.Fatalf("want binary, got %s", c.Name()) } } func TestResolveCodec_JSONHeuristic(t *testing.T) { reg := codec.NewRegistry() _ = reg.Register("json", codec.NewJSONCodec()) chain := codecpkg.NewResolverChain(reg, "") // 添加一个测试resolver chain.Add(&testRawDataResolver{}) raw := []byte(`{"a":1}`) c, err := chain.ResolveForDecode(nil, raw, nil) if err != nil { t.Fatalf("unexpected err: %v", err) } if c.Name() != "json" { t.Fatalf("want json by heuristic, got %s", c.Name()) } }