diff --git a/protocol/nnet.go b/protocol/nnet.go index bcc6daa..b230bd9 100644 --- a/protocol/nnet.go +++ b/protocol/nnet.go @@ -1,6 +1,7 @@ package protocol import ( + "encoding/json" "git.noahlan.cn/noahlan/nnet/core" "git.noahlan.cn/noahlan/nnet/entity" "git.noahlan.cn/noahlan/nnet/middleware" @@ -9,26 +10,62 @@ import ( "time" ) -type NNetConfig struct { -} +type ( + NNetConfig struct { + HeartbeatInterval time.Duration + HandshakeValidator HandshakeValidatorFunc + HandshakeAckBuilder HandshakeAckBuilderFunc + } + + handshakeData struct { + Version string `json:"version"` // 客户端版本,服务器以此判断是否合适与客户端通信 + Type string `json:"type"` // 客户端类型,与客户端版本号一起来确定客户端是否合适 + ClientId string `json:"clientId"` // 客户端ID,服务器以此取值 + ClientSecret string `json:"clientSecret"` // 客户端密钥,服务器以此判定客户端是否可用 + + // 透传信息 + Payload interface{} `json:"payload,optional,omitempty"` + } + + HandshakeAckData struct { + // 心跳间隔,单位秒 0表示不需要心跳 + Heartbeat int64 `json:"heartbeat"` -func WithNNetProtocol( - handshakeValidator func([]byte) error, - heartbeatInterval time.Duration, -) []core.RunOption { - if handshakeValidator == nil { - handshakeValidator = func(bytes []byte) error { return nil } + // 路由 + Routes map[string]uint16 `json:"routes"` // route map to code + Codes map[uint16]string `json:"codes"` // code map to route + + // 透传信息 + Payload interface{} `json:"payload,optional,omitempty"` } - packer := NewNNetPacker() - hbd, err := packer.Pack(Handshake, nil) - nlog.Must(err) +) - return []core.RunOption{ - WithNNetPipeline(handshakeValidator), +func WithNNetProtocol(config NNetConfig) []core.RunOption { + if config.HandshakeValidator == nil { + config.HandshakeValidator = func(bytes []byte) error { return nil } + } + if config.HandshakeAckBuilder == nil { + config.HandshakeAckBuilder = func() ([]byte, error) { + defaultData := &HandshakeAckData{} + return json.Marshal(defaultData) + } + } + + opts := []core.RunOption{ + WithNNetPipeline(config.HandshakeAckBuilder, config.HandshakeValidator), core.WithRouter(NewNNetRouter()), core.WithPacker(func() packet.Packer { return NewNNetPacker() }), - middleware.WithHeartbeat(heartbeatInterval, func(_ entity.NetworkEntity) []byte { + } + + if config.HeartbeatInterval.Seconds() > 0 { + packer := NewNNetPacker() + hbd, err := packer.Pack(Handshake, nil) + nlog.Must(err) + + opts = append(opts, middleware.WithHeartbeat(config.HeartbeatInterval, func(_ entity.NetworkEntity) []byte { return hbd - }), + })) } + + return opts } diff --git a/protocol/pipeline_nnet.go b/protocol/pipeline_nnet.go index 5879b50..4a83c16 100644 --- a/protocol/pipeline_nnet.go +++ b/protocol/pipeline_nnet.go @@ -1,44 +1,20 @@ package protocol import ( - "encoding/json" + "errors" "fmt" "git.noahlan.cn/noahlan/nnet/core" "git.noahlan.cn/noahlan/nnet/entity" "git.noahlan.cn/noahlan/ntools-go/core/nlog" - "time" ) type ( - handshakeData struct { - Version string `json:"version"` // 客户端版本号,服务器以此判断是否合适与客户端通信 - Type string `json:"type"` // 客户端类型,与客户端版本号一起来确定客户端是否合适 - - // 透传信息 - Payload interface{} `json:"payload,optional,omitempty"` - } - handshakeAckData struct { - Heartbeat int64 `json:"heartbeat"` // 心跳间隔,单位秒 0表示不需要心跳 - // 路由 - Routes map[string]uint16 `json:"routes"` // route map to code - Codes map[uint16]string `json:"codes"` // code map to route - - // 服务端支持的body部分消息传输协议 - //Protocol string `json:"protocol,options=[plain,json,protobuf]"` // plain/json/protobuf - - // 透传信息 - Payload interface{} `json:"payload,optional,omitempty"` - } + HandshakeValidatorFunc func([]byte) error + HandshakeAckBuilderFunc func() (interface{}, error) ) -func WithNNetPipeline(heartbeatInterval time.Duration, handshakeValidator func([]byte) error) core.RunOption { - handshakeAck := &handshakeAckData{} - data, err := json.Marshal(handshakeAck) - nlog.Must(err) - +func WithNNetPipeline(ackDataBuilder HandshakeAckBuilderFunc, validator HandshakeValidatorFunc) core.RunOption { packer := NewNNetPacker() - hrd, _ := packer.Pack(Handshake, data) - return func(server *core.Server) { server.Pipeline().Inbound().PushFront(func(entity entity.NetworkEntity, v interface{}) error { pkg, ok := v.(*NNetPacket) @@ -49,9 +25,13 @@ func WithNNetPipeline(heartbeatInterval time.Duration, handshakeValidator func([ switch pkg.PacketType { case Handshake: - if err := handshakeValidator(pkg.Data); err != nil { + if err := validator(pkg.Data); err != nil { return err } + data, err := ackDataBuilder() + nlog.Must(err) + + hrd, _ := packer.Pack(Handshake, data) if err := entity.SendBytes(hrd); err != nil { return err }