You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
2.8 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package nnet
import (
"github.com/noahlann/nnet/internal/client"
clientpkg "github.com/noahlann/nnet/pkg/client"
)
// Client 客户端类型别名
type Client = clientpkg.Client
// ClientConfig 客户端配置类型别名
type ClientConfig = clientpkg.Config
// ClientTLSConfig 客户端TLS配置类型别名
type ClientTLSConfig = clientpkg.TLSConfig
// ClientPool 客户端连接池类型别名
type ClientPool = client.Pool
// ClientPoolConfig 客户端连接池配置类型别名
type ClientPoolConfig = client.PoolConfig
// NewClientPool 创建客户端连接池
func NewClientPool(config *ClientPoolConfig) (*ClientPool, error) {
return client.NewPool(config)
}
// DefaultClientPoolConfig 返回默认客户端连接池配置
func DefaultClientPoolConfig() *ClientPoolConfig {
return client.DefaultPoolConfig()
}
// NewClient 创建客户端(根据配置或地址自动选择传输层协议)
// 传输层协议优先级:
// 1. 如果配置中指定了TransportProtocol使用指定的协议
// 2. 否则根据地址前缀判断tcp://, udp://, ws://, wss://, unix://, serial://
// 3. 如果无法识别默认使用TCP
// 应用层协议通过 ApplicationProtocol 配置指定如nnet
func NewClient(config *ClientConfig) Client {
if config == nil {
config = clientpkg.DefaultConfig()
}
// 优先使用配置中指定的传输层协议
if config.TransportProtocol != "" {
return newClientByProtocol(config, config.TransportProtocol)
}
// 根据地址前缀判断传输层协议
return newClientByAddr(config)
}
// newClientByProtocol 根据指定的传输层协议创建客户端
func newClientByProtocol(config *ClientConfig, protocol string) Client {
switch protocol {
case "tcp":
return client.NewTCPClient(config)
case "udp":
return client.NewUDPClient(config)
case "websocket", "ws", "wss":
return client.NewWebSocketClient(config)
case "unix":
return client.NewUnixClient(config)
case "serial":
return client.NewSerialClient(config)
default:
// 默认使用TCP客户端
return client.NewTCPClient(config)
}
}
// newClientByAddr 根据地址前缀判断传输层协议
func newClientByAddr(config *ClientConfig) Client {
addr := config.Addr
if len(addr) >= 8 {
if addr[:8] == "serial://" {
return client.NewSerialClient(config)
}
if len(addr) >= 7 && addr[:7] == "serial:" {
return client.NewSerialClient(config)
}
}
if len(addr) >= 6 {
if addr[:6] == "wss://" {
return client.NewWebSocketClient(config)
}
}
if len(addr) >= 5 {
if addr[:5] == "ws://" {
return client.NewWebSocketClient(config)
}
if addr[:5] == "udp://" {
return client.NewUDPClient(config)
}
if addr[:5] == "unix:" {
return client.NewUnixClient(config)
}
}
// 默认使用TCP客户端
return client.NewTCPClient(config)
}