【新增】AI:增加 ollama 模型的接入

This commit is contained in:
YunaiV 2024-05-17 22:16:57 +08:00
parent 7fca38ce1e
commit 9de9e938bf
5 changed files with 38 additions and 19 deletions

View File

@ -6,6 +6,7 @@ import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.models.tongyi.QianWenChatClient; import org.springframework.ai.models.tongyi.QianWenChatClient;
import org.springframework.ai.models.xinghuo.XingHuoChatClient; import org.springframework.ai.models.xinghuo.XingHuoChatClient;
import org.springframework.ai.models.yiyan.YiYanChatClient; import org.springframework.ai.models.yiyan.YiYanChatClient;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@ -36,12 +37,17 @@ public class AiChatClientFactory {
// TODO yunai 要不再加一个接口让他们拥有 ChatClientStreamingChatClient 功能 // TODO yunai 要不再加一个接口让他们拥有 ChatClientStreamingChatClient 功能
public StreamingChatClient getStreamingChatClient(AiPlatformEnum platformEnum) { public StreamingChatClient getStreamingChatClient(AiPlatformEnum platformEnum) {
// if (true) {
// return applicationContext.getBean(OllamaChatClient.class);
// }
if (AiPlatformEnum.QIAN_WEN == platformEnum) { if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class); return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) { } else if (AiPlatformEnum.YI_YAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class); return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) { } else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class); return applicationContext.getBean(XingHuoChatClient.class);
} else if (AiPlatformEnum.OLLAMA == platformEnum) {
return applicationContext.getBean(OllamaChatClient.class);
} }
throw new IllegalArgumentException("不支持的 chat client!"); throw new IllegalArgumentException("不支持的 chat client!");
} }

View File

@ -127,7 +127,7 @@ public class AiChatServiceImpl implements AiChatService {
// 1.1 校验对话存在 // 1.1 校验对话存在
AiChatConversationDO conversation = chatConversationService.validateExists(sendReqVO.getConversationId()); AiChatConversationDO conversation = chatConversationService.validateExists(sendReqVO.getConversationId());
if (ObjUtil.notEqual(conversation.getUserId(), userId)) { if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS); throw exception(CHAT_CONVERSATION_NOT_EXISTS); // TODO 芋艿异常情况的对接
} }
// 1.2 校验模型 // 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId()); AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());

View File

@ -24,6 +24,12 @@
<version>1.0.3</version> <version>1.0.3</version>
</dependency> </dependency>
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
<version>1.0.3</version>
</dependency>
<dependency> <dependency>
<groupId>cn.iocoder.boot</groupId> <groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-common</artifactId> <artifactId>yudao-common</artifactId>

View File

@ -1,11 +1,8 @@
package cn.iocoder.yudao.framework.ai.core.enums; package cn.iocoder.yudao.framework.ai.core.enums;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import java.util.List;
// TODO 芋艿这块看看要不要调整下 // TODO 芋艿这块看看要不要调整下
/** /**
* ai 模型平台 * ai 模型平台
@ -17,29 +14,31 @@ import java.util.List;
@AllArgsConstructor @AllArgsConstructor
public enum AiPlatformEnum { public enum AiPlatformEnum {
OPENAI("OpenAI", "OpenAI"),
OLLAMA("dall", "dall"),
YI_YAN("yiyan", "一言"), YI_YAN("yiyan", "一言"),
QIAN_WEN("qianwen", "千问"), QIAN_WEN("qianwen", "千问"),
XING_HUO("xinghuo", "星火"), XING_HUO("xinghuo", "星火"),
OPENAI("OpenAI", "OpenAI"),
OPEN_AI_DALL("dall", "dall"), OPEN_AI_DALL("dall", "dall"),
MIDJOURNEY("midjourney", "midjourney"), MIDJOURNEY("Ollama", "Ollama"),
; ;
private String platform; private final String platform;
private String name; private final String name;
public static List<AiPlatformEnum> CHAT_PLATFORM_LIST = Lists.newArrayList( // public static List<AiPlatformEnum> CHAT_PLATFORM_LIST = Lists.newArrayList(
AiPlatformEnum.YI_YAN, // AiPlatformEnum.YI_YAN,
AiPlatformEnum.QIAN_WEN, // AiPlatformEnum.QIAN_WEN,
AiPlatformEnum.XING_HUO, // AiPlatformEnum.XING_HUO,
AiPlatformEnum.OPENAI // AiPlatformEnum.OPENAI
); // );
//
public static List<AiPlatformEnum> IMAGE_PLATFORM_LIST = Lists.newArrayList( // public static List<AiPlatformEnum> IMAGE_PLATFORM_LIST = Lists.newArrayList(
AiPlatformEnum.OPEN_AI_DALL, // AiPlatformEnum.OPEN_AI_DALL,
AiPlatformEnum.MIDJOURNEY // AiPlatformEnum.MIDJOURNEY
); // );
public static AiPlatformEnum validatePlatform(String platform) { public static AiPlatformEnum validatePlatform(String platform) {
for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) { for (AiPlatformEnum platformEnum : AiPlatformEnum.values()) {

View File

@ -142,6 +142,14 @@ spring:
listener: listener:
missing-topics-fatal: false # 消费监听接口监听的主题不存在时,默认会报错。所以通过设置为 false ,解决报错 missing-topics-fatal: false # 消费监听接口监听的主题不存在时,默认会报错。所以通过设置为 false ,解决报错
--- #################### AI 相关配置 ####################
spring.ai:
ollama:
base-url: http://127.0.0.1:11434
chat:
model: llama3
--- #################### 芋道相关配置 #################### --- #################### 芋道相关配置 ####################
yudao: yudao: