You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

361 lines
6.8 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package client
import (
"context"
"net"
"sync"
"time"
"github.com/noahlann/nnet/pkg/client"
)
// tcpClient TCP客户端实现
type tcpClient struct {
config *client.Config
conn net.Conn
mu sync.RWMutex
connected bool
ctx context.Context
cancel context.CancelFunc
readBuffer []byte
// 自动重连相关
reconnectAttempts int
reconnectCh chan struct{}
reconnectStopCh chan struct{}
// 异步消息推送
messageCh chan []byte
messageErrCh chan error
onMessage func([]byte)
onError func(error)
// 控制是否已启动消息接收协程仅在需要时启动避免与同步Receive竞争读取
receiverStarted bool
}
// NewTCPClient 创建TCP客户端
func NewTCPClient(config *client.Config) client.Client {
if config == nil {
config = client.DefaultConfig()
}
ctx, cancel := context.WithCancel(context.Background())
c := &tcpClient{
config: config,
ctx: ctx,
cancel: cancel,
readBuffer: make([]byte, 4096),
reconnectCh: make(chan struct{}, 1),
reconnectStopCh: make(chan struct{}),
messageCh: make(chan []byte, 100),
messageErrCh: make(chan error, 10),
}
// 如果启用自动重连启动重连goroutine
if config.AutoReconnect {
go c.reconnectLoop()
}
// 默认不启动消息接收goroutine避免与同步Receive竞争读取
// 仅当设置了消息回调或显式需要异步接收时再启动
return c
}
// Connect 连接服务器
func (c *tcpClient) Connect() error {
return c.connect()
}
// connect 内部连接方法
func (c *tcpClient) connect() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connected {
return nil
}
// 解析地址
addr := c.config.Addr
if len(addr) > 6 && addr[:6] == "tcp://" {
addr = addr[6:]
}
// 创建连接
dialer := &net.Dialer{
Timeout: c.config.ConnectTimeout,
}
conn, err := dialer.DialContext(c.ctx, "tcp", addr)
if err != nil {
// 如果启用自动重连,触发重连
if c.config.AutoReconnect {
select {
case c.reconnectCh <- struct{}{}:
default:
}
}
return err
}
c.conn = conn
c.connected = true
c.reconnectAttempts = 0 // 重置重连次数
return nil
}
// Disconnect 断开连接
func (c *tcpClient) Disconnect() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil
}
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.connected = false
// 如果启用自动重连,触发重连
if c.config.AutoReconnect {
select {
case c.reconnectCh <- struct{}{}:
default:
}
}
return nil
}
// Send 发送数据
func (c *tcpClient) Send(data []byte) error {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.connected || c.conn == nil {
return client.NewError("not connected")
}
if c.config.WriteTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
}
_, err := c.conn.Write(data)
return err
}
// Receive 接收数据
func (c *tcpClient) Receive() ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.connected || c.conn == nil {
return nil, client.NewError("not connected")
}
if c.config.ReadTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
}
n, err := c.conn.Read(c.readBuffer)
if err != nil {
return nil, err
}
result := make([]byte, n)
copy(result, c.readBuffer[:n])
return result, nil
}
// Request 请求-响应(带超时)
func (c *tcpClient) Request(data []byte, timeout time.Duration) ([]byte, error) {
// 发送请求
if err := c.Send(data); err != nil {
return nil, err
}
// 设置读取超时
oldTimeout := c.config.ReadTimeout
c.config.ReadTimeout = timeout
defer func() {
c.config.ReadTimeout = oldTimeout
}()
// 接收响应
return c.Receive()
}
// IsConnected 检查是否已连接
func (c *tcpClient) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.connected
}
// Close 关闭客户端
func (c *tcpClient) Close() error {
c.cancel()
close(c.reconnectStopCh)
c.mu.Lock()
defer c.mu.Unlock()
c.config.AutoReconnect = false // 禁用自动重连
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.connected = false
return nil
}
// reconnectLoop 重连循环
func (c *tcpClient) reconnectLoop() {
for {
select {
case <-c.ctx.Done():
return
case <-c.reconnectStopCh:
return
case <-c.reconnectCh:
c.mu.RLock()
autoReconnect := c.config.AutoReconnect
connected := c.connected
attempts := c.reconnectAttempts
c.mu.RUnlock()
if !autoReconnect {
return
}
// 如果已经连接,跳过
if connected {
continue
}
// 检查最大重连次数
if c.config.MaxReconnectAttempts > 0 && attempts >= c.config.MaxReconnectAttempts {
if c.onError != nil {
c.onError(client.NewError("max reconnect attempts reached"))
}
return
}
// 等待重连间隔
time.Sleep(c.config.ReconnectInterval)
// 尝试重连
c.mu.Lock()
c.reconnectAttempts++
c.mu.Unlock()
if err := c.connect(); err != nil {
if c.onError != nil {
c.onError(err)
}
// 继续重连(延迟触发,避免立即重试)
go func() {
time.Sleep(1 * time.Second)
select {
case c.reconnectCh <- struct{}{}:
default:
}
}()
}
}
}
}
// messageReceiver 消息接收器(异步)
func (c *tcpClient) messageReceiver() {
for {
select {
case <-c.ctx.Done():
return
default:
c.mu.RLock()
conn := c.conn
connected := c.connected
c.mu.RUnlock()
if !connected || conn == nil {
time.Sleep(100 * time.Millisecond)
continue
}
// 设置读取超时
if c.config.ReadTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
}
buffer := make([]byte, 4096)
n, err := conn.Read(buffer)
if err != nil {
// 连接错误,触发重连
if c.config.AutoReconnect {
c.mu.Lock()
c.connected = false
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.mu.Unlock()
select {
case c.reconnectCh <- struct{}{}:
default:
}
}
if c.onError != nil {
c.onError(err)
}
continue
}
// 发送消息到channel
message := make([]byte, n)
copy(message, buffer[:n])
select {
case c.messageCh <- message:
default:
// channel已满丢弃消息
}
// 调用回调
if c.onMessage != nil {
c.onMessage(message)
}
}
}
}
// SetOnMessage 设置消息回调
func (c *tcpClient) SetOnMessage(fn func([]byte)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onMessage = fn
// 首次设置回调时,启动消息接收协程
if !c.receiverStarted {
go c.messageReceiver()
c.receiverStarted = true
}
}
// SetOnError 设置错误回调
func (c *tcpClient) SetOnError(fn func(error)) {
c.mu.Lock()
defer c.mu.Unlock()
c.onError = fn
}
// MessageChannel 获取消息channel
func (c *tcpClient) MessageChannel() <-chan []byte {
return c.messageCh
}