【新增】AI:发送消息时,增加上下文

This commit is contained in:
YunaiV 2024-05-21 08:30:51 +08:00
parent 3a10fedddb
commit 802dee2fc3
3 changed files with 72 additions and 19 deletions

View File

@ -19,4 +19,7 @@ public class AiChatMessageSendReqVO {
@NotEmpty(message = "聊天内容不能为空")
private String content;
@Schema(description = "是否携带上下文", example = "true")
private Boolean useContext;
}

View File

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import com.baomidou.mybatisplus.annotation.TableId;
import org.springframework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
@ -27,14 +28,23 @@ public class AiChatMessageDO extends BaseDO {
/**
* 编号作为每条聊天记录的唯一标识符
*/
@TableId
private Long id;
/**
* 会话编号
*
* 关联 {@link AiChatConversationDO#getId()}
* 关联 {@link AiChatConversationDO#getId()} 字段
*/
private Long conversationId;
/**
* 回复消息编号
*
* 关联 {@link #id} 字段
*
* 大模型回复的消息编号用于问答的关联
*/
private Long replyId;
/**
* 消息类型
@ -75,6 +85,9 @@ public class AiChatMessageDO extends BaseDO {
*/
private String content;
// TODO 芋艿是否作为上下文语料use_context待定
/**
* 是否携带上下文
*/
private Boolean useContext;
}

View File

@ -1,5 +1,9 @@
package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.BooleanUtil;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
@ -109,15 +113,14 @@ public class AiChatServiceImpl implements AiChatService {
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
// 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), model,
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent());
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
// 3.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "");
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 创建 chat 需要的 Prompt
// TODO 消息上下文
Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
@ -139,32 +142,66 @@ public class AiChatServiceImpl implements AiChatService {
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, AiChatMessageSendReqVO sendReqVO) {
// TODO 芋艿1保留 n 个上下文2每一轮 token 数量
// if (conversation.getMaxContexts() != null && messages.size() > conversation.getMaxContexts()) {
//
// }
// 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>();
// 1.1 system context 角色设定
chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
// 1.2 history message 历史消息
messages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
// 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象 TODO 芋艿临时注释掉等文心一言兼容了
// TODO 每一轮 token 数量
// ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build();
// return new Prompt(chatMessages, null);
return new Prompt(chatMessages);
}
private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model,
Long userId, Long roleId,
MessageType messageType, String content) {
AiChatMessageDO message = new AiChatMessageDO()
.setConversationId(conversationId).setModel(model.getModel()).setModelId(model.getId())
.setUserId(userId).setRoleId(roleId)
.setType(messageType.getValue()).setContent(content);
/**
* 从历史消息中获得倒序的 n 组消息作为消息上下文
*
* n 指的是 user + assistant 形成一组
*
* @param messages 消息列表
* @param conversation 会话
* @param sendReqVO 发送请求
* @return 消息上下文
*/
private List<AiChatMessageDO> filterContextMessages(List<AiChatMessageDO> messages, AiChatConversationDO conversation, AiChatMessageSendReqVO sendReqVO) {
if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
return Collections.emptyList();
}
List<AiChatMessageDO> contextMessages = new ArrayList<>(conversation.getMaxContexts() * 2);
for (int i = messages.size() - 1; i >= 0; i--) {
AiChatMessageDO assistantMessage = CollUtil.get(messages, i);
if (assistantMessage == null || assistantMessage.getReplyId() == null) {
continue;
}
AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
|| StrUtil.isEmpty(assistantMessage.getContent())) {
continue;
}
// 由于后续要 reverse 反转所以先添加 assistantMessage
contextMessages.add(assistantMessage);
contextMessages.add(userMessage);
// 超过最大上下文结束
if (contextMessages.size() >= conversation.getMaxContexts() * 2) {
break;
}
}
Collections.reverse(contextMessages);
return contextMessages;
}
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
AiChatModelDO model, Long userId, Long roleId,
MessageType messageType, String content, Boolean useContext) {
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
.setType(messageType.getValue()).setContent(content).setUseContext(useContext);
message.setCreateTime(LocalDateTime.now());
chatMessageMapper.insert(message);
return message;