|
@@ -0,0 +1,75 @@
|
|
|
+package com.qmth.themis.business.aspect;
|
|
|
+
|
|
|
+import com.qmth.themis.business.annotation.RedisLimitAnnotation;
|
|
|
+import com.qmth.themis.business.util.ServletUtil;
|
|
|
+import com.qmth.themis.common.util.IpUtil;
|
|
|
+import org.aspectj.lang.ProceedingJoinPoint;
|
|
|
+import org.aspectj.lang.annotation.Around;
|
|
|
+import org.aspectj.lang.annotation.Aspect;
|
|
|
+import org.aspectj.lang.annotation.Pointcut;
|
|
|
+import org.aspectj.lang.reflect.MethodSignature;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.data.redis.core.RedisTemplate;
|
|
|
+import org.springframework.data.redis.core.script.DefaultRedisScript;
|
|
|
+import org.springframework.stereotype.Component;
|
|
|
+
|
|
|
+import java.lang.reflect.Method;
|
|
|
+import java.util.Collections;
|
|
|
+import java.util.List;
|
|
|
+
|
|
|
+@Aspect
|
|
|
+@Component
|
|
|
+public class LimitRestAspect {
|
|
|
+ private final static Logger log = LoggerFactory.getLogger(LimitRestAspect.class);
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private RedisTemplate<String, Object> redisTemplate;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ private DefaultRedisScript<Long> redisluaScript;
|
|
|
+
|
|
|
+
|
|
|
+ @Pointcut(value = "@annotation(com.qmth.themis.business.annotation.RedisLimitAnnotation)")
|
|
|
+ public void rateLimit() {
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ @Around("rateLimit()")
|
|
|
+ public Object interceptor(ProceedingJoinPoint joinPoint) throws Throwable {
|
|
|
+ MethodSignature signature = (MethodSignature) joinPoint.getSignature();
|
|
|
+ Method method = signature.getMethod();
|
|
|
+ Class<?> targetClass = method.getDeclaringClass();
|
|
|
+ RedisLimitAnnotation rateLimit = method.getAnnotation(RedisLimitAnnotation.class);
|
|
|
+ if (rateLimit != null) {
|
|
|
+ String ipAddress = IpUtil.getRemoteIp(ServletUtil.getRequest());
|
|
|
+ StringBuffer stringBuffer = new StringBuffer();
|
|
|
+ stringBuffer.append(ipAddress).append("-")
|
|
|
+ .append(targetClass.getName()).append("- ")
|
|
|
+ .append(method.getName()).append("-")
|
|
|
+ .append(rateLimit.key());
|
|
|
+ List<String> keys = Collections.singletonList(stringBuffer.toString());
|
|
|
+ //调用lua脚本,获取返回结果,这里即为请求的次数
|
|
|
+ Long number = redisTemplate.execute(
|
|
|
+ redisluaScript,
|
|
|
+ // 此处传参只要能转为Object就行(因为数字不能直接强转为String,所以不能用String序列化)
|
|
|
+ //new GenericToStringSerializer<>(Object.class),
|
|
|
+ // 结果的类型需要根据脚本定义,此处是数字--定义的是Long类型
|
|
|
+ //new GenericToStringSerializer<>(Long.class)
|
|
|
+ keys,
|
|
|
+ rateLimit.count(),
|
|
|
+ rateLimit.period()
|
|
|
+ );
|
|
|
+ if (number != null && number.intValue() != 0 && number.intValue() <= rateLimit.count()) {
|
|
|
+ log.info("限流时间段内访问了第:{} 次", number.toString());
|
|
|
+ return joinPoint.proceed();
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ return joinPoint.proceed();
|
|
|
+ }
|
|
|
+ return null;
|
|
|
+// throw new RuntimeException("访问频率过快,被限流了");
|
|
|
+ }
|
|
|
+}
|
|
|
+
|