You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
2.8 KiB
Go
155 lines
2.8 KiB
Go
package client
|
|
|
|
import (
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/noahlann/nnet/pkg/client"
|
|
)
|
|
|
|
// websocketClient WebSocket客户端实现
|
|
type websocketClient struct {
|
|
config *client.Config
|
|
conn *websocket.Conn
|
|
mu sync.RWMutex
|
|
connected bool
|
|
}
|
|
|
|
// NewWebSocketClient 创建WebSocket客户端
|
|
func NewWebSocketClient(config *client.Config) client.Client {
|
|
if config == nil {
|
|
config = client.DefaultConfig()
|
|
}
|
|
|
|
return &websocketClient{
|
|
config: config,
|
|
connected: false,
|
|
}
|
|
}
|
|
|
|
// Connect 连接服务器
|
|
func (c *websocketClient) Connect() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.connected {
|
|
return nil
|
|
}
|
|
|
|
// 解析地址
|
|
addr := c.config.Addr
|
|
scheme := "ws"
|
|
if len(addr) > 6 && addr[:6] == "ws://" {
|
|
addr = addr[6:]
|
|
scheme = "ws"
|
|
} else if len(addr) > 7 && addr[:7] == "wss://" {
|
|
addr = addr[7:]
|
|
scheme = "wss"
|
|
} else if c.config.TLSEnabled {
|
|
scheme = "wss"
|
|
}
|
|
|
|
// 构建URL
|
|
u := url.URL{Scheme: scheme, Host: addr, Path: "/"}
|
|
|
|
// 建立WebSocket连接
|
|
dialer := websocket.Dialer{
|
|
HandshakeTimeout: c.config.ConnectTimeout,
|
|
}
|
|
|
|
conn, _, err := dialer.Dial(u.String(), nil)
|
|
if err != nil {
|
|
return client.NewErrorf("failed to dial WebSocket: %v", err)
|
|
}
|
|
|
|
c.conn = conn
|
|
c.connected = true
|
|
|
|
return nil
|
|
}
|
|
|
|
// Disconnect 断开连接
|
|
func (c *websocketClient) Disconnect() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if !c.connected {
|
|
return nil
|
|
}
|
|
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
|
|
c.connected = false
|
|
return nil
|
|
}
|
|
|
|
// Send 发送数据
|
|
func (c *websocketClient) Send(data []byte) error {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
if !c.connected || c.conn == nil {
|
|
return client.NewError("not connected")
|
|
}
|
|
|
|
return c.conn.WriteMessage(websocket.TextMessage, data)
|
|
}
|
|
|
|
// Receive 接收数据
|
|
func (c *websocketClient) Receive() ([]byte, error) {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
if !c.connected || c.conn == nil {
|
|
return nil, client.NewError("not connected")
|
|
}
|
|
|
|
// 设置读取超时
|
|
if c.config.ReadTimeout > 0 {
|
|
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
|
|
}
|
|
|
|
_, data, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
// Request 请求-响应(带超时)
|
|
func (c *websocketClient) Request(data []byte, timeout time.Duration) ([]byte, error) {
|
|
// 发送请求
|
|
if err := c.Send(data); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 设置读取超时
|
|
oldTimeout := c.config.ReadTimeout
|
|
c.config.ReadTimeout = timeout
|
|
defer func() {
|
|
c.config.ReadTimeout = oldTimeout
|
|
}()
|
|
|
|
// 接收响应
|
|
return c.Receive()
|
|
}
|
|
|
|
// IsConnected 检查是否已连接
|
|
func (c *websocketClient) IsConnected() bool {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
return c.connected
|
|
}
|
|
|
|
// Close 关闭客户端
|
|
func (c *websocketClient) Close() error {
|
|
return c.Disconnect()
|
|
}
|
|
|