File size: 2,988 Bytes
55d9b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from trainer import (
    IOConfig,
    LoaderConfig,
    Trainer,
    TrainerArgs,
    ModelArgs,
    ContextArgs,
    OptimizerConfig,
)
from torch.distributed.elastic.multiprocessing.errors import record

import hydra
from omegaconf import DictConfig, OmegaConf
import logging
import sys
import os
import torch


def setup_logger(run_name: str, log_path: str):
    ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
    if ddp:
        ddp_rank = int(os.environ["RANK"])
        ddp_local_rank = int(os.environ["LOCAL_RANK"])
        ddp_world_size = int(os.environ["WORLD_SIZE"])

        formatter = logging.Formatter(
            f"[%(levelname)s] DDP[{ddp_rank},{ddp_local_rank},{ddp_world_size}] %(asctime)s - [%(filename)s:%(lineno)d]: %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )
    else:
        formatter = logging.Formatter(
            r"[%(levelname)s] %(asctime)s - [%(filename)s:%(lineno)d]: %(message)s",
            datefmt="%Y-%m-%d %H:%M:%S",
        )

    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(formatter)

    os.makedirs(log_path, exist_ok=True)
    file_handler = logging.FileHandler(os.path.join(log_path, f"train_{run_name}.log"))
    file_handler.setFormatter(formatter)

    logging.basicConfig(level=logging.INFO, handlers=[stream_handler, file_handler])

    return logging.getLogger()


@record
@hydra.main(version_base=None, config_path="config", config_name="config")
def main(cfg: DictConfig) -> None:
    logger = setup_logger(
        cfg.get("run_name", "default"), cfg.get("io", {"out_dir": "out"})["out_dir"]
    )

    logger.info("Using config")
    logger.info(cfg)

    cfg = cfg["train"]
    io_conf = IOConfig(**cfg.get("io", {}))
    loader_conf = LoaderConfig(**cfg.get("loader", {}))
    model_args = ModelArgs(**cfg.get("model", {}))
    ctx_args = ContextArgs(**cfg.get("context", {}))
    optmizer_conf = OptimizerConfig(**cfg.get("optimizer", {}))
    train_args = TrainerArgs(
        io_conf=io_conf,
        loader_conf=loader_conf,
        model_conf=model_args,
        context_conf=ctx_args,
        optimizer_conf=optmizer_conf,
        run_name=cfg.get("label", "train_run"),
    )

    # When training on cpu / testing to not max out all cpu cores
    torch.set_num_threads(8)

    trainer = Trainer(
        train_args=train_args,
        dtype=cfg.get("dtype", "float16"),
        compile=cfg.get("compile", False),
    )
    should_profile = cfg.get("profile", False)

    if should_profile:
        with torch.profiler.profile(
            activities=[
                torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA,
            ]
        ) as p:
            trainer.train()

        print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))

    else:
        trainer.train()


if __name__ == "__main__":
    # python train.py train=llama2-M-Full train.model.dim=1024
    main()