【代码优化】AI:新增 AIUtils,用于对接 spring ai 各种对象的构建
This commit is contained in:
parent
41071fc689
commit
471968eaf2
|
@ -26,9 +26,10 @@ public class AiWriteController {
|
|||
private AiWriteService writeService;
|
||||
|
||||
@PostMapping(value = "/generate-stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
@PermitAll
|
||||
@Operation(summary = "写作生成(流式)", description = "流式返回,响应较快")
|
||||
@PermitAll // 解决 SSE 最终响应的时候,会被 Access Denied 拦截的问题
|
||||
public Flux<CommonResult<String>> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
|
||||
return writeService.generateWriteContent(generateReqVO, getLoginUserId());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -4,8 +4,7 @@ import cn.hutool.core.collection.CollUtil;
|
|||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
|
@ -19,7 +18,6 @@ import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
|
|||
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.messages.*;
|
||||
|
@ -28,9 +26,6 @@ import org.springframework.ai.chat.model.ChatResponse;
|
|||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
@ -148,46 +143,17 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|||
}
|
||||
// 1.2 history message 历史消息
|
||||
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
||||
contextMessages.forEach(message -> {
|
||||
// TODO @芋艿:看看有没优化空间
|
||||
if (MessageType.USER.getValue().equals(message.getType())) {
|
||||
chatMessages.add(new UserMessage(message.getContent()));
|
||||
} else {
|
||||
chatMessages.add(new AssistantMessage(message.getContent()));
|
||||
}
|
||||
});
|
||||
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
|
||||
// 1.3 user message 新发送消息
|
||||
chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
||||
|
||||
// 2. 构建 ChatOptions 对象
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(),
|
||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
|
||||
conversation.getTemperature(), conversation.getMaxTokens());
|
||||
return new Prompt(chatMessages, chatOptions);
|
||||
}
|
||||
|
||||
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case OPENAI:
|
||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case OLLAMA:
|
||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
||||
case YI_YAN:
|
||||
// TODO 芋艿:貌似 model 只要一设置,就报错
|
||||
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case XING_HUO:
|
||||
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
||||
.setMaxTokens(maxTokens);
|
||||
case QIAN_WEN:
|
||||
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从历史消息中,获得倒序的 n 组消息作为消息上下文
|
||||
*
|
||||
|
|
|
@ -2,8 +2,7 @@ package cn.iocoder.yudao.module.ai.service.write;
|
|||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
||||
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.write.vo.AiWriteGenerateReqVO;
|
||||
|
@ -16,16 +15,12 @@ import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
|||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
||||
import cn.iocoder.yudao.module.system.api.dict.DictDataApi;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.model.StreamingChatModel;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||
import org.springframework.stereotype.Service;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
|
@ -56,19 +51,21 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||
|
||||
@Override
|
||||
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
||||
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?
|
||||
// 1.1 校验模型 TODO 芋艿 是不是取默认的模型也ok?;那可以,有限拿 chatRole 的角色;如果没有,则获取默认的;
|
||||
AiChatModelDO model = chatModalService.getRequiredDefaultChatModel();
|
||||
StreamingChatModel chatClient = apiKeyService.getChatClient(model.getKeyId());
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||
|
||||
// 1.2 插入写作信息
|
||||
// TODO @xin:建议把 writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()),写在 toBean 的 consumer 里;原因是,让这个 set 保持完整性
|
||||
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
|
||||
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||
|
||||
// 2.1 构建提示词
|
||||
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
|
||||
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
||||
|
||||
// 2.2 流式返回
|
||||
StringBuffer contentBuffer = new StringBuffer();
|
||||
return streamResponse.map(chunk -> {
|
||||
|
@ -92,7 +89,9 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||
String tone = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_TONE, generateReqVO.getFormat());
|
||||
String language = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LANGUAGE, generateReqVO.getFormat());
|
||||
String length = dictDataApi.getDictDataLabel(DictTypeConstants.AI_WRITE_LENGTH, generateReqVO.getFormat());
|
||||
// TODO @xin:建议改成 if return 哈;更简洁;
|
||||
if (Objects.equals(writeType, AiWriteTypeEnum.WRITING.getType())) {
|
||||
// TODO @xin:写成静态枚举哈
|
||||
template = "请撰写一篇关于 [{}] 的文章。文章的内容格式为:[{}],语气为:[{}],语言为:[{}],长度为:[{}]。请确保涵盖主要内容,不需要除了正文内容外的其他回复,如标题、额外的解释或道歉。";
|
||||
return StrUtil.format(template, generateReqVO.getPrompt(), format, tone, language, length);
|
||||
} else if (Objects.equals(writeType, AiWriteTypeEnum.REPLY.getType())) {
|
||||
|
@ -103,27 +102,4 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO 芋艿:复用
|
||||
private static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case OPENAI:
|
||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case OLLAMA:
|
||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
||||
case YI_YAN:
|
||||
// TODO 芋艿:貌似 model 只要一设置,就报错
|
||||
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case XING_HUO:
|
||||
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
||||
.setMaxTokens(maxTokens);
|
||||
case QIAN_WEN:
|
||||
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.util;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
|
||||
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
|
||||
import org.springframework.ai.chat.messages.*;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.ollama.api.OllamaOptions;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||
|
||||
/**
|
||||
* Spring AI 工具类
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class AiUtils {
|
||||
|
||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
|
||||
Float temperatureF = temperature != null ? temperature.floatValue() : null;
|
||||
//noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case OPENAI:
|
||||
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case OLLAMA:
|
||||
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
|
||||
case YI_YAN:
|
||||
// TODO @xin:貌似 model 只要一设置,就报错;可以排查下
|
||||
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
|
||||
case XING_HUO:
|
||||
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
|
||||
.setMaxTokens(maxTokens);
|
||||
case QIAN_WEN:
|
||||
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
}
|
||||
|
||||
public static Message buildMessage(String type, String content) {
|
||||
if (MessageType.USER.getValue().equals(type)) {
|
||||
return new UserMessage(content);
|
||||
}
|
||||
if (MessageType.ASSISTANT.getValue().equals(type)) {
|
||||
return new AssistantMessage(content);
|
||||
}
|
||||
if (MessageType.SYSTEM.getValue().equals(type)) {
|
||||
return new SystemMessage(content);
|
||||
}
|
||||
if (MessageType.FUNCTION.getValue().equals(type)) {
|
||||
return new FunctionMessage(content);
|
||||
}
|
||||
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue