| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- from bge.remote_model import RemoteBGEModel
- import os
- import faiss
- import numpy as np
- import json
- class TextVectorDatabase:
- def __init__(self, text_root_dir:str, name:str):
- self.model = RemoteBGEModel()
- # 模型预热
- self.vector_dim = self.model.encode(["这是一段预热文字,首次推理通过预测保证后续推理的稳定性和性能。"])[0].shape[0]
- self.all_files = []
- self.text_root_dir = text_root_dir
- self.name = name
- self.split_char = "。。"
- # 加载元数据和向量知识库
- self.text_mata_path = f"./db/{self.name}_text_meta.json"
- if os.path.exists(self.text_mata_path):
- with open(self.text_mata_path, "r", encoding="utf-8") as f:
- self.text_meta = json.load(f)
- print(f"从{self.text_mata_path}加载元数据成功")
- else:
- self.text_meta = {} # 存储每个文本的元数据,文件名称为key, value为全文内容
- self.chunk_meta_path = f"./db/{self.name}_chunk_meta.json"
- self.counter_chunk_encode = 0 # 全局chunk位置索引
- if os.path.exists(self.chunk_meta_path):
- with open(self.chunk_meta_path, "r", encoding="utf-8") as f:
- self.chunk_meta = json.load(f)
- print(f"从{self.chunk_meta_path}加载元数据成功")
- else:
- self.chunk_meta = {} # 存储每个chunk的元数据, key为位置编码,value为字典,包括content为chunk内容,和father为文件名
- self.database_path = f'./db/{self.name}_knowledge.faiss'
- self.knowledge_database = faiss.read_index(self.database_path) if os.path.exists(self.database_path) else faiss.IndexFlatIP(self.vector_dim)
- if os.path.exists(self.database_path): print(f"从{self.database_path}加载向量数据库成功")
- # 如果这些元数据和数据库有不存在的就直接重新生成
- if not os.path.exists(self.database_path) or not os.path.exists(self.text_mata_path) or not os.path.exists(self.chunk_meta_path):
- self.text_embedding(self.text_root_dir)
- # 搜索候选数量
- self.indicators = 16 # 32刚好是模型的单次输入上限
- def get_all_texts_directory(self, root_dir):
- """
- 获取所有文本的目录
- :return: 所有文本的目录
- """
- for root, dirs, files in os.walk(root_dir):
- for file in files:
- if file.endswith(".md"):
- self.all_files.append(os.path.join(root, file))
- def read_md(self, file_path):
- """
- 读取md文件
- :param file_path: md文件的路径
- :return: md文件的内容
- """
- if not os.path.exists(file_path):
- return None
- with open(file_path, "r", encoding="utf-8") as f:
- text = f.read()
- # 将文本按照指定的分割符进行分割
- return text, text.split(self.split_char)
- def add_meta_data(self, file_path):
- """为每个chunk创建父文本元数据"""
- text, chunks = self.read_md(file_path)
- chunks.pop(0) # 去除标题
- if chunks:
- # 添加本文元数据
- self.text_meta[os.path.basename(file_path)] = text
- embeddings = []
- for chunk in chunks:
- # 添加chunk元数据
- self.chunk_meta[str(self.counter_chunk_encode)] = {"content": chunk, "father": os.path.basename(file_path)} # content是chunk内容,father是文件名
- self.counter_chunk_encode += 1
- # 将chunk添加到向量数据库
- embedding = self.model.encode(chunk, normalize=True)[0]
- embeddings.append(embedding)
- embeddings = np.array(embeddings, dtype=np.float32)
- # print()
- # print(embeddings.shape)
- self.knowledge_database.add(embeddings) # 向量会按照添加顺序存储在数据库中,查询时也会返回这个顺序作为索引
- def text_embedding(self, text_root_dir):
- """"一次性对知识库所有文本进行embedding"""
- print("正在创建知识库...")
- if not os.path.exists(text_root_dir):
- raise FileNotFoundError(f"{text_root_dir}不存在")
- if not self.all_files:
- self.get_all_texts_directory(text_root_dir)
- for file in self.all_files:
- # 将文件内容添加到元数据
- self.add_meta_data(file)
- print(f"完成处理文件:{file}")
- # 保存向量数据库
- faiss.write_index(self.knowledge_database, self.database_path)
- print(f"向量数据库创建成功,保存至{self.database_path}")
- # 保存元数据
- with open(self.text_mata_path, "w", encoding="utf-8") as f:
- json.dump(self.text_meta, f, ensure_ascii=False, indent=4)
- print(f"元数据保存成功,保存至{self.text_mata_path}")
- with open(self.chunk_meta_path, "w", encoding="utf-8") as f:
- json.dump(self.chunk_meta, f, ensure_ascii=False, indent=4)
- print(f"元数据保存成功,保存至{self.chunk_meta_path}")
- def search(self, query:str, top_k:int=3, confidence_threshold:float=0.20, *args,**kwargs):
- """从知识库中检索"""
- if not isinstance(query, str) or query == '':
- return None
- query_embedding = np.array(self.model.encode([query], normalize=True), dtype=np.float32)
- distances, indices = self.knowledge_database.search(query_embedding, self.indicators)
- distances = distances[0]
- indices = indices[0]
- rerank_score = self.model.compute_score([(query, self.chunk_meta[str(indices[i])]['content']) for i in range(self.indicators) if indices[i] > -1])
- rerank_results = sorted([(indices[_], distances[_], rerank_score[_]) for _ in range(len(rerank_score))], key=lambda x: x[2], reverse=True)[:top_k]
- search_results = []
- for ind, dis, score in rerank_results:
- if score < confidence_threshold:
- continue
- title = self.chunk_meta[str(int(ind))]['father']
- content = self.text_meta[title].replace(self.split_char, ' ')
- search_results.append({
- "title": title,
- "content": content,
- "score": score,
- })
- return search_results
- knowledge_comp_rule_regulation_db = TextVectorDatabase(name='rule_regulation',text_root_dir=r'D:\code\repository\RAG资料库—上线\公司规章制度_语义块') # 公司规章制度
- knowledge_comp_propaganda_db = TextVectorDatabase(name='propaganda',text_root_dir=r'D:\code\repository\RAG资料库—上线\公司宣传及产品介绍_语义块') # 公司宣传及产品介绍
- knowledge_comp_water_treatment_db = TextVectorDatabase(name='water_treatment',text_root_dir=r'D:\code\repository\RAG资料库—上线\污水处理工艺知识_语义块') # 污水处理工艺知识
- knowledge_comp_project_plan_db = TextVectorDatabase(name='project_plan',text_root_dir=r'D:\code\repository\RAG资料库—上线\污水处理项目方案_语义块') # 污水处理项目方案
- knowledge_comp_operation_report_db = TextVectorDatabase(name='operation_report',text_root_dir=r'D:\code\repository\RAG资料库—上线\水厂运行报告_语义块') # 水厂运行报告
- if __name__ == '__main__':
- res = knowledge_comp_rule_regulation_db.search('金科环境的实习期有多长?', top_k=3)
- print(len(res), res)
- res = knowledge_comp_propaganda_db.search('介绍一下新水岛产品', top_k=3)
- print(len(res), res)
- res = knowledge_comp_water_treatment_db.search('反渗透RO如何计算', top_k=3)
- print(len(res), res)
- res = knowledge_comp_project_plan_db.search('请给我一份新水岛实施项目方案大纲', top_k=3)
- print(len(res), res)
- res = knowledge_comp_operation_report_db.search('水厂运行周报告', top_k=3)
- print(len(res), res)
|