|
|
package client
|
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"net"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
|
"github.com/noahlann/nnet/pkg/client"
|
|
|
)
|
|
|
|
|
|
// tcpClient TCP客户端实现
|
|
|
type tcpClient struct {
|
|
|
config *client.Config
|
|
|
conn net.Conn
|
|
|
mu sync.RWMutex
|
|
|
connected bool
|
|
|
ctx context.Context
|
|
|
cancel context.CancelFunc
|
|
|
readBuffer []byte
|
|
|
|
|
|
// 自动重连相关
|
|
|
reconnectAttempts int
|
|
|
reconnectCh chan struct{}
|
|
|
reconnectStopCh chan struct{}
|
|
|
|
|
|
// 异步消息推送
|
|
|
messageCh chan []byte
|
|
|
messageErrCh chan error
|
|
|
onMessage func([]byte)
|
|
|
onError func(error)
|
|
|
|
|
|
// 控制是否已启动消息接收协程(仅在需要时启动,避免与同步Receive竞争读取)
|
|
|
receiverStarted bool
|
|
|
}
|
|
|
|
|
|
// NewTCPClient 创建TCP客户端
|
|
|
func NewTCPClient(config *client.Config) client.Client {
|
|
|
if config == nil {
|
|
|
config = client.DefaultConfig()
|
|
|
}
|
|
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
|
|
c := &tcpClient{
|
|
|
config: config,
|
|
|
ctx: ctx,
|
|
|
cancel: cancel,
|
|
|
readBuffer: make([]byte, 4096),
|
|
|
reconnectCh: make(chan struct{}, 1),
|
|
|
reconnectStopCh: make(chan struct{}),
|
|
|
messageCh: make(chan []byte, 100),
|
|
|
messageErrCh: make(chan error, 10),
|
|
|
}
|
|
|
|
|
|
// 如果启用自动重连,启动重连goroutine
|
|
|
if config.AutoReconnect {
|
|
|
go c.reconnectLoop()
|
|
|
}
|
|
|
|
|
|
// 默认不启动消息接收goroutine,避免与同步Receive竞争读取
|
|
|
// 仅当设置了消息回调或显式需要异步接收时再启动
|
|
|
|
|
|
return c
|
|
|
}
|
|
|
|
|
|
// Connect 连接服务器
|
|
|
func (c *tcpClient) Connect() error {
|
|
|
return c.connect()
|
|
|
}
|
|
|
|
|
|
// connect 内部连接方法
|
|
|
func (c *tcpClient) connect() error {
|
|
|
c.mu.Lock()
|
|
|
defer c.mu.Unlock()
|
|
|
|
|
|
if c.connected {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// 解析地址
|
|
|
addr := c.config.Addr
|
|
|
if len(addr) > 6 && addr[:6] == "tcp://" {
|
|
|
addr = addr[6:]
|
|
|
}
|
|
|
|
|
|
// 创建连接
|
|
|
dialer := &net.Dialer{
|
|
|
Timeout: c.config.ConnectTimeout,
|
|
|
}
|
|
|
|
|
|
conn, err := dialer.DialContext(c.ctx, "tcp", addr)
|
|
|
if err != nil {
|
|
|
// 如果启用自动重连,触发重连
|
|
|
if c.config.AutoReconnect {
|
|
|
select {
|
|
|
case c.reconnectCh <- struct{}{}:
|
|
|
default:
|
|
|
}
|
|
|
}
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
c.conn = conn
|
|
|
c.connected = true
|
|
|
c.reconnectAttempts = 0 // 重置重连次数
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Disconnect 断开连接
|
|
|
func (c *tcpClient) Disconnect() error {
|
|
|
c.mu.Lock()
|
|
|
defer c.mu.Unlock()
|
|
|
|
|
|
if !c.connected {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
if c.conn != nil {
|
|
|
c.conn.Close()
|
|
|
c.conn = nil
|
|
|
}
|
|
|
|
|
|
c.connected = false
|
|
|
|
|
|
// 如果启用自动重连,触发重连
|
|
|
if c.config.AutoReconnect {
|
|
|
select {
|
|
|
case c.reconnectCh <- struct{}{}:
|
|
|
default:
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// Send 发送数据
|
|
|
func (c *tcpClient) Send(data []byte) error {
|
|
|
c.mu.RLock()
|
|
|
defer c.mu.RUnlock()
|
|
|
|
|
|
if !c.connected || c.conn == nil {
|
|
|
return client.NewError("not connected")
|
|
|
}
|
|
|
|
|
|
if c.config.WriteTimeout > 0 {
|
|
|
c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
|
|
|
}
|
|
|
|
|
|
_, err := c.conn.Write(data)
|
|
|
return err
|
|
|
}
|
|
|
|
|
|
// Receive 接收数据
|
|
|
func (c *tcpClient) Receive() ([]byte, error) {
|
|
|
c.mu.RLock()
|
|
|
defer c.mu.RUnlock()
|
|
|
|
|
|
if !c.connected || c.conn == nil {
|
|
|
return nil, client.NewError("not connected")
|
|
|
}
|
|
|
|
|
|
if c.config.ReadTimeout > 0 {
|
|
|
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
|
|
|
}
|
|
|
|
|
|
n, err := c.conn.Read(c.readBuffer)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
result := make([]byte, n)
|
|
|
copy(result, c.readBuffer[:n])
|
|
|
return result, nil
|
|
|
}
|
|
|
|
|
|
// Request 请求-响应(带超时)
|
|
|
func (c *tcpClient) Request(data []byte, timeout time.Duration) ([]byte, error) {
|
|
|
// 发送请求
|
|
|
if err := c.Send(data); err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
|
|
|
// 设置读取超时
|
|
|
oldTimeout := c.config.ReadTimeout
|
|
|
c.config.ReadTimeout = timeout
|
|
|
defer func() {
|
|
|
c.config.ReadTimeout = oldTimeout
|
|
|
}()
|
|
|
|
|
|
// 接收响应
|
|
|
return c.Receive()
|
|
|
}
|
|
|
|
|
|
// IsConnected 检查是否已连接
|
|
|
func (c *tcpClient) IsConnected() bool {
|
|
|
c.mu.RLock()
|
|
|
defer c.mu.RUnlock()
|
|
|
return c.connected
|
|
|
}
|
|
|
|
|
|
// Close 关闭客户端
|
|
|
func (c *tcpClient) Close() error {
|
|
|
c.cancel()
|
|
|
close(c.reconnectStopCh)
|
|
|
c.mu.Lock()
|
|
|
defer c.mu.Unlock()
|
|
|
c.config.AutoReconnect = false // 禁用自动重连
|
|
|
if c.conn != nil {
|
|
|
c.conn.Close()
|
|
|
c.conn = nil
|
|
|
}
|
|
|
c.connected = false
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
// reconnectLoop 重连循环
|
|
|
func (c *tcpClient) reconnectLoop() {
|
|
|
for {
|
|
|
select {
|
|
|
case <-c.ctx.Done():
|
|
|
return
|
|
|
case <-c.reconnectStopCh:
|
|
|
return
|
|
|
case <-c.reconnectCh:
|
|
|
c.mu.RLock()
|
|
|
autoReconnect := c.config.AutoReconnect
|
|
|
connected := c.connected
|
|
|
attempts := c.reconnectAttempts
|
|
|
c.mu.RUnlock()
|
|
|
|
|
|
if !autoReconnect {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 如果已经连接,跳过
|
|
|
if connected {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
// 检查最大重连次数
|
|
|
if c.config.MaxReconnectAttempts > 0 && attempts >= c.config.MaxReconnectAttempts {
|
|
|
if c.onError != nil {
|
|
|
c.onError(client.NewError("max reconnect attempts reached"))
|
|
|
}
|
|
|
return
|
|
|
}
|
|
|
|
|
|
// 等待重连间隔
|
|
|
time.Sleep(c.config.ReconnectInterval)
|
|
|
|
|
|
// 尝试重连
|
|
|
c.mu.Lock()
|
|
|
c.reconnectAttempts++
|
|
|
c.mu.Unlock()
|
|
|
|
|
|
if err := c.connect(); err != nil {
|
|
|
if c.onError != nil {
|
|
|
c.onError(err)
|
|
|
}
|
|
|
// 继续重连(延迟触发,避免立即重试)
|
|
|
go func() {
|
|
|
time.Sleep(1 * time.Second)
|
|
|
select {
|
|
|
case c.reconnectCh <- struct{}{}:
|
|
|
default:
|
|
|
}
|
|
|
}()
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// messageReceiver 消息接收器(异步)
|
|
|
func (c *tcpClient) messageReceiver() {
|
|
|
for {
|
|
|
select {
|
|
|
case <-c.ctx.Done():
|
|
|
return
|
|
|
default:
|
|
|
c.mu.RLock()
|
|
|
conn := c.conn
|
|
|
connected := c.connected
|
|
|
c.mu.RUnlock()
|
|
|
|
|
|
if !connected || conn == nil {
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
// 设置读取超时
|
|
|
if c.config.ReadTimeout > 0 {
|
|
|
conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
|
|
|
}
|
|
|
|
|
|
buffer := make([]byte, 4096)
|
|
|
n, err := conn.Read(buffer)
|
|
|
if err != nil {
|
|
|
// 连接错误,触发重连
|
|
|
if c.config.AutoReconnect {
|
|
|
c.mu.Lock()
|
|
|
c.connected = false
|
|
|
if c.conn != nil {
|
|
|
c.conn.Close()
|
|
|
c.conn = nil
|
|
|
}
|
|
|
c.mu.Unlock()
|
|
|
select {
|
|
|
case c.reconnectCh <- struct{}{}:
|
|
|
default:
|
|
|
}
|
|
|
}
|
|
|
if c.onError != nil {
|
|
|
c.onError(err)
|
|
|
}
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
// 发送消息到channel
|
|
|
message := make([]byte, n)
|
|
|
copy(message, buffer[:n])
|
|
|
select {
|
|
|
case c.messageCh <- message:
|
|
|
default:
|
|
|
// channel已满,丢弃消息
|
|
|
}
|
|
|
|
|
|
// 调用回调
|
|
|
if c.onMessage != nil {
|
|
|
c.onMessage(message)
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// SetOnMessage 设置消息回调
|
|
|
func (c *tcpClient) SetOnMessage(fn func([]byte)) {
|
|
|
c.mu.Lock()
|
|
|
defer c.mu.Unlock()
|
|
|
c.onMessage = fn
|
|
|
// 首次设置回调时,启动消息接收协程
|
|
|
if !c.receiverStarted {
|
|
|
go c.messageReceiver()
|
|
|
c.receiverStarted = true
|
|
|
}
|
|
|
}
|
|
|
|
|
|
// SetOnError 设置错误回调
|
|
|
func (c *tcpClient) SetOnError(fn func(error)) {
|
|
|
c.mu.Lock()
|
|
|
defer c.mu.Unlock()
|
|
|
c.onError = fn
|
|
|
}
|
|
|
|
|
|
// MessageChannel 获取消息channel
|
|
|
func (c *tcpClient) MessageChannel() <-chan []byte {
|
|
|
return c.messageCh
|
|
|
}
|
|
|
|