You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
151 lines
2.6 KiB
Go
151 lines
2.6 KiB
Go
package client
|
|
|
|
import (
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/noahlann/nnet/pkg/client"
|
|
)
|
|
|
|
// unixClient Unix Domain Socket客户端实现
|
|
type unixClient struct {
|
|
config *client.Config
|
|
conn net.Conn
|
|
mu sync.RWMutex
|
|
connected bool
|
|
}
|
|
|
|
// NewUnixClient 创建Unix客户端
|
|
func NewUnixClient(config *client.Config) client.Client {
|
|
if config == nil {
|
|
config = client.DefaultConfig()
|
|
}
|
|
|
|
return &unixClient{
|
|
config: config,
|
|
connected: false,
|
|
}
|
|
}
|
|
|
|
// Connect 连接服务器
|
|
func (c *unixClient) Connect() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.connected {
|
|
return nil
|
|
}
|
|
|
|
// 解析地址
|
|
addr := c.config.Addr
|
|
if len(addr) > 7 && addr[:7] == "unix://" {
|
|
addr = addr[7:]
|
|
}
|
|
|
|
// 创建Unix连接
|
|
dialer := &net.Dialer{
|
|
Timeout: c.config.ConnectTimeout,
|
|
}
|
|
|
|
conn, err := dialer.Dial("unix", addr)
|
|
if err != nil {
|
|
return client.NewErrorf("failed to dial Unix socket: %v", err)
|
|
}
|
|
|
|
c.conn = conn
|
|
c.connected = true
|
|
|
|
return nil
|
|
}
|
|
|
|
// Disconnect 断开连接
|
|
func (c *unixClient) Disconnect() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if !c.connected {
|
|
return nil
|
|
}
|
|
|
|
if c.conn != nil {
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
|
|
c.connected = false
|
|
return nil
|
|
}
|
|
|
|
// Send 发送数据
|
|
func (c *unixClient) Send(data []byte) error {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
if !c.connected || c.conn == nil {
|
|
return client.NewError("not connected")
|
|
}
|
|
|
|
if c.config.WriteTimeout > 0 {
|
|
c.conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
|
|
}
|
|
|
|
_, err := c.conn.Write(data)
|
|
return err
|
|
}
|
|
|
|
// Receive 接收数据
|
|
func (c *unixClient) Receive() ([]byte, error) {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
|
|
if !c.connected || c.conn == nil {
|
|
return nil, client.NewError("not connected")
|
|
}
|
|
|
|
if c.config.ReadTimeout > 0 {
|
|
c.conn.SetReadDeadline(time.Now().Add(c.config.ReadTimeout))
|
|
}
|
|
|
|
buffer := make([]byte, 4096)
|
|
n, err := c.conn.Read(buffer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make([]byte, n)
|
|
copy(result, buffer[:n])
|
|
return result, nil
|
|
}
|
|
|
|
// Request 请求-响应(带超时)
|
|
func (c *unixClient) Request(data []byte, timeout time.Duration) ([]byte, error) {
|
|
// 发送请求
|
|
if err := c.Send(data); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 设置读取超时
|
|
oldTimeout := c.config.ReadTimeout
|
|
c.config.ReadTimeout = timeout
|
|
defer func() {
|
|
c.config.ReadTimeout = oldTimeout
|
|
}()
|
|
|
|
// 接收响应
|
|
return c.Receive()
|
|
}
|
|
|
|
// IsConnected 检查是否已连接
|
|
func (c *unixClient) IsConnected() bool {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
return c.connected
|
|
}
|
|
|
|
// Close 关闭客户端
|
|
func (c *unixClient) Close() error {
|
|
return c.Disconnect()
|
|
}
|
|
|