You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

151 lines
2.6 KiB
Go

package client
import (
"net"
"sync"
"time"
"github.com/noahlann/nnet/pkg/client"
)
// unixClient Unix Domain Socket客户端实现
type unixClient struct {
config *client.Config
conn net.Conn
mu sync.RWMutex
connected bool
}
// NewUnixClient 创建Unix客户端
func NewUnixClient(config *client.Config) client.Client {
if config == nil {
config = client.DefaultConfig()
}
return &unixClient{
config: config,
connected: false,
}
}
// Connect 连接服务器
func (c *unixClient) Connect() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connected {
return nil
}
// 解析地址
addr := c.config.Addr
if len(addr) > 7 && addr[:7] == "unix://" {
addr = addr[7:]
}
// 创建Unix连接
dialer := &net.Dialer{
Timeout: c.config.ConnectTimeout,
}
conn, err := dialer.Dial("unix", addr)
if err != nil {
return client.NewErrorf("failed to dial Unix socket: %v", err)
}
c.conn = conn
c.connected = true
return nil
}
// Disconnect 断开连接
func (c *unixClient) Disconnect() error {
c.mu.Lock()
defer c.mu.Unlock()
if !c.connected {
return nil
}
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.connected = false
return nil
}
// Send 发送数据
func (c *unixClient) Send(data []byte) error {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.connected || c.conn == nil {
return client.NewError("not connected")
}
if c.config.WriteTimeout > 0 {
c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
}
_, err := c.conn.Write(data)
return err
}
// Receive 接收数据
func (c *unixClient) Receive() ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if !c.connected || c.conn == nil {
return nil, client.NewError("not connected")
}
if c.config.ReadTimeout > 0 {
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
}
buffer := make([]byte, 4096)
n, err := c.conn.Read(buffer)
if err != nil {
return nil, err
}
result := make([]byte, n)
copy(result, buffer[:n])
return result, nil
}
// Request 请求-响应(带超时)
func (c *unixClient) Request(data []byte, timeout time.Duration) ([]byte, error) {
// 发送请求
if err := c.Send(data); err != nil {
return nil, err
}
// 设置读取超时
oldTimeout := c.config.ReadTimeout
c.config.ReadTimeout = timeout
defer func() {
c.config.ReadTimeout = oldTimeout
}()
// 接收响应
return c.Receive()
}
// IsConnected 检查是否已连接
func (c *unixClient) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.connected
}
// Close 关闭客户端
func (c *unixClient) Close() error {
return c.Disconnect()
}