Block Diagram Symbol Detection Model
It was introduced in the paper "Unveiling the Power of Integration: Block Diagram Summarization through Local-Global Fusion" accepted at ACL 2024. The full code is available in this BlockNet github repository.
Model description
This model is trained using an object detection model based on YOLOv5, which offers essential capabilities for detecting various objects in an image. Using the CBD, FCA, and FCB dataset, which includes annotations for different shapes and arrows in a diagram, we train the model to recognize six labels: arrow, terminator, process, decision, data, and text.
Training dataset
- YOLOv5 is fine-tuned with the annotations provided in this GitHub repository for symbol detection in block diagrams.
- 396 samples from real-world English block diagram dataset (CBD)
- 357 samples from handwritten English block diagram dataset (FC_A)
- 476 samples from handwritten English block diagram dataset (FC_B)
How to use
Here is how to use this model in PyTorch:
import argparse
import os
from pathlib import Path
import torch
from models.common import DetectMultiBackend
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
from utils.general import LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh
from utils.plots import Annotator, colors, save_one_box, save_block_box
from utils.torch_utils import select_device, smart_inference_mode
def load_model(weights, device, dnn, data, fp16):
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=fp16)
return model
def run_single_image_inference(model, img_path, stride, names, pt, conf_thres=0.35, iou_thres=0.7, max_det=100, augment=True, visualize=False, line_thickness=1, hide_labels=False, hide_conf=False, save_conf=False, save_crop=False, save_block=True, imgsz=(640, 640), vid_stride=1, bs=1, classes=None, agnostic_nms=False, save_txt=True, save_img=True):
dataset = LoadImages(img_path, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) # Load image from file
imgsz = check_img_size(imgsz, s=stride)
# Run inference
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
for path, im, im0s, vid_cap, s in dataset:
with dt[0]:
im = torch.from_numpy(im).to(model.device)
im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
# Inference
with dt[1]:
visualize = False
pred = model(im, augment=augment, visualize=visualize)
# NMS
with dt[2]:
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
# Second-stage classifier (optional)
# pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
# Process predictions
sorted_data_list = []
# Process predictions
for i, det in enumerate(pred): # per image
seen += 1
p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
p = Path(p) # to Path
s += '%gx%g ' % im.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
imc = im0.copy() if save_crop or save_block else im0 # for save_crop
annotator = Annotator(im0, line_width=line_thickness, example=str(names))
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
# Print results
for c in det[:, 5].unique():
n = (det[:, 5] == c).sum() # detections per class
s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
data_for_image=[]
# Write results
for *xyxy, conf, cls in reversed(det):
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
data_for_image.append((int(cls), xywh))
c = int(cls) # integer class
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
annotator.box_label(xyxy, label, color=colors(c, True))
# Sort the data based on the top-left coordinates (Y first, then X)
sorted_data_for_image = sorted(data_for_image, key=lambda x: (x[1][1], x[1][0]))
sorted_data_list.extend(sorted_data_for_image)
# Return the combined sorted data as a tuple
return tuple(sorted_data_list)
# Weight path
object_detection_output_path = 'symbol_detection/runs/detect/exp/labels'
yolo_weights_path = 'symbol_detection/runs/train/best_all/weights/best.pt'
yolo_yaml_file = 'symbol_detection/data/mydata.yaml'
yolo_model = load_model(yolo_weights_path, device='cuda:0', dnn=False, data=yolo_yaml_file, fp16=False)
stride, names, pt = yolo_model.stride, yolo_model.names, yolo_model.pt
# Example usage
image_path = "image.png"
labels = run_single_image_inference(yolo_model, image_path, stride, names, pt)
Contact
If you have any questions about this work, please contact Shreyanshu Bhushan using the following email addresses: [email protected].
License
The content of this project itself is licensed under the Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0).