|
|
@@ -0,0 +1,464 @@
|
|
|
+import os
|
|
|
+
|
|
|
+from sympy.solvers.diophantine.diophantine import equivalent
|
|
|
+
|
|
|
+script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+import sys
|
|
|
+sys.path.append(script_dir)
|
|
|
+import pandas as pd
|
|
|
+import jieba
|
|
|
+import jieba.posseg as pseg
|
|
|
+import re
|
|
|
+import numpy as np
|
|
|
+import json
|
|
|
+import textdistance
|
|
|
+import faiss
|
|
|
+from remote_model import RemoteBGEModel
|
|
|
+
|
|
|
+
|
|
|
+class PLCMatch:
|
|
|
+ """通过关键词+语义相似度的方式,从用户输入中匹配PLC点位"""
|
|
|
+ def __init__(self, project_id:int):
|
|
|
+ # 水厂id
|
|
|
+ self.project_id = str(project_id)
|
|
|
+ # 路径
|
|
|
+ self.script_dir = os.path.dirname(os.path.abspath(__file__)) # 脚本绝对路径
|
|
|
+ # 水厂的词典根路径
|
|
|
+ self.plc_dict_root_dir = os.path.join(self.script_dir, f'plc_dictionary/{self.project_id}_plc_dictionary')
|
|
|
+ # 读取pcl点位文件,生成name-code映射字典
|
|
|
+ self.name_2_code_dict = self.__read_pcl()
|
|
|
+
|
|
|
+ # 加载用户自定义词典,添加到jieba词库
|
|
|
+ user_dictionary_dir = os.path.join(self.script_dir, 'user_maintain_dictionary', 'jieba_words')
|
|
|
+ user_dict_list = [os.path.join(user_dictionary_dir, _) for _ in os.listdir(user_dictionary_dir) if _.split('.')[-1] == 'txt'] # 用户词典
|
|
|
+ self.user_dict_list = user_dict_list
|
|
|
+ self.__load_user_dict()
|
|
|
+
|
|
|
+ # 生成二级字典
|
|
|
+ self.dict_level_2 = self.__make_level_two_dictionary()
|
|
|
+
|
|
|
+ # 生成一级字典
|
|
|
+ self.dict_level_1 = self.__make_level_one_dictionary()
|
|
|
+
|
|
|
+ # 等价词映射表
|
|
|
+ self.equivalent_wordmap_txt = os.path.join(self.script_dir,'user_maintain_dictionary','equivalent_words', 'equivalent_wordmap.txt')
|
|
|
+ self.dict_equivalent_wordmap = self.__construct_equivalent_wordmap()
|
|
|
+
|
|
|
+ # 生成知识库,PLC点位数据库中文字段
|
|
|
+ # 加载bge-m3和bge-reranker远程模型
|
|
|
+ self.plc_database_name_template_list = list(self.name_2_code_dict.keys())
|
|
|
+ self.model = RemoteBGEModel('dev')
|
|
|
+ self.knowledge = self.__load_faiss_database()
|
|
|
+
|
|
|
+
|
|
|
+ def __load_faiss_database(self):
|
|
|
+ """从本地加载向量数据库"""
|
|
|
+ # 水厂的数据库字段知识库
|
|
|
+ faiss_path = os.path.join(self.plc_dict_root_dir, f'{self.project_id}_knowledge.faiss')
|
|
|
+ # 尝试从本地加载
|
|
|
+ if os.path.exists(faiss_path):
|
|
|
+ print('PLC点位查询功能从本地加载点位字段向量知识库...')
|
|
|
+ return faiss.read_index(faiss_path)
|
|
|
+
|
|
|
+ # 如果不存在就尝试重新创建
|
|
|
+ # 首先,我们需要拿到数据库的点位名称,可以直接从name-code映射字典当中获取
|
|
|
+ plc_database_name_template_list = self.plc_database_name_template_list
|
|
|
+ # 调用远程embedding模型,one by one 地处理,远程模型通过配置参数进行归一化
|
|
|
+ embeddings = [self.model.encode([temp], normalize=True)[0] for temp in plc_database_name_template_list]
|
|
|
+ for _ in embeddings:
|
|
|
+ if _ is None:
|
|
|
+ raise RuntimeError('为plc数据库中文字段构建向量数据库时发生异常,embeddings不能存在None')
|
|
|
+ # 要求embeddings是一个二维矩阵,类型为float32
|
|
|
+ embeddings = np.array(embeddings, dtype=np.float32)
|
|
|
+ # 创建 FAISS 索引
|
|
|
+ dimension = embeddings[0].shape[0]
|
|
|
+ local_faiss = faiss.IndexFlatIP(dimension) # 建立内积索引
|
|
|
+ local_faiss.add(embeddings) # 添加索引
|
|
|
+ # 保存未来使用
|
|
|
+ faiss.write_index(local_faiss, faiss_path)
|
|
|
+ return local_faiss
|
|
|
+
|
|
|
+
|
|
|
+ def __read_pcl(self):
|
|
|
+ """
|
|
|
+ 读取pcl文件,生成name2code词典
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ # name-code映射词典路径
|
|
|
+ dict_name2code_path = os.path.join(self.plc_dict_root_dir, f'{self.project_id}_dict_name_2_code.json')
|
|
|
+ # 尝试从本地加载name-code映射字典
|
|
|
+ if os.path.exists(dict_name2code_path):
|
|
|
+ with open(dict_name2code_path, 'r', encoding='utf-8') as f:
|
|
|
+ dict_name2code = json.load(f)
|
|
|
+ return dict_name2code
|
|
|
+
|
|
|
+ # 如果本地没有就重新生成
|
|
|
+ # 检查点位文件是否存在
|
|
|
+ pcl_file_path = os.path.join(self.plc_dict_root_dir, f'{self.project_id}_点位.xlsx') # 点位文件路径
|
|
|
+ if not os.path.exists(pcl_file_path):
|
|
|
+ raise FileNotFoundError(f'{pcl_file_path} does not exist')
|
|
|
+ # 读点位
|
|
|
+ points = pd.read_excel(pcl_file_path)
|
|
|
+ # 列名称,name | code
|
|
|
+ column_label_alias, column_label_code = points.columns.tolist()
|
|
|
+ # 中英文匹配
|
|
|
+ names = points.loc[:, column_label_alias].to_numpy()
|
|
|
+ codes = points.loc[:, column_label_code].to_numpy()
|
|
|
+ # 对齐命名规范, 按照中荷水厂命名风格,将1#UF或1#RO统一替换为UF1,RO1,将所有反渗透文字替换为RO,所有超滤文字替换为UF
|
|
|
+ names = [s.replace('超滤','UF').replace('反渗透','RO') for s in names]
|
|
|
+ names = [self.field_align(s) for s in names]
|
|
|
+ # 名到英文的字典
|
|
|
+ dict_name2code = dict(zip(names, codes))
|
|
|
+ # name-code映射字典保存到本地文件
|
|
|
+ with open(dict_name2code_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(dict_name2code, f, ensure_ascii=False)
|
|
|
+ return dict_name2code
|
|
|
+
|
|
|
+ def __load_user_dict(self):
|
|
|
+ """加载用户词典,添加到jieba词库"""
|
|
|
+ # 删除
|
|
|
+ jieba.del_word('反渗透')
|
|
|
+ jieba.del_word('超滤')
|
|
|
+ for user_dict_txt in self.user_dict_list:
|
|
|
+ # 检查文件是否存在
|
|
|
+ if not os.path.exists(user_dict_txt):
|
|
|
+ raise FileNotFoundError(f'{user_dict_txt} does not exist')
|
|
|
+ # 检查文件后缀名是否合法
|
|
|
+ if os.path.splitext(user_dict_txt)[1] != '.txt':
|
|
|
+ continue
|
|
|
+ # 分词库加载用户字典
|
|
|
+ jieba.load_userdict(user_dict_txt)
|
|
|
+
|
|
|
+ def __construct_equivalent_wordmap(self):
|
|
|
+ """构建等价词汇映射表,等价词汇的使用方式是将备查词的所有等效说法都纳入备查序列,从而保证了搜索的高召回率"""
|
|
|
+ # 检查文件是否存在
|
|
|
+ equivalent_wordmap_path = os.path.join(self.script_dir, 'user_maintain_dictionary','equivalent_words', 'dict_equivalent_wordmap.json')
|
|
|
+ if os.path.exists(equivalent_wordmap_path):
|
|
|
+ with open(equivalent_wordmap_path, 'r', encoding='utf-8') as f:
|
|
|
+ equivalent_wordmap = json.load(f)
|
|
|
+ return equivalent_wordmap
|
|
|
+ # 如果本地不存在等价词典json文件,那么就尝试创建
|
|
|
+ if not os.path.exists(self.equivalent_wordmap_txt):
|
|
|
+ raise FileNotFoundError(f'{self.equivalent_wordmap_txt} does not exist')
|
|
|
+
|
|
|
+ with open(self.equivalent_wordmap_txt, 'r', encoding='utf-8') as f:
|
|
|
+ all_lines = [_.strip() for _ in f.readlines()]
|
|
|
+ # 创建等价词汇映射表
|
|
|
+ dict_equi_wordmap = {}
|
|
|
+ for line in all_lines:
|
|
|
+ split_list = line.split('=')
|
|
|
+ for i in range(len(split_list)):
|
|
|
+ dict_equi_wordmap[split_list[i]] = split_list
|
|
|
+ with open(equivalent_wordmap_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(dict_equi_wordmap,f,ensure_ascii=False)
|
|
|
+ return dict_equi_wordmap
|
|
|
+
|
|
|
+ def __make_level_two_dictionary(self):
|
|
|
+ """创建二级字典,对点位所有字段进行正则匹配中文,将中文一样的字段聚合为同一个字典键值对,键为正则提取的中文字符"""
|
|
|
+ group_dict = {}
|
|
|
+ # 尝试从本地加载二级字典
|
|
|
+ dict_level2_dict_path = os.path.join(self.plc_dict_root_dir, f'{self.project_id}_dict_level_2.json')
|
|
|
+ if os.path.exists(dict_level2_dict_path):
|
|
|
+ with open(dict_level2_dict_path, 'r', encoding='utf-8') as f:
|
|
|
+ group_dict = json.load(f)
|
|
|
+ return group_dict
|
|
|
+
|
|
|
+ if self.name_2_code_dict is None:
|
|
|
+ raise ValueError(f'name_2_code_dict is None', self.name_2_code_dict)
|
|
|
+ data = self.name_2_code_dict.keys()
|
|
|
+
|
|
|
+ # 创建二级字典
|
|
|
+ for item in data:
|
|
|
+ k = re.sub(r'[^\u4e00-\u9fa5]', '', item)
|
|
|
+ # 处理没有汉字的字段
|
|
|
+ if k == '':
|
|
|
+ k = "无"
|
|
|
+ if k not in group_dict.keys():
|
|
|
+ group_dict[k] = [item]
|
|
|
+ else:
|
|
|
+ group_dict[k].append(item)
|
|
|
+
|
|
|
+ # 保存二级字典到本地
|
|
|
+ with open(dict_level2_dict_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(group_dict, f, ensure_ascii=False)
|
|
|
+ return group_dict
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def cut_compair(arr_a: str, arr_b: str, condition='nz') -> str:
|
|
|
+ """
|
|
|
+ :param condition: 词性
|
|
|
+ :param arr_a:
|
|
|
+ :param arr_b:
|
|
|
+ :return: 第一个相同nz词
|
|
|
+ """
|
|
|
+ # a: w1,f1 w2,f2 w3, f3
|
|
|
+ # b: w1,f1 w2,f2 w3, f3
|
|
|
+
|
|
|
+ cut_arr_a = [list(_) for _ in pseg.lcut(arr_a)]
|
|
|
+ cut_arr_b = [list(_) for _ in pseg.lcut(arr_b)]
|
|
|
+ for i in range(len(cut_arr_a)):
|
|
|
+ for j in range(i, len(cut_arr_b)):
|
|
|
+ # 只比较nz词性
|
|
|
+ if cut_arr_a[i][1] != condition or cut_arr_b[j][1] != condition:
|
|
|
+ continue
|
|
|
+ if cut_arr_a[i][0] == cut_arr_b[j][0] and cut_arr_a[i][1] == cut_arr_b[j][1]:
|
|
|
+ return cut_arr_a[i][0]
|
|
|
+ return ''
|
|
|
+
|
|
|
+ def __make_level_one_dictionary(self):
|
|
|
+ """创建一级字典"""
|
|
|
+ group_dict = {} # 存放二次分组的结果
|
|
|
+ # 尝试从本地加载一级字典
|
|
|
+ dict_level_1_path = os.path.join(self.plc_dict_root_dir, f'{self.project_id}_dict_level_1.json')
|
|
|
+ if os.path.exists(dict_level_1_path):
|
|
|
+ with open(dict_level_1_path, 'r', encoding='utf-8') as f:
|
|
|
+ group_dict = json.load(f)
|
|
|
+ return group_dict
|
|
|
+
|
|
|
+ if self.dict_level_2.keys() is None:
|
|
|
+ raise ValueError(f'dict_lev2 is None', self.dict_level_2)
|
|
|
+ # 提取二级字典的所有key
|
|
|
+ data = self.dict_level_2.keys()
|
|
|
+
|
|
|
+ # 如果不存在就重新生成一级字典
|
|
|
+ # 根据用户词典进行分词,筛选出所有带nz词的字段
|
|
|
+ no_nz_list = [] # 没有nz词的字段
|
|
|
+ nz_list = [] # 有nz词的字段
|
|
|
+ for item in data:
|
|
|
+ # 判断是否存在nz名词
|
|
|
+ is_exist_n = False
|
|
|
+ for w, f in pseg.lcut(item):
|
|
|
+ if f == 'nz': # 查看词性
|
|
|
+ is_exist_n = True
|
|
|
+ break
|
|
|
+ if is_exist_n: # 存在词
|
|
|
+ nz_list.append(item)
|
|
|
+ else: # 不存在nz词
|
|
|
+ no_nz_list.append(item)
|
|
|
+
|
|
|
+ # 聚合具有相同nz名词的字段
|
|
|
+ while len(nz_list) > 0:
|
|
|
+ pos = [1 for _ in range(len(nz_list))] # 0表示不被取,1表示需要被取,默认都要被取,用来更新nz_list给下次判断使用
|
|
|
+ pos[0] = 0 # 标记第一个单词为不需要处理
|
|
|
+ for i in range(len(nz_list)):
|
|
|
+ # 查看是否存在相同的nz词
|
|
|
+ same_nz_word = self.cut_compair(nz_list[0], nz_list[i])
|
|
|
+ if same_nz_word:
|
|
|
+ # 执行聚合
|
|
|
+ if same_nz_word not in group_dict.keys():
|
|
|
+ # 首次聚合,与自身比较,创建自身类别
|
|
|
+ group_dict[same_nz_word] = [nz_list[i]]
|
|
|
+ else:
|
|
|
+ group_dict[same_nz_word].append(nz_list[i])
|
|
|
+
|
|
|
+ pos[i] = 0
|
|
|
+ # 处理完一趟就要变更nz_list
|
|
|
+ nz_list = np.array(nz_list)[np.array(pos, dtype=np.bool)].tolist()
|
|
|
+
|
|
|
+ # 聚合不包含nz的名词, 单独占一个类别
|
|
|
+ for item in no_nz_list:
|
|
|
+ group_dict[item] = [item]
|
|
|
+
|
|
|
+ with open(dict_level_1_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(group_dict, f, ensure_ascii=False)
|
|
|
+
|
|
|
+ return group_dict
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def field_align(input_str:str)->str:
|
|
|
+ """按照锡山中荷命名规范对齐字段,1#UF替换为UF1,1#RO替换为RO1,保持统一"""
|
|
|
+ sources_uf = re.findall(r'\d+#UF', input_str, re.IGNORECASE) # 匹配1#UF
|
|
|
+ sources_ro = re.findall(r'\d+#RO', input_str, re.IGNORECASE) # 匹配1#RO
|
|
|
+ sources = sources_uf + sources_ro
|
|
|
+ for sou in sources:
|
|
|
+ number_, flag_ = sou.split('#')
|
|
|
+ input_str = input_str.replace(sou, flag_.upper() + number_) # 统一转为大写
|
|
|
+ return input_str
|
|
|
+
|
|
|
+ @ staticmethod
|
|
|
+ def quicksort_up_part(arr:list, start:int, end:int)-> int:
|
|
|
+ """升序排序"""
|
|
|
+ # 双指针
|
|
|
+ low = start
|
|
|
+ high = end
|
|
|
+ pivot = arr[start][1] # 基准值
|
|
|
+ # 大数放在基准值右边,小数放在基准值左边
|
|
|
+ while low < high:
|
|
|
+ # 先从右向左找比基准值小的
|
|
|
+ while low< high and arr[high][1] >= pivot:
|
|
|
+ high -= 1
|
|
|
+ # 此时high指向值小于基准值,交换
|
|
|
+ if low < high:
|
|
|
+ arr[low], arr[high] = arr[high], arr[low]
|
|
|
+ low +=1
|
|
|
+ # 现在开始从左向右找,比基准值大的数
|
|
|
+ while low < high and arr[low][1] <= pivot:
|
|
|
+ low += 1
|
|
|
+ # 此时low指向值大于基准值,交换
|
|
|
+ if low < high:
|
|
|
+ arr[high], arr[low] = arr[low], arr[high]
|
|
|
+ high -= 1
|
|
|
+ return low
|
|
|
+
|
|
|
+
|
|
|
+ def quicksort_up(self, arr:list, start:int, end:int):
|
|
|
+ """按照元组第二个元素值大小进行升序排序"""
|
|
|
+ if start >= end:
|
|
|
+ return
|
|
|
+ # 先排一次获得基准值位置
|
|
|
+ mid = self.quicksort_up_part(arr, start, end)
|
|
|
+ # 排左面
|
|
|
+ self.quicksort_up(arr, start, mid - 1)
|
|
|
+ # 排右面
|
|
|
+ self.quicksort_up(arr, mid + 1, end)
|
|
|
+
|
|
|
+ def words_similarity_score_sorted(self, query:str, candidates:list)->list:
|
|
|
+ """计算输入语句与候选词的相似度并按照相似度分值进行排序"""
|
|
|
+ # 选择算法(示例使用Levenshtein,归一化到0-1)
|
|
|
+ candidates = candidates.copy()
|
|
|
+ jarowinkler = textdistance.JaroWinkler()
|
|
|
+ key_score_list = [(candidate, jarowinkler.normalized_similarity(query, candidate)) for candidate in candidates]
|
|
|
+ self.quicksort_up(key_score_list, 0, len(key_score_list) - 1) # 升序排序
|
|
|
+ key_sorted_list = [tuple_element[0] for tuple_element in key_score_list] # 取出key
|
|
|
+ key_sorted_list = key_sorted_list[::-1] # 反转,变为降序
|
|
|
+ return key_sorted_list
|
|
|
+
|
|
|
+ def words_similarity_score_sorted_v2(self, query:str, candidates:list)->list:
|
|
|
+ """通过rerank的方式为候选词进行相似度排序"""
|
|
|
+ # 调用远程reranker模型
|
|
|
+ n = len(candidates) # 候选词数量
|
|
|
+ group_query = [(query, i) for i in candidates]
|
|
|
+ score = self.model.compute_score(group_query)
|
|
|
+ key_score_list = [(candidates[i], score[i]) for i in range(n)]
|
|
|
+ self.quicksort_up(key_score_list, 0, len(key_score_list) - 1) # 升序排序
|
|
|
+ key_sorted_list = [tuple_element[0] for tuple_element in key_score_list] # 取出key
|
|
|
+ key_sorted_list = key_sorted_list[::-1] # 反转,变为降序
|
|
|
+ return key_sorted_list
|
|
|
+
|
|
|
+ def match_v2_on(self, promt: str,is_agent:bool=False):
|
|
|
+ """
|
|
|
+ 模糊匹配v2
|
|
|
+ :param is_agent:
|
|
|
+ :param promt:
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ print("=" * 50)
|
|
|
+ # 命名风格转换
|
|
|
+ print("原始查询:", promt)
|
|
|
+ promt = promt.replace('超滤', 'UF').replace('反渗透', 'RO').replace('号', '#').replace('组', '#')
|
|
|
+ promt = self.field_align(promt)
|
|
|
+ print("转换查询:", promt)
|
|
|
+ # 输入分词
|
|
|
+ nz_words = []
|
|
|
+ for w, f in pseg.lcut(promt):
|
|
|
+ print(f'{w}({f})', end="")
|
|
|
+ if f == 'nz':
|
|
|
+ nz_words.append(w)
|
|
|
+ print('\n备查nz词:', nz_words)
|
|
|
+
|
|
|
+ # 处理专有名词的等价词,为了保证高召回率,我们将备查词的所有等价说法都放入备查序列
|
|
|
+ equivalent_words = []
|
|
|
+ for nz_idx, nz in enumerate(nz_words):
|
|
|
+ # 首先判断nz词是否在等价词汇表中,如果不在根本无法替换
|
|
|
+ if nz in self.dict_equivalent_wordmap.keys():
|
|
|
+ # 然后把等价的说法都添加进去就好了
|
|
|
+ equivalent_words = self.dict_equivalent_wordmap.get(nz, [])
|
|
|
+ if equivalent_words:
|
|
|
+ nz_words += equivalent_words
|
|
|
+ nz_words = list(set(nz_words))
|
|
|
+ print('等价备查nz词:', nz_words)
|
|
|
+ del equivalent_words
|
|
|
+
|
|
|
+ # 进行一级查询,根据nz词是否包含于词典
|
|
|
+ query_level_one = []
|
|
|
+ for i in range(len(nz_words)): # 为第i个nz词进行初次匹配
|
|
|
+ result = []
|
|
|
+ # 如果nz词包含在一级词典中就算匹配成功
|
|
|
+ for dict_level_1_key in self.dict_level_1.keys():
|
|
|
+ if nz_words[i] in dict_level_1_key: # 如果nz词包含在一级词典内
|
|
|
+ result+= self.dict_level_1.get(dict_level_1_key)
|
|
|
+ query_level_one.append(result) # 放入一级查询结果中
|
|
|
+
|
|
|
+ # 进行二级查询
|
|
|
+ query_level_two = []
|
|
|
+ for idx_nz, i_nz_query_result in enumerate(query_level_one): # 遍历每个nz词的查询结果
|
|
|
+ result = [] # 为第i个nz词进行二次匹配
|
|
|
+ # 如果第i个nz词一级查询不为空
|
|
|
+ if i_nz_query_result: # 第i个nz词的查询结果list
|
|
|
+ for res_word_level_one in i_nz_query_result:
|
|
|
+ if res_word_level_one in self.dict_level_2.keys():
|
|
|
+ result += self.dict_level_2.get(res_word_level_one) # self.dict_level_2的value本身就是字典,所以用+=拼接
|
|
|
+ # 虽然一级查询失败,但是并不意味着映射词典里没有,因为一级词典忽略英文。
|
|
|
+ else: # 如果一级查询失败,就直接在name2code字典中查询
|
|
|
+ if nz_words[idx_nz] in self.name_2_code_dict.keys():# 如果第i个nz词在2级词典,就直接添加到结果中
|
|
|
+ result.append(nz_words[idx_nz])
|
|
|
+ # 如果第i个nz词的一级查询结果为空,则添加空列表占位
|
|
|
+ query_level_two.append(result)
|
|
|
+
|
|
|
+ # 常规精确匹配结束,如果匹配成功,结构为二维列表,否则为空列表
|
|
|
+ matched_keys = query_level_two # 获取已匹配的字段
|
|
|
+ # 备查词合并,我们约定所有备查词进行统一的查询,后面怎么用这些结果取决于外部的应用,对于agent模式,将会输出许多结果,对月非agent只会输出概率最高的结果
|
|
|
+ tem_matched_keys = []
|
|
|
+ for item in matched_keys:
|
|
|
+ tem_matched_keys += item
|
|
|
+ matched_keys = [list(set(tem_matched_keys))]
|
|
|
+ del tem_matched_keys
|
|
|
+
|
|
|
+ # 如果精确匹配失败,没有匹配到任何结果则按照语义进行模糊匹配,返回满足条件的置信度最高的结果
|
|
|
+ # if not nz_words or ([] in matched_keys):
|
|
|
+ # 比起手动维护词典,我们更相信语义相似度
|
|
|
+ top_k = 5
|
|
|
+ confi = 0.2 # 置信度阈值
|
|
|
+ print(f'进入模糊匹配,召回Top:{top_k} 置信度阈值:{confi}...')
|
|
|
+ # 调用远程bge-m3模型进行embedding
|
|
|
+ query_embedding = np.array(self.model.encode([promt], normalize=True), dtype=np.float32) # 要求query_embedding是一个二维矩阵,形状为(1, 1024)
|
|
|
+ distances, indices = self.knowledge.search(query_embedding, top_k)
|
|
|
+ group_query = [(promt, self.plc_database_name_template_list[indices[0][i]]) for i in range(top_k)]
|
|
|
+ # 我们更愿意相信bge,因此把词典关键词匹配的结果一并放进去重排序
|
|
|
+ group_query_manuel = [(promt, k) for keys in matched_keys for k in keys]
|
|
|
+ group_query += group_query_manuel
|
|
|
+ del group_query_manuel
|
|
|
+ group_query = list(set(group_query)) # 去重
|
|
|
+ # 调用远程bge-reranker模型
|
|
|
+ score = self.model.compute_score(group_query)
|
|
|
+ rerank_result = sorted([(group_query[i][1], score[i]) for i in range(len(group_query))], key=lambda x: x[1], reverse=True)
|
|
|
+ print(F'打印前top{top_k}候选词结果:', rerank_result[:top_k])
|
|
|
+ print(f'首元素模糊匹配到{rerank_result[0][0]}, 置信度为{rerank_result[0][1]}')
|
|
|
+ # matched_keys 为最终结果,保持形状为二维列表
|
|
|
+ matched_keys = [[i[0] for i in rerank_result]]
|
|
|
+ # 每个匹配结果的置信度
|
|
|
+ matched_keys_score = [[i[1] for i in rerank_result]]
|
|
|
+
|
|
|
+ # 为结果创建映射字典
|
|
|
+ result_list = []
|
|
|
+ for i_nz_keys in matched_keys:
|
|
|
+ result_list.append([{key: self.name_2_code_dict.get(key)} for key in i_nz_keys])
|
|
|
+ print(f"查询到{len([_ for _ in result_list if _])}个结果:")
|
|
|
+
|
|
|
+ if not is_agent:
|
|
|
+ # 非agent模式每个匹配结果只取第一个元素的英文
|
|
|
+ tem_list = []
|
|
|
+ for res in result_list:
|
|
|
+ if res:
|
|
|
+ for k, v in res[0].items(): # 每个nz词的查询结果都是一个list,每个list可能包含多个字典
|
|
|
+ tem_list.append(f'{k}:{v}')
|
|
|
+ result_list = tem_list
|
|
|
+ print('以非agent模式返回:', result_list)
|
|
|
+ return result_list
|
|
|
+
|
|
|
+ print('以agent模式返回:', result_list)
|
|
|
+ print('='*50)
|
|
|
+ return result_list, matched_keys_score
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ pj = 92 # pcl点位
|
|
|
+ pcl_helper = PLCMatch(project_id=pj)
|
|
|
+ # 用户输入
|
|
|
+ my_promt = "我想要查询锡山中荷进水电导率"
|
|
|
+ # query_res = pcl_helper.match_v2_on(my_promt, is_agent=True)
|
|
|
+ query_res = pcl_helper.match_v2_on(my_promt, is_agent=False)
|
|
|
+
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+
|