聊天对话,增加 创建对话、还是继续对话逻辑

This commit is contained in:
cherishsince 2024-04-14 18:40:42 +08:00
parent a2bd9b710e
commit 7794992225
5 changed files with 145 additions and 15 deletions

View File

@ -0,0 +1,34 @@
package cn.iocoder.yudao.module.ai.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 聊天类型
*
* @author fansili
* @time 2024/4/14 17:58
* @since 1.0
*/
@AllArgsConstructor
@Getter
public enum ChatTypeEnum {
ROLE_CHAT("roleChat", "角色模板聊天"),
USER_CHAT("userChat", "用户普通聊天"),
;
private String type;
private String name;
public static ChatTypeEnum valueOfType(String type) {
for (ChatTypeEnum itemEnum : ChatTypeEnum.values()) {
if (itemEnum.getType().equals(type)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + type);
}
}

View File

@ -23,12 +23,12 @@ public class AiChatMessageDO {
/**
* 聊天ID关联到特定的会话或对话
*/
private Long chatId;
private Long chatConversationId;
/**
* 角色ID用于标识发送消息的用户或系统的身份
*/
private String userId;
private Long userId;
/**
* 消息具体内容存储用户的发言或者系统响应的文字信息
@ -38,7 +38,7 @@ public class AiChatMessageDO {
/**
* 消息类型枚举值可能包括'system'(系统消息)'user'(用户消息)'assistant'(助手消息)
*/
private Double messageType;
private String messageType;
/**
* 在生成消息时采用的Top-K采样大小

View File

@ -1,14 +1,28 @@
package cn.iocoder.yudao.module.ai.service.impl;
import cn.hutool.core.exceptions.ExceptionUtil;
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.framework.common.exception.ServerException;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum;
import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.service.ChatService;
import cn.iocoder.yudao.module.ai.vo.ChatReq;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
/**
@ -24,6 +38,10 @@ import reactor.core.publisher.Flux;
public class ChatServiceImpl implements ChatService {
private final AiClient aiClient;
private final AiChatRoleMapper aiChatRoleMapper;
private final AiChatMessageMapper aiChatMessageMapper;
private final AiChatConversationMapper aiChatConversationMapper;
/**
* chat
@ -31,16 +49,84 @@ public class ChatServiceImpl implements ChatService {
* @param req
* @return
*/
@Transactional(rollbackFor = Exception.class)
public String chat(ChatReq req) {
// 获取 client 类型
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
// 发送 call 调用
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
return call.getResult().getOutput().getContent();
// 获取 对话类型(新建还是继续)
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
AiChatConversationDO aiChatConversationDO;
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
// 创建一个新的对话
aiChatConversationDO = createNewChatConversation(req, loginUserId);
} else {
// 继续对话
if (req.getConversationId() == null) {
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
}
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
}
String content;
try {
// 创建 chat 需要的 Prompt
Prompt prompt = new Prompt(req.getPrompt());
req.setTopK(req.getTopK());
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
// 发送 call 调用
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
content = call.getResult().getOutput().getContent();
} catch (Exception e) {
content = ExceptionUtil.getMessage(e);
}
// 增加 chat message 记录
aiChatMessageMapper.insert(
new AiChatMessageDO()
.setId(null)
.setChatConversationId(aiChatConversationDO.getId())
.setUserId(loginUserId)
.setMessage(req.getPrompt())
.setMessageType(MessageType.USER.getValue())
.setTopK(req.getTopK())
.setTopP(req.getTopP())
.setTemperature(req.getTemperature())
);
// chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
return content;
}
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
// 获取 chat 角色
String chatRoleName = null;
ChatTypeEnum chatTypeEnum = null;
Long chatRoleId = req.getChatRoleId();
if (req.getChatRoleId() != null) {
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId);
if (aiChatRoleDO == null) {
throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT);
}
chatTypeEnum = ChatTypeEnum.ROLE_CHAT;
chatRoleName = aiChatRoleDO.getRoleName();
} else {
chatTypeEnum = ChatTypeEnum.USER_CHAT;
}
//
AiChatConversationDO insertChatConversation = new AiChatConversationDO()
.setId(null)
.setUserId(loginUserId)
.setChatRoleId(req.getChatRoleId())
.setChatRoleName(chatRoleName)
.setChatType(chatTypeEnum.getType())
.setChatCount(1)
.setChatTitle(req.getPrompt().substring(0, 20) + "...");
aiChatConversationMapper.insert(insertChatConversation);
return insertChatConversation;
}
/**

View File

@ -24,19 +24,29 @@ public class ChatReq {
@Schema(description = "填入固定值1 issues, 2 pr")
private String prompt;
@Schema(description = "chat角色模板")
private Long chatRoleId;
@Schema(description = "用于控制随机性和多样性的温度参数")
private Float temperature;
private Double temperature;
@Schema(description = "生成时核采样方法的概率阈值。例如取值为0.8时仅保留累计概率之和大于等于0.8的概率分布中的token\n" +
" * 作为随机采样的候选集。取值范围为0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" +
" * 默认值为0.8。注意取值不要大于等于1\n")
private Float topP;
private Double topP;
@Schema(description = "在生成消息时采用的Top-K采样大小表示模型生成回复时考虑的候选项集合的大小")
private Integer topK;
private Double topK;
@Schema(description = "ai模型(查看 AiClientNameEnum)")
@NotNull(message = "模型不能为空!")
@Size(max = 30, message = "模型字符最大30个字符!")
private String modal;
@Schema(description = "对话类型(new、continue)")
@NotNull(message = "对话类型,不能为空!")
private String conversationType;
@Schema(description = "对话Id")
private Long conversationId;
}

View File

@ -59,7 +59,7 @@ public abstract class AbstractMessage implements Message {
}
protected AbstractMessage(MessageType messageType, String textContent, List<MediaData> mediaData,
Map<String, Object> messageProperties) {
Map<String, Object> messageProperties) {
Assert.notNull(messageType, "Message type must not be null");
Assert.notNull(textContent, "Content must not be null");