【新增】AI:通过 AiClientFactory 提供 chatclient

This commit is contained in:
YunaiV 2024-05-22 12:37:21 +08:00
parent cad1ce4852
commit 2fefcf8834
12 changed files with 289 additions and 108 deletions

View File

@ -1,57 +0,0 @@
package cn.iocoder.yudao.module.ai.config;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.StreamingChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;
/**
* factory
*
* @author fansili
* @time 2024/4/25 17:36
* @since 1.0
*/
@Component
public class AiChatClientFactory {
@Autowired
private ApplicationContext applicationContext;
public ChatClient getChatClient(AiPlatformEnum platformEnum) {
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class);
}
throw new IllegalArgumentException("不支持的 chat client!");
}
// TODO yunai 要不再加一个接口让他们拥有 ChatClientStreamingChatClient 功能
public StreamingChatClient getStreamingChatClient(AiPlatformEnum platformEnum) {
// if (true) {
// return applicationContext.getBean(OllamaChatClient.class);
// }
if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class);
} else if (AiPlatformEnum.OLLAMA == platformEnum) {
return applicationContext.getBean(OllamaChatClient.class);
} else if (AiPlatformEnum.OPENAI == platformEnum) {
return applicationContext.getBean(OpenAiChatClient.class);
}
throw new IllegalArgumentException("不支持的 chat client!");
}
}

View File

@ -4,13 +4,20 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil; import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.config.AiChatClientFactory; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO; import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert; import cn.iocoder.yudao.module.ai.convert.AiChatMessageConvert;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO; import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
@ -19,12 +26,7 @@ import cn.iocoder.yudao.module.ai.service.AiChatService;
import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService; import cn.iocoder.yudao.module.ai.service.chat.AiChatConversationService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService; import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService; import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -46,16 +48,22 @@ import static cn.iocoder.yudao.module.ai.ErrorCodeConstants.CHAT_CONVERSATION_NO
*/ */
@Slf4j @Slf4j
@Service @Service
@AllArgsConstructor
public class AiChatServiceImpl implements AiChatService { public class AiChatServiceImpl implements AiChatService {
private final AiChatClientFactory chatClientFactory; @Resource
private AiChatMessageMapper chatMessageMapper;
private final AiChatMessageMapper chatMessageMapper; @Resource
private AiClientFactory clientFactory;
private final AiChatConversationService chatConversationService; @Resource
private final AiChatModelService chatModalService; private AiChatConversationService chatConversationService;
private final AiChatRoleService chatRoleService; @Resource
private AiChatModelService chatModalService;
@Resource
private AiChatRoleService chatRoleService;
@Resource
private AiApiKeyService apiKeyService;
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) { public AiChatMessageRespVO chat(AiChatMessageSendReqVO req) {
@ -106,8 +114,7 @@ public class AiChatServiceImpl implements AiChatService {
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectByConversationId(conversation.getId()); List<AiChatMessageDO> historyMessages = chatMessageMapper.selectByConversationId(conversation.getId());
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform()); StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
StreamingChatClient chatClient = chatClientFactory.getStreamingChatClient(platform);
// 2. 插入 user 发送消息 // 2. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model, AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -118,13 +125,13 @@ public class AiChatServiceImpl implements AiChatService {
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext()); userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
// 3.2 创建 chat 需要的 Prompt // 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt(conversation, historyMessages, sendReqVO); Prompt prompt = buildPrompt(conversation, historyMessages, model, sendReqVO);
Flux<ChatResponse> streamResponse = chatClient.stream(prompt); Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
// 3.3 流式返回 // 3.3 流式返回
// 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题 // 注意Schedulers.immediate() 目的是避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
StringBuffer contentBuffer = new StringBuffer(); StringBuffer contentBuffer = new StringBuffer();
return streamResponse.publishOn(Schedulers.immediate()).map(chunk -> { return streamResponse.publishOn(Schedulers.single()).map(chunk -> {
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null; String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 情况 newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 情况
contentBuffer.append(newContent); contentBuffer.append(newContent);
@ -144,7 +151,8 @@ public class AiChatServiceImpl implements AiChatService {
return chatMessageMapper.deleteByConversationId(conversationId) > 0; return chatMessageMapper.deleteByConversationId(conversationId) > 0;
} }
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages, AiChatMessageSendReqVO sendReqVO) { private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
// 1. 构建 Prompt Message 列表 // 1. 构建 Prompt Message 列表
List<Message> chatMessages = new ArrayList<>(); List<Message> chatMessages = new ArrayList<>();
// 1.1 system context 角色设定 // 1.1 system context 角色设定
@ -156,10 +164,11 @@ public class AiChatServiceImpl implements AiChatService {
chatMessages.add(new UserMessage(sendReqVO.getContent())); chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象 TODO 芋艿临时注释掉等文心一言兼容了 // 2. 构建 ChatOptions 对象 TODO 芋艿临时注释掉等文心一言兼容了
// TODO 每一轮 token 数量 AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
// ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build(); ChatOptions chatOptions = clientFactory.buildChatOptions(platform, model.getModel(),
// return new Prompt(chatMessages, null); conversation.getTemperature(), conversation.getMaxTokens());
return new Prompt(chatMessages); return new Prompt(chatMessages, chatOptions);
// return new Prompt(chatMessages);
} }
/** /**

View File

@ -5,6 +5,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageR
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO; import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.ai.chat.StreamingChatClient;
import java.util.List; import java.util.List;
@ -68,4 +69,14 @@ public interface AiApiKeyService {
*/ */
List<AiApiKeyDO> getApiKeyList(); List<AiApiKeyDO> getApiKeyList();
// ========== spring-ai 集成 ==========
/**
* 获得 StreamingChatClient 对象
*
* @param id 编号
* @return StreamingChatClient 对象
*/
StreamingChatClient getStreamingChatClient(Long id);
} }

View File

@ -1,5 +1,7 @@
package cn.iocoder.yudao.module.ai.service.model; package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -8,6 +10,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper; import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
@ -28,6 +31,9 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource @Resource
private AiApiKeyMapper apiKeyMapper; private AiApiKeyMapper apiKeyMapper;
@Resource
private AiClientFactory clientFactory;
@Override @Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) { public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
// 插入 // 插入
@ -86,4 +92,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return apiKeyMapper.selectList(); return apiKeyMapper.selectList();
} }
// ========== spring-ai 集成 ==========
@Override
public StreamingChatClient getStreamingChatClient(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
}
} }

View File

@ -1,6 +1,8 @@
package cn.iocoder.yudao.framework.ai.config; package cn.iocoder.yudao.framework.ai.config;
import cn.hutool.core.io.IoUtil; import cn.hutool.core.io.IoUtil;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
@ -36,17 +38,22 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
/** /**
* ai 自动配置 * 芋道 AI 自动配置
* *
* @author fansili * @author fansili
* @time 2024/4/12 16:29
* @since 1.0
*/ */
@Slf4j
@AutoConfiguration @AutoConfiguration
@EnableConfigurationProperties(YudaoAiProperties.class) @EnableConfigurationProperties(YudaoAiProperties.class)
@Slf4j
public class YudaoAiAutoConfiguration { public class YudaoAiAutoConfiguration {
@Bean
public AiClientFactory aiClientFactory() {
return new AiClientFactoryImpl();
}
// ========== 各种 AI Client 创建 ==========
@Bean @Bean
@ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true") @ConditionalOnProperty(value = "yudao.ai.xinghuo.enable", havingValue = "true")
public XingHuoChatClient xingHuoChatClient(YudaoAiProperties yudaoAiProperties) { public XingHuoChatClient xingHuoChatClient(YudaoAiProperties yudaoAiProperties) {
@ -107,21 +114,6 @@ public class YudaoAiAutoConfiguration {
); );
} }
@Bean
@ConditionalOnProperty(value = "yudao.ai.openAiImage.enable", havingValue = "true")
public OpenAiImageClient openAiImageClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.OpenAiImageProperties openAiImageProperties = yudaoAiProperties.getOpenAiImage();
OpenAiImageOptions openAiImageOptions = new OpenAiImageOptions();
openAiImageOptions.setModel(openAiImageProperties.getModel().getModel());
openAiImageOptions.setStyle(openAiImageProperties.getStyle().getStyle());
openAiImageOptions.setResponseFormat("url"); // TODO 芋艿OpenAiImageOptions.ResponseFormatEnum.URL.getValue()
// 创建 client
return new OpenAiImageClient(
new OpenAiImageApi(openAiImageProperties.getApiKey()),
openAiImageOptions,
RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
@Bean @Bean
@ConditionalOnMissingBean(value = MidjourneyMessageHandler.class) @ConditionalOnMissingBean(value = MidjourneyMessageHandler.class)
public MidjourneyMessageHandler defaultMidjourneyMessageHandler() { public MidjourneyMessageHandler defaultMidjourneyMessageHandler() {

View File

@ -0,0 +1,47 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
/**
* AI 客户端工厂的接口类
*
* @author fansili
*/
public interface AiClientFactory {
/**
* 基于指定配置获得 StreamingChatClient 对象
*
* 如果不存在则进行创建
*
* @param platform 平台
* @param apiKey API KEY
* @param url API URL
* @return StreamingChatClient 对象
*/
StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于默认配置获得 StreamingChatClient 对象
*
* 默认配置指的是在 application.yaml 配置文件中的 spring.ai 相关的配置
*
* @param platform 平台
* @return StreamingChatClient 对象
*/
StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform);
/**
* 创建 Chat 参数
*
* @param platform 平台
* @param model 模型
* @param temperature 温度
* @param maxTokens 生成的最大 Token
* @return Chat 参数
*/
ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens);
}

View File

@ -0,0 +1,167 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.lang.Singleton;
import cn.hutool.core.lang.func.Func0;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import java.util.List;
/**
* AI 客户端工厂的实现类
*
* @author 芋道源码
*/
public class AiClientFactoryImpl implements AiClientFactory {
@Override
public StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(StreamingChatClient.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<StreamingChatClient>) () -> {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return buildOpenAiChatClient(apiKey, url);
case OLLAMA:
return buildOllamaChatClient(url);
case YI_YAN:
return buildYiYanChatClient(apiKey);
case XING_HUO:
return buildXingHuoChatClient(apiKey);
case QIAN_WEN:
return buildQianWenChatClient(apiKey);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
});
}
@Override
public StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return SpringUtil.getBean(OpenAiChatClient.class);
case OLLAMA:
return SpringUtil.getBean(OllamaChatClient.class);
case YI_YAN:
return SpringUtil.getBean(YiYanChatClient.class);
case XING_HUO:
return SpringUtil.getBean(XingHuoChatClient.class);
case QIAN_WEN:
return SpringUtil.getBean(QianWenChatClient.class);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
}
return StrUtil.format("{}#{}", clazz.getName(), ArrayUtil.join(params, "_"));
}
@Override
public ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
Float temperatureF = temperature != null ? temperature.floatValue() : null;
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return OpenAiChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO @fan增加一个 model
return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
// TODO @fan:增加 modeltemperature 参数
return new QianWenOptions().setMaxTokens(maxTokens);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
// ========== 各种创建 spring-ai 客户端的方法 ==========
/**
* 可参考 {@link OpenAiAutoConfiguration}
*/
private static OpenAiChatClient buildOpenAiChatClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
return new OpenAiChatClient(openAiApi);
}
/**
* 可参考 {@link OllamaAutoConfiguration}
*/
private static OllamaChatClient buildOllamaChatClient(String url) {
OllamaApi ollamaApi = new OllamaApi(url);
return new OllamaChatClient(ollamaApi);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#yiYanChatClient(YudaoAiProperties)}
*/
private static YiYanChatClient buildYiYanChatClient(String key) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
YiYanApi yiYanApi = new YiYanApi(appKey, secretKey, YiYanApi.DEFAULT_CHAT_MODEL, 0);
return new YiYanChatClient(yiYanApi);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
*/
private static XingHuoChatClient buildXingHuoChatClient(String key) {
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "XingHuoChatClient 的密钥需要 (appKey|secretKey) 格式");
String appId = keys.get(0);
String appKey = keys.get(1);
String secretKey = keys.get(2);
XingHuoApi xingHuoApi = new XingHuoApi(appId, appKey, secretKey);
return new XingHuoChatClient(xingHuoApi);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#qianWenChatClient(YudaoAiProperties)}
*/
private static QianWenChatClient buildQianWenChatClient(String key) {
QianWenApi qianWenApi = new QianWenApi(key, QianWenChatModal.QWEN_72B_CHAT);
return new QianWenChatClient(qianWenApi);
}
}

View File

@ -6,6 +6,8 @@ import lombok.experimental.Accessors;
import java.util.List; import java.util.List;
// TODO @fan增加一个 model 参数
// TODO @fan增加一个 Temperature 参数
/** /**
* 阿里云 千问 属性 * 阿里云 千问 属性
* *

View File

@ -14,6 +14,7 @@ import lombok.experimental.Accessors;
@Accessors(chain = true) @Accessors(chain = true)
public class XingHuoOptions implements ChatOptions { public class XingHuoOptions implements ChatOptions {
// TODO @fan这里 model 参数然后使用 string
/** /**
* https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E * https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
* <p> * <p>
@ -43,7 +44,6 @@ public class XingHuoOptions implements ChatOptions {
*/ */
private String chatId; private String chatId;
@Override @Override
public Float getTemperature() { public Float getTemperature() {
return this.temperature; return this.temperature;

View File

@ -6,6 +6,7 @@ import org.springframework.ai.chat.prompt.ChatOptions;
import java.util.List; import java.util.List;
// TODO @fan增加一个 model
// TODO @fan字段命名penalty_score 类似的建议改成驼峰原则 // TODO @fan字段命名penalty_score 类似的建议改成驼峰原则
// TODO @fan字段的注释可以都删除掉让用户 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 即可 // TODO @fan字段的注释可以都删除掉让用户 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 即可
/** /**

View File

@ -18,7 +18,7 @@ public class YiYanApi {
private static final String AUTH_2_TOKEN_URI = "/oauth/2.0/token"; private static final String AUTH_2_TOKEN_URI = "/oauth/2.0/token";
public static final String DEFAULT_CHAT_MODEL = YiYanChatModel.ERNIE4_0.getModel(); public static final YiYanChatModel DEFAULT_CHAT_MODEL = YiYanChatModel.ERNIE4_0;
private final String appKey; private final String appKey;
private final String secretKey; private final String secretKey;
@ -39,6 +39,7 @@ public class YiYanApi {
*/ */
private final YiYanChatModel useChatModel; private final YiYanChatModel useChatModel;
// TODO fan看看是不是去掉 refreshTokenSecondTime 字段
public YiYanApi(String appKey, String secretKey, YiYanChatModel useChatModel, int refreshTokenSecondTime) { public YiYanApi(String appKey, String secretKey, YiYanChatModel useChatModel, int refreshTokenSecondTime) {
this.appKey = appKey; this.appKey = appKey;
this.secretKey = secretKey; this.secretKey = secretKey;

View File

@ -150,15 +150,8 @@ spring.ai:
chat: chat:
model: llama3 model: llama3
openai: openai:
# api-key: sk-QmgIIPc5xiYd8lPb076b1b7774Ea49Af9eD2Ef172c8f7e43
# base-url: https://openkey.cloud
# api-key: sk-gkgfYxhX9FxyZJznwxRZSJwKeGQYNPDVWjhby2PRRf17GHeT
# base-url: https://api.chatanywhere.tech
api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z api-key: sk-yzKea6d8e8212c3bdd99f9f44ced1cae37c097e5aa3BTS7z
base-url: https://api.gptsapi.net base-url: https://api.gptsapi.net
# chat:
# options:
# model: gpt-4-0125-preview
yudao.ai: yudao.ai:
yiyan: yiyan: