【优化】处理百度 system 角色定制失效问题。

This commit is contained in:
cherishsince 2024-04-27 18:29:58 +08:00
parent c811f3a4c2
commit 10a94c3ef2
3 changed files with 77 additions and 26 deletions

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.framework.ai.chatyiyan;
import cn.hutool.core.bean.BeanUtil; import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.chat.*; import cn.iocoder.yudao.framework.ai.chat.*;
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
import cn.iocoder.yudao.framework.ai.chat.messages.MessageType;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions; import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi; import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanApi;
@ -9,6 +11,7 @@ import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletion;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest; import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException; import cn.iocoder.yudao.framework.ai.chatyiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback; import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext; import org.springframework.retry.RetryContext;
@ -18,10 +21,11 @@ import reactor.core.publisher.Flux;
import java.time.Duration; import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
/** /**
* 文心一言 * 文心一言
* * <p>
* author: fansili * author: fansili
* time: 2024/3/8 19:11 * time: 2024/3/8 19:11
*/ */
@ -52,7 +56,9 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
public <T extends Object, E extends Throwable> void onError(RetryContext context, public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) { RetryCallback<T, E> callback, Throwable throwable) {
log.warn("重试异常:" + context.getRetryCount(), throwable); log.warn("重试异常:" + context.getRetryCount(), throwable);
}; }
;
}) })
.build(); .build();
@ -92,6 +98,42 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
} }
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) { private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 获取配置
YiYanOptions useOptions = getYiYanOptions(prompt);
// 创建 request
// tip: 百度的 system 不在 message 里面
// tip百度的 message 只有 user assistant
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
// 获取 user assistant
List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream()
// 过滤 system
.filter(msg -> MessageType.SYSTEM != msg.getMessageType())
.map(msg -> new YiYanChatCompletionRequest.Message()
.setRole(msg.getMessageType().getValue())
.setContent(msg.getContent())
).toList();
// 获取 system
String systemPrompt = prompt.getInstructions().stream()
.filter(msg -> MessageType.SYSTEM == msg.getMessageType())
.map(Message::getContent)
.collect(Collectors.joining());
YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
// 复制 qianWenOptions 属性取 request这里 options 属性和 request 基本保持一致
// top: 由于遵循 spring-ai规范支持在构建client的时候传入默认的 chatOptions
BeanUtil.copyProperties(useOptions, request);
request.setTop_p(useOptions.getTopP());
request.setMax_output_tokens(useOptions.getMaxOutputTokens());
request.setTemperature(useOptions.getTemperature());
request.setSystem(systemPrompt);
// 设置 stream
request.setStream(stream);
return request;
}
private @NotNull YiYanOptions getYiYanOptions(Prompt prompt) {
// 两个都为null 则没有配置文件 // 两个都为null 则没有配置文件
if (yiYanOptions == null && prompt.getOptions() == null) { if (yiYanOptions == null && prompt.getOptions() == null) {
throw new ChatException("ChatOptions 未配置参数!"); throw new ChatException("ChatOptions 未配置参数!");
@ -106,19 +148,7 @@ public class YiYanChatClient implements ChatClient, StreamingChatClient {
throw new ChatException("Prompt 传入的不是 YiYanOptions!"); throw new ChatException("Prompt 传入的不是 YiYanOptions!");
} }
// 转换 YiYanOptions // 转换 YiYanOptions
YiYanOptions qianWenOptions = (YiYanOptions) options; YiYanOptions useOptions = (YiYanOptions) options;
// 创建 request return useOptions;
List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream().map(
msg -> new YiYanChatCompletionRequest.Message()
.setRole(msg.getMessageType().getValue())
.setContent(msg.getContent())
).toList();
YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
// 复制 qianWenOptions 属性取 request这里 options 属性和 request 基本保持一致
// top: 由于遵循 spring-ai规范支持在构建client的时候传入默认的 chatOptions
BeanUtil.copyProperties(qianWenOptions, request);
// 设置 stream
request.setStream(stream);
return request;
} }
} }

View File

@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.chatyiyan;
import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions; import cn.iocoder.yudao.framework.ai.chat.prompt.ChatOptions;
import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest; import cn.iocoder.yudao.framework.ai.chatyiyan.api.YiYanChatCompletionRequest;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
@ -40,7 +39,7 @@ public class YiYanOptions implements ChatOptions {
* 2默认0.8取值范围 [0, 1.0] * 2默认0.8取值范围 [0, 1.0]
* 必填 * 必填
*/ */
private Float top_p; private Float topP;
/** /**
* 通过对已生成的token增加惩罚减少重复生成的现象说明 * 通过对已生成的token增加惩罚减少重复生成的现象说明
* 1值越大表示惩罚越大 * 1值越大表示惩罚越大
@ -84,7 +83,7 @@ public class YiYanOptions implements ChatOptions {
* 指定模型最大输出token数范围[2, 2048] * 指定模型最大输出token数范围[2, 2048]
* 必填 * 必填
*/ */
private Integer max_output_tokens; private Integer maxOutputTokens;
/** /**
* 指定响应内容的格式说明 * 指定响应内容的格式说明
* 1可选值 * 1可选值
@ -122,12 +121,12 @@ public class YiYanOptions implements ChatOptions {
@Override @Override
public Float getTopP() { public Float getTopP() {
return top_p; return topP;
} }
@Override @Override
public void setTopP(Float topP) { public void setTopP(Float topP) {
this.top_p = topP; this.topP = topP;
} }
// 百度么有 topK // 百度么有 topK
@ -139,6 +138,5 @@ public class YiYanOptions implements ChatOptions {
@Override @Override
public void setTopK(Integer topK) { public void setTopK(Integer topK) {
} }
} }

View File

@ -1,5 +1,8 @@
package cn.iocoder.yudao.framework.ai.chat; package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.chat.messages.Message;
import cn.iocoder.yudao.framework.ai.chat.messages.SystemMessage;
import cn.iocoder.yudao.framework.ai.chat.messages.UserMessage;
import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt; import cn.iocoder.yudao.framework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient; import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel; import cn.iocoder.yudao.framework.ai.chatyiyan.YiYanChatModel;
@ -9,11 +12,13 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner; import java.util.Scanner;
/** /**
* chat 文心一言 * chat 文心一言
* * <p>
* author: fansili * author: fansili
* time: 2024/3/12 20:59 * time: 2024/3/12 20:59
*/ */
@ -29,18 +34,36 @@ public class YiYanChatTests {
YiYanChatModel.ERNIE4_3_5_8K, YiYanChatModel.ERNIE4_3_5_8K,
86400 86400
); );
yiYanChatClient = new YiYanChatClient(yiYanApi, new YiYanOptions().setMax_output_tokens(2048)); YiYanOptions yiYanOptions = new YiYanOptions();
yiYanOptions.setMaxOutputTokens(2048);
yiYanOptions.setTopP(0.6f);
yiYanOptions.setTemperature(0.85f);
yiYanChatClient = new YiYanChatClient(
yiYanApi,
yiYanOptions
);
} }
@Test @Test
public void callTest() { public void callTest() {
ChatResponse call = yiYanChatClient.call(new Prompt("什么编程语言最好?"));
// tip: 百度的message 有特殊规则(最后一个message为当前请求的信息前面的message为历史对话信息)
// tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
messages.add(new UserMessage("长沙怎么样?"));
ChatResponse call = yiYanChatClient.call(new Prompt(messages));
System.err.println(call.getResult()); System.err.println(call.getResult());
} }
@Test @Test
public void streamTest() { public void streamTest() {
Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt("用java帮我写一个快排算法")); List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
messages.add(new UserMessage("长沙怎么样?"));
Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent())); fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
// 阻止退出 // 阻止退出
Scanner scanner = new Scanner(System.in); Scanner scanner = new Scanner(System.in);