diff --git a/common/limiter/limiter.go b/common/limiter/limiter.go index aab66d2..e2d96ca 100644 --- a/common/limiter/limiter.go +++ b/common/limiter/limiter.go @@ -4,6 +4,7 @@ package limiter import ( "context" "fmt" + "strconv" "strings" "sync" "time" @@ -63,16 +64,17 @@ func (l *Limiter) AddInboundLimiter(tag string, nodeSpeedLimit uint64, userList gs := goCacheStore.NewGoCache(goCache.New(time.Duration(globalLimit.Expiry)*time.Second, 10*time.Minute)) // init redis store - redisClient := redis.NewClient(&redis.Options{ - Addr: globalLimit.RedisAddr, - Password: globalLimit.RedisPassword, - DB: globalLimit.RedisDB, - }) - rs := redisStore.NewRedis(redisClient, store.WithExpiration(time.Duration(globalLimit.Expiry)*time.Second)) + rs := redisStore.NewRedis(redis.NewClient( + &redis.Options{ + Addr: globalLimit.RedisAddr, + Password: globalLimit.RedisPassword, + DB: globalLimit.RedisDB, + }), + store.WithExpiration(time.Duration(globalLimit.Expiry)*time.Second)) // init chained cache. First use local go-cache, if go-cache is nil, then use redis cache cacheManager := cache.NewChain[any]( - cache.New[any](gs), + cache.New[any](gs), // go-cache is priority cache.New[any](rs), ) inboundInfo.GlobalLimit.globalOnlineIP = marshaler.New(cacheManager) @@ -223,29 +225,30 @@ func globalLimit(inboundInfo *InboundInfo, email string, uid int, ip string, dev ctx, cancel := context.WithTimeout(context.Background(), time.Duration(inboundInfo.GlobalLimit.config.Timeout)*time.Second) defer cancel() - email := email[strings.Index(email, "|")+1:] // reformat email for unique key + // reformat email for unique key + email := strings.Replace(email, inboundInfo.Tag, strconv.Itoa(deviceLimit), 1) + v, err := inboundInfo.GlobalLimit.globalOnlineIP.Get(ctx, email, new(map[string]int)) if err != nil { - switch err.(type) { - case *store.NotFound: + if _, ok := err.(*store.NotFound); ok { // If the email is a new device - if v == nil { - go pushIP(inboundInfo, email, ip, uid) - } - default: + go pushIP(inboundInfo, email, &map[string]int{ip: uid}) + } else { newError("cache service").Base(err).AtError().WriteToLog() } return false } - ipMap := *v.(*map[string]int) + ipMap := v.(*map[string]int) // Reject device reach limit directly - if deviceLimit > 0 && len(ipMap) > deviceLimit { + if deviceLimit > 0 && len(*ipMap) > deviceLimit { return true } + // If the ip is not in cache - if _, ok := ipMap[ip]; !ok { - go pushIP(inboundInfo, email, ip, uid) + if _, ok := (*ipMap)[ip]; !ok { + (*ipMap)[ip] = uid + go pushIP(inboundInfo, email, ipMap) } } @@ -253,11 +256,11 @@ func globalLimit(inboundInfo *InboundInfo, email string, uid int, ip string, dev } // push the ip to cache -func pushIP(inboundInfo *InboundInfo, email string, ip string, uid int) { +func pushIP(inboundInfo *InboundInfo, email string, ipMap *map[string]int) { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(inboundInfo.GlobalLimit.config.Timeout)*time.Second) defer cancel() - if err := inboundInfo.GlobalLimit.globalOnlineIP.Set(ctx, email, map[string]int{ip: uid}); err != nil { + if err := inboundInfo.GlobalLimit.globalOnlineIP.Set(ctx, email, ipMap); err != nil { newError("cache service").Base(err).AtError().WriteToLog() } }