text_vector_database.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from bge.remote_model import RemoteBGEModel
  2. import os
  3. import faiss
  4. import numpy as np
  5. import json
  6. class TextVectorDatabase:
  7. def __init__(self, text_root_dir:str, name:str):
  8. self.model = RemoteBGEModel()
  9. # 模型预热
  10. self.vector_dim = self.model.encode(["这是一段预热文字,首次推理通过预测保证后续推理的稳定性和性能。"])[0].shape[0]
  11. self.all_files = []
  12. self.text_root_dir = text_root_dir
  13. self.name = name
  14. self.split_char = "。。"
  15. # 加载元数据和向量知识库
  16. self.text_mata_path = f"./db/{self.name}_text_meta.json"
  17. if os.path.exists(self.text_mata_path):
  18. with open(self.text_mata_path, "r", encoding="utf-8") as f:
  19. self.text_meta = json.load(f)
  20. print(f"从{self.text_mata_path}加载元数据成功")
  21. else:
  22. self.text_meta = {} # 存储每个文本的元数据,文件名称为key, value为全文内容
  23. self.chunk_meta_path = f"./db/{self.name}_chunk_meta.json"
  24. self.counter_chunk_encode = 0 # 全局chunk位置索引
  25. if os.path.exists(self.chunk_meta_path):
  26. with open(self.chunk_meta_path, "r", encoding="utf-8") as f:
  27. self.chunk_meta = json.load(f)
  28. print(f"从{self.chunk_meta_path}加载元数据成功")
  29. else:
  30. self.chunk_meta = {} # 存储每个chunk的元数据, key为位置编码,value为字典,包括content为chunk内容,和father为文件名
  31. self.database_path = f'./db/{self.name}_knowledge.faiss'
  32. self.knowledge_database = faiss.read_index(self.database_path) if os.path.exists(self.database_path) else faiss.IndexFlatIP(self.vector_dim)
  33. if os.path.exists(self.database_path): print(f"从{self.database_path}加载向量数据库成功")
  34. # 如果这些元数据和数据库有不存在的就直接重新生成
  35. if not os.path.exists(self.database_path) or not os.path.exists(self.text_mata_path) or not os.path.exists(self.chunk_meta_path):
  36. self.text_embedding(self.text_root_dir)
  37. # 搜索候选数量
  38. self.indicators = 16 # 32刚好是模型的单次输入上限
  39. def get_all_texts_directory(self, root_dir):
  40. """
  41. 获取所有文本的目录
  42. :return: 所有文本的目录
  43. """
  44. for root, dirs, files in os.walk(root_dir):
  45. for file in files:
  46. if file.endswith(".md"):
  47. self.all_files.append(os.path.join(root, file))
  48. def read_md(self, file_path):
  49. """
  50. 读取md文件
  51. :param file_path: md文件的路径
  52. :return: md文件的内容
  53. """
  54. if not os.path.exists(file_path):
  55. return None
  56. with open(file_path, "r", encoding="utf-8") as f:
  57. text = f.read()
  58. # 将文本按照指定的分割符进行分割
  59. return text, text.split(self.split_char)
  60. def add_meta_data(self, file_path):
  61. """为每个chunk创建父文本元数据"""
  62. text, chunks = self.read_md(file_path)
  63. chunks.pop(0) # 去除标题
  64. if chunks:
  65. # 添加本文元数据
  66. self.text_meta[os.path.basename(file_path)] = text
  67. embeddings = []
  68. for chunk in chunks:
  69. # 添加chunk元数据
  70. self.chunk_meta[str(self.counter_chunk_encode)] = {"content": chunk, "father": os.path.basename(file_path)} # content是chunk内容,father是文件名
  71. self.counter_chunk_encode += 1
  72. # 将chunk添加到向量数据库
  73. embedding = self.model.encode(chunk, normalize=True)[0]
  74. embeddings.append(embedding)
  75. embeddings = np.array(embeddings, dtype=np.float32)
  76. # print()
  77. # print(embeddings.shape)
  78. self.knowledge_database.add(embeddings) # 向量会按照添加顺序存储在数据库中,查询时也会返回这个顺序作为索引
  79. def text_embedding(self, text_root_dir):
  80. """"一次性对知识库所有文本进行embedding"""
  81. print("正在创建知识库...")
  82. if not os.path.exists(text_root_dir):
  83. raise FileNotFoundError(f"{text_root_dir}不存在")
  84. if not self.all_files:
  85. self.get_all_texts_directory(text_root_dir)
  86. for file in self.all_files:
  87. # 将文件内容添加到元数据
  88. self.add_meta_data(file)
  89. print(f"完成处理文件:{file}")
  90. # 保存向量数据库
  91. faiss.write_index(self.knowledge_database, self.database_path)
  92. print(f"向量数据库创建成功,保存至{self.database_path}")
  93. # 保存元数据
  94. with open(self.text_mata_path, "w", encoding="utf-8") as f:
  95. json.dump(self.text_meta, f, ensure_ascii=False, indent=4)
  96. print(f"元数据保存成功,保存至{self.text_mata_path}")
  97. with open(self.chunk_meta_path, "w", encoding="utf-8") as f:
  98. json.dump(self.chunk_meta, f, ensure_ascii=False, indent=4)
  99. print(f"元数据保存成功,保存至{self.chunk_meta_path}")
  100. def search(self, query:str, top_k:int=3, confidence_threshold:float=0.20, *args,**kwargs):
  101. """从知识库中检索"""
  102. if not isinstance(query, str) or query == '':
  103. return None
  104. query_embedding = np.array(self.model.encode([query], normalize=True), dtype=np.float32)
  105. distances, indices = self.knowledge_database.search(query_embedding, self.indicators)
  106. distances = distances[0]
  107. indices = indices[0]
  108. rerank_score = self.model.compute_score([(query, self.chunk_meta[str(indices[i])]['content']) for i in range(self.indicators) if indices[i] > -1])
  109. rerank_results = sorted([(indices[_], distances[_], rerank_score[_]) for _ in range(len(rerank_score))], key=lambda x: x[2], reverse=True)[:top_k]
  110. search_results = []
  111. for ind, dis, score in rerank_results:
  112. if score < confidence_threshold:
  113. continue
  114. title = self.chunk_meta[str(int(ind))]['father']
  115. content = self.text_meta[title].replace(self.split_char, ' ')
  116. search_results.append({
  117. "title": title,
  118. "content": content,
  119. "score": score,
  120. })
  121. return search_results
  122. knowledge_comp_rule_regulation_db = TextVectorDatabase(name='rule_regulation',text_root_dir=r'D:\code\repository\RAG资料库—上线\公司规章制度_语义块') # 公司规章制度
  123. knowledge_comp_propaganda_db = TextVectorDatabase(name='propaganda',text_root_dir=r'D:\code\repository\RAG资料库—上线\公司宣传及产品介绍_语义块') # 公司宣传及产品介绍
  124. knowledge_comp_water_treatment_db = TextVectorDatabase(name='water_treatment',text_root_dir=r'D:\code\repository\RAG资料库—上线\污水处理工艺知识_语义块') # 污水处理工艺知识
  125. knowledge_comp_project_plan_db = TextVectorDatabase(name='project_plan',text_root_dir=r'D:\code\repository\RAG资料库—上线\污水处理项目方案_语义块') # 污水处理项目方案
  126. knowledge_comp_operation_report_db = TextVectorDatabase(name='operation_report',text_root_dir=r'D:\code\repository\RAG资料库—上线\水厂运行报告_语义块') # 水厂运行报告
  127. if __name__ == '__main__':
  128. res = knowledge_comp_rule_regulation_db.search('金科环境的实习期有多长?', top_k=3)
  129. print(len(res), res)
  130. res = knowledge_comp_propaganda_db.search('介绍一下新水岛产品', top_k=3)
  131. print(len(res), res)
  132. res = knowledge_comp_water_treatment_db.search('反渗透RO如何计算', top_k=3)
  133. print(len(res), res)
  134. res = knowledge_comp_project_plan_db.search('请给我一份新水岛实施项目方案大纲', top_k=3)
  135. print(len(res), res)
  136. res = knowledge_comp_operation_report_db.search('水厂运行周报告', top_k=3)
  137. print(len(res), res)