package client import ( "net/url" "sync" "time" "github.com/gorilla/websocket" "github.com/noahlann/nnet/pkg/client" ) // websocketClient WebSocket客户端实现 type websocketClient struct { config *client.Config conn *websocket.Conn mu sync.RWMutex connected bool } // NewWebSocketClient 创建WebSocket客户端 func NewWebSocketClient(config *client.Config) client.Client { if config == nil { config = client.DefaultConfig() } return &websocketClient{ config: config, connected: false, } } // Connect 连接服务器 func (c *websocketClient) Connect() error { c.mu.Lock() defer c.mu.Unlock() if c.connected { return nil } // 解析地址 addr := c.config.Addr scheme := "ws" if len(addr) > 6 && addr[:6] == "ws://" { addr = addr[6:] scheme = "ws" } else if len(addr) > 7 && addr[:7] == "wss://" { addr = addr[7:] scheme = "wss" } else if c.config.TLSEnabled { scheme = "wss" } // 构建URL u := url.URL{Scheme: scheme, Host: addr, Path: "/"} // 建立WebSocket连接 dialer := websocket.Dialer{ HandshakeTimeout: c.config.ConnectTimeout, } conn, _, err := dialer.Dial(u.String(), nil) if err != nil { return client.NewErrorf("failed to dial WebSocket: %v", err) } c.conn = conn c.connected = true return nil } // Disconnect 断开连接 func (c *websocketClient) Disconnect() error { c.mu.Lock() defer c.mu.Unlock() if !c.connected { return nil } if c.conn != nil { c.conn.Close() c.conn = nil } c.connected = false return nil } // Send 发送数据 func (c *websocketClient) Send(data []byte) error { c.mu.RLock() defer c.mu.RUnlock() if !c.connected || c.conn == nil { return client.NewError("not connected") } return c.conn.WriteMessage(websocket.TextMessage, data) } // Receive 接收数据 func (c *websocketClient) Receive() ([]byte, error) { c.mu.RLock() defer c.mu.RUnlock() if !c.connected || c.conn == nil { return nil, client.NewError("not connected") } // 设置读取超时 if c.config.ReadTimeout > 0 { c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)) } _, data, err := c.conn.ReadMessage() if err != nil { return nil, err } return data, nil } // Request 请求-响应(带超时) func (c *websocketClient) Request(data []byte, timeout time.Duration) ([]byte, error) { // 发送请求 if err := c.Send(data); err != nil { return nil, err } // 设置读取超时 oldTimeout := c.config.ReadTimeout c.config.ReadTimeout = timeout defer func() { c.config.ReadTimeout = oldTimeout }() // 接收响应 return c.Receive() } // IsConnected 检查是否已连接 func (c *websocketClient) IsConnected() bool { c.mu.RLock() defer c.mu.RUnlock() return c.connected } // Close 关闭客户端 func (c *websocketClient) Close() error { return c.Disconnect() }