| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # version: 2025.12.04
- import requests
- from typing import List, Tuple, Optional
- import os
- import json
- import time
- import numpy as np
- from FlagEmbedding import FlagAutoModel, FlagReranker
- script_dir = os.path.dirname(os.path.abspath(__file__))
- class RemoteBGEModel:
- def __init__(self, branch:str='dev', timeout:int=3, max_retries:int=3):
- # 加载网址配置文件
- self.branch = branch.strip().lower()
- if not self.branch in ['dev', 'test', 'master', 'main', 'local']: # 输入参数合法
- raise ValueError("Param 'branch' must be dev test master or main",branch)
- self.url_file = os.path.join(script_dir, 'url_config.json')
- self.embedding_url, self.reranker_url = self.load_url()
- self.timeout = timeout
- self.max_retries = max_retries
- # 构建请求头
- self.headers = {"Content-Type": "application/json"}
- def load_url(self):
- """加载url"""
- if not os.path.exists(self.url_file):
- raise FileNotFoundError("File not exist", self.url_file)
- # 读取json配置文件
- with open(self.url_file, 'r', encoding='utf-8') as f:
- json_data = json.load(f)
- if self.branch == 'dev' or self.branch == 'test':
- embed_url = json_data['dev_embed_url'] + '/embed'
- rerank_url = json_data['dev_reranker_url'] + '/rerank'
- elif self.branch == 'main' or self.branch == 'master':
- embed_url = json_data['master_embed_url'] + '/embed'
- rerank_url = json_data['master_reranker_url'] + '/rerank'
- else:
- embed_url = json_data['local_embed_url'] + '/embed'
- rerank_url = json_data['local_reranker_url'] + '/rerank'
- return embed_url, rerank_url
- def _access_remote_model(self, url:str, data:dict):
- """调用bge-m3,embedding"""
- # 类型检查
- time.sleep(0.08) # 方式频繁调用接口
- for attempt in range(self.max_retries):
- try:
- response = requests.post(url=url, headers=self.headers, json=data)
- if response.status_code == 200:
- return np.array(response.json())
- except Exception as e:
- print('请求embedding模型失败', e)
- time.sleep(1)
- return None
- return None
- def encode(self,texts: List[str], normalize: bool = True):
- """调用bge-m3,embedding"""
- # 类型检查
- if not isinstance(texts, list) and not isinstance(texts, str):
- raise TypeError("Text must be list or string",texts)
- if isinstance(texts, List):
- if not texts:
- raise ValueError("Text must not be empty",texts)
- for i, content in enumerate(texts):
- if not isinstance(content, str):
- raise ValueError(f"Text must not be empty, pos:{i}, content{content}")
- # data = {"inputs":texts, "normalize":normalize, "return_dense": True, "return_sparse": True, "return_colbert_vecs": False}
- data = {"inputs":texts, "normalize":normalize}
- return self._access_remote_model(
- url=self.embedding_url,
- data=data
- )
- def compute_score(self, pairs: List[Tuple[str, str]], rerank_limit:int=32):
- """调用远程bge-reranker计算相关性"""
- # 类型检查
- if not isinstance(pairs, list):
- raise TypeError("Pairs must be list",pairs)
- if not pairs:
- raise ValueError("Pairs must not be empty",pairs)
- if len(pairs[0]) != 2:
- raise ValueError("Pairs must not be empty",pairs)
- i = 0
- for j, k in pairs:
- if not isinstance(j, str) or not isinstance(k, str):
- raise TypeError(f"Elements of every pairs must not be str, pos:{i}, ({j}, {k})")
- i+=1
- # 判断pairs的每个query是否为一致
- if len(pairs) >= 3:
- for i in range(1, len(pairs), len(pairs) - 1):
- if pairs[i - 1][0] != pairs[i][0] or pairs[i-1][0] != pairs[i+1][0]:
- raise ValueError("Pairs must have the same query", pairs)
- elif len(pairs) == 2:
- if pairs[0][0] != pairs[1][0]:
- raise ValueError("Pairs must have the same query", pairs)
- texts = [t for q, t in pairs]
- # 为了处理超过32个输入的情形,我们需要为texts进行分组,然后计算排序分数,并将分数按照index的顺序排序,最后拼接所有的结果作为最终的结果
- # 分组
- group_texts = [texts[i:i + rerank_limit] for i in range(0, len(texts), rerank_limit)]
- score = []
- for single_group in group_texts:
- data = {
- "query": pairs[0][0], # 对于bge-reranker,query字段可为空
- "texts": single_group # 输入长度不能超过32, rerank_limit为远程模型有效的输入长度
- }
- # 多次调用远程模型,并返回rerank结果
- res = self._access_remote_model(
- url=self.reranker_url,
- data=data
- )
- # 按照原有位置整理结果,并拼接不同组的分数
- score.extend([_["score"] for _ in sorted(res, key=lambda x: x["index"])])
- return score
- if __name__ == "__main__":
- timeout = 3
- max_retries = 3
- bge_model = RemoteBGEModel('dev', timeout, max_retries)
- t = bge_model.encode(["hello"], normalize=True)
- tt = bge_model.compute_score([("你好呀我的名字叫做汤姆","今天世界杯中国得了冠军"),
- ("你好呀我的名字叫做汤姆","你好呀我的名字叫做山姆"),
- ("你好呀我的名字叫做汤姆","你好呀我的名字叫做汤姆?"),
- ("你好呀我的名字叫做汤姆","我今天非常的开心,你呢?")])
- # reranker = FlagReranker(os.path.join(script_dir, 'bge-reranker-v2-m3'), use_fp16=True, local_files_only=True,
- # devices=["cuda:0"])
- # ttt = reranker.compute_score([("你好呀我的名字叫做汤姆","今天世界杯中国得了冠军"),
- # ("你好呀我的名字叫做汤姆","你好呀我的名字叫做山姆"),
- # ("你好呀我的名字叫做汤姆","你好呀我的名字叫做汤姆?"),
- # ("你好呀我的名字叫做汤姆","我今天非常的开心,你呢?")], normalize=True)
- pass
|