from bge.remote_model import RemoteBGEModel import os import faiss import numpy as np import json class TextVectorDatabase: def __init__(self, text_root_dir:str, name:str): self.model = RemoteBGEModel() # 模型预热 self.vector_dim = self.model.encode(["这是一段预热文字,首次推理通过预测保证后续推理的稳定性和性能。"])[0].shape[0] self.all_files = [] self.text_root_dir = text_root_dir self.name = name self.split_char = "。。" # 加载元数据和向量知识库 self.text_mata_path = f"./db/{self.name}_text_meta.json" if os.path.exists(self.text_mata_path): with open(self.text_mata_path, "r", encoding="utf-8") as f: self.text_meta = json.load(f) print(f"从{self.text_mata_path}加载元数据成功") else: self.text_meta = {} # 存储每个文本的元数据,文件名称为key, value为全文内容 self.chunk_meta_path = f"./db/{self.name}_chunk_meta.json" self.counter_chunk_encode = 0 # 全局chunk位置索引 if os.path.exists(self.chunk_meta_path): with open(self.chunk_meta_path, "r", encoding="utf-8") as f: self.chunk_meta = json.load(f) print(f"从{self.chunk_meta_path}加载元数据成功") else: self.chunk_meta = {} # 存储每个chunk的元数据, key为位置编码,value为字典,包括content为chunk内容,和father为文件名 self.database_path = f'./db/{self.name}_knowledge.faiss' self.knowledge_database = faiss.read_index(self.database_path) if os.path.exists(self.database_path) else faiss.IndexFlatIP(self.vector_dim) if os.path.exists(self.database_path): print(f"从{self.database_path}加载向量数据库成功") # 如果这些元数据和数据库有不存在的就直接重新生成 if not os.path.exists(self.database_path) or not os.path.exists(self.text_mata_path) or not os.path.exists(self.chunk_meta_path): self.text_embedding(self.text_root_dir) # 搜索候选数量 self.indicators = 16 # 32刚好是模型的单次输入上限 def get_all_texts_directory(self, root_dir): """ 获取所有文本的目录 :return: 所有文本的目录 """ for root, dirs, files in os.walk(root_dir): for file in files: if file.endswith(".md"): self.all_files.append(os.path.join(root, file)) def read_md(self, file_path): """ 读取md文件 :param file_path: md文件的路径 :return: md文件的内容 """ if not os.path.exists(file_path): return None with open(file_path, "r", encoding="utf-8") as f: text = f.read() # 将文本按照指定的分割符进行分割 return text, text.split(self.split_char) def add_meta_data(self, file_path): """为每个chunk创建父文本元数据""" text, chunks = self.read_md(file_path) chunks.pop(0) # 去除标题 if chunks: # 添加本文元数据 self.text_meta[os.path.basename(file_path)] = text embeddings = [] for chunk in chunks: # 添加chunk元数据 self.chunk_meta[str(self.counter_chunk_encode)] = {"content": chunk, "father": os.path.basename(file_path)} # content是chunk内容,father是文件名 self.counter_chunk_encode += 1 # 将chunk添加到向量数据库 embedding = self.model.encode(chunk, normalize=True)[0] embeddings.append(embedding) embeddings = np.array(embeddings, dtype=np.float32) # print() # print(embeddings.shape) self.knowledge_database.add(embeddings) # 向量会按照添加顺序存储在数据库中,查询时也会返回这个顺序作为索引 def text_embedding(self, text_root_dir): """"一次性对知识库所有文本进行embedding""" print("正在创建知识库...") if not os.path.exists(text_root_dir): raise FileNotFoundError(f"{text_root_dir}不存在") if not self.all_files: self.get_all_texts_directory(text_root_dir) for file in self.all_files: # 将文件内容添加到元数据 self.add_meta_data(file) print(f"完成处理文件:{file}") # 保存向量数据库 faiss.write_index(self.knowledge_database, self.database_path) print(f"向量数据库创建成功,保存至{self.database_path}") # 保存元数据 with open(self.text_mata_path, "w", encoding="utf-8") as f: json.dump(self.text_meta, f, ensure_ascii=False, indent=4) print(f"元数据保存成功,保存至{self.text_mata_path}") with open(self.chunk_meta_path, "w", encoding="utf-8") as f: json.dump(self.chunk_meta, f, ensure_ascii=False, indent=4) print(f"元数据保存成功,保存至{self.chunk_meta_path}") def search(self, query:str, top_k:int=3, confidence_threshold:float=0.20): """从知识库中检索""" if not isinstance(query, str) or query == '': return None query_embedding = np.array(self.model.encode([query], normalize=True), dtype=np.float32) distances, indices = self.knowledge_database.search(query_embedding, self.indicators) distances = distances[0] indices = indices[0] rerank_score = self.model.compute_score([(query, self.chunk_meta[str(indices[i])]['content']) for i in range(self.indicators) if indices[i] > -1]) rerank_results = sorted([(indices[_], distances[_], rerank_score[_]) for _ in range(len(rerank_score))], key=lambda x: x[2], reverse=True)[:top_k] search_results = [] for ind, dis, score in rerank_results: if score < confidence_threshold: continue title = self.chunk_meta[str(int(ind))]['father'] content = self.text_meta[title].replace(self.split_char, ' ') search_results.append({ "title": title, "content": content, "score": score, }) return search_results knowledge_comp_rule_regulation_db = TextVectorDatabase(name='rule_regulation',text_root_dir=r'D:\code\repository\RAG资料库—上线\公司规章制度_语义块') # 公司规章制度 knowledge_comp_propaganda_db = TextVectorDatabase(name='propaganda',text_root_dir=r'D:\code\repository\RAG资料库—上线\公司宣传及产品介绍_语义块') # 公司宣传及产品介绍 knowledge_comp_water_treatment_db = TextVectorDatabase(name='water_treatment',text_root_dir=r'D:\code\repository\RAG资料库—上线\污水处理工艺知识_语义块') # 污水处理工艺知识 knowledge_comp_project_plan_db = TextVectorDatabase(name='project_plan',text_root_dir=r'D:\code\repository\RAG资料库—上线\污水处理项目方案_语义块') # 污水处理项目方案 knowledge_comp_operation_report_db = TextVectorDatabase(name='operation_report',text_root_dir=r'D:\code\repository\RAG资料库—上线\水厂运行报告_语义块') # 水厂运行报告 if __name__ == '__main__': res = knowledge_comp_rule_regulation_db.search('金科环境的实习期有多长?', top_k=3) print(len(res), res) res = knowledge_comp_propaganda_db.search('介绍一下新水岛产品', top_k=3) print(len(res), res) res = knowledge_comp_water_treatment_db.search('反渗透RO如何计算', top_k=3) print(len(res), res) res = knowledge_comp_project_plan_db.search('请给我一份新水岛实施项目方案大纲', top_k=3) print(len(res), res) res = knowledge_comp_operation_report_db.search('水厂运行周报告', top_k=3) print(len(res), res)