wip: 又双叒加了一些新东西。

main
NorthLan 2 years ago
parent 115166cb11
commit a2ed3090e7

@ -1,6 +1,7 @@
package component package component
import ( import (
"git.noahlan.cn/northlan/nnet/nface"
"reflect" "reflect"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
@ -9,7 +10,7 @@ import (
var ( var (
typeOfError = reflect.TypeOf((*error)(nil)).Elem() typeOfError = reflect.TypeOf((*error)(nil)).Elem()
typeOfBytes = reflect.TypeOf(([]byte)(nil)) typeOfBytes = reflect.TypeOf(([]byte)(nil))
typeOfRequest = reflect.TypeOf(nnet.Request{}) typeOfConnection = reflect.TypeOf((nface.IConnection)(nil))
) )
func isExported(name string) bool { func isExported(name string) bool {
@ -44,8 +45,8 @@ func isHandlerMethod(method reflect.Method) bool {
return false return false
} }
// 第一个显式入参必须是*Request // 第一个显式入参必须是实现了IConnection的具体类的指针类型
if t1 := mt.In(1); t1.Kind() != reflect.Ptr || t1 != typeOfRequest { if t1 := mt.In(1); t1.Kind() != reflect.Ptr || t1 != typeOfConnection {
return false return false
} }

8
env/env.go vendored

@ -0,0 +1,8 @@
package env
import "time"
var (
// TimerPrecision indicates the precision of timer, default is time.Second
TimerPrecision = time.Second
)

@ -6,3 +6,5 @@ require (
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/panjf2000/ants/v2 v2.6.0 github.com/panjf2000/ants/v2 v2.6.0
) )
require google.golang.org/protobuf v1.28.1 // indirect

@ -1,8 +1,14 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/panjf2000/ants/v2 v2.6.0 h1:xOSpw42m+BMiJ2I33we7h6fYzG4DAlpE1xyI7VS2gxU= github.com/panjf2000/ants/v2 v2.6.0 h1:xOSpw42m+BMiJ2I33we7h6fYzG4DAlpE1xyI7VS2gxU=
github.com/panjf2000/ants/v2 v2.6.0/go.mod h1:cU93usDlihJZ5CfRGNDYsiBYvoilLvBF5Qp/BT2GNRE= github.com/panjf2000/ants/v2 v2.6.0/go.mod h1:cU93usDlihJZ5CfRGNDYsiBYvoilLvBF5Qp/BT2GNRE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=

@ -12,8 +12,8 @@ type Logger interface {
Infof(format string, v ...interface{}) Infof(format string, v ...interface{})
Error(v ...interface{}) Error(v ...interface{})
Errorf(format string, v ...interface{}) Errorf(format string, v ...interface{})
Panic(v ...interface{}) Fatal(v ...interface{})
Panicf(format string, v ...interface{}) Fatalf(format string, v ...interface{})
} }
func init() { func init() {
@ -27,8 +27,8 @@ var (
Infof func(format string, v ...interface{}) Infof func(format string, v ...interface{})
Error func(v ...interface{}) Error func(v ...interface{})
Errorf func(format string, v ...interface{}) Errorf func(format string, v ...interface{})
Panic func(v ...interface{}) Fatal func(v ...interface{})
Panicf func(format string, v ...interface{}) Fatalf func(format string, v ...interface{})
) )
func SetLogger(logger Logger) { func SetLogger(logger Logger) {
@ -41,8 +41,8 @@ func SetLogger(logger Logger) {
Infof = logger.Infof Infof = logger.Infof
Error = logger.Error Error = logger.Error
Errorf = logger.Errorf Errorf = logger.Errorf
Panic = logger.Panic Fatal = logger.Fatal
Panicf = logger.Panicf Fatalf = logger.Fatalf
} }
type innerLogger struct { type innerLogger struct {
@ -56,7 +56,7 @@ func newInnerLogger() Logger {
} }
func (i *innerLogger) Debugf(format string, v ...interface{}) { func (i *innerLogger) Debugf(format string, v ...interface{}) {
i.log.Printf(format, v) i.log.Printf(format+"\n", v)
} }
func (i *innerLogger) Debug(v ...interface{}) { func (i *innerLogger) Debug(v ...interface{}) {
@ -68,7 +68,7 @@ func (i *innerLogger) Info(v ...interface{}) {
} }
func (i *innerLogger) Infof(format string, v ...interface{}) { func (i *innerLogger) Infof(format string, v ...interface{}) {
i.log.Printf(format, v) i.log.Printf(format+"\n", v)
} }
func (i *innerLogger) Error(v ...interface{}) { func (i *innerLogger) Error(v ...interface{}) {
@ -76,13 +76,13 @@ func (i *innerLogger) Error(v ...interface{}) {
} }
func (i *innerLogger) Errorf(format string, v ...interface{}) { func (i *innerLogger) Errorf(format string, v ...interface{}) {
i.log.Printf(format, v) i.log.Printf(format+"\n", v)
} }
func (i *innerLogger) Panic(v ...interface{}) { func (i *innerLogger) Fatal(v ...interface{}) {
i.log.Panic(v) i.log.Fatal(v)
} }
func (i *innerLogger) Panicf(format string, v ...interface{}) { func (i *innerLogger) Fatalf(format string, v ...interface{}) {
i.log.Panicf(format, v) i.log.Fatalf(format+"\n", v)
} }

@ -1,18 +0,0 @@
package message
type BinarySerializer struct {
}
func NewBinarySerializer() Serializer {
return &BinarySerializer{}
}
func (b *BinarySerializer) Marshal(i interface{}) ([]byte, error) {
//TODO implement me
panic("implement me")
}
func (b *BinarySerializer) Unmarshal(bytes []byte, i interface{}) error {
//TODO implement me
panic("implement me")
}

@ -0,0 +1,146 @@
package message
import (
"encoding/binary"
"errors"
)
var _ Codec = (*NNetCodec)(nil)
const (
msgRouteCompressMask = 0x01 // 0000 0001 last bit
msgTypeMask = 0x07 // 0000 0111 1-3 bit (需要>>)
msgRouteLengthMask = 0xFF // 1111 1111 last 8 bit
msgHeadLength = 0x02 // 0000 0010 2 bit
)
// Errors that could be occurred in message codec
var (
ErrWrongMessageType = errors.New("wrong message type")
ErrInvalidMessage = errors.New("invalid message")
ErrRouteInfoNotFound = errors.New("route info not found in dictionary")
ErrWrongMessage = errors.New("wrong message")
)
var (
routes = make(map[string]uint16) // route map to code
codes = make(map[uint16]string) // code map to route
)
type NNetCodec struct{}
func (n *NNetCodec) routable(t Type) bool {
return t == Request || t == Notify || t == Push
}
func (n *NNetCodec) invalidType(t Type) bool {
return t < Request || t > Push
}
func (n *NNetCodec) Encode(v interface{}) ([]byte, error) {
m, ok := v.(*Message)
if !ok {
return nil, ErrWrongMessageType
}
if n.invalidType(m.Type) {
return nil, ErrWrongMessageType
}
buf := make([]byte, 0)
flag := byte(m.Type << 1) // 编译器提示,此处 byte 转换不能删
code, compressed := routes[m.Route]
if compressed {
flag |= msgRouteCompressMask
}
buf = append(buf, flag)
if m.Type == Request || m.Type == Response {
n := m.ID
// variant length encode
for {
b := byte(n % 128)
n >>= 7
if n != 0 {
buf = append(buf, b+128)
} else {
buf = append(buf, b)
break
}
}
}
if n.routable(m.Type) {
if compressed {
buf = append(buf, byte((code>>8)&0xFF))
buf = append(buf, byte(code&0xFF))
} else {
buf = append(buf, byte(len(m.Route)))
buf = append(buf, []byte(m.Route)...)
}
}
buf = append(buf, m.Data...)
return buf, nil
}
func (n *NNetCodec) Decode(data []byte) (interface{}, error) {
if len(data) < msgHeadLength {
return nil, ErrInvalidMessage
}
m := New()
flag := data[0]
offset := 1
m.Type = Type((flag >> 1) & msgTypeMask) // 编译器提示,此处Type转换不能删
if n.invalidType(m.Type) {
return nil, ErrWrongMessageType
}
if m.Type == Request || m.Type == Response {
id := uint64(0)
// little end byte order
// WARNING: must can be stored in 64 bits integer
// variant length encode
for i := offset; i < len(data); i++ {
b := data[i]
id += uint64(b&0x7F) << uint64(7*(i-offset))
if b < 128 {
offset = i + 1
break
}
}
m.ID = id
}
if offset >= len(data) {
return nil, ErrWrongMessage
}
if n.routable(m.Type) {
if flag&msgRouteCompressMask == 1 {
m.compressed = true
code := binary.BigEndian.Uint16(data[offset:(offset + 2)])
route, ok := codes[code]
if !ok {
return nil, ErrRouteInfoNotFound
}
m.Route = route
offset += 2
} else {
m.compressed = false
rl := data[offset]
offset++
if offset+int(rl) > len(data) {
return nil, ErrWrongMessage
}
m.Route = string(data[offset:(offset + int(rl))])
offset += int(rl)
}
}
if offset > len(data) {
return nil, ErrWrongMessage
}
m.Data = data[offset:]
return m, nil
}

@ -0,0 +1,11 @@
package message
type (
// Codec 消息编解码器
Codec interface {
// Encode 编码
Encode(v interface{}) ([]byte, error)
// Decode 解码
Decode(data []byte) (interface{}, error)
}
)

@ -1,11 +0,0 @@
package message
type Header struct {
}
type Message struct {
Type byte // 消息类型
ID uint64 // 消息ID
Header []byte // 消息头原始数据
Payload []byte // 数据
}

@ -0,0 +1,46 @@
package message
import (
"fmt"
)
// Type represents the type of message, which could be Request/Notify/Response/Push
type Type byte
// Message types
const (
Request Type = 0x00
Notify = 0x01
Response = 0x02
Push = 0x03
)
var types = map[Type]string{
Request: "Request",
Notify: "Notify",
Response: "Response",
Push: "Push",
}
func (t Type) String() string {
return types[t]
}
// Message represents an unmarshaler message or a message which to be marshaled
type Message struct {
Type Type // message type (flag)
ID uint64 // unique id, zero while notify mode
Route string // route for locating service
Data []byte // payload
compressed bool // if message compressed
}
// New returns a new message instance
func New() *Message {
return &Message{}
}
// String, implementation of fmt.Stringer interface
func (m *Message) String() string {
return fmt.Sprintf("%s %s (%dbytes)", types[m.Type], m.Route, len(m.Data))
}

@ -14,6 +14,8 @@ const (
) )
type IConnection interface { type IConnection interface {
// Server 获取Server实例
Server() IServer
// Status 获取连接状态 // Status 获取连接状态
Status() int32 Status() int32
// SetStatus 设置连接状态 // SetStatus 设置连接状态

@ -0,0 +1,4 @@
package nface
type IServer interface {
}

@ -7,14 +7,20 @@ import (
"git.noahlan.cn/northlan/nnet/packet" "git.noahlan.cn/northlan/nnet/packet"
"git.noahlan.cn/northlan/nnet/pipeline" "git.noahlan.cn/northlan/nnet/pipeline"
"git.noahlan.cn/northlan/nnet/session" "git.noahlan.cn/northlan/nnet/session"
"github.com/gorilla/websocket"
"net" "net"
"sync/atomic" "sync/atomic"
"time" "time"
) )
var ( var (
_ nface.IConnection = (*Connection)(nil)
ErrCloseClosedSession = errors.New("close closed session") ErrCloseClosedSession = errors.New("close closed session")
// ErrBrokenPipe represents the low-level connection has broken.
ErrBrokenPipe = errors.New("broken low-level pipe")
// ErrBufferExceed indicates that the current session buffer is full and
// can not receive more data.
ErrBufferExceed = errors.New("session send buffer exceed")
) )
type ( type (
@ -28,20 +34,20 @@ type (
lastHeartbeatAt int64 // 最近一次心跳时间 lastHeartbeatAt int64 // 最近一次心跳时间
chDie chan struct{} // 停止通道 chDie chan struct{} // 停止通道
chSend chan []byte // 消息发送通道 chSend chan pendingMessage // 消息发送通道
pipeline pipeline.Pipeline // 消息管道 pipeline pipeline.Pipeline // 消息管道
} }
pendingMessage struct { pendingMessage struct {
typ byte // message type typ interface{} // message type
route string // message route route string // message route
mid uint64 // response message id mid uint64 // response message id
payload interface{} // payload payload interface{} // payload
} }
) )
func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) nface.IConnection { func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) *Connection {
r := &Connection{ r := &Connection{
conn: conn, conn: conn,
server: server, server: server,
@ -50,7 +56,7 @@ func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) nf
lastHeartbeatAt: time.Now().Unix(), lastHeartbeatAt: time.Now().Unix(),
chDie: make(chan struct{}), chDie: make(chan struct{}),
chSend: make(chan pendingMessage, 2048), chSend: make(chan pendingMessage, 512),
pipeline: pipeline, pipeline: pipeline,
} }
@ -60,13 +66,18 @@ func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) nf
return r return r
} }
func newConnectionWS(server *Server, conn *websocket.Conn, pipeline pipeline.Pipeline) nface.IConnection { func (r *Connection) send(m pendingMessage) (err error) {
c, err := newWSConn(conn) defer func() {
if err != nil { if e := recover(); e != nil {
// TODO panic ? err = ErrBrokenPipe
panic(err) }
}()
r.chSend <- m
return err
} }
return newConnection(server, c, pipeline)
func (r *Connection) Server() nface.IServer {
return r.server
} }
func (r *Connection) Status() int32 { func (r *Connection) Status() int32 {
@ -85,6 +96,10 @@ func (r *Connection) ID() int64 {
return r.session.ID() return r.session.ID()
} }
func (r *Connection) setLastHeartbeatAt(t int64) {
atomic.StoreInt64(&r.lastHeartbeatAt, t)
}
func (r *Connection) Session() nface.ISession { func (r *Connection) Session() nface.ISession {
return r.session return r.session
} }
@ -100,12 +115,13 @@ func (r *Connection) write() {
close(chWrite) close(chWrite)
_ = r.Close() _ = r.Close()
log.Debugf("Session write goroutine exit, SessionID=%d, UID=%d", r.session.ID(), r.session.UID()) log.Debugf("Connection write goroutine exit, ConnID=%d, SessionUID=%d", r.ID(), r.session.UID())
}() }()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
// TODO heartbeat enable control
deadline := time.Now().Add(-2 * r.server.HeartbeatInterval).Unix() deadline := time.Now().Add(-2 * r.server.HeartbeatInterval).Unix()
if atomic.LoadInt64(&r.lastHeartbeatAt) < deadline { if atomic.LoadInt64(&r.lastHeartbeatAt) < deadline {
log.Debugf("Session heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&r.lastHeartbeatAt), deadline) log.Debugf("Session heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&r.lastHeartbeatAt), deadline)
@ -114,16 +130,22 @@ func (r *Connection) write() {
// TODO heartbeat data // TODO heartbeat data
chWrite <- []byte{} chWrite <- []byte{}
case data := <-r.chSend: case data := <-r.chSend:
// message marshal data // marshal packet body (data)
payload, err := r.server.Serializer.Marshal(data.payload) payload, err := r.server.Serializer.Marshal(data.payload)
if err != nil { if err != nil {
switch data.typ { log.Errorf("message body marshal err: %v", err)
}
break break
} }
// TODO new message and pipeline // TODO new message and pipeline
if pipe := r.pipeline; pipe != nil {
err := pipe.Outbound().Process(r)
if err != nil {
log.Errorf("broken pipeline err: %s", err.Error())
break
}
}
// TODO encode message ? message processor ? // TODO encode message ? message processor ?
// packet pack data // packet pack data
@ -141,16 +163,17 @@ func (r *Connection) write() {
} }
case <-r.chDie: // connection close signal case <-r.chDie: // connection close signal
return return
// TODO application quit signal case <-r.server.DieChan: // application quit signal
return
} }
} }
} }
func (r *Connection) Close() error { func (r *Connection) Close() error {
if r.Status() == StatusClosed { if r.Status() == nface.StatusClosed {
return ErrCloseClosedSession return ErrCloseClosedSession
} }
r.SetStatus(StatusClosed) r.SetStatus(nface.StatusClosed)
log.Debugf("close connection, ID: %d", r.ID()) log.Debugf("close connection, ID: %d", r.ID())

@ -6,6 +6,7 @@ import (
"git.noahlan.cn/northlan/nnet/log" "git.noahlan.cn/northlan/nnet/log"
"git.noahlan.cn/northlan/nnet/packet" "git.noahlan.cn/northlan/nnet/packet"
"git.noahlan.cn/northlan/nnet/pipeline" "git.noahlan.cn/northlan/nnet/pipeline"
"github.com/gorilla/websocket"
"net" "net"
"time" "time"
) )
@ -52,6 +53,15 @@ func (h *Handler) register(comp component.Component, opts []component.Option) er
return nil return nil
} }
func (h *Handler) handleWS(conn *websocket.Conn) {
c, err := newWSConn(conn)
if err != nil {
log.Error(err)
return
}
h.handle(c)
}
func (h *Handler) handle(conn net.Conn) { func (h *Handler) handle(conn net.Conn) {
connection := newConnection(h.server, conn, h.pipeline) connection := newConnection(h.server, conn, h.pipeline)
h.server.sessionMgr.StoreSession(connection.Session()) h.server.sessionMgr.StoreSession(connection.Session())
@ -68,18 +78,20 @@ func (h *Handler) handle(conn net.Conn) {
} }
func (h *Handler) writeLoop(conn *Connection) { func (h *Handler) writeLoop(conn *Connection) {
conn.write()
} }
func (h *Handler) readLoop(conn *Connection) { func (h *Handler) readLoop(conn *Connection) {
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
n, err := conn.conn.Read(buf) n, err := conn.Conn().Read(buf)
if err != nil { if err != nil {
log.Errorf("Read message error: %s, session will be closed immediately", err.Error()) log.Errorf("Read message error: %s, session will be closed immediately", err.Error())
return return
} }
packets, err := h.server.Packer.Unpack(buf)
// warning: 为性能考虑复用slice处理数据buf传入后必须要copy再处理
packets, err := h.server.Packer.Unpack(buf[:n])
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
} }
@ -96,6 +108,6 @@ func (h *Handler) readLoop(conn *Connection) {
func (h *Handler) processPackets(conn *Connection, packets interface{}) error { func (h *Handler) processPackets(conn *Connection, packets interface{}) error {
err := h.processor.ProcessPacket(conn, packets) err := h.processor.ProcessPacket(conn, packets)
conn.lastHeartbeatAt = time.Now().Unix() conn.setLastHeartbeatAt(time.Now().Unix())
return err return err
} }

@ -1 +0,0 @@
package nnet

@ -2,8 +2,10 @@ package nnet
import ( import (
"git.noahlan.cn/northlan/nnet/component" "git.noahlan.cn/northlan/nnet/component"
"git.noahlan.cn/northlan/nnet/env"
"git.noahlan.cn/northlan/nnet/log" "git.noahlan.cn/northlan/nnet/log"
"git.noahlan.cn/northlan/nnet/pipeline" "git.noahlan.cn/northlan/nnet/pipeline"
"net/http"
"time" "time"
) )
@ -34,6 +36,19 @@ func WithHeartbeatInterval(d time.Duration) Option {
} }
} }
// WithTimerPrecision 设置Timer精度
// 注精度需大于1ms, 并且不能在运行时更改
// 默认精度是 time.Second
func WithTimerPrecision(precision time.Duration) Option {
if precision < time.Millisecond {
panic("time precision can not less than a Millisecond")
}
return func(_ *Options) {
env.TimerPrecision = precision
}
}
// WithWebsocket 开启Websocket, 参数是websocket的相关参数 nnet.WSOption
func WithWebsocket(wsOpts ...WSOption) Option { func WithWebsocket(wsOpts ...WSOption) Option {
return func(options *Options) { return func(options *Options) {
for _, opt := range wsOpts { for _, opt := range wsOpts {
@ -43,15 +58,25 @@ func WithWebsocket(wsOpts ...WSOption) Option {
} }
} }
// WithWSPath 设置websocket的path
func WithWSPath(path string) WSOption { func WithWSPath(path string) WSOption {
return func(opts *WSOptions) { return func(opts *WSOptions) {
opts.WebsocketPath = path opts.WebsocketPath = path
} }
} }
func WithWSTLS(certificate, key string) WSOption { // WithWSTLSConfig 设置websocket的证书和密钥
func WithWSTLSConfig(certificate, key string) WSOption {
return func(opts *WSOptions) { return func(opts *WSOptions) {
opts.TLSCertificate = certificate opts.TLSCertificate = certificate
opts.TLSKey = key opts.TLSKey = key
} }
} }
func WithWSCheckOriginFunc(fn func(*http.Request) bool) WSOption {
return func(opts *WSOptions) {
if fn != nil {
opts.CheckOrigin = fn
}
}
}

@ -5,9 +5,10 @@ import (
"fmt" "fmt"
"git.noahlan.cn/northlan/nnet/component" "git.noahlan.cn/northlan/nnet/component"
"git.noahlan.cn/northlan/nnet/log" "git.noahlan.cn/northlan/nnet/log"
"git.noahlan.cn/northlan/nnet/message"
"git.noahlan.cn/northlan/nnet/packet" "git.noahlan.cn/northlan/nnet/packet"
"git.noahlan.cn/northlan/nnet/pipeline" "git.noahlan.cn/northlan/nnet/pipeline"
"git.noahlan.cn/northlan/nnet/scheduler"
"git.noahlan.cn/northlan/nnet/serialize"
"git.noahlan.cn/northlan/nnet/session" "git.noahlan.cn/northlan/nnet/session"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"net" "net"
@ -26,7 +27,7 @@ type (
Components *component.Components // 组件库 Components *component.Components // 组件库
Packer packet.Packer // 封包、拆包器 Packer packet.Packer // 封包、拆包器
PacketProcessor packet.Processor // 数据包处理器 PacketProcessor packet.Processor // 数据包处理器
Serializer message.Serializer // 消息 序列化/反序列化 Serializer serialize.Serializer // 消息 序列化/反序列化
HeartbeatInterval time.Duration // 心跳间隔0表示不进行心跳 HeartbeatInterval time.Duration // 心跳间隔0表示不进行心跳
WS WSOptions // websocket WS WSOptions // websocket
@ -37,6 +38,7 @@ type (
WebsocketPath string // ws地址(ws://127.0.0.1/WebsocketPath) WebsocketPath string // ws地址(ws://127.0.0.1/WebsocketPath)
TLSCertificate string // TLS 证书地址 (websocket) TLSCertificate string // TLS 证书地址 (websocket)
TLSKey string // TLS 证书key地址 (websocket) TLSKey string // TLS 证书key地址 (websocket)
CheckOrigin func(*http.Request) bool // check origin
} }
) )
@ -52,6 +54,8 @@ type Server struct {
// 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP // 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP
// 如果端口号不填或端口号为0例如"127.0.0.1:" 或 ":0",服务端将选择随机可用端口 // 如果端口号不填或端口号为0例如"127.0.0.1:" 或 ":0",服务端将选择随机可用端口
address string address string
// DieChan 应用程序退出信号
DieChan chan struct{}
// handler 消息处理器 // handler 消息处理器
handler *Handler handler *Handler
// sessionMgr session管理器 // sessionMgr session管理器
@ -61,13 +65,17 @@ type Server struct {
func NewServer(protocol, addr string, opts ...Option) *Server { func NewServer(protocol, addr string, opts ...Option) *Server {
options := Options{ options := Options{
Components: &component.Components{}, Components: &component.Components{},
WS: WSOptions{}, WS: WSOptions{
Packer: packet.NewDefaultPacker(), CheckOrigin: func(_ *http.Request) bool { return true },
},
Packer: packet.NewNNetPacker(),
PacketProcessor: packet.NewNNetProcessor(),
} }
s := &Server{ s := &Server{
Options: options, Options: options,
protocol: protocol, protocol: protocol,
address: addr, address: addr,
DieChan: make(chan struct{}),
} }
for _, opt := range opts { for _, opt := range opts {
@ -109,22 +117,29 @@ func (s *Server) Serve() {
} }
}() }()
go scheduler.Schedule()
sg := make(chan os.Signal) sg := make(chan os.Signal)
signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM) signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM)
select { select {
//case <-env.Die: case <-s.DieChan:
// log.Println("The app will shutdown in a few seconds") log.Info("The server will shutdown in a few seconds")
case s := <-sg: case s := <-sg:
log.Infof("server got signal", s) log.Infof("server got signal: %s", s)
} }
log.Infof("server is stopping...") log.Info("server is stopping...")
s.Shutdown()
// TODO close s.shutdown()
scheduler.Close()
} }
func (s *Server) Shutdown() { func (s *Server) Close() {
close(s.DieChan)
}
func (s *Server) shutdown() {
components := s.Components.List() components := s.Components.List()
compLen := len(components) compLen := len(components)
for i := compLen - 1; i >= 0; i-- { for i := compLen - 1; i >= 0; i-- {
@ -139,15 +154,17 @@ func (s *Server) listenAndServe() {
} }
// 监听成功,服务已启动 // 监听成功,服务已启动
// TODO log log.Info("listening...")
defer listener.Close() defer func() {
listener.Close()
s.Close()
}()
go func() {
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
log.Error("服务器网络错误", err) log.Errorf("服务器网络错误 %+v", err)
return return
} }
log.Errorf("监听错误 %v", err) log.Errorf("监听错误 %v", err)
@ -162,26 +179,29 @@ func (s *Server) listenAndServe() {
continue continue
} }
} }
}()
} }
func (s *Server) listenAndServeWS() { func (s *Server) listenAndServeWS() {
upgrade := websocket.Upgrader{ upgrade := websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
CheckOrigin: nil, CheckOrigin: s.WS.CheckOrigin,
EnableCompression: false,
} }
http.HandleFunc(fmt.Sprintf("/%s/", "path"), func(w http.ResponseWriter, r *http.Request) { http.HandleFunc(fmt.Sprintf("/%s/", s.WS.WebsocketPath), func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrade.Upgrade(w, r, nil) conn, err := upgrade.Upgrade(w, r, nil)
if err != nil { if err != nil {
// TODO upgrade failure, uri=r.requestURI err=err.Error() log.Errorf("Upgrade failure, URI=%s, Error=%s", r.RequestURI, err.Error())
return return
} }
// TODO s.handler.handleWS(conn) err = pool.SubmitConn(func() {
s.handler.handleWS(conn)
})
if err != nil {
log.Fatalf("submit conn pool err: %s", err.Error())
}
}) })
if err := http.ListenAndServe(s.address, nil); err != nil { if err := http.ListenAndServe(s.address, nil); err != nil {
panic(err) log.Fatal(err.Error())
} }
} }
@ -189,18 +209,22 @@ func (s *Server) listenAndServeWSTLS() {
upgrade := websocket.Upgrader{ upgrade := websocket.Upgrader{
ReadBufferSize: 1024, ReadBufferSize: 1024,
WriteBufferSize: 1024, WriteBufferSize: 1024,
CheckOrigin: nil, CheckOrigin: s.WS.CheckOrigin,
EnableCompression: false,
} }
http.HandleFunc(fmt.Sprintf("/%s/", "path"), func(w http.ResponseWriter, r *http.Request) { http.HandleFunc(fmt.Sprintf("/%s/", s.WS.WebsocketPath), func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrade.Upgrade(w, r, nil) conn, err := upgrade.Upgrade(w, r, nil)
if err != nil { if err != nil {
// TODO upgrade failure, uri=r.requestURI err=err.Error() log.Errorf("Upgrade failure, URI=%s, Error=%s", r.RequestURI, err.Error())
return return
} }
// TODO s.handler.handleWS(conn) err = pool.SubmitConn(func() {
s.handler.handleWS(conn)
})
if err != nil {
log.Fatalf("submit conn pool err: %s", err.Error())
}
}) })
if err := http.ListenAndServeTLS(s.address, "", "", nil); err != nil { if err := http.ListenAndServeTLS(s.address, s.WS.TLSCertificate, s.WS.TLSKey, nil); err != nil {
panic(err) log.Fatal(err.Error())
} }
} }

@ -0,0 +1,9 @@
package nnet
import "testing"
func TestServer(t *testing.T) {
server := NewServer("tcp4", ":22112")
server.Serve()
}

@ -1,13 +1,15 @@
package packet package packet
import "git.noahlan.cn/northlan/nnet/nface" import (
"git.noahlan.cn/northlan/nnet/nface"
)
// Type 数据帧类型,如:握手,心跳,数据 等 // Type 数据帧类型,如:握手,心跳,数据 等
type Type byte type Type byte
type ( type (
Packer interface { Packer interface {
// Pack 从原始raw bytes创建一个用于网络传输的 packet.Packet 结构 // Pack 从原始raw bytes创建一个用于网络传输的 数据帧结构
Pack(typ Type, data []byte) ([]byte, error) Pack(typ Type, data []byte) ([]byte, error)
// Unpack 解包 // Unpack 解包

@ -5,9 +5,9 @@ import (
"errors" "errors"
) )
var _ Packer = (*DefaultPacker)(nil) var _ Packer = (*NNetPacker)(nil)
type DefaultPacker struct { type NNetPacker struct {
buf *bytes.Buffer buf *bytes.Buffer
size int // 最近一次 长度 size int // 最近一次 长度
typ byte // 最近一次 数据帧类型 typ byte // 最近一次 数据帧类型
@ -21,14 +21,14 @@ const (
var ErrPacketSizeExceed = errors.New("codec: packet size exceed") var ErrPacketSizeExceed = errors.New("codec: packet size exceed")
func NewDefaultPacker() Packer { func NewNNetPacker() Packer {
return &DefaultPacker{ return &NNetPacker{
buf: bytes.NewBuffer(nil), buf: bytes.NewBuffer(nil),
size: -1, size: -1,
} }
} }
func (d *DefaultPacker) Pack(typ Type, data []byte) ([]byte, error) { func (d *NNetPacker) Pack(typ Type, data []byte) ([]byte, error) {
if typ < Handshake || typ > Kick { if typ < Handshake || typ > Kick {
return nil, ErrWrongPacketType return nil, ErrWrongPacketType
} }
@ -47,7 +47,7 @@ func (d *DefaultPacker) Pack(typ Type, data []byte) ([]byte, error) {
} }
// Encode packet data length to bytes(Big end) // Encode packet data length to bytes(Big end)
func (d *DefaultPacker) intToBytes(n uint32) []byte { func (d *NNetPacker) intToBytes(n uint32) []byte {
buf := make([]byte, 3) buf := make([]byte, 3)
buf[0] = byte((n >> 16) & 0xFF) buf[0] = byte((n >> 16) & 0xFF)
buf[1] = byte((n >> 8) & 0xFF) buf[1] = byte((n >> 8) & 0xFF)
@ -55,7 +55,7 @@ func (d *DefaultPacker) intToBytes(n uint32) []byte {
return buf return buf
} }
func (d *DefaultPacker) Unpack(data []byte) ([]interface{}, error) { func (d *NNetPacker) Unpack(data []byte) ([]interface{}, error) {
d.buf.Write(data) // copy d.buf.Write(data) // copy
var ( var (
@ -98,7 +98,7 @@ func (d *DefaultPacker) Unpack(data []byte) ([]interface{}, error) {
return packets, nil return packets, nil
} }
func (d *DefaultPacker) readHeader() error { func (d *NNetPacker) readHeader() error {
header := d.buf.Next(headLength) header := d.buf.Next(headLength)
d.typ = header[0] d.typ = header[0]
if d.typ < Handshake || d.typ > Kick { if d.typ < Handshake || d.typ > Kick {
@ -114,7 +114,7 @@ func (d *DefaultPacker) readHeader() error {
} }
// Decode packet data length byte to int(Big end) // Decode packet data length byte to int(Big end)
func (d *DefaultPacker) bytesToInt(b []byte) int { func (d *NNetPacker) bytesToInt(b []byte) int {
result := 0 result := 0
for _, v := range b { for _, v := range b {
result = result<<8 + int(v) result = result<<8 + int(v)

@ -6,13 +6,13 @@ import (
"git.noahlan.cn/northlan/nnet/nface" "git.noahlan.cn/northlan/nnet/nface"
) )
type DefaultProcessor struct{} type NNetProcessor struct{}
func NewDefaultProcessor() *DefaultProcessor { func NewNNetProcessor() Processor {
return &DefaultProcessor{} return &NNetProcessor{}
} }
func (d *DefaultProcessor) ProcessPacket(conn nface.IConnection, packet interface{}) error { func (d *NNetProcessor) ProcessPacket(conn nface.IConnection, packet interface{}) error {
p := packet.(*Packet) p := packet.(*Packet)
switch p.Type { switch p.Type {
case Handshake: case Handshake:
@ -25,12 +25,12 @@ func (d *DefaultProcessor) ProcessPacket(conn nface.IConnection, packet interfac
case HandshakeAck: case HandshakeAck:
conn.SetStatus(nface.StatusWorking) conn.SetStatus(nface.StatusWorking)
log.Debugf("Receive handshake ACK Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr()) log.Debugf("Receive handshake ACK Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr())
case Data: case Data:
if conn.Status() < nface.StatusWorking { if conn.Status() < nface.StatusWorking {
return fmt.Errorf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s", return fmt.Errorf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s",
conn.Conn().RemoteAddr()) conn.Conn().RemoteAddr())
} }
// TODO message data 处理 // TODO message data 处理
case Heartbeat: case Heartbeat:
// expected // expected

@ -1,11 +1,12 @@
package pipeline package pipeline
import ( import (
"git.noahlan.cn/northlan/nnet/nface"
"sync" "sync"
) )
type ( type (
Func func(request *nnet.Request) error Func func(nface.IConnection) error
// Pipeline 消息管道 // Pipeline 消息管道
Pipeline interface { Pipeline interface {
@ -20,7 +21,7 @@ type (
Channel interface { Channel interface {
PushFront(h Func) PushFront(h Func)
PushBack(h Func) PushBack(h Func)
Process(request *nnet.Request) error Process(nface.IConnection) error
} }
pipelineChannel struct { pipelineChannel struct {
@ -65,7 +66,7 @@ func (p *pipelineChannel) PushBack(h Func) {
} }
// Process 处理所有的pipeline方法 // Process 处理所有的pipeline方法
func (p *pipelineChannel) Process(request *nnet.Request) error { func (p *pipelineChannel) Process(conn nface.IConnection) error {
p.mu.RLock() p.mu.RLock()
defer p.mu.RUnlock() defer p.mu.RUnlock()
@ -74,7 +75,7 @@ func (p *pipelineChannel) Process(request *nnet.Request) error {
} }
for _, handler := range p.handlers { for _, handler := range p.handlers {
err := handler(request) err := handler(conn)
if err != nil { if err != nil {
return err return err
} }

@ -0,0 +1,78 @@
package scheduler
import (
"git.noahlan.cn/northlan/nnet/env"
"git.noahlan.cn/northlan/nnet/log"
"runtime/debug"
"sync/atomic"
"time"
)
const (
messageQueueBacklog = 1 << 10 // 1024
sessionCloseBacklog = 1 << 8 // 256
)
// LocalScheduler schedules task to a customized goroutine
type LocalScheduler interface {
Schedule(Task)
}
type Task func()
type Hook func()
var (
chDie = make(chan struct{})
chExit = make(chan struct{})
chTasks = make(chan Task, 1<<8)
started int32
closed int32
)
func try(f func()) {
defer func() {
if err := recover(); err != nil {
log.Infof("Handle message panic: %+v\n%s", err, debug.Stack())
}
}()
f()
}
func Schedule() {
if atomic.AddInt32(&started, 1) != 1 {
return
}
ticker := time.NewTicker(env.TimerPrecision)
defer func() {
ticker.Stop()
close(chExit)
}()
for {
select {
case <-ticker.C:
cron()
case f := <-chTasks:
try(f)
case <-chDie:
return
}
}
}
func Close() {
if atomic.AddInt32(&closed, 1) != 1 {
return
}
close(chDie)
<-chExit
log.Info("Scheduler stopped")
}
func PushTask(task Task) {
chTasks <- task
}

@ -0,0 +1,194 @@
package scheduler
import (
"git.noahlan.cn/northlan/nnet/log"
"math"
"runtime/debug"
"sync"
"sync/atomic"
"time"
)
const infinite = -1
var (
timerManager = &struct {
incrementID int64 // auto increment id
timers map[int64]*Timer // all timers
sync.Once
muClosingTimer sync.RWMutex // 关闭锁,避免重复关闭
closingTimer []int64 // 已关闭的timer id
muCreatedTimer sync.RWMutex // 创建锁,避免重复创建
createdTimer []*Timer // 已创建的Timer
}{}
)
type (
// TimerFunc represents a function which will be called periodically in main
// logic goroutine.
TimerFunc func()
// TimerCondition represents a checker that returns true when cron job needs
// to execute
TimerCondition interface {
Check(now time.Time) bool
}
// Timer represents a cron job
Timer struct {
id int64 // timer id
fn TimerFunc // function that execute
createAt int64 // timer create time
interval time.Duration // execution interval
condition TimerCondition // condition to cron job execution
elapse int64 // total elapse time
closed int32 // is timer closed
counter int // counter
}
)
func init() {
timerManager.timers = map[int64]*Timer{}
}
// ID returns id of current timer
func (t *Timer) ID() int64 {
return t.id
}
// Stop turns off a timer. After Stop, fn will not be called forever
func (t *Timer) Stop() {
if atomic.AddInt32(&t.closed, 1) != 1 {
return
}
t.counter = 0
}
// safeCall 安全调用,收集所有 fn 触发的 panic给与提示即可
func safeCall(_ int64, fn TimerFunc) {
defer func() {
if err := recover(); err != nil {
log.Infof("Handle timer panic: %+v\n%s", err, debug.Stack())
}
}()
fn()
}
func cron() {
if len(timerManager.createdTimer) > 0 {
timerManager.muCreatedTimer.Lock()
for _, t := range timerManager.createdTimer {
timerManager.timers[t.id] = t
}
timerManager.createdTimer = timerManager.createdTimer[:0]
timerManager.muCreatedTimer.Unlock()
}
if len(timerManager.timers) < 1 {
return
}
now := time.Now()
unn := now.UnixNano()
for id, t := range timerManager.timers {
if t.counter == infinite || t.counter > 0 {
// condition timer
if t.condition != nil {
if t.condition.Check(now) {
safeCall(id, t.fn)
}
continue
}
// execute job
if t.createAt+t.elapse <= unn {
safeCall(id, t.fn)
t.elapse += int64(t.interval)
// update timer counter
if t.counter != infinite && t.counter > 0 {
t.counter--
}
}
}
if t.counter == 0 {
timerManager.muClosingTimer.Lock()
timerManager.closingTimer = append(timerManager.closingTimer, t.id)
timerManager.muClosingTimer.Unlock()
continue
}
}
if len(timerManager.closingTimer) > 0 {
timerManager.muClosingTimer.Lock()
for _, id := range timerManager.closingTimer {
delete(timerManager.timers, id)
}
timerManager.closingTimer = timerManager.closingTimer[:0]
timerManager.muClosingTimer.Unlock()
}
}
// NewTimer returns a new Timer containing a function that will be called
// with a period specified by the duration argument. It adjusts the intervals
// for slow receivers.
// The duration d must be greater than zero; if not, NewTimer will panic.
// Stop the timer to release associated resources.
func NewTimer(interval time.Duration, fn TimerFunc) *Timer {
return NewCountTimer(interval, infinite, fn)
}
// NewCountTimer returns a new Timer containing a function that will be called
// with a period specified by the duration argument. After count times, timer
// will be stopped automatically, It adjusts the intervals for slow receivers.
// The duration d must be greater than zero; if not, NewCountTimer will panic.
// Stop the timer to release associated resources.
func NewCountTimer(interval time.Duration, count int, fn TimerFunc) *Timer {
if fn == nil {
panic("ngs/timer: nil timer function")
}
if interval <= 0 {
panic("non-positive interval for NewTimer")
}
t := &Timer{
id: atomic.AddInt64(&timerManager.incrementID, 1),
fn: fn,
createAt: time.Now().UnixNano(),
interval: interval,
elapse: int64(interval), // first execution will be after interval
counter: count,
}
timerManager.muCreatedTimer.Lock()
timerManager.createdTimer = append(timerManager.createdTimer, t)
timerManager.muCreatedTimer.Unlock()
return t
}
// NewAfterTimer returns a new Timer containing a function that will be called
// after duration that specified by the duration argument.
// The duration d must be greater than zero; if not, NewAfterTimer will panic.
// Stop the timer to release associated resources.
func NewAfterTimer(duration time.Duration, fn TimerFunc) *Timer {
return NewCountTimer(duration, 1, fn)
}
// NewCondTimer returns a new Timer containing a function that will be called
// when condition satisfied that specified by the condition argument.
// The duration d must be greater than zero; if not, NewCondTimer will panic.
// Stop the timer to release associated resources.
func NewCondTimer(condition TimerCondition, fn TimerFunc) *Timer {
if condition == nil {
panic("ngs/timer: nil condition")
}
t := NewCountTimer(time.Duration(math.MaxInt64), infinite, fn)
t.condition = condition
return t
}

@ -0,0 +1,84 @@
package scheduler
import (
"sync/atomic"
"testing"
"time"
)
func TestNewTimer(t *testing.T) {
var exists = struct {
timers int
createdTimes int
closingTimers int
}{
timers: len(timerManager.timers),
createdTimes: len(timerManager.createdTimer),
closingTimers: len(timerManager.closingTimer),
}
const tc = 1000
var counter int64
for i := 0; i < tc; i++ {
NewTimer(1*time.Millisecond, func() {
atomic.AddInt64(&counter, 1)
})
}
<-time.After(5 * time.Millisecond)
cron()
cron()
if counter != tc*2 {
t.Fatalf("expect: %d, got: %d", tc*2, counter)
}
if len(timerManager.timers) != exists.timers+tc {
t.Fatalf("timers: %d", len(timerManager.timers))
}
if len(timerManager.createdTimer) != exists.createdTimes {
t.Fatalf("createdTimer: %d", len(timerManager.createdTimer))
}
if len(timerManager.closingTimer) != exists.closingTimers {
t.Fatalf("closingTimer: %d", len(timerManager.closingTimer))
}
}
func TestNewAfterTimer(t *testing.T) {
var exists = struct {
timers int
createdTimes int
closingTimers int
}{
timers: len(timerManager.timers),
createdTimes: len(timerManager.createdTimer),
closingTimers: len(timerManager.closingTimer),
}
const tc = 1000
var counter int64
for i := 0; i < tc; i++ {
NewAfterTimer(1*time.Millisecond, func() {
atomic.AddInt64(&counter, 1)
})
}
<-time.After(5 * time.Millisecond)
cron()
if counter != tc {
t.Fatalf("expect: %d, got: %d", tc, counter)
}
if len(timerManager.timers) != exists.timers {
t.Fatalf("timers: %d", len(timerManager.timers))
}
if len(timerManager.createdTimer) != exists.createdTimes {
t.Fatalf("createdTimer: %d", len(timerManager.createdTimer))
}
if len(timerManager.closingTimer) != exists.closingTimers {
t.Fatalf("closingTimer: %d", len(timerManager.closingTimer))
}
}

@ -0,0 +1,20 @@
package json
import (
"encoding/json"
"git.noahlan.cn/northlan/nnet/serialize"
)
type Serializer struct{}
func NewSerializer() serialize.Serializer {
return &Serializer{}
}
func (s *Serializer) Marshal(i interface{}) ([]byte, error) {
return json.Marshal(i)
}
func (s *Serializer) Unmarshal(bytes []byte, i interface{}) error {
return json.Unmarshal(bytes, i)
}

@ -0,0 +1,62 @@
package json
import (
"reflect"
"testing"
)
type Message struct {
Code int `json:"code"`
Data string `json:"data"`
}
func TestSerializer_Serialize(t *testing.T) {
m := Message{1, "hello world"}
s := NewSerializer()
b, err := s.Marshal(m)
if err != nil {
t.Fail()
}
m2 := Message{}
if err := s.Unmarshal(b, &m2); err != nil {
t.Fail()
}
if !reflect.DeepEqual(m, m2) {
t.Fail()
}
}
func BenchmarkSerializer_Serialize(b *testing.B) {
m := &Message{100, "hell world"}
s := NewSerializer()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if _, err := s.Marshal(m); err != nil {
b.Fatalf("unmarshal failed: %v", err)
}
}
b.ReportAllocs()
}
func BenchmarkSerializer_Deserialize(b *testing.B) {
m := &Message{100, "hell world"}
s := NewSerializer()
d, err := s.Marshal(m)
if err != nil {
b.Error(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
m1 := &Message{}
if err := s.Unmarshal(d, m1); err != nil {
b.Fatalf("unmarshal failed: %v", err)
}
}
b.ReportAllocs()
}

@ -0,0 +1,32 @@
package protobuf
import (
"errors"
"git.noahlan.cn/northlan/nnet/serialize"
"google.golang.org/protobuf/proto"
)
// ErrWrongValueType is the error used for marshal the value with protobuf encoding.
var ErrWrongValueType = errors.New("protobuf: convert on wrong type value")
type Serializer struct{}
func NewSerializer() serialize.Serializer {
return &Serializer{}
}
func (s *Serializer) Marshal(v interface{}) ([]byte, error) {
pb, ok := v.(proto.Message)
if !ok {
return nil, ErrWrongValueType
}
return proto.Marshal(pb)
}
func (s *Serializer) Unmarshal(data []byte, v interface{}) error {
pb, ok := v.(proto.Message)
if !ok {
return ErrWrongValueType
}
return proto.Unmarshal(data, pb)
}

@ -0,0 +1,56 @@
package protobuf
import (
"git.noahlan.cn/northlan/nnet/serialize/protobuf/testdata"
"google.golang.org/protobuf/proto"
"testing"
)
func TestProtobufSerializer_Serialize(t *testing.T) {
m := &testdata.Ping{Content: "hello"}
s := NewSerializer()
b, err := s.Marshal(m)
if err != nil {
t.Error(err)
}
m1 := &testdata.Ping{}
if err := s.Unmarshal(b, m1); err != nil {
t.Fatalf("unmarshal failed: %v", err)
}
// refer: https://developers.google.com/protocol-buffers/docs/reference/go/faq#deepequal
if !proto.Equal(m, m1) {
t.Fail()
}
}
func BenchmarkSerializer_Serialize(b *testing.B) {
m := &testdata.Ping{Content: "hello"}
s := NewSerializer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := s.Marshal(m); err != nil {
b.Fatalf("unmarshal failed: %v", err)
}
}
}
func BenchmarkSerializer_Deserialize(b *testing.B) {
m := &testdata.Ping{Content: "hello"}
s := NewSerializer()
d, err := s.Marshal(m)
if err != nil {
b.Error(err)
}
b.ReportAllocs()
for i := 0; i < b.N; i++ {
m1 := &testdata.Ping{}
if err := s.Unmarshal(d, m1); err != nil {
b.Fatalf("unmarshal failed: %v", err)
}
}
}

@ -0,0 +1 @@
protoc --go_opt=paths=source_relative --go_out=. --proto_path=. *.proto

@ -0,0 +1,13 @@
syntax = "proto3";
package testdata;
option go_package = "/testdata";
message Ping {
string Content = 1;
}
message Pong {
string Content = 1;
}

@ -1,14 +1,14 @@
package message package serialize
type ( type (
// Marshaler 序列化 // Marshaler 序列化
Marshaler interface { Marshaler interface {
Marshal(interface{}) ([]byte, error) Marshal(v interface{}) ([]byte, error)
} }
// Unmarshaler 反序列化 // Unmarshaler 反序列化
Unmarshaler interface { Unmarshaler interface {
Unmarshal([]byte, interface{}) error Unmarshal(data []byte, v interface{}) error
} }
// Serializer 消息 序列化/反序列化仅针对payload // Serializer 消息 序列化/反序列化仅针对payload
Loading…
Cancel
Save