【增加】chatglm 实现 spring ai 标准
This commit is contained in:
parent
db7315b8cd
commit
d865cc293b
|
@ -60,6 +60,13 @@
|
|||
<version>2.14.0</version>
|
||||
</dependency>
|
||||
|
||||
<!-- bigmodel -->
|
||||
<dependency>
|
||||
<groupId>cn.bigmodel.openapi</groupId>
|
||||
<artifactId>oapi-java-sdk</artifactId>
|
||||
<version>release-V4-2.0.2</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Test 测试相关 -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.model.chatglm;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.chatglm.api.ChatGlmResponseMetadata;
|
||||
import com.zhipu.oapi.ClientV4;
|
||||
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
|
||||
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
||||
import org.springframework.ai.image.*;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.net.URL;
|
||||
import java.util.Base64;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ChatGlmImageModel implements ImageModel {
|
||||
|
||||
private ClientV4 client;
|
||||
|
||||
public ChatGlmImageModel(String apiSecretKey) {
|
||||
client = new ClientV4.Builder(apiSecretKey).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageResponse call(ImagePrompt request) {
|
||||
CreateImageRequest imageRequest = CreateImageRequest.builder()
|
||||
.model(request.getOptions().getModel())
|
||||
.prompt(request.getInstructions().get(0).getText())
|
||||
.build();
|
||||
return convert(client.createImage(imageRequest));
|
||||
}
|
||||
|
||||
private ImageResponse convert(ImageApiResponse result) {
|
||||
return new ImageResponse(
|
||||
result.getData().getData().stream().map(item -> {
|
||||
try {
|
||||
String url = item.getUrl();
|
||||
String base64Image = convertImageToBase64(url);
|
||||
Image image = new Image(url, base64Image);
|
||||
return new ImageGeneration(image);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}).collect(Collectors.toList()),
|
||||
new ChatGlmResponseMetadata(result)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Convert image to base64.
|
||||
* @param imageUrl the image url.
|
||||
* @return the base64 image.
|
||||
* @throws Exception the exception.
|
||||
*/
|
||||
public String convertImageToBase64(String imageUrl) throws Exception {
|
||||
|
||||
var url = new URL(imageUrl);
|
||||
var inputStream = url.openStream();
|
||||
var outputStream = new ByteArrayOutputStream();
|
||||
var buffer = new byte[4096];
|
||||
int bytesRead;
|
||||
|
||||
while ((bytesRead = inputStream.read(buffer)) != -1) {
|
||||
outputStream.write(buffer, 0, bytesRead);
|
||||
}
|
||||
|
||||
var imageBytes = outputStream.toByteArray();
|
||||
|
||||
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
|
||||
|
||||
inputStream.close();
|
||||
outputStream.close();
|
||||
|
||||
return base64Image;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,115 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.model.chatglm;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Setter;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
|
||||
/**
|
||||
* chatglm
|
||||
* api地址:https://open.bigmodel.cn/dev/api#cogview
|
||||
*/
|
||||
@Setter
|
||||
public class ChatGlmImageOptions implements ImageOptions {
|
||||
|
||||
@JsonProperty("n")
|
||||
private Integer n;
|
||||
|
||||
@JsonProperty("model")
|
||||
private String model = "cogview-3";
|
||||
|
||||
@JsonProperty("size_width")
|
||||
private Integer width;
|
||||
|
||||
@JsonProperty("size_height")
|
||||
private Integer height;
|
||||
|
||||
@JsonProperty("size")
|
||||
private String size;
|
||||
|
||||
@JsonProperty("style")
|
||||
private String style;
|
||||
|
||||
@JsonProperty("user_id")
|
||||
private String user;
|
||||
|
||||
@JsonProperty("responseFormat")
|
||||
private String responseFormat;
|
||||
|
||||
// ==== build
|
||||
|
||||
|
||||
public static ChatGlmImageOptions.Builder builder() {
|
||||
return new ChatGlmImageOptions.Builder();
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private final ChatGlmImageOptions options;
|
||||
|
||||
private Builder() {
|
||||
this.options = new ChatGlmImageOptions();
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withN(Integer n) {
|
||||
options.setN(n);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withModel(String model) {
|
||||
options.setModel(model);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withWidth(Integer width) {
|
||||
options.setWidth(width);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withHeight(Integer height) {
|
||||
options.setHeight(height);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withStyle(String style) {
|
||||
options.setStyle(style);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions.Builder withUser(String user) {
|
||||
options.setUser(user);
|
||||
return this;
|
||||
}
|
||||
|
||||
public ChatGlmImageOptions build() {
|
||||
return options;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ==== get
|
||||
|
||||
@Override
|
||||
public Integer getN() {
|
||||
return n;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getWidth() {
|
||||
return width;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getHeight() {
|
||||
return height;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResponseFormat() {
|
||||
return responseFormat;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package cn.iocoder.yudao.framework.ai.core.model.chatglm.api;
|
||||
|
||||
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
||||
import org.springframework.ai.image.ImageResponseMetadata;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
public class ChatGlmResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
|
||||
|
||||
private Long created;
|
||||
|
||||
public ChatGlmResponseMetadata(ImageApiResponse result) {
|
||||
created = result.getData().getCreated();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getCreated() {
|
||||
return created;
|
||||
}
|
||||
|
||||
public void setCreated(Long created) {
|
||||
this.created = created;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.chatglm.ChatGlmImageOptions;
|
||||
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.zhipu.oapi.ClientV4;
|
||||
import com.zhipu.oapi.core.httpclient.ApacheHttpClientTransport;
|
||||
import com.zhipu.oapi.service.v4.image.CreateImageRequest;
|
||||
import com.zhipu.oapi.service.v4.image.ImageApiResponse;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImageOptionsBuilder;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
import org.springframework.ai.qianfan.QianFanImageModel;
|
||||
import org.springframework.ai.qianfan.QianFanImageOptions;
|
||||
import org.springframework.ai.qianfan.api.QianFanImageApi;
|
||||
|
||||
/**
|
||||
* 百度千帆 image
|
||||
*/
|
||||
public class ChatGlmImageModelTests {
|
||||
|
||||
@Test
|
||||
public void callTest() {
|
||||
ChatGlmImageModel model = new ChatGlmImageModel("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy");
|
||||
ImageResponse call = model.call(new ImagePrompt("万里长城", ChatGlmImageOptions.builder().build()));
|
||||
System.err.println(call.getResult().getOutput().getUrl());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void createImageTest() {
|
||||
ClientV4 client = new ClientV4.Builder("78d3228c1d9e5e342a3e1ab349e2dd7b.VXLoq5vrwK2ofboy").build();
|
||||
CreateImageRequest createImageRequest = new CreateImageRequest();
|
||||
createImageRequest.setModel("cogview-3");
|
||||
createImageRequest.setPrompt("长城!");
|
||||
ImageApiResponse image = client.createImage(createImageRequest);
|
||||
System.err.println(JSON.toJSONString(image));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue