【代码评审】AI:文心一言的接入调整

This commit is contained in:
YunaiV 2024-05-18 09:34:17 +08:00
parent 276ef98ff1
commit 04021ce068
25 changed files with 300 additions and 308 deletions

View File

@ -5,7 +5,7 @@ import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.StreamingChatClient; import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.models.tongyi.QianWenChatClient; import org.springframework.ai.models.tongyi.QianWenChatClient;
import org.springframework.ai.models.xinghuo.XingHuoChatClient; import org.springframework.ai.models.xinghuo.XingHuoChatClient;
import org.springframework.ai.models.yiyan.YiYanChatClient; import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
@ -27,7 +27,7 @@ public class AiChatClientFactory {
public ChatClient getChatClient(AiPlatformEnum platformEnum) { public ChatClient getChatClient(AiPlatformEnum platformEnum) {
if (AiPlatformEnum.QIAN_WEN == platformEnum) { if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class); return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) { } else if (AiPlatformEnum.YIYAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class); return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) { } else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class); return applicationContext.getBean(XingHuoChatClient.class);
@ -42,7 +42,7 @@ public class AiChatClientFactory {
// } // }
if (AiPlatformEnum.QIAN_WEN == platformEnum) { if (AiPlatformEnum.QIAN_WEN == platformEnum) {
return applicationContext.getBean(QianWenChatClient.class); return applicationContext.getBean(QianWenChatClient.class);
} else if (AiPlatformEnum.YI_YAN == platformEnum) { } else if (AiPlatformEnum.YIYAN == platformEnum) {
return applicationContext.getBean(YiYanChatClient.class); return applicationContext.getBean(YiYanChatClient.class);
} else if (AiPlatformEnum.XING_HUO == platformEnum) { } else if (AiPlatformEnum.XING_HUO == platformEnum) {
return applicationContext.getBean(XingHuoChatClient.class); return applicationContext.getBean(XingHuoChatClient.class);

View File

@ -1,20 +1,14 @@
package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation; package cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO; import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import com.baomidou.mybatisplus.annotation.FieldFill;
import com.baomidou.mybatisplus.annotation.TableField;
import com.fhs.core.trans.anno.Trans; import com.fhs.core.trans.anno.Trans;
import com.fhs.core.trans.constant.TransType; import com.fhs.core.trans.constant.TransType;
import com.fhs.core.trans.vo.VO; import com.fhs.core.trans.vo.VO;
import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.LocalTime;
@Schema(description = "管理后台 - AI 聊天会话 Response VO") @Schema(description = "管理后台 - AI 聊天会话 Response VO")
@Data @Data
@ -58,7 +52,7 @@ public class AiChatConversationRespVO implements VO {
@Schema(description = "上下文的最大 Message 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "10") @Schema(description = "上下文的最大 Message 数量", requiredMode = Schema.RequiredMode.REQUIRED, example = "10")
private Integer maxContexts; private Integer maxContexts;
@Schema(description = "最后更新时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "2024-05-16") @Schema(description = "最后更新时间", requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime updateTime; private LocalDateTime updateTime;
// ========== 关联 role 信息 ========== // ========== 关联 role 信息 ==========

View File

@ -4,7 +4,7 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
import org.springframework.ai.models.xinghuo.XingHuoChatModel; import org.springframework.ai.models.xinghuo.XingHuoChatModel;
import org.springframework.ai.models.yiyan.YiYanChatModel; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
/** /**
* modal config * modal config

View File

@ -150,9 +150,9 @@ public class AiChatServiceImpl implements AiChatService {
// 1.3 user message 新发送消息 // 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent())); chatMessages.add(new UserMessage(sendReqVO.getContent()));
// 2. 构建 ChatOptions 对象 // 2. 构建 ChatOptions 对象 TODO 芋艿临时注释掉等文心一言兼容了
ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build(); // ChatOptions chatOptions = ChatOptionsBuilder.builder().withTemperature(conversation.getTemperature().floatValue()).build();
return new Prompt(chatMessages, chatOptions); return new Prompt(chatMessages, null);
} }
private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model, private AiChatMessageDO createChatMessage(Long conversationId, AiChatModelDO model,

View File

@ -1,6 +1,7 @@
package cn.iocoder.yudao.module.ai.service.model; package cn.iocoder.yudao.module.ai.service.model;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum; import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -15,6 +16,8 @@ import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
@ -134,7 +137,7 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Override @Override
public List<String> getChatRoleCategoryList() { public List<String> getChatRoleCategoryList() {
List<AiChatRoleDO> list = chatRoleMapper.selectListGroupByCategory(CommonStatusEnum.ENABLE.getStatus()); List<AiChatRoleDO> list = chatRoleMapper.selectListGroupByCategory(CommonStatusEnum.ENABLE.getStatus());
return convertList(list.stream().filter(Objects::nonNull).collect(Collectors.toList()), AiChatRoleDO::getCategory); return convertList(list, AiChatRoleDO::getCategory, role -> StrUtil.isNotBlank(role.getCategory()));
} }
} }

View File

@ -10,25 +10,18 @@
</parent> </parent>
<artifactId>yudao-spring-boot-starter-ai</artifactId> <artifactId>yudao-spring-boot-starter-ai</artifactId>
<!-- TODO 芋艿:这里需要进一步减少 -->
<dependencies> <dependencies>
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>1.0.3</version>
</dependency>
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-openai</artifactId>
<version>1.0.3</version>
</dependency>
<dependency> <dependency>
<groupId>io.springboot.ai</groupId> <groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId> <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
<version>1.0.3</version> <version>1.0.3</version>
</dependency> </dependency>
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>1.0.3</version>
</dependency>
<dependency> <dependency>
<groupId>cn.iocoder.boot</groupId> <groupId>cn.iocoder.boot</groupId>

View File

@ -8,9 +8,9 @@ import org.springframework.ai.models.tongyi.api.QianWenApi;
import org.springframework.ai.models.xinghuo.XingHuoChatClient; import org.springframework.ai.models.xinghuo.XingHuoChatClient;
import org.springframework.ai.models.xinghuo.XingHuoOptions; import org.springframework.ai.models.xinghuo.XingHuoOptions;
import org.springframework.ai.models.xinghuo.api.XingHuoApi; import org.springframework.ai.models.xinghuo.api.XingHuoApi;
import org.springframework.ai.models.yiyan.YiYanChatClient; import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import org.springframework.ai.models.yiyan.YiYanOptions; import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import org.springframework.ai.models.yiyan.api.YiYanApi; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.springframework.ai.models.midjourney.MidjourneyConfig; import org.springframework.ai.models.midjourney.MidjourneyConfig;
import org.springframework.ai.models.midjourney.MidjourneyMessage; import org.springframework.ai.models.midjourney.MidjourneyMessage;
import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi; import org.springframework.ai.models.midjourney.api.MidjourneyInteractionsApi;
@ -91,7 +91,7 @@ public class YudaoAiAutoConfiguration {
public YiYanChatClient yiYanChatClient(YudaoAiProperties yudaoAiProperties) { public YiYanChatClient yiYanChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.YiYanProperties yiYanProperties = yudaoAiProperties.getYiyan(); YudaoAiProperties.YiYanProperties yiYanProperties = yudaoAiProperties.getYiyan();
// 转换配置 // 转换配置
YiYanOptions yiYanOptions = new YiYanOptions(); YiYanChatOptions yiYanOptions = new YiYanChatOptions();
// yiYanOptions.setTopK(yiYanProperties.getTopK()); TODO 芋艿后续弄 // yiYanOptions.setTopK(yiYanProperties.getTopK()); TODO 芋艿后续弄
yiYanOptions.setTopP(yiYanProperties.getTopP()); yiYanOptions.setTopP(yiYanProperties.getTopP());
yiYanOptions.setTemperature(yiYanProperties.getTemperature()); yiYanOptions.setTemperature(yiYanProperties.getTemperature());

View File

@ -3,7 +3,7 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.models.xinghuo.XingHuoChatModel; import org.springframework.ai.models.xinghuo.XingHuoChatModel;
import org.springframework.ai.models.xinghuo.XingHuoOptions; import org.springframework.ai.models.xinghuo.XingHuoOptions;
import org.springframework.ai.models.yiyan.YiYanChatModel; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;

View File

@ -2,7 +2,7 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum; import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.models.xinghuo.XingHuoChatModel; import org.springframework.ai.models.xinghuo.XingHuoChatModel;
import org.springframework.ai.models.yiyan.YiYanChatModel; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageModelEnum;
import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum; import cn.iocoder.yudao.framework.ai.core.enums.OpenAiImageStyleEnum;
import lombok.Data; import lombok.Data;

View File

@ -16,8 +16,8 @@ public enum AiPlatformEnum {
OPENAI("OpenAI", "OpenAI"), OPENAI("OpenAI", "OpenAI"),
OLLAMA("Ollama", "Ollama"), OLLAMA("Ollama", "Ollama"),
YIYAN("YiYan", "文心一言"),
YI_YAN("yiyan", "一言"),
QIAN_WEN("qianwen", "千问"), QIAN_WEN("qianwen", "千问"),
XING_HUO("xinghuo", "星火"), XING_HUO("xinghuo", "星火"),
OPEN_AI_DALL("dall", "dall"), OPEN_AI_DALL("dall", "dall"),

View File

@ -0,0 +1,12 @@
/**
* model 接入各种大模型对标 https://github.com/spring-projects/spring-ai/tree/main/models
*
* 1. yiyan 百度文心一言
* 2. TODO 芋艿
* tongyi 阿里通义千问对标 spring-cloud-alibaba 提供的 ai
* 2.2
* 2.3 xinghuo 讯飞星火自己实现
* 2.4 openai OpenAIChatGPT拷贝 spring-ai 提供的 models/openai
* 2.5 midjourney Midjourney参考 https://github.com/novicezk/midjourney-proxy 实现
*/
package cn.iocoder.yudao.framework.ai.core.model;

View File

@ -0,0 +1,154 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionResponse;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionRequest;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import java.time.Duration;
import java.util.List;
import java.util.stream.Collectors;
/**
* 文心一言的 {@link ChatClient} 实现类
*
* @author fansili
*/
@Slf4j
public class YiYanChatClient implements ChatClient, StreamingChatClient {
private final YiYanApi yiYanApi;
private YiYanChatOptions defaultOptions;
// TODO @fan参考 OpenAiChatClient 调整下 retryTemplate使用 RetryUtils.DEFAULT_RETRY_TEMPLATE加允许传入
public YiYanChatClient(YiYanApi yiYanApi) {
this.yiYanApi = yiYanApi;
// TODO @fan这个情况是不是搞个 defaultOptionsOpenAiChatOptions.builder().withModel(OpenAiApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build()
}
public YiYanChatClient(YiYanApi yiYanApi, YiYanChatOptions defaultOptions) {
Assert.notNull(yiYanApi, "OllamaApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.yiYanApi = yiYanApi;
this.defaultOptions = defaultOptions;
}
public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(YiYanApiException.class)
.exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() {
@Override
public <T, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
log.warn("重试异常:" + context.getRetryCount(), throwable);
}
})
.build();
@Override
public ChatResponse call(Prompt prompt) {
YiYanChatCompletionRequest request = createRequest(prompt, false);
return this.retryTemplate.execute(ctx -> {
// 发送请求
ResponseEntity<YiYanChatCompletionResponse> response = yiYanApi.chatCompletionEntity(request);
// 获取结果封装 ChatResponse
YiYanChatCompletionResponse chatCompletion = response.getBody();
// TODO @fan为空时参考 OpenAiChatClient 的封装
// TODO @fanchatResponseMetadata参考 OpenAiChatResponseMetadata.from(completionEntity.getBody())
return new ChatResponse(List.of(new Generation(chatCompletion.getResult())));
});
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
YiYanChatCompletionRequest request = this.createRequest(prompt, true);
// TODO @fanreturn this.retryTemplate.execute(ctx -> {
// 调用 callWithFunctionSupport 发送请求
Flux<YiYanChatCompletionResponse> response = this.yiYanApi.chatCompletionStream(request);
// TODO @fan下面的 doOnComplete 是不是可以删除哈
response.doOnComplete(new Runnable() {
@Override
public void run() {
String a = ";";
}
});
return response.map(chunk -> {
// TODO @fanChatResponseMetadata chatResponseMetadata
return new ChatResponse(List.of(new Generation(chunk.getResult())));
});
}
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 文档system 是独立字段
// 1.1 获取 user assistant
List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream()
// 过滤 system
.filter(msg -> MessageType.SYSTEM != msg.getMessageType())
.map(message -> new YiYanChatCompletionRequest.Message()
.setRole(message.getMessageType().getValue()).setContent(message.getContent())
).toList();
// 1.2 获取 system
String systemPrompt = prompt.getInstructions().stream()
.filter(message -> MessageType.SYSTEM == message.getMessageType())
.map(Message::getContent)
.collect(Collectors.joining());
// 3. 创建 request
YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
// 复制 YiYanOptions 属性 request 这里 options 属性和 request 基本保持一致
YiYanChatOptions useOptions = getYiYanOptions(prompt);
BeanUtil.copyProperties(useOptions, request);
request.setTop_p(useOptions.getTopP())
.setMax_output_tokens(useOptions.getMaxOutputTokens())
.setTemperature(useOptions.getTemperature())
.setSystem(systemPrompt)
.setStream(stream);
return request;
}
// TODO @fanOptions 的处理参考下 OpenAiChatClient createRequest
private YiYanChatOptions getYiYanOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (defaultOptions == null && prompt.getOptions() == null) {
// TODO @fanIllegalArgumentException 参数更好哈
throw new ChatException("ChatOptions 未配置参数!");
}
// 优先使用 Prompt 里面的 ChatOptions
ChatOptions options = defaultOptions;
if (prompt.getOptions() != null) {
options = (ChatOptions) prompt.getOptions();
}
// Prompt 里面是一个 ChatOptions用户可以随意传入这里做一下判断
if (!(options instanceof YiYanChatOptions)) {
// TODO @fanIllegalArgumentException 参数更好哈
// TODO @fan需要兼容 ChatOptionsBuilder 创建出来的
throw new ChatException("Prompt 传入的不是 YiYanOptions!");
}
// 转换 YiYanOptions
return (YiYanChatOptions) options;
}
}

View File

@ -1,23 +1,22 @@
package org.springframework.ai.models.yiyan; package cn.iocoder.yudao.framework.ai.core.model.yiyan;
import org.springframework.ai.chat.prompt.ChatOptions; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionRequest;
import org.springframework.ai.models.yiyan.api.YiYanChatCompletionRequest;
import lombok.Data; import lombok.Data;
import lombok.experimental.Accessors; import org.springframework.ai.chat.prompt.ChatOptions;
import java.util.List; import java.util.List;
// TODO @fan字段命名penalty_score 类似的建议改成驼峰原则
// TODO @fan字段的注释可以都删除掉让用户 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 即可
/** /**
* 百度 问心一言 * 文心一言的 {@link ChatOptions} 实现类
* *
* 文档地址https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t * 字段说明参考 <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t">ERNIE-4.0-8K</a>
* *
* author: fansili * @author fansili
* time: 2024/3/16 19:33
*/ */
@Data @Data
@Accessors(chain = true) public class YiYanChatOptions implements ChatOptions {
public class YiYanOptions implements ChatOptions {
/** /**
* 一个可触发函数的描述列表说明 * 一个可触发函数的描述列表说明
@ -106,37 +105,24 @@ public class YiYanOptions implements ChatOptions {
*/ */
private String tool_choice; private String tool_choice;
//
// 以下兼容 spring-ai ChatOptions 暂时没有其他地方用到
@Override @Override
public Float getTemperature() { public Float getTemperature() {
return this.temperature; return this.temperature;
} }
// @Override
// public void setTemperature(Float temperature) {
// this.temperature = temperature;
// }
@Override @Override
public Float getTopP() { public Float getTopP() {
return topP; return topP;
} }
// @Override /**
// public void setTopP(Float topP) { * 百度么有 topK
// this.topP = topP; *
// } * @return null
*/
// 百度么有 topK
@Override @Override
public Integer getTopK() { public Integer getTopK() {
return null; return null;
} }
// @Override
// public void setTopK(Integer topK) {
// }
} }

View File

@ -1,8 +1,6 @@
package org.springframework.ai.models.yiyan.api; package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import org.springframework.ai.models.yiyan.YiYanChatModel; import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import org.springframework.ai.models.yiyan.exception.YiYanApiException;
import lombok.Data;
import org.springframework.http.HttpStatusCode; import org.springframework.http.HttpStatusCode;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
@ -10,47 +8,55 @@ import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
/** /**
* 文心一言 * 文心一言 API
* <p> *
* author: fansili * @author fansili
* time: 2024/3/8 21:47
*/ */
@Data
public class YiYanApi { public class YiYanApi {
private static final String DEFAULT_BASE_URL = "https://aip.baidubce.com"; private static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
private static final String AUTH_2_TOKEN_URI = "/oauth/2.0/token"; private static final String AUTH_2_TOKEN_URI = "/oauth/2.0/token";
public static final String DEFAULT_CHAT_MODEL = "ERNIE 4.0"; public static final String DEFAULT_CHAT_MODEL = YiYanChatModel.ERNIE4_0.getModel();
// 获取access_token流程 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Ilkkrb0i5 private final String appKey;
private String appKey; private final String secretKey;
private String secretKey; /**
private String token; * TODO fan这个是不是要有个刷新机制哈如果目前不需要可以删除掉 refreshTokenSecondTime整体更简洁
// token刷新时间() */
private final String token;
/**
* token 刷新时间()
*/
private int refreshTokenSecondTime; private int refreshTokenSecondTime;
// 发送请求 webClient /**
* 发送请求 webClient
*/
private final WebClient webClient; private final WebClient webClient;
// 使用的模型 /**
private YiYanChatModel useChatModel; * 使用的模型
*/
private final YiYanChatModel useChatModel;
public YiYanApi(String appKey, String secretKey, YiYanChatModel useChatModel, int refreshTokenSecondTime) { public YiYanApi(String appKey, String secretKey, YiYanChatModel useChatModel, int refreshTokenSecondTime) {
this.appKey = appKey; this.appKey = appKey;
this.secretKey = secretKey; this.secretKey = secretKey;
this.useChatModel = useChatModel; this.useChatModel = useChatModel;
this.refreshTokenSecondTime = refreshTokenSecondTime; this.refreshTokenSecondTime = refreshTokenSecondTime;
this.webClient = WebClient.builder().baseUrl(DEFAULT_BASE_URL).build();
this.webClient = WebClient.builder() // 获取访问令牌
.baseUrl(DEFAULT_BASE_URL)
.build();
token = getToken(); token = getToken();
} }
/**
* 获得访问令牌
*
* @see <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Ilkkrb0i5">文档地址</>
* @return 访问令牌
*/
private String getToken() { private String getToken() {
// 文档地址: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Ilkkrb0i5 ResponseEntity<YiYanAuthResponse> response = this.webClient.post()
ResponseEntity<YiYanAuthRes> response = this.webClient.post()
.uri(uriBuilder -> uriBuilder.path(AUTH_2_TOKEN_URI) .uri(uriBuilder -> uriBuilder.path(AUTH_2_TOKEN_URI)
.queryParam("grant_type", "client_credentials") .queryParam("grant_type", "client_credentials")
.queryParam("client_id", appKey) .queryParam("client_id", appKey)
@ -58,17 +64,19 @@ public class YiYanApi {
.build() .build()
) )
.retrieve() .retrieve()
.toEntity(YiYanAuthRes.class) .toEntity(YiYanAuthResponse.class)
.block(); .block();
// 检查请求状态 // 检查请求状态
if (HttpStatusCode.valueOf(200) != response.getStatusCode()) { // TODO @fan可以使用 response.getStatusCode().is2xxSuccessful()
if (HttpStatusCode.valueOf(200) != response.getStatusCode()
|| response.getBody() == null) {
// TODO @fan可以使用 IllegalStateException 替代另外最好打印下返回方便排错
throw new YiYanApiException("一言认证失败! apihttps://aip.baidubce.com/oauth/2.0/token 请检查 client_id、client_secret 是否正确!"); throw new YiYanApiException("一言认证失败! apihttps://aip.baidubce.com/oauth/2.0/token 请检查 client_id、client_secret 是否正确!");
} }
YiYanAuthRes body = response.getBody(); return response.getBody().getAccess_token();
return body.getAccess_token();
} }
public ResponseEntity<YiYanChatCompletion> chatCompletionEntity(YiYanChatCompletionRequest request) { public ResponseEntity<YiYanChatCompletionResponse> chatCompletionEntity(YiYanChatCompletionRequest request) {
// TODO: 2024/3/10 小范 这里错误信息返回的结构不一样 // TODO: 2024/3/10 小范 这里错误信息返回的结构不一样
// {"error_code":17,"error_msg":"Open api daily request limit reached"} // {"error_code":17,"error_msg":"Open api daily request limit reached"}
return this.webClient.post() return this.webClient.post()
@ -78,11 +86,11 @@ public class YiYanApi {
.build()) .build())
.body(Mono.just(request), YiYanChatCompletionRequest.class) .body(Mono.just(request), YiYanChatCompletionRequest.class)
.retrieve() .retrieve()
.toEntity(YiYanChatCompletion.class) .toEntity(YiYanChatCompletionResponse.class)
.block(); .block();
} }
public Flux<YiYanChatCompletion> chatCompletionStream(YiYanChatCompletionRequest request) { public Flux<YiYanChatCompletionResponse> chatCompletionStream(YiYanChatCompletionRequest request) {
return this.webClient.post() return this.webClient.post()
.uri(uriBuilder .uri(uriBuilder
-> uriBuilder.path(useChatModel.getUri()) -> uriBuilder.path(useChatModel.getUri())
@ -90,6 +98,7 @@ public class YiYanApi {
.build()) .build())
.body(Mono.just(request), YiYanChatCompletionRequest.class) .body(Mono.just(request), YiYanChatCompletionRequest.class)
.retrieve() .retrieve()
.bodyToFlux(YiYanChatCompletion.class); .bodyToFlux(YiYanChatCompletionResponse.class);
} }
} }

View File

@ -1,15 +1,15 @@
package org.springframework.ai.models.yiyan.api; package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.Data; import lombok.Data;
// TODO @fan字段驼峰字段注释都可以删除贴个链接就好
/** /**
* 一言 获取access_token * 获取文心一言 access_token Response
* *
* author: fansili * @author fansili
* time: 2024/3/10 08:51
*/ */
@Data @Data
public class YiYanAuthRes { public class YiYanAuthResponse {
/** /**
* 访问凭证 * 访问凭证

View File

@ -1,16 +1,16 @@
package org.springframework.ai.models.yiyan.api; package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
// TODO @fan字段驼峰字段注释都可以删除贴个链接就好
/** /**
* 一言 Completion req * 文心一言 Completion Request
* *
* 百度千帆文档https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11 * 百度千帆文档https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
* *
* author: fansili * @author fansili
* time: 2024/3/9 10:34
*/ */
@Data @Data
public class YiYanChatCompletionRequest { public class YiYanChatCompletionRequest {
@ -114,9 +114,11 @@ public class YiYanChatCompletionRequest {
@Data @Data
public static class Message { public static class Message {
private String role; private String role;
private String content; private String content;
} }
@Data @Data

View File

@ -1,16 +1,16 @@
package org.springframework.ai.models.yiyan.api; package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.Data; import lombok.Data;
/** /**
* 聊天返回 * 文心一言 Completion Response
*
* 百度链接: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t * 百度链接: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
* *
* author: fansili * @author fansili
* time: 2024/3/9 10:34
*/ */
@Data @Data
public class YiYanChatCompletion { public class YiYanChatCompletionResponse {
/** /**
* 本轮对话的id * 本轮对话的id
@ -88,4 +88,5 @@ public class YiYanChatCompletion {
*/ */
private int total_tokens; private int total_tokens;
} }
} }

View File

@ -1,15 +1,14 @@
package org.springframework.ai.models.yiyan; package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
/** /**
* 一言模型 * 文心一言模型枚举
* *
* 可参考百度文档https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t * 可参考 <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t">百度文档</>
* *
* author: fansili * @author fansili
* time: 2024/3/9 12:01
*/ */
@Getter @Getter
@AllArgsConstructor @AllArgsConstructor
@ -18,21 +17,24 @@ public enum YiYanChatModel {
ERNIE4_0("ERNIE 4.0", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"), ERNIE4_0("ERNIE 4.0", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"),
ERNIE4_3_5_8K("ERNIE-3.5-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"), ERNIE4_3_5_8K("ERNIE-3.5-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"),
ERNIE4_3_5_8K_0205("ERNIE-3.5-8K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205"), ERNIE4_3_5_8K_0205("ERNIE-3.5-8K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205"),
ERNIE4_3_5_8K_1222("ERNIE-3.5-8K-1222", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222"), ERNIE4_3_5_8K_1222("ERNIE-3.5-8K-1222", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222"),
ERNIE4_BOT_8K("ERNIE-Bot-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"), ERNIE4_BOT_8K("ERNIE-Bot-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"),
ERNIE4_3_5_4K_0205("ERNIE-3.5-4K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205"), ERNIE4_3_5_4K_0205("ERNIE-3.5-4K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205"),
; ;
private String model; /**
* 模型名
private String uri; */
private final String model;
/**
* API URL
*/
private final String uri;
public static YiYanChatModel valueOfModel(String model) { public static YiYanChatModel valueOfModel(String model) {
for (YiYanChatModel itemEnum : YiYanChatModel.values()) { for (YiYanChatModel modelEnum : YiYanChatModel.values()) {
if (itemEnum.getModel().equals(model)) { if (modelEnum.getModel().equals(model)) {
return itemEnum; return modelEnum;
} }
} }
throw new IllegalArgumentException("Invalid MessageType value: " + model); throw new IllegalArgumentException("Invalid MessageType value: " + model);

View File

@ -1,4 +1,4 @@
package org.springframework.ai.models.yiyan.exception; package cn.iocoder.yudao.framework.ai.core.model.yiyan.exception;
/** /**
* 一言 api 调用异常 * 一言 api 调用异常

View File

@ -5,7 +5,7 @@ import org.springframework.ai.chat.*;
import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.models.tongyi.api.QianWenApi; import org.springframework.ai.models.tongyi.api.QianWenApi;
import org.springframework.ai.models.yiyan.exception.YiYanApiException; import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import com.alibaba.dashscope.aigc.generation.GenerationResult; import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam; import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message; import com.alibaba.dashscope.common.Message;

View File

@ -1,158 +0,0 @@
package org.springframework.ai.models.yiyan;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import org.springframework.ai.chat.*;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.models.yiyan.api.YiYanApi;
import org.springframework.ai.models.yiyan.api.YiYanChatCompletion;
import org.springframework.ai.models.yiyan.api.YiYanChatCompletionRequest;
import org.springframework.ai.models.yiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import reactor.core.publisher.Flux;
import java.time.Duration;
import java.util.List;
import java.util.stream.Collectors;
/**
* 文心一言
* <p>
* author: fansili
* time: 2024/3/8 19:11
*/
@Slf4j
public class YiYanChatClient implements ChatClient, StreamingChatClient {
private YiYanApi yiYanApi;
private YiYanOptions yiYanOptions;
public YiYanChatClient(YiYanApi yiYanApi) {
this.yiYanApi = yiYanApi;
}
public YiYanChatClient(YiYanApi yiYanApi, YiYanOptions yiYanOptions) {
this.yiYanApi = yiYanApi;
this.yiYanOptions = yiYanOptions;
}
public final RetryTemplate retryTemplate = RetryTemplate.builder()
// 最大重试次数 10
.maxAttempts(10)
.retryOn(YiYanApiException.class)
// 最大重试5次第一次间隔3000ms第二次3000ms * 2第三次3000ms * 3以此类推最大间隔3 * 60000ms
.exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() {
@Override
public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
log.warn("重试异常:" + context.getRetryCount(), throwable);
}
;
})
.build();
@Override
public String call(String message) {
return ChatClient.super.call(message);
}
@Override
public ChatResponse call(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
// ctx 会有重试的信息
// 创建 request 请求stream模式需要供应商支持
YiYanChatCompletionRequest request = this.createRequest(prompt, false);
// 调用 callWithFunctionSupport 发送请求
ResponseEntity<YiYanChatCompletion> response = yiYanApi.chatCompletionEntity(request);
// 获取结果封装 ChatResponse
YiYanChatCompletion chatCompletion = response.getBody();
return new ChatResponse(List.of(new Generation(chatCompletion.getResult())));
});
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
// ctx 会有重试的信息
// 创建 request 请求stream模式需要供应商支持
YiYanChatCompletionRequest request = this.createRequest(prompt, true);
// 调用 callWithFunctionSupport 发送请求
Flux<YiYanChatCompletion> response = this.yiYanApi.chatCompletionStream(request);
response.doOnComplete(new Runnable() {
@Override
public void run() {
String a = ";";
}
});
return response.map(res -> {
// TODO @fan这里缺少了 usage 的封装
return new ChatResponse(List.of(new Generation(res.getResult())));
});
}
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 获取配置
YiYanOptions useOptions = getYiYanOptions(prompt);
// 创建 request
// tip: 百度的 system 不在 message 里面
// tip百度的 message 只有 user assistant
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
// 获取 user assistant
List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream()
// 过滤 system
.filter(msg -> MessageType.SYSTEM != msg.getMessageType())
.map(msg -> new YiYanChatCompletionRequest.Message()
.setRole(msg.getMessageType().getValue())
.setContent(msg.getContent())
).toList();
// 获取 system
String systemPrompt = prompt.getInstructions().stream()
.filter(msg -> MessageType.SYSTEM == msg.getMessageType())
.map(Message::getContent)
.collect(Collectors.joining());
YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
// 复制 qianWenOptions 属性取 request这里 options 属性和 request 基本保持一致
// top: 由于遵循 spring-ai规范支持在构建client的时候传入默认的 chatOptions
BeanUtil.copyProperties(useOptions, request);
request.setTop_p(useOptions.getTopP());
request.setMax_output_tokens(useOptions.getMaxOutputTokens());
request.setTemperature(useOptions.getTemperature());
request.setSystem(systemPrompt);
// 设置 stream
request.setStream(stream);
return request;
}
private @NotNull YiYanOptions getYiYanOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (yiYanOptions == null && prompt.getOptions() == null) {
throw new ChatException("ChatOptions 未配置参数!");
}
// 优先使用 Prompt 里面的 ChatOptions
ChatOptions options = yiYanOptions;
if (prompt.getOptions() != null) {
options = (ChatOptions) prompt.getOptions();
}
// Prompt 里面是一个 ChatOptions用户可以随意传入这里做一下判断
if (!(options instanceof YiYanOptions)) {
throw new ChatException("Prompt 传入的不是 YiYanOptions!");
}
// 转换 YiYanOptions
YiYanOptions useOptions = (YiYanOptions) options;
return useOptions;
}
}

View File

@ -1,8 +0,0 @@
package org.springframework.ai.models.yiyan.api;
/**
* author: fansili
* time: 2024/3/9 10:37
*/
public class YiYanChatCompletionMessage {
}

View File

@ -5,10 +5,10 @@ import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.models.yiyan.YiYanChatClient; import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import org.springframework.ai.models.yiyan.YiYanChatModel; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import org.springframework.ai.models.yiyan.YiYanOptions; import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import org.springframework.ai.models.yiyan.api.YiYanApi; import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
@ -35,7 +35,7 @@ public class YiYanChatTests {
YiYanChatModel.ERNIE4_3_5_8K, YiYanChatModel.ERNIE4_3_5_8K,
86400 86400
); );
YiYanOptions yiYanOptions = new YiYanOptions(); YiYanChatOptions yiYanOptions = new YiYanChatOptions();
yiYanOptions.setMaxOutputTokens(2048); yiYanOptions.setMaxOutputTokens(2048);
yiYanOptions.setTopP(0.6f); yiYanOptions.setTopP(0.6f);
yiYanOptions.setTemperature(0.85f); yiYanOptions.setTemperature(0.85f);

View File

@ -230,17 +230,6 @@ yudao:
appKey: cb6415c19d6162cda07b47316fcb0416 appKey: cb6415c19d6162cda07b47316fcb0416
secretKey: Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh secretKey: Y2JiYTIxZjA3MDMxMjNjZjQzYzVmNzdh
model: XING_HUO_3_5 model: XING_HUO_3_5
yiyan:
enable: true
aiPlatform: YI_YAN
max-tokens: 1500
temperature: 0.85
topP: 0.8
topK: 0
appKey: x0cuLZ7XsaTCU08vuJWO87Lg
secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
refreshTokenSecondTime: 86400
model: ERNIE4_3_5_8K
openAiImage: openAiImage:
enable: true enable: true
api-key: ${OPEN_AI_KEY} api-key: ${OPEN_AI_KEY}

View File

@ -150,6 +150,19 @@ spring.ai:
chat: chat:
model: llama3 model: llama3
yudao.ai:
yiyan:
enable: true
aiPlatform: YIYAN # TODO @fan建议每个都独立配置属性类
max-tokens: 1500
temperature: 0.85
topP: 0.8
topK: 0
appKey: x0cuLZ7XsaTCU08vuJWO87Lg
secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
refreshTokenSecondTime: 86400
model: ERNIE4_3_5_8K
--- #################### 芋道相关配置 #################### --- #################### 芋道相关配置 ####################
yudao: yudao: