court-records-htr / dataset.py
MikkoLipsanen's picture
Upload 3 files
1838a16 verified
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from augments import RandAug
# Torch dataset
class TextlineDataset(Dataset):
def __init__(self, root_dir, df, processor, augment=False, max_target_length=128):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.augment = augment
self.augmentator = RandAug()
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df['file_name'][idx]
text = self.df['text'][idx]
# prepare image (i.e. resize + normalize)
image = Image.open(self.root_dir + file_name).convert("RGB")
# Add image augmentations
if self.augment:
image = self.augmentator(image)
# extract the pixel values
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = self.processor.tokenizer(str(text),
padding="max_length", truncation=True,
max_length=self.max_target_length).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding