【增加】AI 写作:支持撰写
This commit is contained in:
parent
aa4c1cb268
commit
77ead4859c
|
@ -42,4 +42,9 @@ public interface ErrorCodeConstants {
|
||||||
// ========== API 音乐 1-040-006-000 ==========
|
// ========== API 音乐 1-040-006-000 ==========
|
||||||
ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");
|
ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");
|
||||||
|
|
||||||
|
|
||||||
|
// ========== API 写作 1-022-007-000 ==========
|
||||||
|
ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!");
|
||||||
|
ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "Stream 对话异常!");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import lombok.Getter;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI 音乐状态的枚举
|
* AI 音乐生成模式的枚举
|
||||||
*
|
*
|
||||||
* @author xiaoxin
|
* @author xiaoxin
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -0,0 +1,37 @@
|
||||||
|
package cn.iocoder.yudao.module.ai.enums.write;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI 写作类型的枚举
|
||||||
|
*
|
||||||
|
* @author xiaoxin
|
||||||
|
*/
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Getter
|
||||||
|
public enum AiWriteTypeEnum implements IntArrayValuable {
|
||||||
|
|
||||||
|
DESCRIPTION(1, "撰写"),
|
||||||
|
LYRIC(2, "回复");
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 类型
|
||||||
|
*/
|
||||||
|
private final Integer type;
|
||||||
|
/**
|
||||||
|
* 类型名
|
||||||
|
*/
|
||||||
|
private final String name;
|
||||||
|
|
||||||
|
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiWriteTypeEnum::getType).toArray();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int[] array() {
|
||||||
|
return ARRAYS;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -15,6 +15,8 @@ import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
|
||||||
|
|
||||||
@Tag(name = "管理后台 - AI 写作")
|
@Tag(name = "管理后台 - AI 写作")
|
||||||
@RestController
|
@RestController
|
||||||
@RequestMapping("/ai/write")
|
@RequestMapping("/ai/write")
|
||||||
|
@ -27,6 +29,6 @@ public class AiWriteController {
|
||||||
@PermitAll
|
@PermitAll
|
||||||
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
|
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
|
||||||
public Flux<CommonResult<String>> generateComposition(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
|
public Flux<CommonResult<String>> generateComposition(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
|
||||||
return writeService.generateComposition(generateReqVO);
|
return writeService.generateWriteContent(generateReqVO, getLoginUserId());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,14 +8,14 @@ import lombok.Data;
|
||||||
@Data
|
@Data
|
||||||
public class AiWriteGenerateReqVO {
|
public class AiWriteGenerateReqVO {
|
||||||
|
|
||||||
@Schema(description = "写作内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马")
|
@Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "田忌赛马")
|
||||||
private String content;
|
private String contentPrompt;
|
||||||
|
|
||||||
@Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职")
|
@Schema(description = "原文", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "领导我要辞职")
|
||||||
private String originalContent;
|
private String originalContent;
|
||||||
|
|
||||||
@Schema(description = "回复内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "准了")
|
@Schema(description = "回复内容", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "准了")
|
||||||
private String replyContent;
|
private String replyContentPrompt;
|
||||||
|
|
||||||
@Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "中等")
|
@Schema(description = "长度", requiredMode = Schema.RequiredMode.REQUIRED, example = "中等")
|
||||||
@NotBlank(message = "长度不能为空")
|
@NotBlank(message = "长度不能为空")
|
||||||
|
@ -35,5 +35,5 @@ public class AiWriteGenerateReqVO {
|
||||||
|
|
||||||
|
|
||||||
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||||
private Integer writeType;
|
private Integer writeType; //参见 AiWriteTypeEnum 枚举
|
||||||
}
|
}
|
|
@ -0,0 +1,97 @@
|
||||||
|
package cn.iocoder.yudao.module.ai.dal.dataobject.write;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||||
|
import com.baomidou.mybatisplus.annotation.IdType;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
|
import lombok.Data;
|
||||||
|
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI 写作 DO
|
||||||
|
*
|
||||||
|
* @author xiaoxin
|
||||||
|
*/
|
||||||
|
@TableName(value = "ai_write", autoResultMap = true)
|
||||||
|
@Data
|
||||||
|
public class AiWriteDO extends BaseDO {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 编号
|
||||||
|
*/
|
||||||
|
@TableId(type = IdType.AUTO)
|
||||||
|
private Long id;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 用户编号
|
||||||
|
*/
|
||||||
|
private Long userId;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 写作类型
|
||||||
|
* <p>
|
||||||
|
* 枚举 {@link AiWriteTypeEnum}
|
||||||
|
*/
|
||||||
|
private Integer writeType;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 撰写内容提示
|
||||||
|
*/
|
||||||
|
private String contentPrompt;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成的撰写内容
|
||||||
|
*/
|
||||||
|
private String generatedContent;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 原文
|
||||||
|
*/
|
||||||
|
private String originalContent;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 回复内容提示
|
||||||
|
*/
|
||||||
|
private String replyContentPrompt;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成的回复内容
|
||||||
|
*/
|
||||||
|
private String generatedReplyContent;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 长度提示词
|
||||||
|
*/
|
||||||
|
private String length;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式提示词
|
||||||
|
*/
|
||||||
|
private String format;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 语气提示词
|
||||||
|
*/
|
||||||
|
private String tone;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 语言提示词
|
||||||
|
*/
|
||||||
|
private String language;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 模型
|
||||||
|
*/
|
||||||
|
private String model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 平台
|
||||||
|
*/
|
||||||
|
private String platform;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 错误信息
|
||||||
|
*/
|
||||||
|
private String errorMessage;
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
package cn.iocoder.yudao.module.ai.dal.mysql.write;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
|
||||||
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI 音乐 Mapper
|
||||||
|
*
|
||||||
|
* @author xiaoxin
|
||||||
|
*/
|
||||||
|
@Mapper
|
||||||
|
public interface AiWriteMapper extends BaseMapperX<AiWriteDO> {
|
||||||
|
}
|
|
@ -12,7 +12,14 @@ import reactor.core.publisher.Flux;
|
||||||
public interface AiWriteService {
|
public interface AiWriteService {
|
||||||
|
|
||||||
|
|
||||||
Flux<CommonResult<String>> generateComposition(AiWriteGenerateReqVO generateReqVO);
|
/**
|
||||||
|
* 生成写作内容
|
||||||
|
*
|
||||||
|
* @param generateReqVO 作文生成请求参数
|
||||||
|
* @param userId 用户编号
|
||||||
|
* @return 生成结果
|
||||||
|
*/
|
||||||
|
Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,22 +2,27 @@ package cn.iocoder.yudao.module.ai.service.write;
|
||||||
|
|
||||||
import cn.hutool.core.util.StrUtil;
|
import cn.hutool.core.util.StrUtil;
|
||||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
|
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
||||||
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
|
|
||||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.dataobject.write.AiWriteDO;
|
||||||
|
import cn.iocoder.yudao.module.ai.dal.mysql.write.AiWriteMapper;
|
||||||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||||
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||||
|
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||||
import jakarta.annotation.Resource;
|
import jakarta.annotation.Resource;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.ai.chat.ChatResponse;
|
import org.springframework.ai.chat.model.ChatResponse;
|
||||||
import org.springframework.ai.chat.StreamingChatClient;
|
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||||
import org.springframework.ai.chat.prompt.Prompt;
|
import org.springframework.ai.chat.prompt.Prompt;
|
||||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||||
|
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import reactor.core.publisher.Flux;
|
import reactor.core.publisher.Flux;
|
||||||
|
|
||||||
|
@ -35,16 +40,29 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
|
|
||||||
@Resource
|
@Resource
|
||||||
private AiApiKeyService apiKeyService;
|
private AiApiKeyService apiKeyService;
|
||||||
|
@Resource
|
||||||
|
private AiChatModelService chatModalService;
|
||||||
|
@Resource
|
||||||
|
private AiWriteMapper writeMapper;
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Flux<CommonResult<String>> generateComposition(AiWriteGenerateReqVO generateReqVO) {
|
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
||||||
StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(6L);
|
//TODO 芋艿 写作的模型配置放哪好 先用千问测试
|
||||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform("QianWen");
|
// 1.1 校验模型
|
||||||
ChatOptions chatOptions = buildChatOptions(platform, "qwen-72b-chat", 1.0, 1000);
|
AiChatModelDO model = chatModalService.validateChatModel(14L);
|
||||||
|
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
|
||||||
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||||
|
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||||
|
|
||||||
|
//1.2 插入写作信息
|
||||||
|
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
|
||||||
|
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||||
|
|
||||||
|
//2.1 构建提示词
|
||||||
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
|
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
|
||||||
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
||||||
// 3.3 流式返回
|
// 2.2 流式返回
|
||||||
StringBuffer contentBuffer = new StringBuffer();
|
StringBuffer contentBuffer = new StringBuffer();
|
||||||
return streamResponse.map(chunk -> {
|
return streamResponse.map(chunk -> {
|
||||||
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
|
||||||
|
@ -53,17 +71,17 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
// 响应结果
|
// 响应结果
|
||||||
return success(newContent);
|
return success(newContent);
|
||||||
}).doOnComplete(() -> {
|
}).doOnComplete(() -> {
|
||||||
log.info("generateComposition complete, content: {}", contentBuffer);
|
writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setGeneratedContent(contentBuffer.toString()));
|
||||||
}).onErrorResume(error -> {
|
}).doOnError(throwable -> {
|
||||||
log.error("[AI 写作] 发生异常", error);
|
log.error("[AI Write][generateReqVO({}) 发生异常]", generateReqVO, throwable);
|
||||||
return Flux.just(error(ErrorCodeConstants.AI_CHAT_STREAM_ERROR));
|
writeMapper.updateById(new AiWriteDO().setId(writeDO.getId()).setErrorMessage(throwable.getMessage()));
|
||||||
});
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
||||||
String template = "请直接写一篇关于 [{}] 的文章,格式为:{},语气为:{},语言为:{},长度为:{}。请确保涵盖主要内容,不需要任何额外的解释或道歉。";
|
String template = "请直接写一篇关于 [{}] 的文章,格式为:{},语气为:{},语言为:{},长度为:{}。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
||||||
String content = generateReqVO.getContent();
|
String content = generateReqVO.getContentPrompt();
|
||||||
String format = generateReqVO.getFormat();
|
String format = generateReqVO.getFormat();
|
||||||
String tone = generateReqVO.getTone();
|
String tone = generateReqVO.getTone();
|
||||||
String language = generateReqVO.getLanguage();
|
String language = generateReqVO.getLanguage();
|
||||||
|
@ -81,14 +99,14 @@ public class AiWriteServiceImpl implements AiWriteService {
|
||||||
case OLLAMA:
|
case OLLAMA:
|
||||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
||||||
case YI_YAN:
|
case YI_YAN:
|
||||||
// TODO @fan:增加一个 model
|
// TODO 芋艿:貌似 model 只要一设置,就报错
|
||||||
return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
|
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||||
|
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||||
case XING_HUO:
|
case XING_HUO:
|
||||||
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
||||||
.setMaxTokens(maxTokens);
|
.setMaxTokens(maxTokens);
|
||||||
case QIAN_WEN:
|
case QIAN_WEN:
|
||||||
// TODO @fan:增加 model、temperature 参数
|
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
||||||
return new QianWenOptions().setMaxTokens(maxTokens);
|
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue