# We want to train a classification model on our own data from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer from sklearn.preprocessing import LabelEncoder, OneHotEncoder import numpy as np import pandas as pd from datasets import load_dataset, load_from_disk import torch from joblib import cpu_count from fire import Fire def get_feature_function(preprocessor, encoder): def feature_extraction_function(examples): data = {} data["pixel_values"] = preprocessor(examples["image"]).pixel_values data['label'] = np.eye(len(encoder.classes_))[encoder.transform(examples['product_subcategory_name'])] return data return feature_extraction_function def train( dataset_id, hub_model_id, model_id = 'facebook/convnextv2-atto-1k-224', run_id = 'convextv2-atto-dataset', logging_steps = 100 ) dataset = load_dataset(dataset_id) preprocessor = AutoImageProcessor.from_pretrained(model_id) labels = np.unique(dataset['train']['product_subcategory_name']) encoder = LabelEncoder().fit(y=labels) model = AutoModelForImageClassification.from_pretrained(model_id, ignore_mismatched_sizes=True, num_labels=len(encoder.classes_), id2label={i: label for i, label in enumerate(encoder.classes_)}, label2id={label: i for i, label in enumerate(encoder.classes_)} ) dataset.set_transform(get_feature_function(preprocessor, encoder)) training_args = TrainingArguments( output_dir=f"results/{run_id}", remove_unused_columns=False, learning_rate=5e-5, per_device_train_batch_size=32, gradient_accumulation_steps=4, per_device_eval_batch_size=16, num_train_epochs=3, warmup_ratio=0.1, report_to=['tensorboard'], run_name=run_id, logging_steps=logging_steps, eval_steps=logging_steps, save_steps=logging_steps, save_total_limit=1, save_strategy="steps", evaluation_strategy="steps", skip_memory_metrics=False, logging_first_step=True, push_to_hub=True, hub_model_id=hub_model_id, hub_private_repo=True, hub_strategy="every_save", save_safetensors=True, # memory dataloader_num_workers=cpu_count()//4, # we have to prefetch the data to ensure efficient and stable GPU utilization dataloader_pin_memory=True ) trainer = Trainer( model=model, args=training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"] ) trainer.train() if __name__ == '__main__': Fire(train)