【同步】AI:最新 MJ 的 code review

This commit is contained in:
YunaiV 2024-06-06 20:18:18 +08:00
parent 90f920f55b
commit e781129dbe
9 changed files with 48 additions and 22 deletions

View File

@ -17,7 +17,7 @@ import org.springframework.web.client.RestTemplate;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
// TODO @fan这个写到 starter-ai 里哈搞个 MidjourneyApi参考 https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java 的风格写哈 // TODO @fan高优这个写到 starter-ai 里哈搞个 MidjourneyApi参考 https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java 的风格写哈
/** /**
* Midjourney Proxy 客户端 * Midjourney Proxy 客户端
* *

View File

@ -11,18 +11,14 @@ import java.util.List;
* Midjourney 提交任务 code 枚举 * Midjourney 提交任务 code 枚举
* *
* @author fansili * @author fansili
* @time 2024/5/30 14:33
* @since 1.0
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
public enum MidjourneySubmitCodeEnum { public enum MidjourneySubmitCodeEnum {
// 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
SUBMIT_SUCCESS("1", "提交成功"), SUBMIT_SUCCESS("1", "提交成功"),
ALREADY_EXISTS("21", "已存在"), ALREADY_EXISTS("21", "已存在"),
QUEUING("22", "排队中"), QUEUING("22", "排队中"),
; ;
public static final List<String> SUCCESS_CODES = Lists.newArrayList( public static final List<String> SUCCESS_CODES = Lists.newArrayList(
@ -31,7 +27,7 @@ public enum MidjourneySubmitCodeEnum {
QUEUING.code QUEUING.code
); );
private String code; private final String code;
private String name; private final String name;
} }

View File

@ -63,25 +63,27 @@ public class AiImageController {
return success(true); return success(true);
} }
// ================ midjourney 接口 // ================ midjourney 接口 ================
@Operation(summary = "midjourney-imagine 绘画", description = "...") @Operation(summary = "Midjourney imagine绘画")
@PostMapping("/midjourney/imagine") @PostMapping("/midjourney/imagine")
public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) { public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) {
return success(imageService.midjourneyImagine(getLoginUserId(), req)); return success(imageService.midjourneyImagine(getLoginUserId(), req));
} }
@Operation(summary = "midjourney proxy - 回调通知") @Operation(summary = "Midjourney 回调通知", description = "由 Midjourney Proxy 回调")
@PostMapping("/midjourney-notify") @PostMapping("/midjourney-notify")
@PermitAll @PermitAll
public CommonResult<Boolean> midjourneyNotify(@RequestBody MidjourneyNotifyReqVO notifyReqVO) { public CommonResult<Boolean> midjourneyNotify(@RequestBody MidjourneyNotifyReqVO notifyReqVO) {
return success(imageService.midjourneyNotify(notifyReqVO)); return success(imageService.midjourneyNotify(notifyReqVO));
} }
@Operation(summary = "midjourney - action(放大、缩小、U1、U2...)") @Operation(summary = "Midjourney Action", description = "例如说放大、缩小、U1、U2 等")
@GetMapping("/midjourney/action") @GetMapping("/midjourney/action")
// TODO @fanidcustomerId swagger 注解
public CommonResult<Boolean> midjourneyAction(@RequestParam("id") Long imageId, public CommonResult<Boolean> midjourneyAction(@RequestParam("id") Long imageId,
@RequestParam("customId") String customId) { @RequestParam("customId") String customId) {
return success(imageService.midjourneyAction(getLoginUserId(), imageId, customId)); return success(imageService.midjourneyAction(getLoginUserId(), imageId, customId));
} }
} }

View File

@ -50,9 +50,11 @@ public class AiImageRespVO {
@Schema(description = "绘画 response") @Schema(description = "绘画 response")
private MidjourneyNotifyReqVO response; private MidjourneyNotifyReqVO response;
// TODO @fan进度是百分比还是一个数字哈感觉这个可以统一成通用字段
@Schema(description = "mj 进度") @Schema(description = "mj 进度")
private String progress; private String progress;
@Schema(description = "mj buttons 按钮") @Schema(description = "mj buttons 按钮")
private List<MidjourneyNotifyReqVO.Button> buttons; private List<MidjourneyNotifyReqVO.Button> buttons;
} }

View File

@ -123,6 +123,7 @@ public class AiImageDO extends BaseDO {
*/ */
private String errorMessage; private String errorMessage;
// TODO @芋艿看看是不是 MidjourneyNotifyReqVO.Button 搞到 MJ API
public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> { public static class ButtonTypeHandler extends AbstractJsonTypeHandler<Object> {
@Override @Override
@ -134,6 +135,7 @@ public class AiImageDO extends BaseDO {
protected String toJson(Object obj) { protected String toJson(Object obj) {
return JsonUtils.toJsonString(obj); return JsonUtils.toJsonString(obj);
} }
} }
} }

View File

@ -30,6 +30,7 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class MidjourneyJob implements JobHandler { public class MidjourneyJob implements JobHandler {
// TODO @fan@Resource
@Autowired @Autowired
private MidjourneyProxyClient midjourneyProxyClient; private MidjourneyProxyClient midjourneyProxyClient;
@Autowired @Autowired
@ -37,10 +38,13 @@ public class MidjourneyJob implements JobHandler {
@Autowired @Autowired
private AiImageService imageService; private AiImageService imageService;
// TODO @fan这个方法建议实现到 AiImageService例如说 midjourneySync返回 int 同步数量
@Override @Override
public String execute(String param) throws Exception { public String execute(String param) throws Exception {
// 1获取 midjourney 平台状态在 进行中 image // 1获取 midjourney 平台状态在 进行中 image
// TODO @fan43 51 其实有点重叠日志建议只打 51
log.info("Midjourney 同步 - 开始..."); log.info("Midjourney 同步 - 开始...");
// TODO @fanJobService 等业务层不要直接使用 LambdaUpdateWrapper这样会导致 mapper 穿透到逻辑层要收敛到 mapper
List<AiImageDO> imageList = imageMapper.selectList( List<AiImageDO> imageList = imageMapper.selectList(
new LambdaUpdateWrapper<AiImageDO>() new LambdaUpdateWrapper<AiImageDO>()
.eq(AiImageDO::getStatus, AiImageStatusEnum.IN_PROGRESS.getStatus()) .eq(AiImageDO::getStatus, AiImageStatusEnum.IN_PROGRESS.getStatus())
@ -48,11 +52,14 @@ public class MidjourneyJob implements JobHandler {
); );
log.info("Midjourney 同步 - 任务数量 {}!", imageList.size()); log.info("Midjourney 同步 - 任务数量 {}!", imageList.size());
if (CollUtil.isEmpty(imageList)) { if (CollUtil.isEmpty(imageList)) {
// TODO @fan51 54其实有点重叠建议 51 挪到 55 之后打
return "Midjourney 同步 - 数量为空!"; return "Midjourney 同步 - 数量为空!";
} }
// 2批量拉去 task 信息 // 2批量拉去 task 信息
// TODO @fanimageList.stream().map(AiImageDO::getTaskId).collect(Collectors.toSet()))可以使用 CollectionUtils.convertSet 简化
List<MidjourneyNotifyReqVO> taskList = midjourneyProxyClient List<MidjourneyNotifyReqVO> taskList = midjourneyProxyClient
.listByCondition(imageList.stream().map(AiImageDO::getTaskId).collect(Collectors.toSet())); .listByCondition(imageList.stream().map(AiImageDO::getTaskId).collect(Collectors.toSet()));
// TODO @fantaskList.stream().collect(Collectors.toMap(MidjourneyNotifyReqVO::getId, o -> o))也可以使用 CollectionUtils.convertMap本质上重用 setmap 转换 convert 简化
Map<String, MidjourneyNotifyReqVO> taskIdMap = taskList.stream().collect(Collectors.toMap(MidjourneyNotifyReqVO::getId, o -> o)); Map<String, MidjourneyNotifyReqVO> taskIdMap = taskList.stream().collect(Collectors.toMap(MidjourneyNotifyReqVO::getId, o -> o));
// 3更新 image 状态 // 3更新 image 状态
List<AiImageDO> updateImageList = new ArrayList<>(); List<AiImageDO> updateImageList = new ArrayList<>();
@ -62,13 +69,16 @@ public class MidjourneyJob implements JobHandler {
log.warn("Midjourney 同步 - {} 任务为空!", aiImageDO.getTaskId()); log.warn("Midjourney 同步 - {} 任务为空!", aiImageDO.getTaskId());
continue; continue;
} }
// TODO @ 3.1 3.2 是不是融合下get然后判空continue
// 3.2 获取通知对象 // 3.2 获取通知对象
MidjourneyNotifyReqVO notifyReqVO = taskIdMap.get(aiImageDO.getTaskId()); MidjourneyNotifyReqVO notifyReqVO = taskIdMap.get(aiImageDO.getTaskId());
// 3.2 构建更新对象 // 3.2 构建更新对象
// TODO @fan建议 List<MidjourneyNotifyReqVO> 作为 imageService 去更新
updateImageList.add(imageService.buildUpdateImage(aiImageDO.getId(), notifyReqVO)); updateImageList.add(imageService.buildUpdateImage(aiImageDO.getId(), notifyReqVO));
} }
// 4批了更新 updateImageList // 4批了更新 updateImageList
imageMapper.updateBatch(updateImageList); imageMapper.updateBatch(updateImageList);
return "Midjourney 同步 - ".concat(String.valueOf(updateImageList.size())).concat(" 任务!"); return "Midjourney 同步 - ".concat(String.valueOf(updateImageList.size())).concat(" 任务!");
} }
} }

View File

@ -36,17 +36,18 @@ public interface AiImageService {
* *
* @param userId 用户编号 * @param userId 用户编号
* @param drawReqVO 绘制请求 * @param drawReqVO 绘制请求
* @return 绘画编号
*/ */
Long drawImage(Long userId, AiImageDrawReqVO drawReqVO); Long drawImage(Long userId, AiImageDrawReqVO drawReqVO);
/** /**
* midjourney 图片生成 * Midjourney imagine绘画
* *
* @param loginUserId * @param userId 用户编号
* @param req * @param imagineReqVO 绘制请求
* @return * @return 绘画编号
*/ */
Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req); Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO imagineReqVO);
/** /**
* 删除我的绘画记录 * 删除我的绘画记录

View File

@ -133,10 +133,12 @@ public class AiImageServiceImpl implements AiImageService {
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) { public Long midjourneyImagine(Long userId, AiImageMidjourneyImagineReqVO req) {
// TODO @fan1 2 应该放在一个 1 里面不然 = = 一个逻辑就显得有很多 1/2/3/4这么分的原因是方便阅读的时候容易理解
// 1构建 AiImageDO // 1构建 AiImageDO
// TODO @fan1aiImageDO 可以缩写成 image 更简洁2可以链式调用把相同的放在一行里这样更容易分组
AiImageDO aiImageDO = new AiImageDO(); AiImageDO aiImageDO = new AiImageDO();
aiImageDO.setUserId(loginUserId); aiImageDO.setUserId(userId);
aiImageDO.setPrompt(req.getPrompt()); aiImageDO.setPrompt(req.getPrompt());
aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()); aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
aiImageDO.setModel(req.getModel()); aiImageDO.setModel(req.getModel());
@ -145,6 +147,7 @@ public class AiImageServiceImpl implements AiImageService {
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
// 2保存 image // 2保存 image
imageMapper.insert(aiImageDO); imageMapper.insert(aiImageDO);
// TODO @fan3 2 之间应该空一行因为这里是开始发起请求第三方是个单独的小块逻辑
// 3调用 MidjourneyProxy 提交任务 // 3调用 MidjourneyProxy 提交任务
MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class); MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
imagineReqVO.setNotifyHook(midjourneyNotifyUrl); imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
@ -152,12 +155,15 @@ public class AiImageServiceImpl implements AiImageService {
imagineReqVO.setState(buildParams(req.getWidth(), imagineReqVO.setState(buildParams(req.getWidth(),
req.getHeight(), req.getVersion(), MidjourneyModelEnum.valueOfModel(req.getModel()))); req.getHeight(), req.getVersion(), MidjourneyModelEnum.valueOfModel(req.getModel())));
// 5提交绘画请求 // 5提交绘画请求
// TODO @fan5 这里失败的情况到底抛出异常还是 RespVO可以参考 OpenAI API 封装
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO); MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
// 6保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)) // 6保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) { if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription()); throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
} }
// TODO @fan7 6 之间应该空一行这样最终这个逻辑就会有 2 个空行3 小块逻辑1插入2调用3更新
// 7构建 imageOptions 参数 // 7构建 imageOptions 参数
// TODO @fan1链式调用2其实可以直接使用 AiImageMidjourneyImagineReqVO不用单独一个 options 类哈
MidjourneyImageOptions imageOptions = new MidjourneyImageOptions() MidjourneyImageOptions imageOptions = new MidjourneyImageOptions()
.setWidth(req.getWidth()) .setWidth(req.getWidth())
.setHeight(req.getHeight()) .setHeight(req.getHeight())
@ -181,10 +187,11 @@ public class AiImageServiceImpl implements AiImageService {
if (ObjUtil.notEqual(image.getUserId(), userId)) { if (ObjUtil.notEqual(image.getUserId(), userId)) {
throw exception(AI_IMAGE_NOT_EXISTS); throw exception(AI_IMAGE_NOT_EXISTS);
} }
// 2删除记录 // 2. 删除记录
imageMapper.deleteById(id); imageMapper.deleteById(id);
} }
// TODO @fan建议返回 void然后如果不存在就抛出异常哈
@Override @Override
public Boolean midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) { public Boolean midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
// 1根据 job id 查询关联的 image // 1根据 job id 查询关联的 image
@ -228,15 +235,19 @@ public class AiImageServiceImpl implements AiImageService {
.setErrorMessage(notifyReqVO.getFailReason()); .setErrorMessage(notifyReqVO.getFailReason());
} }
// TODO @fan1不用返回 Boolean
@Override @Override
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class) // TODO @fan只操作一个 db不用事务哈
public Boolean midjourneyAction(Long loginUserId, Long imageId, String customId) { public Boolean midjourneyAction(Long loginUserId, Long imageId, String customId) {
// TODO @fan11 2可以写成 1.11.2都是在做校验2validateCustomId 可以直接抛出 AI_IMAGE_CUSTOM_ID_NOT_EXISTS 异常一般情况下validateXXX 都是失败抛出异常isXXXValid 返回 truefalse
// 1检查 image // 1检查 image
// TODO @fan1aiImageDO 缩写成 image
AiImageDO aiImageDO = validateImageExists(imageId); AiImageDO aiImageDO = validateImageExists(imageId);
// 2检查 customId // 2检查 customId
if (!validateCustomId(customId, aiImageDO.getButtons())) { if (!validateCustomId(customId, aiImageDO.getButtons())) {
throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS); throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
} }
// TODO @fan2 3 之间空一行
// 3调用 midjourney proxy // 3调用 midjourney proxy
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.action( MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.action(
new MidjourneyActionReqVO() new MidjourneyActionReqVO()
@ -248,8 +259,10 @@ public class AiImageServiceImpl implements AiImageService {
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) { if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription()); throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, submitRespVO.getDescription());
} }
// TODO 6 4 之间空一行
// 4新增 image 记录 // 4新增 image 记录
AiImageDO newImage = BeanUtils.toBean(aiImageDO, AiImageDO.class); AiImageDO newImage = BeanUtils.toBean(aiImageDO, AiImageDO.class);
// TODO @fan最好不要 copy 属性因为未来如果加属性可能会导致额外 copy 最好是 new 赋值下显示声明
// 4.1重置参数 // 4.1重置参数
newImage.setId(null); newImage.setId(null);
newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus()); newImage.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
@ -290,6 +303,7 @@ public class AiImageServiceImpl implements AiImageService {
return SpringUtil.getBean(getClass()); return SpringUtil.getBean(getClass());
} }
// TODO @fan这个是不是应该放在 MJ API 的封装里面搞哈
/** /**
* 构建 Midjourney 自定义参数 * 构建 Midjourney 自定义参数
* *

View File

@ -76,14 +76,13 @@ server:
enabled: true enabled: true
charset: UTF-8 charset: UTF-8
force: true force: true
# ai # ai TODO @fan这个融合到 yudao.ai 那好点哈
ai: ai:
midjourney-proxy: midjourney-proxy:
url: https://api.holdai.top/mj url: https://api.holdai.top/mj
notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify notifyUrl: http://61d61685.r21.cpolar.top/admin-api/ai/image/midjourney-notify
key: sk-c3qxUCVKsPfdQiYU8440E3Fc8dE5424d9cB124A4Ee2489E3 key: sk-c3qxUCVKsPfdQiYU8440E3Fc8dE5424d9cB124A4Ee2489E3
--- #################### 定时任务相关配置 #################### --- #################### 定时任务相关配置 ####################
# Quartz 配置项,对应 QuartzProperties 配置类 # Quartz 配置项,对应 QuartzProperties 配置类