聊天对话,增加 创建对话、还是继续对话逻辑
This commit is contained in:
parent
a2bd9b710e
commit
7794992225
|
@ -0,0 +1,34 @@
|
|||
package cn.iocoder.yudao.module.ai.enums;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
/**
|
||||
* 聊天类型
|
||||
*
|
||||
* @author fansili
|
||||
* @time 2024/4/14 17:58
|
||||
* @since 1.0
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
public enum ChatTypeEnum {
|
||||
|
||||
ROLE_CHAT("roleChat", "角色模板聊天"),
|
||||
USER_CHAT("userChat", "用户普通聊天"),
|
||||
|
||||
;
|
||||
|
||||
private String type;
|
||||
|
||||
private String name;
|
||||
|
||||
public static ChatTypeEnum valueOfType(String type) {
|
||||
for (ChatTypeEnum itemEnum : ChatTypeEnum.values()) {
|
||||
if (itemEnum.getType().equals(type)) {
|
||||
return itemEnum;
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException("Invalid MessageType value: " + type);
|
||||
}
|
||||
}
|
|
@ -23,12 +23,12 @@ public class AiChatMessageDO {
|
|||
/**
|
||||
* 聊天ID,关联到特定的会话或对话
|
||||
*/
|
||||
private Long chatId;
|
||||
private Long chatConversationId;
|
||||
|
||||
/**
|
||||
* 角色ID,用于标识发送消息的用户或系统的身份
|
||||
*/
|
||||
private String userId;
|
||||
private Long userId;
|
||||
|
||||
/**
|
||||
* 消息具体内容,存储用户的发言或者系统响应的文字信息
|
||||
|
@ -38,7 +38,7 @@ public class AiChatMessageDO {
|
|||
/**
|
||||
* 消息类型,枚举值可能包括'system'(系统消息)、'user'(用户消息)和'assistant'(助手消息)
|
||||
*/
|
||||
private Double messageType;
|
||||
private String messageType;
|
||||
|
||||
/**
|
||||
* 在生成消息时采用的Top-K采样大小,
|
||||
|
|
|
@ -1,14 +1,28 @@
|
|||
package cn.iocoder.yudao.module.ai.service.impl;
|
||||
|
||||
import cn.hutool.core.exceptions.ExceptionUtil;
|
||||
import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
|
||||
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
|
||||
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
|
||||
import cn.iocoder.yudao.framework.ai.config.AiClient;
|
||||
import cn.iocoder.yudao.framework.common.exception.ServerException;
|
||||
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
|
||||
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatConversationDO;
|
||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatMessageDO;
|
||||
import cn.iocoder.yudao.module.ai.dataobject.AiChatRoleDO;
|
||||
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
|
||||
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
|
||||
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
|
||||
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
|
||||
import cn.iocoder.yudao.module.ai.service.ChatService;
|
||||
import cn.iocoder.yudao.module.ai.vo.ChatReq;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
|
@ -24,6 +38,10 @@ import reactor.core.publisher.Flux;
|
|||
public class ChatServiceImpl implements ChatService {
|
||||
|
||||
private final AiClient aiClient;
|
||||
private final AiChatRoleMapper aiChatRoleMapper;
|
||||
private final AiChatMessageMapper aiChatMessageMapper;
|
||||
private final AiChatConversationMapper aiChatConversationMapper;
|
||||
|
||||
|
||||
/**
|
||||
* chat
|
||||
|
@ -31,16 +49,84 @@ public class ChatServiceImpl implements ChatService {
|
|||
* @param req
|
||||
* @return
|
||||
*/
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public String chat(ChatReq req) {
|
||||
// 获取 client 类型
|
||||
AiClientNameEnum clientNameEnum = AiClientNameEnum.valueOfName(req.getModal());
|
||||
// 创建 chat 需要的 Prompt
|
||||
Prompt prompt = new Prompt(req.getPrompt());
|
||||
req.setTopK(req.getTopK());
|
||||
req.setTopP(req.getTopP());
|
||||
req.setTemperature(req.getTemperature());
|
||||
// 发送 call 调用
|
||||
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
|
||||
return call.getResult().getOutput().getContent();
|
||||
// 获取 对话类型(新建还是继续)
|
||||
ChatConversationTypeEnum chatConversationTypeEnum = ChatConversationTypeEnum.valueOfType(req.getConversationType());
|
||||
|
||||
AiChatConversationDO aiChatConversationDO;
|
||||
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
|
||||
if (ChatConversationTypeEnum.NEW == chatConversationTypeEnum) {
|
||||
// 创建一个新的对话
|
||||
aiChatConversationDO = createNewChatConversation(req, loginUserId);
|
||||
} else {
|
||||
// 继续对话
|
||||
if (req.getConversationId() == null) {
|
||||
throw new ServerException(ErrorCodeConstants.AI_CHAT_CONTINUE_CONVERSATION_ID_NOT_NULL);
|
||||
}
|
||||
aiChatConversationDO = aiChatConversationMapper.selectById(req.getConversationId());
|
||||
}
|
||||
|
||||
String content;
|
||||
try {
|
||||
// 创建 chat 需要的 Prompt
|
||||
Prompt prompt = new Prompt(req.getPrompt());
|
||||
req.setTopK(req.getTopK());
|
||||
req.setTopP(req.getTopP());
|
||||
req.setTemperature(req.getTemperature());
|
||||
// 发送 call 调用
|
||||
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
|
||||
content = call.getResult().getOutput().getContent();
|
||||
} catch (Exception e) {
|
||||
content = ExceptionUtil.getMessage(e);
|
||||
}
|
||||
|
||||
// 增加 chat message 记录
|
||||
aiChatMessageMapper.insert(
|
||||
new AiChatMessageDO()
|
||||
.setId(null)
|
||||
.setChatConversationId(aiChatConversationDO.getId())
|
||||
.setUserId(loginUserId)
|
||||
.setMessage(req.getPrompt())
|
||||
.setMessageType(MessageType.USER.getValue())
|
||||
.setTopK(req.getTopK())
|
||||
.setTopP(req.getTopP())
|
||||
.setTemperature(req.getTemperature())
|
||||
);
|
||||
|
||||
// chat count 先+1
|
||||
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
|
||||
return content;
|
||||
}
|
||||
|
||||
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
|
||||
// 获取 chat 角色
|
||||
String chatRoleName = null;
|
||||
ChatTypeEnum chatTypeEnum = null;
|
||||
Long chatRoleId = req.getChatRoleId();
|
||||
if (req.getChatRoleId() != null) {
|
||||
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId);
|
||||
if (aiChatRoleDO == null) {
|
||||
throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT);
|
||||
}
|
||||
chatTypeEnum = ChatTypeEnum.ROLE_CHAT;
|
||||
chatRoleName = aiChatRoleDO.getRoleName();
|
||||
} else {
|
||||
chatTypeEnum = ChatTypeEnum.USER_CHAT;
|
||||
}
|
||||
//
|
||||
AiChatConversationDO insertChatConversation = new AiChatConversationDO()
|
||||
.setId(null)
|
||||
.setUserId(loginUserId)
|
||||
.setChatRoleId(req.getChatRoleId())
|
||||
.setChatRoleName(chatRoleName)
|
||||
.setChatType(chatTypeEnum.getType())
|
||||
.setChatCount(1)
|
||||
.setChatTitle(req.getPrompt().substring(0, 20) + "...");
|
||||
aiChatConversationMapper.insert(insertChatConversation);
|
||||
return insertChatConversation;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -24,19 +24,29 @@ public class ChatReq {
|
|||
@Schema(description = "填入固定值,1 issues, 2 pr")
|
||||
private String prompt;
|
||||
|
||||
@Schema(description = "chat角色模板")
|
||||
private Long chatRoleId;
|
||||
|
||||
@Schema(description = "用于控制随机性和多样性的温度参数")
|
||||
private Float temperature;
|
||||
private Double temperature;
|
||||
|
||||
@Schema(description = "生成时,核采样方法的概率阈值。例如,取值为0.8时,仅保留累计概率之和大于等于0.8的概率分布中的token,\n" +
|
||||
" * 作为随机采样的候选集。取值范围为(0,1.0),取值越大,生成的随机性越高;取值越低,生成的随机性越低。\n" +
|
||||
" * 默认值为0.8。注意,取值不要大于等于1\n")
|
||||
private Float topP;
|
||||
private Double topP;
|
||||
|
||||
@Schema(description = "在生成消息时采用的Top-K采样大小,表示模型生成回复时考虑的候选项集合的大小")
|
||||
private Integer topK;
|
||||
private Double topK;
|
||||
|
||||
@Schema(description = "ai模型(查看 AiClientNameEnum)")
|
||||
@NotNull(message = "模型不能为空!")
|
||||
@Size(max = 30, message = "模型字符最大30个字符!")
|
||||
private String modal;
|
||||
|
||||
@Schema(description = "对话类型(new、continue)")
|
||||
@NotNull(message = "对话类型,不能为空!")
|
||||
private String conversationType;
|
||||
|
||||
@Schema(description = "对话Id")
|
||||
private Long conversationId;
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ public abstract class AbstractMessage implements Message {
|
|||
}
|
||||
|
||||
protected AbstractMessage(MessageType messageType, String textContent, List<MediaData> mediaData,
|
||||
Map<String, Object> messageProperties) {
|
||||
Map<String, Object> messageProperties) {
|
||||
|
||||
Assert.notNull(messageType, "Message type must not be null");
|
||||
Assert.notNull(textContent, "Content must not be null");
|
||||
|
|
Loading…
Reference in New Issue