【修改todo】增加 sd 各种参数

This commit is contained in:
cherishsince 2024-06-30 16:10:56 +08:00
parent 246c233253
commit f300b8a1ae
3 changed files with 114 additions and 106 deletions

View File

@ -37,6 +37,7 @@ public interface ErrorCodeConstants {
ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!"); ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}"); ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}"); ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
ErrorCode IMAGE_FAIL = new ErrorCode(1_022_005_002, "图片绘画失败! {}");
// ========== API 音乐 1-040-006-000 ========== // ========== API 音乐 1-040-006-000 ==========
ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!"); ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");

View File

@ -123,9 +123,16 @@ public class AiImageServiceImpl implements AiImageService {
.withResponseFormat("b64_json") .withResponseFormat("b64_json")
.build(); .build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) { } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage // https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
return StabilityAiImageOptions.builder().withModel(draw.getModel()) return StabilityAiImageOptions.builder().withModel(draw.getModel())
.withHeight(draw.getHeight()).withWidth(draw.getWidth()) // TODO @范各种参数的接入 .withHeight(draw.getHeight()).withWidth(draw.getWidth())
.withSeed(Long.valueOf(draw.getOptions().get("seed")))
.withCfgScale(Float.valueOf(draw.getOptions().get("scale")))
.withSteps(Integer.valueOf(draw.getOptions().get("steps")))
.withSampler(String.valueOf(draw.getOptions().get("sampler")))
.withStylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
.withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
.build(); .build();
} }
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform()); throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());

View File

@ -1,105 +1,105 @@
package cn.iocoder.yudao.framework.ai.chat; //package cn.iocoder.yudao.framework.ai.chat;
//
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient; //import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal; //import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions; //import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi; //import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import com.alibaba.dashscope.aigc.generation.GenerationResult; //import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam; //import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message; //import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.MessageManager; //import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.common.Role; //import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.InputRequiredException; //import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException; //import com.alibaba.dashscope.exception.NoApiKeyException;
import org.junit.Before; //import org.junit.Before;
import org.junit.Test; //import org.junit.Test;
import org.springframework.ai.chat.messages.SystemMessage; //import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage; //import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse; //import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; //import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux; //import reactor.core.publisher.Flux;
//
import java.util.ArrayList; //import java.util.ArrayList;
import java.util.List; //import java.util.List;
import java.util.Scanner; //import java.util.Scanner;
import java.util.function.Consumer; //import java.util.function.Consumer;
//
// TODO 芋艿整理单测 //// TODO 芋艿整理单测
/** ///**
* author: fansili // * author: fansili
* time: 2024/3/13 21:37 // * time: 2024/3/13 21:37
*/ // */
public class QianWenChatClientTests { //public class QianWenChatClientTests {
//
private QianWenChatClient qianWenChatClient; // private QianWenChatClient qianWenChatClient;
//
@Before // @Before
public void setup() { // public void setup() {
QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT); // QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT);
QianWenOptions qianWenOptions = new QianWenOptions(); // QianWenOptions qianWenOptions = new QianWenOptions();
qianWenOptions.setTopP(0.8F); // qianWenOptions.setTopP(0.8F);
// qianWenOptions.setTopK(3); TODO 芋艿临时处理 //// qianWenOptions.setTopK(3); TODO 芋艿临时处理
// qianWenOptions.setTemperature(0.6F); TODO 芋艿临时处理 //// qianWenOptions.setTemperature(0.6F); TODO 芋艿临时处理
qianWenChatClient = new QianWenChatClient( // qianWenChatClient = new QianWenChatClient(
qianWenApi, // qianWenApi,
qianWenOptions // qianWenOptions
); // );
} // }
//
@Test // @Test
public void callTest() { // public void callTest() {
List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>(); // List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。")); // messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。"));
messages.add(new UserMessage("长沙怎么样?")); // messages.add(new UserMessage("长沙怎么样?"));
//
ChatResponse call = qianWenChatClient.call(new Prompt(messages)); // ChatResponse call = qianWenChatClient.call(new Prompt(messages));
System.err.println(call.getResult()); // System.err.println(call.getResult());
} // }
//
@Test // @Test
public void streamTest() { // public void streamTest() {
List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>(); // List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。")); // messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("长沙怎么样?")); // messages.add(new UserMessage("长沙怎么样?"));
//
Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages)); // Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages));
flux.subscribe(new Consumer<ChatResponse>() { // flux.subscribe(new Consumer<ChatResponse>() {
@Override // @Override
public void accept(ChatResponse chatResponse) { // public void accept(ChatResponse chatResponse) {
System.err.print(chatResponse.getResult().getOutput().getContent()); // System.err.print(chatResponse.getResult().getOutput().getContent());
} // }
}); // });
//
// 阻止退出 // // 阻止退出
Scanner scanner = new Scanner(System.in); // Scanner scanner = new Scanner(System.in);
scanner.nextLine(); // scanner.nextLine();
} // }
//
@Test // @Test
public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException { // public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException {
com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation(); // com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation();
MessageManager msgManager = new MessageManager(10); // MessageManager msgManager = new MessageManager(10);
Message systemMsg = // Message systemMsg =
Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build(); // Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build(); // Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build();
msgManager.add(systemMsg); // msgManager.add(systemMsg);
msgManager.add(userMsg); // msgManager.add(userMsg);
QwenParam param = // QwenParam param =
QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get()) // QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get())
.resultFormat(QwenParam.ResultFormat.MESSAGE) // .resultFormat(QwenParam.ResultFormat.MESSAGE)
.topP(0.8) // .topP(0.8)
/* set the random seed, optional, default to 1234 if not set */ // /* set the random seed, optional, default to 1234 if not set */
.seed(100) // .seed(100)
.apiKey("sk-Zsd81gZYg7") // .apiKey("sk-Zsd81gZYg7")
.build(); // .build();
GenerationResult result = gen.call(param); // GenerationResult result = gen.call(param);
System.out.println(result); // System.out.println(result);
System.out.println("-----------------"); // System.out.println("-----------------");
System.out.println("-----------------"); // System.out.println("-----------------");
msgManager.add(result); // msgManager.add(result);
param.setPrompt("能否缩短一些,只讲三点"); // param.setPrompt("能否缩短一些,只讲三点");
param.setMessages(msgManager.get()); // param.setMessages(msgManager.get());
result = gen.call(param); // result = gen.call(param);
System.out.println(result); // System.out.println(result);
} // }
} //}