LlaMol / config /train /llama2-Debug.yaml
doammii's picture
Add LlaMol codes
55d9b0c verified
raw
history blame
1.53 kB
io:
# I/O
out_dir : "debug"
eval_interval : 10
log_interval : 10
eval_iters : 5
eval_only : false # if True, script exits right after the first eval
always_save_checkpoint : true # if True, always save a checkpoint after each eval
init_from : "scratch" # 'scratch' or 'resume'
resume_when_snapshot_available: false
loader:
batch_size : 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
max_seq_len : 768
dataset : "smiles"
processed_dataset_ckpt : "processed_dataset_500000.pkl"
fragment_creator : "rss"
model:
dim : 32
n_layers : 1
n_heads : 1
multiple_of : 16
dropout : 0.1
context:
context_keys: ["logp", "sascore", "mol_weight"]
context_dims : [1,1,1]
optimizer:
gradient_accumulation_steps : 4 # used to simulate larger batch sizes
learning_rate : 1e-4 # max learning rate
max_iters : 20 # total number of training iterations
weight_decay : 1e-1
beta1 : 0.9
beta2 : 0.95
grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr : true # whether to decay the learning rate
warmup_iters : 10 # how many steps to warm up for
lr_decay_iters : 100 # should be ~= max_iters per Chinchilla
min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
compile : false # Use torch.compile, but in my test this is really slow
label : "llama2-Debug"
profile : false