pay 缓存,使用 guava 替代 job 扫描,目的:提升启动速度,加快缓存失效

This commit is contained in:
YunaiV 2023-09-16 23:22:58 +08:00
parent a51579a77d
commit 222daa5366
14 changed files with 146 additions and 162 deletions

View File

@ -48,7 +48,7 @@ public abstract class AbstractPayClient<Config extends PayClientConfig> implemen
*/
public final void init() {
doInit();
log.info("[init][客户端({}) 初始化完成]", getId());
log.debug("[init][客户端({}) 初始化完成]", getId());
}
/**

View File

@ -233,14 +233,38 @@ public class FileConfigServiceImplTest extends BaseDbUnitTest {
@Test
public void testGetFileClient() {
// mock 数据
FileConfigDO fileConfig = randomFileConfigDO().setMaster(false);
fileConfigMapper.insert(fileConfig);
// 准备参数
Long id = randomLongId();
Long id = fileConfig.getId();
// mock 获得 Client
FileClient fileClient = new LocalFileClient(id, new LocalFileClientConfig());
when(fileClientFactory.getFileClient(eq(id))).thenReturn(fileClient);
// 调用并断言
assertSame(fileClient, fileConfigService.getFileClient(id));
// 断言缓存
verify(fileClientFactory).createOrUpdateFileClient(eq(id), eq(fileConfig.getStorage()),
eq(fileConfig.getConfig()));
}
@Test
public void testGetMasterFileClient() {
// mock 数据
FileConfigDO fileConfig = randomFileConfigDO().setMaster(true);
fileConfigMapper.insert(fileConfig);
// 准备参数
Long id = fileConfig.getId();
// mock 获得 Client
FileClient fileClient = new LocalFileClient(id, new LocalFileClientConfig());
when(fileClientFactory.getFileClient(eq(0L))).thenReturn(fileClient);
// 调用并断言
assertSame(fileClient, fileConfigService.getMasterFileClient());
// 断言缓存
verify(fileClientFactory).createOrUpdateFileClient(eq(0L), eq(fileConfig.getStorage()),
eq(fileConfig.getConfig()));
}
private FileConfigDO randomFileConfigDO() {

View File

@ -5,7 +5,6 @@ import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.operatelog.core.annotations.OperateLog;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientFactory;
import cn.iocoder.yudao.framework.pay.core.client.dto.order.PayOrderRespDTO;
import cn.iocoder.yudao.framework.pay.core.client.dto.refund.PayRefundRespDTO;
import cn.iocoder.yudao.module.pay.controller.admin.notify.vo.PayNotifyTaskDetailRespVO;
@ -16,6 +15,7 @@ import cn.iocoder.yudao.module.pay.dal.dataobject.app.PayAppDO;
import cn.iocoder.yudao.module.pay.dal.dataobject.notify.PayNotifyLogDO;
import cn.iocoder.yudao.module.pay.dal.dataobject.notify.PayNotifyTaskDO;
import cn.iocoder.yudao.module.pay.service.app.PayAppService;
import cn.iocoder.yudao.module.pay.service.channel.PayChannelService;
import cn.iocoder.yudao.module.pay.service.notify.PayNotifyService;
import cn.iocoder.yudao.module.pay.service.order.PayOrderService;
import cn.iocoder.yudao.module.pay.service.refund.PayRefundService;
@ -53,9 +53,8 @@ public class PayNotifyController {
private PayNotifyService notifyService;
@Resource
private PayAppService appService;
@Resource
private PayClientFactory payClientFactory;
private PayChannelService channelService;
@PostMapping(value = "/order/{channelId}")
@Operation(summary = "支付渠道的统一【支付】回调")
@ -66,7 +65,7 @@ public class PayNotifyController {
@RequestBody(required = false) String body) {
log.info("[notifyOrder][channelId({}) 回调数据({}/{})]", channelId, params, body);
// 1. 校验支付渠道是否存在
PayClient payClient = payClientFactory.getPayClient(channelId);
PayClient payClient = channelService.getPayClient(channelId);
if (payClient == null) {
log.error("[notifyCallback][渠道编号({}) 找不到对应的支付客户端]", channelId);
throw exception(CHANNEL_NOT_FOUND);
@ -87,7 +86,7 @@ public class PayNotifyController {
@RequestBody(required = false) String body) {
log.info("[notifyRefund][channelId({}) 回调数据({}/{})]", channelId, params, body);
// 1. 校验支付渠道是否存在
PayClient payClient = payClientFactory.getPayClient(channelId);
PayClient payClient = channelService.getPayClient(channelId);
if (payClient == null) {
log.error("[notifyCallback][渠道编号({}) 找不到对应的支付客户端]", channelId);
throw exception(CHANNEL_NOT_FOUND);

View File

@ -28,7 +28,4 @@ public interface PayChannelMapper extends BaseMapperX<PayChannelDO> {
.eq(PayChannelDO::getStatus, status));
}
@Select("SELECT COUNT(*) FROM pay_channel WHERE update_time > #{maxUpdateTime}")
Long selectCountByUpdateTimeGt(LocalDateTime maxUpdateTime);
}

View File

@ -15,7 +15,6 @@ import javax.annotation.Resource;
* @author 芋道源码
*/
@Component
@TenantJob // 多租户
@Slf4j
public class PayNotifyJob implements JobHandler {
@ -23,6 +22,7 @@ public class PayNotifyJob implements JobHandler {
private PayNotifyService payNotifyService;
@Override
@TenantJob
public String execute(String param) throws Exception {
int notifyCount = payNotifyService.executeNotify();
return String.format("执行支付通知 %s 个", notifyCount);

View File

@ -16,13 +16,13 @@ import javax.annotation.Resource;
* @author 芋道源码
*/
@Component
@TenantJob
public class PayOrderExpireJob implements JobHandler {
@Resource
private PayOrderService orderService;
@Override
@TenantJob
public String execute(String param) {
int count = orderService.expireOrder();
return StrUtil.format("支付过期 {} 个", count);

View File

@ -18,7 +18,6 @@ import java.time.LocalDateTime;
* @author 芋道源码
*/
@Component
@TenantJob
public class PayOrderSyncJob implements JobHandler {
/**
@ -34,6 +33,7 @@ public class PayOrderSyncJob implements JobHandler {
private PayOrderService orderService;
@Override
@TenantJob
public String execute(String param) {
LocalDateTime minCreateTime = LocalDateTime.now().minus(CREATE_TIME_DURATION_BEFORE);
int count = orderService.syncOrder(minCreateTime);

View File

@ -16,13 +16,13 @@ import javax.annotation.Resource;
* @author 芋道源码
*/
@Component
@TenantJob
public class PayRefundSyncJob implements JobHandler {
@Resource
private PayRefundService refundService;
@Override
@TenantJob
public String execute(String param) {
int count = refundService.syncRefund();
return StrUtil.format("同步退款订单 {} 个", count);

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.pay.service.channel;
import cn.iocoder.yudao.framework.common.exception.ServiceException;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.module.pay.controller.admin.channel.vo.PayChannelCreateReqVO;
import cn.iocoder.yudao.module.pay.controller.admin.channel.vo.PayChannelUpdateReqVO;
import cn.iocoder.yudao.module.pay.dal.dataobject.channel.PayChannelDO;
@ -92,4 +93,12 @@ public interface PayChannelService {
*/
List<PayChannelDO> getEnableChannelList(Long appId);
/**
* 获得指定编号的支付客户端
*
* @param id 编号
* @return 支付客户端
*/
PayClient getPayClient(Long id);
}

View File

@ -1,38 +1,33 @@
package cn.iocoder.yudao.module.pay.service.channel;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjectUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.PayClientFactory;
import cn.iocoder.yudao.framework.pay.core.enums.channel.PayChannelEnum;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.pay.controller.admin.channel.vo.PayChannelCreateReqVO;
import cn.iocoder.yudao.module.pay.controller.admin.channel.vo.PayChannelUpdateReqVO;
import cn.iocoder.yudao.module.pay.convert.channel.PayChannelConvert;
import cn.iocoder.yudao.module.pay.dal.dataobject.channel.PayChannelDO;
import cn.iocoder.yudao.module.pay.dal.mysql.channel.PayChannelMapper;
import cn.iocoder.yudao.module.pay.framework.pay.wallet.WalletPayClient;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.validation.Validator;
import java.time.LocalDateTime;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.cache.CacheUtils.buildAsyncReloadingCache;
import static cn.iocoder.yudao.module.pay.enums.ErrorCodeConstants.*;
/**
@ -45,71 +40,34 @@ import static cn.iocoder.yudao.module.pay.enums.ErrorCodeConstants.*;
@Validated
public class PayChannelServiceImpl implements PayChannelService {
@Getter // 为了方便测试这里提供 getter 方法
@Setter
private volatile List<PayChannelDO> channelCache;
/**
* {@link PayClient} 缓存通过它异步清空 smsClientFactory
*/
@Getter
private final LoadingCache<Long, PayClient> clientCache = buildAsyncReloadingCache(Duration.ofSeconds(10L),
new CacheLoader<Long, PayClient>() {
@Override
public PayClient load(Long id) {
// 查询然后尝试清空
PayChannelDO channel = payChannelMapper.selectById(id);
if (channel != null) {
payClientFactory.createOrUpdatePayClient(channel.getId(), channel.getCode(), channel.getConfig());
}
return payClientFactory.getPayClient(id);
}
});
@Resource
private PayClientFactory payClientFactory;
@Resource
private PayChannelMapper channelMapper;
private PayChannelMapper payChannelMapper;
@Resource
private Validator validator;
/**
* 初始化 {@link #payClientFactory} 缓存
*/
@PostConstruct
public void initLocalCache() {
// 注册钱包支付 Class
payClientFactory.registerPayClientClass(PayChannelEnum.WALLET, WalletPayClient.class);
// 注意忽略自动多租户因为要全局初始化缓存
TenantUtils.executeIgnore(() -> {
// 第一步查询数据
List<PayChannelDO> channels = Collections.emptyList();
try {
channels = channelMapper.selectList();
} catch (Throwable ex) {
if (!ex.getMessage().contains("doesn't exist")) {
throw ex;
}
log.error("[支付模块 yudao-module-pay - 表结构未导入][参考 https://doc.iocoder.cn/pay/build/ 开启]");
}
log.info("[initLocalCache][缓存支付渠道,数量为:{}]", channels.size());
// 第二步构建缓存创建或更新支付 Client
channels.forEach(payChannel -> payClientFactory.createOrUpdatePayClient(payChannel.getId(),
payChannel.getCode(), payChannel.getConfig()));
this.channelCache = channels;
});
}
/**
* 通过定时任务轮询刷新缓存
*
* 目的多节点部署时通过轮询通知所有节点进行刷新
*/
@Scheduled(initialDelay = 60, fixedRate = 60, timeUnit = TimeUnit.SECONDS)
public void refreshLocalCache() {
// 注意忽略自动多租户因为要全局初始化缓存
TenantUtils.executeIgnore(() -> {
// 情况一如果缓存里没有数据则直接刷新缓存
if (CollUtil.isEmpty(channelCache)) {
initLocalCache();
return;
}
// 情况二如果缓存里数据则通过 updateTime 判断是否有数据变更有变更则刷新缓存
LocalDateTime maxTime = CollectionUtils.getMaxValue(channelCache, PayChannelDO::getUpdateTime);
if (channelMapper.selectCountByUpdateTimeGt(maxTime) > 0) {
initLocalCache();
}
});
}
@Override
public Long createChannel(PayChannelCreateReqVO reqVO) {
// 断言是否有重复的
@ -121,10 +79,7 @@ public class PayChannelServiceImpl implements PayChannelService {
// 新增渠道
PayChannelDO channel = PayChannelConvert.INSTANCE.convert(reqVO)
.setConfig(parseConfig(reqVO.getCode(), reqVO.getConfig()));
channelMapper.insert(channel);
// 刷新缓存
initLocalCache();
payChannelMapper.insert(channel);
return channel.getId();
}
@ -136,10 +91,10 @@ public class PayChannelServiceImpl implements PayChannelService {
// 更新
PayChannelDO channel = PayChannelConvert.INSTANCE.convert(updateReqVO)
.setConfig(parseConfig(dbChannel.getCode(), updateReqVO.getConfig()));
channelMapper.updateById(channel);
payChannelMapper.updateById(channel);
// 刷新缓存
initLocalCache();
// 清空缓存
clearCache(channel.getId());
}
/**
@ -169,14 +124,23 @@ public class PayChannelServiceImpl implements PayChannelService {
validateChannelExists(id);
// 删除
channelMapper.deleteById(id);
payChannelMapper.deleteById(id);
// 刷新缓存
initLocalCache();
// 清空缓存
clearCache(id);
}
/**
* 删除缓存
*
* @param id 渠道编号
*/
private void clearCache(Long id) {
clientCache.invalidate(id);
}
private PayChannelDO validateChannelExists(Long id) {
PayChannelDO channel = channelMapper.selectById(id);
PayChannelDO channel = payChannelMapper.selectById(id);
if (channel == null) {
throw exception(CHANNEL_NOT_FOUND);
}
@ -185,29 +149,29 @@ public class PayChannelServiceImpl implements PayChannelService {
@Override
public PayChannelDO getChannel(Long id) {
return channelMapper.selectById(id);
return payChannelMapper.selectById(id);
}
@Override
public List<PayChannelDO> getChannelListByAppIds(Collection<Long> appIds) {
return channelMapper.selectListByAppIds(appIds);
return payChannelMapper.selectListByAppIds(appIds);
}
@Override
public PayChannelDO getChannelByAppIdAndCode(Long appId, String code) {
return channelMapper.selectByAppIdAndCode(appId, code);
return payChannelMapper.selectByAppIdAndCode(appId, code);
}
@Override
public PayChannelDO validPayChannel(Long id) {
PayChannelDO channel = channelMapper.selectById(id);
PayChannelDO channel = payChannelMapper.selectById(id);
validPayChannel(channel);
return channel;
}
@Override
public PayChannelDO validPayChannel(Long appId, String code) {
PayChannelDO channel = channelMapper.selectByAppIdAndCode(appId, code);
PayChannelDO channel = payChannelMapper.selectByAppIdAndCode(appId, code);
validPayChannel(channel);
return channel;
}
@ -223,7 +187,12 @@ public class PayChannelServiceImpl implements PayChannelService {
@Override
public List<PayChannelDO> getEnableChannelList(Long appId) {
return channelMapper.selectListByAppId(appId, CommonStatusEnum.ENABLE.getStatus());
return payChannelMapper.selectListByAppId(appId, CommonStatusEnum.ENABLE.getStatus());
}
@Override
public PayClient getPayClient(Long id) {
return clientCache.getUnchecked(id);
}
}

View File

@ -8,7 +8,6 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.date.LocalDateTimeUtils;
import cn.iocoder.yudao.framework.common.util.number.MoneyUtils;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientFactory;
import cn.iocoder.yudao.framework.pay.core.client.dto.order.PayOrderRespDTO;
import cn.iocoder.yudao.framework.pay.core.client.dto.order.PayOrderUnifiedReqDTO;
import cn.iocoder.yudao.framework.pay.core.enums.order.PayOrderStatusRespEnum;
@ -60,9 +59,6 @@ public class PayOrderServiceImpl implements PayOrderService {
@Resource
private PayProperties payProperties;
@Resource
private PayClientFactory payClientFactory;
@Resource
private PayOrderMapper orderMapper;
@Resource
@ -134,7 +130,7 @@ public class PayOrderServiceImpl implements PayOrderService {
PayOrderDO order = validateOrderCanSubmit(reqVO.getId());
// 1.32 校验支付渠道是否有效
PayChannelDO channel = validateChannelCanSubmit(order.getAppId(), reqVO.getChannelCode());
PayClient client = payClientFactory.getPayClient(channel.getId());
PayClient client = channelService.getPayClient(channel.getId());
// 2. 插入 PayOrderExtensionDO
String no = noRedisDAO.generate(payProperties.getOrderNoPrefix());
@ -205,7 +201,7 @@ public class PayOrderServiceImpl implements PayOrderService {
throw exception(ORDER_EXTENSION_IS_PAID);
}
// 情况二调用三方接口查询支付单状态是不是已支付
PayClient payClient = payClientFactory.getPayClient(orderExtension.getChannelId());
PayClient payClient = channelService.getPayClient(orderExtension.getChannelId());
if (payClient == null) {
log.error("[validateOrderCanSubmit][渠道编号({}) 找不到对应的支付客户端]", orderExtension.getChannelId());
return;
@ -224,7 +220,7 @@ public class PayOrderServiceImpl implements PayOrderService {
appService.validPayApp(appId);
// 校验支付渠道是否有效
PayChannelDO channel = channelService.validPayChannel(appId, channelCode);
PayClient client = payClientFactory.getPayClient(channel.getId());
PayClient client = channelService.getPayClient(channel.getId());
if (client == null) {
log.error("[validatePayChannelCanSubmit][渠道编号({}) 找不到对应的支付客户端]", channel.getId());
throw exception(CHANNEL_NOT_FOUND);
@ -458,7 +454,7 @@ public class PayOrderServiceImpl implements PayOrderService {
private boolean syncOrder(PayOrderExtensionDO orderExtension) {
try {
// 1.1 查询支付订单信息
PayClient payClient = payClientFactory.getPayClient(orderExtension.getChannelId());
PayClient payClient = channelService.getPayClient(orderExtension.getChannelId());
if (payClient == null) {
log.error("[syncOrder][渠道编号({}) 找不到对应的支付客户端]", orderExtension.getChannelId());
return false;
@ -513,7 +509,7 @@ public class PayOrderServiceImpl implements PayOrderService {
return false;
}
// 情况二调用三方接口查询支付单状态是不是已支付/已退款
PayClient payClient = payClientFactory.getPayClient(orderExtension.getChannelId());
PayClient payClient = channelService.getPayClient(orderExtension.getChannelId());
if (payClient == null) {
log.error("[expireOrder][渠道编号({}) 找不到对应的支付客户端]", orderExtension.getChannelId());
return false;

View File

@ -4,7 +4,6 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientFactory;
import cn.iocoder.yudao.framework.pay.core.client.dto.refund.PayRefundRespDTO;
import cn.iocoder.yudao.framework.pay.core.client.dto.refund.PayRefundUnifiedReqDTO;
import cn.iocoder.yudao.framework.pay.core.enums.refund.PayRefundStatusRespEnum;
@ -52,9 +51,6 @@ public class PayRefundServiceImpl implements PayRefundService {
@Resource
private PayProperties payProperties;
@Resource
private PayClientFactory payClientFactory;
@Resource
private PayRefundMapper refundMapper;
@Resource
@ -102,7 +98,7 @@ public class PayRefundServiceImpl implements PayRefundService {
PayOrderDO order = validatePayOrderCanRefund(reqDTO);
// 1.3 校验支付渠道是否有效
PayChannelDO channel = channelService.validPayChannel(order.getChannelId());
PayClient client = payClientFactory.getPayClient(channel.getId());
PayClient client = channelService.getPayClient(channel.getId());
if (client == null) {
log.error("[refund][渠道编号({}) 找不到对应的支付客户端]", channel.getId());
throw exception(CHANNEL_NOT_FOUND);
@ -305,7 +301,7 @@ public class PayRefundServiceImpl implements PayRefundService {
private boolean syncRefund(PayRefundDO refund) {
try {
// 1.1 查询退款订单信息
PayClient payClient = payClientFactory.getPayClient(refund.getChannelId());
PayClient payClient = channelService.getPayClient(refund.getChannelId());
if (payClient == null) {
log.error("[syncRefund][渠道编号({}) 找不到对应的支付客户端]", refund.getChannelId());
return false;

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.pay.service.channel;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import cn.iocoder.yudao.framework.pay.core.client.PayClient;
import cn.iocoder.yudao.framework.pay.core.client.PayClientFactory;
import cn.iocoder.yudao.framework.pay.core.client.impl.alipay.AlipayPayClientConfig;
import cn.iocoder.yudao.framework.pay.core.client.impl.weixin.WxPayClientConfig;
@ -12,30 +13,26 @@ import cn.iocoder.yudao.module.pay.controller.admin.channel.vo.PayChannelUpdateR
import cn.iocoder.yudao.module.pay.dal.dataobject.channel.PayChannelDO;
import cn.iocoder.yudao.module.pay.dal.mysql.channel.PayChannelMapper;
import com.alibaba.fastjson.JSON;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.annotation.Import;
import javax.annotation.Resource;
import javax.validation.Validator;
import java.time.Duration;
import java.util.Collections;
import java.util.List;
import static cn.iocoder.yudao.framework.common.util.date.LocalDateTimeUtils.addTime;
import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertPojoEquals;
import static cn.iocoder.yudao.framework.test.core.util.AssertUtils.assertServiceException;
import static cn.iocoder.yudao.framework.test.core.util.RandomUtils.*;
import static cn.iocoder.yudao.module.pay.enums.ErrorCodeConstants.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;
@Import({PayChannelServiceImpl.class})
public class PayChannelServiceTest extends BaseDbUnitTest {
private static final String ALIPAY_SERVER_URL = "https://openapi.alipay.com/gateway.do";
@Resource
private PayChannelServiceImpl channelService;
@ -47,45 +44,6 @@ public class PayChannelServiceTest extends BaseDbUnitTest {
@MockBean
private Validator validator;
@BeforeEach
public void setUp() {
channelService.setChannelCache(null);
}
@Test
public void testInitLocalCache() {
// mock 数据
PayChannelDO dbChannel = randomPojo(PayChannelDO.class,
o -> o.setConfig(randomWxPayClientConfig()));
channelMapper.insert(dbChannel);// @Sql: 先插入出一条存在的数据
// 调用
channelService.initLocalCache();
// 校验缓存
assertEquals(1, channelService.getChannelCache().size());
assertEquals(dbChannel, channelService.getChannelCache().get(0));
}
@Test
public void testRefreshLocalCache() {
// mock 数据 01
PayChannelDO dbChannel = randomPojo(PayChannelDO.class,
o -> o.setConfig(randomWxPayClientConfig()).setUpdateTime(addTime(Duration.ofMinutes(-2))));
channelMapper.insert(dbChannel);// @Sql: 先插入出一条存在的数据
channelService.initLocalCache();
// mock 数据 02
PayChannelDO dbChannel02 = randomPojo(PayChannelDO.class,
o -> o.setConfig(randomWxPayClientConfig()));
channelMapper.insert(dbChannel02);// @Sql: 先插入出一条存在的数据
// 调用
channelService.refreshLocalCache();
// 校验缓存
assertEquals(2, channelService.getChannelCache().size());
assertEquals(dbChannel, channelService.getChannelCache().get(0));
assertEquals(dbChannel02, channelService.getChannelCache().get(1));
}
@Test
public void testCreateChannel_success() {
// 准备参数
@ -103,8 +61,7 @@ public class PayChannelServiceTest extends BaseDbUnitTest {
assertPojoEquals(reqVO, channel, "config");
assertPojoEquals(config, channel.getConfig());
// 校验缓存
assertEquals(1, channelService.getChannelCache().size());
assertEquals(channel, channelService.getChannelCache().get(0));
assertNull(channelService.getClientCache().getIfPresent(channelId));
}
@Test
@ -146,8 +103,7 @@ public class PayChannelServiceTest extends BaseDbUnitTest {
assertPojoEquals(reqVO, channel, "config");
assertPojoEquals(config, channel.getConfig());
// 校验缓存
assertEquals(1, channelService.getChannelCache().size());
assertEquals(channel, channelService.getChannelCache().get(0));
assertNull(channelService.getClientCache().getIfPresent(channel.getId()));
}
@Test
@ -179,7 +135,7 @@ public class PayChannelServiceTest extends BaseDbUnitTest {
// 校验数据不存在了
assertNull(channelMapper.selectById(id));
// 校验缓存
assertEquals(0, channelService.getChannelCache().size());
assertNull(channelService.getClientCache().getIfPresent(id));
}
@Test
@ -344,6 +300,28 @@ public class PayChannelServiceTest extends BaseDbUnitTest {
assertPojoEquals(channel, dbChannel03);
}
@Test
public void testGetPayClient() {
// mock 数据
PayChannelDO channel = randomPojo(PayChannelDO.class, o -> {
o.setCode(PayChannelEnum.ALIPAY_APP.getCode());
o.setConfig(randomAlipayPayClientConfig());
});
channelMapper.insert(channel);
// mock 参数
Long id = channel.getId();
// mock 方法
PayClient mockClient = mock(PayClient.class);
when(payClientFactory.getPayClient(eq(id))).thenReturn(mockClient);
// 调用
PayClient client = channelService.getPayClient(id);
// 断言
assertSame(client, mockClient);
verify(payClientFactory).createOrUpdatePayClient(eq(id), eq(channel.getCode()),
eq(channel.getConfig()));
}
public WxPayClientConfig randomWxPayClientConfig() {
return new WxPayClientConfig()
.setAppId(randomString())

View File

@ -18,6 +18,7 @@ import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;
import org.springframework.validation.annotation.Validated;
import javax.annotation.PostConstruct;
@ -42,6 +43,11 @@ import static cn.iocoder.yudao.module.system.enums.ErrorCodeConstants.SENSITIVE_
@Validated
public class SensitiveWordServiceImpl implements SensitiveWordService {
/**
* 是否开启敏感词功能
*/
private static final Boolean ENABLED = false;
/**
* 敏感词列表缓存
*/
@ -75,6 +81,10 @@ public class SensitiveWordServiceImpl implements SensitiveWordService {
*/
@PostConstruct
public void initLocalCache() {
if (!ENABLED) {
return;
}
// 第一步查询数据
List<SensitiveWordDO> sensitiveWords = sensitiveWordMapper.selectList();
log.info("[initLocalCache][缓存敏感词,数量为:{}]", sensitiveWords.size());
@ -216,6 +226,9 @@ public class SensitiveWordServiceImpl implements SensitiveWordService {
@Override
public List<String> validateText(String text, List<String> tags) {
Assert.isTrue(ENABLED, "敏感词功能未开启,请将 ENABLED 设置为 true");
// 无标签时默认所有
if (CollUtil.isEmpty(tags)) {
return defaultSensitiveWordTrie.validate(text);
}
@ -233,6 +246,9 @@ public class SensitiveWordServiceImpl implements SensitiveWordService {
@Override
public boolean isTextValid(String text, List<String> tags) {
Assert.isTrue(ENABLED, "敏感词功能未开启,请将 ENABLED 设置为 true");
// 无标签时默认所有
if (CollUtil.isEmpty(tags)) {
return defaultSensitiveWordTrie.isValid(text);
}