!1013 新增文心一言、智谱 AI 的绘图能力

Merge pull request !1013 from 芋道源码/master-jdk21-ai
This commit is contained in:
芋道源码 2024-07-13 02:47:49 +00:00 committed by Gitee
commit 3aad33c3cd
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 62 additions and 2 deletions

View File

@ -33,6 +33,7 @@ import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.openai.OpenAiImageOptions;
import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@ -104,7 +105,9 @@ public class AiImageServiceImpl implements AiImageService {
ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
// 2. 上传到文件服务
byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json());
String b64Json = response.getResult().getOutput().getB64Json();
byte[] fileContent = StrUtil.isNotEmpty(b64Json) ? Base64.decode(b64Json)
: HttpUtil.downloadBytes(response.getResult().getOutput().getUrl());
String filePath = fileApi.createFile(fileContent);
// 3. 更新数据库
@ -148,6 +151,10 @@ public class AiImageServiceImpl implements AiImageService {
.withModel(draw.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
return ZhiPuAiImageOptions.builder()
.withModel(draw.getModel())
.build();
}
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
}

View File

@ -30,6 +30,7 @@ import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.model.function.FunctionCallbackContext;
@ -47,7 +48,9 @@ import org.springframework.ai.qianfan.api.QianFanImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
@ -118,6 +121,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return SpringUtil.getBean(TongYiImagesModel.class);
case YI_YAN:
return SpringUtil.getBean(QianFanImageModel.class);
case ZHI_PU:
return SpringUtil.getBean(ZhiPuAiImageModel.class);
case OPENAI:
return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION:
@ -135,6 +140,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
return buildTongYiImagesModel(apiKey);
case YI_YAN:
return buildQianFanImageModel(apiKey);
case ZHI_PU:
return buildZhiPuAiImageModel(apiKey, url);
case OPENAI:
return buildOpenAiImageModel(apiKey, url);
case STABLE_DIFFUSION:
@ -222,7 +229,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
* ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
*/
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
@ -230,6 +238,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
return new ZhiPuAiChatModel(zhiPuAiApi);
}
/**
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
* ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
ZhiPuAiImageApi zhiPuAiApi = new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
return new ZhiPuAiImageModel(zhiPuAiApi);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
*/

View File

@ -0,0 +1,35 @@
package cn.iocoder.yudao.framework.ai.image;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
/**
* {@link ZhiPuAiImageModel} 集成测试
*/
public class ZhiPuAiImageModelTests {
private final ZhiPuAiImageApi imageApi = new ZhiPuAiImageApi(
"78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(imageApi);
@Test
@Disabled
public void testCall() {
// 准备参数
ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
.withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
.build();
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
// 方法调用
ImageResponse response = imageModel.call(prompt);
// 打印结果
System.out.println(response);
}
}