Upload in_silico_perturber.py
#187
by
davidjwen
- opened
geneformer/in_silico_perturber.py
CHANGED
@@ -396,19 +396,22 @@ def quant_cos_sims(model,
|
|
396 |
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
397 |
original_minibatch_lengths = original_minibatch["length"]
|
398 |
original_minibatch_length_set = set(original_minibatch["length"])
|
|
|
|
|
|
|
399 |
if perturb_type == "overexpress":
|
400 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
401 |
else:
|
402 |
new_max_len = model_input_size
|
403 |
if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
|
404 |
-
|
405 |
def pad_or_trunc_example(example):
|
406 |
-
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id,
|
407 |
return example
|
408 |
original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
409 |
original_minibatch.set_format(type="torch")
|
410 |
original_input_data_minibatch = original_minibatch["input_ids"]
|
411 |
-
attention_mask = gen_attention_mask(original_minibatch,
|
412 |
# extract embeddings for original minibatch
|
413 |
with torch.no_grad():
|
414 |
original_outputs = model(
|
@@ -429,7 +432,7 @@ def quant_cos_sims(model,
|
|
429 |
# exclude overexpression due to case when genes are not expressed but being overexpressed
|
430 |
if perturb_type != "overexpress":
|
431 |
original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
|
432 |
-
|
433 |
gene_dim)
|
434 |
|
435 |
# cosine similarity between original emb and batch items
|
@@ -438,7 +441,7 @@ def quant_cos_sims(model,
|
|
438 |
minibatch_comparison = comparison_batch[i:max_range]
|
439 |
elif perturb_group == True:
|
440 |
minibatch_comparison = make_comparison_batch(original_minibatch_emb,
|
441 |
-
|
442 |
perturb_group)
|
443 |
|
444 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
|
|
396 |
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
397 |
original_minibatch_lengths = original_minibatch["length"]
|
398 |
original_minibatch_length_set = set(original_minibatch["length"])
|
399 |
+
|
400 |
+
indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
|
401 |
+
|
402 |
if perturb_type == "overexpress":
|
403 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
404 |
else:
|
405 |
new_max_len = model_input_size
|
406 |
if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
|
407 |
+
new_max_len = min(max(original_minibatch_length_set),new_max_len)
|
408 |
def pad_or_trunc_example(example):
|
409 |
+
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id, new_max_len)
|
410 |
return example
|
411 |
original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
412 |
original_minibatch.set_format(type="torch")
|
413 |
original_input_data_minibatch = original_minibatch["input_ids"]
|
414 |
+
attention_mask = gen_attention_mask(original_minibatch, new_max_len)
|
415 |
# extract embeddings for original minibatch
|
416 |
with torch.no_grad():
|
417 |
original_outputs = model(
|
|
|
432 |
# exclude overexpression due to case when genes are not expressed but being overexpressed
|
433 |
if perturb_type != "overexpress":
|
434 |
original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
|
435 |
+
indices_to_perturb_minibatch,
|
436 |
gene_dim)
|
437 |
|
438 |
# cosine similarity between original emb and batch items
|
|
|
441 |
minibatch_comparison = comparison_batch[i:max_range]
|
442 |
elif perturb_group == True:
|
443 |
minibatch_comparison = make_comparison_batch(original_minibatch_emb,
|
444 |
+
indices_to_perturb_minibatch,
|
445 |
perturb_group)
|
446 |
|
447 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|