wip: 又双叒加了一些新东西。
parent
115166cb11
commit
a2ed3090e7
@ -0,0 +1,8 @@
|
|||||||
|
package env
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// TimerPrecision indicates the precision of timer, default is time.Second
|
||||||
|
TimerPrecision = time.Second
|
||||||
|
)
|
@ -1,8 +1,14 @@
|
|||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||||
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/panjf2000/ants/v2 v2.6.0 h1:xOSpw42m+BMiJ2I33we7h6fYzG4DAlpE1xyI7VS2gxU=
|
github.com/panjf2000/ants/v2 v2.6.0 h1:xOSpw42m+BMiJ2I33we7h6fYzG4DAlpE1xyI7VS2gxU=
|
||||||
github.com/panjf2000/ants/v2 v2.6.0/go.mod h1:cU93usDlihJZ5CfRGNDYsiBYvoilLvBF5Qp/BT2GNRE=
|
github.com/panjf2000/ants/v2 v2.6.0/go.mod h1:cU93usDlihJZ5CfRGNDYsiBYvoilLvBF5Qp/BT2GNRE=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
|
||||||
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||||
|
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
|
||||||
|
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
|
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
package message
|
|
||||||
|
|
||||||
type BinarySerializer struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBinarySerializer() Serializer {
|
|
||||||
return &BinarySerializer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BinarySerializer) Marshal(i interface{}) ([]byte, error) {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BinarySerializer) Unmarshal(bytes []byte, i interface{}) error {
|
|
||||||
//TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
@ -0,0 +1,146 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ Codec = (*NNetCodec)(nil)
|
||||||
|
|
||||||
|
const (
|
||||||
|
msgRouteCompressMask = 0x01 // 0000 0001 last bit
|
||||||
|
msgTypeMask = 0x07 // 0000 0111 1-3 bit (需要>>)
|
||||||
|
msgRouteLengthMask = 0xFF // 1111 1111 last 8 bit
|
||||||
|
msgHeadLength = 0x02 // 0000 0010 2 bit
|
||||||
|
)
|
||||||
|
|
||||||
|
// Errors that could be occurred in message codec
|
||||||
|
var (
|
||||||
|
ErrWrongMessageType = errors.New("wrong message type")
|
||||||
|
ErrInvalidMessage = errors.New("invalid message")
|
||||||
|
ErrRouteInfoNotFound = errors.New("route info not found in dictionary")
|
||||||
|
ErrWrongMessage = errors.New("wrong message")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
routes = make(map[string]uint16) // route map to code
|
||||||
|
codes = make(map[uint16]string) // code map to route
|
||||||
|
)
|
||||||
|
|
||||||
|
type NNetCodec struct{}
|
||||||
|
|
||||||
|
func (n *NNetCodec) routable(t Type) bool {
|
||||||
|
return t == Request || t == Notify || t == Push
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NNetCodec) invalidType(t Type) bool {
|
||||||
|
return t < Request || t > Push
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NNetCodec) Encode(v interface{}) ([]byte, error) {
|
||||||
|
m, ok := v.(*Message)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrWrongMessageType
|
||||||
|
}
|
||||||
|
if n.invalidType(m.Type) {
|
||||||
|
return nil, ErrWrongMessageType
|
||||||
|
}
|
||||||
|
buf := make([]byte, 0)
|
||||||
|
flag := byte(m.Type << 1) // 编译器提示,此处 byte 转换不能删
|
||||||
|
|
||||||
|
code, compressed := routes[m.Route]
|
||||||
|
if compressed {
|
||||||
|
flag |= msgRouteCompressMask
|
||||||
|
}
|
||||||
|
buf = append(buf, flag)
|
||||||
|
|
||||||
|
if m.Type == Request || m.Type == Response {
|
||||||
|
n := m.ID
|
||||||
|
// variant length encode
|
||||||
|
for {
|
||||||
|
b := byte(n % 128)
|
||||||
|
n >>= 7
|
||||||
|
if n != 0 {
|
||||||
|
buf = append(buf, b+128)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, b)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.routable(m.Type) {
|
||||||
|
if compressed {
|
||||||
|
buf = append(buf, byte((code>>8)&0xFF))
|
||||||
|
buf = append(buf, byte(code&0xFF))
|
||||||
|
} else {
|
||||||
|
buf = append(buf, byte(len(m.Route)))
|
||||||
|
buf = append(buf, []byte(m.Route)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = append(buf, m.Data...)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *NNetCodec) Decode(data []byte) (interface{}, error) {
|
||||||
|
if len(data) < msgHeadLength {
|
||||||
|
return nil, ErrInvalidMessage
|
||||||
|
}
|
||||||
|
m := New()
|
||||||
|
flag := data[0]
|
||||||
|
offset := 1
|
||||||
|
m.Type = Type((flag >> 1) & msgTypeMask) // 编译器提示,此处Type转换不能删
|
||||||
|
|
||||||
|
if n.invalidType(m.Type) {
|
||||||
|
return nil, ErrWrongMessageType
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.Type == Request || m.Type == Response {
|
||||||
|
id := uint64(0)
|
||||||
|
// little end byte order
|
||||||
|
// WARNING: must can be stored in 64 bits integer
|
||||||
|
// variant length encode
|
||||||
|
for i := offset; i < len(data); i++ {
|
||||||
|
b := data[i]
|
||||||
|
id += uint64(b&0x7F) << uint64(7*(i-offset))
|
||||||
|
if b < 128 {
|
||||||
|
offset = i + 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.ID = id
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset >= len(data) {
|
||||||
|
return nil, ErrWrongMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.routable(m.Type) {
|
||||||
|
if flag&msgRouteCompressMask == 1 {
|
||||||
|
m.compressed = true
|
||||||
|
code := binary.BigEndian.Uint16(data[offset:(offset + 2)])
|
||||||
|
route, ok := codes[code]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrRouteInfoNotFound
|
||||||
|
}
|
||||||
|
m.Route = route
|
||||||
|
offset += 2
|
||||||
|
} else {
|
||||||
|
m.compressed = false
|
||||||
|
rl := data[offset]
|
||||||
|
offset++
|
||||||
|
if offset+int(rl) > len(data) {
|
||||||
|
return nil, ErrWrongMessage
|
||||||
|
}
|
||||||
|
m.Route = string(data[offset:(offset + int(rl))])
|
||||||
|
offset += int(rl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset > len(data) {
|
||||||
|
return nil, ErrWrongMessage
|
||||||
|
}
|
||||||
|
m.Data = data[offset:]
|
||||||
|
return m, nil
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
type (
|
||||||
|
// Codec 消息编解码器
|
||||||
|
Codec interface {
|
||||||
|
// Encode 编码
|
||||||
|
Encode(v interface{}) ([]byte, error)
|
||||||
|
// Decode 解码
|
||||||
|
Decode(data []byte) (interface{}, error)
|
||||||
|
}
|
||||||
|
)
|
@ -1,11 +0,0 @@
|
|||||||
package message
|
|
||||||
|
|
||||||
type Header struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Type byte // 消息类型
|
|
||||||
ID uint64 // 消息ID
|
|
||||||
Header []byte // 消息头原始数据
|
|
||||||
Payload []byte // 数据
|
|
||||||
}
|
|
@ -0,0 +1,46 @@
|
|||||||
|
package message
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Type represents the type of message, which could be Request/Notify/Response/Push
|
||||||
|
type Type byte
|
||||||
|
|
||||||
|
// Message types
|
||||||
|
const (
|
||||||
|
Request Type = 0x00
|
||||||
|
Notify = 0x01
|
||||||
|
Response = 0x02
|
||||||
|
Push = 0x03
|
||||||
|
)
|
||||||
|
|
||||||
|
var types = map[Type]string{
|
||||||
|
Request: "Request",
|
||||||
|
Notify: "Notify",
|
||||||
|
Response: "Response",
|
||||||
|
Push: "Push",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) String() string {
|
||||||
|
return types[t]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message represents an unmarshaler message or a message which to be marshaled
|
||||||
|
type Message struct {
|
||||||
|
Type Type // message type (flag)
|
||||||
|
ID uint64 // unique id, zero while notify mode
|
||||||
|
Route string // route for locating service
|
||||||
|
Data []byte // payload
|
||||||
|
compressed bool // if message compressed
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new message instance
|
||||||
|
func New() *Message {
|
||||||
|
return &Message{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String, implementation of fmt.Stringer interface
|
||||||
|
func (m *Message) String() string {
|
||||||
|
return fmt.Sprintf("%s %s (%dbytes)", types[m.Type], m.Route, len(m.Data))
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
package nface
|
||||||
|
|
||||||
|
type IServer interface {
|
||||||
|
}
|
@ -1 +0,0 @@
|
|||||||
package nnet
|
|
@ -0,0 +1,9 @@
|
|||||||
|
package nnet
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestServer(t *testing.T) {
|
||||||
|
server := NewServer("tcp4", ":22112")
|
||||||
|
|
||||||
|
server.Serve()
|
||||||
|
}
|
@ -0,0 +1,78 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.noahlan.cn/northlan/nnet/env"
|
||||||
|
"git.noahlan.cn/northlan/nnet/log"
|
||||||
|
"runtime/debug"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
messageQueueBacklog = 1 << 10 // 1024
|
||||||
|
sessionCloseBacklog = 1 << 8 // 256
|
||||||
|
)
|
||||||
|
|
||||||
|
// LocalScheduler schedules task to a customized goroutine
|
||||||
|
type LocalScheduler interface {
|
||||||
|
Schedule(Task)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Task func()
|
||||||
|
|
||||||
|
type Hook func()
|
||||||
|
|
||||||
|
var (
|
||||||
|
chDie = make(chan struct{})
|
||||||
|
chExit = make(chan struct{})
|
||||||
|
chTasks = make(chan Task, 1<<8)
|
||||||
|
started int32
|
||||||
|
closed int32
|
||||||
|
)
|
||||||
|
|
||||||
|
func try(f func()) {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
log.Infof("Handle message panic: %+v\n%s", err, debug.Stack())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Schedule() {
|
||||||
|
if atomic.AddInt32(&started, 1) != 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(env.TimerPrecision)
|
||||||
|
defer func() {
|
||||||
|
ticker.Stop()
|
||||||
|
close(chExit)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
cron()
|
||||||
|
|
||||||
|
case f := <-chTasks:
|
||||||
|
try(f)
|
||||||
|
|
||||||
|
case <-chDie:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Close() {
|
||||||
|
if atomic.AddInt32(&closed, 1) != 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
close(chDie)
|
||||||
|
<-chExit
|
||||||
|
log.Info("Scheduler stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func PushTask(task Task) {
|
||||||
|
chTasks <- task
|
||||||
|
}
|
@ -0,0 +1,84 @@
|
|||||||
|
package scheduler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewTimer(t *testing.T) {
|
||||||
|
var exists = struct {
|
||||||
|
timers int
|
||||||
|
createdTimes int
|
||||||
|
closingTimers int
|
||||||
|
}{
|
||||||
|
timers: len(timerManager.timers),
|
||||||
|
createdTimes: len(timerManager.createdTimer),
|
||||||
|
closingTimers: len(timerManager.closingTimer),
|
||||||
|
}
|
||||||
|
|
||||||
|
const tc = 1000
|
||||||
|
var counter int64
|
||||||
|
for i := 0; i < tc; i++ {
|
||||||
|
NewTimer(1*time.Millisecond, func() {
|
||||||
|
atomic.AddInt64(&counter, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
<-time.After(5 * time.Millisecond)
|
||||||
|
cron()
|
||||||
|
cron()
|
||||||
|
if counter != tc*2 {
|
||||||
|
t.Fatalf("expect: %d, got: %d", tc*2, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(timerManager.timers) != exists.timers+tc {
|
||||||
|
t.Fatalf("timers: %d", len(timerManager.timers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(timerManager.createdTimer) != exists.createdTimes {
|
||||||
|
t.Fatalf("createdTimer: %d", len(timerManager.createdTimer))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(timerManager.closingTimer) != exists.closingTimers {
|
||||||
|
t.Fatalf("closingTimer: %d", len(timerManager.closingTimer))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAfterTimer(t *testing.T) {
|
||||||
|
var exists = struct {
|
||||||
|
timers int
|
||||||
|
createdTimes int
|
||||||
|
closingTimers int
|
||||||
|
}{
|
||||||
|
timers: len(timerManager.timers),
|
||||||
|
createdTimes: len(timerManager.createdTimer),
|
||||||
|
closingTimers: len(timerManager.closingTimer),
|
||||||
|
}
|
||||||
|
|
||||||
|
const tc = 1000
|
||||||
|
var counter int64
|
||||||
|
for i := 0; i < tc; i++ {
|
||||||
|
NewAfterTimer(1*time.Millisecond, func() {
|
||||||
|
atomic.AddInt64(&counter, 1)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
<-time.After(5 * time.Millisecond)
|
||||||
|
cron()
|
||||||
|
if counter != tc {
|
||||||
|
t.Fatalf("expect: %d, got: %d", tc, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(timerManager.timers) != exists.timers {
|
||||||
|
t.Fatalf("timers: %d", len(timerManager.timers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(timerManager.createdTimer) != exists.createdTimes {
|
||||||
|
t.Fatalf("createdTimer: %d", len(timerManager.createdTimer))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(timerManager.closingTimer) != exists.closingTimers {
|
||||||
|
t.Fatalf("closingTimer: %d", len(timerManager.closingTimer))
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
package json
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"git.noahlan.cn/northlan/nnet/serialize"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Serializer struct{}
|
||||||
|
|
||||||
|
func NewSerializer() serialize.Serializer {
|
||||||
|
return &Serializer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Serializer) Marshal(i interface{}) ([]byte, error) {
|
||||||
|
return json.Marshal(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Serializer) Unmarshal(bytes []byte, i interface{}) error {
|
||||||
|
return json.Unmarshal(bytes, i)
|
||||||
|
}
|
@ -0,0 +1,62 @@
|
|||||||
|
package json
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSerializer_Serialize(t *testing.T) {
|
||||||
|
m := Message{1, "hello world"}
|
||||||
|
s := NewSerializer()
|
||||||
|
b, err := s.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
m2 := Message{}
|
||||||
|
if err := s.Unmarshal(b, &m2); err != nil {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(m, m2) {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSerializer_Serialize(b *testing.B) {
|
||||||
|
m := &Message{100, "hell world"}
|
||||||
|
s := NewSerializer()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := s.Marshal(m); err != nil {
|
||||||
|
b.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSerializer_Deserialize(b *testing.B) {
|
||||||
|
m := &Message{100, "hell world"}
|
||||||
|
s := NewSerializer()
|
||||||
|
|
||||||
|
d, err := s.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
b.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m1 := &Message{}
|
||||||
|
if err := s.Unmarshal(d, m1); err != nil {
|
||||||
|
b.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.ReportAllocs()
|
||||||
|
}
|
@ -0,0 +1,32 @@
|
|||||||
|
package protobuf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"git.noahlan.cn/northlan/nnet/serialize"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrWrongValueType is the error used for marshal the value with protobuf encoding.
|
||||||
|
var ErrWrongValueType = errors.New("protobuf: convert on wrong type value")
|
||||||
|
|
||||||
|
type Serializer struct{}
|
||||||
|
|
||||||
|
func NewSerializer() serialize.Serializer {
|
||||||
|
return &Serializer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Serializer) Marshal(v interface{}) ([]byte, error) {
|
||||||
|
pb, ok := v.(proto.Message)
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrWrongValueType
|
||||||
|
}
|
||||||
|
return proto.Marshal(pb)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Serializer) Unmarshal(data []byte, v interface{}) error {
|
||||||
|
pb, ok := v.(proto.Message)
|
||||||
|
if !ok {
|
||||||
|
return ErrWrongValueType
|
||||||
|
}
|
||||||
|
return proto.Unmarshal(data, pb)
|
||||||
|
}
|
@ -0,0 +1,56 @@
|
|||||||
|
package protobuf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.noahlan.cn/northlan/nnet/serialize/protobuf/testdata"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProtobufSerializer_Serialize(t *testing.T) {
|
||||||
|
m := &testdata.Ping{Content: "hello"}
|
||||||
|
s := NewSerializer()
|
||||||
|
|
||||||
|
b, err := s.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m1 := &testdata.Ping{}
|
||||||
|
if err := s.Unmarshal(b, m1); err != nil {
|
||||||
|
t.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
// refer: https://developers.google.com/protocol-buffers/docs/reference/go/faq#deepequal
|
||||||
|
if !proto.Equal(m, m1) {
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSerializer_Serialize(b *testing.B) {
|
||||||
|
m := &testdata.Ping{Content: "hello"}
|
||||||
|
s := NewSerializer()
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if _, err := s.Marshal(m); err != nil {
|
||||||
|
b.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSerializer_Deserialize(b *testing.B) {
|
||||||
|
m := &testdata.Ping{Content: "hello"}
|
||||||
|
s := NewSerializer()
|
||||||
|
|
||||||
|
d, err := s.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
b.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m1 := &testdata.Ping{}
|
||||||
|
if err := s.Unmarshal(d, m1); err != nil {
|
||||||
|
b.Fatalf("unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1 @@
|
|||||||
|
protoc --go_opt=paths=source_relative --go_out=. --proto_path=. *.proto
|
@ -0,0 +1,13 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package testdata;
|
||||||
|
|
||||||
|
option go_package = "/testdata";
|
||||||
|
|
||||||
|
message Ping {
|
||||||
|
string Content = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Pong {
|
||||||
|
string Content = 1;
|
||||||
|
}
|
Loading…
Reference in New Issue