madhavanvenkatesh commited on
Commit
e384cc6
1 Parent(s): fe1640b

"save_model_without_heads" is redundant

Browse files

perturber and emb-extractor works even when provided pytorch_model.bin with the classification weights/biases (no need to remove classification heads to do downstream tasks:
newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight'] treated correctly with BertForMaskedLM

Files changed (1) hide show
  1. geneformer/mtl_classifier.py +15 -15
geneformer/mtl_classifier.py CHANGED
@@ -344,18 +344,18 @@ class MTLClassifier:
344
 
345
  eval_utils.load_and_evaluate_test_model(self.config)
346
 
347
- def save_model_without_heads(
348
- self,
349
- ):
350
- """
351
- Save previously fine-tuned multi-task model without classification heads.
352
- """
353
-
354
- required_variable_names = ["model_save_path"]
355
- required_variables = [self.model_save_path]
356
- req_var_dict = dict(zip(required_variable_names, required_variables))
357
- self.validate_additional_options(req_var_dict)
358
-
359
- utils.save_model_without_heads(
360
- os.path.join(self.model_save_path, "GeneformerMultiTask")
361
- )
 
344
 
345
  eval_utils.load_and_evaluate_test_model(self.config)
346
 
347
+ # def save_model_without_heads(
348
+ # self,
349
+ # ):
350
+ # """
351
+ # Save previously fine-tuned multi-task model without classification heads.
352
+ # """
353
+
354
+ # required_variable_names = ["model_save_path"]
355
+ # required_variables = [self.model_save_path]
356
+ # req_var_dict = dict(zip(required_variable_names, required_variables))
357
+ # self.validate_additional_options(req_var_dict)
358
+
359
+ # utils.save_model_without_heads(
360
+ # os.path.join(self.model_save_path, "GeneformerMultiTask")
361
+ # )