mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-07-20 02:18:02 +00:00
feat: TProxy
This commit is contained in:
138
app/internal/tproxy/udp_linux.go
Normal file
138
app/internal/tproxy/udp_linux.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package tproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/apernet/go-tproxy"
|
||||
"github.com/apernet/hysteria/core/client"
|
||||
)
|
||||
|
||||
const (
|
||||
udpBufferSize = 4096
|
||||
defaultTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
type UDPTProxy struct {
|
||||
HyClient client.Client
|
||||
Timeout time.Duration
|
||||
EventLogger UDPEventLogger
|
||||
}
|
||||
|
||||
type UDPEventLogger interface {
|
||||
Connect(addr, reqAddr net.Addr)
|
||||
Error(addr, reqAddr net.Addr, err error)
|
||||
}
|
||||
|
||||
func (r *UDPTProxy) ListenAndServe(laddr *net.UDPAddr) error {
|
||||
conn, err := tproxy.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, udpBufferSize)
|
||||
for {
|
||||
// We will only get the first packet of each src/dst pair here,
|
||||
// because newPair will create a TProxy connection and take over
|
||||
// the src/dst pair. Later packets will be sent there instead of here.
|
||||
n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(conn, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.newPair(srcAddr, dstAddr, buf[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UDPTProxy) newPair(srcAddr, dstAddr *net.UDPAddr, initPkt []byte) {
|
||||
if r.EventLogger != nil {
|
||||
r.EventLogger.Connect(srcAddr, dstAddr)
|
||||
}
|
||||
var closeErr error
|
||||
defer func() {
|
||||
// If closeErr is nil, it means we at least successfully sent the first packet
|
||||
// and started forwarding, in which case we don't call the error logger.
|
||||
if r.EventLogger != nil && closeErr != nil {
|
||||
r.EventLogger.Error(srcAddr, dstAddr, closeErr)
|
||||
}
|
||||
}()
|
||||
conn, err := tproxy.DialUDP("udp", dstAddr, srcAddr)
|
||||
if err != nil {
|
||||
closeErr = err
|
||||
return
|
||||
}
|
||||
hyConn, err := r.HyClient.UDP()
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
closeErr = err
|
||||
return
|
||||
}
|
||||
// Send the first packet
|
||||
err = hyConn.Send(initPkt, dstAddr.String())
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
_ = hyConn.Close()
|
||||
closeErr = err
|
||||
return
|
||||
}
|
||||
// Start forwarding
|
||||
go r.forwarding(conn, hyConn, dstAddr.String())
|
||||
}
|
||||
|
||||
func (r *UDPTProxy) forwarding(conn *net.UDPConn, hyConn client.HyUDPConn, dst string) {
|
||||
errChan := make(chan error, 2)
|
||||
// Local <- Remote
|
||||
go func() {
|
||||
for {
|
||||
bs, _, err := hyConn.Receive()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(bs)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
_ = r.updateConnDeadline(conn)
|
||||
}
|
||||
}()
|
||||
// Local -> Remote
|
||||
go func() {
|
||||
buf := make([]byte, udpBufferSize)
|
||||
for {
|
||||
_ = r.updateConnDeadline(conn)
|
||||
n, err := conn.Read(buf)
|
||||
if n > 0 {
|
||||
err := hyConn.Send(buf[:n], dst)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
err := <-errChan
|
||||
_ = conn.Close()
|
||||
_ = hyConn.Close()
|
||||
if r.EventLogger != nil {
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
// We don't consider deadline exceeded (timeout) an error
|
||||
err = nil
|
||||
}
|
||||
r.EventLogger.Error(conn.LocalAddr(), conn.RemoteAddr(), err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *UDPTProxy) updateConnDeadline(conn *net.UDPConn) error {
|
||||
if r.Timeout == 0 {
|
||||
return conn.SetReadDeadline(time.Now().Add(defaultTimeout))
|
||||
} else {
|
||||
return conn.SetReadDeadline(time.Now().Add(r.Timeout))
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user