|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
from augments import RandAug |
|
|
|
|
|
class TextlineDataset(Dataset): |
|
def __init__(self, root_dir, df, processor, augment=False, max_target_length=128): |
|
self.root_dir = root_dir |
|
self.df = df |
|
self.processor = processor |
|
self.augment = augment |
|
self.augmentator = RandAug() |
|
self.max_target_length = max_target_length |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
|
|
file_name = self.df['file_name'][idx] |
|
text = self.df['text'][idx] |
|
|
|
|
|
image = Image.open(self.root_dir + file_name).convert("RGB") |
|
|
|
|
|
if self.augment: |
|
image = self.augmentator(image) |
|
|
|
|
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values |
|
|
|
|
|
labels = self.processor.tokenizer(str(text), |
|
padding="max_length", truncation=True, |
|
max_length=self.max_target_length).input_ids |
|
|
|
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] |
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} |
|
return encoding |
|
|