From d320aadb54fd19764d0f3404ef30468b01b40f03 Mon Sep 17 00:00:00 2001
From: thank243 <thank243@gmail.com>
Date: Mon, 28 Nov 2022 11:16:27 +0800
Subject: [PATCH] fix: global limit will failure on some situation (#124)

fix: typo
---
 app/mydispatcher/default.go      |  2 +-
 common/limiter/limiter.go        | 11 +++----
 common/mylego/model.go           | 16 +++++-----
 main/main.go                     |  2 +-
 service/controller/control.go    | 12 +++----
 service/controller/controller.go | 54 ++++++++++++++++++++++++++------
 6 files changed, 65 insertions(+), 32 deletions(-)

diff --git a/app/mydispatcher/default.go b/app/mydispatcher/default.go
index 0c4a8ba..a7efa63 100644
--- a/app/mydispatcher/default.go
+++ b/app/mydispatcher/default.go
@@ -234,7 +234,7 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sn
 		// Speed Limit and Device Limit
 		bucket, ok, reject := d.Limiter.GetUserBucket(sessionInbound.Tag, user.Email, sessionInbound.Source.Address.IP().String())
 		if reject {
-			newError("Devices reach the limit: ", user.Email).AtError().WriteToLog()
+			newError("Devices reach the limit: ", user.Email).AtWarning().WriteToLog()
 			common.Close(outboundLink.Writer)
 			common.Close(inboundLink.Writer)
 			common.Interrupt(outboundLink.Reader)
diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go
index 9d41fc3..951e797 100644
--- a/common/limiter/limiter.go
+++ b/common/limiter/limiter.go
@@ -114,7 +114,8 @@ func (l *Limiter) DeleteInboundLimiter(tag string) error {
 }
 
 func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) {
-	onlineUser := make([]api.OnlineUser, 0)
+	var onlineUser []api.OnlineUser
+
 	if value, ok := l.InboundInfo.Load(tag); ok {
 		inboundInfo := value.(*InboundInfo)
 		// Clear Speed Limiter bucket for users who are not online
@@ -128,13 +129,11 @@ func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) {
 		inboundInfo.UserOnlineIP.Range(func(key, value interface{}) bool {
 			ipMap := value.(*sync.Map)
 			ipMap.Range(func(key, value interface{}) bool {
-				ip := key.(string)
 				uid := value.(int)
+				ip := key.(string)
 				onlineUser = append(onlineUser, api.OnlineUser{UID: uid, IP: ip})
 				return true
 			})
-			email := key.(string)
-			inboundInfo.UserOnlineIP.Delete(email) // Reset online device
 			return true
 		})
 	} else {
@@ -187,8 +186,8 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r
 		// If any device is online
 		if v, ok := inboundInfo.UserOnlineIP.LoadOrStore(email, ipMap); ok {
 			ipMap := v.(*sync.Map)
-			// If this ip is a new device
-			if _, ok := ipMap.LoadOrStore(ip, uid); !ok {
+			// If this is a new ip
+			if _, ok := ipMap.LoadOrStore(ip, uid); !ok || l.g.enable {
 				counter := 0
 				ipMap.Range(func(key, value interface{}) bool {
 					counter++
diff --git a/common/mylego/model.go b/common/mylego/model.go
index 5e565b1..270907b 100644
--- a/common/mylego/model.go
+++ b/common/mylego/model.go
@@ -1,14 +1,14 @@
 package mylego
 
 type CertConfig struct {
-	CertMode                string            `mapstructure:"CertMode"` // none, file, http, dns
-	CertDomain              string            `mapstructure:"CertDomain"`
-	CertFile                string            `mapstructure:"CertFile"`
-	KeyFile                 string            `mapstructure:"KeyFile"`
-	Provider                string            `mapstructure:"Provider"` // alidns, cloudflare, gandi, godaddy....
-	Email                   string            `mapstructure:"Email"`
-	DNSEnv                  map[string]string `mapstructure:"DNSEnv"`
-	RejectUnknownSni        bool              `mapstructure:"RejectUnknownSni"`
+	CertMode         string            `mapstructure:"CertMode"` // none, file, http, dns
+	CertDomain       string            `mapstructure:"CertDomain"`
+	CertFile         string            `mapstructure:"CertFile"`
+	KeyFile          string            `mapstructure:"KeyFile"`
+	Provider         string            `mapstructure:"Provider"` // alidns, cloudflare, gandi, godaddy....
+	Email            string            `mapstructure:"Email"`
+	DNSEnv           map[string]string `mapstructure:"DNSEnv"`
+	RejectUnknownSni bool              `mapstructure:"RejectUnknownSni"`
 }
 
 type LegoCMD struct {
diff --git a/main/main.go b/main/main.go
index ce96527..f8c141b 100644
--- a/main/main.go
+++ b/main/main.go
@@ -24,7 +24,7 @@ var (
 )
 
 var (
-	version  = "0.8.6"
+	version  = "0.8.7"
 	codename = "XrayR"
 	intro    = "A Xray backend that supports many panels"
 )
diff --git a/service/controller/control.go b/service/controller/control.go
index cd8b885..e06f4fe 100644
--- a/service/controller/control.go
+++ b/service/controller/control.go
@@ -16,12 +16,12 @@ import (
 )
 
 func (c *Controller) removeInbound(tag string) error {
-	err := c.ihm.RemoveHandler(context.Background(), tag)
+	err := c.ibm.RemoveHandler(context.Background(), tag)
 	return err
 }
 
 func (c *Controller) removeOutbound(tag string) error {
-	err := c.ohm.RemoveHandler(context.Background(), tag)
+	err := c.obm.RemoveHandler(context.Background(), tag)
 	return err
 }
 
@@ -34,7 +34,7 @@ func (c *Controller) addInbound(config *core.InboundHandlerConfig) error {
 	if !ok {
 		return fmt.Errorf("not an InboundHandler: %s", err)
 	}
-	if err := c.ihm.AddHandler(context.Background(), handler); err != nil {
+	if err := c.ibm.AddHandler(context.Background(), handler); err != nil {
 		return err
 	}
 	return nil
@@ -49,14 +49,14 @@ func (c *Controller) addOutbound(config *core.OutboundHandlerConfig) error {
 	if !ok {
 		return fmt.Errorf("not an InboundHandler: %s", err)
 	}
-	if err := c.ohm.AddHandler(context.Background(), handler); err != nil {
+	if err := c.obm.AddHandler(context.Background(), handler); err != nil {
 		return err
 	}
 	return nil
 }
 
 func (c *Controller) addUsers(users []*protocol.User, tag string) error {
-	handler, err := c.ihm.GetHandler(context.Background(), tag)
+	handler, err := c.ibm.GetHandler(context.Background(), tag)
 	if err != nil {
 		return fmt.Errorf("no such inbound tag: %s", err)
 	}
@@ -83,7 +83,7 @@ func (c *Controller) addUsers(users []*protocol.User, tag string) error {
 }
 
 func (c *Controller) removeUsers(users []string, tag string) error {
-	handler, err := c.ihm.GetHandler(context.Background(), tag)
+	handler, err := c.ibm.GetHandler(context.Background(), tag)
 	if err != nil {
 		return fmt.Errorf("no such inbound tag: %s", err)
 	}
diff --git a/service/controller/controller.go b/service/controller/controller.go
index ae33fa3..01a7b1c 100644
--- a/service/controller/controller.go
+++ b/service/controller/controller.go
@@ -43,8 +43,8 @@ type Controller struct {
 	limitedUsers map[api.UserInfo]LimitInfo
 	warnedUsers  map[api.UserInfo]int
 	panelType    string
-	ihm          inbound.Manager
-	ohm          outbound.Manager
+	ibm          inbound.Manager
+	obm          outbound.Manager
 	stm          stats.Manager
 	dispatcher   *mydispatcher.DefaultDispatcher
 	startAt      time.Time
@@ -63,8 +63,8 @@ func New(server *core.Instance, api api.API, config *Config, panelType string) *
 		config:     config,
 		apiClient:  api,
 		panelType:  panelType,
-		ihm:        server.GetFeature(inbound.ManagerType()).(inbound.Manager),
-		ohm:        server.GetFeature(outbound.ManagerType()).(outbound.Manager),
+		ibm:        server.GetFeature(inbound.ManagerType()).(inbound.Manager),
+		obm:        server.GetFeature(outbound.ManagerType()).(outbound.Manager),
 		stm:        server.GetFeature(stats.ManagerType()).(stats.Manager),
 		dispatcher: server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher),
 		startAt:    time.Now(),
@@ -104,17 +104,19 @@ func (c *Controller) Start() error {
 		return err
 	}
 
+	// sync controller userList
+	c.userList = userInfo
+
 	err = c.addNewUser(userInfo, newNodeInfo)
 	if err != nil {
 		return err
 	}
-	// sync controller userList
-	c.userList = userInfo
 
 	// Add Limiter
 	if err := c.AddInboundLimiter(c.Tag, newNodeInfo.SpeedLimit, userInfo, c.config.GlobalDeviceLimitConfig); err != nil {
 		log.Print(err)
 	}
+
 	// Add Rule Manager
 	if !c.config.DisableGetRule {
 		if ruleList, err := c.apiClient.GetNodeRule(); err != nil {
@@ -138,26 +140,30 @@ func (c *Controller) Start() error {
 	// Add periodic tasks
 	c.tasks = append(c.tasks,
 		periodicTask{
-			tag: "node",
+			tag: "node monitor",
 			Periodic: &task.Periodic{
 				Interval: time.Duration(c.config.UpdatePeriodic) * time.Second,
 				Execute:  c.nodeInfoMonitor,
 			}},
 		periodicTask{
-			tag: "user",
+			tag: "user monitor",
 			Periodic: &task.Periodic{
 				Interval: time.Duration(c.config.UpdatePeriodic) * time.Second,
 				Execute:  c.userInfoMonitor,
 			}},
 	)
+
+	// Check cert service in need
 	if c.nodeInfo.NodeType != "Shadowsocks" {
 		c.tasks = append(c.tasks, periodicTask{
-			tag: "cert",
+			tag: "cert monitor",
 			Periodic: &task.Periodic{
 				Interval: time.Duration(c.config.UpdatePeriodic) * time.Second * 60,
 				Execute:  c.certMonitor,
 			}})
 	}
+
+	// Check global limit in need
 	if c.config.GlobalDeviceLimitConfig.Enable {
 		c.tasks = append(c.tasks,
 			periodicTask{
@@ -169,6 +175,16 @@ func (c *Controller) Start() error {
 			})
 	}
 
+	// Reset online user
+	c.tasks = append(c.tasks,
+		periodicTask{
+			tag: "reset online user",
+			Periodic: &task.Periodic{
+				Interval: time.Duration(c.config.UpdatePeriodic) * time.Second * 15,
+				Execute:  c.resetOnlineUser,
+			},
+		})
+
 	// Start periodic tasks
 	for i := range c.tasks {
 		log.Printf("%s Start %s periodic task", c.logPrefix(), c.tasks[i].tag)
@@ -663,9 +679,9 @@ func (c *Controller) globalLimitFetch() (err error) {
 			} else {
 				for k := range cmdMap {
 					ips := cmdMap[k].Val()
+					ipMap := new(sync.Map)
 					for i := range ips {
 						uid, _ := strconv.Atoi(ips[i])
-						ipMap := new(sync.Map)
 						ipMap.Store(i, uid)
 						inboundInfo.UserOnlineIP.LoadOrStore(k, ipMap)
 					}
@@ -680,3 +696,21 @@ func (c *Controller) globalLimitFetch() (err error) {
 
 	return nil
 }
+
+func (c *Controller) resetOnlineUser() error {
+	// delay to start
+	if time.Since(c.startAt) < time.Duration(c.config.UpdatePeriodic)*time.Second*15 {
+		return nil
+	}
+
+	if value, ok := c.dispatcher.Limiter.InboundInfo.Load(c.Tag); ok {
+		inboundInfo := value.(*limiter.InboundInfo)
+		inboundInfo.UserOnlineIP.Range(func(key, value interface{}) bool {
+			email := key.(string)
+			inboundInfo.UserOnlineIP.Delete(email) // Reset online device
+			return true
+		})
+	}
+
+	return nil
+}