【代码评审】AI:写作实现
This commit is contained in:
parent
9ddd2eddf8
commit
f20c27a7ef
|
@ -33,7 +33,7 @@ public interface ErrorCodeConstants {
|
|||
// ========== API 聊天消息 1-040-004-000 ==========
|
||||
|
||||
ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
|
||||
ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!");
|
||||
ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "对话生成异常!");
|
||||
|
||||
// ========== API 绘画 1-040-005-000 ==========
|
||||
|
||||
|
@ -48,6 +48,6 @@ public interface ErrorCodeConstants {
|
|||
|
||||
// ========== API 写作 1-022-007-000 ==========
|
||||
ErrorCode WRITE_NOT_EXISTS = new ErrorCode(1_022_007_000, "作文不存在!");
|
||||
ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "Stream 对话异常!");
|
||||
ErrorCode WRITE_STREAM_ERROR = new ErrorCode(1_022_07_001, "写作生成异常!");
|
||||
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import lombok.Getter;
|
|||
|
||||
import java.util.Arrays;
|
||||
|
||||
// TODO @xin:写作的几个,不用枚举类哈。直接搞字段就好了。AiWriteTypeEnum 还是需要的哈
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
public enum AiLanguageEnum implements IntArrayValuable {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package cn.iocoder.yudao.module.ai.controller.admin.write.vo;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.validation.InEnum;
|
||||
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
@ -8,6 +10,11 @@ import lombok.Data;
|
|||
@Data
|
||||
public class AiWriteGenerateReqVO {
|
||||
|
||||
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
@InEnum(AiWriteTypeEnum.class)
|
||||
private Integer type;
|
||||
|
||||
// TODO @xin:如果非必填,可以不用写 requiredMode
|
||||
@Schema(description = "写作内容提示", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "1.撰写:田忌赛马;2.回复:不批")
|
||||
private String prompt;
|
||||
|
||||
|
@ -30,7 +37,4 @@ public class AiWriteGenerateReqVO {
|
|||
@NotNull(message = "语言不能为空")
|
||||
private Integer language;
|
||||
|
||||
|
||||
@Schema(description = "写作类型", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
private Integer type; //参见 AiWriteTypeEnum 枚举
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
package cn.iocoder.yudao.module.ai.dal.dataobject.write;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
|
@ -34,6 +35,18 @@ public class AiWriteDO extends BaseDO {
|
|||
*/
|
||||
private Integer type;
|
||||
|
||||
/**
|
||||
* 模型
|
||||
*/
|
||||
private String model;
|
||||
|
||||
/**
|
||||
* 平台
|
||||
*
|
||||
* 枚举 {@link AiPlatformEnum}
|
||||
*/
|
||||
private String platform;
|
||||
|
||||
/**
|
||||
* 生成内容提示
|
||||
*/
|
||||
|
@ -69,16 +82,6 @@ public class AiWriteDO extends BaseDO {
|
|||
*/
|
||||
private Integer language;
|
||||
|
||||
/**
|
||||
* 模型
|
||||
*/
|
||||
private String model;
|
||||
|
||||
/**
|
||||
* 平台
|
||||
*/
|
||||
private String platform;
|
||||
|
||||
/**
|
||||
* 错误信息
|
||||
*/
|
||||
|
|
|
@ -11,7 +11,6 @@ import reactor.core.publisher.Flux;
|
|||
*/
|
||||
public interface AiWriteService {
|
||||
|
||||
|
||||
/**
|
||||
* 生成写作内容
|
||||
*
|
||||
|
@ -21,5 +20,4 @@ public interface AiWriteService {
|
|||
*/
|
||||
Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId);
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -46,23 +46,22 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||
@Resource
|
||||
private AiChatModelService chatModalService;
|
||||
@Resource
|
||||
private AiWriteMapper writeMapper;
|
||||
|
||||
private AiWriteMapper writeMapper; // TODO @xin:上面空一行;因为同类之间不要空行,非同类空行;
|
||||
|
||||
@Override
|
||||
public Flux<CommonResult<String>> generateWriteContent(AiWriteGenerateReqVO generateReqVO, Long userId) {
|
||||
//TODO 芋艿 写作的模型配置放哪好 先用千问测试
|
||||
// 1.1 校验模型
|
||||
// TODO @xin:可以约定大于配置先,查询某个名字。例如说,写作助手!然后写作助手,上面是有个 model 的,可以使用它。
|
||||
AiChatModelDO model = chatModalService.validateChatModel(14L);
|
||||
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
|
||||
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
||||
ChatOptions chatOptions = buildChatOptions(platform, model.getModel(), model.getTemperature(), model.getMaxTokens());
|
||||
|
||||
//1.2 插入写作信息
|
||||
// 1.2 插入写作信息
|
||||
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class);
|
||||
writeMapper.insert(writeDO.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
|
||||
|
||||
//2.1 构建提示词
|
||||
// 2.1 构建提示词
|
||||
Prompt prompt = new Prompt(buildWritingPrompt(generateReqVO), chatOptions);
|
||||
Flux<ChatResponse> streamResponse = chatClient.stream(prompt);
|
||||
// 2.2 流式返回
|
||||
|
@ -81,7 +80,6 @@ public class AiWriteServiceImpl implements AiWriteService {
|
|||
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.WRITE_STREAM_ERROR)));
|
||||
}
|
||||
|
||||
|
||||
private String buildWritingPrompt(AiWriteGenerateReqVO generateReqVO) {
|
||||
String template;
|
||||
Integer writeType = generateReqVO.getType();
|
||||
|
|
|
@ -9,12 +9,17 @@ import lombok.Data;
|
|||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.openai.api.ApiUtils;
|
||||
import org.springframework.http.HttpRequest;
|
||||
import org.springframework.http.HttpStatusCode;
|
||||
import org.springframework.web.reactive.function.client.ClientResponse;
|
||||
import org.springframework.web.reactive.function.client.WebClient;
|
||||
import reactor.core.publisher.Mono;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* Midjourney API
|
||||
|
@ -25,6 +30,16 @@ import java.util.Map;
|
|||
@Slf4j
|
||||
public class MidjourneyApi {
|
||||
|
||||
private final Predicate<HttpStatusCode> STATUS_PREDICATE = status -> !status.is2xxSuccessful();
|
||||
|
||||
private final Function<Object, Function<ClientResponse, Mono<? extends Throwable>>> EXCEPTION_FUNCTION =
|
||||
reqParam -> response -> response.bodyToMono(String.class).handle((responseBody, sink) -> {
|
||||
HttpRequest request = response.request();
|
||||
log.error("[midjourney-api] 调用失败!请求方式:[{}],请求地址:[{}],请求参数:[{}],响应数据: [{}]",
|
||||
request.getMethod(), request.getURI(), reqParam, responseBody);
|
||||
sink.error(new IllegalStateException("[midjourney-api] 调用失败!"));
|
||||
});
|
||||
|
||||
private final WebClient webClient;
|
||||
|
||||
/**
|
||||
|
@ -80,17 +95,11 @@ public class MidjourneyApi {
|
|||
}
|
||||
|
||||
private String post(String uri, Object body) {
|
||||
// 1、发送 post 请求
|
||||
return webClient.post()
|
||||
.uri(uri)
|
||||
.body(Mono.just(JsonUtils.toJsonString(body)), String.class)
|
||||
.retrieve()
|
||||
.onStatus(status -> !status.is2xxSuccessful(),
|
||||
response -> response.bodyToMono(String.class)
|
||||
.handle((respBody, sink) -> {
|
||||
log.error("【Midjourney api】调用失败!resp: 【{}】", respBody);
|
||||
sink.error(new IllegalStateException("【Midjourney api】调用失败!"));
|
||||
}))
|
||||
.onStatus(STATUS_PREDICATE, EXCEPTION_FUNCTION.apply(body))
|
||||
.bodyToMono(String.class)
|
||||
.block();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue