ctheodoris madhavanvenkatesh commited on
Commit
7470753
1 Parent(s): beb62a4

pointing dictionaries from the mtl module's init (#397)

Browse files

- pointing dictionaries from the mtl module's init (5539d14469f84e2f0b13a7cb3f6054b2b0cbf1f3)


Co-authored-by: Madhavan Venkatesh <[email protected]>

Files changed (1) hide show
  1. geneformer/mtl/collators.py +3 -3
geneformer/mtl/collators.py CHANGED
@@ -1,18 +1,18 @@
1
  # imports
2
  import torch
3
-
4
  from ..collator_for_classification import DataCollatorForGeneClassification
 
5
 
6
  """
7
  Geneformer collator for multi-task cell classification.
8
  """
9
 
10
-
11
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
12
  class_type = "cell"
13
 
14
  def __init__(self, *args, **kwargs) -> None:
15
- super().__init__(*args, **kwargs)
 
16
 
17
  def _prepare_batch(self, features):
18
  # Process inputs as usual
 
1
  # imports
2
  import torch
 
3
  from ..collator_for_classification import DataCollatorForGeneClassification
4
+ from . import TOKEN_DICTIONARY # import the token dictionary from the mtl module's init
5
 
6
  """
7
  Geneformer collator for multi-task cell classification.
8
  """
9
 
 
10
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
11
  class_type = "cell"
12
 
13
  def __init__(self, *args, **kwargs) -> None:
14
+ # Use the loaded token dictionary from the mtl module's init
15
+ super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
16
 
17
  def _prepare_batch(self, features):
18
  # Process inputs as usual