|
|
@@ -67,6 +67,7 @@ class RemoteBGEModel:
|
|
|
for i, content in enumerate(texts):
|
|
|
if not isinstance(content, str):
|
|
|
raise ValueError(f"Text must not be empty, pos:{i}, content{content}")
|
|
|
+ # data = {"inputs":texts, "normalize":normalize, "return_dense": True, "return_sparse": True, "return_colbert_vecs": False}
|
|
|
data = {"inputs":texts, "normalize":normalize}
|
|
|
|
|
|
return self._access_remote_model(
|
|
|
@@ -74,8 +75,8 @@ class RemoteBGEModel:
|
|
|
data=data
|
|
|
)
|
|
|
|
|
|
- def compute_score(self, pairs: List[Tuple[str, str]]):
|
|
|
- """调用远程bge-reranker计算相关性, 并按照原位置输出分数"""
|
|
|
+ def compute_score(self, pairs: List[Tuple[str, str]], rerank_limit:int=32):
|
|
|
+ """调用远程bge-reranker计算相关性"""
|
|
|
# 类型检查
|
|
|
if not isinstance(pairs, list):
|
|
|
raise TypeError("Pairs must be list",pairs)
|
|
|
@@ -99,18 +100,22 @@ class RemoteBGEModel:
|
|
|
if pairs[0][0] != pairs[1][0]:
|
|
|
raise ValueError("Pairs must have the same query", pairs)
|
|
|
texts = [t for q, t in pairs]
|
|
|
- data = {
|
|
|
- "query": pairs[0][0], # 对于bge-reranker,query字段可为空
|
|
|
- "texts": texts
|
|
|
- }
|
|
|
-
|
|
|
- # 返回rerank结果
|
|
|
- res = self._access_remote_model(
|
|
|
- url=self.reranker_url,
|
|
|
- data=data
|
|
|
- )
|
|
|
- # 按照原有位置输出score
|
|
|
- score = [_["score"] for _ in sorted(res, key=lambda x: x["index"])]
|
|
|
+ # 为了处理超过32个输入的情形,我们需要为texts进行分组,然后计算排序分数,并将分数按照index的顺序排序,最后拼接所有的结果作为最终的结果
|
|
|
+ # 分组
|
|
|
+ group_texts = [texts[i:i + rerank_limit] for i in range(0, len(texts), rerank_limit)]
|
|
|
+ score = []
|
|
|
+ for single_group in group_texts:
|
|
|
+ data = {
|
|
|
+ "query": pairs[0][0], # 对于bge-reranker,query字段可为空
|
|
|
+ "texts": single_group # 输入长度不能超过32, rerank_limit为远程模型有效的输入长度
|
|
|
+ }
|
|
|
+ # 多次调用远程模型,并返回rerank结果
|
|
|
+ res = self._access_remote_model(
|
|
|
+ url=self.reranker_url,
|
|
|
+ data=data
|
|
|
+ )
|
|
|
+ # 按照原有位置整理结果,并拼接不同组的分数
|
|
|
+ score.extend([_["score"] for _ in sorted(res, key=lambda x: x["index"])])
|
|
|
return score
|
|
|
|
|
|
|