diff --git a/.gitattributes b/.gitattributes index da5ad41327c1a4ebe37fcc3a6ec0b58d77173f77..7dc96f47694079f647b7f44d30bd40e102f37ac6 100644 --- a/.gitattributes +++ b/.gitattributes @@ -26,4 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zstandard filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text -model.safetensors filter=lfs diff=lfs merge=lfs -text \ No newline at end of file +model.safetensors filter=lfs diff=lfs merge=lfs -text diff --git a/MANIFEST.in b/MANIFEST.in index 7899a8fa49ff82e5a26f56212587d43308eddeb4..4f3b32b88e11d3674683a387329cf95239103f01 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ -include geneformer/gene_median_dictionary.pkl -include geneformer/token_dictionary.pkl -include geneformer/gene_name_id_dict.pkl +include geneformer/gene_median_dictionary_95m.pkl +include geneformer/token_dictionary_95m.pkl +include geneformer/gene_name_id_dict_95m.pkl diff --git a/config.json b/config.json index d131b7026d684013f988cc9e3dcae2e5a284bc0e..86e20c35e6f257f0daeb00ebb92a0751d12d8fff 100644 --- a/config.json +++ b/config.json @@ -3,21 +3,22 @@ "BertForMaskedLM" ], "attention_probs_dropout_prob": 0.02, - "gradient_checkpointing": false, + "classifier_dropout": null, "hidden_act": "relu", "hidden_dropout_prob": 0.02, - "hidden_size": 256, + "hidden_size": 512, "initializer_range": 0.02, - "intermediate_size": 512, + "intermediate_size": 1024, "layer_norm_eps": 1e-12, - "max_position_embeddings": 2048, + "max_position_embeddings": 4096, "model_type": "bert", - "num_attention_heads": 4, - "num_hidden_layers": 6, + "num_attention_heads": 8, + "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", - "transformers_version": "4.6.0", + "torch_dtype": "float32", + "transformers_version": "4.37.1", "type_vocab_size": 2, "use_cache": true, - "vocab_size": 25426 + "vocab_size": 20275 } diff --git a/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json new file mode 100755 index 0000000000000000000000000000000000000000..bc8099f84af0bd3e35d700a7135dd417e38f6bea --- /dev/null +++ b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "classifier_dropout": null, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 1024, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 4096, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "torch_dtype": "float32", + "transformers_version": "4.37.2", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 20275 +} diff --git a/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin new file mode 100755 index 0000000000000000000000000000000000000000..87625b1b8fe02c6aa0fc3ffd8c746275570e589d --- /dev/null +++ b/fine_tuned_models/gf-12L-95M-i4096_MTLCellClassifier_CELLxGENE_240522/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:07b28d8c7bb789d59755c42d32f6182cc04d2cf34aafaa6397aa50e4fdf1a9b4 +size 152363342 diff --git a/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/config.json b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json similarity index 100% rename from fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/config.json rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/config.json diff --git a/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/optimizer.pt b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt similarity index 100% rename from fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/optimizer.pt rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/optimizer.pt diff --git a/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/pytorch_model.bin b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin similarity index 100% rename from fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/pytorch_model.bin rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/pytorch_model.bin diff --git a/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/rng_state.pth b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth similarity index 100% rename from fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/rng_state.pth rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/rng_state.pth diff --git a/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/scheduler.pt b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt similarity index 100% rename from fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/scheduler.pt rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/scheduler.pt diff --git a/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/trainer_state.json b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json similarity index 100% rename from fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/trainer_state.json rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/trainer_state.json diff --git a/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/training_args.bin b/fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin similarity index 100% rename from fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224/training_args.bin rename to fine_tuned_models/gf-6L-30M-i2048_CellClassifier_cardiomyopathies_220224/training_args.bin diff --git a/geneformer/__init__.py b/geneformer/__init__.py index 693072ff5f0b96aa77a84fab8b8e4ca8ffc166a8..4fd1ffac058f01fcf07ca1daf7ab7dd270f18873 100644 --- a/geneformer/__init__.py +++ b/geneformer/__init__.py @@ -1,10 +1,12 @@ # ruff: noqa: F401 from pathlib import Path +import warnings +warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") # noqa # isort:skip -GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" -TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" -ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl" -ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict.pkl" +GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary_gc95M.pkl" +TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary_gc95M.pkl" +ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict_gc95M.pkl" +ENSEMBL_MAPPING_FILE = Path(__file__).parent / "ensembl_mapping_dict_gc95M.pkl" from . import ( collator_for_classification, @@ -25,4 +27,7 @@ from .pretrainer import GeneformerPretrainer from .tokenizer import TranscriptomeTokenizer from . import classifier # noqa # isort:skip -from .classifier import Classifier # noqa # isort:skip \ No newline at end of file +from .classifier import Classifier # noqa # isort:skip + +from . import mtl_classifier # noqa # isort:skip +from .mtl_classifier import MTLClassifier # noqa # isort:skip \ No newline at end of file diff --git a/geneformer/classifier.py b/geneformer/classifier.py index b586c91cc2cbdefdfc9d93d6cbbecbc102ad1099..76d2ff1faf2fb9728b7ac876876f5fbd9370b2c0 100644 --- a/geneformer/classifier.py +++ b/geneformer/classifier.py @@ -72,6 +72,7 @@ logger = logging.getLogger(__name__) class Classifier: valid_option_dict = { "classifier": {"cell", "gene"}, + "quantize": {bool, dict}, "cell_state_dict": {None, dict}, "gene_class_dict": {None, dict}, "filter_data": {None, dict}, @@ -93,6 +94,7 @@ class Classifier: def __init__( self, classifier=None, + quantize=False, cell_state_dict=None, gene_class_dict=None, filter_data=None, @@ -118,6 +120,13 @@ class Classifier: classifier : {"cell", "gene"} | Whether to fine-tune a cell state or gene classifier. + quantize : bool, dict + | Whether to fine-tune a quantized model. + | If True and no config provided, will use default. + | Will use custom config if provided. + | Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft). + | For example: {"bnb_config": BitsAndBytesConfig(...), + | "peft_config": LoraConfig(...)} cell_state_dict : None, dict | Cell states to fine-tune model to distinguish. | Two-item dictionary with keys: state_key and states @@ -191,6 +200,7 @@ class Classifier: self.model_type = "CellClassifier" elif self.classifier == "gene": self.model_type = "GeneClassifier" + self.quantize = quantize self.cell_state_dict = cell_state_dict self.gene_class_dict = gene_class_dict self.filter_data = filter_data @@ -256,7 +266,7 @@ class Classifier: f"Genes to classify {missing_genes} are not in token dictionary." ) self.gene_class_dict = { - k: set([self.gene_token_dict.get(gene) for gene in v]) + k: list(set([self.gene_token_dict.get(gene) for gene in v])) for k, v in self.gene_class_dict.items() } empty_classes = [] @@ -403,6 +413,15 @@ class Classifier: "Column name 'labels' must be reserved for class IDs. Please rename column." ) raise + + if (attr_to_split is not None) and (attr_to_balance is None): + logger.error( + "Splitting by attribute while balancing confounders requires both attr_to_split and attr_to_balance to be defined." + ) + raise + + if not isinstance(attr_to_balance, list): + attr_to_balance = [attr_to_balance] if self.classifier == "cell": # remove cell states representing < rare_threshold of cells @@ -505,6 +524,7 @@ class Classifier: output_directory, output_prefix, save_eval_output=True, + gene_balance=False, ): """ Train cell state or gene classifier using all data. @@ -525,13 +545,20 @@ class Classifier: save_eval_output : bool | Whether to save cross-fold eval output | Saves as pickle file of dictionary of eval metrics - + gene_balance : None, bool + | Whether to automatically balance genes in training set. + | Only available for binary gene classifications. + **Output** Returns trainer after fine-tuning with all data. """ + if (gene_balance is True) and (len(self.gene_class_dict.values())!=2): + logger.error("Automatically balancing gene sets for training is only available for binary gene classifications.") + raise + ##### Load data and prepare output directory ##### # load numerical id to class dictionary (id:class) with open(id_class_dict_file, "rb") as f: @@ -563,7 +590,7 @@ class Classifier: ) assert len(targets) == len(labels) data = cu.prep_gene_classifier_all_data( - data, targets, labels, self.max_ncells, self.nproc + data, targets, labels, self.max_ncells, self.nproc, gene_balance ) trainer = self.train_classifier( @@ -582,12 +609,15 @@ class Classifier: split_id_dict=None, attr_to_split=None, attr_to_balance=None, + gene_balance=False, max_trials=100, pval_threshold=0.1, save_eval_output=True, predict_eval=True, predict_trainer=False, n_hyperopt_trials=0, + save_gene_split_datasets=True, + debug_gene_split_datasets=False, ): """ (Cross-)validate cell state or gene classifier. @@ -622,6 +652,9 @@ class Classifier: attr_to_balance : None, list | List of attribute keys on which to balance data while splitting on attr_to_split | e.g. ["age", "sex"] for balancing these characteristics while splitting by patient + gene_balance : None, bool + | Whether to automatically balance genes in training set. + | Only available for binary gene classifications. max_trials : None, int | Maximum number of trials of random splitting to try to achieve balanced other attribute | If no split is found without significant (p < pval_threshold) differences in other attributes, will select best @@ -640,11 +673,17 @@ class Classifier: n_hyperopt_trials : int | Number of trials to run for hyperparameter optimization | If 0, will not optimize hyperparameters + save_gene_split_datasets : bool + | Whether or not to save train, valid, and test gene-labeled datasets """ if self.num_crossval_splits == 0: logger.error("num_crossval_splits must be 1 or 5 to validate.") raise - + + if (gene_balance is True) and (len(self.gene_class_dict.values())!=2): + logger.error("Automatically balancing gene sets for training is only available for binary gene classifications.") + raise + # ensure number of genes in each class is > 5 if validating model if self.classifier == "gene": insuff_classes = [k for k, v in self.gene_class_dict.items() if len(v) < 5] @@ -725,7 +764,7 @@ class Classifier: else: # 5-fold cross-validate num_cells = len(data) - fifth_cells = num_cells * 0.2 + fifth_cells = int(np.floor(num_cells * 0.2)) num_eval = min((self.eval_size * num_cells), fifth_cells) start = i * fifth_cells end = start + num_eval @@ -804,8 +843,19 @@ class Classifier: self.max_ncells, iteration_num, self.nproc, + gene_balance, ) - + + if save_gene_split_datasets is True: + for split_name in ["train", "valid"]: + labeled_dataset_output_path = ( + Path(output_dir) / f"{output_prefix}_{split_name}_gene_labeled_ksplit{iteration_num}" + ).with_suffix(".dataset") + if split_name == "train": + train_data.save_to_disk(str(labeled_dataset_output_path)) + elif split_name == "valid": + eval_data.save_to_disk(str(labeled_dataset_output_path)) + if self.oos_test_size > 0: test_data = cu.prep_gene_classifier_split( data, @@ -817,7 +867,14 @@ class Classifier: iteration_num, self.nproc, ) - + if save_gene_split_datasets is True: + test_labeled_dataset_output_path = ( + Path(output_dir) / f"{output_prefix}_test_gene_labeled_ksplit{iteration_num}" + ).with_suffix(".dataset") + test_data.save_to_disk(str(test_labeled_dataset_output_path)) + if debug_gene_split_datasets is True: + logger.error("Exiting after saving gene split datasets given debug_gene_split_datasets = True.") + raise if n_hyperopt_trials == 0: trainer = self.train_classifier( model_directory, @@ -966,7 +1023,7 @@ class Classifier: subprocess.call(f"mkdir {output_directory}", shell=True) ##### Load model and training args ##### - model = pu.load_model(self.model_type, num_classes, model_directory, "train") + model = pu.load_model(self.model_type, num_classes, model_directory, "train", quantize=self.quantize) def_training_args, def_freeze_layers = cu.get_default_train_args( model, self.classifier, train_data, output_directory ) @@ -990,14 +1047,14 @@ class Classifier: ##### Fine-tune the model ##### # define the data collator if self.classifier == "cell": - data_collator = DataCollatorForCellClassification() + data_collator = DataCollatorForCellClassification(token_dictionary=self.token_dictionary) elif self.classifier == "gene": - data_collator = DataCollatorForGeneClassification() + data_collator = DataCollatorForGeneClassification(token_dictionary=self.token_dictionary) # define function to initiate model def model_init(): model = pu.load_model( - self.model_type, num_classes, model_directory, "train" + self.model_type, num_classes, model_directory, "train", quantize=self.quantize ) if self.freeze_layers is not None: @@ -1009,7 +1066,8 @@ class Classifier: for param in module.parameters(): param.requires_grad = False - model = model.to("cuda:0") + if self.quantize is False: + model = model.to("cuda:0") return model # create the trainer @@ -1122,7 +1180,7 @@ class Classifier: subprocess.call(f"mkdir {output_directory}", shell=True) ##### Load model and training args ##### - model = pu.load_model(self.model_type, num_classes, model_directory, "train") + model = pu.load_model(self.model_type, num_classes, model_directory, "train", quantize=self.quantize) def_training_args, def_freeze_layers = cu.get_default_train_args( model, self.classifier, train_data, output_directory @@ -1152,9 +1210,9 @@ class Classifier: ##### Fine-tune the model ##### # define the data collator if self.classifier == "cell": - data_collator = DataCollatorForCellClassification() + data_collator = DataCollatorForCellClassification(token_dictionary=self.token_dictionary) elif self.classifier == "gene": - data_collator = DataCollatorForGeneClassification() + data_collator = DataCollatorForGeneClassification(token_dictionary=self.token_dictionary) # create the trainer trainer = Trainer( @@ -1276,7 +1334,7 @@ class Classifier: test_data = pu.load_and_filter(None, self.nproc, test_data_file) # load previously fine-tuned model - model = pu.load_model(self.model_type, num_classes, model_directory, "eval") + model = pu.load_model(self.model_type, num_classes, model_directory, "eval", quantize=self.quantize) # evaluate the model result = self.evaluate_model( diff --git a/geneformer/classifier_utils.py b/geneformer/classifier_utils.py index 92d131e94987694f502ba24fa1b30ba4a630a0f6..ba8818141da739898d435e23a380a780f1b55ad1 100644 --- a/geneformer/classifier_utils.py +++ b/geneformer/classifier_utils.py @@ -137,21 +137,22 @@ def label_gene_classes(example, class_id_dict, gene_class_dict): def prep_gene_classifier_train_eval_split( - data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc + data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc, balance=False ): # generate cross-validation splits train_data = prep_gene_classifier_split( - data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc + data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc, balance ) eval_data = prep_gene_classifier_split( - data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc + data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc, balance ) return train_data, eval_data def prep_gene_classifier_split( - data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc + data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc, balance=False ): + # generate cross-validation splits targets = np.array(targets) labels = np.array(labels) @@ -172,6 +173,10 @@ def prep_gene_classifier_split( f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n" ) + # balance gene subsets if train + if (subset_name == "train") and (balance is True): + subset_data, label_dict_subset = balance_gene_split(subset_data, label_dict_subset, num_proc) + # subsample to max_ncells subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None) @@ -187,7 +192,7 @@ def prep_gene_classifier_split( return subset_data -def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc): +def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc, balance=False): targets = np.array(targets) labels = np.array(labels) label_dict_train = dict(zip(targets, labels)) @@ -205,6 +210,9 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc): f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n" ) + if balance is True: + train_data, label_dict_train = balance_gene_split(train_data, label_dict_train, num_proc) + # subsample to max_ncells train_data = downsample_and_shuffle(train_data, max_ncells, None, None) @@ -220,6 +228,110 @@ def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc): return train_data +def balance_gene_split(subset_data, label_dict_subset, num_proc): + # count occurrence of genes in each label category + label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset, num_proc) + label_ratio_0to1 = label0_counts/label1_counts + + if 8/10 <= label_ratio_0to1 <= 10/8: + # gene sets already balanced + logger.info( + "Gene sets were already balanced within 0.8-1.25 fold and did not require balancing.\n" + ) + return subset_data, label_dict_subset + else: + label_ratio_0to1_orig = label_ratio_0to1+0 + label_dict_subset_orig = label_dict_subset.copy() + # balance gene sets + max_ntrials = 25 + boost = 1 + if label_ratio_0to1 > 10/8: + # downsample label 0 + for i in range(max_ntrials): + label0 = 0 + label0_genes = [k for k,v in label_dict_subset.items() if v == label0] + label0_ngenes = len(label0_genes) + label0_nremove = max(1,int(np.floor(label0_ngenes - label0_ngenes/(label_ratio_0to1*boost)))) + random.seed(i) + label0_remove_genes = random.sample(label0_genes, label0_nremove) + label_dict_subset_new = {k:v for k,v in label_dict_subset.items() if k not in label0_remove_genes} + label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset_new, num_proc) + label_ratio_0to1 = label0_counts/label1_counts + if 8/10 <= label_ratio_0to1 <= 10/8: + # if gene sets now balanced, return new filtered data and new label_dict_subset + return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc) + elif label_ratio_0to1 > 10/8: + boost = boost*1.1 + elif label_ratio_0to1 < 8/10: + boost = boost*0.9 + else: + # downsample label 1 + for i in range(max_ntrials): + label1 = 1 + label1_genes = [k for k,v in label_dict_subset.items() if v == label1] + label1_ngenes = len(label1_genes) + label1_nremove = max(1,int(np.floor(label1_ngenes - label1_ngenes/((1/label_ratio_0to1)*boost)))) + random.seed(i) + label1_remove_genes = random.sample(label1_genes, label1_nremove) + label_dict_subset_new = {k:v for k,v in label_dict_subset.items() if k not in label1_remove_genes} + label0_counts, label1_counts = count_genes_for_balancing(subset_data, label_dict_subset_new, num_proc) + label_ratio_0to1 = label0_counts/label1_counts + if 8/10 <= label_ratio_0to1 <= 10/8: + # if gene sets now balanced, return new filtered data and new label_dict_subset + return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc) + elif label_ratio_0to1 < 8/10: + boost = boost*1.1 + elif label_ratio_0to1 > 10/8: + boost = boost*0.9 + + assert i+1 == max_ntrials + if (label_ratio_0to1 <= label_ratio_0to1_orig < 8/10) or (10/8 > label_ratio_0to1_orig >= label_ratio_0to1): + label_ratio_0to1 = label_ratio_0to1_orig + label_dict_subset_new = label_dict_subset_orig + logger.warning( + f"Gene sets were not able to be balanced within 0.8-1.25 fold after {max_ntrials} trials. Imbalance level: {label_ratio_0to1}\n" + ) + return filter_data_balanced_genes(subset_data, label_dict_subset_new, num_proc) + + +def count_genes_for_balancing(subset_data, label_dict_subset, num_proc): + def count_targets(example): + labels = [ + label_dict_subset.get(token_id, -100) for token_id in example["input_ids"] + ] + counter_labels = Counter(labels) + # get count of labels 0 or 1, or if absent, return 0 + example["labels_counts"] = [counter_labels.get(0,0),counter_labels.get(1,0)] + return example + + subset_data = subset_data.map(count_targets, num_proc=num_proc) + + label0_counts = sum([counts[0] for counts in subset_data["labels_counts"]]) + label1_counts = sum([counts[1] for counts in subset_data["labels_counts"]]) + + subset_data = subset_data.remove_columns("labels_counts") + + return label0_counts, label1_counts + + +def filter_data_balanced_genes(subset_data, label_dict_subset, num_proc): + # function to filter by whether contains labels + def if_contains_subset_label(example): + a = list(label_dict_subset.keys()) + b = example["input_ids"] + return not set(a).isdisjoint(b) + + # filter dataset for examples containing classes for this split + logger.info("Filtering data for balanced genes") + subset_data_len_orig = len(subset_data) + subset_data = subset_data.filter(if_contains_subset_label, num_proc=num_proc) + logger.info( + f"Filtered {round((1-len(subset_data)/subset_data_len_orig)*100)}%; {len(subset_data)} remain\n" + ) + + return subset_data, label_dict_subset + + def balance_attr_splits( data, attr_to_split, diff --git a/geneformer/collator_for_classification.py b/geneformer/collator_for_classification.py index ba5c13d6dfad6274d0ab0cb670422a3cc7c15aee..3dfd934d6e162878ef36b6cf6f483a43b1054f9c 100644 --- a/geneformer/collator_for_classification.py +++ b/geneformer/collator_for_classification.py @@ -18,12 +18,6 @@ from transformers import ( from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj from transformers.utils.generic import _is_tensorflow, _is_torch -from . import TOKEN_DICTIONARY_FILE - -# load token dictionary (Ensembl IDs:token) -with open(TOKEN_DICTIONARY_FILE, "rb") as f: - token_dictionary = pickle.load(f) - EncodedInput = List[int] logger = logging.get_logger(__name__) VERY_LARGE_INTEGER = int( @@ -85,16 +79,18 @@ class TensorType(ExplicitEnum): class PrecollatorForGeneAndCellClassification(SpecialTokensMixin): - mask_token = "" - mask_token_id = token_dictionary.get("") - pad_token = "" - pad_token_id = token_dictionary.get("") - padding_side = "right" - all_special_ids = [ - token_dictionary.get(""), - token_dictionary.get("") - ] - model_input_names = ["input_ids"] + def __init__(self, *args, **kwargs) -> None: + super().__init__(mask_token="", pad_token="") + + self.token_dictionary = kwargs.get("token_dictionary") + self.padding_side = "right" + self.model_input_names = ["input_ids"] + self.mask_token_id = self.token_dictionary.get("") + self.pad_token_id = self.token_dictionary.get("") + self.all_special_ids = [ + self.token_dictionary.get(""), + self.token_dictionary.get("") + ] def _get_padding_truncation_strategies( self, padding=True, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs @@ -550,8 +546,7 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification): label_pad_token_id (:obj:`int`, `optional`, defaults to -100): The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions). """ - - tokenizer = PrecollatorForGeneAndCellClassification() + class_type = "gene" padding: Union[bool, str, PaddingStrategy] = True max_length: Optional[int] = None @@ -559,8 +554,9 @@ class DataCollatorForGeneClassification(DataCollatorForTokenClassification): label_pad_token_id: int = -100 def __init__(self, *args, **kwargs) -> None: + self.token_dictionary = kwargs.pop("token_dictionary") super().__init__( - tokenizer=self.tokenizer, + tokenizer=PrecollatorForGeneAndCellClassification(token_dictionary=self.token_dictionary), padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, diff --git a/geneformer/emb_extractor.py b/geneformer/emb_extractor.py index 241139d9cc4b8a772867fb2fb0d1cb84fd378b26..bd5451c8ceddac0e37867f0deb46db82dade1a2b 100644 --- a/geneformer/emb_extractor.py +++ b/geneformer/emb_extractor.py @@ -286,12 +286,20 @@ def plot_umap(embs_df, emb_dims, label, output_file, kwargs_dict, seed=0): sc.tl.umap(adata, random_state=seed) sns.set(rc={"figure.figsize": (10, 10)}, font_scale=2.3) sns.set_style("white") - default_kwargs_dict = {"palette": "Set2", "size": 200} + default_kwargs_dict = {"size": 200} if kwargs_dict is not None: default_kwargs_dict.update(kwargs_dict) - with plt.rc_context(): - sc.pl.umap(adata, color=label, **default_kwargs_dict) + cats = set(embs_df[label]) + + with plt.rc_context(): + ax = sc.pl.umap(adata, color=label, show=False, **default_kwargs_dict) + ax.legend(markerscale=2, + frameon=False, + loc="center left", + bbox_to_anchor=(1, 0.5), + ncol=(1 if len(cats) <= 14 else 2 if len(cats) <= 30 else 3)) + plt.show() plt.savefig(output_file, bbox_inches="tight") @@ -470,7 +478,6 @@ class EmbExtractor: ... emb_mode="cell", ... filter_data={"cell_type":["cardiomyocyte"]}, ... max_ncells=1000, - ... max_ncells_to_plot=1000, ... emb_layer=-1, ... emb_label=["disease", "cell_type"], ... labels_to_plot=["disease", "cell_type"]) @@ -783,15 +790,15 @@ class EmbExtractor: logger.error("Plotting UMAP requires 'labels_to_plot'. ") raise - if max_ncells_to_plot > self.max_ncells: - max_ncells_to_plot = self.max_ncells - logger.warning( - "max_ncells_to_plot must be <= max_ncells. " - f"Changing max_ncells_to_plot to {self.max_ncells}." - ) - - if (max_ncells_to_plot is not None) and (max_ncells_to_plot < self.max_ncells): - embs = embs.sample(max_ncells_to_plot, axis=0) + if max_ncells_to_plot is not None: + if max_ncells_to_plot > self.max_ncells: + max_ncells_to_plot = self.max_ncells + logger.warning( + "max_ncells_to_plot must be <= max_ncells. " + f"Changing max_ncells_to_plot to {self.max_ncells}." + ) + elif max_ncells_to_plot < self.max_ncells: + embs = embs.sample(max_ncells_to_plot, axis=0) if self.emb_label is None: label_len = 0 diff --git a/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl b/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl new file mode 100644 index 0000000000000000000000000000000000000000..b2bda1a2d693fb4987842d068471d3cc3592686d --- /dev/null +++ b/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3b589bb5ec75040d05fc44dd6bf0184cf87f3c362cf158d196a6ed3b7fe5f39 +size 940965 diff --git a/geneformer/gene_name_id_dict.pkl b/geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl similarity index 100% rename from geneformer/gene_name_id_dict.pkl rename to geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl diff --git a/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl b/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9238d4f76c3546871229f31e0794273e7fa9d2c3 --- /dev/null +++ b/geneformer/gene_dictionaries_30m/token_dictionary_gc30M.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab9dc40973fa5224d77b793e2fd114cacf3d08423ed9c4c49caf0ba9c7f218f1 +size 788424 diff --git a/geneformer/gene_median_dictionary.pkl b/geneformer/gene_median_dictionary.pkl deleted file mode 100644 index a0b5a900cdca5fd50aa6970e4df4465986a06873..0000000000000000000000000000000000000000 Binary files a/geneformer/gene_median_dictionary.pkl and /dev/null differ diff --git a/geneformer/in_silico_perturber.py b/geneformer/in_silico_perturber.py index ad79d99e830500d48d22feb7f6e3fae5eaa6a44f..13b4cf2ca544633a317bf01c552799acd7393e87 100644 --- a/geneformer/in_silico_perturber.py +++ b/geneformer/in_silico_perturber.py @@ -63,7 +63,7 @@ class InSilicoPerturber: "anchor_gene": {None, str}, "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"}, "num_classes": {int}, - "emb_mode": {"cell", "cell_and_gene"}, + "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"}, "cell_emb_style": {"mean_pool"}, "filter_data": {None, dict}, "cell_states_to_model": {None, dict}, @@ -71,6 +71,7 @@ class InSilicoPerturber: "max_ncells": {None, int}, "cell_inds_to_perturb": {"all", dict}, "emb_layer": {-1, 0}, + "token_dictionary_file" : {None, str}, "forward_batch_size": {int}, "nproc": {int}, } @@ -94,7 +95,8 @@ class InSilicoPerturber: emb_layer=-1, forward_batch_size=100, nproc=4, - token_dictionary_file=TOKEN_DICTIONARY_FILE, + token_dictionary_file=None, + clear_mem_ncells=1000, ): """ Initialize in silico perturber. @@ -129,16 +131,16 @@ class InSilicoPerturber: | ENSEMBL ID of gene to use as anchor in combination perturbations. | For example, if combos=1 and anchor_gene="ENSG00000148400": | anchor gene will be perturbed in combination with each other gene. - model_type : {"Pretrained", "GeneClassifier", "CellClassifier"} - | Whether model is the pretrained Geneformer or a fine-tuned gene or cell classifier. + model_type : {"Pretrained", "GeneClassifier", "CellClassifier", "MTLCellClassifier", "MTLCellClassifier-Quantized"} + | Whether model is the pretrained Geneformer or a fine-tuned gene, cell, or multitask cell classifier (+/- 8bit quantization). num_classes : int | If model is a gene or cell classifier, specify number of classes it was trained to classify. | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier. - emb_mode : {"cell", "cell_and_gene"} - | Whether to output impact of perturbation on cell and/or gene embeddings. + emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"} + | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings. | Gene embedding shifts only available as compared to original cell, not comparing to goal state. cell_emb_style : "mean_pool" - | Method for summarizing cell embeddings. + | Method for summarizing cell embeddings if not using CLS token. | Currently only option is mean pooling of gene embeddings for given cell. filter_data : None, dict | Default is to use all input data for in silico perturbation study. @@ -183,6 +185,8 @@ class InSilicoPerturber: | Number of CPU processes to use. token_dictionary_file : Path | Path to pickle file containing token dictionary (Ensembl ID:token). + clear_mem_ncells : int + | Clear memory every n cells. """ try: set_start_method("spawn") @@ -219,15 +223,31 @@ class InSilicoPerturber: self.emb_layer = emb_layer self.forward_batch_size = forward_batch_size self.nproc = nproc + self.token_dictionary_file = token_dictionary_file + self.clear_mem_ncells = clear_mem_ncells self.validate_options() # load token dictionary (Ensembl IDs:token) + if self.token_dictionary_file is None: + token_dictionary_file = TOKEN_DICTIONARY_FILE with open(token_dictionary_file, "rb") as f: self.gene_token_dict = pickle.load(f) self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()} self.pad_token_id = self.gene_token_dict.get("") + self.cls_token_id = self.gene_token_dict.get("") + self.eos_token_id = self.gene_token_dict.get("") + + + # Identify if special token is present in the token dictionary + if (self.cls_token_id is not None) and (self.eos_token_id is not None): + self.special_token = True + else: + if "cls" in self.emb_mode: + logger.error(f"emb_mode set to {self.emb_mode} but or token not in token dictionary.") + raise + self.special_token = False if self.anchor_gene is None: self.anchor_token = None @@ -285,7 +305,7 @@ class InSilicoPerturber: continue valid_type = False for option in valid_options: - if (option in [bool, int, list, dict]) and isinstance( + if (option in [bool, int, list, dict, str]) and isinstance( attr_value, option ): valid_type = True @@ -426,22 +446,46 @@ class InSilicoPerturber: self.max_len = pu.get_model_input_size(model) layer_to_quant = pu.quant_layers(model) + self.emb_layer - ### filter input data ### # general filtering of input data based on filter_data argument filtered_input_data = pu.load_and_filter( self.filter_data, self.nproc, input_data_file ) + + # Ensure emb_mode is cls if first token of the filtered input data is cls token + if self.special_token: + if (filtered_input_data["input_ids"][0][0] == self.cls_token_id) and ("cls" not in self.emb_mode): + logger.error( + "Emb mode 'cls' or 'cls_and_gene' required when first token is ." + ) + raise + if ("cls" in self.emb_mode): + if (filtered_input_data["input_ids"][0][0] != self.cls_token_id) or (filtered_input_data["input_ids"][0][-1] != self.eos_token_id): + logger.error( + "Emb mode 'cls' and 'cls_and_gene' require that first token is and last token is ." + ) + raise + filtered_input_data = self.apply_additional_filters(filtered_input_data) if self.perturb_group is True: - self.isp_perturb_set( - model, filtered_input_data, layer_to_quant, output_path_prefix - ) + if (self.special_token) and ("cls" in self.emb_mode): + self.isp_perturb_set_special( + model, filtered_input_data, layer_to_quant, output_path_prefix + ) + else: + self.isp_perturb_set( + model, filtered_input_data, layer_to_quant, output_path_prefix + ) else: - self.isp_perturb_all( - model, filtered_input_data, layer_to_quant, output_path_prefix - ) + if (self.special_token) and ("cls" in self.emb_mode): + self.isp_perturb_all_special( + model, filtered_input_data, layer_to_quant, output_path_prefix + ) + else: + self.isp_perturb_all( + model, filtered_input_data, layer_to_quant, output_path_prefix + ) def apply_additional_filters(self, filtered_input_data): # additional filtering of input data dependent on isp mode @@ -486,6 +530,7 @@ class InSilicoPerturber: layer_to_quant: int, output_path_prefix: str, ): + def make_group_perturbation_batch(example): example_input_ids = example["input_ids"] example["tokens_to_perturb"] = self.tokens_to_perturb @@ -504,7 +549,7 @@ class InSilicoPerturber: if self.perturb_type == "delete": example = pu.delete_indices(example) elif self.perturb_type == "overexpress": - example = pu.overexpress_tokens(example, self.max_len) + example = pu.overexpress_tokens(example, self.max_len, self.special_token) example["n_overflow"] = pu.calc_n_overflow( self.max_len, example["length"], @@ -678,8 +723,6 @@ class InSilicoPerturber: cos_sims_dict = self.update_perturbation_dictionary( cos_sims_dict, cos_sims_data, - filtered_input_data, - indices_to_perturb, gene_list, ) else: @@ -688,8 +731,6 @@ class InSilicoPerturber: cos_sims_dict[state] = self.update_perturbation_dictionary( cos_sims_dict[state], cos_sims_data[state], - filtered_input_data, - indices_to_perturb, gene_list, ) del minibatch @@ -711,6 +752,256 @@ class InSilicoPerturber: f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}", ) + + def isp_perturb_set_special( + self, + model, + filtered_input_data: Dataset, + layer_to_quant: int, + output_path_prefix: str, + ): + + def make_group_perturbation_batch(example): + example_input_ids = example["input_ids"] + example["tokens_to_perturb"] = self.tokens_to_perturb + indices_to_perturb = [ + example_input_ids.index(token) if token in example_input_ids else None + for token in self.tokens_to_perturb + ] + indices_to_perturb = [ + item for item in indices_to_perturb if item is not None + ] + if len(indices_to_perturb) > 0: + example["perturb_index"] = indices_to_perturb + else: + # -100 indicates tokens to overexpress are not present in rank value encoding + example["perturb_index"] = [-100] + if self.perturb_type == "delete": + example = pu.delete_indices(example) + elif self.perturb_type == "overexpress": + example = pu.overexpress_tokens(example, self.max_len, self.special_token) + example["n_overflow"] = pu.calc_n_overflow( + self.max_len, + example["length"], + self.tokens_to_perturb, + indices_to_perturb, + ) + return example + + total_batch_length = len(filtered_input_data) + if self.cell_states_to_model is None: + cos_sims_dict = defaultdict(list) + else: + cos_sims_dict = { + state: defaultdict(list) + for state in pu.get_possible_states(self.cell_states_to_model) + } + + perturbed_data = filtered_input_data.map( + make_group_perturbation_batch, num_proc=self.nproc + ) + + if self.perturb_type == "overexpress": + filtered_input_data = filtered_input_data.add_column( + "n_overflow", perturbed_data["n_overflow"] + ) + filtered_input_data = filtered_input_data.map( + pu.truncate_by_n_overflow_special, num_proc=self.nproc + ) + + if self.emb_mode == "cls_and_gene": + stored_gene_embs_dict = defaultdict(list) + + # iterate through batches + for i in trange(0, total_batch_length, self.forward_batch_size): + max_range = min(i + self.forward_batch_size, total_batch_length) + inds_select = [i for i in range(i, max_range)] + + minibatch = filtered_input_data.select(inds_select) + perturbation_batch = perturbed_data.select(inds_select) + + ##### CLS Embedding Mode ##### + if self.emb_mode == "cls": + indices_to_perturb = perturbation_batch["perturb_index"] + + original_cls_emb = get_embs( + model, + minibatch, + "cls", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + token_gene_dict=self.token_gene_dict, + summary_stat=None, + silent=True, + ) + + perturbation_cls_emb = get_embs( + model, + perturbation_batch, + "cls", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + token_gene_dict=self.token_gene_dict, + summary_stat=None, + silent=True, + ) + + # Calculate the cosine similarities + cls_cos_sims = pu.quant_cos_sims( + perturbation_cls_emb, + original_cls_emb, + self.cell_states_to_model, + self.state_embs_dict, + emb_mode="cell") + + # Update perturbation dictionary + if self.cell_states_to_model is None: + cos_sims_dict = self.update_perturbation_dictionary( + cos_sims_dict, + cls_cos_sims, + gene_list = None, + ) + else: + for state in cos_sims_dict.keys(): + cos_sims_dict[state] = self.update_perturbation_dictionary( + cos_sims_dict[state], + cls_cos_sims[state], + gene_list = None, + ) + + ##### CLS and Gene Embedding Mode ##### + elif self.emb_mode == "cls_and_gene": + full_original_emb = get_embs( + model, + minibatch, + "gene", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + self.token_gene_dict, + summary_stat=None, + silent=True, + ) + indices_to_perturb = perturbation_batch["perturb_index"] + # remove indices that were perturbed + original_emb = pu.remove_perturbed_indices_set( + full_original_emb, + self.perturb_type, + indices_to_perturb, + self.tokens_to_perturb, + minibatch["length"], + ) + full_perturbation_emb = get_embs( + model, + perturbation_batch, + "gene", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + self.token_gene_dict, + summary_stat=None, + silent=True, + ) + + # remove special tokens and padding + original_emb = original_emb[:, 1:-1, :] + if self.perturb_type == "overexpress": + perturbation_emb = full_perturbation_emb[:,1+len(self.tokens_to_perturb):-1,:] + elif self.perturb_type == "delete": + perturbation_emb = full_perturbation_emb[:,1:max(perturbation_batch["length"])-1,:] + + n_perturbation_genes = perturbation_emb.size()[1] + + gene_cos_sims = pu.quant_cos_sims( + perturbation_emb, + original_emb, + self.cell_states_to_model, + self.state_embs_dict, + emb_mode="gene", + ) + + # get cls emb + original_cls_emb = full_original_emb[:,0,:] + perturbation_cls_emb = full_perturbation_emb[:,0,:] + + cls_cos_sims = pu.quant_cos_sims( + perturbation_cls_emb, + original_cls_emb, + self.cell_states_to_model, + self.state_embs_dict, + emb_mode="cell", + ) + + # get cosine similarities in gene embeddings + # since getting gene embeddings, need gene names + + gene_list = minibatch["input_ids"] + # need to truncate gene_list + genes_to_exclude = self.tokens_to_perturb + [self.cls_token_id, self.eos_token_id] + gene_list = [ + [g for g in genes if g not in genes_to_exclude][ + :n_perturbation_genes + ] + for genes in gene_list + ] + + for cell_i, genes in enumerate(gene_list): + for gene_j, affected_gene in enumerate(genes): + if len(self.genes_to_perturb) > 1: + tokens_to_perturb = tuple(self.tokens_to_perturb) + else: + tokens_to_perturb = self.tokens_to_perturb[0] + + # fill in the gene cosine similarities + try: + stored_gene_embs_dict[ + (tokens_to_perturb, affected_gene) + ].append(gene_cos_sims[cell_i, gene_j].item()) + except KeyError: + stored_gene_embs_dict[ + (tokens_to_perturb, affected_gene) + ] = gene_cos_sims[cell_i, gene_j].item() + + if self.cell_states_to_model is None: + cos_sims_dict = self.update_perturbation_dictionary( + cos_sims_dict, + cls_cos_sims, + gene_list = None, + ) + else: + for state in cos_sims_dict.keys(): + cos_sims_dict[state] = self.update_perturbation_dictionary( + cos_sims_dict[state], + cls_cos_sims[state], + gene_list = None, + ) + del full_original_emb + del original_emb + del full_perturbation_emb + del perturbation_emb + del gene_cos_sims + + del original_cls_emb + del perturbation_cls_emb + del cls_cos_sims + del minibatch + del perturbation_batch + + torch.cuda.empty_cache() + + pu.write_perturbation_dictionary( + cos_sims_dict, + f"{output_path_prefix}_cell_embs_dict_{self.tokens_to_perturb}", + ) + + if self.emb_mode == "cls_and_gene": + pu.write_perturbation_dictionary( + stored_gene_embs_dict, + f"{output_path_prefix}_gene_embs_dict_{self.tokens_to_perturb}", + ) + def isp_perturb_all( self, model, @@ -729,8 +1020,10 @@ class InSilicoPerturber: if self.emb_mode == "cell_and_gene": stored_gene_embs_dict = defaultdict(list) - for i in trange(len(filtered_input_data)): - example_cell = filtered_input_data.select([i]) + + num_inds_perturbed = 1 + self.combos + for h in trange(len(filtered_input_data)): + example_cell = filtered_input_data.select([h]) full_original_emb = get_embs( model, example_cell, @@ -738,18 +1031,33 @@ class InSilicoPerturber: layer_to_quant, self.pad_token_id, self.forward_batch_size, - token_gene_dict=self.token_gene_dict, + self.token_gene_dict, summary_stat=None, silent=True, ) - + + if self.cell_states_to_model is not None: + original_cell_emb = pu.compute_nonpadded_cell_embedding( + full_original_emb, "mean_pool" + ) + # gene_list is used to assign cos sims back to genes - # need to remove the anchor gene gene_list = example_cell["input_ids"][0][:] + # need to remove the anchor gene if self.anchor_token is not None: for token in self.anchor_token: gene_list.remove(token) - + # index 0 is not overexpressed so remove + if self.perturb_type == "overexpress": + gene_list = gene_list[ + num_inds_perturbed: + ] + # remove perturbed index for gene list dict + perturbed_gene_dict = { + gene: gene_list[:i] + gene_list[i + 1 :] + for i, gene in enumerate(gene_list) + } + perturbation_batch, indices_to_perturb = pu.make_perturbation_batch( example_cell, self.perturb_type, @@ -759,148 +1067,430 @@ class InSilicoPerturber: self.nproc, ) - full_perturbation_emb = get_embs( - model, - perturbation_batch, - "gene", - layer_to_quant, - self.pad_token_id, - self.forward_batch_size, - token_gene_dict=self.token_gene_dict, - summary_stat=None, - silent=True, - ) - - num_inds_perturbed = 1 + self.combos - # need to remove overexpressed gene to quantify cosine shifts - if self.perturb_type == "overexpress": - perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :] - gene_list = gene_list[ - num_inds_perturbed: - ] # index 0 is not overexpressed - - elif self.perturb_type == "delete": - perturbation_emb = full_perturbation_emb + ispall_total_batch_length = len(perturbation_batch) + for i in trange(0, ispall_total_batch_length, self.forward_batch_size, leave=False): + ispall_max_range = min(i + self.forward_batch_size, ispall_total_batch_length) + perturbation_minibatch = perturbation_batch.select([i for i in range(i, ispall_max_range)]) + indices_to_perturb_mini = indices_to_perturb[i : ispall_max_range] + gene_list_mini = gene_list[i : ispall_max_range] # only perturbed genes from this minibatch - original_batch = pu.make_comparison_batch( - full_original_emb, indices_to_perturb, perturb_group=False - ) - - if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene": - gene_cos_sims = pu.quant_cos_sims( - perturbation_emb, - original_batch, - self.cell_states_to_model, - self.state_embs_dict, - emb_mode="gene", - ) - if self.cell_states_to_model is not None: - original_cell_emb = pu.compute_nonpadded_cell_embedding( - full_original_emb, "mean_pool" - ) - perturbation_cell_emb = pu.compute_nonpadded_cell_embedding( - full_perturbation_emb, "mean_pool" + full_perturbation_emb = get_embs( + model, + perturbation_minibatch, + "gene", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + self.token_gene_dict, + summary_stat=None, + silent=True, ) + + del perturbation_minibatch + + # need to remove overexpressed gene to quantify cosine shifts + if self.perturb_type == "overexpress": + perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :] + + elif self.perturb_type == "delete": + perturbation_emb = full_perturbation_emb + + + if self.cell_states_to_model is None or self.emb_mode == "cell_and_gene": + original_emb_minibatch = pu.make_comparison_batch( + full_original_emb, indices_to_perturb_mini, perturb_group=False + ) + gene_cos_sims = pu.quant_cos_sims( + perturbation_emb, + original_emb_minibatch, + self.cell_states_to_model, + self.state_embs_dict, + emb_mode="gene", + ) + del original_emb_minibatch + + if self.cell_states_to_model is not None: + perturbation_cell_emb = pu.compute_nonpadded_cell_embedding( + full_perturbation_emb, "mean_pool" + ) + + cell_cos_sims = pu.quant_cos_sims( + perturbation_cell_emb, + original_cell_emb, + self.cell_states_to_model, + self.state_embs_dict, + emb_mode="cell", + ) + del perturbation_cell_emb + + if self.emb_mode == "cell_and_gene": - cell_cos_sims = pu.quant_cos_sims( - perturbation_cell_emb, - original_cell_emb, - self.cell_states_to_model, - self.state_embs_dict, - emb_mode="cell", + for perturbation_i, perturbed_gene in enumerate(gene_list_mini): + for gene_j, affected_gene in enumerate( + perturbed_gene_dict[perturbed_gene] + ): + try: + stored_gene_embs_dict[ + (perturbed_gene, affected_gene) + ].append(gene_cos_sims[perturbation_i, gene_j].item()) + except KeyError: + stored_gene_embs_dict[ + (perturbed_gene, affected_gene) + ] = gene_cos_sims[perturbation_i, gene_j].item() + + del full_perturbation_emb + + if self.cell_states_to_model is None: + cos_sims_data = torch.mean(gene_cos_sims, dim=1) + cos_sims_dict = self.update_perturbation_dictionary( + cos_sims_dict, + cos_sims_data, + gene_list_mini, + ) + else: + cos_sims_data = cell_cos_sims + for state in cos_sims_dict.keys(): + cos_sims_dict[state] = self.update_perturbation_dictionary( + cos_sims_dict[state], + cos_sims_data[state], + gene_list_mini, + ) + + # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells + if i % self.clear_mem_ncells/10 == 0: + pu.write_perturbation_dictionary( + cos_sims_dict, + f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}", + ) + if self.emb_mode == "cell_and_gene": + pu.write_perturbation_dictionary( + stored_gene_embs_dict, + f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}", + ) + + # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell + if i % self.clear_mem_ncells == 0: + pickle_batch += 1 + if self.cell_states_to_model is None: + cos_sims_dict = defaultdict(list) + else: + cos_sims_dict = { + state: defaultdict(list) + for state in pu.get_possible_states(self.cell_states_to_model) + } + + if self.emb_mode == "cell_and_gene": + stored_gene_embs_dict = defaultdict(list) + + torch.cuda.empty_cache() + + pu.write_perturbation_dictionary( + cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}" + ) + + if self.emb_mode == "cell_and_gene": + pu.write_perturbation_dictionary( + stored_gene_embs_dict, + f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}", ) - if self.emb_mode == "cell_and_gene": - # remove perturbed index for gene list - perturbed_gene_dict = { - gene: gene_list[:i] + gene_list[i + 1 :] - for i, gene in enumerate(gene_list) + pickle_batch = -1 + if self.cell_states_to_model is None: + cos_sims_dict = defaultdict(list) + else: + cos_sims_dict = { + state: defaultdict(list) + for state in pu.get_possible_states(self.cell_states_to_model) } - for perturbation_i, perturbed_gene in enumerate(gene_list): - for gene_j, affected_gene in enumerate( - perturbed_gene_dict[perturbed_gene] - ): - try: - stored_gene_embs_dict[ - (perturbed_gene, affected_gene) - ].append(gene_cos_sims[perturbation_i, gene_j].item()) - except KeyError: - stored_gene_embs_dict[ - (perturbed_gene, affected_gene) - ] = gene_cos_sims[perturbation_i, gene_j].item() + if self.emb_mode == "cell_and_gene": + stored_gene_embs_dict = defaultdict(list) - if self.cell_states_to_model is None: - cos_sims_data = torch.mean(gene_cos_sims, dim=1) - cos_sims_dict = self.update_perturbation_dictionary( - cos_sims_dict, - cos_sims_data, - filtered_input_data, - indices_to_perturb, - gene_list, - ) - else: - cos_sims_data = cell_cos_sims - for state in cos_sims_dict.keys(): - cos_sims_dict[state] = self.update_perturbation_dictionary( - cos_sims_dict[state], - cos_sims_data[state], - filtered_input_data, - indices_to_perturb, - gene_list, - ) + # clear memory between cells + del perturbation_batch + del full_original_emb + if self.cell_states_to_model is not None: + del original_cell_emb + torch.cuda.empty_cache() + + def isp_perturb_all_special( + self, + model, + filtered_input_data: Dataset, + layer_to_quant: int, + output_path_prefix: str, + ): + pickle_batch = -1 + if self.cell_states_to_model is None: + cos_sims_dict = defaultdict(list) + else: + cos_sims_dict = { + state: defaultdict(list) + for state in pu.get_possible_states(self.cell_states_to_model) + } - # save dict to disk every 100 cells - if i % 100 == 0: - pu.write_perturbation_dictionary( - cos_sims_dict, - f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}", - ) - if self.emb_mode == "cell_and_gene": - pu.write_perturbation_dictionary( - stored_gene_embs_dict, - f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}", - ) + if self.emb_mode == "cls_and_gene": + stored_gene_embs_dict = defaultdict(list) - # reset and clear memory every 1000 cells - if i % 1000 == 0: - pickle_batch += 1 - if self.cell_states_to_model is None: - cos_sims_dict = defaultdict(list) - else: - cos_sims_dict = { - state: defaultdict(list) - for state in pu.get_possible_states(self.cell_states_to_model) - } + num_inds_perturbed = 1 + self.combos + for h in trange(len(filtered_input_data)): + example_cell = filtered_input_data.select([h]) - if self.emb_mode == "cell_and_gene": - stored_gene_embs_dict = defaultdict(list) + # get original example cell cls and/or gene embs for comparison + if self.emb_mode == "cls": + original_cls_emb = get_embs( + model, + example_cell, + "cls", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + self.token_gene_dict, + summary_stat=None, + silent=True, + ) + elif self.emb_mode == "cls_and_gene": + full_original_emb = get_embs( + model, + example_cell, + "gene", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + self.token_gene_dict, + summary_stat=None, + silent=True, + ) + original_cls_emb = full_original_emb[:,0,:].clone().detach() + + # gene_list is used to assign cos sims back to genes + gene_list = example_cell["input_ids"][0][:] - torch.cuda.empty_cache() + # need to remove special tokens + for token in [self.cls_token_id, self.eos_token_id]: + gene_list.remove(token) + # need to remove the anchor gene + if self.anchor_token is not None: + for token in self.anchor_token: + gene_list.remove(token) + # index 0 is not overexpressed so remove + if self.perturb_type == "overexpress": + gene_list = gene_list[ + num_inds_perturbed: + ] + # remove perturbed index for gene list dict + perturbed_gene_dict = { + gene: gene_list[:i] + gene_list[i + 1 :] + for i, gene in enumerate(gene_list) + } - pu.write_perturbation_dictionary( - cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_1Kbatch{pickle_batch}" - ) + perturbation_batch, indices_to_perturb = pu.make_perturbation_batch_special( + example_cell, + self.perturb_type, + self.tokens_to_perturb, + self.anchor_token, + self.combos, + self.nproc, + ) - if self.emb_mode == "cell_and_gene": + ispall_total_batch_length = len(perturbation_batch) + for i in trange(0, ispall_total_batch_length, self.forward_batch_size, leave=False): + ispall_max_range = min(i + self.forward_batch_size, ispall_total_batch_length) + perturbation_minibatch = perturbation_batch.select([i for i in range(i, ispall_max_range)]) + indices_to_perturb_mini = indices_to_perturb[i : ispall_max_range] + gene_list_mini = gene_list[i : ispall_max_range] # only perturbed genes from this minibatch + + ##### CLS Embedding Mode ##### + if self.emb_mode == "cls": + # Extract cls embeddings from perturbed cells + perturbation_cls_emb = get_embs( + model, + perturbation_minibatch, + "cls", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + self.token_gene_dict, + summary_stat=None, + silent=True, + ) + + # Calculate cosine similarities + cls_cos_sims = pu.quant_cos_sims( + perturbation_cls_emb, + original_cls_emb, + self.cell_states_to_model, + self.state_embs_dict, + emb_mode="cell", + ) + + if self.cell_states_to_model is None: + cos_sims_dict = self.update_perturbation_dictionary( + cos_sims_dict, + cls_cos_sims, + gene_list_mini, + ) + else: + + for state in cos_sims_dict.keys(): + cos_sims_dict[state] = self.update_perturbation_dictionary( + cos_sims_dict[state], + cls_cos_sims[state], + gene_list_mini, + ) + + del perturbation_minibatch + del perturbation_cls_emb + del cls_cos_sims + + ##### CLS and Gene Embedding Mode ##### + elif self.emb_mode == "cls_and_gene": + full_perturbation_emb = get_embs( + model, + perturbation_minibatch, + "gene", + layer_to_quant, + self.pad_token_id, + self.forward_batch_size, + self.token_gene_dict, + summary_stat=None, + silent=True, + ) + + # need to remove overexpressed gene and cls/eos to quantify cosine shifts + if self.perturb_type == "overexpress": + perturbation_emb = full_perturbation_emb[:, 1+num_inds_perturbed:-1, :].clone().detach() + elif self.perturb_type == "delete": + perturbation_emb = full_perturbation_emb[:, 1:-1, :].clone().detach() + + original_emb_minibatch = pu.make_comparison_batch( + full_original_emb, indices_to_perturb_mini, perturb_group=False + ) + + original_emb_minibatch = original_emb_minibatch[:, 1:-1, :].clone().detach() + gene_cos_sims = pu.quant_cos_sims( + perturbation_emb, + original_emb_minibatch, + self.cell_states_to_model, + self.state_embs_dict, + emb_mode="gene", + ) + + for perturbation_i, perturbed_gene in enumerate(gene_list_mini): + for gene_j, affected_gene in enumerate( + perturbed_gene_dict[perturbed_gene] + ): + try: + stored_gene_embs_dict[ + (perturbed_gene, affected_gene) + ].append(gene_cos_sims[perturbation_i, gene_j].item()) + except KeyError: + stored_gene_embs_dict[ + (perturbed_gene, affected_gene) + ] = gene_cos_sims[perturbation_i, gene_j].item() + + # get cls emb + perturbation_cls_emb = full_perturbation_emb[:,0,:].clone().detach() + + cls_cos_sims = pu.quant_cos_sims( + perturbation_cls_emb, + original_cls_emb, + self.cell_states_to_model, + self.state_embs_dict, + emb_mode="cell", + ) + + if self.cell_states_to_model is None: + cos_sims_dict = self.update_perturbation_dictionary( + cos_sims_dict, + cls_cos_sims, + gene_list_mini, + ) + else: + for state in cos_sims_dict.keys(): + cos_sims_dict[state] = self.update_perturbation_dictionary( + cos_sims_dict[state], + cls_cos_sims[state], + gene_list_mini, + ) + + del perturbation_minibatch + del original_emb_minibatch + del full_perturbation_emb + del perturbation_emb + del perturbation_cls_emb + del cls_cos_sims + del gene_cos_sims + + # save dict to disk every self.clear_mem_ncells/10 (default 100) simulated cells + if i % max(1,self.clear_mem_ncells/10) == 0: + pu.write_perturbation_dictionary( + cos_sims_dict, + f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}", + ) + if self.emb_mode == "cls_and_gene": + pu.write_perturbation_dictionary( + stored_gene_embs_dict, + f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}", + ) + + # reset and clear memory every self.clear_mem_ncells (default 1000) simulated cells or at the end of the example cell + if i % self.clear_mem_ncells == 0: + pickle_batch += 1 + if self.cell_states_to_model is None: + cos_sims_dict = defaultdict(list) + else: + cos_sims_dict = { + state: defaultdict(list) + for state in pu.get_possible_states(self.cell_states_to_model) + } + + if self.emb_mode == "cls_and_gene": + stored_gene_embs_dict = defaultdict(list) + + torch.cuda.empty_cache() + pu.write_perturbation_dictionary( - stored_gene_embs_dict, - f"{output_path_prefix}_dict_gene_embs_1Kbatch{pickle_batch}", + cos_sims_dict, f"{output_path_prefix}_dict_cell_embs_{h}batch{pickle_batch}" ) + + if self.emb_mode == "cls_and_gene": + pu.write_perturbation_dictionary( + stored_gene_embs_dict, + f"{output_path_prefix}_dict_gene_embs_{h}batch{pickle_batch}", + ) + + pickle_batch = -1 + if self.cell_states_to_model is None: + cos_sims_dict = defaultdict(list) + else: + cos_sims_dict = { + state: defaultdict(list) + for state in pu.get_possible_states(self.cell_states_to_model) + } + + if self.emb_mode == "cls_and_gene": + stored_gene_embs_dict = defaultdict(list) + + # clear memory between cells + del perturbation_batch + del original_cls_emb + if self.emb_mode == "cls_and_gene": + del full_original_emb + torch.cuda.empty_cache() + def update_perturbation_dictionary( self, cos_sims_dict: defaultdict, cos_sims_data: torch.Tensor, - filtered_input_data: Dataset, - indices_to_perturb: List[List[int]], gene_list=None, ): if gene_list is not None and cos_sims_data.shape[0] != len(gene_list): logger.error( f"len(cos_sims_data.shape[0]) != len(gene_list). \n \ - cos_sims_data.shape[0] = {cos_sims_data.shape[0]}.\n \ - len(gene_list) = {len(gene_list)}." + {cos_sims_data.shape[0]=}.\n \ + {len(gene_list)=}." ) raise @@ -924,4 +1514,4 @@ class InSilicoPerturber: for i, cos in enumerate(cos_sims_data.tolist()): cos_sims_dict[(gene_list[i], "cell_emb")].append(cos) - return cos_sims_dict + return cos_sims_dict \ No newline at end of file diff --git a/geneformer/in_silico_perturber_stats.py b/geneformer/in_silico_perturber_stats.py index 1e10d64f0706c009760bdbd44c2a57f6c7a82ed2..373697127a4585b433ff289b947772361643a18d 100644 --- a/geneformer/in_silico_perturber_stats.py +++ b/geneformer/in_silico_perturber_stats.py @@ -114,6 +114,7 @@ def read_dictionaries( state_dict[state_value][key] += new_dict[key] except KeyError: state_dict[state_value][key] = new_dict[key] + if not file_found: logger.error( "No raw data for processing found within provided directory. " @@ -237,13 +238,16 @@ def find(variable, x): def isp_aggregate_gene_shifts( - cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict + cos_sims_df, dict_list, gene_token_id_dict, gene_id_name_dict, token_dtype ): cos_shift_data = dict() for i in trange(cos_sims_df.shape[0]): token = cos_sims_df["Gene"][i] for dict_i in dict_list: - affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)] + if token_dtype == "nontuple": + affected_pairs = [k for k, v in dict_i.items() if k[0] == token] + else: + affected_pairs = [k for k, v in dict_i.items() if find(k[0], token)] for key in affected_pairs: if key in cos_shift_data.keys(): cos_shift_data[key] += dict_i.get(key, []) @@ -256,11 +260,11 @@ def isp_aggregate_gene_shifts( cos_sims_full_df = pd.DataFrame() cos_sims_full_df["Perturbed"] = [k[0] for k, v in cos_data_mean.items()] cos_sims_full_df["Gene_name"] = [ - cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"][0] + cos_sims_df[cos_sims_df["Gene"] == k[0]]["Gene_name"].item() for k, v in cos_data_mean.items() ] cos_sims_full_df["Ensembl_ID"] = [ - cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"][0] + cos_sims_df[cos_sims_df["Gene"] == k[0]]["Ensembl_ID"].item() for k, v in cos_data_mean.items() ] @@ -690,7 +694,7 @@ class InSilicoPerturberStats: | Default is assuming genes_to_perturb in isp experiment was "all" (each gene in each cell). | Otherwise, may provide a list of ENSEMBL IDs of genes perturbed as a group all together. combos : {0,1,2} - | Whether to perturb genes individually (0), in pairs (1), or in triplets (2). + | Whether genex perturbed in isp experiment were perturbed individually (0), in pairs (1), or in triplets (2). anchor_gene : None, str | ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes. | For example, if combos=1 and anchor_gene="ENSG00000136574": @@ -1014,7 +1018,7 @@ class InSilicoPerturberStats: }, index=[i for i in range(len(gene_list))], ) - + if self.mode == "goal_state_shift": cos_sims_df = isp_stats_to_goal_state( cos_sims_df_initial, @@ -1045,11 +1049,23 @@ class InSilicoPerturberStats: cos_sims_df = isp_aggregate_grouped_perturb(cos_sims_df_initial, dict_list, self.genes_perturbed) elif self.mode == "aggregate_gene_shifts": + if (self.genes_perturbed == "all") and (self.combos == 0): + tuple_types = [True if isinstance(genes, tuple) else False for genes in gene_list] + if all(tuple_types): + token_dtype = "tuple" + elif not any(tuple_types): + token_dtype = "nontuple" + else: + token_dtype = "mix" + else: + token_dtype = "mix" + cos_sims_df = isp_aggregate_gene_shifts( cos_sims_df_initial, dict_list, self.gene_token_id_dict, self.gene_id_name_dict, + token_dtype ) # save perturbation stats to output_path diff --git a/geneformer/mtl/__init__.py b/geneformer/mtl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/geneformer/mtl/collators.py b/geneformer/mtl/collators.py new file mode 100644 index 0000000000000000000000000000000000000000..ecc56b459a4d2b6ad5481393d2d8c7cfeba770d8 --- /dev/null +++ b/geneformer/mtl/collators.py @@ -0,0 +1,66 @@ +#imports +import torch + +from ..collator_for_classification import DataCollatorForGeneClassification + +""" +Geneformer collator for multi-task cell classification. +""" + +class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification): + class_type = "cell" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def _prepare_batch(self, features): + # Process inputs as usual + batch = self.tokenizer.pad( + features, + class_type=self.class_type, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + + # Check if labels are present + if "label" in features[0]: + # Initialize labels dictionary for all tasks + labels = {task: [] for task in features[0]["label"].keys()} + + # Populate labels for each task + for feature in features: + for task, label in feature["label"].items(): + labels[task].append(label) + + # Convert label lists to tensors, handling dictionaries appropriately + for task in labels: + if isinstance(labels[task][0], (list, torch.Tensor)): + dtype = torch.long + labels[task] = torch.tensor(labels[task], dtype=dtype) + elif isinstance(labels[task][0], dict): + # Handle dict specifically if needed + pass # Resolve nested data structure + + # Update the batch to include task-specific labels + batch["labels"] = labels + else: + # If no labels are present, create empty labels for all tasks + batch["labels"] = {task: torch.tensor([], dtype=torch.long) for task in features[0]["input_ids"].keys()} + + return batch + + def __call__(self, features): + batch = self._prepare_batch(features) + + for k, v in batch.items(): + if torch.is_tensor(v): + batch[k] = v.clone().detach() + elif isinstance(v, dict): + # Assuming nested structure needs conversion + batch[k] = {task: torch.tensor(labels, dtype=torch.int64) for task, labels in v.items()} + else: + batch[k] = torch.tensor(v, dtype=torch.int64) + + return batch \ No newline at end of file diff --git a/geneformer/mtl/data.py b/geneformer/mtl/data.py new file mode 100644 index 0000000000000000000000000000000000000000..a8cb294ca97f73261c72b6dfbc022ebaf856e736 --- /dev/null +++ b/geneformer/mtl/data.py @@ -0,0 +1,116 @@ +from .imports import * +import os +from .collators import DataCollatorForMultitaskCellClassification + +def load_and_preprocess_data(dataset_path, config, is_test=False, dataset_type=""): + try: + dataset = load_from_disk(dataset_path) + + task_names = [f"task{i+1}" for i in range(len(config["task_columns"]))] + task_to_column = dict(zip(task_names, config["task_columns"])) + config["task_names"] = task_names + + if not is_test: + available_columns = set(dataset.column_names) + for column in task_to_column.values(): + if column not in available_columns: + raise KeyError(f"Column {column} not found in the dataset. Available columns: {list(available_columns)}") + + label_mappings = {} + task_label_mappings = {} + cell_id_mapping = {} + num_labels_list = [] + + # Load or create task label mappings + if not is_test: + for task, column in task_to_column.items(): + unique_values = sorted(set(dataset[column])) # Ensure consistency + label_mappings[column] = {label: idx for idx, label in enumerate(unique_values)} + task_label_mappings[task] = label_mappings[column] + num_labels_list.append(len(unique_values)) + + # Print the mappings for each task with dataset type prefix + for task, mapping in task_label_mappings.items(): + print(f"{dataset_type.capitalize()} mapping for {task}: {mapping}") # sanity check, for train/validation splits + + # Save the task label mappings as a pickle file + with open(f"{config['results_dir']}/task_label_mappings.pkl", "wb") as f: + pickle.dump(task_label_mappings, f) + else: + # Load task label mappings from pickle file for test data + with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f: + task_label_mappings = pickle.load(f) + + # Infer num_labels_list from task_label_mappings + for task, mapping in task_label_mappings.items(): + num_labels_list.append(len(mapping)) + + # Store unique cell IDs in a separate dictionary + for idx, record in enumerate(dataset): + cell_id = record.get('unique_cell_id', idx) + cell_id_mapping[idx] = cell_id + + # Transform records to the desired format + transformed_dataset = [] + for idx, record in enumerate(dataset): + transformed_record = {} + transformed_record['input_ids'] = torch.tensor(record['input_ids'], dtype=torch.long) + + # Use index-based cell ID for internal tracking + transformed_record['cell_id'] = idx + + if not is_test: + # Prepare labels + label_dict = {} + for task, column in task_to_column.items(): + label_value = record[column] + label_index = task_label_mappings[task][label_value] + label_dict[task] = label_index + transformed_record['label'] = label_dict + else: + # Create dummy labels for test data + label_dict = {task: -1 for task in config["task_names"]} + transformed_record['label'] = label_dict + + transformed_dataset.append(transformed_record) + + return transformed_dataset, cell_id_mapping, num_labels_list + except KeyError as e: + print(f"Missing configuration or dataset key: {e}") + except Exception as e: + print(f"An error occurred while loading or preprocessing data: {e}") + return None, None, None + +def preload_and_process_data(config): + # Load and preprocess data once + train_dataset, train_cell_id_mapping, num_labels_list = load_and_preprocess_data(config["train_path"], config, dataset_type="train") + val_dataset, val_cell_id_mapping, _ = load_and_preprocess_data(config["val_path"], config, dataset_type="validation") + return train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list + +def get_data_loader(preprocessed_dataset, batch_size): + nproc = os.cpu_count() ### I/O operations + + data_collator = DataCollatorForMultitaskCellClassification() + + loader = DataLoader(preprocessed_dataset, batch_size=batch_size, shuffle=True, + collate_fn=data_collator, num_workers=nproc, pin_memory=True) + return loader +def preload_data(config): + # Preprocessing the data before the Optuna trials start + train_loader = get_data_loader("train", config) + val_loader = get_data_loader("val", config) + return train_loader, val_loader + +def load_and_preprocess_test_data(config): + """ + Load and preprocess test data, treating it as unlabeled. + """ + return load_and_preprocess_data(config["test_path"], config, is_test=True) + +def prepare_test_loader(config): + """ + Prepare DataLoader for the test dataset. + """ + test_dataset, cell_id_mapping, num_labels_list = load_and_preprocess_test_data(config) + test_loader = get_data_loader(test_dataset, config['batch_size']) + return test_loader, cell_id_mapping, num_labels_list \ No newline at end of file diff --git a/geneformer/mtl/eval_utils.py b/geneformer/mtl/eval_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f885313db9610a93a9feb2b7703861ab4db816f5 --- /dev/null +++ b/geneformer/mtl/eval_utils.py @@ -0,0 +1,81 @@ +from .imports import * +import pandas as pd +from .data import prepare_test_loader +from .model import GeneformerMultiTask + +def evaluate_test_dataset(model, device, test_loader, cell_id_mapping, config): + task_pred_labels = {task_name: [] for task_name in config["task_names"]} + task_pred_probs = {task_name: [] for task_name in config["task_names"]} + cell_ids = [] + + # Load task label mappings from pickle file + with open(f"{config['results_dir']}/task_label_mappings.pkl", "rb") as f: + task_label_mappings = pickle.load(f) + + model.eval() + with torch.no_grad(): + for batch in test_loader: + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + _, logits, _ = model(input_ids, attention_mask) + for sample_idx in range(len(batch['input_ids'])): + cell_id = cell_id_mapping[batch['cell_id'][sample_idx].item()] + cell_ids.append(cell_id) + for i, task_name in enumerate(config["task_names"]): + pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() + pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy() + task_pred_labels[task_name].append(pred_label) + task_pred_probs[task_name].append(pred_prob) + + # Save test predictions with cell IDs and probabilities to CSV + test_results_dir = config["results_dir"] + os.makedirs(test_results_dir, exist_ok=True) + test_preds_file = os.path.join(test_results_dir, "test_preds.csv") + + rows = [] + for sample_idx in range(len(cell_ids)): + row = {'Cell ID': cell_ids[sample_idx]} + for task_name in config["task_names"]: + row[f'{task_name} Prediction'] = task_pred_labels[task_name][sample_idx] + row[f'{task_name} Probabilities'] = ','.join(map(str, task_pred_probs[task_name][sample_idx])) + rows.append(row) + + df = pd.DataFrame(rows) + df.to_csv(test_preds_file, index=False) + print(f"Test predictions saved to {test_preds_file}") + +def load_and_evaluate_test_model(config): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + test_loader, cell_id_mapping, num_labels_list = prepare_test_loader(config) + model_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") + hyperparams_path = os.path.join(model_directory, "hyperparameters.json") + + # Load the saved best hyperparameters + with open(hyperparams_path, 'r') as f: + best_hyperparams = json.load(f) + + # Extract the task weights if present, otherwise set to None + task_weights = best_hyperparams.get("task_weights", None) + normalized_task_weights = task_weights if task_weights else [] + + # Print the loaded hyperparameters + print("Loaded hyperparameters:") + for param, value in best_hyperparams.items(): + if param == "task_weights": + print(f"normalized_task_weights: {value}") + else: + print(f"{param}: {value}") + + best_model_path = os.path.join(model_directory, "pytorch_model.bin") + best_model = GeneformerMultiTask( + config["pretrained_path"], + num_labels_list, + dropout_rate=best_hyperparams["dropout_rate"], + use_task_weights=config["use_task_weights"], + task_weights=normalized_task_weights + ) + best_model.load_state_dict(torch.load(best_model_path)) + best_model.to(device) + + evaluate_test_dataset(best_model, device, test_loader, cell_id_mapping, config) + print("Evaluation completed.") diff --git a/geneformer/mtl/imports.py b/geneformer/mtl/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..7a300fdcff1d190192d5e132ea78886949bf3e71 --- /dev/null +++ b/geneformer/mtl/imports.py @@ -0,0 +1,46 @@ +import numpy as np +import pickle +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from itertools import chain +import warnings +from enum import Enum +from typing import Dict, List, Optional, Union +import sys +import os +import json +import gc +import functools +import pandas as pd + +from sklearn.metrics import f1_score, accuracy_score, roc_auc_score, roc_curve +from sklearn.preprocessing import LabelEncoder +from sklearn.model_selection import train_test_split + +import optuna + +from transformers import ( + BertConfig, + BertModel, + AdamW, + get_linear_schedule_with_warmup, + get_cosine_schedule_with_warmup, + DataCollatorForTokenClassification, + SpecialTokensMixin, + BatchEncoding, + get_scheduler, +) +from transformers.utils import logging, to_py_obj + +from datasets import load_from_disk + +# local modules +from .data import preload_and_process_data, get_data_loader +from .model import GeneformerMultiTask +from .utils import save_model +from .optuna_utils import create_optuna_study +from .collators import DataCollatorForMultitaskCellClassification \ No newline at end of file diff --git a/geneformer/mtl/model.py b/geneformer/mtl/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ce8620ab6080985eab9839f4ada1c5303f477e --- /dev/null +++ b/geneformer/mtl/model.py @@ -0,0 +1,84 @@ +from transformers import BertModel, BertConfig +import torch +import torch.nn as nn + +class AttentionPool(nn.Module): + """Attention-based pooling layer.""" + def __init__(self, hidden_size): + super(AttentionPool, self).__init__() + self.attention_weights = nn.Parameter(torch.randn(hidden_size, 1)) + nn.init.xavier_uniform_(self.attention_weights) # https://pytorch.org/docs/stable/nn.init.html + + def forward(self, hidden_states): + attention_scores = torch.matmul(hidden_states, self.attention_weights) + attention_scores = torch.softmax(attention_scores, dim=1) + pooled_output = torch.sum(hidden_states * attention_scores, dim=1) + return pooled_output + +class GeneformerMultiTask(nn.Module): + def __init__(self, pretrained_path, num_labels_list, dropout_rate=0.1, use_task_weights=False, task_weights=None, max_layers_to_freeze=0, use_attention_pooling=False): + super(GeneformerMultiTask, self).__init__() + self.config = BertConfig.from_pretrained(pretrained_path) + self.bert = BertModel(self.config) + self.num_labels_list = num_labels_list + self.use_task_weights = use_task_weights + self.dropout = nn.Dropout(dropout_rate) + self.use_attention_pooling = use_attention_pooling + + if use_task_weights and (task_weights is None or len(task_weights) != len(num_labels_list)): + raise ValueError("Task weights must be defined and match the number of tasks when 'use_task_weights' is True.") + self.task_weights = task_weights if use_task_weights else [1.0] * len(num_labels_list) + + # Freeze the specified initial layers + for layer in self.bert.encoder.layer[:max_layers_to_freeze]: + for param in layer.parameters(): + param.requires_grad = False + + self.attention_pool = AttentionPool(self.config.hidden_size) if use_attention_pooling else None + + self.classification_heads = nn.ModuleList([ + nn.Linear(self.config.hidden_size, num_labels) for num_labels in num_labels_list + ]) + # initialization of the classification heads: https://pytorch.org/docs/stable/nn.init.html + for head in self.classification_heads: + nn.init.xavier_uniform_(head.weight) + nn.init.zeros_(head.bias) + + def forward(self, input_ids, attention_mask, labels=None): + try: + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + except Exception as e: + raise RuntimeError(f"Error during BERT forward pass: {e}") + + sequence_output = outputs.last_hidden_state + + try: + pooled_output = self.attention_pool(sequence_output) if self.use_attention_pooling else sequence_output[:, 0, :] + pooled_output = self.dropout(pooled_output) + except Exception as e: + raise RuntimeError(f"Error during pooling and dropout: {e}") + + total_loss = 0 + logits = [] + losses = [] + + for task_id, (head, num_labels) in enumerate(zip(self.classification_heads, self.num_labels_list)): + try: + task_logits = head(pooled_output) + except Exception as e: + raise RuntimeError(f"Error during forward pass of classification head {task_id}: {e}") + + logits.append(task_logits) + + if labels is not None: + try: + loss_fct = nn.CrossEntropyLoss() + task_loss = loss_fct(task_logits.view(-1, num_labels), labels[task_id].view(-1)) + if self.use_task_weights: + task_loss *= self.task_weights[task_id] + total_loss += task_loss + losses.append(task_loss.item()) + except Exception as e: + raise RuntimeError(f"Error during loss computation for task {task_id}: {e}") + + return total_loss, logits, losses if labels is not None else logits diff --git a/geneformer/mtl/optuna_utils.py b/geneformer/mtl/optuna_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62f2964d6e6c7a8034cd425d60f3ee1fc001d535 --- /dev/null +++ b/geneformer/mtl/optuna_utils.py @@ -0,0 +1,21 @@ +import optuna +from optuna.integration import TensorBoardCallback + +def save_trial_callback(study, trial, trials_result_path): + with open(trials_result_path, "a") as f: + f.write(f"Trial {trial.number}: Value (F1 Macro): {trial.value}, Params: {trial.params}\n") + +def create_optuna_study(objective, n_trials, trials_result_path, tensorboard_log_dir): + study = optuna.create_study(direction="maximize") + + # init TensorBoard callback + tensorboard_callback = TensorBoardCallback(dirname=tensorboard_log_dir, metric_name="F1 Macro") + + # callback and TensorBoard callback + callbacks = [ + lambda study, trial: save_trial_callback(study, trial, trials_result_path), + tensorboard_callback + ] + + study.optimize(objective, n_trials=n_trials, callbacks=callbacks) + return study \ No newline at end of file diff --git a/geneformer/mtl/train.py b/geneformer/mtl/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4845cbe59f22e100848859411f7c12ee4d75868d --- /dev/null +++ b/geneformer/mtl/train.py @@ -0,0 +1,242 @@ +from .imports import * +from .data import preload_and_process_data, get_data_loader +from .model import GeneformerMultiTask +from .utils import calculate_task_specific_metrics +from torch.utils.tensorboard import SummaryWriter +import pandas as pd +import os +from tqdm import tqdm +import random +import numpy as np +import torch + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def initialize_wandb(config): + if config.get("use_wandb", False): + import wandb + wandb.init(project=config["wandb_project"], config=config) + print("Weights & Biases (wandb) initialized and will be used for logging.") + else: + print("Weights & Biases (wandb) is not enabled. Logging will use other methods.") + +def create_model(config, num_labels_list, device): + model = GeneformerMultiTask( + config["pretrained_path"], + num_labels_list, + dropout_rate=config["dropout_rate"], + use_task_weights=config["use_task_weights"], + task_weights=config["task_weights"], + max_layers_to_freeze=config["max_layers_to_freeze"], + use_attention_pooling=config["use_attention_pooling"] + ) + if config["use_data_parallel"]: + model = nn.DataParallel(model) + return model.to(device) + +def setup_optimizer_and_scheduler(model, config, total_steps): + optimizer = AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"]) + warmup_steps = int(config["warmup_ratio"] * total_steps) + + if config["lr_scheduler_type"] == "linear": + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) + elif config["lr_scheduler_type"] == "cosine": + scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, num_cycles=0.5) + + return optimizer, scheduler + +def train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch): + model.train() + progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}") + for batch_idx, batch in enumerate(progress_bar): + optimizer.zero_grad() + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + labels = [batch['labels'][task_name].to(device) for task_name in config["task_names"]] + + loss, _, _ = model(input_ids, attention_mask, labels) + loss.backward() + + if config["gradient_clipping"]: + torch.nn.utils.clip_grad_norm_(model.parameters(), config["max_grad_norm"]) + + optimizer.step() + scheduler.step() + + writer.add_scalar('Training Loss', loss.item(), epoch * len(train_loader) + batch_idx) + if config.get("use_wandb", False): + wandb.log({'Training Loss': loss.item()}) + + # Update progress bar + progress_bar.set_postfix({'loss': f"{loss.item():.4f}"}) + + return loss.item() # Return the last batch loss + +def validate_model(model, val_loader, device, config): + model.eval() + val_loss = 0.0 + task_true_labels = {task_name: [] for task_name in config["task_names"]} + task_pred_labels = {task_name: [] for task_name in config["task_names"]} + task_pred_probs = {task_name: [] for task_name in config["task_names"]} + + with torch.no_grad(): + for batch in val_loader: + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + labels = [batch['labels'][task_name].to(device) for task_name in config["task_names"]] + loss, logits, _ = model(input_ids, attention_mask, labels) + val_loss += loss.item() + + for sample_idx in range(len(batch['input_ids'])): + for i, task_name in enumerate(config["task_names"]): + true_label = batch['labels'][task_name][sample_idx].item() + pred_label = torch.argmax(logits[i][sample_idx], dim=-1).item() + pred_prob = torch.softmax(logits[i][sample_idx], dim=-1).cpu().numpy() + task_true_labels[task_name].append(true_label) + task_pred_labels[task_name].append(pred_label) + task_pred_probs[task_name].append(pred_prob) + + val_loss /= len(val_loader) + return val_loss, task_true_labels, task_pred_labels, task_pred_probs + +def log_metrics(task_metrics, val_loss, config, writer, epochs): + for task_name, metrics in task_metrics.items(): + print(f"{task_name} - Validation F1 Macro: {metrics['f1']:.4f}, Validation Accuracy: {metrics['accuracy']:.4f}") + if config.get("use_wandb", False): + import wandb + wandb.log({ + f'{task_name} Validation F1 Macro': metrics['f1'], + f'{task_name} Validation Accuracy': metrics['accuracy'] + }) + + writer.add_scalar('Validation Loss', val_loss, epochs) + for task_name, metrics in task_metrics.items(): + writer.add_scalar(f'{task_name} - Validation F1 Macro', metrics['f1'], epochs) + writer.add_scalar(f'{task_name} - Validation Accuracy', metrics['accuracy'], epochs) + +def save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config, trial_number=None): + if trial_number is not None: + trial_results_dir = os.path.join(config["results_dir"], f"trial_{trial_number}") + os.makedirs(trial_results_dir, exist_ok=True) + val_preds_file = os.path.join(trial_results_dir, "val_preds.csv") + else: + val_preds_file = os.path.join(config["results_dir"], "manual_run_val_preds.csv") + + rows = [] + for sample_idx in range(len(val_cell_id_mapping)): + row = {'Cell ID': val_cell_id_mapping[sample_idx]} + for task_name in config["task_names"]: + row[f'{task_name} True'] = task_true_labels[task_name][sample_idx] + row[f'{task_name} Pred'] = task_pred_labels[task_name][sample_idx] + row[f'{task_name} Probabilities'] = ','.join(map(str, task_pred_probs[task_name][sample_idx])) + rows.append(row) + + df = pd.DataFrame(rows) + df.to_csv(val_preds_file, index=False) + print(f"Validation predictions saved to {val_preds_file}") + + +def train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list): + set_seed(config["seed"]) + initialize_wandb(config) + + model = create_model(config, num_labels_list, device) + total_steps = len(train_loader) * config["epochs"] + optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) + + log_dir = os.path.join(config["tensorboard_log_dir"], "manual_run") + writer = SummaryWriter(log_dir=log_dir) + + epoch_progress = tqdm(range(config["epochs"]), desc="Training Progress") + for epoch in epoch_progress: + last_loss = train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch) + epoch_progress.set_postfix({'last_loss': f"{last_loss:.4f}"}) + + val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(model, val_loader, device, config) + task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) + + log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) + writer.close() + + save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config) + + if config.get("use_wandb", False): + import wandb + wandb.finish() + + print(f"\nFinal Validation Loss: {val_loss:.4f}") + return val_loss, model # Return both the validation loss and the trained model + +def objective(trial, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list, config, device): + set_seed(config["seed"]) # Set the seed before each trial + initialize_wandb(config) + + # Hyperparameters + config["learning_rate"] = trial.suggest_float("learning_rate", config["hyperparameters"]["learning_rate"]["low"], config["hyperparameters"]["learning_rate"]["high"], log=config["hyperparameters"]["learning_rate"]["log"]) + config["warmup_ratio"] = trial.suggest_float("warmup_ratio", config["hyperparameters"]["warmup_ratio"]["low"], config["hyperparameters"]["warmup_ratio"]["high"]) + config["weight_decay"] = trial.suggest_float("weight_decay", config["hyperparameters"]["weight_decay"]["low"], config["hyperparameters"]["weight_decay"]["high"]) + config["dropout_rate"] = trial.suggest_float("dropout_rate", config["hyperparameters"]["dropout_rate"]["low"], config["hyperparameters"]["dropout_rate"]["high"]) + config["lr_scheduler_type"] = trial.suggest_categorical("lr_scheduler_type", config["hyperparameters"]["lr_scheduler_type"]["choices"]) + config["use_attention_pooling"] = trial.suggest_categorical("use_attention_pooling", [True, False]) + + if config["use_task_weights"]: + config["task_weights"] = [trial.suggest_float(f"task_weight_{i}", config["hyperparameters"]["task_weights"]["low"], config["hyperparameters"]["task_weights"]["high"]) for i in range(len(num_labels_list))] + weight_sum = sum(config["task_weights"]) + config["task_weights"] = [weight / weight_sum for weight in config["task_weights"]] + else: + config["task_weights"] = None + + # Fix for max_layers_to_freeze + if isinstance(config["max_layers_to_freeze"], dict): + config["max_layers_to_freeze"] = trial.suggest_int("max_layers_to_freeze", config["max_layers_to_freeze"]["min"], config["max_layers_to_freeze"]["max"]) + elif isinstance(config["max_layers_to_freeze"], int): + # If it's already an int, we don't need to suggest it + pass + else: + raise ValueError("Invalid type for max_layers_to_freeze. Expected dict or int.") + + model = create_model(config, num_labels_list, device) + total_steps = len(train_loader) * config["epochs"] + optimizer, scheduler = setup_optimizer_and_scheduler(model, config, total_steps) + + log_dir = os.path.join(config["tensorboard_log_dir"], f"trial_{trial.number}") + writer = SummaryWriter(log_dir=log_dir) + + for epoch in range(config["epochs"]): + train_epoch(model, train_loader, optimizer, scheduler, device, config, writer, epoch) + + val_loss, task_true_labels, task_pred_labels, task_pred_probs = validate_model(model, val_loader, device, config) + task_metrics = calculate_task_specific_metrics(task_true_labels, task_pred_labels) + + log_metrics(task_metrics, val_loss, config, writer, config["epochs"]) + writer.close() + + save_validation_predictions(val_cell_id_mapping, task_true_labels, task_pred_labels, task_pred_probs, config, trial.number) + + trial.set_user_attr("model_state_dict", model.state_dict()) + trial.set_user_attr("task_weights", config["task_weights"]) + + trial.report(val_loss, config["epochs"]) + + if trial.should_prune(): + raise optuna.TrialPruned() + + if config.get("use_wandb", False): + import wandb + wandb.log({ + "trial_number": trial.number, + "val_loss": val_loss, + **{f"{task_name}_f1": metrics['f1'] for task_name, metrics in task_metrics.items()}, + **{f"{task_name}_accuracy": metrics['accuracy'] for task_name, metrics in task_metrics.items()}, + **{k: v for k, v in config.items() if k in ["learning_rate", "warmup_ratio", "weight_decay", "dropout_rate", "lr_scheduler_type", "use_attention_pooling", "max_layers_to_freeze"]} + }) + wandb.finish() + + return val_loss \ No newline at end of file diff --git a/geneformer/mtl/train_utils.py b/geneformer/mtl/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3558baee0d36ba9bee17b474a21d09b6cba96c96 --- /dev/null +++ b/geneformer/mtl/train_utils.py @@ -0,0 +1,126 @@ +from .imports import * +from .data import preload_and_process_data, get_data_loader +from .train import objective, train_model +from .model import GeneformerMultiTask +from .utils import save_model +import random + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def run_manual_tuning(config): + # Set seed for reproducibility + set_seed(config["seed"]) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list = preload_and_process_data(config) + train_loader = get_data_loader(train_dataset, config['batch_size']) + val_loader = get_data_loader(val_dataset, config['batch_size']) + + # Print the manual hyperparameters being used + print("\nManual hyperparameters being used:") + for key, value in config["manual_hyperparameters"].items(): + print(f"{key}: {value}") + print() # Add an empty line for better readability + + # Use the manual hyperparameters + for key, value in config["manual_hyperparameters"].items(): + config[key] = value + + # Train the model + val_loss, trained_model = train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list) + + print(f"\nValidation loss with manual hyperparameters: {val_loss}") + + # Save the trained model + model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") + save_model(trained_model, model_save_directory) + + # Save the hyperparameters + hyperparams_to_save = { + **config["manual_hyperparameters"], + "dropout_rate": config["dropout_rate"], + "use_task_weights": config["use_task_weights"], + "task_weights": config["task_weights"], + "max_layers_to_freeze": config["max_layers_to_freeze"], + "use_attention_pooling": config["use_attention_pooling"] + } + hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") + with open(hyperparams_path, 'w') as f: + json.dump(hyperparams_to_save, f) + print(f"Manual hyperparameters saved to {hyperparams_path}") + + return val_loss + +def run_optuna_study(config): + # Set seed for reproducibility + set_seed(config["seed"]) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + train_dataset, train_cell_id_mapping, val_dataset, val_cell_id_mapping, num_labels_list = preload_and_process_data(config) + train_loader = get_data_loader(train_dataset, config['batch_size']) + val_loader = get_data_loader(val_dataset, config['batch_size']) + + if config["use_manual_hyperparameters"]: + train_model(config, device, train_loader, val_loader, train_cell_id_mapping, val_cell_id_mapping, num_labels_list) + else: + objective_with_config_and_data = functools.partial( + objective, + train_loader=train_loader, + val_loader=val_loader, + train_cell_id_mapping=train_cell_id_mapping, + val_cell_id_mapping=val_cell_id_mapping, + num_labels_list=num_labels_list, + config=config, + device=device + ) + + study = optuna.create_study( + direction='minimize', # Minimize validation loss + study_name=config["study_name"], + #storage=config["storage"], + load_if_exists=True + ) + + study.optimize( + objective_with_config_and_data, + n_trials=config["n_trials"] + ) + + # After finding the best trial + best_params = study.best_trial.params + best_task_weights = study.best_trial.user_attrs["task_weights"] + print("Saving the best model and its hyperparameters...") + + # Saving model as before + best_model = GeneformerMultiTask( + config["pretrained_path"], + num_labels_list, + dropout_rate=best_params["dropout_rate"], + use_task_weights=config["use_task_weights"], + task_weights=best_task_weights + ) + + # Get the best model state dictionary + best_model_state_dict = study.best_trial.user_attrs["model_state_dict"] + + # Remove the "module." prefix from the state dictionary keys if present + best_model_state_dict = {k.replace("module.", ""): v for k, v in best_model_state_dict.items()} + + # Load the modified state dictionary into the model, skipping unexpected keys + best_model.load_state_dict(best_model_state_dict, strict=False) + + model_save_directory = os.path.join(config["model_save_path"], "GeneformerMultiTask") + save_model(best_model, model_save_directory) + + # Additionally, save the best hyperparameters and task weights + hyperparams_path = os.path.join(model_save_directory, "hyperparameters.json") + + with open(hyperparams_path, 'w') as f: + json.dump({**best_params, "task_weights": best_task_weights}, f) + print(f"Best hyperparameters and task weights saved to {hyperparams_path}") \ No newline at end of file diff --git a/geneformer/mtl/utils.py b/geneformer/mtl/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9db8fc93fb0e5801bb917f793981479f121e4527 --- /dev/null +++ b/geneformer/mtl/utils.py @@ -0,0 +1,106 @@ +from .imports import * +from sklearn.metrics import f1_score, accuracy_score +from sklearn.preprocessing import LabelEncoder +from transformers import BertModel, BertConfig, AutoConfig +import os +import shutil + +def save_model(model, model_save_directory): + if not os.path.exists(model_save_directory): + os.makedirs(model_save_directory) + + # Get the state dict + if isinstance(model, nn.DataParallel): + model_state_dict = model.module.state_dict() # Use model.module to access the underlying model + else: + model_state_dict = model.state_dict() + + # Remove the "module." prefix from the keys if present + model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()} + + model_save_path = os.path.join(model_save_directory, "pytorch_model.bin") + torch.save(model_state_dict, model_save_path) + + # Save the model configuration + if isinstance(model, nn.DataParallel): + model.module.config.to_json_file(os.path.join(model_save_directory, "config.json")) + else: + model.config.to_json_file(os.path.join(model_save_directory, "config.json")) + + print(f"Model and configuration saved to {model_save_directory}") + +def calculate_task_specific_metrics(task_true_labels, task_pred_labels): + task_metrics = {} + for task_name in task_true_labels.keys(): + true_labels = task_true_labels[task_name] + pred_labels = task_pred_labels[task_name] + f1 = f1_score(true_labels, pred_labels, average='macro') + accuracy = accuracy_score(true_labels, pred_labels) + task_metrics[task_name] = {'f1': f1, 'accuracy': accuracy} + return task_metrics + +def calculate_combined_f1(combined_labels, combined_preds): + # Initialize the LabelEncoder + le = LabelEncoder() + + # Fit and transform combined labels and predictions to numerical values + le.fit(combined_labels + combined_preds) + encoded_true_labels = le.transform(combined_labels) + encoded_pred_labels = le.transform(combined_preds) + + # Print out the mapping for sanity check + print("\nLabel Encoder Mapping:") + for index, class_label in enumerate(le.classes_): + print(f"'{class_label}': {index}") + + # Calculate accuracy + accuracy = accuracy_score(encoded_true_labels, encoded_pred_labels) + + # Calculate F1 Macro score + f1 = f1_score(encoded_true_labels, encoded_pred_labels, average='macro') + + return f1, accuracy + +def save_model_without_heads(original_model_save_directory): + # Create a new directory for the model without heads + new_model_save_directory = original_model_save_directory + "_No_Heads" + if not os.path.exists(new_model_save_directory): + os.makedirs(new_model_save_directory) + + # Load the model state dictionary + model_state_dict = torch.load(os.path.join(original_model_save_directory, "pytorch_model.bin")) + + # Initialize a new BERT model without the classification heads + config = BertConfig.from_pretrained(os.path.join(original_model_save_directory, "config.json")) + model_without_heads = BertModel(config) + + # Filter the state dict to exclude classification heads + model_without_heads_state_dict = {k: v for k, v in model_state_dict.items() if not k.startswith("classification_heads")} + + # Load the filtered state dict into the model + model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False) + + # Save the model without heads + model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin") + torch.save(model_without_heads.state_dict(), model_save_path) + + # Copy the configuration file + shutil.copy(os.path.join(original_model_save_directory, "config.json"), new_model_save_directory) + + print(f"Model without classification heads saved to {new_model_save_directory}") + + +def get_layer_freeze_range(pretrained_path): + """ + Dynamically determines the number of layers to freeze based on the model depth from its configuration. + Args: + pretrained_path (str): Path to the pretrained model directory or model identifier. + Returns: + dict: A dictionary with 'min' and 'max' keys indicating the range of layers to freeze. + """ + if pretrained_path: + config = AutoConfig.from_pretrained(pretrained_path) + total_layers = config.num_hidden_layers + return {"min": 0, "max": total_layers - 1} + else: + return {"min": 0, "max": 0} \ No newline at end of file diff --git a/geneformer/mtl_classifier.py b/geneformer/mtl_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..7674a64b92f58416cea9c3ad2e3e5b446d7bc7f1 --- /dev/null +++ b/geneformer/mtl_classifier.py @@ -0,0 +1,338 @@ +""" +Geneformer multi-task cell classifier. + +**Input data:** + +| Single-cell transcriptomes as Geneformer rank value encodings with cell state labels for each task in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py). Must contain "unique_cell_id" column for logging. + +**Usage:** + +.. code-block :: python + + >>> from geneformer import MTLClassifier + >>> mc = MTLClassifier(task_columns = ["task1", "task2"], + ... study_name = "mtl", + ... pretrained_path = "/path/pretrained/model", + ... train_path = "/path/train/set", + ... val_path = "/path/eval/set", + ... test_path = "/path/test/set", + ... model_save_path = "/results/directory/save_path", + ... trials_result_path = "/results/directory/results.txt", + ... results_dir = "/results/directory", + ... tensorboard_log_dir = "/results/tblogdir", + ... hyperparameters = hyperparameters) + >>> mc.run_optuna_study() + >>> mc.load_and_evaluate_test_model() + >>> mc.save_model_without_heads() +""" + +import logging +import os +from .mtl import train_utils +from .mtl import utils +from .mtl import eval_utils + +logger = logging.getLogger(__name__) + + +class MTLClassifier: + valid_option_dict = { + "task_columns": {list}, + "train_path": {None, str}, + "val_path": {None, str}, + "test_path": {None, str}, + "pretrained_path": {None, str}, + "model_save_path": {None, str}, + "results_dir": {None, str}, + "batch_size": {None, int}, + "n_trials": {None, int}, + "study_name": {None, str}, + "max_layers_to_freeze": {None, dict}, + "epochs": {None, int}, + "tensorboard_log_dir": {None, str}, + "use_data_parallel": {None, bool}, + "use_attention_pooling": {None, bool}, + "use_task_weights": {None, bool}, + "hyperparameters": {None, dict}, + "manual_hyperparameters": {None, dict}, + "use_manual_hyperparameters": {None, bool}, + "use_wandb": {None, bool}, + "wandb_project": {None, str}, + "gradient_clipping": {None, bool}, + "max_grad_norm": {None, int, float}, + "seed": {None, int}, + "trials_result_path": {None, str}, + } + + def __init__( + self, + task_columns=None, + train_path=None, + val_path=None, + test_path=None, + pretrained_path=None, + model_save_path=None, + results_dir=None, + trials_result_path=None, + batch_size=4, + n_trials=15, + study_name="mtl", + max_layers_to_freeze=None, + epochs=1, + tensorboard_log_dir="/results/tblogdir", + use_data_parallel=False, + use_attention_pooling=True, + use_task_weights=True, + hyperparameters=None, # Default is None + manual_hyperparameters=None, # Default is None + use_manual_hyperparameters=False, # Default is False + use_wandb=False, + wandb_project=None, + gradient_clipping=False, + max_grad_norm=None, + seed=42 # Default seed value + ): + + """ + Initialize Geneformer multi-task classifier. + **Parameters:** + task_columns : list + | List of tasks for cell state classification + | Input data columns are labeled with corresponding task names + study_name : None, str + | Study name for labeling output files + pretrained_path : None, str + | Path to pretrained model + train_path : None, str + | Path to training dataset with task columns and "unique_cell_id" column + val_path : None, str + | Path to validation dataset with task columns and "unique_cell_id" column + test_path : None, str + | Path to test dataset with task columns and "unique_cell_id" column + model_save_path : None, str + | Path to directory to save output model (either full model or model without heads) + trials_result_path : None, str + | Path to directory to save hyperparameter tuning trial results + results_dir : None, str + | Path to directory to save results + tensorboard_log_dir : None, str + | Path to directory for Tensorboard logging results + use_data_parallel : None, bool + | Whether to use data parallelization + use_attention_pooling : None, bool + | Whether to use attention pooling + use_task_weights : None, bool + | Whether to use task weights + batch_size : None, int + | Batch size to use + n_trials : None, int + | Number of trials for hyperparameter tuning + epochs : None, int + | Number of epochs for training + max_layers_to_freeze : None, dict + | Dictionary with keys "min" and "max" indicating the min and max layers to freeze from fine-tuning (int) + | 0: no layers will be frozen; 2: first two layers will be frozen; etc. + hyperparameters : None, dict + | Dictionary of categorical max and min for each hyperparameter for tuning + | For example: + | {"learning_rate": {"type":"float", "low":"1e-5", "high":"1e-3", "log":True}, "task_weights": {...}, ...} + manual_hyperparameters : None, dict + | Dictionary of manually set value for each hyperparameter + | For example: + | {"learning_rate": 0.001, "task_weights": [1, 1], ...} + use_manual_hyperparameters : None, bool + | Whether to use manually set hyperparameters + use_wandb : None, bool + | Whether to use Weights & Biases for logging + wandb_project : None, str + | Weights & Biases project name + gradient_clipping : None, bool + | Whether to use gradient clipping + max_grad_norm : None, int, float + | Maximum norm for gradient clipping + seed : None, int + | Random seed + """ + + self.task_columns = task_columns + self.train_path = train_path + self.val_path = val_path + self.test_path = test_path + self.pretrained_path = pretrained_path + self.model_save_path = model_save_path + self.results_dir = results_dir + self.trials_result_path = trials_result_path + self.batch_size = batch_size + self.n_trials = n_trials + self.study_name = study_name + + if max_layers_to_freeze is None: + # Dynamically determine the range of layers to freeze + layer_freeze_range = utils.get_layer_freeze_range(pretrained_path) + self.max_layers_to_freeze = {"min": 1, "max": layer_freeze_range['max']} + else: + self.max_layers_to_freeze = max_layers_to_freeze + + self.epochs = epochs + self.tensorboard_log_dir = tensorboard_log_dir + self.use_data_parallel = use_data_parallel + self.use_attention_pooling = use_attention_pooling + self.use_task_weights = use_task_weights + self.hyperparameters = hyperparameters if hyperparameters is not None else { + "learning_rate": { + "type": "float", + "low": 1e-5, + "high": 1e-3, + "log": True + }, + "warmup_ratio": { + "type": "float", + "low": 0.005, + "high": 0.01 + }, + "weight_decay": { + "type": "float", + "low": 0.01, + "high": 0.1 + }, + "dropout_rate": { + "type": "float", + "low": 0.0, + "high": 0.7 + }, + "lr_scheduler_type": { + "type": "categorical", + "choices": ["cosine"] + }, + "task_weights": { + "type": "float", + "low": 0.1, + "high": 2.0 + } + } + self.manual_hyperparameters = manual_hyperparameters if manual_hyperparameters is not None else { + "learning_rate": 0.001, + "warmup_ratio": 0.01, + "weight_decay": 0.1, + "dropout_rate": 0.1, + "lr_scheduler_type": "cosine", + "use_attention_pooling": False, + "task_weights": [1, 1], + "max_layers_to_freeze": 2 + } + self.use_manual_hyperparameters = use_manual_hyperparameters + self.use_wandb = use_wandb + self.wandb_project = wandb_project + self.gradient_clipping = gradient_clipping + self.max_grad_norm = max_grad_norm + self.seed = seed + + if self.use_manual_hyperparameters: + logger.warning( + "Hyperparameter tuning is highly recommended for optimal results." + ) + + self.validate_options() + + # set up output directories + if self.results_dir is not None: + self.trials_results_path = f"{self.results_dir}/results.txt".replace("//","/") + + for output_dir in [self.model_save_path, self.results_dir]: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + self.config = {key: value for key, value in self.__dict__.items() if key in self.valid_option_dict} + + def validate_options(self): + # confirm arguments are within valid options and compatible with each other + for attr_name, valid_options in self.valid_option_dict.items(): + attr_value = self.__dict__[attr_name] + if not isinstance(attr_value, (list, dict)): + if attr_value in valid_options: + continue + valid_type = False + for option in valid_options: + if (option in [int, float, list, dict, bool, str]) and isinstance( + attr_value, option + ): + valid_type = True + break + if valid_type: + continue + logger.error( + f"Invalid option for {attr_name}. " + f"Valid options for {attr_name}: {valid_options}" + ) + raise ValueError(f"Invalid option for {attr_name}. Valid options for {attr_name}: {valid_options}") + + def run_manual_tuning(self): + """ + Manual hyperparameter tuning and multi-task fine-tuning of pretrained model. + """ + required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"] + required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir] + req_var_dict = dict(zip(required_variable_names, required_variables)) + self.validate_additional_options(req_var_dict) + + if not self.use_manual_hyperparameters: + raise ValueError("Manual hyperparameters are not enabled. Set use_manual_hyperparameters to True.") + + # Ensure manual_hyperparameters are set in the config + self.config["manual_hyperparameters"] = self.manual_hyperparameters + self.config["use_manual_hyperparameters"] = True + + train_utils.run_manual_tuning(self.config) + + def validate_additional_options(self, req_var_dict): + missing_variable = False + for variable_name, variable in req_var_dict.items(): + if variable is None: + logger.warning( + f"Please provide value to MTLClassifier for required variable {variable_name}" + ) + missing_variable = True + if missing_variable is True: + raise ValueError("Missing required variables for MTLClassifier") + + def run_optuna_study( + self, + ): + """ + Hyperparameter optimization and/or multi-task fine-tuning of pretrained model. + """ + + required_variable_names = ["train_path", "val_path", "pretrained_path", "model_save_path", "results_dir"] + required_variables = [self.train_path, self.val_path, self.pretrained_path, self.model_save_path, self.results_dir] + req_var_dict = dict(zip(required_variable_names, required_variables)) + self.validate_additional_options(req_var_dict) + + train_utils.run_optuna_study(self.config) + + def load_and_evaluate_test_model( + self, + ): + """ + Loads previously fine-tuned multi-task model and evaluates on test data. + """ + + required_variable_names = ["test_path", "model_save_path", "results_dir"] + required_variables = [self.test_path, self.model_save_path, self.results_dir] + req_var_dict = dict(zip(required_variable_names, required_variables)) + self.validate_additional_options(req_var_dict) + + eval_utils.load_and_evaluate_test_model(self.config) + + def save_model_without_heads( + self, + ): + """ + Save previously fine-tuned multi-task model without classification heads. + """ + + required_variable_names = ["model_save_path"] + required_variables = [self.model_save_path] + req_var_dict = dict(zip(required_variable_names, required_variables)) + self.validate_additional_options(req_var_dict) + + utils.save_model_without_heads(os.path.join(self.model_save_path, "GeneformerMultiTask")) diff --git a/geneformer/perturber_utils.py b/geneformer/perturber_utils.py index 833a4792cd5144492146d7d3c5141c01fdc288b8..ae1f4930ee815db4cc3c5e1d63acb42c4ca02b32 100644 --- a/geneformer/perturber_utils.py +++ b/geneformer/perturber_utils.py @@ -12,13 +12,17 @@ import pandas as pd import seaborn as sns import torch from datasets import Dataset, load_from_disk +from peft import LoraConfig, get_peft_model from transformers import ( BertForMaskedLM, BertForSequenceClassification, BertForTokenClassification, + BitsAndBytesConfig, ) -from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_DICTIONARY_FILE +GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" +TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" +ENSEMBL_DICTIONARY_FILE = Path(__file__).parent / "gene_name_id_dict.pkl" logger = logging.getLogger(__name__) @@ -111,17 +115,49 @@ def slice_by_inds_to_perturb(filtered_input_data, cell_inds_to_perturb): # load model to GPU -def load_model(model_type, num_classes, model_directory, mode): +def load_model(model_type, num_classes, model_directory, mode, quantize=False): + if model_type == "MTLCellClassifier-Quantized": + model_type = "MTLCellClassifier" + quantize = True + if mode == "eval": output_hidden_states = True elif mode == "train": output_hidden_states = False + if quantize is True: + if model_type == "MTLCellClassifier": + quantize = { + "peft_config": None, + "bnb_config": BitsAndBytesConfig( + load_in_8bit=True, + ) + } + else: + quantize = { + "peft_config": LoraConfig( + lora_alpha=128, + lora_dropout=0.1, + r=64, + bias="none", + task_type="TokenClassification", + ), + "bnb_config": BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) + } + elif quantize is False: + quantize = {"bnb_config": None} + if model_type == "Pretrained": model = BertForMaskedLM.from_pretrained( model_directory, output_hidden_states=output_hidden_states, output_attentions=False, + quantization_config=quantize["bnb_config"], ) elif model_type == "GeneClassifier": model = BertForTokenClassification.from_pretrained( @@ -129,6 +165,7 @@ def load_model(model_type, num_classes, model_directory, mode): num_labels=num_classes, output_hidden_states=output_hidden_states, output_attentions=False, + quantization_config=quantize["bnb_config"], ) elif model_type == "CellClassifier": model = BertForSequenceClassification.from_pretrained( @@ -136,11 +173,24 @@ def load_model(model_type, num_classes, model_directory, mode): num_labels=num_classes, output_hidden_states=output_hidden_states, output_attentions=False, + quantization_config=quantize["bnb_config"], + ) + elif model_type == "MTLCellClassifier": + model = BertForMaskedLM.from_pretrained( + model_directory, + num_labels=num_classes, + output_hidden_states=output_hidden_states, + output_attentions=False, + quantization_config=quantize["bnb_config"], ) # if eval mode, put the model in eval mode for fwd pass if mode == "eval": model.eval() - model = model.to("cuda") + if (quantize is False) or (quantize == {'bnb_config': None}) or (model_type == "MTLCellClassifier"): + model = model.to("cuda") + else: + model.enable_input_require_grads() + model = get_peft_model(model, quantize["peft_config"]) return model @@ -222,27 +272,47 @@ def overexpress_indices(example): indices = example["perturb_index"] if any(isinstance(el, list) for el in indices): indices = flatten_list(indices) - for index in sorted(indices, reverse=True): - example["input_ids"].insert(0, example["input_ids"].pop(index)) - + insert_pos = 0 + for index in sorted(indices, reverse=False): + example["input_ids"].insert(insert_pos, example["input_ids"].pop(index)) + insert_pos += 1 example["length"] = len(example["input_ids"]) return example +# if CLS token present, move to 1st rather than 0th position +def overexpress_indices_special(example): + indices = example["perturb_index"] + if any(isinstance(el, list) for el in indices): + indices = flatten_list(indices) + insert_pos = 1 # Insert starting after CLS token + for index in sorted(indices, reverse=False): + example["input_ids"].insert(insert_pos, example["input_ids"].pop(index)) + insert_pos += 1 + example["length"] = len(example["input_ids"]) + return example # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell -def overexpress_tokens(example, max_len): +def overexpress_tokens(example, max_len, special_token): # -100 indicates tokens to overexpress are not present in rank value encoding if example["perturb_index"] != [-100]: example = delete_indices(example) - [ - example["input_ids"].insert(0, token) - for token in example["tokens_to_perturb"][::-1] - ] + if special_token: + [ + example["input_ids"].insert(1, token) + for token in example["tokens_to_perturb"][::-1] + ] + else: + [ + example["input_ids"].insert(0, token) + for token in example["tokens_to_perturb"][::-1] + ] # truncate to max input size, must also truncate original emb to be comparable if len(example["input_ids"]) > max_len: - example["input_ids"] = example["input_ids"][0:max_len] - + if special_token: + example["input_ids"] = example["input_ids"][0:max_len-1]+[example["input_ids"][-1]] + else: + example["input_ids"] = example["input_ids"][0:max_len] example["length"] = len(example["input_ids"]) return example @@ -259,6 +329,13 @@ def truncate_by_n_overflow(example): example["length"] = len(example["input_ids"]) return example +def truncate_by_n_overflow_special(example): + if example["n_overflow"] > 0: + new_max_len = example["length"] - example["n_overflow"] + example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]] + example["length"] = len(example["input_ids"]) + return example + def remove_indices_from_emb(emb, indices_to_remove, gene_dim): # indices_to_remove is list of indices to remove @@ -392,7 +469,81 @@ def make_perturbation_batch( return perturbation_dataset, indices_to_perturb -# perturbed cell emb removing the activated/overexpressed/inhibited gene emb +def make_perturbation_batch_special( + example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc +) -> tuple[Dataset, List[int]]: + if combo_lvl == 0 and tokens_to_perturb == "all": + if perturb_type in ["overexpress", "activate"]: + range_start = 1 + elif perturb_type in ["delete", "inhibit"]: + range_start = 0 + range_start += 1 # Starting after the CLS token + indices_to_perturb = [ + [i] for i in range(range_start, example_cell["length"][0]-1) # And excluding the EOS token + ] + + # elif combo_lvl > 0 and anchor_token is None: + ## to implement + elif combo_lvl > 0 and (anchor_token is not None): + example_input_ids = example_cell["input_ids"][0] + anchor_index = example_input_ids.index(anchor_token[0]) + indices_to_perturb = [ + sorted([anchor_index, i]) if i != anchor_index else None + for i in range(1, example_cell["length"][0]-1) # Exclude CLS and EOS tokens + ] + indices_to_perturb = [item for item in indices_to_perturb if item is not None] + else: + example_input_ids = example_cell["input_ids"][0] + indices_to_perturb = [ + [example_input_ids.index(token)] if token in example_input_ids else None + for token in tokens_to_perturb + ] + indices_to_perturb = [item for item in indices_to_perturb if item is not None] + + # create all permutations of combo_lvl of modifiers from tokens_to_perturb + if combo_lvl > 0 and (anchor_token is None): + if tokens_to_perturb != "all": + if len(tokens_to_perturb) == combo_lvl + 1: + indices_to_perturb = [ + list(x) for x in it.combinations(indices_to_perturb, combo_lvl + 1) + ] + else: + all_indices = [[i] for i in range(1, example_cell["length"][0]-1)] # Exclude CLS and EOS tokens + all_indices = [ + index for index in all_indices if index not in indices_to_perturb + ] + indices_to_perturb = [ + [[j for i in indices_to_perturb for j in i], x] for x in all_indices + ] + + length = len(indices_to_perturb) + perturbation_dataset = Dataset.from_dict( + { + "input_ids": example_cell["input_ids"] * length, + "perturb_index": indices_to_perturb, + } + ) + + if length < 400: + num_proc_i = 1 + else: + num_proc_i = num_proc + + if perturb_type == "delete": + perturbation_dataset = perturbation_dataset.map( + delete_indices, num_proc=num_proc_i + ) + elif perturb_type == "overexpress": + perturbation_dataset = perturbation_dataset.map( + overexpress_indices_special, num_proc=num_proc_i + ) + + perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i) + + return perturbation_dataset, indices_to_perturb + + +# original cell emb removing the activated/overexpressed/inhibited gene emb # so that only non-perturbed gene embeddings are compared to each other # in original or perturbed context def make_comparison_batch(original_emb_batch, indices_to_perturb, perturb_group): @@ -589,9 +740,10 @@ def quant_cos_sims( cos = torch.nn.CosineSimilarity(dim=1) # if emb_mode == "gene", can only calculate gene cos sims - # against original cell anyways + # against original cell if cell_states_to_model is None or emb_mode == "gene": cos_sims = cos(perturbation_emb, original_emb).to("cuda") + elif cell_states_to_model is not None and emb_mode == "cell": possible_states = get_possible_states(cell_states_to_model) cos_sims = dict(zip(possible_states, [[] for _ in range(len(possible_states))])) @@ -758,4 +910,4 @@ class GeneIdHandler: return self.ens_to_symbol(self.token_to_ens(token)) def symbol_to_token(self, symbol): - return self.ens_to_token(self.symbol_to_ens(symbol)) + return self.ens_to_token(self.symbol_to_ens(symbol)) \ No newline at end of file diff --git a/geneformer/pretrainer.py b/geneformer/pretrainer.py index a615ef71b1a6a65343364006f6b897dea983f4fa..93a47b0363f73fb06f5abc9ca4e67cf95abf0166 100644 --- a/geneformer/pretrainer.py +++ b/geneformer/pretrainer.py @@ -32,8 +32,6 @@ from transformers.training_args import ParallelMode from transformers.utils import is_tf_available, is_torch_available, logging, to_py_obj from transformers.utils.generic import _is_tensorflow, _is_torch -from . import TOKEN_DICTIONARY_FILE - logger = logging.get_logger(__name__) EncodedInput = List[int] VERY_LARGE_INTEGER = int( @@ -52,9 +50,6 @@ _is_torch_generator_available = False if version.parse(torch.__version__) >= version.parse("1.6"): _is_torch_generator_available = True -with open(TOKEN_DICTIONARY_FILE, "rb") as f: - token_dictionary = pickle.load(f) - class ExplicitEnum(Enum): """ @@ -109,15 +104,7 @@ class GeneformerPreCollator(SpecialTokensMixin): super().__init__(mask_token="", pad_token="") self.token_dictionary = kwargs.get("token_dictionary") - # self.mask_token = "" - # self.mask_token_id = self.token_dictionary.get("") - # self.pad_token = "" - # self.pad_token_id = self.token_dictionary.get("") self.padding_side = "right" - # self.all_special_ids = [ - # self.token_dictionary.get(""), - # self.token_dictionary.get(""), - # ] self.model_input_names = ["input_ids"] def convert_ids_to_tokens(self, value): diff --git a/geneformer/token_dictionary.pkl b/geneformer/token_dictionary.pkl deleted file mode 100644 index e879153d2fa7a53486d7d0888663d8bb82599836..0000000000000000000000000000000000000000 Binary files a/geneformer/token_dictionary.pkl and /dev/null differ diff --git a/geneformer/token_dictionary_gc95M.pkl b/geneformer/token_dictionary_gc95M.pkl index 78ef583010dba9e79413539e5ed22c7be2ac7469..b56e406e79c255328f84d9ca00c5c3da2dd04811 100644 Binary files a/geneformer/token_dictionary_gc95M.pkl and b/geneformer/token_dictionary_gc95M.pkl differ diff --git a/generation_config.json b/generation_config.json new file mode 100644 index 0000000000000000000000000000000000000000..6f690c1f39b5b262e6b898b8891afd9d44978f11 --- /dev/null +++ b/generation_config.json @@ -0,0 +1,5 @@ +{ + "_from_model_config": true, + "pad_token_id": 0, + "transformers_version": "4.37.1" +} diff --git a/geneformer-12L-30M/config.json b/gf-12L-30M-i2048/config.json similarity index 100% rename from geneformer-12L-30M/config.json rename to gf-12L-30M-i2048/config.json diff --git a/geneformer-12L-30M/pytorch_model.bin b/gf-12L-30M-i2048/pytorch_model.bin similarity index 100% rename from geneformer-12L-30M/pytorch_model.bin rename to gf-12L-30M-i2048/pytorch_model.bin diff --git a/geneformer-12L-30M/training_args.bin b/gf-12L-30M-i2048/training_args.bin similarity index 100% rename from geneformer-12L-30M/training_args.bin rename to gf-12L-30M-i2048/training_args.bin diff --git a/gf-12L-95M-i4096/config.json b/gf-12L-95M-i4096/config.json new file mode 100755 index 0000000000000000000000000000000000000000..86e20c35e6f257f0daeb00ebb92a0751d12d8fff --- /dev/null +++ b/gf-12L-95M-i4096/config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "classifier_dropout": null, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 1024, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 4096, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "torch_dtype": "float32", + "transformers_version": "4.37.1", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 20275 +} diff --git a/gf-12L-95M-i4096/generation_config.json b/gf-12L-95M-i4096/generation_config.json new file mode 100755 index 0000000000000000000000000000000000000000..6f690c1f39b5b262e6b898b8891afd9d44978f11 --- /dev/null +++ b/gf-12L-95M-i4096/generation_config.json @@ -0,0 +1,5 @@ +{ + "_from_model_config": true, + "pad_token_id": 0, + "transformers_version": "4.37.1" +} diff --git a/gf-12L-95M-i4096/model.safetensors b/gf-12L-95M-i4096/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..1069352219a29bed65fa8e13feb77004128174fa --- /dev/null +++ b/gf-12L-95M-i4096/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c +size 152012980 diff --git a/gf-12L-95M-i4096/training_args.bin b/gf-12L-95M-i4096/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..18802f485a03e0262866d1ef7a3e4748a3b14ed3 --- /dev/null +++ b/gf-12L-95M-i4096/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d +size 4920 diff --git a/gf-12L-95M-i4096_CLcancer/config.json b/gf-12L-95M-i4096_CLcancer/config.json new file mode 100755 index 0000000000000000000000000000000000000000..a7793eb2ea27b28f1f4c5b9974d30c98b4afe8a6 --- /dev/null +++ b/gf-12L-95M-i4096_CLcancer/config.json @@ -0,0 +1,25 @@ +{ + "_name_or_path": "/gladstone/theodoris/lab/pretrained_models/encoder/240402_194213_geneformer_94M_L12_emb512_SL4096_E3_B4_LR0.0005_LScosine_WU5000_Oadamw_DS8/models", + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "classifier_dropout": null, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 1024, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 4096, + "model_type": "bert", + "num_attention_heads": 8, + "num_hidden_layers": 12, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "torch_dtype": "float32", + "transformers_version": "4.37.1", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 20275 +} diff --git a/gf-12L-95M-i4096_CLcancer/generation_config.json b/gf-12L-95M-i4096_CLcancer/generation_config.json new file mode 100755 index 0000000000000000000000000000000000000000..6f690c1f39b5b262e6b898b8891afd9d44978f11 --- /dev/null +++ b/gf-12L-95M-i4096_CLcancer/generation_config.json @@ -0,0 +1,5 @@ +{ + "_from_model_config": true, + "pad_token_id": 0, + "transformers_version": "4.37.1" +} diff --git a/gf-12L-95M-i4096_CLcancer/model.safetensors b/gf-12L-95M-i4096_CLcancer/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..cc620ee4b4243b7ab6d83ad518563e1425eab45b --- /dev/null +++ b/gf-12L-95M-i4096_CLcancer/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2451adeed240c165634fea60ccba17063da8a2843ea9fcdcc0ce185720bf0dc2 +size 152012980 diff --git a/gf-12L-95M-i4096_CLcancer/training_args.bin b/gf-12L-95M-i4096_CLcancer/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..1669f5848710ca4a53db6e118e50b816f85381b7 --- /dev/null +++ b/gf-12L-95M-i4096_CLcancer/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37074f3ea62a6ba0a312c38526c20c2dccbb068a2c7ee8c7c73b435dd90ab7b1 +size 5048 diff --git a/gf-20L-95M-i4096/config.json b/gf-20L-95M-i4096/config.json new file mode 100755 index 0000000000000000000000000000000000000000..db949ba1ae442ad3b9e52fd8b7922c6b936ef98c --- /dev/null +++ b/gf-20L-95M-i4096/config.json @@ -0,0 +1,24 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "classifier_dropout": null, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 896, + "initializer_range": 0.02, + "intermediate_size": 1792, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 4096, + "model_type": "bert", + "num_attention_heads": 14, + "num_hidden_layers": 20, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "torch_dtype": "float32", + "transformers_version": "4.37.1", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 20275 +} diff --git a/gf-20L-95M-i4096/generation_config.json b/gf-20L-95M-i4096/generation_config.json new file mode 100755 index 0000000000000000000000000000000000000000..6f690c1f39b5b262e6b898b8891afd9d44978f11 --- /dev/null +++ b/gf-20L-95M-i4096/generation_config.json @@ -0,0 +1,5 @@ +{ + "_from_model_config": true, + "pad_token_id": 0, + "transformers_version": "4.37.1" +} diff --git a/gf-20L-95M-i4096/model.safetensors b/gf-20L-95M-i4096/model.safetensors new file mode 100755 index 0000000000000000000000000000000000000000..37212863afb501a17425dd48766d71d534537d24 --- /dev/null +++ b/gf-20L-95M-i4096/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db85c081a6d392448955c7d0185e26aba74507518df991ca8c69ee9108ce8bbf +size 605292732 diff --git a/gf-20L-95M-i4096/training_args.bin b/gf-20L-95M-i4096/training_args.bin new file mode 100755 index 0000000000000000000000000000000000000000..3db61b0b99d299afb7c4a237d2b531baa253e5d3 --- /dev/null +++ b/gf-20L-95M-i4096/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5afed602918d6f0c4916c1b9335bcdb619bca2c6fd6c7e0dd2a86d195264b8cc +size 5048 diff --git a/gf-6L-30M-i2048/config.json b/gf-6L-30M-i2048/config.json new file mode 100644 index 0000000000000000000000000000000000000000..d131b7026d684013f988cc9e3dcae2e5a284bc0e --- /dev/null +++ b/gf-6L-30M-i2048/config.json @@ -0,0 +1,23 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.02, + "gradient_checkpointing": false, + "hidden_act": "relu", + "hidden_dropout_prob": 0.02, + "hidden_size": 256, + "initializer_range": 0.02, + "intermediate_size": 512, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 2048, + "model_type": "bert", + "num_attention_heads": 4, + "num_hidden_layers": 6, + "pad_token_id": 0, + "position_embedding_type": "absolute", + "transformers_version": "4.6.0", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 25426 +} diff --git a/gf-6L-30M-i2048/model.safetensors b/gf-6L-30M-i2048/model.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..c06bc0c9f7517d5db759187f65d27bacc76eb631 --- /dev/null +++ b/gf-6L-30M-i2048/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5e33a757431643b3697de7ef6127950cdc49e06e58d4266b3a3ab191b683f14 +size 41183536 diff --git a/pytorch_model.bin b/gf-6L-30M-i2048/pytorch_model.bin similarity index 100% rename from pytorch_model.bin rename to gf-6L-30M-i2048/pytorch_model.bin diff --git a/gf-6L-30M-i2048/training_args.bin b/gf-6L-30M-i2048/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..3e03ccc99722f70224937e7b2e46f8faab774e23 --- /dev/null +++ b/gf-6L-30M-i2048/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f0ec3459454205174c9d2e4d6c6930f6b0fbf3364fc03a6f4d99c4d3add2012b +size 2607 diff --git a/model.safetensors b/model.safetensors index c06bc0c9f7517d5db759187f65d27bacc76eb631..1069352219a29bed65fa8e13feb77004128174fa 100644 --- a/model.safetensors +++ b/model.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a5e33a757431643b3697de7ef6127950cdc49e06e58d4266b3a3ab191b683f14 -size 41183536 +oid sha256:4365ba23e393fcfa0e65a94ac64a0983cd788bd23a8d4914f4ab66f85cfe043c +size 152012980 diff --git a/requirements.txt b/requirements.txt index b148bc1fc59c72a5939c3aaa94fafff9dcb1da8b..77f06bc9b6f449eb8b89de3bb05fac1ee6bf52c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,8 @@ hyperopt>=0.2 loompy>=3.0 matplotlib>=3.7 numpy>=1.23 +optuna>=3.6 +optuna-integration>=3.6 packaging>=23.0 pandas>=2.0 pyarrow>=12.0 diff --git a/training_args.bin b/training_args.bin index 3e03ccc99722f70224937e7b2e46f8faab774e23..18802f485a03e0262866d1ef7a3e4748a3b14ed3 100644 --- a/training_args.bin +++ b/training_args.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f0ec3459454205174c9d2e4d6c6930f6b0fbf3364fc03a6f4d99c4d3add2012b -size 2607 +oid sha256:21a45980734b138029422e95a5601def858821a9ec02cd473938b9f525ac108d +size 4920