diff --git a/app/cmd/client.go b/app/cmd/client.go index b20facb..9a785d2 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -179,7 +179,7 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error { if hyConfig.ServerAddr.Network() == "udphop" { hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr) newFunc = func(addr net.Addr) (net.PacketConn, error) { - return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval) + return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, nil) } } else { newFunc = func(addr net.Addr) (net.PacketConn, error) { diff --git a/extras/transport/udphop/conn.go b/extras/transport/udphop/conn.go index ccb0b38..f20c583 100644 --- a/extras/transport/udphop/conn.go +++ b/extras/transport/udphop/conn.go @@ -17,9 +17,10 @@ const ( ) type udpHopPacketConn struct { - Addr net.Addr - Addrs []net.Addr - HopInterval time.Duration + Addr net.Addr + Addrs []net.Addr + HopInterval time.Duration + ListenUDPFunc ListenUDPFunc connMutex sync.RWMutex prevConn net.PacketConn @@ -43,29 +44,37 @@ type udpPacket struct { Err error } -func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration) (net.PacketConn, error) { +type ListenUDPFunc func() (net.PacketConn, error) + +func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) { if hopInterval == 0 { hopInterval = defaultHopInterval } else if hopInterval < 5*time.Second { return nil, errors.New("hop interval must be at least 5 seconds") } + if listenUDPFunc == nil { + listenUDPFunc = func() (net.PacketConn, error) { + return net.ListenUDP("udp", nil) + } + } addrs, err := addr.addrs() if err != nil { return nil, err } - curConn, err := net.ListenUDP("udp", nil) + curConn, err := listenUDPFunc() if err != nil { return nil, err } hConn := &udpHopPacketConn{ - Addr: addr, - Addrs: addrs, - HopInterval: hopInterval, - prevConn: nil, - currentConn: curConn, - addrIndex: rand.Intn(len(addrs)), - recvQueue: make(chan *udpPacket, packetQueueSize), - closeChan: make(chan struct{}), + Addr: addr, + Addrs: addrs, + HopInterval: hopInterval, + ListenUDPFunc: listenUDPFunc, + prevConn: nil, + currentConn: curConn, + addrIndex: rand.Intn(len(addrs)), + recvQueue: make(chan *udpPacket, packetQueueSize), + closeChan: make(chan struct{}), bufPool: sync.Pool{ New: func() interface{} { return make([]byte, udpBufferSize) @@ -121,7 +130,7 @@ func (u *udpHopPacketConn) hop() { if u.closed { return } - newConn, err := net.ListenUDP("udp", nil) + newConn, err := u.ListenUDPFunc() if err != nil { // Could be temporary, just skip this hop return