Upload in_silico_perturber.py

#187
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +8 -5
geneformer/in_silico_perturber.py CHANGED
@@ -396,19 +396,22 @@ def quant_cos_sims(model,
396
  original_minibatch = original_emb.select([i for i in range(i, max_range)])
397
  original_minibatch_lengths = original_minibatch["length"]
398
  original_minibatch_length_set = set(original_minibatch["length"])
 
 
 
399
  if perturb_type == "overexpress":
400
  new_max_len = model_input_size - len(tokens_to_perturb)
401
  else:
402
  new_max_len = model_input_size
403
  if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
404
- original_max_len = min(max(original_minibatch_length_set),new_max_len)
405
  def pad_or_trunc_example(example):
406
- example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id, original_max_len)
407
  return example
408
  original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
409
  original_minibatch.set_format(type="torch")
410
  original_input_data_minibatch = original_minibatch["input_ids"]
411
- attention_mask = gen_attention_mask(original_minibatch, original_max_len)
412
  # extract embeddings for original minibatch
413
  with torch.no_grad():
414
  original_outputs = model(
@@ -429,7 +432,7 @@ def quant_cos_sims(model,
429
  # exclude overexpression due to case when genes are not expressed but being overexpressed
430
  if perturb_type != "overexpress":
431
  original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
432
- indices_to_perturb,
433
  gene_dim)
434
 
435
  # cosine similarity between original emb and batch items
@@ -438,7 +441,7 @@ def quant_cos_sims(model,
438
  minibatch_comparison = comparison_batch[i:max_range]
439
  elif perturb_group == True:
440
  minibatch_comparison = make_comparison_batch(original_minibatch_emb,
441
- indices_to_perturb,
442
  perturb_group)
443
 
444
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
 
396
  original_minibatch = original_emb.select([i for i in range(i, max_range)])
397
  original_minibatch_lengths = original_minibatch["length"]
398
  original_minibatch_length_set = set(original_minibatch["length"])
399
+
400
+ indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
401
+
402
  if perturb_type == "overexpress":
403
  new_max_len = model_input_size - len(tokens_to_perturb)
404
  else:
405
  new_max_len = model_input_size
406
  if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
407
+ new_max_len = min(max(original_minibatch_length_set),new_max_len)
408
  def pad_or_trunc_example(example):
409
+ example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id, new_max_len)
410
  return example
411
  original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
412
  original_minibatch.set_format(type="torch")
413
  original_input_data_minibatch = original_minibatch["input_ids"]
414
+ attention_mask = gen_attention_mask(original_minibatch, new_max_len)
415
  # extract embeddings for original minibatch
416
  with torch.no_grad():
417
  original_outputs = model(
 
432
  # exclude overexpression due to case when genes are not expressed but being overexpressed
433
  if perturb_type != "overexpress":
434
  original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
435
+ indices_to_perturb_minibatch,
436
  gene_dim)
437
 
438
  # cosine similarity between original emb and batch items
 
441
  minibatch_comparison = comparison_batch[i:max_range]
442
  elif perturb_group == True:
443
  minibatch_comparison = make_comparison_batch(original_minibatch_emb,
444
+ indices_to_perturb_minibatch,
445
  perturb_group)
446
 
447
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]