bmodel_test.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import argparse
  2. import sophon.sail as sail
  3. import cv2
  4. import logging
  5. import numpy as np
  6. class Predictor:
  7. def __init__(self):
  8. # 加载推理引擎
  9. self.net = sail.Engine(args.bmodel, args.dev_id, sail.IOMode.SYSIO)
  10. self.graph_name = self.net.get_graph_names()[0]
  11. self.input_names = self.net.get_input_names(self.graph_name)
  12. self.input_shapes = [self.net.get_input_shape(self.graph_name, name) for name in self.input_names]
  13. self.output_names = self.net.get_output_names(self.graph_name)
  14. self.output_shapes = [self.net.get_output_shape(self.graph_name, name) for name in self.output_names] # [[1, 2]]
  15. self.input_name = self.input_names[0]
  16. self.input_shape = self.input_shapes[0] # [1, 3, 256, 256]
  17. self.batch_size = self.input_shape[0]
  18. self.net_h = self.input_shape[2] # 输入图像的高
  19. self.net_w = self.input_shape[3] # 输入图像的宽
  20. # 归一化参数,采用imagenet预训练参数
  21. self.mean = [0.485, 0.456, 0.406]
  22. self.std = [0.229, 0.224, 0.225]
  23. self.print_network_info()
  24. def __call__(self, img):
  25. return self.predict(img)
  26. def print_network_info(self):
  27. info = {
  28. 'Graph Name': self.graph_name,
  29. 'Input Name': self.input_name,
  30. 'Output Names': self.output_names,
  31. 'Output Shapes': self.output_shapes,
  32. 'Input Shape': self.input_shape,
  33. 'Batch Size': self.batch_size,
  34. 'Height': self.net_h,
  35. 'Width': self.net_w,
  36. 'Mean': self.mean,
  37. 'Std': self.std
  38. }
  39. print("=" * 50)
  40. print("Network Configuration Info")
  41. print("=" * 50)
  42. for key, value in info.items():
  43. print(f"{key:<18}: {value}")
  44. print("=" * 50)
  45. def predict(self, input_img):
  46. input_data = {self.input_name: input_img}
  47. outputs = self.net.process(self.graph_name, input_data)
  48. print('predict fun:', outputs)
  49. print('predict fun return:', list(outputs.values())[0])
  50. return list(outputs.values())[0]
  51. def preprocess(self, img):
  52. h, w, _ = img.shape
  53. if h != 256 or w != 256:
  54. img = cv2.resize(img, (self.net_w, self.net_h))
  55. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  56. img = img.astype('float32')
  57. img = (img / 255 - self.mean) / self.std # 这一步是很有必要的
  58. # img = img / 255. # 编译过程并不会帮你做归一化,所以这里要自己做归一化,否则预测数值可能会非常不准确
  59. img = np.transpose(img, (2, 0, 1))
  60. return img
  61. def postprocess(self, outputs):
  62. outputs_exp = np.exp(outputs)
  63. print('exp res:', outputs_exp)
  64. outputs = outputs_exp / np.sum(outputs_exp, axis=1)[:, None]
  65. print('softmax res:', outputs)
  66. predictions = np.argmax(outputs, axis=1)
  67. print('预测结果:', predictions)
  68. return outputs
  69. def main(args):
  70. predictor = Predictor()
  71. filename = args.input
  72. src_img = cv2.imread(filename, cv2.IMREAD_COLOR)
  73. src_img = predictor.preprocess(src_img)
  74. src_img = np.stack([src_img])
  75. print('图像输入shape:',src_img.shape)
  76. if src_img is None:
  77. logging.error("{} imread is None.".format(filename))
  78. return
  79. res = predictor(src_img)
  80. print('预测结果res:', res)
  81. predictor.postprocess(res)
  82. def argsparser():
  83. parser = argparse.ArgumentParser(prog=__file__)
  84. parser.add_argument('--input', type=str, default='./000000_256_512_1.jpg', help='path of input, must be image directory')
  85. parser.add_argument('--bmodel', type=str, default='./shufflenet_f32.bmodel', help='path of bmodel')
  86. parser.add_argument('--dev_id', type=int, default=0, help='tpu id')
  87. args = parser.parse_args()
  88. return args
  89. if __name__ == '__main__':
  90. args = argsparser()
  91. main(args)