dance-classifier / train.py
waidhoferj's picture
Refactor config style and reorganize files
557fb53
raw
history blame
1.05 kB
from typing import Callable
import importlib
import yaml
from argparse import ArgumentParser
import os
ROOT_DIR = os.path.basename(os.path.dirname(__file__))
def get_training_fn(id: str) -> Callable:
module_name, fn_name = id.rsplit(".", 1)
module = importlib.import_module("models." + module_name, ROOT_DIR)
return getattr(module, fn_name)
def get_config(filepath: str) -> dict:
with open(filepath, "r") as f:
config = yaml.safe_load(f)
return config
if __name__ == "__main__":
parser = ArgumentParser(
description="Trains models on the dance dataset and saves weights."
)
parser.add_argument(
"--config",
help="Path to the yaml file that defines the training configuration.",
default="models/config/train_local.yaml",
)
args = parser.parse_args()
config = get_config(args.config)
training_fn_path = config["training_fn"]
print(f"Config: {args.config}\nTrainer Id: {training_fn_path}")
train = get_training_fn(training_fn_path)
train(config)