“sayehs” commited on
Commit
c1efd44
1 Parent(s): 94a9749
Files changed (3) hide show
  1. config.json +0 -4
  2. modeling_gptj.py +11 -165
  3. test.py +13 -0
config.json CHANGED
@@ -3,10 +3,6 @@
3
  "architectures": [
4
  "GPTJForCausalLM"
5
  ],
6
- "auto_map": {
7
- "AutoConfig": "configuration_gptj.GPTJConfig",
8
- "AutoModelForCausalLM": "modeling_gptj.GPTJForCausalLM"
9
- },
10
  "attn_pdrop": 0.0,
11
  "bos_token_id": 50256,
12
  "embd_pdrop": 0.0,
 
3
  "architectures": [
4
  "GPTJForCausalLM"
5
  ],
 
 
 
 
6
  "attn_pdrop": 0.0,
7
  "bos_token_id": 50256,
8
  "embd_pdrop": 0.0,
modeling_gptj.py CHANGED
@@ -38,7 +38,6 @@ from transformers.utils import (
38
  is_torch_fx_proxy,
39
  logging,
40
  )
41
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
42
  from .configuration_gptj import GPTJConfig
43
 
44
 
@@ -453,54 +452,6 @@ GPTJ_INPUTS_DOCSTRING = r"""
453
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
454
  """
455
 
456
- PARALLELIZE_DOCSTRING = r"""
457
- This is an experimental feature and is a subject to change at a moment's notice. Uses a device map to distribute
458
- attention modules of the model across several devices. If no device map is given, it will evenly distribute blocks
459
- across all devices.
460
-
461
- Args:
462
- device_map (`Dict[int, list]`, optional, defaults to None):
463
- A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
464
- automatically mapped to the first device (for esoteric reasons). That means that the first device should
465
- have fewer attention modules mapped to it than other devices. For reference, the GPT-J models have the
466
- following number of attention modules:
467
-
468
- - gpt-j-6B: 28
469
-
470
- Example:
471
-
472
- ```python
473
- # Here is an example of a device map on a machine with 4 GPUs using gpt-j-6B, which has a total of 28 attention modules:
474
- model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
475
- device_map = {
476
- 0: [0, 1, 2, 3, 4, 5, 6],
477
- 1: [7, 8, 9, 10, 11, 12, 13],
478
- 2: [14, 15, 16, 17, 18, 19, 20],
479
- 3: [21, 22, 23, 24, 25, 26, 27],
480
- }
481
- model.parallelize(device_map)
482
- ```
483
- """
484
-
485
- DEPARALLELIZE_DOCSTRING = r"""
486
- Moves the model to CPU from a model parallel state.
487
-
488
- Example:
489
-
490
- ```python
491
- # On a 4 GPU machine with gpt-j-6B:
492
- model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
493
- device_map = {
494
- 0: [0, 1, 2, 3, 4, 5, 6],
495
- 1: [7, 8, 9, 10, 11, 12, 13],
496
- 2: [14, 15, 16, 17, 18, 19, 20],
497
- 3: [21, 22, 23, 24, 25, 26, 27],
498
- }
499
- model.parallelize(device_map) # Splits the model across several devices
500
- model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
501
- ```
502
- """
503
-
504
 
505
  @add_start_docstrings(
506
  "The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.",
@@ -517,62 +468,11 @@ class GPTJModel(GPTJPreTrainedModel):
517
  self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])
518
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
519
 
520
- # Model parallel
521
- self.model_parallel = False
522
- self.device_map = None
523
  self.gradient_checkpointing = False
524
 
525
  # Initialize weights and apply final processing
526
  self.post_init()
527
 
528
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
529
- def parallelize(self, device_map=None):
530
- warnings.warn(
531
- "`GPTJModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
532
- " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
533
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
534
- " ...}",
535
- FutureWarning,
536
- )
537
- # Check validity of device_map
538
- self.device_map = (
539
- get_device_map(len(self.h), range(torch.cuda.device_count()))
540
- if device_map is None
541
- else device_map
542
- )
543
- assert_device_map(self.device_map, len(self.h))
544
- self.model_parallel = True
545
- self.first_device = (
546
- "cpu"
547
- if "cpu" in self.device_map.keys()
548
- else "cuda:" + str(min(self.device_map.keys()))
549
- )
550
- self.last_device = "cuda:" + str(max(self.device_map.keys()))
551
- self.wte = self.wte.to(self.first_device)
552
- # Load onto devices
553
- for k, v in self.device_map.items():
554
- for block in v:
555
- cuda_device = "cuda:" + str(k)
556
- self.h[block] = self.h[block].to(cuda_device)
557
- # ln_f to last
558
- self.ln_f = self.ln_f.to(self.last_device)
559
-
560
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
561
- def deparallelize(self):
562
- warnings.warn(
563
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
564
- FutureWarning,
565
- )
566
- self.model_parallel = False
567
- self.device_map = None
568
- self.first_device = "cpu"
569
- self.last_device = "cpu"
570
- self.wte = self.wte.to("cpu")
571
- for index in range(len(self.h)):
572
- self.h[index] = self.h[index].to("cpu")
573
- self.ln_f = self.ln_f.to("cpu")
574
- torch.cuda.empty_cache()
575
-
576
  def get_input_embeddings(self):
577
  return self.wte
578
 
@@ -702,19 +602,16 @@ class GPTJModel(GPTJPreTrainedModel):
702
  all_self_attentions = () if output_attentions else None
703
  all_hidden_states = () if output_hidden_states else None
704
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
705
- # Model parallel
706
- if self.model_parallel:
707
- torch.cuda.set_device(hidden_states.device)
708
- # Ensure layer_past is on same device as hidden_states (might not be correct)
709
- if layer_past is not None:
710
- layer_past = tuple(
711
- past_state.to(hidden_states.device) for past_state in layer_past
712
- )
713
- # Ensure that attention_mask is always on the same device as hidden_states
714
- if attention_mask is not None:
715
- attention_mask = attention_mask.to(hidden_states.device)
716
- if isinstance(head_mask, torch.Tensor):
717
- head_mask = head_mask.to(hidden_states.device)
718
  if output_hidden_states:
719
  all_hidden_states = all_hidden_states + (hidden_states,)
720
 
@@ -749,12 +646,6 @@ class GPTJModel(GPTJPreTrainedModel):
749
  outputs[2 if use_cache else 1],
750
  )
751
 
752
- # Model Parallel: If it's the last layer for that device, put things on the next device
753
- if self.model_parallel:
754
- for k, v in self.device_map.items():
755
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
756
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
757
-
758
  hidden_states = self.ln_f(hidden_states)
759
 
760
  hidden_states = hidden_states.view(output_shape)
@@ -796,44 +687,9 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
796
  self.transformer = GPTJModel(config)
797
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
798
 
799
- # Model parallel
800
- self.model_parallel = False
801
- self.device_map = None
802
-
803
  # Initialize weights and apply final processing
804
  self.post_init()
805
 
806
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
807
- def parallelize(self, device_map=None):
808
- warnings.warn(
809
- "`GPTJForCausalLM.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
810
- " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
811
- " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
812
- " 0, 'transformer.h.1': 1, ...}",
813
- FutureWarning,
814
- )
815
- self.device_map = (
816
- get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
817
- if device_map is None
818
- else device_map
819
- )
820
- assert_device_map(self.device_map, len(self.transformer.h))
821
- self.transformer.parallelize(self.device_map)
822
- self.lm_head = self.lm_head.to(self.transformer.first_device)
823
- self.model_parallel = True
824
-
825
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
826
- def deparallelize(self):
827
- warnings.warn(
828
- "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
829
- FutureWarning,
830
- )
831
- self.transformer.deparallelize()
832
- self.transformer = self.transformer.to("cpu")
833
- self.lm_head = self.lm_head.to("cpu")
834
- self.model_parallel = False
835
- torch.cuda.empty_cache()
836
-
837
  def get_output_embeddings(self):
838
  return self.lm_head
839
 
@@ -937,9 +793,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
937
  hidden_states = transformer_outputs[0]
938
 
939
  # Set device for model parallelism
940
- if self.model_parallel:
941
- torch.cuda.set_device(self.transformer.first_device)
942
- hidden_states = hidden_states.to(self.lm_head.weight.device)
943
 
944
  # make sure sampling in fp16 works correctly and
945
  # compute loss in fp32 to match with mesh-tf version
@@ -1013,10 +867,6 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
1013
  self.transformer = GPTJModel(config)
1014
  self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1015
 
1016
- # Model parallel
1017
- self.model_parallel = False
1018
- self.device_map = None
1019
-
1020
  # Initialize weights and apply final processing
1021
  self.post_init()
1022
 
@@ -1154,10 +1004,6 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
1154
  self.transformer = GPTJModel(config)
1155
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1156
 
1157
- # Model parallel
1158
- self.model_parallel = False
1159
- self.device_map = None
1160
-
1161
  # Initialize weights and apply final processing
1162
  self.post_init()
1163
 
 
38
  is_torch_fx_proxy,
39
  logging,
40
  )
 
41
  from .configuration_gptj import GPTJConfig
42
 
43
 
 
452
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
453
  """
454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
 
456
  @add_start_docstrings(
457
  "The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.",
 
468
  self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])
469
  self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
470
 
 
 
 
471
  self.gradient_checkpointing = False
472
 
473
  # Initialize weights and apply final processing
474
  self.post_init()
475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  def get_input_embeddings(self):
477
  return self.wte
478
 
 
602
  all_self_attentions = () if output_attentions else None
603
  all_hidden_states = () if output_hidden_states else None
604
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
605
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
606
+ if layer_past is not None:
607
+ layer_past = tuple(
608
+ past_state.to(hidden_states.device) for past_state in layer_past
609
+ )
610
+ # Ensure that attention_mask is always on the same device as hidden_states
611
+ if attention_mask is not None:
612
+ attention_mask = attention_mask.to(hidden_states.device)
613
+ if isinstance(head_mask, torch.Tensor):
614
+ head_mask = head_mask.to(hidden_states.device)
 
 
 
615
  if output_hidden_states:
616
  all_hidden_states = all_hidden_states + (hidden_states,)
617
 
 
646
  outputs[2 if use_cache else 1],
647
  )
648
 
 
 
 
 
 
 
649
  hidden_states = self.ln_f(hidden_states)
650
 
651
  hidden_states = hidden_states.view(output_shape)
 
687
  self.transformer = GPTJModel(config)
688
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
689
 
 
 
 
 
690
  # Initialize weights and apply final processing
691
  self.post_init()
692
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
  def get_output_embeddings(self):
694
  return self.lm_head
695
 
 
793
  hidden_states = transformer_outputs[0]
794
 
795
  # Set device for model parallelism
796
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
 
 
797
 
798
  # make sure sampling in fp16 works correctly and
799
  # compute loss in fp32 to match with mesh-tf version
 
867
  self.transformer = GPTJModel(config)
868
  self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
869
 
 
 
 
 
870
  # Initialize weights and apply final processing
871
  self.post_init()
872
 
 
1004
  self.transformer = GPTJModel(config)
1005
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1006
 
 
 
 
 
1007
  # Initialize weights and apply final processing
1008
  self.post_init()
1009
 
test.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from datasets import load_dataset
3
+ from transformers import AutoModelForCausalLM
4
+
5
+ perplexity = evaluate.load("d-matrix/dmx_perplexity", module_type="metric")
6
+ input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10]
7
+ model = AutoModelForCausalLM.from_pretrained(
8
+ pretrained_model_name_or_path="d-matrix/gpt-j-6b",
9
+ # trust_remote_code=True,
10
+ device_map="auto",
11
+ use_auth_token=True,
12
+ )
13
+ results = perplexity.compute(model=model, references=input_texts)