diff --git a/app/cmd/client.go b/app/cmd/client.go index 2f127bb..2bbd5ea 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -310,7 +310,7 @@ func runClient(cmd *cobra.Command, args []string) { logger.Fatal("failed to load client config", zap.Error(err)) } - c, err := client.NewClient(hyConfig) + c, err := client.NewReconnectableClient(hyConfig, connectLog, false) if err != nil { logger.Fatal("failed to initialize client", zap.Error(err)) } @@ -487,6 +487,15 @@ func (f *obfsConnFactory) New(addr net.Addr) (net.PacketConn, error) { return obfs.WrapPacketConn(conn, f.Obfuscator), nil } +func connectLog(count int) { + if count == 1 { + logger.Info("connected to server") + } else { + // Not the first time, we have reconnected + logger.Info("reconnected to server", zap.Int("count", count)) + } +} + type socks5Logger struct{} func (l *socks5Logger) TCPRequest(addr net.Addr, reqAddr string) { diff --git a/core/client/client.go b/core/client/client.go index 11fe2f6..b269c4d 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -35,7 +35,7 @@ type HyUDPConn interface { } func NewClient(config *Config) (Client, error) { - if err := config.fill(); err != nil { + if err := config.verifyAndFill(); err != nil { return nil, err } c := &clientImpl{ diff --git a/core/client/config.go b/core/client/config.go index 2fad846..d6f366b 100644 --- a/core/client/config.go +++ b/core/client/config.go @@ -24,11 +24,16 @@ type Config struct { QUICConfig QUICConfig BandwidthConfig BandwidthConfig FastOpen bool + + filled bool // whether the fields have been verified and filled } -// fill fills the fields that are not set by the user with default values when possible, -// and returns an error if the user has not set a required field. -func (c *Config) fill() error { +// verifyAndFill fills the fields that are not set by the user with default values when possible, +// and returns an error if the user has not set a required field or has set an invalid value. +func (c *Config) verifyAndFill() error { + if c.filled { + return nil + } if c.ConnFactory == nil { c.ConnFactory = &udpConnFactory{} } @@ -66,6 +71,8 @@ func (c *Config) fill() error { return errors.ConfigError{Field: "QUICConfig.KeepAlivePeriod", Reason: "must be between 2s and 60s"} } c.QUICConfig.DisablePathMTUDiscovery = c.QUICConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery + + c.filled = true return nil } diff --git a/core/client/reconnect.go b/core/client/reconnect.go new file mode 100644 index 0000000..08da63f --- /dev/null +++ b/core/client/reconnect.go @@ -0,0 +1,115 @@ +package client + +import ( + "net" + "sync" + + coreErrs "github.com/apernet/hysteria/core/errors" +) + +// reconnectableClientImpl is a wrapper of Client, which can reconnect when the connection is closed, +// except when the caller explicitly calls Close() to permanently close this client. +type reconnectableClientImpl struct { + config *Config + client Client + count int + connectedFunc func(int) // called when successfully connected + m sync.Mutex + closed bool // permanent close +} + +func NewReconnectableClient(config *Config, connectedFunc func(int), lazy bool) (Client, error) { + // Make sure we capture any error in config and return it here, + // so that the caller doesn't have to wait until the first call + // to TCP() or UDP() to get the error (when lazy is true). + if err := config.verifyAndFill(); err != nil { + return nil, err + } + rc := &reconnectableClientImpl{ + config: config, + connectedFunc: connectedFunc, + } + if !lazy { + if err := rc.reconnect(); err != nil { + return nil, err + } + } + return rc, nil +} + +func (rc *reconnectableClientImpl) reconnect() error { + if rc.client != nil { + _ = rc.client.Close() + } + var err error + rc.client, err = NewClient(rc.config) + if err != nil { + return err + } else { + rc.count++ + if rc.connectedFunc != nil { + rc.connectedFunc(rc.count) + } + return nil + } +} + +func (rc *reconnectableClientImpl) TCP(addr string) (net.Conn, error) { + rc.m.Lock() + defer rc.m.Unlock() + if rc.closed { + return nil, coreErrs.ClosedError{} + } + if rc.client == nil { + // First time + if err := rc.reconnect(); err != nil { + return nil, err + } + } + conn, err := rc.client.TCP(addr) + if _, ok := err.(coreErrs.ClosedError); ok { + // Connection closed, reconnect + if err := rc.reconnect(); err != nil { + return nil, err + } + return rc.client.TCP(addr) + } else { + // OK or some other temporary error + return conn, err + } +} + +func (rc *reconnectableClientImpl) UDP() (HyUDPConn, error) { + rc.m.Lock() + defer rc.m.Unlock() + if rc.closed { + return nil, coreErrs.ClosedError{} + } + if rc.client == nil { + // First time + if err := rc.reconnect(); err != nil { + return nil, err + } + } + conn, err := rc.client.UDP() + if _, ok := err.(coreErrs.ClosedError); ok { + // Connection closed, reconnect + if err := rc.reconnect(); err != nil { + return nil, err + } + return rc.client.UDP() + } else { + // OK or some other temporary error + return conn, err + } +} + +func (rc *reconnectableClientImpl) Close() error { + rc.m.Lock() + defer rc.m.Unlock() + rc.closed = true + if rc.client != nil { + return rc.client.Close() + } + return nil +}