Convnextv2 finetuned for level classification
Convnextv2 base-size model finetuned for the classification of camera angles. Cinescale dataset is used to finetune the model for 20 epochs.
Classifies an image into six classes: aerial, eye, ground, hip, knee, shoulder
Evaluation
On the test set (test.csv), the model has an accuracy of 90.20% and macro-f1 of 82.28%
How to use
from transformers import AutoModelForImageClassification
import torch
from torchvision.transforms import v2
from torchvision.io import read_image, ImageReadMode
model = AutoModelForImageClassification.from_pretrained("gullalc/convnextv2-base-22k-224-cinescale-level")
im_size = 224
# https://www.pexels.com/photo/aerial-view-of-city-buildings-8783146/
image = read_image("demo/level_demo.jpg", mode=ImageReadMode.RGB)
transform = v2.Compose([v2.Resize((im_size,im_size), antialias=True),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
inputs = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(pixel_values=inputs)
predicted_label = model.config.id2label[torch.argmax(outputs.logits).item()]
print(predicted_label)
# --> aerial
Training Details
## Training transforms
randomorder = v2.RandomOrder([
v2.RandomHorizontalFlip(),
v2.GaussianBlur(5),
v2.RandomAdjustSharpness(2),
v2.RandomGrayscale(p=0.2),
v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)])
train_transform = v2.Compose([v2.Resize((im_size,im_size), antialias=True),
randomorder,
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
## Training Arguments
training_args = TrainingArguments(
evaluation_strategy = "epoch",
save_strategy = "epoch",
learning_rate=5e-5,
per_device_train_batch_size=128,
gradient_accumulation_steps=4,
per_device_eval_batch_size=128,
num_train_epochs=30,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="f1",
dataloader_num_workers=32,
torch_compile=True
)
- Downloads last month
- 14
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.