【优化】AI 写作:1. 优先获取写作角色;2. 优化写作提示词

This commit is contained in:
xiaoxin 2024-07-09 22:05:42 +08:00
parent ae69b137be
commit 6e71b721e8
2 changed files with 56 additions and 22 deletions

View File

@ -1,5 +1,7 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
@ -15,8 +17,8 @@ import java.util.Arrays;
@Getter
public enum AiWriteTypeEnum implements IntArrayValuable {
WRITING(1, "撰写"),
REPLY(2, "回复");
WRITING(1, "撰写", "请撰写一篇关于 [{}] 的文章。文章的内容格式:{},语气:{},语言:{},长度:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"),
REPLY(2, "回复", "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复格式:{},语气:{},语言:{},长度:{}。不需要除了正文内容外的其他回复,如标题、开头、额外的解释或道歉。");
/**
* 类型
@ -27,6 +29,11 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
*/
private final String name;
/**
* 模版
*/
private final String template;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray();
@Override
@ -34,4 +41,9 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
return ARRAYS;
}
public static void validateType(Integer type) {
if (ArrayUtil.contains(ARRAYS, type)) return;
throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type));
}
}

View File

@ -1,13 +1,17 @@
package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
@ -15,6 +19,7 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
@ -25,6 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.List;
import java.util.Objects;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
@ -43,6 +49,8 @@ public class AiWriteServiceImpl implements AiWriteService {
private AiApiKeyService apiKeyService;
@Resource
private AiChatModelService chatModalService;
@Resource
private AiChatRoleService chatRoleService;
@Resource
private DictDataApi dictDataApi;
@ -52,15 +60,22 @@ public class AiWriteServiceImpl implements AiWriteService {
@Override
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok那可以有限拿 chatRole 的角色如果没有则获取默认的
AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 1.1 获取写作模型 尝试获取写作助手角色如果没有则使用默认模型
AiChatRoleDO writeRole = selectOneWriteRole();
AiChatModelDO model;
if (Objects.nonNull(writeRole)) {
model = chatModalService.getChatModel(writeRole.getModelId());
} else {
model = chatModalService.getRequiredDefaultChatModel();
}
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
// 1.2 插入写作信息
// TODO @xin建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform())写在 toBean consumer 原因是让这个 set 保持完整性
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
writeMapper.insert(writeDO);
// 2.1 构建提示词
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
@ -87,23 +102,30 @@ public class AiWriteServiceImpl implements AiWriteService {
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
}
private AiChatRoleDO selectOneWriteRole() {
AiChatRoleDO chatRoleDO = null;
PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手"));
List<AiChatRoleDO> list = writeRolePage.getList();
if (CollUtil.isNotEmpty(list)) {
chatRoleDO = list.get(0);
}
return chatRoleDO;
}
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
String template;
Integer writeType = generateReqVO.getType();
Integer type = generateReqVO.getType();
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
// TODO @xin建议改成 if return 更简洁
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
// TODO @xin写成静态枚举哈
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
template = "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
return StrUtil.format(template, generateReqVO.getOriginalContent(), generateReqVO.getPrompt(), format, tone, language, length);
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
String prompt = generateReqVO.getPrompt();
// 校验写作类型是否合法
AiWriteTypeEnum.validateType(type);
if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) {
return StrUtil.format(AiWriteTypeEnum.WRITING.getTemplate(), prompt, format, tone, language, length);
} else {
throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", writeType));
return StrUtil.format(AiWriteTypeEnum.REPLY.getTemplate(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
}
}