diff --git a/cmd/main.go b/cmd/main.go index dbaba52..0d5a052 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -131,7 +131,7 @@ func initApp(c *cli.Context) error { "version", "url", "config", "file", "mode", "addr", "src", "dst", "session", "action", - "error", + "msg", "error", }, TimestampFormat: c.String("log-timestamp"), }) diff --git a/cmd/server.go b/cmd/server.go index a659391..a4529bb 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -128,6 +128,20 @@ func server(config *serverConfig) { default: logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode") } + connectFunc := func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) { + ok, msg := authFunc(addr, auth, sSend, sRecv) + if !ok { + logrus.WithFields(logrus.Fields{ + "src": addr, + "msg": msg, + }).Info("Authentication failed, client rejected") + } else { + logrus.WithFields(logrus.Fields{ + "src": addr, + }).Info("Client connected") + } + return ok, msg + } // Obfuscator var obfuscator obfs.Obfuscator if len(config.Obfs) > 0 { @@ -169,7 +183,7 @@ func server(config *serverConfig) { uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, config.DisableUDP, aclEngine, obfuscator, authFunc, + }, config.DisableUDP, aclEngine, obfuscator, connectFunc, disconnectFunc, tcpRequestFunc, tcpErrorFunc, udpRequestFunc, udpErrorFunc, promReg) if err != nil { logrus.WithField("error", err).Fatal("Failed to initialize server") @@ -181,6 +195,13 @@ func server(config *serverConfig) { logrus.WithField("error", err).Fatal("Server shutdown") } +func disconnectFunc(addr net.Addr, auth []byte, err error) { + logrus.WithFields(logrus.Fields{ + "src": addr, + "error": err, + }).Info("Client disconnected") +} + func tcpRequestFunc(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) { logrus.WithFields(logrus.Fields{ "src": addr.String(), diff --git a/pkg/core/server.go b/pkg/core/server.go index 3be0a00..76b2565 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -14,7 +14,8 @@ import ( "net" ) -type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) +type ConnectFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) +type DisconnectFunc func(addr net.Addr, auth []byte, err error) type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error) type UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32) @@ -27,7 +28,8 @@ type Server struct { disableUDP bool aclEngine *acl.Engine - authFunc AuthFunc + connectFunc ConnectFunc + disconnectFunc DisconnectFunc tcpRequestFunc TCPRequestFunc tcpErrorFunc TCPErrorFunc udpRequestFunc UDPRequestFunc @@ -41,7 +43,8 @@ type Server struct { func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig *quic.Config, transport transport2.Transport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, - obfuscator obfs.Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, + obfuscator obfs.Obfuscator, connectFunc ConnectFunc, disconnectFunc DisconnectFunc, + tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, promRegistry *prometheus.Registry) (*Server, error) { pktConn, err := transport.QUICPacketConn(protocol, true, addr, "", obfuscator) if err != nil { @@ -60,7 +63,8 @@ func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig * congestionFactory: congestionFactory, disableUDP: disableUDP, aclEngine: aclEngine, - authFunc: authFunc, + connectFunc: connectFunc, + disconnectFunc: disconnectFunc, tcpRequestFunc: tcpRequestFunc, tcpErrorFunc: tcpErrorFunc, udpRequestFunc: udpRequestFunc, @@ -118,8 +122,9 @@ func (s *Server) handleClient(cs quic.Session) { sc := newServerClient(cs, s.transport, auth, s.disableUDP, s.aclEngine, s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.upCounterVec, s.downCounterVec, s.connGaugeVec) - sc.Run() + err = sc.Run() _ = cs.CloseWithError(closeErrorCodeGeneric, "") + s.disconnectFunc(cs.RemoteAddr(), auth, err) } // Auth & negotiate speed @@ -151,7 +156,7 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byt serverRecvBPS = s.recvBPS } // Auth - ok, msg := s.authFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS) + ok, msg := s.connectFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS) // Response err = struc.Pack(stream, &serverHello{ OK: ok, diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index 50f49c4..857bb12 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -64,7 +64,7 @@ func newServerClient(cs quic.Session, transport transport.Transport, auth []byte return sc } -func (c *serverClient) Run() { +func (c *serverClient) Run() error { if !c.DisableUDP { go func() { for { @@ -79,7 +79,7 @@ func (c *serverClient) Run() { for { stream, err := c.CS.AcceptStream(context.Background()) if err != nil { - break + return err } if c.ConnGauge != nil { c.ConnGauge.Inc()