package client import ( "context" "net" "sync" "time" "github.com/noahlann/nnet/pkg/client" ) // tcpClient TCP客户端实现 type tcpClient struct { config *client.Config conn net.Conn mu sync.RWMutex connected bool ctx context.Context cancel context.CancelFunc readBuffer []byte // 自动重连相关 reconnectAttempts int reconnectCh chan struct{} reconnectStopCh chan struct{} // 异步消息推送 messageCh chan []byte messageErrCh chan error onMessage func([]byte) onError func(error) // 控制是否已启动消息接收协程(仅在需要时启动,避免与同步Receive竞争读取) receiverStarted bool } // NewTCPClient 创建TCP客户端 func NewTCPClient(config *client.Config) client.Client { if config == nil { config = client.DefaultConfig() } ctx, cancel := context.WithCancel(context.Background()) c := &tcpClient{ config: config, ctx: ctx, cancel: cancel, readBuffer: make([]byte, 4096), reconnectCh: make(chan struct{}, 1), reconnectStopCh: make(chan struct{}), messageCh: make(chan []byte, 100), messageErrCh: make(chan error, 10), } // 如果启用自动重连,启动重连goroutine if config.AutoReconnect { go c.reconnectLoop() } // 默认不启动消息接收goroutine,避免与同步Receive竞争读取 // 仅当设置了消息回调或显式需要异步接收时再启动 return c } // Connect 连接服务器 func (c *tcpClient) Connect() error { return c.connect() } // connect 内部连接方法 func (c *tcpClient) connect() error { c.mu.Lock() defer c.mu.Unlock() if c.connected { return nil } // 解析地址 addr := c.config.Addr if len(addr) > 6 && addr[:6] == "tcp://" { addr = addr[6:] } // 创建连接 dialer := &net.Dialer{ Timeout: c.config.ConnectTimeout, } conn, err := dialer.DialContext(c.ctx, "tcp", addr) if err != nil { // 如果启用自动重连,触发重连 if c.config.AutoReconnect { select { case c.reconnectCh <- struct{}{}: default: } } return err } c.conn = conn c.connected = true c.reconnectAttempts = 0 // 重置重连次数 return nil } // Disconnect 断开连接 func (c *tcpClient) Disconnect() error { c.mu.Lock() defer c.mu.Unlock() if !c.connected { return nil } if c.conn != nil { c.conn.Close() c.conn = nil } c.connected = false // 如果启用自动重连,触发重连 if c.config.AutoReconnect { select { case c.reconnectCh <- struct{}{}: default: } } return nil } // Send 发送数据 func (c *tcpClient) Send(data []byte) error { c.mu.RLock() defer c.mu.RUnlock() if !c.connected || c.conn == nil { return client.NewError("not connected") } if c.config.WriteTimeout > 0 { c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)) } _, err := c.conn.Write(data) return err } // Receive 接收数据 func (c *tcpClient) Receive() ([]byte, error) { c.mu.RLock() defer c.mu.RUnlock() if !c.connected || c.conn == nil { return nil, client.NewError("not connected") } if c.config.ReadTimeout > 0 { c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)) } n, err := c.conn.Read(c.readBuffer) if err != nil { return nil, err } result := make([]byte, n) copy(result, c.readBuffer[:n]) return result, nil } // Request 请求-响应(带超时) func (c *tcpClient) Request(data []byte, timeout time.Duration) ([]byte, error) { // 发送请求 if err := c.Send(data); err != nil { return nil, err } // 设置读取超时 oldTimeout := c.config.ReadTimeout c.config.ReadTimeout = timeout defer func() { c.config.ReadTimeout = oldTimeout }() // 接收响应 return c.Receive() } // IsConnected 检查是否已连接 func (c *tcpClient) IsConnected() bool { c.mu.RLock() defer c.mu.RUnlock() return c.connected } // Close 关闭客户端 func (c *tcpClient) Close() error { c.cancel() close(c.reconnectStopCh) c.mu.Lock() defer c.mu.Unlock() c.config.AutoReconnect = false // 禁用自动重连 if c.conn != nil { c.conn.Close() c.conn = nil } c.connected = false return nil } // reconnectLoop 重连循环 func (c *tcpClient) reconnectLoop() { for { select { case <-c.ctx.Done(): return case <-c.reconnectStopCh: return case <-c.reconnectCh: c.mu.RLock() autoReconnect := c.config.AutoReconnect connected := c.connected attempts := c.reconnectAttempts c.mu.RUnlock() if !autoReconnect { return } // 如果已经连接,跳过 if connected { continue } // 检查最大重连次数 if c.config.MaxReconnectAttempts > 0 && attempts >= c.config.MaxReconnectAttempts { if c.onError != nil { c.onError(client.NewError("max reconnect attempts reached")) } return } // 等待重连间隔 time.Sleep(c.config.ReconnectInterval) // 尝试重连 c.mu.Lock() c.reconnectAttempts++ c.mu.Unlock() if err := c.connect(); err != nil { if c.onError != nil { c.onError(err) } // 继续重连(延迟触发,避免立即重试) go func() { time.Sleep(1 * time.Second) select { case c.reconnectCh <- struct{}{}: default: } }() } } } } // messageReceiver 消息接收器(异步) func (c *tcpClient) messageReceiver() { for { select { case <-c.ctx.Done(): return default: c.mu.RLock() conn := c.conn connected := c.connected c.mu.RUnlock() if !connected || conn == nil { time.Sleep(100 * time.Millisecond) continue } // 设置读取超时 if c.config.ReadTimeout > 0 { conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)) } buffer := make([]byte, 4096) n, err := conn.Read(buffer) if err != nil { // 连接错误,触发重连 if c.config.AutoReconnect { c.mu.Lock() c.connected = false if c.conn != nil { c.conn.Close() c.conn = nil } c.mu.Unlock() select { case c.reconnectCh <- struct{}{}: default: } } if c.onError != nil { c.onError(err) } continue } // 发送消息到channel message := make([]byte, n) copy(message, buffer[:n]) select { case c.messageCh <- message: default: // channel已满,丢弃消息 } // 调用回调 if c.onMessage != nil { c.onMessage(message) } } } } // SetOnMessage 设置消息回调 func (c *tcpClient) SetOnMessage(fn func([]byte)) { c.mu.Lock() defer c.mu.Unlock() c.onMessage = fn // 首次设置回调时,启动消息接收协程 if !c.receiverStarted { go c.messageReceiver() c.receiverStarted = true } } // SetOnError 设置错误回调 func (c *tcpClient) SetOnError(fn func(error)) { c.mu.Lock() defer c.mu.Unlock() c.onError = fn } // MessageChannel 获取消息channel func (c *tcpClient) MessageChannel() <-chan []byte { return c.messageCh }