From a2ed3090e75eaf3bdad55976f87a3f7ea7833c99 Mon Sep 17 00:00:00 2001 From: NorthLan <6995syu@163.com> Date: Mon, 7 Nov 2022 19:18:10 +0800 Subject: [PATCH] =?UTF-8?q?wip:=20=E5=8F=88=E5=8F=8C=E5=8F=92=E5=8A=A0?= =?UTF-8?q?=E4=BA=86=E4=B8=80=E4=BA=9B=E6=96=B0=E4=B8=9C=E8=A5=BF=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- component/util.go | 11 +- env/env.go | 8 + go.mod | 2 + go.sum | 6 + log/logger.go | 26 +-- message/binary_serializer.go | 18 -- message/codec_nnet.go | 146 ++++++++++++++++ message/interface.go | 11 ++ message/message.go | 11 -- message/message_nnet.go | 46 +++++ nface/i_connection.go | 2 + nface/i_server.go | 4 + nnet/connection.go | 65 ++++--- nnet/handler.go | 20 ++- nnet/interface.go | 1 - nnet/options.go | 27 ++- nnet/server.go | 128 ++++++++------ nnet/server_test.go | 9 + packet/interface.go | 8 +- packet/{packer.go => packer_nnet.go} | 18 +- packet/{packet.go => packet_nnet.go} | 0 packet/{processor.go => processor_nnet.go} | 10 +- pipeline/pipeline.go | 9 +- scheduler/scheduler.go | 78 +++++++++ scheduler/timer.go | 194 +++++++++++++++++++++ scheduler/timer_test.go | 84 +++++++++ serialize/json/json.go | 20 +++ serialize/json/json_test.go | 62 +++++++ serialize/protobuf/protobuf.go | 32 ++++ serialize/protobuf/protobuf_test.go | 56 ++++++ serialize/protobuf/testdata/gen_proto.bat | 1 + serialize/protobuf/testdata/test.proto | 13 ++ {message => serialize}/serializer.go | 6 +- 33 files changed, 982 insertions(+), 150 deletions(-) create mode 100644 env/env.go delete mode 100644 message/binary_serializer.go create mode 100644 message/codec_nnet.go create mode 100644 message/interface.go delete mode 100644 message/message.go create mode 100644 message/message_nnet.go create mode 100644 nface/i_server.go delete mode 100644 nnet/interface.go create mode 100644 nnet/server_test.go rename packet/{packer.go => packer_nnet.go} (82%) rename packet/{packet.go => packet_nnet.go} (100%) rename packet/{processor.go => processor_nnet.go} (80%) create mode 100644 scheduler/scheduler.go create mode 100644 scheduler/timer.go create mode 100644 scheduler/timer_test.go create mode 100644 serialize/json/json.go create mode 100644 serialize/json/json_test.go create mode 100644 serialize/protobuf/protobuf.go create mode 100644 serialize/protobuf/protobuf_test.go create mode 100644 serialize/protobuf/testdata/gen_proto.bat create mode 100644 serialize/protobuf/testdata/test.proto rename {message => serialize}/serializer.go (69%) diff --git a/component/util.go b/component/util.go index 964080c..f6c5d06 100644 --- a/component/util.go +++ b/component/util.go @@ -1,15 +1,16 @@ package component import ( + "git.noahlan.cn/northlan/nnet/nface" "reflect" "unicode" "unicode/utf8" ) var ( - typeOfError = reflect.TypeOf((*error)(nil)).Elem() - typeOfBytes = reflect.TypeOf(([]byte)(nil)) - typeOfRequest = reflect.TypeOf(nnet.Request{}) + typeOfError = reflect.TypeOf((*error)(nil)).Elem() + typeOfBytes = reflect.TypeOf(([]byte)(nil)) + typeOfConnection = reflect.TypeOf((nface.IConnection)(nil)) ) func isExported(name string) bool { @@ -44,8 +45,8 @@ func isHandlerMethod(method reflect.Method) bool { return false } - // 第一个显式入参必须是*Request - if t1 := mt.In(1); t1.Kind() != reflect.Ptr || t1 != typeOfRequest { + // 第一个显式入参必须是实现了IConnection的具体类的指针类型 + if t1 := mt.In(1); t1.Kind() != reflect.Ptr || t1 != typeOfConnection { return false } diff --git a/env/env.go b/env/env.go new file mode 100644 index 0000000..89d5adc --- /dev/null +++ b/env/env.go @@ -0,0 +1,8 @@ +package env + +import "time" + +var ( + // TimerPrecision indicates the precision of timer, default is time.Second + TimerPrecision = time.Second +) diff --git a/go.mod b/go.mod index f9567bd..9b6e618 100644 --- a/go.mod +++ b/go.mod @@ -6,3 +6,5 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/panjf2000/ants/v2 v2.6.0 ) + +require google.golang.org/protobuf v1.28.1 // indirect diff --git a/go.sum b/go.sum index d8816c7..de1adcb 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/panjf2000/ants/v2 v2.6.0 h1:xOSpw42m+BMiJ2I33we7h6fYzG4DAlpE1xyI7VS2gxU= github.com/panjf2000/ants/v2 v2.6.0/go.mod h1:cU93usDlihJZ5CfRGNDYsiBYvoilLvBF5Qp/BT2GNRE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= +google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= diff --git a/log/logger.go b/log/logger.go index a03f4be..f67041d 100644 --- a/log/logger.go +++ b/log/logger.go @@ -12,8 +12,8 @@ type Logger interface { Infof(format string, v ...interface{}) Error(v ...interface{}) Errorf(format string, v ...interface{}) - Panic(v ...interface{}) - Panicf(format string, v ...interface{}) + Fatal(v ...interface{}) + Fatalf(format string, v ...interface{}) } func init() { @@ -27,8 +27,8 @@ var ( Infof func(format string, v ...interface{}) Error func(v ...interface{}) Errorf func(format string, v ...interface{}) - Panic func(v ...interface{}) - Panicf func(format string, v ...interface{}) + Fatal func(v ...interface{}) + Fatalf func(format string, v ...interface{}) ) func SetLogger(logger Logger) { @@ -41,8 +41,8 @@ func SetLogger(logger Logger) { Infof = logger.Infof Error = logger.Error Errorf = logger.Errorf - Panic = logger.Panic - Panicf = logger.Panicf + Fatal = logger.Fatal + Fatalf = logger.Fatalf } type innerLogger struct { @@ -56,7 +56,7 @@ func newInnerLogger() Logger { } func (i *innerLogger) Debugf(format string, v ...interface{}) { - i.log.Printf(format, v) + i.log.Printf(format+"\n", v) } func (i *innerLogger) Debug(v ...interface{}) { @@ -68,7 +68,7 @@ func (i *innerLogger) Info(v ...interface{}) { } func (i *innerLogger) Infof(format string, v ...interface{}) { - i.log.Printf(format, v) + i.log.Printf(format+"\n", v) } func (i *innerLogger) Error(v ...interface{}) { @@ -76,13 +76,13 @@ func (i *innerLogger) Error(v ...interface{}) { } func (i *innerLogger) Errorf(format string, v ...interface{}) { - i.log.Printf(format, v) + i.log.Printf(format+"\n", v) } -func (i *innerLogger) Panic(v ...interface{}) { - i.log.Panic(v) +func (i *innerLogger) Fatal(v ...interface{}) { + i.log.Fatal(v) } -func (i *innerLogger) Panicf(format string, v ...interface{}) { - i.log.Panicf(format, v) +func (i *innerLogger) Fatalf(format string, v ...interface{}) { + i.log.Fatalf(format+"\n", v) } diff --git a/message/binary_serializer.go b/message/binary_serializer.go deleted file mode 100644 index 5d30d74..0000000 --- a/message/binary_serializer.go +++ /dev/null @@ -1,18 +0,0 @@ -package message - -type BinarySerializer struct { -} - -func NewBinarySerializer() Serializer { - return &BinarySerializer{} -} - -func (b *BinarySerializer) Marshal(i interface{}) ([]byte, error) { - //TODO implement me - panic("implement me") -} - -func (b *BinarySerializer) Unmarshal(bytes []byte, i interface{}) error { - //TODO implement me - panic("implement me") -} diff --git a/message/codec_nnet.go b/message/codec_nnet.go new file mode 100644 index 0000000..9ef4483 --- /dev/null +++ b/message/codec_nnet.go @@ -0,0 +1,146 @@ +package message + +import ( + "encoding/binary" + "errors" +) + +var _ Codec = (*NNetCodec)(nil) + +const ( + msgRouteCompressMask = 0x01 // 0000 0001 last bit + msgTypeMask = 0x07 // 0000 0111 1-3 bit (需要>>) + msgRouteLengthMask = 0xFF // 1111 1111 last 8 bit + msgHeadLength = 0x02 // 0000 0010 2 bit +) + +// Errors that could be occurred in message codec +var ( + ErrWrongMessageType = errors.New("wrong message type") + ErrInvalidMessage = errors.New("invalid message") + ErrRouteInfoNotFound = errors.New("route info not found in dictionary") + ErrWrongMessage = errors.New("wrong message") +) + +var ( + routes = make(map[string]uint16) // route map to code + codes = make(map[uint16]string) // code map to route +) + +type NNetCodec struct{} + +func (n *NNetCodec) routable(t Type) bool { + return t == Request || t == Notify || t == Push +} + +func (n *NNetCodec) invalidType(t Type) bool { + return t < Request || t > Push +} + +func (n *NNetCodec) Encode(v interface{}) ([]byte, error) { + m, ok := v.(*Message) + if !ok { + return nil, ErrWrongMessageType + } + if n.invalidType(m.Type) { + return nil, ErrWrongMessageType + } + buf := make([]byte, 0) + flag := byte(m.Type << 1) // 编译器提示,此处 byte 转换不能删 + + code, compressed := routes[m.Route] + if compressed { + flag |= msgRouteCompressMask + } + buf = append(buf, flag) + + if m.Type == Request || m.Type == Response { + n := m.ID + // variant length encode + for { + b := byte(n % 128) + n >>= 7 + if n != 0 { + buf = append(buf, b+128) + } else { + buf = append(buf, b) + break + } + } + } + + if n.routable(m.Type) { + if compressed { + buf = append(buf, byte((code>>8)&0xFF)) + buf = append(buf, byte(code&0xFF)) + } else { + buf = append(buf, byte(len(m.Route))) + buf = append(buf, []byte(m.Route)...) + } + } + + buf = append(buf, m.Data...) + return buf, nil +} + +func (n *NNetCodec) Decode(data []byte) (interface{}, error) { + if len(data) < msgHeadLength { + return nil, ErrInvalidMessage + } + m := New() + flag := data[0] + offset := 1 + m.Type = Type((flag >> 1) & msgTypeMask) // 编译器提示,此处Type转换不能删 + + if n.invalidType(m.Type) { + return nil, ErrWrongMessageType + } + + if m.Type == Request || m.Type == Response { + id := uint64(0) + // little end byte order + // WARNING: must can be stored in 64 bits integer + // variant length encode + for i := offset; i < len(data); i++ { + b := data[i] + id += uint64(b&0x7F) << uint64(7*(i-offset)) + if b < 128 { + offset = i + 1 + break + } + } + m.ID = id + } + + if offset >= len(data) { + return nil, ErrWrongMessage + } + + if n.routable(m.Type) { + if flag&msgRouteCompressMask == 1 { + m.compressed = true + code := binary.BigEndian.Uint16(data[offset:(offset + 2)]) + route, ok := codes[code] + if !ok { + return nil, ErrRouteInfoNotFound + } + m.Route = route + offset += 2 + } else { + m.compressed = false + rl := data[offset] + offset++ + if offset+int(rl) > len(data) { + return nil, ErrWrongMessage + } + m.Route = string(data[offset:(offset + int(rl))]) + offset += int(rl) + } + } + + if offset > len(data) { + return nil, ErrWrongMessage + } + m.Data = data[offset:] + return m, nil +} diff --git a/message/interface.go b/message/interface.go new file mode 100644 index 0000000..89d070e --- /dev/null +++ b/message/interface.go @@ -0,0 +1,11 @@ +package message + +type ( + // Codec 消息编解码器 + Codec interface { + // Encode 编码 + Encode(v interface{}) ([]byte, error) + // Decode 解码 + Decode(data []byte) (interface{}, error) + } +) diff --git a/message/message.go b/message/message.go deleted file mode 100644 index 4c5720e..0000000 --- a/message/message.go +++ /dev/null @@ -1,11 +0,0 @@ -package message - -type Header struct { -} - -type Message struct { - Type byte // 消息类型 - ID uint64 // 消息ID - Header []byte // 消息头原始数据 - Payload []byte // 数据 -} diff --git a/message/message_nnet.go b/message/message_nnet.go new file mode 100644 index 0000000..205d2ab --- /dev/null +++ b/message/message_nnet.go @@ -0,0 +1,46 @@ +package message + +import ( + "fmt" +) + +// Type represents the type of message, which could be Request/Notify/Response/Push +type Type byte + +// Message types +const ( + Request Type = 0x00 + Notify = 0x01 + Response = 0x02 + Push = 0x03 +) + +var types = map[Type]string{ + Request: "Request", + Notify: "Notify", + Response: "Response", + Push: "Push", +} + +func (t Type) String() string { + return types[t] +} + +// Message represents an unmarshaler message or a message which to be marshaled +type Message struct { + Type Type // message type (flag) + ID uint64 // unique id, zero while notify mode + Route string // route for locating service + Data []byte // payload + compressed bool // if message compressed +} + +// New returns a new message instance +func New() *Message { + return &Message{} +} + +// String, implementation of fmt.Stringer interface +func (m *Message) String() string { + return fmt.Sprintf("%s %s (%dbytes)", types[m.Type], m.Route, len(m.Data)) +} diff --git a/nface/i_connection.go b/nface/i_connection.go index 358df0f..2718722 100644 --- a/nface/i_connection.go +++ b/nface/i_connection.go @@ -14,6 +14,8 @@ const ( ) type IConnection interface { + // Server 获取Server实例 + Server() IServer // Status 获取连接状态 Status() int32 // SetStatus 设置连接状态 diff --git a/nface/i_server.go b/nface/i_server.go new file mode 100644 index 0000000..1635f02 --- /dev/null +++ b/nface/i_server.go @@ -0,0 +1,4 @@ +package nface + +type IServer interface { +} diff --git a/nnet/connection.go b/nnet/connection.go index 35bb6cd..62170f4 100644 --- a/nnet/connection.go +++ b/nnet/connection.go @@ -7,14 +7,20 @@ import ( "git.noahlan.cn/northlan/nnet/packet" "git.noahlan.cn/northlan/nnet/pipeline" "git.noahlan.cn/northlan/nnet/session" - "github.com/gorilla/websocket" "net" "sync/atomic" "time" ) var ( + _ nface.IConnection = (*Connection)(nil) + ErrCloseClosedSession = errors.New("close closed session") + // ErrBrokenPipe represents the low-level connection has broken. + ErrBrokenPipe = errors.New("broken low-level pipe") + // ErrBufferExceed indicates that the current session buffer is full and + // can not receive more data. + ErrBufferExceed = errors.New("session send buffer exceed") ) type ( @@ -27,21 +33,21 @@ type ( lastMid uint64 // 最近一次消息ID lastHeartbeatAt int64 // 最近一次心跳时间 - chDie chan struct{} // 停止通道 - chSend chan []byte // 消息发送通道 + chDie chan struct{} // 停止通道 + chSend chan pendingMessage // 消息发送通道 pipeline pipeline.Pipeline // 消息管道 } pendingMessage struct { - typ byte // message type + typ interface{} // message type route string // message route mid uint64 // response message id payload interface{} // payload } ) -func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) nface.IConnection { +func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) *Connection { r := &Connection{ conn: conn, server: server, @@ -50,7 +56,7 @@ func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) nf lastHeartbeatAt: time.Now().Unix(), chDie: make(chan struct{}), - chSend: make(chan pendingMessage, 2048), + chSend: make(chan pendingMessage, 512), pipeline: pipeline, } @@ -60,13 +66,18 @@ func newConnection(server *Server, conn net.Conn, pipeline pipeline.Pipeline) nf return r } -func newConnectionWS(server *Server, conn *websocket.Conn, pipeline pipeline.Pipeline) nface.IConnection { - c, err := newWSConn(conn) - if err != nil { - // TODO panic ? - panic(err) - } - return newConnection(server, c, pipeline) +func (r *Connection) send(m pendingMessage) (err error) { + defer func() { + if e := recover(); e != nil { + err = ErrBrokenPipe + } + }() + r.chSend <- m + return err +} + +func (r *Connection) Server() nface.IServer { + return r.server } func (r *Connection) Status() int32 { @@ -85,6 +96,10 @@ func (r *Connection) ID() int64 { return r.session.ID() } +func (r *Connection) setLastHeartbeatAt(t int64) { + atomic.StoreInt64(&r.lastHeartbeatAt, t) +} + func (r *Connection) Session() nface.ISession { return r.session } @@ -100,12 +115,13 @@ func (r *Connection) write() { close(chWrite) _ = r.Close() - log.Debugf("Session write goroutine exit, SessionID=%d, UID=%d", r.session.ID(), r.session.UID()) + log.Debugf("Connection write goroutine exit, ConnID=%d, SessionUID=%d", r.ID(), r.session.UID()) }() for { select { case <-ticker.C: + // TODO heartbeat enable control deadline := time.Now().Add(-2 * r.server.HeartbeatInterval).Unix() if atomic.LoadInt64(&r.lastHeartbeatAt) < deadline { log.Debugf("Session heartbeat timeout, LastTime=%d, Deadline=%d", atomic.LoadInt64(&r.lastHeartbeatAt), deadline) @@ -114,16 +130,22 @@ func (r *Connection) write() { // TODO heartbeat data chWrite <- []byte{} case data := <-r.chSend: - // message marshal data + // marshal packet body (data) payload, err := r.server.Serializer.Marshal(data.payload) if err != nil { - switch data.typ { - - } + log.Errorf("message body marshal err: %v", err) break } // TODO new message and pipeline + if pipe := r.pipeline; pipe != nil { + err := pipe.Outbound().Process(r) + if err != nil { + log.Errorf("broken pipeline err: %s", err.Error()) + break + } + } + // TODO encode message ? message processor ? // packet pack data @@ -141,16 +163,17 @@ func (r *Connection) write() { } case <-r.chDie: // connection close signal return - // TODO application quit signal + case <-r.server.DieChan: // application quit signal + return } } } func (r *Connection) Close() error { - if r.Status() == StatusClosed { + if r.Status() == nface.StatusClosed { return ErrCloseClosedSession } - r.SetStatus(StatusClosed) + r.SetStatus(nface.StatusClosed) log.Debugf("close connection, ID: %d", r.ID()) diff --git a/nnet/handler.go b/nnet/handler.go index 11aaed5..2600a44 100644 --- a/nnet/handler.go +++ b/nnet/handler.go @@ -6,6 +6,7 @@ import ( "git.noahlan.cn/northlan/nnet/log" "git.noahlan.cn/northlan/nnet/packet" "git.noahlan.cn/northlan/nnet/pipeline" + "github.com/gorilla/websocket" "net" "time" ) @@ -52,6 +53,15 @@ func (h *Handler) register(comp component.Component, opts []component.Option) er return nil } +func (h *Handler) handleWS(conn *websocket.Conn) { + c, err := newWSConn(conn) + if err != nil { + log.Error(err) + return + } + h.handle(c) +} + func (h *Handler) handle(conn net.Conn) { connection := newConnection(h.server, conn, h.pipeline) h.server.sessionMgr.StoreSession(connection.Session()) @@ -68,18 +78,20 @@ func (h *Handler) handle(conn net.Conn) { } func (h *Handler) writeLoop(conn *Connection) { - + conn.write() } func (h *Handler) readLoop(conn *Connection) { buf := make([]byte, 4096) for { - n, err := conn.conn.Read(buf) + n, err := conn.Conn().Read(buf) if err != nil { log.Errorf("Read message error: %s, session will be closed immediately", err.Error()) return } - packets, err := h.server.Packer.Unpack(buf) + + // warning: 为性能考虑,复用slice处理数据,buf传入后必须要copy再处理 + packets, err := h.server.Packer.Unpack(buf[:n]) if err != nil { log.Error(err.Error()) } @@ -96,6 +108,6 @@ func (h *Handler) readLoop(conn *Connection) { func (h *Handler) processPackets(conn *Connection, packets interface{}) error { err := h.processor.ProcessPacket(conn, packets) - conn.lastHeartbeatAt = time.Now().Unix() + conn.setLastHeartbeatAt(time.Now().Unix()) return err } diff --git a/nnet/interface.go b/nnet/interface.go deleted file mode 100644 index 93328a5..0000000 --- a/nnet/interface.go +++ /dev/null @@ -1 +0,0 @@ -package nnet diff --git a/nnet/options.go b/nnet/options.go index 4829a33..7233e7c 100644 --- a/nnet/options.go +++ b/nnet/options.go @@ -2,8 +2,10 @@ package nnet import ( "git.noahlan.cn/northlan/nnet/component" + "git.noahlan.cn/northlan/nnet/env" "git.noahlan.cn/northlan/nnet/log" "git.noahlan.cn/northlan/nnet/pipeline" + "net/http" "time" ) @@ -34,6 +36,19 @@ func WithHeartbeatInterval(d time.Duration) Option { } } +// WithTimerPrecision 设置Timer精度 +// 注:精度需大于1ms, 并且不能在运行时更改 +// 默认精度是 time.Second +func WithTimerPrecision(precision time.Duration) Option { + if precision < time.Millisecond { + panic("time precision can not less than a Millisecond") + } + return func(_ *Options) { + env.TimerPrecision = precision + } +} + +// WithWebsocket 开启Websocket, 参数是websocket的相关参数 nnet.WSOption func WithWebsocket(wsOpts ...WSOption) Option { return func(options *Options) { for _, opt := range wsOpts { @@ -43,15 +58,25 @@ func WithWebsocket(wsOpts ...WSOption) Option { } } +// WithWSPath 设置websocket的path func WithWSPath(path string) WSOption { return func(opts *WSOptions) { opts.WebsocketPath = path } } -func WithWSTLS(certificate, key string) WSOption { +// WithWSTLSConfig 设置websocket的证书和密钥 +func WithWSTLSConfig(certificate, key string) WSOption { return func(opts *WSOptions) { opts.TLSCertificate = certificate opts.TLSKey = key } } + +func WithWSCheckOriginFunc(fn func(*http.Request) bool) WSOption { + return func(opts *WSOptions) { + if fn != nil { + opts.CheckOrigin = fn + } + } +} diff --git a/nnet/server.go b/nnet/server.go index 5077e21..3685b76 100644 --- a/nnet/server.go +++ b/nnet/server.go @@ -5,9 +5,10 @@ import ( "fmt" "git.noahlan.cn/northlan/nnet/component" "git.noahlan.cn/northlan/nnet/log" - "git.noahlan.cn/northlan/nnet/message" "git.noahlan.cn/northlan/nnet/packet" "git.noahlan.cn/northlan/nnet/pipeline" + "git.noahlan.cn/northlan/nnet/scheduler" + "git.noahlan.cn/northlan/nnet/serialize" "git.noahlan.cn/northlan/nnet/session" "github.com/gorilla/websocket" "net" @@ -26,17 +27,18 @@ type ( Components *component.Components // 组件库 Packer packet.Packer // 封包、拆包器 PacketProcessor packet.Processor // 数据包处理器 - Serializer message.Serializer // 消息 序列化/反序列化 + Serializer serialize.Serializer // 消息 序列化/反序列化 HeartbeatInterval time.Duration // 心跳间隔,0表示不进行心跳 WS WSOptions // websocket } WSOptions struct { - IsWebsocket bool // 是否为websocket服务端 - WebsocketPath string // ws地址(ws://127.0.0.1/WebsocketPath) - TLSCertificate string // TLS 证书地址 (websocket) - TLSKey string // TLS 证书key地址 (websocket) + IsWebsocket bool // 是否为websocket服务端 + WebsocketPath string // ws地址(ws://127.0.0.1/WebsocketPath) + TLSCertificate string // TLS 证书地址 (websocket) + TLSKey string // TLS 证书key地址 (websocket) + CheckOrigin func(*http.Request) bool // check origin } ) @@ -52,6 +54,8 @@ type Server struct { // 地址可直接使用hostname,但强烈不建议这样做,可能会同时监听多个本地IP // 如果端口号不填或端口号为0,例如:"127.0.0.1:" 或 ":0",服务端将选择随机可用端口 address string + // DieChan 应用程序退出信号 + DieChan chan struct{} // handler 消息处理器 handler *Handler // sessionMgr session管理器 @@ -61,13 +65,17 @@ type Server struct { func NewServer(protocol, addr string, opts ...Option) *Server { options := Options{ Components: &component.Components{}, - WS: WSOptions{}, - Packer: packet.NewDefaultPacker(), + WS: WSOptions{ + CheckOrigin: func(_ *http.Request) bool { return true }, + }, + Packer: packet.NewNNetPacker(), + PacketProcessor: packet.NewNNetProcessor(), } s := &Server{ Options: options, protocol: protocol, address: addr, + DieChan: make(chan struct{}), } for _, opt := range opts { @@ -109,22 +117,29 @@ func (s *Server) Serve() { } }() + go scheduler.Schedule() + sg := make(chan os.Signal) signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM) select { - //case <-env.Die: - // log.Println("The app will shutdown in a few seconds") + case <-s.DieChan: + log.Info("The server will shutdown in a few seconds") case s := <-sg: - log.Infof("server got signal", s) + log.Infof("server got signal: %s", s) } - log.Infof("server is stopping...") - s.Shutdown() - // TODO close + log.Info("server is stopping...") + + s.shutdown() + scheduler.Close() } -func (s *Server) Shutdown() { +func (s *Server) Close() { + close(s.DieChan) +} + +func (s *Server) shutdown() { components := s.Components.List() compLen := len(components) for i := compLen - 1; i >= 0; i-- { @@ -139,68 +154,77 @@ func (s *Server) listenAndServe() { } // 监听成功,服务已启动 - // TODO log - defer listener.Close() + log.Info("listening...") + defer func() { + listener.Close() + s.Close() + }() - go func() { - for { - conn, err := listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - log.Error("服务器网络错误", err) - return - } - log.Errorf("监听错误 %v", err) - continue + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + log.Errorf("服务器网络错误 %+v", err) + return } + log.Errorf("监听错误 %v", err) + continue + } - err = pool.SubmitConn(func() { - s.handler.handle(conn) - }) - if err != nil { - log.Errorf("submit conn pool err: %s", err.Error()) - continue - } + err = pool.SubmitConn(func() { + s.handler.handle(conn) + }) + if err != nil { + log.Errorf("submit conn pool err: %s", err.Error()) + continue } - }() + } } func (s *Server) listenAndServeWS() { upgrade := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: nil, - EnableCompression: false, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: s.WS.CheckOrigin, } - http.HandleFunc(fmt.Sprintf("/%s/", "path"), func(w http.ResponseWriter, r *http.Request) { + http.HandleFunc(fmt.Sprintf("/%s/", s.WS.WebsocketPath), func(w http.ResponseWriter, r *http.Request) { conn, err := upgrade.Upgrade(w, r, nil) if err != nil { - // TODO upgrade failure, uri=r.requestURI err=err.Error() + log.Errorf("Upgrade failure, URI=%s, Error=%s", r.RequestURI, err.Error()) return } - // TODO s.handler.handleWS(conn) + err = pool.SubmitConn(func() { + s.handler.handleWS(conn) + }) + if err != nil { + log.Fatalf("submit conn pool err: %s", err.Error()) + } }) if err := http.ListenAndServe(s.address, nil); err != nil { - panic(err) + log.Fatal(err.Error()) } } func (s *Server) listenAndServeWSTLS() { upgrade := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: nil, - EnableCompression: false, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: s.WS.CheckOrigin, } - http.HandleFunc(fmt.Sprintf("/%s/", "path"), func(w http.ResponseWriter, r *http.Request) { + http.HandleFunc(fmt.Sprintf("/%s/", s.WS.WebsocketPath), func(w http.ResponseWriter, r *http.Request) { conn, err := upgrade.Upgrade(w, r, nil) if err != nil { - // TODO upgrade failure, uri=r.requestURI err=err.Error() + log.Errorf("Upgrade failure, URI=%s, Error=%s", r.RequestURI, err.Error()) return } - // TODO s.handler.handleWS(conn) + err = pool.SubmitConn(func() { + s.handler.handleWS(conn) + }) + if err != nil { + log.Fatalf("submit conn pool err: %s", err.Error()) + } }) - if err := http.ListenAndServeTLS(s.address, "", "", nil); err != nil { - panic(err) + if err := http.ListenAndServeTLS(s.address, s.WS.TLSCertificate, s.WS.TLSKey, nil); err != nil { + log.Fatal(err.Error()) } } diff --git a/nnet/server_test.go b/nnet/server_test.go new file mode 100644 index 0000000..e7fc2ea --- /dev/null +++ b/nnet/server_test.go @@ -0,0 +1,9 @@ +package nnet + +import "testing" + +func TestServer(t *testing.T) { + server := NewServer("tcp4", ":22112") + + server.Serve() +} diff --git a/packet/interface.go b/packet/interface.go index fdc2dc6..95b5b98 100644 --- a/packet/interface.go +++ b/packet/interface.go @@ -1,13 +1,15 @@ package packet -import "git.noahlan.cn/northlan/nnet/nface" +import ( + "git.noahlan.cn/northlan/nnet/nface" +) -// Type 数据帧类型,如:握手,心跳,数据等 +// Type 数据帧类型,如:握手,心跳,数据 等 type Type byte type ( Packer interface { - // Pack 从原始raw bytes创建一个用于网络传输的 packet.Packet 结构 + // Pack 从原始raw bytes创建一个用于网络传输的 数据帧结构 Pack(typ Type, data []byte) ([]byte, error) // Unpack 解包 diff --git a/packet/packer.go b/packet/packer_nnet.go similarity index 82% rename from packet/packer.go rename to packet/packer_nnet.go index 6f2f683..91470c3 100644 --- a/packet/packer.go +++ b/packet/packer_nnet.go @@ -5,9 +5,9 @@ import ( "errors" ) -var _ Packer = (*DefaultPacker)(nil) +var _ Packer = (*NNetPacker)(nil) -type DefaultPacker struct { +type NNetPacker struct { buf *bytes.Buffer size int // 最近一次 长度 typ byte // 最近一次 数据帧类型 @@ -21,14 +21,14 @@ const ( var ErrPacketSizeExceed = errors.New("codec: packet size exceed") -func NewDefaultPacker() Packer { - return &DefaultPacker{ +func NewNNetPacker() Packer { + return &NNetPacker{ buf: bytes.NewBuffer(nil), size: -1, } } -func (d *DefaultPacker) Pack(typ Type, data []byte) ([]byte, error) { +func (d *NNetPacker) Pack(typ Type, data []byte) ([]byte, error) { if typ < Handshake || typ > Kick { return nil, ErrWrongPacketType } @@ -47,7 +47,7 @@ func (d *DefaultPacker) Pack(typ Type, data []byte) ([]byte, error) { } // Encode packet data length to bytes(Big end) -func (d *DefaultPacker) intToBytes(n uint32) []byte { +func (d *NNetPacker) intToBytes(n uint32) []byte { buf := make([]byte, 3) buf[0] = byte((n >> 16) & 0xFF) buf[1] = byte((n >> 8) & 0xFF) @@ -55,7 +55,7 @@ func (d *DefaultPacker) intToBytes(n uint32) []byte { return buf } -func (d *DefaultPacker) Unpack(data []byte) ([]interface{}, error) { +func (d *NNetPacker) Unpack(data []byte) ([]interface{}, error) { d.buf.Write(data) // copy var ( @@ -98,7 +98,7 @@ func (d *DefaultPacker) Unpack(data []byte) ([]interface{}, error) { return packets, nil } -func (d *DefaultPacker) readHeader() error { +func (d *NNetPacker) readHeader() error { header := d.buf.Next(headLength) d.typ = header[0] if d.typ < Handshake || d.typ > Kick { @@ -114,7 +114,7 @@ func (d *DefaultPacker) readHeader() error { } // Decode packet data length byte to int(Big end) -func (d *DefaultPacker) bytesToInt(b []byte) int { +func (d *NNetPacker) bytesToInt(b []byte) int { result := 0 for _, v := range b { result = result<<8 + int(v) diff --git a/packet/packet.go b/packet/packet_nnet.go similarity index 100% rename from packet/packet.go rename to packet/packet_nnet.go diff --git a/packet/processor.go b/packet/processor_nnet.go similarity index 80% rename from packet/processor.go rename to packet/processor_nnet.go index 3a9b594..b81cb11 100644 --- a/packet/processor.go +++ b/packet/processor_nnet.go @@ -6,13 +6,13 @@ import ( "git.noahlan.cn/northlan/nnet/nface" ) -type DefaultProcessor struct{} +type NNetProcessor struct{} -func NewDefaultProcessor() *DefaultProcessor { - return &DefaultProcessor{} +func NewNNetProcessor() Processor { + return &NNetProcessor{} } -func (d *DefaultProcessor) ProcessPacket(conn nface.IConnection, packet interface{}) error { +func (d *NNetProcessor) ProcessPacket(conn nface.IConnection, packet interface{}) error { p := packet.(*Packet) switch p.Type { case Handshake: @@ -25,12 +25,12 @@ func (d *DefaultProcessor) ProcessPacket(conn nface.IConnection, packet interfac case HandshakeAck: conn.SetStatus(nface.StatusWorking) log.Debugf("Receive handshake ACK Id=%d, Remote=%s", conn.ID(), conn.Conn().RemoteAddr()) - case Data: if conn.Status() < nface.StatusWorking { return fmt.Errorf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s", conn.Conn().RemoteAddr()) } + // TODO message data 处理 case Heartbeat: // expected diff --git a/pipeline/pipeline.go b/pipeline/pipeline.go index 72b880d..ae30d04 100644 --- a/pipeline/pipeline.go +++ b/pipeline/pipeline.go @@ -1,11 +1,12 @@ package pipeline import ( + "git.noahlan.cn/northlan/nnet/nface" "sync" ) type ( - Func func(request *nnet.Request) error + Func func(nface.IConnection) error // Pipeline 消息管道 Pipeline interface { @@ -20,7 +21,7 @@ type ( Channel interface { PushFront(h Func) PushBack(h Func) - Process(request *nnet.Request) error + Process(nface.IConnection) error } pipelineChannel struct { @@ -65,7 +66,7 @@ func (p *pipelineChannel) PushBack(h Func) { } // Process 处理所有的pipeline方法 -func (p *pipelineChannel) Process(request *nnet.Request) error { +func (p *pipelineChannel) Process(conn nface.IConnection) error { p.mu.RLock() defer p.mu.RUnlock() @@ -74,7 +75,7 @@ func (p *pipelineChannel) Process(request *nnet.Request) error { } for _, handler := range p.handlers { - err := handler(request) + err := handler(conn) if err != nil { return err } diff --git a/scheduler/scheduler.go b/scheduler/scheduler.go new file mode 100644 index 0000000..3d8ce6e --- /dev/null +++ b/scheduler/scheduler.go @@ -0,0 +1,78 @@ +package scheduler + +import ( + "git.noahlan.cn/northlan/nnet/env" + "git.noahlan.cn/northlan/nnet/log" + "runtime/debug" + "sync/atomic" + "time" +) + +const ( + messageQueueBacklog = 1 << 10 // 1024 + sessionCloseBacklog = 1 << 8 // 256 +) + +// LocalScheduler schedules task to a customized goroutine +type LocalScheduler interface { + Schedule(Task) +} + +type Task func() + +type Hook func() + +var ( + chDie = make(chan struct{}) + chExit = make(chan struct{}) + chTasks = make(chan Task, 1<<8) + started int32 + closed int32 +) + +func try(f func()) { + defer func() { + if err := recover(); err != nil { + log.Infof("Handle message panic: %+v\n%s", err, debug.Stack()) + } + }() + f() +} + +func Schedule() { + if atomic.AddInt32(&started, 1) != 1 { + return + } + + ticker := time.NewTicker(env.TimerPrecision) + defer func() { + ticker.Stop() + close(chExit) + }() + + for { + select { + case <-ticker.C: + cron() + + case f := <-chTasks: + try(f) + + case <-chDie: + return + } + } +} + +func Close() { + if atomic.AddInt32(&closed, 1) != 1 { + return + } + close(chDie) + <-chExit + log.Info("Scheduler stopped") +} + +func PushTask(task Task) { + chTasks <- task +} diff --git a/scheduler/timer.go b/scheduler/timer.go new file mode 100644 index 0000000..e9815aa --- /dev/null +++ b/scheduler/timer.go @@ -0,0 +1,194 @@ +package scheduler + +import ( + "git.noahlan.cn/northlan/nnet/log" + "math" + "runtime/debug" + "sync" + "sync/atomic" + "time" +) + +const infinite = -1 + +var ( + timerManager = &struct { + incrementID int64 // auto increment id + timers map[int64]*Timer // all timers + + sync.Once + muClosingTimer sync.RWMutex // 关闭锁,避免重复关闭 + closingTimer []int64 // 已关闭的timer id + muCreatedTimer sync.RWMutex // 创建锁,避免重复创建 + createdTimer []*Timer // 已创建的Timer + }{} +) + +type ( + // TimerFunc represents a function which will be called periodically in main + // logic goroutine. + TimerFunc func() + + // TimerCondition represents a checker that returns true when cron job needs + // to execute + TimerCondition interface { + Check(now time.Time) bool + } + + // Timer represents a cron job + Timer struct { + id int64 // timer id + fn TimerFunc // function that execute + createAt int64 // timer create time + interval time.Duration // execution interval + condition TimerCondition // condition to cron job execution + elapse int64 // total elapse time + closed int32 // is timer closed + counter int // counter + } +) + +func init() { + timerManager.timers = map[int64]*Timer{} +} + +// ID returns id of current timer +func (t *Timer) ID() int64 { + return t.id +} + +// Stop turns off a timer. After Stop, fn will not be called forever +func (t *Timer) Stop() { + if atomic.AddInt32(&t.closed, 1) != 1 { + return + } + + t.counter = 0 +} + +// safeCall 安全调用,收集所有 fn 触发的 panic,给与提示即可 +func safeCall(_ int64, fn TimerFunc) { + defer func() { + if err := recover(); err != nil { + log.Infof("Handle timer panic: %+v\n%s", err, debug.Stack()) + } + }() + + fn() +} + +func cron() { + if len(timerManager.createdTimer) > 0 { + timerManager.muCreatedTimer.Lock() + for _, t := range timerManager.createdTimer { + timerManager.timers[t.id] = t + } + timerManager.createdTimer = timerManager.createdTimer[:0] + timerManager.muCreatedTimer.Unlock() + } + + if len(timerManager.timers) < 1 { + return + } + + now := time.Now() + unn := now.UnixNano() + for id, t := range timerManager.timers { + if t.counter == infinite || t.counter > 0 { + // condition timer + if t.condition != nil { + if t.condition.Check(now) { + safeCall(id, t.fn) + } + continue + } + + // execute job + if t.createAt+t.elapse <= unn { + safeCall(id, t.fn) + t.elapse += int64(t.interval) + + // update timer counter + if t.counter != infinite && t.counter > 0 { + t.counter-- + } + } + } + + if t.counter == 0 { + timerManager.muClosingTimer.Lock() + timerManager.closingTimer = append(timerManager.closingTimer, t.id) + timerManager.muClosingTimer.Unlock() + continue + } + } + + if len(timerManager.closingTimer) > 0 { + timerManager.muClosingTimer.Lock() + for _, id := range timerManager.closingTimer { + delete(timerManager.timers, id) + } + timerManager.closingTimer = timerManager.closingTimer[:0] + timerManager.muClosingTimer.Unlock() + } +} + +// NewTimer returns a new Timer containing a function that will be called +// with a period specified by the duration argument. It adjusts the intervals +// for slow receivers. +// The duration d must be greater than zero; if not, NewTimer will panic. +// Stop the timer to release associated resources. +func NewTimer(interval time.Duration, fn TimerFunc) *Timer { + return NewCountTimer(interval, infinite, fn) +} + +// NewCountTimer returns a new Timer containing a function that will be called +// with a period specified by the duration argument. After count times, timer +// will be stopped automatically, It adjusts the intervals for slow receivers. +// The duration d must be greater than zero; if not, NewCountTimer will panic. +// Stop the timer to release associated resources. +func NewCountTimer(interval time.Duration, count int, fn TimerFunc) *Timer { + if fn == nil { + panic("ngs/timer: nil timer function") + } + if interval <= 0 { + panic("non-positive interval for NewTimer") + } + + t := &Timer{ + id: atomic.AddInt64(&timerManager.incrementID, 1), + fn: fn, + createAt: time.Now().UnixNano(), + interval: interval, + elapse: int64(interval), // first execution will be after interval + counter: count, + } + + timerManager.muCreatedTimer.Lock() + timerManager.createdTimer = append(timerManager.createdTimer, t) + timerManager.muCreatedTimer.Unlock() + return t +} + +// NewAfterTimer returns a new Timer containing a function that will be called +// after duration that specified by the duration argument. +// The duration d must be greater than zero; if not, NewAfterTimer will panic. +// Stop the timer to release associated resources. +func NewAfterTimer(duration time.Duration, fn TimerFunc) *Timer { + return NewCountTimer(duration, 1, fn) +} + +// NewCondTimer returns a new Timer containing a function that will be called +// when condition satisfied that specified by the condition argument. +// The duration d must be greater than zero; if not, NewCondTimer will panic. +// Stop the timer to release associated resources. +func NewCondTimer(condition TimerCondition, fn TimerFunc) *Timer { + if condition == nil { + panic("ngs/timer: nil condition") + } + + t := NewCountTimer(time.Duration(math.MaxInt64), infinite, fn) + t.condition = condition + + return t +} diff --git a/scheduler/timer_test.go b/scheduler/timer_test.go new file mode 100644 index 0000000..5d8786a --- /dev/null +++ b/scheduler/timer_test.go @@ -0,0 +1,84 @@ +package scheduler + +import ( + "sync/atomic" + "testing" + "time" +) + +func TestNewTimer(t *testing.T) { + var exists = struct { + timers int + createdTimes int + closingTimers int + }{ + timers: len(timerManager.timers), + createdTimes: len(timerManager.createdTimer), + closingTimers: len(timerManager.closingTimer), + } + + const tc = 1000 + var counter int64 + for i := 0; i < tc; i++ { + NewTimer(1*time.Millisecond, func() { + atomic.AddInt64(&counter, 1) + }) + } + + <-time.After(5 * time.Millisecond) + cron() + cron() + if counter != tc*2 { + t.Fatalf("expect: %d, got: %d", tc*2, counter) + } + + if len(timerManager.timers) != exists.timers+tc { + t.Fatalf("timers: %d", len(timerManager.timers)) + } + + if len(timerManager.createdTimer) != exists.createdTimes { + t.Fatalf("createdTimer: %d", len(timerManager.createdTimer)) + } + + if len(timerManager.closingTimer) != exists.closingTimers { + t.Fatalf("closingTimer: %d", len(timerManager.closingTimer)) + } +} + +func TestNewAfterTimer(t *testing.T) { + var exists = struct { + timers int + createdTimes int + closingTimers int + }{ + timers: len(timerManager.timers), + createdTimes: len(timerManager.createdTimer), + closingTimers: len(timerManager.closingTimer), + } + + const tc = 1000 + var counter int64 + for i := 0; i < tc; i++ { + NewAfterTimer(1*time.Millisecond, func() { + atomic.AddInt64(&counter, 1) + }) + } + + <-time.After(5 * time.Millisecond) + cron() + if counter != tc { + t.Fatalf("expect: %d, got: %d", tc, counter) + } + + if len(timerManager.timers) != exists.timers { + t.Fatalf("timers: %d", len(timerManager.timers)) + } + + if len(timerManager.createdTimer) != exists.createdTimes { + t.Fatalf("createdTimer: %d", len(timerManager.createdTimer)) + } + + if len(timerManager.closingTimer) != exists.closingTimers { + t.Fatalf("closingTimer: %d", len(timerManager.closingTimer)) + } +} diff --git a/serialize/json/json.go b/serialize/json/json.go new file mode 100644 index 0000000..b75bd17 --- /dev/null +++ b/serialize/json/json.go @@ -0,0 +1,20 @@ +package json + +import ( + "encoding/json" + "git.noahlan.cn/northlan/nnet/serialize" +) + +type Serializer struct{} + +func NewSerializer() serialize.Serializer { + return &Serializer{} +} + +func (s *Serializer) Marshal(i interface{}) ([]byte, error) { + return json.Marshal(i) +} + +func (s *Serializer) Unmarshal(bytes []byte, i interface{}) error { + return json.Unmarshal(bytes, i) +} diff --git a/serialize/json/json_test.go b/serialize/json/json_test.go new file mode 100644 index 0000000..e4e097f --- /dev/null +++ b/serialize/json/json_test.go @@ -0,0 +1,62 @@ +package json + +import ( + "reflect" + "testing" +) + +type Message struct { + Code int `json:"code"` + Data string `json:"data"` +} + +func TestSerializer_Serialize(t *testing.T) { + m := Message{1, "hello world"} + s := NewSerializer() + b, err := s.Marshal(m) + if err != nil { + t.Fail() + } + + m2 := Message{} + if err := s.Unmarshal(b, &m2); err != nil { + t.Fail() + } + + if !reflect.DeepEqual(m, m2) { + t.Fail() + } +} + +func BenchmarkSerializer_Serialize(b *testing.B) { + m := &Message{100, "hell world"} + s := NewSerializer() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := s.Marshal(m); err != nil { + b.Fatalf("unmarshal failed: %v", err) + } + } + + b.ReportAllocs() +} + +func BenchmarkSerializer_Deserialize(b *testing.B) { + m := &Message{100, "hell world"} + s := NewSerializer() + + d, err := s.Marshal(m) + if err != nil { + b.Error(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m1 := &Message{} + if err := s.Unmarshal(d, m1); err != nil { + b.Fatalf("unmarshal failed: %v", err) + } + } + b.ReportAllocs() +} diff --git a/serialize/protobuf/protobuf.go b/serialize/protobuf/protobuf.go new file mode 100644 index 0000000..88bff2d --- /dev/null +++ b/serialize/protobuf/protobuf.go @@ -0,0 +1,32 @@ +package protobuf + +import ( + "errors" + "git.noahlan.cn/northlan/nnet/serialize" + "google.golang.org/protobuf/proto" +) + +// ErrWrongValueType is the error used for marshal the value with protobuf encoding. +var ErrWrongValueType = errors.New("protobuf: convert on wrong type value") + +type Serializer struct{} + +func NewSerializer() serialize.Serializer { + return &Serializer{} +} + +func (s *Serializer) Marshal(v interface{}) ([]byte, error) { + pb, ok := v.(proto.Message) + if !ok { + return nil, ErrWrongValueType + } + return proto.Marshal(pb) +} + +func (s *Serializer) Unmarshal(data []byte, v interface{}) error { + pb, ok := v.(proto.Message) + if !ok { + return ErrWrongValueType + } + return proto.Unmarshal(data, pb) +} diff --git a/serialize/protobuf/protobuf_test.go b/serialize/protobuf/protobuf_test.go new file mode 100644 index 0000000..16a8921 --- /dev/null +++ b/serialize/protobuf/protobuf_test.go @@ -0,0 +1,56 @@ +package protobuf + +import ( + "git.noahlan.cn/northlan/nnet/serialize/protobuf/testdata" + "google.golang.org/protobuf/proto" + "testing" +) + +func TestProtobufSerializer_Serialize(t *testing.T) { + m := &testdata.Ping{Content: "hello"} + s := NewSerializer() + + b, err := s.Marshal(m) + if err != nil { + t.Error(err) + } + + m1 := &testdata.Ping{} + if err := s.Unmarshal(b, m1); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + // refer: https://developers.google.com/protocol-buffers/docs/reference/go/faq#deepequal + if !proto.Equal(m, m1) { + t.Fail() + } +} + +func BenchmarkSerializer_Serialize(b *testing.B) { + m := &testdata.Ping{Content: "hello"} + s := NewSerializer() + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if _, err := s.Marshal(m); err != nil { + b.Fatalf("unmarshal failed: %v", err) + } + } +} + +func BenchmarkSerializer_Deserialize(b *testing.B) { + m := &testdata.Ping{Content: "hello"} + s := NewSerializer() + + d, err := s.Marshal(m) + if err != nil { + b.Error(err) + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + m1 := &testdata.Ping{} + if err := s.Unmarshal(d, m1); err != nil { + b.Fatalf("unmarshal failed: %v", err) + } + } +} diff --git a/serialize/protobuf/testdata/gen_proto.bat b/serialize/protobuf/testdata/gen_proto.bat new file mode 100644 index 0000000..c026490 --- /dev/null +++ b/serialize/protobuf/testdata/gen_proto.bat @@ -0,0 +1 @@ +protoc --go_opt=paths=source_relative --go_out=. --proto_path=. *.proto \ No newline at end of file diff --git a/serialize/protobuf/testdata/test.proto b/serialize/protobuf/testdata/test.proto new file mode 100644 index 0000000..df3f510 --- /dev/null +++ b/serialize/protobuf/testdata/test.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package testdata; + +option go_package = "/testdata"; + +message Ping { + string Content = 1; +} + +message Pong { + string Content = 1; +} \ No newline at end of file diff --git a/message/serializer.go b/serialize/serializer.go similarity index 69% rename from message/serializer.go rename to serialize/serializer.go index 7104335..a99ecab 100644 --- a/message/serializer.go +++ b/serialize/serializer.go @@ -1,14 +1,14 @@ -package message +package serialize type ( // Marshaler 序列化 Marshaler interface { - Marshal(interface{}) ([]byte, error) + Marshal(v interface{}) ([]byte, error) } // Unmarshaler 反序列化 Unmarshaler interface { - Unmarshal([]byte, interface{}) error + Unmarshal(data []byte, v interface{}) error } // Serializer 消息 序列化/反序列化,仅针对payload