pointing dictionaries from the mtl module's init
#397
by
madhavanvenkatesh
- opened
geneformer/mtl/collators.py
CHANGED
@@ -1,18 +1,18 @@
|
|
1 |
# imports
|
2 |
import torch
|
3 |
-
|
4 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
|
|
5 |
|
6 |
"""
|
7 |
Geneformer collator for multi-task cell classification.
|
8 |
"""
|
9 |
|
10 |
-
|
11 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
12 |
class_type = "cell"
|
13 |
|
14 |
def __init__(self, *args, **kwargs) -> None:
|
15 |
-
|
|
|
16 |
|
17 |
def _prepare_batch(self, features):
|
18 |
# Process inputs as usual
|
|
|
1 |
# imports
|
2 |
import torch
|
|
|
3 |
from ..collator_for_classification import DataCollatorForGeneClassification
|
4 |
+
from . import TOKEN_DICTIONARY # import the token dictionary from the mtl module's init
|
5 |
|
6 |
"""
|
7 |
Geneformer collator for multi-task cell classification.
|
8 |
"""
|
9 |
|
|
|
10 |
class DataCollatorForMultitaskCellClassification(DataCollatorForGeneClassification):
|
11 |
class_type = "cell"
|
12 |
|
13 |
def __init__(self, *args, **kwargs) -> None:
|
14 |
+
# Use the loaded token dictionary from the mtl module's init
|
15 |
+
super().__init__(token_dictionary=TOKEN_DICTIONARY, *args, **kwargs)
|
16 |
|
17 |
def _prepare_batch(self, features):
|
18 |
# Process inputs as usual
|