# version: 2025.12.04 import requests from typing import List, Tuple, Optional import os import json import time import numpy as np from FlagEmbedding import FlagAutoModel, FlagReranker script_dir = os.path.dirname(os.path.abspath(__file__)) class RemoteBGEModel: def __init__(self, branch:str='dev', timeout:int=3, max_retries:int=3): # 加载网址配置文件 self.branch = branch.strip().lower() if not self.branch in ['dev', 'test', 'master', 'main', 'local']: # 输入参数合法 raise ValueError("Param 'branch' must be dev test master or main",branch) self.url_file = os.path.join(script_dir, 'url_config.json') self.embedding_url, self.reranker_url = self.load_url() self.timeout = timeout self.max_retries = max_retries # 构建请求头 self.headers = {"Content-Type": "application/json"} def load_url(self): """加载url""" if not os.path.exists(self.url_file): raise FileNotFoundError("File not exist", self.url_file) # 读取json配置文件 with open(self.url_file, 'r', encoding='utf-8') as f: json_data = json.load(f) if self.branch == 'dev' or self.branch == 'test': embed_url = json_data['dev_embed_url'] + '/embed' rerank_url = json_data['dev_reranker_url'] + '/rerank' elif self.branch == 'main' or self.branch == 'master': embed_url = json_data['master_embed_url'] + '/embed' rerank_url = json_data['master_reranker_url'] + '/rerank' else: embed_url = json_data['local_embed_url'] + '/embed' rerank_url = json_data['local_reranker_url'] + '/rerank' return embed_url, rerank_url def _access_remote_model(self, url:str, data:dict): """调用bge-m3,embedding""" # 类型检查 time.sleep(0.08) # 方式频繁调用接口 for attempt in range(self.max_retries): try: response = requests.post(url=url, headers=self.headers, json=data) if response.status_code == 200: return np.array(response.json()) except Exception as e: print('请求embedding模型失败', e) time.sleep(1) return None return None def encode(self,texts: List[str], normalize: bool = True): """调用bge-m3,embedding""" # 类型检查 if not isinstance(texts, list) and not isinstance(texts, str): raise TypeError("Text must be list or string",texts) if isinstance(texts, List): if not texts: raise ValueError("Text must not be empty",texts) for i, content in enumerate(texts): if not isinstance(content, str): raise ValueError(f"Text must not be empty, pos:{i}, content{content}") # data = {"inputs":texts, "normalize":normalize, "return_dense": True, "return_sparse": True, "return_colbert_vecs": False} data = {"inputs":texts, "normalize":normalize} return self._access_remote_model( url=self.embedding_url, data=data ) def compute_score(self, pairs: List[Tuple[str, str]], rerank_limit:int=32): """调用远程bge-reranker计算相关性""" # 类型检查 if not isinstance(pairs, list): raise TypeError("Pairs must be list",pairs) if not pairs: raise ValueError("Pairs must not be empty",pairs) if len(pairs[0]) != 2: raise ValueError("Pairs must not be empty",pairs) i = 0 for j, k in pairs: if not isinstance(j, str) or not isinstance(k, str): raise TypeError(f"Elements of every pairs must not be str, pos:{i}, ({j}, {k})") i+=1 # 判断pairs的每个query是否为一致 if len(pairs) >= 3: for i in range(1, len(pairs), len(pairs) - 1): if pairs[i - 1][0] != pairs[i][0] or pairs[i-1][0] != pairs[i+1][0]: raise ValueError("Pairs must have the same query", pairs) elif len(pairs) == 2: if pairs[0][0] != pairs[1][0]: raise ValueError("Pairs must have the same query", pairs) texts = [t for q, t in pairs] # 为了处理超过32个输入的情形,我们需要为texts进行分组,然后计算排序分数,并将分数按照index的顺序排序,最后拼接所有的结果作为最终的结果 # 分组 group_texts = [texts[i:i + rerank_limit] for i in range(0, len(texts), rerank_limit)] score = [] for single_group in group_texts: data = { "query": pairs[0][0], # 对于bge-reranker,query字段可为空 "texts": single_group # 输入长度不能超过32, rerank_limit为远程模型有效的输入长度 } # 多次调用远程模型,并返回rerank结果 res = self._access_remote_model( url=self.reranker_url, data=data ) # 按照原有位置整理结果,并拼接不同组的分数 score.extend([_["score"] for _ in sorted(res, key=lambda x: x["index"])]) return score if __name__ == "__main__": timeout = 3 max_retries = 3 bge_model = RemoteBGEModel('dev', timeout, max_retries) t = bge_model.encode(["hello"], normalize=True) tt = bge_model.compute_score([("你好呀我的名字叫做汤姆","今天世界杯中国得了冠军"), ("你好呀我的名字叫做汤姆","你好呀我的名字叫做山姆"), ("你好呀我的名字叫做汤姆","你好呀我的名字叫做汤姆?"), ("你好呀我的名字叫做汤姆","我今天非常的开心,你呢?")]) # reranker = FlagReranker(os.path.join(script_dir, 'bge-reranker-v2-m3'), use_fp16=True, local_files_only=True, # devices=["cuda:0"]) # ttt = reranker.compute_score([("你好呀我的名字叫做汤姆","今天世界杯中国得了冠军"), # ("你好呀我的名字叫做汤姆","你好呀我的名字叫做山姆"), # ("你好呀我的名字叫做汤姆","你好呀我的名字叫做汤姆?"), # ("你好呀我的名字叫做汤姆","我今天非常的开心,你呢?")], normalize=True) pass