【优化】聊天 event stream 改为 flex 返回更加的优雅
This commit is contained in:
parent
708d66e8cf
commit
5579620140
|
@ -1,26 +0,0 @@
|
||||||
package cn.iocoder.yudao.module.ai.controller;
|
|
||||||
|
|
||||||
import org.springframework.http.HttpHeaders;
|
|
||||||
import org.springframework.http.MediaType;
|
|
||||||
import org.springframework.http.server.ServerHttpResponse;
|
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
||||||
|
|
||||||
import java.nio.charset.StandardCharsets;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 解决中文乱码
|
|
||||||
*
|
|
||||||
* @author fansili
|
|
||||||
* @time 2024/4/14 15:13
|
|
||||||
* @since 1.0
|
|
||||||
*/
|
|
||||||
public class Utf8SseEmitter extends SseEmitter {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void extendResponse(ServerHttpResponse outputMessage) {
|
|
||||||
super.extendResponse(outputMessage);
|
|
||||||
|
|
||||||
HttpHeaders headers = outputMessage.getHeaders();
|
|
||||||
headers.setContentType(new MediaType(MediaType.TEXT_EVENT_STREAM, StandardCharsets.UTF_8));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,10 +1,9 @@
|
||||||
package cn.iocoder.yudao.module.ai.controller.admin.chat;
|
package cn.iocoder.yudao.module.ai.controller.admin.chat;
|
||||||
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||||
|
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.Parameter;
|
import io.swagger.v3.oas.annotations.Parameter;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
|
@ -13,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.validation.annotation.Validated;
|
import org.springframework.validation.annotation.Validated;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -39,10 +38,8 @@ public class AiChatMessageController {
|
||||||
// TODO @fan:要不要使用 Flux 来返回;可以使用 Flux<AiChatMessageRespVO>
|
// TODO @fan:要不要使用 Flux 来返回;可以使用 Flux<AiChatMessageRespVO>
|
||||||
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
|
@Operation(summary = "发送消息(流式)", description = "流式返回,响应较快")
|
||||||
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
@PostMapping(value = "/send-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
public SseEmitter sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
public Flux<AiChatMessageRespVO> sendMessageStream(@Validated @RequestBody AiChatMessageSendReqVO sendReqVO) {
|
||||||
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
|
return chatService.chatStream(sendReqVO);
|
||||||
chatService.chatStream(sendReqVO, sseEmitter);
|
|
||||||
return sseEmitter;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Operation(summary = "获得指定会话的消息列表")
|
@Operation(summary = "获得指定会话的消息列表")
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
package cn.iocoder.yudao.module.ai.controller.admin.image;
|
package cn.iocoder.yudao.module.ai.controller.admin.image;
|
||||||
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
|
||||||
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
|
||||||
|
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
|
@ -14,7 +13,6 @@ import org.springframework.web.bind.annotation.PostMapping;
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
||||||
|
|
||||||
// TODO @芋艿:整理接口定义
|
// TODO @芋艿:整理接口定义
|
||||||
/**
|
/**
|
||||||
|
@ -35,10 +33,11 @@ public class AiImageController {
|
||||||
|
|
||||||
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
|
@Operation(summary = "dall2/dall3绘画", description = "openAi dall3是付费的!")
|
||||||
@PostMapping("/dallDrawing")
|
@PostMapping("/dallDrawing")
|
||||||
public SseEmitter dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) {
|
public void dallDrawing(@Validated @RequestBody AiImageDallDrawingReq req) {
|
||||||
Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
|
// Utf8SseEmitter sseEmitter = new Utf8SseEmitter();
|
||||||
aiImageService.dallDrawing(req, sseEmitter);
|
// aiImageService.dallDrawing(req, sseEmitter);
|
||||||
return sseEmitter;
|
// return sseEmitter;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
|
@Operation(summary = "midjourney", description = "midjourney图片绘画流程:1、提交任务 2、获取完成的任务 3、选择对应功能 4、获取最终结果")
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
package cn.iocoder.yudao.module.ai.service;
|
package cn.iocoder.yudao.module.ai.service;
|
||||||
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||||
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -26,11 +26,10 @@ public interface AiChatService {
|
||||||
/**
|
/**
|
||||||
* chat stream
|
* chat stream
|
||||||
*
|
*
|
||||||
* @param req
|
* @param sendReqVO
|
||||||
* @param sseEmitter
|
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter);
|
Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO sendReqVO);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取 - 获取对话 message list
|
* 获取 - 获取对话 message list
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
package cn.iocoder.yudao.module.ai.service;
|
package cn.iocoder.yudao.module.ai.service;
|
||||||
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
|
||||||
|
|
||||||
|
@ -17,9 +16,8 @@ public interface AiImageService {
|
||||||
* ai绘画 - dall2/dall3 绘画
|
* ai绘画 - dall2/dall3 绘画
|
||||||
*
|
*
|
||||||
* @param req
|
* @param req
|
||||||
* @param sseEmitter
|
|
||||||
*/
|
*/
|
||||||
void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter);
|
void dallDrawing(AiImageDallDrawingReq req);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* midjourney 图片生成
|
* midjourney 图片生成
|
||||||
|
|
|
@ -9,7 +9,6 @@ import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
|
||||||
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||||
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory;
|
||||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
|
||||||
|
@ -25,13 +24,12 @@ import cn.iocoder.yudao.module.ai.service.AiChatRoleService;
|
||||||
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
import cn.iocoder.yudao.module.ai.service.AiChatService;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.http.MediaType;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -76,6 +74,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||||
chatModal.getModel(), chatModal.getId(), req.getContent(),
|
chatModal.getModel(), chatModal.getId(), req.getContent(),
|
||||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||||
String content = null;
|
String content = null;
|
||||||
|
int tokens = 0;
|
||||||
try {
|
try {
|
||||||
// 创建 chat 需要的 Prompt
|
// 创建 chat 需要的 Prompt
|
||||||
Prompt prompt = new Prompt(req.getContent());
|
Prompt prompt = new Prompt(req.getContent());
|
||||||
|
@ -87,6 +86,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||||
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
|
ChatClient chatClient = aiChatClientFactory.getChatClient(platformEnum);
|
||||||
ChatResponse call = chatClient.call(prompt);
|
ChatResponse call = chatClient.call(prompt);
|
||||||
content = call.getResult().getOutput().getContent();
|
content = call.getResult().getOutput().getContent();
|
||||||
|
tokens = call.getResults().size();
|
||||||
// 更新 conversation
|
// 更新 conversation
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
content = ExceptionUtil.getMessage(e);
|
content = ExceptionUtil.getMessage(e);
|
||||||
|
@ -94,7 +94,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||||
// 保存 chat message
|
// 保存 chat message
|
||||||
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
||||||
chatModal.getModel(), chatModal.getId(), content,
|
chatModal.getModel(), chatModal.getId(), content,
|
||||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
tokens, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||||
}
|
}
|
||||||
return new AiChatMessageRespVO().setContent(content);
|
return new AiChatMessageRespVO().setContent(content);
|
||||||
}
|
}
|
||||||
|
@ -123,8 +123,7 @@ public class AiChatServiceImpl implements AiChatService {
|
||||||
return insertChatMessageDO;
|
return insertChatMessageDO;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public Flux<AiChatMessageRespVO> chatStream(AiChatMessageSendReqVO req) {
|
||||||
public void chatStream(AiChatMessageSendReqVO req, Utf8SseEmitter sseEmitter) {
|
|
||||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||||
// 查询对话
|
// 查询对话
|
||||||
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
AiChatConversationRespVO conversation = chatConversationService.getConversationOfValidate(req.getConversationId());
|
||||||
|
@ -144,47 +143,43 @@ public class AiChatServiceImpl implements AiChatService {
|
||||||
// req.setTopK(req.getTopK());
|
// req.setTopK(req.getTopK());
|
||||||
// req.setTopP(req.getTopP());
|
// req.setTopP(req.getTopP());
|
||||||
// req.setTemperature(req.getTemperature());
|
// req.setTemperature(req.getTemperature());
|
||||||
// 保存 chat message
|
|
||||||
// 保存 chat message
|
// 保存 chat message
|
||||||
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
insertChatMessage(conversation.getId(), MessageType.USER, loginUserId, conversation.getRoleId(),
|
||||||
chatModal.getModel(), chatModal.getId(), req.getContent(),
|
chatModal.getModel(), chatModal.getId(), req.getContent(),
|
||||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||||
|
|
||||||
// 获取 client 类型
|
// 获取 client 类型
|
||||||
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform());
|
AiPlatformEnum platformEnum = AiPlatformEnum.valueOfPlatform(chatModal.getPlatform());
|
||||||
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
StreamingChatClient streamingChatClient = aiChatClientFactory.getStreamingChatClient(platformEnum);
|
||||||
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
Flux<ChatResponse> streamResponse = streamingChatClient.stream(prompt);
|
||||||
|
// 转换 flex AiChatMessageRespVO
|
||||||
StringBuffer contentBuffer = new StringBuffer();
|
StringBuffer contentBuffer = new StringBuffer();
|
||||||
streamResponse.subscribe(
|
AtomicInteger tokens = new AtomicInteger(0);
|
||||||
new Consumer<ChatResponse>() {
|
return streamResponse.map(res -> {
|
||||||
|
AiChatMessageRespVO aiChatMessageRespVO = new AiChatMessageRespVO();
|
||||||
|
aiChatMessageRespVO.setContent(res.getResult().getOutput().getContent());
|
||||||
|
contentBuffer.append(res.getResult().getOutput().getContent());
|
||||||
|
tokens.incrementAndGet();
|
||||||
|
return aiChatMessageRespVO;
|
||||||
|
}
|
||||||
|
).doOnComplete(new Runnable() {
|
||||||
@Override
|
@Override
|
||||||
public void accept(ChatResponse chatResponse) {
|
public void run() {
|
||||||
String content = chatResponse.getResults().get(0).getOutput().getContent();
|
|
||||||
try {
|
|
||||||
contentBuffer.append(content);
|
|
||||||
sseEmitter.send(new AiChatMessageRespVO().setContent(content), MediaType.APPLICATION_JSON);
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("发送异常{}", ExceptionUtil.getMessage(e));
|
|
||||||
// 如果不是因为关闭而抛出异常,则重新连接
|
|
||||||
sseEmitter.completeWithError(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
error -> {
|
|
||||||
//
|
|
||||||
log.error("subscribe错误 {}", ExceptionUtil.getMessage(error));
|
|
||||||
},
|
|
||||||
() -> {
|
|
||||||
log.info("发送完成!");
|
log.info("发送完成!");
|
||||||
sseEmitter.complete();
|
|
||||||
// 保存 chat message
|
// 保存 chat message
|
||||||
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
||||||
chatModal.getModel(), chatModal.getId(), contentBuffer.toString(),
|
chatModal.getModel(), chatModal.getId(), contentBuffer.toString(),
|
||||||
null, conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||||
|
|
||||||
}
|
}
|
||||||
);
|
}).doOnError(new Consumer<Throwable>() {
|
||||||
|
@Override
|
||||||
|
public void accept(Throwable throwable) {
|
||||||
|
log.error("发送错误 {}!", throwable.getMessage());
|
||||||
|
// 保存 chat message
|
||||||
|
insertChatMessage(conversation.getId(), MessageType.SYSTEM, loginUserId, conversation.getRoleId(),
|
||||||
|
chatModal.getModel(), chatModal.getId(), throwable.getMessage(),
|
||||||
|
tokens.get(), conversation.getTemperature(), conversation.getMaxTokens(), conversation.getMaxContexts());
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -5,8 +5,8 @@ import cn.iocoder.yudao.framework.ai.image.ImageGeneration;
|
||||||
import cn.iocoder.yudao.framework.ai.image.ImagePrompt;
|
import cn.iocoder.yudao.framework.ai.image.ImagePrompt;
|
||||||
import cn.iocoder.yudao.framework.ai.image.ImageResponse;
|
import cn.iocoder.yudao.framework.ai.image.ImageResponse;
|
||||||
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
|
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageClient;
|
||||||
import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum;
|
|
||||||
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
|
import cn.iocoder.yudao.framework.ai.imageopenai.OpenAiImageOptions;
|
||||||
|
import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageModelEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum;
|
import cn.iocoder.yudao.framework.ai.imageopenai.enums.OpenAiImageStyleEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
|
import cn.iocoder.yudao.framework.ai.midjourney.api.MidjourneyInteractionsApi;
|
||||||
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
|
import cn.iocoder.yudao.framework.ai.midjourney.webSocket.MidjourneyWebSocketStarter;
|
||||||
|
@ -14,22 +14,18 @@ import cn.iocoder.yudao.framework.ai.midjourney.webSocket.WssNotify;
|
||||||
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
import cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil;
|
||||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||||
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
|
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
|
||||||
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
|
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
|
|
||||||
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDallDrawingReq;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
|
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageMidjourneyReq;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.mysql.AiImageMapper;
|
||||||
|
import cn.iocoder.yudao.module.ai.enums.AiChatDrawingStatusEnum;
|
||||||
|
import cn.iocoder.yudao.module.ai.service.AiImageService;
|
||||||
import jakarta.annotation.PostConstruct;
|
import jakarta.annotation.PostConstruct;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.http.MediaType;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ai 作图
|
* ai 作图
|
||||||
*
|
*
|
||||||
|
@ -64,7 +60,7 @@ public class AiImageServiceImpl implements AiImageService {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void dallDrawing(AiImageDallDrawingReq req, Utf8SseEmitter sseEmitter) {
|
public void dallDrawing(AiImageDallDrawingReq req) {
|
||||||
// 获取 model
|
// 获取 model
|
||||||
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal());
|
OpenAiImageModelEnum openAiImageModelEnum = OpenAiImageModelEnum.valueOfModel(req.getModal());
|
||||||
OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
|
OpenAiImageStyleEnum openAiImageStyleEnum = OpenAiImageStyleEnum.valueOfStyle(req.getStyle());
|
||||||
|
@ -79,7 +75,7 @@ public class AiImageServiceImpl implements AiImageService {
|
||||||
// 发送
|
// 发送
|
||||||
ImageGeneration imageGeneration = imageResponse.getResult();
|
ImageGeneration imageGeneration = imageResponse.getResult();
|
||||||
// 发送信息
|
// 发送信息
|
||||||
sendSseEmitter(sseEmitter, imageGeneration);
|
// sendSseEmitter(sseEmitter, imageGeneration);
|
||||||
// 保存数据库
|
// 保存数据库
|
||||||
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
||||||
imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
|
imageGeneration.getOutput().getUrl(), AiChatDrawingStatusEnum.COMPLETE, null);
|
||||||
|
@ -88,7 +84,7 @@ public class AiImageServiceImpl implements AiImageService {
|
||||||
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
doSave(req.getPrompt(), req.getSize(), req.getModal(),
|
||||||
null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
|
null, AiChatDrawingStatusEnum.FAIL, aiException.getMessage());
|
||||||
// 发送错误信息
|
// 发送错误信息
|
||||||
sendSseEmitter(sseEmitter, aiException.getMessage());
|
// sendSseEmitter(sseEmitter, aiException.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,16 +101,16 @@ public class AiImageServiceImpl implements AiImageService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
|
// private static void sendSseEmitter(Utf8SseEmitter sseEmitter, Object object) {
|
||||||
try {
|
// try {
|
||||||
sseEmitter.send(object, MediaType.APPLICATION_JSON);
|
// sseEmitter.send(object, MediaType.APPLICATION_JSON);
|
||||||
} catch (IOException e) {
|
// } catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
// throw new RuntimeException(e);
|
||||||
} finally {
|
// } finally {
|
||||||
// 发送 complete
|
// // 发送 complete
|
||||||
sseEmitter.complete();
|
// sseEmitter.complete();
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
private AiImageDO doSave(String prompt,
|
private AiImageDO doSave(String prompt,
|
||||||
String size,
|
String size,
|
||||||
|
|
|
@ -2,7 +2,6 @@ server:
|
||||||
port: 48080
|
port: 48080
|
||||||
|
|
||||||
--- #################### 数据库相关配置 ####################
|
--- #################### 数据库相关配置 ####################
|
||||||
|
|
||||||
spring:
|
spring:
|
||||||
# 数据源配置项
|
# 数据源配置项
|
||||||
autoconfigure:
|
autoconfigure:
|
||||||
|
@ -79,7 +78,12 @@ spring:
|
||||||
port: 6379 # 端口
|
port: 6379 # 端口
|
||||||
database: 0 # 数据库索引
|
database: 0 # 数据库索引
|
||||||
# password: dev # 密码,建议生产环境开启
|
# password: dev # 密码,建议生产环境开启
|
||||||
|
server:
|
||||||
|
servlet:
|
||||||
|
encoding:
|
||||||
|
enabled: true
|
||||||
|
charset: UTF-8
|
||||||
|
force: true
|
||||||
--- #################### 定时任务相关配置 ####################
|
--- #################### 定时任务相关配置 ####################
|
||||||
|
|
||||||
# Quartz 配置项,对应 QuartzProperties 配置类
|
# Quartz 配置项,对应 QuartzProperties 配置类
|
||||||
|
|
Loading…
Reference in New Issue