LlaMol / fragment_creator.py
doammii's picture
Add LlaMol codes
55d9b0c verified
raw
history blame contribute delete
No virus
3.61 kB
from abc import ABC
from dataclasses import dataclass
from typing import List, Union
import numpy as np
from rdkit import Chem
from rdkit.Chem.BRICS import BRICSDecompose
from rdkit.Chem.Recap import RecapDecompose
import random
@dataclass
class Fragment:
smiles: Union[str, None]
tokens: Union[List[int], None]
class BaseFragmentCreator(ABC):
"""
Is the base class for all fragment creator and does nothing to the smiles
"""
def __init__(self) -> None:
pass
def create_fragment(self, frag: Fragment) -> Fragment:
return ""
# This is the method used in the paper
class RandomSubsliceFragmentCreator(BaseFragmentCreator):
def __init__(self, max_fragment_size=50) -> None:
super().__init__()
self.max_fragment_size = max_fragment_size
def create_fragment(self, frag: Fragment) -> Fragment:
"""
Creates the random sub slice fragments from the tokens
"""
tokens = frag.tokens
startIdx = np.random.randint(0, len(tokens) - 1)
endIdx = np.random.randint(
startIdx + 1, min(len(tokens), startIdx + self.max_fragment_size)
)
return Fragment(smiles=None, tokens=tokens[startIdx:endIdx])
class BricksFragmentCreator(BaseFragmentCreator):
def __init__(self) -> None:
super().__init__()
def create_fragment(self, frag: Fragment) -> Fragment:
"""
Creates the Bricks fragments and takes one randomly
"""
smiles = frag.smiles
m = Chem.MolFromSmiles(smiles)
if m is None:
return ""
res = list(BRICSDecompose(m, minFragmentSize=3))
# print(res)
return random.choice(res)
class RecapFragmentCreator(BaseFragmentCreator):
def __init__(self) -> None:
super().__init__()
def create_fragment(self, frag: Fragment) -> Fragment:
"""
Creates the Recap fragments and takes one randomly
"""
smiles = frag.smiles
m = Chem.MolFromSmiles(smiles)
if m is None:
return ""
res = RecapDecompose(m, minFragmentSize=3).GetAllChildren()
# print(res)
return random.choice(res)
class MolFragsFragmentCreator(BaseFragmentCreator):
def __init__(self) -> None:
super().__init__()
def create_fragment(self, frag: Fragment) -> Fragment:
"""
Creates the Bricks fragments and takes one randomly
"""
smiles = frag.smiles
m = Chem.MolFromSmiles(smiles)
if m is None:
return ""
res = list(Chem.rdmolops.GetMolFrags(m, asMols=True))
res = [Chem.MolToSmiles(m) for m in res]
# print(res)
return random.choice(res)
def fragment_creator_factory(key: Union[str, None]):
if key is None:
return None
if key == "mol_frags":
return MolFragsFragmentCreator()
elif key == "recap":
return RecapFragmentCreator()
elif key == "bricks":
return BricksFragmentCreator()
elif key == "rss":
return RandomSubsliceFragmentCreator()
else:
raise ValueError(f"Do not have factory for the given key: {key}")
if __name__ == "__main__":
from tokenizer import SmilesTokenizer
tokenizer = SmilesTokenizer()
creator = BricksFragmentCreator()
# creator = MolFragsFragmentCreator()
# creator = RecapFragmentCreator()
frag = creator.create_fragment("CC(=O)NC1=CC=C(C=C1)O")
print(frag)
tokens = tokenizer.encode(frag)
print(tokens)
print([tokenizer._convert_id_to_token(t) for t in tokens])