diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java index f561dacb52..b95982ed3e 100644 --- a/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/factory/AiModelFactoryImpl.java @@ -55,8 +55,6 @@ public class AiModelFactoryImpl implements AiModelFactory { return Singleton.get(cacheKey, (Func0) () -> { //noinspection EnhancedSwitchMigration switch (platform) { - case OLLAMA: - return buildOllamaChatClient(url); case YI_YAN: return buildYiYanChatClient(apiKey); case XING_HUO: @@ -67,6 +65,8 @@ public class AiModelFactoryImpl implements AiModelFactory { return buildDeepSeekChatClient(apiKey); case OPENAI: return buildOpenAiChatModel(apiKey, url); + case OLLAMA: + return buildOllamaChatModel(url); default: throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform)); } @@ -163,7 +163,7 @@ public class AiModelFactoryImpl implements AiModelFactory { /** * 可参考 {@link OllamaAutoConfiguration} */ - private static OllamaChatModel buildOllamaChatClient(String url) { + private static OllamaChatModel buildOllamaChatModel(String url) { OllamaApi ollamaApi = new OllamaApi(url); return new OllamaChatModel(ollamaApi); } diff --git a/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java new file mode 100644 index 0000000000..c6b99f287b --- /dev/null +++ b/yudao-module-ai/yudao-spring-boot-starter-ai/src/test/java/cn/iocoder/yudao/framework/ai/chat/LlamaChatModelTests.java @@ -0,0 +1,63 @@ +package cn.iocoder.yudao.framework.ai.chat; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.api.OllamaApi; +import org.springframework.ai.ollama.api.OllamaModel; +import org.springframework.ai.ollama.api.OllamaOptions; +import reactor.core.publisher.Flux; + +import java.util.ArrayList; +import java.util.List; + +/** + * {@link OllamaChatModel} 集成测试 + * + * @author 芋道源码 + */ +public class LlamaChatModelTests { + + private final OllamaApi ollamaApi = new OllamaApi( + "http://127.0.0.1:11434"); + private final OllamaChatModel chatModel = new OllamaChatModel(ollamaApi, + OllamaOptions.create().withModel(OllamaModel.LLAMA3.getModelName())); + + @Test + @Disabled + public void testCall() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + ChatResponse response = chatModel.call(new Prompt(messages)); + // 打印结果 + System.out.println(response); + System.out.println(response.getResult().getOutput()); + } + + @Test + @Disabled + public void testStream() { + // 准备参数 + List messages = new ArrayList<>(); + messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); + messages.add(new UserMessage("1 + 1 = ?")); + + // 调用 + Flux flux = chatModel.stream(new Prompt(messages)); + // 打印结果 + flux.doOnNext(response -> { +// System.out.println(response); + System.out.println(response.getResult().getOutput()); + }).then().block(); + } + +}