patch_intent_cls_local.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # 从本地加载模型进行推理,bge和reranker均使用本地模型
  2. from FlagEmbedding import FlagAutoModel, FlagReranker
  3. from intent_description_template import template, intent_code
  4. import torch
  5. import os
  6. import faiss
  7. import numpy as np
  8. import re
  9. script_dir = os.path.dirname(os.path.abspath(__file__))
  10. class IntentRecognizer:
  11. def __init__(self):
  12. # 加载元数据
  13. self.template_meta = {}
  14. for k, v in template.items():
  15. for desc in v:
  16. self.template_meta[desc] = k
  17. # 模板元数据
  18. self.template_meta_list = list(self.template_meta.keys())
  19. # 加载模型
  20. self.model = FlagAutoModel.from_finetuned(os.path.join(script_dir, "bge-m3"),
  21. query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
  22. local_files_only=True,
  23. use_fp16=True,
  24. pooling_method="cls",
  25. devices=["cuda:0"])
  26. self.reranker = FlagReranker(os.path.join(script_dir, 'bge-reranker-v2-m3'), use_fp16=True, local_files_only=True, devices=["cuda:0"])
  27. # 模型预热
  28. print("模型预热中...")
  29. self.model.encode(["这是一段预热文字,首次推理通过预测保证后续推理的稳定性和性能。"])
  30. self.reranker.compute_score([("这是一段预热文字,首次推理通过预测保证后续推理的稳定性和性能。",
  31. "这是一段预热文字,首次推理通过预测保证后续推理的稳定性和性能。")])
  32. self.print_gpu()
  33. # 加载向量数据库
  34. self.database_index = None
  35. database_path = os.path.join(script_dir, "intent_index.faiss")
  36. if not os.path.exists(database_path):
  37. # 要求embeddings是一个二维矩阵,类型为float32
  38. embeddings = self.model.encode(self.template_meta_list)['dense_vecs'].astype(np.float32) # 选取密集向量,变为float32
  39. faiss.normalize_L2(embeddings) # L2归一化
  40. # Create FAISS index
  41. dimension = embeddings[0].shape[0]
  42. self.database_index = faiss.IndexFlatIP(dimension) # 建立内积索引
  43. self.database_index.add(embeddings) # 添加索引
  44. # Save for future use
  45. faiss.write_index(self.database_index, database_path)
  46. if self.database_index is None:
  47. self.database_index = faiss.read_index(database_path)
  48. @staticmethod
  49. def print_gpu():
  50. if torch.cuda.is_available():
  51. print(f"allocated:{torch.cuda.memory_allocated()/1024**3:.2f}GB", end=' ')
  52. print(f"reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB")
  53. # 应用推理阶段
  54. def pick_out(self, query, top_k):
  55. # 要求query_embedding是一个二维矩阵,形状为(1, 1024)
  56. query_embedding = self.model.encode([query])['dense_vecs'].astype(np.float32)
  57. faiss.normalize_L2(query_embedding)
  58. distances, indices = self.database_index.search(query_embedding, top_k)
  59. group_query = [(query, self.template_meta_list[indices[0][i]]) for i in range(top_k)]
  60. score = self.reranker.compute_score(group_query, normalize=True)
  61. rerank_result = sorted([(distances[0][_], indices[0][_], score[_]) for _ in range(top_k)], key=lambda x: x[2],
  62. reverse=True) # distance, indices, rerank_score
  63. score_idx = 2 # 重排序相关度
  64. meta_idx = 1 # 模板位置
  65. similarity_idx = 0 # 向量相似度
  66. print("***检索结果***:")
  67. for i in range(top_k - 1, -1, -1):
  68. print(
  69. f"***{i} 相关度:{rerank_result[i][score_idx]:.2f} 相似度:{rerank_result[i][similarity_idx]:.2f} 意图:{self.template_meta.get(self.template_meta_list[rerank_result[i][meta_idx]])} 关联:{self.template_meta_list[rerank_result[i][meta_idx]]}")
  70. # 从排序结果中拆解到意图的大小类编号
  71. result = [] # 意图识别结果
  72. confidence = []# 置信度
  73. for i in range(top_k):
  74. # 拿到描述词
  75. description = self.template_meta_list[rerank_result[i][meta_idx]]
  76. # 拿到自编码
  77. custom_number = self.template_meta[description]
  78. # 拿到大小类标号
  79. result.append(intent_code[custom_number])
  80. # 添加置信度
  81. confidence.append(rerank_result[i][score_idx])
  82. return confidence, result
  83. def quick_answer_q_a_v2(question, is_english=0):
  84. """快速问答分支,赋值metric"""
  85. metric = ''
  86. question += ' ' # 深拷贝
  87. if is_english == 0:
  88. pattern_res_open_qa = re.findall("(开启|打开).*?问答", question)
  89. pattern_res_close_qa = re.findall("(关闭|关掉).*?问答", question)
  90. else:
  91. question = question.strip().lower()
  92. pattern_res_open_qa = re.findall(
  93. r'\b(?:open|show me|enable)\b\W*.*?(?:quiz\W+with\W+prizes|enable\W+award\W*-\W*winning\W+q\W*&\W*a)',
  94. question)
  95. pattern_res_close_qa = re.findall(
  96. r'\b(?:close|disable)\b\W*.*?(?:quiz\W+with\W+prizes|enable\W+award\W*-\W*winning\W+q\W*&\W*a)', question)
  97. if len(pattern_res_open_qa) > 0:
  98. metric = "openQandA"
  99. if len(pattern_res_close_qa) > 0:
  100. metric = "closeQandA"
  101. return metric
  102. # 单例模式
  103. recognizer_bge = IntentRecognizer()