【代码评审】AI 大模型:知识库的逻辑

This commit is contained in:
YunaiV 2024-09-07 20:12:37 +08:00
parent 0b1d9ce251
commit 8e56b81a3a
14 changed files with 25 additions and 17 deletions

View File

@ -32,4 +32,5 @@ public class AiKnowledgeCreateMyReqVO {
@Schema(description = "topK", requiredMode = Schema.RequiredMode.REQUIRED, example = "3")
@NotNull(message = "topK 不能为空")
private Integer topK;
}

View File

@ -42,4 +42,5 @@ public class AiKnowledgeDocumentCreateReqVO {
@Schema(description = "分块是否保留分隔符", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
@NotNull(message = "分块是否保留分隔符不能为空")
private Boolean keepSeparator;
}

View File

@ -57,17 +57,16 @@ public class AiKnowledgeDO extends BaseDO {
* topK
*/
private Integer topK;
/**
* 相似度阈值
*/
private Double similarityThreshold;
/**
* 状态
* <p>
* 枚举 {@link CommonStatusEnum}
*/
private Integer status;
}

View File

@ -47,10 +47,12 @@ public class AiKnowledgeDocumentDO extends BaseDO {
* 字符数
*/
private Integer wordCount;
// TODO @新chunk 1是不是 segment这样命名保持一致会好点哈2Size 是不是改成 Tokens 会统一点3defaultChunkSizedefaultChunkSizeminChunkSizeCharsmaxNumChunks 这几个字段的命名可能要微信一起讨论下尽量命名保持风格统一哈
/**
* 每个文本块的目标 token
*/
private Integer defaultChunkSize;
// TODO @xinSizeChars wordCount 好像是一个意思是不是也要统一哈
/**
* 每个文本块的最小字符数
*/

View File

@ -27,7 +27,7 @@ public class AiKnowledgeSegmentDO extends BaseDO {
/**
* 向量库的编号
*/
@TableField(updateStrategy = FieldStrategy.ALWAYS)
@TableField(updateStrategy = FieldStrategy.ALWAYS) // TODO @新尽量规避要这个注解万一后面加个 status 单独更新可能会踩坑
private String vectorId;
/**
* 知识库编号

View File

@ -25,9 +25,11 @@ public interface AiKnowledgeSegmentMapper extends BaseMapperX<AiKnowledgeSegment
.orderByDesc(AiKnowledgeSegmentDO::getId));
}
// TODO @新selectListByXXX
default List<AiKnowledgeSegmentDO> selectList(List<String> vectorIdList) {
return selectList(new LambdaQueryWrapperX<AiKnowledgeSegmentDO>()
.in(AiKnowledgeSegmentDO::getVectorId, vectorIdList)
.orderByDesc(AiKnowledgeSegmentDO::getId));
}
}

View File

@ -83,7 +83,7 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
.setStatus(CommonStatusEnum.ENABLE.getStatus()));
segmentMapper.insertBatch(segmentDOList);
// 3.2 向量化并存储
// 3. 向量化并存储
segments.forEach(segment -> segment.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId()));
vectorStore.add(segments);
return documentId;

View File

@ -38,9 +38,8 @@ public interface AiKnowledgeSegmentService {
*/
void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO);
/**
* 段落召回
* 召回段落
*
* @param reqVO 召回请求信息
* @return 召回的段落

View File

@ -55,19 +55,19 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Override
public void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO) {
// 0 校验
// 1. 校验
AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
// 2.1 获取知识库向量实例
VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
// 2.2 删除原向量
vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
// 2.3 重新向量化
Document document = new Document(reqVO.getContent());
document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
vectorStore.add(List.of(document));
// 2.1 更新段落内容
// 3. 更新段落内容
AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
knowledgeSegment.setVectorId(document.getId());
segmentMapper.updateById(knowledgeSegment);
@ -98,14 +98,14 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
@Override
public List<AiKnowledgeSegmentDO> similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO) {
// 0. 校验
// 1. 校验
AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqVO.getKnowledgeId());
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 1.1 获取向量存储实例
// 2. 获取向量存储实例
VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
// 1.2 向量检索
// 3.1 向量检索
List<Document> documentList = vectorStore.similaritySearch(SearchRequest.query(reqVO.getContent())
.withTopK(knowledge.getTopK())
.withSimilarityThreshold(knowledge.getSimilarityThreshold())
@ -113,11 +113,10 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
if (CollUtil.isEmpty(documentList)) {
return ListUtil.empty();
}
// 2.1 段落召回
// 3.2 段落召回
return segmentMapper.selectList(CollUtil.getFieldValues(documentList, "id", String.class));
}
/**
* 校验段落是否存在
*
@ -131,4 +130,5 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
}
return knowledgeSegment;
}
}

View File

@ -23,7 +23,6 @@ public interface AiKnowledgeService {
*/
Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId);
/**
* 创建我的知识库
*
@ -32,7 +31,6 @@ public interface AiKnowledgeService {
*/
void updateKnowledgeMy(AiKnowledgeUpdateMyReqVO updateReqVO, Long userId);
/**
* 校验知识库是否存在
*
@ -49,6 +47,7 @@ public interface AiKnowledgeService {
*/
PageResult<AiKnowledgeDO> getKnowledgePageMy(Long userId, PageParam pageReqVO);
// TODO @新knowledgeId validateKnowledgeExists id 是同一个么如果是的话建议变量也用 id 然后两边的 id 注释保持一致
/**
* 根据知识库编号获取向量存储实例
*
@ -56,4 +55,5 @@ public interface AiKnowledgeService {
* @return 向量存储实例
*/
VectorStore getVectorStoreById(Long knowledgeId);
}

View File

@ -38,6 +38,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
private AiChatModelService chatModelService;
@Resource
private AiApiKeyService apiKeyService;
// TODO @新chatModelService apiKeyService 可以放到 33 行的 chatModalService 后面尽量保持想通类型的变量在一块例如说Service 一块Mapper 一块
@Override
public Long createKnowledgeMy(AiKnowledgeCreateMyReqVO createReqVO, Long userId) {
@ -85,6 +86,7 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
public VectorStore getVectorStoreById(Long knowledgeId) {
AiKnowledgeDO knowledge = validateKnowledgeExists(knowledgeId);
AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
// 创建或获取 VectorStore 对象
return apiKeyService.getOrCreateVectorStore(model.getKeyId());
}

View File

@ -146,6 +146,7 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
public VectorStore getOrCreateVectorStore(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
// 创建或获取 VectorStore 对象
return modelFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
}

View File

@ -47,7 +47,7 @@
<!-- 向量化,基于 Redis 存储Tika 解析内容 -->
<!-- 暂不做经济型,先注释 -->
<!-- 暂不做经济型,先注释 TODO 经济型是啥呀? -->
<!-- <dependency>-->
<!-- <groupId>${spring-ai.groupId}</groupId>-->
<!-- <artifactId>spring-ai-transformers-spring-boot-starter</artifactId>-->

View File

@ -197,6 +197,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
});
}
// TODO @新貌似可以创建一个大的 VectorStore然后搜的时候通过 Filter.Expression 过滤对应的数据
@Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel, AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);