diff --git a/bilibili/live.go b/bilibili/live.go index 5fb566e..80d7c2d 100644 --- a/bilibili/live.go +++ b/bilibili/live.go @@ -70,7 +70,7 @@ func (l *LiveBilibili) preConnect() (url string, err error) { return "", err } switch cfg.Type { - case config.BilibiliTypeOfficial: + case config.TypeOfficial: err = l.initWebsocketInfo() if err != nil { logger.SLog.Errorf("获取wss信息失败, err:%+v", err) @@ -110,7 +110,7 @@ func (l *LiveBilibili) preConnect() (url string, err error) { msg_handler.NewGuardBuyOfficialHandler(l.roomInfo.RoomId), msg_handler.NewGiftOfficialHandler(l.roomInfo.RoomId), ) - case config.BilibiliTypeCustom: + case config.TypeCustom: url = cfg.Custom.Url l.registerMessageHandler( msg_handler.NewDanmakuHandler(l.roomInfo.RoomId), @@ -209,12 +209,12 @@ func (l *LiveBilibili) initWebsocketInfo() error { func (l *LiveBilibili) Init(conn *ws.NWebsocket) (err error) { cfg := config.Config.Bilibili switch cfg.Type { - case config.BilibiliTypeOfficial: + case config.TypeOfficial: if err = l.auth(conn); err != nil { logger.SLog.Error(err) return } - case config.BilibiliTypeCustom: + case config.TypeCustom: if err = l.joinRoom(conn); err != nil { logger.SLog.Error(err) return diff --git a/config-dev-sy.yml b/config-dev-sy.yml index f6f592f..2e86989 100644 --- a/config-dev-sy.yml +++ b/config-dev-sy.yml @@ -1,5 +1,5 @@ Bilibili: - Enabled: true + Enabled: false Type: custom Official: Api: https://live-open.biliapi.com diff --git a/config-dev.yml b/config-dev.yml index 9b720ff..1883409 100644 --- a/config-dev.yml +++ b/config-dev.yml @@ -1,5 +1,5 @@ Bilibili: - Enabled: true + Enabled: false Type: custom Official: Api: https://live-open.biliapi.com @@ -11,6 +11,14 @@ Bilibili: HeartbeatInterval: 30 UserId: 111222 RoomId: 8722013 +Douyu: + Enabled: true + Type: custom + Custom: + Url: wss://danmuproxy.douyu.com:8501/ + HeartbeatInterval: 45 + UserId: 111222 + RoomId: 2947432 # 10984327 Log: Console: Level: debug diff --git a/config.yml b/config.yml index f983189..ca3741d 100644 --- a/config.yml +++ b/config.yml @@ -1,10 +1,24 @@ Bilibili: - Enabled: true - Url: wss://broadcastlv.chat.bilibili.com:2245/sub + Enabled: false + Type: custom + Official: + Api: https://live-open.biliapi.com + AkId: 1rVL1YRSoii28LJ1O7KIQqqQ + AkSecret: 5USwZt2bygTIE8a2cXahZrzsdKtbxd + Custom: + Url: wss://broadcastlv.chat.bilibili.com:2245/sub GetRoomUrl: https://api.live.bilibili.com/room/v1/Room/room_init?id= HeartbeatInterval: 30 UserId: 111222 RoomId: 8722013 +Douyu: + Enabled: true + Type: custom + Custom: + Url: wss://danmuproxy.douyu.com:8501/ + HeartbeatInterval: 45 + UserId: 111222 + RoomId: 10984327 # Log: Console: Level: info @@ -21,7 +35,7 @@ Log: Compress: true # 压缩日志 Kafka: Danmaku: - Addr: ["127.0.0.1:9093"] + Addr: [ "127.0.0.1:9093" ] Topic: "danmaku" Gift: Addr: [ "127.0.0.1:9093" ] diff --git a/config/config.go b/config/config.go index f6dcfe0..885107f 100644 --- a/config/config.go +++ b/config/config.go @@ -10,11 +10,11 @@ import ( var Config config -type BilibiliType string +type ConnectType string const ( - BilibiliTypeOfficial BilibiliType = "official" - BilibiliTypeCustom BilibiliType = "custom" + TypeOfficial ConnectType = "official" + TypeCustom ConnectType = "custom" ) type ( @@ -24,8 +24,8 @@ type ( } config struct { Bilibili struct { - Enabled bool // 是否启用 - Type BilibiliType // 类型 + Enabled bool // 是否启用 + Type ConnectType // 类型 Official struct { Api string // API 地址 AkId string // accessKeyId @@ -40,6 +40,16 @@ type ( UserId int64 // 用于连接的userId,0则随机生成 HeartbeatInterval time.Duration // 心跳间隔 单位s } + Douyu struct { + Enabled bool // 是否启用 + Type ConnectType // 类型 + Custom struct { + Url string // 弹幕服务器url + } + RoomId int64 // 待连接roomId + UserId int64 // 用于连接的userId,0则随机生成 + HeartbeatInterval time.Duration // 心跳间隔 单位s + } // Log 日志配置 Log struct { File logger.FileConfig diff --git a/douyu/codec.go b/douyu/codec.go new file mode 100644 index 0000000..3cfac72 --- /dev/null +++ b/douyu/codec.go @@ -0,0 +1,57 @@ +package douyu + +import ( + "bytes" + "fmt" + "github.com/pkg/errors" + "live-gateway/douyu/stt" + "live-gateway/ws" +) + +type CodecDouyu struct { +} + +func NewCodecDouyu() ws.Codec { + return &CodecDouyu{} +} + +// Encode encodes TypedData to douyu.WsEntry +func (c *CodecDouyu) Encode(v interface{}) (interface{}, error) { + data, ok := v.(TypedData) + if !ok { + return nil, errors.New("[Codec-Douyu] 写入值类型必须实现 douyu.TypedData 接口") + } + + bf := bytes.NewBuffer([]byte{}) + bf.WriteString(fmt.Sprintf("type@=%s/", data.DataType())) + + sttData, err := stt.Marshal(v) + if err != nil { + return nil, err + } + bf.Write(sttData) + + resp := &WsEntry{ + data: bf.Bytes(), + msgType: TypeMessageToServer, + } + //logger.SLog.Debugf("发送消息: %s", string(resp.data)) + return resp, nil +} + +func (c *CodecDouyu) Decode(customEntry interface{}) (interface{}, error) { + // 处理data + entry, ok := customEntry.(*WsEntry) + if !ok { + return nil, errors.New(fmt.Sprintf("[Codec-Douyu] 写入值类型必须为%T", WsEntry{})) + } + var typed struct { + Type string `stt:"type"` + } + err := stt.Unmarshal(entry.data, &typed) + if err != nil { + return nil, errors.New("[Codec-Douyu] 获取type字段失败") + } + entry.dataType = typed.Type + return entry, nil +} diff --git a/douyu/codec_test.go b/douyu/codec_test.go new file mode 100644 index 0000000..0f69ed7 --- /dev/null +++ b/douyu/codec_test.go @@ -0,0 +1,33 @@ +package douyu + +import ( + "fmt" + jsoniter "github.com/json-iterator/go" + "github.com/mitchellh/mapstructure" + "live-gateway/douyu/msg_handler" + "testing" +) + +func TestDecodeData(t *testing.T) { + codec := NewCodecDouyu() + + // type@=chatmsg/rid@=58839/ct@=8/hashid@=9LA18ePx4dqW/nn@=test/txt@=666/cid@=1111/ic@=icon/sahf@=0/level@=1/nl@=0/nc@=0/cmt@=0/gt@=0/col@=0/rg@=0/pg@=0/dlv@=0/dc@=0/bdlv@=0/gatin@=0/chtin@=0/repin@=0/bnn@=test/bl@=0/brid@=58839/hc@=0/ol@=0/rev@=0/hl@=0/ifs@=0/p2p@=0/el@=eid@AA=1@ASetp@AA=1@ASsc@AA=1@AS/ + // type@=chatmsg/rid@=3484/ct@=14/uid@=53882869/nn@=.........................../txt@=......................../cid@=0e7aca52aede49eb2e9a2a0100000000/ic@=avatar@Sdefault@S11/level@=30/sahf@=0/col@=1/cst@=1657543712943/bnn@=....../bl@=18/brid@=3484/hc@=709494ed025488b6f77db71902e124cd/diaf@=1/hl@=1/ifs@=1/el@=/lk@=/fl@=18/dms@=3/pdg@=54/pdk@=87/ext@=/. + var wsEntry interface{} = &WsEntry{ + data: []byte("type@=chatmsg/rid@=3484/ct@=14/uid@=53882869/nn@=.........................../txt@=......................../cid@=0e7aca52aede49eb2e9a2a0100000000/ic@=avatar@Sdefault@S11/level@=30/sahf@=0/col@=1/cst@=1657543712943/bnn@=....../bl@=18/brid@=3484/hc@=709494ed025488b6f77db71902e124cd/diaf@=1/hl@=1/ifs@=1/el@=/lk@=/fl@=18/dms@=3/pdg@=54/pdk@=87/ext@=/."), + } + var err error + wsEntry, err = codec.Decode(wsEntry) + if err != nil { + panic(err) + } + var dm msg_handler.MsgDanmaku + err = mapstructure.Decode(wsEntry.(*WsEntry).MapData, &dm) + + var cmd struct { + CMD string `json:"cmd"` + } + jsoniter.Unmarshal([]byte(`{"cmd":"stringcmd"}`), &cmd) + + fmt.Printf("%+v", wsEntry) +} diff --git a/douyu/live.go b/douyu/live.go new file mode 100644 index 0000000..b4eb122 --- /dev/null +++ b/douyu/live.go @@ -0,0 +1,155 @@ +package douyu + +import ( + "git.noahlan.cn/northlan/ntools-go/logger" + "github.com/pkg/errors" + "live-gateway/config" + "live-gateway/douyu/msg_handler" + "live-gateway/live" + "live-gateway/ws" + "time" +) + +type MsgHandler interface { + TypedData + HandlerMessage(data []byte) +} + +// 实现 live.Handler 接口 +var _ live.Handler = (*LiveDouyu)(nil) + +type ( + LiveDouyu struct { + *live.Live + sequenceId uint32 + + msgHandlerMapper map[string]MsgHandler + + //loginResChan chan struct{} // login res + entered chan struct{} // join group + } +) + +func NewLiveDouyu() *LiveDouyu { + bl := &LiveDouyu{ + msgHandlerMapper: make(map[string]MsgHandler, 6), + entered: make(chan struct{}), + } + + l := live.NewLive( + live.WithWsOptions( + ws.WithPacker(NewPackDouyu()), + ws.WithCodec(NewCodecDouyu()), + ), + ) + l.PreConnect(bl.preConnect) + + l.Init(bl.Init) + l.Handler(bl) + + bl.Live = l + return bl +} + +func (l *LiveDouyu) registerMessageHandler(h ...MsgHandler) { + for _, handler := range h { + l.msgHandlerMapper[handler.DataType()] = handler + } +} + +func (l *LiveDouyu) preConnect() (url string, err error) { + cfg := config.Config.Douyu + url = cfg.Custom.Url + // 注册监听器 + l.registerMessageHandler( + msg_handler.NewMsgDanmakuHandler(cfg.RoomId), + msg_handler.NewMsgGiftHandler(cfg.RoomId), + ) + return +} + +func (l *LiveDouyu) Init(conn *ws.NWebsocket) (err error) { + cfg := config.Config.Douyu + switch cfg.Type { + case config.TypeOfficial: + case config.TypeCustom: + // 1. login + err = conn.SendBinaryMessage(&MsgLoginReq{ + RoomId: cfg.RoomId, + //Dfl: "sn@AA=105@ASss@AA=1", + //Dfl: MsgLoginReqDfl{}, + // sn@=105/ss@=1 + Username: "auto_EHUwJCggl7", + Uid: cfg.UserId, + Ver: "20190610", + AVer: "218101901", + Ct: 0, + }) + if err != nil { + return errors.Wrap(err, "发送login消息到wss失败") + } + logger.SLog.Debug("发送登录消息") + // 2. wait login resp + <-l.entered + + err = conn.SendBinaryMessage(&MsgJoinGroup{ + RoomId: cfg.RoomId, + GId: -9999, // 海量弹幕模式 + }) + if err != nil { + return errors.Wrap(err, "发送joingroup消息到wss失败") + } + logger.SLog.Debug("发送入组消息") + } + // 心跳 + go l.heartbeat(conn, cfg.HeartbeatInterval*time.Second) + logger.SLog.Debug("开始心跳...") + return +} + +func (l *LiveDouyu) heartbeat(conn *ws.NWebsocket, t time.Duration) { + hb := func(conn *ws.NWebsocket) { + //logger.SLog.Debug("heartbeat !!!") + data := &MsgHeartbeat{} + l.sequenceId++ + err := conn.SendBinaryMessage(data) + if err != nil { + return + } + } + hb(conn) + + ticker := time.NewTicker(t) + defer ticker.Stop() + for { + select { + case <-ticker.C: + hb(conn) + } + } +} + +func (l *LiveDouyu) HandlerMessage(v interface{}) { + entry, ok := v.(*WsEntry) + if !ok { + logger.SLog.Warnf("读取消息错误, 数据类型不匹配 %T %T", v, &WsEntry{}) + return + } + //logger.SLog.Debugf("接收消息 %s", string(entry.data)) + switch entry.dataType { + case "loginres": + go func() { + select { + case l.entered <- struct{}{}: + } + }() + default: + handler, ok := l.msgHandlerMapper[entry.dataType] + if !ok { + return + } + handler.HandlerMessage(entry.data) + //logger.SLog.Infof("接收消息 %v\n%v", string(entry.data), entry.MapData) + } + +} diff --git a/douyu/message.go b/douyu/message.go new file mode 100644 index 0000000..4631ce9 --- /dev/null +++ b/douyu/message.go @@ -0,0 +1,187 @@ +package douyu + +// 弹幕服务器端相应消息类型 +const ( + // TypeLoginRes 登录响应消息 字段说明 + // type 表示为“登出”消息,固定为 loginres + // userid 用户 ID + // roomgroup 房间权限组 + // pg 平台权限组 + // sessionid 会话 ID + // username 用户名 + // nickname 用户昵称 + // live_stat 直播状态 + // is_illegal 是否合规 + // is_signed 是否已在房间签到 + // signed_count 日总签到次数 + // npv 是否需要手机验证 + // best_dlev 最高酬勤等级 + // cur_lev 酬勤等级 + TypeLoginRes = "loginres" + + // TypeNewGift 赠送礼物消息 + // 用户在房间赠送礼物时,服务端发送此消息给客户端。完整的数据部分应包含的 字段如下: + // 字段说明 + // type 表示为“赠送礼物”消息,固定为 dgb + // rid 房间 ID + // gid 弹幕分组 ID + // gfid 礼物 id + // gs 礼物显示样式 + // uid 用户 id + // nn 用户昵称 + // str 用户战斗力 + // level 用户等级 + // dw 主播体重 + // gfcnt 礼物个数:默认值 1(表示 1 个礼物) + // hits 礼物连击次数:默认值 1(表示 1 连击) + // dlv 酬勤头衔:默认值 0(表示没有酬勤) + // dc 酬勤个数:默认值 0(表示没有酬勤数量) + // bdl 全站最高酬勤等级:默认值 0(表示全站都没有酬勤) + // rg 房间身份组:默认值 1(表示普通权限用户) + // pg 平台身份组:默认值 1(表示普通权限用户) + // rpid 红包 id:默认值 0(表示没有红包) + // slt 红包开启剩余时间:默认值 0(表示没有红包) + // elt 红包销毁剩余时间:默认值 0(表示没有红包) + TypeNewGift = "dgb" + + // TypeUserEnter 特殊用户进房通知消息 + // 具有特殊属性的用户进入直播间时,服务端发送此消息至客户端。完整的数据部 分应包含的字段如下: + // 字段说明 + // type 表示为“用户进房通知”消息,固定为 uenter + // rid 房间 ID + // gid 弹幕分组 ID + // uid 用户 ID + // nn 用户昵称 + // str 战斗力 + // level 新用户等级 + // gt 礼物头衔:默认值 0(表示没有头衔) + // rg 房间权限组:默认值 1(表示普通权限用户) + // pg 平台身份组:默认值 1(表示普通权限用户) + // dlv 酬勤等级:默认值 0(表示没有酬勤) + // dc 酬勤数量:默认值 0(表示没有酬勤数量) + // bdlv 最高酬勤等级:默认值 0(表示全站都没有酬勤) + TypeUserEnter = "uenter" + + // TypeNewDeserve 用户赠送酬勤通知消息 + // 用户赠送酬勤时,服务端发送此消息至客户端。完整的数据部分应包含的字段如 下: + // 字段说明 + // type 表示为“赠送酬勤通知”消息,固定为 bc_buy_deserve + // rid 房间 ID + // gid 弹幕分组 ID + // level 用户等级 + // cnt 赠送数量 + // hits 赠送连击次数 + // lev 酬勤等级 + // sui 用户信息序列化字符串,详见下文。注意,此处为嵌套序列化,需注 意符号的转义变换。(转义符号参见 2.2 序列化) + TypeNewDeserve = "bc_buy_deserve" + + // TypeLiveStatusChange 房间开关播提醒消息 + // 房间开播提醒主要部分应包含的字段如下: + // 字段说明 + // type 表示为“房间开播提醒”消息,固定为 rss + // rid 房间 id + // gid 弹幕分组 id + // ss 直播状态,0-没有直播,1-正在直播 + // code 类型 + // rt 开关播原因:0-主播开关播,其他值-其他原因 + // notify 通知类型 + // endtime 关播时间(仅关播时有效) + TypeLiveStatusChange = "rss" + + TypeRankList = "ranklist" // 广播排行榜消息 + + // TypeMsgToAll 超级弹幕消息(如,火箭弹幕) + // 超级弹幕主要部分应包含的字段如下: + // 字段说明 + // type 表示为“超级弹幕”消息,固定为 ssd + // rid 房间 id + // gid 弹幕分组 id + // sdid 超级弹幕 id + // trid 跳转房间 id + // content 超级弹幕的内容 + TypeMsgToAll = "ssd" + + // TypeMsgToRoom 房间内礼物广播 + // 房间内赠送礼物成功后效果主要部分应包含的字段如下: + // 字段说明 + // type 表示为“房间内礼物广播”,固定为 spbc + // rid 房间 id + // gid 弹幕分组 id + // sn 赠送者昵称 + // dn 受赠者昵称 + // gn 礼物名称 + // gc 礼物数量 + // drid 赠送房间 + // gs 广播样式 + // gb 是否有礼包(0-无礼包,1-有礼包) + // es 广播展现样式(1-火箭,2-飞机) + // gfid 礼物 id + // eid 特效 id + TypeMsgToRoom = "spbc" + + // TypeNewRedPacket 房间用户抢红包 + // 房间赠送礼物成功后效果(赠送礼物效果,连击数)主要部分应包含的字段如下: + // 字段说明 + // type 表示“房间用户抢红包”信息,固定为 ggbb + // rid 房间 id + // gid 弹幕分组 id + // sl 抢到的鱼丸数量 + // sid 礼包产生者 id + // did 抢礼包者 id + // snk 礼包产生者昵称 + // dnk 抢礼包者昵称 + // rpt 礼包类型 + TypeNewRedPacket = "ggbb" + + // TypeRoomRankChange 房间内top10变化消息 + // 房间内 top10 排行榜变化后,广播。主要部分应包含的字段如下: + // 字段说明 + // type 表示为“房间 top10 排行榜变换”,固定为 rankup + // rid 房间 id + // gid 弹幕分组 id + // uid 用户 id + // drid 目标房间 id + // rt 房间所属栏目类型 + // bt 广播类型:1-房间内广播,2-栏目广播,4-全站广播 + // sz 展示区域:1-聊天区展示,2-flash 展示,3-都显示 + // nk 用户昵称 + // rkt top10 榜的类型 1-周榜 2-总榜 4-日榜 + // rn 上升后的排名 + TypeRoomRankChange = "rankup" +) + +// 消息体类型 +const ( + TypeMessageToServer uint16 = 689 + TypeMessageFromServer = 690 +) + +const EndCharacter byte = 0 + +// TagName 序列化tag +const TagName = "stt" + +type TypedData interface { + DataType() string +} + +// WsEntry douyu websocket 结构 小端序 +type WsEntry struct { + // 消息正文 斗鱼STT协议 外部使用工具将 struct 转 map + // 键 key 和值 value 直接采用‘@=’分割 + // 数组采用‘/’分割 + // 如果 key 或者 value 中含有字符‘/’,则使用‘@S’转义 + // 如果 key 或者 value 中含有字符‘@’,使用‘@A’转义举例: + // (1) 多个键值对数据:key1@=value1/key2@=value2/key3@=value3/ + // (2) 数组数据:value1/value2/value3/ + data []byte + dataType string + + // msgType 消息类型,2字节小端整数,表示消息类型。取值如下: + // 689 客户端发送给弹幕服务器的文本格式数据 + // 690 弹幕服务器发送给客户端的文本格式数据。 + msgType uint16 + secret byte // 加密字段,1字节,暂时未用,默认为0 + reserved byte // 保留字段,1字节,暂时未用,默认为0 + //ending byte // 结尾字段,1字节 必须'\0' +} diff --git a/douyu/msg_handler/common.go b/douyu/msg_handler/common.go new file mode 100644 index 0000000..2a761fc --- /dev/null +++ b/douyu/msg_handler/common.go @@ -0,0 +1,9 @@ +package msg_handler + +// FansMedal 斗鱼粉丝牌 +type FansMedal struct { + MedalName string `stt:"bnn"` // 粉丝勋章名称 + MedalLevel int64 `stt:"bl"` // 粉丝勋章等级 + MedalRoomId int64 `stt:"brid"` // 粉丝勋章房间ID + Hc string `stt:"hc"` // 粉丝勋章校验码 +} diff --git a/douyu/msg_handler/msg_danmaku.go b/douyu/msg_handler/msg_danmaku.go new file mode 100644 index 0000000..f2de812 --- /dev/null +++ b/douyu/msg_handler/msg_danmaku.go @@ -0,0 +1,72 @@ +package msg_handler + +import ( + "git.noahlan.cn/northlan/ntools-go/kafka" + "git.noahlan.cn/northlan/ntools-go/logger" + "live-gateway/config" + "live-gateway/douyu/stt" + pbMq "live-gateway/pb/mq" + pbVars "live-gateway/pb/vars" + kfk "live-gateway/pkg/kafka" + "strconv" +) + +// MsgDanmaku 用户在房间发送弹幕时,服务端发此消息给客户端,完整的数据部分应包含的字段如下(挑选有用的) +type ( + MsgDanmaku struct { + RoomId int64 `stt:"rid"` // 房间ID + Uid int64 `stt:"uid"` // 用户ID + Nickname string `stt:"nn"` // 昵称 + Txt string `stt:"txt"` // 弹幕 + Level int32 `stt:"level"` // 用户等级 + Rg int32 `stt:"rg"` // 房间权限组,默认值1(表示普通权限用户) UP:5 + Avatar string `stt:"ic"` // 头像地址,后需要添加 _big.jpg 或 _small.jpg 头像获取 https://apic.douyucdn.cn/upload/ + FansMedal // 粉丝勋章 + Ol int32 `stt:"ol"` // 主播等级 + Cid string `stt:"cid"` // 弹幕唯一ID + Cst int64 `stt:"cst"` // 时间戳 ms + Ct int32 `stt:"ct"` // 客户端类型,默认值0 + Sahf string `stt:"sahf"` // 扩展字段,一般不用 + } + MsgDanmakuHandler struct { + producer *kafka.Producer + liveRoomId int64 + } +) + +func NewMsgDanmakuHandler(liveRoomId int64) *MsgDanmakuHandler { + cfg := config.Config.Kafka.Danmaku + return &MsgDanmakuHandler{ + producer: kafka.NewKafkaProducer(kfk.DefaultProducerConfig, cfg.Addr, cfg.Topic), + liveRoomId: liveRoomId, + } +} + +func (m *MsgDanmakuHandler) DataType() string { + return "chatmsg" +} + +func (m *MsgDanmakuHandler) HandlerMessage(data []byte) { + var ret MsgDanmaku + err := stt.Unmarshal(data, &ret) + if err != nil { + return + } + logger.SLog.Debugf("%s 说: %s", ret.Nickname, ret.Txt) + + dmMsg := &pbMq.MqDanmaku{ + Platform: pbVars.Platform_name[int32(pbVars.Platform_Douyu)], + LiveRoomId: m.liveRoomId, + Uid: ret.Uid, + Uname: ret.Nickname, + Avatar: "https://apic.douyucdn.cn/upload/" + ret.Avatar + "_small.jpg", // TODO 暂时组合,应该保留原始,然后客户端配置 + Msg: ret.Txt, + MsgId: ret.Cid, + Timestamp: ret.Cst, + + FansMedalWearingStatus: ret.FansMedal.MedalRoomId == m.liveRoomId, + FansMedalName: ret.FansMedal.MedalName, + FansMedalLevel: ret.FansMedal.MedalLevel, + } + _ = m.producer.SendMessageAsync(dmMsg, strconv.FormatInt(ret.Uid, 10)) +} diff --git a/douyu/msg_handler/msg_gift.go b/douyu/msg_handler/msg_gift.go new file mode 100644 index 0000000..ea2b9e2 --- /dev/null +++ b/douyu/msg_handler/msg_gift.go @@ -0,0 +1,209 @@ +package msg_handler + +import ( + "fmt" + "git.noahlan.cn/northlan/ntools-go/kafka" + "git.noahlan.cn/northlan/ntools-go/logger" + jsoniter "github.com/json-iterator/go" + "io" + "live-gateway/config" + "live-gateway/douyu/stt" + kfk "live-gateway/pkg/kafka" + "net/http" + "strconv" +) + +const ( + giftTypeYuChi = "YUCHI" + giftTypeYuWan = "YUWAN" +) + +type ( + MsgGift struct { + RoomId int64 `stt:"rid"` // 房间ID + GiftId int64 `stt:"gfid"` // 礼物ID + GiftStyle string `stt:"gs"` // 礼物显示样式(不确定数字类型?用string避免错误) + UID int64 `stt:"uid"` // 用户ID + Nickname string `stt:"nn"` // 用户昵称 + Avatar string `stt:"ic"` // 用户头像 + + // ... ignore eid & eic + + Level int32 `stt:"level"` // 用户等级 + Dw int64 `stt:"dw"` // 主播体重 + GiftCount int64 `stt:"gfcnt"` // 礼物个数:默认值 1(表示 1 个礼物) + Hits int32 `stt:"hits"` // 礼物连击次数:默认值 1(表示 1 连击) + + // ... ingore bcnd & bst & ct & el & cm & + + FansMedal // 粉丝勋章 + + Sahf string `stt:"sahf"` // 扩展字段,一般不使用,可忽略 + Fc int64 `stt:"fc"` // 攻击道具的攻击力 + + // ... ignore gpf & pid & bnid & bnl + + ReceiveUID int64 `stt:"receive_uid"` // 接收礼物用户ID + ReceiveNickname string `stt:"receive_nn"` // 接收礼物用户昵称 + + // ... ignore from & pfm & pma & mss & bcst + + //GroupId int64 `stt:"gid"` // 弹幕分组ID + //BigGift int32 `stt:"bg"` // 大礼物标识:默认值为 0(表示是小礼物) + //Dlv int32 `stt:"dlv"` // 酬勤头衔:默认值 0(表示没有酬勤) + //Dc int32 `stt:"dc"` // 酬勤个数:默认值 0(表示没有酬勤数量) + //Bdl int32 `stt:"bdl"` // 全站最高酬勤等级:默认值 0(表示全站都没有酬勤) + //Rg int32 `stt:"rg"` // 房间身份组:默认值 1(表示普通权限用户)5:UP + //Pg int32 `stt:"pg"` // 平台身份组:默认值 1(表示普通权限用户) + //Nl int32 `stt:"nl"` // 贵族等级:默认值 0(表示不是贵族) + } + MsgGiftHandler struct { + producer *kafka.Producer + liveRoomId int64 + giftMap map[int64]GiftConfig + } + GiftConfig struct { + ID int64 `json:"id"` // 礼物ID + Name string `json:"name"` // 礼物名 + Price int64 `json:"price"` // 价格 收费礼物则是 鱼翅*10,免费礼物则是 鱼丸 + IsPaid bool `json:"-"` // 是否收费礼物 + } +) + +func NewMsgGiftHandler(liveRoomId int64) *MsgGiftHandler { + cfg := config.Config.Kafka.Gift + ret := &MsgGiftHandler{ + producer: kafka.NewKafkaProducer(kfk.DefaultProducerConfig, cfg.Addr, cfg.Topic), + liveRoomId: liveRoomId, + giftMap: map[int64]GiftConfig{}, + } + + ret.setupGiftConfig() + return ret +} + +func (m *MsgGiftHandler) setupGiftConfig() { + logger.SLog.Info("初始化 Douyu 礼物配置...") + // bagGiftData + { + httpResp, err := http.Get("https://webconf.douyucdn.cn/resource/common/prop_gift_list/prop_gift_config.json") + if err != nil { + return + } + httpBodyResp, err := io.ReadAll(httpResp.Body) + if err != nil { + return + } + httpBodyResp = httpBodyResp[len("DYConfigCallback("):] + httpBodyResp = httpBodyResp[:len(httpBodyResp)-2] + + var gf struct { + Data map[string]struct { + Name string `json:"name"` + Price int64 `json:"pc"` + } `json:"data"` + } + err = jsoniter.Unmarshal(httpBodyResp, &gf) + if err != nil { + return + } + for id, g := range gf.Data { + idInt, err := strconv.ParseInt(id, 10, 64) + if err != nil { + continue + } + gfCfg := GiftConfig{ + ID: idInt, + Name: g.Name, + IsPaid: true, + Price: g.Price, + } + m.giftMap[gfCfg.ID] = gfCfg + } + } + // roomGiftData + { + httpResp, err := http.Get(fmt.Sprintf("https://gift.douyucdn.cn/api/gift/v3/web/list?rid=%d", m.liveRoomId)) + if err != nil { + return + } + httpBodyResp, err := io.ReadAll(httpResp.Body) + if err != nil { + return + } + type GiftPriceInfo struct { + Price int64 `json:"price"` + PriceType string `json:"priceType"` + } + var gf struct { + Data struct { + GiftList []struct { + ID int64 `json:"id"` // 礼物ID + Name string `json:"name"` // 礼物名 + GiftPriceInfo `json:"priceInfo"` // 价格参数 + } `json:"giftList"` + } `json:"data"` + } + err = jsoniter.Unmarshal(httpBodyResp, &gf) + if err != nil { + return + } + for _, g := range gf.Data.GiftList { + gfCfg := GiftConfig{ + ID: g.ID, + Name: g.Name, + IsPaid: false, + Price: g.Price, + } + // 处理价格 + if g.PriceType == giftTypeYuChi { + gfCfg.IsPaid = true + } + m.giftMap[gfCfg.ID] = gfCfg + } + } + + logger.SLog.Infof("Douyu 礼物配置读取成功 num:%d", len(m.giftMap)) +} + +func (m *MsgGiftHandler) DataType() string { + return "dgb" +} + +func (m *MsgGiftHandler) HandlerMessage(data []byte) { + var ret MsgGift + err := stt.Unmarshal(data, &ret) + if err != nil { + return + } + + if len(m.giftMap) == 0 { + m.setupGiftConfig() + } + + giftConfig := m.giftMap[ret.GiftId] + logger.SLog.Debugf("%s 赠送: %+vx%d", ret.Nickname, giftConfig, ret.GiftCount) + + //dmMsg := &pbMq.MqGift{ + // Platform: pbVars.Platform_name[int32(pbVars.Platform_Douyu)], + // LiveRoomId: m.liveRoomId, + // MsgId: "", + // Timestamp: time.Now().UnixMicro(), + // Uid: ret.UID, + // Uname: ret.Nickname, + // Avatar: "https://apic.douyucdn.cn/upload/" + ret.Avatar, // TODO 暂时组合,应该保留原始,然后客户端配置 + // NobilityLevel: 0, + // GiftId: ret.GiftId, + // GiftName: giftConfig.Name, + // GiftNum: ret.GiftCount, + // Price: giftConfig.Price, + // IsPaid: giftConfig.IsPaid, + // Type: pbMq.MqGift_NORMAL, + // PackGift: nil, + // + // FansMedalWearingStatus: ret.FansMedal.MedalRoomId == m.liveRoomId, + // FansMedalName: ret.FansMedal.MedalName, + // FansMedalLevel: ret.FansMedal.MedalLevel, + //} + //_ = m.producer.SendMessageAsync(dmMsg, strconv.FormatInt(dmMsg.Uid, 10)) +} diff --git a/douyu/msg_wss.go b/douyu/msg_wss.go new file mode 100644 index 0000000..39028aa --- /dev/null +++ b/douyu/msg_wss.go @@ -0,0 +1,46 @@ +package douyu + +var _ TypedData = (*MsgLoginReq)(nil) + +type ( + MsgLoginReq struct { + RoomId int64 `stt:"roomid"` // 直播间ID + //Dfl MsgLoginReqDfl `stt:"dfl"` // 不清楚含义,给默认值 + Dfl string `stt:"dfl"` + + Username string `stt:"username"` // 用户名,随机生成就好 + Uid int64 `stt:"uid"` // 用户ID + Ver string `stt:"ver"` // 版本 + AVer string `stt:"aver"` // 另一个版本 + Ct int `stt:"ct"` // 客户端类型 通常为0 + } + MsgLoginReqDfl struct { + Sn int32 `stt:"sn,omitempty"` + Ss int32 `stt:"ss,omitempty"` + } +) + +func (m *MsgLoginReq) DataType() string { + return "loginreq" +} + +var _ TypedData = (*MsgJoinGroup)(nil) + +type MsgJoinGroup struct { + RoomId int64 `stt:"rid"` + GId int `stt:"gid"` +} + +func (m *MsgJoinGroup) DataType() string { + return "joingroup" +} + +var _ TypedData = (*MsgHeartbeat)(nil) + +// MsgHeartbeat 服务端心跳消息,空消息 +type MsgHeartbeat struct { +} + +func (m *MsgHeartbeat) DataType() string { + return "mrkl" +} diff --git a/douyu/packer.go b/douyu/packer.go new file mode 100644 index 0000000..50e353b --- /dev/null +++ b/douyu/packer.go @@ -0,0 +1,129 @@ +package douyu + +import ( + "bytes" + "encoding/binary" + "fmt" + "github.com/gorilla/websocket" + "github.com/pkg/errors" + "live-gateway/ws" +) + +const ( + wsPackageLen = 4 + wsMsgTypeLen = 2 + wsSecretLen = 1 + wsReservedLen = 1 + wsEndingLen = 1 +) + +var _ ws.Packer = (*PackDouyu)(nil) + +// PackDouyu douyu packer +// +// | segment | type | size | offset | remark | +// | ---------- | ------ | ------- | -------| ------------------------------------------- | +// | `packageLen` | uint32 | 4 | 0 | 消息长度 4字节 | +// | `packageLen` | uint32 | 4 | 4 | 消息长度 4字节 | +// | `msgType` | uint16 | 2 | 8 | 消息类型 2字节 小端整数 | +// | `secret` | byte | 1 | 10 | 加密字段 1字节 暂时未用,默认为 0 | +// | `reserved` | byte | 1 | 11 | 1 | +// | `data` | []byte | dynamic | 12 | :数据内容,结尾'\0' | +// WsEntry -> ws.Entry#Raw +type PackDouyu struct { +} + +func NewPackDouyu() ws.Packer { + return &PackDouyu{} +} + +func (*PackDouyu) byteOrder() binary.ByteOrder { + return binary.LittleEndian +} + +// Unpack unpacks douyu.WsEntry to ws.Entry +func (p *PackDouyu) Unpack(v interface{}) (*ws.Entry, error) { + result := &ws.Entry{ + MessageType: websocket.BinaryMessage, + } + // v must be bilibili.WsEntry + entry, ok := v.(*WsEntry) + if !ok { + return nil, fmt.Errorf("[Pack-Douyu] 写入值类型必须为 %T", WsEntry{}) + } + var err error + + length := wsPackageLen + wsMsgTypeLen + wsSecretLen + + wsReservedLen + len(entry.data) + wsEndingLen // 8+len(data)+1 + + buffer := bytes.NewBuffer([]byte{}) + byteOrder := p.byteOrder() + err = binary.Write(buffer, byteOrder, int32(length)) + err = binary.Write(buffer, byteOrder, int32(length)) + err = binary.Write(buffer, byteOrder, entry.msgType) + err = binary.Write(buffer, byteOrder, entry.secret) + err = binary.Write(buffer, byteOrder, entry.reserved) + err = binary.Write(buffer, byteOrder, entry.data) + err = binary.Write(buffer, byteOrder, EndCharacter) + if err != nil { + return nil, err + } + result.Raw = buffer.Bytes() + return result, nil +} + +// Pack packs ws.Entry to douyu.WsEntry +func (p *PackDouyu) Pack(entry *ws.Entry) (interface{}, error) { + defer func() { + if err := recover(); err != nil { + } + }() + if entry.MessageType != websocket.BinaryMessage { + return nil, errors.New("err of msg") + } + var ( + offsetLen = wsPackageLen + offsetMsgType = offsetLen + offsetLen + offsetSecret = offsetMsgType + wsMsgTypeLen + offsetReserved = offsetSecret + wsSecretLen + offsetData = offsetReserved + wsReservedLen + ) + byteOrder := p.byteOrder() + raw := entry.Raw + + l := int(byteOrder.Uint32(raw[:offsetLen])) + wsPackageLen + + if len(raw) > l { + // 粘包 + var slice []*ws.Entry + for len(raw) > 0 { + tmp := &ws.Entry{ + MessageType: entry.MessageType, + } + length := int(byteOrder.Uint32(raw[:offsetLen])) + wsPackageLen + var ll = length + if len(raw) < length { + ll = len(raw) + } + tmp.Raw = raw[:ll] + + raw = raw[ll:] + slice = append(slice, tmp) + } + return slice, nil + } else { + // 独立包 + ent := &WsEntry{} + + ent.msgType = byteOrder.Uint16(raw[offsetMsgType : offsetMsgType+wsMsgTypeLen]) + ent.secret = raw[offsetSecret] + ent.reserved = raw[offsetReserved] + + endOffset := len(entry.Raw) + if entry.Raw[endOffset-1] == 0 { + endOffset -= 1 + } + ent.data = raw[offsetData:endOffset] + return ent, nil + } +} diff --git a/douyu/stt/config.go b/douyu/stt/config.go new file mode 100644 index 0000000..4aa32b7 --- /dev/null +++ b/douyu/stt/config.go @@ -0,0 +1,182 @@ +package stt + +import ( + "github.com/modern-go/concurrent" + "github.com/modern-go/reflect2" + "io" + "sync" +) + +// Config customize how the API should behave. +// The API is created from Config by frozenConfig. +type Config struct { + TagKey string // 自定义 tagKey 默认 frozenConfig + OnlyTaggedField bool // 是否仅限 tag 标记字段 + CaseSensitive bool // key是否大小写敏感 默认true +} + +// API the public interface of this package. +// Primary Marshal and Unmarshal. +type API interface { + IteratorPool + StreamPool + + Marshal(v any) ([]byte, error) + Unmarshal(data []byte, v any) error + RegisterExtension(extension Extension) + DecoderOf(typ reflect2.Type) ValDecoder + EncoderOf(typ reflect2.Type) ValEncoder + //NewEncoder(writer io.Writer) *Encoder + //NewDecoder(reader io.Reader) *Decoder +} + +var ConfigDefault = Config{}.Froze() + +type frozenConfig struct { + configBeforeFrozen Config + onlyTaggedField bool + caseSensitive bool + decoderCache *concurrent.Map + encoderCache *concurrent.Map + encoderExtension Extension + decoderExtension Extension + extraExtensions []Extension + streamPool *sync.Pool // marshal stream + iteratorPool *sync.Pool // unmarshal iter pool +} + +func (cfg *frozenConfig) initCache() { + cfg.decoderCache = concurrent.NewMap() + cfg.encoderCache = concurrent.NewMap() +} + +func (cfg *frozenConfig) addDecoderToCache(cacheKey uintptr, decoder ValDecoder) { + cfg.decoderCache.Store(cacheKey, decoder) +} + +func (cfg *frozenConfig) addEncoderToCache(cacheKey uintptr, encoder ValEncoder) { + cfg.encoderCache.Store(cacheKey, encoder) +} + +func (cfg *frozenConfig) getDecoderFromCache(cacheKey uintptr) ValDecoder { + decoder, found := cfg.decoderCache.Load(cacheKey) + if found { + return decoder.(ValDecoder) + } + return nil +} + +func (cfg *frozenConfig) getEncoderFromCache(cacheKey uintptr) ValEncoder { + encoder, found := cfg.encoderCache.Load(cacheKey) + if found { + return encoder.(ValEncoder) + } + return nil +} + +var cfgCache = concurrent.NewMap() + +func getFrozenConfigFromCache(cfg Config) *frozenConfig { + obj, found := cfgCache.Load(cfg) + if found { + return obj.(*frozenConfig) + } + return nil +} + +func addFrozenConfigToCache(cfg Config, frozenConfig *frozenConfig) { + cfgCache.Store(cfg, frozenConfig) +} + +///////////// Config + +func (cfg Config) Froze() API { + api := &frozenConfig{ + onlyTaggedField: cfg.OnlyTaggedField, + caseSensitive: cfg.CaseSensitive, + } + api.streamPool = &sync.Pool{ + New: func() interface{} { + return NewStream(api, nil, 512) + }, + } + api.iteratorPool = &sync.Pool{ + New: func() interface{} { + return NewIterator(api) + }, + } + api.initCache() + encoderExtension := EncoderExtension{} + decoderExtension := DecoderExtension{} + + api.encoderExtension = encoderExtension + api.decoderExtension = decoderExtension + api.configBeforeFrozen = cfg + return api +} + +func (cfg Config) frozeWithCacheReuse(extraExtensions []Extension) *frozenConfig { + api := getFrozenConfigFromCache(cfg) + if api != nil { + return api + } + api = cfg.Froze().(*frozenConfig) + for _, extension := range extraExtensions { + api.RegisterExtension(extension) + } + addFrozenConfigToCache(cfg, api) + return api +} + +func (cfg *frozenConfig) getTagKey() string { + tagKey := cfg.configBeforeFrozen.TagKey + if tagKey == "" { + return "stt" + } + return tagKey +} + +func (cfg *frozenConfig) RegisterExtension(extension Extension) { + cfg.extraExtensions = append(cfg.extraExtensions, extension) + copied := cfg.configBeforeFrozen + cfg.configBeforeFrozen = copied +} + +func (cfg *frozenConfig) Marshal(v interface{}) ([]byte, error) { + stream := cfg.BorrowStream(nil) + defer cfg.ReturnStream(stream) + stream.WriteVal(v) + if stream.Error != nil { + return nil, stream.Error + } + result := stream.Buffer() + copied := make([]byte, len(result)) + copy(copied, result) + return copied, nil +} + +func (cfg *frozenConfig) Unmarshal(data []byte, v interface{}) error { + // decrypt + validData := cfg.getValidData(data) + //iter := cfg.BorrowIterator(decryptBytes(validData, -1)) + iter := cfg.BorrowIterator(validData) + defer cfg.ReturnIterator(iter) + iter.ReadVal(v) + c := iter.nextToken() + if c == 0 { + if iter.Error == io.EOF { + return nil + } + return iter.Error + } + iter.ReportError("Unmarshal", "there are bytes left after unmarshal") + return iter.Error +} + +func (cfg *frozenConfig) getValidData(data []byte) []byte { + c := data[len(data)-1] + if c != '/' { + data = append(data, '/') + } + return data +} diff --git a/douyu/stt/iter.go b/douyu/stt/iter.go new file mode 100644 index 0000000..40144c2 --- /dev/null +++ b/douyu/stt/iter.go @@ -0,0 +1,294 @@ +package stt + +import ( + "fmt" + "io" +) + +// ValueType the type for JSON element +type ValueType int + +const ( + // InvalidValue invalid element + InvalidValue ValueType = iota + // StringValue element "string" + StringValue + // NumberValue element 100 or 0.10 + NumberValue + // NilValue element null + NilValue + // BoolValue element true or false + BoolValue +) + +var valueTypes []ValueType + +func init() { + valueTypes = make([]ValueType, 256) + for i := 0; i < len(valueTypes); i++ { + valueTypes[i] = StringValue + } + valueTypes['-'] = NumberValue + valueTypes['0'] = NumberValue + valueTypes['1'] = NumberValue + valueTypes['2'] = NumberValue + valueTypes['3'] = NumberValue + valueTypes['4'] = NumberValue + valueTypes['5'] = NumberValue + valueTypes['6'] = NumberValue + valueTypes['7'] = NumberValue + valueTypes['8'] = NumberValue + valueTypes['9'] = NumberValue + valueTypes['t'] = BoolValue + valueTypes['f'] = BoolValue + valueTypes['n'] = NilValue +} + +// Iterator is an io.Reader like object, with STT specific read functions. +// Error is not returned as return value, but stored as Error member on this iterator instance. +type Iterator struct { + cfg *frozenConfig + reader io.Reader + buf []byte + head int + tail int + depth int + captureStartedAt int + captured []byte + Error error + Attachment interface{} // open for customized decoder +} + +// NewIterator creates an empty Iterator instance +func NewIterator(cfg API) *Iterator { + return &Iterator{ + cfg: cfg.(*frozenConfig), + reader: nil, + buf: nil, + head: 0, + tail: 0, + depth: 0, + } +} + +// Parse creates an Iterator instance from io.Reader +func Parse(cfg API, reader io.Reader, bufSize int) *Iterator { + return &Iterator{ + cfg: cfg.(*frozenConfig), + reader: reader, + buf: make([]byte, bufSize), + head: 0, + tail: 0, + depth: 0, + } +} + +// ParseBytes creates an Iterator instance from byte array +func ParseBytes(cfg API, input []byte) *Iterator { + return &Iterator{ + cfg: cfg.(*frozenConfig), + reader: nil, + buf: input, + head: 0, + tail: len(input), + depth: 0, + } +} + +// ParseString creates an Iterator instance from string +func ParseString(cfg API, input string) *Iterator { + return ParseBytes(cfg, []byte(input)) +} + +// Pool returns a pool can provide more iterator with same configuration +func (iter *Iterator) Pool() IteratorPool { + return iter.cfg +} + +// Reset reuse iterator instance by specifying another reader +func (iter *Iterator) Reset(reader io.Reader) *Iterator { + iter.reader = reader + iter.head = 0 + iter.tail = 0 + iter.depth = 0 + return iter +} + +// ResetBytes reuse iterator instance by specifying another byte array as input +func (iter *Iterator) ResetBytes(input []byte) *Iterator { + iter.reader = nil + iter.buf = input + iter.head = 0 + iter.tail = len(input) + iter.depth = 0 + return iter +} + +// WhatIsNext gets ValueType of relatively next json element +func (iter *Iterator) WhatIsNext() ValueType { + valueType := valueTypes[iter.nextToken()] + iter.unreadByte() + return valueType +} + +func (iter *Iterator) skipWhitespacesWithoutLoadMore() bool { + for i := iter.head; i < iter.tail; i++ { + c := iter.buf[i] + switch c { + case ' ', '\n', '\t', '\r': + continue + } + iter.head = i + return false + } + return true +} + +func (iter *Iterator) nextToken() byte { + // a variation of skip whitespaces, returning the next non-whitespace token + for { + for i := iter.head; i < iter.tail; i++ { + c := iter.buf[i] + switch c { + case ' ', '\n', '\t', '\r': + continue + } + iter.head = i + 1 + return c + } + if !iter.loadMore() { + return 0 + } + } +} + +// ReportError record an error in iterator instance with current position. +func (iter *Iterator) ReportError(operation string, msg string) { + if iter.Error != nil { + if iter.Error != io.EOF { + return + } + } + peekStart := iter.head - 10 + if peekStart < 0 { + peekStart = 0 + } + peekEnd := iter.head + 10 + if peekEnd > iter.tail { + peekEnd = iter.tail + } + parsing := string(iter.buf[peekStart:peekEnd]) + contextStart := iter.head - 50 + if contextStart < 0 { + contextStart = 0 + } + contextEnd := iter.head + 50 + if contextEnd > iter.tail { + contextEnd = iter.tail + } + context := string(iter.buf[contextStart:contextEnd]) + iter.Error = fmt.Errorf("%s: %s, error found in #%v byte of ...|%s|..., bigger context ...|%s|...", + operation, msg, iter.head-peekStart, parsing, context) +} + +// CurrentBuffer gets current buffer as string for debugging purpose +func (iter *Iterator) CurrentBuffer() string { + peekStart := iter.head - 10 + if peekStart < 0 { + peekStart = 0 + } + return fmt.Sprintf("parsing #%v byte, around ...|%s|..., whole buffer ...|%s|...", iter.head, + string(iter.buf[peekStart:iter.head]), string(iter.buf[0:iter.tail])) +} + +func (iter *Iterator) readByte() (ret byte) { + if iter.head == iter.tail { + if iter.loadMore() { + ret = iter.buf[iter.head] + iter.head++ + return ret + } + return 0 + } + ret = iter.buf[iter.head] + iter.head++ + return ret +} + +func (iter *Iterator) loadMore() bool { + if iter.reader == nil { + if iter.Error == nil { + iter.head = iter.tail + iter.Error = io.EOF + } + return false + } + if iter.captured != nil { + iter.captured = append(iter.captured, + iter.buf[iter.captureStartedAt:iter.tail]...) + iter.captureStartedAt = 0 + } + for { + n, err := iter.reader.Read(iter.buf) + if n == 0 { + if err != nil { + if iter.Error == nil { + iter.Error = err + } + return false + } + } else { + iter.head = 0 + iter.tail = n + return true + } + } +} + +func (iter *Iterator) unreadByte() { + if iter.Error != nil { + return + } + iter.head-- + return +} + +// Read the next STT element as generic interface{}. +func (iter *Iterator) Read() interface{} { + valueType := iter.WhatIsNext() + switch valueType { + case StringValue: + return iter.ReadString() + case NumberValue: + return iter.ReadFloat64() + case NilValue: + iter.skipBytes('n', 'u', 'l', 'l') + return nil + case BoolValue: + return iter.ReadBool() + default: + iter.ReportError("Read", fmt.Sprintf("unexpected value type: %v", valueType)) + return nil + } +} + +// limit maximum depth of nesting +const maxDepth = 20 + +func (iter *Iterator) incrementDepth() (success bool) { + iter.depth++ + if iter.depth <= maxDepth { + return true + } + iter.ReportError("incrementDepth", "exceeded max depth") + return false +} + +func (iter *Iterator) decrementDepth() (success bool) { + iter.depth-- + if iter.depth >= 0 { + return true + } + iter.ReportError("decrementDepth", "unexpected negative nesting") + return false +} diff --git a/douyu/stt/iter_float.go b/douyu/stt/iter_float.go new file mode 100644 index 0000000..bd2bbea --- /dev/null +++ b/douyu/stt/iter_float.go @@ -0,0 +1,300 @@ +package stt + +import ( + "io" + "strconv" + "strings" + "unsafe" +) + +var floatDigits []int8 + +const invalidCharForNumber = int8(-1) +const endOfNumber = int8(-2) +const dotInNumber = int8(-3) + +func init() { + floatDigits = make([]int8, 256) + for i := 0; i < len(floatDigits); i++ { + floatDigits[i] = invalidCharForNumber + } + for i := int8('0'); i <= int8('9'); i++ { + floatDigits[i] = i - int8('0') + } + floatDigits[','] = endOfNumber + floatDigits['/'] = endOfNumber + floatDigits[' '] = endOfNumber + floatDigits['\t'] = endOfNumber + floatDigits['\n'] = endOfNumber + floatDigits['.'] = dotInNumber +} + +//ReadFloat32 read float32 +func (iter *Iterator) ReadFloat32() (ret float32) { + c := iter.nextToken() + if c == '-' { + return -iter.readPositiveFloat32() + } + iter.unreadByte() + return iter.readPositiveFloat32() +} + +func (iter *Iterator) readPositiveFloat32() (ret float32) { + i := iter.head + // first char + if i == iter.tail { + return iter.readFloat32SlowPath() + } + c := iter.buf[i] + i++ + ind := floatDigits[c] + switch ind { + case invalidCharForNumber: + return iter.readFloat32SlowPath() + case endOfNumber: + iter.ReportError("readFloat32", "empty number") + return + case dotInNumber: + iter.ReportError("readFloat32", "leading dot is invalid") + return + case 0: + if i == iter.tail { + return iter.readFloat32SlowPath() + } + c = iter.buf[i] + switch c { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + iter.ReportError("readFloat32", "leading zero is invalid") + return + } + } + value := uint64(ind) + // chars before dot +non_decimal_loop: + for ; i < iter.tail; i++ { + c = iter.buf[i] + ind := floatDigits[c] + switch ind { + case invalidCharForNumber: + return iter.readFloat32SlowPath() + case endOfNumber: + iter.head = i + return float32(value) + case dotInNumber: + break non_decimal_loop + } + if value > uint64SafeToMultiple10 { + return iter.readFloat32SlowPath() + } + value = (value << 3) + (value << 1) + uint64(ind) // value = value * 10 + ind; + } + // chars after dot + if c == '.' { + i++ + decimalPlaces := 0 + if i == iter.tail { + return iter.readFloat32SlowPath() + } + for ; i < iter.tail; i++ { + c = iter.buf[i] + ind := floatDigits[c] + switch ind { + case endOfNumber: + if decimalPlaces > 0 && decimalPlaces < len(pow10) { + iter.head = i + return float32(float64(value) / float64(pow10[decimalPlaces])) + } + // too many decimal places + return iter.readFloat32SlowPath() + case invalidCharForNumber, dotInNumber: + return iter.readFloat32SlowPath() + } + decimalPlaces++ + if value > uint64SafeToMultiple10 { + return iter.readFloat32SlowPath() + } + value = (value << 3) + (value << 1) + uint64(ind) + } + } + return iter.readFloat32SlowPath() +} + +func (iter *Iterator) readFloat32SlowPath() (ret float32) { + str := iter.readNumberAsString() + if iter.Error != nil && iter.Error != io.EOF { + return + } + errMsg := validateFloat(str) + if errMsg != "" { + iter.ReportError("readFloat32SlowPath", errMsg) + return + } + val, err := strconv.ParseFloat(str, 32) + if err != nil { + iter.Error = err + return + } + return float32(val) +} + +func (iter *Iterator) readNumberAsString() (ret string) { + strBuf := [16]byte{} + str := strBuf[0:0] +load_loop: + for { + for i := iter.head; i < iter.tail; i++ { + c := iter.buf[i] + switch c { + case '+', '-', '.', 'e', 'E', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + str = append(str, c) + continue + default: + iter.head = i + break load_loop + } + } + if !iter.loadMore() { + break + } + } + if iter.Error != nil && iter.Error != io.EOF { + return + } + if len(str) == 0 { + iter.ReportError("readNumberAsString", "invalid number") + } + return *(*string)(unsafe.Pointer(&str)) +} + +// ReadFloat64 read float64 +func (iter *Iterator) ReadFloat64() (ret float64) { + c := iter.nextToken() + if c == '-' { + return -iter.readPositiveFloat64() + } + iter.unreadByte() + return iter.readPositiveFloat64() +} + +func (iter *Iterator) readPositiveFloat64() (ret float64) { + i := iter.head + // first char + if i == iter.tail { + return iter.readFloat64SlowPath() + } + c := iter.buf[i] + i++ + ind := floatDigits[c] + switch ind { + case invalidCharForNumber: + return iter.readFloat64SlowPath() + case endOfNumber: + iter.ReportError("readFloat64", "empty number") + return + case dotInNumber: + iter.ReportError("readFloat64", "leading dot is invalid") + return + case 0: + if i == iter.tail { + return iter.readFloat64SlowPath() + } + c = iter.buf[i] + switch c { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + iter.ReportError("readFloat64", "leading zero is invalid") + return + } + } + value := uint64(ind) + // chars before dot +non_decimal_loop: + for ; i < iter.tail; i++ { + c = iter.buf[i] + ind := floatDigits[c] + switch ind { + case invalidCharForNumber: + return iter.readFloat64SlowPath() + case endOfNumber: + iter.head = i + return float64(value) + case dotInNumber: + break non_decimal_loop + } + if value > uint64SafeToMultiple10 { + return iter.readFloat64SlowPath() + } + value = (value << 3) + (value << 1) + uint64(ind) // value = value * 10 + ind; + } + // chars after dot + if c == '.' { + i++ + decimalPlaces := 0 + if i == iter.tail { + return iter.readFloat64SlowPath() + } + for ; i < iter.tail; i++ { + c = iter.buf[i] + ind := floatDigits[c] + switch ind { + case endOfNumber: + if decimalPlaces > 0 && decimalPlaces < len(pow10) { + iter.head = i + return float64(value) / float64(pow10[decimalPlaces]) + } + // too many decimal places + return iter.readFloat64SlowPath() + case invalidCharForNumber, dotInNumber: + return iter.readFloat64SlowPath() + } + decimalPlaces++ + if value > uint64SafeToMultiple10 { + return iter.readFloat64SlowPath() + } + value = (value << 3) + (value << 1) + uint64(ind) + if value > maxFloat64 { + return iter.readFloat64SlowPath() + } + } + } + return iter.readFloat64SlowPath() +} + +func (iter *Iterator) readFloat64SlowPath() (ret float64) { + str := iter.readNumberAsString() + if iter.Error != nil && iter.Error != io.EOF { + return + } + errMsg := validateFloat(str) + if errMsg != "" { + iter.ReportError("readFloat64SlowPath", errMsg) + return + } + val, err := strconv.ParseFloat(str, 64) + if err != nil { + iter.Error = err + return + } + return val +} + +func validateFloat(str string) string { + // strconv.ParseFloat is not validating `1.` or `1.e1` + if len(str) == 0 { + return "empty number" + } + if str[0] == '-' { + return "-- is not valid" + } + dotPos := strings.IndexByte(str, '.') + if dotPos != -1 { + if dotPos == len(str)-1 { + return "dot can not be last character" + } + switch str[dotPos+1] { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + default: + return "missing digit after dot" + } + } + return "" +} diff --git a/douyu/stt/iter_int.go b/douyu/stt/iter_int.go new file mode 100644 index 0000000..2d6b9a3 --- /dev/null +++ b/douyu/stt/iter_int.go @@ -0,0 +1,346 @@ +package stt + +import ( + "math" + "strconv" +) + +var intDigits []int8 + +const uint32SafeToMultiply10 = uint32(0xffffffff)/10 - 1 +const uint64SafeToMultiple10 = uint64(0xffffffffffffffff)/10 - 1 +const maxFloat64 = 1<<53 - 1 + +func init() { + intDigits = make([]int8, 256) + for i := 0; i < len(intDigits); i++ { + intDigits[i] = invalidCharForNumber + } + for i := int8('0'); i <= int8('9'); i++ { + intDigits[i] = i - int8('0') + } +} + +// ReadUint read uint +func (iter *Iterator) ReadUint() uint { + if strconv.IntSize == 32 { + return uint(iter.ReadUint32()) + } + return uint(iter.ReadUint64()) +} + +// ReadInt read int +func (iter *Iterator) ReadInt() int { + if strconv.IntSize == 32 { + return int(iter.ReadInt32()) + } + return int(iter.ReadInt64()) +} + +// ReadInt8 read int8 +func (iter *Iterator) ReadInt8() (ret int8) { + c := iter.nextToken() + if c == '-' { + val := iter.readUint32(iter.readByte()) + if val > math.MaxInt8+1 { + iter.ReportError("ReadInt8", "overflow: "+strconv.FormatInt(int64(val), 10)) + return + } + return -int8(val) + } + val := iter.readUint32(c) + if val > math.MaxInt8 { + iter.ReportError("ReadInt8", "overflow: "+strconv.FormatInt(int64(val), 10)) + return + } + return int8(val) +} + +// ReadUint8 read uint8 +func (iter *Iterator) ReadUint8() (ret uint8) { + val := iter.readUint32(iter.nextToken()) + if val > math.MaxUint8 { + iter.ReportError("ReadUint8", "overflow: "+strconv.FormatInt(int64(val), 10)) + return + } + return uint8(val) +} + +// ReadInt16 read int16 +func (iter *Iterator) ReadInt16() (ret int16) { + c := iter.nextToken() + if c == '-' { + val := iter.readUint32(iter.readByte()) + if val > math.MaxInt16+1 { + iter.ReportError("ReadInt16", "overflow: "+strconv.FormatInt(int64(val), 10)) + return + } + return -int16(val) + } + val := iter.readUint32(c) + if val > math.MaxInt16 { + iter.ReportError("ReadInt16", "overflow: "+strconv.FormatInt(int64(val), 10)) + return + } + return int16(val) +} + +// ReadUint16 read uint16 +func (iter *Iterator) ReadUint16() (ret uint16) { + val := iter.readUint32(iter.nextToken()) + if val > math.MaxUint16 { + iter.ReportError("ReadUint16", "overflow: "+strconv.FormatInt(int64(val), 10)) + return + } + return uint16(val) +} + +// ReadInt32 read int32 +func (iter *Iterator) ReadInt32() (ret int32) { + c := iter.nextToken() + if c == '-' { + val := iter.readUint32(iter.readByte()) + if val > math.MaxInt32+1 { + iter.ReportError("ReadInt32", "overflow: "+strconv.FormatInt(int64(val), 10)) + return + } + return -int32(val) + } + val := iter.readUint32(c) + if val > math.MaxInt32 { + iter.ReportError("ReadInt32", "overflow: "+strconv.FormatInt(int64(val), 10)) + return + } + return int32(val) +} + +// ReadUint32 read uint32 +func (iter *Iterator) ReadUint32() (ret uint32) { + return iter.readUint32(iter.nextToken()) +} + +func (iter *Iterator) readUint32(c byte) (ret uint32) { + ind := intDigits[c] + if ind == 0 { + iter.assertInteger() + return 0 // single zero + } + if ind == invalidCharForNumber { + iter.ReportError("readUint32", "unexpected character: "+string([]byte{byte(ind)})) + return + } + value := uint32(ind) + if iter.tail-iter.head > 10 { + i := iter.head + ind2 := intDigits[iter.buf[i]] + if ind2 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value + } + i++ + ind3 := intDigits[iter.buf[i]] + if ind3 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*10 + uint32(ind2) + } + //iter.head = i + 1 + //value = value * 100 + uint32(ind2) * 10 + uint32(ind3) + i++ + ind4 := intDigits[iter.buf[i]] + if ind4 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*100 + uint32(ind2)*10 + uint32(ind3) + } + i++ + ind5 := intDigits[iter.buf[i]] + if ind5 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*1000 + uint32(ind2)*100 + uint32(ind3)*10 + uint32(ind4) + } + i++ + ind6 := intDigits[iter.buf[i]] + if ind6 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*10000 + uint32(ind2)*1000 + uint32(ind3)*100 + uint32(ind4)*10 + uint32(ind5) + } + i++ + ind7 := intDigits[iter.buf[i]] + if ind7 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*100000 + uint32(ind2)*10000 + uint32(ind3)*1000 + uint32(ind4)*100 + uint32(ind5)*10 + uint32(ind6) + } + i++ + ind8 := intDigits[iter.buf[i]] + if ind8 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*1000000 + uint32(ind2)*100000 + uint32(ind3)*10000 + uint32(ind4)*1000 + uint32(ind5)*100 + uint32(ind6)*10 + uint32(ind7) + } + i++ + ind9 := intDigits[iter.buf[i]] + value = value*10000000 + uint32(ind2)*1000000 + uint32(ind3)*100000 + uint32(ind4)*10000 + uint32(ind5)*1000 + uint32(ind6)*100 + uint32(ind7)*10 + uint32(ind8) + iter.head = i + if ind9 == invalidCharForNumber { + iter.assertInteger() + return value + } + } + for { + for i := iter.head; i < iter.tail; i++ { + ind = intDigits[iter.buf[i]] + if ind == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value + } + if value > uint32SafeToMultiply10 { + value2 := (value << 3) + (value << 1) + uint32(ind) + if value2 < value { + iter.ReportError("readUint32", "overflow") + return + } + value = value2 + continue + } + value = (value << 3) + (value << 1) + uint32(ind) + } + if !iter.loadMore() { + iter.assertInteger() + return value + } + } +} + +// ReadInt64 read int64 +func (iter *Iterator) ReadInt64() (ret int64) { + c := iter.nextToken() + if c == '-' { + val := iter.readUint64(iter.readByte()) + if val > math.MaxInt64+1 { + iter.ReportError("ReadInt64", "overflow: "+strconv.FormatUint(uint64(val), 10)) + return + } + return -int64(val) + } + val := iter.readUint64(c) + if val > math.MaxInt64 { + iter.ReportError("ReadInt64", "overflow: "+strconv.FormatUint(uint64(val), 10)) + return + } + return int64(val) +} + +// ReadUint64 read uint64 +func (iter *Iterator) ReadUint64() uint64 { + return iter.readUint64(iter.nextToken()) +} + +func (iter *Iterator) readUint64(c byte) (ret uint64) { + ind := intDigits[c] + if ind == 0 { + iter.assertInteger() + return 0 // single zero + } + if ind == invalidCharForNumber { + iter.ReportError("readUint64", "unexpected character: "+string([]byte{byte(ind)})) + return + } + value := uint64(ind) + if iter.tail-iter.head > 10 { + i := iter.head + ind2 := intDigits[iter.buf[i]] + if ind2 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value + } + i++ + ind3 := intDigits[iter.buf[i]] + if ind3 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*10 + uint64(ind2) + } + //iter.head = i + 1 + //value = value * 100 + uint32(ind2) * 10 + uint32(ind3) + i++ + ind4 := intDigits[iter.buf[i]] + if ind4 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*100 + uint64(ind2)*10 + uint64(ind3) + } + i++ + ind5 := intDigits[iter.buf[i]] + if ind5 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*1000 + uint64(ind2)*100 + uint64(ind3)*10 + uint64(ind4) + } + i++ + ind6 := intDigits[iter.buf[i]] + if ind6 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*10000 + uint64(ind2)*1000 + uint64(ind3)*100 + uint64(ind4)*10 + uint64(ind5) + } + i++ + ind7 := intDigits[iter.buf[i]] + if ind7 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*100000 + uint64(ind2)*10000 + uint64(ind3)*1000 + uint64(ind4)*100 + uint64(ind5)*10 + uint64(ind6) + } + i++ + ind8 := intDigits[iter.buf[i]] + if ind8 == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value*1000000 + uint64(ind2)*100000 + uint64(ind3)*10000 + uint64(ind4)*1000 + uint64(ind5)*100 + uint64(ind6)*10 + uint64(ind7) + } + i++ + ind9 := intDigits[iter.buf[i]] + value = value*10000000 + uint64(ind2)*1000000 + uint64(ind3)*100000 + uint64(ind4)*10000 + uint64(ind5)*1000 + uint64(ind6)*100 + uint64(ind7)*10 + uint64(ind8) + iter.head = i + if ind9 == invalidCharForNumber { + iter.assertInteger() + return value + } + } + for { + for i := iter.head; i < iter.tail; i++ { + ind = intDigits[iter.buf[i]] + if ind == invalidCharForNumber { + iter.head = i + iter.assertInteger() + return value + } + if value > uint64SafeToMultiple10 { + value2 := (value << 3) + (value << 1) + uint64(ind) + if value2 < value { + iter.ReportError("readUint64", "overflow") + return + } + value = value2 + continue + } + value = (value << 3) + (value << 1) + uint64(ind) + } + if !iter.loadMore() { + iter.assertInteger() + return value + } + } +} + +func (iter *Iterator) assertInteger() { + if iter.head < iter.tail && iter.buf[iter.head] == '.' { + iter.ReportError("assertInteger", "can not decode float as int") + } +} diff --git a/douyu/stt/iter_skip.go b/douyu/stt/iter_skip.go new file mode 100644 index 0000000..9246328 --- /dev/null +++ b/douyu/stt/iter_skip.go @@ -0,0 +1,68 @@ +package stt + +import "fmt" + +// ReadNil reads a json object as nil and +// returns whether it's a nil or not +func (iter *Iterator) ReadNil() (ret bool) { + c := iter.nextToken() + if c == 'n' { + iter.skipBytes('u', 'l', 'l') // null + return true + } + iter.unreadByte() + return false +} + +// ReadBool reads t/true/f/false/1/0 as BoolValue +func (iter *Iterator) ReadBool() (ret bool) { + c := iter.nextToken() + if c == 't' { + iter.skipBytes('r', 'u', 'e') + return true + } + if c == 'f' { + iter.skipBytes('a', 'l', 's', 'e') + return false + } + if c == '0' { + return false + } + if c == '1' { + return true + } + iter.ReportError("ReadBool", "expect t/true/1 or f/false/0, but found "+string([]byte{c})) + return +} + +// Skip skips an object and positions to relatively the next object +func (iter *Iterator) Skip() { + iter.skipString() +} + +func (iter *Iterator) skipString() { + if !iter.trySkipString() { + iter.unreadByte() + iter.ReadString() + } +} + +func (iter *Iterator) trySkipString() bool { + for i := iter.head; i < iter.tail; i++ { + c := iter.buf[i] + if c == '/' { + iter.head = i + return true // valid + } + } + return false +} + +func (iter *Iterator) skipBytes(bytes ...byte) { + for _, b := range bytes { + if iter.readByte() != b { + iter.ReportError("skipBytes", fmt.Sprintf("expect %s", string(bytes))) + return + } + } +} diff --git a/douyu/stt/iter_str.go b/douyu/stt/iter_str.go new file mode 100644 index 0000000..f2a3088 --- /dev/null +++ b/douyu/stt/iter_str.go @@ -0,0 +1,44 @@ +package stt + +import "bytes" + +// ReadString read string from iterator +func (iter *Iterator) ReadString() (ret string) { + depthEnding := encryptBytes([]byte{'/'}, iter.depth) + depthEndingLen := len(depthEnding) + for i := iter.head; i < iter.tail; i++ { + if iter.depth > 1 { + if i+depthEndingLen >= iter.tail { + return + } + endings := iter.buf[i : i+depthEndingLen] + if bytes.Equal(endings, depthEnding) { + ret = string(iter.buf[iter.head:i]) + iter.head = i + break + } + } else { + c := iter.buf[i] + if c == '/' { + ret = string(iter.buf[iter.head:i]) + iter.head = i + break + } + } + } + ret = string(decryptBytes([]byte(ret), iter.depth)) + return +} + +func (iter *Iterator) ReadFieldName() (ret string) { + // key@=value/ + for i := iter.head; i < iter.tail; i++ { + c := iter.buf[i] + if c == '@' { + ret = string(iter.buf[iter.head:i]) + iter.head = i + return ret + } + } + return +} diff --git a/douyu/stt/iter_str_test.go b/douyu/stt/iter_str_test.go new file mode 100644 index 0000000..614f58a --- /dev/null +++ b/douyu/stt/iter_str_test.go @@ -0,0 +1,28 @@ +package stt + +import ( + "fmt" + "testing" +) + +func TestReadFieldName(t *testing.T) { + //str := "type@=chatmsg/a@=false/bbbbbb@=123/c@=cid@AA=xx1@AScnn@AA=xx2@AS/" + //str := "type@=chatmsg/a@=false/bbbbbb@=123/c@=/d@=/" + str := "type@=noble_num_info/sum@=28/vn@=16944/rid@=2947432/list@=lev@AA=8@ASnum@AA=4@AS@Slev@AA=5@ASnum@AA=1@AS@Slev@AA=4@ASnum@AA=3@AS@Slev@AA=3@ASnum@AA=1@AS@Slev@AA=1@ASnum@AA=3@AS@Slev@AA=7@ASnum@AA=16@AS@S/" + //iter := ConfigDefault.BorrowIterator([]byte(str)) + //fieldName := iter.ReadFieldName() + var s struct { + //Type string `stt:"type"` + A bool `stt:"a"` + B int32 `stt:"bbbbbb"` + C struct { + Cid string `stt:"cid"` + Cnn string `stt:"cnn"` + } `stt:"c,omitempty"` + } + err := Unmarshal([]byte(str), &s) + fmt.Println(s, err) + + ret, err := Marshal(&s) + fmt.Println(string(ret), err) +} diff --git a/douyu/stt/pool.go b/douyu/stt/pool.go new file mode 100644 index 0000000..5d9468c --- /dev/null +++ b/douyu/stt/pool.go @@ -0,0 +1,40 @@ +package stt + +import "io" + +// IteratorPool a thread safe pool of iterators with same configuration +type IteratorPool interface { + BorrowIterator(data []byte) *Iterator + ReturnIterator(iter *Iterator) +} + +// StreamPool a thread safe pool of streams with same configuration +type StreamPool interface { + BorrowStream(writer io.Writer) *Stream + ReturnStream(stream *Stream) +} + +func (cfg *frozenConfig) BorrowStream(writer io.Writer) *Stream { + stream := cfg.streamPool.Get().(*Stream) + stream.Reset(writer) + return stream +} + +func (cfg *frozenConfig) ReturnStream(stream *Stream) { + stream.out = nil + stream.Error = nil + stream.Attachment = nil + cfg.streamPool.Put(stream) +} + +func (cfg *frozenConfig) BorrowIterator(data []byte) *Iterator { + iter := cfg.iteratorPool.Get().(*Iterator) + iter.ResetBytes(data) + return iter +} + +func (cfg *frozenConfig) ReturnIterator(iter *Iterator) { + iter.Error = nil + iter.Attachment = nil + cfg.iteratorPool.Put(iter) +} diff --git a/douyu/stt/reflect.go b/douyu/stt/reflect.go new file mode 100644 index 0000000..537f62e --- /dev/null +++ b/douyu/stt/reflect.go @@ -0,0 +1,299 @@ +package stt + +import ( + "fmt" + "github.com/modern-go/reflect2" + "reflect" + "unsafe" +) + +// ValDecoder is an internal type registered to cache as needed. +// 反射获取类型并缓存 +// 尽可能地避免反射获取值,reflect.Value本身具有alloc,但具有如下例外 +// 1. 创建新值的实例,如: *int 将需要 int +// 2. slice的append,如果容量(cap)不够,将使用Reflect.New从而alloc +// 3. 将值分配给map,key/value都是reflect.Value +// 对于简单struct的绑定,应该尽量避免reflect.Value和其本身的alloc开销 +type ValDecoder interface { + Decode(ptr unsafe.Pointer, iter *Iterator) +} + +// ValEncoder is an internal type registered to cache as needed. +type ValEncoder interface { + IsEmpty(ptr unsafe.Pointer) bool + Encode(ptr unsafe.Pointer, stream *Stream) +} + +type checkIsEmpty interface { + IsEmpty(ptr unsafe.Pointer) bool +} + +type ctx struct { + *frozenConfig + prefix string + encoders map[reflect2.Type]ValEncoder + decoders map[reflect2.Type]ValDecoder +} + +func (b *ctx) caseSensitive() bool { + if b.frozenConfig == nil { + // default is case-insensitive + return false + } + return b.frozenConfig.caseSensitive +} + +func (b *ctx) append(prefix string) *ctx { + return &ctx{ + frozenConfig: b.frozenConfig, + prefix: b.prefix + " " + prefix, + encoders: b.encoders, + decoders: b.decoders, + } +} + +// ReadVal copy the underlying STT into go interface +func (iter *Iterator) ReadVal(obj interface{}) { + depth := iter.depth + cacheKey := reflect2.RTypeOf(obj) + decoder := iter.cfg.getDecoderFromCache(cacheKey) + if decoder == nil { + typ := reflect2.TypeOf(obj) + if typ == nil || typ.Kind() != reflect.Ptr { + iter.ReportError("ReadVal", "can only unmarshal into pointer") + return + } + decoder = iter.cfg.DecoderOf(typ) + } + ptr := reflect2.PtrOf(obj) + if ptr == nil { + iter.ReportError("ReadVal", "can not read into nil pointer") + return + } + decoder.Decode(ptr, iter) + if iter.depth != depth { + iter.ReportError("ReadVal", "unexpected mismatched nesting") + return + } +} + +// WriteVal copy the go interface into underlying JSON, same as json.Marshal +func (stream *Stream) WriteVal(val interface{}) { + if nil == val { + stream.WriteNil() + return + } + cacheKey := reflect2.RTypeOf(val) + encoder := stream.cfg.getEncoderFromCache(cacheKey) + if encoder == nil { + typ := reflect2.TypeOf(val) + encoder = stream.cfg.EncoderOf(typ) + } + encoder.Encode(reflect2.PtrOf(val), stream) +} + +func (cfg *frozenConfig) DecoderOf(typ reflect2.Type) ValDecoder { + cacheKey := typ.RType() + decoder := cfg.getDecoderFromCache(cacheKey) + if decoder != nil { + return decoder + } + ctx := &ctx{ + frozenConfig: cfg, + prefix: "", + decoders: map[reflect2.Type]ValDecoder{}, + encoders: map[reflect2.Type]ValEncoder{}, + } + ptrType := typ.(*reflect2.UnsafePtrType) + decoder = decoderOfType(ctx, ptrType.Elem()) + cfg.addDecoderToCache(cacheKey, decoder) + return decoder +} + +func decoderOfType(ctx *ctx, typ reflect2.Type) ValDecoder { + decoder := getTypeDecoderFromExtension(ctx, typ) + if decoder != nil { + return decoder + } + decoder = createDecoderOfType(ctx, typ) + for _, extension := range extensions { + decoder = extension.DecorateDecoder(typ, decoder) + } + decoder = ctx.decoderExtension.DecorateDecoder(typ, decoder) + for _, extension := range ctx.extraExtensions { + decoder = extension.DecorateDecoder(typ, decoder) + } + return decoder +} + +func createDecoderOfType(ctx *ctx, typ reflect2.Type) ValDecoder { + decoder := ctx.decoders[typ] + if decoder != nil { + return decoder + } + placeholder := &placeholderDecoder{} + ctx.decoders[typ] = placeholder + decoder = _createDecoderOfType(ctx, typ) + placeholder.decoder = decoder + return decoder +} + +func _createDecoderOfType(ctx *ctx, typ reflect2.Type) ValDecoder { + decoder := createDecoderOfNative(ctx, typ) + if decoder != nil { + return decoder + } + switch typ.Kind() { + case reflect.Interface: + ifaceType, isIFace := typ.(*reflect2.UnsafeIFaceType) + if isIFace { + return &ifaceDecoder{valType: ifaceType} + } + return &efaceDecoder{} + case reflect.Struct: + return decoderOfStruct(ctx, typ) + case reflect.Array: + return decoderOfArray(ctx, typ) + case reflect.Slice: + return decoderOfSlice(ctx, typ) + case reflect.Map: + return decoderOfMap(ctx, typ) + case reflect.Ptr: + return decoderOfOptional(ctx, typ) + default: + return &lazyErrorDecoder{err: fmt.Errorf("%s%s is unsupported type", ctx.prefix, typ.String())} + } +} + +func (cfg *frozenConfig) EncoderOf(typ reflect2.Type) ValEncoder { + cacheKey := typ.RType() + encoder := cfg.getEncoderFromCache(cacheKey) + if encoder != nil { + return encoder + } + ctx := &ctx{ + frozenConfig: cfg, + prefix: "", + decoders: map[reflect2.Type]ValDecoder{}, + encoders: map[reflect2.Type]ValEncoder{}, + } + encoder = encoderOfType(ctx, typ) + if typ.LikePtr() { + encoder = &onePtrEncoder{encoder} + } + cfg.addEncoderToCache(cacheKey, encoder) + return encoder +} + +type onePtrEncoder struct { + encoder ValEncoder +} + +func (encoder *onePtrEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return encoder.encoder.IsEmpty(unsafe.Pointer(&ptr)) +} + +func (encoder *onePtrEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + encoder.encoder.Encode(unsafe.Pointer(&ptr), stream) +} + +func encoderOfType(ctx *ctx, typ reflect2.Type) ValEncoder { + encoder := getTypeEncoderFromExtension(ctx, typ) + if encoder != nil { + return encoder + } + encoder = createEncoderOfType(ctx, typ) + for _, extension := range extensions { + encoder = extension.DecorateEncoder(typ, encoder) + } + encoder = ctx.encoderExtension.DecorateEncoder(typ, encoder) + for _, extension := range ctx.extraExtensions { + encoder = extension.DecorateEncoder(typ, encoder) + } + return encoder +} + +func createEncoderOfType(ctx *ctx, typ reflect2.Type) ValEncoder { + encoder := ctx.encoders[typ] + if encoder != nil { + return encoder + } + placeholder := &placeholderEncoder{} + ctx.encoders[typ] = placeholder + encoder = _createEncoderOfType(ctx, typ) + placeholder.encoder = encoder + return encoder +} +func _createEncoderOfType(ctx *ctx, typ reflect2.Type) ValEncoder { + encoder := createEncoderOfNative(ctx, typ) + if encoder != nil { + return encoder + } + kind := typ.Kind() + switch kind { + case reflect.Interface: + return &dynamicEncoder{typ} + case reflect.Struct: + return encoderOfStruct(ctx, typ) + case reflect.Array: + return encoderOfArray(ctx, typ) + case reflect.Slice: + return encoderOfSlice(ctx, typ) + case reflect.Map: + return encoderOfMap(ctx, typ) + case reflect.Ptr: + return encoderOfOptional(ctx, typ) + default: + return &lazyErrorEncoder{err: fmt.Errorf("%s%s is unsupported type", ctx.prefix, typ.String())} + } +} + +type lazyErrorDecoder struct { + err error +} + +func (decoder *lazyErrorDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + if iter.WhatIsNext() != NilValue { + if iter.Error == nil { + iter.Error = decoder.err + } + } else { + iter.Skip() + } +} + +type lazyErrorEncoder struct { + err error +} + +func (encoder *lazyErrorEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + if ptr == nil { + stream.WriteNil() + } else if stream.Error == nil { + stream.Error = encoder.err + } +} + +func (encoder *lazyErrorEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + +type placeholderDecoder struct { + decoder ValDecoder +} + +func (decoder *placeholderDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + decoder.decoder.Decode(ptr, iter) +} + +type placeholderEncoder struct { + encoder ValEncoder +} + +func (encoder *placeholderEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + encoder.encoder.Encode(ptr, stream) +} + +func (encoder *placeholderEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return encoder.encoder.IsEmpty(ptr) +} diff --git a/douyu/stt/reflect_array.go b/douyu/stt/reflect_array.go new file mode 100644 index 0000000..ed84150 --- /dev/null +++ b/douyu/stt/reflect_array.go @@ -0,0 +1,104 @@ +package stt + +import ( + "fmt" + "github.com/modern-go/reflect2" + "io" + "unsafe" +) + +func decoderOfArray(ctx *ctx, typ reflect2.Type) ValDecoder { + arrayType := typ.(*reflect2.UnsafeArrayType) + decoder := decoderOfType(ctx.append("[arrayElem]"), arrayType.Elem()) + return &arrayDecoder{arrayType, decoder} +} + +func encoderOfArray(ctx *ctx, typ reflect2.Type) ValEncoder { + arrayType := typ.(*reflect2.UnsafeArrayType) + if arrayType.Len() == 0 { + return emptyArrayEncoder{} + } + encoder := encoderOfType(ctx.append("[arrayElem]"), arrayType.Elem()) + return &arrayEncoder{arrayType, encoder} +} + +type emptyArrayEncoder struct{} + +func (encoder emptyArrayEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + //stream.WriteEmptyArray() +} + +func (encoder emptyArrayEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return true +} + +type arrayEncoder struct { + arrayType *reflect2.UnsafeArrayType + elemEncoder ValEncoder +} + +func (encoder *arrayEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + //stream.WriteArrayStart() + //elemPtr := unsafe.Pointer(ptr) + //encoder.elemEncoder.Encode(elemPtr, stream) + //for i := 1; i < encoder.arrayType.Len(); i++ { + // stream.WriteMore() + // elemPtr = encoder.arrayType.UnsafeGetIndex(ptr, i) + // encoder.elemEncoder.Encode(elemPtr, stream) + //} + //stream.WriteArrayEnd() + //if stream.Error != nil && stream.Error != io.EOF { + // stream.Error = fmt.Errorf("%v: %s", encoder.arrayType, stream.Error.Error()) + //} +} + +func (encoder *arrayEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + +type arrayDecoder struct { + arrayType *reflect2.UnsafeArrayType + elemDecoder ValDecoder +} + +func (decoder *arrayDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + decoder.doDecode(ptr, iter) + if iter.Error != nil && iter.Error != io.EOF { + iter.Error = fmt.Errorf("%v: %s", decoder.arrayType, iter.Error.Error()) + } +} + +func (decoder *arrayDecoder) doDecode(ptr unsafe.Pointer, iter *Iterator) { + c := iter.nextToken() + arrayType := decoder.arrayType + if c == 'n' { + iter.skipBytes('u', 'l', 'l') + return + } + if c != '[' { + iter.ReportError("decode array", "expect [ or n, but found "+string([]byte{c})) + return + } + c = iter.nextToken() + if c == ']' { + return + } + iter.unreadByte() + elemPtr := arrayType.UnsafeGetIndex(ptr, 0) + decoder.elemDecoder.Decode(elemPtr, iter) + length := 1 + for c = iter.nextToken(); c == ','; c = iter.nextToken() { + if length >= arrayType.Len() { + iter.Skip() + continue + } + idx := length + length += 1 + elemPtr = arrayType.UnsafeGetIndex(ptr, idx) + decoder.elemDecoder.Decode(elemPtr, iter) + } + if c != ']' { + iter.ReportError("decode array", "expect ], but found "+string([]byte{c})) + return + } +} diff --git a/douyu/stt/reflect_dynamic.go b/douyu/stt/reflect_dynamic.go new file mode 100644 index 0000000..87088d9 --- /dev/null +++ b/douyu/stt/reflect_dynamic.go @@ -0,0 +1,70 @@ +package stt + +import ( + "github.com/modern-go/reflect2" + "reflect" + "unsafe" +) + +type dynamicEncoder struct { + valType reflect2.Type +} + +func (encoder *dynamicEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + obj := encoder.valType.UnsafeIndirect(ptr) + stream.WriteVal(obj) +} + +func (encoder *dynamicEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return encoder.valType.UnsafeIndirect(ptr) == nil +} + +type efaceDecoder struct { +} + +func (decoder *efaceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + pObj := (*interface{})(ptr) + obj := *pObj + if obj == nil { + *pObj = iter.Read() + return + } + typ := reflect2.TypeOf(obj) + if typ.Kind() != reflect.Ptr { + *pObj = iter.Read() + return + } + ptrType := typ.(*reflect2.UnsafePtrType) + ptrElemType := ptrType.Elem() + if iter.WhatIsNext() == NilValue { + if ptrElemType.Kind() != reflect.Ptr { + iter.skipBytes('n', 'u', 'l', 'l') + *pObj = nil + return + } + } + if reflect2.IsNil(obj) { + obj := ptrElemType.New() + iter.ReadVal(obj) + *pObj = obj + return + } + iter.ReadVal(obj) +} + +type ifaceDecoder struct { + valType *reflect2.UnsafeIFaceType +} + +func (decoder *ifaceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + if iter.ReadNil() { + decoder.valType.UnsafeSet(ptr, decoder.valType.UnsafeNew()) + return + } + obj := decoder.valType.UnsafeIndirect(ptr) + if reflect2.IsNil(obj) { + iter.ReportError("decode non empty interface", "can not unmarshal into nil") + return + } + iter.ReadVal(obj) +} diff --git a/douyu/stt/reflect_extension.go b/douyu/stt/reflect_extension.go new file mode 100644 index 0000000..6725d71 --- /dev/null +++ b/douyu/stt/reflect_extension.go @@ -0,0 +1,476 @@ +package stt + +import ( + "fmt" + "github.com/modern-go/reflect2" + "reflect" + "sort" + "strings" + "unicode" + "unsafe" +) + +var typeDecoders = map[string]ValDecoder{} +var fieldDecoders = map[string]ValDecoder{} +var typeEncoders = map[string]ValEncoder{} +var fieldEncoders = map[string]ValEncoder{} +var extensions = []Extension{} + +// StructDescriptor describe how should we encode/decode the struct +type StructDescriptor struct { + Type reflect2.Type + Fields []*Binding +} + +// GetField get one field from the descriptor by its name. +// Can not use map here to keep field orders. +func (structDescriptor *StructDescriptor) GetField(fieldName string) *Binding { + for _, binding := range structDescriptor.Fields { + if binding.Field.Name() == fieldName { + return binding + } + } + return nil +} + +// Binding describe how should we encode/decode the struct field +type Binding struct { + levels []int + Field reflect2.StructField + FromNames []string + ToNames []string + Encoder ValEncoder + Decoder ValDecoder +} + +// Extension the one for all SPI. Customize encoding/decoding by specifying alternate encoder/decoder. +// Can also rename fields by UpdateStructDescriptor. +type Extension interface { + UpdateStructDescriptor(structDescriptor *StructDescriptor) + CreateMapKeyDecoder(typ reflect2.Type) ValDecoder + CreateMapKeyEncoder(typ reflect2.Type) ValEncoder + CreateDecoder(typ reflect2.Type) ValDecoder + CreateEncoder(typ reflect2.Type) ValEncoder + DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder + DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder +} + +// DummyExtension embed this type get dummy implementation for all methods of Extension +type DummyExtension struct { +} + +// UpdateStructDescriptor No-op +func (extension *DummyExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) { +} + +// CreateMapKeyDecoder No-op +func (extension *DummyExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder { + return nil +} + +// CreateMapKeyEncoder No-op +func (extension *DummyExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder { + return nil +} + +// CreateDecoder No-op +func (extension *DummyExtension) CreateDecoder(typ reflect2.Type) ValDecoder { + return nil +} + +// CreateEncoder No-op +func (extension *DummyExtension) CreateEncoder(typ reflect2.Type) ValEncoder { + return nil +} + +// DecorateDecoder No-op +func (extension *DummyExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder { + return decoder +} + +// DecorateEncoder No-op +func (extension *DummyExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder { + return encoder +} + +type EncoderExtension map[reflect2.Type]ValEncoder + +// UpdateStructDescriptor No-op +func (extension EncoderExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) { +} + +// CreateDecoder No-op +func (extension EncoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder { + return nil +} + +// CreateEncoder get encoder from map +func (extension EncoderExtension) CreateEncoder(typ reflect2.Type) ValEncoder { + return extension[typ] +} + +// CreateMapKeyDecoder No-op +func (extension EncoderExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder { + return nil +} + +// CreateMapKeyEncoder No-op +func (extension EncoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder { + return nil +} + +// DecorateDecoder No-op +func (extension EncoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder { + return decoder +} + +// DecorateEncoder No-op +func (extension EncoderExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder { + return encoder +} + +type DecoderExtension map[reflect2.Type]ValDecoder + +// UpdateStructDescriptor No-op +func (extension DecoderExtension) UpdateStructDescriptor(structDescriptor *StructDescriptor) { +} + +// CreateMapKeyDecoder No-op +func (extension DecoderExtension) CreateMapKeyDecoder(typ reflect2.Type) ValDecoder { + return nil +} + +// CreateMapKeyEncoder No-op +func (extension DecoderExtension) CreateMapKeyEncoder(typ reflect2.Type) ValEncoder { + return nil +} + +// CreateDecoder get decoder from map +func (extension DecoderExtension) CreateDecoder(typ reflect2.Type) ValDecoder { + return extension[typ] +} + +// CreateEncoder No-op +func (extension DecoderExtension) CreateEncoder(typ reflect2.Type) ValEncoder { + return nil +} + +// DecorateDecoder No-op +func (extension DecoderExtension) DecorateDecoder(typ reflect2.Type, decoder ValDecoder) ValDecoder { + return decoder +} + +// DecorateEncoder No-op +func (extension DecoderExtension) DecorateEncoder(typ reflect2.Type, encoder ValEncoder) ValEncoder { + return encoder +} + +type funcDecoder struct { + fun DecoderFunc +} + +func (decoder *funcDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + decoder.fun(ptr, iter) +} + +type funcEncoder struct { + fun EncoderFunc + isEmptyFunc func(ptr unsafe.Pointer) bool +} + +func (encoder *funcEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + encoder.fun(ptr, stream) +} + +func (encoder *funcEncoder) IsEmpty(ptr unsafe.Pointer) bool { + if encoder.isEmptyFunc == nil { + return false + } + return encoder.isEmptyFunc(ptr) +} + +// DecoderFunc the function form of TypeDecoder +type DecoderFunc func(ptr unsafe.Pointer, iter *Iterator) + +// EncoderFunc the function form of TypeEncoder +type EncoderFunc func(ptr unsafe.Pointer, stream *Stream) + +// RegisterTypeDecoderFunc register TypeDecoder for a type with function +func RegisterTypeDecoderFunc(typ string, fun DecoderFunc) { + typeDecoders[typ] = &funcDecoder{fun} +} + +// RegisterTypeDecoder register TypeDecoder for a typ +func RegisterTypeDecoder(typ string, decoder ValDecoder) { + typeDecoders[typ] = decoder +} + +// RegisterFieldDecoderFunc register TypeDecoder for a struct field with function +func RegisterFieldDecoderFunc(typ string, field string, fun DecoderFunc) { + RegisterFieldDecoder(typ, field, &funcDecoder{fun}) +} + +// RegisterFieldDecoder register TypeDecoder for a struct field +func RegisterFieldDecoder(typ string, field string, decoder ValDecoder) { + fieldDecoders[fmt.Sprintf("%s/%s", typ, field)] = decoder +} + +// RegisterTypeEncoderFunc register TypeEncoder for a type with encode/isEmpty function +func RegisterTypeEncoderFunc(typ string, fun EncoderFunc, isEmptyFunc func(unsafe.Pointer) bool) { + typeEncoders[typ] = &funcEncoder{fun, isEmptyFunc} +} + +// RegisterTypeEncoder register TypeEncoder for a type +func RegisterTypeEncoder(typ string, encoder ValEncoder) { + typeEncoders[typ] = encoder +} + +// RegisterFieldEncoderFunc register TypeEncoder for a struct field with encode/isEmpty function +func RegisterFieldEncoderFunc(typ string, field string, fun EncoderFunc, isEmptyFunc func(unsafe.Pointer) bool) { + RegisterFieldEncoder(typ, field, &funcEncoder{fun, isEmptyFunc}) +} + +// RegisterFieldEncoder register TypeEncoder for a struct field +func RegisterFieldEncoder(typ string, field string, encoder ValEncoder) { + fieldEncoders[fmt.Sprintf("%s/%s", typ, field)] = encoder +} + +// RegisterExtension register extension +func RegisterExtension(extension Extension) { + extensions = append(extensions, extension) +} + +func getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder { + decoder := _getTypeDecoderFromExtension(ctx, typ) + if decoder != nil { + for _, extension := range extensions { + decoder = extension.DecorateDecoder(typ, decoder) + } + decoder = ctx.decoderExtension.DecorateDecoder(typ, decoder) + for _, extension := range ctx.extraExtensions { + decoder = extension.DecorateDecoder(typ, decoder) + } + } + return decoder +} +func _getTypeDecoderFromExtension(ctx *ctx, typ reflect2.Type) ValDecoder { + for _, extension := range extensions { + decoder := extension.CreateDecoder(typ) + if decoder != nil { + return decoder + } + } + decoder := ctx.decoderExtension.CreateDecoder(typ) + if decoder != nil { + return decoder + } + for _, extension := range ctx.extraExtensions { + decoder := extension.CreateDecoder(typ) + if decoder != nil { + return decoder + } + } + typeName := typ.String() + decoder = typeDecoders[typeName] + if decoder != nil { + return decoder + } + if typ.Kind() == reflect.Ptr { + ptrType := typ.(*reflect2.UnsafePtrType) + decoder := typeDecoders[ptrType.Elem().String()] + if decoder != nil { + return &OptionalDecoder{ptrType.Elem(), decoder} + } + } + return nil +} + +func getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder { + encoder := _getTypeEncoderFromExtension(ctx, typ) + if encoder != nil { + for _, extension := range extensions { + encoder = extension.DecorateEncoder(typ, encoder) + } + encoder = ctx.encoderExtension.DecorateEncoder(typ, encoder) + for _, extension := range ctx.extraExtensions { + encoder = extension.DecorateEncoder(typ, encoder) + } + } + return encoder +} + +func _getTypeEncoderFromExtension(ctx *ctx, typ reflect2.Type) ValEncoder { + for _, extension := range extensions { + encoder := extension.CreateEncoder(typ) + if encoder != nil { + return encoder + } + } + encoder := ctx.encoderExtension.CreateEncoder(typ) + if encoder != nil { + return encoder + } + for _, extension := range ctx.extraExtensions { + encoder := extension.CreateEncoder(typ) + if encoder != nil { + return encoder + } + } + typeName := typ.String() + encoder = typeEncoders[typeName] + if encoder != nil { + return encoder + } + if typ.Kind() == reflect.Ptr { + typePtr := typ.(*reflect2.UnsafePtrType) + encoder := typeEncoders[typePtr.Elem().String()] + if encoder != nil { + return &OptionalEncoder{encoder} + } + } + return nil +} + +func describeStruct(ctx *ctx, typ reflect2.Type) *StructDescriptor { + structType := typ.(*reflect2.UnsafeStructType) + embeddedBindings := []*Binding{} + bindings := []*Binding{} + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + tag, hasTag := field.Tag().Lookup(ctx.getTagKey()) + if ctx.onlyTaggedField && !hasTag && !field.Anonymous() { + continue + } + if tag == "-" || field.Name() == "_" { + continue + } + tagParts := strings.Split(tag, ",") + if field.Anonymous() && (tag == "" || tagParts[0] == "") { + if field.Type().Kind() == reflect.Struct { + structDescriptor := describeStruct(ctx, field.Type()) + for _, binding := range structDescriptor.Fields { + binding.levels = append([]int{i}, binding.levels...) + omitempty := binding.Encoder.(*structFieldEncoder).omitempty + binding.Encoder = &structFieldEncoder{field, binding.Encoder, omitempty} + binding.Decoder = &structFieldDecoder{field, binding.Decoder} + embeddedBindings = append(embeddedBindings, binding) + } + continue + } else if field.Type().Kind() == reflect.Ptr { + ptrType := field.Type().(*reflect2.UnsafePtrType) + if ptrType.Elem().Kind() == reflect.Struct { + structDescriptor := describeStruct(ctx, ptrType.Elem()) + for _, binding := range structDescriptor.Fields { + binding.levels = append([]int{i}, binding.levels...) + omitempty := binding.Encoder.(*structFieldEncoder).omitempty + binding.Encoder = &dereferenceEncoder{binding.Encoder} + binding.Encoder = &structFieldEncoder{field, binding.Encoder, omitempty} + binding.Decoder = &dereferenceDecoder{ptrType.Elem(), binding.Decoder} + binding.Decoder = &structFieldDecoder{field, binding.Decoder} + embeddedBindings = append(embeddedBindings, binding) + } + continue + } + } + } + fieldNames := calcFieldNames(field.Name(), tagParts[0], tag) + fieldCacheKey := fmt.Sprintf("%s/%s", typ.String(), field.Name()) + decoder := fieldDecoders[fieldCacheKey] + if decoder == nil { + decoder = decoderOfType(ctx.append(field.Name()), field.Type()) + } + encoder := fieldEncoders[fieldCacheKey] + if encoder == nil { + encoder = encoderOfType(ctx.append(field.Name()), field.Type()) + } + binding := &Binding{ + Field: field, + FromNames: fieldNames, + ToNames: fieldNames, + Decoder: decoder, + Encoder: encoder, + } + binding.levels = []int{i} + bindings = append(bindings, binding) + } + return createStructDescriptor(ctx, typ, bindings, embeddedBindings) +} + +func createStructDescriptor(ctx *ctx, typ reflect2.Type, bindings []*Binding, embeddedBindings []*Binding) *StructDescriptor { + structDescriptor := &StructDescriptor{ + Type: typ, + Fields: bindings, + } + for _, extension := range extensions { + extension.UpdateStructDescriptor(structDescriptor) + } + ctx.encoderExtension.UpdateStructDescriptor(structDescriptor) + ctx.decoderExtension.UpdateStructDescriptor(structDescriptor) + for _, extension := range ctx.extraExtensions { + extension.UpdateStructDescriptor(structDescriptor) + } + processTags(structDescriptor, ctx.frozenConfig) + // merge normal & embedded bindings & sort with original order + allBindings := sortableBindings(append(embeddedBindings, structDescriptor.Fields...)) + sort.Sort(allBindings) + structDescriptor.Fields = allBindings + return structDescriptor +} + +type sortableBindings []*Binding + +func (bindings sortableBindings) Len() int { + return len(bindings) +} + +func (bindings sortableBindings) Less(i, j int) bool { + left := bindings[i].levels + right := bindings[j].levels + k := 0 + for { + if left[k] < right[k] { + return true + } else if left[k] > right[k] { + return false + } + k++ + } +} + +func (bindings sortableBindings) Swap(i, j int) { + bindings[i], bindings[j] = bindings[j], bindings[i] +} + +func processTags(structDescriptor *StructDescriptor, cfg *frozenConfig) { + for _, binding := range structDescriptor.Fields { + shouldOmitEmpty := false + tagParts := strings.Split(binding.Field.Tag().Get(cfg.getTagKey()), ",") + for _, tagPart := range tagParts[1:] { + if tagPart == "omitempty" { + shouldOmitEmpty = true + } + } + binding.Decoder = &structFieldDecoder{binding.Field, binding.Decoder} + binding.Encoder = &structFieldEncoder{binding.Field, binding.Encoder, shouldOmitEmpty} + } +} + +func calcFieldNames(originalFieldName string, tagProvidedFieldName string, wholeTag string) []string { + // ignore? + if wholeTag == "-" { + return []string{} + } + // rename? + var fieldNames []string + if tagProvidedFieldName == "" { + fieldNames = []string{originalFieldName} + } else { + fieldNames = []string{tagProvidedFieldName} + } + // private? + isNotExported := unicode.IsLower(rune(originalFieldName[0])) || originalFieldName[0] == '_' + if isNotExported { + fieldNames = []string{} + } + return fieldNames +} diff --git a/douyu/stt/reflect_map.go b/douyu/stt/reflect_map.go new file mode 100644 index 0000000..42031b4 --- /dev/null +++ b/douyu/stt/reflect_map.go @@ -0,0 +1,306 @@ +package stt + +import ( + "fmt" + "github.com/modern-go/reflect2" + "reflect" + "unsafe" +) + +func decoderOfMap(ctx *ctx, typ reflect2.Type) ValDecoder { + mapType := typ.(*reflect2.UnsafeMapType) + keyDecoder := decoderOfMapKey(ctx.append("[mapKey]"), mapType.Key()) + elemDecoder := decoderOfType(ctx.append("[mapElem]"), mapType.Elem()) + return &mapDecoder{ + mapType: mapType, + keyType: mapType.Key(), + elemType: mapType.Elem(), + keyDecoder: keyDecoder, + elemDecoder: elemDecoder, + } +} + +func encoderOfMap(ctx *ctx, typ reflect2.Type) ValEncoder { + mapType := typ.(*reflect2.UnsafeMapType) + //if ctx.sortMapKeys { + // return &sortKeysMapEncoder{ + // mapType: mapType, + // keyEncoder: encoderOfMapKey(ctx.append("[mapKey]"), mapType.Key()), + // elemEncoder: encoderOfType(ctx.append("[mapElem]"), mapType.Elem()), + // } + //} + return &mapEncoder{ + mapType: mapType, + keyEncoder: encoderOfMapKey(ctx.append("[mapKey]"), mapType.Key()), + elemEncoder: encoderOfType(ctx.append("[mapElem]"), mapType.Elem()), + } +} + +func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder { + decoder := ctx.decoderExtension.CreateMapKeyDecoder(typ) + if decoder != nil { + return decoder + } + for _, extension := range ctx.extraExtensions { + decoder := extension.CreateMapKeyDecoder(typ) + if decoder != nil { + return decoder + } + } + + switch typ.Kind() { + case reflect.String: + return decoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String)) + case reflect.Bool, + reflect.Uint8, reflect.Int8, + reflect.Uint16, reflect.Int16, + reflect.Uint32, reflect.Int32, + reflect.Uint64, reflect.Int64, + reflect.Uint, reflect.Int, + reflect.Float32, reflect.Float64, + reflect.Uintptr: + typ = reflect2.DefaultTypeOfKind(typ.Kind()) + return &numericMapKeyDecoder{decoderOfType(ctx, typ)} + default: + return &lazyErrorDecoder{err: fmt.Errorf("unsupported map key type: %v", typ)} + } +} + +func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder { + encoder := ctx.encoderExtension.CreateMapKeyEncoder(typ) + if encoder != nil { + return encoder + } + for _, extension := range ctx.extraExtensions { + encoder := extension.CreateMapKeyEncoder(typ) + if encoder != nil { + return encoder + } + } + + switch typ.Kind() { + case reflect.String: + return encoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String)) + case reflect.Bool, + reflect.Uint8, reflect.Int8, + reflect.Uint16, reflect.Int16, + reflect.Uint32, reflect.Int32, + reflect.Uint64, reflect.Int64, + reflect.Uint, reflect.Int, + reflect.Float32, reflect.Float64, + reflect.Uintptr: + typ = reflect2.DefaultTypeOfKind(typ.Kind()) + return &numericMapKeyEncoder{encoderOfType(ctx, typ)} + default: + if typ.Kind() == reflect.Interface { + return &dynamicMapKeyEncoder{ctx, typ} + } + return &lazyErrorEncoder{err: fmt.Errorf("unsupported map key type: %v", typ)} + } +} + +type mapDecoder struct { + mapType *reflect2.UnsafeMapType + keyType reflect2.Type + elemType reflect2.Type + keyDecoder ValDecoder + elemDecoder ValDecoder +} + +func (decoder *mapDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + mapType := decoder.mapType + c := iter.nextToken() + if c == 'n' { + iter.skipBytes('u', 'l', 'l') + *(*unsafe.Pointer)(ptr) = nil + mapType.UnsafeSet(ptr, mapType.UnsafeNew()) + return + } + if mapType.UnsafeIsNil(ptr) { + mapType.UnsafeSet(ptr, mapType.UnsafeMakeMap(0)) + } + if c != '{' { + iter.ReportError("ReadMapCB", `expect { or n, but found `+string([]byte{c})) + return + } + c = iter.nextToken() + if c == '}' { + return + } + iter.unreadByte() + key := decoder.keyType.UnsafeNew() + decoder.keyDecoder.Decode(key, iter) + c = iter.nextToken() + if c != ':' { + iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c})) + return + } + elem := decoder.elemType.UnsafeNew() + decoder.elemDecoder.Decode(elem, iter) + decoder.mapType.UnsafeSetIndex(ptr, key, elem) + for c = iter.nextToken(); c == ','; c = iter.nextToken() { + key := decoder.keyType.UnsafeNew() + decoder.keyDecoder.Decode(key, iter) + c = iter.nextToken() + if c != ':' { + iter.ReportError("ReadMapCB", "expect : after object field, but found "+string([]byte{c})) + return + } + elem := decoder.elemType.UnsafeNew() + decoder.elemDecoder.Decode(elem, iter) + decoder.mapType.UnsafeSetIndex(ptr, key, elem) + } + if c != '}' { + iter.ReportError("ReadMapCB", `expect }, but found `+string([]byte{c})) + } +} + +type numericMapKeyDecoder struct { + decoder ValDecoder +} + +func (decoder *numericMapKeyDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + c := iter.nextToken() + if c != '"' { + iter.ReportError("ReadMapCB", `expect ", but found `+string([]byte{c})) + return + } + decoder.decoder.Decode(ptr, iter) + c = iter.nextToken() + if c != '"' { + iter.ReportError("ReadMapCB", `expect ", but found `+string([]byte{c})) + return + } +} + +type numericMapKeyEncoder struct { + encoder ValEncoder +} + +func (encoder *numericMapKeyEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.writeBytes('"') + encoder.encoder.Encode(ptr, stream) + stream.writeBytes('"') +} + +func (encoder *numericMapKeyEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + +type dynamicMapKeyEncoder struct { + ctx *ctx + valType reflect2.Type +} + +func (encoder *dynamicMapKeyEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + obj := encoder.valType.UnsafeIndirect(ptr) + encoderOfMapKey(encoder.ctx, reflect2.TypeOf(obj)).Encode(reflect2.PtrOf(obj), stream) +} + +func (encoder *dynamicMapKeyEncoder) IsEmpty(ptr unsafe.Pointer) bool { + obj := encoder.valType.UnsafeIndirect(ptr) + return encoderOfMapKey(encoder.ctx, reflect2.TypeOf(obj)).IsEmpty(reflect2.PtrOf(obj)) +} + +type mapEncoder struct { + mapType *reflect2.UnsafeMapType + keyEncoder ValEncoder + elemEncoder ValEncoder +} + +func (encoder *mapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + if *(*unsafe.Pointer)(ptr) == nil { + stream.WriteNil() + return + } + //stream.WriteObjectStart() + //iter := encoder.mapType.UnsafeIterate(ptr) + //for i := 0; iter.HasNext(); i++ { + // if i != 0 { + // stream.WriteMore() + // } + // key, elem := iter.UnsafeNext() + // encoder.keyEncoder.Encode(key, stream) + // if stream.indention > 0 { + // stream.writeTwoBytes(byte(':'), byte(' ')) + // } else { + // stream.writeByte(':') + // } + // encoder.elemEncoder.Encode(elem, stream) + //} + //stream.WriteObjectEnd() +} + +func (encoder *mapEncoder) IsEmpty(ptr unsafe.Pointer) bool { + iter := encoder.mapType.UnsafeIterate(ptr) + return !iter.HasNext() +} + +//type sortKeysMapEncoder struct { +// mapType *reflect2.UnsafeMapType +// keyEncoder ValEncoder +// elemEncoder ValEncoder +//} +// +//func (encoder *sortKeysMapEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { +// if *(*unsafe.Pointer)(ptr) == nil { +// stream.WriteNil() +// return +// } +// stream.WriteObjectStart() +// mapIter := encoder.mapType.UnsafeIterate(ptr) +// subStream := stream.cfg.BorrowStream(nil) +// subStream.Attachment = stream.Attachment +// subIter := stream.cfg.BorrowIterator(nil) +// keyValues := encodedKeyValues{} +// for mapIter.HasNext() { +// key, elem := mapIter.UnsafeNext() +// subStreamIndex := subStream.Buffered() +// encoder.keyEncoder.Encode(key, subStream) +// if subStream.Error != nil && subStream.Error != io.EOF && stream.Error == nil { +// stream.Error = subStream.Error +// } +// encodedKey := subStream.Buffer()[subStreamIndex:] +// subIter.ResetBytes(encodedKey) +// decodedKey := subIter.ReadString() +// if stream.indention > 0 { +// subStream.writeTwoBytes(byte(':'), byte(' ')) +// } else { +// subStream.writeByte(':') +// } +// encoder.elemEncoder.Encode(elem, subStream) +// keyValues = append(keyValues, encodedKV{ +// key: decodedKey, +// keyValue: subStream.Buffer()[subStreamIndex:], +// }) +// } +// sort.Sort(keyValues) +// for i, keyValue := range keyValues { +// if i != 0 { +// stream.WriteMore() +// } +// stream.Write(keyValue.keyValue) +// } +// if subStream.Error != nil && stream.Error == nil { +// stream.Error = subStream.Error +// } +// stream.WriteObjectEnd() +// stream.cfg.ReturnStream(subStream) +// stream.cfg.ReturnIterator(subIter) +//} +// +//func (encoder *sortKeysMapEncoder) IsEmpty(ptr unsafe.Pointer) bool { +// iter := encoder.mapType.UnsafeIterate(ptr) +// return !iter.HasNext() +//} + +type encodedKeyValues []encodedKV + +type encodedKV struct { + key string + keyValue []byte +} + +func (sv encodedKeyValues) Len() int { return len(sv) } +func (sv encodedKeyValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } +func (sv encodedKeyValues) Less(i, j int) bool { return sv[i].key < sv[j].key } diff --git a/douyu/stt/reflect_native.go b/douyu/stt/reflect_native.go new file mode 100644 index 0000000..a4ecc2d --- /dev/null +++ b/douyu/stt/reflect_native.go @@ -0,0 +1,396 @@ +package stt + +import ( + "github.com/modern-go/reflect2" + "reflect" + "strconv" + "unsafe" +) + +const ptrSize = 32 << uintptr(^uintptr(0)>>63) + +func createEncoderOfNative(ctx *ctx, typ reflect2.Type) ValEncoder { + typeName := typ.String() + kind := typ.Kind() + switch kind { + case reflect.String: + if typeName != "string" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*string)(nil)).Elem()) + } + return &stringCodec{} + case reflect.Int: + if typeName != "int" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*int)(nil)).Elem()) + } + if strconv.IntSize == 32 { + return &int32Codec{} + } + return &int64Codec{} + case reflect.Int8: + if typeName != "int8" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*int8)(nil)).Elem()) + } + return &int8Codec{} + case reflect.Int16: + if typeName != "int16" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*int16)(nil)).Elem()) + } + return &int16Codec{} + case reflect.Int32: + if typeName != "int32" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*int32)(nil)).Elem()) + } + return &int32Codec{} + case reflect.Int64: + if typeName != "int64" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*int64)(nil)).Elem()) + } + return &int64Codec{} + case reflect.Uint: + if typeName != "uint" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*uint)(nil)).Elem()) + } + if strconv.IntSize == 32 { + return &uint32Codec{} + } + return &uint64Codec{} + case reflect.Uint8: + if typeName != "uint8" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*uint8)(nil)).Elem()) + } + return &uint8Codec{} + case reflect.Uint16: + if typeName != "uint16" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*uint16)(nil)).Elem()) + } + return &uint16Codec{} + case reflect.Uint32: + if typeName != "uint32" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*uint32)(nil)).Elem()) + } + return &uint32Codec{} + case reflect.Uintptr: + if typeName != "uintptr" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*uintptr)(nil)).Elem()) + } + if ptrSize == 32 { + return &uint32Codec{} + } + return &uint64Codec{} + case reflect.Uint64: + if typeName != "uint64" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*uint64)(nil)).Elem()) + } + return &uint64Codec{} + case reflect.Float32: + if typeName != "float32" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*float32)(nil)).Elem()) + } + return &float32Codec{} + case reflect.Float64: + if typeName != "float64" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*float64)(nil)).Elem()) + } + return &float64Codec{} + case reflect.Bool: + if typeName != "bool" { + return encoderOfType(ctx, reflect2.TypeOfPtr((*bool)(nil)).Elem()) + } + return &boolCodec{} + } + return nil +} + +func createDecoderOfNative(ctx *ctx, typ reflect2.Type) ValDecoder { + typeName := typ.String() + switch typ.Kind() { + case reflect.String: + if typeName != "string" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*string)(nil)).Elem()) + } + return &stringCodec{} + case reflect.Int: + if typeName != "int" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*int)(nil)).Elem()) + } + if strconv.IntSize == 32 { + return &int32Codec{} + } + return &int64Codec{} + case reflect.Int8: + if typeName != "int8" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*int8)(nil)).Elem()) + } + return &int8Codec{} + case reflect.Int16: + if typeName != "int16" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*int16)(nil)).Elem()) + } + return &int16Codec{} + case reflect.Int32: + if typeName != "int32" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*int32)(nil)).Elem()) + } + return &int32Codec{} + case reflect.Int64: + if typeName != "int64" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*int64)(nil)).Elem()) + } + return &int64Codec{} + case reflect.Uint: + if typeName != "uint" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*uint)(nil)).Elem()) + } + if strconv.IntSize == 32 { + return &uint32Codec{} + } + return &uint64Codec{} + case reflect.Uint8: + if typeName != "uint8" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*uint8)(nil)).Elem()) + } + return &uint8Codec{} + case reflect.Uint16: + if typeName != "uint16" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*uint16)(nil)).Elem()) + } + return &uint16Codec{} + case reflect.Uint32: + if typeName != "uint32" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*uint32)(nil)).Elem()) + } + return &uint32Codec{} + case reflect.Uintptr: + if typeName != "uintptr" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*uintptr)(nil)).Elem()) + } + if ptrSize == 32 { + return &uint32Codec{} + } + return &uint64Codec{} + case reflect.Uint64: + if typeName != "uint64" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*uint64)(nil)).Elem()) + } + return &uint64Codec{} + case reflect.Float32: + if typeName != "float32" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*float32)(nil)).Elem()) + } + return &float32Codec{} + case reflect.Float64: + if typeName != "float64" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*float64)(nil)).Elem()) + } + return &float64Codec{} + case reflect.Bool: + if typeName != "bool" { + return decoderOfType(ctx, reflect2.TypeOfPtr((*bool)(nil)).Elem()) + } + return &boolCodec{} + } + return nil +} + +type stringCodec struct { +} + +func (codec *stringCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { + *((*string)(ptr)) = iter.ReadString() +} + +func (codec *stringCodec) Encode(ptr unsafe.Pointer, stream *Stream) { + str := *((*string)(ptr)) + stream.WriteString(str) +} + +func (codec *stringCodec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*string)(ptr)) == "" +} + +type int8Codec struct { +} + +func (codec *int8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*int8)(ptr)) = iter.ReadInt8() + } +} + +func (codec *int8Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteInt8(*((*int8)(ptr))) +} + +func (codec *int8Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*int8)(ptr)) == 0 +} + +type int16Codec struct { +} + +func (codec *int16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*int16)(ptr)) = iter.ReadInt16() + } +} + +func (codec *int16Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteInt16(*((*int16)(ptr))) +} + +func (codec *int16Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*int16)(ptr)) == 0 +} + +type int32Codec struct { +} + +func (codec *int32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*int32)(ptr)) = iter.ReadInt32() + } +} + +func (codec *int32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteInt32(*((*int32)(ptr))) +} + +func (codec *int32Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*int32)(ptr)) == 0 +} + +type int64Codec struct { +} + +func (codec *int64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*int64)(ptr)) = iter.ReadInt64() + } +} + +func (codec *int64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteInt64(*((*int64)(ptr))) +} + +func (codec *int64Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*int64)(ptr)) == 0 +} + +type uint8Codec struct { +} + +func (codec *uint8Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*uint8)(ptr)) = iter.ReadUint8() + } +} + +func (codec *uint8Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint8(*((*uint8)(ptr))) +} + +func (codec *uint8Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*uint8)(ptr)) == 0 +} + +type uint16Codec struct { +} + +func (codec *uint16Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*uint16)(ptr)) = iter.ReadUint16() + } +} + +func (codec *uint16Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint16(*((*uint16)(ptr))) +} + +func (codec *uint16Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*uint16)(ptr)) == 0 +} + +type uint32Codec struct { +} + +func (codec *uint32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*uint32)(ptr)) = iter.ReadUint32() + } +} + +func (codec *uint32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint32(*((*uint32)(ptr))) +} + +func (codec *uint32Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*uint32)(ptr)) == 0 +} + +type uint64Codec struct { +} + +func (codec *uint64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*uint64)(ptr)) = iter.ReadUint64() + } +} + +func (codec *uint64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteUint64(*((*uint64)(ptr))) +} + +func (codec *uint64Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*uint64)(ptr)) == 0 +} + +type float32Codec struct { +} + +func (codec *float32Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*float32)(ptr)) = iter.ReadFloat32() + } +} + +func (codec *float32Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteFloat32(*((*float32)(ptr))) +} + +func (codec *float32Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*float32)(ptr)) == 0 +} + +type float64Codec struct { +} + +func (codec *float64Codec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*float64)(ptr)) = iter.ReadFloat64() + } +} + +func (codec *float64Codec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteFloat64(*((*float64)(ptr))) +} + +func (codec *float64Codec) IsEmpty(ptr unsafe.Pointer) bool { + return *((*float64)(ptr)) == 0 +} + +type boolCodec struct { +} + +func (codec *boolCodec) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.ReadNil() { + *((*bool)(ptr)) = iter.ReadBool() + } +} + +func (codec *boolCodec) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteBool(*((*bool)(ptr))) +} + +func (codec *boolCodec) IsEmpty(ptr unsafe.Pointer) bool { + return !(*((*bool)(ptr))) +} diff --git a/douyu/stt/reflect_optional.go b/douyu/stt/reflect_optional.go new file mode 100644 index 0000000..794f3fe --- /dev/null +++ b/douyu/stt/reflect_optional.go @@ -0,0 +1,129 @@ +package stt + +import ( + "github.com/modern-go/reflect2" + "unsafe" +) + +func decoderOfOptional(ctx *ctx, typ reflect2.Type) ValDecoder { + ptrType := typ.(*reflect2.UnsafePtrType) + elemType := ptrType.Elem() + decoder := decoderOfType(ctx, elemType) + return &OptionalDecoder{elemType, decoder} +} + +func encoderOfOptional(ctx *ctx, typ reflect2.Type) ValEncoder { + ptrType := typ.(*reflect2.UnsafePtrType) + elemType := ptrType.Elem() + elemEncoder := encoderOfType(ctx, elemType) + encoder := &OptionalEncoder{elemEncoder} + return encoder +} + +type OptionalDecoder struct { + ValueType reflect2.Type + ValueDecoder ValDecoder +} + +func (decoder *OptionalDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + if iter.ReadNil() { + *((*unsafe.Pointer)(ptr)) = nil + } else { + if *((*unsafe.Pointer)(ptr)) == nil { + //pointer to null, we have to allocate memory to hold the value + newPtr := decoder.ValueType.UnsafeNew() + decoder.ValueDecoder.Decode(newPtr, iter) + *((*unsafe.Pointer)(ptr)) = newPtr + } else { + //reuse existing instance + decoder.ValueDecoder.Decode(*((*unsafe.Pointer)(ptr)), iter) + } + } +} + +type dereferenceDecoder struct { + // only to deference a pointer + valueType reflect2.Type + valueDecoder ValDecoder +} + +func (decoder *dereferenceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + if *((*unsafe.Pointer)(ptr)) == nil { + //pointer to null, we have to allocate memory to hold the value + newPtr := decoder.valueType.UnsafeNew() + decoder.valueDecoder.Decode(newPtr, iter) + *((*unsafe.Pointer)(ptr)) = newPtr + } else { + //reuse existing instance + decoder.valueDecoder.Decode(*((*unsafe.Pointer)(ptr)), iter) + } +} + +type OptionalEncoder struct { + ValueEncoder ValEncoder +} + +func (encoder *OptionalEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + if *((*unsafe.Pointer)(ptr)) == nil { + stream.WriteNil() + } else { + encoder.ValueEncoder.Encode(*((*unsafe.Pointer)(ptr)), stream) + } +} + +func (encoder *OptionalEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return *((*unsafe.Pointer)(ptr)) == nil +} + +type dereferenceEncoder struct { + ValueEncoder ValEncoder +} + +func (encoder *dereferenceEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + if *((*unsafe.Pointer)(ptr)) == nil { + stream.WriteNil() + } else { + encoder.ValueEncoder.Encode(*((*unsafe.Pointer)(ptr)), stream) + } +} + +func (encoder *dereferenceEncoder) IsEmpty(ptr unsafe.Pointer) bool { + dePtr := *((*unsafe.Pointer)(ptr)) + if dePtr == nil { + return true + } + return encoder.ValueEncoder.IsEmpty(dePtr) +} + +func (encoder *dereferenceEncoder) IsEmbeddedPtrNil(ptr unsafe.Pointer) bool { + deReferenced := *((*unsafe.Pointer)(ptr)) + if deReferenced == nil { + return true + } + isEmbeddedPtrNil, converted := encoder.ValueEncoder.(IsEmbeddedPtrNil) + if !converted { + return false + } + fieldPtr := unsafe.Pointer(deReferenced) + return isEmbeddedPtrNil.IsEmbeddedPtrNil(fieldPtr) +} + +type referenceEncoder struct { + encoder ValEncoder +} + +func (encoder *referenceEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + encoder.encoder.Encode(unsafe.Pointer(&ptr), stream) +} + +func (encoder *referenceEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return encoder.encoder.IsEmpty(unsafe.Pointer(&ptr)) +} + +type referenceDecoder struct { + decoder ValDecoder +} + +func (decoder *referenceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + decoder.decoder.Decode(unsafe.Pointer(&ptr), iter) +} diff --git a/douyu/stt/reflect_slice.go b/douyu/stt/reflect_slice.go new file mode 100644 index 0000000..19356ba --- /dev/null +++ b/douyu/stt/reflect_slice.go @@ -0,0 +1,99 @@ +package stt + +import ( + "fmt" + "github.com/modern-go/reflect2" + "io" + "unsafe" +) + +func decoderOfSlice(ctx *ctx, typ reflect2.Type) ValDecoder { + sliceType := typ.(*reflect2.UnsafeSliceType) + decoder := decoderOfType(ctx.append("[sliceElem]"), sliceType.Elem()) + return &sliceDecoder{sliceType, decoder} +} + +func encoderOfSlice(ctx *ctx, typ reflect2.Type) ValEncoder { + sliceType := typ.(*reflect2.UnsafeSliceType) + encoder := encoderOfType(ctx.append("[sliceElem]"), sliceType.Elem()) + return &sliceEncoder{sliceType, encoder} +} + +type sliceEncoder struct { + sliceType *reflect2.UnsafeSliceType + elemEncoder ValEncoder +} + +func (encoder *sliceEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + //if encoder.sliceType.UnsafeIsNil(ptr) { + // stream.WriteNil() + // return + //} + //length := encoder.sliceType.UnsafeLengthOf(ptr) + //if length == 0 { + // stream.WriteEmptyArray() + // return + //} + //stream.WriteArrayStart() + //encoder.elemEncoder.Encode(encoder.sliceType.UnsafeGetIndex(ptr, 0), stream) + //for i := 1; i < length; i++ { + // stream.WriteMore() + // elemPtr := encoder.sliceType.UnsafeGetIndex(ptr, i) + // encoder.elemEncoder.Encode(elemPtr, stream) + //} + //stream.WriteArrayEnd() + //if stream.Error != nil && stream.Error != io.EOF { + // stream.Error = fmt.Errorf("%v: %s", encoder.sliceType, stream.Error.Error()) + //} +} + +func (encoder *sliceEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return encoder.sliceType.UnsafeLengthOf(ptr) == 0 +} + +type sliceDecoder struct { + sliceType *reflect2.UnsafeSliceType + elemDecoder ValDecoder +} + +func (decoder *sliceDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + decoder.doDecode(ptr, iter) + if iter.Error != nil && iter.Error != io.EOF { + iter.Error = fmt.Errorf("%v: %s", decoder.sliceType, iter.Error.Error()) + } +} + +func (decoder *sliceDecoder) doDecode(ptr unsafe.Pointer, iter *Iterator) { + c := iter.nextToken() + sliceType := decoder.sliceType + if c == 'n' { + iter.skipBytes('u', 'l', 'l') + sliceType.UnsafeSetNil(ptr) + return + } + if c != '[' { + iter.ReportError("decode slice", "expect [ or n, but found "+string([]byte{c})) + return + } + c = iter.nextToken() + if c == ']' { + sliceType.UnsafeSet(ptr, sliceType.UnsafeMakeSlice(0, 0)) + return + } + iter.unreadByte() + sliceType.UnsafeGrow(ptr, 1) + elemPtr := sliceType.UnsafeGetIndex(ptr, 0) + decoder.elemDecoder.Decode(elemPtr, iter) + length := 1 + for c = iter.nextToken(); c == ','; c = iter.nextToken() { + idx := length + length += 1 + sliceType.UnsafeGrow(ptr, length) + elemPtr = sliceType.UnsafeGetIndex(ptr, idx) + decoder.elemDecoder.Decode(elemPtr, iter) + } + if c != ']' { + iter.ReportError("decode slice", "expect ], but found "+string([]byte{c})) + return + } +} diff --git a/douyu/stt/reflect_struct_decoder.go b/douyu/stt/reflect_struct_decoder.go new file mode 100644 index 0000000..e70764d --- /dev/null +++ b/douyu/stt/reflect_struct_decoder.go @@ -0,0 +1,202 @@ +package stt + +import ( + "bytes" + "fmt" + "github.com/modern-go/reflect2" + "io" + "strings" + "unsafe" +) + +func decoderOfStruct(ctx *ctx, typ reflect2.Type) ValDecoder { + bindings := map[string]*Binding{} + structDescriptor := describeStruct(ctx, typ) + for _, binding := range structDescriptor.Fields { + for _, fromName := range binding.FromNames { + old := bindings[fromName] + if old == nil { + bindings[fromName] = binding + continue + } + ignoreOld, ignoreNew := resolveConflictBinding(ctx.frozenConfig, old, binding) + if ignoreOld { + delete(bindings, fromName) + } + if !ignoreNew { + bindings[fromName] = binding + } + } + } + fields := map[string]*structFieldDecoder{} + for k, binding := range bindings { + fields[k] = binding.Decoder.(*structFieldDecoder) + } + return createStructDecoder(ctx, typ, fields) +} + +func createStructDecoder(ctx *ctx, typ reflect2.Type, fields map[string]*structFieldDecoder) ValDecoder { + switch len(fields) { + case 0: + return &skipObjectDecoder{typ} + } + return &generalStructDecoder{typ, fields} +} + +type generalStructDecoder struct { + typ reflect2.Type + fields map[string]*structFieldDecoder +} + +func (decoder *generalStructDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + if !iter.incrementDepth() { + return + } + var splitBytes = encryptBytes([]byte{'/'}, iter.depth) + cc := splitBytes + var c byte = '/' + for { + if iter.head >= iter.tail { + break + } + if iter.depth > 1 { + if bytes.Equal(cc, splitBytes) { + nc := iter.nextToken() + iter.unreadByte() + if nc != '/' && nc != '@' { + decoder.decodeOneField(ptr, iter) + cc = cc[:0] + for i, sbLen := 0, len(splitBytes); i < sbLen; i++ { + cc = append(cc, iter.nextToken()) + } + } else { + break + } + } else { + cc = cc[:1] + for i, sbLen := 0, len(splitBytes); i < sbLen; i++ { + cc = append(cc, iter.nextToken()) + } + continue + } + } else { + if c == '/' { + nc := iter.nextToken() + iter.unreadByte() + if nc != '/' { + decoder.decodeOneField(ptr, iter) + c = iter.nextToken() + } else { + break + } + } else { + c = iter.nextToken() + continue + } + } + } + if iter.Error != nil && iter.Error != io.EOF && len(decoder.typ.Type1().Name()) != 0 { + iter.Error = fmt.Errorf("%v.%s", decoder.typ, iter.Error.Error()) + } + if c != '/' { + iter.ReportError("struct Decode", `expect /, but found `+string([]byte{c})) + } + iter.decrementDepth() +} + +func (decoder *generalStructDecoder) decodeOneField(ptr unsafe.Pointer, iter *Iterator) { + var field string + var fieldDecoder *structFieldDecoder + + field = iter.ReadFieldName() + fieldDecoder = decoder.fields[field] + if fieldDecoder == nil && !iter.cfg.caseSensitive { + fieldDecoder = decoder.fields[strings.ToLower(field)] + } + + if fieldDecoder == nil { + c := iter.nextToken() + if c != '@' { + iter.ReportError("ReadObject", "expect @ after object field, but found "+string([]byte{c})) + } + iter.skipBytes('=') + iter.Skip() + return + } + c := iter.nextToken() + if c != '@' { + iter.ReportError("ReadObject", "expect @ after object field, but found "+string([]byte{c})) + } + iter.unreadByte() + if iter.depth > 1 { + iter.skipBytes(encryptBytes([]byte("@="), iter.depth)...) + } else { + iter.skipBytes('@', '=') + } + fieldDecoder.Decode(ptr, iter) +} + +type skipObjectDecoder struct { + typ reflect2.Type +} + +func (decoder *skipObjectDecoder) Decode(_ unsafe.Pointer, iter *Iterator) { + valueType := iter.WhatIsNext() + if valueType != NilValue { + iter.ReportError("skipObjectDecoder", "expect object or null") + return + } + iter.Skip() +} + +type structFieldDecoder struct { + field reflect2.StructField + fieldDecoder ValDecoder +} + +func (decoder *structFieldDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + fieldPtr := decoder.field.UnsafeGet(ptr) + decoder.fieldDecoder.Decode(fieldPtr, iter) + if iter.Error != nil && iter.Error != io.EOF { + iter.Error = fmt.Errorf("%s: %s", decoder.field.Name(), iter.Error.Error()) + } +} + +type stringModeStringDecoder struct { + elemDecoder ValDecoder + cfg *frozenConfig +} + +func (decoder *stringModeStringDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + decoder.elemDecoder.Decode(ptr, iter) + str := *((*string)(ptr)) + tempIter := decoder.cfg.BorrowIterator([]byte(str)) + defer decoder.cfg.ReturnIterator(tempIter) + *((*string)(ptr)) = tempIter.ReadString() +} + +type stringModeNumberDecoder struct { + elemDecoder ValDecoder +} + +func (decoder *stringModeNumberDecoder) Decode(ptr unsafe.Pointer, iter *Iterator) { + if iter.WhatIsNext() == NilValue { + decoder.elemDecoder.Decode(ptr, iter) + return + } + + c := iter.nextToken() + if c != '"' { + iter.ReportError("stringModeNumberDecoder", `expect ", but found `+string([]byte{c})) + return + } + decoder.elemDecoder.Decode(ptr, iter) + if iter.Error != nil { + return + } + c = iter.readByte() + if c != '"' { + iter.ReportError("stringModeNumberDecoder", `expect ", but found `+string([]byte{c})) + return + } +} diff --git a/douyu/stt/reflect_struct_encoder.go b/douyu/stt/reflect_struct_encoder.go new file mode 100644 index 0000000..f43a33b --- /dev/null +++ b/douyu/stt/reflect_struct_encoder.go @@ -0,0 +1,183 @@ +package stt + +import ( + "fmt" + "github.com/modern-go/reflect2" + "io" + "reflect" + "unsafe" +) + +func encoderOfStruct(ctx *ctx, typ reflect2.Type) ValEncoder { + type bindingTo struct { + binding *Binding + toName string + ignored bool + } + orderedBindings := []*bindingTo{} + structDescriptor := describeStruct(ctx, typ) + for _, binding := range structDescriptor.Fields { + for _, toName := range binding.ToNames { + new := &bindingTo{ + binding: binding, + toName: toName, + } + for _, old := range orderedBindings { + if old.toName != toName { + continue + } + old.ignored, new.ignored = resolveConflictBinding(ctx.frozenConfig, old.binding, new.binding) + } + orderedBindings = append(orderedBindings, new) + } + } + if len(orderedBindings) == 0 { + return &emptyStructEncoder{} + } + finalOrderedFields := []structFieldTo{} + for _, bindingTo := range orderedBindings { + if !bindingTo.ignored { + finalOrderedFields = append(finalOrderedFields, structFieldTo{ + encoder: bindingTo.binding.Encoder.(*structFieldEncoder), + toName: bindingTo.toName, + }) + } + } + return &structEncoder{typ, finalOrderedFields} +} + +func createCheckIsEmpty(ctx *ctx, typ reflect2.Type) checkIsEmpty { + encoder := createEncoderOfNative(ctx, typ) + if encoder != nil { + return encoder + } + kind := typ.Kind() + switch kind { + case reflect.Interface: + return &dynamicEncoder{typ} + case reflect.Struct: + return &structEncoder{typ: typ} + case reflect.Array: + return &arrayEncoder{} + case reflect.Slice: + return &sliceEncoder{} + case reflect.Map: + return encoderOfMap(ctx, typ) + case reflect.Ptr: + return &OptionalEncoder{} + default: + return &lazyErrorEncoder{err: fmt.Errorf("unsupported type: %v", typ)} + } +} + +func resolveConflictBinding(cfg *frozenConfig, old, new *Binding) (ignoreOld, ignoreNew bool) { + newTagged := new.Field.Tag().Get(cfg.getTagKey()) != "" + oldTagged := old.Field.Tag().Get(cfg.getTagKey()) != "" + if newTagged { + if oldTagged { + if len(old.levels) > len(new.levels) { + return true, false + } else if len(new.levels) > len(old.levels) { + return false, true + } else { + return true, true + } + } else { + return true, false + } + } else { + if oldTagged { + return true, false + } + if len(old.levels) > len(new.levels) { + return true, false + } else if len(new.levels) > len(old.levels) { + return false, true + } else { + return true, true + } + } +} + +type structFieldEncoder struct { + field reflect2.StructField + fieldEncoder ValEncoder + omitempty bool +} + +func (encoder *structFieldEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + fieldPtr := encoder.field.UnsafeGet(ptr) + encoder.fieldEncoder.Encode(fieldPtr, stream) + if stream.Error != nil && stream.Error != io.EOF { + stream.Error = fmt.Errorf("%s: %s", encoder.field.Name(), stream.Error.Error()) + } +} + +func (encoder *structFieldEncoder) IsEmpty(ptr unsafe.Pointer) bool { + fieldPtr := encoder.field.UnsafeGet(ptr) + return encoder.fieldEncoder.IsEmpty(fieldPtr) +} + +func (encoder *structFieldEncoder) IsEmbeddedPtrNil(ptr unsafe.Pointer) bool { + isEmbeddedPtrNil, converted := encoder.fieldEncoder.(IsEmbeddedPtrNil) + if !converted { + return false + } + fieldPtr := encoder.field.UnsafeGet(ptr) + return isEmbeddedPtrNil.IsEmbeddedPtrNil(fieldPtr) +} + +type IsEmbeddedPtrNil interface { + IsEmbeddedPtrNil(ptr unsafe.Pointer) bool +} + +type structEncoder struct { + typ reflect2.Type + fields []structFieldTo +} + +type structFieldTo struct { + encoder *structFieldEncoder + toName string +} + +func (encoder *structEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + if !stream.incrementDepth() { + return + } + isNotFirst := false + for _, field := range encoder.fields { + if field.encoder.omitempty && field.encoder.IsEmpty(ptr) { + continue + } + if field.encoder.IsEmbeddedPtrNil(ptr) { + continue + } + if isNotFirst { + stream.WriteMore() + } + stream.WriteObjectField(field.toName) + field.encoder.Encode(ptr, stream) + isNotFirst = true + } + stream.WriteObjectEnd() + if stream.Error != nil && stream.Error != io.EOF { + stream.Error = fmt.Errorf("%v.%s", encoder.typ, stream.Error.Error()) + } + stream.decrementDepth() +} + +func (encoder *structEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + +type emptyStructEncoder struct { +} + +func (encoder *emptyStructEncoder) Encode(ptr unsafe.Pointer, stream *Stream) { + stream.WriteEmptyObject() +} + +func (encoder *emptyStructEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} diff --git a/douyu/stt/stream.go b/douyu/stt/stream.go new file mode 100644 index 0000000..e8db889 --- /dev/null +++ b/douyu/stt/stream.go @@ -0,0 +1,179 @@ +package stt + +import ( + "io" +) + +// Stream is an io.Writer like object, with STT specific write functions. +// Error is not returned as return value, but stored as Error member on this stream instance. +type Stream struct { + cfg *frozenConfig + out io.Writer + depth int + buf []byte + Error error + Attachment interface{} // open for customized encoder +} + +// NewStream create new stream instance. +// cfg can be stt.ConfigDefault. +// out can be nil if write to internal buffer. +// bufSize is the initial size for the internal buffer in bytes. +func NewStream(cfg API, out io.Writer, bufSize int) *Stream { + return &Stream{ + cfg: cfg.(*frozenConfig), + out: out, + depth: 0, + buf: make([]byte, 0, bufSize), + Error: nil, + } +} + +// Pool returns a pool can provide more stream with same configuration +func (stream *Stream) Pool() StreamPool { + return stream.cfg +} + +// Reset reuse this stream instance by assign a new writer +func (stream *Stream) Reset(out io.Writer) { + stream.depth = 0 + stream.out = out + stream.buf = stream.buf[:0] +} + +// Available returns how many bytes are unused in the buffer. +func (stream *Stream) Available() int { + return cap(stream.buf) - len(stream.buf) +} + +// Buffered returns the number of bytes that have been written into the current buffer. +func (stream *Stream) Buffered() int { + return len(stream.buf) +} + +// Buffer if writer is nil, use this method to take the result +func (stream *Stream) Buffer() []byte { + return stream.buf +} + +// SetBuffer allows to append to the internal buffer directly +func (stream *Stream) SetBuffer(buf []byte) { + stream.buf = buf +} + +// Write writes the contents of p into the buffer. +// It returns the number of bytes written. +// If nn < len(p), it also returns an error explaining +// why write is short. +func (stream *Stream) Write(p []byte) (nn int, err error) { + stream.buf = append(stream.buf, p...) + if stream.out != nil { + nn, err = stream.out.Write(stream.buf) + stream.buf = stream.buf[nn:] + return + } + return len(p), nil +} + +// WriteByte writes a single byte. +func (stream *Stream) writeBytes(c ...byte) { + stream.buf = append(stream.buf, c...) +} + +// Flush writes any buffered data to the underlying io.Writer. +func (stream *Stream) Flush() error { + if stream.out == nil { + return nil + } + if stream.Error != nil { + return stream.Error + } + _, err := stream.out.Write(stream.buf) + if err != nil { + if stream.Error == nil { + stream.Error = err + } + return err + } + stream.buf = stream.buf[:0] + return nil +} + +// WriteRaw write string out without quotes, just like []byte +func (stream *Stream) WriteRaw(s string) { + stream.buf = append(stream.buf, s...) +} + +// WriteNil write null to stream +func (stream *Stream) WriteNil() { + stream.writeBytes('n', 'u', 'l', 'l') +} + +// WriteTrue write true to stream +func (stream *Stream) WriteTrue() { + stream.writeBytes('t', 'r', 'u', 'e') +} + +// WriteFalse write false to stream +func (stream *Stream) WriteFalse() { + stream.writeBytes('f', 'a', 'l', 's', 'e') +} + +// WriteBool write true or false into stream +func (stream *Stream) WriteBool(val bool) { + if val { + stream.WriteTrue() + } else { + stream.WriteFalse() + } +} + +// WriteObjectField write field@= with possible indention +func (stream *Stream) WriteObjectField(field string) { + stream.WriteString(field) + // TODO 尚不清楚斗鱼STT的转义规则,大体来看是2层深度转义2次,1层深度作为operator是不转义的? + bb := []byte("@=") + if stream.depth > 1 { + bb = encryptBytes(bb, stream.depth) + } + stream.writeBytes(bb...) +} + +// WriteObjectEnd write / with possible indention +func (stream *Stream) WriteObjectEnd() { + if stream.depth > 1 { + stream.writeBytes(encryptBytes([]byte{'/'}, 1)...) + } else { + stream.writeBytes('/') + } +} + +// WriteEmptyObject write / +func (stream *Stream) WriteEmptyObject() { + //stream.writeBytes('/') +} + +// WriteMore write / with possible indention +func (stream *Stream) WriteMore() { + if stream.depth > 1 { + stream.writeBytes(encryptBytes([]byte{'/'}, 1)...) + } else { + stream.writeBytes('/') + } +} + +func (stream *Stream) incrementDepth() (success bool) { + stream.depth++ + if stream.depth <= maxDepth { + return true + } + return false +} + +func (stream *Stream) decrementDepth() (success bool) { + stream.depth-- + if stream.depth >= 0 { + return true + } + return false +} diff --git a/douyu/stt/stream_float.go b/douyu/stt/stream_float.go new file mode 100644 index 0000000..4b12e32 --- /dev/null +++ b/douyu/stt/stream_float.go @@ -0,0 +1,47 @@ +package stt + +import ( + "fmt" + "math" + "strconv" +) + +var pow10 []uint64 + +func init() { + pow10 = []uint64{1, 10, 100, 1000, 10000, 100000, 1000000} +} + +// WriteFloat32 write float32 to stream +func (stream *Stream) WriteFloat32(val float32) { + if math.IsInf(float64(val), 0) || math.IsNaN(float64(val)) { + stream.Error = fmt.Errorf("unsupported value: %f", val) + return + } + abs := math.Abs(float64(val)) + ff := byte('f') + // Note: Must use float32 comparisons for underlying float32 value to get precise cutoffs right. + if abs != 0 { + if float32(abs) < 1e-6 || float32(abs) >= 1e21 { + ff = 'e' + } + } + stream.buf = strconv.AppendFloat(stream.buf, float64(val), ff, -1, 32) +} + +// WriteFloat64 write float64 to stream +func (stream *Stream) WriteFloat64(val float64) { + if math.IsInf(val, 0) || math.IsNaN(val) { + stream.Error = fmt.Errorf("unsupported value: %f", val) + return + } + abs := math.Abs(val) + ff := byte('f') + // Note: Must use float32 comparisons for underlying float32 value to get precise cutoffs right. + if abs != 0 { + if abs < 1e-6 || abs >= 1e21 { + ff = 'e' + } + } + stream.buf = strconv.AppendFloat(stream.buf, val, ff, -1, 64) +} diff --git a/douyu/stt/stream_int.go b/douyu/stt/stream_int.go new file mode 100644 index 0000000..5b62364 --- /dev/null +++ b/douyu/stt/stream_int.go @@ -0,0 +1,190 @@ +package stt + +var digits []uint32 + +func init() { + digits = make([]uint32, 1000) + for i := uint32(0); i < 1000; i++ { + digits[i] = (((i / 100) + '0') << 16) + ((((i / 10) % 10) + '0') << 8) + i%10 + '0' + if i < 10 { + digits[i] += 2 << 24 + } else if i < 100 { + digits[i] += 1 << 24 + } + } +} + +func writeFirstBuf(space []byte, v uint32) []byte { + start := v >> 24 + if start == 0 { + space = append(space, byte(v>>16), byte(v>>8)) + } else if start == 1 { + space = append(space, byte(v>>8)) + } + space = append(space, byte(v)) + return space +} + +func writeBuf(buf []byte, v uint32) []byte { + return append(buf, byte(v>>16), byte(v>>8), byte(v)) +} + +// WriteUint8 write uint8 to stream +func (stream *Stream) WriteUint8(val uint8) { + stream.buf = writeFirstBuf(stream.buf, digits[val]) +} + +// WriteInt8 write int8 to stream +func (stream *Stream) WriteInt8(nval int8) { + var val uint8 + if nval < 0 { + val = uint8(-nval) + stream.buf = append(stream.buf, '-') + } else { + val = uint8(nval) + } + stream.buf = writeFirstBuf(stream.buf, digits[val]) +} + +// WriteUint16 write uint16 to stream +func (stream *Stream) WriteUint16(val uint16) { + q1 := val / 1000 + if q1 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[val]) + return + } + r1 := val - q1*1000 + stream.buf = writeFirstBuf(stream.buf, digits[q1]) + stream.buf = writeBuf(stream.buf, digits[r1]) + return +} + +// WriteInt16 write int16 to stream +func (stream *Stream) WriteInt16(nval int16) { + var val uint16 + if nval < 0 { + val = uint16(-nval) + stream.buf = append(stream.buf, '-') + } else { + val = uint16(nval) + } + stream.WriteUint16(val) +} + +// WriteUint32 write uint32 to stream +func (stream *Stream) WriteUint32(val uint32) { + q1 := val / 1000 + if q1 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[val]) + return + } + r1 := val - q1*1000 + q2 := q1 / 1000 + if q2 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[q1]) + stream.buf = writeBuf(stream.buf, digits[r1]) + return + } + r2 := q1 - q2*1000 + q3 := q2 / 1000 + if q3 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[q2]) + } else { + r3 := q2 - q3*1000 + stream.buf = append(stream.buf, byte(q3+'0')) + stream.buf = writeBuf(stream.buf, digits[r3]) + } + stream.buf = writeBuf(stream.buf, digits[r2]) + stream.buf = writeBuf(stream.buf, digits[r1]) +} + +// WriteInt32 write int32 to stream +func (stream *Stream) WriteInt32(nval int32) { + var val uint32 + if nval < 0 { + val = uint32(-nval) + stream.buf = append(stream.buf, '-') + } else { + val = uint32(nval) + } + stream.WriteUint32(val) +} + +// WriteUint64 write uint64 to stream +func (stream *Stream) WriteUint64(val uint64) { + q1 := val / 1000 + if q1 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[val]) + return + } + r1 := val - q1*1000 + q2 := q1 / 1000 + if q2 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[q1]) + stream.buf = writeBuf(stream.buf, digits[r1]) + return + } + r2 := q1 - q2*1000 + q3 := q2 / 1000 + if q3 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[q2]) + stream.buf = writeBuf(stream.buf, digits[r2]) + stream.buf = writeBuf(stream.buf, digits[r1]) + return + } + r3 := q2 - q3*1000 + q4 := q3 / 1000 + if q4 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[q3]) + stream.buf = writeBuf(stream.buf, digits[r3]) + stream.buf = writeBuf(stream.buf, digits[r2]) + stream.buf = writeBuf(stream.buf, digits[r1]) + return + } + r4 := q3 - q4*1000 + q5 := q4 / 1000 + if q5 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[q4]) + stream.buf = writeBuf(stream.buf, digits[r4]) + stream.buf = writeBuf(stream.buf, digits[r3]) + stream.buf = writeBuf(stream.buf, digits[r2]) + stream.buf = writeBuf(stream.buf, digits[r1]) + return + } + r5 := q4 - q5*1000 + q6 := q5 / 1000 + if q6 == 0 { + stream.buf = writeFirstBuf(stream.buf, digits[q5]) + } else { + stream.buf = writeFirstBuf(stream.buf, digits[q6]) + r6 := q5 - q6*1000 + stream.buf = writeBuf(stream.buf, digits[r6]) + } + stream.buf = writeBuf(stream.buf, digits[r5]) + stream.buf = writeBuf(stream.buf, digits[r4]) + stream.buf = writeBuf(stream.buf, digits[r3]) + stream.buf = writeBuf(stream.buf, digits[r2]) + stream.buf = writeBuf(stream.buf, digits[r1]) +} + +// WriteInt64 write int64 to stream +func (stream *Stream) WriteInt64(nval int64) { + var val uint64 + if nval < 0 { + val = uint64(-nval) + stream.buf = append(stream.buf, '-') + } else { + val = uint64(nval) + } + stream.WriteUint64(val) +} + +// WriteInt write int to stream +func (stream *Stream) WriteInt(val int) { + stream.WriteInt64(int64(val)) +} + +// WriteUint write uint to stream +func (stream *Stream) WriteUint(val uint) { + stream.WriteUint64(uint64(val)) +} diff --git a/douyu/stt/stream_str.go b/douyu/stt/stream_str.go new file mode 100644 index 0000000..6d83765 --- /dev/null +++ b/douyu/stt/stream_str.go @@ -0,0 +1,178 @@ +package stt + +import "unicode/utf8" + +// safeSet holds the value true if the ASCII character with the given array +// position can be represented inside a JSON string without any further +// escaping. +// +// All values are true except for the ASCII control characters (0-31), the +// double quote ("), and the backslash character ("\"). +var safeSet = [utf8.RuneSelf]bool{ + ' ': true, + '!': true, + '"': false, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '(': true, + ')': true, + '*': true, + '+': true, + ',': true, + '-': true, + '.': true, + '/': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + ':': true, + ';': true, + '<': true, + '=': true, + '>': true, + '?': true, + '@': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'V': true, + 'W': true, + 'X': true, + 'Y': true, + 'Z': true, + '[': true, + '\\': true, + ']': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '{': true, + '|': true, + '}': true, + '~': true, + '\u007f': true, +} + +var hex = "0123456789abcdef" + +// WriteString write string to stream without html escape +func (stream *Stream) WriteString(s string) { + valLen := len(s) + // write string, the fast path, without utf8 and escape support + i := 0 + for ; i < valLen; i++ { + c := s[i] + if c > 31 { + switch c { + case '@': + stream.buf = append(stream.buf, '@', 'A') + case '/': + stream.buf = append(stream.buf, '@', 'S') + default: + stream.buf = append(stream.buf, c) + } + } else { + break + } + } + if i == valLen { + return + } + writeStringSlowPath(stream, i, s, valLen) +} + +func writeStringSlowPath(stream *Stream, i int, s string, valLen int) { + start := i + // for the remaining parts, we process them char by char + for i < valLen { + if b := s[i]; b < utf8.RuneSelf { + if safeSet[b] { + i++ + continue + } + if start < i { + stream.WriteRaw(s[start:i]) + } + switch b { + case '\\', '"': + stream.writeBytes('\\', b) + case '\n': + stream.writeBytes('\\', 'n') + case '\r': + stream.writeBytes('\\', 'r') + case '\t': + stream.writeBytes('\\', 't') + default: + // This encodes bytes < 0x20 except for \t, \n and \r. + // If escapeHTML is set, it also escapes <, >, and & + // because they can lead to security holes when + // user-controlled strings are rendered into JSON + // and served to some browsers. + stream.WriteRaw(`\u00`) + stream.writeBytes(hex[b>>4], hex[b&0xF]) + } + i++ + start = i + continue + } + i++ + continue + } + if start < len(s) { + stream.WriteRaw(s[start:]) + } +} diff --git a/douyu/stt/stt.go b/douyu/stt/stt.go new file mode 100644 index 0000000..cbab411 --- /dev/null +++ b/douyu/stt/stt.go @@ -0,0 +1,39 @@ +package stt + +import "bytes" + +func Marshal(v any) ([]byte, error) { + return ConfigDefault.Marshal(v) +} + +func Unmarshal(data []byte, v any) error { + return ConfigDefault.Unmarshal(data, v) +} + +func encryptBytes(c []byte, num int) []byte { + for i := 0; i < num; i++ { + c = bytes.ReplaceAll(c, []byte{'@'}, []byte("@A")) + c = bytes.ReplaceAll(c, []byte{'/'}, []byte("@S")) + } + + return c +} + +func decryptBytes(c []byte, num int) []byte { + if num == -1 { + for { + if bytes.Contains(c, []byte("@S")) || bytes.Contains(c, []byte("@A")) { + c = bytes.ReplaceAll(c, []byte("@S"), []byte{'/'}) + c = bytes.ReplaceAll(c, []byte("@A"), []byte{'@'}) + } else { + break + } + } + } else { + for i := 0; i < num; i++ { + c = bytes.ReplaceAll(c, []byte("@S"), []byte{'/'}) + c = bytes.ReplaceAll(c, []byte("@A"), []byte{'@'}) + } + } + return c +} diff --git a/go.mod b/go.mod index 24f0314..3e57511 100644 --- a/go.mod +++ b/go.mod @@ -31,10 +31,13 @@ require ( github.com/jcmturner/gokrb5/v8 v8.4.2 // indirect github.com/jcmturner/rpc/v2 v2.0.3 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.15.1 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/mapstructure v1.4.3 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect github.com/natefinch/lumberjack v2.0.0+incompatible // indirect github.com/pierrec/lz4 v2.6.1+incompatible // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect diff --git a/go.sum b/go.sum index d610f40..9f4c985 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,7 @@ github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.14.4/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A= @@ -106,7 +107,9 @@ github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7/go.mod h1:ZX github.com/mitchellh/mapstructure v1.4.3 h1:OVowDSCllw/YjdLkam3/sm7wEtOy59d8ndGgCcyj8cs= github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/natefinch/lumberjack v2.0.0+incompatible h1:4QJd3OLAMgj7ph+yZTuX13Ld4UpgHp07nNdFX7mqFfM= github.com/natefinch/lumberjack v2.0.0+incompatible/go.mod h1:Wi9p2TTF5DG5oU+6YfsmYQpsTIOm0B1VNzQg9Mw6nPk= diff --git a/main.go b/main.go index b4a79e8..7eab1a4 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "git.noahlan.cn/northlan/ntools-go/logger" "live-gateway/bilibili" "live-gateway/config" + "live-gateway/douyu" "sync" "time" ) @@ -45,5 +46,26 @@ func main() { } }() } + if config.Config.Douyu.Enabled { + bLive := douyu.NewLiveDouyu() + wg.Add(1) + go func() { + if err := bLive.Serve(); err != nil { + logger.SLog.Error("err: ", err) + wg.Done() + } + }() + + //go func() { + // // timer + // timer := time.NewTimer(time.Second * config.Config.Douyu.ResetInterval) + // for { + // select { + // case <-timer.C: + // bLive.ReConnect() + // } + // } + //}() + } wg.Wait() } diff --git a/ws/connection.go b/ws/connection.go index 1922f16..1e4a107 100644 --- a/ws/connection.go +++ b/ws/connection.go @@ -116,6 +116,12 @@ func WithPacker(packer Packer) ConnectionOption { } } +func WithCodec(codec Codec) ConnectionOption { + return func(options *ConnectionOptions) { + options.Codec = codec + } +} + func WithBackoff(b *BackoffOptions) ConnectionOption { return func(options *ConnectionOptions) { options.BackoffOptions = b @@ -323,8 +329,8 @@ func (c *NWebsocket) readLoop() { if c.onReceiveError != nil { c.onReceiveError(errors.Wrapf(err, "decode msg err: %+v", err)) } + break } - break } if c.onBinaryMessageReceived != nil { go c.onBinaryMessageReceived(msg)