【增加】Midjourney Proxy 回调通知

This commit is contained in:
cherishsince 2024-05-31 17:05:28 +08:00
parent 36b6fee0ec
commit 56e8707e38
6 changed files with 67 additions and 40 deletions

View File

@ -11,7 +11,10 @@ import lombok.Data;
* @since 1.0 * @since 1.0
*/ */
@Data @Data
public class MidjourneyNotifyVO { public class MidjourneyNotifyReqVO {
@Schema(description = "job id")
private String id;
@Schema(description = "任务类型") @Schema(description = "任务类型")
private MidjourneyTaskActionEnum action; private MidjourneyTaskActionEnum action;

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.image;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
@ -13,7 +14,6 @@ import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
@ -74,9 +74,9 @@ public class AiImageController {
return success(aiImageService.deleteIdMy(id, getLoginUserId())); return success(aiImageService.deleteIdMy(id, getLoginUserId()));
} }
@Operation(summary = "删除【我的】绘画记录") @Operation(summary = "midjourney proxy - 回调通知")
@RequestMapping("/midjourney-notify") @RequestMapping("/midjourney-notify")
public CommonResult<Boolean> midjourneyNotify(HttpServletRequest request) { public CommonResult<Boolean> midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
return success(true); return success(aiImageService.midjourneyNotify(getLoginUserId(), notifyReqVO));
} }
} }

View File

@ -28,6 +28,9 @@ public class AiImageDO extends BaseDO {
@Schema(description = "用户编号") @Schema(description = "用户编号")
private Long userId; private Long userId;
@Schema(description = "midjourney proxy 关联的 job id")
private String jobId;
@Schema(description = "提示词") @Schema(description = "提示词")
private String prompt; private String prompt;

View File

@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
import org.springframework.stereotype.Repository;
/** /**
* AI 绘图 Mapper * AI 绘图 Mapper
@ -26,4 +25,14 @@ public interface AiImageMapper extends BaseMapperX<AiImageDO> {
return; return;
} }
/**
* 查询 - 根据 job id
*
* @param id
* @return
*/
default AiImageDO selectByJobId(String id) {
return this.selectOne(new LambdaQueryWrapperX<AiImageDO>().eq(AiImageDO::getJobId, id));
}
} }

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.image; package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
@ -65,4 +66,12 @@ public interface AiImageService {
*/ */
Boolean deleteIdMy(Long id, Long loginUserId); Boolean deleteIdMy(Long id, Long loginUserId);
/**
* midjourney proxy - 回调通知
*
* @param loginUserId
* @param notifyReqVO
* @return
*/
Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO);
} }

View File

@ -1,5 +1,7 @@
package cn.iocoder.yudao.module.ai.service.image; package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
@ -14,9 +16,14 @@ import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient; import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum; import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum; import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO; import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO; import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*; import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper; import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum; import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
@ -36,15 +43,9 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
// TODO @fan注释优化下哈
/** /**
* AI 绘画(接入 dall2/dall3midjourney) * AI 绘画(接入 dall2/dall3midjourney)
* *
@ -56,9 +57,6 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
@Slf4j @Slf4j
public class AiImageServiceImpl implements AiImageService { public class AiImageServiceImpl implements AiImageService {
// TODO @fan使用 @Resource 注入
// TODO @fanimageMapper
@Resource @Resource
private AiImageMapper imageMapper; private AiImageMapper imageMapper;
@Resource @Resource
@ -173,19 +171,16 @@ public class AiImageServiceImpl implements AiImageService {
// 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)) // 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
String updateStatus = null; String updateStatus = null;
String errorMessage = null; String errorMessage = null;
Map<String, Object> drawResponse = new HashMap<>();
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) { if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
updateStatus = AiImageStatusEnum.FAIL.getStatus(); updateStatus = AiImageStatusEnum.FAIL.getStatus();
errorMessage = submitRespVO.getDescription(); errorMessage = submitRespVO.getDescription();
} else {
drawResponse.put("jobId", submitRespVO.getResult());
} }
imageMapper.updateById(new AiImageDO() imageMapper.updateById(new AiImageDO()
.setId(aiImageDO.getId()) .setId(aiImageDO.getId())
.setStatus(updateStatus) .setStatus(updateStatus)
.setErrorMessage(errorMessage) .setErrorMessage(errorMessage)
.setDrawResponse(drawResponse) .setJobId(submitRespVO.getResult())
); );
return aiImageDO.getId(); return aiImageDO.getId();
} }
@ -228,28 +223,36 @@ public class AiImageServiceImpl implements AiImageService {
return imageMapper.deleteById(id) > 0; return imageMapper.deleteById(id) > 0;
} }
private void validateMessageId(String mjMessageId, String messageId) { @Override
if (!mjMessageId.equals(messageId)) { public Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_MESSAGE_ID_INCORRECT); // 1根据 job id 查询关联的 image
AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
if (image == null) {
log.warn("midjourneyNotify 回调的 jobId 不存在! jobId: {}", notifyReqVO.getId());
return false;
} }
} //
String imageStatus = null;
private AiImageMidjourneyOperationsVO validateMidjourneyOperationsExists(List<AiImageMidjourneyOperationsVO> midjourneyOperations, String operateId) { if (MidjourneyTaskStatusEnum.SUCCESS == notifyReqVO.getStatus()) {
for (AiImageMidjourneyOperationsVO midjourneyOperation : midjourneyOperations) { imageStatus = AiImageStatusEnum.COMPLETE.getStatus();
if (midjourneyOperation.getCustom_id().equals(operateId)) { } else if (MidjourneyTaskStatusEnum.FAILURE == notifyReqVO.getStatus()) {
return midjourneyOperation; imageStatus = AiImageStatusEnum.FAIL.getStatus();
}
} }
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_OPERATION_NOT_EXISTS); // 2上传图片
} String filePath = null;
if (!StrUtil.isBlank(notifyReqVO.getImageUrl())) {
filePath = fileApi.createFile(HttpUtil.downloadBytes(notifyReqVO.getImageUrl()));
private List<AiImageMidjourneyOperationsVO> getMidjourneyOperations(AiImageDO aiImageDO) { }
// if (StrUtil.isBlank(aiImageDO.getMjOperations())) { // 2更新 image 状态
// return Collections.emptyList(); imageMapper.updateById(
// } new AiImageDO()
// return JsonUtils.parseArray(aiImageDO.getMjOperations(), AiImageMidjourneyOperationsVO.class); .setId(image.getId())
return null; .setStatus(imageStatus)
.setPicUrl(filePath)
.setOriginalPicUrl(notifyReqVO.getImageUrl())
.setDrawResponse(BeanUtil.beanToMap(notifyReqVO))
);
return true;
} }
private AiImageDO validateExists(Long id) { private AiImageDO validateExists(Long id) {