jiyuhang 2 bulan lalu
induk
melakukan
1ceaaced05
3 mengubah file dengan 33 tambahan dan 15 penghapusan
  1. 1 1
      patch_intent_cls.py
  2. 19 14
      remote_model.py
  3. 13 0
      tem.py

+ 1 - 1
patch_intent_cls.py

@@ -14,7 +14,7 @@ class IntentRecognizer:
 
     def __init__(self):
         # 加载元数据
-        self.template_meta = {}
+        self.template_meta = {}  # 字典是有序的
         for k, v in template.items():
             for desc in v:
                 self.template_meta[desc] = k

+ 19 - 14
remote_model.py

@@ -67,6 +67,7 @@ class RemoteBGEModel:
             for i, content in enumerate(texts):
                 if not isinstance(content, str):
                     raise ValueError(f"Text must not be empty, pos:{i}, content{content}")
+        # data = {"inputs":texts, "normalize":normalize, "return_dense": True, "return_sparse": True, "return_colbert_vecs": False}
         data = {"inputs":texts, "normalize":normalize}
 
         return self._access_remote_model(
@@ -74,8 +75,8 @@ class RemoteBGEModel:
             data=data
         )
 
-    def compute_score(self, pairs: List[Tuple[str, str]]):
-        """调用远程bge-reranker计算相关性, 并按照原位置输出分数"""
+    def compute_score(self, pairs: List[Tuple[str, str]], rerank_limit:int=32):
+        """调用远程bge-reranker计算相关性"""
         # 类型检查
         if not isinstance(pairs, list):
             raise TypeError("Pairs must be list",pairs)
@@ -99,18 +100,22 @@ class RemoteBGEModel:
             if pairs[0][0] != pairs[1][0]:
                 raise ValueError("Pairs must have the same query", pairs)
         texts = [t for q, t in pairs]
-        data = {
-            "query": pairs[0][0],  # 对于bge-reranker,query字段可为空
-            "texts": texts
-        }
-
-        # 返回rerank结果
-        res = self._access_remote_model(
-            url=self.reranker_url,
-            data=data
-        )
-        # 按照原有位置输出score
-        score = [_["score"] for _ in sorted(res, key=lambda x: x["index"])]
+        # 为了处理超过32个输入的情形,我们需要为texts进行分组,然后计算排序分数,并将分数按照index的顺序排序,最后拼接所有的结果作为最终的结果
+        # 分组
+        group_texts = [texts[i:i + rerank_limit] for i in range(0, len(texts), rerank_limit)]
+        score = []
+        for single_group in group_texts:
+            data = {
+                "query": pairs[0][0],  # 对于bge-reranker,query字段可为空
+                "texts": single_group  # 输入长度不能超过32, rerank_limit为远程模型有效的输入长度
+            }
+            # 多次调用远程模型,并返回rerank结果
+            res = self._access_remote_model(
+                url=self.reranker_url,
+                data=data
+            )
+            # 按照原有位置整理结果,并拼接不同组的分数
+            score.extend([_["score"] for _ in sorted(res, key=lambda x: x["index"])])
         return score
 
 

+ 13 - 0
tem.py

@@ -0,0 +1,13 @@
+import faiss
+from remote_model import RemoteBGEModel
+import numpy as np
+
+model = RemoteBGEModel()
+
+embedding = np.array(model.encode(["hello world"]), dtype=np.float32)
+
+db = faiss.IndexFlatIP(embedding[0].shape[0])
+
+db.add(embedding)
+dis, idx =db.search(embedding, 5)
+print(dis, idx)