Merge pull request #953 from apernet/wip-udphop-listenudpfunc

feat: allow set ListenUDP impl for udphop conn
This commit is contained in:
Toby 2024-02-29 16:17:40 -08:00 committed by GitHub
commit 982be5498b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 15 deletions

View File

@ -179,7 +179,7 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error {
if hyConfig.ServerAddr.Network() == "udphop" { if hyConfig.ServerAddr.Network() == "udphop" {
hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr) hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr)
newFunc = func(addr net.Addr) (net.PacketConn, error) { newFunc = func(addr net.Addr) (net.PacketConn, error) {
return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval) return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, nil)
} }
} else { } else {
newFunc = func(addr net.Addr) (net.PacketConn, error) { newFunc = func(addr net.Addr) (net.PacketConn, error) {

View File

@ -17,9 +17,10 @@ const (
) )
type udpHopPacketConn struct { type udpHopPacketConn struct {
Addr net.Addr Addr net.Addr
Addrs []net.Addr Addrs []net.Addr
HopInterval time.Duration HopInterval time.Duration
ListenUDPFunc ListenUDPFunc
connMutex sync.RWMutex connMutex sync.RWMutex
prevConn net.PacketConn prevConn net.PacketConn
@ -43,29 +44,37 @@ type udpPacket struct {
Err error Err error
} }
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) { type ListenUDPFunc func() (net.PacketConn, error)
func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) {
if hopInterval == 0 { if hopInterval == 0 {
hopInterval = defaultHopInterval hopInterval = defaultHopInterval
} else if hopInterval < 5*time.Second { } else if hopInterval < 5*time.Second {
return nil, errors.New("hop interval must be at least 5 seconds") return nil, errors.New("hop interval must be at least 5 seconds")
} }
if listenUDPFunc == nil {
listenUDPFunc = func() (net.PacketConn, error) {
return net.ListenUDP("udp", nil)
}
}
addrs, err := addr.addrs() addrs, err := addr.addrs()
if err != nil { if err != nil {
return nil, err return nil, err
} }
curConn, err := net.ListenUDP("udp", nil) curConn, err := listenUDPFunc()
if err != nil { if err != nil {
return nil, err return nil, err
} }
hConn := &udpHopPacketConn{ hConn := &udpHopPacketConn{
Addr: addr, Addr: addr,
Addrs: addrs, Addrs: addrs,
HopInterval: hopInterval, HopInterval: hopInterval,
prevConn: nil, ListenUDPFunc: listenUDPFunc,
currentConn: curConn, prevConn: nil,
addrIndex: rand.Intn(len(addrs)), currentConn: curConn,
recvQueue: make(chan *udpPacket, packetQueueSize), addrIndex: rand.Intn(len(addrs)),
closeChan: make(chan struct{}), recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}),
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, udpBufferSize) return make([]byte, udpBufferSize)
@ -121,7 +130,7 @@ func (u *udpHopPacketConn) hop() {
if u.closed { if u.closed {
return return
} }
newConn, err := net.ListenUDP("udp", nil) newConn, err := u.ListenUDPFunc()
if err != nil { if err != nil {
// Could be temporary, just skip this hop // Could be temporary, just skip this hop
return return