增加创建role对话

This commit is contained in:
cherishsince 2024-04-23 16:29:26 +08:00
parent dc44fe1cf9
commit 1ab1538afe
9 changed files with 125 additions and 69 deletions

View File

@ -15,8 +15,9 @@ import lombok.Getter;
@Getter
public enum ChatConversationTypeEnum {
NEW("new", "新建对话"),
CONTINUE("continue", "继续对话"),
// roleChatuserChat
ROLE_CHAT("roleChat", "角色对话"),
USER_CHAT("userChat", "用户对话"),
;

View File

@ -2,7 +2,8 @@ package cn.iocoder.yudao.module.ai.controller;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.service.ChatConversationService;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
import io.swagger.v3.oas.annotations.Operation;
@ -30,10 +31,16 @@ public class ChatConversationController {
private final ChatConversationService chatConversationService;
@Operation(summary = "创建 - 对话")
@PostMapping("/create")
public CommonResult<ChatConversationRes> create(@RequestBody @Validated ChatConversationCreateReq req) {
return CommonResult.success(chatConversationService.create(req));
@Operation(summary = "创建 - 对话普通对话")
@PostMapping("/createConversation")
public CommonResult<ChatConversationRes> createConversation(@RequestBody @Validated ChatConversationCreateUserReq req) {
return CommonResult.success(chatConversationService.createConversation(req));
}
@Operation(summary = "创建 - 对话角色对话")
@PostMapping("/createRoleConversation")
public CommonResult<ChatConversationRes> createRoleConversation(@RequestBody @Validated ChatConversationCreateRoleReq req) {
return CommonResult.success(chatConversationService.createRoleConversation(req));
}
@Operation(summary = "获取 - 获取对话")

View File

@ -24,7 +24,11 @@ import java.util.List;
@Mapper
public interface AiChatConversationMapper extends BaseMapperX<AiChatConversationDO> {
/**
* 更新 - chat count
*
* @param id
*/
@Update("update ai_chat_conversation set chat_count = chat_count + 1 where id = #{id}")
void updateIncrChatCount(@Param("id") Long id);

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
@ -15,12 +16,21 @@ import java.util.List;
public interface ChatConversationService {
/**
* 对话 - 创建
* 对话 - 创建普通对话
*
* @param req
* @return
*/
ChatConversationRes create(ChatConversationCreateReq req);
ChatConversationRes createConversation(ChatConversationCreateUserReq req);
/**
* 对话 - 创建role对话
*
* @param req
* @return
*/
ChatConversationRes createRoleConversation(ChatConversationCreateRoleReq req);
/**
* 获取 - 对话
@ -44,4 +54,5 @@ public interface ChatConversationService {
* @param id
*/
void delete(Long id);
}

View File

@ -5,13 +5,18 @@ import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.convert.ChatConversationConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.ChatConversationTypeEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
import cn.iocoder.yudao.module.ai.service.ChatConversationService;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateRoleReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationCreateUserReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationListReq;
import cn.iocoder.yudao.module.ai.vo.ChatConversationRes;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.stereotype.Service;
import java.util.List;
@ -27,10 +32,11 @@ import java.util.List;
@AllArgsConstructor
public class ChatConversationServiceImpl implements ChatConversationService {
private final AiChatRoleMapper aiChatRoleMapper;
private final AiChatConversationMapper aiChatConversationMapper;
@Override
public ChatConversationRes create(ChatConversationCreateReq req) {
public ChatConversationRes createConversation(ChatConversationCreateUserReq req) {
// 获取用户id
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询最新的对话
@ -40,19 +46,47 @@ public class ChatConversationServiceImpl implements ChatConversationService {
return ChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
}
// 创建新的 Conversation
AiChatConversationDO insertConversation = new AiChatConversationDO();
insertConversation.setId(null);
insertConversation.setUserId(loginUserId);
insertConversation.setChatRoleId(null);
insertConversation.setChatRoleName(null);
insertConversation.setTitle(null);
insertConversation.setChatCount(0);
insertConversation.setType(req.getChatType());
aiChatConversationMapper.insert(insertConversation);
AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
null, null, ChatConversationTypeEnum.USER_CHAT);
// 转换 res
return ChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
}
@Override
public ChatConversationRes createRoleConversation(ChatConversationCreateRoleReq req) {
// 获取用户id
Long loginUserId = SecurityFrameworkUtils.getLoginUserId();
// 查询最新的对话
AiChatConversationDO latestConversation = aiChatConversationMapper.selectLatestConversation(loginUserId);
// 如果有对话没有被使用过那就返回这个
if (latestConversation != null && latestConversation.getChatCount() <= 0) {
return ChatConversationConvert.INSTANCE.covnertChatConversationRes(latestConversation);
}
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(req.getChatRoleId());
// 创建新的 Conversation
AiChatConversationDO insertConversation = saveConversation(req.getTitle(), loginUserId,
req.getChatRoleId(), aiChatRoleDO.getRoleName(), ChatConversationTypeEnum.ROLE_CHAT);
// 转换 res
return ChatConversationConvert.INSTANCE.covnertChatConversationRes(insertConversation);
}
private @NotNull AiChatConversationDO saveConversation(String title,
Long userId,
Long chatRoleId,
String chatRoleName,
ChatConversationTypeEnum typeEnum) {
AiChatConversationDO insertConversation = new AiChatConversationDO();
insertConversation.setId(null);
insertConversation.setUserId(userId);
insertConversation.setChatRoleId(chatRoleId);
insertConversation.setChatRoleName(chatRoleName);
insertConversation.setTitle(title);
insertConversation.setChatCount(0);
insertConversation.setType(typeEnum.getType());
aiChatConversationMapper.insert(insertConversation);
return insertConversation;
}
@Override
public ChatConversationRes getConversation(Long id) {
AiChatConversationDO aiChatConversationDO = aiChatConversationMapper.selectById(id);

View File

@ -5,15 +5,10 @@ import cn.iocoder.yudao.framework.ai.chat.ChatResponse;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.config.AiClient;
import cn.iocoder.yudao.framework.common.exception.ServerException;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.module.ai.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.controller.Utf8SseEmitter;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.enums.AiClientNameEnum;
import cn.iocoder.yudao.module.ai.enums.ChatTypeEnum;
import cn.iocoder.yudao.module.ai.mapper.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatMessageMapper;
import cn.iocoder.yudao.module.ai.mapper.AiChatRoleMapper;
@ -49,7 +44,6 @@ public class ChatServiceImpl implements ChatService {
private final AiChatConversationMapper aiChatConversationMapper;
private final ChatConversationService chatConversationService;
/**
* chat
*
@ -64,7 +58,7 @@ public class ChatServiceImpl implements ChatService {
// 获取对话信息
ChatConversationRes conversationRes = chatConversationService.getConversation(req.getConversationId());
// 保存 chat message
saveChatMessage(req, conversationRes.getId(), loginUserId);
saveChatMessage(req, conversationRes, loginUserId);
String content = null;
try {
// 创建 chat 需要的 Prompt
@ -75,16 +69,19 @@ public class ChatServiceImpl implements ChatService {
// 发送 call 调用
ChatResponse call = aiClient.call(prompt, clientNameEnum.getName());
content = call.getResult().getOutput().getContent();
// 更新 conversation
} catch (Exception e) {
content = ExceptionUtil.getMessage(e);
} finally {
// 保存 chat message
saveSystemChatMessage(req, conversationRes.getId(), loginUserId, content);
saveSystemChatMessage(req, conversationRes, loginUserId, content);
}
return content;
}
private void saveChatMessage(ChatReq req, Long chatConversationId, Long loginUserId) {
private void saveChatMessage(ChatReq req, ChatConversationRes conversationRes, Long loginUserId) {
Long chatConversationId = conversationRes.getId();
// 增加 chat message 记录
aiChatMessageMapper.insert(
new AiChatMessageDO()
@ -97,12 +94,12 @@ public class ChatServiceImpl implements ChatService {
.setTopP(req.getTopP())
.setTemperature(req.getTemperature())
);
// chat count +1
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
}
public void saveSystemChatMessage(ChatReq req, Long chatConversationId, Long loginUserId, String systemPrompts) {
public void saveSystemChatMessage(ChatReq req, ChatConversationRes conversationRes, Long loginUserId, String systemPrompts) {
Long chatConversationId = conversationRes.getId();
// 增加 chat message 记录
aiChatMessageMapper.insert(
new AiChatMessageDO()
@ -120,34 +117,6 @@ public class ChatServiceImpl implements ChatService {
aiChatConversationMapper.updateIncrChatCount(req.getConversationId());
}
private AiChatConversationDO createNewChatConversation(ChatReq req, Long loginUserId) {
// 获取 chat 角色
String chatRoleName = null;
ChatTypeEnum chatTypeEnum = null;
Long chatRoleId = req.getChatRoleId();
if (req.getChatRoleId() != null) {
AiChatRoleDO aiChatRoleDO = aiChatRoleMapper.selectById(chatRoleId);
if (aiChatRoleDO == null) {
throw new ServerException(ErrorCodeConstants.AI_CHAT_ROLE_NOT_EXISTENT);
}
chatTypeEnum = ChatTypeEnum.ROLE_CHAT;
chatRoleName = aiChatRoleDO.getRoleName();
} else {
chatTypeEnum = ChatTypeEnum.USER_CHAT;
}
//
AiChatConversationDO insertChatConversation = new AiChatConversationDO()
.setId(null)
.setUserId(loginUserId)
.setChatRoleId(req.getChatRoleId())
.setChatRoleName(chatRoleName)
.setType(chatTypeEnum.getType())
.setChatCount(1)
.setTitle(req.getPrompt().substring(0, 20) + "...");
aiChatConversationMapper.insert(insertChatConversation);
return insertChatConversation;
}
/**
* chat stream
*
@ -168,7 +137,7 @@ public class ChatServiceImpl implements ChatService {
req.setTopP(req.getTopP());
req.setTemperature(req.getTemperature());
// 保存 chat message
saveChatMessage(req, conversationRes.getId(), loginUserId);
saveChatMessage(req, conversationRes, loginUserId);
Flux<ChatResponse> streamResponse = aiClient.stream(prompt, clientNameEnum.getName());
StringBuffer contentBuffer = new StringBuffer();
@ -195,7 +164,7 @@ public class ChatServiceImpl implements ChatService {
log.info("发送完成!");
sseEmitter.complete();
// 保存 chat message
saveSystemChatMessage(req, conversationRes.getId(), loginUserId, contentBuffer.toString());
saveSystemChatMessage(req, conversationRes, loginUserId, contentBuffer.toString());
}
);
}

View File

@ -0,0 +1,26 @@
package cn.iocoder.yudao.module.ai.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import lombok.experimental.Accessors;
/**
* 聊天对话
*
* @author fansili
* @time 2024/4/18 16:24
* @since 1.0
*/
@Data
@Accessors(chain = true)
public class ChatConversationCreateRoleReq {
@Schema(description = "chat角色Id")
@NotNull(message = "聊天角色id不能为空!")
private Long chatRoleId;
@Schema(description = "标题(有程序自动生成)")
@NotNull(message = "标题不能为空!")
private String title;
}

View File

@ -14,10 +14,9 @@ import lombok.experimental.Accessors;
*/
@Data
@Accessors(chain = true)
public class ChatConversationCreateReq {
@Schema(description = "对话类型(roleChat、userChat)")
@NotNull(message = "聊天类型不能为空!")
private String chatType;
public class ChatConversationCreateUserReq {
@Schema(description = "对话标题")
@NotNull(message = "标题不能为空!")
private String title;
}

View File

@ -16,7 +16,12 @@ GET {{baseUrl}}/ai/chat/conversation/1781604279872581644
Authorization: {{token}}
### 对话 - id获取
### 对话 - list
GET {{baseUrl}}/ai/chat/conversation/list
Authorization: {{token}}
### 对话 - 删除
DELETE {{baseUrl}}/ai/chat/conversation/1781604279872581644
Authorization: {{token}}