package protocol import ( "bytes" "encoding/binary" "errors" "git.noahlan.cn/noahlan/nnet/packet" ) type NNetPacker struct { buf *bytes.Buffer size int // 最近一次 length typ byte // 最近一次 packet type flag byte // 最近一次 flag } // packer constants. const ( headLength = 5 maxPacketSize = 64 * 1024 msgRouteCompressMask = 0x01 // 0000 0001 last bit msgTypeMask = 0x07 // 0000 0111 1-3 bit (需要>>) msgRouteLengthMask = 0xFF // 1111 1111 last 8 bit msgHeadLength = 0x02 // 0000 0010 2 bit ) var ( ErrPacketSizeExceed = errors.New("packer: packet size exceed") ErrWrongMessageType = errors.New("wrong message type") ErrRouteInfoNotFound = errors.New("route info not found in dictionary") ErrWrongMessage = errors.New("wrong message") // ErrWrongPacketType represents a wrong packet type. ErrWrongPacketType = errors.New("wrong packet type") ) func NewNNetPacker() *NNetPacker { p := &NNetPacker{ buf: bytes.NewBuffer(nil), } p.resetFlags() return p } func (d *NNetPacker) resetFlags() { d.size = -1 d.typ = byte(Unknown) d.flag = 0x00 } func (d *NNetPacker) routable(t MsgType) bool { return t == Request || t == Notify || t == Push } func (d *NNetPacker) invalidType(t MsgType) bool { return t < Request || t > Push } func (d *NNetPacker) Pack(header interface{}, data []byte) ([]byte, error) { h, ok := header.(Header) if !ok { return nil, ErrWrongPacketType } typ := h.PacketType if typ < Handshake || typ > Kick { return nil, ErrWrongPacketType } if d.invalidType(h.MsgType) { return nil, ErrWrongMessageType } buf := make([]byte, 0) // packet type buf = append(buf, byte(h.PacketType)) // length buf = append(buf, d.intToBytes(uint32(len(data)))...) // flag flag := byte(h.MsgType << 1) // 编译器提示,此处 byte 转换不能删 code, compressed := routeMap.Routes[h.Route] if compressed { flag |= msgRouteCompressMask } buf = append(buf, flag) // msg id if h.MsgType == Request || h.MsgType == Response { n := h.ID // variant length encode for { b := byte(n % 128) n >>= 7 if n != 0 { buf = append(buf, b+128) } else { buf = append(buf, b) break } } } // route if d.routable(h.MsgType) { if compressed { buf = append(buf, byte((code>>8)&0xFF)) buf = append(buf, byte(code&0xFF)) } else { buf = append(buf, byte(len(h.Route))) buf = append(buf, []byte(h.Route)...) } } // body buf = append(buf, data...) return buf, nil } // Encode packet data length to bytes(Big end) func (d *NNetPacker) intToBytes(n uint32) []byte { buf := make([]byte, 3) buf[0] = byte((n >> 16) & 0xFF) buf[1] = byte((n >> 8) & 0xFF) buf[2] = byte(n & 0xFF) return buf } func (d *NNetPacker) Unpack(data []byte) ([]packet.IPacket, error) { d.buf.Write(data) // copy var ( packets []packet.IPacket err error ) // 检查包长度 if d.buf.Len() < headLength { return nil, err } // 第一次拆包 if d.size < 0 { if err = d.readHeader(); err != nil { return nil, err } } for d.size <= d.buf.Len() { // 读取 p := newPacket(Type(d.typ)) p.MsgType = MsgType((d.flag >> 1) & msgTypeMask) if d.invalidType(p.MsgType) { return nil, ErrWrongMessageType } if p.MsgType == Request || p.MsgType == Response { id := uint64(0) // little end byte order // WARNING: must be stored in 64 bits integer // variant length encode c := 0 for { b, err := d.buf.ReadByte() if err != nil { break } id += uint64(b&0x7F) << uint64(7*c) if b < 128 { break } c++ } p.ID = id } if d.routable(p.MsgType) { if d.flag&msgRouteCompressMask == 1 { p.compressed = true code := binary.BigEndian.Uint16(d.buf.Next(2)) route, ok := routeMap.Codes[code] if !ok { return nil, ErrRouteInfoNotFound } p.Route = route } else { p.compressed = false rl, _ := d.buf.ReadByte() if int(rl) > d.buf.Len() { return nil, ErrWrongMessage } p.Route = string(d.buf.Next(int(rl))) } } p.Length = uint32(d.size) p.Data = d.buf.Next(d.size) packets = append(packets, p) // 剩余数据不满足至少一个数据帧,重置数据帧长度 // 数据缓存内存 保留至 下一次进入本方法以继续拆包 if d.buf.Len() < headLength { d.resetFlags() break } // 读取下一个包 next if err = d.readHeader(); err != nil { return packets, err } } if packets == nil || len(packets) <= 0 { d.resetFlags() d.buf.Reset() } return packets, nil } func (d *NNetPacker) readHeader() error { header := d.buf.Next(headLength) d.typ = header[0] if d.typ < Handshake || d.typ > Kick { return ErrWrongPacketType } d.size = d.bytesToInt(header[1 : len(header)-1]) d.flag = header[len(header)-1] // 最大包限定 if d.size > maxPacketSize { return ErrPacketSizeExceed } return nil } // Decode packet data length byte to int(Big end) func (d *NNetPacker) bytesToInt(b []byte) int { result := 0 for _, v := range b { result = result<<8 + int(v) } return result }