Merge branch 'master-jdk21-ai' of https://gitee.com/cherishsince/ruoyi-vue-pro into master-jdk21-ai
This commit is contained in:
commit
c29430e754
|
@ -628,15 +628,6 @@
|
|||
<artifactId>ureport2-console</artifactId>
|
||||
<version>${ureport2.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- 添加ai模块 -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-bom</artifactId>
|
||||
<version>${spring-ai-bom.version}</version>
|
||||
<type>pom</type>
|
||||
<scope>import</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ public interface ErrorCodeConstants {
|
|||
|
||||
// ========== 模块 ai 错误码区间 [1-022-000-000 ~ 1-023-000-000) ==========
|
||||
|
||||
// TODO @fansili:1)类注释不太对;2)中英文之间,有个空格;例如说 AI 模型
|
||||
ErrorCode AI_MODULE_NOT_SUPPORTED = new ErrorCode(1_022_000_000, "AI模型暂不支持!");
|
||||
|
||||
}
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
package cn.iocoder.yudao.module.ai.enums;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* author: fansili
|
||||
* time: 2024/3/4 12:36
|
||||
*/
|
||||
@Getter
|
||||
public enum AiModelEnum {
|
||||
|
||||
OPEN_AI_GPT_3_5("gpt-3.5-turbo", "GPT3.5"),
|
||||
OPEN_AI_GPT_4("gpt-4-turbo", "GPT4")
|
||||
|
||||
;
|
||||
|
||||
AiModelEnum(String value, String message) {
|
||||
this.value = value;
|
||||
this.message = message;
|
||||
}
|
||||
|
||||
private String value;
|
||||
|
||||
private String message;
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package cn.iocoder.yudao.module.ai.enums;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
// TODO done @fansili:1)类注释要加下;2)author 和 time 用 javadoc,@author 和 @since;3)@AllArgsConstructor 使用这个注解,去掉构造方法;4)value 改成 model 字段,然后注释都写下哈;5)message 改成 name,然后注释都写下哈
|
||||
/**
|
||||
* @author: fansili
|
||||
* @time: 2024/3/4 12:36
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum OpenAiModelEnum {
|
||||
|
||||
/**
|
||||
* open ai 3.5模型
|
||||
*/
|
||||
OPEN_AI_GPT_3_5("gpt-3.5-turbo", "GPT3.5"),
|
||||
/**
|
||||
* open ai 4.0 收费模型
|
||||
*/
|
||||
OPEN_AI_GPT_4("gpt-4-turbo", "GPT4")
|
||||
|
||||
;
|
||||
|
||||
/**
|
||||
* 模型 - 用于参数传递
|
||||
*/
|
||||
private String model;
|
||||
/**
|
||||
* 模型名字 - 用于展示
|
||||
*/
|
||||
private String name;
|
||||
|
||||
}
|
|
@ -4,7 +4,7 @@ import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
|||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.vo.AiChatReqVO;
|
||||
import cn.iocoder.yudao.module.ai.enums.AiModelEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
|
@ -13,7 +13,6 @@ import org.springframework.ai.chat.ChatClient;
|
|||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatClient;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
|
@ -23,15 +22,9 @@ import org.springframework.web.bind.annotation.RequestMapping;
|
|||
import org.springframework.web.bind.annotation.RestController;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.Scanner;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
/**
|
||||
* AI模块
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/3/3 20:28
|
||||
*/
|
||||
// TODO done @fansili:有了 swagger 注释,就不用类注释了
|
||||
@Tag(name = "AI模块")
|
||||
@RestController
|
||||
@RequestMapping("/ai-api")
|
||||
|
@ -47,7 +40,7 @@ public class ChatController {
|
|||
ChatClient chatClient = getChatClient(reqVO.getAiModel());
|
||||
String res;
|
||||
try {
|
||||
res = chatClient.call(reqVO.getInputText());
|
||||
res = chatClient.call(reqVO.getPrompt());
|
||||
} catch (Exception e) {
|
||||
res = e.getMessage();
|
||||
}
|
||||
|
@ -58,33 +51,14 @@ public class ChatController {
|
|||
@Operation(summary = "对话聊天chatStream", description = "简单的ai聊天")
|
||||
public CommonResult chatStream(HttpServletResponse response, @RequestBody @Validated AiChatReqVO reqVO) throws InterruptedException {
|
||||
OpenAiChatClient chatClient = applicationContext.getBean(OpenAiChatClient.class);
|
||||
Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getInputText()));
|
||||
Flux<ChatResponse> chatResponse = chatClient.stream(new Prompt(reqVO.getPrompt()));
|
||||
chatResponse.subscribe(new Consumer<ChatResponse>() {
|
||||
@Override
|
||||
public void accept(ChatResponse chatResponse) {
|
||||
System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
|
||||
}
|
||||
});
|
||||
return CommonResult.success("1");
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
OpenAiChatClient openAiChatClient = new OpenAiChatClient(new OpenAiApi("openkey"));
|
||||
Flux<ChatResponse> responseFlux = openAiChatClient.stream(new Prompt("最好的编程语言!"));
|
||||
long now = System.currentTimeMillis();
|
||||
responseFlux.subscribe(new Consumer<ChatResponse>() {
|
||||
@Override
|
||||
public void accept(ChatResponse chatResponse) {
|
||||
if (chatResponse.getResults().get(0).getOutput() == null) {
|
||||
return;
|
||||
}
|
||||
System.err.println(chatResponse.getResults().get(0).getOutput().getContent());
|
||||
}
|
||||
});
|
||||
|
||||
// 阻止退出
|
||||
Scanner scanner = new Scanner(System.in);
|
||||
scanner.nextLine();
|
||||
return CommonResult.success(null);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -93,8 +67,8 @@ public class ChatController {
|
|||
* @param aiModelEnum
|
||||
* @return
|
||||
*/
|
||||
private ChatClient getChatClient(AiModelEnum aiModelEnum) {
|
||||
if (AiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
|
||||
private ChatClient getChatClient(OpenAiModelEnum aiModelEnum) {
|
||||
if (OpenAiModelEnum.OPEN_AI_GPT_3_5 == aiModelEnum) {
|
||||
return applicationContext.getBean(OpenAiChatClient.class);
|
||||
}
|
||||
// AI模型暂不支持
|
||||
|
|
|
@ -1,26 +1,21 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.vo;
|
||||
|
||||
import cn.iocoder.yudao.module.ai.enums.AiModelEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.OpenAiModelEnum;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* ai 聊天 req
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/3/4 12:33
|
||||
*/
|
||||
@Schema(description = "用户 App - 上传文件 Request VO")
|
||||
// TODO done @fansili 1)swagger 注释不太对;2)有了 swagger 注释,就不用类注释了
|
||||
@Data
|
||||
@Schema(description = "用户 App - 上传文件 Request VO")
|
||||
public class AiChatReqVO {
|
||||
|
||||
@Schema(description = "输入内容", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@NotNull(message = "输入内容不能为空")
|
||||
private String inputText;
|
||||
@Schema(description = "提示词", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@NotNull(message = "提示词不能为空!")
|
||||
private String prompt;
|
||||
|
||||
@Schema(description = "AI模型", requiredMode = Schema.RequiredMode.REQUIRED)
|
||||
@NotNull(message = "AI模型不能为空")
|
||||
private AiModelEnum aiModel;
|
||||
private OpenAiModelEnum aiModel;
|
||||
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# open ai
|
||||
# open ai TODO @fansili??????????????
|
||||
|
||||
# openAI https://openai.com/
|
||||
spring.ai.openai.api-key=${OPEN_AI_KEY}
|
||||
|
|
|
@ -149,6 +149,10 @@
|
|||
<!-- </exclusion>-->
|
||||
<!-- </exclusions>-->
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>cn.iocoder.boot</groupId>
|
||||
<artifactId>yudao-common</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -6,6 +6,7 @@ import cn.iocoder.yudao.framework.ai.chat.*;
|
|||
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
|
||||
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
|
||||
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||
import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
|
||||
import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
|
||||
import com.aliyun.broadscope.bailian.sdk.models.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -29,7 +30,7 @@ import java.util.stream.Collectors;
|
|||
* time: 2024/3/13 21:06
|
||||
*/
|
||||
@Slf4j
|
||||
public class QianWenChatClient implements ChatClient, StreamingChatClient {
|
||||
public class QianWenChatClient implements ChatClient, StreamingChatClient {
|
||||
|
||||
private QianWenApi qianWenApi;
|
||||
|
||||
|
@ -44,6 +45,7 @@ public class QianWenChatClient implements ChatClient, StreamingChatClient {
|
|||
this.qianWenOptions = qianWenOptions;
|
||||
}
|
||||
|
||||
// TODO @fansili:看看咋公用出来,允许传入类似异常之类的参数;
|
||||
public final RetryTemplate retryTemplate = RetryTemplate.builder()
|
||||
// 最大重试次数 10
|
||||
.maxAttempts(10)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package cn.iocoder.yudao.framework.ai.chatqianwen;
|
||||
package cn.iocoder.yudao.framework.ai.chatqianwen.api;
|
||||
|
||||
import com.aliyun.broadscope.bailian.sdk.AccessTokenClient;
|
||||
import com.aliyun.broadscope.bailian.sdk.ApplicationClient;
|
||||
|
@ -9,6 +9,7 @@ import org.springframework.http.HttpStatusCode;
|
|||
import org.springframework.http.ResponseEntity;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
// TODO done @fansili:是不是挪到 api 包里?按照 spring ai 的结构;根目录只放 client 和 options
|
||||
/**
|
||||
* 阿里 通义千问
|
||||
*
|
|
@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.imageopenai;
|
|||
import cn.hutool.json.JSONUtil;
|
||||
import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageRequest;
|
||||
import cn.iocoder.yudao.framework.ai.imageopenai.api.OpenAiImageResponse;
|
||||
import cn.iocoder.yudao.framework.ai.util.JacksonUtil;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import io.netty.channel.ChannelOption;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.http.HttpEntity;
|
||||
|
@ -14,7 +14,6 @@ import org.apache.http.impl.client.CloseableHttpClient;
|
|||
import org.apache.http.impl.client.HttpClients;
|
||||
import org.apache.http.util.EntityUtils;
|
||||
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
|
||||
import org.springframework.web.reactive.function.BodyInserters;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.netty.http.client.HttpClient;
|
||||
|
||||
|
@ -55,7 +54,7 @@ public class OpenAiImageApi {
|
|||
httpPost.setURI(URI.create(DEFAULT_BASE_URL.concat("/v1/images/generations")));
|
||||
httpPost.setHeader("Content-Type", "application/json");
|
||||
httpPost.setHeader("Authorization", "Bearer " + apiKey);
|
||||
httpPost.setEntity(new StringEntity(JacksonUtil.toJson(request), "UTF-8"));
|
||||
httpPost.setEntity(new StringEntity(JsonUtils.toJsonString(request), "UTF-8"));
|
||||
|
||||
CloseableHttpResponse response= null;
|
||||
try {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
|
@ -7,7 +8,7 @@ import java.util.List;
|
|||
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class MjMessage {
|
||||
public class MidjourneyMessage {
|
||||
|
||||
/**
|
||||
* id是一个重要的字段,在同时生成多个的时候,可以区分生成信息
|
||||
|
@ -41,7 +42,7 @@ public class MjMessage {
|
|||
* 1、等待
|
||||
* 2、进行中
|
||||
* 3、完成
|
||||
* {@link cn.iocoder.yudao.framework.ai.midjourney.constants.MjGennerateStatusEnum}
|
||||
* {@link MidjourneyGennerateStatusEnum}
|
||||
*/
|
||||
private String generateStatus;
|
||||
|
|
@ -0,0 +1,83 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.api;
|
||||
|
||||
import cn.hutool.core.util.IdUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants;
|
||||
import com.google.common.collect.Maps;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
// TODO @fansili:按照 spring ai 的封装习惯,这个类是不是 MidjourneyApi
|
||||
|
||||
/**
|
||||
* 图片生成
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/3 17:36
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class MidjourneyInteractions {
|
||||
|
||||
// TODO done @fansili:静态变量,放在最前面哈;
|
||||
/**
|
||||
* header - referer 头信息
|
||||
*/
|
||||
private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
|
||||
/**
|
||||
* mj配置文件
|
||||
*/
|
||||
protected final MidjourneyConfig midjourneyConfig;
|
||||
|
||||
protected MidjourneyInteractions(MidjourneyConfig midjourneyConfig) {
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取headers - application json
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
protected HttpHeaders getHeadersOfAppJson() {
|
||||
// 设置header值
|
||||
HttpHeaders httpHeaders = new HttpHeaders();
|
||||
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
|
||||
httpHeaders.set("Authorization", midjourneyConfig.getToken());
|
||||
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
|
||||
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
|
||||
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
|
||||
return httpHeaders;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取headers - http form data
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
protected HttpHeaders getHeadersOfFormData() {
|
||||
// 设置header值
|
||||
HttpHeaders httpHeaders = new HttpHeaders();
|
||||
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
|
||||
httpHeaders.set("Authorization", midjourneyConfig.getToken());
|
||||
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
|
||||
httpHeaders.set("Cookie", MidjourneyConstants.HTTP_COOKIE);
|
||||
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
|
||||
return httpHeaders;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 - 默认参数
|
||||
* @return
|
||||
*/
|
||||
protected HashMap<String, String> getDefaultParams() {
|
||||
HashMap<String, String> requestParams = Maps.newHashMap();
|
||||
// TODO @fansili:感觉参数的组装,可以搞成一个公用的方法;就是 config + 入参的感觉;
|
||||
requestParams.put("guild_id", midjourneyConfig.getGuildId());
|
||||
requestParams.put("channel_id", midjourneyConfig.getChannelId());
|
||||
requestParams.put("session_id", midjourneyConfig.getSessionId());
|
||||
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId())); // TODO @fansili:建议用 uuid 之类的;nextId 跨进程未必合适哈;
|
||||
return requestParams;
|
||||
}
|
||||
}
|
|
@ -1,20 +1,16 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.interactions;
|
||||
package cn.iocoder.yudao.framework.ai.midjourney.api;
|
||||
|
||||
import cn.hutool.core.util.IdUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjConstants;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.util.MjUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.vo.UploadAttachmentsRes;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.req.AttachmentsReq;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.req.DescribeReq;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.res.UploadAttachmentsRes;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.core.io.FileSystemResource;
|
||||
import org.springframework.http.*;
|
||||
import org.springframework.util.LinkedMultiValueMap;
|
||||
|
@ -24,6 +20,7 @@ import org.springframework.web.client.RestTemplate;
|
|||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
|
||||
// TODO @fansili:按照 spring ai 的封装习惯,这个类是不是 MidjourneyApi
|
||||
/**
|
||||
* 图片生成
|
||||
*
|
||||
|
@ -31,17 +28,13 @@ import java.util.HashMap;
|
|||
* time: 2024/4/3 17:36
|
||||
*/
|
||||
@Slf4j
|
||||
public class MjInteractions {
|
||||
|
||||
public class MidjourneyInteractionsApi extends MidjourneyInteractions {
|
||||
|
||||
private final String url;
|
||||
private final MidjourneyConfig midjourneyConfig;
|
||||
private final RestTemplate restTemplate = new RestTemplate();
|
||||
private static final String HEADER_REFERER = "https://discord.com/channels/%s/%s";
|
||||
private final RestTemplate restTemplate = new RestTemplate(); // TODO @fansili:优先级低:后续搞到统一的管理
|
||||
|
||||
|
||||
public MjInteractions(MidjourneyConfig midjourneyConfig) {
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
public MidjourneyInteractionsApi(MidjourneyConfig midjourneyConfig) {
|
||||
super(midjourneyConfig);
|
||||
this.url = midjourneyConfig.getServerUrl().concat(midjourneyConfig.getApiInteractions());
|
||||
}
|
||||
|
||||
|
@ -49,45 +42,38 @@ public class MjInteractions {
|
|||
// 获取请求模板
|
||||
String requestTemplate = midjourneyConfig.getRequestTemplates().get("imagine");
|
||||
// 设置参数
|
||||
HashMap<String, String> requestParams = Maps.newHashMap();
|
||||
requestParams.put("guild_id", midjourneyConfig.getGuildId());
|
||||
requestParams.put("channel_id", midjourneyConfig.getChannelId());
|
||||
requestParams.put("session_id", midjourneyConfig.getSessionId());
|
||||
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
|
||||
HashMap<String, String> requestParams = getDefaultParams();
|
||||
requestParams.put("prompt", prompt);
|
||||
// 解析 template 参数占位符
|
||||
String requestBody = MjUtil.parseTemplate(requestTemplate, requestParams);
|
||||
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
|
||||
// 获取 header
|
||||
HttpHeaders httpHeaders = getHttpHeaders();
|
||||
HttpHeaders httpHeaders = getHeadersOfAppJson();
|
||||
// 发送请求
|
||||
HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
|
||||
String res = restTemplate.postForObject(url, requestEntity, String.class);
|
||||
// 这个 res 只要不返回值,就是成功!
|
||||
boolean isSuccess = StrUtil.isBlank(res);
|
||||
if (isSuccess) {
|
||||
// TODO @fansili:可以直接 if (StrUtil.isBlank(res))
|
||||
if (StrUtil.isBlank(res)) {
|
||||
return true;
|
||||
} else {
|
||||
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
|
||||
return false;
|
||||
}
|
||||
log.error("请求失败! 请求参数:{} 返回结果! {}", requestBody, res);
|
||||
return isSuccess;
|
||||
}
|
||||
// TODO done @fansili:方法和方法之间,空一行哈;
|
||||
|
||||
|
||||
|
||||
public Boolean reRoll(ReRoll reRoll) {
|
||||
public Boolean reRoll(ReRollReq reRoll) {
|
||||
// 获取请求模板
|
||||
String requestTemplate = midjourneyConfig.getRequestTemplates().get("reroll");
|
||||
// 设置参数
|
||||
HashMap<String, String> requestParams = Maps.newHashMap();
|
||||
requestParams.put("guild_id", midjourneyConfig.getGuildId());
|
||||
requestParams.put("channel_id", midjourneyConfig.getChannelId());
|
||||
requestParams.put("session_id", midjourneyConfig.getSessionId());
|
||||
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
|
||||
HashMap<String, String> requestParams = getDefaultParams();
|
||||
requestParams.put("custom_id", reRoll.getCustomId());
|
||||
requestParams.put("message_id", reRoll.getMessageId());
|
||||
// 获取 header
|
||||
HttpHeaders httpHeaders = getHttpHeaders();
|
||||
HttpHeaders httpHeaders = getHeadersOfAppJson();
|
||||
// 设置参数
|
||||
String requestBody = MjUtil.parseTemplate(requestTemplate, requestParams);
|
||||
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
|
||||
// 发送请求
|
||||
HttpEntity<String> requestEntity = new HttpEntity<>(requestBody, httpHeaders);
|
||||
String res = restTemplate.postForObject(url, requestEntity, String.class);
|
||||
|
@ -100,12 +86,13 @@ public class MjInteractions {
|
|||
return isSuccess;
|
||||
}
|
||||
|
||||
|
||||
public UploadAttachmentsRes uploadAttachments(Attachments attachments) {
|
||||
// TODO @fansili:搞成私有方法,可能会好点;
|
||||
public UploadAttachmentsRes uploadAttachments(AttachmentsReq attachments) {
|
||||
// file
|
||||
JSONObject fileObj = new JSONObject();
|
||||
fileObj.put("id", "0");
|
||||
fileObj.put("filename", attachments.getFileSystemResource().getFilename());
|
||||
// TODO @fansili:这块用 lombok 哪个异常处理,简化下代码;
|
||||
try {
|
||||
fileObj.put("file_size", attachments.getFileSystemResource().contentLength());
|
||||
} catch (IOException e) {
|
||||
|
@ -115,12 +102,7 @@ public class MjInteractions {
|
|||
MultiValueMap<String, Object> multipartRequest = new LinkedMultiValueMap<>();
|
||||
multipartRequest.put("files", Lists.newArrayList(fileObj));
|
||||
// 设置header值
|
||||
HttpHeaders httpHeaders = new HttpHeaders();
|
||||
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
|
||||
httpHeaders.set("Authorization", midjourneyConfig.getToken());
|
||||
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
|
||||
httpHeaders.set("Cookie", MjConstants.HTTP_COOKIE);
|
||||
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
|
||||
HttpHeaders httpHeaders = getHeadersOfAppJson();
|
||||
// 创建HttpEntity对象,包含表单数据和头部信息
|
||||
HttpEntity<MultiValueMap<String, Object>> multiValueMapHttpEntity = new HttpEntity<>(multipartRequest, httpHeaders);
|
||||
// 发送POST请求并接收响应
|
||||
|
@ -128,38 +110,26 @@ public class MjInteractions {
|
|||
String response = restTemplate.postForObject(midjourneyConfig.getServerUrl().concat(uri), multiValueMapHttpEntity, String.class);
|
||||
UploadAttachmentsRes uploadAttachmentsRes = JSON.parseObject(response, UploadAttachmentsRes.class);
|
||||
|
||||
|
||||
//
|
||||
// 上传文件
|
||||
String uploadUrl = uploadAttachmentsRes.getAttachments().getFirst().getUploadUrl();
|
||||
String uploadAttachmentsUrl = midjourneyConfig.getApiAttachmentsUpload().concat(uploadUrl);
|
||||
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
|
||||
HttpEntity<FileSystemResource> fileSystemResourceHttpEntity = new HttpEntity<>(attachments.getFileSystemResource(), httpHeaders);
|
||||
ResponseEntity<String> exchange = restTemplate.exchange(uploadUrl, HttpMethod.PUT, fileSystemResourceHttpEntity, String.class);
|
||||
String uploadRes = exchange.getBody();
|
||||
|
||||
return uploadAttachmentsRes;
|
||||
}
|
||||
|
||||
public Boolean describe(Describe describe) {
|
||||
public Boolean describe(DescribeReq describe) {
|
||||
// 获取请求模板
|
||||
String requestTemplate = midjourneyConfig.getRequestTemplates().get("describe");
|
||||
// 设置参数
|
||||
HashMap<String, String> requestParams = Maps.newHashMap();
|
||||
requestParams.put("guild_id", midjourneyConfig.getGuildId());
|
||||
requestParams.put("channel_id", midjourneyConfig.getChannelId());
|
||||
requestParams.put("session_id", midjourneyConfig.getSessionId());
|
||||
requestParams.put("nonce", String.valueOf(IdUtil.getSnowflakeNextId()));
|
||||
HashMap<String, String> requestParams = getDefaultParams();
|
||||
requestParams.put("file_name", describe.getFileName());
|
||||
requestParams.put("final_file_name", describe.getFinalFileName());
|
||||
// 设置 header
|
||||
HttpHeaders httpHeaders = new HttpHeaders();
|
||||
httpHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); // 设置内容类型为JSON
|
||||
httpHeaders.set("Authorization", midjourneyConfig.getToken());
|
||||
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
|
||||
httpHeaders.set("Cookie", MjConstants.HTTP_COOKIE);
|
||||
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
|
||||
String requestBody = MjUtil.parseTemplate(requestTemplate, requestParams);
|
||||
HttpHeaders httpHeaders = getHeadersOfFormData();
|
||||
String requestBody = MidjourneyUtil.parseTemplate(requestTemplate, requestParams);
|
||||
// 创建表单数据
|
||||
MultiValueMap<String, String> formData = new LinkedMultiValueMap<>();
|
||||
formData.add("payload_json", requestBody);
|
||||
|
@ -175,14 +145,4 @@ public class MjInteractions {
|
|||
return isSuccess;
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private HttpHeaders getHttpHeaders() {
|
||||
HttpHeaders httpHeaders = new HttpHeaders();
|
||||
httpHeaders.setContentType(MediaType.APPLICATION_JSON); // 设置内容类型为JSON
|
||||
httpHeaders.set("Authorization", midjourneyConfig.getToken());
|
||||
httpHeaders.set("User-Agent", midjourneyConfig.getUserAage());
|
||||
httpHeaders.set("Cookie", MjConstants.HTTP_COOKIE);
|
||||
httpHeaders.set("Referer", String.format(HEADER_REFERER, midjourneyConfig.getGuildId(), midjourneyConfig.getChannelId()));
|
||||
return httpHeaders;
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.vo;
|
||||
package cn.iocoder.yudao.framework.ai.midjourney.api.req;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
@ -12,7 +12,7 @@ import org.springframework.core.io.FileSystemResource;
|
|||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class Attachments {
|
||||
public class AttachmentsReq {
|
||||
|
||||
/**
|
||||
* 创建文件系统资源对象
|
|
@ -1,10 +1,8 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.vo;
|
||||
package cn.iocoder.yudao.framework.ai.midjourney.api.req;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
/**
|
||||
* describe
|
||||
*
|
||||
|
@ -13,7 +11,7 @@ import java.io.File;
|
|||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class Describe {
|
||||
public class DescribeReq {
|
||||
|
||||
/**
|
||||
* 文件名字
|
|
@ -1,4 +1,4 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.vo;
|
||||
package cn.iocoder.yudao.framework.ai.midjourney.api.req;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
||||
|
@ -9,7 +9,7 @@ import lombok.experimental.Accessors;
|
|||
*/
|
||||
@Data
|
||||
@Accessors(chain = true)
|
||||
public class ReRoll {
|
||||
public class ReRollReq {
|
||||
|
||||
/**
|
||||
* socket 消息里面收到的 messageId
|
|
@ -1,4 +1,4 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.vo;
|
||||
package cn.iocoder.yudao.framework.ai.midjourney.api.res;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.experimental.Accessors;
|
|
@ -1,6 +1,6 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.constants;
|
||||
|
||||
public final class MjConstants {
|
||||
public final class MidjourneyConstants {
|
||||
|
||||
/**
|
||||
* 消息 - 编号
|
|
@ -0,0 +1,31 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.constants;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
// TODO done @fansili:1)Mj 缩写,还是搞成全称。。虽然长一点,但是感觉会相对清晰一些哈;2)lombok 相关的注解,可以用用哈;3)value 改 status;
|
||||
/**
|
||||
* mj 生成状态
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/6 21:07
|
||||
*/
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public enum MidjourneyGennerateStatusEnum {
|
||||
|
||||
WAITING("waiting", "等待..."),
|
||||
IN_PROGRESS("in_progress", "进行中"),
|
||||
COMPLETED("completed", "完成"),
|
||||
|
||||
;
|
||||
|
||||
/**
|
||||
* 状态
|
||||
*/
|
||||
private String status;
|
||||
/**
|
||||
* 状态信息
|
||||
*/
|
||||
private String message;
|
||||
}
|
|
@ -6,7 +6,7 @@ import lombok.Getter;
|
|||
* MJ 命令
|
||||
*/
|
||||
@Getter
|
||||
public enum MjInteractionsEnum {
|
||||
public enum MidjourneyInteractionsEnum {
|
||||
|
||||
IMAGINE("imagine", "生成图片"),
|
||||
DESCRIBE("describe", "生成描述"),
|
||||
|
@ -17,7 +17,7 @@ public enum MjInteractionsEnum {
|
|||
|
||||
;
|
||||
|
||||
MjInteractionsEnum(String value, String message) {
|
||||
MidjourneyInteractionsEnum(String value, String message) {
|
||||
this.value =value;
|
||||
this.message =message;
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.constants;
|
||||
|
||||
|
||||
public enum MjMessageTypeEnum {
|
||||
public enum MidjourneyMessageTypeEnum {
|
||||
/**
|
||||
* 创建.
|
||||
*/
|
||||
|
@ -15,7 +15,7 @@ public enum MjMessageTypeEnum {
|
|||
*/
|
||||
DELETE;
|
||||
|
||||
public static MjMessageTypeEnum of(String type) {
|
||||
public static MidjourneyMessageTypeEnum of(String type) {
|
||||
return switch (type) {
|
||||
case "MESSAGE_CREATE" -> CREATE;
|
||||
case "MESSAGE_UPDATE" -> UPDATE;
|
|
@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.midjourney.constants;
|
|||
import lombok.experimental.UtilityClass;
|
||||
|
||||
@UtilityClass
|
||||
public final class MjNotifyCode {
|
||||
public final class MidjourneyNotifyCode {
|
||||
/**
|
||||
* 成功.
|
||||
*/
|
|
@ -1,29 +0,0 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.constants;
|
||||
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* mj 生成状态
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/4/6 21:07
|
||||
*/
|
||||
@Getter
|
||||
public enum MjGennerateStatusEnum {
|
||||
|
||||
|
||||
WAITING("waiting", "等待..."),
|
||||
IN_PROGRESS("in_progress", "进行中"),
|
||||
COMPLETED("completed", "完成"),
|
||||
|
||||
;
|
||||
|
||||
MjGennerateStatusEnum(String value, String message) {
|
||||
this.value = value;
|
||||
this.message = message;
|
||||
}
|
||||
|
||||
private String value;
|
||||
|
||||
private String message;
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.util;
|
||||
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MjMessage;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
|
@ -13,7 +13,7 @@ import java.util.regex.Pattern;
|
|||
* author: fansili
|
||||
* time: 2024/4/6 19:00
|
||||
*/
|
||||
public class MjUtil {
|
||||
public class MidjourneyUtil {
|
||||
/**
|
||||
* content正则匹配prompt和进度.
|
||||
*/
|
||||
|
@ -26,12 +26,12 @@ public class MjUtil {
|
|||
* @param content
|
||||
* @return
|
||||
*/
|
||||
public static MjMessage.Content parseContent(String content) {
|
||||
public static MidjourneyMessage.Content parseContent(String content) {
|
||||
// 有三种格式。
|
||||
// 南极应该是什么样子?
|
||||
// "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)",
|
||||
// "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (fast, stealth)"
|
||||
MjMessage.Content mjContent = new MjMessage.Content();
|
||||
MidjourneyMessage.Content mjContent = new MidjourneyMessage.Content();
|
||||
if (CharSequenceUtil.isBlank(content)) {
|
||||
return null;
|
||||
}
|
|
@ -4,9 +4,9 @@ package cn.iocoder.yudao.framework.ai.midjourney.webSocket;
|
|||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.thread.ThreadUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjNotifyCode;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MjWebSocketHandler;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyNotifyCode;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.handler.MidjourneyWebSocketHandler;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.tomcat.websocket.Constants;
|
||||
|
@ -20,10 +20,10 @@ import org.springframework.web.socket.client.standard.StandardWebSocketClient;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.concurrent.TimeoutException;
|
||||
|
||||
// TODO @fansili:mj 这块 websocket 有点小复杂,虽然代码量 400 多行;感觉可以考虑,有没第三方 sdk,通过它透明接入 mj
|
||||
@Slf4j
|
||||
public class MjWebSocketStarter implements WebSocketStarter {
|
||||
public class MidjourneyWebSocketStarter implements WebSocketStarter {
|
||||
/**
|
||||
* 链接重试次数
|
||||
*/
|
||||
|
@ -35,7 +35,7 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||
/**
|
||||
* mj 监听(所有message 都会 callback到这里)
|
||||
*/
|
||||
private final MjMessageListener userMessageListener;
|
||||
private final MidjourneyMessageListener userMessageListener;
|
||||
/**
|
||||
* wss 服务器
|
||||
*/
|
||||
|
@ -57,10 +57,10 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||
*/
|
||||
private WebSocketSession webSocketSession = null;
|
||||
|
||||
public MjWebSocketStarter(String wssServer,
|
||||
String resumeWss,
|
||||
MidjourneyConfig midjourneyConfig,
|
||||
MjMessageListener userMessageListener) {
|
||||
public MidjourneyWebSocketStarter(String wssServer,
|
||||
String resumeWss,
|
||||
MidjourneyConfig midjourneyConfig,
|
||||
MidjourneyMessageListener userMessageListener) {
|
||||
this.wssServer = wssServer;
|
||||
this.resumeWss = resumeWss;
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
|
@ -82,7 +82,7 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||
headers.add("Sec-Websocket-Extensions", "permessage-deflate; client_max_window_bits");
|
||||
headers.add("User-Agent", this.midjourneyConfig.getUserAage());
|
||||
// 创建 mjHeader
|
||||
MjWebSocketHandler mjWebSocketHandler = new MjWebSocketHandler(
|
||||
MidjourneyWebSocketHandler mjWebSocketHandler = new MidjourneyWebSocketHandler(
|
||||
this.midjourneyConfig, this.userMessageListener, this::onSocketSuccess, this::onSocketFailure);
|
||||
//
|
||||
String gatewayUrl;
|
||||
|
@ -104,12 +104,12 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||
socketSessionFuture.addCallback(new ListenableFutureCallback<>() {
|
||||
@Override
|
||||
public void onFailure(@NotNull Throwable e) {
|
||||
onSocketFailure(MjWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage());
|
||||
onSocketFailure(MidjourneyWebSocketHandler.CLOSE_CODE_EXCEPTION, e.getMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSuccess(WebSocketSession session) {
|
||||
MjWebSocketStarter.this.webSocketSession = session;
|
||||
MidjourneyWebSocketStarter.this.webSocketSession = session;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -117,7 +117,7 @@ public class MjWebSocketStarter implements WebSocketStarter {
|
|||
private void onSocketSuccess(String sessionId, Object sequence, String resumeGatewayUrl) {
|
||||
this.resumeData = new ResumeData(sessionId, sequence, resumeGatewayUrl);
|
||||
this.running = true;
|
||||
notifyWssLock(MjNotifyCode.SUCCESS, "");
|
||||
notifyWssLock(MidjourneyNotifyCode.SUCCESS, "");
|
||||
}
|
||||
|
||||
private void onSocketFailure(int code, String reason) {
|
|
@ -8,7 +8,7 @@ import cn.hutool.http.useragent.UserAgentUtil;
|
|||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.FailureCallback;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.SuccessCallback;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.dv8tion.jda.api.utils.data.DataArray;
|
||||
|
@ -29,7 +29,7 @@ import java.util.concurrent.ScheduledExecutorService;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
public class MjWebSocketHandler implements WebSocketHandler {
|
||||
public class MidjourneyWebSocketHandler implements WebSocketHandler {
|
||||
/**
|
||||
* close 错误码:重连
|
||||
*/
|
||||
|
@ -49,7 +49,7 @@ public class MjWebSocketHandler implements WebSocketHandler {
|
|||
/**
|
||||
* mj 消息监听
|
||||
*/
|
||||
private final MjMessageListener userMessageListener;
|
||||
private final MidjourneyMessageListener userMessageListener;
|
||||
/**
|
||||
* 成功回调
|
||||
*/
|
||||
|
@ -85,10 +85,10 @@ public class MjWebSocketHandler implements WebSocketHandler {
|
|||
*/
|
||||
private final Decompressor decompressor = new ZlibDecompressor(2048);
|
||||
|
||||
public MjWebSocketHandler(MidjourneyConfig account,
|
||||
MjMessageListener userMessageListener,
|
||||
SuccessCallback successCallback,
|
||||
FailureCallback failureCallback) {
|
||||
public MidjourneyWebSocketHandler(MidjourneyConfig account,
|
||||
MidjourneyMessageListener userMessageListener,
|
||||
SuccessCallback successCallback,
|
||||
FailureCallback failureCallback) {
|
||||
this.midjourneyConfig = account;
|
||||
this.userMessageListener = userMessageListener;
|
||||
this.successCallback = successCallback;
|
|
@ -0,0 +1,83 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener;
|
||||
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyMessage;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyConstants;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyGennerateStatusEnum;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MidjourneyMessageTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.dv8tion.jda.api.utils.data.DataObject;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class MidjourneyMessageListener {
|
||||
|
||||
private MidjourneyConfig midjourneyConfig;
|
||||
|
||||
public MidjourneyMessageListener(MidjourneyConfig midjourneyConfig) {
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
}
|
||||
|
||||
public void onMessage(DataObject raw) {
|
||||
MidjourneyMessageTypeEnum messageType = MidjourneyMessageTypeEnum.of(raw.getString("t"));
|
||||
if (messageType == null || MidjourneyMessageTypeEnum.DELETE == messageType) {
|
||||
return;
|
||||
}
|
||||
DataObject data = raw.getObject("d");
|
||||
if (ignoreAndLogMessage(data, messageType)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 转换几个重要的信息
|
||||
MidjourneyMessage mjMessage = new MidjourneyMessage();
|
||||
mjMessage.setId(data.getString(MidjourneyConstants.MSG_ID));
|
||||
mjMessage.setType(data.getInt(MidjourneyConstants.MSG_TYPE));
|
||||
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
|
||||
mjMessage.setContent(MidjourneyUtil.parseContent(data.getString(MidjourneyConstants.MSG_CONTENT)));
|
||||
// 转换 components
|
||||
if (!data.getArray(MidjourneyConstants.MSG_COMPONENTS).isEmpty()) {
|
||||
String componentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_COMPONENTS).toJson(), "UTF-8");
|
||||
List<MidjourneyMessage.ComponentType> components = JSON.parseArray(componentsJson, MidjourneyMessage.ComponentType.class);
|
||||
mjMessage.setComponents(components);
|
||||
}
|
||||
// 转换附件
|
||||
if (!data.getArray(MidjourneyConstants.MSG_ATTACHMENTS).isEmpty()) {
|
||||
String attachmentsJson = StrUtil.str(data.getArray(MidjourneyConstants.MSG_ATTACHMENTS).toJson(), "UTF-8");
|
||||
List<MidjourneyMessage.Attachment> attachments = JSON.parseArray(attachmentsJson, MidjourneyMessage.Attachment.class);
|
||||
mjMessage.setAttachments(attachments);
|
||||
}
|
||||
// 转换状态
|
||||
convertGenerateStatus(mjMessage);
|
||||
//
|
||||
log.info("message 信息 {}", JSONUtil.toJsonPrettyStr(mjMessage));
|
||||
System.err.println(JSONUtil.toJsonPrettyStr(mjMessage));
|
||||
}
|
||||
|
||||
private void convertGenerateStatus(MidjourneyMessage mjMessage) {
|
||||
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
|
||||
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.WAITING.getStatus());
|
||||
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
|
||||
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.IN_PROGRESS.getStatus());
|
||||
} else if (mjMessage.getType() == 0 && !CollUtil.isEmpty(mjMessage.getComponents())) {
|
||||
mjMessage.setGenerateStatus(MidjourneyGennerateStatusEnum.COMPLETED.getStatus());
|
||||
}
|
||||
}
|
||||
|
||||
private boolean ignoreAndLogMessage(DataObject data, MidjourneyMessageTypeEnum messageType) {
|
||||
String channelId = data.getString(MidjourneyConstants.MSG_CHANNEL_ID);
|
||||
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
|
||||
return true;
|
||||
}
|
||||
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
|
||||
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
package cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener;
|
||||
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.text.CharSequenceUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MjMessage;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjConstants;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjGennerateStatusEnum;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.constants.MjMessageTypeEnum;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.util.MjUtil;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.dv8tion.jda.api.utils.data.DataObject;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class MjMessageListener {
|
||||
|
||||
private MidjourneyConfig midjourneyConfig;
|
||||
|
||||
public MjMessageListener(MidjourneyConfig midjourneyConfig) {
|
||||
this.midjourneyConfig = midjourneyConfig;
|
||||
}
|
||||
|
||||
public void onMessage(DataObject raw) {
|
||||
MjMessageTypeEnum messageType = MjMessageTypeEnum.of(raw.getString("t"));
|
||||
if (messageType == null || MjMessageTypeEnum.DELETE == messageType) {
|
||||
return;
|
||||
}
|
||||
DataObject data = raw.getObject("d");
|
||||
if (ignoreAndLogMessage(data, messageType)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 转换几个重要的信息
|
||||
MjMessage mjMessage = new MjMessage();
|
||||
mjMessage.setId(data.getString(MjConstants.MSG_ID));
|
||||
mjMessage.setType(data.getInt(MjConstants.MSG_TYPE));
|
||||
mjMessage.setRawData(StrUtil.str(raw.toJson(), "UTF-8"));
|
||||
mjMessage.setContent(MjUtil.parseContent(data.getString(MjConstants.MSG_CONTENT)));
|
||||
// 转换 components
|
||||
if (!data.getArray(MjConstants.MSG_COMPONENTS).isEmpty()) {
|
||||
String componentsJson = StrUtil.str(data.getArray(MjConstants.MSG_COMPONENTS).toJson(), "UTF-8");
|
||||
List<MjMessage.ComponentType> components = JSON.parseArray(componentsJson, MjMessage.ComponentType.class);
|
||||
mjMessage.setComponents(components);
|
||||
}
|
||||
// 转换附件
|
||||
if (!data.getArray(MjConstants.MSG_ATTACHMENTS).isEmpty()) {
|
||||
String attachmentsJson = StrUtil.str(data.getArray(MjConstants.MSG_ATTACHMENTS).toJson(), "UTF-8");
|
||||
List<MjMessage.Attachment> attachments = JSON.parseArray(attachmentsJson, MjMessage.Attachment.class);
|
||||
mjMessage.setAttachments(attachments);
|
||||
}
|
||||
// 转换状态
|
||||
convertGenerateStatus(mjMessage);
|
||||
//
|
||||
log.info("message 信息 {}", JSONUtil.toJsonPrettyStr(mjMessage));
|
||||
System.err.println(JSONUtil.toJsonPrettyStr(mjMessage));
|
||||
}
|
||||
|
||||
private void convertGenerateStatus(MjMessage mjMessage) {
|
||||
if (mjMessage.getType() == 20 && mjMessage.getContent().getStatus().contains("Waiting")) {
|
||||
mjMessage.setGenerateStatus(MjGennerateStatusEnum.WAITING.getValue());
|
||||
} else if (mjMessage.getType() == 20 && !StrUtil.isBlank(mjMessage.getContent().getProgress())) {
|
||||
mjMessage.setGenerateStatus(MjGennerateStatusEnum.IN_PROGRESS.getValue());
|
||||
} else if (mjMessage.getType() == 0 && !CollUtil.isEmpty(mjMessage.getComponents())) {
|
||||
mjMessage.setGenerateStatus(MjGennerateStatusEnum.COMPLETED.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
private boolean ignoreAndLogMessage(DataObject data, MjMessageTypeEnum messageType) {
|
||||
String channelId = data.getString(MjConstants.MSG_CHANNEL_ID);
|
||||
if (!CharSequenceUtil.equals(channelId, midjourneyConfig.getChannelId())) {
|
||||
return true;
|
||||
}
|
||||
String authorName = data.optObject("author").map(a -> a.getString("username")).orElse("System");
|
||||
log.debug("{} - {} - {}: {}", midjourneyConfig.getChannelId(), messageType.name(), authorName, data.opt("content").orElse(""));
|
||||
return false;
|
||||
}
|
||||
}
|
|
@ -1,5 +1,15 @@
|
|||
/**
|
||||
* author: fansili
|
||||
* time: 2024/3/12 20:29
|
||||
*
|
||||
* TODO @fansili:包的想法,需要重点看看
|
||||
*
|
||||
* 1. org.springframework.ai:包括 chat、image、model、parser、util 部分
|
||||
*
|
||||
* 2. yudao.framework.models
|
||||
* \qianwen 通义千问
|
||||
* \yiyan 文心一言
|
||||
* \xinghuo 星火
|
||||
* \midjourney
|
||||
*/
|
||||
package cn.iocoder.yudao.framework.ai;
|
|
@ -1,79 +0,0 @@
|
|||
package cn.iocoder.yudao.framework.ai.util;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||
import com.fasterxml.jackson.databind.module.SimpleModule;
|
||||
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Jackson工具类
|
||||
*
|
||||
* author: fansili
|
||||
* time: 2024/3/17 10:13
|
||||
*/
|
||||
public class JacksonUtil {
|
||||
|
||||
private static final ObjectMapper objectMapper = new ObjectMapper();
|
||||
|
||||
/**
|
||||
* 初始化 ObjectMapper 以美化输出(即格式化JSON内容)
|
||||
*/
|
||||
static {
|
||||
// 美化输出(缩进)
|
||||
objectMapper.enable(SerializationFeature.INDENT_OUTPUT);
|
||||
// 忽略值为 null 的属性
|
||||
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
|
||||
// 配置一个模块来将 Long 类型转换为 String 类型
|
||||
SimpleModule module = new SimpleModule();
|
||||
module.addSerializer(Long.class, ToStringSerializer.instance);
|
||||
objectMapper.registerModule(module);
|
||||
}
|
||||
|
||||
/**
|
||||
* 将对象转换为 JSON 字符串
|
||||
*
|
||||
* @param obj 需要序列化的Java对象
|
||||
* @return 序列化后的 JSON 字符串
|
||||
* @throws JsonProcessingException 当 JSON 序列化过程中出现错误时抛出异常
|
||||
*/
|
||||
public static String toJson(Object obj) {
|
||||
try {
|
||||
return objectMapper.writeValueAsString(obj);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 JSON 字符串反序列化为指定类型的对象
|
||||
*
|
||||
* @param json JSON 字符串
|
||||
* @param clazz 目标类型 Class 对象
|
||||
* @param <T> 泛型类型参数
|
||||
* @return 反序列化后的 Java 对象
|
||||
* @throws IOException 当 JSON 解析过程中出现错误时抛出异常
|
||||
*/
|
||||
public static <T> T fromJson(String json, Class<T> clazz) {
|
||||
try {
|
||||
return objectMapper.readValue(json, clazz);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将对象转换为格式化的 JSON 字符串(已启用 INDENT_OUTPUT 功能,所以所有方法都会返回格式化后的 JSON)
|
||||
*
|
||||
* @param obj 需要序列化的Java对象
|
||||
* @return 格式化后的 JSON 字符串
|
||||
* @throws JsonProcessingException 当 JSON 序列化过程中出现错误时抛出异常
|
||||
*/
|
||||
public static String toFormattedJson(Object obj) {
|
||||
// 已在类初始化时设置了 SerializationFeature.INDENT_OUTPUT,此处无需额外操作
|
||||
return toJson(obj);
|
||||
}
|
||||
}
|
|
@ -1,10 +1,9 @@
|
|||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenApi;
|
||||
import cn.iocoder.yudao.framework.ai.chatqianwen.api.QianWenApi;
|
||||
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenChatClient;
|
||||
import cn.iocoder.yudao.framework.ai.chatqianwen.QianWenOptions;
|
||||
import com.aliyun.broadscope.bailian.sdk.models.CompletionsRequest;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
package cn.iocoder.yudao.framework.ai.mj;
|
||||
package cn.iocoder.yudao.framework.ai.midjourney;
|
||||
|
||||
import cn.hutool.core.io.FileUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.interactions.MjInteractions;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.vo.Attachments;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.vo.Describe;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.vo.ReRoll;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.vo.UploadAttachmentsRes;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.req.AttachmentsReq;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.req.DescribeReq;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.req.ReRollReq;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.api.res.UploadAttachmentsRes;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
@ -23,7 +22,7 @@ import java.util.Map;
|
|||
* author: fansili
|
||||
* time: 2024/4/4 18:59
|
||||
*/
|
||||
public class MjInteractionsTests {
|
||||
public class MidjourneyInteractionsTests {
|
||||
|
||||
private MidjourneyConfig midjourneyConfig;
|
||||
@Before
|
||||
|
@ -39,24 +38,24 @@ public class MjInteractionsTests {
|
|||
|
||||
@Test
|
||||
public void mjImageTest() {
|
||||
MjInteractions mjImagineInteractions = new MjInteractions(midjourneyConfig);
|
||||
MidjourneyInteractionsApi mjImagineInteractions = new MidjourneyInteractionsApi(midjourneyConfig);
|
||||
mjImagineInteractions.imagine("童话里应该是什么样子?");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void reRollTest() {
|
||||
MjInteractions mjImagineInteractions = new MjInteractions(midjourneyConfig);
|
||||
mjImagineInteractions.reRoll(new ReRoll()
|
||||
MidjourneyInteractionsApi mjImagineInteractions = new MidjourneyInteractionsApi(midjourneyConfig);
|
||||
mjImagineInteractions.reRoll(new ReRollReq()
|
||||
.setMessageId("1226165117448753243")
|
||||
.setCustomId("MJ::JOB::upsample::3::2aeefbef-43e2-4057-bcf1-43b5f39ab6f7"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void uploadAttachmentsTest() {
|
||||
MjInteractions mjImagineInteractions = new MjInteractions(midjourneyConfig);
|
||||
MidjourneyInteractionsApi mjImagineInteractions = new MidjourneyInteractionsApi(midjourneyConfig);
|
||||
UploadAttachmentsRes res = mjImagineInteractions.uploadAttachments(
|
||||
new Attachments().setFileSystemResource(
|
||||
new AttachmentsReq().setFileSystemResource(
|
||||
new FileSystemResource(new File("/Users/fansili/Downloads/DSC01402.JPG")))
|
||||
);
|
||||
System.err.println(JSON.toJSONString(res));
|
||||
|
@ -64,8 +63,8 @@ public class MjInteractionsTests {
|
|||
|
||||
@Test
|
||||
public void describeTest() {
|
||||
MjInteractions mjImagineInteractions = new MjInteractions(midjourneyConfig);
|
||||
mjImagineInteractions.describe(new Describe()
|
||||
MidjourneyInteractionsApi mjImagineInteractions = new MidjourneyInteractionsApi(midjourneyConfig);
|
||||
mjImagineInteractions.describe(new DescribeReq()
|
||||
.setFileName("DSC01402.JPG")
|
||||
.setFinalFileName("16826931-2873-45ec-8cfb-0ad81f1a075f/DSC01402.JPG")
|
||||
);
|
|
@ -1,6 +1,6 @@
|
|||
package cn.iocoder.yudao.framework.ai.mj;
|
||||
package cn.iocoder.yudao.framework.ai.midjourney;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.util.MjUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.util.MidjourneyUtil;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
|
@ -9,14 +9,14 @@ import org.junit.Test;
|
|||
* author: fansili
|
||||
* time: 2024/4/6 21:57
|
||||
*/
|
||||
public class MjUtilTests {
|
||||
public class MidjourneyUtilTests {
|
||||
|
||||
@Test
|
||||
public void parseContentTest() {
|
||||
String content1 = "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (32%) (fast, stealth)";
|
||||
String content2 = "**南极应该是什么样子? --v 6.0 --style raw** - <@972721304891453450> (fast, stealth)";
|
||||
|
||||
System.err.println(MjUtil.parseContent(content1));
|
||||
System.err.println(MjUtil.parseContent(content2));
|
||||
System.err.println(MidjourneyUtil.parseContent(content1));
|
||||
System.err.println(MidjourneyUtil.parseContent(content2));
|
||||
}
|
||||
}
|
|
@ -1,9 +1,8 @@
|
|||
package cn.iocoder.yudao.framework.ai.mj;
|
||||
package cn.iocoder.yudao.framework.ai.midjourney;
|
||||
|
||||
import cn.hutool.core.io.FileUtil;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.MidjourneyConfig;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MjMessageListener;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MjWebSocketStarter;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.listener.MidjourneyMessageListener;
|
||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -17,7 +16,7 @@ import java.util.Scanner;
|
|||
* author: fansili
|
||||
* time: 2024/4/3 16:40
|
||||
*/
|
||||
public class MjWebSocketTests {
|
||||
public class MidjourneyWebSocketTests {
|
||||
|
||||
private MidjourneyConfig midjourneyConfig;
|
||||
|
||||
|
@ -35,8 +34,8 @@ public class MjWebSocketTests {
|
|||
@Test
|
||||
public void startSocketTest() {
|
||||
String wssUrl = "wss://gateway.discord.gg";
|
||||
var messageListener = new MjMessageListener(midjourneyConfig);
|
||||
var webSocketStarter = new MjWebSocketStarter(wssUrl, null, midjourneyConfig, messageListener);
|
||||
var messageListener = new MidjourneyMessageListener(midjourneyConfig);
|
||||
var webSocketStarter = new MidjourneyWebSocketStarter(wssUrl, null, midjourneyConfig, messageListener);
|
||||
|
||||
try {
|
||||
webSocketStarter.start();
|
Loading…
Reference in New Issue