【增加】chatglm 实现 spring ai 标准

This commit is contained in:
cherishsince 2024-07-12 15:03:34 +08:00
parent db7315b8cd
commit d865cc293b
5 changed files with 261 additions and 0 deletions

View File

@ -60,6 +60,13 @@
<version>2.14.0</version>
</dependency>
<!-- bigmodel -->
<dependency>
<groupId>cn.bigmodel.openapi</groupId>
<artifactId>oapi-java-sdk</artifactId>
<version>release-V4-2.0.2</version>
</dependency>
<!-- Test 测试相关 -->
<dependency>
<groupId>org.springframework.boot</groupId>

View File

@ -0,0 +1,75 @@
package cn.iocoder.yudao.framework.ai.core.model.chatglm;
import cn.iocoder.yudao.framework.ai.core.model.chatglm.api.ChatGlmResponseMetadata;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import org.springframework.ai.image.*;
import java.io.ByteArrayOutputStream;
import java.net.URL;
import java.util.Base64;
import java.util.stream.Collectors;
public class ChatGlmImageModel implements ImageModel {
private ClientV4 client;
public ChatGlmImageModel(String apiSecretKey) {
client = new ClientV4.Builder(apiSecretKey).build();
}
@Override
public ImageResponse call(ImagePrompt request) {
CreateImageRequest imageRequest = CreateImageRequest.builder()
.model(request.getOptions().getModel())
.prompt(request.getInstructions().get(0).getText())
.build();
return convert(client.createImage(imageRequest));
}
private ImageResponse convert(ImageApiResponse result) {
return new ImageResponse(
result.getData().getData().stream().map(item -> {
try {
String url = item.getUrl();
String base64Image = convertImageToBase64(url);
Image image = new Image(url, base64Image);
return new ImageGeneration(image);
} catch (Exception e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList()),
new ChatGlmResponseMetadata(result)
);
}
/**
* Convert image to base64.
* @param imageUrl the image url.
* @return the base64 image.
* @throws Exception the exception.
*/
public String convertImageToBase64(String imageUrl) throws Exception {
var url = new URL(imageUrl);
var inputStream = url.openStream();
var outputStream = new ByteArrayOutputStream();
var buffer = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
var imageBytes = outputStream.toByteArray();
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
inputStream.close();
outputStream.close();
return base64Image;
}
}

View File

@ -0,0 +1,115 @@
package cn.iocoder.yudao.framework.ai.core.model.chatglm;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Setter;
import org.springframework.ai.image.ImageOptions;
/**
* chatglm
* api地址https://open.bigmodel.cn/dev/api#cogview
*/
@Setter
public class ChatGlmImageOptions implements ImageOptions {
@JsonProperty("n")
private Integer n;
@JsonProperty("model")
private String model = "cogview-3";
@JsonProperty("size_width")
private Integer width;
@JsonProperty("size_height")
private Integer height;
@JsonProperty("size")
private String size;
@JsonProperty("style")
private String style;
@JsonProperty("user_id")
private String user;
@JsonProperty("responseFormat")
private String responseFormat;
// ==== build
public static ChatGlmImageOptions.Builder builder() {
return new ChatGlmImageOptions.Builder();
}
public static class Builder {
private final ChatGlmImageOptions options;
private Builder() {
this.options = new ChatGlmImageOptions();
}
public ChatGlmImageOptions.Builder withN(Integer n) {
options.setN(n);
return this;
}
public ChatGlmImageOptions.Builder withModel(String model) {
options.setModel(model);
return this;
}
public ChatGlmImageOptions.Builder withWidth(Integer width) {
options.setWidth(width);
return this;
}
public ChatGlmImageOptions.Builder withHeight(Integer height) {
options.setHeight(height);
return this;
}
public ChatGlmImageOptions.Builder withStyle(String style) {
options.setStyle(style);
return this;
}
public ChatGlmImageOptions.Builder withUser(String user) {
options.setUser(user);
return this;
}
public ChatGlmImageOptions build() {
return options;
}
}
// ==== get
@Override
public Integer getN() {
return n;
}
@Override
public String getModel() {
return model;
}
@Override
public Integer getWidth() {
return width;
}
@Override
public Integer getHeight() {
return height;
}
@Override
public String getResponseFormat() {
return responseFormat;
}
}

View File

@ -0,0 +1,24 @@
package cn.iocoder.yudao.framework.ai.core.model.chatglm.api;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import org.springframework.ai.image.ImageResponseMetadata;
import java.util.HashMap;
public class ChatGlmResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
private Long created;
public ChatGlmResponseMetadata(ImageApiResponse result) {
created = result.getData().getCreated();
}
@Override
public Long getCreated() {
return created;
}
public void setCreated(Long created) {
this.created = created;
}
}

View File

@ -0,0 +1,40 @@
package cn.iocoder.yudao.framework.ai.image;
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import com.alibaba.fastjson.JSON;
import com.zhipu.oapi.ClientV4;
import com.zhipu.oapi.core.httpclient.ApacheHttpClientTransport;
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
import org.junit.jupiter.api.Test;
import org.springframework.ai.image.ImageOptionsBuilder;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.springframework.ai.qianfan.QianFanImageModel;
import org.springframework.ai.qianfan.QianFanImageOptions;
import org.springframework.ai.qianfan.api.QianFanImageApi;
/**
* 百度千帆 image
*/
public class ChatGlmImageModelTests {
@Test
public void callTest() {
ChatGlmImageModel model = new ChatGlmImageModel("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
ImageResponse call = model.call(new ImagePrompt("万里长城", ChatGlmImageOptions.builder().build()));
System.err.println(call.getResult().getOutput().getUrl());
}
@Test
public void createImageTest() {
ClientV4 client = new ClientV4.Builder("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy").build();
CreateImageRequest createImageRequest = new CreateImageRequest();
createImageRequest.setModel("cogview-3");
createImageRequest.setPrompt("长城!");
ImageApiResponse image = client.createImage(createImageRequest);
System.err.println(JSON.toJSONString(image));
}
}