From 39c1036c4aa8bc06bae32e02dd28416a549eb03b Mon Sep 17 00:00:00 2001
From: JoshuaCylinder <joshua@joshuacyl.work>
Date: Sat, 1 Oct 2022 06:49:51 +0000
Subject: [PATCH] Add AutoSpeedLimitConfig for some special node (For example:
 IPLC, IEPL, GameNode)

---
 main/config.yml.example          |  5 +++
 service/controller/config.go     | 34 ++++++++++------
 service/controller/controller.go | 69 ++++++++++++++++++++++++++++++++
 3 files changed, 95 insertions(+), 13 deletions(-)

diff --git a/main/config.yml.example b/main/config.yml.example
index 84845b2..d55874c 100644
--- a/main/config.yml.example
+++ b/main/config.yml.example
@@ -34,6 +34,11 @@ Nodes:
       DNSType: AsIs # AsIs, UseIP, UseIPv4, UseIPv6, DNS strategy
       EnableProxyProtocol: false # Only works for WebSocket and TCP
       EnableFallback: false # Only support for Trojan and Vless
+      AutoSpeedLimitConfig:
+        Limit: 0 # Warned speed. Set to 0 to disable AutoSpeedLimit (mbps)
+        WarnTimes: 0 # After (WarnTimes) consecutive warnings, the user will be limited. Set to 0 to punish overspeed user immediately.
+        LimitSpeed: 0 # The speedlimit of a limited user (unit: mbps)
+        LimitDuration: 0 # How many minutes will the limiting last (unit: minute)
       FallBackConfigs:  # Support multiple fallbacks
         -
           SNI: # TLS SNI(Server Name Indication), Empty for any
diff --git a/service/controller/config.go b/service/controller/config.go
index 36907aa..1791d74 100644
--- a/service/controller/config.go
+++ b/service/controller/config.go
@@ -1,19 +1,27 @@
 package controller
 
 type Config struct {
-	ListenIP             string            `mapstructure:"ListenIP"`
-	SendIP               string            `mapstructure:"SendIP"`
-	UpdatePeriodic       int               `mapstructure:"UpdatePeriodic"`
-	CertConfig           *CertConfig       `mapstructure:"CertConfig"`
-	EnableDNS            bool              `mapstructure:"EnableDNS"`
-	DNSType              string            `mapstructure:"DNSType"`
-	DisableUploadTraffic bool              `mapstructure:"DisableUploadTraffic"`
-	DisableGetRule       bool              `mapstructure:"DisableGetRule"`
-	EnableProxyProtocol  bool              `mapstructure:"EnableProxyProtocol"`
-	EnableFallback       bool              `mapstructure:"EnableFallback"`
-	DisableIVCheck       bool              `mapstructure:"DisableIVCheck"`
-	DisableSniffing      bool              `mapstructure:"DisableSniffing"`
-	FallBackConfigs      []*FallBackConfig `mapstructure:"FallBackConfigs"`
+	ListenIP             string                `mapstructure:"ListenIP"`
+	SendIP               string                `mapstructure:"SendIP"`
+	UpdatePeriodic       int                   `mapstructure:"UpdatePeriodic"`
+	CertConfig           *CertConfig           `mapstructure:"CertConfig"`
+	EnableDNS            bool                  `mapstructure:"EnableDNS"`
+	DNSType              string                `mapstructure:"DNSType"`
+	DisableUploadTraffic bool                  `mapstructure:"DisableUploadTraffic"`
+	DisableGetRule       bool                  `mapstructure:"DisableGetRule"`
+	EnableProxyProtocol  bool                  `mapstructure:"EnableProxyProtocol"`
+	EnableFallback       bool                  `mapstructure:"EnableFallback"`
+	DisableIVCheck       bool                  `mapstructure:"DisableIVCheck"`
+	DisableSniffing      bool                  `mapstructure:"DisableSniffing"`
+	AutoSpeedLimitConfig *AutoSpeedLimitConfig `mapstructure:"AutoSpeedLimitConfig"`
+	FallBackConfigs      []*FallBackConfig     `mapstructure:"FallBackConfigs"`
+}
+
+type AutoSpeedLimitConfig struct {
+	Limit         int `mapstructure:"Limit"` // mbps
+	WarnTimes     int `mapstructure:"WarnTimes"`
+	LimitSpeed    int `mapstructure:"LimitSpeed"`    // mbps
+	LimitDuration int `mapstructure:"LimitDuration"` // minute
 }
 
 type CertConfig struct {
diff --git a/service/controller/controller.go b/service/controller/controller.go
index 47e3f39..55fd334 100644
--- a/service/controller/controller.go
+++ b/service/controller/controller.go
@@ -19,6 +19,11 @@ import (
 	"github.com/xtls/xray-core/features/stats"
 )
 
+type LimitInfo struct {
+	end              int64
+	originSpeedLimit uint64
+}
+
 type Controller struct {
 	server                  *core.Instance
 	config                  *Config
@@ -29,6 +34,8 @@ type Controller struct {
 	userList                *[]api.UserInfo
 	nodeInfoMonitorPeriodic *task.Periodic
 	userReportPeriodic      *task.Periodic
+	limitedUsers            map[api.UserInfo]LimitInfo
+	warnedUsers             map[api.UserInfo]int
 	panelType               string
 	ihm                     inbound.Manager
 	ohm                     outbound.Manager
@@ -102,6 +109,10 @@ func (c *Controller) Start() error {
 		Interval: time.Duration(c.config.UpdatePeriodic) * time.Second,
 		Execute:  c.userInfoMonitor,
 	}
+	if c.config.AutoSpeedLimitConfig.Limit > 0 {
+		c.limitedUsers = make(map[api.UserInfo]LimitInfo)
+		c.warnedUsers = make(map[api.UserInfo]int)
+	}
 	log.Printf("[%s: %d] Start monitor node status", c.nodeInfo.NodeType, c.nodeInfo.NodeID)
 	// delay to start nodeInfoMonitor
 	go func() {
@@ -409,6 +420,16 @@ func compareUserList(old, new *[]api.UserInfo) (deleted, added []api.UserInfo) {
 	return deleted, added
 }
 
+func limitUser(c *Controller, user api.UserInfo, silentUsers *[]api.UserInfo) {
+	c.limitedUsers[user] = LimitInfo{
+		end:              time.Now().Unix() + int64(c.config.AutoSpeedLimitConfig.LimitDuration*60),
+		originSpeedLimit: user.SpeedLimit,
+	}
+	log.Printf("    User: %s Speed: %d End: %s", user.Email, user.SpeedLimit, time.Unix(c.limitedUsers[user].end, 0).Format("01-02 15:04:05"))
+	user.SpeedLimit = uint64(c.config.AutoSpeedLimitConfig.LimitSpeed) * 1024 * 1024 / 8
+	*silentUsers = append(*silentUsers, user)
+}
+
 func (c *Controller) userInfoMonitor() (err error) {
 	// Get server status
 	CPU, Mem, Disk, Uptime, err := serverstatus.GetSystemInfo()
@@ -425,14 +446,55 @@ func (c *Controller) userInfoMonitor() (err error) {
 	if err != nil {
 		log.Print(err)
 	}
+	// Unlock users
+	if c.config.AutoSpeedLimitConfig.Limit > 0 && len(c.limitedUsers) > 0 {
+		log.Printf("Limited users:")
+		toReleaseUsers := make([]api.UserInfo, 0)
+		for user, limitInfo := range c.limitedUsers {
+			if time.Now().Unix() > limitInfo.end {
+				user.SpeedLimit = limitInfo.originSpeedLimit
+				toReleaseUsers = append(toReleaseUsers, user)
+				log.Printf("    User: %s Speed: %d End: nil (Unlimit)", user.Email, user.SpeedLimit)
+				delete(c.limitedUsers, user)
+			} else {
+				log.Printf("    User: %s Speed: %d End: %s", user.Email, user.SpeedLimit, time.Unix(c.limitedUsers[user].end, 0).Format("01-02 15:04:05"))
+			}
+		}
+		if len(toReleaseUsers) > 0 {
+			if err := c.UpdateInboundLimiter(c.Tag, &toReleaseUsers); err != nil {
+				log.Print(err)
+			}
+		}
+	}
 
 	// Get User traffic
 	var userTraffic []api.UserTraffic
 	var upCounterList []stats.Counter
 	var downCounterList []stats.Counter
+	AutoSpeedLimit := int64(c.config.AutoSpeedLimitConfig.Limit)
+	UpdatePeriodic := int64(c.config.UpdatePeriodic)
+	limitedUsers := make([]api.UserInfo, 0)
 	for _, user := range *c.userList {
 		up, down, upCounter, downCounter := c.getTraffic(c.buildUserTag(&user))
 		if up > 0 || down > 0 {
+			// Over speed users
+			if AutoSpeedLimit > 0 {
+				if down > AutoSpeedLimit*1024*1024*UpdatePeriodic/8 {
+					if _, ok := c.limitedUsers[user]; !ok {
+						if c.config.AutoSpeedLimitConfig.WarnTimes == 0 {
+							limitUser(c, user, &limitedUsers)
+						} else {
+							c.warnedUsers[user] += 1
+							if c.warnedUsers[user] > c.config.AutoSpeedLimitConfig.WarnTimes {
+								limitUser(c, user, &limitedUsers)
+								delete(c.warnedUsers, user)
+							}
+						}
+					}
+				} else {
+					delete(c.warnedUsers, user)
+				}
+			}
 			userTraffic = append(userTraffic, api.UserTraffic{
 				UID:      user.UID,
 				Email:    user.Email,
@@ -445,6 +507,13 @@ func (c *Controller) userInfoMonitor() (err error) {
 			if downCounter != nil {
 				downCounterList = append(downCounterList, downCounter)
 			}
+		} else {
+			delete(c.warnedUsers, user)
+		}
+	}
+	if len(limitedUsers) > 0 {
+		if err := c.UpdateInboundLimiter(c.Tag, &limitedUsers); err != nil {
+			log.Print(err)
 		}
 	}
 	if len(userTraffic) > 0 {