Christina Theodoris
commited on
Commit
•
5426788
1
Parent(s):
b73028f
Add Geneformer tokenizer and updated model card
Browse files- README.md +2 -5
- geneformer/__init__.py +0 -0
- geneformer/gene_median_dictionary.pkl +0 -0
- geneformer/token_dictionary.pkl +0 -0
- geneformer/tokenizer.py +204 -0
README.md
CHANGED
@@ -1,17 +1,14 @@
|
|
1 |
# Geneformer
|
2 |
Geneformer is a transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
3 |
|
4 |
-
<!---
|
5 |
See [our manuscript](manuscript_link) for details.
|
6 |
-
-->
|
7 |
|
8 |
# Model Description
|
9 |
Geneformer is transformer model pretrained on a [Genecorpus-30M](dataset_link), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
10 |
|
11 |
The rank value encoding of each single cell’s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
12 |
|
13 |
-
|
14 |
-
During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. Fine-tuning Geneformer towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modeling with limited patient data, Geneformer identified candidate therapeutic targets. Overall, Geneformer represents an invaluable pretrained model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
|
15 |
|
16 |
# Application
|
17 |
-
The pretrained Geneformer model can be used directly, for example for in silico deletion analysis, but is best used by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
|
|
1 |
# Geneformer
|
2 |
Geneformer is a transformer model pretrained on a large-scale corpus of ~30 million single cell transcriptomes to enable context-aware predictions in settings with limited data in network biology.
|
3 |
|
|
|
4 |
See [our manuscript](manuscript_link) for details.
|
|
|
5 |
|
6 |
# Model Description
|
7 |
Geneformer is transformer model pretrained on a [Genecorpus-30M](dataset_link), a pretraining corpus comprised of ~30 million single cell transcriptomes from a broad range of human tissues. Each single cell’s transcriptome is presented to the model as a rank value encoding where genes are ranked by their expression in that cell normalized by their expression across the entire Genecorpus-30M. The rank value encoding provides a nonparametric representation of that cell’s transcriptome and takes advantage of the many observations of each gene’s expression across Genecorpus-30M to prioritize genes that distinguish cell state. Specifically, this method will deprioritize ubiquitously highly-expressed housekeeping genes by normalizing them to a lower rank. Conversely, genes such as transcription factors that may be lowly expressed when they are expressed but highly distinguish cell state will move to a higher rank within the encoding. Furthermore, this rank-based approach may be more robust against technical artifacts that may systematically bias the absolute transcript counts value while the overall relative ranking of genes within each cell remains more stable.
|
8 |
|
9 |
The rank value encoding of each single cell’s transcriptome then proceeds through six transformer encoder units. Pretraining was accomplished using a masked learning objective where 15% of the genes within each transcriptome were masked and the model was trained to predict which gene should be within each masked position in that specific cell state using the context of the remaining unmasked genes. A major strength of this approach is that it is entirely self-supervised and can be accomplished on completely unlabeled data, which allows the inclusion of large amounts of training data without being restricted to samples with accompanying labels.
|
10 |
|
11 |
+
We detail applications and results in [our manuscript](manuscript_link). During pretraining, Geneformer gained a fundamental understanding of network dynamics, encoding network hierarchy in the model’s attention weights in a completely self-supervised manner. Fine-tuning Geneformer towards a diverse panel of downstream tasks relevant to chromatin and network dynamics using limited task-specific data demonstrated that Geneformer consistently boosted predictive accuracy. Applied to disease modeling with limited patient data, Geneformer identified candidate therapeutic targets. Overall, Geneformer represents an invaluable pretrained model from which fine-tuning towards a broad range of downstream applications can be pursued to accelerate discovery of key network regulators and candidate therapeutic targets.
|
|
|
12 |
|
13 |
# Application
|
14 |
+
The pretrained Geneformer model can be used directly, for example for in silico deletion analysis, but is best used by fine-tuning towards the relevant downstream task, such as gene or cell state classification.
|
geneformer/__init__.py
ADDED
File without changes
|
geneformer/gene_median_dictionary.pkl
ADDED
Binary file (941 kB). View file
|
|
geneformer/token_dictionary.pkl
ADDED
Binary file (788 kB). View file
|
|
geneformer/tokenizer.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Geneformer tokenizer.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
from geneformer.tokenizer import Tokenizer
|
6 |
+
tk = Tokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4)
|
7 |
+
tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix")
|
8 |
+
"""
|
9 |
+
|
10 |
+
import pickle
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import loompy as lp
|
14 |
+
import numpy as np
|
15 |
+
from datasets import Dataset
|
16 |
+
|
17 |
+
GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl"
|
18 |
+
TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl"
|
19 |
+
|
20 |
+
|
21 |
+
def tokenize_cell(gene_vector, gene_tokens):
|
22 |
+
"""
|
23 |
+
Convert normalized gene expression vector to tokenized rank value encoding.
|
24 |
+
"""
|
25 |
+
# create array of gene vector with token indices
|
26 |
+
# mask undetected genes
|
27 |
+
nonzero_mask = np.nonzero(gene_vector)[0]
|
28 |
+
# sort by median-scaled gene values
|
29 |
+
sorted_indices = np.argsort(-gene_vector[nonzero_mask])
|
30 |
+
# tokenize
|
31 |
+
sentence_tokens = gene_tokens[nonzero_mask][sorted_indices]
|
32 |
+
return sentence_tokens
|
33 |
+
|
34 |
+
|
35 |
+
class Tokenizer:
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
custom_attr_name_dict,
|
39 |
+
nproc=1,
|
40 |
+
gene_median_file=GENE_MEDIAN_FILE,
|
41 |
+
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
42 |
+
):
|
43 |
+
"""
|
44 |
+
Initialize tokenizer.
|
45 |
+
|
46 |
+
Parameters
|
47 |
+
----------
|
48 |
+
custom_attr_name_dict : dict
|
49 |
+
Dictionary of custom attributes to be added to the dataset.
|
50 |
+
Keys are the names of the attributes in the loom file.
|
51 |
+
Values are the names of the attributes in the dataset.
|
52 |
+
nproc : int
|
53 |
+
Number of processes to use for dataset mapping.
|
54 |
+
gene_median_file : Path
|
55 |
+
Path to pickle file containing dictionary of non-zero median
|
56 |
+
gene expression values across Genecorpus-30M.
|
57 |
+
token_dictionary_file : Path
|
58 |
+
Path to pickle file containing token dictionary (Ensembl IDs:token).
|
59 |
+
"""
|
60 |
+
# dictionary of custom attributes {output dataset column name: input .loom column name}
|
61 |
+
self.custom_attr_name_dict = custom_attr_name_dict
|
62 |
+
|
63 |
+
# number of processes for dataset mapping
|
64 |
+
self.nproc = nproc
|
65 |
+
|
66 |
+
# load dictionary of gene normalization factors
|
67 |
+
# (non-zero median value of expression across Genecorpus-30M)
|
68 |
+
with open(gene_median_file, "rb") as f:
|
69 |
+
self.gene_median_dict = pickle.load(f)
|
70 |
+
|
71 |
+
# load token dictionary (Ensembl IDs:token)
|
72 |
+
with open(token_dictionary_file, "rb") as f:
|
73 |
+
self.gene_token_dict = pickle.load(f)
|
74 |
+
|
75 |
+
# gene keys for full vocabulary
|
76 |
+
self.gene_keys = list(self.gene_median_dict.keys())
|
77 |
+
|
78 |
+
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
79 |
+
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|
80 |
+
|
81 |
+
def tokenize_data(self, loom_data_directory, output_directory, output_prefix):
|
82 |
+
"""
|
83 |
+
Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory.
|
84 |
+
|
85 |
+
Parameters
|
86 |
+
----------
|
87 |
+
loom_data_directory : Path
|
88 |
+
Path to directory containing loom files
|
89 |
+
output_directory : Path
|
90 |
+
Path to directory where tokenized data will be saved as .dataset
|
91 |
+
output_prefix : str
|
92 |
+
Prefix for output .dataset
|
93 |
+
"""
|
94 |
+
tokenized_cells, cell_metadata = self.tokenize_files(loom_data_directory)
|
95 |
+
tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata)
|
96 |
+
|
97 |
+
output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset")
|
98 |
+
tokenized_dataset.save_to_disk(output_path)
|
99 |
+
|
100 |
+
def tokenize_files(self, loom_data_directory):
|
101 |
+
tokenized_cells = []
|
102 |
+
cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.keys()}
|
103 |
+
|
104 |
+
# loops through directories to tokenize .loom files
|
105 |
+
for loom_file_path in loom_data_directory.glob("*.loom"):
|
106 |
+
print(f"Tokenizing {loom_file_path}")
|
107 |
+
file_tokenized_cells, file_cell_metadata = self.tokenize_file(
|
108 |
+
loom_file_path
|
109 |
+
)
|
110 |
+
tokenized_cells += file_tokenized_cells
|
111 |
+
cell_metadata.update(file_cell_metadata)
|
112 |
+
|
113 |
+
return tokenized_cells, cell_metadata
|
114 |
+
|
115 |
+
def tokenize_file(self, loom_file_path):
|
116 |
+
file_cell_metadata = {
|
117 |
+
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
118 |
+
}
|
119 |
+
|
120 |
+
with lp.connect(str(loom_file_path)) as data:
|
121 |
+
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
122 |
+
coding_miRNA_loc = np.where(
|
123 |
+
[self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
|
124 |
+
)[0]
|
125 |
+
norm_factor_vector = np.array(
|
126 |
+
[
|
127 |
+
self.gene_median_dict[i]
|
128 |
+
for i in data.ra["ensembl_id"][coding_miRNA_loc]
|
129 |
+
]
|
130 |
+
)
|
131 |
+
coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]
|
132 |
+
coding_miRNA_tokens = np.array(
|
133 |
+
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
134 |
+
)
|
135 |
+
|
136 |
+
# define coordinates of cells passing filters for inclusion (e.g. QC)
|
137 |
+
try:
|
138 |
+
data.ca["filter_pass"]
|
139 |
+
except NameError:
|
140 |
+
var_exists = False
|
141 |
+
else:
|
142 |
+
var_exists = True
|
143 |
+
|
144 |
+
if var_exists is True:
|
145 |
+
filter_pass_loc = np.where(
|
146 |
+
[True if i == 1 else False for i in data.ca["filter_pass"]]
|
147 |
+
)[0]
|
148 |
+
elif var_exists is False:
|
149 |
+
print(
|
150 |
+
f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells."
|
151 |
+
)
|
152 |
+
filter_pass_loc = np.array([i for i in range(data.shape[1])])
|
153 |
+
|
154 |
+
# scan through .loom files and tokenize cells
|
155 |
+
tokenized_cells = []
|
156 |
+
for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1):
|
157 |
+
# select subview with protein-coding and miRNA genes
|
158 |
+
subview = view.view[coding_miRNA_loc, :]
|
159 |
+
|
160 |
+
# normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
|
161 |
+
# and normalize by gene normalization factors
|
162 |
+
subview_norm_array = (
|
163 |
+
subview[:, :]
|
164 |
+
/ subview.ca.n_counts
|
165 |
+
* 10_000
|
166 |
+
/ norm_factor_vector[:, None]
|
167 |
+
)
|
168 |
+
# tokenize subview gene vectors
|
169 |
+
tokenized_cells += [
|
170 |
+
tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens)
|
171 |
+
for i in range(subview_norm_array.shape[1])
|
172 |
+
]
|
173 |
+
|
174 |
+
# add custom attributes for subview to dict
|
175 |
+
for k in file_cell_metadata.keys():
|
176 |
+
file_cell_metadata[k] += subview.ca[k].tolist()
|
177 |
+
|
178 |
+
return tokenized_cells, file_cell_metadata
|
179 |
+
|
180 |
+
def create_dataset(self, tokenized_cells, cell_metadata):
|
181 |
+
# create dict for dataset creation
|
182 |
+
dataset_dict = {"input_ids": tokenized_cells}
|
183 |
+
dataset_dict.update(cell_metadata)
|
184 |
+
|
185 |
+
# create dataset
|
186 |
+
output_dataset = Dataset.from_dict(dataset_dict)
|
187 |
+
|
188 |
+
# truncate dataset
|
189 |
+
def truncate(example):
|
190 |
+
example["input_ids"] = example["input_ids"][0:2048]
|
191 |
+
return example
|
192 |
+
|
193 |
+
output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc)
|
194 |
+
|
195 |
+
# measure lengths of dataset
|
196 |
+
def measure_length(example):
|
197 |
+
example["length"] = len(example["input_ids"])
|
198 |
+
return example
|
199 |
+
|
200 |
+
output_dataset_truncated_w_length = output_dataset_truncated.map(
|
201 |
+
measure_length, num_proc=self.nproc
|
202 |
+
)
|
203 |
+
|
204 |
+
return output_dataset_truncated_w_length
|