【增加】Midjourney Proxy 回调通知

This commit is contained in:
cherishsince 2024-05-31 17:05:28 +08:00
parent 36b6fee0ec
commit 56e8707e38
6 changed files with 67 additions and 40 deletions

View File

@ -11,7 +11,10 @@ import lombok.Data;
* @since 1.0
*/
@Data
public class MidjourneyNotifyVO {
public class MidjourneyNotifyReqVO {
@Schema(description = "job id")
private String id;
@Schema(description = "任务类型")
private MidjourneyTaskActionEnum action;

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.image;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
@ -13,7 +14,6 @@ import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
@ -74,9 +74,9 @@ public class AiImageController {
return success(aiImageService.deleteIdMy(id, getLoginUserId()));
}
@Operation(summary = "删除【我的】绘画记录")
@Operation(summary = "midjourney proxy - 回调通知")
@RequestMapping("/midjourney-notify")
public CommonResult<Boolean> midjourneyNotify(HttpServletRequest request) {
return success(true);
public CommonResult<Boolean> midjourneyNotify(MidjourneyNotifyReqVO notifyReqVO) {
return success(aiImageService.midjourneyNotify(getLoginUserId(), notifyReqVO));
}
}

View File

@ -28,6 +28,9 @@ public class AiImageDO extends BaseDO {
@Schema(description = "用户编号")
private Long userId;
@Schema(description = "midjourney proxy 关联的 job id")
private String jobId;
@Schema(description = "提示词")
private String prompt;

View File

@ -4,7 +4,6 @@ import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import org.apache.ibatis.annotations.Mapper;
import org.springframework.stereotype.Repository;
/**
* AI 绘图 Mapper
@ -26,4 +25,14 @@ public interface AiImageMapper extends BaseMapperX<AiImageDO> {
return;
}
/**
* 查询 - 根据 job id
*
* @param id
* @return
*/
default AiImageDO selectByJobId(String id) {
return this.selectOne(new LambdaQueryWrapperX<AiImageDO>().eq(AiImageDO::getJobId, id));
}
}

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
@ -65,4 +66,12 @@ public interface AiImageService {
*/
Boolean deleteIdMy(Long id, Long loginUserId);
/**
* midjourney proxy - 回调通知
*
* @param loginUserId
* @param notifyReqVO
* @return
*/
Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO);
}

View File

@ -1,5 +1,7 @@
package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
@ -14,9 +16,14 @@ import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyModelEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneySubmitCodeEnum;
import cn.iocoder.yudao.module.ai.client.enums.MidjourneyTaskStatusEnum;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyNotifyReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
@ -36,15 +43,9 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
// TODO @fan注释优化下哈
/**
* AI 绘画(接入 dall2/dall3midjourney)
*
@ -56,9 +57,6 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
@Slf4j
public class AiImageServiceImpl implements AiImageService {
// TODO @fan使用 @Resource 注入
// TODO @fanimageMapper
@Resource
private AiImageMapper imageMapper;
@Resource
@ -173,19 +171,16 @@ public class AiImageServiceImpl implements AiImageService {
// 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
String updateStatus = null;
String errorMessage = null;
Map<String, Object> drawResponse = new HashMap<>();
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
updateStatus = AiImageStatusEnum.FAIL.getStatus();
errorMessage = submitRespVO.getDescription();
} else {
drawResponse.put("jobId", submitRespVO.getResult());
}
imageMapper.updateById(new AiImageDO()
.setId(aiImageDO.getId())
.setStatus(updateStatus)
.setErrorMessage(errorMessage)
.setDrawResponse(drawResponse)
.setJobId(submitRespVO.getResult())
);
return aiImageDO.getId();
}
@ -228,28 +223,36 @@ public class AiImageServiceImpl implements AiImageService {
return imageMapper.deleteById(id) > 0;
}
private void validateMessageId(String mjMessageId, String messageId) {
if (!mjMessageId.equals(messageId)) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_MESSAGE_ID_INCORRECT);
@Override
public Boolean midjourneyNotify(Long loginUserId, MidjourneyNotifyReqVO notifyReqVO) {
// 1根据 job id 查询关联的 image
AiImageDO image = imageMapper.selectByJobId(notifyReqVO.getId());
if (image == null) {
log.warn("midjourneyNotify 回调的 jobId 不存在! jobId: {}", notifyReqVO.getId());
return false;
}
}
private AiImageMidjourneyOperationsVO validateMidjourneyOperationsExists(List<AiImageMidjourneyOperationsVO> midjourneyOperations, String operateId) {
for (AiImageMidjourneyOperationsVO midjourneyOperation : midjourneyOperations) {
if (midjourneyOperation.getCustom_id().equals(operateId)) {
return midjourneyOperation;
}
//
String imageStatus = null;
if (MidjourneyTaskStatusEnum.SUCCESS == notifyReqVO.getStatus()) {
imageStatus = AiImageStatusEnum.COMPLETE.getStatus();
} else if (MidjourneyTaskStatusEnum.FAILURE == notifyReqVO.getStatus()) {
imageStatus = AiImageStatusEnum.FAIL.getStatus();
}
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_OPERATION_NOT_EXISTS);
}
private List<AiImageMidjourneyOperationsVO> getMidjourneyOperations(AiImageDO aiImageDO) {
// if (StrUtil.isBlank(aiImageDO.getMjOperations())) {
// return Collections.emptyList();
// }
// return JsonUtils.parseArray(aiImageDO.getMjOperations(), AiImageMidjourneyOperationsVO.class);
return null;
// 2上传图片
String filePath = null;
if (!StrUtil.isBlank(notifyReqVO.getImageUrl())) {
filePath = fileApi.createFile(HttpUtil.downloadBytes(notifyReqVO.getImageUrl()));
}
// 2更新 image 状态
imageMapper.updateById(
new AiImageDO()
.setId(image.getId())
.setStatus(imageStatus)
.setPicUrl(filePath)
.setOriginalPicUrl(notifyReqVO.getImageUrl())
.setDrawResponse(BeanUtil.beanToMap(notifyReqVO))
);
return true;
}
private AiImageDO validateExists(Long id) {