File size: 4,538 Bytes
a9289c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""Compute segmentation maps for images in the input folder.
"""
import os
import glob
import cv2
import argparse
import torch
import torch.nn.functional as F
import util.io
from torchvision.transforms import Compose
from dpt.models import DPTSegmentationModel
from dpt.transforms import Resize, NormalizeImage, PrepareForNet
def run(input_path, output_path, model_path, model_type="dpt_hybrid", optimize=True):
"""Run segmentation network
Args:
input_path (str): path to input folder
output_path (str): path to output folder
model_path (str): path to saved model
"""
print("initialize")
# select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)
net_w = net_h = 480
# load network
if model_type == "dpt_large":
model = DPTSegmentationModel(
150,
path=model_path,
backbone="vitl16_384",
)
elif model_type == "dpt_hybrid":
model = DPTSegmentationModel(
150,
path=model_path,
backbone="vitb_rn50_384",
)
else:
assert (
False
), f"model_type '{model_type}' not implemented, use: --model_type [dpt_large|dpt_hybrid]"
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="minimal",
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
PrepareForNet(),
]
)
model.eval()
if optimize == True and device == torch.device("cuda"):
model = model.to(memory_format=torch.channels_last)
model = model.half()
model.to(device)
# get input
img_names = glob.glob(os.path.join(input_path, "*"))
num_images = len(img_names)
# create output folder
os.makedirs(output_path, exist_ok=True)
print("start processing")
for ind, img_name in enumerate(img_names):
print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
# input
img = util.io.read_image(img_name)
img_input = transform({"image": img})["image"]
# compute
with torch.no_grad():
sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
if optimize == True and device == torch.device("cuda"):
sample = sample.to(memory_format=torch.channels_last)
sample = sample.half()
out = model.forward(sample)
prediction = torch.nn.functional.interpolate(
out, size=img.shape[:2], mode="bicubic", align_corners=False
)
prediction = torch.argmax(prediction, dim=1) + 1
prediction = prediction.squeeze().cpu().numpy()
# output
filename = os.path.join(
output_path, os.path.splitext(os.path.basename(img_name))[0]
)
util.io.write_segm_img(filename, img, prediction, alpha=0.5)
print("finished")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i", "--input_path", default="input", help="folder with input images"
)
parser.add_argument(
"-o", "--output_path", default="output_semseg", help="folder for output images"
)
parser.add_argument(
"-m",
"--model_weights",
default=None,
help="path to the trained weights of model",
)
# 'vit_large', 'vit_hybrid'
parser.add_argument("-t", "--model_type", default="dpt_hybrid", help="model type")
parser.add_argument("--optimize", dest="optimize", action="store_true")
parser.add_argument("--no-optimize", dest="optimize", action="store_false")
parser.set_defaults(optimize=True)
args = parser.parse_args()
default_models = {
"dpt_large": "weights/dpt_large-ade20k-b12dca68.pt",
"dpt_hybrid": "weights/dpt_hybrid-ade20k-53898607.pt",
}
if args.model_weights is None:
args.model_weights = default_models[args.model_type]
# set torch options
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# compute segmentation maps
run(
args.input_path,
args.output_path,
args.model_weights,
args.model_type,
args.optimize,
)
|