【增加】AI 写作:支持撰写

This commit is contained in:
xiaoxin 2024-07-03 10:52:41 +08:00
parent aa4c1cb268
commit 77ead4859c
9 changed files with 207 additions and 27 deletions

View File

@ -42,4 +42,9 @@ public interface ErrorCodeConstants {
// ========== API 音乐 1-040-006-000 ========== // ========== API 音乐 1-040-006-000 ==========
ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!"); ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");
// ========== API 写作 1-022-007-000 ==========
ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!");
ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "Stream 对话异常!");
} }

View File

@ -7,7 +7,7 @@ import lombok.Getter;
import java.util.Arrays; import java.util.Arrays;
/** /**
* AI 音乐状态的枚举 * AI 音乐生成模式的枚举
* *
* @author xiaoxin * @author xiaoxin
*/ */

View File

@ -0,0 +1,37 @@
package cn.iocoder.yudao.module.ai.enums.write;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 写作类型的枚举
*
* @author xiaoxin
*/
@AllArgsConstructor
@Getter
public enum AiWriteTypeEnum implements IntArrayValuable {
DESCRIPTION(1, "撰写"),
LYRIC(2, "回复");
/**
* 类型
*/
private final Integer type;
/**
* 类型名
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray();
@Override
public int[] array() {
return ARRAYS;
}
}

View File

@ -15,6 +15,8 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
@Tag(name = "管理后台 - AI 写作") @Tag(name = "管理后台 - AI 写作")
@RestController @RestController
@RequestMapping("/ai/write") @RequestMapping("/ai/write")
@ -27,6 +29,6 @@ public class AiWriteController {
@PermitAll @PermitAll
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快") @Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
public Flux<CommonResult<String>> generateComposition(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) { public Flux<CommonResult<String>> generateComposition(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
return writeService.generateComposition(generateReqVO); return writeService.generateWriteContent(generateReqVO, getLoginUserId());
} }
} }

View File

@ -8,14 +8,14 @@ import lombok.Data;
@Data @Data
public class AiWriteGenerateReqVO { public class AiWriteGenerateReqVO {
@Schema(description = "写作内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马") @Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马")
private String content; private String contentPrompt;
@Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职") @Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职")
private String originalContent; private String originalContent;
@Schema(description = "回复内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "准了") @Schema(description = "回复内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "准了")
private String replyContent; private String replyContentPrompt;
@Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "中等") @Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "中等")
@NotBlank(message = "长度不能为空") @NotBlank(message = "长度不能为空")
@ -35,5 +35,5 @@ public class AiWriteGenerateReqVO {
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") @Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Integer writeType; private Integer writeType; //参见 AiWriteTypeEnum 枚举
} }

View File

@ -0,0 +1,97 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.write;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
/**
* AI 写作 DO
*
* @author xiaoxin
*/
@TableName(value = "ai_write", autoResultMap = true)
@Data
public class AiWriteDO extends BaseDO {
/**
* 编号
*/
@TableId(type = IdType.AUTO)
private Long id;
/**
* 用户编号
*/
private Long userId;
/**
* 写作类型
* <p>
* 枚举 {@link AiWriteTypeEnum}
*/
private Integer writeType;
/**
* 撰写内容提示
*/
private String contentPrompt;
/**
* 生成的撰写内容
*/
private String generatedContent;
/**
* 原文
*/
private String originalContent;
/**
* 回复内容提示
*/
private String replyContentPrompt;
/**
* 生成的回复内容
*/
private String generatedReplyContent;
/**
* 长度提示词
*/
private String length;
/**
* 格式提示词
*/
private String format;
/**
* 语气提示词
*/
private String tone;
/**
* 语言提示词
*/
private String language;
/**
* 模型
*/
private String model;
/**
* 平台
*/
private String platform;
/**
* 错误信息
*/
private String errorMessage;
}

View File

@ -0,0 +1,14 @@
package cn.iocoder.yudao.module.ai.dal.mysql.write;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
import org.apache.ibatis.annotations.Mapper;
/**
* AI 音乐 Mapper
*
* @author xiaoxin
*/
@Mapper
public interface AiWriteMapper extends BaseMapperX<AiWriteDO> {
}

View File

@ -12,7 +12,14 @@ import reactor.core.publisher.Flux;
public interface AiWriteService { public interface AiWriteService {
Flux<CommonResult<String>> generateComposition(AiWriteGenerateReqVO generateReqVO); /**
* 生成写作内容
*
* @param generateReqVO 作文生成请求参数
* @param userId 用户编号
* @return 生成结果
*/
Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId);
} }

View File

@ -2,22 +2,27 @@ package cn.iocoder.yudao.module.ai.service.write;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions; import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO; import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants; import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService; import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -35,16 +40,29 @@ public class AiWriteServiceImpl implements AiWriteService {
@Resource @Resource
private AiApiKeyService apiKeyService; private AiApiKeyService apiKeyService;
@Resource
private AiChatModelService chatModalService;
@Resource
private AiWriteMapper writeMapper;
@Override @Override
public Flux<CommonResult<String>> generateComposition(AiWriteGenerateReqVO generateReqVO) { public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(6L); //TODO 芋艿 写作的模型配置放哪好 先用千问测试
AiPlatformEnum platform = AiPlatformEnum.validatePlatform("QianWen"); // 1.1 校验模型
ChatOptions chatOptions = buildChatOptions(platform, "qwen-72b-chat", 1.0, 1000); AiChatModelDO model = chatModalService.validateChatModel(14L);
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
//1.2 插入写作信息
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
//2.1 构建提示词
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions); Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt); Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 3.3 流式返回 // 2.2 流式返回
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> { return streamResponse.map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
@ -53,17 +71,17 @@ public class AiWriteServiceImpl implements AiWriteService {
// 响应结果 // 响应结果
return success(newContent); return success(newContent);
}).doOnComplete(() -> { }).doOnComplete(() -> {
log.info("generateComposition complete, content: {}", contentBuffer); writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setGeneratedContent(contentBuffer.toString()));
}).onErrorResume(error -> { }).doOnError(throwable -> {
log.error("[AI 写作] 发生异常", error); log.error("[AI Write][generateReqVO({}) 发生异常]", generateReqVO, throwable);
return Flux.just(error(ErrorCodeConstants.AI_CHAT_STREAM_ERROR)); writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setErrorMessage(throwable.getMessage()));
}); }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
} }
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) { private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
String template = "请直接写一篇关于 [{}] 的文章,格式为:{},语气为:{},语言为:{},长度为:{}。请确保涵盖主要内容,不需要任何额外的解释或道歉。"; String template = "请直接写一篇关于 [{}] 的文章,格式为:{},语气为:{},语言为:{},长度为:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
String content = generateReqVO.getContent(); String content = generateReqVO.getContentPrompt();
String format = generateReqVO.getFormat(); String format = generateReqVO.getFormat();
String tone = generateReqVO.getTone(); String tone = generateReqVO.getTone();
String language = generateReqVO.getLanguage(); String language = generateReqVO.getLanguage();
@ -81,14 +99,14 @@ public class AiWriteServiceImpl implements AiWriteService {
case OLLAMA: case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens); return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN: case YI_YAN:
// TODO @fan增加一个 model // TODO 芋艿貌似 model 只要一设置就报错
return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens); // return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO: case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF) return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens); .setMaxTokens(maxTokens);
case QIAN_WEN: case QIAN_WEN:
// TODO @fan:增加 modeltemperature 参数 return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
return new QianWenOptions().setMaxTokens(maxTokens);
default: default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
} }