pointing dictionaries from the mtl module's init

#397
Files changed (1) hide show
  1. geneformer/mtl/collators.py +3 -3
geneformer/mtl/collators.py CHANGED
@@ -1,18 +1,18 @@
1
  # imports
2
  import torch
3
-
4
  from ..collator_for_classification import DataCollatorForGeneClassification
 
5
 
6
  """
7
  Geneformer collator for multi-task cell classification.
8
  """
9
 
10
-
11
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
12
  class_type = "cell"
13
 
14
  def __init__(self, *args, **kwargs) -> None:
15
- super().__init__(*args, **kwargs)
 
16
 
17
  def _prepare_batch(self, features):
18
  # Process inputs as usual
 
1
  # imports
2
  import torch
 
3
  from ..collator_for_classification import DataCollatorForGeneClassification
4
+ from . import TOKEN_DICTIONARY # import the token dictionary from the mtl module's init
5
 
6
  """
7
  Geneformer collator for multi-task cell classification.
8
  """
9
 
 
10
  class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
11
  class_type = "cell"
12
 
13
  def __init__(self, *args, **kwargs) -> None:
14
+ # Use the loaded token dictionary from the mtl module's init
15
+ super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
16
 
17
  def _prepare_batch(self, features):
18
  # Process inputs as usual