wip: 又双叒加了一些新东西。
							parent
							
								
									115166cb11
								
							
						
					
					
						commit
						a2ed3090e7
					
				@ -0,0 +1,8 @@
 | 
			
		||||
package env
 | 
			
		||||
 | 
			
		||||
import "time"
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// TimerPrecision indicates the precision of timer, default is time.Second
 | 
			
		||||
	TimerPrecision = time.Second
 | 
			
		||||
)
 | 
			
		||||
@ -1,8 +1,14 @@
 | 
			
		||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 | 
			
		||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 | 
			
		||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
			
		||||
github.com/panjf2000/ants/v2 v2.6.0 h1:xOSpw42m+BMiJ2I33we7h6fYzG4DAlpE1xyI7VS2gxU=
 | 
			
		||||
github.com/panjf2000/ants/v2 v2.6.0/go.mod h1:cU93usDlihJZ5CfRGNDYsiBYvoilLvBF5Qp/BT2GNRE=
 | 
			
		||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 | 
			
		||||
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
 | 
			
		||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 | 
			
		||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 | 
			
		||||
google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w=
 | 
			
		||||
google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
 | 
			
		||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
 | 
			
		||||
 | 
			
		||||
@ -1,18 +0,0 @@
 | 
			
		||||
package message
 | 
			
		||||
 | 
			
		||||
type BinarySerializer struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewBinarySerializer() Serializer {
 | 
			
		||||
	return &BinarySerializer{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *BinarySerializer) Marshal(i interface{}) ([]byte, error) {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
	panic("implement me")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *BinarySerializer) Unmarshal(bytes []byte, i interface{}) error {
 | 
			
		||||
	//TODO implement me
 | 
			
		||||
	panic("implement me")
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,146 @@
 | 
			
		||||
package message
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/binary"
 | 
			
		||||
	"errors"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ Codec = (*NNetCodec)(nil)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	msgRouteCompressMask = 0x01 // 0000 0001  last bit
 | 
			
		||||
	msgTypeMask          = 0x07 // 0000 0111  1-3 bit (需要>>)
 | 
			
		||||
	msgRouteLengthMask   = 0xFF // 1111 1111  last 8 bit
 | 
			
		||||
	msgHeadLength        = 0x02 // 0000 0010  2 bit
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Errors that could be occurred in message codec
 | 
			
		||||
var (
 | 
			
		||||
	ErrWrongMessageType  = errors.New("wrong message type")
 | 
			
		||||
	ErrInvalidMessage    = errors.New("invalid message")
 | 
			
		||||
	ErrRouteInfoNotFound = errors.New("route info not found in dictionary")
 | 
			
		||||
	ErrWrongMessage      = errors.New("wrong message")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	routes = make(map[string]uint16) // route map to code
 | 
			
		||||
	codes  = make(map[uint16]string) // code map to route
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type NNetCodec struct{}
 | 
			
		||||
 | 
			
		||||
func (n *NNetCodec) routable(t Type) bool {
 | 
			
		||||
	return t == Request || t == Notify || t == Push
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NNetCodec) invalidType(t Type) bool {
 | 
			
		||||
	return t < Request || t > Push
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NNetCodec) Encode(v interface{}) ([]byte, error) {
 | 
			
		||||
	m, ok := v.(*Message)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, ErrWrongMessageType
 | 
			
		||||
	}
 | 
			
		||||
	if n.invalidType(m.Type) {
 | 
			
		||||
		return nil, ErrWrongMessageType
 | 
			
		||||
	}
 | 
			
		||||
	buf := make([]byte, 0)
 | 
			
		||||
	flag := byte(m.Type << 1) // 编译器提示,此处 byte 转换不能删
 | 
			
		||||
 | 
			
		||||
	code, compressed := routes[m.Route]
 | 
			
		||||
	if compressed {
 | 
			
		||||
		flag |= msgRouteCompressMask
 | 
			
		||||
	}
 | 
			
		||||
	buf = append(buf, flag)
 | 
			
		||||
 | 
			
		||||
	if m.Type == Request || m.Type == Response {
 | 
			
		||||
		n := m.ID
 | 
			
		||||
		// variant length encode
 | 
			
		||||
		for {
 | 
			
		||||
			b := byte(n % 128)
 | 
			
		||||
			n >>= 7
 | 
			
		||||
			if n != 0 {
 | 
			
		||||
				buf = append(buf, b+128)
 | 
			
		||||
			} else {
 | 
			
		||||
				buf = append(buf, b)
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if n.routable(m.Type) {
 | 
			
		||||
		if compressed {
 | 
			
		||||
			buf = append(buf, byte((code>>8)&0xFF))
 | 
			
		||||
			buf = append(buf, byte(code&0xFF))
 | 
			
		||||
		} else {
 | 
			
		||||
			buf = append(buf, byte(len(m.Route)))
 | 
			
		||||
			buf = append(buf, []byte(m.Route)...)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	buf = append(buf, m.Data...)
 | 
			
		||||
	return buf, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NNetCodec) Decode(data []byte) (interface{}, error) {
 | 
			
		||||
	if len(data) < msgHeadLength {
 | 
			
		||||
		return nil, ErrInvalidMessage
 | 
			
		||||
	}
 | 
			
		||||
	m := New()
 | 
			
		||||
	flag := data[0]
 | 
			
		||||
	offset := 1
 | 
			
		||||
	m.Type = Type((flag >> 1) & msgTypeMask) // 编译器提示,此处Type转换不能删
 | 
			
		||||
 | 
			
		||||
	if n.invalidType(m.Type) {
 | 
			
		||||
		return nil, ErrWrongMessageType
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if m.Type == Request || m.Type == Response {
 | 
			
		||||
		id := uint64(0)
 | 
			
		||||
		// little end byte order
 | 
			
		||||
		// WARNING: must can be stored in 64 bits integer
 | 
			
		||||
		// variant length encode
 | 
			
		||||
		for i := offset; i < len(data); i++ {
 | 
			
		||||
			b := data[i]
 | 
			
		||||
			id += uint64(b&0x7F) << uint64(7*(i-offset))
 | 
			
		||||
			if b < 128 {
 | 
			
		||||
				offset = i + 1
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		m.ID = id
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if offset >= len(data) {
 | 
			
		||||
		return nil, ErrWrongMessage
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if n.routable(m.Type) {
 | 
			
		||||
		if flag&msgRouteCompressMask == 1 {
 | 
			
		||||
			m.compressed = true
 | 
			
		||||
			code := binary.BigEndian.Uint16(data[offset:(offset + 2)])
 | 
			
		||||
			route, ok := codes[code]
 | 
			
		||||
			if !ok {
 | 
			
		||||
				return nil, ErrRouteInfoNotFound
 | 
			
		||||
			}
 | 
			
		||||
			m.Route = route
 | 
			
		||||
			offset += 2
 | 
			
		||||
		} else {
 | 
			
		||||
			m.compressed = false
 | 
			
		||||
			rl := data[offset]
 | 
			
		||||
			offset++
 | 
			
		||||
			if offset+int(rl) > len(data) {
 | 
			
		||||
				return nil, ErrWrongMessage
 | 
			
		||||
			}
 | 
			
		||||
			m.Route = string(data[offset:(offset + int(rl))])
 | 
			
		||||
			offset += int(rl)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if offset > len(data) {
 | 
			
		||||
		return nil, ErrWrongMessage
 | 
			
		||||
	}
 | 
			
		||||
	m.Data = data[offset:]
 | 
			
		||||
	return m, nil
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,11 @@
 | 
			
		||||
package message
 | 
			
		||||
 | 
			
		||||
type (
 | 
			
		||||
	// Codec 消息编解码器
 | 
			
		||||
	Codec interface {
 | 
			
		||||
		// Encode 编码
 | 
			
		||||
		Encode(v interface{}) ([]byte, error)
 | 
			
		||||
		// Decode 解码
 | 
			
		||||
		Decode(data []byte) (interface{}, error)
 | 
			
		||||
	}
 | 
			
		||||
)
 | 
			
		||||
@ -1,11 +0,0 @@
 | 
			
		||||
package message
 | 
			
		||||
 | 
			
		||||
type Header struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Type    byte   // 消息类型
 | 
			
		||||
	ID      uint64 // 消息ID
 | 
			
		||||
	Header  []byte // 消息头原始数据
 | 
			
		||||
	Payload []byte // 数据
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,46 @@
 | 
			
		||||
package message
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Type represents the type of message, which could be Request/Notify/Response/Push
 | 
			
		||||
type Type byte
 | 
			
		||||
 | 
			
		||||
// Message types
 | 
			
		||||
const (
 | 
			
		||||
	Request  Type = 0x00
 | 
			
		||||
	Notify        = 0x01
 | 
			
		||||
	Response      = 0x02
 | 
			
		||||
	Push          = 0x03
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var types = map[Type]string{
 | 
			
		||||
	Request:  "Request",
 | 
			
		||||
	Notify:   "Notify",
 | 
			
		||||
	Response: "Response",
 | 
			
		||||
	Push:     "Push",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t Type) String() string {
 | 
			
		||||
	return types[t]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Message represents an unmarshaler message or a message which to be marshaled
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Type       Type   // message type (flag)
 | 
			
		||||
	ID         uint64 // unique id, zero while notify mode
 | 
			
		||||
	Route      string // route for locating service
 | 
			
		||||
	Data       []byte // payload
 | 
			
		||||
	compressed bool   // if message compressed
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// New returns a new message instance
 | 
			
		||||
func New() *Message {
 | 
			
		||||
	return &Message{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// String, implementation of fmt.Stringer interface
 | 
			
		||||
func (m *Message) String() string {
 | 
			
		||||
	return fmt.Sprintf("%s %s (%dbytes)", types[m.Type], m.Route, len(m.Data))
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,4 @@
 | 
			
		||||
package nface
 | 
			
		||||
 | 
			
		||||
type IServer interface {
 | 
			
		||||
}
 | 
			
		||||
@ -1 +0,0 @@
 | 
			
		||||
package nnet
 | 
			
		||||
@ -0,0 +1,9 @@
 | 
			
		||||
package nnet
 | 
			
		||||
 | 
			
		||||
import "testing"
 | 
			
		||||
 | 
			
		||||
func TestServer(t *testing.T) {
 | 
			
		||||
	server := NewServer("tcp4", ":22112")
 | 
			
		||||
 | 
			
		||||
	server.Serve()
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,78 @@
 | 
			
		||||
package scheduler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"git.noahlan.cn/northlan/nnet/env"
 | 
			
		||||
	"git.noahlan.cn/northlan/nnet/log"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	messageQueueBacklog = 1 << 10 // 1024
 | 
			
		||||
	sessionCloseBacklog = 1 << 8  // 256
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// LocalScheduler schedules task to a customized goroutine
 | 
			
		||||
type LocalScheduler interface {
 | 
			
		||||
	Schedule(Task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Task func()
 | 
			
		||||
 | 
			
		||||
type Hook func()
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	chDie   = make(chan struct{})
 | 
			
		||||
	chExit  = make(chan struct{})
 | 
			
		||||
	chTasks = make(chan Task, 1<<8)
 | 
			
		||||
	started int32
 | 
			
		||||
	closed  int32
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func try(f func()) {
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			log.Infof("Handle message panic: %+v\n%s", err, debug.Stack())
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	f()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Schedule() {
 | 
			
		||||
	if atomic.AddInt32(&started, 1) != 1 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ticker := time.NewTicker(env.TimerPrecision)
 | 
			
		||||
	defer func() {
 | 
			
		||||
		ticker.Stop()
 | 
			
		||||
		close(chExit)
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			cron()
 | 
			
		||||
 | 
			
		||||
		case f := <-chTasks:
 | 
			
		||||
			try(f)
 | 
			
		||||
 | 
			
		||||
		case <-chDie:
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Close() {
 | 
			
		||||
	if atomic.AddInt32(&closed, 1) != 1 {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	close(chDie)
 | 
			
		||||
	<-chExit
 | 
			
		||||
	log.Info("Scheduler stopped")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PushTask(task Task) {
 | 
			
		||||
	chTasks <- task
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,84 @@
 | 
			
		||||
package scheduler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestNewTimer(t *testing.T) {
 | 
			
		||||
	var exists = struct {
 | 
			
		||||
		timers        int
 | 
			
		||||
		createdTimes  int
 | 
			
		||||
		closingTimers int
 | 
			
		||||
	}{
 | 
			
		||||
		timers:        len(timerManager.timers),
 | 
			
		||||
		createdTimes:  len(timerManager.createdTimer),
 | 
			
		||||
		closingTimers: len(timerManager.closingTimer),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	const tc = 1000
 | 
			
		||||
	var counter int64
 | 
			
		||||
	for i := 0; i < tc; i++ {
 | 
			
		||||
		NewTimer(1*time.Millisecond, func() {
 | 
			
		||||
			atomic.AddInt64(&counter, 1)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	<-time.After(5 * time.Millisecond)
 | 
			
		||||
	cron()
 | 
			
		||||
	cron()
 | 
			
		||||
	if counter != tc*2 {
 | 
			
		||||
		t.Fatalf("expect: %d, got: %d", tc*2, counter)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(timerManager.timers) != exists.timers+tc {
 | 
			
		||||
		t.Fatalf("timers: %d", len(timerManager.timers))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(timerManager.createdTimer) != exists.createdTimes {
 | 
			
		||||
		t.Fatalf("createdTimer: %d", len(timerManager.createdTimer))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(timerManager.closingTimer) != exists.closingTimers {
 | 
			
		||||
		t.Fatalf("closingTimer: %d", len(timerManager.closingTimer))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewAfterTimer(t *testing.T) {
 | 
			
		||||
	var exists = struct {
 | 
			
		||||
		timers        int
 | 
			
		||||
		createdTimes  int
 | 
			
		||||
		closingTimers int
 | 
			
		||||
	}{
 | 
			
		||||
		timers:        len(timerManager.timers),
 | 
			
		||||
		createdTimes:  len(timerManager.createdTimer),
 | 
			
		||||
		closingTimers: len(timerManager.closingTimer),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	const tc = 1000
 | 
			
		||||
	var counter int64
 | 
			
		||||
	for i := 0; i < tc; i++ {
 | 
			
		||||
		NewAfterTimer(1*time.Millisecond, func() {
 | 
			
		||||
			atomic.AddInt64(&counter, 1)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	<-time.After(5 * time.Millisecond)
 | 
			
		||||
	cron()
 | 
			
		||||
	if counter != tc {
 | 
			
		||||
		t.Fatalf("expect: %d, got: %d", tc, counter)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(timerManager.timers) != exists.timers {
 | 
			
		||||
		t.Fatalf("timers: %d", len(timerManager.timers))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(timerManager.createdTimer) != exists.createdTimes {
 | 
			
		||||
		t.Fatalf("createdTimer: %d", len(timerManager.createdTimer))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(timerManager.closingTimer) != exists.closingTimers {
 | 
			
		||||
		t.Fatalf("closingTimer: %d", len(timerManager.closingTimer))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,20 @@
 | 
			
		||||
package json
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"git.noahlan.cn/northlan/nnet/serialize"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Serializer struct{}
 | 
			
		||||
 | 
			
		||||
func NewSerializer() serialize.Serializer {
 | 
			
		||||
	return &Serializer{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Serializer) Marshal(i interface{}) ([]byte, error) {
 | 
			
		||||
	return json.Marshal(i)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Serializer) Unmarshal(bytes []byte, i interface{}) error {
 | 
			
		||||
	return json.Unmarshal(bytes, i)
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,62 @@
 | 
			
		||||
package json
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Code int    `json:"code"`
 | 
			
		||||
	Data string `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSerializer_Serialize(t *testing.T) {
 | 
			
		||||
	m := Message{1, "hello world"}
 | 
			
		||||
	s := NewSerializer()
 | 
			
		||||
	b, err := s.Marshal(m)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fail()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	m2 := Message{}
 | 
			
		||||
	if err := s.Unmarshal(b, &m2); err != nil {
 | 
			
		||||
		t.Fail()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !reflect.DeepEqual(m, m2) {
 | 
			
		||||
		t.Fail()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkSerializer_Serialize(b *testing.B) {
 | 
			
		||||
	m := &Message{100, "hell world"}
 | 
			
		||||
	s := NewSerializer()
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		if _, err := s.Marshal(m); err != nil {
 | 
			
		||||
			b.Fatalf("unmarshal failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.ReportAllocs()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkSerializer_Deserialize(b *testing.B) {
 | 
			
		||||
	m := &Message{100, "hell world"}
 | 
			
		||||
	s := NewSerializer()
 | 
			
		||||
 | 
			
		||||
	d, err := s.Marshal(m)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		b.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.ResetTimer()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		m1 := &Message{}
 | 
			
		||||
		if err := s.Unmarshal(d, m1); err != nil {
 | 
			
		||||
			b.Fatalf("unmarshal failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	b.ReportAllocs()
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,32 @@
 | 
			
		||||
package protobuf
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"git.noahlan.cn/northlan/nnet/serialize"
 | 
			
		||||
	"google.golang.org/protobuf/proto"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ErrWrongValueType is the error used for marshal the value with protobuf encoding.
 | 
			
		||||
var ErrWrongValueType = errors.New("protobuf: convert on wrong type value")
 | 
			
		||||
 | 
			
		||||
type Serializer struct{}
 | 
			
		||||
 | 
			
		||||
func NewSerializer() serialize.Serializer {
 | 
			
		||||
	return &Serializer{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Serializer) Marshal(v interface{}) ([]byte, error) {
 | 
			
		||||
	pb, ok := v.(proto.Message)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, ErrWrongValueType
 | 
			
		||||
	}
 | 
			
		||||
	return proto.Marshal(pb)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Serializer) Unmarshal(data []byte, v interface{}) error {
 | 
			
		||||
	pb, ok := v.(proto.Message)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return ErrWrongValueType
 | 
			
		||||
	}
 | 
			
		||||
	return proto.Unmarshal(data, pb)
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,56 @@
 | 
			
		||||
package protobuf
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"git.noahlan.cn/northlan/nnet/serialize/protobuf/testdata"
 | 
			
		||||
	"google.golang.org/protobuf/proto"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestProtobufSerializer_Serialize(t *testing.T) {
 | 
			
		||||
	m := &testdata.Ping{Content: "hello"}
 | 
			
		||||
	s := NewSerializer()
 | 
			
		||||
 | 
			
		||||
	b, err := s.Marshal(m)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	m1 := &testdata.Ping{}
 | 
			
		||||
	if err := s.Unmarshal(b, m1); err != nil {
 | 
			
		||||
		t.Fatalf("unmarshal failed: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	// refer: https://developers.google.com/protocol-buffers/docs/reference/go/faq#deepequal
 | 
			
		||||
	if !proto.Equal(m, m1) {
 | 
			
		||||
		t.Fail()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkSerializer_Serialize(b *testing.B) {
 | 
			
		||||
	m := &testdata.Ping{Content: "hello"}
 | 
			
		||||
	s := NewSerializer()
 | 
			
		||||
 | 
			
		||||
	b.ReportAllocs()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		if _, err := s.Marshal(m); err != nil {
 | 
			
		||||
			b.Fatalf("unmarshal failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkSerializer_Deserialize(b *testing.B) {
 | 
			
		||||
	m := &testdata.Ping{Content: "hello"}
 | 
			
		||||
	s := NewSerializer()
 | 
			
		||||
 | 
			
		||||
	d, err := s.Marshal(m)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		b.Error(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.ReportAllocs()
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		m1 := &testdata.Ping{}
 | 
			
		||||
		if err := s.Unmarshal(d, m1); err != nil {
 | 
			
		||||
			b.Fatalf("unmarshal failed: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1 @@
 | 
			
		||||
protoc --go_opt=paths=source_relative --go_out=. --proto_path=. *.proto
 | 
			
		||||
@ -0,0 +1,13 @@
 | 
			
		||||
syntax = "proto3";
 | 
			
		||||
 | 
			
		||||
package testdata;
 | 
			
		||||
 | 
			
		||||
option go_package = "/testdata";
 | 
			
		||||
 | 
			
		||||
message Ping {
 | 
			
		||||
  string Content = 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
message Pong {
 | 
			
		||||
  string Content = 1;
 | 
			
		||||
}
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue