From 839b15c22ca2f5c307bb388efd16f3fb6f111793 Mon Sep 17 00:00:00 2001 From: pocketW <104479902+pocketW@users.noreply.github.com> Date: Sat, 3 Dec 2022 14:28:32 +1100 Subject: [PATCH] fix: fix condition of check cert service fix: check devices on redis first --- common/limiter/limiter.go | 11 ++++++++--- service/controller/controller.go | 8 ++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index 1e5bb2e..fda2144 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -187,10 +187,10 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r ipMap.Delete(ip) return nil, false, true } - go pushIP(email, ip, inboundInfo.GlobalLimit) + go pushIP(email, ip, deviceLimit, inboundInfo.GlobalLimit) } } else { - go pushIP(email, ip, inboundInfo.GlobalLimit) + go pushIP(email, ip, deviceLimit, inboundInfo.GlobalLimit) } } @@ -214,10 +214,15 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r } // Push new IP to redis -func pushIP(email string, ip string, g *GlobalLimit) { +func pushIP(email string, ip string, deviceLimit int, g *GlobalLimit) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(g.Timeout)) defer cancel() + // First check whether the device in redis reach the limit + if g.R.SCard(ctx, email).Val() >= int64(deviceLimit) { + return + } + if err := g.R.SAdd(ctx, email, ip).Err(); err != nil { newError(fmt.Errorf("redis: %v", err)).AtError().WriteToLog() } diff --git a/service/controller/controller.go b/service/controller/controller.go index a307d4e..0f27f23 100644 --- a/service/controller/controller.go +++ b/service/controller/controller.go @@ -143,7 +143,7 @@ func (c *Controller) Start() error { ) // Check cert service in need - if c.nodeInfo.NodeType != "Shadowsocks" { + if c.nodeInfo.EnableTLS { c.tasks = append(c.tasks, periodicTask{ tag: "cert monitor", Periodic: &task.Periodic{ @@ -672,12 +672,12 @@ func (c *Controller) globalLimitFetch() (err error) { newError(fmt.Errorf("redis: %v", err)).AtError().WriteToLog() } else { inboundInfo.GlobalLimit.OnlineIP = new(sync.Map) - for k := range cmdMap { - ips := cmdMap[k].Val() + for email := range cmdMap { + ips := cmdMap[email].Val() ipMap := new(sync.Map) for i := range ips { ipMap.Store(ips[i], 0) - inboundInfo.GlobalLimit.OnlineIP.Store(k, ipMap) + inboundInfo.GlobalLimit.OnlineIP.Store(email, ipMap) } } }