LlaMol / config /train /llama2-M-Full.yaml
doammii's picture
Add LlaMol codes
55d9b0c verified
raw
history blame
1.52 kB
io:
# I/O
out_dir : "out"
eval_interval : 500
log_interval : 10
eval_iters : 10
eval_only : false # if True, script exits right after the first eval
always_save_checkpoint : false # if True, always save a checkpoint after each eval
init_from : "scratch" # 'scratch' or 'resume'
resume_when_snapshot_available: true
loader:
batch_size : 384 # if gradient_accumulation_steps > 1, this is the micro-batch size
max_seq_len : 768
dataset : "smiles"
processed_dataset_ckpt : "processed_dataset_None.pkl"
fragment_creator : null
model:
dim : 256
n_layers : 8
n_heads : 8
multiple_of : 128
dropout : 0.1
context:
context_keys: ["logp", "sascore", "mol_weight"]
context_dims : [1,1,1]
optimizer:
gradient_accumulation_steps : 4 # used to simulate larger batch sizes
learning_rate : 1e-4 # max learning rate
max_iters : 100000 # total number of training iterations
weight_decay : 1e-1
beta1 : 0.9
beta2 : 0.95
grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr : true # whether to decay the learning rate
warmup_iters : 1000 # how many steps to warm up for
lr_decay_iters : 100000 # should be ~= max_iters per Chinchilla
min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
compile : false # Use torch.compile, but in my test this is really slow
label : "llama2-M-Full"