【增加】增加 midjourney 提交任务
This commit is contained in:
parent
782f108516
commit
a1f738dd81
|
@ -0,0 +1,36 @@
|
|||
package cn.iocoder.yudao.module.ai.client.vo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Midjourney 提交任务 code 枚举
|
||||
*
|
||||
* @author fansili
|
||||
* @time 2024/5/30 14:33
|
||||
* @since 1.0
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum MidjourneySubmitCodeEnum {
|
||||
|
||||
// 状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误)
|
||||
SUBMIT_SUCCESS("1", "提交成功"),
|
||||
ALREADY_EXISTS("1", "已存在"),
|
||||
QUEUING("22", "排队中"),
|
||||
|
||||
;
|
||||
|
||||
public static final List<String> SUCCESS_CODES = Lists.newArrayList(
|
||||
SUBMIT_SUCCESS.code,
|
||||
ALREADY_EXISTS.code,
|
||||
QUEUING.code
|
||||
);
|
||||
|
||||
private String code;
|
||||
private String name;
|
||||
|
||||
}
|
|
@ -3,6 +3,8 @@ package cn.iocoder.yudao.module.ai.client.vo;
|
|||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Midjourney:Imagine 请求
|
||||
*
|
||||
|
@ -20,7 +22,7 @@ public class MidjourneySubmitRespVO {
|
|||
private String description;
|
||||
|
||||
@Schema(description = "扩展字段")
|
||||
private String properties;
|
||||
private Map<String, Object> properties;
|
||||
|
||||
@Schema(description = "任务ID")
|
||||
private String result;
|
||||
|
|
|
@ -3,13 +3,17 @@ package cn.iocoder.yudao.module.ai.controller.admin.image;
|
|||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageMyRespVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
||||
import cn.iocoder.yudao.module.ai.service.image.AiImageService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
@ -49,32 +53,17 @@ public class AiImageController {
|
|||
}
|
||||
|
||||
// TODO @fan:建议把 dallDrawing、midjourney 融合成一个 draw 接口,异步绘制;然后返回一个 id 给前端;前端通过 get 接口轮询,直到获取到生成成功
|
||||
// TODO @芋艿: 参数差异较大
|
||||
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
|
||||
@PostMapping("/dall")
|
||||
public CommonResult<Long> dall(@Validated @RequestBody AiImageDallReqVO req) {
|
||||
return success(aiImageService.dall(getLoginUserId(), req));
|
||||
}
|
||||
|
||||
@Operation(summary = "midjourney绘画", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
|
||||
@PostMapping("/midjourney")
|
||||
public CommonResult<Void> midjourney(@Validated @RequestBody AiImageMidjourneyReqVO req) {
|
||||
aiImageService.midjourney(req);
|
||||
return success(null);
|
||||
}
|
||||
|
||||
@Operation(summary = "midjourney绘画操作", description = "一般有选择图片、放大、换一批...")
|
||||
@PostMapping("/midjourney-operate")
|
||||
public CommonResult<Void> midjourneyOperate(@Validated @RequestBody AiImageMidjourneyOperateReqVO req) {
|
||||
aiImageService.midjourneyOperate(req);
|
||||
return success(null);
|
||||
}
|
||||
|
||||
// TODO @fan:要不先不要 midjourneyOperate、cancelMidjourney 接口哈
|
||||
@Operation(summary = "取消 midjourney 绘画", description = "取消 midjourney 绘画")
|
||||
@PostMapping("/cancel-midjourney")
|
||||
public CommonResult<Void> cancelMidjourney(@RequestParam("id") Long id) {
|
||||
// @范 这里实现mj取消逻辑
|
||||
return success(null);
|
||||
@Operation(summary = "midjourney-imagine 绘画", description = "...")
|
||||
@PostMapping("/midjourney/imagine")
|
||||
public CommonResult<Long> midjourneyImagine(@Validated @RequestBody AiImageMidjourneyImagineReqVO req) {
|
||||
return success(aiImageService.midjourneyImagine(getLoginUserId(), req));
|
||||
}
|
||||
|
||||
@Operation(summary = "删除【我的】绘画记录")
|
||||
|
@ -83,4 +72,10 @@ public class AiImageController {
|
|||
public CommonResult<Boolean> deleteIdMy(@RequestParam("id") Long id) {
|
||||
return success(aiImageService.deleteIdMy(id, getLoginUserId()));
|
||||
}
|
||||
|
||||
@Operation(summary = "删除【我的】绘画记录")
|
||||
@RequestMapping("/midjourney-notify")
|
||||
public CommonResult<Boolean> midjourneyNotify(HttpServletRequest request) {
|
||||
return success(true);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.image.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* midjourney req
|
||||
*
|
||||
|
@ -13,17 +16,15 @@ import lombok.experimental.Accessors;
|
|||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class AiImageMidjourneyReqVO {
|
||||
public class AiImageMidjourneyImagineReqVO {
|
||||
|
||||
@Schema(description = "提示词")
|
||||
@NotNull(message = "提示词不能为空!")
|
||||
private String prompt;
|
||||
|
||||
@Schema(description = "绘画比例 1:1、3:4、4:3、9:16、16:9")
|
||||
private String size;
|
||||
@Schema(description = "模型(midjourney、niji)")
|
||||
private String model;
|
||||
|
||||
@Schema(description = "风格")
|
||||
private String style;
|
||||
|
||||
@Schema(description = "参考图")
|
||||
private String referImage;
|
||||
@Schema(description = "垫图(参考图)base64数组")
|
||||
private List<String> base64Array;
|
||||
}
|
|
@ -3,8 +3,8 @@ package cn.iocoder.yudao.module.ai.service.image;
|
|||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageListReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyImagineReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyOperateReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
||||
|
||||
/**
|
||||
|
@ -44,10 +44,11 @@ public interface AiImageService {
|
|||
/**
|
||||
* midjourney 图片生成
|
||||
*
|
||||
* @param loginUserId
|
||||
* @param req
|
||||
* @return
|
||||
*/
|
||||
void midjourney(AiImageMidjourneyReqVO req);
|
||||
Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req);
|
||||
|
||||
/**
|
||||
* midjourney 操作(u1、u2、放大、换一批...)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
package cn.iocoder.yudao.module.ai.service.image;
|
||||
|
||||
import cn.hutool.core.util.IdUtil;
|
||||
import cn.hutool.http.HttpUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.exception.AiException;
|
||||
|
@ -11,6 +11,10 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
|||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.module.ai.AiCommonConstants;
|
||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.client.MidjourneyProxyClient;
|
||||
import cn.iocoder.yudao.module.ai.client.vo.MidjourneyImagineReqVO;
|
||||
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitCodeEnum;
|
||||
import cn.iocoder.yudao.module.ai.client.vo.MidjourneySubmitRespVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
|
||||
|
@ -18,21 +22,22 @@ import cn.iocoder.yudao.module.ai.enums.AiImagePublicStatusEnum;
|
|||
import cn.iocoder.yudao.module.ai.enums.AiImageStatusEnum;
|
||||
import cn.iocoder.yudao.module.infra.api.file.FileApi;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.image.ImageGeneration;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
|
||||
import org.springframework.ai.models.midjourney.webSocket.MidjourneyWebSocketStarter;
|
||||
import org.springframework.ai.openai.OpenAiImageClient;
|
||||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
|
||||
|
@ -59,28 +64,11 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
private FileApi fileApi;
|
||||
@Resource
|
||||
private OpenAiImageClient openAiImageClient;
|
||||
@Resource
|
||||
private MidjourneyWebSocketStarter midjourneyWebSocketStarter;
|
||||
@Resource
|
||||
private MidjourneyInteractionsApi midjourneyInteractionsApi;
|
||||
@Autowired
|
||||
private MidjourneyProxyClient midjourneyProxyClient;
|
||||
|
||||
// TODO @fan:接 mj proxy
|
||||
@PostConstruct
|
||||
public void startMidjourney() {
|
||||
// todo @fan 暂时注释掉
|
||||
// log.info("midjourney web socket starter...");
|
||||
// midjourneyWebSocketStarter.start(new WssNotify() {
|
||||
// @Override
|
||||
// public void notify(int code, String message) {
|
||||
// log.info("code: {}, message: {}", code, message);
|
||||
// if (message.contains("Authentication failed")) {
|
||||
// // TODO 芋艿,这里看怎么处理,token无效的时候会认证失败!
|
||||
// // 认证失败
|
||||
// log.error("midjourney socket 认证失败,检查token是否失效!");
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
}
|
||||
@Value("${ai.midjourney-proxy.notifyUrl:http://127.0.0.1:48080/admin-api/ai/image/midjourney-notify}")
|
||||
private String midjourneyNotifyUrl;
|
||||
|
||||
@Override
|
||||
public PageResult<AiImageDO> getImagePageMy(Long loginUserId, AiImageListReqVO req) {
|
||||
|
@ -143,18 +131,53 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
|
||||
@Override
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public void midjourney(AiImageMidjourneyReqVO req) {
|
||||
// 保存数据库
|
||||
String messageId = String.valueOf(IdUtil.getSnowflakeNextId());
|
||||
// todo
|
||||
// AiImageDO aiImageDO = doSave(req.getPrompt(), null, "midjoureny",
|
||||
// null, null, AiImageStatusEnum.SUBMIT, null,
|
||||
// messageId, null, null);
|
||||
// 提交 midjourney 任务
|
||||
Boolean imagine = midjourneyInteractionsApi.imagine(messageId, req.getPrompt());
|
||||
if (!imagine) {
|
||||
throw ServiceExceptionUtil.exception(ErrorCodeConstants.AI_MIDJOURNEY_IMAGINE_FAIL);
|
||||
public Long midjourneyImagine(Long loginUserId, AiImageMidjourneyImagineReqVO req) {
|
||||
|
||||
// 1、构建 AiImageDO
|
||||
AiImageDO aiImageDO = new AiImageDO();
|
||||
aiImageDO.setId(null);
|
||||
aiImageDO.setUserId(loginUserId);
|
||||
aiImageDO.setPrompt(req.getPrompt());
|
||||
aiImageDO.setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
|
||||
// todo @范 平台需要转换(mj 模型一般分版本)
|
||||
aiImageDO.setModel(null);
|
||||
aiImageDO.setWidth(null);
|
||||
aiImageDO.setHeight(null);
|
||||
aiImageDO.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
|
||||
aiImageDO.setPublicStatus(AiImagePublicStatusEnum.PRIVATE.getStatus());
|
||||
aiImageDO.setPicUrl(null);
|
||||
aiImageDO.setOriginalPicUrl(null);
|
||||
aiImageDO.setDrawRequest(null);
|
||||
aiImageDO.setDrawResponse(null);
|
||||
aiImageDO.setErrorMessage(null);
|
||||
|
||||
// 2、保存 image
|
||||
imageMapper.insert(aiImageDO);
|
||||
|
||||
// 3、调用 MidjourneyProxy 提交任务
|
||||
MidjourneyImagineReqVO imagineReqVO = BeanUtils.toBean(req, MidjourneyImagineReqVO.class);
|
||||
imagineReqVO.setNotifyHook(midjourneyNotifyUrl);
|
||||
imagineReqVO.setState(String.valueOf(aiImageDO.getId()));
|
||||
MidjourneySubmitRespVO submitRespVO = midjourneyProxyClient.imagine(imagineReqVO);
|
||||
|
||||
// 4、保存任务 id (状态码: 1(提交成功), 21(已存在), 22(排队中), other(错误))
|
||||
String updateStatus = null;
|
||||
String errorMessage = null;
|
||||
Map<String, Object> drawResponse = new HashMap<>();
|
||||
|
||||
if (!MidjourneySubmitCodeEnum.SUCCESS_CODES.contains(submitRespVO.getCode())) {
|
||||
updateStatus = AiImageStatusEnum.FAIL.getStatus();
|
||||
errorMessage = submitRespVO.getDescription();
|
||||
} else {
|
||||
drawResponse.put("jobId", submitRespVO.getResult());
|
||||
}
|
||||
imageMapper.updateById(new AiImageDO()
|
||||
.setId(aiImageDO.getId())
|
||||
.setStatus(updateStatus)
|
||||
.setErrorMessage(errorMessage)
|
||||
.setDrawResponse(drawResponse)
|
||||
);
|
||||
return aiImageDO.getId();
|
||||
}
|
||||
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
|
|
Loading…
Reference in New Issue