You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ngs/client/client.go

285 lines
5.1 KiB
Go

package client
import (
"git.noahlan.cn/northlan/ngs/internal/codec"
"git.noahlan.cn/northlan/ngs/internal/log"
"git.noahlan.cn/northlan/ngs/internal/message"
"git.noahlan.cn/northlan/ngs/internal/packet"
"google.golang.org/protobuf/proto"
"net"
"sync"
)
var (
hsd []byte // handshake data
had []byte // handshake ack data
)
func init() {
var err error
hsd, err = codec.Encode(packet.Handshake, nil)
if err != nil {
panic(err)
}
had, err = codec.Encode(packet.HandshakeAck, nil)
if err != nil {
panic(err)
}
}
type (
// Callback represents the callback type which will be called
// when the correspond events is occurred.
Callback func(data interface{})
// Client is a tiny Ngs client
Client struct {
conn net.Conn // low-level connection
codec *codec.Decoder // decoder
die chan struct{} // connector close channel
chSend chan []byte // send queue
mid uint64 // message id
// events handler
muEvents sync.RWMutex
events map[string]Callback
// response handler
muResponses sync.RWMutex
responses map[uint64]Callback
connectedCallback func() // connected callback
}
)
// NewClient create a new Client
func NewClient() *Client {
return &Client{
die: make(chan struct{}),
codec: codec.NewDecoder(),
chSend: make(chan []byte, 64),
mid: 1,
events: map[string]Callback{},
responses: map[uint64]Callback{},
}
}
// Start connect to the server and send/recv between the c/s
func (c *Client) Start(addr string) error {
conn, err := net.Dial("tcp", addr)
if err != nil {
return err
}
c.conn = conn
go c.write()
// send handshake packet
c.send(hsd)
// read and process network message
go c.read()
return nil
}
// OnConnected set the callback which will be called when the client connected to the server
func (c *Client) OnConnected(callback func()) {
c.connectedCallback = callback
}
// Request send a request to server and register a callback for the response
func (c *Client) Request(route string, v proto.Message, callback Callback) error {
data, err := serialize(v)
if err != nil {
return err
}
msg := &message.Message{
Type: message.Request,
Route: route,
ID: c.mid,
Data: data,
}
c.setResponseHandler(c.mid, callback)
if err := c.sendMessage(msg); err != nil {
c.setResponseHandler(c.mid, nil)
return err
}
return nil
}
// Notify send a notification to server
func (c *Client) Notify(route string, v proto.Message) error {
data, err := serialize(v)
if err != nil {
return err
}
msg := &message.Message{
Type: message.Notify,
Route: route,
Data: data,
}
return c.sendMessage(msg)
}
// On add the callback for the event
func (c *Client) On(event string, callback Callback) {
c.muEvents.Lock()
defer c.muEvents.Unlock()
c.events[event] = callback
}
// Close the connection, and shutdown the benchmark
func (c *Client) Close() {
c.conn.Close()
close(c.die)
}
func (c *Client) eventHandler(event string) (Callback, bool) {
c.muEvents.RLock()
defer c.muEvents.RUnlock()
cb, ok := c.events[event]
return cb, ok
}
func (c *Client) responseHandler(mid uint64) (Callback, bool) {
c.muResponses.RLock()
defer c.muResponses.RUnlock()
cb, ok := c.responses[mid]
return cb, ok
}
func (c *Client) setResponseHandler(mid uint64, cb Callback) {
c.muResponses.Lock()
defer c.muResponses.Unlock()
if cb == nil {
delete(c.responses, mid)
} else {
c.responses[mid] = cb
}
}
func (c *Client) sendMessage(msg *message.Message) error {
data, err := msg.Encode()
if err != nil {
return err
}
//log.Printf("%+v",msg)
payload, err := codec.Encode(packet.Data, data)
if err != nil {
return err
}
c.mid++
c.send(payload)
return nil
}
func (c *Client) write() {
defer close(c.chSend)
for {
select {
case data := <-c.chSend:
if _, err := c.conn.Write(data); err != nil {
log.Println(err.Error())
c.Close()
}
case <-c.die:
return
}
}
}
func (c *Client) send(data []byte) {
c.chSend <- data
}
func (c *Client) read() {
buf := make([]byte, 2048)
for {
n, err := c.conn.Read(buf)
if err != nil {
log.Println(err.Error())
c.Close()
return
}
packets, err := c.codec.Decode(buf[:n])
if err != nil {
log.Println(err.Error())
c.Close()
return
}
for i := range packets {
p := packets[i]
c.processPacket(p)
}
}
}
func (c *Client) processPacket(p *packet.Packet) {
switch p.Type {
case packet.Handshake:
c.send(had)
c.connectedCallback()
case packet.Data:
msg, err := message.Decode(p.Data)
if err != nil {
log.Println(err.Error())
return
}
c.processMessage(msg)
case packet.Kick:
c.Close()
}
}
func (c *Client) processMessage(msg *message.Message) {
switch msg.Type {
case message.Push:
cb, ok := c.eventHandler(msg.Route)
if !ok {
log.Println("event handler not found", msg.Route)
return
}
cb(msg.Data)
case message.Response:
cb, ok := c.responseHandler(msg.ID)
if !ok {
log.Println("response handler not found", msg.ID)
return
}
cb(msg.Data)
c.setResponseHandler(msg.ID, nil)
}
}
func serialize(v proto.Message) ([]byte, error) {
data, err := proto.Marshal(v)
if err != nil {
return nil, err
}
return data, nil
}