You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
2.8 KiB
Go

package client
import (
"net/url"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/noahlann/nnet/pkg/client"
)
// websocketClient WebSocket客户端实现
type websocketClient struct {
config *client.Config
conn *websocket.Conn
mu sync.RWMutex
connected bool
}
// NewWebSocketClient 创建WebSocket客户端
func NewWebSocketClient(config *client.Config) client.Client {
if config == nil {
config = client.DefaultConfig()
}
return &websocketClient{
config: config,
connected: false,
}
}
// Connect 连接服务器
func (c *websocketClient) Connect() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connected {
return nil
}
// 解析地址
addr := c.config.Addr
scheme := "ws"
if len(addr) > 6 && addr[:6] == "ws://" {
addr = addr[6:]
scheme = "ws"
} else if len(addr) > 7 && addr[:7] == "wss://" {
addr = addr[7:]
scheme = "wss"
} else if c.config.TLSEnabled {
scheme = "wss"
}
// 构建URL
u := url.URL{Scheme: scheme, Host: addr, Path: "/"}
// 建立WebSocket连接
dialer := websocket.Dialer{
HandshakeTimeout: c.config.ConnectTimeout,
}
conn, _, err := dialer.Dial(u.String(), nil)
if err != nil {
return client.NewErrorf("failed to dial WebSocket: %v", err)
}
c.conn = conn
c.connected = true
return nil
}
// Disconnect 断开连接
func (c *websocketClient) Disconnect() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil
}
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.connected = false
return nil
}
// Send 发送数据
func (c *websocketClient) Send(data []byte) error {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.connected || c.conn == nil {
return client.NewError("not connected")
}
return c.conn.WriteMessage(websocket.TextMessage, data)
}
// Receive 接收数据
func (c *websocketClient) Receive() ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.connected || c.conn == nil {
return nil, client.NewError("not connected")
}
// 设置读取超时
if c.config.ReadTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
}
_, data, err := c.conn.ReadMessage()
if err != nil {
return nil, err
}
return data, nil
}
// Request 请求-响应(带超时)
func (c *websocketClient) Request(data []byte, timeout time.Duration) ([]byte, error) {
// 发送请求
if err := c.Send(data); err != nil {
return nil, err
}
// 设置读取超时
oldTimeout := c.config.ReadTimeout
c.config.ReadTimeout = timeout
defer func() {
c.config.ReadTimeout = oldTimeout
}()
// 接收响应
return c.Receive()
}
// IsConnected 检查是否已连接
func (c *websocketClient) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.connected
}
// Close 关闭客户端
func (c *websocketClient) Close() error {
return c.Disconnect()
}