diff --git a/app/mydispatcher/default.go b/app/mydispatcher/default.go index 0c4a8ba..a7efa63 100644 --- a/app/mydispatcher/default.go +++ b/app/mydispatcher/default.go @@ -234,7 +234,7 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network, sn // Speed Limit and Device Limit bucket, ok, reject := d.Limiter.GetUserBucket(sessionInbound.Tag, user.Email, sessionInbound.Source.Address.IP().String()) if reject { - newError("Devices reach the limit: ", user.Email).AtError().WriteToLog() + newError("Devices reach the limit: ", user.Email).AtWarning().WriteToLog() common.Close(outboundLink.Writer) common.Close(inboundLink.Writer) common.Interrupt(outboundLink.Reader) diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index 9d41fc3..951e797 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -114,7 +114,8 @@ func (l *Limiter) DeleteInboundLimiter(tag string) error { } func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) { - onlineUser := make([]api.OnlineUser, 0) + var onlineUser []api.OnlineUser + if value, ok := l.InboundInfo.Load(tag); ok { inboundInfo := value.(*InboundInfo) // Clear Speed Limiter bucket for users who are not online @@ -128,13 +129,11 @@ func (l *Limiter) GetOnlineDevice(tag string) (*[]api.OnlineUser, error) { inboundInfo.UserOnlineIP.Range(func(key, value interface{}) bool { ipMap := value.(*sync.Map) ipMap.Range(func(key, value interface{}) bool { - ip := key.(string) uid := value.(int) + ip := key.(string) onlineUser = append(onlineUser, api.OnlineUser{UID: uid, IP: ip}) return true }) - email := key.(string) - inboundInfo.UserOnlineIP.Delete(email) // Reset online device return true }) } else { @@ -187,8 +186,8 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r // If any device is online if v, ok := inboundInfo.UserOnlineIP.LoadOrStore(email, ipMap); ok { ipMap := v.(*sync.Map) - // If this ip is a new device - if _, ok := ipMap.LoadOrStore(ip, uid); !ok { + // If this is a new ip + if _, ok := ipMap.LoadOrStore(ip, uid); !ok || l.g.enable { counter := 0 ipMap.Range(func(key, value interface{}) bool { counter++ diff --git a/common/mylego/model.go b/common/mylego/model.go index 5e565b1..270907b 100644 --- a/common/mylego/model.go +++ b/common/mylego/model.go @@ -1,14 +1,14 @@ package mylego type CertConfig struct { - CertMode string `mapstructure:"CertMode"` // none, file, http, dns - CertDomain string `mapstructure:"CertDomain"` - CertFile string `mapstructure:"CertFile"` - KeyFile string `mapstructure:"KeyFile"` - Provider string `mapstructure:"Provider"` // alidns, cloudflare, gandi, godaddy.... - Email string `mapstructure:"Email"` - DNSEnv map[string]string `mapstructure:"DNSEnv"` - RejectUnknownSni bool `mapstructure:"RejectUnknownSni"` + CertMode string `mapstructure:"CertMode"` // none, file, http, dns + CertDomain string `mapstructure:"CertDomain"` + CertFile string `mapstructure:"CertFile"` + KeyFile string `mapstructure:"KeyFile"` + Provider string `mapstructure:"Provider"` // alidns, cloudflare, gandi, godaddy.... + Email string `mapstructure:"Email"` + DNSEnv map[string]string `mapstructure:"DNSEnv"` + RejectUnknownSni bool `mapstructure:"RejectUnknownSni"` } type LegoCMD struct { diff --git a/main/main.go b/main/main.go index ce96527..f8c141b 100644 --- a/main/main.go +++ b/main/main.go @@ -24,7 +24,7 @@ var ( ) var ( - version = "0.8.6" + version = "0.8.7" codename = "XrayR" intro = "A Xray backend that supports many panels" ) diff --git a/service/controller/control.go b/service/controller/control.go index cd8b885..e06f4fe 100644 --- a/service/controller/control.go +++ b/service/controller/control.go @@ -16,12 +16,12 @@ import ( ) func (c *Controller) removeInbound(tag string) error { - err := c.ihm.RemoveHandler(context.Background(), tag) + err := c.ibm.RemoveHandler(context.Background(), tag) return err } func (c *Controller) removeOutbound(tag string) error { - err := c.ohm.RemoveHandler(context.Background(), tag) + err := c.obm.RemoveHandler(context.Background(), tag) return err } @@ -34,7 +34,7 @@ func (c *Controller) addInbound(config *core.InboundHandlerConfig) error { if !ok { return fmt.Errorf("not an InboundHandler: %s", err) } - if err := c.ihm.AddHandler(context.Background(), handler); err != nil { + if err := c.ibm.AddHandler(context.Background(), handler); err != nil { return err } return nil @@ -49,14 +49,14 @@ func (c *Controller) addOutbound(config *core.OutboundHandlerConfig) error { if !ok { return fmt.Errorf("not an InboundHandler: %s", err) } - if err := c.ohm.AddHandler(context.Background(), handler); err != nil { + if err := c.obm.AddHandler(context.Background(), handler); err != nil { return err } return nil } func (c *Controller) addUsers(users []*protocol.User, tag string) error { - handler, err := c.ihm.GetHandler(context.Background(), tag) + handler, err := c.ibm.GetHandler(context.Background(), tag) if err != nil { return fmt.Errorf("no such inbound tag: %s", err) } @@ -83,7 +83,7 @@ func (c *Controller) addUsers(users []*protocol.User, tag string) error { } func (c *Controller) removeUsers(users []string, tag string) error { - handler, err := c.ihm.GetHandler(context.Background(), tag) + handler, err := c.ibm.GetHandler(context.Background(), tag) if err != nil { return fmt.Errorf("no such inbound tag: %s", err) } diff --git a/service/controller/controller.go b/service/controller/controller.go index ae33fa3..01a7b1c 100644 --- a/service/controller/controller.go +++ b/service/controller/controller.go @@ -43,8 +43,8 @@ type Controller struct { limitedUsers map[api.UserInfo]LimitInfo warnedUsers map[api.UserInfo]int panelType string - ihm inbound.Manager - ohm outbound.Manager + ibm inbound.Manager + obm outbound.Manager stm stats.Manager dispatcher *mydispatcher.DefaultDispatcher startAt time.Time @@ -63,8 +63,8 @@ func New(server *core.Instance, api api.API, config *Config, panelType string) * config: config, apiClient: api, panelType: panelType, - ihm: server.GetFeature(inbound.ManagerType()).(inbound.Manager), - ohm: server.GetFeature(outbound.ManagerType()).(outbound.Manager), + ibm: server.GetFeature(inbound.ManagerType()).(inbound.Manager), + obm: server.GetFeature(outbound.ManagerType()).(outbound.Manager), stm: server.GetFeature(stats.ManagerType()).(stats.Manager), dispatcher: server.GetFeature(routing.DispatcherType()).(*mydispatcher.DefaultDispatcher), startAt: time.Now(), @@ -104,17 +104,19 @@ func (c *Controller) Start() error { return err } + // sync controller userList + c.userList = userInfo + err = c.addNewUser(userInfo, newNodeInfo) if err != nil { return err } - // sync controller userList - c.userList = userInfo // Add Limiter if err := c.AddInboundLimiter(c.Tag, newNodeInfo.SpeedLimit, userInfo, c.config.GlobalDeviceLimitConfig); err != nil { log.Print(err) } + // Add Rule Manager if !c.config.DisableGetRule { if ruleList, err := c.apiClient.GetNodeRule(); err != nil { @@ -138,26 +140,30 @@ func (c *Controller) Start() error { // Add periodic tasks c.tasks = append(c.tasks, periodicTask{ - tag: "node", + tag: "node monitor", Periodic: &task.Periodic{ Interval: time.Duration(c.config.UpdatePeriodic) * time.Second, Execute: c.nodeInfoMonitor, }}, periodicTask{ - tag: "user", + tag: "user monitor", Periodic: &task.Periodic{ Interval: time.Duration(c.config.UpdatePeriodic) * time.Second, Execute: c.userInfoMonitor, }}, ) + + // Check cert service in need if c.nodeInfo.NodeType != "Shadowsocks" { c.tasks = append(c.tasks, periodicTask{ - tag: "cert", + tag: "cert monitor", Periodic: &task.Periodic{ Interval: time.Duration(c.config.UpdatePeriodic) * time.Second * 60, Execute: c.certMonitor, }}) } + + // Check global limit in need if c.config.GlobalDeviceLimitConfig.Enable { c.tasks = append(c.tasks, periodicTask{ @@ -169,6 +175,16 @@ func (c *Controller) Start() error { }) } + // Reset online user + c.tasks = append(c.tasks, + periodicTask{ + tag: "reset online user", + Periodic: &task.Periodic{ + Interval: time.Duration(c.config.UpdatePeriodic) * time.Second * 15, + Execute: c.resetOnlineUser, + }, + }) + // Start periodic tasks for i := range c.tasks { log.Printf("%s Start %s periodic task", c.logPrefix(), c.tasks[i].tag) @@ -663,9 +679,9 @@ func (c *Controller) globalLimitFetch() (err error) { } else { for k := range cmdMap { ips := cmdMap[k].Val() + ipMap := new(sync.Map) for i := range ips { uid, _ := strconv.Atoi(ips[i]) - ipMap := new(sync.Map) ipMap.Store(i, uid) inboundInfo.UserOnlineIP.LoadOrStore(k, ipMap) } @@ -680,3 +696,21 @@ func (c *Controller) globalLimitFetch() (err error) { return nil } + +func (c *Controller) resetOnlineUser() error { + // delay to start + if time.Since(c.startAt) < time.Duration(c.config.UpdatePeriodic)*time.Second*15 { + return nil + } + + if value, ok := c.dispatcher.Limiter.InboundInfo.Load(c.Tag); ok { + inboundInfo := value.(*limiter.InboundInfo) + inboundInfo.UserOnlineIP.Range(func(key, value interface{}) bool { + email := key.(string) + inboundInfo.UserOnlineIP.Delete(email) // Reset online device + return true + }) + } + + return nil +}