|
import logging |
|
import os |
|
import onnx |
|
import tensorrt as trt |
|
from typing import List |
|
from collections import OrderedDict |
|
from onnx import shape_inference |
|
|
|
|
|
def vit_tagging_t2t(input_path="simple_model.onnx",output_path="vit.trt"): |
|
model = onnx.load(input_path) |
|
inferred_model = shape_inference.infer_shapes(model) |
|
|
|
simplified_model = input_path |
|
bitmask = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
|
|
|
trt_logger = trt.Logger() |
|
all_count,mix_count=0,0 |
|
with trt.Builder(trt_logger) as builder, builder.create_network(bitmask) as network, builder.create_builder_config() as config, trt.OnnxParser(network, trt_logger) as parser: |
|
|
|
config.set_flag(trt.BuilderFlag.FP16) |
|
with open(simplified_model, 'rb') as f: |
|
success = parser.parse(f.read()) |
|
if not success: |
|
for idx in range(parser.num_errors): |
|
print(parser.get_error(idx)) |
|
raise RuntimeError("Failed to parse the ONNX file.") |
|
profile = builder.create_optimization_profile() |
|
min_shape = [3,224,224] |
|
max_shape = [3,224,224] |
|
opt_shape = max_shape |
|
profile.set_shape("input", |
|
min=(1, *min_shape), |
|
opt=(70, *opt_shape), |
|
max=(70, *max_shape)) |
|
|
|
config.add_optimization_profile(profile) |
|
""" |
|
for i in range(network.num_layers): |
|
all_count+=1 |
|
layer = network.get_layer(i) |
|
if "ReduceMean" in layer.name or "Pow" in layer.name: |
|
mix_count+=1 |
|
config.set_flag(trt.BuilderFlag.STRICT_TYPES) |
|
layer.precision = trt.float32 |
|
layer.set_output_type(0, trt.float32) |
|
""" |
|
|
|
network.get_input(0).dtype = trt.float32 |
|
network.get_output(0).dtype = trt.float32 |
|
|
|
print(all_count,mix_count) |
|
engine = builder.build_engine(network, config) |
|
|
|
with open(output_path, 'wb') as f: |
|
f.write(engine.serialize()) |
|
f.close() |
|
|
|
if __name__=="__main__": |
|
vit_tagging_t2t() |