diff --git a/pkg/core/client.go b/pkg/core/client.go index 70ae934..fc099d8 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -76,10 +76,7 @@ func (c *Client) connectToServer() error { return err } if c.obfuscator != nil { - pktConn = &obfsUDPConn{ - Orig: udpConn, - Obfuscator: c.obfuscator, - } + pktConn = newObfsUDPConn(udpConn, c.obfuscator) } else { pktConn = udpConn } @@ -89,10 +86,7 @@ func (c *Client) connectToServer() error { return err } if c.obfuscator != nil { - pktConn = &obfsPacketConn{ - Orig: ftcpConn, - Obfuscator: c.obfuscator, - } + pktConn = newObfsPacketConn(ftcpConn, c.obfuscator) } else { pktConn = ftcpConn } diff --git a/pkg/core/obfs.go b/pkg/core/obfs.go index 419a31a..d67b2b9 100644 --- a/pkg/core/obfs.go +++ b/pkg/core/obfs.go @@ -3,41 +3,60 @@ package core import ( "net" "os" + "sync" "syscall" "time" ) type Obfuscator interface { Deobfuscate(in []byte, out []byte) int - Obfuscate(p []byte) []byte + Obfuscate(in []byte, out []byte) int } type obfsUDPConn struct { - Orig *net.UDPConn - Obfuscator Obfuscator + orig *net.UDPConn + obfs Obfuscator + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func newObfsUDPConn(orig *net.UDPConn, obfs Obfuscator) *obfsUDPConn { + return &obfsUDPConn{ + orig: orig, + obfs: obfs, + readBuf: make([]byte, udpBufferSize), + writeBuf: make([]byte, udpBufferSize), + } } func (c *obfsUDPConn) ReadFrom(p []byte) (int, net.Addr, error) { - buf := make([]byte, udpBufferSize) for { - n, addr, err := c.Orig.ReadFrom(buf) + c.readMutex.Lock() + n, addr, err := c.orig.ReadFrom(c.readBuf) if n <= 0 { + c.readMutex.Unlock() return 0, addr, err } - newN := c.Obfuscator.Deobfuscate(buf[:n], p) + newN := c.obfs.Deobfuscate(c.readBuf[:n], p) + c.readMutex.Unlock() if newN > 0 { // Valid packet return newN, addr, err } else if err != nil { - // Not valid and Orig.ReadFrom had some error + // Not valid and orig.ReadFrom had some error return 0, addr, err } } } func (c *obfsUDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - np := c.Obfuscator.Obfuscate(p) - _, err = c.Orig.WriteTo(np, addr) + c.writeMutex.Lock() + bn := c.obfs.Obfuscate(p, c.writeBuf) + _, err = c.orig.WriteTo(c.writeBuf[:bn], addr) + c.writeMutex.Unlock() if err != nil { return 0, err } else { @@ -46,67 +65,85 @@ func (c *obfsUDPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } func (c *obfsUDPConn) Close() error { - return c.Orig.Close() + return c.orig.Close() } func (c *obfsUDPConn) LocalAddr() net.Addr { - return c.Orig.LocalAddr() + return c.orig.LocalAddr() } func (c *obfsUDPConn) SetDeadline(t time.Time) error { - return c.Orig.SetDeadline(t) + return c.orig.SetDeadline(t) } func (c *obfsUDPConn) SetReadDeadline(t time.Time) error { - return c.Orig.SetReadDeadline(t) + return c.orig.SetReadDeadline(t) } func (c *obfsUDPConn) SetWriteDeadline(t time.Time) error { - return c.Orig.SetWriteDeadline(t) + return c.orig.SetWriteDeadline(t) } func (c *obfsUDPConn) SetReadBuffer(bytes int) error { - return c.Orig.SetReadBuffer(bytes) + return c.orig.SetReadBuffer(bytes) } func (c *obfsUDPConn) SetWriteBuffer(bytes int) error { - return c.Orig.SetWriteBuffer(bytes) + return c.orig.SetWriteBuffer(bytes) } func (c *obfsUDPConn) SyscallConn() (syscall.RawConn, error) { - return c.Orig.SyscallConn() + return c.orig.SyscallConn() } func (c *obfsUDPConn) File() (f *os.File, err error) { - return c.Orig.File() + return c.orig.File() } type obfsPacketConn struct { - Orig net.PacketConn - Obfuscator Obfuscator + orig net.PacketConn + obfs Obfuscator + + readBuf []byte + readMutex sync.Mutex + writeBuf []byte + writeMutex sync.Mutex +} + +func newObfsPacketConn(orig net.PacketConn, obfs Obfuscator) *obfsPacketConn { + return &obfsPacketConn{ + orig: orig, + obfs: obfs, + readBuf: make([]byte, udpBufferSize), + writeBuf: make([]byte, udpBufferSize), + } } func (c *obfsPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { - buf := make([]byte, udpBufferSize) for { - n, addr, err := c.Orig.ReadFrom(buf) + c.readMutex.Lock() + n, addr, err := c.orig.ReadFrom(c.readBuf) if n <= 0 { + c.readMutex.Unlock() return 0, addr, err } - newN := c.Obfuscator.Deobfuscate(buf[:n], p) + newN := c.obfs.Deobfuscate(c.readBuf[:n], p) + c.readMutex.Unlock() if newN > 0 { // Valid packet return newN, addr, err } else if err != nil { - // Not valid and Orig.ReadFrom had some error + // Not valid and orig.ReadFrom had some error return 0, addr, err } } } func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - np := c.Obfuscator.Obfuscate(p) - _, err = c.Orig.WriteTo(np, addr) + c.writeMutex.Lock() + bn := c.obfs.Obfuscate(p, c.writeBuf) + _, err = c.orig.WriteTo(c.writeBuf[:bn], addr) + c.writeMutex.Unlock() if err != nil { return 0, err } else { @@ -115,21 +152,21 @@ func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { } func (c *obfsPacketConn) Close() error { - return c.Orig.Close() + return c.orig.Close() } func (c *obfsPacketConn) LocalAddr() net.Addr { - return c.Orig.LocalAddr() + return c.orig.LocalAddr() } func (c *obfsPacketConn) SetDeadline(t time.Time) error { - return c.Orig.SetDeadline(t) + return c.orig.SetDeadline(t) } func (c *obfsPacketConn) SetReadDeadline(t time.Time) error { - return c.Orig.SetReadDeadline(t) + return c.orig.SetReadDeadline(t) } func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error { - return c.Orig.SetWriteDeadline(t) + return c.orig.SetWriteDeadline(t) } diff --git a/pkg/core/server.go b/pkg/core/server.go index 79a1cb8..4887be5 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -53,10 +53,7 @@ func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig * return nil, err } if obfuscator != nil { - pktConn = &obfsUDPConn{ - Orig: udpConn, - Obfuscator: obfuscator, - } + pktConn = newObfsUDPConn(udpConn, obfuscator) } else { pktConn = udpConn } @@ -66,10 +63,7 @@ func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig * return nil, err } if obfuscator != nil { - pktConn = &obfsPacketConn{ - Orig: ftcpConn, - Obfuscator: obfuscator, - } + pktConn = newObfsPacketConn(ftcpConn, obfuscator) } else { pktConn = ftcpConn } diff --git a/pkg/obfs/xor.go b/pkg/obfs/xor.go deleted file mode 100644 index 1cee94e..0000000 --- a/pkg/obfs/xor.go +++ /dev/null @@ -1,20 +0,0 @@ -package obfs - -type XORObfuscator []byte - -func (x XORObfuscator) Deobfuscate(in []byte, out []byte) int { - l := len(x) - for i := range in { - out[i] = in[i] ^ x[i%l] - } - return len(in) -} - -func (x XORObfuscator) Obfuscate(p []byte) []byte { - np := make([]byte, len(p)) - l := len(x) - for i := range p { - np[i] = p[i] ^ x[i%l] - } - return np -} diff --git a/pkg/obfs/xplus.go b/pkg/obfs/xplus.go index b1484d6..dd63645 100644 --- a/pkg/obfs/xplus.go +++ b/pkg/obfs/xplus.go @@ -39,16 +39,14 @@ func (x *XPlusObfuscator) Deobfuscate(in []byte, out []byte) int { return pLen } -func (x *XPlusObfuscator) Obfuscate(p []byte) []byte { - pLen := len(p) - buf := make([]byte, saltLen+pLen) +func (x *XPlusObfuscator) Obfuscate(in []byte, out []byte) int { x.lk.Lock() - _, _ = x.RandSrc.Read(buf[:saltLen]) // salt + _, _ = x.RandSrc.Read(out[:saltLen]) // salt x.lk.Unlock() // Obfuscate the payload - key := sha256.Sum256(append(x.Key, buf[:saltLen]...)) - for i, c := range p { - buf[i+saltLen] = c ^ key[i%sha256.Size] + key := sha256.Sum256(append(x.Key, out[:saltLen]...)) + for i, c := range in { + out[i+saltLen] = c ^ key[i%sha256.Size] } - return buf + return len(in) + saltLen } diff --git a/pkg/obfs/xplus_test.go b/pkg/obfs/xplus_test.go index baff4c2..c1cf629 100644 --- a/pkg/obfs/xplus_test.go +++ b/pkg/obfs/xplus_test.go @@ -20,11 +20,11 @@ func TestXPlusObfuscator(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - bs := x.Obfuscate(tt.p) - outBs := make([]byte, len(bs)) - n := x.Deobfuscate(bs, outBs) - if !bytes.Equal(tt.p, outBs[:n]) { - t.Errorf("Inconsistent deobfuscate result: got %v, want %v", outBs[:n], tt.p) + buf := make([]byte, 10240) + n := x.Obfuscate(tt.p, buf) + n2 := x.Deobfuscate(buf[:n], buf[n:]) + if !bytes.Equal(tt.p, buf[n:n+n2]) { + t.Errorf("Inconsistent deobfuscate result: got %v, want %v", buf[n:n+n2], tt.p) } }) }