|
import matplotlib |
|
matplotlib.use('Agg') |
|
|
|
import os |
|
import wandb |
|
import pytorch_lightning as pl |
|
|
|
from data import VerseDataModule |
|
from model import VerseFxClassifier |
|
from utils.config import get_config |
|
|
|
if __name__ == '__main__': |
|
config = get_config("config.yaml") |
|
|
|
USE_WANDB = 'online' if config.pop('USE_WANDB', False) else 'disabled' |
|
WANDB_API_KEY = config.pop('WANDB_API_KEY') |
|
SAVE_MODEL = config.pop('SAVE_MODEL') |
|
|
|
wandb.login(key=WANDB_API_KEY) |
|
|
|
run = wandb.init( |
|
project=f'fx-{config["task"]}-baseline-3d', |
|
entity='ifl-diva', |
|
config=config, |
|
mode=USE_WANDB |
|
) |
|
|
|
hparams = wandb.config |
|
|
|
wandb_logger = pl.loggers.WandbLogger() |
|
|
|
model = VerseFxClassifier(hparams) |
|
data = VerseDataModule(hparams) |
|
|
|
callbacks = [pl.callbacks.EarlyStopping(monitor="val/F1", mode="max", patience=hparams.early_stopping_patience)] |
|
|
|
if bool(SAVE_MODEL): |
|
callbacks.append(pl.callbacks.model_checkpoint.ModelCheckpoint(monitor='val/F1', mode="max", |
|
dirpath='saved_models', |
|
filename=f"{wandb.run.name}-epoch{{epoch}}-val_F1={{val/F1:.3f}}", |
|
auto_insert_metric_name=False)) |
|
|
|
trainer = pl.Trainer( |
|
gpus=1, |
|
logger=wandb_logger, |
|
log_every_n_steps=2, |
|
|
|
callbacks=callbacks, |
|
|
|
) |
|
|
|
with run: |
|
trainer.fit(model, data) |