|
|
|
package protocol
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/json"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"git.noahlan.cn/noahlan/nnet/core"
|
|
|
|
"git.noahlan.cn/noahlan/nnet/entity"
|
|
|
|
"git.noahlan.cn/noahlan/nnet/packet"
|
|
|
|
"git.noahlan.cn/noahlan/ntools-go/core/nlog"
|
|
|
|
)
|
|
|
|
|
|
|
|
type (
|
|
|
|
HandshakeValidatorFunc func(*HandshakeReq) error
|
|
|
|
HandshakeAckPayloadFunc func() interface{}
|
|
|
|
)
|
|
|
|
|
|
|
|
func WithNNetPipeline(
|
|
|
|
handshakeResp *HandshakeResp,
|
|
|
|
validator HandshakeValidatorFunc,
|
|
|
|
packer packet.Packer,
|
|
|
|
) core.RunOption {
|
|
|
|
return func(server *core.NNet) {
|
|
|
|
server.Pipeline().Inbound().PushFront(func(entity entity.NetworkEntity, v interface{}) error {
|
|
|
|
pkg, ok := v.(*NNetPacket)
|
|
|
|
if !ok {
|
|
|
|
return ErrWrongPacketType
|
|
|
|
}
|
|
|
|
conn, _ := entity.Conn()
|
|
|
|
|
|
|
|
switch pkg.PacketType {
|
|
|
|
case Handshake:
|
|
|
|
var handshakeData HandshakeReq
|
|
|
|
err := json.Unmarshal(pkg.Data, &handshakeData)
|
|
|
|
nlog.Must(err)
|
|
|
|
|
|
|
|
if err := validator(&handshakeData); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
handshakeResp.Payload = handshakeData.Payload
|
|
|
|
|
|
|
|
data, err := json.Marshal(handshakeResp)
|
|
|
|
nlog.Must(err)
|
|
|
|
|
|
|
|
hrd, _ := packer.Pack(Header{
|
|
|
|
PacketType: Handshake,
|
|
|
|
MessageHeader: MessageHeader{},
|
|
|
|
}, data)
|
|
|
|
if err := entity.SendBytes(hrd); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
entity.SetStatus(core.StatusPrepare)
|
|
|
|
nlog.Debugf("connection handshake Id=%d, Remote=%s", entity.Session().ID(), conn.RemoteAddr())
|
|
|
|
case HandshakeAck:
|
|
|
|
entity.SetStatus(core.StatusPending)
|
|
|
|
nlog.Debugf("receive handshake ACK Id=%d, Remote=%s", entity.Session().ID(), conn.RemoteAddr())
|
|
|
|
case Data:
|
|
|
|
if entity.Status() < core.StatusPending {
|
|
|
|
return errors.New(fmt.Sprintf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s",
|
|
|
|
conn.RemoteAddr()))
|
|
|
|
}
|
|
|
|
entity.SetStatus(core.StatusWorking)
|
|
|
|
|
|
|
|
var lastMid uint64
|
|
|
|
switch pkg.MsgType {
|
|
|
|
case Request:
|
|
|
|
lastMid = pkg.ID
|
|
|
|
case Notify:
|
|
|
|
lastMid = 0
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("Invalid message type: %s ", pkg.MsgType.String())
|
|
|
|
}
|
|
|
|
entity.SetLastMID(lastMid)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|