【优化】AI 写作:1. 优先获取写作角色;2. 优化写作提示词
This commit is contained in:
parent
ae69b137be
commit
6e71b721e8
|
@ -1,5 +1,7 @@
|
||||||
package cn.iocoder.yudao.module.ai.enums.write;
|
package cn.iocoder.yudao.module.ai.enums.write;
|
||||||
|
|
||||||
|
import cn.hutool.core.util.ArrayUtil;
|
||||||
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
|
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -15,8 +17,8 @@ import java.util.Arrays;
|
||||||
@Getter
|
@Getter
|
||||||
public enum AiWriteTypeEnum implements IntArrayValuable {
|
public enum AiWriteTypeEnum implements IntArrayValuable {
|
||||||
|
|
||||||
WRITING(1, "撰写"),
|
WRITING(1, "撰写", "请撰写一篇关于 [{}] 的文章。文章的内容格式:{},语气:{},语言:{},长度:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。"),
|
||||||
REPLY(2, "回复");
|
REPLY(2, "回复", "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复格式:{},语气:{},语言:{},长度:{}。不需要除了正文内容外的其他回复,如标题、开头、额外的解释或道歉。");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 类型
|
* 类型
|
||||||
|
@ -27,6 +29,11 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
|
||||||
*/
|
*/
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 模版
|
||||||
|
*/
|
||||||
|
private final String template;
|
||||||
|
|
||||||
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray();
|
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -34,4 +41,9 @@ public enum AiWriteTypeEnum implements IntArrayValuable {
|
||||||
return ARRAYS;
|
return ARRAYS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void validateType(Integer type) {
|
||||||
|
if (ArrayUtil.contains(ARRAYS, type)) return;
|
||||||
|
throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", type));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
package cn.iocoder.yudao.module.ai.service.write;
|
package cn.iocoder.yudao.module.ai.service.write;
|
||||||
|
|
||||||
|
import cn.hutool.core.collection.CollUtil;
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
|
||||||
|
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRolePageReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
|
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
|
||||||
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
|
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
|
||||||
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
|
import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
|
||||||
|
@ -15,6 +19,7 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||||
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||||
|
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
|
||||||
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
|
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -25,6 +30,7 @@ import org.springframework.ai.chat.prompt.Prompt;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
|
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
|
||||||
|
@ -43,6 +49,8 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
private AiApiKeyService apiKeyService;
|
private AiApiKeyService apiKeyService;
|
||||||
@Resource
|
@Resource
|
||||||
private AiChatModelService chatModalService;
|
private AiChatModelService chatModalService;
|
||||||
|
@Resource
|
||||||
|
private AiChatRoleService chatRoleService;
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private DictDataApi dictDataApi;
|
private DictDataApi dictDataApi;
|
||||||
|
@ -52,15 +60,22 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
||||||
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的;
|
// 1.1 获取写作模型 尝试获取写作助手角色,如果没有则使用默认模型
|
||||||
AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
|
AiChatRoleDO writeRole = selectOneWriteRole();
|
||||||
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
AiChatModelDO model;
|
||||||
|
if (Objects.nonNull(writeRole)) {
|
||||||
|
model = chatModalService.getChatModel(writeRole.getModelId());
|
||||||
|
} else {
|
||||||
|
model = chatModalService.getRequiredDefaultChatModel();
|
||||||
|
}
|
||||||
|
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
|
|
||||||
|
StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
|
||||||
|
|
||||||
// 1.2 插入写作信息
|
// 1.2 插入写作信息
|
||||||
// TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性
|
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, e -> e.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||||
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
|
writeMapper.insert(writeDO);
|
||||||
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
|
||||||
|
|
||||||
// 2.1 构建提示词
|
// 2.1 构建提示词
|
||||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||||
|
@ -87,23 +102,30 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private AiChatRoleDO selectOneWriteRole() {
|
||||||
|
AiChatRoleDO chatRoleDO = null;
|
||||||
|
PageResult<AiChatRoleDO> writeRolePage = chatRoleService.getChatRolePage(new AiChatRolePageReqVO().setName("写作助手"));
|
||||||
|
List<AiChatRoleDO> list = writeRolePage.getList();
|
||||||
|
if (CollUtil.isNotEmpty(list)) {
|
||||||
|
chatRoleDO = list.get(0);
|
||||||
|
}
|
||||||
|
return chatRoleDO;
|
||||||
|
}
|
||||||
|
|
||||||
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
||||||
String template;
|
Integer type = generateReqVO.getType();
|
||||||
Integer writeType = generateReqVO.getType();
|
|
||||||
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
|
String format = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_FORMAT, generateReqVO.getFormat());
|
||||||
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
|
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getTone());
|
||||||
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
|
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getLanguage());
|
||||||
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
|
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getLength());
|
||||||
// TODO @xin:建议改成 if return 哈;更简洁;
|
String prompt = generateReqVO.getPrompt();
|
||||||
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
|
// 校验写作类型是否合法
|
||||||
// TODO @xin:写成静态枚举哈
|
AiWriteTypeEnum.validateType(type);
|
||||||
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
|
||||||
return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
|
if (Objects.equals(type, AiWriteTypeEnum.WRITING.getType())) {
|
||||||
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
|
return StrUtil.format(AiWriteTypeEnum.WRITING.getTemplate(), prompt, format, tone, language, length);
|
||||||
template = "请针对如下内容:[{}] 做个回复。回复内容参考:[{}], 回复的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
|
||||||
return StrUtil.format(template, generateReqVO.getOriginalContent(), generateReqVO.getPrompt(), format, tone, language, length);
|
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知写作类型({})", writeType));
|
return StrUtil.format(AiWriteTypeEnum.REPLY.getTemplate(), generateReqVO.getOriginalContent(), prompt, format, tone, language, length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue