Update with gene classifier, custom token dict, and str validate options
#329
by
hchen725
- opened
- geneformer/classifier.py +72 -38
geneformer/classifier.py
CHANGED
@@ -53,7 +53,6 @@ from pathlib import Path
|
|
53 |
import numpy as np
|
54 |
import pandas as pd
|
55 |
import seaborn as sns
|
56 |
-
from sklearn.model_selection import StratifiedKFold
|
57 |
from tqdm.auto import tqdm, trange
|
58 |
from transformers import Trainer
|
59 |
from transformers.training_args import TrainingArguments
|
@@ -86,6 +85,7 @@ class Classifier:
|
|
86 |
"no_eval": {bool},
|
87 |
"stratify_splits_col": {None, str},
|
88 |
"forward_batch_size": {int},
|
|
|
89 |
"nproc": {int},
|
90 |
"ngpu": {int},
|
91 |
}
|
@@ -107,6 +107,7 @@ class Classifier:
|
|
107 |
stratify_splits_col=None,
|
108 |
no_eval=False,
|
109 |
forward_batch_size=100,
|
|
|
110 |
nproc=4,
|
111 |
ngpu=1,
|
112 |
):
|
@@ -175,6 +176,9 @@ class Classifier:
|
|
175 |
| Otherwise, will perform eval during training.
|
176 |
forward_batch_size : int
|
177 |
| Batch size for forward pass (for evaluation, not training).
|
|
|
|
|
|
|
178 |
nproc : int
|
179 |
| Number of CPU processes to use.
|
180 |
ngpu : int
|
@@ -183,6 +187,10 @@ class Classifier:
|
|
183 |
"""
|
184 |
|
185 |
self.classifier = classifier
|
|
|
|
|
|
|
|
|
186 |
self.cell_state_dict = cell_state_dict
|
187 |
self.gene_class_dict = gene_class_dict
|
188 |
self.filter_data = filter_data
|
@@ -201,6 +209,7 @@ class Classifier:
|
|
201 |
self.stratify_splits_col = stratify_splits_col
|
202 |
self.no_eval = no_eval
|
203 |
self.forward_batch_size = forward_batch_size
|
|
|
204 |
self.nproc = nproc
|
205 |
self.ngpu = ngpu
|
206 |
|
@@ -222,7 +231,9 @@ class Classifier:
|
|
222 |
] = self.cell_state_dict["states"]
|
223 |
|
224 |
# load token dictionary (Ensembl IDs:token)
|
225 |
-
|
|
|
|
|
226 |
self.gene_token_dict = pickle.load(f)
|
227 |
|
228 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
@@ -267,7 +278,7 @@ class Classifier:
|
|
267 |
continue
|
268 |
valid_type = False
|
269 |
for option in valid_options:
|
270 |
-
if (option in [int, float, list, dict, bool]) and isinstance(
|
271 |
attr_value, option
|
272 |
):
|
273 |
valid_type = True
|
@@ -630,7 +641,6 @@ class Classifier:
|
|
630 |
| Number of trials to run for hyperparameter optimization
|
631 |
| If 0, will not optimize hyperparameters
|
632 |
"""
|
633 |
-
|
634 |
if self.num_crossval_splits == 0:
|
635 |
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
636 |
raise
|
@@ -772,17 +782,20 @@ class Classifier:
|
|
772 |
]
|
773 |
)
|
774 |
assert len(targets) == len(labels)
|
775 |
-
n_splits = int(1 / self.
|
776 |
-
skf =
|
777 |
# (Cross-)validate
|
778 |
-
|
|
|
|
|
|
|
779 |
print(
|
780 |
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
781 |
)
|
782 |
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
783 |
# filter data for examples containing classes for this split
|
784 |
# subsample to max_ncells and relabel data in column "labels"
|
785 |
-
train_data, eval_data = cu.
|
786 |
data,
|
787 |
targets,
|
788 |
labels,
|
@@ -793,6 +806,18 @@ class Classifier:
|
|
793 |
self.nproc,
|
794 |
)
|
795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
796 |
if n_hyperopt_trials == 0:
|
797 |
trainer = self.train_classifier(
|
798 |
model_directory,
|
@@ -802,6 +827,15 @@ class Classifier:
|
|
802 |
ksplit_output_dir,
|
803 |
predict_trainer,
|
804 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
805 |
else:
|
806 |
trainer = self.hyperopt_classifier(
|
807 |
model_directory,
|
@@ -811,20 +845,27 @@ class Classifier:
|
|
811 |
ksplit_output_dir,
|
812 |
n_trials=n_hyperopt_trials,
|
813 |
)
|
814 |
-
|
815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
816 |
else:
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
id_class_dict,
|
823 |
-
eval_data,
|
824 |
-
predict_eval,
|
825 |
-
ksplit_output_dir,
|
826 |
-
output_prefix,
|
827 |
-
)
|
828 |
results += [result]
|
829 |
all_conf_mat = all_conf_mat + result["conf_mat"]
|
830 |
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
@@ -925,12 +966,7 @@ class Classifier:
|
|
925 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
926 |
|
927 |
##### Load model and training args #####
|
928 |
-
|
929 |
-
model_type = "CellClassifier"
|
930 |
-
elif self.classifier == "gene":
|
931 |
-
model_type = "GeneClassifier"
|
932 |
-
|
933 |
-
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
934 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
935 |
model, self.classifier, train_data, output_directory
|
936 |
)
|
@@ -946,6 +982,9 @@ class Classifier:
|
|
946 |
if eval_data is None:
|
947 |
def_training_args["evaluation_strategy"] = "no"
|
948 |
def_training_args["load_best_model_at_end"] = False
|
|
|
|
|
|
|
949 |
training_args_init = TrainingArguments(**def_training_args)
|
950 |
|
951 |
##### Fine-tune the model #####
|
@@ -957,7 +996,9 @@ class Classifier:
|
|
957 |
|
958 |
# define function to initiate model
|
959 |
def model_init():
|
960 |
-
model = pu.load_model(
|
|
|
|
|
961 |
|
962 |
if self.freeze_layers is not None:
|
963 |
def_freeze_layers = self.freeze_layers
|
@@ -1018,6 +1059,7 @@ class Classifier:
|
|
1018 |
metric="eval_macro_f1",
|
1019 |
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1020 |
),
|
|
|
1021 |
)
|
1022 |
|
1023 |
return trainer
|
@@ -1080,11 +1122,7 @@ class Classifier:
|
|
1080 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1081 |
|
1082 |
##### Load model and training args #####
|
1083 |
-
|
1084 |
-
model_type = "CellClassifier"
|
1085 |
-
elif self.classifier == "gene":
|
1086 |
-
model_type = "GeneClassifier"
|
1087 |
-
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
1088 |
|
1089 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1090 |
model, self.classifier, train_data, output_directory
|
@@ -1238,11 +1276,7 @@ class Classifier:
|
|
1238 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1239 |
|
1240 |
# load previously fine-tuned model
|
1241 |
-
|
1242 |
-
model_type = "CellClassifier"
|
1243 |
-
elif self.classifier == "gene":
|
1244 |
-
model_type = "GeneClassifier"
|
1245 |
-
model = pu.load_model(model_type, num_classes, model_directory, "eval")
|
1246 |
|
1247 |
# evaluate the model
|
1248 |
result = self.evaluate_model(
|
|
|
53 |
import numpy as np
|
54 |
import pandas as pd
|
55 |
import seaborn as sns
|
|
|
56 |
from tqdm.auto import tqdm, trange
|
57 |
from transformers import Trainer
|
58 |
from transformers.training_args import TrainingArguments
|
|
|
85 |
"no_eval": {bool},
|
86 |
"stratify_splits_col": {None, str},
|
87 |
"forward_batch_size": {int},
|
88 |
+
"token_dictionary_file": {None, str},
|
89 |
"nproc": {int},
|
90 |
"ngpu": {int},
|
91 |
}
|
|
|
107 |
stratify_splits_col=None,
|
108 |
no_eval=False,
|
109 |
forward_batch_size=100,
|
110 |
+
token_dictionary_file=None,
|
111 |
nproc=4,
|
112 |
ngpu=1,
|
113 |
):
|
|
|
176 |
| Otherwise, will perform eval during training.
|
177 |
forward_batch_size : int
|
178 |
| Batch size for forward pass (for evaluation, not training).
|
179 |
+
token_dictionary_file : None, str
|
180 |
+
| Default is to use token dictionary file from Geneformer
|
181 |
+
| Otherwise, will load custom gene token dictionary.
|
182 |
nproc : int
|
183 |
| Number of CPU processes to use.
|
184 |
ngpu : int
|
|
|
187 |
"""
|
188 |
|
189 |
self.classifier = classifier
|
190 |
+
if self.classifier == "cell":
|
191 |
+
self.model_type = "CellClassifier"
|
192 |
+
elif self.classifier == "gene":
|
193 |
+
self.model_type = "GeneClassifier"
|
194 |
self.cell_state_dict = cell_state_dict
|
195 |
self.gene_class_dict = gene_class_dict
|
196 |
self.filter_data = filter_data
|
|
|
209 |
self.stratify_splits_col = stratify_splits_col
|
210 |
self.no_eval = no_eval
|
211 |
self.forward_batch_size = forward_batch_size
|
212 |
+
self.token_dictionary_file = token_dictionary_file
|
213 |
self.nproc = nproc
|
214 |
self.ngpu = ngpu
|
215 |
|
|
|
231 |
] = self.cell_state_dict["states"]
|
232 |
|
233 |
# load token dictionary (Ensembl IDs:token)
|
234 |
+
if self.token_dictionary_file is None:
|
235 |
+
self.token_dictionary_file = TOKEN_DICTIONARY_FILE
|
236 |
+
with open(token_dictionary_file, "rb") as f:
|
237 |
self.gene_token_dict = pickle.load(f)
|
238 |
|
239 |
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
|
|
278 |
continue
|
279 |
valid_type = False
|
280 |
for option in valid_options:
|
281 |
+
if (option in [int, float, list, dict, bool, str]) and isinstance(
|
282 |
attr_value, option
|
283 |
):
|
284 |
valid_type = True
|
|
|
641 |
| Number of trials to run for hyperparameter optimization
|
642 |
| If 0, will not optimize hyperparameters
|
643 |
"""
|
|
|
644 |
if self.num_crossval_splits == 0:
|
645 |
logger.error("num_crossval_splits must be 1 or 5 to validate.")
|
646 |
raise
|
|
|
782 |
]
|
783 |
)
|
784 |
assert len(targets) == len(labels)
|
785 |
+
n_splits = int(1 / (1 - self.train_size))
|
786 |
+
skf = cu.StratifiedKFold3(n_splits=n_splits, random_state=0, shuffle=True)
|
787 |
# (Cross-)validate
|
788 |
+
test_ratio = self.oos_test_size / (self.eval_size + self.oos_test_size)
|
789 |
+
for train_index, eval_index, test_index in tqdm(
|
790 |
+
skf.split(targets, labels, test_ratio)
|
791 |
+
):
|
792 |
print(
|
793 |
f"****** Validation split: {iteration_num}/{self.num_crossval_splits} ******\n"
|
794 |
)
|
795 |
ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
|
796 |
# filter data for examples containing classes for this split
|
797 |
# subsample to max_ncells and relabel data in column "labels"
|
798 |
+
train_data, eval_data = cu.prep_gene_classifier_train_eval_split(
|
799 |
data,
|
800 |
targets,
|
801 |
labels,
|
|
|
806 |
self.nproc,
|
807 |
)
|
808 |
|
809 |
+
if self.oos_test_size > 0:
|
810 |
+
test_data = cu.prep_gene_classifier_split(
|
811 |
+
data,
|
812 |
+
targets,
|
813 |
+
labels,
|
814 |
+
test_index,
|
815 |
+
"test",
|
816 |
+
self.max_ncells,
|
817 |
+
iteration_num,
|
818 |
+
self.nproc,
|
819 |
+
)
|
820 |
+
|
821 |
if n_hyperopt_trials == 0:
|
822 |
trainer = self.train_classifier(
|
823 |
model_directory,
|
|
|
827 |
ksplit_output_dir,
|
828 |
predict_trainer,
|
829 |
)
|
830 |
+
result = self.evaluate_model(
|
831 |
+
trainer.model,
|
832 |
+
num_classes,
|
833 |
+
id_class_dict,
|
834 |
+
eval_data,
|
835 |
+
predict_eval,
|
836 |
+
ksplit_output_dir,
|
837 |
+
output_prefix,
|
838 |
+
)
|
839 |
else:
|
840 |
trainer = self.hyperopt_classifier(
|
841 |
model_directory,
|
|
|
845 |
ksplit_output_dir,
|
846 |
n_trials=n_hyperopt_trials,
|
847 |
)
|
848 |
+
|
849 |
+
model = cu.load_best_model(
|
850 |
+
ksplit_output_dir, self.model_type, num_classes
|
851 |
+
)
|
852 |
+
|
853 |
+
if self.oos_test_size > 0:
|
854 |
+
result = self.evaluate_model(
|
855 |
+
model,
|
856 |
+
num_classes,
|
857 |
+
id_class_dict,
|
858 |
+
test_data,
|
859 |
+
predict_eval,
|
860 |
+
ksplit_output_dir,
|
861 |
+
output_prefix,
|
862 |
+
)
|
863 |
else:
|
864 |
+
if iteration_num == self.num_crossval_splits:
|
865 |
+
return
|
866 |
+
else:
|
867 |
+
iteration_num = iteration_num + 1
|
868 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
869 |
results += [result]
|
870 |
all_conf_mat = all_conf_mat + result["conf_mat"]
|
871 |
# break after 1 or 5 splits, each with train/eval proportions dictated by eval_size
|
|
|
966 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
967 |
|
968 |
##### Load model and training args #####
|
969 |
+
model = pu.load_model(self.model_type, num_classes, model_directory, "train")
|
|
|
|
|
|
|
|
|
|
|
970 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
971 |
model, self.classifier, train_data, output_directory
|
972 |
)
|
|
|
982 |
if eval_data is None:
|
983 |
def_training_args["evaluation_strategy"] = "no"
|
984 |
def_training_args["load_best_model_at_end"] = False
|
985 |
+
def_training_args.update(
|
986 |
+
{"save_strategy": "epoch", "save_total_limit": 1}
|
987 |
+
) # only save last model for each run
|
988 |
training_args_init = TrainingArguments(**def_training_args)
|
989 |
|
990 |
##### Fine-tune the model #####
|
|
|
996 |
|
997 |
# define function to initiate model
|
998 |
def model_init():
|
999 |
+
model = pu.load_model(
|
1000 |
+
self.model_type, num_classes, model_directory, "train"
|
1001 |
+
)
|
1002 |
|
1003 |
if self.freeze_layers is not None:
|
1004 |
def_freeze_layers = self.freeze_layers
|
|
|
1059 |
metric="eval_macro_f1",
|
1060 |
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1061 |
),
|
1062 |
+
local_dir=output_directory,
|
1063 |
)
|
1064 |
|
1065 |
return trainer
|
|
|
1122 |
subprocess.call(f"mkdir {output_directory}", shell=True)
|
1123 |
|
1124 |
##### Load model and training args #####
|
1125 |
+
model = pu.load_model(self.model_type, num_classes, model_directory, "train")
|
|
|
|
|
|
|
|
|
1126 |
|
1127 |
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
1128 |
model, self.classifier, train_data, output_directory
|
|
|
1276 |
test_data = pu.load_and_filter(None, self.nproc, test_data_file)
|
1277 |
|
1278 |
# load previously fine-tuned model
|
1279 |
+
model = pu.load_model(self.model_type, num_classes, model_directory, "eval")
|
|
|
|
|
|
|
|
|
1280 |
|
1281 |
# evaluate the model
|
1282 |
result = self.evaluate_model(
|