【增加】增加 midjourney 提交任务

This commit is contained in:
cherishsince 2024-05-30 16:29:28 +08:00
parent 782f108516
commit a1f738dd81
6 changed files with 126 additions and 68 deletions

View File

@ -0,0 +1,36 @@
package cn.iocoder.yudao.module.ai.client.vo;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.List;
/**
* Midjourney 提交任务 code 枚举
*
* @author fansili
* @time 2024/5/30 14:33
* @since 1.0
*/
@Getter
@AllArgsConstructor
public enum MidjourneySubmitCodeEnum {
// 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
SUBMIT_SUCCESS("1", "提交成功"),
ALREADY_EXISTS("1", "已存在"),
QUEUING("22", "排队中"),
;
public static final List<String> SUCCESS_CODES = Lists.newArrayList(
SUBMIT_SUCCESS.code,
ALREADY_EXISTS.code,
QUEUING.code
);
private String code;
private String name;
}

View File

@ -3,6 +3,8 @@ package cn.iocoder.yudao.module.ai.client.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.Map;
/**
* MidjourneyImagine 请求
*
@ -20,7 +22,7 @@ public class MidjourneySubmitRespVO {
private String description;
@Schema(description = "扩展字段")
private String properties;
private Map<String, Object> properties;
@Schema(description = "任务ID")
private String result;

View File

@ -3,13 +3,17 @@ package cn.iocoder.yudao.module.ai.controller.admin.image;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageMyRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.service.image.AiImageService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import jakarta.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
@ -49,32 +53,17 @@ public class AiImageController {
}
// TODO @fan建议把 dallDrawingmidjourney 融合成一个 draw 接口异步绘制然后返回一个 id 给前端前端通过 get 接口轮询直到获取到生成成功
// TODO @芋艿: 参数差异较大
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
@PostMapping("/dall")
public CommonResult<Long> dall(@Validated @RequestBody AiImageDallReqVO req) {
return success(aiImageService.dall(getLoginUserId(), req));
}
@Operation(summary = "midjourney绘画", description = "midjourney图片绘画流程1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
@PostMapping("/midjourney")
public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReqVO req) {
aiImageService.midjourney(req);
return success(null);
}
@Operation(summary = "midjourney绘画操作", description = "一般有选择图片、放大、换一批...")
@PostMapping("/midjourney-operate")
public CommonResult<Void> midjourneyOperate(@Validated @RequestBody AiImageMidjourneyOperateReqVO req) {
aiImageService.midjourneyOperate(req);
return success(null);
}
// TODO @fan要不先不要 midjourneyOperatecancelMidjourney 接口哈
@Operation(summary = "取消 midjourney 绘画", description = "取消 midjourney 绘画")
@PostMapping("/cancel-midjourney")
public CommonResult<Void> cancelMidjourney(@RequestParam("id") Long id) {
// @范 这里实现mj取消逻辑
return success(null);
@Operation(summary = "midjourney-imagine 绘画", description = "...")
@PostMapping("/midjourney/imagine")
public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) {
return success(aiImageService.midjourneyImagine(getLoginUserId(), req));
}
@Operation(summary = "删除【我的】绘画记录")
@ -83,4 +72,10 @@ public class AiImageController {
public CommonResult<Boolean> deleteIdMy(@RequestParam("id") Long id) {
return success(aiImageService.deleteIdMy(id, getLoginUserId()));
}
@Operation(summary = "删除【我的】绘画记录")
@RequestMapping("/midjourney-notify")
public CommonResult<Boolean> midjourneyNotify(HttpServletRequest request) {
return success(true);
}
}

View File

@ -1,9 +1,12 @@
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
/**
* midjourney req
*
@ -13,17 +16,15 @@ import lombok.experimental.Accessors;
*/
@Data
@Accessors(chain = true)
public class AiImageMidjourneyReqVO {
public class AiImageMidjourneyImagineReqVO {
@Schema(description = "提示词")
@NotNull(message = "提示词不能为空!")
private String prompt;
@Schema(description = "绘画比例 1:1、3:4、4:3、9:16、16:9")
private String size;
@Schema(description = "模型(midjourney、niji)")
private String model;
@Schema(description = "风格")
private String style;
@Schema(description = "参考图")
private String referImage;
@Schema(description = "垫图(参考图)base64数组")
private List<String> base64Array;
}

View File

@ -3,8 +3,8 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
/**
@ -44,10 +44,11 @@ public interface AiImageService {
/**
* midjourney 图片生成
*
* @param loginUserId
* @param req
* @return
*/
void midjourney(AiImageMidjourneyReqVO req);
Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req);
/**
* midjourney 操作(u1u2放大换一批...)

View File

@ -1,7 +1,7 @@
package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.util.IdUtil;
import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
import cn.iocoder.yudao.framework.ai.core.exception.AiException;
@ -11,6 +11,10 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.AiCommonConstants;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitCodeEnum;
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
@ -18,21 +22,22 @@ import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import com.google.common.collect.ImmutableMap;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.image.ImageGeneration;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -59,28 +64,11 @@ public class AiImageServiceImpl implements AiImageService {
private FileApi fileApi;
@Resource
private OpenAiImageClient openAiImageClient;
@Resource
private MidjourneyWebSocketStarter midjourneyWebSocketStarter;
@Resource
private MidjourneyInteractionsApi midjourneyInteractionsApi;
@Autowired
private MidjourneyProxyClient midjourneyProxyClient;
// TODO @fan mj proxy
@PostConstruct
public void startMidjourney() {
// todo @fan 暂时注释掉
// log.info("midjourney web socket starter...");
// midjourneyWebSocketStarter.start(new WssNotify() {
// @Override
// public void notify(int code, String message) {
// log.info("code: {}, message: {}", code, message);
// if (message.contains("Authentication failed")) {
// // TODO 芋艿这里看怎么处理token无效的时候会认证失败
// // 认证失败
// log.error("midjourney socket 认证失败检查token是否失效!");
// }
// }
// });
}
@Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
private String midjourneyNotifyUrl;
@Override
public PageResult<AiImageDO> getImagePageMy(Long loginUserId, AiImageListReqVO req) {
@ -143,18 +131,53 @@ public class AiImageServiceImpl implements AiImageService {
@Override
@Transactional(rollbackFor = Exception.class)
public void midjourney(AiImageMidjourneyReqVO req) {
// 保存数据库
String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
// todo
// AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
// null, null, AiImageStatusEnum.SUBMIT, null,
// messageId, null, null);
// 提交 midjourney 任务
Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
if (!imagine) {
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) {
// 1构建 AiImageDO
AiImageDO aiImageDO = new AiImageDO();
aiImageDO.setId(null);
aiImageDO.setUserId(loginUserId);
aiImageDO.setPrompt(req.getPrompt());
aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
// todo @范 平台需要转换(mj 模型一般分版本)
aiImageDO.setModel(null);
aiImageDO.setWidth(null);
aiImageDO.setHeight(null);
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
aiImageDO.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus());
aiImageDO.setPicUrl(null);
aiImageDO.setOriginalPicUrl(null);
aiImageDO.setDrawRequest(null);
aiImageDO.setDrawResponse(null);
aiImageDO.setErrorMessage(null);
// 2保存 image
imageMapper.insert(aiImageDO);
// 3调用 MidjourneyProxy 提交任务
MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
imagineReqVO.setState(String.valueOf(aiImageDO.getId()));
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
// 4保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
String updateStatus = null;
String errorMessage = null;
Map<String, Object> drawResponse = new HashMap<>();
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
updateStatus = AiImageStatusEnum.FAIL.getStatus();
errorMessage = submitRespVO.getDescription();
} else {
drawResponse.put("jobId", submitRespVO.getResult());
}
imageMapper.updateById(new AiImageDO()
.setId(aiImageDO.getId())
.setStatus(updateStatus)
.setErrorMessage(errorMessage)
.setDrawResponse(drawResponse)
);
return aiImageDO.getId();
}
@Transactional(rollbackFor = Exception.class)