package client import ( "net" "sync" "time" "github.com/noahlann/nnet/pkg/client" ) // unixClient Unix Domain Socket客户端实现 type unixClient struct { config *client.Config conn net.Conn mu sync.RWMutex connected bool } // NewUnixClient 创建Unix客户端 func NewUnixClient(config *client.Config) client.Client { if config == nil { config = client.DefaultConfig() } return &unixClient{ config: config, connected: false, } } // Connect 连接服务器 func (c *unixClient) Connect() error { c.mu.Lock() defer c.mu.Unlock() if c.connected { return nil } // 解析地址 addr := c.config.Addr if len(addr) > 7 && addr[:7] == "unix://" { addr = addr[7:] } // 创建Unix连接 dialer := &net.Dialer{ Timeout: c.config.ConnectTimeout, } conn, err := dialer.Dial("unix", addr) if err != nil { return client.NewErrorf("failed to dial Unix socket: %v", err) } c.conn = conn c.connected = true return nil } // Disconnect 断开连接 func (c *unixClient) Disconnect() error { c.mu.Lock() defer c.mu.Unlock() if !c.connected { return nil } if c.conn != nil { c.conn.Close() c.conn = nil } c.connected = false return nil } // Send 发送数据 func (c *unixClient) Send(data []byte) error { c.mu.RLock() defer c.mu.RUnlock() if !c.connected || c.conn == nil { return client.NewError("not connected") } if c.config.WriteTimeout > 0 { c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)) } _, err := c.conn.Write(data) return err } // Receive 接收数据 func (c *unixClient) Receive() ([]byte, error) { c.mu.RLock() defer c.mu.RUnlock() if !c.connected || c.conn == nil { return nil, client.NewError("not connected") } if c.config.ReadTimeout > 0 { c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout)) } buffer := make([]byte, 4096) n, err := c.conn.Read(buffer) if err != nil { return nil, err } result := make([]byte, n) copy(result, buffer[:n]) return result, nil } // Request 请求-响应(带超时) func (c *unixClient) Request(data []byte, timeout time.Duration) ([]byte, error) { // 发送请求 if err := c.Send(data); err != nil { return nil, err } // 设置读取超时 oldTimeout := c.config.ReadTimeout c.config.ReadTimeout = timeout defer func() { c.config.ReadTimeout = oldTimeout }() // 接收响应 return c.Receive() } // IsConnected 检查是否已连接 func (c *unixClient) IsConnected() bool { c.mu.RLock() defer c.mu.RUnlock() return c.connected } // Close 关闭客户端 func (c *unixClient) Close() error { return c.Disconnect() }