Update geneformer/in_silico_perturber.py
Browse files
geneformer/in_silico_perturber.py
CHANGED
@@ -604,7 +604,7 @@ class InSilicoPerturber:
|
|
604 |
"filter_data": {None, dict},
|
605 |
"cell_states_to_model": {None, dict},
|
606 |
"max_ncells": {None, int},
|
607 |
-
"
|
608 |
"emb_layer": {-1, 0},
|
609 |
"forward_batch_size": {int},
|
610 |
"nproc": {int},
|
@@ -623,7 +623,7 @@ class InSilicoPerturber:
|
|
623 |
filter_data=None,
|
624 |
cell_states_to_model=None,
|
625 |
max_ncells=None,
|
626 |
-
|
627 |
emb_layer=-1,
|
628 |
forward_batch_size=100,
|
629 |
nproc=4,
|
@@ -689,9 +689,9 @@ class InSilicoPerturber:
|
|
689 |
max_ncells : None, int
|
690 |
Maximum number of cells to test.
|
691 |
If None, will test all cells.
|
692 |
-
|
693 |
Default is perturbing each cell in the dataset.
|
694 |
-
Otherwise, may provide a dict of indices of
|
695 |
start_ind: the first index to perturb.
|
696 |
end_ind: the last index to perturb (exclusive).
|
697 |
Indices will be selected *after* the filter_data criteria and sorting.
|
@@ -732,7 +732,7 @@ class InSilicoPerturber:
|
|
732 |
self.filter_data = filter_data
|
733 |
self.cell_states_to_model = cell_states_to_model
|
734 |
self.max_ncells = max_ncells
|
735 |
-
self.
|
736 |
self.emb_layer = emb_layer
|
737 |
self.forward_batch_size = forward_batch_size
|
738 |
self.nproc = nproc
|
@@ -908,15 +908,15 @@ class InSilicoPerturber:
|
|
908 |
"Values in filter_data dict must be lists. " \
|
909 |
f"Changing {key} value to list ([{value}]).")
|
910 |
|
911 |
-
if self.
|
912 |
-
if set(self.
|
913 |
logger.error(
|
914 |
-
"If
|
915 |
)
|
916 |
raise
|
917 |
-
if self.
|
918 |
logger.error(
|
919 |
-
'
|
920 |
)
|
921 |
raise
|
922 |
|
@@ -1017,15 +1017,15 @@ class InSilicoPerturber:
|
|
1017 |
cos_sims_dict = defaultdict(list)
|
1018 |
pickle_batch = -1
|
1019 |
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
1020 |
-
if self.
|
1021 |
-
if self.
|
1022 |
-
logger.error("
|
1023 |
raise
|
1024 |
-
if self.
|
1025 |
-
logger.warning("
|
1026 |
Setting to the end of the filtered dataset.")
|
1027 |
-
self.
|
1028 |
-
filtered_input_data = filtered_input_data.select([i for i in range(self.
|
1029 |
|
1030 |
# make perturbation batch w/ single perturbation in multiple cells
|
1031 |
if self.perturb_group == True:
|
|
|
604 |
"filter_data": {None, dict},
|
605 |
"cell_states_to_model": {None, dict},
|
606 |
"max_ncells": {None, int},
|
607 |
+
"cell_inds_to_perturb": {"all", dict},
|
608 |
"emb_layer": {-1, 0},
|
609 |
"forward_batch_size": {int},
|
610 |
"nproc": {int},
|
|
|
623 |
filter_data=None,
|
624 |
cell_states_to_model=None,
|
625 |
max_ncells=None,
|
626 |
+
cell_inds_to_perturb="all",
|
627 |
emb_layer=-1,
|
628 |
forward_batch_size=100,
|
629 |
nproc=4,
|
|
|
689 |
max_ncells : None, int
|
690 |
Maximum number of cells to test.
|
691 |
If None, will test all cells.
|
692 |
+
cell_inds_to_perturb : "all", list
|
693 |
Default is perturbing each cell in the dataset.
|
694 |
+
Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
|
695 |
start_ind: the first index to perturb.
|
696 |
end_ind: the last index to perturb (exclusive).
|
697 |
Indices will be selected *after* the filter_data criteria and sorting.
|
|
|
732 |
self.filter_data = filter_data
|
733 |
self.cell_states_to_model = cell_states_to_model
|
734 |
self.max_ncells = max_ncells
|
735 |
+
self.cell_inds_to_perturb = cell_inds_to_perturb
|
736 |
self.emb_layer = emb_layer
|
737 |
self.forward_batch_size = forward_batch_size
|
738 |
self.nproc = nproc
|
|
|
908 |
"Values in filter_data dict must be lists. " \
|
909 |
f"Changing {key} value to list ([{value}]).")
|
910 |
|
911 |
+
if self.cell_inds_to_perturb != "all":
|
912 |
+
if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
|
913 |
logger.error(
|
914 |
+
"If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
915 |
)
|
916 |
raise
|
917 |
+
if self.cell_inds_to_perturb["start"] < 0 or self.cell_inds_to_perturb["end"] < 0:
|
918 |
logger.error(
|
919 |
+
'cell_inds_to_perturb must be positive.'
|
920 |
)
|
921 |
raise
|
922 |
|
|
|
1017 |
cos_sims_dict = defaultdict(list)
|
1018 |
pickle_batch = -1
|
1019 |
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
1020 |
+
if self.cell_inds_to_perturb != "all":
|
1021 |
+
if self.cell_inds_to_perturb["start"] >= len(filtered_input_data):
|
1022 |
+
logger.error("cell_inds_to_perturb['start'] is larger than the filtered dataset.")
|
1023 |
raise
|
1024 |
+
if self.cell_inds_to_perturb["end"] > len(filtered_input_data):
|
1025 |
+
logger.warning("cell_inds_to_perturb['end'] is larger than the filtered dataset. \
|
1026 |
Setting to the end of the filtered dataset.")
|
1027 |
+
self.cell_inds_to_perturb["end"] = len(filtered_input_data)
|
1028 |
+
filtered_input_data = filtered_input_data.select([i for i in range(self.cell_inds_to_perturb["start"], self.cell_inds_to_perturb["end"])])
|
1029 |
|
1030 |
# make perturbation batch w/ single perturbation in multiple cells
|
1031 |
if self.perturb_group == True:
|