【优化】AI:依赖从文心一言,使用 spring ai 替代接入

This commit is contained in:
YunaiV 2024-06-29 18:33:16 +08:00
parent 7dfa7a1573
commit c5db930603
16 changed files with 89 additions and 841 deletions

View File

@ -7,7 +7,6 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -33,6 +32,7 @@ import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
@ -191,8 +191,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO @fan增加一个 model
return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
// TODO 芋艿貌似 model 只要一设置就报错
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);

View File

@ -43,6 +43,13 @@
<artifactId>yudao-common</artifactId>
</dependency>
<!-- TODO 芋艿:等 spring-ai 官方发布后,需要把 groupId 改下 -->
<dependency>
<groupId>group.springframework.ai</groupId>
<artifactId>spring-ai-qianfan-spring-boot-starter</artifactId>
<version>1.1.0</version>
</dependency>
<!-- 阿里云 通义千问 -->
<dependency>
<groupId>com.alibaba</groupId>

View File

@ -11,9 +11,6 @@ import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
@ -78,27 +75,6 @@ public class YudaoAiAutoConfiguration {
);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.yiyan.enable", havingValue = "true")
public YiYanChatClient yiYanChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.YiYanProperties yiYanProperties = yudaoAiProperties.getYiyan();
// 转换配置
YiYanChatOptions yiYanOptions = new YiYanChatOptions();
// yiYanOptions.setTopK(yiYanProperties.getTopK()); TODO 芋艿后续弄
yiYanOptions.setTopP(yiYanProperties.getTopP());
yiYanOptions.setTemperature(yiYanProperties.getTemperature());
yiYanOptions.setMaxOutputTokens(yiYanProperties.getMaxTokens());
return new YiYanChatClient(
new YiYanApi(
yiYanProperties.getAppKey(),
yiYanProperties.getSecretKey(),
yiYanProperties.getModel(),
yiYanProperties.getRefreshTokenSecondTime()
),
yiYanOptions
);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {

View File

@ -3,7 +3,6 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import lombok.Data;
import lombok.experimental.Accessors;
import org.springframework.ai.autoconfigure.openai.OpenAiImageProperties;
@ -23,7 +22,6 @@ public class YudaoAiProperties {
private QianWenProperties qianwen;
private XingHuoProperties xinghuo;
private YiYanProperties yiyan;
private OpenAiImageProperties openAiImage;
private MidjourneyProperties midjourney;
private SunoProperties suno;
@ -63,7 +61,6 @@ public class YudaoAiProperties {
}
@Data
@Accessors(chain = true)
public static class XingHuoProperties extends ChatProperties {
private String appId;
@ -73,28 +70,6 @@ public class YudaoAiProperties {
}
@Data
@Accessors(chain = true)
public static class YiYanProperties extends ChatProperties {
/**
* appKey
*/
private String appKey;
/**
* secretKey
*/
private String secretKey;
/**
* 模型
*/
private YiYanChatModel model = YiYanChatModel.ERNIE4_3_5_8K;
/**
* token 刷新时间(默认 86400 = 24小时)
*/
private int refreshTokenSecondTime = 86400;
}
@Data
public static class MidjourneyProperties {

View File

@ -16,10 +16,11 @@ import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.ollama.OllamaChatModel;
@ -29,8 +30,12 @@ import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.qianfan.QianFanChatModel;
import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import java.util.List;
@ -75,7 +80,7 @@ public class AiClientFactoryImpl implements AiClientFactory {
case OLLAMA:
return SpringUtil.getBean(OllamaChatModel.class);
case YI_YAN:
return SpringUtil.getBean(YiYanChatClient.class);
return SpringUtil.getBean(QianFanChatModel.class);
case XING_HUO:
return SpringUtil.getBean(XingHuoChatClient.class);
case QIAN_WEN:
@ -153,15 +158,16 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#yiYanChatClient(YudaoAiProperties)}
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private static YiYanChatClient buildYiYanChatClient(String key) {
private static QianFanChatModel buildYiYanChatClient(String key) {
// TODO 芋艿貌似目前设置request 势必会报错
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
YiYanApi yiYanApi = new YiYanApi(appKey, secretKey, YiYanApi.DEFAULT_CHAT_MODEL, 0);
return new YiYanChatClient(yiYanApi);
QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
return new QianFanChatModel(qianFanApi);
}
/**

View File

@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.core.model.tongyi;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message;
@ -59,7 +58,7 @@ public class QianWenChatClient implements ChatModel, StreamingChatModel {
public final RetryTemplate retryTemplate = RetryTemplate.builder()
// 最大重试次数 10
.maxAttempts(10)
.retryOn(YiYanApiException.class)
.retryOn(Exception.class) // TODO 芋艿临时这么写
// 最大重试5次第一次间隔3000ms第二次3000ms * 2第三次3000ms * 3以此类推最大间隔3 * 60000ms
.exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() {

View File

@ -1,159 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionRequest;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionResponse;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import java.time.Duration;
import java.util.List;
import java.util.stream.Collectors;
/**
* 文心一言的 {@link ChatClient} 实现类
*
* @author fansili
*/
@Slf4j
public class YiYanChatClient implements ChatModel, StreamingChatModel {
private final YiYanApi yiYanApi;
private YiYanChatOptions defaultOptions;
// TODO @fan参考 OpenAiChatClient 调整下 retryTemplate使用 RetryUtils.DEFAULT_RETRY_TEMPLATE加允许传入
public YiYanChatClient(YiYanApi yiYanApi) {
this.yiYanApi = yiYanApi;
// TODO @fan这个情况是不是搞个 defaultOptionsOpenAiChatOptions.builder().withModel(OpenAiApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build()
}
public YiYanChatClient(YiYanApi yiYanApi, YiYanChatOptions defaultOptions) {
Assert.notNull(yiYanApi, "OllamaApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.yiYanApi = yiYanApi;
this.defaultOptions = defaultOptions;
}
public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(YiYanApiException.class)
.exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() {
@Override
public <T, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
log.warn("重试异常:" + context.getRetryCount(), throwable);
}
})
.build();
@Override
public ChatResponse call(Prompt prompt) {
YiYanChatCompletionRequest request = createRequest(prompt, false);
return this.retryTemplate.execute(ctx -> {
// 发送请求
ResponseEntity<YiYanChatCompletionResponse> response = yiYanApi.chatCompletionEntity(request);
// 获取结果封装 ChatResponse
YiYanChatCompletionResponse chatCompletion = response.getBody();
if (chatCompletion == null) {
log.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
} else {
// TODO @fanchatResponseMetadata参考 OpenAiChatResponseMetadata.from(completionEntity.getBody())
return new ChatResponse(List.of(new Generation(chatCompletion.getResult())));
}
});
}
@Override
public ChatOptions getDefaultOptions() {
// TODO 芋艿需要跟进下
throw new UnsupportedOperationException();
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
YiYanChatCompletionRequest request = this.createRequest(prompt, true);
return this.retryTemplate.execute(ctx -> {
// 调用 callWithFunctionSupport 发送请求
Flux<YiYanChatCompletionResponse> response = this.yiYanApi.chatCompletionStream(request);
return response.map(chunk -> {
// TODO @fanChatResponseMetadata chatResponseMetadata
return new ChatResponse(List.of(new Generation(chunk.getResult())));
});
});
}
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 文档system 是独立字段
// 1.1 获取 user assistant
List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream()
// 过滤 system
.filter(msg -> MessageType.SYSTEM != msg.getMessageType())
.map(message -> new YiYanChatCompletionRequest.Message()
.setRole(message.getMessageType().getValue()).setContent(message.getContent())
).toList();
// 1.2 获取 system
String systemPrompt = prompt.getInstructions().stream()
.filter(message -> MessageType.SYSTEM == message.getMessageType())
.map(Message::getContent)
.collect(Collectors.joining());
// 3. 创建 request
YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
// 复制 YiYanOptions 属性 request 这里 options 属性和 request 基本保持一致
YiYanChatOptions useOptions = getYiYanOptions(prompt);
BeanUtil.copyProperties(useOptions, request);
request.setTopP(useOptions.getTopP())
.setMaxOutputTokens(useOptions.getMaxOutputTokens())
.setTemperature(useOptions.getTemperature())
.setSystem(systemPrompt)
.setStream(stream);
return request;
}
// TODO @fanOptions 的处理参考下 OpenAiChatClient createRequest
private YiYanChatOptions getYiYanOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (defaultOptions == null && prompt.getOptions() == null) {
// TODO @fanIllegalArgumentException 参数更好哈
throw new ChatException("ChatOptions 未配置参数!");
}
// 优先使用 Prompt 里面的 ChatOptions
ChatOptions options = defaultOptions;
if (prompt.getOptions() != null) {
options = (ChatOptions) prompt.getOptions();
}
// Prompt 里面是一个 ChatOptions用户可以随意传入这里做一下判断
if (!(options instanceof YiYanChatOptions)) {
// TODO @fanIllegalArgumentException 参数更好哈
// TODO @fan需要兼容 ChatOptionsBuilder 创建出来的
throw new ChatException("Prompt 传入的不是 YiYanOptions!");
}
// 转换 YiYanOptions
return (YiYanChatOptions) options;
}
}

View File

@ -1,91 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionRequest;
import lombok.Data;
import org.springframework.ai.chat.prompt.ChatOptions;
import java.util.List;
/**
* 文心一言的 {@link ChatOptions} 实现类
*
* 字段说明参考 <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t">ERNIE-4.0-8K</a>
*
* @author fansili
*/
@Data
public class YiYanChatOptions implements ChatOptions {
/**
* functions 函数
*/
private List<YiYanChatCompletionRequest.Function> functions;
/**
* temperature
*/
private Float temperature;
/**
* topP
*/
private Float topP;
/**
* 通过对已生成的token增加惩罚减少重复生成的现象
*/
private Float penaltyScore;
/**
* stream 模式请求
*/
private Boolean stream;
/**
* system 提示
*/
private String system;
/**
* 生成停止标识当模型生成结果以stop中某个元素结尾时停止文本生成
*/
private List<String> stop;
/**
* 是否强制关闭实时搜索功能
*/
private Boolean disableSearch;
/**
* 是否开启上角标返回
*/
private Boolean enableCitation;
/**
* 输出最大 token
*/
private Integer maxOutputTokens;
/**
* 响应格式 textjson_object
*/
private String responseFormat;
/**
* 用户id
*/
private String userId;
/**
* 在函数调用场景下提示大模型选择指定的函数非强制说明指定的函数名必须在functions中存在
* tip: ERNIE-4.0-8K 模型没有这个字段
*/
private String toolChoice;
@Override
public Float getTemperature() {
return this.temperature;
}
@Override
public Float getTopP() {
return topP;
}
/**
* 百度么有 topK
*/
@Override
public Integer getTopK() {
return null;
}
}

View File

@ -1,106 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.ResponseEntity;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* 文心一言 API
*
* @author fansili
*/
public class YiYanApi {
private static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
private static final String AUTH_2_TOKEN_URI = "/oauth/2.0/token";
public static final YiYanChatModel DEFAULT_CHAT_MODEL = YiYanChatModel.ERNIE4_0;
private final String appKey;
private final String secretKey;
/**
* TODO fan这个是不是要有个刷新机制哈如果目前不需要可以删除掉 refreshTokenSecondTime整体更简洁
*/
private final String token;
/**
* token 刷新时间()
*/
private int refreshTokenSecondTime;
/**
* 发送请求 webClient
*/
private final WebClient webClient;
/**
* 使用的模型
*/
private final YiYanChatModel useChatModel;
// TODO fan看看是不是去掉 refreshTokenSecondTime 字段
public YiYanApi(String appKey, String secretKey, YiYanChatModel useChatModel, int refreshTokenSecondTime) {
this.appKey = appKey;
this.secretKey = secretKey;
this.useChatModel = useChatModel;
this.refreshTokenSecondTime = refreshTokenSecondTime;
this.webClient = WebClient.builder().baseUrl(DEFAULT_BASE_URL).build();
// 获取访问令牌
token = getToken();
}
/**
* 获得访问令牌
*
* @see <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Ilkkrb0i5">文档地址</>
* @return 访问令牌
*/
private String getToken() {
ResponseEntity<YiYanAuthResponse> response = this.webClient.post()
.uri(uriBuilder -> uriBuilder.path(AUTH_2_TOKEN_URI)
.queryParam("grant_type", "client_credentials")
.queryParam("client_id", appKey)
.queryParam("client_secret", secretKey)
.build()
)
.retrieve()
.toEntity(YiYanAuthResponse.class)
.block();
// 检查请求状态
// TODO @fan可以使用 response.getStatusCode().is2xxSuccessful()
if (HttpStatusCode.valueOf(200) != response.getStatusCode()
|| response.getBody() == null) {
// TODO @fan可以使用 IllegalStateException 替代另外最好打印下返回方便排错
throw new YiYanApiException("一言认证失败! apihttps://aip.baidubce.com/oauth/2.0/token 请检查 client_id、client_secret 是否正确!");
}
return response.getBody().getAccess_token();
}
public ResponseEntity<YiYanChatCompletionResponse> chatCompletionEntity(YiYanChatCompletionRequest request) {
// TODO: 2024/3/10 小范 这里错误信息返回的结构不一样
// {"error_code":17,"error_msg":"Open api daily request limit reached"}
return this.webClient.post()
.uri(uriBuilder
-> uriBuilder.path(useChatModel.getUri())
.queryParam("access_token", token)
.build())
.body(Mono.just(JsonUtils.toJsonString(request)), String.class)
.retrieve()
.toEntity(YiYanChatCompletionResponse.class)
.block();
}
public Flux<YiYanChatCompletionResponse> chatCompletionStream(YiYanChatCompletionRequest request) {
return this.webClient.post()
.uri(uriBuilder
-> uriBuilder.path(useChatModel.getUri())
.queryParam("access_token", token)
.build())
.body(Mono.just(request), YiYanChatCompletionRequest.class)
.retrieve()
.bodyToFlux(YiYanChatCompletionResponse.class);
}
}

View File

@ -1,48 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.Data;
// TODO @fan字段驼峰字段注释都可以删除贴个链接就好
/**
* 获取文心一言的 access_token Response
*
* @author fansili
*/
@Data
public class YiYanAuthResponse {
/**
* 访问凭证
*/
private String access_token;
/**
* 有效期Access Token的有效期
* 说明单位是秒有效期30天
*/
private int expires_in;
/**
* 错误码说明响应失败时返回该字段成功时不返回
*/
private String error;
/**
* 错误描述信息帮助理解和解决发生的错误
* 说明响应失败时返回该字段成功时不返回
*/
private String error_description;
/**
* 暂时未使用可忽略
*/
private String session_key;
/**
* 暂时未使用可忽略
*/
private String refresh_token;
/**
* 暂时未使用可忽略
*/
private String scope;
/**
* 暂时未使用可忽略
*/
private String session_secret;
}

View File

@ -1,154 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
/**
* 文心一言 Completion Request
*
* 百度千帆文档https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
*
* @author fansili
*/
@Data
public class YiYanChatCompletionRequest {
public YiYanChatCompletionRequest(List<Message> messages) {
this.messages = messages;
}
/**
* 聊天上下文信息
*/
private List<Message> messages;
/**
* functions 函数
*/
private List<Function> functions;
/**
* temperature
*/
private Float temperature;
/**
* topP
*/
@JsonProperty("top_p")
private Float topP;
/**
* 通过对已生成的token增加惩罚减少重复生成的现象
*/
@JsonProperty("penalty_score")
private Float penaltyScore;
/**
* stream 模式
*/
private Boolean stream;
/**
* system 预设角色
*/
private String system;
/**
* 生成停止标识当模型生成结果以stop中某个元素结尾时停止文本生成
*/
private List<String> stop;
/**
* 是否强制关闭实时搜索功能
*/
@JsonProperty("disable_search")
private Boolean disableSearch;
/**
* 是否开启上角标返回
*/
@JsonProperty("enable_citation")
private Boolean enableCitation;
/**
* 最大输出 token
*/
@JsonProperty("max_output_tokens")
private Integer maxOutputTokens;
/**
* 返回格式 textjson_object
*/
@JsonProperty("response_format")
private String responseFormat;
/**
* 用户 id
*/
@JsonProperty("user_id")
private String userId;
/**
* 在函数调用场景下提示大模型选择指定的函数非强制说明指定的函数名必须在functions中存在
* tip: ERNIE-4.0-8K 模型没有这个字段
*/
@JsonProperty("tool_choice")
private String toolChoice;
@Data
public static class Message {
private String role;
private String content;
}
@Data
public static class ToolChoice {
/**
* 指定工具类型function
* 必填:
*/
private String type;
/**
* 指定要使用的函数
* 必填:
*/
private Function function;
/**
* 指定要使用的函数名
* 必填:
*/
private String name;
}
@Data
public static class Function {
/**
* 函数名
* 必填:
*/
private String name;
/**
* 函数描述
* 必填:
*/
private String description;
/**
* 函数请求参数说明
* 1JSON Schema 格式参考JSON Schema描述
* 2如果函数没有请求参数parameters值格式如下
* {"type": "object","properties": {}}
* 必填:
*/
private String parameters;
/**
* 函数响应参数JSON Schema 格式参考JSON Schema描述
* 必填:
*/
private String responses;
/**
* function调用的一些历史示例说明
* 1可以提供正例正常触发和反例无需触发的example
* ·正例从历史请求数据中获取
* ·反例
* 当role = user不会触发请求的query
* 当role = assistant有固定的格式function_call的name为空arguments是空对象:"{}"thought可以填固定的:"我不需要调用任何工具"
* 2兼容之前的 List(example) 格式
*/
private String examples;
}
}

View File

@ -1,92 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.Data;
/**
* 文心一言 Completion Response
*
* 百度链接: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
*
* @author fansili
*/
@Data
public class YiYanChatCompletionResponse {
/**
* 本轮对话的id
*/
private String id;
/**
* 回包类型chat.completion多轮对话返回
*/
private String object;
/**
* 时间戳
*/
private int created;
/**
* 表示当前子句的序号只有在流式接口模式下会返回该字段
*/
private int sentence_id;
/**
* 表示当前子句是否是最后一句只有在流式接口模式下会返回该字段
*/
private boolean is_end;
/**
* 当前生成的结果是否被截断
*/
private boolean is_truncated;
/**
* 输出内容标识说明
* · normal输出内容完全由大模型生成未触发截断替换
* · stop输出结果命中入参stop中指定的字段后被截断
* · length达到了最大的token数根据EB返回结果is_truncated来截断
* · content_filter输出内容被截断兜底替换为**
*/
private String finish_reason;
/**
* 搜索数据当请求参数enable_citation为true并且触发搜索时会返回该字段
*/
private String search_info;
/**
* 对话返回结果
*/
private String result;
/**
* 表示用户输入是否存在安全是否关闭当前会话清理历史会话信息
* true表示用户输入存在安全风险建议关闭当前会话清理历史会话信息
* false表示用户输入无安全风险
*/
private boolean need_clear_history;
/**
* 说明
* · 0正常返回
* · 其他非正常
*/
private int flag;
/**
* 当need_clear_history为true时此字段会告知第几轮对话有敏感信息如果是当前问题ban_round=-1
*/
private int ban_round;
/**
* token统计信息
*/
private Usage usage;
@Data
public static class Usage {
/**
* 问题tokens数
*/
private int prompt_tokens;
/**
* 回答tokens数
*/
private int completion_tokens;
/**
* tokens总数
*/
private int total_tokens;
}
}

View File

@ -1,42 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 文心一言的模型枚举
*
* 可参考 <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t">百度文档</>
*
* @author fansili
*/
@Getter
@AllArgsConstructor
public enum YiYanChatModel {
ERNIE4_0("ERNIE 4.0", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"),
ERNIE4_3_5_8K("ERNIE-3.5-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"),
ERNIE4_3_5_8K_0205("ERNIE-3.5-8K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205"),
ERNIE4_3_5_8K_1222("ERNIE-3.5-8K-1222", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222"),
ERNIE4_BOT_8K("ERNIE-Bot-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"),
ERNIE4_3_5_4K_0205("ERNIE-3.5-4K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205"),
;
/**
* 模型名
*/
private final String model;
/**
* API URL
*/
private final String uri;
public static YiYanChatModel valueOfModel(String model) {
for (YiYanChatModel modelEnum : YiYanChatModel.values()) {
if (modelEnum.getModel().equals(model)) {
return modelEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
}

View File

@ -1,16 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.exception;
/**
* 一言 api 调用异常
*/
public class YiYanApiException extends RuntimeException {
public YiYanApiException(String message) {
super(message);
}
public YiYanApiException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -1,21 +1,21 @@
package cn.iocoder.yudao.framework.ai.chat;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import org.junit.Before;
import org.junit.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
//import org.junit.Before;
//import org.junit.Test;
//import org.springframework.ai.chat.messages.Message;
//import org.springframework.ai.chat.messages.SystemMessage;
//import org.springframework.ai.chat.messages.UserMessage;
//import org.springframework.ai.chat.model.ChatResponse;
//import org.springframework.ai.chat.prompt.Prompt;
//import reactor.core.publisher.Flux;
//
//import java.util.ArrayList;
//import java.util.List;
//import java.util.Scanner;
// TODO 芋艿整理单测
/**
@ -26,49 +26,49 @@ import java.util.Scanner;
*/
public class YiYanChatTests {
private YiYanChatClient yiYanChatClient;
@Before
public void setup() {
YiYanApi yiYanApi = new YiYanApi(
"x0cuLZ7XsaTCU08vuJWO87Lg",
"R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK",
YiYanChatModel.ERNIE4_3_5_8K,
86400
);
YiYanChatOptions yiYanOptions = new YiYanChatOptions();
yiYanOptions.setMaxOutputTokens(2048);
yiYanOptions.setTopP(0.6f);
yiYanOptions.setTemperature(0.85f);
yiYanChatClient = new YiYanChatClient(
yiYanApi,
yiYanOptions
);
}
@Test
public void callTest() {
// tip: 百度的message 有特殊规则(最后一个message为当前请求的信息前面的message为历史对话信息)
// tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
messages.add(new UserMessage("长沙怎么样?"));
ChatResponse call = yiYanChatClient.call(new Prompt(messages));
System.err.println(call.getResult());
}
@Test
public void streamTest() {
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
messages.add(new UserMessage("长沙怎么样?"));
Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
// 阻止退出
Scanner scanner = new Scanner(System.in);
scanner.nextLine();
}
// private YiYanChatClient yiYanChatClient;
//
// @Before
// public void setup() {
// YiYanApi yiYanApi = new YiYanApi(
// "x0cuLZ7XsaTCU08vuJWO87Lg",
// "R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK",
// YiYanChatModel.ERNIE4_3_5_8K,
// 86400
// );
// YiYanChatOptions yiYanOptions = new YiYanChatOptions();
// yiYanOptions.setMaxOutputTokens(2048);
// yiYanOptions.setTopP(0.6f);
// yiYanOptions.setTemperature(0.85f);
// yiYanChatClient = new YiYanChatClient(
// yiYanApi,
// yiYanOptions
// );
// }
//
// @Test
// public void callTest() {
//
// // tip: 百度的message 有特殊规则(最后一个message为当前请求的信息前面的message为历史对话信息)
// // tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
// List<Message> messages = new ArrayList<>();
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// ChatResponse call = yiYanChatClient.call(new Prompt(messages));
// System.err.println(call.getResult());
// }
//
// @Test
// public void streamTest() {
// List<Message> messages = new ArrayList<>();
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
// fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
// // 阻止退出
// Scanner scanner = new Scanner(System.in);
// scanner.nextLine();
// }
}

View File

@ -160,19 +160,11 @@ spring:
gemini:
project-id: 1 # TODO 芋艿:缺配置
location: 2
qianfan: # 文心一言
api-key: x0cuLZ7XsaTCU08vuJWO87Lg
secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
yudao.ai:
yiyan:
enable: true
aiPlatform: YI_YAN # TODO @fan建议每个都独立配置属性类
max-tokens: 1500
temperature: 0.85
topP: 0.8
topK: 0
appKey: x0cuLZ7XsaTCU08vuJWO87Lg
secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
refreshTokenSecondTime: 86400
model: ERNIE4_3_5_8K
xinghuo:
enable: true
aiPlatform: XING_HUO # TODO @fan建议每个都独立配置属性类