from transformers import BertTokenizerFast import os import os import json import time import threading from typing import Tuple, List import numpy as np try: import pycuda.driver as cuda cuda.init() import tensorrt as trt TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) TRT_AVLAIBLE=True except ImportError: TRT_AVLAIBLE=False try: import onnxruntime ONNXRUNTIME_AVLAIBLE=True except: ONNXRUNTIME_AVLAIBLE=False tokenizer = BertTokenizerFast.from_pretrained("hfl/chinese-macbert-base") def preprocess_data(text="this is a sad thing", is_trt=False): texts = [text for _ in range(5)] context = tokenizer(texts, padding="max_length", return_tensors='pt',max_length=128, truncation=True, return_offsets_mapping=True) input_ids = context['input_ids'].detach().cpu().numpy() attention_mask = context['attention_mask'].detach().cpu().numpy() token_type_ids = context['token_type_ids'].detach().cpu().numpy() if is_trt: return [input_ids, attention_mask, token_type_ids] else: return {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids} class TrtModel: def __init__(self, model_name="detector_corrector", model_dir=".", cached_engine=True, max_batch_size=5) -> None: self.cfx = cuda.Device(0).make_context() self.model_dir = model_dir self.model_name = model_name self.max_batch_size = max_batch_size self.catched_engine = cached_engine self.engine = self.load_model( os.path.join(model_dir, model_name + ".onnx")) self.input_binding_idxs, self.output_binding_idxs = self.get_binding_idxs() self.input_names = [self.engine.get_binding_name( binding_idx) for binding_idx in self.input_binding_idxs] self.output_names = [self.engine.get_binding_name( binding_idx) for binding_idx in self.output_binding_idxs] def __del__(self): self.cfx.detach() def load_model(self, model_path): return self.load_engine(model_path) def get_context(self): return self.engine.create_execution_context() def get_stream(self): return cuda.Stream() def predict(self, host_inputs): self.cfx.push() context = self.get_context() stream = self.get_stream() device_inputs = [cuda.mem_alloc(h_input.nbytes) for h_input in host_inputs] for h_input, d_input in zip(host_inputs, device_inputs): cuda.memcpy_htod_async(d_input, h_input, stream) host_outputs, device_outputs = self.gen_output_buffer( host_inputs, context) bindings = device_inputs + device_outputs exe_res = context.execute_async_v2( bindings=bindings, stream_handle=stream.handle) if not exe_res: print(f"{self.__class__.__name__} execute_async_v2 error") for h_output, d_output in zip(host_outputs, device_outputs): cuda.memcpy_dtoh_async(h_output, d_output, stream) stream.synchronize() for b in bindings: b.free() self.cfx.pop() return host_outputs def gen_output_buffer(self, host_inputs: List[np.ndarray], context): assert context.all_binding_shapes_specified host_outputs = [] device_outputs = [] for binding_index in self.output_binding_idxs: output_shape = context.get_binding_shape(binding_index) # Allocate buffers to hold output results after copying back to host buffer = np.empty(output_shape, dtype=np.float32) host_outputs.append(buffer) # Allocate output buffers on device device_outputs.append(cuda.mem_alloc(buffer.nbytes)) return host_outputs, device_outputs def get_binding_idxs(self): # Separate input and output binding indices for convenience input_binding_idxs = [] output_binding_idxs = [] for binding_index in range(0, self.engine.num_bindings): if self.engine.binding_is_input(binding_index): input_binding_idxs.append(binding_index) else: output_binding_idxs.append(binding_index) return input_binding_idxs, output_binding_idxs def load_engine(self, onnx_file_path): runtime = trt.Runtime(TRT_LOGGER) cached_engine_path = os.path.join( self.model_dir, self.model_name + ".engine") if self.catched_engine and os.path.exists(cached_engine_path): with open(cached_engine_path, "rb") as f: serialized_engine = f.read() engine = runtime.deserialize_cuda_engine(serialized_engine) print(f"load engine from cache: {cached_engine_path} sucessfully") return engine EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) with trt.Builder(TRT_LOGGER) as builder, builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: builder.max_batch_size = self.max_batch_size with open(onnx_file_path, 'rb') as model: if not parser.parse(model.read()): for error in range(parser.num_errors): print(parser.get_error(error)) config = builder.create_builder_config() config.max_workspace_size = 1 << 30 # config.set_flag(trt.BuilderFlag.FP16) serialized_engine = builder.build_serialized_network( network, config=config) engine = runtime.deserialize_cuda_engine(serialized_engine) print("build engine from {} sucessfully".format(onnx_file_path)) if self.catched_engine: with open(cached_engine_path, "wb") as f: f.write(serialized_engine) print(f"cached engine to: {cached_engine_path}") return engine class OnnxModel: def __init__(self, model_name="detector_corrector", model_dir="."): if not model_name.endswith(".onnx"): model_name = model_name + ".onnx" model_path = os.path.join(model_dir, model_name) print(f"onnx model path is {model_path}") self.ort_session = self.load_model(model_path) def load_model(self, model_path): providers = ['CUDAExecutionProvider'] # onnxruntime # sess_options = onnxruntime.SessionOptions() # sess_options.intra_op_num_threads = 10 # sess_options.inter_op_num_threads = 10 print(f"onnxruntime get device {onnxruntime.get_device()} available providers {onnxruntime.get_available_providers()}") ort_session = onnxruntime.InferenceSession( model_path, providers=providers) print(f"onnxruntime session providers {ort_session.get_providers()}") return ort_session def predict(self, inputs): ort_outs = self.ort_session.run(None, inputs) return ort_outs trt_model = TrtModel() onnx_model = OnnxModel() trt_output = trt_model.predict(preprocess_data(is_trt=True)) onnx_output = onnx_model.predict(preprocess_data(is_trt=False)) trt_detector_logits, trt_corrector_logits = trt_output onnx_detector_logits, onnx_corrector_logits = onnx_output import numpy as np assert np.allclose(trt_detector_logits, onnx_detector_logits) assert np.allclose(trt_corrector_logits, onnx_corrector_logits)
if you run this code, np.allclose result always False