package nnet import ( "encoding/json" "errors" "fmt" "git.noahlan.cn/noahlan/nnet" "git.noahlan.cn/noahlan/nnet/connection" "git.noahlan.cn/noahlan/nnet/packet" "git.noahlan.cn/noahlan/ntool/nlog" ) type ( HandshakeValidatorFunc func(*HandshakeReq) error HandshakeAckPayloadFunc func() interface{} ) func withNNetPipeline( handshakeResp *HandshakeResp, validator HandshakeValidatorFunc, packer packet.Packer, ) nnet.RunOption { return func(ngin *nnet.Engine) { ngin.Pipeline().Inbound().PushFront(func(conn *connection.Connection, v interface{}) error { pkg, ok := v.(*Packet) if !ok { return packet.ErrWrongPacketType } nc, _ := conn.Conn() switch pkg.PacketType { case Handshake: var handshakeData HandshakeReq err := json.Unmarshal(pkg.Data, &handshakeData) nlog.Must(err) if err := validator(&handshakeData); err != nil { return err } handshakeResp.Payload = handshakeData.Payload data, err := json.Marshal(handshakeResp) nlog.Must(err) hrd, _ := packer.Pack(Header{ PacketType: Handshake, MessageHeader: MessageHeader{}, }, data) if err := conn.SendBytes(hrd); err != nil { return err } conn.SetStatus(connection.StatusPrepare) nlog.Debugf("connection handshake Id=%d, Remote=%s", conn.Session().ID(), nc.RemoteAddr()) case HandshakeAck: conn.SetStatus(connection.StatusPending) nlog.Debugf("receive handshake ACK Id=%d, Remote=%s", conn.Session().ID(), nc.RemoteAddr()) case Data: if conn.Status() < connection.StatusPending { return errors.New(fmt.Sprintf("receive data on socket which not yet ACK, session will be closed immediately, remote=%s", nc.RemoteAddr())) } conn.SetStatus(connection.StatusWorking) var lastMid uint64 switch pkg.MsgType { case Request: lastMid = pkg.ID case Notify: lastMid = 0 default: return fmt.Errorf("Invalid message type: %s ", pkg.MsgType.String()) } conn.SetLastMID(lastMid) } return nil }) } }