【代码优化】AI:思维导入、写作的生成
This commit is contained in:
parent
ecb50c6511
commit
68ed8cd6f8
|
@ -111,7 +111,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
||||||
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
||||||
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
||||||
|
|
||||||
// 3.2 创建 chat 需要的 Prompt
|
// 3.2 构建 Prompt,并进行调用
|
||||||
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
|
||||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||||
|
|
||||||
|
|
|
@ -32,13 +32,12 @@ import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
|
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
|
||||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI 写作 Service 实现类
|
* AI 思维导图 Service 实现类
|
||||||
*
|
*
|
||||||
* @author xiaoxin
|
* @author xiaoxin
|
||||||
*/
|
*/
|
||||||
|
@ -58,30 +57,28 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
|
public Flux<CommonResult<String>> generateMindMap(AiMindMapGenerateReqVO generateReqVO, Long userId) {
|
||||||
// 1 获取脑图模型 尝试获取思维导图助手角色,如果没有则使用默认模型
|
// 1. 获取脑图模型。尝试获取思维导图助手角色,如果没有则使用默认模型
|
||||||
AiChatRoleDO mindMapRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
|
AiChatRoleDO role = CollUtil.getFirst(
|
||||||
|
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
|
||||||
// 1.1 获取脑图执行模型
|
// 1.1 获取脑图执行模型
|
||||||
AiChatModelDO model = getModel(mindMapRole);
|
AiChatModelDO model = getModel(role);
|
||||||
// 1.2 获取角色设定消息
|
// 1.2 获取角色设定消息
|
||||||
String systemMessage = Objects.nonNull(mindMapRole) && StrUtil.isNotBlank(mindMapRole.getSystemMessage())
|
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
|
||||||
? mindMapRole.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
|
? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
|
||||||
// 1.3 校验平台
|
// 1.3 校验平台
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
||||||
|
|
||||||
// 2 插入思维导图信息
|
// 2. 插入思维导图信息
|
||||||
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
|
AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
|
||||||
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||||
mindMapMapper.insert(mindMapDO);
|
mindMapMapper.insert(mindMapDO);
|
||||||
|
|
||||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
// 3.1 构建 Prompt,并进行调用
|
||||||
// 3.1 角色设定
|
Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
|
||||||
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
|
||||||
// 3.3 构建提示词
|
|
||||||
Prompt prompt = new Prompt(chatMessages, chatOptions);
|
|
||||||
|
|
||||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||||
// 3.4 流式返回
|
|
||||||
|
// 3.2 流式返回
|
||||||
StringBuffer contentBuffer = new StringBuffer();
|
StringBuffer contentBuffer = new StringBuffer();
|
||||||
return streamResponse.map(chunk -> {
|
return streamResponse.map(chunk -> {
|
||||||
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
||||||
|
@ -102,24 +99,32 @@ public class AiMindMapServiceImpl implements AiMindMapService {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
|
||||||
|
// 1. 构建 message 列表
|
||||||
|
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
||||||
|
// 2. 构建 options 对象
|
||||||
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
|
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||||
|
return new Prompt(chatMessages, options);
|
||||||
|
}
|
||||||
|
|
||||||
private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
|
private static List<Message> buildMessages(AiMindMapGenerateReqVO generateReqVO, String systemMessage) {
|
||||||
List<Message> chatMessages = new ArrayList<>();
|
List<Message> chatMessages = new ArrayList<>();
|
||||||
|
// 1. 角色设定
|
||||||
if (StrUtil.isNotBlank(systemMessage)) {
|
if (StrUtil.isNotBlank(systemMessage)) {
|
||||||
// 1.1 角色设定
|
|
||||||
chatMessages.add(new SystemMessage(systemMessage));
|
chatMessages.add(new SystemMessage(systemMessage));
|
||||||
}
|
}
|
||||||
// 1.2 用户输入
|
// 2. 用户输入
|
||||||
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
|
chatMessages.add(new UserMessage(generateReqVO.getPrompt()));
|
||||||
return chatMessages;
|
return chatMessages;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO 芋艿:这里脑图、写作都用到了,是不是可以抽哪里去
|
private AiChatModelDO getModel(AiChatRoleDO role) {
|
||||||
private AiChatModelDO getModel(AiChatRoleDO chatRoleDO) {
|
|
||||||
AiChatModelDO model = null;
|
AiChatModelDO model = null;
|
||||||
if (Objects.nonNull(chatRoleDO) && Objects.nonNull(chatRoleDO.getModelId())) {
|
if (role != null && role.getModelId() != null) {
|
||||||
model = chatModalService.getChatModel(chatRoleDO.getModelId());
|
model = chatModalService.getChatModel(role.getModelId());
|
||||||
}
|
}
|
||||||
if (Objects.isNull(model)) {
|
if (model != null) {
|
||||||
model = chatModalService.getRequiredDefaultChatModel();
|
model = chatModalService.getRequiredDefaultChatModel();
|
||||||
}
|
}
|
||||||
Assert.notNull(model, "[AI] 获取不到模型");
|
Assert.notNull(model, "[AI] 获取不到模型");
|
||||||
|
|
|
@ -68,8 +68,9 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
||||||
// 1 获取写作模型 尝试获取写作助手角色,没有则使用默认模型
|
// 1 获取写作模型。尝试获取写作助手角色,没有则使用默认模型
|
||||||
AiChatRoleDO writeRole = CollUtil.getFirst(chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
|
AiChatRoleDO writeRole = CollUtil.getFirst(
|
||||||
|
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_WRITE_ROLE.getName()));
|
||||||
// 1.1 获取写作执行模型
|
// 1.1 获取写作执行模型
|
||||||
AiChatModelDO model = getModel(writeRole);
|
AiChatModelDO model = getModel(writeRole);
|
||||||
// 1.2 获取角色设定消息
|
// 1.2 获取角色设定消息
|
||||||
|
@ -84,16 +85,11 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
|
write -> write.setUserId(userId).setPlatform(platform.getPlatform()).setModel(model.getModel()));
|
||||||
writeMapper.insert(writeDO);
|
writeMapper.insert(writeDO);
|
||||||
|
|
||||||
// 3. 调用大模型,写作生成
|
// 3.1 构建 Prompt,并进行调用
|
||||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
Prompt prompt = buildPrompt(generateReqVO, model, systemMessage);
|
||||||
// 3.1 构建消息列表
|
|
||||||
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
|
||||||
// 3.2 构建提示词
|
|
||||||
Prompt prompt = new Prompt(chatMessages, chatOptions);
|
|
||||||
// 3.3 流式调用
|
|
||||||
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
||||||
|
|
||||||
// 4. 流式返回
|
// 3.2 流式返回
|
||||||
StringBuffer contentBuffer = new StringBuffer();
|
StringBuffer contentBuffer = new StringBuffer();
|
||||||
return streamResponse.map(chunk -> {
|
return streamResponse.map(chunk -> {
|
||||||
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
||||||
|
@ -125,6 +121,15 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Prompt buildPrompt(AiWriteGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
|
||||||
|
// 1. 构建 message 列表
|
||||||
|
List<Message> chatMessages = buildMessages(generateReqVO, systemMessage);
|
||||||
|
// 2. 构建 options 对象
|
||||||
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
|
ChatOptions options = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||||
|
return new Prompt(chatMessages, options);
|
||||||
|
}
|
||||||
|
|
||||||
private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
|
private List<Message> buildMessages(AiWriteGenerateReqVO generateReqVO, String systemMessage) {
|
||||||
List<Message> chatMessages = new ArrayList<>();
|
List<Message> chatMessages = new ArrayList<>();
|
||||||
if (StrUtil.isNotBlank(systemMessage)) {
|
if (StrUtil.isNotBlank(systemMessage)) {
|
||||||
|
@ -132,11 +137,11 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
chatMessages.add(new SystemMessage(systemMessage));
|
chatMessages.add(new SystemMessage(systemMessage));
|
||||||
}
|
}
|
||||||
// 1.2 用户输入
|
// 1.2 用户输入
|
||||||
chatMessages.add(new UserMessage(buildWritingPrompt(generateReqVO)));
|
chatMessages.add(new UserMessage(buildUserMessage(generateReqVO)));
|
||||||
return chatMessages;
|
return chatMessages;
|
||||||
}
|
}
|
||||||
|
|
||||||
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
private String buildUserMessage(AiWriteGenerateReqVO generateReqVO) {
|
||||||
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
|
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
|
||||||
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
|
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
|
||||||
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
|
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
|
||||||
|
|
Loading…
Reference in New Issue