Noelia Ferruz
Update app.py
fea1ca5
raw
history blame
18.7 kB
import gradio as gr
import numpy as np
import os
import ray
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
#from transformers import pipeline as pl
from transformers import GPT2LMHeadModel , GPT2Tokenizer
from GPUtil import showUtilization as gpu_usage
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import plotly.graph_objects as go
import torch
import gc
import jax
from numba import cuda
import math
print('GPU available',torch.cuda.is_available())
#print('__CUDA Device Name:',torch.cuda.get_device_name(0))
print(os.getcwd())
if "/home/user/app/alphafold" not in sys.path:
sys.path.append("/home/user/app/alphafold")
from alphafold.common import protein
from alphafold.data import pipeline
from alphafold.data import templates
from alphafold.model import data
from alphafold.model import config
from alphafold.model import model
def mk_mock_template(query_sequence):
"""create blank template"""
ln = len(query_sequence)
output_templates_sequence = "-" * ln
templates_all_atom_positions = np.zeros(
(ln, templates.residue_constants.atom_type_num, 3)
)
templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))
templates_aatype = templates.residue_constants.sequence_to_onehot(
output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID
)
template_features = {
"template_all_atom_positions": templates_all_atom_positions[None],
"template_all_atom_masks": templates_all_atom_masks[None],
"template_aatype": np.array(templates_aatype)[None],
"template_domain_names": [f"none".encode()],
}
return template_features
def predict_structure(prefix, feature_dict, model_runners, random_seed=0):
"""Predicts structure using AlphaFold for the given sequence."""
# Run the models.
# currently we only run model1
plddts = {}
for model_name, model_runner in model_runners.items():
processed_feature_dict = model_runner.process_features(
feature_dict, random_seed=random_seed
)
prediction_result = model_runner.predict(processed_feature_dict)
b_factors = (
prediction_result["plddt"][:, None]
* prediction_result["structure_module"]["final_atom_mask"]
)
unrelaxed_protein = protein.from_prediction(
processed_feature_dict, prediction_result, b_factors
)
unrelaxed_pdb_path = f"/home/user/app/{prefix}_unrelaxed_{model_name}.pdb"
plddts[model_name] = prediction_result["plddt"]
print(f"{model_name} {plddts[model_name].mean()}")
with open(unrelaxed_pdb_path, "w") as f:
f.write(protein.to_pdb(unrelaxed_protein))
return plddts
def compute_perplexity(model, tokenizer, sequence):
input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
return math.exp(loss)
@ray.remote(num_gpus=1, max_calls=1)
def run_protgpt2(startsequence, length, repetitionPenalty, top_k_poolsize, max_seqs):
print("running protgpt2")
print(gpu_usage())
seqs_to_sample = max_seqs*10 # get the top 10
#protgpt2 = pl("text-generation", model="nferruz/ProtGPT2")
model = GPT2LMHeadModel.from_pretrained("nferruz/ProtGPT2")
tokenizer = GPT2Tokenizer.from_pretrained("nferruz/ProtGPT2")
input_ids = tokenizer.encode(startsequence, return_tensors='pt')
sequences = model.generate(input_ids,
max_length=length,
do_sample=True,
top_k=top_k_poolsize,
repetition_penalty=repetitionPenalty,
num_return_sequences=seqs_to_sample,
eos_token_id=0)
filtered_sequences = []
for sequence in sequences:
decoded_seq = tokenizer.decode(seq)
# No newlines in first line and avoid truncation
if '\n' not in decoded_seq[0:60] and decoded_seq.count('<|endoftext|>')>=2:
clean_seq = decoded_seq.split('<|endoftext|>')[0]
ppl = compute_perplexity(model, tokenizer, clean_seq)
filtered_sequences.append((clean_seq, ppl/len(clean_seq)))
## THis needs to be fixed to show warning if not enough sequences fulfill the criteria!
selected_sequences = filtered_sequences.sort(key = lambda x: x[2])[:max_seqs]
# sequences = protgpt2(
# startsequence,
# max_length=length,
# do_sample=True,
# top_k=top_k_poolsize,
# repetition_penalty=repetitionPenalty,
# num_return_sequences=seqs_to_sample,
# eos_token_id=0,
# )
print("Cleaning up after protGPT2")
#print(gpu_usage())
#torch.cuda.empty_cache()
#device = cuda.get_current_device()
#device.reset()
#print(gpu_usage())
return selected_sequences
@ray.remote(num_gpus=1, max_calls=1)
def run_alphafold(startsequence):
print(gpu_usage())
model_runners = {}
models = ["model_1"] # ,"model_2","model_3","model_4","model_5"]
for model_name in models:
model_config = config.model_config(model_name)
model_config.data.eval.num_ensemble = 1
model_params = data.get_model_haiku_params(model_name=model_name, data_dir="/home/user/app/")
model_runner = model.RunModel(model_config, model_params)
model_runners[model_name] = model_runner
query_sequence = startsequence.replace("\n", "")
feature_dict = {
**pipeline.make_sequence_features(
sequence=query_sequence, description="none", num_res=len(query_sequence)
),
**pipeline.make_msa_features(
msas=[[query_sequence]], deletion_matrices=[[[0] * len(query_sequence)]]
),
**mk_mock_template(query_sequence),
}
plddts = predict_structure("test", feature_dict, model_runners)
print("AF2 done")
#backend = jax.lib.xla_bridge.get_backend()
#for buf in backend.live_buffers(): buf.delete()
#device = cuda.get_current_device()
#device.reset()
#print(gpu_usage())
return plddts["model_1"]
def update_protGPT2(inp, length,repetitionPenalty, top_k_poolsize, max_seqs):
startsequence = inp
seqlen = length
generated_seqs = ray.get(run_protgpt2.remote(startsequence, seqlen, repetitionPenalty, top_k_poolsize, max_seqs))
gen_seqs = [x["generated_text"] for x in generated_seqs]
# Make sure sequences weren't truncated due to the length cutoff
# Select the best scoring top 10th:
print(sel_seqs)
sequencestxt = ""
for i, seq in enumerate(sel_seqs):
s = seq.replace("\n","")
seqlen = len(s)
s = "\n".join([s[i:i+70] for i in range(0, len(s), 70)])
sequencestxt +=f">seq{i}, {seqlen} residues \n{s}\n\n"
return sequencestxt
def update(inp):
print("Running AF on", inp)
startsequence = inp
# run alphafold using ray
plddts = ray.get(run_alphafold.remote(startsequence))
print(plddts)
x = np.arange(10)
#plt.style.use(["seaborn-ticks", "seaborn-talk"])
#fig = plt.figure()
#ax = fig.add_subplot(111)
#ax.plot(plddts)
#ax.set_ylabel("predicted LDDT")
#ax.set_xlabel("positions")
#ax.set_title("pLDDT")
fig = go.Figure(data=go.Scatter(x=np.arange(len(plddts)), y=plddts, hovertemplate='<i>pLDDT</i>: %{y:.2f} <br><i>Residue index:</i> %{x}'))
fig.update_layout(title="pLDDT",
xaxis_title="Residue index",
yaxis_title="pLDDT",
height=500,
template="simple_white")
return (
molecule(
f"test_unrelaxed_model_1.pdb",
),
fig,
f"{np.mean(plddts):.1f} Β± {np.std(plddts):.1f}",
)
def read_mol(molpath):
with open(molpath, "r") as fp:
lines = fp.readlines()
mol = ""
for l in lines:
mol += l
return mol
def molecule(pdb):
mol = read_mol(pdb)
x = (
"""<!DOCTYPE html>
<html>
<head>
<meta http-equiv="content-type" content="text/html; charset=UTF-8" />
<link rel="stylesheet" href="https://unpkg.com/[email protected]/dist/flowbite.min.css" />
<style>
body{
font-family:sans-serif
}
.mol-container {
width: 100%;
height: 800px;
position: relative;
}
.space-x-2 > * + *{
margin-left: 0.5rem;
}
.p-1{
padding:0.5rem;
}
.flex{
display:flex;
align-items: center;
}
.w-4{
width:1rem;
}
.h-4{
height:1rem;
}
.mt-4{
margin-top:1rem;
}
</style>
<script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script>
</head>
<body>
<div id="container" class="mol-container"></div>
<div class="flex">
<div class="px-4">
<label for="sidechain" class="relative inline-flex items-center mb-4 cursor-pointer ">
<input id="sidechain"type="checkbox" class="sr-only peer">
<div class="w-11 h-6 bg-gray-200 rounded-full peer peer-focus:ring-4 peer-focus:ring-blue-300 dark:peer-focus:ring-blue-800 dark:bg-gray-700 peer-checked:after:translate-x-full peer-checked:after:border-white after:absolute after:top-0.5 after:left-[2px] after:bg-white after:border-gray-300 after:border after:rounded-full after:h-5 after:w-5 after:transition-all dark:border-gray-600 peer-checked:bg-blue-600"></div>
<span class="ml-3 text-sm font-medium text-gray-900 dark:text-gray-300">Show side chains</span>
</label>
</div>
<button type="button" class="text-gray-900 bg-white hover:bg-gray-100 border border-gray-200 focus:ring-4 focus:outline-none focus:ring-gray-100 font-medium rounded-lg text-sm px-5 py-2.5 text-center inline-flex items-center dark:focus:ring-gray-600 dark:bg-gray-800 dark:border-gray-700 dark:text-white dark:hover:bg-gray-700 mr-2 mb-2" id="download">
<svg class="w-6 h-6 mr-2 -ml-1" fill="none" stroke="currentColor" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"></path></svg>
Download predicted structure
</button>
</div>
<div class="text-sm">
<div class="font-medium mt-4"><b>AlphaFold model confidence:</b></div>
<div class="flex space-x-2 py-1"><span class="w-4 h-4"
style="background-color: rgb(0, 83, 214);">&nbsp;</span><span class="legendlabel">Very high
(pLDDT &gt; 90)</span></div>
<div class="flex space-x-2 py-1"><span class="w-4 h-4"
style="background-color: rgb(101, 203, 243);">&nbsp;</span><span class="legendlabel">Confident
(90 &gt; pLDDT &gt; 70)</span></div>
<div class="flex space-x-2 py-1"><span class="w-4 h-4"
style="background-color: rgb(255, 219, 19);">&nbsp;</span><span class="legendlabel">Low (70 &gt;
pLDDT &gt; 50)</span></div>
<div class="flex space-x-2 py-1"><span class="w-4 h-4"
style="background-color: rgb(255, 125, 69);">&nbsp;</span><span class="legendlabel">Very low
(pLDDT &lt; 50)</span></div>
<div class="row column legendDesc"> AlphaFold produces a per-residue confidence
score (pLDDT) between 0 and 100. Some regions below 50 pLDDT may be unstructured in isolation.
</div>
</div>
<script>
let viewer = null;
let voldata = null;
$(document).ready(function () {
let element = $("#container");
let config = { backgroundColor: "white" };
viewer = $3Dmol.createViewer( element, config );
viewer.ui.initiateUI();
let data = `"""
+ mol
+ """`
viewer.addModel( data, "pdb" );
//AlphaFold code from https://gist.github.com/piroyon/30d1c1099ad488a7952c3b21a5bebc96
let colorAlpha = function (atom) {
if (atom.b < 50) {
return "OrangeRed";
} else if (atom.b < 70) {
return "Gold";
} else if (atom.b < 90) {
return "MediumTurquoise";
} else {
return "Blue";
}
};
viewer.setStyle({}, { cartoon: { colorfunc: colorAlpha } });
viewer.zoomTo();
viewer.render();
viewer.zoom(0.8, 2000);
viewer.getModel(0).setHoverable({}, true,
function (atom, viewer, event, container) {
console.log(atom)
if (!atom.label) {
atom.label = viewer.addLabel(atom.resn+atom.resi+" pLDDT=" + atom.b, { position: atom, backgroundColor: "mintcream", fontColor: "black" });
}
},
function (atom, viewer) {
if (atom.label) {
viewer.removeLabel(atom.label);
delete atom.label;
}
}
);
$("#sidechain").change(function () {
if (this.checked) {
BB = ["C", "O", "N"]
viewer.setStyle( {"and": [{resn: ["GLY", "PRO"], invert: true},{atom: BB, invert: true},]},{stick: {colorscheme: "WhiteCarbon", radius: 0.3}, cartoon: { colorfunc: colorAlpha }});
viewer.render()
} else {
viewer.setStyle({cartoon: { colorfunc: colorAlpha }});
viewer.render()
}
});
$("#download").click(function () {
download("gradioFold_model1.pdb", data);
})
});
function download(filename, text) {
var element = document.createElement("a");
element.setAttribute("href", "data:text/plain;charset=utf-8," + encodeURIComponent(text));
element.setAttribute("download", filename);
element.style.display = "none";
document.body.appendChild(element);
element.click();
document.body.removeChild(element);
}
</script>
</body></html>"""
)
return f"""<iframe style="width: 800px; height: 1200px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def change_sequence(chosenSeq):
return chosenSeq
proteindream = gr.Blocks()
with proteindream:
gr.Markdown("# GradioFold")
gr.Markdown(
"""GradioFold is a web-based tool that combines a large language model trained on natural protein sequence (protGPT2) with structure prediction using AlphaFold.
Type a start sequence that protGPT2 can complete or let protGPT2 generate a complete sequence without a start token."""
)
gr.Markdown("## protGPT2")
gr.Markdown(
"""
Enter a start sequence and have the language model complete it OR leave empty.
"""
)
with gr.Box():
with gr.Row():
inp = gr.Textbox(placeholder="MTYKLILNGKTLKGETTT", label="Start sequence")
length = gr.Number(value=100, label="Max sequence length")
with gr.Row():
repetitionPenalty = gr.Slider(minimum=1, maximum=5,value=1.2, label="Repetition penalty")
top_k_poolsize = gr.Slider(minimum=700, maximum=52056,value=950, label="Top-K sampling pool size")
max_seqs = gr.Slider(minimum=2, maximum=20,value=5, step=1, label="Number of sequences to generate")
btn = gr.Button("Predict sequences using protGPT2")
results = gr.Textbox(label="Results", lines=15)
btn.click(fn=update_protGPT2, inputs=[inp, length, repetitionPenalty, top_k_poolsize, max_seqs], outputs=results)
gr.Markdown("## AlphaFold")
gr.Markdown(
"Select a generated sequence above and copy it in the field below for structure prediction using AlphaFold2. You can also edit the sequence. Predictions will take around 2-5 minutes to be processed. Proteins larger than about 1000 residues will not fit into memory."
)
with gr.Group():
chosenSeq = gr.Textbox(label="Chosen sequence")
btn2 = gr.Button("Predict structure")
with gr.Group():
meanpLDDT = gr.Textbox(label="Mean pLDDT of chosen sequence")
with gr.Row():
mol = gr.HTML()
plot = gr.Plot(label="pLDDT")
gr.Markdown(
"""## Acknowledgements
More information about the used algorithms can be found below.
All code is available on [Huggingface](https://huggingface.co/spaces/simonduerr/protGPT2_gradioFold/blob/main) and licensed under MIT license.
- ProtGPT2: Ferruz et.al πŸ“„[BioRxiv](https://doi.org/10.1101/2022.03.09.483666) πŸ’»[Code](https://huggingface.co/nferruz/ProtGPT2)
- AlphaFold2: Jumper et.al πŸ“„[Paper](https://doi.org/10.1038/s41586-021-03819-2) πŸ’»[Code](https://github.com/deepmind/alphafold) Model parameters released under CC BY 4.0
- ColabFold: Mirdita et.al πŸ“„[Paper](https://doi.org/10.1101/2021.08.15.456425 ) πŸ’»[Code](https://github.com/sokrypton/ColabFold)
- 3Dmol.js: Rego & Koes πŸ“„[Paper](https://academic.oup.com/bioinformatics/article/31/8/1322/213186) πŸ’» [Code](https://github.com/3dmol/3Dmol.js)
Created by [@simonduerr](https://twitter.com/simonduerr)
Thanks to Hugginface team for sponsoring a free GPU for this demo.
"""
)
#seqChoice.change(fn=update_seqs, inputs=seqChoice, outputs=chosenSeq)
btn2.click(fn=update, inputs=chosenSeq, outputs=[mol, plot, meanpLDDT])
ray.init(runtime_env={"working_dir": "./alphafold"})
proteindream.launch(share=False)