fix: fix condition of check cert service

fix: check devices on redis first
This commit is contained in:
pocketW 2022-12-03 14:28:32 +11:00
parent 5b45b8ffe8
commit 839b15c22c
2 changed files with 12 additions and 7 deletions

View File

@ -187,10 +187,10 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r
ipMap.Delete(ip) ipMap.Delete(ip)
return nil, false, true return nil, false, true
} }
go pushIP(email, ip, inboundInfo.GlobalLimit) go pushIP(email, ip, deviceLimit, inboundInfo.GlobalLimit)
} }
} else { } else {
go pushIP(email, ip, inboundInfo.GlobalLimit) go pushIP(email, ip, deviceLimit, inboundInfo.GlobalLimit)
} }
} }
@ -214,10 +214,15 @@ func (l *Limiter) GetUserBucket(tag string, email string, ip string) (limiter *r
} }
// Push new IP to redis // Push new IP to redis
func pushIP(email string, ip string, g *GlobalLimit) { func pushIP(email string, ip string, deviceLimit int, g *GlobalLimit) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(g.Timeout)) ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(g.Timeout))
defer cancel() defer cancel()
// First check whether the device in redis reach the limit
if g.R.SCard(ctx, email).Val() >= int64(deviceLimit) {
return
}
if err := g.R.SAdd(ctx, email, ip).Err(); err != nil { if err := g.R.SAdd(ctx, email, ip).Err(); err != nil {
newError(fmt.Errorf("redis: %v", err)).AtError().WriteToLog() newError(fmt.Errorf("redis: %v", err)).AtError().WriteToLog()
} }

View File

@ -143,7 +143,7 @@ func (c *Controller) Start() error {
) )
// Check cert service in need // Check cert service in need
if c.nodeInfo.NodeType != "Shadowsocks" { if c.nodeInfo.EnableTLS {
c.tasks = append(c.tasks, periodicTask{ c.tasks = append(c.tasks, periodicTask{
tag: "cert monitor", tag: "cert monitor",
Periodic: &task.Periodic{ Periodic: &task.Periodic{
@ -672,12 +672,12 @@ func (c *Controller) globalLimitFetch() (err error) {
newError(fmt.Errorf("redis: %v", err)).AtError().WriteToLog() newError(fmt.Errorf("redis: %v", err)).AtError().WriteToLog()
} else { } else {
inboundInfo.GlobalLimit.OnlineIP = new(sync.Map) inboundInfo.GlobalLimit.OnlineIP = new(sync.Map)
for k := range cmdMap { for email := range cmdMap {
ips := cmdMap[k].Val() ips := cmdMap[email].Val()
ipMap := new(sync.Map) ipMap := new(sync.Map)
for i := range ips { for i := range ips {
ipMap.Store(ips[i], 0) ipMap.Store(ips[i], 0)
inboundInfo.GlobalLimit.OnlineIP.Store(k, ipMap) inboundInfo.GlobalLimit.OnlineIP.Store(email, ipMap)
} }
} }
} }