fix: global limit will failure on some situation (#124)

fix: typo
This commit is contained in:
thank243 2022-11-28 11:16:27 +08:00 committed by GitHub
parent 40ae48f507
commit d320aadb54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 32 deletions

View File

@ -234,7 +234,7 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sn
// Speed Limit and Device Limit // Speed Limit and Device Limit
bucket, ok, reject := d.Limiter.GetUserBucket(sessionInbound.Tag, user.Email, sessionInbound.Source.Address.IP().String()) bucket, ok, reject := d.Limiter.GetUserBucket(sessionInbound.Tag, user.Email, sessionInbound.Source.Address.IP().String())
if reject { if reject {
newError("Devices reach the limit: ", user.Email).AtError().WriteToLog() newError("Devices reach the limit: ", user.Email).AtWarning().WriteToLog()
common.Close(outboundLink.Writer) common.Close(outboundLink.Writer)
common.Close(inboundLink.Writer) common.Close(inboundLink.Writer)
common.Interrupt(outboundLink.Reader) common.Interrupt(outboundLink.Reader)

View File

@ -114,7 +114,8 @@ func (l *Limiter) DeleteInboundLimiter(tag string) error {
} }
func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) { func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) {
onlineUser := make([]api.OnlineUser, 0) var onlineUser []api.OnlineUser
if value, ok := l.InboundInfo.Load(tag); ok { if value, ok := l.InboundInfo.Load(tag); ok {
inboundInfo := value.(*InboundInfo) inboundInfo := value.(*InboundInfo)
// Clear Speed Limiter bucket for users who are not online // Clear Speed Limiter bucket for users who are not online
@ -128,13 +129,11 @@ func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) {
inboundInfo.UserOnlineIP.Range(func(key, value interface{}) bool { inboundInfo.UserOnlineIP.Range(func(key, value interface{}) bool {
ipMap := value.(*sync.Map) ipMap := value.(*sync.Map)
ipMap.Range(func(key, value interface{}) bool { ipMap.Range(func(key, value interface{}) bool {
ip := key.(string)
uid := value.(int) uid := value.(int)
ip := key.(string)
onlineUser = append(onlineUser, api.OnlineUser{UID: uid, IP: ip}) onlineUser = append(onlineUser, api.OnlineUser{UID: uid, IP: ip})
return true return true
}) })
email := key.(string)
inboundInfo.UserOnlineIP.Delete(email) // Reset online device
return true return true
}) })
} else { } else {
@ -187,8 +186,8 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r
// If any device is online // If any device is online
if v, ok := inboundInfo.UserOnlineIP.LoadOrStore(email, ipMap); ok { if v, ok := inboundInfo.UserOnlineIP.LoadOrStore(email, ipMap); ok {
ipMap := v.(*sync.Map) ipMap := v.(*sync.Map)
// If this ip is a new device // If this is a new ip
if _, ok := ipMap.LoadOrStore(ip, uid); !ok { if _, ok := ipMap.LoadOrStore(ip, uid); !ok || l.g.enable {
counter := 0 counter := 0
ipMap.Range(func(key, value interface{}) bool { ipMap.Range(func(key, value interface{}) bool {
counter++ counter++

View File

@ -1,14 +1,14 @@
package mylego package mylego
type CertConfig struct { type CertConfig struct {
CertMode string `mapstructure:"CertMode"` // none, file, http, dns CertMode string `mapstructure:"CertMode"` // none, file, http, dns
CertDomain string `mapstructure:"CertDomain"` CertDomain string `mapstructure:"CertDomain"`
CertFile string `mapstructure:"CertFile"` CertFile string `mapstructure:"CertFile"`
KeyFile string `mapstructure:"KeyFile"` KeyFile string `mapstructure:"KeyFile"`
Provider string `mapstructure:"Provider"` // alidns, cloudflare, gandi, godaddy.... Provider string `mapstructure:"Provider"` // alidns, cloudflare, gandi, godaddy....
Email string `mapstructure:"Email"` Email string `mapstructure:"Email"`
DNSEnv map[string]string `mapstructure:"DNSEnv"` DNSEnv map[string]string `mapstructure:"DNSEnv"`
RejectUnknownSni bool `mapstructure:"RejectUnknownSni"` RejectUnknownSni bool `mapstructure:"RejectUnknownSni"`
} }
type LegoCMD struct { type LegoCMD struct {

View File

@ -24,7 +24,7 @@ var (
) )
var ( var (
version = "0.8.6" version = "0.8.7"
codename = "XrayR" codename = "XrayR"
intro = "A Xray backend that supports many panels" intro = "A Xray backend that supports many panels"
) )

View File

@ -16,12 +16,12 @@ import (
) )
func (c *Controller) removeInbound(tag string) error { func (c *Controller) removeInbound(tag string) error {
err := c.ihm.RemoveHandler(context.Background(), tag) err := c.ibm.RemoveHandler(context.Background(), tag)
return err return err
} }
func (c *Controller) removeOutbound(tag string) error { func (c *Controller) removeOutbound(tag string) error {
err := c.ohm.RemoveHandler(context.Background(), tag) err := c.obm.RemoveHandler(context.Background(), tag)
return err return err
} }
@ -34,7 +34,7 @@ func (c *Controller) addInbound(config *core.InboundHandlerConfig) error {
if !ok { if !ok {
return fmt.Errorf("not an InboundHandler: %s", err) return fmt.Errorf("not an InboundHandler: %s", err)
} }
if err := c.ihm.AddHandler(context.Background(), handler); err != nil { if err := c.ibm.AddHandler(context.Background(), handler); err != nil {
return err return err
} }
return nil return nil
@ -49,14 +49,14 @@ func (c *Controller) addOutbound(config *core.OutboundHandlerConfig) error {
if !ok { if !ok {
return fmt.Errorf("not an InboundHandler: %s", err) return fmt.Errorf("not an InboundHandler: %s", err)
} }
if err := c.ohm.AddHandler(context.Background(), handler); err != nil { if err := c.obm.AddHandler(context.Background(), handler); err != nil {
return err return err
} }
return nil return nil
} }
func (c *Controller) addUsers(users []*protocol.User, tag string) error { func (c *Controller) addUsers(users []*protocol.User, tag string) error {
handler, err := c.ihm.GetHandler(context.Background(), tag) handler, err := c.ibm.GetHandler(context.Background(), tag)
if err != nil { if err != nil {
return fmt.Errorf("no such inbound tag: %s", err) return fmt.Errorf("no such inbound tag: %s", err)
} }
@ -83,7 +83,7 @@ func (c *Controller) addUsers(users []*protocol.User, tag string) error {
} }
func (c *Controller) removeUsers(users []string, tag string) error { func (c *Controller) removeUsers(users []string, tag string) error {
handler, err := c.ihm.GetHandler(context.Background(), tag) handler, err := c.ibm.GetHandler(context.Background(), tag)
if err != nil { if err != nil {
return fmt.Errorf("no such inbound tag: %s", err) return fmt.Errorf("no such inbound tag: %s", err)
} }

View File

@ -43,8 +43,8 @@ type Controller struct {
limitedUsers map[api.UserInfo]LimitInfo limitedUsers map[api.UserInfo]LimitInfo
warnedUsers map[api.UserInfo]int warnedUsers map[api.UserInfo]int
panelType string panelType string
ihm inbound.Manager ibm inbound.Manager
ohm outbound.Manager obm outbound.Manager
stm stats.Manager stm stats.Manager
dispatcher *mydispatcher.DefaultDispatcher dispatcher *mydispatcher.DefaultDispatcher
startAt time.Time startAt time.Time
@ -63,8 +63,8 @@ func New(server *core.Instance, api api.API, config *Config, panelType string) *
config: config, config: config,
apiClient: api, apiClient: api,
panelType: panelType, panelType: panelType,
ihm: server.GetFeature(inbound.ManagerType()).(inbound.Manager), ibm: server.GetFeature(inbound.ManagerType()).(inbound.Manager),
ohm: server.GetFeature(outbound.ManagerType()).(outbound.Manager), obm: server.GetFeature(outbound.ManagerType()).(outbound.Manager),
stm: server.GetFeature(stats.ManagerType()).(stats.Manager), stm: server.GetFeature(stats.ManagerType()).(stats.Manager),
dispatcher: server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher), dispatcher: server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher),
startAt: time.Now(), startAt: time.Now(),
@ -104,17 +104,19 @@ func (c *Controller) Start() error {
return err return err
} }
// sync controller userList
c.userList = userInfo
err = c.addNewUser(userInfo, newNodeInfo) err = c.addNewUser(userInfo, newNodeInfo)
if err != nil { if err != nil {
return err return err
} }
// sync controller userList
c.userList = userInfo
// Add Limiter // Add Limiter
if err := c.AddInboundLimiter(c.Tag, newNodeInfo.SpeedLimit, userInfo, c.config.GlobalDeviceLimitConfig); err != nil { if err := c.AddInboundLimiter(c.Tag, newNodeInfo.SpeedLimit, userInfo, c.config.GlobalDeviceLimitConfig); err != nil {
log.Print(err) log.Print(err)
} }
// Add Rule Manager // Add Rule Manager
if !c.config.DisableGetRule { if !c.config.DisableGetRule {
if ruleList, err := c.apiClient.GetNodeRule(); err != nil { if ruleList, err := c.apiClient.GetNodeRule(); err != nil {
@ -138,26 +140,30 @@ func (c *Controller) Start() error {
// Add periodic tasks // Add periodic tasks
c.tasks = append(c.tasks, c.tasks = append(c.tasks,
periodicTask{ periodicTask{
tag: "node", tag: "node monitor",
Periodic: &task.Periodic{ Periodic: &task.Periodic{
Interval: time.Duration(c.config.UpdatePeriodic) * time.Second, Interval: time.Duration(c.config.UpdatePeriodic) * time.Second,
Execute: c.nodeInfoMonitor, Execute: c.nodeInfoMonitor,
}}, }},
periodicTask{ periodicTask{
tag: "user", tag: "user monitor",
Periodic: &task.Periodic{ Periodic: &task.Periodic{
Interval: time.Duration(c.config.UpdatePeriodic) * time.Second, Interval: time.Duration(c.config.UpdatePeriodic) * time.Second,
Execute: c.userInfoMonitor, Execute: c.userInfoMonitor,
}}, }},
) )
// Check cert service in need
if c.nodeInfo.NodeType != "Shadowsocks" { if c.nodeInfo.NodeType != "Shadowsocks" {
c.tasks = append(c.tasks, periodicTask{ c.tasks = append(c.tasks, periodicTask{
tag: "cert", tag: "cert monitor",
Periodic: &task.Periodic{ Periodic: &task.Periodic{
Interval: time.Duration(c.config.UpdatePeriodic) * time.Second * 60, Interval: time.Duration(c.config.UpdatePeriodic) * time.Second * 60,
Execute: c.certMonitor, Execute: c.certMonitor,
}}) }})
} }
// Check global limit in need
if c.config.GlobalDeviceLimitConfig.Enable { if c.config.GlobalDeviceLimitConfig.Enable {
c.tasks = append(c.tasks, c.tasks = append(c.tasks,
periodicTask{ periodicTask{
@ -169,6 +175,16 @@ func (c *Controller) Start() error {
}) })
} }
// Reset online user
c.tasks = append(c.tasks,
periodicTask{
tag: "reset online user",
Periodic: &task.Periodic{
Interval: time.Duration(c.config.UpdatePeriodic) * time.Second * 15,
Execute: c.resetOnlineUser,
},
})
// Start periodic tasks // Start periodic tasks
for i := range c.tasks { for i := range c.tasks {
log.Printf("%s Start %s periodic task", c.logPrefix(), c.tasks[i].tag) log.Printf("%s Start %s periodic task", c.logPrefix(), c.tasks[i].tag)
@ -663,9 +679,9 @@ func (c *Controller) globalLimitFetch() (err error) {
} else { } else {
for k := range cmdMap { for k := range cmdMap {
ips := cmdMap[k].Val() ips := cmdMap[k].Val()
ipMap := new(sync.Map)
for i := range ips { for i := range ips {
uid, _ := strconv.Atoi(ips[i]) uid, _ := strconv.Atoi(ips[i])
ipMap := new(sync.Map)
ipMap.Store(i, uid) ipMap.Store(i, uid)
inboundInfo.UserOnlineIP.LoadOrStore(k, ipMap) inboundInfo.UserOnlineIP.LoadOrStore(k, ipMap)
} }
@ -680,3 +696,21 @@ func (c *Controller) globalLimitFetch() (err error) {
return nil return nil
} }
func (c *Controller) resetOnlineUser() error {
// delay to start
if time.Since(c.startAt) < time.Duration(c.config.UpdatePeriodic)*time.Second*15 {
return nil
}
if value, ok := c.dispatcher.Limiter.InboundInfo.Load(c.Tag); ok {
inboundInfo := value.(*limiter.InboundInfo)
inboundInfo.UserOnlineIP.Range(func(key, value interface{}) bool {
email := key.(string)
inboundInfo.UserOnlineIP.Delete(email) // Reset online device
return true
})
}
return nil
}