【代码优化】AI:ChatGlm 替换成 ZhiPuAiImage 实现
This commit is contained in:
parent
4311fe4517
commit
73502d565f
|
@ -9,7 +9,6 @@ import cn.hutool.core.util.StrUtil;
|
|||
import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.hutool.http.HttpUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
|
@ -34,6 +33,7 @@ import org.springframework.ai.image.ImageResponse;
|
|||
import org.springframework.ai.openai.OpenAiImageOptions;
|
||||
import org.springframework.ai.qianfan.QianFanImageOptions;
|
||||
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
@ -105,7 +105,9 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
|
||||
|
||||
// 2. 上传到文件服务
|
||||
byte[] fileContent = Base64.decode(response.getResult().getOutput().getB64Json());
|
||||
String b64Json = response.getResult().getOutput().getB64Json();
|
||||
byte[] fileContent = StrUtil.isNotEmpty(b64Json) ? Base64.decode(b64Json)
|
||||
: HttpUtil.downloadBytes(response.getResult().getOutput().getUrl());
|
||||
String filePath = fileApi.createFile(fileContent);
|
||||
|
||||
// 3. 更新数据库
|
||||
|
@ -149,8 +151,8 @@ public class AiImageServiceImpl implements AiImageService {
|
|||
.withModel(draw.getModel()).withN(1)
|
||||
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
|
||||
.build();
|
||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.CHATGLM.getPlatform())) {
|
||||
return ChatGlmImageOptions.builder()
|
||||
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
|
||||
return ZhiPuAiImageOptions.builder()
|
||||
.withModel(draw.getModel())
|
||||
.build();
|
||||
}
|
||||
|
|
|
@ -60,13 +60,6 @@
|
|||
<version>2.14.0</version>
|
||||
</dependency>
|
||||
|
||||
<!-- bigmodel -->
|
||||
<dependency>
|
||||
<groupId>cn.bigmodel.openapi</groupId>
|
||||
<artifactId>oapi-java-sdk</artifactId>
|
||||
<version>release-V4-2.0.2</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Test 测试相关 -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
|
|
|
@ -28,7 +28,6 @@ public enum AiPlatformEnum {
|
|||
STABLE_DIFFUSION("StableDiffusion", "StableDiffusion"), // Stability AI
|
||||
MIDJOURNEY("Midjourney", "Midjourney"), // Midjourney
|
||||
SUNO("Suno", "Suno"), // Suno AI
|
||||
CHATGLM("ChatGlm", "ChatGlm"), // Suno AI
|
||||
|
||||
;
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ import cn.hutool.extra.spring.SpringUtil;
|
|||
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
|
||||
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
|
@ -31,6 +30,7 @@ import org.springframework.ai.autoconfigure.qianfan.QianFanImageProperties;
|
|||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiChatProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiConnectionProperties;
|
||||
import org.springframework.ai.autoconfigure.zhipuai.ZhiPuAiImageProperties;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.image.ImageModel;
|
||||
import org.springframework.ai.model.function.FunctionCallbackContext;
|
||||
|
@ -48,7 +48,9 @@ import org.springframework.ai.qianfan.api.QianFanImageApi;
|
|||
import org.springframework.ai.stabilityai.StabilityAiImageModel;
|
||||
import org.springframework.ai.stabilityai.api.StabilityAiApi;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.web.client.ResponseErrorHandler;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
@ -119,6 +121,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
return SpringUtil.getBean(TongYiImagesModel.class);
|
||||
case YI_YAN:
|
||||
return SpringUtil.getBean(QianFanImageModel.class);
|
||||
case ZHI_PU:
|
||||
return SpringUtil.getBean(ZhiPuAiImageModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiImageModel.class);
|
||||
case STABLE_DIFFUSION:
|
||||
|
@ -136,12 +140,12 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
return buildTongYiImagesModel(apiKey);
|
||||
case YI_YAN:
|
||||
return buildQianFanImageModel(apiKey);
|
||||
case ZHI_PU:
|
||||
return buildZhiPuAiImageModel(apiKey, url);
|
||||
case OPENAI:
|
||||
return buildOpenAiImageModel(apiKey, url);
|
||||
case STABLE_DIFFUSION:
|
||||
return buildStabilityAiImageModel(apiKey, url);
|
||||
case CHATGLM:
|
||||
return buildChatGlmModel(apiKey);
|
||||
default:
|
||||
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
|
||||
}
|
||||
|
@ -225,7 +229,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
|
||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiChatModel(
|
||||
* ZhiPuAiConnectionProperties, ZhiPuAiChatProperties, RestClient.Builder, List, FunctionCallbackContext, RetryTemplate, ResponseErrorHandler)}
|
||||
*/
|
||||
private ZhiPuAiChatModel buildZhiPuChatModel(String apiKey, String url) {
|
||||
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
||||
|
@ -233,6 +238,16 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
return new ZhiPuAiChatModel(zhiPuAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link ZhiPuAiAutoConfiguration#zhiPuAiImageModel(
|
||||
* ZhiPuAiConnectionProperties, ZhiPuAiImageProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
|
||||
*/
|
||||
private ZhiPuAiImageModel buildZhiPuAiImageModel(String apiKey, String url) {
|
||||
url = StrUtil.blankToDefault(url, ZhiPuAiConnectionProperties.DEFAULT_BASE_URL);
|
||||
ZhiPuAiImageApi zhiPuAiApi = new ZhiPuAiImageApi(url, apiKey, RestClient.builder());
|
||||
return new ZhiPuAiImageModel(zhiPuAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link YudaoAiAutoConfiguration#xingHuoChatClient(YudaoAiProperties)}
|
||||
*/
|
||||
|
@ -276,7 +291,4 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
|||
return new StabilityAiImageModel(stabilityAiApi);
|
||||
}
|
||||
|
||||
private ChatGlmImageModel buildChatGlmModel(String apiKey) {
|
||||
return new ChatGlmImageModel(apiKey);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.model.chatglm;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.chatglm.api.ChatGlmResponseMetadata;
|
||||
import com.zhipu.oapi.ClientV4;
|
||||
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
|
||||
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
||||
import org.springframework.ai.image.*;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.net.URL;
|
||||
import java.util.Base64;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ChatGlmImageModel implements ImageModel {
|
||||
|
||||
private ClientV4 client;
|
||||
|
||||
public ChatGlmImageModel(String apiSecretKey) {
|
||||
client = new ClientV4.Builder(apiSecretKey).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageResponse call(ImagePrompt request) {
|
||||
CreateImageRequest imageRequest = CreateImageRequest.builder()
|
||||
.model(request.getOptions().getModel())
|
||||
.prompt(request.getInstructions().get(0).getText())
|
||||
.build();
|
||||
return convert(client.createImage(imageRequest));
|
||||
}
|
||||
|
||||
private ImageResponse convert(ImageApiResponse result) {
|
||||
return new ImageResponse(
|
||||
result.getData().getData().stream().map(item -> {
|
||||
try {
|
||||
String url = item.getUrl();
|
||||
String base64Image = convertImageToBase64(url);
|
||||
Image image = new Image(url, base64Image);
|
||||
return new ImageGeneration(image);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}).collect(Collectors.toList()),
|
||||
new ChatGlmResponseMetadata(result)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Convert image to base64.
|
||||
* @param imageUrl the image url.
|
||||
* @return the base64 image.
|
||||
* @throws Exception the exception.
|
||||
*/
|
||||
public String convertImageToBase64(String imageUrl) throws Exception {
|
||||
|
||||
var url = new URL(imageUrl);
|
||||
var inputStream = url.openStream();
|
||||
var outputStream = new ByteArrayOutputStream();
|
||||
var buffer = new byte[4096];
|
||||
int bytesRead;
|
||||
|
||||
while ((bytesRead = inputStream.read(buffer)) != -1) {
|
||||
outputStream.write(buffer, 0, bytesRead);
|
||||
}
|
||||
|
||||
var imageBytes = outputStream.toByteArray();
|
||||
|
||||
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
|
||||
|
||||
inputStream.close();
|
||||
outputStream.close();
|
||||
|
||||
return base64Image;
|
||||
}
|
||||
}
|
|
@ -1,115 +0,0 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.model.chatglm;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Setter;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
|
||||
/**
|
||||
* chatglm
|
||||
* api地址:https://open.bigmodel.cn/dev/api#cogview
|
||||
*/
|
||||
@Setter
|
||||
public class ChatGlmImageOptions implements ImageOptions {
|
||||
|
||||
@JsonProperty("n")
|
||||
private Integer n;
|
||||
|
||||
@JsonProperty("model")
|
||||
private String model = "cogview-3";
|
||||
|
||||
@JsonProperty("size_width")
|
||||
private Integer width;
|
||||
|
||||
@JsonProperty("size_height")
|
||||
private Integer height;
|
||||
|
||||
@JsonProperty("size")
|
||||
private String size;
|
||||
|
||||
@JsonProperty("style")
|
||||
private String style;
|
||||
|
||||
@JsonProperty("user_id")
|
||||
private String user;
|
||||
|
||||
@JsonProperty("responseFormat")
|
||||
private String responseFormat;
|
||||
|
||||
// ==== build
|
||||
|
||||
|
||||
public static ChatGlmImageOptions.Builder builder() {
|
||||
return new ChatGlmImageOptions.Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private final ChatGlmImageOptions options;
|
||||
|
||||
private Builder() {
|
||||
this.options = new ChatGlmImageOptions();
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withN(Integer n) {
|
||||
options.setN(n);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withModel(String model) {
|
||||
options.setModel(model);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withWidth(Integer width) {
|
||||
options.setWidth(width);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withHeight(Integer height) {
|
||||
options.setHeight(height);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withStyle(String style) {
|
||||
options.setStyle(style);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withUser(String user) {
|
||||
options.setUser(user);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions build() {
|
||||
return options;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ==== get
|
||||
|
||||
@Override
|
||||
public Integer getN() {
|
||||
return n;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getWidth() {
|
||||
return width;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getHeight() {
|
||||
return height;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResponseFormat() {
|
||||
return responseFormat;
|
||||
}
|
||||
}
|
|
@ -1,24 +0,0 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.model.chatglm.api;
|
||||
|
||||
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
||||
import org.springframework.ai.image.ImageResponseMetadata;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
public class ChatGlmResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
|
||||
|
||||
private Long created;
|
||||
|
||||
public ChatGlmResponseMetadata(ImageApiResponse result) {
|
||||
created = result.getData().getCreated();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getCreated() {
|
||||
return created;
|
||||
}
|
||||
|
||||
public void setCreated(Long created) {
|
||||
this.created = created;
|
||||
}
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.zhipu.oapi.ClientV4;
|
||||
import com.zhipu.oapi.core.httpclient.ApacheHttpClientTransport;
|
||||
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
|
||||
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptionsBuilder;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.qianfan.QianFanImageModel;
|
||||
import org.springframework.ai.qianfan.QianFanImageOptions;
|
||||
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
||||
|
||||
/**
|
||||
* 百度千帆 image
|
||||
*/
|
||||
public class ChatGlmImageModelTests {
|
||||
|
||||
@Test
|
||||
public void callTest() {
|
||||
ChatGlmImageModel model = new ChatGlmImageModel("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
|
||||
ImageResponse call = model.call(new ImagePrompt("万里长城", ChatGlmImageOptions.builder().build()));
|
||||
System.err.println(call.getResult().getOutput().getUrl());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createImageTest() {
|
||||
ClientV4 client = new ClientV4.Builder("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy").build();
|
||||
CreateImageRequest createImageRequest = new CreateImageRequest();
|
||||
createImageRequest.setModel("cogview-3");
|
||||
createImageRequest.setPrompt("长城!");
|
||||
ImageApiResponse image = client.createImage(createImageRequest);
|
||||
System.err.println(JSON.toJSONString(image));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageModel;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiImageOptions;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
||||
|
||||
/**
|
||||
* {@link ZhiPuAiImageModel} 集成测试
|
||||
*/
|
||||
public class ZhiPuAiImageModelTests {
|
||||
|
||||
private final ZhiPuAiImageApi imageApi = new ZhiPuAiImageApi(
|
||||
"78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
|
||||
private final ZhiPuAiImageModel imageModel = new ZhiPuAiImageModel(imageApi);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
ZhiPuAiImageOptions imageOptions = ZhiPuAiImageOptions.builder()
|
||||
.withModel(ZhiPuAiImageApi.ImageModel.CogView_3.getValue())
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue