LlaMol / preprocess_dataset.py
doammii's picture
Add LlaMol codes
55d9b0c verified
raw
history blame
11.9 kB
import argparse
import json
import os
import pickle
import random
from functools import partial
import pandas as pd
import numpy as np
import requests
import torch
import torch.distributed as dist
from tqdm import tqdm
import multiprocessing
from multiprocessing import Pool
from fragment_creator import BaseFragmentCreator, BricksFragmentCreator, Fragment
from tokenizer import SmilesTokenizer
from torch.utils.data.distributed import DistributedSampler
from rdkit import Chem
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
from tqdm.contrib.concurrent import process_map, thread_map
from typing import List
import swifter
DATA_CACHE_DIR = "data"
def _tokenize_smiles(
smi: List[str],
tokenizer: SmilesTokenizer = None,
max_smiles_len=256,
log_output=True,
):
# try:
tokens = tokenizer.encode(smi)
if len(tokens) > max_smiles_len:
if log_output:
print(f"Removing to long {smi} with smiles len of {len(tokens)} ")
return None
return tokens
# except Exception as e:
# print(e)
# return None
def _tokenize_scaffolds(smi: str, tokenizer=None, max_smiles_len=256, log_output=True):
# try:
smi = MurckoScaffoldSmiles(smi)
tokens = tokenizer.encode(smi)
tokens = tokens[1:-1] # remove [SEP] and [CLS] tokens
if len(tokens) > max_smiles_len:
if log_output:
print(f"Removing to long {smi} with smiles len of {len(tokens)} ")
return None
return tokens
# except Exception as e:
# print(e)
# return None
def pad_batch(src, pad_idx):
max_len = max([len(d) for d in src])
# src = [d["src_input_ids"] for d in data]
padded_src = np.ones([len(src), max_len]) * pad_idx
for i, j in enumerate(src):
padded_src[i][0 : len(j)] = j
# try to predict the next token from the previouse tokens
# essentially reconstructing the src sentence from the embeddings and the previous sentence
padded_src = padded_src.T
return padded_src
def pretokenize(
data_file=os.path.join(
DATA_CACHE_DIR, "FULL_combined_zinc_pubchemqc_qm9_pc9_reddb_chembl.parquet"
),
tokenizer=SmilesTokenizer(),
limit=None,
context=["logp", "sascore", "mol_weight"],
out_name: str = "processed_dataset",
remove_nan_context_rows: bool = False,
):
df = pd.read_parquet(data_file)
if limit is not None:
# smiles_list = df.smiles[:limit]
df = df.sample(n=limit) # df[:limit]
# NOTE: Set here if necessary, but for memory efficiency not duplicating millions of smiles
# smiles_list = df.smiles
else:
# shuffle the rows
df = df.sample(frac=1.0)
cpu_count = (
multiprocessing.cpu_count()
) # min(int(multiprocessing.cpu_count() * 0.8), 8)
print(f"Running on {cpu_count} CPUs ")
tqdm.pandas()
df["scaffolds"] = df["smiles"].progress_map(lambda s: None if "." in s else s)
df["smiles"] = df["scaffolds"].copy()
orig_len = len(df)
if context is not None:
if df.get("origin") is not None:
origins = df["origin"].unique()
origin_dics = {}
for i, o in enumerate(origins):
df.loc[df["origin"] == o, "origin"] = i
origin_dics[o] = i
df["origin"] = df["origin"].astype(float)
with open(
os.path.join(
DATA_CACHE_DIR, os.path.basename(data_file) + "_origins.json"
),
"w",
) as f:
json.dump(origin_dics, f)
mask = (
~df["smiles"].isna()
& (
(~df[context].isna()).all(axis=1)
if remove_nan_context_rows
else np.ones(len(df["smiles"]), dtype=bool)
)
& ~df["scaffolds"].isna()
)
else:
mask = ~df["smiles"].isna()
error_count = np.count_nonzero(~mask)
df = df[mask]
# print("HELLO")
# print("***"*10)
# tokenizer.batch_encode_plus()
# df["scaffolds"] = df["scaffolds"].swifter.apply(
# partial(_tokenize_scaffolds, tokenizer=tokenizer, log_output=False)
# )
# df["scaffolds"] = df["scaffolds"].swifter.apply(
# partial(_tokenize_scaffolds, tokenizer=tokenizer, log_output=False)
# )
df["tokens"] = df["smiles"].swifter.apply(
partial(_tokenize_smiles, tokenizer=tokenizer, log_output=False)
)
df["scaffolds"] = df["tokens"].copy()
mask = ~df["tokens"].isna() & ~df["scaffolds"].isna()
df = df[mask]
error_count += np.count_nonzero(~mask)
# Shuffle the data
df = df.sample(frac=1).reset_index(drop=True)
# with Pool(cpu_count) as p:
# df["scaffolds"] = list(
# p.map(partial( _tokenize_scaffolds ,tokenizer=tokenizer, log_output=False), tqdm(df.smiles.to_numpy(),total=len(df)), chunksize=1000),
# )
# df["smiles"] = list(
# p.map(partial( _tokenize_smiles ,tokenizer=tokenizer, log_output=False), tqdm(df.smiles.to_numpy(),total=len(df)), chunksize=1000),
# )
if context is not None:
context_list = df[context].to_numpy()
context_dict = {k: context_list[:, i] for i, k in enumerate(context)}
else:
context_dict = {}
print(f"Error count: {error_count} / {orig_len} = {error_count/orig_len}")
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
os.makedirs(cache_path, exist_ok=True)
out_path = os.path.join(cache_path, f"{out_name}_{limit}.pkl")
with open(out_path, "wb") as f:
pickle.dump(
{
"tokens": df["tokens"].tolist(),
"smiles": df["smiles"].tolist(),
"scaf": df["scaffolds"].tolist(),
**context_dict,
},
f,
)
print(f"Saved to {out_path}")
print("Done.")
class PretokDataset(torch.utils.data.Dataset):
"""Loads pretokenized example from disk and returns them as PyTorch tensors."""
def __init__(self, split, pad_token_id, dataset="processed_dataset.pkl"):
super().__init__()
self.split = split
self.dataset = dataset
self.pad_token_id = pad_token_id
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
with open(os.path.join(cache_path, self.dataset), "rb") as f:
self.data_dict = pickle.load(f)
# split out 10% of the data for validation
split_ix = int(len(self.data_dict["tokens"]) * 0.9)
if self.split == "train":
self.data_dict = {k: self.data_dict[k][:split_ix] for k in self.data_dict}
elif self.split == "val":
self.data_dict = {k: self.data_dict[k][split_ix:] for k in self.data_dict}
else:
raise RuntimeError(f"Could not find split for: self.split={self.split}")
def __len__(self):
return len(self.data_dict["tokens"])
def __getitem__(self, idx):
m = self.data_dict
start = idx
end = idx + 1
# calling .astype will copy the data into a new numpy array, now in RAM
padded_tokens = pad_batch(m["tokens"][start:end], self.pad_token_id)
chunk = torch.from_numpy((padded_tokens).astype(np.int64))
padded_scaffolds = torch.from_numpy(
pad_batch(m["scaf"][start:end], self.pad_token_id).astype(np.int64)
)
item = {
"seq": chunk,
"scaf": padded_scaffolds,
"smiles": m["smiles"][start:end],
**{
k: torch.tensor(m[k][start:end], dtype=torch.float32)
for k in m
if k != "scaf" and k != "tokens" and k != "smiles"
},
}
return item
def padding_collate_fn(
data, tokenizer: SmilesTokenizer, fragment_creator: BaseFragmentCreator
):
# data = list of dicts
pad_idx = tokenizer.pad_token_id
src = [d["seq"] for d in data]
max_len = max([len(d) for d in src])
padded_src = np.ones([len(src), max_len]) * pad_idx
for i, j in enumerate(src):
padded_src[i][0 : len(j)] = j.ravel()
if fragment_creator is None:
smiles_context = [d["scaf"] for d in data]
else:
# Remove start and end token after tokenization with [1:-1 ]
smiles_context = []
for d in data:
s = d["smiles"][0]
tokens = d["seq"]
frag = fragment_creator.create_fragment(Fragment(smiles=s, tokens=tokens))
if frag.tokens is not None:
smiles_context.append(frag.tokens)
else:
smiles_context.append(
torch.tensor(
tokenizer.encode(frag.smiles)[1:-1],
dtype=torch.long,
device=tokens.device,
)
)
max_len_ctx = max([len(d) for d in smiles_context])
padded_smiles_context = np.ones([len(smiles_context), max_len_ctx]) * pad_idx
for i, j in enumerate(smiles_context):
padded_smiles_context[i][0 : len(j)] = j.ravel()
# try to predict the next token from the previouse tokens
# essentially reconstructing the src sentence from the embeddings and the previous sentence
padded_src = padded_src.T
original_context_keys = [
k for k in data[0].keys() if k != "seq" and k != "scaf" and k != "smiles"
]
context_out_dict = {k: [] for k in original_context_keys}
for k in original_context_keys:
val_list = []
for d in data:
val_list.append(d[k])
context_out_dict[k] = torch.concat(val_list, dim=0)
return {
"src": torch.tensor(padded_src, dtype=torch.long), # for (seq_len, batch_size)
"fragment": torch.tensor(padded_smiles_context.T, dtype=torch.long),
"context": context_out_dict,
}
class SmilesTask:
@staticmethod
def iter_batches(
split,
batch_size,
device,
context_keys: List[str],
num_workers=0,
dataset="processed_dataset.pkl",
fragment_creator: BaseFragmentCreator = BricksFragmentCreator(),
):
tokenizer = SmilesTokenizer()
ds = PretokDataset(split, tokenizer.pad_token_id, dataset=dataset)
is_ddp = int(os.environ.get("RANK", -1)) != -1
dl = torch.utils.data.DataLoader(
ds,
batch_size=batch_size,
pin_memory=True,
num_workers=num_workers,
shuffle=False,
sampler=DistributedSampler(ds) if is_ddp else None,
collate_fn=lambda batch: padding_collate_fn(
batch, tokenizer, fragment_creator
),
)
for data in dl:
data["src"] = data["src"].to(device, non_blocking=True)
data["tgt"] = data["src"].to(device, non_blocking=True)
data["src"] = data["src"][:-1, :].T # batch_size, seq_len
data["tgt"] = data["tgt"][1:, :].T # batch_size, seq_len
data["fragment"] = (
data["fragment"].to(device, non_blocking=True).T
) # batch_size, seq_len
keys = list(data["context"].keys())
for d in keys:
if d not in context_keys:
del data["context"][d]
else:
data["context"][d] = data["context"][d].to(
device, non_blocking=True
)
yield data
if __name__ == "__main__":
pretokenize(
data_file=os.path.join(
DATA_CACHE_DIR,
"OrganiX13.parquet",
),
limit=None, # Set how many molecules should be processed, if None all molecules will be processed,
context=["logp", "sascore", "mol_weight"],
out_name="processed_dataset",
remove_nan_context_rows=False,
)