“sayehs”
commited on
Commit
•
c1efd44
1
Parent(s):
94a9749
WIP
Browse files- config.json +0 -4
- modeling_gptj.py +11 -165
- test.py +13 -0
config.json
CHANGED
@@ -3,10 +3,6 @@
|
|
3 |
"architectures": [
|
4 |
"GPTJForCausalLM"
|
5 |
],
|
6 |
-
"auto_map": {
|
7 |
-
"AutoConfig": "configuration_gptj.GPTJConfig",
|
8 |
-
"AutoModelForCausalLM": "modeling_gptj.GPTJForCausalLM"
|
9 |
-
},
|
10 |
"attn_pdrop": 0.0,
|
11 |
"bos_token_id": 50256,
|
12 |
"embd_pdrop": 0.0,
|
|
|
3 |
"architectures": [
|
4 |
"GPTJForCausalLM"
|
5 |
],
|
|
|
|
|
|
|
|
|
6 |
"attn_pdrop": 0.0,
|
7 |
"bos_token_id": 50256,
|
8 |
"embd_pdrop": 0.0,
|
modeling_gptj.py
CHANGED
@@ -38,7 +38,6 @@ from transformers.utils import (
|
|
38 |
is_torch_fx_proxy,
|
39 |
logging,
|
40 |
)
|
41 |
-
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
42 |
from .configuration_gptj import GPTJConfig
|
43 |
|
44 |
|
@@ -453,54 +452,6 @@ GPTJ_INPUTS_DOCSTRING = r"""
|
|
453 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
454 |
"""
|
455 |
|
456 |
-
PARALLELIZE_DOCSTRING = r"""
|
457 |
-
This is an experimental feature and is a subject to change at a moment's notice. Uses a device map to distribute
|
458 |
-
attention modules of the model across several devices. If no device map is given, it will evenly distribute blocks
|
459 |
-
across all devices.
|
460 |
-
|
461 |
-
Args:
|
462 |
-
device_map (`Dict[int, list]`, optional, defaults to None):
|
463 |
-
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
464 |
-
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
465 |
-
have fewer attention modules mapped to it than other devices. For reference, the GPT-J models have the
|
466 |
-
following number of attention modules:
|
467 |
-
|
468 |
-
- gpt-j-6B: 28
|
469 |
-
|
470 |
-
Example:
|
471 |
-
|
472 |
-
```python
|
473 |
-
# Here is an example of a device map on a machine with 4 GPUs using gpt-j-6B, which has a total of 28 attention modules:
|
474 |
-
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
|
475 |
-
device_map = {
|
476 |
-
0: [0, 1, 2, 3, 4, 5, 6],
|
477 |
-
1: [7, 8, 9, 10, 11, 12, 13],
|
478 |
-
2: [14, 15, 16, 17, 18, 19, 20],
|
479 |
-
3: [21, 22, 23, 24, 25, 26, 27],
|
480 |
-
}
|
481 |
-
model.parallelize(device_map)
|
482 |
-
```
|
483 |
-
"""
|
484 |
-
|
485 |
-
DEPARALLELIZE_DOCSTRING = r"""
|
486 |
-
Moves the model to CPU from a model parallel state.
|
487 |
-
|
488 |
-
Example:
|
489 |
-
|
490 |
-
```python
|
491 |
-
# On a 4 GPU machine with gpt-j-6B:
|
492 |
-
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
|
493 |
-
device_map = {
|
494 |
-
0: [0, 1, 2, 3, 4, 5, 6],
|
495 |
-
1: [7, 8, 9, 10, 11, 12, 13],
|
496 |
-
2: [14, 15, 16, 17, 18, 19, 20],
|
497 |
-
3: [21, 22, 23, 24, 25, 26, 27],
|
498 |
-
}
|
499 |
-
model.parallelize(device_map) # Splits the model across several devices
|
500 |
-
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
501 |
-
```
|
502 |
-
"""
|
503 |
-
|
504 |
|
505 |
@add_start_docstrings(
|
506 |
"The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.",
|
@@ -517,62 +468,11 @@ class GPTJModel(GPTJPreTrainedModel):
|
|
517 |
self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])
|
518 |
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
519 |
|
520 |
-
# Model parallel
|
521 |
-
self.model_parallel = False
|
522 |
-
self.device_map = None
|
523 |
self.gradient_checkpointing = False
|
524 |
|
525 |
# Initialize weights and apply final processing
|
526 |
self.post_init()
|
527 |
|
528 |
-
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
529 |
-
def parallelize(self, device_map=None):
|
530 |
-
warnings.warn(
|
531 |
-
"`GPTJModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
|
532 |
-
" model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
533 |
-
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
|
534 |
-
" ...}",
|
535 |
-
FutureWarning,
|
536 |
-
)
|
537 |
-
# Check validity of device_map
|
538 |
-
self.device_map = (
|
539 |
-
get_device_map(len(self.h), range(torch.cuda.device_count()))
|
540 |
-
if device_map is None
|
541 |
-
else device_map
|
542 |
-
)
|
543 |
-
assert_device_map(self.device_map, len(self.h))
|
544 |
-
self.model_parallel = True
|
545 |
-
self.first_device = (
|
546 |
-
"cpu"
|
547 |
-
if "cpu" in self.device_map.keys()
|
548 |
-
else "cuda:" + str(min(self.device_map.keys()))
|
549 |
-
)
|
550 |
-
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
551 |
-
self.wte = self.wte.to(self.first_device)
|
552 |
-
# Load onto devices
|
553 |
-
for k, v in self.device_map.items():
|
554 |
-
for block in v:
|
555 |
-
cuda_device = "cuda:" + str(k)
|
556 |
-
self.h[block] = self.h[block].to(cuda_device)
|
557 |
-
# ln_f to last
|
558 |
-
self.ln_f = self.ln_f.to(self.last_device)
|
559 |
-
|
560 |
-
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
561 |
-
def deparallelize(self):
|
562 |
-
warnings.warn(
|
563 |
-
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
564 |
-
FutureWarning,
|
565 |
-
)
|
566 |
-
self.model_parallel = False
|
567 |
-
self.device_map = None
|
568 |
-
self.first_device = "cpu"
|
569 |
-
self.last_device = "cpu"
|
570 |
-
self.wte = self.wte.to("cpu")
|
571 |
-
for index in range(len(self.h)):
|
572 |
-
self.h[index] = self.h[index].to("cpu")
|
573 |
-
self.ln_f = self.ln_f.to("cpu")
|
574 |
-
torch.cuda.empty_cache()
|
575 |
-
|
576 |
def get_input_embeddings(self):
|
577 |
return self.wte
|
578 |
|
@@ -702,19 +602,16 @@ class GPTJModel(GPTJPreTrainedModel):
|
|
702 |
all_self_attentions = () if output_attentions else None
|
703 |
all_hidden_states = () if output_hidden_states else None
|
704 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
705 |
-
#
|
706 |
-
if
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
attention_mask = attention_mask.to(hidden_states.device)
|
716 |
-
if isinstance(head_mask, torch.Tensor):
|
717 |
-
head_mask = head_mask.to(hidden_states.device)
|
718 |
if output_hidden_states:
|
719 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
720 |
|
@@ -749,12 +646,6 @@ class GPTJModel(GPTJPreTrainedModel):
|
|
749 |
outputs[2 if use_cache else 1],
|
750 |
)
|
751 |
|
752 |
-
# Model Parallel: If it's the last layer for that device, put things on the next device
|
753 |
-
if self.model_parallel:
|
754 |
-
for k, v in self.device_map.items():
|
755 |
-
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
756 |
-
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
757 |
-
|
758 |
hidden_states = self.ln_f(hidden_states)
|
759 |
|
760 |
hidden_states = hidden_states.view(output_shape)
|
@@ -796,44 +687,9 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|
796 |
self.transformer = GPTJModel(config)
|
797 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
798 |
|
799 |
-
# Model parallel
|
800 |
-
self.model_parallel = False
|
801 |
-
self.device_map = None
|
802 |
-
|
803 |
# Initialize weights and apply final processing
|
804 |
self.post_init()
|
805 |
|
806 |
-
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
807 |
-
def parallelize(self, device_map=None):
|
808 |
-
warnings.warn(
|
809 |
-
"`GPTJForCausalLM.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
|
810 |
-
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
811 |
-
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
|
812 |
-
" 0, 'transformer.h.1': 1, ...}",
|
813 |
-
FutureWarning,
|
814 |
-
)
|
815 |
-
self.device_map = (
|
816 |
-
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
817 |
-
if device_map is None
|
818 |
-
else device_map
|
819 |
-
)
|
820 |
-
assert_device_map(self.device_map, len(self.transformer.h))
|
821 |
-
self.transformer.parallelize(self.device_map)
|
822 |
-
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
823 |
-
self.model_parallel = True
|
824 |
-
|
825 |
-
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
826 |
-
def deparallelize(self):
|
827 |
-
warnings.warn(
|
828 |
-
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
829 |
-
FutureWarning,
|
830 |
-
)
|
831 |
-
self.transformer.deparallelize()
|
832 |
-
self.transformer = self.transformer.to("cpu")
|
833 |
-
self.lm_head = self.lm_head.to("cpu")
|
834 |
-
self.model_parallel = False
|
835 |
-
torch.cuda.empty_cache()
|
836 |
-
|
837 |
def get_output_embeddings(self):
|
838 |
return self.lm_head
|
839 |
|
@@ -937,9 +793,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel):
|
|
937 |
hidden_states = transformer_outputs[0]
|
938 |
|
939 |
# Set device for model parallelism
|
940 |
-
|
941 |
-
torch.cuda.set_device(self.transformer.first_device)
|
942 |
-
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
943 |
|
944 |
# make sure sampling in fp16 works correctly and
|
945 |
# compute loss in fp32 to match with mesh-tf version
|
@@ -1013,10 +867,6 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
|
|
1013 |
self.transformer = GPTJModel(config)
|
1014 |
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
1015 |
|
1016 |
-
# Model parallel
|
1017 |
-
self.model_parallel = False
|
1018 |
-
self.device_map = None
|
1019 |
-
|
1020 |
# Initialize weights and apply final processing
|
1021 |
self.post_init()
|
1022 |
|
@@ -1154,10 +1004,6 @@ class GPTJForQuestionAnswering(GPTJPreTrainedModel):
|
|
1154 |
self.transformer = GPTJModel(config)
|
1155 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
1156 |
|
1157 |
-
# Model parallel
|
1158 |
-
self.model_parallel = False
|
1159 |
-
self.device_map = None
|
1160 |
-
|
1161 |
# Initialize weights and apply final processing
|
1162 |
self.post_init()
|
1163 |
|
|
|
38 |
is_torch_fx_proxy,
|
39 |
logging,
|
40 |
)
|
|
|
41 |
from .configuration_gptj import GPTJConfig
|
42 |
|
43 |
|
|
|
452 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
453 |
"""
|
454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
|
456 |
@add_start_docstrings(
|
457 |
"The bare GPT-J Model transformer outputting raw hidden-states without any specific head on top.",
|
|
|
468 |
self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])
|
469 |
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
470 |
|
|
|
|
|
|
|
471 |
self.gradient_checkpointing = False
|
472 |
|
473 |
# Initialize weights and apply final processing
|
474 |
self.post_init()
|
475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
def get_input_embeddings(self):
|
477 |
return self.wte
|
478 |
|
|
|
602 |
all_self_attentions = () if output_attentions else None
|
603 |
all_hidden_states = () if output_hidden_states else None
|
604 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
605 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
606 |
+
if layer_past is not None:
|
607 |
+
layer_past = tuple(
|
608 |
+
past_state.to(hidden_states.device) for past_state in layer_past
|
609 |
+
)
|
610 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
611 |
+
if attention_mask is not None:
|
612 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
613 |
+
if isinstance(head_mask, torch.Tensor):
|
614 |
+
head_mask = head_mask.to(hidden_states.device)
|
|
|
|
|
|
|
615 |
if output_hidden_states:
|
616 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
617 |
|
|
|
646 |
outputs[2 if use_cache else 1],
|
647 |
)
|
648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
hidden_states = self.ln_f(hidden_states)
|
650 |
|
651 |
hidden_states = hidden_states.view(output_shape)
|
|
|
687 |
self.transformer = GPTJModel(config)
|
688 |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
689 |
|
|
|
|
|
|
|
|
|
690 |
# Initialize weights and apply final processing
|
691 |
self.post_init()
|
692 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
def get_output_embeddings(self):
|
694 |
return self.lm_head
|
695 |
|
|
|
793 |
hidden_states = transformer_outputs[0]
|
794 |
|
795 |
# Set device for model parallelism
|
796 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
|
|
|
|
797 |
|
798 |
# make sure sampling in fp16 works correctly and
|
799 |
# compute loss in fp32 to match with mesh-tf version
|
|
|
867 |
self.transformer = GPTJModel(config)
|
868 |
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
869 |
|
|
|
|
|
|
|
|
|
870 |
# Initialize weights and apply final processing
|
871 |
self.post_init()
|
872 |
|
|
|
1004 |
self.transformer = GPTJModel(config)
|
1005 |
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
1006 |
|
|
|
|
|
|
|
|
|
1007 |
# Initialize weights and apply final processing
|
1008 |
self.post_init()
|
1009 |
|
test.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
from datasets import load_dataset
|
3 |
+
from transformers import AutoModelForCausalLM
|
4 |
+
|
5 |
+
perplexity = evaluate.load("d-matrix/dmx_perplexity", module_type="metric")
|
6 |
+
input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10]
|
7 |
+
model = AutoModelForCausalLM.from_pretrained(
|
8 |
+
pretrained_model_name_or_path="d-matrix/gpt-j-6b",
|
9 |
+
# trust_remote_code=True,
|
10 |
+
device_map="auto",
|
11 |
+
use_auth_token=True,
|
12 |
+
)
|
13 |
+
results = perplexity.compute(model=model, references=input_texts)
|