remote_model.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # version: 2025.12.04
  2. import requests
  3. from typing import List, Tuple, Optional
  4. import os
  5. import json
  6. import time
  7. import numpy as np
  8. from FlagEmbedding import FlagAutoModel, FlagReranker
  9. script_dir = os.path.dirname(os.path.abspath(__file__))
  10. class RemoteBGEModel:
  11. def __init__(self, branch:str='dev', timeout:int=3, max_retries:int=3):
  12. # 加载网址配置文件
  13. self.branch = branch.strip().lower()
  14. if not self.branch in ['dev', 'test', 'master', 'main', 'local']: # 输入参数合法
  15. raise ValueError("Param 'branch' must be dev test master or main",branch)
  16. self.url_file = os.path.join(script_dir, 'url_config.json')
  17. self.embedding_url, self.reranker_url = self.load_url()
  18. self.timeout = timeout
  19. self.max_retries = max_retries
  20. # 构建请求头
  21. self.headers = {"Content-Type": "application/json"}
  22. def load_url(self):
  23. """加载url"""
  24. if not os.path.exists(self.url_file):
  25. raise FileNotFoundError("File not exist", self.url_file)
  26. # 读取json配置文件
  27. with open(self.url_file, 'r', encoding='utf-8') as f:
  28. json_data = json.load(f)
  29. if self.branch == 'dev' or self.branch == 'test':
  30. embed_url = json_data['dev_embed_url'] + '/embed'
  31. rerank_url = json_data['dev_reranker_url'] + '/rerank'
  32. elif self.branch == 'main' or self.branch == 'master':
  33. embed_url = json_data['master_embed_url'] + '/embed'
  34. rerank_url = json_data['master_reranker_url'] + '/rerank'
  35. else:
  36. embed_url = json_data['local_embed_url'] + '/embed'
  37. rerank_url = json_data['local_reranker_url'] + '/rerank'
  38. return embed_url, rerank_url
  39. def _access_remote_model(self, url:str, data:dict):
  40. """调用bge-m3,embedding"""
  41. # 类型检查
  42. time.sleep(0.08) # 方式频繁调用接口
  43. for attempt in range(self.max_retries):
  44. try:
  45. response = requests.post(url=url, headers=self.headers, json=data)
  46. if response.status_code == 200:
  47. return np.array(response.json())
  48. except Exception as e:
  49. print('请求embedding模型失败', e)
  50. time.sleep(1)
  51. return None
  52. return None
  53. def encode(self,texts: List[str], normalize: bool = True):
  54. """调用bge-m3,embedding"""
  55. # 类型检查
  56. if not isinstance(texts, list) and not isinstance(texts, str):
  57. raise TypeError("Text must be list or string",texts)
  58. if isinstance(texts, List):
  59. if not texts:
  60. raise ValueError("Text must not be empty",texts)
  61. for i, content in enumerate(texts):
  62. if not isinstance(content, str):
  63. raise ValueError(f"Text must not be empty, pos:{i}, content{content}")
  64. # data = {"inputs":texts, "normalize":normalize, "return_dense": True, "return_sparse": True, "return_colbert_vecs": False}
  65. data = {"inputs":texts, "normalize":normalize}
  66. return self._access_remote_model(
  67. url=self.embedding_url,
  68. data=data
  69. )
  70. def compute_score(self, pairs: List[Tuple[str, str]], rerank_limit:int=32):
  71. """调用远程bge-reranker计算相关性"""
  72. # 类型检查
  73. if not isinstance(pairs, list):
  74. raise TypeError("Pairs must be list",pairs)
  75. if not pairs:
  76. raise ValueError("Pairs must not be empty",pairs)
  77. if len(pairs[0]) != 2:
  78. raise ValueError("Pairs must not be empty",pairs)
  79. i = 0
  80. for j, k in pairs:
  81. if not isinstance(j, str) or not isinstance(k, str):
  82. raise TypeError(f"Elements of every pairs must not be str, pos:{i}, ({j}, {k})")
  83. i+=1
  84. # 判断pairs的每个query是否为一致
  85. if len(pairs) >= 3:
  86. for i in range(1, len(pairs), len(pairs) - 1):
  87. if pairs[i - 1][0] != pairs[i][0] or pairs[i-1][0] != pairs[i+1][0]:
  88. raise ValueError("Pairs must have the same query", pairs)
  89. elif len(pairs) == 2:
  90. if pairs[0][0] != pairs[1][0]:
  91. raise ValueError("Pairs must have the same query", pairs)
  92. texts = [t for q, t in pairs]
  93. # 为了处理超过32个输入的情形,我们需要为texts进行分组,然后计算排序分数,并将分数按照index的顺序排序,最后拼接所有的结果作为最终的结果
  94. # 分组
  95. group_texts = [texts[i:i + rerank_limit] for i in range(0, len(texts), rerank_limit)]
  96. score = []
  97. for single_group in group_texts:
  98. data = {
  99. "query": pairs[0][0], # 对于bge-reranker,query字段可为空
  100. "texts": single_group # 输入长度不能超过32, rerank_limit为远程模型有效的输入长度
  101. }
  102. # 多次调用远程模型,并返回rerank结果
  103. res = self._access_remote_model(
  104. url=self.reranker_url,
  105. data=data
  106. )
  107. # 按照原有位置整理结果,并拼接不同组的分数
  108. score.extend([_["score"] for _ in sorted(res, key=lambda x: x["index"])])
  109. return score
  110. if __name__ == "__main__":
  111. timeout = 3
  112. max_retries = 3
  113. bge_model = RemoteBGEModel('dev', timeout, max_retries)
  114. t = bge_model.encode(["hello"], normalize=True)
  115. tt = bge_model.compute_score([("你好呀我的名字叫做汤姆","今天世界杯中国得了冠军"),
  116. ("你好呀我的名字叫做汤姆","你好呀我的名字叫做山姆"),
  117. ("你好呀我的名字叫做汤姆","你好呀我的名字叫做汤姆?"),
  118. ("你好呀我的名字叫做汤姆","我今天非常的开心,你呢?")])
  119. # reranker = FlagReranker(os.path.join(script_dir, 'bge-reranker-v2-m3'), use_fp16=True, local_files_only=True,
  120. # devices=["cuda:0"])
  121. # ttt = reranker.compute_score([("你好呀我的名字叫做汤姆","今天世界杯中国得了冠军"),
  122. # ("你好呀我的名字叫做汤姆","你好呀我的名字叫做山姆"),
  123. # ("你好呀我的名字叫做汤姆","你好呀我的名字叫做汤姆?"),
  124. # ("你好呀我的名字叫做汤姆","我今天非常的开心,你呢?")], normalize=True)
  125. pass