remote_model.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # version: 2025.12.04
  2. import requests
  3. from typing import List, Tuple, Optional
  4. import os
  5. import json
  6. import time
  7. import numpy as np
  8. from FlagEmbedding import FlagAutoModel, FlagReranker
  9. script_dir = os.path.dirname(os.path.abspath(__file__))
  10. class RemoteBGEModel:
  11. def __init__(self, branch:str='dev', timeout:int=3, max_retries:int=3):
  12. # 加载网址配置文件
  13. self.branch = branch.strip().lower()
  14. if not self.branch in ['dev', 'test', 'master', 'main', 'local']: # 输入参数合法
  15. raise ValueError("Param 'branch' must be dev test master or main",branch)
  16. self.url_file = os.path.join(script_dir, 'url_config.json')
  17. self.embedding_url, self.reranker_url = self.load_url()
  18. self.timeout = timeout
  19. self.max_retries = max_retries
  20. # 构建请求头
  21. self.headers = {"Content-Type": "application/json"}
  22. def load_url(self):
  23. """加载url"""
  24. if not os.path.exists(self.url_file):
  25. raise FileNotFoundError("File not exist", self.url_file)
  26. # 读取json配置文件
  27. with open(self.url_file, 'r', encoding='utf-8') as f:
  28. json_data = json.load(f)
  29. if self.branch == 'dev' or self.branch == 'test':
  30. embed_url = json_data['dev_embed_url'] + '/embed'
  31. rerank_url = json_data['dev_reranker_url'] + '/rerank'
  32. elif self.branch == 'main' or self.branch == 'master':
  33. embed_url = json_data['master_embed_url'] + '/embed'
  34. rerank_url = json_data['master_reranker_url'] + '/rerank'
  35. else:
  36. embed_url = json_data['local_embed_url'] + '/embed'
  37. rerank_url = json_data['local_reranker_url'] + '/rerank'
  38. return embed_url, rerank_url
  39. def _access_remote_model(self, url:str, data:dict):
  40. """调用bge-m3,embedding"""
  41. # 类型检查
  42. time.sleep(0.08) # 方式频繁调用接口
  43. for attempt in range(self.max_retries):
  44. try:
  45. response = requests.post(url=url, headers=self.headers, json=data)
  46. if response.status_code == 200:
  47. return np.array(response.json())
  48. except Exception as e:
  49. print('请求embedding模型失败', e)
  50. time.sleep(1)
  51. return None
  52. return None
  53. def encode(self,texts: List[str], normalize: bool = True):
  54. """调用bge-m3,embedding"""
  55. # 类型检查
  56. if not isinstance(texts, list) and not isinstance(texts, str):
  57. raise TypeError("Text must be list or string",texts)
  58. if isinstance(texts, List):
  59. if not texts:
  60. raise ValueError("Text must not be empty",texts)
  61. for i, content in enumerate(texts):
  62. if not isinstance(content, str):
  63. raise ValueError(f"Text must not be empty, pos:{i}, content{content}")
  64. data = {"inputs":texts, "normalize":normalize}
  65. return self._access_remote_model(
  66. url=self.embedding_url,
  67. data=data
  68. )
  69. def compute_score(self, pairs: List[Tuple[str, str]]):
  70. """调用远程bge-reranker计算相关性, 并按照原位置输出分数"""
  71. # 类型检查
  72. if not isinstance(pairs, list):
  73. raise TypeError("Pairs must be list",pairs)
  74. if not pairs:
  75. raise ValueError("Pairs must not be empty",pairs)
  76. if len(pairs[0]) != 2:
  77. raise ValueError("Pairs must not be empty",pairs)
  78. i = 0
  79. for j, k in pairs:
  80. if not isinstance(j, str) or not isinstance(k, str):
  81. raise TypeError(f"Elements of every pairs must not be str, pos:{i}, ({j}, {k})")
  82. i+=1
  83. # 判断pairs的每个query是否为一致
  84. if len(pairs) >= 3:
  85. for i in range(1, len(pairs), len(pairs) - 1):
  86. if pairs[i - 1][0] != pairs[i][0] or pairs[i-1][0] != pairs[i+1][0]:
  87. raise ValueError("Pairs must have the same query", pairs)
  88. elif len(pairs) == 2:
  89. if pairs[0][0] != pairs[1][0]:
  90. raise ValueError("Pairs must have the same query", pairs)
  91. texts = [t for q, t in pairs]
  92. data = {
  93. "query": pairs[0][0], # 对于bge-reranker,query字段可为空
  94. "texts": texts
  95. }
  96. # 返回rerank结果
  97. res = self._access_remote_model(
  98. url=self.reranker_url,
  99. data=data
  100. )
  101. # 按照原有位置输出score
  102. score = [_["score"] for _ in sorted(res, key=lambda x: x["index"])]
  103. return score
  104. if __name__ == "__main__":
  105. timeout = 3
  106. max_retries = 3
  107. bge_model = RemoteBGEModel('dev', timeout, max_retries)
  108. t = bge_model.encode(["hello"], normalize=True)
  109. tt = bge_model.compute_score([("你好呀我的名字叫做汤姆","今天世界杯中国得了冠军"),
  110. ("你好呀我的名字叫做汤姆","你好呀我的名字叫做山姆"),
  111. ("你好呀我的名字叫做汤姆","你好呀我的名字叫做汤姆?"),
  112. ("你好呀我的名字叫做汤姆","我今天非常的开心,你呢?")])
  113. # reranker = FlagReranker(os.path.join(script_dir, 'bge-reranker-v2-m3'), use_fp16=True, local_files_only=True,
  114. # devices=["cuda:0"])
  115. # ttt = reranker.compute_score([("你好呀我的名字叫做汤姆","今天世界杯中国得了冠军"),
  116. # ("你好呀我的名字叫做汤姆","你好呀我的名字叫做山姆"),
  117. # ("你好呀我的名字叫做汤姆","你好呀我的名字叫做汤姆?"),
  118. # ("你好呀我的名字叫做汤姆","我今天非常的开心,你呢?")], normalize=True)
  119. pass