【新增】AI:发送消息时,增加上下文
This commit is contained in:
parent
3a10fedddb
commit
802dee2fc3
|
@ -19,4 +19,7 @@ public class AiChatMessageSendReqVO {
|
|||
@NotEmpty(message = "聊天内容不能为空")
|
||||
private String content;
|
||||
|
||||
@Schema(description = "是否携带上下文", example = "true")
|
||||
private Boolean useContext;
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import org.springframework.ai.chat.messages.MessageType;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||
|
@ -27,14 +28,23 @@ public class AiChatMessageDO extends BaseDO {
|
|||
/**
|
||||
* 编号,作为每条聊天记录的唯一标识符
|
||||
*/
|
||||
@TableId
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* 会话编号
|
||||
*
|
||||
* 关联 {@link AiChatConversationDO#getId()}
|
||||
* 关联 {@link AiChatConversationDO#getId()} 字段
|
||||
*/
|
||||
private Long conversationId;
|
||||
/**
|
||||
* 回复消息编号
|
||||
*
|
||||
* 关联 {@link #id} 字段
|
||||
*
|
||||
* 大模型回复的消息编号,用于“问答”的关联
|
||||
*/
|
||||
private Long replyId;
|
||||
|
||||
/**
|
||||
* 消息类型
|
||||
|
@ -75,6 +85,9 @@ public class AiChatMessageDO extends BaseDO {
|
|||
*/
|
||||
private String content;
|
||||
|
||||
// TODO 芋艿:是否作为上下文语料?use_context,待定
|
||||
/**
|
||||
* 是否携带上下文
|
||||
*/
|
||||
private Boolean useContext;
|
||||
|
||||
}
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
package cn.iocoder.yudao.module.ai.service.impl;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.collection.ListUtil;
|
||||
import cn.hutool.core.util.ArrayUtil;
|
||||
import cn.hutool.core.util.BooleanUtil;
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
|
@ -109,15 +113,14 @@ public class AiChatServiceImpl implements AiChatService {
|
|||
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
|
||||
|
||||
// 2. 插入 user 发送消息
|
||||
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), model,
|
||||
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent());
|
||||
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
|
||||
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
|
||||
|
||||
// 3.1 插入 assistant 接收消息
|
||||
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), model,
|
||||
userId, conversation.getRoleId(), MessageType.ASSISTANT, "");
|
||||
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
||||
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
||||
|
||||
// 3.2 创建 chat 需要的 Prompt
|
||||
// TODO 消息上下文
|
||||
Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO);
|
||||
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
||||
|
||||
|
@ -139,32 +142,66 @@ public class AiChatServiceImpl implements AiChatService {
|
|||
}
|
||||
|
||||
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, AiChatMessageSendReqVO sendReqVO) {
|
||||
// TODO 芋艿:1)保留 n 个上下文;2)每一轮 token 数量
|
||||
// if (conversation.getMaxContexts() != null && messages.size() > conversation.getMaxContexts()) {
|
||||
//
|
||||
// }
|
||||
// 1. 构建 Prompt Message 列表
|
||||
List<Message> chatMessages = new ArrayList<>();
|
||||
// 1.1 system context 角色设定
|
||||
chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
|
||||
// 1.2 history message 历史消息
|
||||
messages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
|
||||
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
||||
contextMessages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
|
||||
// 1.3 user message 新发送消息
|
||||
chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
||||
|
||||
// 2. 构建 ChatOptions 对象 TODO 芋艿:临时注释掉;等文心一言兼容了;
|
||||
// TODO 每一轮 token 数量
|
||||
// ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build();
|
||||
// return new Prompt(chatMessages, null);
|
||||
return new Prompt(chatMessages);
|
||||
}
|
||||
|
||||
private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model,
|
||||
Long userId, Long roleId,
|
||||
MessageType messageType, String content) {
|
||||
AiChatMessageDO message = new AiChatMessageDO()
|
||||
.setConversationId(conversationId).setModel(model.getModel()).setModelId(model.getId())
|
||||
.setUserId(userId).setRoleId(roleId)
|
||||
.setType(messageType.getValue()).setContent(content);
|
||||
/**
|
||||
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
||||
*
|
||||
* n 组:指的是 user + assistant 形成一组
|
||||
*
|
||||
* @param messages 消息列表
|
||||
* @param conversation 会话
|
||||
* @param sendReqVO 发送请求
|
||||
* @return 消息上下文
|
||||
*/
|
||||
private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, AiChatConversationDO conversation, AiChatMessageSendReqVO sendReqVO) {
|
||||
if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<AiChatMessageDO> contextMessages = new ArrayList<>(conversation.getMaxContexts() * 2);
|
||||
for (int i = messages.size() - 1; i >= 0; i--) {
|
||||
AiChatMessageDO assistantMessage = CollUtil.get(messages, i);
|
||||
if (assistantMessage == null || assistantMessage.getReplyId() == null) {
|
||||
continue;
|
||||
}
|
||||
AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
|
||||
if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
|
||||
|| StrUtil.isEmpty(assistantMessage.getContent())) {
|
||||
continue;
|
||||
}
|
||||
// 由于后续要 reverse 反转,所以先添加 assistantMessage
|
||||
contextMessages.add(assistantMessage);
|
||||
contextMessages.add(userMessage);
|
||||
// 超过最大上下文,结束
|
||||
if (contextMessages.size() >= conversation.getMaxContexts() * 2) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Collections.reverse(contextMessages);
|
||||
return contextMessages;
|
||||
}
|
||||
|
||||
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
|
||||
AiChatModelDO model, Long userId, Long roleId,
|
||||
MessageType messageType, String content, Boolean useContext) {
|
||||
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
|
||||
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
|
||||
.setType(messageType.getValue()).setContent(content).setUseContext(useContext);
|
||||
message.setCreateTime(LocalDateTime.now());
|
||||
chatMessageMapper.insert(message);
|
||||
return message;
|
||||
|
|
Loading…
Reference in New Issue